一、目录
- 代码
二、实现
- 代码
import torch.nn as nn
import math
import torch
from typing import Optional,Any
class SinusoidalPositionalEmbedding(nn.Embedding):
"""
This module produces sinusoidal positional embeddings of any length.
We don't want to save the weight of this embedding since it's not trained (deterministic) and it can be huge.
Padding symbols are ignored.
These embeddings get automatically extended in forward if more positions is needed.
"""
def __init__(self, num_positions, embedding_dim, padding_idx):
self.make_weight(num_positions, embedding_dim, padding_idx)
def make_weight(self, num_positions, embedding_dim, padding_idx):
weight = self.get_embedding(num_positions, embedding_dim, padding_idx)
if not hasattr(self, "weight"):
# in ___init__
super().__init__(num_positions, embedding_dim, padding_idx, _weight=weight)
else:
# in forward put the weights on the correct dtype and device of the param
weight = weight.to(dtype=self.weight.dtype, device=self.weight.device)
self.weight = nn.Parameter(weight)
self.weight.detach_()
self.weight.requires_grad = False
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx):
"""
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
"Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
@staticmethod
def make_positions(tensor, padding_idx: int):
"""
Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
def forward(
self,
input,
incremental_state: Optional[Any] = None,
timestep= None,
):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input.shape[:2]
max_pos = self.padding_idx + 1 + seq_len
if max_pos > self.weight.size(0):
# expand embeddings if needed
self.make_weight(max_pos, self.embedding_dim, self.padding_idx)
positions = self.make_positions(input, self.padding_idx)
return super().forward(positions)
if __name__ == '__main__':
max_position_embeddings=1024
embed_dim=712
padding_idx=102
embed_positions = SinusoidalPositionalEmbedding(
max_position_embeddings + padding_idx + 1, embed_dim, padding_idx
)
input_ids=torch.randint(0,10,size=(2,10))
positions = embed_positions(input_ids)
print(positions.shape)
文章来源地址https://www.toymoban.com/news/detail-783355.html
文章来源:https://www.toymoban.com/news/detail-783355.html
到了这里,关于SinusoidalPositionalEmbedding/tensor2tensor中实现的绝对位置编码的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!