Source code for EduNLP.Vector.elmo_vec

from pathlib import PurePath
import os
import torch
import torch.optim as optim
import torch.nn as nn
import torch.utils.data as tud
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from EduNLP.Pretrain import train_elmo, ElmoTokenizer
from EduNLP.ModelZoo.rnn import ElmoLM
from .meta import Vector
import json


[docs]class ElmoModel(Vector): def __init__(self, pretrained_model_path: str): """ Parameters ---------- pretrained_model_path: str """ super(ElmoModel, self).__init__() with open(os.path.join(pretrained_model_path, 'config.json'), 'r') as f: config = json.load(f) self.model = ElmoLM(vocab_size=config['vocab_size'], embedding_dim=config['emb_dim'], hidden_size=config['hid_dim']) self.model.load_state_dict(torch.load(os.path.join(pretrained_model_path, 'weight.pt'))) self.model.eval() def __call__(self, *args, **kwargs): return self.infer_vector(*args, **kwargs)
[docs] def infer_vector(self, items, *args, **kwargs) -> torch.Tensor: is_batch = isinstance(items[0], list) items = items if is_batch else [items] lengths = kwargs.get('lengths', [len(i) for i in items]) pred_forward, pred_backward, forward_hiddens, backward_hiddens = self.model(torch.tensor(items), torch.tensor(lengths, dtype=torch.int64)) ret = torch.cat( (forward_hiddens[torch.arange(len(lengths)), torch.tensor(lengths) - 1], backward_hiddens[torch.arange(len(lengths)), max(lengths) - torch.tensor(lengths)]), dim=-1) if is_batch else torch.cat( (forward_hiddens[0, -1, :], backward_hiddens[0, 0, :]), dim=-1) return ret
[docs] def infer_tokens(self, items, *args, **kwargs) -> torch.Tensor: is_batch = isinstance(items[0], list) items = items if is_batch else [items] lengths = kwargs.get('lengths', [len(i) for i in items]) pred_forward, pred_backward, forward_hiddens, backward_hiddens = self.model(torch.tensor(items), torch.tensor(lengths, dtype=torch.int64)) if is_batch: ret = [] for fh, bh, lg in zip(forward_hiddens, backward_hiddens, lengths): _bh = torch.cat((torch.flip(bh[:lg], [0]), bh[lg:]), dim=0) ret.append(torch.cat((fh, _bh), dim=-1)) return torch.stack(tuple(ret)) else: return torch.cat((forward_hiddens[0], torch.flip(backward_hiddens, [1])[0]), dim=-1)
@property def vector_size(self): return 2 * self.model.hidden_size