Ejemplo n.º 1
0
    def _encode(self, src_token_ids, padding_mask, training=False):
        """Converts source sequences token ids into continuous representation, and 
    computes the Encoder-encoded sequences.

    Args:
      src_token_ids: int tensor of shape [batch_size, src_seq_len], token ids
        of source sequences.
      padding_mask: float tensor of shape [batch_size, 1, 1, src_seq_len], 
        populated with either 0 (for tokens to keep) or 1 (for tokens to be 
        masked). 
      training: bool scalar, True if in training mode.

    Returns:
      encoder_outputs: float tensor of shape [batch_size, src_seq_len, 
        hidden_size], the encoded source sequences to be used as reference. 
    """
        src_seq_len = tf.shape(src_token_ids)[1]

        # [batch_size, src_seq_len, hidden_size]
        src_token_embeddings = self._embedding_logits_layer(
            src_token_ids, 'embedding')

        # [src_seq_len, hidden_size]
        positional_encoding = utils.get_positional_encoding(
            src_seq_len, self._hidden_size)
        src_token_embeddings += positional_encoding
        src_token_embeddings = self._encoder_dropout_layer(
            src_token_embeddings, training)

        encoder_outputs = self._encoder(src_token_embeddings, padding_mask,
                                        training)
        return encoder_outputs
    def _get_final_embeddings(self, inputs, memories, training):
        """Computes the final embedding vectors coming off the top layer of
    TransformerXL model.

    Args:
      inputs: int tensor of shape [batch_size, q_seq_len], token ids of the
        input sequence segment.
      memories: float tensor of shape [batch_size, stack_size, c_seq_len,
        hidden_size], embeddings of the tokens from the previous sequence
        segment for each layer of the decoder stack.
      training: bool scalar, True if in training mode.

    Returns:
      embeddings: float tensor of shape [batch_size, q_seq_len, hidden_size],
        the final embeddings of inputs.
      new_memories: float tensor of shape [batch_size, stack_size, c_seq_len,
        hidden_size], the updated embedding vectors of the memory sequence
        segment.
    """
        m_seq_len = tf.shape(memories)[2]
        batch_size, q_seq_len = tf.unstack(tf.shape(inputs), axis=0)
        new_memories = []

        # [batch_size, q_seq_len, hidden_size]
        embeddings = self._embedding_layer(inputs, mode='embedding')

        # [1, 1, q_seq_len, q_seq_len + m_seq_len]
        attention_mask = utils.get_look_ahead_mask(q_seq_len, m_seq_len)

        # [q_seq_len + m_seq_len, hidden_size]
        positional_encoding = utils.get_positional_encoding(batch_size,
                                                            m_seq_len +
                                                            q_seq_len,
                                                            self._hidden_size,
                                                            reverse=True)

        embeddings = self._embeddings_dropout_layer(embeddings,
                                                    training=training)
        positional_encoding = self._positional_encoding_dropout_layer(
            positional_encoding, training=training)

        for i in range(self._stack_size):
            new_memories.append(utils.cache_memory(memories[:, i], embeddings))

            if self._tie_biases:
                content_bias, position_bias = self.weights[:2]
            else:
                content_bias, position_bias = self.weights[0][i], self.weights[
                    1][i]

            embeddings = self._stack[i](embeddings,
                                        tf.concat([memories[:, i], embeddings],
                                                  axis=1),
                                        positional_encoding,
                                        attention_mask,
                                        content_bias,
                                        position_bias,
                                        training=training)
        new_memories = tf.stack(new_memories, axis=1)
        return embeddings, new_memories
Ejemplo n.º 3
0
    def beam_decode(self, src, guess, src_lang_idx, tgt_lang_idx, logit_mask):
        embed_dim = self.args.embed_dim
        max_len = src.size(1) + 51
        pos_embedding = ut.get_positional_encoding(embed_dim, max_len)
        word_embedding = F.normalize(self.word_embedding, dim=-1) if self.args.fix_norm else self.word_embedding
        logit_mask = logit_mask if self.logit_mask is None else self.logit_mask
        tgt_lang_embed = self.lang_embedding[tgt_lang_idx]

        encoder1_inputs = self.get_input(src, src_lang_idx, word_embedding, pos_embedding)
        encoder1_mask = (src == ac.PAD_ID).unsqueeze(1).unsqueeze(2)
        encoder1_outputs = self.encoder1(encoder1_inputs, encoder1_mask)

        encoder2_inputs = self.get_input(guess, tgt_lang_idx, word_embedding, pos_embedding)
        encoder2_mask = (guess == ac.PAD_ID).unsqueeze(1).unsqueeze(2)
        encoder2_outputs = self.encoder2(encoder2_inputs, encoder2_mask)

        def get_tgt_inp(tgt, time_step):
            word_embed = F.embedding(tgt.type(src.type()), word_embedding) * self.scale
            pos_embed = pos_embedding[time_step, :].reshape(1, 1, -1)
            return word_embed + tgt_lang_embed + pos_embed

        def logprob_fn(decoder_output):
            logits = self.logit_fn(decoder_output, word_embedding, logit_mask)
            return F.log_softmax(logits, dim=-1)

        # following Attention is all you need, we decode up to src_len + 50 tokens only
        max_lengths = torch.sum(src != ac.PAD_ID, dim=-1).type(src.type()) + 50
        return self.decoder.beam_decode(encoder1_outputs, encoder1_mask, encoder2_outputs, encoder2_mask, get_tgt_inp, logprob_fn, ac.BOS_ID, ac.EOS_ID, max_lengths, beam_size=self.args.beam_size, alpha=self.args.beam_alpha)
Ejemplo n.º 4
0
    def __getitem__(self, idx):
        current_question = self.questions[idx]
        img_filename = os.path.join(self.img_dir,
                                    current_question['image_filename'])
        # image = Image.open(img_filename).convert('RGB')
        image_id = int(img_filename.rsplit('_', 1)[1][:-4])

        image = self.img_features_file[image_id]

        _, H, W = image.shape
        position_encoding = get_positional_encoding(H, W)

        image = np.concatenate((image, position_encoding), axis=0)

        question = utils.to_dictionary_indexes(self.dictionaries[0],
                                               current_question['question'])
        answer = utils.to_dictionary_indexes(self.dictionaries[1],
                                             current_question['answer'])
        # answer_class = self.dictionaries[2][answer.item()]

        question_type = get_ques_type(
            current_question['program'][-1]['function'])

        answer = (answer - 1)  # convert to zero based indexing

        return image, question, len(question), answer, question_type
    def forward(self, src, tgt, targets, src_lang_idx, tgt_lang_idx,
                logit_mask):
        embed_dim = self.args.embed_dim
        max_len = max(src.size(1), tgt.size(1))
        pos_embedding = ut.get_positional_encoding(embed_dim, max_len)
        word_embedding = F.normalize(
            self.word_embedding,
            dim=-1) if self.args.fix_norm else self.word_embedding

        encoder_inputs = self.get_input(src, src_lang_idx, word_embedding,
                                        pos_embedding)
        encoder_mask = (src == ac.PAD_ID).unsqueeze(1).unsqueeze(2)
        encoder_outputs = self.encoder(encoder_inputs, encoder_mask)

        decoder_inputs = self.get_input(tgt, tgt_lang_idx, word_embedding,
                                        pos_embedding)
        decoder_mask = torch.triu(torch.ones((tgt.size(-1), tgt.size(-1))),
                                  diagonal=1).type(tgt.type()) == 1
        decoder_mask = decoder_mask.unsqueeze(0).unsqueeze(1)
        decoder_outputs = self.decoder(decoder_inputs, decoder_mask,
                                       encoder_outputs, encoder_mask)

        logit_mask = logit_mask if self.logit_mask is None else self.logit_mask
        logits = self.logit_fn(decoder_outputs, word_embedding, logit_mask)
        neglprobs = F.log_softmax(logits, -1) * logit_mask.type(
            logits.type()).reshape(1, -1)
        targets = targets.reshape(-1, 1)
        non_pad_mask = targets != ac.PAD_ID

        nll_loss = neglprobs.gather(dim=-1, index=targets)[non_pad_mask]
        smooth_loss = neglprobs.sum(dim=-1, keepdim=True)[non_pad_mask]

        # label smoothing: https://arxiv.org/pdf/1701.06548.pdf
        nll_loss = -(nll_loss.sum())
        smooth_loss = -(smooth_loss.sum())
        label_smoothing = self.args.label_smoothing
        if label_smoothing > 0:
            loss = (
                1.0 - label_smoothing
            ) * nll_loss + label_smoothing * smooth_loss / logit_mask.type(
                nll_loss.type()).sum()
        else:
            loss = nll_loss

        num_words = non_pad_mask.type(loss.type()).sum()
        opt_loss = loss / num_words
        return {
            'opt_loss': opt_loss,
            'loss': loss,
            'nll_loss': nll_loss,
            'num_words': num_words
        }
Ejemplo n.º 6
0
    def _decode(self, tgt_token_ids, encoder_outputs, padding_mask):
        """Computes the estimated logits of target token ids, based on the encoded 
    source sequences. Note this function should be called in training mode only.

    Args:
      tgt_token_ids: int tensor of shape [batch_size, tgt_seq_len] token ids of 
        target sequences.
      encoder_outputs: float tensor of shape [batch_size, src_seq_len, 
        hidden_size], the encoded source sequences to be used as reference. 
      padding_mask: float tensor of shape [batch_size, 1, 1, src_seq_len], 
        populated with either 0 (for tokens to keep) or 1 (for tokens to be 
        masked). 

    Returns:
      logits: float tensor of shape [batch_size, tgt_seq_len, vocab_size].
    """
        tgt_seq_len = tf.shape(tgt_token_ids)[1]

        # [batch_size, tgt_seq_len, hidden_size]
        tgt_token_embeddings = self._embedding_logits_layer(
            tgt_token_ids, 'embedding')

        # [tgt_seq_len, hidden_size]
        positional_encoding = utils.get_positional_encoding(
            tgt_seq_len, self._hidden_size)
        tgt_token_embeddings += positional_encoding
        tgt_token_embeddings = self._decoder_dropout_layer(
            tgt_token_embeddings, training=True)

        look_ahead_mask = utils.get_look_ahead_mask(tgt_seq_len)

        decoder_outputs = self._decoder(tgt_token_embeddings,
                                        encoder_outputs,
                                        look_ahead_mask,
                                        padding_mask,
                                        training=True)

        logits = self._embedding_logits_layer(decoder_outputs, 'logits')
        return logits
Ejemplo n.º 7
0
    def _build_decoding_fn(self, max_decode_length):
        """Builds the decoding function that will be called in beam search.

    The function steps through the proposed token ids one at a time, and 
    generates the logits of next token id over the vocabulary.

    Args:
      max_decode_length: int scalar, the decoded sequences would not exceed
        `max_decode_length`.

    Returns:
      decoding_fn: a callable that outputs the logits of the next decoded token
        ids.
    """
        # [max_decode_length, hidden_size]
        timing_signal = utils.get_positional_encoding(max_decode_length,
                                                      self._hidden_size)
        timing_signal = tf.cast(timing_signal, 'float32')

        def decoding_fn(decoder_input, cache, **kwargs):
            """Computes the logits of the next decoded token ids.

      Args:
        decoder_input: int tensor of shape [batch_size * beam_width, 1], the 
          decoded tokens at index `i`.
        cache: dict of entries
          'encoder_outputs': tensor of shape 
            [batch_size * beam_width, src_seq_len, hidden_size],
          'padding_mask': tensor of shape
            [batch_size * beam_width, 1, 1, src_seq_len],

          and entries with keys 'layer_0',...,'layer_[decoder_num_layers - 1]'
          where the value associated with key 'layer_*' is a dict with entries
            'k': tensor of shape [batch_size * beam_width, seq_len, num_heads, 
              size_per_head],
            'v': tensor of shape [batch_size * beam_width, seq_len, num_heads, 
              size_per_head],
            'tgt_tgt_attention': tensor of shape [batch_size * beam_width, 
              num_heads, seq_len, seq_len],
            'tgt_src_attention': tensor of shape [batch_size * beam_width, 
              num_heads, seq_len, src_seq_len].
            Note `seq_len` is the running length of the growing decode sequence.
        kwargs: dict, storing the following additional keyword arguments.
          index -> int scalar tensor, the index of the `decoder_input` in the 
            decoded sequence.

      Returns:
        logits: float tensor of shape [batch_size * beam_width, vocab_size].
        cache: a dict with the same structure as the input `cache`, except that
          the shapes of the values of key `k`, `v`, `tgt_tgt_attention`, 
          `tgt_src_attention` are
          [batch_size * beam_width, seq_len + 1, num_heads, size_per_head],
          [batch_size * beam_width, seq_len + 1, num_heads, size_per_head],
          [batch_size * beam_width, num_heads, seq_len + 1, seq_len + 1],
          [batch_size * beam_width, num_heads, seq_len + 1, src_seq_len].
      """
            index = kwargs['index']
            # [batch_size * beam_width, 1, hidden_size]
            decoder_input = self._embedding_logits_layer(
                decoder_input, 'embedding')
            decoder_input += timing_signal[index:index + 1]

            decoder_outputs = self._decoder(decoder_input,
                                            cache['encoder_outputs'],
                                            tf.zeros((1, 1, 1, index + 1),
                                                     dtype='float32'),
                                            cache['padding_mask'],
                                            training=False,
                                            cache=cache)

            logits = self._embedding_logits_layer(decoder_outputs,
                                                  mode='logits')
            logits = tf.squeeze(logits, axis=1)
            return logits, cache

        return decoding_fn
Ejemplo n.º 8
0
    def forward(self, question, question_mask, image_feats):
        # Input unit - encoding question
        question = self.q_embed(question)
        lstm_seq, (h, _) = pack_and_rnn(question, question_mask.sum(1),
                                        self.q_enc)
        q_vec = torch.cat([h[0], h[1]], -1)
        # Process kb
        position_encoding = get_positional_encoding(
            image_feats.size(1), image_feats.size(2), self.pe_dim,
            image_feats.device).repeat(image_feats.size(0), 1, 1, 1)
        kb_batch = torch.cat([image_feats, position_encoding], 3)
        kb_batch = channels_last_conv(kb_batch, self.kb_process)
        # init values
        control = self.init_ctrl.unsqueeze(0).repeat(image_feats.size(0), 1)
        att_stack, stack_ptr, mem = self.nmn.get_init_values(
            image_feats.size(0), image_feats.device)
        module_logits = []
        question_attns = []
        image_attns = []
        for i in range(self.steps):
            # Controller and NMN
            control, module_logit, module_probs, qattn = self.controller(
                question, lstm_seq, q_vec, control, question_mask, i)
            question_attns.append(qattn)
            module_logits.append(module_logit)
            # module validity
            if self.cfg.MODEL.NMN.VALIDATE_MODULES:
                module_validity = stack_ptr.float(
                ) @ self.nmn.module_validity_mat.to(stack_ptr.device)
                module_probs = module_probs * module_validity
                module_probs = module_probs / module_probs.sum(1).unsqueeze(1)
            # nmn
            att_stack, stack_ptr, mem = self.nmn(control, kb_batch,
                                                 module_probs, mem, att_stack,
                                                 stack_ptr)
            image_attns.append((att_stack * stack_ptr[:, None, None]).sum(-1))
        outputs = {
            "qattns": torch.stack(question_attns, 1),
            "iattns": torch.stack(image_attns, 1),
        }
        # output - two layer FC
        if self.cfg.MODEL.BUILD_VQA:
            output_logits = self.output_unit(torch.cat([q_vec, mem], 1))
            outputs["logits"] = output_logits
        # output for clevr-ref
        if self.cfg.MODEL.BUILD_LOC:
            att_last = self.nmn.get_stack_value(att_stack, stack_ptr)
            # first a linear layer (LOC_SCORES_POS_AFFINE)
            loc_scores = (torch.abs(self.output_loc_aff_w) * att_last +
                          self.output_loc_aff_b)
            loc_scores = loc_scores.view(
                -1, self.cfg.MODEL.H_FEAT * self.cfg.MODEL.W_FEAT)
            # one layer conv (BBOX_REG_AS_FCN)
            bbox_offset_fcn = channels_last_conv(kb_batch, self.loc_conv)
            N = bbox_offset_fcn.size(0)
            B = self.cfg.MODEL.H_FEAT * self.cfg.MODEL.W_FEAT
            # bbox_offset_fcn [N, B, 4] is used for training
            bbox_offset_fcn = bbox_offset_fcn.view(N, B, 4)
            # bbox_offset [N, 4] is only used for prediction
            bbox_offset_flat = bbox_offset_fcn.view(N * B, 4)
            slice_inds = (torch.arange(0, N, device=loc_scores.device) * B +
                          torch.argmax(loc_scores, dim=-1).long())
            bbox_offset = bbox_offset_flat[slice_inds]
            outputs["loc_scores"] = loc_scores
            outputs["bbox_offset"] = bbox_offset
            outputs["bbox_offset_fcn"] = bbox_offset_fcn

        outputs["module_logits"] = torch.stack(module_logits, 1)
        return outputs