def forward(self, tgt, memory_bank, memory_bank_t, step=None, **kwargs): """Decode, possibly stepwise.""" if step == 0: self._init_cache(memory_bank) tgt_words = tgt[:, :, 0].transpose(0, 1) emb = self.embeddings(tgt, step=step) assert emb.dim() == 3 # len x batch x embedding_dim output = emb.transpose(0, 1).contiguous() src_memory_bank = memory_bank.transpose(0, 1).contiguous() ################========================================= tt_memory_bank = memory_bank_t.transpose(0, 1).contiguous() pad_idx = self.embeddings.word_padding_idx src_lens = kwargs["memory_lengths"] src_max_len = self.state["src"].shape[0] src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1) #################============================================= tt_lens = kwargs["memory_lengths_t"] tt_max_len = self.state["tt"].shape[0] tt_pad_mask = ~sequence_mask(tt_lens, tt_max_len).unsqueeze(1) tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] with_align = kwargs.pop('with_align', False) attn_aligns = [] for i, layer in enumerate(self.transformer_layers): layer_cache = self.state["cache"]["layer_{}".format(i)] \ if step is not None else None output, attn, attn_align = layer(output, src_memory_bank, src_pad_mask, tt_memory_bank, tt_pad_mask, tgt_pad_mask, layer_cache=layer_cache, step=step, with_align=with_align) if attn_align is not None: attn_aligns.append(attn_align) output = self.layer_norm(output) dec_outs = output.transpose(0, 1).contiguous() attn = attn.transpose(0, 1).contiguous() attns = {"std": attn} if self._copy: attns["copy"] = attn if with_align: attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg # TODO change the way attns is returned dict => list or tuple (onnx) return dec_outs, attns
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): # here we do not need to calculate the align # because the answer vector is already averaged representations if source.dim() == 2: source = source.unsqueeze(1) batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) return c.squeeze(1), align_vectors
def forward(self, source, memory_bank, memory_lengths=None, memory_turns = None, coverage=None): # here we implement a hierarchical attention if source.dim() == 2: source = source.unsqueeze(1) batch, source_tl, source_wl, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) # word level attention word_align = self.word_score(source, memory_bank.contiguous() .view(batch, -1, dim)) # transform align (b, 1, tl * wl) -> (b * tl, 1, wl) word_align = word_align.view(batch * source_tl, 1, source_wl) if memory_lengths is not None: word_mask = sequence_mask_herd(memory_lengths.view(-1), max_len=word_align.size(-1)) word_mask = word_mask.unsqueeze(1) # Make it broadcastable. word_align.masked_fill_(1 - word_mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": word_align_vectors = F.softmax(word_align.view(batch * source_tl, source_wl), -1) else: word_align_vectors = sparsemax(word_align.view(batch * source_tl, source_wl), -1) # mask the all padded sentences sent_pad_mask = memory_lengths.view(-1).eq(0).unsqueeze(1) word_align_vectors = torch.mul(word_align_vectors, (1.0 - sent_pad_mask).type_as(word_align_vectors)) word_align_vectors = word_align_vectors.view(batch * source_tl, target_l, source_wl) # each context vector c_t is the weighted average # over all the source hidden states cw = torch.bmm(word_align_vectors, memory_bank.view(batch * source_tl, source_wl, -1)) cw = cw.view(batch, source_tl, -1) # concat_cw = torch.cat([cw, source.repeat(1, source_tl, 1)], 2).view(batch*source_tl, -1) # attn_hw = self.word_linear_out(concat_cw).view(batch, source_tl, -1) # attn_hw = torch.tanh(attn_hw) # turn level attention turn_align = self.turn_score(source, cw) if memory_turns is not None: turn_mask = sequence_mask(memory_turns, max_len=turn_align.size(-1)) turn_mask = turn_mask.unsqueeze(1) # Make it broadcastable. turn_align.masked_fill_(1 - turn_mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": turn_align_vectors = F.softmax(turn_align.view(batch * target_l, source_tl), -1) else: turn_align_vectors = sparsemax(turn_align.view(batch * target_l, source_tl), -1) turn_align_vectors = turn_align_vectors.view(batch, target_l, source_tl) # each context vector c_t is the weighted average # over all the source hidden states ct = torch.bmm(turn_align_vectors, cw) return ct.squeeze(1), None
def forward(self, src, lengths=None): """See :func:`EncoderBase.forward()`""" self._check_args(src, lengths) emb = self.embeddings(src) out = emb.transpose(0, 1).contiguous() ''' indicators = [] for idx, l in enumerate(lengths): src_item = src[:l, idx, 0] src_seq = str(l.item()) + '_' + ' '.join([self.vocab[i] for i in src_item]) if src_seq not in self.reaction_atoms: print('error') reaction_atom_indicator = self.reaction_atoms[src_seq].tolist() reaction_atom_indicator.extend([0] * (out.shape[1] - l.item())) indicators.append(reaction_atom_indicator) indicators = np.array(indicators, dtype='float32') indicators = torch.from_numpy(indicators).float().cuda() indicators = indicators.unsqueeze(2) out = torch.cat((out, indicators), dim=2) out = self.linear(out) out = self.linear_dropout(out) ''' mask = ~sequence_mask(lengths).unsqueeze(1) # Run the forward pass of every layer of the transformer. for layer in self.transformer: out = layer(out, mask) out = self.layer_norm(out) # User added to incorporate reaction atom indicator out = out.transpose(0, 1).contiguous() return emb, out, lengths
def forward(self, memory_bank, memory_lengths=None, coverage=None): """ Args: source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` memory_lengths (LongTensor): the source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) Returns: (FloatTensor, FloatTensor): * Computed vector ``(tgt_len, batch, dim)`` * Attention distribtutions for each query ``(tgt_len, batch, src_len)`` """ batch, source_l, dim = memory_bank.size() source = self.source source = source.expand(batch, -1) source = source.unsqueeze(1) batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(~mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch*target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch*target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # batch, target_l, dim c = c.mean(dim=1) # batch, dim # Check output sizes batch_, dim_ = c.size() aeq(batch, batch_) aeq(dim, dim_) return c
def forward(self, src, grh, lengths=None): """See :func:`EncoderBase.forward()`""" self._check_args(src, lengths) assert src.size(0) == grh.size(-1), "srclen != grh_n" emb = self.embeddings(src) # batch, srclen, dim out = emb.transpose(0, 1).contiguous() mask = ~sequence_mask(lengths).unsqueeze(1) # Run the forward pass of every layer of the tranformer. out_list = [] out_list.append(out) if self.aggregation == "dense": for i in range(len(self.transformer)): out = self.dense_linear[i](torch.cat(out_list, dim=-1)) out = self.transformer[i](out, grh, mask) out_list.append(out) aggregate_out = torch.cat(out_list, dim=-1) aggregate_out = self.aggregate_layer_norm(aggregate_out) aggregate_out = self.aggregate_linear(aggregate_out) return emb, aggregate_out.transpose(0, 1).contiguous(), lengths else: for layer in self.transformer: out = layer(out, grh, mask) out_list.append(out) out = self.layer_norm(out) if self.aggregation == "jump": aggregate_out = torch.cat(out_list, dim=-1) aggregate_out = self.aggregate_layer_norm(aggregate_out) aggregate_out = self.aggregate_linear(aggregate_out) return emb, aggregate_out.transpose(0, 1).contiguous(), lengths return emb, out.transpose(0, 1).contiguous(), lengths
def forward(self, src, lengths=None, batch=None): """See :func:`EncoderBase.forward()`""" self._check_args(src, lengths) emb = self.embeddings(src) batch_size = emb.size()[1] seq_len = emb.size()[0] out = emb.transpose(0, 1).contiguous() mask = ~sequence_mask(lengths).unsqueeze(1) batch_geometric = get_embs_graph(batch, out) if self.decoder_dim != self.d_model: out = self.linear_before(out) for i, layer in enumerate(self.transformer): out, attns = layer(out, mask) out_gnn = out.view((batch_size * seq_len, self.d_model)) memory_bank = self.gnns[i](out_gnn, batch_geometric.edge_index, edge_type=batch_geometric.y) out = self.rnn(memory_bank, out_gnn) out = out.view((batch_size, seq_len, self.d_model)) out = self.layer_norm(out) if self.decoder_dim != self.d_model: out = self.linear(out) attn_weight = F.softmax(self.FC(out), dim=1) outputs2 = attn_weight * out outputs2 = self.FC2(outputs2.sum(dim=1)) return emb, out.transpose(0, 1).contiguous(), lengths, outputs2
def forward(self, src, lengths=None, segment_count=None, padding_value=None): self._check_args(src, lengths) encoder_final, memory_bank, lengths = self.rnn_encoder(src, lengths, enforce_sorted=False) context = self.attn(memory_bank.transpose(0,1), memory_lengths=lengths) segment_count_, dim_ = context.size() assert dim_ == self.hidden_size assert segment_count.sum() == segment_count_ batch_ = segment_count.size() segment_representation = context.split(segment_count.tolist()) segment_representation_padded = pad_sequence(segment_representation, padding_value=padding_value) memory_bank, _ = self.content_selection_attn(segment_representation_padded.transpose(0, 1).contiguous(), segment_representation_padded.transpose(0, 1), memory_lengths=segment_count) _, batch, emb_dim = memory_bank.size() assert batch == batch_[0] assert emb_dim == self.hidden_size if segment_count is not None: # we avoid padding while mean pooling mask = sequence_mask(segment_count).float() mask = mask / segment_count.unsqueeze(1).float() mean = torch.bmm(mask.unsqueeze(1), memory_bank.transpose(0, 1)).squeeze(1) else: mean = memory_bank.mean(0) mean = mean.expand(self.num_layers, batch, emb_dim) encoder_final = (mean, mean) return encoder_final, memory_bank, segment_count
def forward(self, src, lengths=None, batch=None): """See :func:`EncoderBase.forward()`""" self._check_args(src, lengths) emb = self.embeddings(src) batch_size = emb.size()[1] seq_len = emb.size()[0] out = emb.transpose(0, 1).contiguous() mask = ~sequence_mask(lengths).unsqueeze(1) if self.decoder_dim != self.d_model: out = self.linear_before(out) batch_geometric = get_embs_graph(batch, out) memory_bank = batch_geometric.x for layer in self.gnns: new_memory_bank = layer(memory_bank, batch_geometric.edge_index, edge_type=batch_geometric.y) memory_bank = self.rnn(new_memory_bank, memory_bank) memory_bank = memory_bank.view((batch_size, seq_len, self.d_model)) for layer in self.transformer: out, attns = layer(out, mask) out = self.layer_norm(out) out = torch.cat([out, memory_bank], dim=2) out = self.final_linear(out) return emb, out.transpose(0, 1).contiguous(), lengths
def _compute_te_loss(self, target_attns): """ :param target_attns: a tuple (stacked_target_attns, target_attns_lens, src_states_target_list) :return: target encoding loss """ # stacked_target_attns: [b_size, max_sep_num, sample_size+1] # target_attns_lens: [b_size] # src_states_target_list: [b_size] stacked_target_attns, target_attns_lens, src_states_target_list = target_attns b_size, max_sep_num, cls_num = stacked_target_attns.size() device = stacked_target_attns.device gt_tensor = torch.Tensor(src_states_target_list).view(b_size, 1).repeat(1, max_sep_num).to(device) # class_dist_flat: [b_size * max_sep_num, sample_size+1] class_dist_flat = stacked_target_attns.view(-1, cls_num) log_dist_flat = torch.log(class_dist_flat + EPS) target_flat = gt_tensor.view(-1, 1) # [b_size * max_sep_num, 1] losses_flat = -torch.gather(log_dist_flat, dim=1, index=target_flat.long()) losses = losses_flat.view(b_size, max_sep_num) mask = sequence_mask(torch.Tensor(target_attns_lens)).to(device) losses = losses * mask.float() losses = losses.sum(dim=1) return losses
def _compute_orthogonal_loss(self, sep_states): """ The orthogonal loss computation function sep_states: a tuple (stacked_sep_states, sep_states_lens) :return: a scalar, the orthogonal loss """ # stacked_sep_states: [b_size, max_sep_num, src_h_size] stacked_sep_states, sep_states_lens = sep_states b_size, max_sep_num, src_h_size = stacked_sep_states.size() b_size_ = len(sep_states_lens) aeq(b_size, b_size_) device = stacked_sep_states.device # obtain the mask # [b_size, max_sep_num] mask = sequence_mask(torch.Tensor(sep_states_lens)).to(device) mask = mask.float() # [b_size, 1, max_sep_num] mask = mask.unsqueeze(1) # [b_size, max_sep_num, max_sep_num] mask_2d = torch.bmm(mask.transpose(1, 2), mask) # compute the loss # [b_size, max_sep_num, max_sep_num] identity = torch.eye(max_sep_num).unsqueeze(0).repeat(b_size, 1, 1).to(device) # [b_size, max_sep_num, max_sep_num] orthogonal_loss_ = torch.bmm(stacked_sep_states, stacked_sep_states.transpose(1, 2)) - identity orthogonal_loss_ = orthogonal_loss_ * mask_2d # [b_size] orthogonal_loss = torch.norm(orthogonal_loss_.view(b_size, -1), p=2, dim=1) return orthogonal_loss
def _make_shard_state(self, batch, range_, result): mel, mel_lengths = batch.tgt #import pdb;pdb.set_trace() return { "output": result["dec_out"], "target": mel[1:], "lengths": sequence_mask(mel_lengths-1).transpose(0,1).unsqueeze(-1).type(torch.FloatTensor).cuda() }
def __call__(self, tgt: torch.Tensor, memory_bank: torch.Tensor, step: Optional[int] = None, **kwargs): """Decode, possibly stepwise.""" if step == 0: self._init_cache(memory_bank) tgt_words = tgt[:, :, 0].transpose(0, 1) emb = self.embeddings(tgt, step=step) assert emb.dim() == 3 # len x batch x embedding_dim output = emb.transpose(0, 1).contiguous() src_memory_bank = memory_bank.transpose(0, 1).contiguous() pad_idx = self.embeddings.word_padding_idx src_lens = kwargs["memory_lengths"] src_max_len = self.state["src"].shape[0] #Turbo add bool -> float src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1) tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] with_align = kwargs.pop('with_align', False) if with_align: raise "with_align must be False" attn_aligns = [] # It's Turbo's show time! for i, layer in enumerate(self.transformer_layers): layer_cache = self.state["cache"]["layer_{}".format(i)] \ if step is not None else None output, attn, attn_align = layer(output, src_memory_bank, src_pad_mask, tgt_pad_mask, layer_cache=layer_cache, step=step, with_align=with_align) if attn_align is not None: attn_aligns.append(attn_align) # Turbo finished. output = self.layer_norm(output) dec_outs = output.transpose(0, 1).contiguous() attn = attn.transpose(0, 1).contiguous() attns = {"std": attn} if self._copy: attns["copy"] = attn if with_align: attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg # TODO(OpenNMT-py) change the way attns is returned dict => list or tuple (onnx) return dec_outs, attns
def ans_match(src_seq, ans_seq): import torch.nn.functional as F BF_ans_mask = sequence_mask(ans_lengths) # [batch, ans_seq_len] BF_src_mask = sequence_mask(src_lengths) # [batch, src_seq_len] BF_src_outputs = src_seq.transpose( 0, 1) # [batch, src_seq_len, 2*hidden_size] BF_ans_outputs = ans_seq.transpose( 0, 1) # [batch, ans_seq_len, 2*hidden_size] # compute bi-att scores src_scores = BF_src_outputs.bmm(BF_ans_outputs.transpose( 2, 1)) # [batch, src_seq_len, ans_seq_len] ans_scores = BF_ans_outputs.bmm(BF_src_outputs.transpose( 2, 1)) # [batch, ans_seq_len, src_seq_len] # mask padding Expand_BF_ans_mask = BF_ans_mask.unsqueeze(1).expand( src_scores.size()) # [batch, src_seq_len, ans_seq_len] src_scores.data.masked_fill_(~(Expand_BF_ans_mask).bool(), -float('inf')) # src_scores = torch.ones(src_scores.shape).to(ans_seq.device) Expand_BF_src_mask = BF_src_mask.unsqueeze(1).expand( ans_scores.size()) # [batch, ans_seq_len, src_seq_len] ans_scores.data.masked_fill_(~(Expand_BF_src_mask).bool(), -float('inf')) # normalize with softmax src_alpha = F.softmax(src_scores, dim=2) # [batch, src_seq_len, ans_seq_len] ans_alpha = F.softmax(ans_scores, dim=2) # [batch, ans_seq_len, src_seq_len] # take the weighted average BF_src_matched_seq = src_alpha.bmm( BF_ans_outputs) # [batch, src_seq_len, 2*hidden_size] src_matched_seq = BF_src_matched_seq.transpose( 0, 1) # [src_seq_len, batch, 2*hidden_size] BF_ans_matched_seq = ans_alpha.bmm( BF_src_outputs) # [batch, ans_seq_len, 2*hidden_size] ans_matched_seq = BF_ans_matched_seq.transpose( 0, 1) # [src_seq_len, batch, 2*hidden_size] return src_matched_seq, ans_matched_seq
def forward(self, src, lengths=None): """See :func:`EncoderBase.forward()`""" self._check_args(src, lengths) # print("First:", lengths, lengths.dtype) emb = self.embeddings(src) # print("before", emb.size()) if self.conv_net or self.max_pooling or self.highway_net or self.linear_mapping: emb = emb.permute(1, 0, 2) if self.conv_net or self.max_pooling: emb = emb.unsqueeze(1) if self.conv_net: emb = self.conv_net(emb) if self.max_pooling: emb = self.max_pooling(emb) if self.conv_net or self.max_pooling: emb = emb.squeeze(1) if self.highway_net: emb = self.highway_net(emb) if self.linear_mapping: emb = self.linear_mapping(emb) if self.conv_net or self.max_pooling or self.highway_net or self.linear_mapping: emb = emb.permute(1, 0, 2) # print(emb.size()) if self.pos_encoding: emb = self.pos_encoding(emb) out = emb.transpose(0, 1).contiguous() # print("-----", out.size()) # out = emb.contiguous() len_type = lengths.dtype if self.max_pooling: lengths = torch.ceil(lengths.float() / self.conv_pooling).to(dtype=len_type) if lengths.max().tolist() != out.size(1): # print("ERRoor", lengths.max().tolist(), out.size(1)) lengths = torch.tensor(emb.size(1) * [emb.size(0)], device=emb.device) mask = ~sequence_mask(lengths).unsqueeze(1) # print("Second:", lengths, lengths.dtype, mask.size(), out.size()) # print("") for layer in self.transformer: out = layer(out, mask) out = self.layer_norm(out) # print(emb.size(), out.transpose(0, 1).contiguous().size(), lengths.size()) return emb, out.transpose(0, 1).contiguous(), lengths
def __call__(self, batch, rl_forward, baseline_forward): """ There's no better way for now than a for-loop... """ rl_sentences, rl_log_probs, rl_attns = rl_forward baseline_sentences, baseline_log_probs, baseline_attns = baseline_forward device = batch.tgt.device rl_lengths = list() baseline_lengths = list() decoded_sequences = list() rl_scores = list() baseline_scores = list() for b in range(batch.batch_size): rl_candidate, rl_length = self.cleaner.clean_candidate_tokens(rl_sentences[:, b], batch.src_map[:, b], batch.src_ex_vocab[b], rl_attns[:, b]) baseline_candidate, baseline_length = self.cleaner.clean_candidate_tokens(baseline_sentences[:, b], batch.src_map[:, b], batch.src_ex_vocab[b], baseline_attns[:, b]) if rl_length == 0: rl_length = 1 rl_lengths.append(rl_length) baseline_lengths.append(baseline_length) decoded_sequences.append((self.references[batch.indices[b].item()], " ".join(baseline_candidate), " ".join(rl_candidate))) rl_scores.append(self.metric(rl_candidate, batch.indices[b].item())) baseline_scores.append(self.metric(baseline_candidate, batch.indices[b].item())) rl_lengths = torch.LongTensor(rl_lengths).to(device) baseline_lengths = torch.LongTensor(baseline_lengths).to(device) mask = sequence_mask(rl_lengths, max_len=len(rl_sentences)) sequences_scores = rl_log_probs.masked_fill(~mask.transpose(0,1), 0) sequences_scores = sequences_scores.sum(dim=0) / rl_lengths.float() # we reward the model according to f1_score rl_rewards = torch.FloatTensor(rl_scores).to(device) baseline_rewards = torch.FloatTensor(baseline_scores).to(device) rewards = baseline_rewards - rl_rewards loss = (rewards * sequences_scores).mean() stats = self._stats(loss, baseline_rewards.mean(), rl_rewards.mean(), baseline_lengths, rl_lengths, decoded_sequences) return loss, stats
def build_chunk_mask(lengths, ent_size): """ [bsz, n_ents, n_ents] Filled with -inf where self-attention shouldn't attend, a zeros elsewhere. """ ones = lengths // ent_size ones = sequence_mask(ones).unsqueeze(1).repeat(1, ones.max(), 1).to(lengths.device) mask = torch.full(ones.shape, float('-inf')).to(lengths.device) mask.masked_fill_(ones, 0) return mask
def forward(self, src_enc, tgt, src_lengths): dec_outputs = [] dec_states = [] dec_attns = [] tgt_len = tgt.shape[0] src_mask = sequence_mask(src_lengths, max_len=src_enc.size(0)).transpose(0, 1) # Precompute all target-side embeddings tgt_embed = self.embedding(tgt.squeeze(-1)) # Initialize decoder state s = torch.tanh(self.Ws(src_enc[0, :, :])) #dec_states += [s.clone().unsqueeze(0),] # Recurrence for i in range(1, tgt_len): # Condition attention on current decoder state s_ = s.unsqueeze(0).expand(src_enc.shape[0], s.shape[0], s.shape[1]) inpt = torch.cat([s_, src_enc], -1) ei = self.va(torch.tanh(self.Wa(inpt))).squeeze(-1) ei = ei.masked_fill(1 - src_mask, -float('inf')) ai = torch.exp(torch.log_softmax(ei, 0)) #print (ai.shape, src_enc.shape); sys.exit(0) ci = torch.sum(ai.unsqueeze(-1) * src_enc, 0) # Compute decoder output (single layer, no drop-out) # note: use encoder state s_(i-1), before state update inpt = torch.cat([tgt_embed[i - 1, :, :], s, ci], 1) ti = self.Wo(inpt) # Store decoder state and output, attention distributions dec_states += [ s.clone().unsqueeze(0), ] dec_outputs += [ ti.clone().unsqueeze(0), ] dec_attns += [ ai.clone().transpose(0, 1).unsqueeze(0), ] # Update decoder state inpt = torch.cat([tgt_embed[i - 1, :, :], s, ci], -1) zi = torch.sigmoid(self.Wz(inpt)) # update ri = torch.sigmoid(self.Wr(inpt)) # reset inpt = torch.cat([tgt_embed[i - 1, :, :], ri * s, ci], -1) ni = torch.tanh(self.Wn(inpt)) # proposal s = (1.0 - zi) * s + zi * ni # new state dec_outputs = torch.cat(dec_outputs) # (tgt_len, batch_len, nhidden) dec_states = torch.cat(dec_states) # (tgt_len, batch_len, nhidden) dec_attns = torch.cat(dec_attns) # (tgt_len, batch_len, src_lengths) return dec_outputs, dec_states, dec_attns
def forward(self, src, lengths=None): self._check_args(src, lengths) emb = self.embeddings(src) out = emb.transpose(0, 1).contiguous() mask = ~sequence_mask(lengths).unsqueeze(1)# Run the forward pass of every layer of the tranformer. for layer in self.transformer: out = layer(out, mask) out = self.layer_norm(out) return emb, out.transpose(0, 1).contiguous(), lengths
def _make_shard_state(self, batch, range_, result): #import pdb;pdb.set_trace() tgt, tgt_lengths = batch.tgt txt = batch.txt[0] #import pdb;pdb.set_trace() return { "output": result["dec_out"], "target": tgt[1:], "tgt_lengths":sequence_mask(tgt_lengths).transpose(0,1).unsqueeze(-1).type(torch.FloatTensor).cuda(), "txt_out": result["txt_out"], "txt": txt[1:] }
def forward(self, src, batch, lengths=None): """See :func:`EncoderBase.forward()`""" self._check_args(src, lengths) emb = self.embeddings(src) out = emb.transpose(0, 1).contiguous() mask = ~sequence_mask(lengths).unsqueeze(1) # Run the forward pass of every layer of the tranformer. for i, layer in enumerate(self.transformer): out, at_self_attn = layer(out, mask) self.build_visualization(batch, i, at_self_attn) out = self.layer_norm(out) self.batch_count += 1 return emb, out.transpose(0, 1).contiguous(), lengths
def forward(self, src, lengths=None): """See :func:`EncoderBase.forward()`""" self._check_args(src, lengths) emb = self.embeddings(src) #src[300,13,1] emb[300,13,512] out = emb.transpose(0, 1).contiguous() #[13,300,512] mask = ~sequence_mask(lengths).unsqueeze(1) # Run the forward pass of every layer of the tranformer. for layer in self.transformer: out = layer(out, mask) out = self.layer_norm(out) return emb, out.transpose(0, 1).contiguous(), lengths
def forward(self, memory_bank, lengths): #encoder_out # memory_bank: [maxlen, B, H] # lengths: [B, ] mask = sequence_mask(lengths).float() # [B, maxlen] mask = mask / lengths.unsqueeze(1).float() # [B, maxlen] # arg1: [B, 1, maxlen], arg2: [B, maxlen, H]] ==> [B, H] mean = torch.bmm(mask.unsqueeze(1), memory_bank.transpose(0, 1)).squeeze(1) x = torch.tanh(self.fc1(mean)) if self.dropout is not None: x = self.dropout(x) x = self.fc2(x) return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x)
def get_transformer_encoder_attn(model, src, lengths=None): emb = model.embeddings(src) model._check_args(src, lengths) out = emb.transpose(0, 1).contiguous() mask = ~sequence_mask(lengths).unsqueeze(1) attn_matrices = list() # Run the forward pass of every layer of the tranformer. for layer in model.transformer: out, attns = transformer_encoder_forward_with_attn(layer, out, mask) attns.detach() attn_matrices.append(attns) return emb, out.transpose(0, 1).contiguous(), lengths, attn_matrices
def forward(self, tgt, memory_bank, step=None, emotion=None, **kwargs): """Decode, possibly stepwise.""" if step == 0: self._init_cache(memory_bank) tgt_words = tgt[:, :, 0].transpose(0, 1) emb = self.embeddings(tgt, step=step) assert emb.dim() == 3 # len x batch x embedding_dim # add emotion embedding using linear transformation if emotion is not None: batch_emotion_embedding = self.emo_embedding( emotion) # (bacth, emotion_emb_size) batch_emotion_embedding = batch_emotion_embedding.unsqueeze( 0).repeat(emb.size(0), 1, 1) # (len, bacth, emotion_emb_size) emb = self.emo_mlp( torch.cat([emb, batch_emotion_embedding], dim=2)) # emb: (len, bacth, embedding_dim) output = emb.transpose(0, 1).contiguous() src_memory_bank = memory_bank.transpose(0, 1).contiguous() pad_idx = self.embeddings.word_padding_idx src_lens = kwargs["memory_lengths"] src_max_len = self.state["src"].shape[0] src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1) tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] for i, layer in enumerate(self.transformer_layers): layer_cache = self.state["cache"]["layer_{}".format(i)] \ if step is not None else None output, attn = layer(output, src_memory_bank, src_pad_mask, tgt_pad_mask, layer_cache=layer_cache, step=step) output = self.layer_norm(output) dec_outs = output.transpose(0, 1).contiguous() attn = attn.transpose(0, 1).contiguous() attns = {"std": attn} if self._copy: attns["copy"] = attn # TODO change the way attns is returned dict => list or tuple (onnx) return dec_outs, attns
def _make_shard_state(self, batch, range_, result): tgt, tgt_lengths = batch.tgt src_txt = batch.src_txt[0] tgt_txt = batch.tgt_txt[0] return { "output": result["dec_out"], "target": tgt[1:], "tgt_lengths":sequence_mask(tgt_lengths).transpose(0,1).unsqueeze(-1).type(torch.FloatTensor).cuda(), "asr_dec_out": result["asr_dec_out"], "src_txt": src_txt[1:], "tgt_txt_out": result["tgt_txt_out"], "tgt_txt":tgt_txt[1:], }
def forward(self, src, lengths=None): #import pdb;pdb.set_trace() """See :func:`EncoderBase.forward()`""" self._check_args(src, lengths) emb = self.embeddings(src) if self.embeddings is not None else src emb = self.pe(self.noise(emb)) out = emb.transpose(0, 1).contiguous() mask = ~sequence_mask(lengths).unsqueeze(1) # [B, 1, T] # Run the forward pass of every layer of the tranformer. for layer in self.transformer: out = layer(out, mask) out = self.layer_norm(out) return emb, out.transpose(0, 1).contiguous(), lengths
def forward(self, src, imgs=None, lengths=None): """See :func:`EncoderBase.forward()`""" self._check_args(src, lengths) visual_out = self.video_encoder(imgs) emb = self.embeddings(src) out = emb.transpose(0, 1).contiguous() mask = ~sequence_mask(lengths).unsqueeze(1) # Run the forward pass of every layer of the tranformer. for layer in self.transformer: out = layer(out, mask, imgs=visual_out) out = self.layer_norm(out) return emb, (visual_out.transpose(0, 1).contiguous(), out.transpose(0, 1).contiguous()), lengths
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): """ Args: source (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ if source.dim() == 2: source = source.unsqueeze(1) batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch * target_l, self.indim + self.outdim) attn_h = self.linear_out(concat_c).view(batch, target_l, self.outdim) return attn_h.squeeze(1), align_vectors.squeeze(1)
def forward(self, tgt, memory_bank, step=None, **kwargs): """Decode, possibly stepwise.""" if step == 0: self._init_cache(memory_bank) tgt_words = tgt[:, :, 0].transpose(0, 1) emb = self.embeddings(tgt, step=step) if self.n_latent > 1: emb = emb + self.latent_embedding(kwargs["latent_input"].to( self.latent_embedding.weight.device)) assert emb.dim() == 3 # len x batch x embedding_dim if self.n_segments > 0: emb = emb + self.segment_embedding(kwargs["segment_input"]) assert emb.dim() == 3 # len x batch x embedding_dim output = emb.transpose(0, 1).contiguous() src_memory_bank = memory_bank.transpose(0, 1).contiguous() pad_idx = self.embeddings.word_padding_idx src_lens = kwargs["memory_lengths"] src_max_len = self.state["src"].shape[0] src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1) tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] for i, layer in enumerate(self.transformer_layers): layer_cache = self.state["cache"]["layer_{}".format(i)] \ if step is not None else None output, attn = layer(output, src_memory_bank, src_pad_mask, tgt_pad_mask, layer_cache=layer_cache, step=step) output = self.layer_norm(output) dec_outs = output.transpose(0, 1).contiguous() attn = attn.transpose(0, 1).contiguous() attns = {"std": attn} if self._copy: attns["copy"] = attn # TODO change the way attns is returned dict => list or tuple (onnx) return dec_outs, attns
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): """ Args: source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` memory_lengths (LongTensor): the source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) Returns: (FloatTensor, FloatTensor): * Computed vector ``(tgt_len, batch, dim)`` * Attention distribtutions for each query ``(tgt_len, batch, src_len)`` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch*target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch*target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return attn_h, align_vectors