Ejemplo n.º 1
0
def test_region_grounding_model(config):
    db = vg(config, 'test')
    loaddb = region_loader(db)
    loader = DataLoader(loaddb,
                        batch_size=3 * config.batch_size,
                        shuffle=True,
                        num_workers=config.num_workers,
                        collate_fn=region_collate_fn)

    net = RegionGroundingModel(config)
    if config.pretrained is not None:
        pretrained_path = osp.join(config.data_dir,
                                   'caches/region_grounding_ckpts',
                                   config.pretrained + '.pkl')
        states = torch.load(pretrained_path,
                            map_location=lambda storage, loc: storage)
        net.load_state_dict(states['state_dict'], strict=False)
    net.train()
    for name, param in net.named_parameters():
        print(name, param.size())

    for cnt, batched in enumerate(loader):
        scene_inds = batched['scene_inds'].long()
        sent_inds = batched['sent_inds'].long()
        sent_msks = batched['sent_msks'].long()
        region_feats = batched['region_feats'].float()
        region_clses = batched['region_clses'].long()
        region_masks = batched['region_masks'].float()
        img_feats, masked_feats, txt_feats, subspace_masks, sample_logits, sample_indices = \
            net(scene_inds, sent_inds, sent_msks, None, None, None, region_feats, region_clses, region_masks, config.explore_mode)
        if config.instance_dim > 1:
            print(sample_indices[0])
        # print('sample_logits', sample_logits.size())
        # print('sample_indices', sample_indices.size())
        txt_masks = txt_feats.new_ones(txt_feats.size(0), txt_feats.size(1))
        losses = net.final_loss(img_feats, masked_feats, region_masks,
                                txt_feats, txt_masks, sample_logits,
                                sample_indices)
        print('losses', losses.size(), torch.mean(losses))

        if config.subspace_alignment_mode > 0:
            metrics, cache_results = net.evaluate(masked_feats, region_masks,
                                                  txt_feats)
        else:
            metrics, cache_results = net.evaluate(img_feats, region_masks,
                                                  txt_feats)
        print('metrics', metrics)
        print('txt_feats', txt_feats.size())
        print('img_feats', img_feats.size())

        break
Ejemplo n.º 2
0
def test_region_model(config):
    db = vg(config, 'test')
    loaddb = region_loader(db)
    loader = DataLoader(loaddb,
                        batch_size=3 * config.batch_size,
                        shuffle=True,
                        num_workers=config.num_workers,
                        collate_fn=region_collate_fn)

    net = RegionModel(config)
    net.train()

    for name, param in net.named_parameters():
        print(name, param.size())

    for cnt, batched in enumerate(loader):
        start = time()
        scene_inds = batched['scene_inds'].long()[:config.batch_size]
        sent_inds = batched['sent_inds'].long()[:config.batch_size]
        sent_msks = batched['sent_msks'].long()[:config.batch_size]
        region_feats = batched['region_feats'].float()[:config.batch_size]
        region_clses = batched['region_clses'].long()[:config.batch_size]
        region_masks = batched['region_masks'].float()[:config.batch_size]
        src_region_feats = batched['region_feats'].float(
        )[config.batch_size:2 * config.batch_size]
        src_region_clses = batched['region_clses'].long()[config.batch_size:2 *
                                                          config.batch_size]
        src_region_masks = batched['region_masks'].float(
        )[config.batch_size:2 * config.batch_size]

        img_feats, masked_feats, txt_feats, subspace_masks, sample_logits, sample_indices = \
            net(scene_inds, sent_inds, sent_msks,
            src_region_feats, src_region_clses, src_region_masks,
            region_feats, region_clses, region_masks,
            config.explore_mode)
        print('img_feats', img_feats.size())
        print('txt_feats', txt_feats.size())
        if config.subspace_alignment_mode > 0:
            print('masked_feats', masked_feats.size())
            print('subspace_masks', subspace_masks.size())
        if config.instance_dim > 1:
            print('sample_logits', sample_logits.size())
            print('sample_indices', sample_indices.size())
        print('time:', time() - start)
        break
Ejemplo n.º 3
0
def check_region_clses(config):
    db = vg(config, 'train')
    loaddb = region_loader(db)
    loader = DataLoader(loaddb,
                        batch_size=config.batch_size,
                        shuffle=True,
                        num_workers=config.num_workers,
                        collate_fn=region_collate_fn)

    min_index = 1000000
    max_index = -1
    for cnt, batched in enumerate(loader):
        region_clses = batched['region_clses'].long()
        min_index = min(min_index, torch.min(region_clses).item())
        max_index = max(max_index, torch.max(region_clses).item())
        if cnt % 1000:
            print('iter:', cnt)
    print('min_index', min_index)
    print('max_index', max_index)
Ejemplo n.º 4
0
def test_grounding_loss(config):
    db = vg(config, 'test')
    loaddb = region_loader(db)
    loader = DataLoader(loaddb,
                        batch_size=3 * config.batch_size,
                        shuffle=True,
                        num_workers=config.num_workers,
                        collate_fn=region_collate_fn)

    net = RegionModel(config)
    criterion = GroundingLoss(config)
    for cnt, batched in enumerate(loader):
        scene_inds = batched['scene_inds'].long()[:config.batch_size]
        sent_inds = batched['sent_inds'].long()[:config.batch_size]
        sent_msks = batched['sent_msks'].long()[:config.batch_size]
        region_feats = batched['region_feats'].float()[:config.batch_size]
        region_clses = batched['region_clses'].long()[:config.batch_size]
        region_masks = batched['region_masks'].float()[:config.batch_size]
        src_region_feats = batched['region_feats'].float(
        )[config.batch_size:2 * config.batch_size]
        src_region_clses = batched['region_clses'].long()[config.batch_size:2 *
                                                          config.batch_size]
        src_region_masks = batched['region_masks'].float(
        )[config.batch_size:2 * config.batch_size]

        img_feats, masked_feats, txt_feats, subspace_masks, sample_logits, sample_indices = \
            net(scene_inds, sent_inds, sent_msks,
            src_region_feats, src_region_clses, src_region_masks,
            region_feats, region_clses, region_masks,
            config.explore_mode)
        masked_feats = img_feats
        sim1 = criterion.compute_batch_mutual_similarity(
            masked_feats, region_masks, txt_feats)
        sim2 = criterion.debug_compute_batch_mutual_similarity(
            masked_feats, region_masks, txt_feats)
        print('sim1', sim1.size())
        print('sim2', sim2.size())
        print('diff', torch.sum(torch.abs(sim1 - sim2)))
        txt_masks = txt_feats.new_ones(txt_feats.size(0), txt_feats.size(1))
        losses = criterion.forward_loss(masked_feats, region_masks, txt_feats,
                                        txt_masks, config.loss_reduction_mode)
        print('losses', losses.size())
        break
Ejemplo n.º 5
0
def test_text_encoder(config):
    db = vg(config, 'test')
    loaddb = region_loader(db)
    loader = DataLoader(loaddb,
                        batch_size=3 * config.batch_size,
                        shuffle=True,
                        num_workers=config.num_workers,
                        collate_fn=region_collate_fn)

    net = TextEncoder(config)
    for cnt, batched in enumerate(loader):
        sent_inds = batched['sent_inds'].long()
        sent_msks = batched['sent_msks'].float()
        bsize, slen, fsize = sent_inds.size()
        print('sent_inds', sent_inds.size())
        print('sent_msks', sent_msks.size())
        f1, f2, h = net(sent_inds.view(bsize * slen, fsize),
                        sent_msks.view(bsize * slen, fsize))
        print(f1.size(), f2.size(), h.size())
        break
Ejemplo n.º 6
0
def test_region_encoder(config):
    db = vg(config, 'test')
    loaddb = region_loader(db)
    loader = DataLoader(loaddb,
                        batch_size=3 * config.batch_size,
                        shuffle=True,
                        num_workers=config.num_workers,
                        collate_fn=region_collate_fn)

    net = RegionEncoder(config)
    for cnt, batched in enumerate(loader):
        region_feats = batched['region_feats'].float()
        region_clses = batched['region_clses'].long()
        print('region_feats', region_feats.size())
        print('region_clses', region_clses.size())
        img_feats, masked_feats, mm = net(region_feats, region_clses)
        print('img_feats', img_feats.size())
        if config.subspace_alignment_mode > 0:
            print('masked_feats', masked_feats.size())
            print('mm', mm.size())
        break
Ejemplo n.º 7
0
def test_region_loader(config):
    db = vg(config, 'train')
    # db = coco(config, 'train')
    loaddb = region_loader(db)
    loader = DataLoader(loaddb,
                        batch_size=config.batch_size,
                        shuffle=True,
                        num_workers=config.num_workers,
                        collate_fn=region_collate_fn)

    output_dir = osp.join(config.model_dir, 'test_region_loader')
    maybe_create(output_dir)

    start = time()
    plt.switch_backend('agg')
    for cnt, batched in enumerate(loader):
        print('scene_inds', batched['scene_inds'])
        sent_inds = batched['sent_inds'].long()
        sent_msks = batched['sent_msks'].long()
        widths = batched['widths']
        heights = batched['heights']

        captions = batched['captions']
        region_boxes = batched['region_boxes'].float()
        region_feats = batched['region_feats'].float()
        region_clses = batched['region_clses'].long()
        region_masks = batched['region_masks'].long()

        print('sent_inds', sent_inds.size())
        print('sent_msks', sent_msks.size())
        print('region_boxes', region_boxes.size())
        print('region_feats', region_feats.size())
        print('region_clses', region_clses.size())
        print('region_masks', region_masks.size())
        print('clses', torch.min(region_clses), torch.max(region_clses))
        print('widths', widths)
        print('heights', heights)

        for i in range(len(sent_inds)):
            # print('####')
            # print(len(captions), len(captions[0]))
            entry = {}
            image_index = batched['image_inds'][i]
            entry['width'] = widths[i]
            entry['height'] = heights[i]
            nr = torch.sum(region_masks[i])
            entry['region_boxes'] = xyxys_to_xywhs(
                region_boxes[i, :nr].cpu().data.numpy())

            color = cv2.imread(db.color_path_from_index(image_index),
                               cv2.IMREAD_COLOR)
            color, _, _ = create_squared_image(color)

            out_path = osp.join(output_dir, '%d.png' % image_index)
            layouts = db.render_regions_as_output(
                entry,
                bg=cv2.resize(
                    color,
                    (config.visu_size[0], config.visu_size[0]))[:, :, ::-1])

            fig = plt.figure(figsize=(32, 16))
            for j in range(min(14, len(layouts))):
                plt.subplot(3, 5, j + 1)
                if j < config.max_turns:
                    plt.title(
                        captions[i][j] + '\n' +
                        ' '.join([str(x.data.item())
                                  for x in sent_inds[i, j]]) + '\n' +
                        ' '.join([str(x.data.item())
                                  for x in sent_msks[i, j]]))
                plt.imshow(layouts[j].astype(np.uint8))
                plt.axis('off')
            plt.subplot(3, 5, 15)
            plt.imshow(color[:, :, ::-1])
            plt.axis('off')
            fig.savefig(out_path, bbox_inches='tight')
            plt.close(fig)

        print('------------------')
        if cnt == 2:
            break
    print("Time", time() - start)
Ejemplo n.º 8
0
    def test(self, test_db):
        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)
        start = time()
        test_loaddb = region_loader(test_db)
        test_loader = DataLoader(test_loaddb,
                                 batch_size=self.cfg.batch_size,
                                 shuffle=False,
                                 num_workers=self.cfg.num_workers,
                                 collate_fn=region_collate_fn)

        sample_mode = 0 if self.cfg.rl_finetune > 0 else self.cfg.explore_mode
        all_txt_feats, all_img_feats, all_img_masks, losses = [], [], [], []
        self.net.eval()
        for cnt, batched in enumerate(test_loader):
            ##################################################################
            ## Batched data
            ##################################################################
            scene_inds, sent_inds, sent_msks, region_feats, region_masks, region_clses = self.batch_data(
                batched)

            ##################################################################
            ## Inference one step
            ##################################################################
            with torch.no_grad():
                img_feats, masked_feats, txt_feats, subspace_masks, sample_logits, sample_indices = \
                    self.net(scene_inds, sent_inds, sent_msks, None, None, None, region_feats, region_clses, region_masks, sample_mode=sample_mode)
                txt_masks = txt_feats.new_ones(txt_feats.size(0),
                                               txt_feats.size(1))
                batch_losses = self.net.final_loss(img_feats, masked_feats,
                                                   region_masks, txt_feats,
                                                   txt_masks, sample_logits,
                                                   sample_indices)
                loss = torch.sum(torch.mean(batch_losses, -1))
            losses.append(loss.cpu().data.item())
            all_txt_feats.append(txt_feats)
            all_img_masks.append(region_masks)
            if self.cfg.subspace_alignment_mode > 0:
                all_img_feats.append(masked_feats)
            else:
                all_img_feats.append(img_feats)
            ##################################################################
            ## Print info
            ##################################################################
            if cnt % self.cfg.log_per_steps == 0:
                print('Iter %07d:' % (cnt))
                tmp_losses = np.stack(losses, 0)
                print('mean loss: ', np.mean(tmp_losses))
                print('-------------------------')

        torch.cuda.empty_cache()
        losses = np.array(losses)
        all_img_feats = torch.cat(all_img_feats, 0)
        all_img_masks = torch.cat(all_img_masks, 0)
        all_txt_feats = torch.cat(all_txt_feats, 0)
        all_txt_masks = all_txt_feats.new_ones(all_txt_feats.size(0),
                                               all_txt_feats.size(1))

        # print('all_img_feats', all_img_feats.size())
        all_img_feats_np = all_img_feats.cpu().data.numpy()
        all_img_masks_np = all_img_masks.cpu().data.numpy()
        with open(
                osp.join(self.cfg.model_dir,
                         'img_features_%d.pkl' % self.cfg.n_feature_dim),
                'wb') as fid:
            pickle.dump({
                'feats': all_img_feats_np,
                'masks': all_img_masks_np
            }, fid, pickle.HIGHEST_PROTOCOL)

        ##################################################################
        ## Evaluation
        ##################################################################
        print('Evaluating the per-turn performance, may take a while.')
        metrics, caches_results = self.net.evaluate(all_img_feats,
                                                    all_img_masks,
                                                    all_txt_feats)

        with open(osp.join(self.cfg.model_dir, 'test_metrics.json'),
                  'w') as fp:
            json.dump(metrics, fp, indent=4, sort_keys=True)
        with open(osp.join(self.cfg.model_dir, 'test_caches.pkl'),
                  'wb') as fid:
            pickle.dump(caches_results, fid, pickle.HIGHEST_PROTOCOL)

        visualize(self.cfg.exp_name, metrics,
                  osp.join(self.cfg.model_dir, 'evaluation.png'))

        return losses, metrics, caches_results
Ejemplo n.º 9
0
    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()
        min_val_loss = 1000.0
        max_val_recall = -1.0
        train_loaddb = region_loader(train_db)
        val_loaddb = region_loader(val_db)
        #TODO
        train_loader = DataLoader(train_loaddb,
                                  batch_size=self.cfg.batch_size,
                                  shuffle=True,
                                  num_workers=self.cfg.num_workers,
                                  collate_fn=region_collate_fn)
        val_loader = DataLoader(val_loaddb,
                                batch_size=self.cfg.batch_size,
                                shuffle=False,
                                num_workers=self.cfg.num_workers,
                                collate_fn=region_collate_fn)

        for epoch in range(self.epoch, self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            if self.cfg.coco_mode >= 0:
                self.cfg.coco_mode = np.random.randint(0, self.cfg.max_turns)
            torch.cuda.empty_cache()
            train_losses = self.train_epoch(train_loaddb, train_loader, epoch)

            ##################################################################
            ## Validation
            ##################################################################
            if self.cfg.coco_mode >= 0:
                self.cfg.coco_mode = 0
            torch.cuda.empty_cache()
            val_losses, val_metrics, caches_results = self.validate_epoch(
                val_loaddb, val_loader, epoch)

            #################################################################
            # Logging
            #################################################################

            # update optim scheduler
            current_val_loss = np.mean(val_losses)
            self.optimizer.update(current_val_loss, epoch)
            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)
            logz.log_tabular("TrainAverageLoss", np.mean(train_losses))
            logz.log_tabular("ValAverageLoss", current_val_loss)

            mmm = np.zeros((5, ), dtype=np.float64)
            for k, v in val_metrics.items():
                mmm = mmm + np.array(v)
            mmm /= len(val_metrics)
            logz.log_tabular("t2i_R1", mmm[0])
            logz.log_tabular("t2i_R5", mmm[1])
            logz.log_tabular("t2i_R10", mmm[2])
            logz.log_tabular("t2i_medr", mmm[3])
            logz.log_tabular("t2i_meanr", mmm[4])
            logz.dump_tabular()
            current_val_recall = np.mean(mmm[:3])

            ##################################################################
            ## Checkpoint
            ##################################################################
            if self.cfg.rl_finetune == 0 and self.cfg.coco_mode < 0:
                if min_val_loss > current_val_loss:
                    min_val_loss = current_val_loss
                    self.save_checkpoint(epoch)
                    with open(
                            osp.join(self.cfg.model_dir,
                                     'val_metrics_%d.json' % epoch),
                            'w') as fp:
                        json.dump(val_metrics, fp, indent=4, sort_keys=True)
                    with open(
                            osp.join(self.cfg.model_dir,
                                     'val_top5_inds_%d.pkl' % epoch),
                            'wb') as fid:
                        pickle.dump(caches_results, fid,
                                    pickle.HIGHEST_PROTOCOL)
            else:
                if max_val_recall < current_val_recall:
                    max_val_recall = current_val_recall
                    self.save_checkpoint(epoch)
                    with open(
                            osp.join(self.cfg.model_dir,
                                     'val_metrics_%d.json' % epoch),
                            'w') as fp:
                        json.dump(val_metrics, fp, indent=4, sort_keys=True)
                    with open(
                            osp.join(self.cfg.model_dir,
                                     'val_top5_inds_%d.pkl' % epoch),
                            'wb') as fid:
                        pickle.dump(caches_results, fid,
                                    pickle.HIGHEST_PROTOCOL)