Exemplo n.º 1
0
def demo(opt):
    model.eval()
    #########################################################################################
    # eval begins here
    #########################################################################################
    data_iter_val = iter(dataloader_val)
    loss_temp = 0
    start = time.time()

    num_show = 0
    predictions = []
    count = 0
    for step in range(1000):
        data = data_iter_val.next()
        img, iseq, gts_seq, num, proposals, bboxs, box_mask, img_id = data

        # if img_id[0] != 134688:
        #     continue

        # # for i in range(proposals.size(1)): print(opt.itoc[proposals[0][i][4]], i)

        # # list1 = [6, 10]
        # list1 = [0, 1, 10, 2, 3, 4, 5, 6, 7, 8, 9]
        # proposals = proposals[:,list1]
        # num[0,1] = len(list1)
        proposals = proposals[:,:max(int(max(num[:,1])),1),:]

        input_imgs.resize_(img.size()).copy_(img)
        input_seqs.resize_(iseq.size()).copy_(iseq)
        gt_seqs.resize_(gts_seq.size()).copy_(gts_seq)
        input_num.resize_(num.size()).copy_(num)
        input_ppls.resize_(proposals.size()).copy_(proposals)
        gt_bboxs.resize_(bboxs.size()).copy_(bboxs)
        mask_bboxs.resize_(box_mask.size()).copy_(box_mask)
        input_imgs.resize_(img.size()).copy_(img)

        eval_opt = {'sample_max':1, 'beam_size': opt.beam_size, 'inference_mode' : True, 'tag_size' : opt.cbs_tag_size}
        seq, bn_seq, fg_seq, _, _, _ = model._sample(input_imgs, input_ppls, input_num, eval_opt)

        sents, det_idx, det_word = utils.decode_sequence_det(dataset_val.itow, dataset_val.itod, dataset_val.ltow, dataset_val.itoc, dataset_val.wtod, \
                                                            seq, bn_seq, fg_seq, opt.vocab_size, opt)

        if opt.dataset == 'flickr30k':
            im2show = Image.open(os.path.join(opt.image_path, '%d.jpg' % img_id[0])).convert('RGB')
        else:

            if os.path.isfile(os.path.join(opt.image_path, 'val2014/COCO_val2014_%012d.jpg' % img_id[0])):
                im2show = Image.open(os.path.join(opt.image_path, 'val2014/COCO_val2014_%012d.jpg' % img_id[0])).convert('RGB')
            else:
                im2show = Image.open(os.path.join(opt.image_path, 'train2014/COCO_train2014_%012d.jpg' % img_id[0])).convert('RGB')

        w, h = im2show.size

        rest_idx = []
        for i in range(proposals[0].shape[0]):
            if i not in det_idx:
                rest_idx.append(i)


        if len(det_idx) > 0:
            # for visulization
            proposals = proposals[0].numpy()
            proposals[:,0] = proposals[:,0] * w / float(opt.image_crop_size)
            proposals[:,2] = proposals[:,2] * w / float(opt.image_crop_size)
            proposals[:,1] = proposals[:,1] * h / float(opt.image_crop_size)
            proposals[:,3] = proposals[:,3] * h / float(opt.image_crop_size)            

            cls_dets = proposals[det_idx]
            rest_dets = proposals[rest_idx]

        # fig = plt.figure()
        # fig = plt.figure(frameon=False)
        # ax = plt.Axes(fig, [0., 0., 1., 1.])
        fig = plt.figure(frameon=False)
        # fig.set_size_inches(5,5*h/w)
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        fig.add_axes(ax)
        a=fig.gca()
        a.set_frame_on(False)
        a.set_xticks([]); a.set_yticks([])
        plt.axis('off')
        plt.xlim(0,w); plt.ylim(h,0)
        # fig, ax = plt.subplots(1)

        # show other box in grey.

        plt.imshow(im2show)

        if len(rest_idx) > 0:
            for i in range(len(rest_dets)):
                ax = utils.vis_detections(ax, dataset_val.itoc[int(rest_dets[i,4])], rest_dets[i,:5], i, 1)

        if len(det_idx) > 0:
            for i in range(len(cls_dets)):
                ax = utils.vis_detections(ax, dataset_val.itoc[int(cls_dets[i,4])], cls_dets[i,:5], i, 0)

        # plt.axis('off')
        # plt.axis('tight')
        # plt.tight_layout()
        fig.savefig('visu/%d.jpg' %(img_id[0].item()), bbox_inches='tight', pad_inches=0, dpi=150)
        print(str(img_id[0].item()) + ': ' + sents[0])

        entry = {'image_id': img_id[0].item(), 'caption': sents[0]}
        predictions.append(entry)

    return predictions
Exemplo n.º 2
0
def demo_fusion_models(opt,
                       dataset_val,
                       imp_pro,
                       spa_pro,
                       sem_pro,
                       imp_model=None,
                       spa_model=None,
                       sem_model=None,
                       save_name=''):
    dataloader_val = torch.utils.data.DataLoader(dataset_val,
                                                 batch_size=opt.batch_size,
                                                 shuffle=False,
                                                 num_workers=opt.num_workers)
    input_imgs = torch.FloatTensor(1)
    input_seqs = torch.LongTensor(1)
    input_ppls = torch.FloatTensor(1)
    gt_bboxs = torch.FloatTensor(1)
    mask_bboxs = torch.ByteTensor(1)
    gt_seqs = torch.LongTensor(1)
    input_num = torch.LongTensor(1)

    if opt.cuda:
        input_imgs = input_imgs.cuda()
        input_seqs = input_seqs.cuda()
        gt_seqs = gt_seqs.cuda()
        input_num = input_num.cuda()
        input_ppls = input_ppls.cuda()
        gt_bboxs = gt_bboxs.cuda()
        mask_bboxs = mask_bboxs.cuda()

    input_imgs = Variable(input_imgs)
    input_seqs = Variable(input_seqs)
    gt_seqs = Variable(gt_seqs)
    input_num = Variable(input_num)
    input_ppls = Variable(input_ppls)
    gt_bboxs = Variable(gt_bboxs)
    mask_bboxs = Variable(mask_bboxs)

    data_iter_val = iter(dataloader_val)
    loss_temp = 0
    start = time.time()

    num_show = 0
    predictions = []
    progress_bar = tqdm(dataloader_val,
                        desc='|Validation process',
                        leave=False)
    # for step in range(len(dataloader_val)):
    for step, data in enumerate(progress_bar):
        if step * opt.batch_size > 1000:
            break
        # data = data_iter_val.next()
        img, iseq, gts_seq, num, proposals, bboxs, box_mask, img_id, spa_adj_matrix, sem_adj_matrix = data

        proposals = proposals[:, :max(int(max(num[:, 1])), 1), :]

        # FF: Fix the bug with .data not run in the Pytorch
        input_imgs.resize_(img.size()).copy_(img)
        input_seqs.resize_(iseq.size()).copy_(iseq)
        gt_seqs.resize_(gts_seq.size()).copy_(gts_seq)
        input_num.resize_(num.size()).copy_(num)
        input_ppls.resize_(proposals.size()).copy_(proposals)
        gt_bboxs.resize_(bboxs.size()).copy_(bboxs)
        # FF: modify 0/1 to true/false
        mask_bboxs.resize_(box_mask.size()).copy_(box_mask.bool())
        # mask_bboxs.data.resize_(box_mask.size()).copy_(box_mask)
        input_imgs.resize_(img.size()).copy_(img)

        if len(spa_adj_matrix[0]) != 0:
            spa_adj_matrix = spa_adj_matrix[:, :max(int(max(num[:, 1])), 1), :
                                            max(int(max(num[:, 1])), 1)]
        if len(sem_adj_matrix[0]) != 0:
            sem_adj_matrix = sem_adj_matrix[:, :max(int(max(num[:, 1])), 1), :
                                            max(int(max(num[:, 1])), 1)]

        # relationship modify
        eval_opt_rel = {
            'imp_model': opt.imp_model,
            'spa_model': opt.spa_model,
            'sem_model': opt.sem_model,
            "graph_att": opt.graph_attention
        }
        pos_emb_var, spa_adj_matrix, sem_adj_matrix = prepare_graph_variables(
            opt.relation_type, proposals[:, :, :4], sem_adj_matrix,
            spa_adj_matrix, opt.nongt_dim, opt.imp_pos_emb_dim,
            opt.spa_label_num, opt.sem_label_num, eval_opt_rel)

        eval_opt = {
            'sample_max': 1,
            'beam_size': opt.beam_size,
            'inference_mode': True,
            'tag_size': opt.cbs_tag_size
        }
        seq, bn_seq, fg_seq, seqLogprobs, bnLogprobs, fgLogprobs, _ = fusion_beam_sample(
            opt, imp_pro, spa_pro, sem_pro, input_ppls, input_imgs, input_num,
            pos_emb_var, spa_adj_matrix, sem_adj_matrix, eval_opt, imp_model,
            spa_model, sem_model)
        sents, det_idx, det_word = utils.decode_sequence_det(
            dataset_val.itow, dataset_val.itod, dataset_val.ltow,
            dataset_val.itoc, dataset_val.wtod, seq.data, bn_seq.data,
            fg_seq.data, opt.vocab_size, opt)

        for i in range(opt.batch_size):
            print(i)

            if os.path.isfile(
                    os.path.join(opt.image_path,
                                 'val2014/COCO_val2014_%012d.jpg' %
                                 img_id[i])):
                im2show = Image.open(
                    os.path.join(opt.image_path,
                                 'val2014/COCO_val2014_%012d.jpg' %
                                 img_id[i])).convert('RGB')
            else:
                im2show = Image.open(
                    os.path.join(
                        opt.image_path, 'train2014/COCO_train2014_%012d.jpg' %
                        img_id[i])).convert('RGB')

            w, h = im2show.size

            rest_idx = []
            # import pdb
            # pdb.set_trace()
            proposals_one = proposals[i].numpy()
            ppl_mask = np.all(np.equal(proposals_one, 0), axis=1)
            proposals_one = proposals_one[~ppl_mask]
            # if i == 2:

            # det_idx = det_idx[:proposals_one.shape[0]]
            new_det_idx = []
            for j in range(len(det_idx)):
                if det_idx[j] < proposals_one.shape[0] and det_idx[
                        j] not in new_det_idx:
                    new_det_idx.append(det_idx[j])
            det_idx = new_det_idx
            for j in range(proposals_one.shape[0]):
                if j not in det_idx:
                    rest_idx.append(j)

            if len(det_idx) > 0:
                # for visulization

                proposals_one[:, 0] = proposals_one[:, 0] * w / float(
                    opt.image_crop_size)
                proposals_one[:, 2] = proposals_one[:, 2] * w / float(
                    opt.image_crop_size)
                proposals_one[:, 1] = proposals_one[:, 1] * h / float(
                    opt.image_crop_size)
                proposals_one[:, 3] = proposals_one[:, 3] * h / float(
                    opt.image_crop_size)

                cls_dets = proposals_one[det_idx]
                rest_dets = proposals_one[rest_idx]

            fig = plt.figure(frameon=False)
            # fig.set_size_inches(5,5*h/w)
            ax = plt.Axes(fig, [0., 0., 1., 1.])
            ax.set_axis_off()
            fig.add_axes(ax)
            a = fig.gca()
            a.set_frame_on(False)
            a.set_xticks([])
            a.set_yticks([])
            plt.axis('off')
            plt.xlim(0, w)
            plt.ylim(h, 0)
            # fig, ax = plt.subplots(1)

            # show other box in grey.

            plt.imshow(im2show)

            if len(rest_idx) > 0:
                for j in range(len(rest_dets)):
                    ax = utils.vis_detections(
                        ax, dataset_val.itoc[int(rest_dets[j, 4])],
                        rest_dets[j, :5], j, 1)
            # import pdb
            # pdb.set_trace()
            if len(det_idx) > 0:
                for j in range(len(cls_dets)):
                    ax = utils.vis_detections(
                        ax, dataset_val.itoc[int(cls_dets[j, 4])],
                        cls_dets[j, :5], j, 0)

            # plt.axis('off')
            # plt.axis('tight')
            # plt.tight_layout()
            # import pdb
            # pdb.set_trace()
            fig.savefig(
                '/import/nobackup_mmv_ioannisp/tx301/vg_feature/visu_relation'
                + save_name + '/%d.jpg' % (img_id[i].item()),
                bbox_inches='tight',
                pad_inches=0,
                dpi=150)
            # fig.savefig('visu_relation/%d.jpg' % (img_id[i].item()),
            #             bbox_inches='tight', pad_inches=0, dpi=150)
            print(str(img_id[i].item()) + ': ' + sents[i])

            entry = {'image_id': img_id[i].item(), 'caption': sents[i]}
            predictions.append(entry)

    return predictions