使用DisenQ向量化容器

[2]:
from EduNLP.I2V import DisenQ, get_pretrained_i2v

# 设置你的数据路径和输出路径
BASE_DIR = '../..' # "/your/own/base/path"

data_dir = f"{BASE_DIR}/static/test_data"
output_dir = f"{BASE_DIR}/examples/test_model/disenq"

使用I2V加载本地模型

[3]:
tokenizer_kwargs = {
    "tokenizer_config_dir": output_dir,
}
i2v = DisenQ('disenq', 'disenq', output_dir, tokenizer_kwargs=tokenizer_kwargs, device="cpu")
[4]:
test_items = [
    {"content": "10 米 的 (2/5) = 多少 米 的 (1/2),有 公 式"},
    {"content": "10 米 的 (2/5) = 多少 米 的 (1/2),有 公 式 , 如 图 , 若 $x,y$ 满 足 约 束 条 件 公 式"},
]

t_vec = i2v.infer_token_vector(test_items, key=lambda x: x["content"])
i_vec_k = i2v.infer_item_vector(test_items, key=lambda x: x["content"], vector_type="k")
i_vec_i = i2v.infer_item_vector(test_items, key=lambda x: x["content"], vector_type="i")

print(t_vec.shape) # == torch.Size([2, 23, 128])
print(i_vec_k.shape) # == torch.Size([2, 128])
print(i_vec_i.shape) # == torch.Size([2, 128])

t_vec = i2v.infer_token_vector(test_items[0], key=lambda x: x["content"])
i_vec_k = i2v.infer_item_vector(test_items[0], key=lambda x: x["content"], vector_type="k")
i_vec_i = i2v.infer_item_vector(test_items, key=lambda x: x["content"], vector_type="i")

print(t_vec.shape) # == torch.Size([1, 11, 128])
print(i_vec_k.shape) # == torch.Size([1, 128])
print(i_vec_i.shape) # == torch.Size([2, 128])
torch.Size([2, 23, 128])
torch.Size([2, 128])
torch.Size([2, 128])
torch.Size([1, 11, 128])
torch.Size([1, 128])
torch.Size([2, 128])

使用get_pretrained_i2v加载公开模型

[5]:
# 获取公开的预训练模型
pretrained_dir = f"{BASE_DIR}/examples/test_model/disenq"
i2v = get_pretrained_i2v("disenq_test_128", model_dir=pretrained_dir)
EduNLP, INFO model_dir: ..\..\examples\test_model\disenq\disenq_test_128
EduNLP, INFO Use pretrained t2v model disenq_test_128
downloader, INFO http://base.ustc.edu.cn/data/model_zoo/modelhub/disenq_public/1/disenq_test_128.zip is saved as ..\..\examples\test_model\disenq\disenq_test_128.zip
downloader, INFO file existed, skipped
[6]:
test_items = [
    "有 公 式 $\\FormFigureID{1}$ ,如 图 $\\FigureID{088f15ea-xxx}$",
    "已知 圆 $x^{2}+y^{2}-6 x=0$ ,过 点 (1,2) 的 直 线 被 该 圆 所 截 得 的 弦 的 长度 的 最小 值 为"
]

# 获得句表征和词表征
i_vec, t_vec = i2v(test_items)
print(i_vec[0].shape, i_vec[1].shape)
print(t_vec.shape)
print()

i_vec_k, t_vec = i2v(test_items, vector_type="k")
print(i_vec_k.shape)
print(t_vec.shape)
print()

# 获得指定表征
i_vec_k = i2v.infer_item_vector(test_items, vector_type="k")
i_vec_i = i2v.infer_item_vector(test_items, vector_type="i")
t_vec = i2v.infer_token_vector(test_items)

print(i_vec_k.shape)
print(i_vec_i.shape)
print(t_vec.shape)
torch.Size([2, 128]) torch.Size([2, 128])
torch.Size([2, 24, 128])

torch.Size([2, 128])
torch.Size([2, 24, 128])

torch.Size([2, 128])
torch.Size([2, 128])
torch.Size([2, 24, 128])