def forward(self, src, src_len, tgt, tgt_len, split='train'): n_batch=src.size(0) #Creating mask src_pad_mask = get_pad_mask(src, self.pad).to(self.device) #(B, S) tgt_pad_mask = get_pad_mask(tgt, self.pad).to(self.device) #(B, T) pad_mask = (src_pad_mask, tgt_pad_mask) attn_mask = generate_square_subsequent_mask(self.max_tgt_len) src = self.embed(src.to(self.device)) tgt = self.embed(tgt.to(self.device)) #permutation src = src.permute(1, 0, 2) tgt = tgt.permute(1, 0, 2) enc_output, attn = self.encoder(src, pad_mask=pad_mask) output, attn = self.decoder(tgt, enc_output, pad_mask=pad_mask, cross_mask=attn_mask) #permutation output = output.permute(1, 0, 2) * (self.scale ** 0.5) logits = F.log_softmax(self.final(output), dim=-1) generations = torch.argmax(logits, dim=-1) return logits, generations
def generate_lm_batch(batch, pad_idx, in_chat=False): if not in_chat: batch = torch.tensor(batch).T else: batch = pad_sequence(list(map(torch.tensor, batch)), padding_value=pad_idx) x = batch[:-1] y = batch[1:] x_mask = utils.generate_square_subsequent_mask(x.shape[0]) x_pad_mask = (x == pad_idx).T return LMFeature(x=x, y=y, x_mask=x_mask, x_pad_mask=x_pad_mask)
def generate_lm_batch(batch, vocab, is_mlm, in_chat=False): pad_idx = vocab.stoi(utils.PAD) mask_idx = vocab.stoi(utils.MASK) if not in_chat: x_mlm = None x_mlm_pad_mask = None batch = torch.tensor(batch).T elif not is_mlm: x_mlm = None x_mlm_pad_mask = None batch = pad_sequence(list(map(torch.tensor, batch)), padding_value=pad_idx) else: x_mlm = [] for v in batch: l = len(v) if l == 3: # bos + w + eos x_mlm.append(v) continue mask = 1 if l > 5: # mask = 2 # 99/100 is 2, 1/100 is 0 mask = min(2, random.randint(0, 99)) if mask == 1: mask = 2 start = random.randint(1, l-mask-1) x_mlm.append(v[0:start] + [mask_idx] + v[start+mask:]) batch = pad_sequence(list(map(torch.tensor, batch)), padding_value=pad_idx) x_mlm = pad_sequence(list(map(torch.tensor, x_mlm)), padding_value=pad_idx) x_mlm_pad_mask = (x_mlm == pad_idx).T x = batch[:-1] y = batch[1:] x_mask = utils.generate_square_subsequent_mask(x.shape[0]) x_pad_mask = (x == pad_idx).T return LMFeature(x=x, y=y, x_mlm=x_mlm, x_mask=x_mask, x_pad_mask=x_pad_mask, x_mlm_pad_mask=x_mlm_pad_mask)
def generate_batch(batch, pad_idx): context, segs, personas_no_tag, tags, resp, persona, lm = zip(*batch) fn = lambda x: list(map(torch.tensor, x)) context_pad = pad_sequence(fn(context), padding_value=pad_idx) segs_pad = pad_sequence(fn(segs), padding_value=pad_idx) tags = itertools.chain(*tags) tags_pad = pad_sequence(fn(tags), padding_value=pad_idx) tags_pad = tags_pad.view(-1, 2, int(tags_pad.shape[1] / 2)).transpose(1, 0) resp_pad = pad_sequence(fn(resp), padding_value=pad_idx) src_pad_mask = (context_pad == pad_idx).T tgt_pad_mask = (resp_pad == pad_idx).T tgt_mask = utils.generate_square_subsequent_mask(resp_pad.shape[0]) persona_pad = pad_sequence(fn(persona), padding_value=pad_idx) persona_pad_mask = (persona_pad == pad_idx).T # batch_size X n_persona X 2 tmp = list( map(lambda x: pad_sequence(fn(x), padding_value=pad_idx), personas_no_tag)) # n_persona X batch_size X 2 --> 2 X n_persona X batch_size personas_no_tag_pad = pad_sequence(tmp, padding_value=pad_idx).permute(2, 0, 1) lm = generate_lm_batch(lm, pad_idx, in_chat=True) return ChatFeature( context=context_pad, segs=segs_pad, personas_no_tag=personas_no_tag_pad, tags=tags_pad, resp=resp_pad, persona=persona_pad, context_pad_mask=src_pad_mask, resp_mask=tgt_mask, resp_pad_mask=tgt_pad_mask, persona_pad_mask=persona_pad_mask, lm=lm, )
def generate_batch(batch, vocab, persona_vocab, is_mlm): pad_idx = vocab.stoi(utils.PAD) persona_pad_idx = pad_idx if persona_vocab is not None: persona_pad_idx = persona_vocab.stoi(utils.PAD) context, segs, personas_no_tag, tags, resp, persona, lm = zip(*batch) post = [] cls_idx = vocab.stoi(utils.CLS) sep_idx = vocab.stoi(utils.SEP) for v in context: post_start = list(reversed(v)).index(sep_idx) if _USE_BERT_FEATURE: # cls_idx for bert-like sentence rep, xlnet is last token post.append([cls_idx] + v[-post_start:]) else: post.append(v[-post_start:-1]) fn = lambda x: list(map(torch.tensor, x)) context_pad = pad_sequence(fn(context), padding_value=pad_idx) post_pad = pad_sequence(fn(post), padding_value=pad_idx) segs_pad = pad_sequence(fn(segs), padding_value=pad_idx) tags = itertools.chain(*tags) tags_pad = pad_sequence(fn(tags), padding_value=persona_pad_idx) tags_pad = tags_pad.view(-1, 2, int(tags_pad.shape[1]/2)).transpose(1, 0) resp_pad = pad_sequence(fn(resp), padding_value=pad_idx) persona_pad = pad_sequence(fn(persona), padding_value=persona_pad_idx) context_pad_mask = (context_pad == pad_idx).T post_pad_mask = (post_pad == pad_idx).T tags_pad_mask = (tags_pad == persona_pad_idx).T resp_mask = utils.generate_square_subsequent_mask(resp_pad.shape[0]) resp_pad_mask = (resp_pad == pad_idx).T persona_pad_mask = (persona_pad == persona_pad_idx).T # batch_size X n_persona X 2 tmp = list(map(lambda x: pad_sequence(fn(x), padding_value=persona_pad_idx), personas_no_tag)) # n_persona X batch_size X 2 --> 2 X n_persona X batch_size personas_no_tag_pad = pad_sequence(tmp, padding_value=persona_pad_idx).permute(2, 0, 1) personas_no_tag_pad_mask = (personas_no_tag_pad == persona_pad_idx).T lm = generate_lm_batch(lm, vocab, is_mlm, in_chat=True) return ChatFeature( context=context_pad, post=post_pad, segs=segs_pad, personas_no_tag=personas_no_tag_pad, tags=tags_pad, resp=resp_pad, persona=persona_pad, context_pad_mask=context_pad_mask, post_pad_mask=post_pad_mask, resp_mask=resp_mask, resp_pad_mask=resp_pad_mask, persona_pad_mask=persona_pad_mask, personas_no_tag_pad_mask=personas_no_tag_pad_mask, tags_pad_mask=tags_pad_mask, lm=lm, )
def forward( self, input: torch.Tensor, input_mask: torch.Tensor, node_hidden: torch.Tensor, prev_action_hidden: torch.Tensor, prev_action_mask: torch.Tensor, ) -> torch.Tensor: """ input: (batch, input_seq_len, hidden_dim) input_mask: (batch, input_seq_len) node_hidden: (batch, num_node, hidden_dim) prev_action_hidden: (batch, prev_action_len, hidden_dim) prev_action_mask: (batch, prev_action_len) output: (batch, input_seq_len, hidden_dim) """ # calculate attention mask for decoding # this is the mask that prevents MultiheadAttention # from attending to future values input_seq_len = input.size(1) attn_mask = generate_square_subsequent_mask(input_seq_len).to( input.device) # (input_seq_len, input_seq_len) # add the positional encodings pos_encoded_input = self.pos_encoder(input) # self attention layer input_residual = pos_encoded_input # MultiheadAttention expects batch dim to be 1 for q, k, v # but 0 for key_padding_mask, so we need to transpose transposed_pos_encoded_input = pos_encoded_input.transpose(0, 1) input_attn, _ = self.self_attn( transposed_pos_encoded_input, transposed_pos_encoded_input, transposed_pos_encoded_input, key_padding_mask=input_mask == 0, attn_mask=attn_mask, ) input_attn = input_attn.transpose(0, 1) input_attn *= input_mask.unsqueeze(-1) input_attn += input_residual # (batch, input_seq_len, hidden_dim) # calculate self attention for the nodes and previous action # strictly speaking, we should calculate attention masks for these # based on input_mask, but due to this bug: # https://github.com/pytorch/pytorch/issues/41508 # it returns nan's if we apply attention masks. So let's just skip it. # It's OK, b/c we apply input_mask when we combine these. # apply layer norm to the input self attention output to calculate the query query = self.self_attn_layer_norm(input_attn).transpose(0, 1) # (input_seq_len, batch, hidden_dim) # self attention for the nodes # no key_padding_mask, since we use all the nodes # (batch * num_heads, input_seq_len, num_node) transposed_node_hidden = node_hidden.transpose(0, 1) node_attn, _ = self.node_attn(query, transposed_node_hidden, transposed_node_hidden) node_attn = node_attn.transpose(0, 1) # (batch, input_seq_len, hidden_dim) # self attention for the previous action # key_padding_mask is from prev_action_mask # (batch * num_heads, input_seq_len, prev_action_len) transposed_prev_action_hidden = prev_action_hidden.transpose(0, 1) prev_action_attn, _ = self.prev_action_attn( query, transposed_prev_action_hidden, transposed_prev_action_hidden, key_padding_mask=prev_action_mask == 0, ) prev_action_attn = prev_action_attn.transpose(0, 1) # (batch, input_seq_len, hidden_dim) # combine self attention for the previous action and nodes with # input self attention combined_self_attn = self.combine_node_prev_action( torch.cat([prev_action_attn, node_attn], dim=-1)) combined_self_attn *= input_mask.unsqueeze(-1) combined_self_attn += input_attn # (batch, input_seq_len, hidden_dim) # linear layer output = self.linear_layer_norm(combined_self_attn) output = self.linear_layers(output) output += combined_self_attn # (batch, input_seq_len, hidden_dim) return output
def test_generate_subsequent_mask(size): mask = generate_square_subsequent_mask(size) # assert that the sum of tril and triu is the original mask assert mask.equal(torch.tril(mask) + torch.triu(mask, diagonal=1))