Ejemplo n.º 1
0
Archivo: model.py Proyecto: zqyuan/r2c
    def __init__(self,
                 vocab: Vocabulary,
                 span_encoder: Seq2SeqEncoder,
                 reasoning_encoder: Seq2SeqEncoder,
                 input_dropout: float = 0.3,
                 hidden_dim_maxpool: int = 1024,
                 class_embs: bool=True,
                 reasoning_use_obj: bool=True,
                 reasoning_use_answer: bool=True,
                 reasoning_use_question: bool=True,
                 pool_reasoning: bool = True,
                 pool_answer: bool = True,
                 pool_question: bool = False,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 ):
        super(AttentionQA, self).__init__(vocab)

        self.detector = SimpleDetector(pretrained=True, average_pool=True, semantic=class_embs, final_dim=512)
        ###################################################################################################

        self.rnn_input_dropout = TimeDistributed(InputVariationalDropout(input_dropout)) if input_dropout > 0 else None

        self.span_encoder = TimeDistributed(span_encoder)
        self.reasoning_encoder = TimeDistributed(reasoning_encoder)

        self.span_attention = BilinearMatrixAttention(
            matrix_1_dim=span_encoder.get_output_dim(),
            matrix_2_dim=span_encoder.get_output_dim(),
        )

        self.obj_attention = BilinearMatrixAttention(
            matrix_1_dim=span_encoder.get_output_dim(),
            matrix_2_dim=self.detector.final_dim,
        )

        self.reasoning_use_obj = reasoning_use_obj
        self.reasoning_use_answer = reasoning_use_answer
        self.reasoning_use_question = reasoning_use_question
        self.pool_reasoning = pool_reasoning
        self.pool_answer = pool_answer
        self.pool_question = pool_question
        dim = sum([d for d, to_pool in [(reasoning_encoder.get_output_dim(), self.pool_reasoning),
                                        (span_encoder.get_output_dim(), self.pool_answer),
                                        (span_encoder.get_output_dim(), self.pool_question)] if to_pool])

        self.final_mlp = torch.nn.Sequential(
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(dim, hidden_dim_maxpool),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(hidden_dim_maxpool, 1),
        )
        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self)
Ejemplo n.º 2
0
    def __init__(self,
                 vocab: Vocabulary,
                 option_encoder: Seq2SeqEncoder,
                 input_dropout: float = 0.3,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 ):
        super(LSTMBatchNormFreezeDetGlobalNoFinalImageFull, self).__init__(vocab)
        self.rnn_input_dropout = TimeDistributed(InputVariationalDropout(input_dropout)) if input_dropout > 0 else None
        self.detector = SimpleDetector(pretrained=True, average_pool=True, semantic=False, final_dim=512)

        # freeze everything related to conv net
        for submodule in self.detector.backbone.modules():
            # if isinstance(submodule, BatchNorm2d):
                # submodule.track_running_stats = False
            for p in submodule.parameters():
                p.requires_grad = False

        for submodule in self.detector.after_roi_align.modules():
            # if isinstance(submodule, BatchNorm2d):
                # submodule.track_running_stats = False
            for p in submodule.parameters():
                p.requires_grad = False

        self.image_BN = BatchNorm1d(512)

        self.option_encoder = TimeDistributed(option_encoder)
        self.option_BN = torch.nn.Sequential(
            BatchNorm1d(512)
        )
        self.query_BN = torch.nn.Sequential(
            BatchNorm1d(512)
        )
        self.final_mlp = torch.nn.Sequential(
            torch.nn.Linear(1024, 512),
            torch.nn.ReLU(inplace=True),
        )
        self.final_BN = torch.nn.Sequential(
            BatchNorm1d(512)
        )
        self.final_mlp_linear = torch.nn.Sequential(
            torch.nn.Linear(512,1)
        )
        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self)
Ejemplo n.º 3
0
    def __init__(self,
                 vocab: Vocabulary,
                 class_embs: bool = True,
                 bert_model_name: str = "bert-base-uncased",
                 cnn_loss_ratio: float = 0.0,
                 special_visual_initialize: bool = False,
                 text_only: bool = False,
                 visual_embedding_dim: int = 512,
                 hard_cap_seq_len: int = None,
                 cut_first: str = 'text',
                 embedding_strategy: str = 'plain',
                 random_initialize: bool = False,
                 training_head_type: str = "pretraining",
                 bypass_transformer: bool = False,
                 pretrained_detector: bool = True,
                 output_attention_weights: bool = False):
        super(VisualBERTDetector, self).__init__(vocab)

        from utils.detector import SimpleDetector
        self.detector = SimpleDetector(pretrained=pretrained_detector,
                                       average_pool=True,
                                       semantic=class_embs,
                                       final_dim=512)
        ##################################################################################################
        self.bert = TrainVisualBERTObjective.from_pretrained(
            bert_model_name,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(-1)),
            training_head_type=training_head_type,
            visual_embedding_dim=visual_embedding_dim,
            hard_cap_seq_len=hard_cap_seq_len,
            cut_first=cut_first,
            embedding_strategy=embedding_strategy,
            bypass_transformer=bypass_transformer,
            random_initialize=random_initialize,
            output_attention_weights=output_attention_weights)
        if special_visual_initialize:
            self.bert.bert.embeddings.special_intialize()

        self.training_head_type = training_head_type
        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()
        self.cnn_loss_ratio = cnn_loss_ratio
Ejemplo n.º 4
0
    def __init__(
        self,
        span_encoder: Seq2SeqEncoder,
        input_dropout: float = 0.3,
        class_embs: bool = True,
        initializer: InitializerApplicator = InitializerApplicator(),
        learned_omcs: dict = {},
    ):
        # VCR dataset becomes unpicklable due to VCR.vocab, but we don't need
        # to pass in vocab from the dataset anyway as the BERT embeddings are
        # pretrained and stored in h5 files per dataset instance. Just pass
        # a dummy vocab instance for init.
        vocab = Vocabulary()
        super(KeyValueAttentionTrunk, self).__init__(vocab)

        self.detector = SimpleDetector(pretrained=True,
                                       average_pool=True,
                                       semantic=class_embs,
                                       final_dim=512)

        self.rnn_input_dropout = TimeDistributed(
            InputVariationalDropout(
                input_dropout)) if input_dropout > 0 else None

        self.span_encoder = TimeDistributed(span_encoder)
        span_dim = span_encoder.get_output_dim()

        self.question_mlp = torch.nn.Sequential(
            # 2 (bidirectional) * 4 (num_answers) * dim -> dim
            torch.nn.Linear(8 * span_dim, span_dim),
            torch.nn.Tanh(),
        )
        self.answer_mlp = torch.nn.Sequential(
            # 2 (bidirectional) * dim -> 2 (key-value) * dim
            torch.nn.Linear(2 * span_dim, 2 * span_dim),
            torch.nn.Tanh(),
        )
        self.obj_mlp = torch.nn.Sequential(
            # obj_dim -> 2 (key-value) * dim
            torch.nn.Linear(self.detector.final_dim, 2 * span_dim),
            torch.nn.Tanh(),
        )

        self.span_attention = BilinearMatrixAttention(
            matrix_1_dim=span_encoder.get_output_dim(),
            matrix_2_dim=span_encoder.get_output_dim(),
        )

        self.obj_attention = BilinearMatrixAttention(
            matrix_1_dim=span_encoder.get_output_dim(),
            matrix_2_dim=self.detector.final_dim,
        )

        self.kv_transformer = KeyValueTransformer(
            dim=span_dim,
            num_heads=8,
            num_steps=4,
        )

        self.omcs_index = None
        if learned_omcs.get('enabled', False):
            use_sentence_embs = learned_omcs.get('use_sentence_embeddings',
                                                 True)
            omcs_embs, self.omcs_index = self.load_omcs(use_sentence_embs)
            # Let's replicate the OMCS embeddings to each device to attend over them
            # after FAISS lookup. We could also do faiss.search_and_reconstruct, but
            # that prevents us from using quantized indices for faster search which
            # we might need to.
            self.register_buffer('omcs_embs', omcs_embs)
            self.omcs_mlp = torch.nn.Sequential(
                torch.nn.Linear(768, self.omcs_index.d), )
            self.k = learned_omcs.get('max_neighbors', 5)
            self.similarity_thresh = learned_omcs.get('similarity_thresh', 0.0)

        initializer(self)
Ejemplo n.º 5
0
    def __init__(
            self,
            vocab: Vocabulary,
            span_encoder: Seq2SeqEncoder,
            reasoning_encoder: Seq2SeqEncoder,
            input_dropout: float = 0.3,
            hidden_dim_maxpool: int = 1024,
            class_embs: bool = True,
            reasoning_use_obj: bool = True,
            reasoning_use_answer: bool = True,
            reasoning_use_question: bool = True,
            pool_reasoning: bool = True,
            pool_answer: bool = True,
            pool_question: bool = False,
            reasoning_use_vision: bool = True,
            initializer: InitializerApplicator = InitializerApplicator(),
    ):

        super(AttentionQA, self).__init__(vocab)

        self.detector = SimpleDetector(pretrained=True,
                                       average_pool=True,
                                       semantic=class_embs,
                                       final_dim=512)
        self.extractor = SimpleExtractor(pretrained=True,
                                         num_classes=365,
                                         arch='resnet50')
        ###################################################################################################

        self.rnn_input_dropout = TimeDistributed(
            InputVariationalDropout(
                input_dropout)) if input_dropout > 0 else None

        self.span_encoder = TimeDistributed(span_encoder)
        self.reasoning_encoder = TimeDistributed(reasoning_encoder)

        # add scene classification visual feature and word embedding feature

        self.span_attention = BilinearMatrixAttention(
            matrix_1_dim=span_encoder.get_output_dim(),
            matrix_2_dim=span_encoder.get_output_dim(),
        )

        self.obj_attention = BilinearMatrixAttention(
            matrix_1_dim=span_encoder.get_output_dim(),
            matrix_2_dim=self.detector.final_dim,
        )

        self.reasoning_use_obj = reasoning_use_obj
        self.reasoning_use_answer = reasoning_use_answer
        self.reasoning_use_question = reasoning_use_question
        self.pool_reasoning = pool_reasoning
        self.pool_answer = pool_answer
        self.pool_question = pool_question
        self.reasoning_use_vision = reasoning_use_vision
        dim = sum([
            d for d, to_pool in [(
                reasoning_encoder.get_output_dim(), self.pool_reasoning
            ), (span_encoder.get_output_dim(), self.pool_answer
                ), (span_encoder.get_output_dim(), self.pool_question)]
            if to_pool
        ])

        self.projection = torch.nn.Conv2d(2048,
                                          self.detector.final_dim,
                                          kernel_size=1,
                                          stride=2,
                                          padding=1,
                                          bias=True)

        self.final_mlp = torch.nn.Sequential(
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(dim, hidden_dim_maxpool),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(hidden_dim_maxpool, 1),
        )
        self._accuracy = CategoricalAccuracy()

        # I want to replace the CrossEntropyLoss with LSR

        # self._loss = LabelSmoothingLoss(size=4,smoothing=0.1)
        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self)
Ejemplo n.º 6
0
    def __init__(
        self,
        span_encoder: Seq2SeqEncoder,
        reasoning_encoder: Seq2SeqEncoder,
        input_dropout: float = 0.3,
        hidden_dim_maxpool: int = 1024,
        class_embs: bool = True,
        reasoning_use_obj: bool = True,
        reasoning_use_answer: bool = True,
        reasoning_use_question: bool = True,
        pool_reasoning: bool = True,
        pool_answer: bool = True,
        pool_question: bool = False,
        initializer: InitializerApplicator = InitializerApplicator(),
        learned_omcs: dict = {},
    ):
        # VCR dataset becomes unpicklable due to VCR.vocab, but we don't need
        # to pass in vocab from the dataset anyway as the BERT embeddings are
        # pretrained and stored in h5 files per dataset instance. Just pass
        # a dummy vocab instance for init.
        vocab = Vocabulary()
        super(AttentionQATrunk, self).__init__(vocab)

        self.detector = SimpleDetector(pretrained=True,
                                       average_pool=True,
                                       semantic=class_embs,
                                       final_dim=512)
        ###################################################################################################

        self.rnn_input_dropout = TimeDistributed(
            InputVariationalDropout(
                input_dropout)) if input_dropout > 0 else None

        self.span_encoder = TimeDistributed(span_encoder)
        self.reasoning_encoder = TimeDistributed(reasoning_encoder)

        self.span_attention = BilinearMatrixAttention(
            matrix_1_dim=span_encoder.get_output_dim(),
            matrix_2_dim=span_encoder.get_output_dim(),
        )

        self.obj_attention = BilinearMatrixAttention(
            matrix_1_dim=span_encoder.get_output_dim(),
            matrix_2_dim=self.detector.final_dim,
        )

        self.reasoning_use_obj = reasoning_use_obj
        self.reasoning_use_answer = reasoning_use_answer
        self.reasoning_use_question = reasoning_use_question
        self.pool_reasoning = pool_reasoning
        self.pool_answer = pool_answer
        self.pool_question = pool_question
        self.output_dim = sum([
            d for d, to_pool in [(
                reasoning_encoder.get_output_dim(), self.pool_reasoning
            ), (span_encoder.get_output_dim(), self.pool_answer
                ), (span_encoder.get_output_dim(), self.pool_question)]
            if to_pool
        ])

        self.omcs_index = None
        if learned_omcs.get('enabled', False):
            use_sentence_embs = learned_omcs.get('use_sentence_embeddings',
                                                 True)
            omcs_embs, self.omcs_index = self.load_omcs(use_sentence_embs)
            # Let's replicate the OMCS embeddings to each device to attend over them
            # after FAISS lookup. We could also do faiss.search_and_reconstruct, but
            # that prevents us from using quantized indices for faster search which
            # we might need to.
            self.register_buffer('omcs_embs', omcs_embs)
            self.omcs_mlp = torch.nn.Sequential(
                torch.nn.Linear(768, self.omcs_index.d), )
            self.k = learned_omcs.get('max_neighbors', 5)
            self.similarity_thresh = learned_omcs.get('similarity_thresh', 0.0)

        initializer(self)
Ejemplo n.º 7
0
    def __init__(
            self,
            vocab: Vocabulary,
            span_encoder: Seq2SeqEncoder,
            reasoning_encoder: Seq2SeqEncoder,
            input_dropout: float = 0.1,
            hidden_dim_maxpool: int = 512,
            class_embs: bool = True,
            reasoning_use_obj: bool = True,
            reasoning_use_answer: bool = True,
            reasoning_use_question: bool = True,
            pool_reasoning: bool = True,
            pool_answer: bool = True,
            pool_question: bool = False,
            preload_path: str = "source_model.th",
            initializer: InitializerApplicator = InitializerApplicator(),
    ):
        super(AttentionQA, self).__init__(vocab)

        self.detector = SimpleDetector(pretrained=True,
                                       average_pool=True,
                                       semantic=class_embs,
                                       final_dim=512)
        ###################################################################################################

        self.rnn_input_dropout = TimeDistributed(
            InputVariationalDropout(
                input_dropout)) if input_dropout > 0 else None

        self.span_encoder = TimeDistributed(span_encoder)
        self.reasoning_encoder = TimeDistributed(reasoning_encoder)
        self.BiLSTM = TimeDistributed(MYLSTM(1280, 512, 256))
        self.source_encoder = TimeDistributed(source_LSTM(768, 256))

        self.span_attention = BilinearMatrixAttention(
            matrix_1_dim=span_encoder.get_output_dim(),
            matrix_2_dim=span_encoder.get_output_dim(),
        )
        self.span_attention_2 = BilinearMatrixAttention(
            matrix_1_dim=span_encoder.get_output_dim(),
            matrix_2_dim=span_encoder.get_output_dim(),
        )

        self.obj_attention = BilinearMatrixAttention(
            matrix_1_dim=span_encoder.get_output_dim(),
            matrix_2_dim=self.detector.final_dim,
        )

        self.obj_attention_2 = BilinearMatrixAttention(
            matrix_1_dim=span_encoder.get_output_dim(),
            matrix_2_dim=self.detector.final_dim,
        )

        self._matrix_attention = DotProductMatrixAttention()
        #self._matrix_attention = MatrixAttention(similarity_function)

        self.reasoning_use_obj = reasoning_use_obj
        self.reasoning_use_answer = reasoning_use_answer
        self.reasoning_use_question = reasoning_use_question
        self.pool_reasoning = pool_reasoning
        self.pool_answer = pool_answer
        self.pool_question = pool_question
        dim = sum([
            d for d, to_pool in [(
                reasoning_encoder.get_output_dim(), self.pool_reasoning
            ), (span_encoder.get_output_dim(), self.pool_answer
                ), (span_encoder.get_output_dim(), self.pool_question)]
            if to_pool
        ])

        self.final_mlp = torch.nn.Sequential(
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(dim, hidden_dim_maxpool),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(hidden_dim_maxpool, 1),
        )
        self.final_mlp_2 = torch.nn.Sequential(
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(dim, hidden_dim_maxpool),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(input_dropout, inplace=False),
            torch.nn.Linear(hidden_dim_maxpool, 1),
        )

        self.answer_BN = torch.nn.Sequential(BatchNorm1d(512))
        self.question_BN = torch.nn.Sequential(BatchNorm1d(512))
        self.source_answer_BN = torch.nn.Sequential(BatchNorm1d(512))
        self.source_question_BN = torch.nn.Sequential(BatchNorm1d(512))
        self.image_BN = BatchNorm1d(512)
        self.final_BN = torch.nn.Sequential(BatchNorm1d(512))
        self.final_mlp_linear = torch.nn.Sequential(torch.nn.Linear(512, 1))
        self.final_mlp_pool = torch.nn.Sequential(
            torch.nn.Linear(2560, 512),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(input_dropout, inplace=False),
        )

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self)

        if preload_path is not None:
            logger.info("Preloading!")
            preload = torch.load(preload_path)
            own_state = self.state_dict()
            for name, param in preload.items():
                #if name[0:8] == "_encoder":
                #    suffix = "._module."+name[9:]
                #    logger.info("preload paramter {}".format("span_encoder"+suffix))
                #    own_state["span_encoder"+suffix].copy_(param)
                #新引入的source_encoder
                if name[0:4] == "LSTM":
                    suffix = "._module" + name[4:]
                    logger.info("preload paramter {}".format("source_encoder" +
                                                             suffix))
                    own_state["source_encoder" + suffix].copy_(param)