1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
| class PositionalEncoding(nn.Module): def __init__(self, d_model, max_pos, device): super(PositionalEncoding, self).__init__() self.device = device self.pos_embedding = nn.Embedding(max_pos, d_model)
def forward(self, inputs): seq_len = inputs.size(1) pos = torch.arange(seq_len, dtype=torch.long, device=self.device) pos = pos.unsqueeze(0).expand_as(inputs) return self.pos_embedding(pos)
def get_attn_subsequence_mask(seq, device): attn_shape = [seq.size(0), seq.size(1), seq.size(1)] subsequence_mask = np.triu(np.ones(attn_shape), k=1) subsequence_mask = torch.from_numpy(subsequence_mask).byte() subsequence_mask = subsequence_mask.to(device) return subsequence_mask
def get_attn_pad_mask(attention_mask): batch_size, len_seq = attention_mask.size() attention_mask = attention_mask.data.eq(0).unsqueeze(1) return attention_mask.expand(batch_size, len_seq, len_seq)
class Decoder(nn.Module): def __init__(self, d_model, n_heads, d_ff, d_k, d_v, vocab_size, max_pos, n_layers, device): super(Decoder, self).__init__() self.device = device self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = PositionalEncoding(d_model, max_pos, device) self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, d_k, d_v) for _ in range(n_layers)])
def forward(self, inputs, attention_mask): outputs = self.embedding(inputs) + self.pos_encoding(inputs) subsequence_mask = get_attn_subsequence_mask(inputs, self.device) if attention_mask is not None: attention_mask = get_attn_pad_mask(attention_mask) attention_mask = torch.gt((attention_mask + subsequence_mask), 0) else: attention_mask = subsequence_mask.bool() # 计算每一层的结果 self_attns = [] for layer in self.layers: # outputs: [batch_size, seq_len, d_model], # self_attn: [batch_size, n_heads, seq_len, seq_len], outputs, self_attn = layer(outputs, attention_mask) self_attns.append(self_attn) return outputs, self_attns
|