Source code for EduNLP.Vector.rnn.rnn

# coding: utf-8
# 2021/7/12 @ tongshiwei

import torch
from ..gensim_vec import W2V
from ..embedding import Embedding
from ..meta import Vector
from EduNLP.ModelZoo import rnn, set_device
from baize.torch import save_params


[docs]class RNNModel(Vector): """ Examples -------- >>> model = RNNModel("BiLSTM", None, 2, vocab_size=4, embedding_dim=3) >>> seq_idx = [[1, 2, 3], [1, 2, 0], [3, 0, 0]] >>> output, hn = model(seq_idx, indexing=False, padding=False) >>> seq_idx = [[1, 2, 3], [1, 2], [3]] >>> output, hn = model(seq_idx, indexing=False, padding=True) >>> output.shape torch.Size([3, 3, 4]) >>> hn.shape torch.Size([2, 3, 2]) >>> tokens = model.infer_tokens(seq_idx, indexing=False) >>> tokens.shape torch.Size([3, 3, 4]) >>> tokens = model.infer_tokens(seq_idx, agg="mean", indexing=False) >>> tokens.shape torch.Size([3, 4]) >>> item = model.infer_vector(seq_idx, indexing=False) >>> item.shape torch.Size([3, 4]) >>> item = model.infer_vector(seq_idx, agg="mean", indexing=False) >>> item.shape torch.Size([3, 2]) >>> item = model.infer_vector(seq_idx, agg=None, indexing=False) >>> item.shape torch.Size([2, 3, 2]) """ def __init__(self, rnn_type, w2v: (W2V, tuple, list, dict, None), hidden_size, freeze_pretrained=True, model_params=None, device=None, **kwargs): self.embedding = Embedding(w2v, freeze_pretrained, **kwargs) for key in ["vocab_size", "embedding_dim"]: if key in kwargs: kwargs.pop(key) self.rnn = rnn.LM( rnn_type, self.embedding.vocab_size, self.embedding.embedding_dim, hidden_size=hidden_size, embedding=self.embedding.embedding, model_params=model_params, **kwargs ) self.bidirectional = self.rnn.rnn.bidirectional self.hidden_size = self.rnn.hidden_size self.freeze_pretrained = freeze_pretrained if device is not None: self.set_device(device) def __call__(self, items, indexing=True, padding=True, **kwargs): seq_idx, seq_len = self.embedding(items, indexing=indexing, padding=padding, vectorization=False) tokens, item = self.rnn(torch.LongTensor(seq_idx), torch.LongTensor(seq_len)) return tokens, item
[docs] def infer_vector(self, items, agg: (int, str, None) = -1, indexing=True, padding=True, *args, **kwargs) -> torch.Tensor: vector = self(items, indexing=indexing, padding=padding, **kwargs)[1] if agg is not None: if agg == -1: return torch.reshape(vector, (vector.shape[1], -1)) return eval("torch.%s" % agg)(vector, dim=0) return vector
[docs] def infer_tokens(self, items, agg=None, *args, **kwargs) -> torch.Tensor: tokens = self(items, **kwargs)[0] if agg is not None: return eval("torch.%s" % agg)(tokens, dim=1) return tokens
@property def vector_size(self) -> int: return self.hidden_size * (1 if self.bidirectional is False else 2)
[docs] def set_device(self, device): self.rnn = set_device(self.rnn, device)
[docs] def save(self, filepath, save_embedding=False): save_params(filepath, self.rnn, select=None if save_embedding is True else '^(?!.*embedding)') return filepath
[docs] def freeze(self, *args, **kwargs): return self.eval()
@property def is_frozen(self): return not self.rnn.training
[docs] def eval(self): self.rnn.eval() return self
[docs] def train(self, mode=True): self.rnn.train(mode) return self