def forward(self, feed_dict):
        feed_dict = GView(feed_dict)
        monitors, outputs = {}, {}

        f_scene = self.resnet(
            feed_dict.image)  # [batch_size=32,n_channels=256,h=16,w=24]
        f_sng = self.scene_graph(
            f_scene, feed_dict.image,
            feed_dict.objects_mask if self.true_mask else None)

        programs = feed_dict.program_qsseq
        programs, buffers, answers = self.reasoning(f_sng,
                                                    programs,
                                                    fd=feed_dict)
        outputs['buffers'] = buffers
        outputs['answer'] = answers

        update_from_loss_module(monitors, outputs,
                                self.scene_graph.get_monitor())
        update_from_loss_module(monitors, outputs,
                                self.qa_loss(feed_dict, answers))
        canonize_monitors(monitors)

        if self.training:
            loss = monitors[
                'loss/qa'] + monitors['loss/monet'] * self.loss_ratio
            return loss, monitors, outputs
        else:
            outputs['monitors'] = monitors
            outputs['buffers'] = buffers
            return outputs
Esempio n. 2
0
    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)

        states = feed_dict.states.float()
        f = self.get_binary_relations(states)
        logits = self.pred(f).squeeze(dim=-1).view(states.size(0), -1)
        policy = F.softmax(logits, dim=-1).clamp(min=1e-20)

        if not self.training:
            return dict(policy=policy, logits=logits)

        pred_states = feed_dict.pred_states.float()
        f = self.get_binary_relations(pred_states, depth=args.pred_depth)
        f = self.pred_valid(f).squeeze(dim=-1).view(pred_states.size(0), -1)
        # Set minimal value to avoid loss to be nan.
        valid = f[range(pred_states.size(0)),
                  feed_dict.pred_actions].clamp(min=1e-20)

        loss, monitors = self.loss(policy, feed_dict.actions,
                                   feed_dict.discount_rewards,
                                   feed_dict.entropy_beta)
        pred_loss = self.pred_loss(valid, feed_dict.valid)
        monitors['pred/accuracy'] = feed_dict.valid.eq(
            (valid > 0.5).float()).float().mean()
        loss = loss + args.pred_weight * pred_loss
        return loss, monitors, dict()
Esempio n. 3
0
    def forward(self, feed_dict, y_hat, additional_info=None):
        feed_dict = GView(feed_dict)
        # Pdb().set_trace()
        relations = feed_dict.relations.float()

        batch_size, nr = relations.size()[:2]

        #states = feed_dict.query.float()
        # @TODO : should we give x as input as well?
        if self.task_is_futoshiki:
            states = torch.stack([
                y_hat - feed_dict.target.float(),
                feed_dict.query[:, :, 1].float(), feed_dict.query[:, :,
                                                                  2].float()
            ], 2)
        elif self.task_is_sudoku:
            states = y_hat.transpose(1, 2) - torch.nn.functional.one_hot(
                feed_dict.target.long(), 10).float()
        else:
            states = (y_hat - feed_dict.target.float()).unsqueeze(2)
        #

        inp = [None for _ in range(self.latent_breadth + 1)]
        inp[1] = states
        inp[2] = relations

        depth = None
        feature = self.features(inp, depth=None)[self.feature_axis]

        latent_z = self.pred(feature)
        return dict(latent_z=latent_z)
def main():
    initialize_dataset(args.dataset)
    build_symbolic_dataset = get_symbolic_dataset_builder(args.dataset)
    dataset = build_symbolic_dataset(args)
    dataloader = dataset.make_dataloader(32, False, False, nr_workers=4)
    meters = GroupMeters()

    for idx, feed_dict in tqdm_gofor(dataloader):
        feed_dict = GView(feed_dict)

        for i, (p, s, gt) in enumerate(
                zip(feed_dict.program_seq, feed_dict.scene, feed_dict.answer)):
            _, pred = execute_program(p, s)

            if pred[0] == 'error':
                raise pred[1]

            if pred[1] != gt:
                print(p)
                print(s)

                from IPython import embed
                embed()
                from sys import exit
                exit()

            meters.update('accuracy', pred[1] == gt)
        get_current_tqdm().set_description(
            meters.format_simple('Exec:', 'val', compressed=True))

    logger.critical(
        meters.format_simple('Symbolic execution test:',
                             'avg',
                             compressed=False))
def main():
    initialize_dataset(args.dataset)
    build_symbolic_dataset = get_symbolic_dataset_builder(args.dataset)
    dataset = build_symbolic_dataset(args)

    if args.nr_vis is None:
        args.nr_vis = min(100, len(dataset))

    if args.random:
        indices = random.choice(len(dataset), size=args.nr_vis, replace=False)
    else:
        indices = list(range(args.nr_vis))

    vis = HTMLTableVisualizer(args.data_vis_dir,
                              'Dataset: ' + args.dataset.upper())
    vis.begin_html()
    with vis.table('Metainfo', [
            HTMLTableColumnDesc('k', 'Key', 'text', {}, None),
            HTMLTableColumnDesc('v', 'Value', 'code', {}, None)
    ]):
        for k, v in args.__dict__.items():
            vis.row(k=k, v=v)

    with vis.table('Visualize', [
            HTMLTableColumnDesc('id', 'QuestionID', 'text', {}, None),
            HTMLTableColumnDesc('image', 'QA', 'figure', {'width': '100%'},
                                None),
            HTMLTableColumnDesc(
                'qa', 'QA', 'text', css=None, td_css={'width': '30%'}),
            HTMLTableColumnDesc(
                'p', 'Program', 'code', css=None, td_css={'width': '30%'})
    ]):
        for i in tqdm(indices):
            feed_dict = GView(dataset[i])
            image_filename = osp.join(args.data_image_root,
                                      feed_dict.image_filename)
            image = Image.open(image_filename)

            if 'objects' in feed_dict:
                fig, ax = vis_bboxes(image,
                                     feed_dict.objects,
                                     'object',
                                     add_text=False)
            else:
                fig, ax = vis_bboxes(image, [], 'object', add_text=False)
            _ = ax.set_title('object bounding box annotations')

            QA_string = """
                <p><b>Q</b>: {}</p>
                <p><b>A</b>: {}</p>
            """.format(feed_dict.question_raw, feed_dict.answer)
            P_string = '\n'.join([repr(x) for x in feed_dict.program_seq])

            vis.row(id=i, image=fig, qa=QA_string, p=P_string)
            plt.close()
    vis.end_html()

    logger.info(
        'Happy Holiday! You can find your result at "http://monday.csail.mit.edu/xiuming'
        + osp.realpath(args.data_vis_dir) + '".')
Esempio n. 6
0
    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)
        states = None
        if args.is_path_task:
            states = feed_dict.states.float()
            relations = feed_dict.relations.float()
        elif args.is_sort_task:
            relations = feed_dict.states.float()

        def get_features(states, relations, depth=None):
            inp = [None for i in range(args.nlm_breadth + 1)]
            inp[1] = states
            inp[2] = relations
            features = self.features(inp, depth=depth)
            return features

        if args.model == 'memnet':
            f = self.feature(relations, states)
        else:
            f = get_features(states, relations)[self.feature_axis]
        if self.feature_axis == 2:  #sorting task
            f = meshgrid_exclude_self(f)

        logits = self.pred(f).squeeze(dim=-1).view(relations.size(0), -1)
        # Set minimal value to avoid loss to be nan.
        policy = F.softmax(logits, dim=-1).clamp(min=1e-20)

        if self.training:
            loss, monitors = self.loss(policy, feed_dict.actions,
                                       feed_dict.discount_rewards,
                                       feed_dict.entropy_beta)
            return loss, monitors, dict()
        else:
            return dict(policy=policy, logits=logits)
Esempio n. 7
0
    def add_concept(self, feed_dict):
        feed_dict = GView(feed_dict)
        depth = feed_dict.depth
        depth = F.tanh(depth) * 0.5
        inp = torch.cat((feed_dict.image, depth.unsqueeze(1)), axis=1).cuda()
        f_scene = self.resnet(inp)
        f_sng = self.scene_graph(f_scene, feed_dict.objects.cuda(),
                                 feed_dict.objects_length.cuda())
        prototype_features = f_sng[0][1]

        attribute = feed_dict.attribute_name[0]
        concept = feed_dict.concept_name[0]
        # gdef.attribute_concepts[attribute].append(concept)
        # self.scene_loss.used_concepts["attribute"][attribute].append(concept)
        attribute_taxonomy = self.reasoning.embedding_attribute
        attribute_taxonomy.init_concept(
            concept,
            configs.model.vse_hidden_dims[1],
            known_belong=attribute if configs.model.vse_known_belong else None,
        )
        concept_initialization = attribute_taxonomy.get_attribute(attribute)(
            prototype_features)
        with torch.no_grad():
            embedding = attribute_taxonomy.get_concept(concept).embedding
            embedding.data = torch.reshape(concept_initialization,
                                           embedding.shape)
Esempio n. 8
0
    def forward(self, feed_dict, y_hat, additional_info=None):
        feed_dict = GView(feed_dict)
        target = feed_dict["target"].long()
        # y_hat has shape exp_batch_size x 10 x 81 x num_steps
        # x has shape exp_batch_size x 81 x num_steps
        if self.args.latent_sudoku_input_prob:
            x = torch.gather(
                y_hat.softmax(dim=1),
                dim=1,
                index=target.unsqueeze(-1).expand(
                    len(y_hat), 81,
                    self.args.sudoku_num_steps).unsqueeze(1)).squeeze(1)
        else:
            x = y_hat.argmax(dim=1).long()
            x = (x == target.unsqueeze(-1).expand(
                len(y_hat), 81, self.args.sudoku_num_steps)).float()

        # shuffle dimensions to make it exp_batch_size x num_steps x 81
        # reshape it to exp_batch_size x num_steps x 9 x 9
        x = x.transpose(1, 2).view(-1, self.args.sudoku_num_steps, 9, 9)

        for i in range(len(self.layers) + 1):
            x = torch.relu(self._modules["conv_{}".format(i)](x))
        x = x.view(-1, 81)
        return {'latent_z': self.linear(x)}
Esempio n. 9
0
    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)
        feature_f = self._extract_sent_feature(feed_dict.sent_f,
                                               feed_dict.sent_f_length,
                                               self.gru_f)
        feature_b = self._extract_sent_feature(feed_dict.sent_b,
                                               feed_dict.sent_b_length,
                                               self.gru_b)
        feature_img = feed_dict.image

        feature = torch.cat([feature_f, feature_b, feature_img], dim=1)
        predict = self.predict(feature)

        if self.training:
            label = self.embedding(feed_dict.label)
            loss = cosine_loss(predict, label).mean()
            return loss, {}, {}
        else:
            output_dict = dict(pred=predict)
            if 'label' in feed_dict:
                dis = cosine_distance(predict, self.embedding.weight)
                _, topk = dis.topk(1000, dim=1, sorted=True)
                for k in [1, 10, 100, 1000]:
                    output_dict['top{}'.format(k)] = torch.eq(
                        topk,
                        feed_dict.label.unsqueeze(-1))[:, :k].float().sum(
                            dim=1).mean()
            return output_dict
Esempio n. 10
0
    def forward(self, feed_dict, y_hat, additional_info=None):
        feed_dict = GView(feed_dict)
        #        err = (81-(feed_dict["target"].float()==y_hat.argmax(dim=1).float())).sum(dim=1).float().unsqueeze(-1)
        #        err = err*self.avg_weight
        #        return {'latent_z':err}
        #convert it to graph
        g = self.collate_fn(feed_dict, y_hat)

        if self.args.latent_sudoku_input_type == 'pae':
            input_emb = g.ndata.pop('err')
        else:
            ##ALTERNATIVE - USING TYPE_EMB AS WEIGHTS OF A LINEAR LAYER
            #type_context_scores = F.linear(context_representation_bag.squeeze(-1),self.type_embeddings.weight[:num_types]).view(-1,bag_size, num_types)
            avg_emb = torch.mm(g.ndata.pop('prob'), self.digit_embed.weight)
            #print('Latent: ',self.digit_embed.weight.data[2,:4], self.row_embed.weight.data[2,:4])
            #print('Atn over steps: ',self.atn_across_steps)
            if self.args.latent_sudoku_input_type == 'dif':
                input_emb = avg_emb - self.digit_embed(g.ndata.pop('a'))
            else:
                input_emb = torch.cat(
                    [avg_emb, self.digit_embed(g.ndata.pop('a'))], -1)
        #
        #input_digits = self.digit_embed(g.ndata['q'])
        rows = self.row_embed(g.ndata.pop('row'))
        cols = self.col_embed(g.ndata.pop('col'))
        x = self.input_layer(torch.cat([input_emb, rows, cols], -1))

        g.ndata['x'] = x
        g.ndata['h'] = x
        g.ndata['rnn_h'] = torch.zeros_like(x, dtype=torch.float)
        g.ndata['rnn_c'] = torch.zeros_like(x, dtype=torch.float)

        outputs = self.rrn(g, True)[-1]
        outputs = outputs.view(-1, 81, outputs.size(-1))
        max_pool_output, _ = outputs.max(dim=1)
        #Pdb().set_trace()
        attn_wts = F.softmax(
            torch.bmm(outputs, max_pool_output.unsqueeze(-1)) /
            float(outputs.size(-1)),
            dim=1)
        outputs = (outputs * attn_wts.expand_as(outputs)).sum(dim=1)
        logits = self.output_layer(outputs)

        #logits = self.output_layer(outputs)
        #logits : of shape : args.latent_sudoku_num_steps x batchsize*81 x nullary_dim

        #logits = (self.atn_across_steps.unsqueeze(-1).unsqueeze(-1).expand_as(logits)*logits).sum(dim=0)
        #shape:  batchsize*81 x nullary_dim

        #if self.args.selector_model:
        #    logits = logits.view(-1,81,1)
        #else:
        #    logits = logits.view(-1,81,self.args.nlm_nullary_dim)
        #shape: batchsize x 81 x nullary_dim

        #logits = (self.atn_across_nodes.unsqueeze(0).unsqueeze(-1).expand_as(logits)*logits).sum(dim=1)
        #shape: batchsize x nullary_dim

        return {'latent_z': logits}
Esempio n. 11
0
    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)

        # properties
        if args.task_is_adjacent:
            states = feed_dict.states.float()
        else:
            states = None

        # relations
        relations = feed_dict.relations.float()
        batch_size, nr = relations.size()[:2]

        if args.model == 'nlm':
            if args.task_is_adjacent and args.task_is_mnist_input:
                states_shape = states.size()
                states = states.view((-1, ) + states_shape[2:])
                states = self.lenet(states)
                states = states.view(states_shape[:2] + (-1, ))
                states = F.sigmoid(states)

            inp = [None for _ in range(args.nlm_breadth + 1)]
            inp[1] = states
            inp[2] = relations

            depth = None
            if args.nlm_recursion:
                depth = 1
                while 2**depth + 1 < nr:
                    depth += 1
                depth = depth * 2 + 1
            feature = self.features(inp, depth=depth)[self.feature_axis]
        elif args.model == 'memnet':
            feature = self.feature(relations, states)
            if args.task_is_adjacent and args.task_is_mnist_input:
                raise NotImplementedError()

        pred = self.pred(feature)
        if not args.task_is_adjacent:
            pred = pred.squeeze(-1)
        if args.task_is_connectivity:
            pred = meshgrid_exclude_self(pred)  # exclude self-cycle

        if self.training:
            monitors = dict()
            target = feed_dict.target.float()

            if args.task_is_adjacent:
                target = target[:, :, :args.adjacent_pred_colors]

            monitors.update(binary_accuracy(target, pred, return_float=False))

            loss = self.loss(pred, target)
            # ohem loss is unused.
            if args.ohem_size > 0:
                loss = loss.view(-1).topk(args.ohem_size)[0].mean()
            return loss, monitors, dict(pred=pred)
        else:
            return dict(pred=pred)
Esempio n. 12
0
    def forward(self, feed_dict, return_loss_matrix=False):
        feed_dict = GView(feed_dict)

        states = None

        # relations
        relations = feed_dict.relations.float()
        states = feed_dict.query.float()
        batch_size, nr = relations.size()[:2]
        inp = [None for _ in range(self.args.nlm_breadth + 1)]

        inp[1] = states
        inp[2] = relations
        depth = None
        if self.args.nlm_recursion:
            depth = 1
            while 2**depth + 1 < nr:
                depth += 1
            depth = depth * 2 + 1

        pred = self.distributed_pred(inp, depth=depth)

        if self.training:
            monitors = dict()
            target = feed_dict.target
            target = target.float()
            count = None
            if self.args.cc_loss or self.args.min_loss or self.args.naive_pll_loss or 'weights' in feed_dict or return_loss_matrix:
                target = feed_dict.target_set
                target = target.float()
                count = feed_dict.count.int()

            this_meters, _, reward = instance_accuracy(
                feed_dict.target.float(),
                pred,
                return_float=False,
                feed_dict=feed_dict,
                task=self.args.task,
                args=self.args)

            #logger.info("Reward: ")
            # logger.info(reward)
            monitors.update(this_meters)
            loss_matrix = self.loss(pred, target, count)

            if self.args.min_loss or self.args.cc_loss or self.args.naive_pll_loss:
                loss = loss_matrix.mean()
            elif 'weights' in feed_dict:
                loss = (feed_dict.weights*loss_matrix).sum() / \
                    feed_dict.weights.sum()
            else:
                loss = loss_matrix

            return loss, monitors, dict(pred=pred,
                                        reward=reward,
                                        loss_matrix=loss_matrix)
        else:
            return dict(pred=pred)
Esempio n. 13
0
    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)
        monitors, outputs = {}, {}

        vid_shape = feed_dict.video.size()

        B = vid_shape[0]
        N_frames = vid_shape[1]
        video_frames = feed_dict.video.reshape(vid_shape[0] * vid_shape[1],
                                               vid_shape[2], vid_shape[3],
                                               vid_shape[4])
        f_scene = self.resnet(video_frames)
        f_scene = f_scene.reshape(B, N_frames, -1)

        f_scene, _ = self.lstm_video(f_scene)
        f_scene = f_scene[:, -1, :]
        f_scene = f_scene.squeeze()
        f_scene = f_scene.unsqueeze(-1).unsqueeze(-1)

        f_sng = self.scene_graph(f_scene, feed_dict.objects,
                                 feed_dict.objects_length)

        programs = feed_dict.program_qsseq
        programs, buffers, answers = self.reasoning(f_sng,
                                                    programs,
                                                    fd=feed_dict)
        outputs["buffers"] = buffers
        outputs["answer"] = answers

        update_from_loss_module(
            monitors,
            outputs,
            self.scene_loss(
                feed_dict,
                f_sng,
                self.reasoning.embedding_attribute,
                self.reasoning.embedding_relation,
            ),
        )
        update_from_loss_module(monitors, outputs,
                                self.qa_loss(feed_dict, answers))

        canonize_monitors(monitors)

        if self.training:
            loss = monitors["loss/qa"]
            if configs.train.scene_add_supervision:
                loss = loss + monitors["loss/scene"]
            return loss, monitors, outputs
        else:
            outputs["monitors"] = monitors
            outputs["buffers"] = buffers
            return outputs
Esempio n. 14
0
    def __getitem__(self, index):
        info = self.images[index]

        feed_dict = GView()
        feed_dict.image_filename = info['file_name']
        if self.image_root is not None:
            feed_dict.image = Image.open(
                osp.join(self.image_root,
                         feed_dict.image_filename)).convert('RGB')
            feed_dict.image = self.image_transform(feed_dict.image)

        return feed_dict.raw()
Esempio n. 15
0
    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)
        monitors, outputs = {}, {}

        depth = feed_dict.depth
        depth = F.tanh(depth) * 0.5
        inp = torch.cat((feed_dict.image, depth.unsqueeze(1)), axis=1)

        f_scene = self.resnet(inp)
        f_sng = self.scene_graph(f_scene, feed_dict.objects,
                                 feed_dict.objects_length)

        programs = feed_dict.program_qsseq
        programs, buffers, answers = self.reasoning(f_sng,
                                                    programs,
                                                    fd=feed_dict)
        outputs["buffers"] = buffers
        outputs["answer"] = answers

        update_from_loss_module(
            monitors,
            outputs,
            self.scene_loss(
                feed_dict,
                f_sng,
                self.reasoning.embedding_attribute,
                self.reasoning.embedding_relation,
            ),
        )
        update_from_loss_module(monitors, outputs,
                                self.qa_loss(feed_dict, answers))

        canonize_monitors(monitors)

        if self.training:
            loss = monitors["loss/qa"]
            if configs.train.full_scene_supervision:
                loss = loss + monitors["loss/scene"]
            return loss, monitors, outputs
        else:
            outputs["monitors"] = monitors
            outputs["buffers"] = buffers
            return outputs
Esempio n. 16
0
    def forward(self, feed_dict, y_hat, additional_info=None):
        feed_dict = GView(feed_dict)
        target = feed_dict["target"].long()
        # y_hat has shape exp_batch_size x 10 x 81 x num_steps
        # x has shape exp_batch_size x 81 x num_steps
        if self.args.latent_sudoku_input_prob:
            x = torch.gather(
                y_hat.softmax(dim=1),
                dim=1,
                index=target.unsqueeze(-1).expand(
                    len(y_hat), 81,
                    self.args.sudoku_num_steps).unsqueeze(1)).squeeze(1)
        else:
            x = y_hat.argmax(dim=1).long()
            x = (x == target.unsqueeze(-1).expand(
                len(y_hat), 81, self.args.sudoku_num_steps)).float()
        x = (x * self.atn_over_steps).sum(dim=2)

        x = torch.relu(self.flinear(x))
        #return {'latent_z':(err*self.weight).sum(dim=1,keepdim=True)}
        return {'latent_z': self.linear(x)}
Esempio n. 17
0
    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)

        monitors, outputs = {}, {}

        f_scene = self.resnet(
            feed_dict.image)  # [batch_size=32,n_channels=256,h=16,w=24]
        f_sng = self.scene_graph(f_scene, feed_dict.objects,
                                 feed_dict.objects_length)

        programs = feed_dict.program_qsseq
        programs, buffers, answers = self.reasoning(f_sng,
                                                    programs,
                                                    fd=feed_dict)
        outputs['buffers'] = buffers
        outputs['answer'] = answers

        update_from_loss_module(
            monitors, outputs,
            self.scene_loss(feed_dict, f_sng,
                            self.reasoning.embedding_attribute,
                            self.reasoning.embedding_relation))
        update_from_loss_module(monitors, outputs,
                                self.qa_loss(feed_dict, answers))

        canonize_monitors(monitors)

        if self.training:
            loss = monitors['loss/qa']
            if configs.train.scene_add_supervision:
                loss = loss + monitors['loss/scene']
            return loss, monitors, outputs
        else:
            outputs['monitors'] = monitors
            outputs['buffers'] = buffers
            return outputs
Esempio n. 18
0
 def forward(self, feed_dict):
     feed_dict = GView(feed_dict)
     f = self.resnet(feed_dict.image)
     output_dict = {'features': f}
     return output_dict
Esempio n. 19
0
def validate_epoch(epoch,
                   model,
                   val_dataloader,
                   meters,
                   meter_prefix='validation'):
    end = time.time()

    visualized = 0
    vis = HTMLTableVisualizer(args.vis_dir, 'NSCL Execution Visualization')
    vis.begin_html()

    try:
        with tqdm_pbar(total=len(val_dataloader)) as pbar:
            for feed_dict in val_dataloader:
                if args.use_gpu:
                    if not args.gpu_parallel:
                        feed_dict = async_copy_to(feed_dict, 0)

                data_time = time.time() - end
                end = time.time()

                output_dict = model(feed_dict)
                monitors = {
                    meter_prefix + '/' + k: v
                    for k, v in as_float(output_dict['monitors']).items()
                }
                step_time = time.time() - end
                end = time.time()

                n = feed_dict['image'].size(0)
                meters.update(monitors, n=n)
                meters.update({'time/data': data_time, 'time/step': step_time})

                feed_dict = GView(as_detached(as_cpu(feed_dict)))
                output_dict = GView(as_detached(as_cpu(output_dict)))

                for i in range(n):
                    with vis.table(
                            'Visualize #{} Metainfo'.format(visualized), [
                                HTMLTableColumnDesc('id', 'QID', 'text',
                                                    {'width': '50px'}),
                                HTMLTableColumnDesc('image', 'Image', 'figure',
                                                    {'width': '400px'}),
                                HTMLTableColumnDesc('qa', 'QA', 'text',
                                                    {'width': '200px'}),
                                HTMLTableColumnDesc('p', 'Program', 'code',
                                                    {'width': '200px'})
                            ]):
                        image_filename = osp.join(args.data_image_root,
                                                  feed_dict.image_filename[i])
                        image = Image.open(image_filename)
                        fig, ax = vis_bboxes(image,
                                             feed_dict.objects_raw[i],
                                             'object',
                                             add_text=False)
                        _ = ax.set_title('object bounding box annotations')
                        QA_string = """
                            <p><b>Q</b>: {}</p>
                            <p><b>A</b>: {}</p>
                        """.format(feed_dict.question_raw[i],
                                   feed_dict.answer[i])
                        P_string = '\n'.join(
                            [repr(x) for x in feed_dict.program_seq[i]])

                        vis.row(id=i, image=fig, qa=QA_string, p=P_string)
                        plt.close()

                    with vis.table(
                            'Visualize #{} Metainfo'.format(visualized), [
                                HTMLTableColumnDesc('id', 'QID', 'text',
                                                    {'width': '50px'}),
                                HTMLTableColumnDesc('image', 'Image', 'figure',
                                                    {'width': '400px'}),
                                HTMLTableColumnDesc('mask', 'Mask', 'figure',
                                                    {'width': '700px'})
                            ]):
                        image_filename = osp.join(args.data_image_root,
                                                  feed_dict.image_filename[i])
                        image = Image.open(image_filename)
                        fig, ax = vis_bboxes(image,
                                             feed_dict.objects_raw[i],
                                             'object',
                                             add_text=False)
                        _ = ax.set_title('object bounding box annotations')
                        if not args.show_mask:
                            montage = fig
                        else:
                            num_slots = output_dict['monet/m'].shape[1]
                            monet_fig = [
                                [
                                    tensor2im(output_dict['monet/m'][i, k])
                                    for k in range(num_slots)
                                ],
                                [
                                    tensor2im(output_dict['monet/x'][i, k])
                                    for k in range(num_slots)
                                ],
                                [
                                    tensor2im(output_dict['monet/xm'][i, k])
                                    for k in range(num_slots)
                                ],
                                [tensor2im(output_dict['monet/x_input'][i])] +
                                [
                                    tensor2im(output_dict['monet/x_tilde'][i])
                                    for k in range(num_slots - 1)
                                ]
                            ]
                            montage = montage_fig(monet_fig)
                        vis.row(id=i, image=fig, mask=montage)
                        plt.close()

                    with vis.table('Visualize #{} Trace'.format(visualized), [
                            HTMLTableColumnDesc('id', 'Step', 'text',
                                                {'width': '50px'}),
                            HTMLTableColumnDesc('image', 'Image', 'figure',
                                                {'width': '600px'}),
                            HTMLTableColumnDesc('p', 'operation', 'text',
                                                {'width': '200px'}),
                            HTMLTableColumnDesc('r', 'result', 'code',
                                                {'width': '200px'})
                    ]):
                        # TODO(Jiayuan Mao @ 11/20): support output_dict.programs.
                        for j, (prog, buf) in enumerate(
                                zip(feed_dict.program_seq[i],
                                    output_dict.buffers[i])):
                            if j != len(feed_dict.program_seq[i]) - 1 and (
                                    buf > 0
                            ).long().sum().item() > 0 and buf.size(
                                    0) == feed_dict.objects_raw[i].shape[0]:
                                this_objects = feed_dict.objects_raw[i][
                                    torch.nonzero(buf > 0)[:, 0].numpy()]
                                fig, ax = vis_bboxes(image,
                                                     this_objects,
                                                     'object',
                                                     add_text=False)
                            else:
                                fig, ax = vis_bboxes(image, [],
                                                     'object',
                                                     add_text=False)
                            vis.row(id=j, image=fig, p=repr(prog), r=repr(buf))
                            plt.close()

                    visualized += 1
                    if visualized > args.nr_visualize:
                        raise StopIteration()

                pbar.set_description(
                    meters.format_simple(
                        'Epoch {} (validation)'.format(epoch), {
                            k: v
                            for k, v in meters.val.items()
                            if k.startswith('validation') and k.count('/') <= 1
                        },
                        compressed=True))
                pbar.update()

                end = time.time()
    except StopIteration:
        pass

    from jacinle.utils.meta import dict_deep_kv
    from jacinle.utils.printing import kvformat
    with vis.table('Info', [
            HTMLTableColumnDesc('name', 'Name', 'code', {}),
            HTMLTableColumnDesc('info', 'KV', 'code', {})
    ]):
        vis.row(name='args', info=kvformat(args.__dict__, max_key_len=32))
        vis.row(name='configs',
                info=kvformat(dict(dict_deep_kv(configs)), max_key_len=32))
    vis.end_html()

    logger.info(
        'Happy Holiday! You can find your result at "http://monday.csail.mit.edu/xiuming'
        + osp.realpath(args.vis_dir) + '".')
Esempio n. 20
0
    def __getitem__(self, index):
        metainfo = GView(self.get_metainfo(index))
        feed_dict = GView()

        # metainfo annotations
        if self.incl_scene:
            feed_dict.scene = metainfo.scene
            feed_dict.update(gdef.annotate_objects(metainfo.scene))
            if 'objects' in feed_dict:
                # NB(Jiayuan Mao): in some datasets, object information might be completely unavailable.
                feed_dict.objects_raw = feed_dict.objects.copy()
            feed_dict.update(gdef.annotate_scene(metainfo.scene))

        # image
        feed_dict.image_index = metainfo.image_index
        feed_dict.image_filename = metainfo.image_filename
        if self.image_root is not None and feed_dict.image_filename is not None:
            feed_dict.image = Image.open(osp.join(self.image_root, feed_dict.image_filename)).convert('RGB')
            feed_dict.image, feed_dict.objects = self.image_transform(feed_dict.image, feed_dict.objects)

        # program
        if 'program_raw' in metainfo:
            feed_dict.program_raw = metainfo.program_raw
            feed_dict.program_seq = metainfo.program_seq
            feed_dict.program_tree = metainfo.program_tree
            feed_dict.program_qsseq = metainfo.program_qsseq
            feed_dict.program_qstree = metainfo.program_qstree
        feed_dict.question_type = metainfo.question_type

        # question
        feed_dict.question_index = metainfo.question_index
        feed_dict.question_raw = metainfo.question
        feed_dict.question_raw_tokenized = metainfo.question_tokenized
        feed_dict.question_metainfo = gdef.annotate_question_metainfo(metainfo)
        feed_dict.question = metainfo.question_tokenized
        feed_dict.answer = gdef.canonize_answer(metainfo.answer, metainfo.question_type)
        feed_dict.update(gdef.annotate_question(metainfo))

        if self.question_transform is not None:
            self.question_transform(feed_dict)
        feed_dict.question = np.array(self.vocab.map_sequence(feed_dict.question), dtype='int64')

        return feed_dict.raw()
Esempio n. 21
0
    def __getitem__(self, index):
        metainfo = GView(self.get_metainfo(index))
        feed_dict = GView()

        # scene annotations
        if self.incl_scene:
            feed_dict.scene = metainfo.scene
            feed_dict.update(gdef.annotate_objects(metainfo.scene))
            feed_dict.objects_raw = feed_dict.objects.copy()
            feed_dict.update(gdef.annotate_scene(metainfo.scene))

        # image
        feed_dict.image_index = metainfo.image_index
        feed_dict.image_filename = metainfo.image_filename
        if self.image_root is not None:
            feed_dict.image = Image.open(osp.join(self.image_root, feed_dict.image_filename)).convert('RGB')
            feed_dict.image, feed_dict.objects = self.image_transform(feed_dict.image, feed_dict.objects)

        return feed_dict.raw()
Esempio n. 22
0
    def __getitem__(self, index):
        metainfo = GView(self.get_metainfo(index))
        feed_dict = GView()

        # scene annotations
        if self.incl_scene:
            feed_dict.scene = metainfo.scene
            feed_dict.update(gdef.annotate_objects(metainfo.scene))
            feed_dict.objects_raw = feed_dict.objects.copy()
            feed_dict.update(gdef.annotate_scene(metainfo.scene))

        # image
        feed_dict.image_index = metainfo.image_index
        feed_dict.image_filename = metainfo.image_filename
        if self.image_root is not None:
            feed_dict.image = Image.open(osp.join(self.image_root, feed_dict.image_filename)).convert('RGB')
            feed_dict.image, feed_dict.objects = self.image_transform(feed_dict.image, feed_dict.objects)

        # program
        feed_dict.program_raw = metainfo.program_raw
        feed_dict.program_seq = metainfo.program_seq
        feed_dict.program_tree = metainfo.program_tree
        feed_dict.program_qsseq = metainfo.program_qsseq
        feed_dict.program_qstree = metainfo.program_qstree
        feed_dict.question_type = metainfo.question_type

        # question
        feed_dict.answer = True

        return feed_dict.raw()
Esempio n. 23
0
    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)

        if args.task in ['final', 'stack', 'sort']:
            states = feed_dict.states.float()
            batch_size = states.size(0)
        else:
            states = feed_dict.states
            batch_size = states[0].size(0)

        f, more_info = self.get_binary_relations(states)
        saved_for_fa = f
        if args.model == 'dlm':
            f = self.pred(f)
            logits = f[0].squeeze(dim=-1).view(batch_size, -1)
            logits = 1e-5 + logits * (1.0 - 2e-5)
            if args.distribution == 0:
                sigma = logits.sum(-1).unsqueeze(-1)
                policy = torch.where(sigma > 1.0, logits/sigma, logits + (1-sigma)/logits.shape[1])
            elif args.distribution == 1:
                policy = F.softmax(logits / args.last_tau, dim=-1).clamp(min=1e-20)
            elif args.distribution == 2:
                if self.training:
                    fa = self.ac_selector(saved_for_fa.detach())
                    policy = (fa.sigmoid() + 1e-5 )*logits
                else:
                    policy = logits
                policy = policy / policy.sum(-1).unsqueeze(-1)
            else:
                raise()

            if feed_dict.training:
                if 'saturation' in more_info.keys():
                    more_info['saturation'].extend(f[1]['saturation'])
                else:
                    more_info['saturation']=[f[1]['saturation']]

                if 'entropies' in more_info.keys():
                    more_info['entropies'].extend(f[1]['entropies'])
        else:
            logits = self.pred(f).squeeze(dim=-1).view(batch_size, -1)
            policy = F.softmax(logits, dim=-1).clamp(min=1e-20)

        if not feed_dict.training:
            return dict(policy=policy, logits=logits)

        loss, monitors = self.loss(policy, feed_dict.actions, feed_dict.discount_rewards, feed_dict.entropy_beta)
 
        if args.pred_weight != 0.0:
            pred_states = feed_dict.pred_states.float()
            f, _ = self.get_binary_relations(pred_states, depth=args.pred_depth)
            if args.model == 'dlm':
                f = self.pred_valid(f)[0].squeeze(dim=-1).view(pred_states.size(0), -1)
            else:
                f = self.pred_valid(f).squeeze(dim=-1).view(pred_states.size(0), -1)
            # Set minimal value to avoid loss to be nan.
            valid = f[range(pred_states.size(0)), feed_dict.pred_actions].clamp(min=1e-20)
            pred_loss = self.pred_loss(valid, feed_dict.valid)
            monitors['pred/accuracy'] = feed_dict.valid.eq((valid > 0.5).float()).float().mean()
            loss = loss + args.pred_weight * pred_loss
        if args.model == 'dlm':
            pred = (logits.detach().cpu() > 0.5).float()
            sat = 1 - (logits.detach().cpu() - pred).abs()
            monitors.update({'saturation/min': np.array(sat.min())})
            monitors.update({'saturation/mean': np.array(sat.mean())})
            saturation_inside = torch.cat([a.flatten() for a in more_info['saturation']])
            monitors.update({'saturation-inside/min': np.array(saturation_inside.cpu().min())})
            monitors.update({'saturation-inside/mean': np.array(saturation_inside.cpu().mean())})
            monitors.update({'tau': np.array(self.tau)})
            monitors.update({'dropout_prob': np.array(self.dropout_prob)})
            monitors.update({'gumbel_prob': np.array(self.gumbel_prob)})

        return loss, monitors, dict()
Esempio n. 24
0
 def __init__(self, training, *, loss=0, monitors=None, output_dict=None):
     self.training = training
     self.loss = loss
     self.monitors = GView(monitors)
     self.output_dict = GView(output_dict)
     self.hyperparameters = dict()
Esempio n. 25
0
class ForwardContext(object):
    def __init__(self, training, *, loss=0, monitors=None, output_dict=None):
        self.training = training
        self.loss = loss
        self.monitors = GView(monitors)
        self.output_dict = GView(output_dict)
        self.hyperparameters = dict()

    def set_hyperparameter(self, key, value):
        self.hyperparameters[key] = value

    def get_hyperparameter(self, key, default=None):
        return self.hyperparameters.get(key, default=default)

    def add_loss(self, loss, key=None, accumulate=True):
        if float(accumulate) > 0:
            self.loss = self.loss + loss * float(accumulate)

        if key is not None:
            if f'loss/{key}' in self.monitors:
                self.monitors[f'loss/{key}'] += float(loss)
            else:
                self.monitors[f'loss/{key}'] = float(loss)
        return self

    def add_accuracy(self, accuracy, key):
        self.monitors[f'accuracy/{key}'] = float(accuracy)
        return self

    def add_output(self, output, key):
        self.output_dict[key] = output
        return self

    def update_monitors(self, monitors):
        self.monitors.update(monitors)
        return self

    def update_mo(self, monitors, output_dict):
        self.monitors.update(monitors)
        self.output_dict.update(output_dict)
        return self

    binary_classification_accuracy = _wrap_monitor_function(
        monitor.binary_classification_accuracy)
    classification_accuracy = _wrap_monitor_function(
        monitor.classification_accuracy)
    regression_accuracy = _wrap_monitor_function(monitor.regression_accuracy)
    monitor_rms = _wrap_monitor_function(monitor.monitor_rms)
    monitor_param_saturation = _wrap_monitor_function(
        monitor.monitor_param_saturation)
    monitor_param_rms = _wrap_monitor_function(monitor.monitor_param_rms)
    monitor_param_gradrms = _wrap_monitor_function(
        monitor.monitor_param_gradrms)
    monitor_param_gradrms_ratio = _wrap_monitor_function(
        monitor.monitor_param_gradrms_ratio)

    @wrap_custom_as_default(is_local=True)
    def as_default(self) -> 'ForwardContext':
        yield self

    def finalize(self):
        if self.training:
            return self.loss, self.monitors, self.output_dict
        else:
            self.output_dict.monitors = self.monitors
            return self.output_dict
Esempio n. 26
0
    def __getitem__(self, index):
        metainfo = GView(self.get_metainfo(index))
        feed_dict = GView()

        # scene annotations
        if self.incl_scene:
            feed_dict.scene = metainfo.scene
            feed_dict.update(gdef.annotate_objects(metainfo.scene))
            feed_dict.objects_raw = feed_dict.objects.copy()
            feed_dict.update(gdef.annotate_scene(metainfo.scene))

        # image
        feed_dict.image_index = metainfo.image_index
        feed_dict.image_filename = metainfo.image_filename

        # video
        feed_dict.video_folder = metainfo.video_folder
        video = []
        original_objects = feed_dict.objects
        if self.image_root is not None:
            feed_dict.image = Image.open(
                osp.join(self.image_root,
                         feed_dict.image_filename)).convert("RGB")
            feed_dict.image, feed_dict.objects = self.image_transform(
                feed_dict.image, feed_dict.objects)

        if self.image_root is not None and feed_dict.video_folder is not None:
            import glob

            for name in glob.glob(
                    osp.join(self.image_root, feed_dict.video_folder) +
                    "/*.png"):
                image = Image.open(name).convert("RGB")
                image, _ = self.image_transform(image, original_objects)
                video += [image]

            feed_dict.video = torch.cat(video)

        # program
        feed_dict.program_raw = metainfo.program_raw
        feed_dict.program_seq = metainfo.program_seq
        feed_dict.program_tree = metainfo.program_tree
        feed_dict.program_qsseq = metainfo.program_qsseq
        feed_dict.program_qstree = metainfo.program_qstree
        feed_dict.question_type = metainfo.question_type

        # question
        feed_dict.answer = True

        return feed_dict.raw()
Esempio n. 27
0
    def __getitem__(self, index):
        metainfo = GView(self.get_metainfo(index))
        feed_dict = GView()

        # metainfo annotations
        if self.incl_scene:
            feed_dict.scene = metainfo.scene
            feed_dict.update(gdef.annotate_objects(metainfo.scene))
            if "objects" in feed_dict:
                # NB(Jiayuan Mao): in some datasets, object information might be completely unavailable.
                feed_dict.objects_raw = feed_dict.objects.copy()
            feed_dict.update(gdef.annotate_scene(metainfo.scene))

        # image
        feed_dict.image_index = metainfo.image_index
        feed_dict.image_filename = metainfo.image_filename
        # video
        feed_dict.video_folder = metainfo.video_folder
        video = []
        original_objects = feed_dict.objects
        if self.image_root is not None and feed_dict.image_filename is not None:
            feed_dict.image = Image.open(
                osp.join(self.image_root,
                         feed_dict.image_filename)).convert("RGB")
            feed_dict.image, feed_dict.objects = self.image_transform(
                feed_dict.image, feed_dict.objects)

            # print("Image:", feed_dict.image.shape)
            # print(feed_dict.objects)

        if self.image_root is not None and feed_dict.video_folder is not None:
            import glob

            for name in glob.glob(
                    osp.join(self.image_root, feed_dict.video_folder) +
                    "/*.png"):
                image = Image.open(name).convert("RGB")
                image, _ = self.image_transform(image, original_objects)
                video += [image]

            feed_dict.video = torch.stack(video)

            # Tensor
            # print("Video:", feed_dict.video.shape)

        # program
        if "program_raw" in metainfo:
            feed_dict.program_raw = metainfo.program_raw
            feed_dict.program_seq = metainfo.program_seq
            feed_dict.program_tree = metainfo.program_tree
            feed_dict.program_qsseq = metainfo.program_qsseq
            feed_dict.program_qstree = metainfo.program_qstree
        feed_dict.question_type = metainfo.question_type

        # question
        feed_dict.question_index = metainfo.question_index
        feed_dict.question_raw = metainfo.question
        feed_dict.question_raw_tokenized = metainfo.question_tokenized
        feed_dict.question_metainfo = gdef.annotate_question_metainfo(metainfo)
        feed_dict.question = metainfo.question_tokenized
        feed_dict.answer = gdef.canonize_answer(metainfo.answer,
                                                metainfo.question_type)
        feed_dict.update(gdef.annotate_question(metainfo))

        if self.question_transform is not None:
            self.question_transform(feed_dict)
        feed_dict.question = np.array(self.vocab.map_sequence(
            feed_dict.question),
                                      dtype="int64")

        return feed_dict.raw()
Esempio n. 28
0
    def forward(self, feed_dict, return_loss_matrix=False, can_break=False):
        # Pdb().set_trace()
        feed_dict = GView(feed_dict)
        # convert it to graph
        bg = self.collate_fn(feed_dict)
        # logits : of shape : args.sudoku_num_steps x batchsize*81 x 10 if training
        # logits: of shape : batch_size*81 x 10 if not training
        logits = self.sudoku_solver(bg, self.training)

        if self.training:
            # testing
            """
            labelsa = bg.ndata['a']
            labelsb = torch.stack([labelsa]*self.num_steps, 0)
            labels = labelsb.view([-1])
            labels1 = feed_dict.target.flatten().unsqueeze(0).expand(self.num_steps,-1).flatten().long()
            gl = dgl.unbatch(bg)
            gl[0].ndata['q']
            gl[1].ndata['q']
            Pdb().set_trace()
            print((labels != labels1).sum())
            loss = self.loss_func(logits.view([-1,10]), labels)
            #
            """
            logits = logits.transpose(1, 2)
            logits = logits.transpose(0, 2)
        else:
            logits = logits.unsqueeze(-1)
        # shape of logits now : BS*81 x 10 x 32 if self.training ,  otherwise BS*81 x 10 x 1
        logits = logits.view(-1, 81, logits.size(-2), logits.size(-1))
        # shape of logits now : BS x  81 x 10 x 32(1)
        logits = logits.transpose(1, 2)
        # shape of logits now : BS x  10 x 81 x 32(1)
        #pred = logits[:,:,:,-1].argmax(dim=1)
        pred = logits

        if self.training:
            # Pdb().set_trace()
            this_meters, _, reward = instance_accuracy(
                feed_dict.target.float(),
                pred,
                return_float=False,
                feed_dict=feed_dict,
                task=self.args.task,
                args=self.args)

            monitors = dict()
            target = feed_dict.target.float()
            count = None
            # Pdb().set_trace()
            loss_matrix = None
            if self.args.cc_loss or self.args.naive_pll_loss or self.args.min_loss or 'weights' in feed_dict or return_loss_matrix:
                loss_matrix = self.loss(logits, feed_dict.target_set,
                                        feed_dict.mask)
            else:
                loss_matrix = self.loss(logits, target.unsqueeze(1),
                                        feed_dict.mask[:, 0].unsqueeze(-1))
            # Pdb().set_trace()
            if 'weights' in feed_dict:
                loss = (feed_dict.weights*loss_matrix).sum() / \
                    feed_dict.weights.sum()

            else:
                loss = loss_matrix.mean()

            #loss = loss_ch
            # print(loss,loss_ch)

            #logger.info("Reward: ")
            # logger.info(reward)
            monitors.update(this_meters)
            # logits = logits.view([, 10])
            #labels = labels.view([-1])

            # loss_matrix of size: batch_size x target set size
            # when in training mode return prediction for all steps
            return loss, monitors, dict(pred=pred,
                                        reward=reward,
                                        loss_matrix=loss_matrix)
        else:
            return dict(pred=pred)
Esempio n. 29
0
    def __getitem__(self, index):
        metainfo = GView(self.get_metainfo(index))
        feed_dict = GView()

        # scene annotations
        if self.incl_scene:
            feed_dict.scene = metainfo.scene
            feed_dict.update(gdef.annotate_objects(metainfo.scene))
            feed_dict.objects_raw = feed_dict.objects.copy()
            feed_dict.update(gdef.annotate_scene(metainfo.scene))

        # image
        feed_dict.image_index = metainfo.image_index
        feed_dict.image_filename = metainfo.image_filename
        # video_folder
        feed_dict.video_folder = metainfo.video_folder
        feed_dict.video = []

        if self.image_root is not None:
            feed_dict.image = Image.open(
                osp.join(self.image_root,
                         feed_dict.image_filename)).convert("RGB")
            feed_dict.image, feed_dict.objects = self.image_transform(
                feed_dict.image, feed_dict.objects)

        if self.image_root is not None:
            import glob

            print("Got video 3")
            for name in glob.glob(
                    osp.join(self.image_root, feed_dict.video_folder) +
                    "/*.png"):
                feed_dict.video += [Image.open(name).convert("RGB")]

        return feed_dict.raw()
Esempio n. 30
0
    def __getitem__(self, index):
        # index = index % 200
        metainfo = GView(self.get_metainfo(index))
        metainfo.view_id = 1
        feed_dict = GView()
        feed_dict.scene = metainfo.scene
        feed_dict.attribute_name = "shape"
        feed_dict.concept_name = metainfo.scene["objects"][0][feed_dict.attribute_name]
        if self.incl_scene:
            feed_dict.scene = metainfo.scene
            feed_dict.update(gdef.annotate_objects(metainfo.scene))
            if "objects" in feed_dict:
                # NB(Jiayuan Mao): in some datasets, object information might be completely unavailable.
                feed_dict.objects_raw = feed_dict.objects.copy()
            feed_dict.update(gdef.annotate_scene(metainfo.scene))

        # image
        feed_dict.image_index = metainfo.image_index
        feed_dict.image_filename = metainfo.image_filename
        if self.image_root is not None and feed_dict.image_filename is not None:
            feed_dict.image = Image.open(
                osp.join(self.image_root, feed_dict.image_filename)
            ).convert("RGB")
            feed_dict.image, feed_dict.objects = self.image_transform(
                feed_dict.image, feed_dict.objects
            )
        if self.depth_root is not None and feed_dict.image_filename is not None:
            depth_filename = feed_dict.image_filename.split(".")[0] + ".exr"
            feed_dict.depth = torch.tensor(
                load_depth(osp.join(self.depth_root, depth_filename))
            )
        # program

        # Scene
        # feed_dict.bboxes = torch.tensor(feed_dict.scene["obj_bboxes"][0]).reshape(-1, 9)
        # # feed_dict.bboxes_len = torch.tensor(feed_dict.bboxes.size(0))
        # feed_dict.pix_T_cam = torch.tensor(metainfo.scene["pix_T_cams"]).float()
        # feed_dict.origin_T_cam = torch.tensor(
        #     metainfo.scene["origin_T_cams"][metainfo.view_id]
        # ).float()
        return feed_dict.raw()