def _encode(self, src_token_ids, padding_mask, training=False): """Converts source sequences token ids into continuous representation, and computes the Encoder-encoded sequences. Args: src_token_ids: int tensor of shape [batch_size, src_seq_len], token ids of source sequences. padding_mask: float tensor of shape [batch_size, 1, 1, src_seq_len], populated with either 0 (for tokens to keep) or 1 (for tokens to be masked). training: bool scalar, True if in training mode. Returns: encoder_outputs: float tensor of shape [batch_size, src_seq_len, hidden_size], the encoded source sequences to be used as reference. """ src_seq_len = tf.shape(src_token_ids)[1] # [batch_size, src_seq_len, hidden_size] src_token_embeddings = self._embedding_logits_layer( src_token_ids, 'embedding') # [src_seq_len, hidden_size] positional_encoding = utils.get_positional_encoding( src_seq_len, self._hidden_size) src_token_embeddings += positional_encoding src_token_embeddings = self._encoder_dropout_layer( src_token_embeddings, training) encoder_outputs = self._encoder(src_token_embeddings, padding_mask, training) return encoder_outputs
def _get_final_embeddings(self, inputs, memories, training): """Computes the final embedding vectors coming off the top layer of TransformerXL model. Args: inputs: int tensor of shape [batch_size, q_seq_len], token ids of the input sequence segment. memories: float tensor of shape [batch_size, stack_size, c_seq_len, hidden_size], embeddings of the tokens from the previous sequence segment for each layer of the decoder stack. training: bool scalar, True if in training mode. Returns: embeddings: float tensor of shape [batch_size, q_seq_len, hidden_size], the final embeddings of inputs. new_memories: float tensor of shape [batch_size, stack_size, c_seq_len, hidden_size], the updated embedding vectors of the memory sequence segment. """ m_seq_len = tf.shape(memories)[2] batch_size, q_seq_len = tf.unstack(tf.shape(inputs), axis=0) new_memories = [] # [batch_size, q_seq_len, hidden_size] embeddings = self._embedding_layer(inputs, mode='embedding') # [1, 1, q_seq_len, q_seq_len + m_seq_len] attention_mask = utils.get_look_ahead_mask(q_seq_len, m_seq_len) # [q_seq_len + m_seq_len, hidden_size] positional_encoding = utils.get_positional_encoding(batch_size, m_seq_len + q_seq_len, self._hidden_size, reverse=True) embeddings = self._embeddings_dropout_layer(embeddings, training=training) positional_encoding = self._positional_encoding_dropout_layer( positional_encoding, training=training) for i in range(self._stack_size): new_memories.append(utils.cache_memory(memories[:, i], embeddings)) if self._tie_biases: content_bias, position_bias = self.weights[:2] else: content_bias, position_bias = self.weights[0][i], self.weights[ 1][i] embeddings = self._stack[i](embeddings, tf.concat([memories[:, i], embeddings], axis=1), positional_encoding, attention_mask, content_bias, position_bias, training=training) new_memories = tf.stack(new_memories, axis=1) return embeddings, new_memories
def beam_decode(self, src, guess, src_lang_idx, tgt_lang_idx, logit_mask): embed_dim = self.args.embed_dim max_len = src.size(1) + 51 pos_embedding = ut.get_positional_encoding(embed_dim, max_len) word_embedding = F.normalize(self.word_embedding, dim=-1) if self.args.fix_norm else self.word_embedding logit_mask = logit_mask if self.logit_mask is None else self.logit_mask tgt_lang_embed = self.lang_embedding[tgt_lang_idx] encoder1_inputs = self.get_input(src, src_lang_idx, word_embedding, pos_embedding) encoder1_mask = (src == ac.PAD_ID).unsqueeze(1).unsqueeze(2) encoder1_outputs = self.encoder1(encoder1_inputs, encoder1_mask) encoder2_inputs = self.get_input(guess, tgt_lang_idx, word_embedding, pos_embedding) encoder2_mask = (guess == ac.PAD_ID).unsqueeze(1).unsqueeze(2) encoder2_outputs = self.encoder2(encoder2_inputs, encoder2_mask) def get_tgt_inp(tgt, time_step): word_embed = F.embedding(tgt.type(src.type()), word_embedding) * self.scale pos_embed = pos_embedding[time_step, :].reshape(1, 1, -1) return word_embed + tgt_lang_embed + pos_embed def logprob_fn(decoder_output): logits = self.logit_fn(decoder_output, word_embedding, logit_mask) return F.log_softmax(logits, dim=-1) # following Attention is all you need, we decode up to src_len + 50 tokens only max_lengths = torch.sum(src != ac.PAD_ID, dim=-1).type(src.type()) + 50 return self.decoder.beam_decode(encoder1_outputs, encoder1_mask, encoder2_outputs, encoder2_mask, get_tgt_inp, logprob_fn, ac.BOS_ID, ac.EOS_ID, max_lengths, beam_size=self.args.beam_size, alpha=self.args.beam_alpha)
def __getitem__(self, idx): current_question = self.questions[idx] img_filename = os.path.join(self.img_dir, current_question['image_filename']) # image = Image.open(img_filename).convert('RGB') image_id = int(img_filename.rsplit('_', 1)[1][:-4]) image = self.img_features_file[image_id] _, H, W = image.shape position_encoding = get_positional_encoding(H, W) image = np.concatenate((image, position_encoding), axis=0) question = utils.to_dictionary_indexes(self.dictionaries[0], current_question['question']) answer = utils.to_dictionary_indexes(self.dictionaries[1], current_question['answer']) # answer_class = self.dictionaries[2][answer.item()] question_type = get_ques_type( current_question['program'][-1]['function']) answer = (answer - 1) # convert to zero based indexing return image, question, len(question), answer, question_type
def forward(self, src, tgt, targets, src_lang_idx, tgt_lang_idx, logit_mask): embed_dim = self.args.embed_dim max_len = max(src.size(1), tgt.size(1)) pos_embedding = ut.get_positional_encoding(embed_dim, max_len) word_embedding = F.normalize( self.word_embedding, dim=-1) if self.args.fix_norm else self.word_embedding encoder_inputs = self.get_input(src, src_lang_idx, word_embedding, pos_embedding) encoder_mask = (src == ac.PAD_ID).unsqueeze(1).unsqueeze(2) encoder_outputs = self.encoder(encoder_inputs, encoder_mask) decoder_inputs = self.get_input(tgt, tgt_lang_idx, word_embedding, pos_embedding) decoder_mask = torch.triu(torch.ones((tgt.size(-1), tgt.size(-1))), diagonal=1).type(tgt.type()) == 1 decoder_mask = decoder_mask.unsqueeze(0).unsqueeze(1) decoder_outputs = self.decoder(decoder_inputs, decoder_mask, encoder_outputs, encoder_mask) logit_mask = logit_mask if self.logit_mask is None else self.logit_mask logits = self.logit_fn(decoder_outputs, word_embedding, logit_mask) neglprobs = F.log_softmax(logits, -1) * logit_mask.type( logits.type()).reshape(1, -1) targets = targets.reshape(-1, 1) non_pad_mask = targets != ac.PAD_ID nll_loss = neglprobs.gather(dim=-1, index=targets)[non_pad_mask] smooth_loss = neglprobs.sum(dim=-1, keepdim=True)[non_pad_mask] # label smoothing: https://arxiv.org/pdf/1701.06548.pdf nll_loss = -(nll_loss.sum()) smooth_loss = -(smooth_loss.sum()) label_smoothing = self.args.label_smoothing if label_smoothing > 0: loss = ( 1.0 - label_smoothing ) * nll_loss + label_smoothing * smooth_loss / logit_mask.type( nll_loss.type()).sum() else: loss = nll_loss num_words = non_pad_mask.type(loss.type()).sum() opt_loss = loss / num_words return { 'opt_loss': opt_loss, 'loss': loss, 'nll_loss': nll_loss, 'num_words': num_words }
def _decode(self, tgt_token_ids, encoder_outputs, padding_mask): """Computes the estimated logits of target token ids, based on the encoded source sequences. Note this function should be called in training mode only. Args: tgt_token_ids: int tensor of shape [batch_size, tgt_seq_len] token ids of target sequences. encoder_outputs: float tensor of shape [batch_size, src_seq_len, hidden_size], the encoded source sequences to be used as reference. padding_mask: float tensor of shape [batch_size, 1, 1, src_seq_len], populated with either 0 (for tokens to keep) or 1 (for tokens to be masked). Returns: logits: float tensor of shape [batch_size, tgt_seq_len, vocab_size]. """ tgt_seq_len = tf.shape(tgt_token_ids)[1] # [batch_size, tgt_seq_len, hidden_size] tgt_token_embeddings = self._embedding_logits_layer( tgt_token_ids, 'embedding') # [tgt_seq_len, hidden_size] positional_encoding = utils.get_positional_encoding( tgt_seq_len, self._hidden_size) tgt_token_embeddings += positional_encoding tgt_token_embeddings = self._decoder_dropout_layer( tgt_token_embeddings, training=True) look_ahead_mask = utils.get_look_ahead_mask(tgt_seq_len) decoder_outputs = self._decoder(tgt_token_embeddings, encoder_outputs, look_ahead_mask, padding_mask, training=True) logits = self._embedding_logits_layer(decoder_outputs, 'logits') return logits
def _build_decoding_fn(self, max_decode_length): """Builds the decoding function that will be called in beam search. The function steps through the proposed token ids one at a time, and generates the logits of next token id over the vocabulary. Args: max_decode_length: int scalar, the decoded sequences would not exceed `max_decode_length`. Returns: decoding_fn: a callable that outputs the logits of the next decoded token ids. """ # [max_decode_length, hidden_size] timing_signal = utils.get_positional_encoding(max_decode_length, self._hidden_size) timing_signal = tf.cast(timing_signal, 'float32') def decoding_fn(decoder_input, cache, **kwargs): """Computes the logits of the next decoded token ids. Args: decoder_input: int tensor of shape [batch_size * beam_width, 1], the decoded tokens at index `i`. cache: dict of entries 'encoder_outputs': tensor of shape [batch_size * beam_width, src_seq_len, hidden_size], 'padding_mask': tensor of shape [batch_size * beam_width, 1, 1, src_seq_len], and entries with keys 'layer_0',...,'layer_[decoder_num_layers - 1]' where the value associated with key 'layer_*' is a dict with entries 'k': tensor of shape [batch_size * beam_width, seq_len, num_heads, size_per_head], 'v': tensor of shape [batch_size * beam_width, seq_len, num_heads, size_per_head], 'tgt_tgt_attention': tensor of shape [batch_size * beam_width, num_heads, seq_len, seq_len], 'tgt_src_attention': tensor of shape [batch_size * beam_width, num_heads, seq_len, src_seq_len]. Note `seq_len` is the running length of the growing decode sequence. kwargs: dict, storing the following additional keyword arguments. index -> int scalar tensor, the index of the `decoder_input` in the decoded sequence. Returns: logits: float tensor of shape [batch_size * beam_width, vocab_size]. cache: a dict with the same structure as the input `cache`, except that the shapes of the values of key `k`, `v`, `tgt_tgt_attention`, `tgt_src_attention` are [batch_size * beam_width, seq_len + 1, num_heads, size_per_head], [batch_size * beam_width, seq_len + 1, num_heads, size_per_head], [batch_size * beam_width, num_heads, seq_len + 1, seq_len + 1], [batch_size * beam_width, num_heads, seq_len + 1, src_seq_len]. """ index = kwargs['index'] # [batch_size * beam_width, 1, hidden_size] decoder_input = self._embedding_logits_layer( decoder_input, 'embedding') decoder_input += timing_signal[index:index + 1] decoder_outputs = self._decoder(decoder_input, cache['encoder_outputs'], tf.zeros((1, 1, 1, index + 1), dtype='float32'), cache['padding_mask'], training=False, cache=cache) logits = self._embedding_logits_layer(decoder_outputs, mode='logits') logits = tf.squeeze(logits, axis=1) return logits, cache return decoding_fn
def forward(self, question, question_mask, image_feats): # Input unit - encoding question question = self.q_embed(question) lstm_seq, (h, _) = pack_and_rnn(question, question_mask.sum(1), self.q_enc) q_vec = torch.cat([h[0], h[1]], -1) # Process kb position_encoding = get_positional_encoding( image_feats.size(1), image_feats.size(2), self.pe_dim, image_feats.device).repeat(image_feats.size(0), 1, 1, 1) kb_batch = torch.cat([image_feats, position_encoding], 3) kb_batch = channels_last_conv(kb_batch, self.kb_process) # init values control = self.init_ctrl.unsqueeze(0).repeat(image_feats.size(0), 1) att_stack, stack_ptr, mem = self.nmn.get_init_values( image_feats.size(0), image_feats.device) module_logits = [] question_attns = [] image_attns = [] for i in range(self.steps): # Controller and NMN control, module_logit, module_probs, qattn = self.controller( question, lstm_seq, q_vec, control, question_mask, i) question_attns.append(qattn) module_logits.append(module_logit) # module validity if self.cfg.MODEL.NMN.VALIDATE_MODULES: module_validity = stack_ptr.float( ) @ self.nmn.module_validity_mat.to(stack_ptr.device) module_probs = module_probs * module_validity module_probs = module_probs / module_probs.sum(1).unsqueeze(1) # nmn att_stack, stack_ptr, mem = self.nmn(control, kb_batch, module_probs, mem, att_stack, stack_ptr) image_attns.append((att_stack * stack_ptr[:, None, None]).sum(-1)) outputs = { "qattns": torch.stack(question_attns, 1), "iattns": torch.stack(image_attns, 1), } # output - two layer FC if self.cfg.MODEL.BUILD_VQA: output_logits = self.output_unit(torch.cat([q_vec, mem], 1)) outputs["logits"] = output_logits # output for clevr-ref if self.cfg.MODEL.BUILD_LOC: att_last = self.nmn.get_stack_value(att_stack, stack_ptr) # first a linear layer (LOC_SCORES_POS_AFFINE) loc_scores = (torch.abs(self.output_loc_aff_w) * att_last + self.output_loc_aff_b) loc_scores = loc_scores.view( -1, self.cfg.MODEL.H_FEAT * self.cfg.MODEL.W_FEAT) # one layer conv (BBOX_REG_AS_FCN) bbox_offset_fcn = channels_last_conv(kb_batch, self.loc_conv) N = bbox_offset_fcn.size(0) B = self.cfg.MODEL.H_FEAT * self.cfg.MODEL.W_FEAT # bbox_offset_fcn [N, B, 4] is used for training bbox_offset_fcn = bbox_offset_fcn.view(N, B, 4) # bbox_offset [N, 4] is only used for prediction bbox_offset_flat = bbox_offset_fcn.view(N * B, 4) slice_inds = (torch.arange(0, N, device=loc_scores.device) * B + torch.argmax(loc_scores, dim=-1).long()) bbox_offset = bbox_offset_flat[slice_inds] outputs["loc_scores"] = loc_scores outputs["bbox_offset"] = bbox_offset outputs["bbox_offset_fcn"] = bbox_offset_fcn outputs["module_logits"] = torch.stack(module_logits, 1) return outputs