def __init__(self, feature_dim, output_dims, downsample_rate, args=None): super().__init__() self.pool_size = 7 self.feature_dim = feature_dim self.output_dims = output_dims self.downsample_rate = downsample_rate self.args = args if self.args.rel_box_flag: self.col_fuse = nn.Linear(128 * 4 * 4, output_dims[1]) self.object_roi_pool = jacnn.PrRoIPool2D(self.pool_size, self.pool_size, 1.0 / downsample_rate) self.context_roi_pool = jacnn.PrRoIPool2D(self.pool_size, self.pool_size, 1.0 / downsample_rate) self.relation_roi_pool = jacnn.PrRoIPool2D(self.pool_size, self.pool_size, 1.0 / downsample_rate) if not DEBUG: self.context_feature_extract = nn.Conv2d(feature_dim, feature_dim, 1) self.relation_feature_extract = nn.Conv2d(feature_dim, feature_dim // 2 * 3, 1) self.object_feature_fuse = nn.Conv2d(feature_dim * 2, output_dims[1], 1) self.relation_feature_fuse = nn.Conv2d( feature_dim // 2 * 3 + output_dims[1] * 2, output_dims[2], 1) self.object_feature_fc = nn.Sequential( nn.ReLU(True), nn.Linear(output_dims[1] * self.pool_size**2, output_dims[1])) self.relation_feature_fc = nn.Sequential( nn.ReLU(True), nn.Linear(output_dims[2] * self.pool_size**2, output_dims[2])) self.reset_parameters() else: def gen_replicate(n): def rep(x): return torch.cat([x for _ in range(n)], dim=1) return rep self.pool_size = 32 self.object_roi_pool = jacnn.PrRoIPool2D(32, 32, 1.0 / downsample_rate) self.context_roi_pool = jacnn.PrRoIPool2D(32, 32, 1.0 / downsample_rate) self.relation_roi_pool = jacnn.PrRoIPool2D(32, 32, 1.0 / downsample_rate) self.context_feature_extract = gen_replicate(2) self.relation_feature_extract = gen_replicate(3) self.object_feature_fuse = jacnn.Identity() self.relation_feature_fuse = jacnn.Identity()
def __init__(self, vocab, configs): super().__init__() self.vocab = vocab #pdb.set_trace() import jactorch.models.vision.resnet as resnet self.resnet = resnet.resnet34(pretrained=True, incl_gap=False, num_classes=None) self.resnet.layer4 = jacnn.Identity() import nscl.nn.scene_graph.scene_graph as sng # number of channels = 256; downsample rate = 16. self.scene_graph = sng.SceneGraph(256, configs.model.sg_dims, 16) import nscl.nn.reasoning_v1.quasi_symbolic as qs self.reasoning = qs.DifferentiableReasoning( self._make_vse_concepts(configs.model.vse_large_scale, configs.model.vse_known_belong), self.scene_graph.output_dims, configs.model.vse_hidden_dims) import nscl.nn.reasoning_v1.losses as vqa_losses self.scene_loss = vqa_losses.SceneParsingLoss( gdef.all_concepts, add_supervision=configs.train.scene_add_supervision) self.qa_loss = vqa_losses.QALoss( add_supervision=configs.train.qa_add_supervision)
def build(self): if self.use_vision: import jactorch.models.vision.resnet as resnet self.resnet = resnet.resnet34(pretrained=True, incl_gap=False, num_classes=None) self.resnet.layer4 = jacnn.Identity() self.mlp = jacnn.MLPLayer(256 + 128 * 2, len(self.tools.answers), [512]) else: self.mlp = jacnn.MLPLayer(128 * 2, len(self.tools.answers), [256]) padding_idx = self.tools.words['<NULL>'] self.embedding = nn.Embedding(self.num_vocab, self.dim, padding_idx=padding_idx) self.gru = jacnn.GRULayer(self.dim, 128, 1, bidirectional=True, batch_first=True, dropout=0.1) self.loss_fn = F.nll_loss if self.use_lm: self.gru_dropout = nn.Dropout(0.1) self.decode = nn.Linear(128 * 2, self.num_vocab) self.decode.bias.data.zero_() self.decode_loss = jacnn.CrossEntropyLoss(average='none')
def __init__(self, vocab, configs, args=None): super().__init__() self.vocab = vocab self.args = args #pdb.set_trace() import jactorch.models.vision.resnet as resnet self.resnet = resnet.resnet34(pretrained=True, incl_gap=False, num_classes=None) self.resnet.layer4 = jacnn.Identity() import clevrer.models.scene_graph as sng # number of channels = 256; downsample rate = 16. #pdb.set_trace() self.scene_graph = sng.SceneGraph(256, configs.model.sg_dims, 16, args=configs) #pdb.set_trace() import clevrer.models.quasi_symbolic as qs if configs.rel_box_flag: self.scene_graph.output_dims[ 2] = self.scene_graph.output_dims[2] * 2 if configs.dynamic_ftr_flag and ( not self.args.box_only_for_collision_flag): self.scene_graph.output_dims[2] = self.scene_graph.output_dims[ 2] + self.scene_graph.output_dims[3] * 4 elif configs.dynamic_ftr_flag and self.args.box_only_for_collision_flag: self.scene_graph.output_dims[ 2] = self.scene_graph.output_dims[3] * 4 if self.args.box_iou_for_collision_flag: box_dim = 4 self.scene_graph.output_dims[2] += int( self.scene_graph.output_dims[3] / box_dim) self.reasoning = qs.DifferentiableReasoning( self._make_vse_concepts(configs.model.vse_large_scale, configs.model.vse_known_belong), self.scene_graph.output_dims, configs.model.vse_hidden_dims, args=self.args) import clevrer.losses as vqa_losses self.scene_loss = vqa_losses.SceneParsingLoss( gdef.all_concepts_clevrer, add_supervision=configs.train.scene_add_supervision, args=self.args) self.qa_loss = vqa_losses.QALoss( add_supervision=configs.train.qa_add_supervision)
def __init__(self, configs, args=None): super().__init__() self.args = args configs.colli_ftr_type = args.colli_ftr_type import jactorch.models.vision.resnet as resnet self.resnet = resnet.resnet34(pretrained=True, incl_gap=False, num_classes=None) self.resnet.layer4 = jacnn.Identity() import clevrer.models.scene_graph as sng # number of channels = 256; downsample rate = 16. self.scene_graph = sng.SceneGraph(256, configs.model.sg_dims, 16, args=configs) import clevrer.models.quasi_symbolic_v2 as qs ftr_dim = self.scene_graph.output_dims[3] box_dim = 4 time_step = int(ftr_dim / box_dim) offset = time_step % self.args.smp_coll_frm_num seg_frm_num = int((time_step - offset) / self.args.smp_coll_frm_num) if configs.rel_box_flag: self.scene_graph.output_dims[ 2] = self.scene_graph.output_dims[2] * 2 if configs.dynamic_ftr_flag and ( not self.args.box_only_for_collision_flag): self.scene_graph.output_dims[2] = self.scene_graph.output_dims[ 2] + seg_frm_num * 4 * box_dim elif configs.dynamic_ftr_flag and self.args.box_only_for_collision_flag: self.scene_graph.output_dims[2] = seg_frm_num * 4 * box_dim if self.args.box_iou_for_collision_flag: box_dim = 4 self.scene_graph.output_dims[2] += seg_frm_num self.reasoning = qs.DifferentiableReasoning( self._make_vse_concepts(configs.model.vse_known_belong), self.scene_graph.output_dims, configs.model.vse_hidden_dims, args=self.args, seg_frm_num=seg_frm_num) #pdb.set_trace() import clevrer.losses_v2 as vqa_losses self.scene_loss = vqa_losses.SceneParsingLoss( gdef.all_concepts_clevrer, add_supervision=configs.train.scene_add_supervision, args=self.args) self.qa_loss = vqa_losses.QALoss( add_supervision=configs.train.qa_add_supervision, args=self.args)
def __init__(self, feature_dim, output_dims, downsample_rate, object_supervision=False, concatenative_pair_representation=True): super().__init__() self.pool_size = 7 self.feature_dim = feature_dim self.output_dims = output_dims self.downsample_rate = downsample_rate self.object_supervision = object_supervision self.concatenative_pair_representation = concatenative_pair_representation if self.object_supervision: self.object_roi_pool = jacnn.PrRoIPool2D(self.pool_size, self.pool_size, 1.0 / downsample_rate) self.context_roi_pool = jacnn.PrRoIPool2D(self.pool_size, self.pool_size, 1.0 / downsample_rate) self.relation_roi_pool = jacnn.PrRoIPool2D(self.pool_size, self.pool_size, 1.0 / downsample_rate) if not DEBUG: self.context_feature_extract = nn.Conv2d( feature_dim, feature_dim, 1) self.relation_feature_extract = nn.Conv2d( feature_dim, feature_dim // 2 * 3, 1) self.object_feature_fuse = nn.Conv2d(feature_dim * 2, output_dims[1], 1) self.relation_feature_fuse = nn.Conv2d( feature_dim // 2 * 3 + output_dims[1] * 2, output_dims[2], 1) self.object_feature_fc = nn.Sequential( nn.ReLU(True), nn.Linear(output_dims[1] * self.pool_size**2, output_dims[1])) self.relation_feature_fc = nn.Sequential( nn.ReLU(True), nn.Linear(output_dims[2] * self.pool_size**2, output_dims[2])) self.obj1_linear = nn.Linear(output_dims[1], output_dims[1]) self.obj2_linear = nn.Linear(output_dims[1], output_dims[1]) self.reset_parameters() else: def gen_replicate(n): def rep(x): return torch.cat([x for _ in range(n)], dim=1) return rep self.pool_size = 32 self.object_roi_pool = jacnn.PrRoIPool2D( 32, 32, 1.0 / downsample_rate) self.context_roi_pool = jacnn.PrRoIPool2D( 32, 32, 1.0 / downsample_rate) self.relation_roi_pool = jacnn.PrRoIPool2D( 32, 32, 1.0 / downsample_rate) self.context_feature_extract = gen_replicate(2) self.relation_feature_extract = gen_replicate(3) self.object_feature_fuse = jacnn.Identity() self.relation_feature_fuse = jacnn.Identity() else: self.num_objects_upperbound = 11 self.temperature = 6 self.object_coord_fuse = nn.Sequential( nn.Conv2d(feature_dim + 2, feature_dim, kernel_size=1), nn.ReLU(True)) self.query = nn.Parameter( torch.randn(self.num_objects_upperbound, feature_dim)) self.object_features_layer = nn.Sequential( nn.Linear(feature_dim, output_dims[1]), nn.ReLU(True)) self.obj1_linear = nn.Linear(output_dims[1], output_dims[1]) self.obj2_linear = nn.Linear(output_dims[1], output_dims[1]) self.reset_parameters()