Source code for EduNLP.ModelZoo.rnn.rnn

# coding: utf-8
# 2021/7/12 @ tongshiwei

import torch
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from baize.torch import load_net
import torch.nn.functional as F
import json


[docs]class LM(nn.Module): """ Parameters ---------- rnn_type:str Legal types including RNN, LSTM, GRU, BiLSTM vocab_size: int embedding_dim: int hidden_size: int num_layers bidirectional embedding model_params kwargs Examples -------- >>> import torch >>> seq_idx = torch.LongTensor([[1, 2, 3], [1, 2, 0], [3, 0, 0]]) >>> seq_len = torch.LongTensor([3, 2, 1]) >>> lm = LM("RNN", 4, 3, 2) >>> output, hn = lm(seq_idx, seq_len) >>> output.shape torch.Size([3, 3, 2]) >>> hn.shape torch.Size([1, 3, 2]) >>> lm = LM("RNN", 4, 3, 2, num_layers=2) >>> output, hn = lm(seq_idx, seq_len) >>> output.shape torch.Size([3, 3, 2]) >>> hn.shape torch.Size([2, 3, 2]) """ def __init__(self, rnn_type: str, vocab_size: int, embedding_dim: int, hidden_size: int, num_layers=1, bidirectional=False, embedding=None, model_params=None, **kwargs): super(LM, self).__init__() rnn_type = rnn_type.upper() self.embedding = torch.nn.Embedding(vocab_size, embedding_dim) if embedding is None else embedding self.c = False if rnn_type == "RNN": self.rnn = torch.nn.RNN( embedding_dim, hidden_size, num_layers, bidirectional=bidirectional, **kwargs ) elif rnn_type == "LSTM": self.rnn = torch.nn.LSTM( embedding_dim, hidden_size, num_layers, bidirectional=bidirectional, **kwargs ) self.c = True elif rnn_type == "GRU": self.rnn = torch.nn.GRU( embedding_dim, hidden_size, num_layers, bidirectional=bidirectional, **kwargs ) elif rnn_type == "BILSTM": bidirectional = True self.rnn = torch.nn.LSTM( embedding_dim, hidden_size, num_layers, bidirectional=bidirectional, **kwargs ) self.c = True else: raise TypeError("Unknown rnn_type %s" % rnn_type) self.num_layers = num_layers self.bidirectional = bidirectional if bidirectional is True: self.num_layers *= 2 self.hidden_size = hidden_size if model_params: load_net(model_params, self, allow_missing=True)
[docs] def forward(self, seq_idx, seq_len): """ Parameters ---------- seq_idx:Tensor a list of indices seq_len:Tensor length Returns -------- sequence a PackedSequence object """ seq = self.embedding(seq_idx) pack = pack_padded_sequence(seq, seq_len.cpu(), batch_first=True, enforce_sorted=False) h0 = torch.zeros(self.num_layers, seq.shape[0], self.hidden_size) if self.c is True: c0 = torch.zeros(self.num_layers, seq.shape[0], self.hidden_size).to(seq_idx.device) output, (hn, _) = self.rnn(pack, (h0, c0)) else: output, hn = self.rnn(pack, h0) output, _ = pad_packed_sequence(output, batch_first=True) return output, hn
[docs]class ElmoLM(nn.Module): def __init__(self, vocab_size: int, embedding_dim: int, hidden_size: int, dropout_rate: float = 0.5, batch_first=True): super(ElmoLM, self).__init__() self.LM_layer = LM("BiLSTM", vocab_size, embedding_dim, hidden_size, num_layers=2, batch_first=batch_first) self.pred_layer = nn.Linear(hidden_size, vocab_size) self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.hidden_size = hidden_size self.dropout = nn.Dropout(dropout_rate)
[docs] def forward(self, seq_idx, seq_len): """ Parameters ---------- seq_idx:Tensor, of shape (batch_size, sequence_length) a list of indices seq_len:Tensor, of shape (batch_size) length Returns ---------- pred_forward: of shape (batch_size, sequence_length) pred_backward: of shape (batch_size, sequence_length) forward_output: of shape (batch_size, sequence_length, hidden_size) backward_output: of shape (batch_size, sequence_length, hidden_size) """ lm_output, _ = self.LM_layer(seq_idx, seq_len) forward_output = lm_output[:, :, :self.hidden_size] backward_output = lm_output[:, :, self.hidden_size:] forward_output = self.dropout(forward_output) backward_output = self.dropout(backward_output) pred_forward = F.softmax(input=self.pred_layer(forward_output), dim=-1) pred_backward = F.softmax(input=self.pred_layer(backward_output), dim=-1) return pred_forward, pred_backward, forward_output, backward_output