import torch
[docs]def sequence_mask(lengths, max_len=None):
"""Same as tf.sequence_mask, Returns a mask tensor representing the first N positions of each cell.
Parameters
----------
lengths : _type_
integer tensor, all its values <= maxlen.
max_len : _type_, optional
scalar integer tensor, size of last dimension of returned tensor. Default is the maximum value in lengths.
Returns
-------
_type_
A mask tensor of shape lengths.shape + (maxlen,)
Examples:
---------
>>> sequence_mask(torch.tensor([1, 3, 2]), 5)
tensor([[ True, False, False, False, False],
[ True, True, True, False, False],
[ True, True, False, False, False]])
>>> sequence_mask(torch.tensor([[1, 3],[2,0]]))
tensor([[[ True, False, False],
[ True, True, True]],
<BLANKLINE>
[[ True, True, False],
[False, False, False]]])
"""
lengths_shape = lengths.shape # torch.size() is a tuple
lengths = lengths.reshape(-1)
batch_size = lengths.numel()
max_len = max_len or int(lengths.max())
lengths_shape += (max_len,)
return (torch.arange(0, max_len, device=lengths.device)
.type_as(lengths)
.unsqueeze(0).expand(batch_size, max_len)
.lt(lengths.unsqueeze(1))).reshape(lengths_shape)
[docs]def gather_nd(params, indices):
"""_summary_
Parameters
----------
params : _type_
_description_
indices : _type_
_description_
Returns
-------
_type_
_description_
Examples:
---------
>>> gather_nd(
... params=torch.tensor([[1, 2, 3],
... [4, 5, 6]]),
... indices=torch.tensor([[1],
... [0]]))
tensor([[4, 5, 6],
[1, 2, 3]])
"""
newshape = indices.shape[:-1] + params.shape[indices.shape[-1]:]
indices = indices.view(-1, indices.shape[-1]).tolist()
out = torch.cat([params.__getitem__(tuple(i)) for i in indices])
return out.reshape(newshape)