def __init__(self, config, word_vocabulary): super(AttentionNet, self).__init__() self.fusion_factory = FusionFactory() if config['attention_fusion_type'] == 'concat_mlp': self.attention_fusion = ConcatMLP(config['attention_fusion_mlp']) elif config['attention_fusion_type'] == 'block': self.attention_fusion = self.fusion_factory.create_fusion(config['attention_fusion_block']) else: raise ValueError('Unimplemented attention fusion') if config['final_fusion_type'] == 'concat_mlp': self.final_fusion = ConcatMLP(config['final_fusion_mlp']) elif config['final_fusion_type'] == 'block': self.final_fusion = self.fusion_factory.create_fusion(config['final_fusion_block']) else: raise ValueError('Unimplemented final fusion') self.buffer = None self.txt_enc = get_text_enc(config, word_vocabulary) self.q_linear0 = nn.Linear( config['q_att']['q_linear0']['input_dim'], config['q_att']['q_linear0']['output_dim']) self.q_linear1 = nn.Linear( config['q_att']['q_linear1']['input_dim'], config['q_att']['q_linear1']['output_dim']) self.obj_linear0 = nn.Linear( config['obj_att']['obj_linear0']['input_dim'], config['obj_att']['obj_linear0']['output_dim']) self.obj_linear1 = nn.Linear( config['obj_att']['obj_linear1']['input_dim'], config['obj_att']['obj_linear1']['output_dim']) self.log_softmax = nn.LogSoftmax(dim=1)
def __init__(self, config): super(GraphCell, self).__init__() fusion_factory = FusionFactory() graph_layer_factory = GraphLayerFactory() fusion_features_cfg = config['fusion']['obj_features_question'] self.fusion_features = fusion_factory.create_fusion( fusion_features_cfg) graph_cfg = config['graph'] kwargs = graph_cfg['kwargs'] if graph_cfg['layer_specify_method'] == 'manual': self.manual = True self.gat1 = GATConv(2048, 256, heads=4) self.gat2 = GATConv(1024, 256, heads=4) self.gat3 = GATConv(1024, 512, heads=5, concat=False) else: self.manual = False graph_layer = graph_layer_factory.get_graph_layer(config['graph']) self.graph_hidden_list = nn.ModuleList([ graph_layer(graph_cfg['input_dim'], graph_cfg['graph_hidden_list'][0], **kwargs) ]) if len(graph_cfg['graph_hidden_list']) > 1: for length1, length2 in zip( graph_cfg['graph_hidden_list'][:-1], graph_cfg['graph_hidden_list'][1:]): self.graph_hidden_list.append( graph_layer(length1, length2, **kwargs)) kwargs['concat'] = False self.last_layer = graph_layer( graph_cfg['graph_hidden_list'][-1] * kwargs['heads'] if 'heads' in kwargs else graph_cfg['graph_hidden_list'][-1], graph_cfg['output_dim'], **kwargs)
def __init__(self, config, word_vocabulary): super(MurelNet, self).__init__() self.fusion_factory = FusionFactory() self.use_pairwise = config['use_pairwise'] self.use_graph_module = config['use_graph_module'] if config['use_pairwise']: self.murel_cell = MurelCell(config) if config['use_graph_module']: self.graph_module = GraphCell(config) self.buffer = None self.final_fusion = self.fusion_factory.create_fusion( config['fusion']['final_fusion']) self.unroll_steps = config['unroll_steps'] self.log_softmax = nn.LogSoftmax(dim=1) self.txt_enc = get_text_enc(config, word_vocabulary) self.linear0 = nn.Linear(config['q_att']['linear0']['input_dim'], config['q_att']['linear0']['output_dim']) self.linear1 = nn.Linear(config['q_att']['linear1']['input_dim'], config['q_att']['linear1']['output_dim']) self.pooling_agg = get_aggregation_func(config['pooling_agg'], dim=1)
def __init__(self, config): super(MurelCell, self).__init__() self.fusion_factory = FusionFactory() fusion_features_cfg = config['fusion']['obj_features_question'] fusion_box_cfg = config['fusion']['box'] fusion_fused_cfg = config['fusion']['obj_features_obj_features'] if config['murel_attention']: self.murel_cell_attention = True self.murel_cell_attention_linear0 = nn.Linear( config['murel_cell_attention']['linear0']['input_dim'], config['murel_cell_attention']['linear0']['output_dim']) self.murel_cell_attention_linear1 = nn.Linear( config['murel_cell_attention']['linear1']['input_dim'], config['murel_cell_attention']['linear1']['output_dim']) else: self.murel_cell_attention = False self.buffer = None self.fusion_features = self.fusion_factory.create_fusion( fusion_features_cfg) self.fusion_box = self.fusion_factory.create_fusion(fusion_box_cfg) self.fusion_fused = self.fusion_factory.create_fusion(fusion_fused_cfg) self.pairwise_agg = get_aggregation_func(config['pairwise_agg'], dim=2)
class AttentionNet(nn.Module): def __init__(self, config, word_vocabulary): super(AttentionNet, self).__init__() self.fusion_factory = FusionFactory() if config['attention_fusion_type'] == 'concat_mlp': self.attention_fusion = ConcatMLP(config['attention_fusion_mlp']) elif config['attention_fusion_type'] == 'block': self.attention_fusion = self.fusion_factory.create_fusion(config['attention_fusion_block']) else: raise ValueError('Unimplemented attention fusion') if config['final_fusion_type'] == 'concat_mlp': self.final_fusion = ConcatMLP(config['final_fusion_mlp']) elif config['final_fusion_type'] == 'block': self.final_fusion = self.fusion_factory.create_fusion(config['final_fusion_block']) else: raise ValueError('Unimplemented final fusion') self.buffer = None self.txt_enc = get_text_enc(config, word_vocabulary) self.q_linear0 = nn.Linear( config['q_att']['q_linear0']['input_dim'], config['q_att']['q_linear0']['output_dim']) self.q_linear1 = nn.Linear( config['q_att']['q_linear1']['input_dim'], config['q_att']['q_linear1']['output_dim']) self.obj_linear0 = nn.Linear( config['obj_att']['obj_linear0']['input_dim'], config['obj_att']['obj_linear0']['output_dim']) self.obj_linear1 = nn.Linear( config['obj_att']['obj_linear1']['input_dim'], config['obj_att']['obj_linear1']['output_dim']) self.log_softmax = nn.LogSoftmax(dim=1) def initialise_buffers(self): self.buffer = {} print('Buffer initialised. Model ready to visualise.') def forward(self, item): question_ids = item['question_ids'] object_features_list = item['object_features_list'] question_lengths = item['question_lengths'] question_each_word_embedding = self.txt_enc.embedding(question_ids) question_features, question_final_feature = self.txt_enc.rnn( question_each_word_embedding) question_attentioned = self.self_attention_question( question_features, question_lengths) object_attentioned = self.compute_object_attention_with_question( question_attentioned, object_features_list) # Construct training vector x = self.final_fusion([question_attentioned, object_attentioned]) x = self.log_softmax(x) return x def compute_object_attention_with_question(self, question_self_attentioned, object_features_list): batch_size = object_features_list.size(0) no_objects = object_features_list.size(1) q_expanded = question_self_attentioned.unsqueeze(1).expand(-1, no_objects, -1) fused = self.attention_fusion( [ q_expanded.contiguous().view(batch_size * no_objects, -1), object_features_list.contiguous().view(batch_size * no_objects, -1) ] ) fused = fused.view(batch_size, no_objects, -1) fused_att = self.obj_linear0(fused) fused_att = F.relu(fused_att) fused_att = self.obj_linear1(fused_att) fused_att = F.softmax(fused_att, dim=1) glimpses = torch.unbind(fused_att, dim=2) attentioned_glimpses = [] for i, glimpse in enumerate(glimpses): glimpse = glimpse.unsqueeze(2).expand(-1, -1, object_features_list.size(-1)) if self.buffer is not None: self.buffer['glimpse' + str(i)] = glimpse.data().cpu() attentioned_feature = object_features_list * glimpse attentioned_feature = torch.sum(attentioned_feature, dim=1) attentioned_glimpses.append(attentioned_feature) fused_attentioned = torch.cat(attentioned_glimpses, dim=1) return fused_attentioned def self_attention_question(self, question_features, question_lengths): q_att = self.q_linear0(question_features) q_att = torch.nn.functional.relu(q_att) q_att = self.q_linear1(q_att) # http://juditacs.github.io/2018/12/27/masked-attention.html # Compute attention weights such that the padded units give # 0 attention weights q_att = masked_softmax(q_att, question_lengths) # Glimpses contain attention values for each question_feature # DIM: BATCH_SIZE x NO_WORDS glimpses = torch.unbind(q_att, dim=2) attentioned_glimpses = [] for glimpse in glimpses: glimpse = glimpse.unsqueeze(2).expand(-1, -1, question_features.size(-1)) attentioned_feature = question_features * glimpse attentioned_feature = torch.sum(attentioned_feature, dim=1) attentioned_glimpses.append(attentioned_feature) question_attentioned = torch.cat(attentioned_glimpses, dim=1) return question_attentioned
class MurelCell(nn.Module): def __init__(self, config): super(MurelCell, self).__init__() self.fusion_factory = FusionFactory() fusion_features_cfg = config['fusion']['obj_features_question'] fusion_box_cfg = config['fusion']['box'] fusion_fused_cfg = config['fusion']['obj_features_obj_features'] if config['murel_attention']: self.murel_cell_attention = True self.murel_cell_attention_linear0 = nn.Linear( config['murel_cell_attention']['linear0']['input_dim'], config['murel_cell_attention']['linear0']['output_dim']) self.murel_cell_attention_linear1 = nn.Linear( config['murel_cell_attention']['linear1']['input_dim'], config['murel_cell_attention']['linear1']['output_dim']) else: self.murel_cell_attention = False self.buffer = None self.fusion_features = self.fusion_factory.create_fusion( fusion_features_cfg) self.fusion_box = self.fusion_factory.create_fusion(fusion_box_cfg) self.fusion_fused = self.fusion_factory.create_fusion(fusion_fused_cfg) self.pairwise_agg = get_aggregation_func(config['pairwise_agg'], dim=2) def initialise_buffers(self): self.buffer = {} def compute_relation_attention(self, relations): _, _, no_objects, no_feats = relations.size() r_att = self.murel_cell_attention_linear0(relations) r_att = torch.nn.functional.tanh(r_att) r_att = self.murel_cell_attention_linear1(r_att) r_att = torch.softmax(r_att, dim=2) r_att = torch.squeeze(torch.unbind(r_att, dim=3)[0]) r_att = r_att.unsqueeze(3).expand(-1, no_objects, no_objects, no_feats) r_att = relations * r_att r_att = torch.sum(r_att, dim=2) return r_att def pairwise(self, fused_features, bounding_boxes, batch_size, num_obj): relations = self.fusion_fused( [ fused_features.unsqueeze(2).expand(-1, -1, num_obj, -1) .contiguous().view(batch_size * num_obj * num_obj, -1), fused_features.unsqueeze(1).contiguous().expand(-1, num_obj, -1, -1) .contiguous().view(batch_size * num_obj * num_obj, -1) ] ) + \ self.fusion_box( [ bounding_boxes.unsqueeze(2).expand(-1, -1, num_obj, -1) .contiguous().view(batch_size * num_obj * num_obj, -1), bounding_boxes.unsqueeze(1).contiguous().expand(-1, num_obj, -1, -1) .contiguous().view(batch_size * num_obj * num_obj, -1) ] ) relations = relations.view(batch_size, num_obj, num_obj, -1) # BS x 36 x 2048 x 2048 if self.murel_cell_attention: e_hat = self.compute_relation_attention(relations) else: e_hat = self.pairwise_agg(relations) # BS x 36 x 2048 res = fused_features + e_hat if self.buffer is not None: _, argmax = torch.max(res, dim=1) l2_norm = torch.norm(e_hat / res, dim=2) # BS x 36 self.buffer['i_hat'] = torch.max(l2_norm, dim=1)[1].data.cpu() # BS self.buffer['relations'] = relations.data.cpu() self.buffer['argmax'] = argmax.data.cpu() return res def fuse_object_features_with_questions(self, object_features_list, question_embedding, batch_size, num_obj): res = self.fusion_features([ question_embedding, object_features_list.contiguous().view(batch_size * num_obj, -1), ]) res = res.view(batch_size, num_obj, -1) return res def forward(self, question_embedding, object_features_list, bounding_boxes, batch_size, num_obj): # Sensitive? fused_question_object = self.fuse_object_features_with_questions( object_features_list, question_embedding, batch_size, num_obj) pairwise_res = self.pairwise(fused_question_object, bounding_boxes, batch_size, num_obj) res = object_features_list + pairwise_res return res
class MurelNet(nn.Module): def __init__(self, config, word_vocabulary): super(MurelNet, self).__init__() self.fusion_factory = FusionFactory() self.use_pairwise = config['use_pairwise'] self.use_graph_module = config['use_graph_module'] if config['use_pairwise']: self.murel_cell = MurelCell(config) if config['use_graph_module']: self.graph_module = GraphCell(config) self.buffer = None self.final_fusion = self.fusion_factory.create_fusion( config['fusion']['final_fusion']) self.unroll_steps = config['unroll_steps'] self.log_softmax = nn.LogSoftmax(dim=1) self.txt_enc = get_text_enc(config, word_vocabulary) self.linear0 = nn.Linear(config['q_att']['linear0']['input_dim'], config['q_att']['linear0']['output_dim']) self.linear1 = nn.Linear(config['q_att']['linear1']['input_dim'], config['q_att']['linear1']['output_dim']) self.pooling_agg = get_aggregation_func(config['pooling_agg'], dim=1) def initialise_buffers(self): self.buffer = {} self.murel_cell.initialise_buffers() def forward(self, item): question_ids = item['question_ids'] object_features_list = item['object_features_list'] bounding_boxes = item['bounding_boxes'] question_lengths = item['question_lengths'] if self.use_graph_module: graph_batch = item['graph_batch'] # q_att question_each_word_embedding = self.txt_enc.embedding(question_ids) question_features, question_final_feature = self.txt_enc.rnn( question_each_word_embedding) q_att = self.linear0(question_features) q_att = torch.nn.functional.relu(q_att) q_att = self.linear1(q_att) # http://juditacs.github.io/2018/12/27/masked-attention.html # Compute attention weights such that the padded units # give 0 attention weights q_att = masked_softmax(q_att, question_lengths) # Glimpses contain attention values for each question_feature # DIM: BATCH_SIZE x NO_WORDS glimpses = torch.unbind(q_att, dim=2) attentioned_glimpses = [] for glimpse in glimpses: glimpse = glimpse.unsqueeze(2).expand(-1, -1, question_features.size(-1)) attentioned_feature = question_features * glimpse attentioned_feature = torch.sum(attentioned_feature, dim=1) attentioned_glimpses.append(attentioned_feature) question_attentioned = torch.cat(attentioned_glimpses, dim=1) batch_size, num_obj, _ = list(object_features_list.size()) # Resize question outside for loop as # it would be used repeatedly in multiple unroll steps # Reshape question # (BATCH_SIZE x QUES_DIM) TO (BATCH_SIZE x NUM_OBJ x QUES_DIM) question_attentioned_repeated = question_attentioned.unsqueeze( 1).expand(-1, num_obj, -1).contiguous() # Reshape question to (BATCH_SIZE * NUM_OBJ x QUES_DIM) question_attentioned_repeated = question_attentioned_repeated.view( batch_size * num_obj, -1) if self.use_pairwise: for i in range(self.unroll_steps): object_features_list = self.murel_cell( question_attentioned_repeated, object_features_list, bounding_boxes, batch_size, num_obj) if self.buffer is not None: self.buffer[i] = deepcopy(self.murel_cell.buffer) pool = self.pooling_agg(object_features_list) if self.use_graph_module: object_features_list = object_features_list.contiguous() object_features_list = object_features_list.view( batch_size * num_obj, -1) for i in range(1): object_features_list = self.graph_module( question_attentioned_repeated, object_features_list, bounding_boxes, batch_size, num_obj, graph_batch) pool = global_max_pool(object_features_list, graph_batch.batch) scores = self.final_fusion([question_attentioned, pool]) prob = self.log_softmax(scores) return prob
class FRCNNConcat(nn.Module): def __init__(self, config, word_vocabulary): super(FRCNNConcat, self).__init__() self.fusion_factory = FusionFactory() self.agg_type = config['agg_type'] self.q_self_attention = config['q_self_attention'] if config['fusion_type'] == 'concat_mlp': self.fusion = ConcatMLP(config['fusion_mlp']) elif config['fusion_type'] == 'block': self.fusion = self.fusion_factory.create_fusion( config['fusion_block']) else: raise ValueError('Unimplemented attention fusion') self.txt_enc = get_text_enc(config, word_vocabulary) self.q_linear0 = nn.Linear(config['q_att']['q_linear0']['input_dim'], config['q_att']['q_linear0']['output_dim']) self.q_linear1 = nn.Linear(config['q_att']['q_linear1']['input_dim'], config['q_att']['q_linear1']['output_dim']) self.log_softmax = nn.LogSoftmax(dim=1) def forward(self, item): question_ids = item['question_ids'] object_features_list = item['object_features_list'] question_lengths = item['question_lengths'] question_each_word_embedding = self.txt_enc.embedding(question_ids) question_features, question_final_feature = self.txt_enc.rnn( question_each_word_embedding) if self.q_self_attention: question_attentioned = self.self_attention_question( question_features, question_lengths) else: question_attentioned = question_final_feature object_attentioned = self.process_butd_features( object_features_list, self.agg_type) # Construct training vector x = self.fusion([question_attentioned, object_attentioned]) x = self.log_softmax(x) return x def process_butd_features(self, object_features, agg_type): if agg_type == 'mean': return torch.mean(object_features, dim=1) if agg_type == 'max': return torch.max(object_features, dim=1)[0] if agg_type == 'min': return torch.min(object_features, dim=1)[0] if agg_type == 'sum': return torch.sum(object_features, dim=1) def self_attention_question(self, question_features, question_lengths): q_att = self.q_linear0(question_features) q_att = torch.nn.functional.relu(q_att) q_att = self.q_linear1(q_att) # http://juditacs.github.io/2018/12/27/masked-attention.html # Compute attention weights such that the padded units give # 0 attention weights q_att = masked_softmax(q_att, question_lengths) # Glimpses contain attention values for each question_feature # DIM: BATCH_SIZE x NO_WORDS glimpses = torch.unbind(q_att, dim=2) attentioned_glimpses = [] for glimpse in glimpses: glimpse = glimpse.unsqueeze(2).expand(-1, -1, question_features.size(-1)) attentioned_feature = question_features * glimpse attentioned_feature = torch.sum(attentioned_feature, dim=1) attentioned_glimpses.append(attentioned_feature) question_attentioned = torch.cat(attentioned_glimpses, dim=1) return question_attentioned