Example #1
0
    def test_swap_is_not_persisted_in_class(self):
        opt = self._opt()
        dictionary = DictionaryAgent(opt)

        CustomFFN = type('CustomFFN', (TransformerFFN, ), {})
        wrapped_class = TransformerGeneratorModel.with_components(
            encoder=TransformerEncoder.with_components(
                layer=TransformerEncoderLayer.with_components(
                    feedforward=CustomFFN)))
        model = wrapped_class(opt=opt, dictionary=dictionary)
        assert (
            model.swappables.encoder.swappables.layer.swappables.feedforward ==
            CustomFFN)  # type: ignore

        another_model = TransformerGeneratorModel(opt, dictionary)
        assert another_model.swappables != model.swappables
        assert issubclass(another_model.swappables.encoder,
                          TransformerEncoder)  # type: ignore

        wrapped_class.swap_components(
            encoder=TransformerEncoder.with_components(
                layer=TransformerEncoderLayer.with_components(
                    feedforward=TransformerFFN)))
        one_more_model = wrapped_class(opt=opt, dictionary=dictionary)
        assert (one_more_model.swappables.encoder.swappables.layer.swappables.
                feedforward == TransformerFFN)  # type: ignore
Example #2
0
 def _build_context_encoder(self):
     """
     Build the context (i.e. dialogue history) encoder.
     """
     if self.opt.get("share_encoder"):
         self.context_encoder = self.label_encoder
     else:
         if (self.opt["load_context_encoder_from"] is None
                 and self.opt["context_encoder_embedding_type"]
                 == "fasttext_cc"):
             embeddings = load_fasttext_embeddings(
                 self.dictionary, self.opt["embedding_size"],
                 self.opt["datapath"])
         else:
             embeddings = nn.Embedding(len(self.dictionary),
                                       self.opt["embedding_size"])
         self.context_encoder = TransformerEncoder(
             opt=self.opt,
             embedding=embeddings,
             vocabulary_size=len(self.dictionary),
             padding_idx=self.dictionary.tok2ind[
                 self.dictionary.null_token],
             embeddings_scale=False,
             output_scaling=1.0,
         )
         if self.opt.get("load_context_encoder_from") is not None:
             self._load_context_encoder_state()
Example #3
0
    def _build_text_encoder(self, n_layers_text):
        """
        Build the text (candidate) encoder.

        :param n_layers_text:
            how many layers the transformer will have
        """
        self.embeddings = nn.Embedding(len(self.dictionary),
                                       self.opt['embedding_size'])
        if (self.opt.get('load_encoder_from') is None
                and self.opt['embedding_type'] == 'fasttext_cc'):
            self.embeddings = load_fasttext_embeddings(
                self.dictionary, self.opt['embedding_size'],
                self.opt['datapath'])

        self.text_encoder = TransformerEncoder(
            opt=self.opt,
            embedding=self.embeddings,
            vocabulary_size=len(self.dictionary),
            padding_idx=self.dictionary.tok2ind[self.dictionary.null_token],
            embeddings_scale=False,
            output_scaling=1.0,
        )
        if self.opt.get('load_encoder_from') is not None:
            self._load_text_encoder_state()

        self.additional_layer = LinearWrapper(
            self.opt['embedding_size'],
            self.opt['hidden_dim'],
            dropout=self.opt['additional_layer_dropout'],
        )
Example #4
0
 def __init__(self, opt: ModelOption, tokenizer: Tokenizer,
              embedding: nn.Embedding):
     super(TransformerMemNet, self).__init__()
     self.encoder = TransformerEncoder(
         n_heads=opt.n_heads,
         n_layers=opt.n_layers,
         embedding_size=embedding.weight.shape[1],
         vocabulary_size=embedding.weight.shape[0],
         ffn_size=opt.ffn_size,
         embedding=embedding,
         dropout=opt.dropout,
         padding_idx=tokenizer.dictionary[tokenizer.dictionary.null_token],
         n_positions=opt.n_positions,
         reduction_type='none')
     self.decoder = TransformerDecoder(
         n_heads=opt.n_heads,
         n_layers=opt.n_layers,
         embedding_size=embedding.weight.shape[1],
         ffn_size=opt.ffn_size,
         vocabulary_size=embedding.weight.shape[0],
         embedding=embedding,
         dropout=opt.dropout,
         n_positions=opt.n_positions,
         padding_idx=tokenizer.dictionary[tokenizer.dictionary.null_token])
     self.encoder = ContextKnowledgeEncoder(self.encoder)
     self.decoder = ContextKnowledgeDecoder(self.decoder)
     self.linear = nn.Linear(300, embedding.weight.shape[0])
Example #5
0
 def __init__(self, transformer, opt, dictionary, agenttype):
   super().__init__()
   # The transformer takes care of most of the work, but other modules
   # expect us to have an embeddings available
   self.embeddings = transformer.embeddings
   self.embed_dim = transformer.embeddings.embedding_dim
   self.transformer = transformer
   self.knowledge_transformer = TransformerEncoder(
       embedding=self.embeddings,
       n_heads=opt['n_heads'],
       n_layers=opt['n_layers_knowledge'],
       embedding_size=opt['embedding_size'],
       ffn_size=opt['ffn_size'],
       vocabulary_size=len(dictionary),
       padding_idx=transformer.padding_idx,
       learn_positional_embeddings=opt['learn_positional_embeddings'],
       embeddings_scale=opt['embeddings_scale'],
       reduction_type=transformer.reduction_type,
       n_positions=transformer.n_positions,
       activation=opt['activation'],
       variant=opt['variant'],
       output_scaling=opt['output_scaling'],
   )
   self.agenttype = agenttype
   if self.agenttype == 'agent':
     self.intent_head = ClassificationHead(opt['embedding_size'])
     self.name_head = MultiTokenClassificationHead(opt['embedding_size'],
                                                   self.embeddings,
                                                   opt.get('name_vec_len'))
     self.reservation_transformer = TransformerEncoder(
         embedding=self.embeddings,
         n_heads=opt['n_heads'],
         n_layers=opt['n_layers_knowledge'],
         embedding_size=opt['embedding_size'],
         ffn_size=opt['ffn_size'],
         vocabulary_size=len(dictionary),
         padding_idx=transformer.padding_idx,
         learn_positional_embeddings=opt['learn_positional_embeddings'],
         embeddings_scale=opt['embeddings_scale'],
         reduction_type=transformer.reduction_type,
         n_positions=transformer.n_positions,
         activation=opt['activation'],
         variant=opt['variant'],
         output_scaling=opt['output_scaling'],
     )
     self.know_use_project = nn.Linear(opt['embedding_size'],
                                       opt['embedding_size'])
Example #6
0
 def build_encoder(
     cls, opt, dictionary, embedding=None, padding_idx=None, reduction_type='mean'
 ):
     return TransformerEncoder(
         opt=opt,
         embedding=embedding,
         vocabulary_size=len(dictionary),
         padding_idx=padding_idx,
         reduction_type=reduction_type,
     )
Example #7
0
    def _build_text_encoder(self, n_layers_text):
        """
        Build the text (candidate) encoder.

        :param n_layers_text:
            how many layers the transformer will have
        """
        self.embeddings = nn.Embedding(len(self.dictionary),
                                       self.opt['embedding_size'])
        if (self.opt.get('load_encoder_from') is None
                and self.opt['embedding_type'] == 'fasttext_cc'):
            self.embeddings = load_fasttext_embeddings(
                self.dictionary, self.opt['embedding_size'],
                self.opt['datapath'])

        self.text_encoder = TransformerEncoder(
            n_heads=self.opt['n_heads'],
            n_layers=self.opt['n_layers'],
            embedding_size=self.opt['embedding_size'],
            ffn_size=self.opt['ffn_size'],
            vocabulary_size=len(self.dictionary),
            embedding=self.embeddings,
            dropout=self.opt['dropout'],
            attention_dropout=self.opt['attention_dropout'],
            relu_dropout=self.opt['relu_dropout'],
            padding_idx=self.dictionary.tok2ind[self.dictionary.null_token],
            learn_positional_embeddings=self.
            opt['learn_positional_embeddings'],
            embeddings_scale=False,
            n_positions=self.opt['n_positions'],
            activation=self.opt['activation'],
            variant=self.opt['variant'],
            n_segments=self.opt['n_segments'],
        )
        if self.opt.get('load_encoder_from') is not None:
            self._load_text_encoder_state()

        self.additional_layer = LinearWrapper(
            self.opt['embedding_size'],
            self.opt['hidden_dim'],
            dropout=self.opt['additional_layer_dropout'],
        )
Example #8
0
 def build_encoder(self, opt, embeddings):
     return TransformerEncoder(
         opt=opt,
         embedding=embeddings,
         vocabulary_size=self.vocab_size,
         padding_idx=self.pad_idx,
         dropout=0.0,
         n_positions=1024,
         n_segments=0,
         activation='relu',
         variant='aiayn',
         output_scaling=1.0,
     )
Example #9
0
 def _build_context_encoder(self):
     """
     Build the context (i.e. dialogue history) encoder.
     """
     if self.opt.get("share_encoder"):
         self.context_encoder = self.label_encoder
     else:
         if (self.opt["load_context_encoder_from"] is None
                 and self.opt["context_encoder_embedding_type"]
                 == "fasttext_cc"):
             embeddings = load_fasttext_embeddings(
                 self.dictionary, self.opt["embedding_size"],
                 self.opt["datapath"])
         else:
             embeddings = nn.Embedding(len(self.dictionary),
                                       self.opt["embedding_size"])
         self.context_encoder = TransformerEncoder(
             n_heads=self.opt["n_heads"],
             n_layers=self.opt["n_layers"],
             embedding_size=self.opt["embedding_size"],
             ffn_size=self.opt["ffn_size"],
             vocabulary_size=len(self.dictionary),
             embedding=embeddings,
             dropout=self.opt["dropout"],
             attention_dropout=self.opt["attention_dropout"],
             relu_dropout=self.opt["relu_dropout"],
             padding_idx=self.dictionary.tok2ind[
                 self.dictionary.null_token],
             learn_positional_embeddings=self.
             opt["learn_positional_embeddings"],
             embeddings_scale=False,
             n_positions=self.opt["n_positions"],
             activation=self.opt["activation"],
             variant=self.opt["variant"],
             n_segments=self.opt["n_segments"],
         )
         if self.opt.get("load_context_encoder_from") is not None:
             self._load_context_encoder_state()
Example #10
0
 def build_encoder(self, opt, embeddings):
     return TransformerEncoder(
         n_heads=opt['n_heads'],
         n_layers=opt['n_layers'],
         embedding_size=opt['embedding_size'],
         ffn_size=opt['ffn_size'],
         vocabulary_size=self.vocab_size,
         embedding=embeddings,
         attention_dropout=opt['attention_dropout'],
         relu_dropout=opt['relu_dropout'],
         padding_idx=self.pad_idx,
         learn_positional_embeddings=opt.get('learn_positional_embeddings', False),
         embeddings_scale=opt['embeddings_scale'],
     )
 def build_model(self, states=None):
     wrapped_class = TransformerGeneratorModel.with_components(
         encoder=TransformerEncoder.with_components(
             layer=TransformerEncoderLayer.with_components(
                 self_attention=MultiHeadAttention,
                 feedforward=TransformerFFN)),
         decoder=TransformerDecoder.with_components(
             layer=TransformerDecoderLayer.with_components(
                 encoder_attention=MultiHeadAttention,
                 self_attention=MultiHeadAttention,
                 feedforward=TransformerFFN,
             )),
     )
     return wrapped_class(opt=self.opt, dictionary=self.dict)
Example #12
0
 def test_swap_encoder_attention(self):
     CustomFFN = type('CustomFFN', (TransformerFFN, ), {})
     CustomFFN.forward = MagicMock()
     wrapped_class = TransformerGeneratorModel.with_components(
         encoder=TransformerEncoder.with_components(
             layer=TransformerEncoderLayer.with_components(
                 feedforward=CustomFFN)))
     opt = self._opt()
     CustomFFN.forward.assert_not_called
     model = wrapped_class(opt=opt, dictionary=DictionaryAgent(opt))
     assert isinstance(model, TransformerGeneratorModel)  # type: ignore
     try:
         model(torch.zeros(1, 1).long(),
               ys=torch.zeros(1, 1).long())  # type: ignore
     except TypeError:
         pass
     finally:
         CustomFFN.forward.assert_called
Example #13
0
class TransresnetMultimodalModel(TransresnetModel):
    """
    Extension of Transresnet to incorporate dialogue history and multimodality.
    """
    @staticmethod
    def add_cmdline_args(argparser):
        """
        Override to include model-specific args.
        """
        TransresnetModel.add_cmdline_args(argparser)
        agent = argparser.add_argument_group(
            "TransresnetMultimodal task arguments")
        agent.add_argument(
            "--context-encoder-embedding-type",
            type=str,
            default=None,
            choices=[None, "fasttext_cc"],
            help="Specify if using pretrained embeddings",
        )
        agent.add_argument(
            "--load-context-encoder-from",
            type=str,
            default=None,
            help="Specify if using a pretrained transformer encoder",
        )
        agent.add_argument(
            "--share-encoder",
            type="bool",
            default=False,
            help="Whether to share the text encoder for the "
            "labels and the dialogue history",
        )
        agent.add_argument("--num-layers-multimodal-encoder",
                           type=int,
                           default=1)
        agent.add_argument(
            "--multimodal",
            type="bool",
            default=False,
            help="If true, feed a query term into a separate "
            "transformer prior to computing final rank "
            "scores",
        )
        agent.add_argument(
            "--multimodal-combo",
            type=str,
            choices=["concat", "sum"],
            default="sum",
            help="How to combine the encoding for the "
            "multi-modal transformer",
        )
        agent.add_argument(
            "--encode-image",
            type="bool",
            default=True,
            help="Whether to include the image encoding when "
            "retrieving a candidate response",
        )
        agent.add_argument(
            "--encode-dialogue-history",
            type="bool",
            default=True,
            help="Whether to include the dialogue history "
            "encoding when retrieving a candidate response",
        )
        agent.add_argument(
            "--encode-personality",
            type="bool",
            default=True,
            help="Whether to include the personality encoding "
            "when retrieving a candidate response",
        )

    def __init__(self, opt, personalities_list, dictionary):
        super().__init__(opt, personalities_list, dictionary)
        self.hidden_dim = self.opt["hidden_dim"]
        self.share_encoder = opt.get("share_encoder")
        nlayers_mm = (opt["num_layers_all"] if opt["num_layers_all"] != -1 else
                      opt["num_layers_multimodal_encoder"])

        # blank encoding (for concat)
        self.blank_encoding = torch.Tensor(
            opt["hidden_dim"]).fill_(0).detach_()
        if self.use_cuda:
            self.blank_encoding = self.blank_encoding.cuda()

        # Encoders
        self.encode_image = opt.get("encode_image", True)
        self.encode_personality = opt.get("encode_personality", True)
        self.encode_dialogue_history = opt.get("encode_dialogue_history", True)
        assert any([
            self.encode_dialogue_history, self.encode_image,
            self.encode_personality
        ])

        # Transformer 2
        self._build_multimodal_encoder(nlayers_mm)

        # Label Encoder
        self.label_encoder = self.text_encoder

        # Context encoder
        self._build_context_encoder()

    def _build_multimodal_encoder(self, n_layers_mm):
        """
        Build the multimodal encoder.

        :param n_layers_mm:
            number of layers for the transformer
        """
        self.multimodal = self.opt.get("multimodal")
        if self.multimodal:
            self.multimodal_combo = self.opt.get("multimodal_combo", "sum")
            nlayers_mm = (self.opt["num_layers_all"]
                          if self.opt["num_layers_all"] != -1 else
                          self.opt["num_layers_multimodal_encoder"])
            self.multimodal_encoder = MultimodalCombiner(
                n_heads=self.opt["n_heads"],
                n_layers=nlayers_mm,
                hidden_dim=self.opt["hidden_dim"],
                ffn_size=self.opt["embedding_size"] * 4,
                attention_dropout=self.opt["attention_dropout"],
                relu_dropout=self.opt["relu_dropout"],
                learn_positional_embeddings=self.opt.get(
                    "learn_positional_embeddings", False),
                reduction=True,
            )

    def _build_context_encoder(self):
        """
        Build the context (i.e. dialogue history) encoder.
        """
        if self.opt.get("share_encoder"):
            self.context_encoder = self.label_encoder
        else:
            if (self.opt["load_context_encoder_from"] is None
                    and self.opt["context_encoder_embedding_type"]
                    == "fasttext_cc"):
                embeddings = load_fasttext_embeddings(
                    self.dictionary, self.opt["embedding_size"],
                    self.opt["datapath"])
            else:
                embeddings = nn.Embedding(len(self.dictionary),
                                          self.opt["embedding_size"])
            self.context_encoder = TransformerEncoder(
                n_heads=self.opt["n_heads"],
                n_layers=self.opt["n_layers"],
                embedding_size=self.opt["embedding_size"],
                ffn_size=self.opt["ffn_size"],
                vocabulary_size=len(self.dictionary),
                embedding=embeddings,
                dropout=self.opt["dropout"],
                attention_dropout=self.opt["attention_dropout"],
                relu_dropout=self.opt["relu_dropout"],
                padding_idx=self.dictionary.tok2ind[
                    self.dictionary.null_token],
                learn_positional_embeddings=self.
                opt["learn_positional_embeddings"],
                embeddings_scale=False,
                n_positions=self.opt["n_positions"],
                activation=self.opt["activation"],
                variant=self.opt["variant"],
                n_segments=self.opt["n_segments"],
            )
            if self.opt.get("load_context_encoder_from") is not None:
                self._load_context_encoder_state()

    def forward(
        self,
        image_features,
        personalities,
        dialogue_histories,
        labels,
        batchsize=None,
        personalities_tensor=None,
    ):
        """
        Model forward pass.

        :param image_features:
            list of tensors of image features, one per example
        :param personalities:
            list of personalities, one per example
        :param dialogue_histories:
            list of dialogue histories, one per example
        :param labels:
            list of response labels, one per example
        :param personalities_tensor:
            (optional) list of personality representations, usually a one-hot
            vector if specified

        :return:
            the encoded context and the encoded captions.
        """
        # labels
        labels_encoded = self.forward_text_encoder(labels)
        # dialog history
        d_hist_encoded = self.forward_text_encoder(dialogue_histories,
                                                   dialogue_history=True,
                                                   batchsize=batchsize)
        # images
        img_encoded = self.forward_image(image_features)
        # personalities
        pers_encoded = self.forward_personality(personalities,
                                                personalities_tensor)
        total_encoded = self.get_rep(
            [img_encoded, d_hist_encoded, pers_encoded], batchsize=batchsize)
        loss, nb_ok = self.get_loss(total_encoded, labels_encoded)

        return loss, nb_ok, total_encoded

    def forward_personality(self, personalities, personalities_tensor):
        """
        Encode personalities.

        :param personalities:
            list of personalities, one per example
        :param personalities_tensor:
            (optional) list of personality representations, usually a one-hot
            vector if specified

        :return:
            encoded representation of the personalities
        """
        pers_encoded = None
        if not self.encode_personality:
            if self.multimodal and self.multimodal_combo == "concat":
                pers_encoded = self.blank_encoding
        else:
            pers_encoded = super().forward_personality(personalities,
                                                       personalities_tensor)
        return pers_encoded

    def forward_text_encoder(self,
                             texts,
                             dialogue_history=False,
                             batchsize=None):
        """
        Forward pass for a text encoder.

        :param texts:
            text to encode
        :param dialogue_history:
            flag that indicates whether the text is dialogue history; if False,
            text is a response candidate
        :param batchsize:
            size of the batch

        :return:
            encoded representation of the `texts`
        """
        texts_encoded = None
        if texts is None or (dialogue_history
                             and not self.encode_dialogue_history):
            if (self.multimodal and self.multimodal_combo == "concat"
                    and dialogue_history):
                texts_encoded = torch.stack(
                    [self.blank_encoding for _ in range(batchsize)])
        else:
            encoder = self.context_encoder if dialogue_history else self.label_encoder
            indexes, mask = self.captions_to_tensor(texts)
            texts_encoded = encoder(indexes)
            if self.text_encoder_frozen:
                texts_encoded = texts_encoded.detach()
            texts_encoded = self.additional_layer(texts_encoded)

        return texts_encoded

    def forward_image(self, image_features):
        """
        Encode image features.

        :param image_features:
            list of image features

        :return:
            encoded representation of the image features
        """
        img_encoded = None
        if image_features is None or not self.encode_image:
            if self.multimodal and self.multimodal_combo == "concat":
                img_encoded = self.blank_encoding
        else:
            img_encoded = super().forward_image(image_features)

        return img_encoded

    def get_rep(self, encodings, batchsize=None):
        """
        Get the multimodal representation of the encodings.

        :param encodings:
            list of encodings
        :param batchsize:
            size of batch

        :return:
            final multimodal representations
        """
        if not self.multimodal:
            rep = self.sum_encodings(encodings)
        else:
            if self.multimodal_combo == "sum":
                encodings = self.sum_encodings(encodings).unsqueeze(1)
            elif self.multimodal_combo == "concat":
                encodings = self.cat_encodings(encodings)
            all_one_mask = torch.ones(encodings.size()[:2])
            if self.use_cuda:
                all_one_mask = all_one_mask.cuda()
            rep = self.multimodal_encoder(encodings, all_one_mask)
        if rep is None:
            rep = torch.stack([self.blank_encoding for _ in range(batchsize)])
        return rep

    def choose_best_response(
        self,
        image_features,
        personalities,
        dialogue_histories,
        candidates,
        candidates_encoded=None,
        k=1,
        batchsize=None,
    ):
        """
        Choose the best response for each example.

        :param image_features:
            list of tensors of image features
        :param personalities:
            list of personalities
        :param dialogue_histories:
            list of dialogue histories, one per example
        :param candidates:
            list of candidates, one set per example
        :param candidates_encoded:
            optional; if specified, a fixed set of encoded candidates that is
            used for each example
        :param k:
            number of ranked candidates to return. if < 1, we return the ranks
            of all candidates in the set.

        :return:
            a set of ranked candidates for each example
        """
        self.eval()
        _, _, encoded = self.forward(image_features,
                                     personalities,
                                     dialogue_histories,
                                     None,
                                     batchsize=batchsize)
        encoded = encoded.detach()
        one_cand_set = True
        if candidates_encoded is None:
            one_cand_set = False
            candidates_encoded = [
                self.forward_text_encoder(c).detach() for c in candidates
            ]
        chosen = [
            self.choose_topk(
                idx if not one_cand_set else 0,
                encoded,
                candidates,
                candidates_encoded,
                one_cand_set,
                k,
            ) for idx in range(len(encoded))
        ]
        return chosen

    def choose_topk(self, idx, encoded, candidates, candidates_encoded,
                    one_cand_set, k):
        """
        Choose top k best responses for a single example.

        :param idx:
            idx of example in encoded
        :param encoded:
            full matrix of encoded representations (for the whole batch)
        :param candidates:
            list of candidates
        :param candidates_encoded:
            encoding of the candidates
        :param one_cand_set:
            true if there is one set of candidates for each example
        :param k:
            how many ranked responses to return

        :return:
            ranked list of k responses
        """
        encoding = encoded[idx:idx + 1, :]
        scores = torch.mm(
            candidates_encoded[idx]
            if not one_cand_set else candidates_encoded,
            encoding.transpose(0, 1),
        )
        if k >= 1:
            _, index_top = torch.topk(scores, k, dim=0)
        else:
            _, index_top = torch.topk(scores, scores.size(0), dim=0)
        return [
            candidates[idx][idx2] if not one_cand_set else candidates[idx2]
            for idx2 in index_top.unsqueeze(1)
        ]

    def get_loss(self, total_encoded, labels_encoded):
        """
        Compute loss over batch.

        :param total_encoded:
            encoding of the examples
        :param labels_encoded:
            encoding of the labels

        :return:
            total batch loss, and number of correct examples
        """
        loss = None
        num_correct = None
        if labels_encoded is not None:
            dot_products = total_encoded.mm(
                labels_encoded.t())  # batch_size * batch_size
            log_prob = torch.nn.functional.log_softmax(dot_products, dim=1)
            targets = torch.arange(0, len(total_encoded), dtype=torch.long)
            if self.use_cuda:
                targets = targets.cuda()
            loss = torch.nn.functional.nll_loss(log_prob, targets)
            num_correct = (log_prob.max(dim=1)[1] == targets).float().sum()
        return loss, num_correct

    def cat_encodings(self, tensors):
        """
        Concatenate non-`None` encodings.

        :param tensors:
            list tensors to concatenate

        :return:
            concatenated tensors
        """
        tensors = [t for t in tensors if t is not None]
        return torch.cat([t.unsqueeze(1) for t in tensors], dim=1)

    def _load_text_encoder_state(self):
        try:
            state_file = self.opt.get("load_encoder_from")
            with PathManager.open(state_file, 'rb') as f:
                model = torch.load(f)
            states = model["model"]
            self.text_encoder.load_state_dict(states)
        except Exception as e:
            print("WARNING: Cannot load transformer state; please make sure "
                  "specified file is a dictionary with the states in `model`. "
                  "Additionally, make sure that the appropriate options are "
                  "specified. Error: {}".format(e))

    def _load_context_encoder_state(self):
        try:
            state_file = self.opt.get("load_context_encoder_from")
            with PathManager.open(state_file, 'rb') as f:
                model = torch.load(f)
            states = model["model"]
            self.context_encoder.load_state_dict(states)
        except Exception as e:
            print("WARNING: Cannot load transformer state; please make sure "
                  "specified file is a dictionary with the states in `model`. "
                  "Additionally, make sure that the appropriate options are "
                  "specified. Error: {}".format(e))
Example #14
0
class TransresnetModel(nn.Module):
    """Actual model code for the Transresnet Agent."""
    @staticmethod
    def add_cmdline_args(argparser):
        """Add command line arguments."""
        Transformer.add_common_cmdline_args(argparser)
        agent = argparser.add_argument_group('TransresnetModel arguments')
        agent.add_argument(
            '--truncate',
            type=int,
            default=32,
            help='Max amount of tokens allowed in a text sequence',
        )
        agent.add_argument(
            '--image-features-dim',
            type=int,
            default=2048,
            help='dimensionality of image features',
        )
        agent.add_argument(
            '--embedding-type',
            type=str,
            default=None,
            choices=[None, 'fasttext_cc'],
            help='Specify if using pretrained embeddings',
        )
        agent.add_argument(
            '--load-encoder-from',
            type=str,
            default=None,
            help='Specify if using a pretrained transformer encoder',
        )
        agent.add_argument(
            '--hidden-dim',
            type=int,
            default=300,
            help='Hidden dimesionality of personality and image encoder',
        )
        agent.add_argument(
            '--num-layers-all',
            type=int,
            default=-1,
            help='If >= 1, number of layers for both the text '
            'and image encoders.',
        )
        agent.add_argument(
            '--num-layers-text-encoder',
            type=int,
            default=1,
            help='Number of layers for the text encoder',
        )
        agent.add_argument(
            '--num-layers-image-encoder',
            type=int,
            default=1,
            help='Number of layers for the image encoder',
        )
        agent.add_argument(
            '--no-cuda',
            dest='no_cuda',
            action='store_true',
            help='If True, perform ops on CPU only',
        )
        agent.add_argument(
            '--learningrate',
            type=float,
            default=0.0005,
            help='learning rate for optimizer',
        )
        agent.add_argument(
            '--additional-layer-dropout',
            type=float,
            default=0.2,
            help='dropout for additional linear layer',
        )
        argparser.set_params(ffn_size=1200,
                             attention_dropout=0.2,
                             relu_dropout=0.2,
                             n_positions=1000)

    def __init__(self, opt, personalities_list, dictionary):
        super().__init__()
        self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available()
        self.opt = opt
        self.dictionary = dictionary
        self.truncate_length = opt['truncate']
        if opt['num_layers_all'] != -1:
            n_layers_text = n_layers_img = opt['num_layers_all']
        else:
            n_layers_text = opt['num_layers_text_encoder']
            n_layers_img = opt['num_layers_image_encoder']
        self.text_encoder_frozen = False

        # Initialize personalities dictionary
        self._build_personality_dictionary(personalities_list)

        # Text encoder
        self._build_text_encoder(n_layers_text)

        # Image encoder
        self._build_image_encoder(n_layers_img)

        # Personality Encoder
        self._build_personality_encoder()

        # optimizer
        self.optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.parameters()),
            self.opt['learningrate'],
        )

    def _build_personality_dictionary(self, personalities_list):
        """
        Build the personality dictionary mapping personality to id.

        :param personalities_list:
            list of personalities
        """
        self.personalities_list = personalities_list
        self.personality_to_id = {
            p: i
            for i, p in enumerate(personalities_list)
        }
        self.num_personalities = len(self.personalities_list) + 1

    def _build_text_encoder(self, n_layers_text):
        """
        Build the text (candidate) encoder.

        :param n_layers_text:
            how many layers the transformer will have
        """
        self.embeddings = nn.Embedding(len(self.dictionary),
                                       self.opt['embedding_size'])
        if (self.opt.get('load_encoder_from') is None
                and self.opt['embedding_type'] == 'fasttext_cc'):
            self.embeddings = load_fasttext_embeddings(
                self.dictionary, self.opt['embedding_size'],
                self.opt['datapath'])

        self.text_encoder = TransformerEncoder(
            n_heads=self.opt['n_heads'],
            n_layers=self.opt['n_layers'],
            embedding_size=self.opt['embedding_size'],
            ffn_size=self.opt['ffn_size'],
            vocabulary_size=len(self.dictionary),
            embedding=self.embeddings,
            dropout=self.opt['dropout'],
            attention_dropout=self.opt['attention_dropout'],
            relu_dropout=self.opt['relu_dropout'],
            padding_idx=self.dictionary.tok2ind[self.dictionary.null_token],
            learn_positional_embeddings=self.
            opt['learn_positional_embeddings'],
            embeddings_scale=False,
            n_positions=self.opt['n_positions'],
            activation=self.opt['activation'],
            variant=self.opt['variant'],
            n_segments=self.opt['n_segments'],
        )
        if self.opt.get('load_encoder_from') is not None:
            self._load_text_encoder_state()

        self.additional_layer = LinearWrapper(
            self.opt['embedding_size'],
            self.opt['hidden_dim'],
            dropout=self.opt['additional_layer_dropout'],
        )

    def _build_image_encoder(self, n_layers_img):
        """
        Build the image encoder mapping raw image features to the appropriate space.

        :param n_layers_img:
            number of feed-forward layers for the image encoder
        """
        image_layers = [
            nn.BatchNorm1d(self.opt['image_features_dim']),
            nn.Dropout(p=self.opt['dropout']),
            nn.Linear(self.opt['image_features_dim'], self.opt['hidden_dim']),
        ]
        for _ in range(n_layers_img - 1):
            image_layers += [
                nn.ReLU(),
                nn.Dropout(p=self.opt['dropout']),
                nn.Linear(self.opt['hidden_dim'], self.opt['hidden_dim']),
            ]
        self.image_encoder = nn.Sequential(*image_layers)

    def _build_personality_encoder(self):
        personality_layers = [
            nn.BatchNorm1d(self.num_personalities),
            nn.Dropout(p=self.opt['dropout']),
            nn.Linear(self.num_personalities, self.opt['hidden_dim']),
        ]
        self.personality_encoder = nn.Sequential(*personality_layers)

    def forward(self,
                image_features,
                personalities,
                captions,
                personalities_tensor=None):
        """
        Model forward pass.

        :param image_features:
            list of tensors of image features, one per example
        :param personalities:
            list of personalities, one per example
        :param captions:
            list of captions, one per example
        :param personalities_tensor:
            (optional) list of personality representations, usually a one-hot
            vector if specified

        :return:
            the encoded context and the encoded captions.
        """
        captions_encoded = None
        context_encoded = None
        img_encoded = None

        # encode captions
        if captions is not None:
            indexes, mask = self.captions_to_tensor(captions)
            captions_encoded = self.text_encoder(indexes)
            if self.text_encoder_frozen:
                captions_encoded = captions_encoded.detach()
            captions_encoded = self.additional_layer(captions_encoded)

        # encode personalities
        pers_encoded = self.forward_personality(personalities,
                                                personalities_tensor)

        # encode images
        img_encoded = self.forward_image(image_features)

        context_encoded = self.sum_encodings([pers_encoded, img_encoded])
        return context_encoded, captions_encoded

    def forward_personality(self, personalities, personalities_tensor):
        """
        Encode personalities.

        :param personalities:
            list of personalities, one per example
        :param personalities_tensor:
            (optional) list of personality representations, usually a one-hot
            vector if specified

        :return:
            encoded representation of the personalities
        """
        pers_encoded = None
        if personalities is not None:
            if personalities_tensor is not None:
                pers_feature = personalities_tensor
            else:
                res = torch.FloatTensor(len(personalities),
                                        self.num_personalities).fill_(0)
                p_to_i = self.personalities_to_index(personalities)
                for i, index in enumerate(p_to_i):
                    res[i, index] = 1  # no personality corresponds to 0
                if self.use_cuda:
                    res = res.cuda()
                pers_feature = res
            pers_encoded = self.personality_encoder(pers_feature)

        return pers_encoded

    def forward_image(self, image_features):
        """
        Encode image features.

        :param image_features:
            list of image features

        :return:
            encoded representation of the image features
        """
        img_encoded = None
        if image_features is not None:
            stacked = torch.stack(image_features)
            if self.use_cuda:
                stacked = stacked.cuda()
            img_encoded = self.image_encoder(stacked)

        return img_encoded

    def train_batch(self, image_features, personalities, captions):
        """
        Batch train on a set of examples.

        Uses captions from other examples as negatives during training

        :param image_features:
            list of tensors of image features
        :param personalities:
            list of personalities
        :param captions:
            list of captions

        :return:
            the total loss, the number of correct examples, and the number of
            examples trained on
        """
        self.zero_grad()
        self.train()
        context_encoded, captions_encoded = self.forward(
            image_features, personalities, captions)
        loss, num_correct = self.evaluate_one_batch(context_encoded,
                                                    captions_encoded,
                                                    during_train=True)
        loss.backward()
        self.optimizer.step()

        # re-run forward pass to compute hits@1 metrics
        loss, num_correct, num_examples = self.eval_batch_of_100(
            context_encoded, captions_encoded)
        return loss, num_correct, num_examples

    def eval_batch(self, image_features, personalities, captions):
        """
        Evaluate performance of model on one batch.

        Batch is split into chunks of 100 to evaluate hits@1/100

        :param image_features:
            list of tensors of image features
        :param personalities:
            list of personalities
        :param captions:
            list of captions

        :return:
            the total loss, the number of correct examples, and the number of
            examples trained on
        """
        if personalities is None:
            personalities = [''] * len(image_features)
        if len(image_features) == 0:
            return 0, 0, 1
        self.eval()
        context_encoded, captions_encoded = self.forward(
            image_features, personalities, captions)
        loss, num_correct, num_examples = self.eval_batch_of_100(
            context_encoded, captions_encoded)
        return loss, num_correct, num_examples

    def choose_best_caption(self,
                            image_features,
                            personalities,
                            candidates,
                            candidates_encoded=None,
                            k=1):
        """
        Choose the best caption for each example.

        :param image_features:
            list of tensors of image features
        :param personalities:
            list of personalities
        :param candidates:
            list of candidates, one set per example
        :param candidates_encoded:
            optional; if specified, a fixed set of encoded candidates that is
            used for each example
        :param k:
            number of ranked candidates to return. if < 1, we return the ranks
            of all candidates in the set.

        :return:
            a set of ranked candidates for each example
        """
        self.eval()
        context_encoded, _ = self.forward(image_features, personalities, None)
        context_encoded = context_encoded.detach()
        one_cand_set = True
        if candidates_encoded is None:
            one_cand_set = False
            candidates_encoded = [
                self.forward(None, None, c)[1].detach() for c in candidates
            ]
        chosen = []
        for img_index in range(len(context_encoded)):
            context_encoding = context_encoded[img_index:img_index + 1, :]
            scores = torch.mm(
                candidates_encoded[img_index]
                if not one_cand_set else candidates_encoded,
                context_encoding.transpose(0, 1),
            )
            if k >= 1:
                _, index_top = torch.topk(scores, k, dim=0)
            else:
                _, index_top = torch.topk(scores, scores.size(0), dim=0)
            chosen.append([
                candidates[img_index][idx]
                if not one_cand_set else candidates[idx]
                for idx in index_top.unsqueeze(1)
            ])

        return chosen

    def eval_batch_of_100(self, context_encoded, captions_encoded):
        """
        Evaluate a batch of 100 examples.

        The captions of the other examples are used as negatives.

        :param context_encoded:
            the encoded context
        :param captions_encoded:
            the encoded captions

        :return:
            the total loss, the total number of correct examples, and the
            total number of examples evaluated.
        """
        total_loss = 0
        total_ok = 0
        num_examples = 0
        for i in range(0, len(context_encoded), 100):
            if i + 100 > len(context_encoded):
                break
            num_examples += 100
            loss, num_correct = self.evaluate_one_batch(
                context_encoded[i:i + 100, :], captions_encoded[i:i + 100, :])
            total_loss += loss.data.cpu().item()
            total_ok += num_correct.data.cpu().item()
        return total_loss, total_ok, num_examples

    def evaluate_one_batch(self,
                           context_encoded,
                           captions_encoded,
                           during_train=False):
        """
        Compute loss - and number of correct examples - for one batch.

        :param context_encoded:
            the encoded context
        :param captions_encoded:
            the encoded captions
        :param during_train:
            true if training, else False

        :return:
            the batch loss and the number of correct examples
        """
        if not during_train:
            self.zero_grad()
            self.eval()
        dot_products = context_encoded.mm(captions_encoded.t())
        log_prob = torch.nn.functional.log_softmax(dot_products, dim=1)
        targets = torch.arange(0, len(context_encoded), dtype=torch.long)
        if self.use_cuda:
            targets = targets.cuda()
        loss = torch.nn.functional.nll_loss(log_prob, targets)
        num_correct = (log_prob.max(dim=1)[1] == targets).float().sum()
        return loss, num_correct

    def freeze_text_encoder(self):
        """Freeze the text (candidate) encoder."""
        self.text_encoder_frozen = True

    def unfreeze_text_encoder(self):
        """Unfreeze the text (candidate) encoder."""
        self.text_encoder_frozen = False

    def sum_encodings(self, addends):
        """
        Add up a list of encodings, some of which may be `None`.

        :param addends:
            tensors to add

        :return:
            sum of non-`None` addends
        """
        addends = [a for a in addends if a is not None]
        return sum(addends) if len(addends) > 0 else None

    def personalities_to_index(self, personalities):
        """
        Map personalities to their index in the personality dictionary.

        :param personalities:
            list of personalities

        :return:
            list of personality ids
        """
        res = []
        for p in personalities:
            if p in self.personality_to_id:
                res.append(self.personality_to_id[p] + 1)
            else:
                res.append(0)
        return res

    def captions_to_tensor(self, captions):
        """
        Tokenize a list of sentences into a 2D float tensor.

        :param captions:
            list of sentences to tokenize

        :return:
            a (batchsize X truncate_length) tensor representation of the captions,
            and a similarly sized mask tensor
        """
        max_length = self.truncate_length
        indexes = []
        for c in captions:
            vec = self.dictionary.txt2vec(c)
            if len(vec) > max_length:
                vec = vec[:max_length]
            indexes.append(self.dictionary.txt2vec(c))
        longest = max([len(v) for v in indexes])
        res = torch.LongTensor(len(captions), longest).fill_(
            self.dictionary.tok2ind[self.dictionary.null_token])
        mask = torch.FloatTensor(len(captions), longest).fill_(0)
        for i, inds in enumerate(indexes):
            res[i, 0:len(inds)] = torch.LongTensor(inds)
            mask[i, 0:len(inds)] = torch.FloatTensor([1] * len(inds))
        if self.use_cuda:
            res = res.cuda()
            mask = mask.cuda()
        return res, mask

    def _load_text_encoder_state(self):
        try:
            state_file = self.opt.get('load_encoder_from')
            model = torch.load(state_file)
            states = model['model']
            self.text_encoder.load_state_dict(states)
        except Exception as e:
            print('WARNING: Cannot load transformer state; please make sure '
                  'specified file is a dictionary with the states in `model`. '
                  'Additionally, make sure that the appropriate options are '
                  'specified. Error: {}'.format(e))