Source code for EduNLP.Vector.embedding

# coding: utf-8
# 2021/7/12 @ tongshiwei

from typing import List
import torch
from .gensim_vec import W2V
from .const import PAD
from EduNLP.ModelZoo import pad_sequence, set_device


[docs]class Embedding(object): def __init__(self, w2v: (W2V, tuple, list, dict, None), freeze=True, device=None, **kwargs): if w2v is None: self.w2v = None elif isinstance(w2v, (tuple, list)): self.w2v = W2V(*w2v) elif isinstance(w2v, dict): self.w2v = W2V(**w2v) elif isinstance(w2v, W2V): self.w2v = w2v else: raise TypeError("w2v argument must be one of W2V, tuple, list, dict or None, now is %s" % type(w2v)) if self.w2v is not None: self.vocab_size = len(self.w2v) self.embedding_dim = self.w2v.vector_size else: self.vocab_size = kwargs["vocab_size"] self.embedding_dim = kwargs["embedding_dim"] self.embedding = torch.nn.Embedding(self.vocab_size, self.embedding_dim) self.pad_val = 0 if self.w2v is not None: self.embedding.from_pretrained(torch.Tensor(self.w2v.vectors), freeze) self.pad_val = self.w2v.constants[PAD] self.key_to_index = self.w2v.key_to_index if w2v is not None else lambda x: x if device is not None: self.set_device(device) def __call__(self, items: List[List[str]], indexing=True, padding=True, vectorization=True, *args, **kwargs) -> tuple: items, item_len = self.indexing(items, padding=padding, indexing=indexing) items = self.infer_token_vector(items, indexing=False)[0] if vectorization else items return items, item_len
[docs] def infer_token_vector(self, items: List[List[str]], indexing=True) -> tuple: items, item_len = self.indexing(items, padding=True, indexing=indexing) item_embedding = self.embedding(torch.LongTensor(items)) return item_embedding, item_len
[docs] def indexing(self, items: List[List[str]], padding=False, indexing=True) -> tuple: """ Parameters ---------- items: list of list of str(word/token) padding: bool whether padding the returned list with default pad_val to make all item in items have the same length indexing: bool Returns ------- token_idx: list of list of int the list of the tokens of each item token_len: list of int the list of the length of tokens of each item """ items_idx = [[self.key_to_index(word) for word in item] for item in items] if indexing else items item_len = [len(_idx) for _idx in items_idx] padded_items_idx = pad_sequence(items_idx, pad_val=self.pad_val) if padding is True else items_idx return padded_items_idx, item_len
[docs] def set_device(self, device): self.embedding = set_device(self.embedding, device) return self