示例#1
0
  def __init__(self, vocab, image_size=(64, 64), embedding_dim=64,
               gconv_dim=128, gconv_hidden_dim=512,
               gconv_pooling='avg', gconv_num_layers=5,
               refinement_dims=(1024, 512, 256, 128, 64),
               normalization='batch', activation='leakyrelu-0.2',
               mask_size=None, mlp_normalization='none', layout_noise_dim=0,
               **kwargs):
    super(Sg2ImModel, self).__init__()

    # We used to have some additional arguments: 
    # vec_noise_dim, gconv_mode, box_anchor, decouple_obj_predictions
    if len(kwargs) > 0:
      print('WARNING: Model got unexpected kwargs ', kwargs)

    self.vocab = vocab
    self.image_size = image_size
    self.layout_noise_dim = layout_noise_dim

    num_objs = len(vocab['object_idx_to_name'])
    num_preds = len(vocab['pred_idx_to_name'])
    self.obj_embeddings = nn.Embedding(num_objs + 1, embedding_dim)
    self.pred_embeddings = nn.Embedding(num_preds, embedding_dim)

    if gconv_num_layers == 0:
      self.gconv = nn.Linear(embedding_dim, gconv_dim)
    elif gconv_num_layers > 0:
      gconv_kwargs = {
        'input_dim': embedding_dim,
        'output_dim': gconv_dim,
        'hidden_dim': gconv_hidden_dim,
        'pooling': gconv_pooling,
        'mlp_normalization': mlp_normalization,
      }
      self.gconv = GraphTripleConv(**gconv_kwargs)

    self.gconv_net = None
    if gconv_num_layers > 1:
      gconv_kwargs = {
        'input_dim': gconv_dim,
        'hidden_dim': gconv_hidden_dim,
        'pooling': gconv_pooling,
        'num_layers': gconv_num_layers - 1,
        'mlp_normalization': mlp_normalization,
      }
      self.gconv_net = GraphTripleConvNet(**gconv_kwargs)

    box_net_dim = 4
    box_net_layers = [gconv_dim, gconv_hidden_dim, box_net_dim]
    self.box_net = build_mlp(box_net_layers, batch_norm=mlp_normalization)

    self.mask_net = None
    if mask_size is not None and mask_size > 0:
      self.mask_net = self._build_mask_net(num_objs, gconv_dim, mask_size)
示例#2
0
    def __init__(
            self,
            vocab,
            image_size=(64, 64),
            embedding_dim=64,
            gconv_dim=128,
            gconv_hidden_dim=512,
            gconv_pooling='avg',
            gconv_num_layers=5,
            refinement_dims=(1024, 512, 256, 128, 64),
            normalization='batch',
            activation='leakyrelu-0.2',
            mask_size=None,
            mlp_normalization='none',
            layout_noise_dim=0,
            sg_context_dim=0,  #None, 
            sg_context_dim_d=0,  #None, 
            gcnn_pooling='avg',
            triplet_box_net=False,
            triplet_mask_size=0,
            triplet_embedding_size=0,
            use_bbox_info=False,
            triplet_superbox_net=False,
            **kwargs):
        super(Sg2ImModel, self).__init__()

        # We used to have some additional arguments:
        # vec_noise_dim, gconv_mode, box_anchor, decouple_obj_predictions
        if len(kwargs) > 0:
            print('WARNING: Model got unexpected kwargs ', kwargs)

        self.vocab = vocab
        self.image_size = image_size
        self.layout_noise_dim = layout_noise_dim
        self.sg_context_dim = sg_context_dim
        self.sg_context_dim_d = sg_context_dim_d
        self.gcnn_pooling = gcnn_pooling
        self.triplet_box_net = triplet_box_net
        self.triplet_mask_size = triplet_mask_size
        self.triplet_embedding_size = triplet_embedding_size
        self.use_bbox_info = use_bbox_info
        self.triplet_superbox_net = triplet_superbox_net

        num_objs = len(vocab['object_idx_to_name'])
        num_preds = len(vocab['pred_idx_to_name'])
        self.obj_embeddings = nn.Embedding(num_objs + 1, embedding_dim)
        self.pred_embeddings = nn.Embedding(num_preds, embedding_dim)

        if gconv_num_layers == 0:
            self.gconv = nn.Linear(embedding_dim, gconv_dim)
        elif gconv_num_layers > 0:
            gconv_kwargs = {
                'input_dim': embedding_dim,
                'output_dim': gconv_dim,
                'hidden_dim': gconv_hidden_dim,
                'pooling': gconv_pooling,
                'mlp_normalization': mlp_normalization,
            }
            self.gconv = GraphTripleConv(**gconv_kwargs)

        self.gconv_net = None
        if gconv_num_layers > 1:
            gconv_kwargs = {
                'input_dim': gconv_dim,
                'hidden_dim': gconv_hidden_dim,
                'pooling': gconv_pooling,
                'num_layers': gconv_num_layers - 1,
                'mlp_normalization': mlp_normalization,
            }
            self.gconv_net = GraphTripleConvNet(**gconv_kwargs)

        if self.use_bbox_info:
            box_net_dim = 4 + 1  # augment with addition info abt bbox
        else:
            box_net_dim = 4
        box_net_layers = [gconv_dim, gconv_hidden_dim, box_net_dim]
        self.box_net = build_mlp(box_net_layers, batch_norm=mlp_normalization)

        # triplet-related nets
        self.triplet_box_net = None
        self.triplet_embed_net = None
        self.triplet_mask_net = None
        self.triplet_superbox_net = None

        # output dimension
        triplet_box_net_dim = 8
        if triplet_box_net:
            # input dimension is 3*128 for concatenated triplet
            triplet_box_net_layers = [
                3 * gconv_dim, gconv_hidden_dim, triplet_box_net_dim
            ]
            self.triplet_box_net = build_mlp(triplet_box_net_layers,
                                             batch_norm=mlp_normalization)

        # triplet embedding
        if self.triplet_embedding_size > 0:
            # input dimsn is 3*128 for concatenated triplet, output dimsn is triplet_embed_dim
            triplet_embed_layers = [
                3 * gconv_dim, gconv_hidden_dim, triplet_embedding_size
            ]
            self.triplet_embed_net = build_mlp(triplet_embed_layers,
                                               batch_norm=mlp_normalization)

        if self.triplet_mask_size > 0:
            # input dimsn is 3*128 for concatenated triplet, output dimsn is triplet_mask_size
            #self.triplet_mask_net = self._build_mask_net(num_objs, 3*gconv_dim, self.triplet_mask_size)
            self.triplet_mask_net = self._build_triplet_mask_net(
                num_objs, 3 * gconv_dim, self.triplet_mask_size)

        triplet_superbox_net_dim = 4
        if triplet_superbox_net:
            # input dimension is 3*128 for concatenated triplet
            triplet_superbox_net_layers = [
                3 * gconv_dim, gconv_hidden_dim, triplet_superbox_net_dim
            ]
            self.triplet_superbox_net = build_mlp(triplet_superbox_net_layers,
                                                  batch_norm=mlp_normalization)

        self.mask_net = None
        if mask_size is not None and mask_size > 0:
            self.mask_net = self._build_mask_net(num_objs, gconv_dim,
                                                 mask_size)

        ###########################
        self.sg_context_net = None
        self.sg_context_net_d = None
        if sg_context_dim is not None and sg_context_dim > 0:
            H, W = self.image_size
            self.sg_context_net = nn.Linear(gconv_dim, sg_context_dim)
            self.sg_context_net_d = nn.Linear(gconv_dim, sg_context_dim_d)
            # sg_context_net_layers = [gconv_dim, sg_context_dim]
            # sg_context_net_layers = [gconv_dim, sg_context_dim_d]
            # self.sg_context_net = build_mlp(sg_context_net_layers, batch_norm=mlp_normalization)
            # self.sg_context_net_d = build_mlp(sg_context_net_layers, batch_norm=mlp_normalization)
        #######################

        rel_aux_layers = [2 * embedding_dim + 8, gconv_hidden_dim, num_preds]
        self.rel_aux_net = build_mlp(rel_aux_layers,
                                     batch_norm=mlp_normalization)

        if sg_context_dim > 0:
            refinement_kwargs = {
                'dims': (gconv_dim + sg_context_dim + layout_noise_dim, ) +
                refinement_dims,
                'normalization': normalization,
                'activation': activation,
            }
        else:
            refinement_kwargs = {
                'dims': (gconv_dim + layout_noise_dim, ) + refinement_dims,
                'normalization': normalization,
                'activation': activation,
            }
        self.refinement_net = RefinementNetwork(**refinement_kwargs)
示例#3
0
  def __init__(self, vocab, image_size=(64, 64), embedding_dim=64,
               gconv_dim=128, gconv_hidden_dim=512,
               gconv_pooling='avg', gconv_num_layers=5,
               refinement_dims=(1024, 512, 256, 128, 64),
               normalization='batch', activation='leakyrelu-0.2',
               mask_size=None, mlp_normalization='none', layout_noise_dim=0,
               sg_context_dim=0, #None, 
               sg_context_dim_d=0, #None, 
               gcnn_pooling='avg',
               triplet_box_net=False,
               triplet_mask_size=0,
               triplet_embedding_size=0,
               use_bbox_info=False,
               triplet_superbox_net=False,
               use_masked_sg=False,
               **kwargs):
    super(Sg2ImModel, self).__init__()

    # We used to have some additional arguments: 
    # vec_noise_dim, gconv_mode, box_anchor, decouple_obj_predictions
    if len(kwargs) > 0:
      print('WARNING: Model got unexpected kwargs ', kwargs)

    self.vocab = vocab
    self.image_size = image_size
    self.layout_noise_dim = layout_noise_dim
    self.sg_context_dim = sg_context_dim 
    self.sg_context_dim_d = sg_context_dim_d 
    self.gcnn_pooling = gcnn_pooling 
    self.triplet_box_net = triplet_box_net 
    self.triplet_mask_size = triplet_mask_size
    self.triplet_embedding_size = triplet_embedding_size
    self.use_bbox_info = use_bbox_info
    self.triplet_superbox_net = triplet_superbox_net
    self.use_masked_sg = use_masked_sg
    # hack to deal with vocabs with differing # of predicates
    self.mask_pred = 46 # vocab['idx_to_pred_name'][46] = 'none'
    #self.mask_pred = vocab['pred_name_to_idx']['none']
    self.embedding_dim = embedding_dim 
  
    num_objs = len(vocab['object_idx_to_name'])
    num_preds = len(vocab['pred_idx_to_name'])
 
    self.obj_embeddings = nn.Embedding(num_objs + 1, embedding_dim)
    #self.pred_embeddings = nn.Embedding(num_preds, embedding_dim)  
    self.pred_embeddings = nn.Embedding(num_preds + 1 , embedding_dim)  # MASK
  
    # frozen embedding layers 
    self.fr_obj_embeddings = nn.Embedding(num_objs + 1, embedding_dim)
    self.fr_pred_embeddings = nn.Embedding(num_preds + 1, embedding_dim)  
    self.fr_obj_embeddings.requires_grad = False
    self.fr_pred_embeddings.requires_grad = False

    # postional embeddings for bounding boxes (for spatio-semantic retrieval)
    bbox_dim = 4
    self.positional_embeddings = nn.Linear(bbox_dim, embedding_dim)

    if gconv_num_layers == 0:
      self.gconv = nn.Linear(embedding_dim, gconv_dim)
    elif gconv_num_layers > 0:
      gconv_kwargs = {
        'input_dim': embedding_dim,
        'output_dim': gconv_dim,
        'hidden_dim': gconv_hidden_dim,
        'pooling': gconv_pooling,
        'mlp_normalization': mlp_normalization,
      }
      self.gconv = GraphTripleConv(**gconv_kwargs)

    self.gconv_net = None
    if gconv_num_layers > 1:
      gconv_kwargs = {
        'input_dim': gconv_dim,
        'hidden_dim': gconv_hidden_dim,
        'pooling': gconv_pooling,
        'num_layers': gconv_num_layers - 1,
        'mlp_normalization': mlp_normalization,
      }
      self.gconv_net = GraphTripleConvNet(**gconv_kwargs)

    if self.use_bbox_info:
      box_net_dim = 4 + 1 # augment with addition info abt bbox
    else:
      box_net_dim = 4
    box_net_layers = [gconv_dim, gconv_hidden_dim, box_net_dim]
    self.box_net = build_mlp(box_net_layers, batch_norm=mlp_normalization)

    # triplet-related nets 
    self.triplet_box_net = None
    self.triplet_embed_net = None
    self.triplet_mask_net = None
    self.triplet_superbox_net = None
    self.pred_ground_net = None

    # output dimension
    triplet_box_net_dim = 8
    if triplet_box_net:
      # input dimension is 3*128 for concatenated triplet
      triplet_box_net_layers = [3*gconv_dim, gconv_hidden_dim, triplet_box_net_dim]
      self.triplet_box_net = build_mlp(triplet_box_net_layers, batch_norm=mlp_normalization)

    # triplet embedding 
    if self.triplet_embedding_size > 0: 
      # input dimn is 3*128 for concatenated triplet, output dimsn is triplet_embed_dim
      triplet_embed_layers = [3*gconv_dim, gconv_hidden_dim, triplet_embedding_size]
      self.triplet_embed_net = build_mlp(triplet_embed_layers, batch_norm=mlp_normalization)

    if self.triplet_mask_size > 0:
      # input dimsn is 3*128 for concatenated triplet, output dimsn is triplet_mask_size
      #self.triplet_mask_net = self._build_mask_net(num_objs, 3*gconv_dim, self.triplet_mask_size)
      self.triplet_mask_net = self._build_triplet_mask_net(num_objs, 3*gconv_dim, self.triplet_mask_size)

    triplet_superbox_net_dim = 4
    if triplet_superbox_net:
      # input dimension is 3*128 for concatenated triplet
      triplet_superbox_net_layers = [3*gconv_dim, gconv_hidden_dim, triplet_superbox_net_dim]
      self.triplet_superbox_net = build_mlp(triplet_superbox_net_layers, batch_norm=mlp_normalization)

    self.mask_net = None
    if mask_size is not None and mask_size > 0:
      self.mask_net = self._build_mask_net(num_objs, gconv_dim, mask_size)

    ###########################
    self.sg_context_net = None
    self.sg_context_net_d = None
    if sg_context_dim is not None and sg_context_dim > 0:
      H, W = self.image_size
      self.sg_context_net = nn.Linear(gconv_dim, sg_context_dim)
      self.sg_context_net_d = nn.Linear(gconv_dim, sg_context_dim_d) 
      # sg_context_net_layers = [gconv_dim, sg_context_dim]
      # sg_context_net_layers = [gconv_dim, sg_context_dim_d]
      # self.sg_context_net = build_mlp(sg_context_net_layers, batch_norm=mlp_normalization)
      # self.sg_context_net_d = build_mlp(sg_context_net_layers, batch_norm=mlp_normalization)
    ####################### 

    rel_aux_layers = [2 * embedding_dim + 8, gconv_hidden_dim, num_preds]
    self.rel_aux_net = build_mlp(rel_aux_layers, batch_norm=mlp_normalization)

    # subject prediction network
    subj_aux_layers = [2 * embedding_dim + 8, gconv_hidden_dim, num_objs]
    self.subj_aux_net = build_mlp(subj_aux_layers, batch_norm=mlp_normalization)
    
    # object prediction network
    obj_aux_layers = [2 * embedding_dim + 8, gconv_hidden_dim, num_objs]
    self.obj_aux_net = build_mlp(obj_aux_layers, batch_norm=mlp_normalization)
   
    # object class prediction network 
    obj_class_aux_layers = [embedding_dim, gconv_hidden_dim, num_objs]
    #self.obj_class_aux_net = build_mlp(obj_class_aux_layers, batch_norm=mlp_normalization)
    self.obj_class_aux_net = nn.Linear(embedding_dim, num_objs) 

    # relationship embedding network
    self.rel_embed_aux_net = nn.Linear(embedding_dim, embedding_dim) 
    # relationship class prediction network 
    self.rel_class_aux_net = nn.Linear(embedding_dim, num_preds) 

    # predicate mask prediction network
    pred_mask_layers = [2 * embedding_dim, gconv_hidden_dim, num_preds]
    self.pred_mask_net = build_mlp(pred_mask_layers, batch_norm=mlp_normalization)

    pred_ground_net_dim = 4
    # input dimension 128 for relationship
    pred_ground_net_layers = [gconv_dim, gconv_hidden_dim, pred_ground_net_dim]
    self.pred_ground_net = build_mlp(pred_ground_net_layers, batch_norm=mlp_normalization)

    # input dimn is 3*128 for concatenated triplet
    triplet_context_layers = [4*gconv_dim, gconv_hidden_dim, 3*gconv_dim]
    #self.triplet_context_net = nn.Linear(4*gconv_dim, 3*gconv_dim) 
    #triplet_context_layers = [3*gconv_dim, gconv_hidden_dim, 4]
    #self.triplet_context_net = nn.Linear(4*gconv_dim, 3*gconv_dim) 
    self.triplet_context_net = build_mlp(triplet_context_layers, batch_norm=mlp_normalization)

    if sg_context_dim > 0:
      refinement_kwargs = {
      'dims': (gconv_dim + sg_context_dim + layout_noise_dim,) + refinement_dims,
      'normalization': normalization,
      'activation': activation,
    }
    else:
      refinement_kwargs = {
        'dims': (gconv_dim + layout_noise_dim,) + refinement_dims,
        'normalization': normalization,
        'activation': activation,
      }
    self.refinement_net = RefinementNetwork(**refinement_kwargs)
示例#4
0
    def __init__(self,
                 vocab,
                 image_size=(64, 64),
                 embedding_dim=64,
                 gconv_dim=128,
                 gconv_hidden_dim=512,
                 gconv_pooling='avg',
                 gconv_num_layers=5,
                 refinement_dims=(1024, 512, 256, 128, 64),
                 normalization='batch',
                 activation='leakyrelu-0.2',
                 mask_size=None,
                 mlp_normalization='none',
                 layout_noise_dim=0,
                 **kwargs):
        super(ReflectionModel, self).__init__()

        if len(kwargs) > 0:
            print("WARNING: Model got unexpected kwargs ", kwargs)

        self.vocab = vocab
        self.image_size = image_size
        self.layout_noise_dim = layout_noise_dim

        num_objs = len(vocab['object_idx_to_name'])
        num_preds = len(vocab['pred_idx_to_name'])
        self.obj_embeddings = nn.Embedding(num_objs + 1, embedding_dim)
        self.pred_embeddings = nn.Embedding(num_preds, embedding_dim)

        if gconv_num_layers == 0:
            self.gconv = nn.Linear(embedding_dim, gconv_dim)
        elif gconv_num_layers > 0:
            gconv_kwargs = {
                'input_dim': embedding_dim,
                'output_dim': gconv_dim,
                'hidden_dim': gconv_hidden_dim,
                'pooling': gconv_num_layers - 1,
                'mlp_normalization': mlp_normalization
            }
            self.gconv = GraphTripleConv(**gconv_kwargs)

        self.gconv_net = None
        if gconv_num_layers > 1:
            gconv_kwargs = {
                'input_dim': gconv_dim,
                'hidden_dim': gconv_hidden_dim,
                'pooling': gconv_pooling,
                'num_layers': gconv_num_layers - 1,
                'mlp_normalization': mlp_normalization,
            }
            self.gconv_net = GraphTripleConvNet(**gconv_kwargs)

        # Network for regressing bounding box using multilayer perceptual
        box_net_dim = 4
        box_net_layers = [gconv_dim, gconv_hidden_dim, box_net_dim]
        self.box_net = build_mlp(box_net_layers, batch_norm=mlp_normalization)

        # Network for regress segmentation.
        self.mask_net = None
        if mask_size is not None and mask_size > 0:
            self.mask_net = self._build_mask_net(num_objs, gconv_dim,
                                                 mask_size)

        #    rel_aux_layers >>??
        rel_aux_layers = [2 * embedding_dim + 8, gconv_hidden_dim, num_preds]
        self.rel_aux_net = build_mlp(rel_aux_layers,
                                     batch_norm=mlp_normalization)

        # Define cascaded refinement network.
        refinement_kwargs = {
            'dims': (gconv_dim + layout_noise_dim, ) + refinement_dims,
            'normalization': normalization,
            'activation': activation,
        }
        self.refinement_net = RefinementNetwork(**refinement_kwargs)
示例#5
0
文件: model.py 项目: yanivbenny/sg2im
    def __init__(self,
                 vocab,
                 image_size=(64, 64),
                 embedding_dim=64,
                 gconv_dim=128,
                 gconv_hidden_dim=512,
                 gconv_pooling='avg',
                 gconv_num_layers=5,
                 refinement_dims=(1024, 512, 256, 128, 64),
                 normalization='batch',
                 activation='leakyrelu-0.2',
                 mask_size=None,
                 mlp_normalization='none',
                 layout_noise_dim=0,
                 **kwargs):
        super(Sg2ImModel, self).__init__()

        # We used to have some additional arguments:
        # vec_noise_dim, gconv_mode, box_anchor, decouple_obj_predictions
        if len(kwargs) > 0:
            print('WARNING: Model got unexpected kwargs ', kwargs)

        self.vocab = vocab  # discitonary/dataframe
        self.image_size = image_size  # H, W
        self.layout_noise_dim = layout_noise_dim  # scalar

        num_objs = len(vocab['object_idx_to_name'])
        num_preds = len(vocab['pred_idx_to_name'])
        self.obj_embeddings = nn.Embedding(num_objs + 1, embedding_dim)
        self.pred_embeddings = nn.Embedding(num_preds, embedding_dim)

        ## Load graph convolution model
        # There are 2 graph convolution model self.gconv and self.gconv_net
        # self.gconv is required, self.gconv_net is an additional (optional) model after self.gconv
        # The reason is to separate one that accepts embedding to the ones that follows with constant input_dim
        if gconv_num_layers == 0:
            self.gconv = nn.Linear(embedding_dim, gconv_dim)
        elif gconv_num_layers > 0:
            gconv_kwargs = {
                'input_dim': embedding_dim,
                'output_dim': gconv_dim,
                'hidden_dim': gconv_hidden_dim,
                'pooling': gconv_pooling,
                'mlp_normalization': mlp_normalization,
            }
            self.gconv = GraphTripleConv(**gconv_kwargs)

        self.gconv_net = None
        if gconv_num_layers > 1:
            gconv_kwargs = {
                'input_dim': gconv_dim,
                'hidden_dim': gconv_hidden_dim,
                'pooling': gconv_pooling,
                'num_layers': gconv_num_layers - 1,
                'mlp_normalization': mlp_normalization,
            }
            self.gconv_net = GraphTripleConvNet(**gconv_kwargs)

        ## BBOX model
        box_net_dim = 4
        box_net_layers = [gconv_dim, gconv_hidden_dim, box_net_dim]
        self.box_net = build_mlp(box_net_layers, batch_norm=mlp_normalization)

        # MASK model
        self.mask_net = None
        if mask_size is not None and mask_size > 0:
            self.mask_net = self._build_mask_net(num_objs, gconv_dim,
                                                 mask_size)

        # AUX model
        # TODO: what task? retrieving the relation between 2 nodes?
        rel_aux_layers = [2 * embedding_dim + 8, gconv_hidden_dim, num_preds]
        self.rel_aux_net = build_mlp(rel_aux_layers,
                                     batch_norm=mlp_normalization)

        # Refinement model
        refinement_kwargs = {
            'dims': (gconv_dim + layout_noise_dim, ) + refinement_dims,
            'normalization': normalization,
            'activation': activation,
        }
        self.refinement_net = RefinementNetwork(**refinement_kwargs)
示例#6
0
    def __init__(self,
                 vocab,
                 image_size=(64, 64),
                 embedding_dim=64,
                 gconv_dim=128,
                 gconv_hidden_dim=512,
                 gconv_pooling='avg',
                 gconv_num_layers=5,
                 refinement_dims=(1024, 512, 256, 128, 64),
                 normalization='batch',
                 activation='leakyrelu-0.2',
                 mask_size=None,
                 mlp_normalization='none',
                 layout_noise_dim=0,
                 model_type=None,
                 **kwargs):
        super(Sg2ImModel, self).__init__()

        # We used to have some additional arguments:
        # vec_noise_dim, gconv_mode, box_anchor, decouple_obj_predictions
        if len(kwargs) > 0:
            print('WARNING: Model got unexpected kwargs ', kwargs)

        self.vocab = vocab
        self.image_size = image_size
        self.layout_noise_dim = layout_noise_dim

        num_objs = len(vocab['object_idx_to_name'])
        num_preds = len(vocab['pred_idx_to_name'])
        self.obj_embeddings = nn.Embedding(num_objs + 1, embedding_dim)
        self.pred_embeddings = nn.Embedding(num_preds, embedding_dim)

        if gconv_num_layers == 0:
            self.gconv = nn.Linear(embedding_dim, gconv_dim)
        elif gconv_num_layers > 0:
            gconv_kwargs = {
                'input_dim': embedding_dim,
                'output_dim': gconv_dim,
                'hidden_dim': gconv_hidden_dim,
                'pooling': gconv_pooling,
                'mlp_normalization': mlp_normalization,
                'model_type': model_type,
            }
            model_constructor = None
            if model_type == 'baseline':
                model_constructor = GraphTripleConv
            elif model_type == 'random-walk-baseline':
                model_constructor = GraphTripleRandomWalkConv
            elif model_type == 'rnn-baseline':
                model_constructor = GraphTripleRnnConv
            elif model_type == 'graphsage-maxpool':
                model_constructor = GraphSageMaxPoolConv
            elif model_type == 'graphsage-lstm':
                model_constructor = GraphSageLSTMConv
            elif model_type == 'graphsage-mean':
                model_constructor = GraphSageMeanConv
            elif model_type == 'gat-baseline':
                model_constructor = GraphAttnConv
            print("gconv_kwargs", gconv_kwargs)
            print("model_type", model_type)
            self.gconv = model_constructor(**gconv_kwargs)

        self.gconv_net = None
        if gconv_num_layers > 1:
            gconv_kwargs = {
                'input_dim': gconv_dim,
                'hidden_dim': gconv_hidden_dim,
                'pooling': gconv_pooling,
                'num_layers': gconv_num_layers - 1,
                'mlp_normalization': mlp_normalization,
                'model_type': model_type,
            }
            self.gconv_net = GraphTripleConvNet(**gconv_kwargs)

        box_net_dim = 4
        box_net_layers = [gconv_dim, gconv_hidden_dim, box_net_dim]
        self.box_net = build_mlp(box_net_layers, batch_norm=mlp_normalization)

        self.mask_net = None
        if mask_size is not None and mask_size > 0:
            self.mask_net = self._build_mask_net(num_objs, gconv_dim,
                                                 mask_size)

        rel_aux_layers = [2 * embedding_dim + 8, gconv_hidden_dim, num_preds]
        self.rel_aux_net = build_mlp(rel_aux_layers,
                                     batch_norm=mlp_normalization)

        refinement_kwargs = {
            'dims': (gconv_dim + layout_noise_dim, ) + refinement_dims,
            'normalization': normalization,
            'activation': activation,
        }
        self.refinement_net = RefinementNetwork(**refinement_kwargs)