Source code for EduNLP.ModelZoo.disenqnet.disenqnet

# -*- coding: utf-8 -*-

import logging
import torch
from torch import nn
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import StepLR
import os
import json
from gensim.models import KeyedVectors

from .modules import TextEncoder, AttnModel, ConceptEstimator, MIEstimator, DisenEstimator
from .utils import get_mask
from ..utils import set_device


class QuestionEncoder(nn.Module):
    """
    DisenQNet question representation model

    Parameters
    ----------
    vocab_size: int
        size of vocabulary
    hidden_dim: int
        size of word and question embedding
    dropout: float
        dropout rate
    wv: torch.Tensor
        Tensor of (vocab_size, hidden_dim) or None, initial word embedding, default = None
    """

    def __init__(self, vocab_size, hidden_dim, dropout, wv=None):
        super(QuestionEncoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.encoder = TextEncoder(vocab_size, hidden_dim, dropout, wv=wv)
        self.k_model = AttnModel(hidden_dim, dropout)
        self.i_model = AttnModel(hidden_dim, dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, input, length, get_vk=True, get_vi=True):
        """
        Parameters
        ----------
        input: Tensor of (batch_size, seq_len)
            word index
        length: Tensor of (batch_size)
            valid sequence length of each batch
        get_vk: bool
            whether to return vk
        get_vi: bool
            whether to return vi

        Returns
        -------
        (embed, k_hidden, i_hidden)
            - embed: Tensor of (batch_size, seq_len, hidden_dim), word embedding
            - k_hidden: Tensor of (batch_size, hidden_dim) or None, concept representation of question
            - i_hidden: Tensor of (batch_size, hidden_dim) or None, individual representation of question
        """
        # embed: batch_size * seq_len * hidden_dim
        # q_hidden: batch_size * hidden_dim
        embed, q_hidden = self.encoder(input)
        # batch_size * seq_len, 0 means valid, 1 means pad
        mask = get_mask(input.size(1), length)
        embed.masked_fill_(mask.unsqueeze(-1), 0)
        k_hidden, i_hidden = None, None
        q_hidden_dp = self.dropout(q_hidden)
        embed_dp = self.dropout(embed)
        # batch_size * hidden_dim
        if get_vk:
            k_hidden, _ = self.k_model(q_hidden_dp, embed_dp, embed_dp, mask)
        if get_vi:
            i_hidden, _ = self.i_model(q_hidden_dp, embed_dp, embed_dp, mask)
        return embed, k_hidden, i_hidden


[docs]class DisenQNet(object): """ DisenQNet training and evaluation model Parameters ---------- vocab_size: int size of vocabulary concept_size: int number of concept classes hidden_dim: int size of word and question embedding dropout: float dropout rate pos_weight: float positive sample weight in unbalanced multi-label concept classifier w_cp: float weight of concept loss w_mi: float weight of mutual information loss w_dis: float weight of disentangling loss wv: torch.Tensor Tensor of (vocab_size, hidden_dim) or None, initial word embedding, default = None device: str, defaults as 'cpu' Set device for model, examples 'cpu'、'cuda'、'cuda:0,2' """ def __init__(self, vocab_size, concept_size, hidden_dim, dropout, pos_weight, w_cp, w_mi, w_dis, wv=None, device="cpu"): super(DisenQNet, self).__init__() self.disen_q_net = QuestionEncoder(vocab_size, hidden_dim, dropout, wv) self.mi_estimator = MIEstimator(hidden_dim, hidden_dim * 2, dropout) self.concept_estimator = ConceptEstimator(hidden_dim, concept_size, pos_weight, dropout) self.disen_estimator = DisenEstimator(hidden_dim, dropout) self.w_cp = w_cp self.w_mi = w_mi self.w_dis = w_dis self.hidden_dim = hidden_dim self.params = { "vocab_size": vocab_size, "concept_size": concept_size, "hidden_dim": hidden_dim, "dropout": dropout, "pos_weight": pos_weight, "w_cp": w_cp, "w_mi": w_mi, "w_dis": w_dis, } self.modules = (self.disen_q_net, self.mi_estimator, self.concept_estimator, self.disen_estimator) self.to(device)
[docs] def train(self, train_data, test_data, epoch, lr, step_size, gamma, warm_up, n_adversarial, silent): """ train DisenQNet Parameters ---------- train_data: train dataloader, contains text, length, concept - text: Tensor of (batch_size, seq_len) - length: Tensor of (batch_size) - concept: Tensor of (batch_size, class_size) test_data: test dataloader epoch: int number of epoch lr: float initial learning rate step_size: int step_size for StepLR, period of learning rate decay gamma: float gamma for StepLR, multiplicative factor of learning rate decay warm_up: int number of epoch for warming up, without adversarial process for dis_loss n_adversarial: int ratio of disc/enc training for adversarial process silent: bool whether to log loss """ if not silent: print("Start training the disenQNet...") # optimizer & scheduler model_params = list() for params in [list(self.disen_q_net.parameters()), list(self.mi_estimator.parameters()), list(self.concept_estimator.parameters())]: model_params.extend(params) adv_params = list(self.disen_estimator.parameters()) optimizer = Adam(model_params, lr=lr) adv_optimizer = Adam(adv_params, lr=lr) scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma) adv_scheduler = StepLR(adv_optimizer, step_size=step_size, gamma=gamma) # train epoch = warm_up + epoch for epoch_idx in range(epoch): epoch_idx += 1 # warming_up: cp_loss & mi_loss only, ignore adversarial dis_loss warming_up = (epoch_idx <= warm_up) self.set_mode(True) for data in train_data: text, length, concept = data text, length, concept = text.to(self.device), length.to(self.device), concept.to(self.device) # WGAN-like adversarial training: min_enc max_disc dis_loss # train disc if not warming_up: _, k_hidden, i_hidden = self.disen_q_net(text, length) # stop gradient propagation to encoder k_hidden, i_hidden = k_hidden.detach(), i_hidden.detach() # max dis_loss dis_loss = - self.disen_estimator(k_hidden, i_hidden) dis_loss = n_adversarial * self.w_dis * dis_loss adv_optimizer.zero_grad() dis_loss.backward() adv_optimizer.step() # Lipschitz constrain for Disc of WGAN self.disen_estimator.spectral_norm() # train enc embed, k_hidden, i_hidden = self.disen_q_net(text, length) hidden = torch.cat((k_hidden, i_hidden), dim=-1) # max mi mi_loss = - self.mi_estimator(embed, hidden, length) # min concept_loss cp_loss = self.concept_estimator(k_hidden, concept) if warming_up: loss = self.w_mi * mi_loss + self.w_cp * cp_loss else: # min dis dis_loss = self.disen_estimator(k_hidden, i_hidden) loss = self.w_mi * mi_loss + self.w_cp * cp_loss + self.w_dis * dis_loss optimizer.zero_grad() loss.backward() optimizer.step() if not warming_up: scheduler.step() adv_scheduler.step() # test train_loss = self.eval(train_data) if test_data is not None and not warming_up: test_loss = self.eval(test_data) if not silent: print(f"[Epoch {epoch_idx:2d}] train loss: {train_loss:.4f}, eval loss: {test_loss:.4f}") elif not silent: print(f"[Epoch {epoch_idx:2d}] train loss: {train_loss:.4f}") return
[docs] def inference(self, items: dict): """ DisenQNet for i2v inference. Now not support for batch ! Parameters ---------- items: dict which contains content_idx and content_len - content_idx: Tensor of (batch_size, seq_len) - content_len: Tensor of (batch_size) device: str cpu or cuda Returns --------- embed: torch.Tensor Tensor of (batch_size, seq_len, hidden_dim) k_hidden: torch.Tensor Tensor of (batch_size, hidden_dim) i_hidden: torch.Tensor Tensor of (batch_size, hidden_dim) """ self.set_mode(False) text, length = items["content_idx"].to(self.device), items["content_len"].to(self.device) embed, k_hidden, i_hidden = self.disen_q_net(text, length) return embed, k_hidden, i_hidden
[docs] def eval(self, test_data): """ eval DisenQNet Parameters ---------- test_data: iterable, train dataset, contains text, length, concept - text: Tensor of (batch_size, seq_len) - length: Tensor of (batch_size) - concept: Tensor of (batch_size, class_size) device: str cpu or cuda Returns --------- loss: float average loss for test dataset """ total_size = 0 total_loss = 0 self.set_mode(False) with torch.no_grad(): for data in test_data: text, length, concept = data text, length, concept = text.to(self.device), length.to(self.device), concept.to(self.device) embed, k_hidden, i_hidden = self.disen_q_net(text, length) hidden = torch.cat((k_hidden, i_hidden), dim=-1) mi_loss = - self.mi_estimator(embed, hidden, length) cp_loss = self.concept_estimator(k_hidden, concept) dis_loss = self.disen_estimator(k_hidden, i_hidden) loss = self.w_mi * mi_loss + self.w_cp * cp_loss + self.w_dis * dis_loss batch_size = text.size(0) total_size += batch_size total_loss += loss.item() * batch_size loss = total_loss / total_size return loss
[docs] def save_pretrained(self, output_dir): filepath = os.path.join(output_dir, "disen_q_net.th") config_path = os.path.join(output_dir, "model_config.json") state_dicts = [module.state_dict() for module in self.modules] torch.save(state_dicts, filepath) self.save_config(config_path) return
[docs] def load(self, filepath): state_dicts = torch.load(filepath, map_location='cpu') for module, state_dict in zip(self.modules, state_dicts): module.load_state_dict(state_dict) return
[docs] def to(self, device): for module in self.modules: # module.to(device) set_device(module, device) self.device = "cpu" if device == "cpu" else "cuda" return
[docs] def set_mode(self, train): for module in self.modules: if train: module.train() else: module.eval() return
[docs] def save_config(self, config_path): with open(config_path, "w", encoding="utf-8") as wf: json.dump(self.params, wf, ensure_ascii=False, indent=2)
[docs] @classmethod def from_config(cls, config_path): with open(config_path, "r", encoding="utf-8") as rf: model_config = json.load(rf) wv = torch.load(model_config["wv_path"], map_location='cpu', mmap="r") if "wv_path" in model_config else None return cls( model_config["vocab_size"], model_config["concept_size"], model_config["hidden_dim"], model_config["dropout"], model_config["pos_weight"], model_config["w_cp"], model_config["w_mi"], model_config["w_dis"], wv=wv)
[docs] @classmethod def from_pretrained(cls, model_dir): config_path = os.path.join(model_dir, "model_config.json") model_path = os.path.join(model_dir, "disen_q_net.th") model = cls.from_config(config_path) model.load(model_path) return model