Пример #1
0
    def __init__(self, config, word_vocabulary):
        super(AttentionNet, self).__init__()

        self.fusion_factory = FusionFactory()
        if config['attention_fusion_type'] == 'concat_mlp':
            self.attention_fusion = ConcatMLP(config['attention_fusion_mlp'])
        elif config['attention_fusion_type'] == 'block':
            self.attention_fusion = self.fusion_factory.create_fusion(config['attention_fusion_block'])
        else:
            raise ValueError('Unimplemented attention fusion')

        if config['final_fusion_type'] == 'concat_mlp':
            self.final_fusion = ConcatMLP(config['final_fusion_mlp'])
        elif config['final_fusion_type'] == 'block':
            self.final_fusion = self.fusion_factory.create_fusion(config['final_fusion_block'])
        else:
            raise ValueError('Unimplemented final fusion')
        self.buffer = None
        self.txt_enc = get_text_enc(config, word_vocabulary)
        self.q_linear0 = nn.Linear(
                config['q_att']['q_linear0']['input_dim'],
                config['q_att']['q_linear0']['output_dim'])
        self.q_linear1 = nn.Linear(
                config['q_att']['q_linear1']['input_dim'],
                config['q_att']['q_linear1']['output_dim'])

        self.obj_linear0 = nn.Linear(
                config['obj_att']['obj_linear0']['input_dim'],
                config['obj_att']['obj_linear0']['output_dim'])
        self.obj_linear1 = nn.Linear(
                config['obj_att']['obj_linear1']['input_dim'],
                config['obj_att']['obj_linear1']['output_dim'])
        self.log_softmax = nn.LogSoftmax(dim=1)
Пример #2
0
    def __init__(self, config):
        super(GraphCell, self).__init__()
        fusion_factory = FusionFactory()
        graph_layer_factory = GraphLayerFactory()
        fusion_features_cfg = config['fusion']['obj_features_question']
        self.fusion_features = fusion_factory.create_fusion(
            fusion_features_cfg)
        graph_cfg = config['graph']
        kwargs = graph_cfg['kwargs']
        if graph_cfg['layer_specify_method'] == 'manual':
            self.manual = True
            self.gat1 = GATConv(2048, 256, heads=4)
            self.gat2 = GATConv(1024, 256, heads=4)
            self.gat3 = GATConv(1024, 512, heads=5, concat=False)
        else:
            self.manual = False
            graph_layer = graph_layer_factory.get_graph_layer(config['graph'])

            self.graph_hidden_list = nn.ModuleList([
                graph_layer(graph_cfg['input_dim'],
                            graph_cfg['graph_hidden_list'][0], **kwargs)
            ])
            if len(graph_cfg['graph_hidden_list']) > 1:
                for length1, length2 in zip(
                        graph_cfg['graph_hidden_list'][:-1],
                        graph_cfg['graph_hidden_list'][1:]):
                    self.graph_hidden_list.append(
                        graph_layer(length1, length2, **kwargs))
            kwargs['concat'] = False
            self.last_layer = graph_layer(
                graph_cfg['graph_hidden_list'][-1] * kwargs['heads']
                if 'heads' in kwargs else graph_cfg['graph_hidden_list'][-1],
                graph_cfg['output_dim'], **kwargs)
Пример #3
0
    def __init__(self, config, word_vocabulary):
        super(MurelNet, self).__init__()
        self.fusion_factory = FusionFactory()
        self.use_pairwise = config['use_pairwise']
        self.use_graph_module = config['use_graph_module']

        if config['use_pairwise']:
            self.murel_cell = MurelCell(config)
        if config['use_graph_module']:
            self.graph_module = GraphCell(config)

        self.buffer = None
        self.final_fusion = self.fusion_factory.create_fusion(
            config['fusion']['final_fusion'])
        self.unroll_steps = config['unroll_steps']
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.txt_enc = get_text_enc(config, word_vocabulary)
        self.linear0 = nn.Linear(config['q_att']['linear0']['input_dim'],
                                 config['q_att']['linear0']['output_dim'])
        self.linear1 = nn.Linear(config['q_att']['linear1']['input_dim'],
                                 config['q_att']['linear1']['output_dim'])
        self.pooling_agg = get_aggregation_func(config['pooling_agg'], dim=1)
Пример #4
0
    def __init__(self, config):
        super(MurelCell, self).__init__()
        self.fusion_factory = FusionFactory()
        fusion_features_cfg = config['fusion']['obj_features_question']
        fusion_box_cfg = config['fusion']['box']
        fusion_fused_cfg = config['fusion']['obj_features_obj_features']

        if config['murel_attention']:
            self.murel_cell_attention = True
            self.murel_cell_attention_linear0 = nn.Linear(
                config['murel_cell_attention']['linear0']['input_dim'],
                config['murel_cell_attention']['linear0']['output_dim'])
            self.murel_cell_attention_linear1 = nn.Linear(
                config['murel_cell_attention']['linear1']['input_dim'],
                config['murel_cell_attention']['linear1']['output_dim'])
        else:
            self.murel_cell_attention = False

        self.buffer = None
        self.fusion_features = self.fusion_factory.create_fusion(
            fusion_features_cfg)
        self.fusion_box = self.fusion_factory.create_fusion(fusion_box_cfg)
        self.fusion_fused = self.fusion_factory.create_fusion(fusion_fused_cfg)
        self.pairwise_agg = get_aggregation_func(config['pairwise_agg'], dim=2)
Пример #5
0
class AttentionNet(nn.Module):
    def __init__(self, config, word_vocabulary):
        super(AttentionNet, self).__init__()

        self.fusion_factory = FusionFactory()
        if config['attention_fusion_type'] == 'concat_mlp':
            self.attention_fusion = ConcatMLP(config['attention_fusion_mlp'])
        elif config['attention_fusion_type'] == 'block':
            self.attention_fusion = self.fusion_factory.create_fusion(config['attention_fusion_block'])
        else:
            raise ValueError('Unimplemented attention fusion')

        if config['final_fusion_type'] == 'concat_mlp':
            self.final_fusion = ConcatMLP(config['final_fusion_mlp'])
        elif config['final_fusion_type'] == 'block':
            self.final_fusion = self.fusion_factory.create_fusion(config['final_fusion_block'])
        else:
            raise ValueError('Unimplemented final fusion')
        self.buffer = None
        self.txt_enc = get_text_enc(config, word_vocabulary)
        self.q_linear0 = nn.Linear(
                config['q_att']['q_linear0']['input_dim'],
                config['q_att']['q_linear0']['output_dim'])
        self.q_linear1 = nn.Linear(
                config['q_att']['q_linear1']['input_dim'],
                config['q_att']['q_linear1']['output_dim'])

        self.obj_linear0 = nn.Linear(
                config['obj_att']['obj_linear0']['input_dim'],
                config['obj_att']['obj_linear0']['output_dim'])
        self.obj_linear1 = nn.Linear(
                config['obj_att']['obj_linear1']['input_dim'],
                config['obj_att']['obj_linear1']['output_dim'])
        self.log_softmax = nn.LogSoftmax(dim=1)

    def initialise_buffers(self):
        self.buffer = {}
        print('Buffer initialised. Model ready to visualise.')

    def forward(self, item):
        question_ids = item['question_ids']
        object_features_list = item['object_features_list']
        question_lengths = item['question_lengths']

        question_each_word_embedding = self.txt_enc.embedding(question_ids)
        question_features, question_final_feature = self.txt_enc.rnn(
                question_each_word_embedding)

        question_attentioned = self.self_attention_question(
                question_features, question_lengths)

        object_attentioned = self.compute_object_attention_with_question(
                question_attentioned, object_features_list)

        # Construct training vector
        x = self.final_fusion([question_attentioned, object_attentioned])
        x = self.log_softmax(x)
        return x
 
    def compute_object_attention_with_question(self, question_self_attentioned, object_features_list):
        batch_size = object_features_list.size(0)
        no_objects = object_features_list.size(1)
        q_expanded = question_self_attentioned.unsqueeze(1).expand(-1, no_objects, -1)
        fused = self.attention_fusion(
                [
                 q_expanded.contiguous().view(batch_size * no_objects, -1),
                 object_features_list.contiguous().view(batch_size * no_objects, -1)
                ]
        )
        fused = fused.view(batch_size, no_objects, -1)
        fused_att = self.obj_linear0(fused)
        fused_att = F.relu(fused_att)
        fused_att = self.obj_linear1(fused_att)
        fused_att = F.softmax(fused_att, dim=1)

        glimpses = torch.unbind(fused_att, dim=2)
        attentioned_glimpses = []
        for i, glimpse in enumerate(glimpses):
            glimpse = glimpse.unsqueeze(2).expand(-1, -1, object_features_list.size(-1))
            if self.buffer is not None:
                self.buffer['glimpse' + str(i)] = glimpse.data().cpu()
            attentioned_feature = object_features_list * glimpse
            attentioned_feature = torch.sum(attentioned_feature, dim=1)
            attentioned_glimpses.append(attentioned_feature)
        fused_attentioned = torch.cat(attentioned_glimpses, dim=1)
        return fused_attentioned

    def self_attention_question(self, question_features, question_lengths):
        q_att = self.q_linear0(question_features)
        q_att = torch.nn.functional.relu(q_att)
        q_att = self.q_linear1(q_att)

        # http://juditacs.github.io/2018/12/27/masked-attention.html
        # Compute attention weights such that the padded units give
        # 0 attention weights
        q_att = masked_softmax(q_att, question_lengths)
        # Glimpses contain attention values for each question_feature
        # DIM: BATCH_SIZE x NO_WORDS
        glimpses = torch.unbind(q_att, dim=2)
        attentioned_glimpses = []
        for glimpse in glimpses:
            glimpse = glimpse.unsqueeze(2).expand(-1, -1, question_features.size(-1))
            attentioned_feature = question_features * glimpse
            attentioned_feature = torch.sum(attentioned_feature, dim=1)
            attentioned_glimpses.append(attentioned_feature)
        question_attentioned = torch.cat(attentioned_glimpses, dim=1)
        return question_attentioned   
Пример #6
0
class MurelCell(nn.Module):
    def __init__(self, config):
        super(MurelCell, self).__init__()
        self.fusion_factory = FusionFactory()
        fusion_features_cfg = config['fusion']['obj_features_question']
        fusion_box_cfg = config['fusion']['box']
        fusion_fused_cfg = config['fusion']['obj_features_obj_features']

        if config['murel_attention']:
            self.murel_cell_attention = True
            self.murel_cell_attention_linear0 = nn.Linear(
                config['murel_cell_attention']['linear0']['input_dim'],
                config['murel_cell_attention']['linear0']['output_dim'])
            self.murel_cell_attention_linear1 = nn.Linear(
                config['murel_cell_attention']['linear1']['input_dim'],
                config['murel_cell_attention']['linear1']['output_dim'])
        else:
            self.murel_cell_attention = False

        self.buffer = None
        self.fusion_features = self.fusion_factory.create_fusion(
            fusion_features_cfg)
        self.fusion_box = self.fusion_factory.create_fusion(fusion_box_cfg)
        self.fusion_fused = self.fusion_factory.create_fusion(fusion_fused_cfg)
        self.pairwise_agg = get_aggregation_func(config['pairwise_agg'], dim=2)

    def initialise_buffers(self):
        self.buffer = {}

    def compute_relation_attention(self, relations):
        _, _, no_objects, no_feats = relations.size()
        r_att = self.murel_cell_attention_linear0(relations)
        r_att = torch.nn.functional.tanh(r_att)
        r_att = self.murel_cell_attention_linear1(r_att)
        r_att = torch.softmax(r_att, dim=2)
        r_att = torch.squeeze(torch.unbind(r_att, dim=3)[0])
        r_att = r_att.unsqueeze(3).expand(-1, no_objects, no_objects, no_feats)
        r_att = relations * r_att
        r_att = torch.sum(r_att, dim=2)
        return r_att

    def pairwise(self, fused_features, bounding_boxes, batch_size, num_obj):
        relations = self.fusion_fused(
                [
                        fused_features.unsqueeze(2).expand(-1, -1, num_obj, -1)
                        .contiguous().view(batch_size * num_obj * num_obj, -1),
                        fused_features.unsqueeze(1).contiguous().expand(-1, num_obj, -1, -1)
                        .contiguous().view(batch_size * num_obj * num_obj, -1)
                ]
                ) + \
                    self.fusion_box(
                [
                        bounding_boxes.unsqueeze(2).expand(-1, -1, num_obj, -1)
                        .contiguous().view(batch_size * num_obj * num_obj, -1),
                        bounding_boxes.unsqueeze(1).contiguous().expand(-1, num_obj, -1, -1)
                        .contiguous().view(batch_size * num_obj * num_obj, -1)
                ]
                )
        relations = relations.view(batch_size, num_obj, num_obj,
                                   -1)  # BS x 36 x 2048 x 2048
        if self.murel_cell_attention:
            e_hat = self.compute_relation_attention(relations)
        else:
            e_hat = self.pairwise_agg(relations)  # BS x 36 x 2048

        res = fused_features + e_hat
        if self.buffer is not None:
            _, argmax = torch.max(res, dim=1)
            l2_norm = torch.norm(e_hat / res, dim=2)  # BS x 36
            self.buffer['i_hat'] = torch.max(l2_norm,
                                             dim=1)[1].data.cpu()  # BS
            self.buffer['relations'] = relations.data.cpu()
            self.buffer['argmax'] = argmax.data.cpu()
        return res

    def fuse_object_features_with_questions(self, object_features_list,
                                            question_embedding, batch_size,
                                            num_obj):

        res = self.fusion_features([
            question_embedding,
            object_features_list.contiguous().view(batch_size * num_obj, -1),
        ])
        res = res.view(batch_size, num_obj, -1)
        return res

    def forward(self, question_embedding, object_features_list, bounding_boxes,
                batch_size, num_obj):
        # Sensitive?
        fused_question_object = self.fuse_object_features_with_questions(
            object_features_list, question_embedding, batch_size, num_obj)

        pairwise_res = self.pairwise(fused_question_object, bounding_boxes,
                                     batch_size, num_obj)
        res = object_features_list + pairwise_res
        return res
Пример #7
0
class MurelNet(nn.Module):
    def __init__(self, config, word_vocabulary):
        super(MurelNet, self).__init__()
        self.fusion_factory = FusionFactory()
        self.use_pairwise = config['use_pairwise']
        self.use_graph_module = config['use_graph_module']

        if config['use_pairwise']:
            self.murel_cell = MurelCell(config)
        if config['use_graph_module']:
            self.graph_module = GraphCell(config)

        self.buffer = None
        self.final_fusion = self.fusion_factory.create_fusion(
            config['fusion']['final_fusion'])
        self.unroll_steps = config['unroll_steps']
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.txt_enc = get_text_enc(config, word_vocabulary)
        self.linear0 = nn.Linear(config['q_att']['linear0']['input_dim'],
                                 config['q_att']['linear0']['output_dim'])
        self.linear1 = nn.Linear(config['q_att']['linear1']['input_dim'],
                                 config['q_att']['linear1']['output_dim'])
        self.pooling_agg = get_aggregation_func(config['pooling_agg'], dim=1)

    def initialise_buffers(self):
        self.buffer = {}
        self.murel_cell.initialise_buffers()

    def forward(self, item):
        question_ids = item['question_ids']
        object_features_list = item['object_features_list']
        bounding_boxes = item['bounding_boxes']
        question_lengths = item['question_lengths']
        if self.use_graph_module:
            graph_batch = item['graph_batch']

        # q_att
        question_each_word_embedding = self.txt_enc.embedding(question_ids)
        question_features, question_final_feature = self.txt_enc.rnn(
            question_each_word_embedding)

        q_att = self.linear0(question_features)
        q_att = torch.nn.functional.relu(q_att)
        q_att = self.linear1(q_att)

        # http://juditacs.github.io/2018/12/27/masked-attention.html
        # Compute attention weights such that the padded units
        # give 0 attention weights
        q_att = masked_softmax(q_att, question_lengths)
        # Glimpses contain attention values for each question_feature
        # DIM: BATCH_SIZE x NO_WORDS
        glimpses = torch.unbind(q_att, dim=2)
        attentioned_glimpses = []
        for glimpse in glimpses:
            glimpse = glimpse.unsqueeze(2).expand(-1, -1,
                                                  question_features.size(-1))
            attentioned_feature = question_features * glimpse
            attentioned_feature = torch.sum(attentioned_feature, dim=1)
            attentioned_glimpses.append(attentioned_feature)
        question_attentioned = torch.cat(attentioned_glimpses, dim=1)

        batch_size, num_obj, _ = list(object_features_list.size())
        # Resize question outside for loop as
        # it would be used repeatedly in multiple unroll steps
        # Reshape question
        # (BATCH_SIZE x QUES_DIM) TO (BATCH_SIZE x NUM_OBJ x QUES_DIM)
        question_attentioned_repeated = question_attentioned.unsqueeze(
            1).expand(-1, num_obj, -1).contiguous()

        # Reshape question to (BATCH_SIZE * NUM_OBJ x QUES_DIM)
        question_attentioned_repeated = question_attentioned_repeated.view(
            batch_size * num_obj, -1)

        if self.use_pairwise:
            for i in range(self.unroll_steps):
                object_features_list = self.murel_cell(
                    question_attentioned_repeated, object_features_list,
                    bounding_boxes, batch_size, num_obj)
                if self.buffer is not None:
                    self.buffer[i] = deepcopy(self.murel_cell.buffer)
            pool = self.pooling_agg(object_features_list)

        if self.use_graph_module:
            object_features_list = object_features_list.contiguous()
            object_features_list = object_features_list.view(
                batch_size * num_obj, -1)
            for i in range(1):
                object_features_list = self.graph_module(
                    question_attentioned_repeated, object_features_list,
                    bounding_boxes, batch_size, num_obj, graph_batch)
            pool = global_max_pool(object_features_list, graph_batch.batch)
        scores = self.final_fusion([question_attentioned, pool])
        prob = self.log_softmax(scores)
        return prob
Пример #8
0
class FRCNNConcat(nn.Module):
    def __init__(self, config, word_vocabulary):
        super(FRCNNConcat, self).__init__()
        self.fusion_factory = FusionFactory()
        self.agg_type = config['agg_type']
        self.q_self_attention = config['q_self_attention']
        if config['fusion_type'] == 'concat_mlp':
            self.fusion = ConcatMLP(config['fusion_mlp'])
        elif config['fusion_type'] == 'block':
            self.fusion = self.fusion_factory.create_fusion(
                config['fusion_block'])
        else:
            raise ValueError('Unimplemented attention fusion')

        self.txt_enc = get_text_enc(config, word_vocabulary)
        self.q_linear0 = nn.Linear(config['q_att']['q_linear0']['input_dim'],
                                   config['q_att']['q_linear0']['output_dim'])
        self.q_linear1 = nn.Linear(config['q_att']['q_linear1']['input_dim'],
                                   config['q_att']['q_linear1']['output_dim'])
        self.log_softmax = nn.LogSoftmax(dim=1)

    def forward(self, item):
        question_ids = item['question_ids']
        object_features_list = item['object_features_list']
        question_lengths = item['question_lengths']

        question_each_word_embedding = self.txt_enc.embedding(question_ids)
        question_features, question_final_feature = self.txt_enc.rnn(
            question_each_word_embedding)

        if self.q_self_attention:
            question_attentioned = self.self_attention_question(
                question_features, question_lengths)
        else:
            question_attentioned = question_final_feature

        object_attentioned = self.process_butd_features(
            object_features_list, self.agg_type)

        # Construct training vector
        x = self.fusion([question_attentioned, object_attentioned])
        x = self.log_softmax(x)
        return x

    def process_butd_features(self, object_features, agg_type):
        if agg_type == 'mean':
            return torch.mean(object_features, dim=1)
        if agg_type == 'max':
            return torch.max(object_features, dim=1)[0]
        if agg_type == 'min':
            return torch.min(object_features, dim=1)[0]
        if agg_type == 'sum':
            return torch.sum(object_features, dim=1)

    def self_attention_question(self, question_features, question_lengths):
        q_att = self.q_linear0(question_features)
        q_att = torch.nn.functional.relu(q_att)
        q_att = self.q_linear1(q_att)

        # http://juditacs.github.io/2018/12/27/masked-attention.html
        # Compute attention weights such that the padded units give
        # 0 attention weights
        q_att = masked_softmax(q_att, question_lengths)
        # Glimpses contain attention values for each question_feature
        # DIM: BATCH_SIZE x NO_WORDS
        glimpses = torch.unbind(q_att, dim=2)
        attentioned_glimpses = []
        for glimpse in glimpses:
            glimpse = glimpse.unsqueeze(2).expand(-1, -1,
                                                  question_features.size(-1))
            attentioned_feature = question_features * glimpse
            attentioned_feature = torch.sum(attentioned_feature, dim=1)
            attentioned_glimpses.append(attentioned_feature)
        question_attentioned = torch.cat(attentioned_glimpses, dim=1)
        return question_attentioned