import torch
import numpy as np
import torch.nn as nn
from .modules import FeatureExtractor, ImageAE, MetaAE
from .util import SeqBatch
import json
import os
from pathlib import Path
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, PackedSequence
from torchvision.transforms.functional import to_tensor
from transformers.modeling_outputs import ModelOutput
from transformers import PretrainedConfig
from ..base_model import BaseModel
class QuesNetOutput(ModelOutput):
"""
Output type of [`DisenQNet`]
Parameters
----------
pack_embeded: Tensor of (batch_size, seq_len, hidden_size), word embedding
hidden: Tensor of (batch_size, hidden_size) or None
"""
pack_embeded: torch.FloatTensor = None
embeded: torch.FloatTensor = None
hidden: torch.FloatTensor = None
[docs]class QuesNet(BaseModel, FeatureExtractor):
base_model_prefix = 'quesnet'
def __init__(self, _stoi=None, meta='know_name', pretrained_embs: np.ndarray = None,
pretrained_image: nn.Module = None, pretrained_meta: nn.Module = None,
lambda_input=None,
feat_size=256, emb_size=256, rnn_type='LSTM', layers=4, **kwargs):
BaseModel.__init__(self)
FeatureExtractor.__init__(self, feat_size=feat_size)
self.feat_size = feat_size
self.rnn_size = feat_size // 2
self.emb_size = emb_size
self.vocab_size = len(_stoi['word'])
self.stoi = _stoi
self.itos = {v: k for k, v in self.stoi['word'].items()}
# Encoder - Word
self.we = nn.Embedding(self.vocab_size, emb_size)
if pretrained_embs is not None:
self.load_emb(pretrained_embs)
# Encoder - Image
self.ie = ImageAE(emb_size)
if pretrained_image is not None:
self.load_img(pretrained_image)
# Encoder - Mata
self.meta = meta
self.meta_size = len(_stoi[self.meta])
self.me = MetaAE(self.meta_size, emb_size)
if pretrained_meta is not None:
self.load_meta(pretrained_meta)
if lambda_input is None:
lambda_input = [1., 1., 1.]
self.lambda_input = lambda_input
self.proj_q = nn.Linear(feat_size, feat_size)
self.proj_k = nn.Linear(feat_size, feat_size)
self.proj_v = nn.Linear(feat_size, feat_size)
self.dropout = nn.Dropout(0.2)
self.rnn_type = rnn_type
if rnn_type == 'GRU':
self.rnn = nn.GRU(emb_size, self.rnn_size, layers,
bidirectional=True, batch_first=True)
self.h0 = nn.Parameter(torch.rand(layers * 2, 1, self.rnn_size))
elif rnn_type == 'LSTM':
self.rnn = nn.LSTM(emb_size, self.rnn_size, layers,
bidirectional=True, batch_first=True)
self.h0 = nn.Parameter(torch.rand(layers * 2, 1, self.rnn_size))
self.c0 = nn.Parameter(torch.rand(layers * 2, 1, self.rnn_size))
else:
raise ValueError('quesnet only support GRU and LSTM now.')
self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__", "kwargs"]}
# self.config.update(kwargs)
self.config["architecture"] = 'quesnet'
self.config = PretrainedConfig.from_dict(self.config)
[docs] def init_h(self, batch_size):
size = list(self.h0.size())
size[1] = batch_size
if self.config.rnn_type == 'GRU':
return self.h0.expand(size).contiguous()
else:
return self.h0.expand(size).contiguous(), self.c0.expand(size).contiguous()
[docs] def load_emb(self, emb):
self.we.weight.detach().copy_(torch.from_numpy(emb))
[docs] def load_img(self, img_layer: nn.Module):
if self.config.emb_size != img_layer.emb_size:
raise ValueError("Unmatched pre-trained ImageAE and embedding size")
else:
self.ie.load_state_dict(img_layer.state_dict())
[docs] def make_batch(self, data, device, pretrain=False):
"""Returns embeddings"""
lembs = []
rembs = []
embs = []
gt = []
ans_input = []
ans_output = []
false_options = [[] for i in range(3)]
for q in data:
meta = torch.zeros(len(self.stoi[self.meta])).to(device)
meta[q.labels.get(self.meta) or []] = 1
_lembs = [self.we(torch.tensor([0], device=device)),
self.we(torch.tensor([0], device=device)),
self.me.enc(meta.unsqueeze(0)) * self.lambda_input[2]]
_rembs = [self.me.enc(meta.unsqueeze(0)) * self.lambda_input[2]]
_embs = [self.we(torch.tensor([0], device=device)),
self.we(torch.tensor([0], device=device)),
self.me.enc(meta.unsqueeze(0)) * self.lambda_input[2]]
_gt = [torch.tensor([0], device=device), meta]
for w in q.content:
if isinstance(w, int):
word = torch.tensor([w], device=device)
item = self.we(word) * self.lambda_input[0]
_lembs.append(item)
_rembs.append(item)
_embs.append(item)
_gt.append(word)
else:
im = to_tensor(w).to(device)
item = self.ie.enc(im.unsqueeze(0), detach_tensor=True) * self.lambda_input[1]
_lembs.append(item)
_rembs.append(item)
_embs.append(item)
_gt.append(im)
_gt.append(torch.tensor([0], device=device))
_rembs.append(self.we(torch.tensor([0], device=device)))
_rembs.append(self.we(torch.tensor([0], device=device)))
_embs.append(self.we(torch.tensor([0], device=device)))
_embs.append(self.we(torch.tensor([0], device=device)))
lembs.append(torch.cat(_lembs, dim=0))
rembs.append(torch.cat(_rembs, dim=0))
embs.append(torch.cat(_embs, dim=0))
gt.append(_gt)
ans_input.append([0] + q.answer)
ans_output.append(q.answer + [0])
for i, fo in enumerate(q.false_options):
false_options[i].append([0] + fo)
lembs = SeqBatch(lembs, device=device)
rembs = SeqBatch(rembs, device=device)
embs = SeqBatch(embs, device=device)
ans_input = SeqBatch(ans_input, device=device)
ans_output = SeqBatch(ans_output, device=device)
false_opt_input = [SeqBatch(foi, device=device) for foi in false_options]
length = sum(lembs.lens)
words = []
ims = []
metas = []
# p = lembs.packed().data
wmask = torch.zeros(length, device=device).byte()
imask = torch.zeros(length, device=device).byte()
mmask = torch.zeros(length, device=device).byte()
for i, _gt in enumerate(gt):
for j, v in enumerate(_gt):
ind = lembs.index((j, i))
if v.size() == torch.Size([1]): # word
words.append((v, ind))
wmask[ind] = 1
elif v.dim() == 1: # meta
metas.append((v.unsqueeze(0), ind))
mmask[ind] = 1
else: # img
ims.append((v.unsqueeze(0), ind))
imask[ind] = 1
words = [x[0] for x in sorted(words, key=lambda x: x[1])]
ims = [x[0] for x in sorted(ims, key=lambda x: x[1])]
metas = [x[0] for x in sorted(metas, key=lambda x: x[1])]
words = torch.cat(words, dim=0) if words else None
ims = torch.cat(ims, dim=0) if ims else None
metas = torch.cat(metas, dim=0) if metas else None
if pretrain:
return (
lembs, rembs, words, ims, metas, wmask, imask, mmask,
embs, ans_input, ans_output, false_opt_input
)
else:
return embs
[docs] def forward(self, inputs: SeqBatch):
packed = inputs.packed()
h = self.init_h(packed.batch_sizes[0])
# y: (batch_size, seq_len, 2*rnn_size)
y, _ = self.rnn(packed, h)
hs, lens = pad_packed_sequence(y, batch_first=True)
mask = [[1] * lens[i].item() + [0] * (lens[0] - lens[i]).item()
for i in range(len(lens))]
mask = torch.tensor(mask).byte().to(inputs.device)
# hs: (B, S, D), mask: (B, S)
q, k, v = self.proj_q(hs), self.proj_k(hs), self.proj_v(hs)
scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) # (B, S, S)
if mask is not None:
mask = mask.float()
scores = scores - 1e9 * (1.0 - mask.unsqueeze(1))
scores = self.dropout(F.softmax(scores, dim=-1)) # (B, S, S)
h = (scores @ v).max(1)[0] # (B, D)
return QuesNetOutput(
pack_embeded=y,
embeded=hs,
hidden=inputs.invert(h, 0)
)
[docs] @classmethod
def from_config(cls, config_path, **kwargs):
with open(config_path, "r", encoding="utf-8") as rf:
model_config = json.load(rf)
model_config.update(kwargs)
return cls(_stoi=model_config["_stoi"],
feat_size=model_config["feat_size"],
emb_size=model_config["emb_size"],
rnn_type=model_config["rnn_type"],
layers=model_config["layers"])
class QuesNetForPreTrainingOutput(ModelOutput):
"""
Output type of [`DisenQNet`]
Parameters
----------
hidden: Tensor of (batch_size, hidden_size) or None
"""
loss: torch.FloatTensor = None
embeded: torch.FloatTensor = None
hidden: torch.FloatTensor = None
[docs]class QuesNetForPreTraining(BaseModel):
base_model_prefix = 'quesnet'
"""Sequence-to-sequence feature extractor based on RNN. Supports different
input forms and different RNN types (LSTM/GRU), """
def __init__(self, _stoi=None, pretrained_embs: np.ndarray = None, pretrained_image: nn.Module = None,
pretrained_meta: nn.Module = None,
meta='know_name',
emb_size=256,
feat_size=512,
rnn_type='LSTM',
lambda_input=None,
lambda_loss=None,
layers=4, **kwargs):
BaseModel.__init__(self)
self.quesnet = QuesNet(_stoi, meta=meta, pretrained_embs=pretrained_embs, pretrained_image=pretrained_image,
pretrained_meta=pretrained_meta, lambda_input=lambda_input, feat_size=feat_size,
emb_size=emb_size, rnn_type=rnn_type, layers=layers, **kwargs)
self.vocab_size = self.quesnet.vocab_size
self.feat_size = self.quesnet.feat_size
self.emb_size = self.quesnet.emb_size
self.rnn_size = self.quesnet.rnn_size
if lambda_loss is None:
lambda_loss = [1., 1.]
self.lambda_loss = lambda_loss
# 预训练头 - HLM
self.woutput = nn.Linear(self.feat_size, self.vocab_size)
self.ioutput = nn.Linear(self.feat_size, self.emb_size)
self.moutput = nn.Linear(self.feat_size, self.emb_size)
self.lwoutput = nn.Linear(self.rnn_size, self.vocab_size)
self.lioutput = nn.Linear(self.rnn_size, self.emb_size)
self.lmoutput = nn.Linear(self.rnn_size, self.emb_size)
self.rwoutput = nn.Linear(self.rnn_size, self.vocab_size)
self.rioutput = nn.Linear(self.rnn_size, self.emb_size)
self.rmoutput = nn.Linear(self.rnn_size, self.emb_size)
# 预训练头 - QA
self.ans_decode = nn.GRU(self.emb_size, self.feat_size, layers,
batch_first=True)
self.ans_output = nn.Linear(self.feat_size, self.vocab_size)
self.ans_judge = nn.Linear(self.feat_size, 1)
self.dropout = nn.Dropout(0.2)
self.config = {k: v for k, v in locals().items() if k not in [
"self", "__class__", "kwargs", "pretrained_embs", "pretrained_image", "pretrained_meta"]}
# self.config.update(kwargs)
self.config['architecture'] = 'quesnet'
self.config = PretrainedConfig.from_dict(self.config)
[docs] def forward(self, batch):
left, right, words, ims, metas, wmask, imask, mmask, inputs, ans_input, ans_output, false_opt_input = batch
# high-level loss
outputs = self.quesnet(inputs)
embeded = outputs.embeded
h = outputs.hidden
x = ans_input.packed()
y, _ = self.ans_decode(PackedSequence(self.quesnet.we(x.data), x.batch_sizes),
h.repeat(self.config.layers, 1, 1))
floss = F.cross_entropy(self.ans_output(y.data),
ans_output.packed().data)
floss = floss + F.binary_cross_entropy_with_logits(self.ans_judge(y.data),
torch.ones_like(self.ans_judge(y.data)))
for false_opt in false_opt_input:
x = false_opt.packed()
y, _ = self.ans_decode(PackedSequence(self.quesnet.we(x.data), x.batch_sizes),
h.repeat(self.config.layers, 1, 1))
floss = floss + F.binary_cross_entropy_with_logits(self.ans_judge(y.data),
torch.zeros_like(self.ans_judge(y.data)))
loss = floss * self.lambda_loss[1]
# low-level loss
left_hid = self.quesnet(left).pack_embeded.data[:, :self.rnn_size]
right_hid = self.quesnet(right).pack_embeded.data[:, self.rnn_size:]
wloss = iloss = mloss = None
if words is not None:
lwfea = torch.masked_select(left_hid, wmask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size)
lout = self.lwoutput(lwfea)
rwfea = torch.masked_select(right_hid, wmask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size)
rout = self.rwoutput(rwfea)
out = self.woutput(torch.cat([lwfea, rwfea], dim=1))
wloss = (F.cross_entropy(out, words) + F.cross_entropy(lout, words) + F.
cross_entropy(rout, words)) * self.quesnet.lambda_input[0] / 3
wloss *= self.lambda_loss[0]
loss = loss + wloss
if ims is not None:
lifea = torch.masked_select(left_hid, imask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size)
lout = self.lioutput(lifea)
rifea = torch.masked_select(right_hid, imask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size)
rout = self.rioutput(rifea)
out = self.ioutput(torch.cat([lifea, rifea], dim=1))
iloss = (self.quesnet.ie.loss(ims, out) + self.quesnet.ie.loss(ims, lout) + self.quesnet.ie.
loss(ims, rout)) * self.quesnet.lambda_input[1] / 3
iloss *= self.lambda_loss[0]
loss = loss + iloss
if metas is not None:
lmfea = torch.masked_select(left_hid, mmask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size)
lout = self.lmoutput(lmfea)
rmfea = torch.masked_select(right_hid, mmask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size)
rout = self.rmoutput(rmfea)
out = self.moutput(torch.cat([lmfea, rmfea], dim=1))
mloss = (self.quesnet.me.loss(metas, out) + self.quesnet.me.loss(metas, lout) + self.quesnet.me.
loss(metas, rout)) * self.quesnet.lambda_input[2] / 3
mloss *= self.lambda_loss[0]
loss = loss + mloss
return QuesNetForPreTrainingOutput(
loss=loss,
embeded=embeded,
hidden=h
)
[docs] @classmethod
def from_config(cls, config_path, **kwargs):
with open(config_path, "r", encoding="utf-8") as rf:
model_config = json.load(rf)
model_config.update(kwargs)
return cls(
_stoi=model_config["_stoi"],
emb_size=model_config["emb_size"],
rnn=model_config["rnn"], lambda_input=model_config["lambda_input"],
lambda_loss=model_config["lambda_loss"], layers=model_config["layers"],
feat_size=model_config["feat_size"]
)