def gen_samples(params):
    # For fast training
    #cudnn.benchmark = True
    gpu_id = 0
    use_cuda = params['cuda']
    b_sz = params['batch_size']

    if params['use_same_g']:
        if len(params['use_same_g']) == 1:
            gCV = torch.load(params['use_same_g'][0])
    solvers = []
    configs = []
    for i, mfile in enumerate(params['model']):
        model = torch.load(mfile)
        configs.append(model['arch'])
        configs[-1]['pretrained_model'] = mfile
        configs[-1]['load_encoder'] = 1
        configs[-1]['load_discriminator'] = 0 if params[
            'evaluating_discr'] is not None else 1
        if i == 0:
            configs[i]['onlypretrained_discr'] = params['evaluating_discr']
        else:
            configs[i]['onlypretrained_discr'] = None

        if params['withExtMask'] and params['mask_size'] != 32:
            configs[-1]['lowres_mask'] = 0
            configs[-1]['load_encoder'] = 0
        else:
            params['mask_size'] = 32

        solvers.append(
            Solver(None,
                   None,
                   ParamObject(configs[-1]),
                   mode='test' if i > 0 else 'eval',
                   pretrainedcv=model))
        solvers[-1].G.eval()
        if configs[-1]['train_boxreconst'] > 0:
            solvers[-1].E.eval()
        if params['use_same_g']:
            solvers[-1].no_inpainter = 0
            solvers[-1].load_pretrained_generator(gCV)
            print 'loaded generator again'

    solvers[0].D.eval()
    solvers[0].D_cls.eval()

    dataset = get_dataset('',
                          '',
                          params['image_size'],
                          params['image_size'],
                          params['dataset'],
                          params['split'],
                          select_attrs=configs[0]['selected_attrs'],
                          datafile=params['datafile'],
                          bboxLoader=1,
                          bbox_size=params['box_size'],
                          randomrotate=params['randomrotate'],
                          randomscale=params['randomscale'],
                          max_object_size=params['max_object_size'],
                          use_gt_mask=configs[0]['use_gtmask_inp'],
                          onlyrandBoxes=params['extmask_type'] == 'randbox',
                          square_resize=configs[0].get('square_resize', 0)
                          if params['square_resize_override'] < 0 else
                          params['square_resize_override'],
                          filter_by_mincooccur=params['filter_by_mincooccur'],
                          only_indiv_occur=params['only_indiv_occur'])

    #gt_mask_data = get_dataset('','', params['mask_size'], params['mask_size'], params['dataset'], params['split'],
    #                         select_attrs=configs[0]['selected_attrs'], bboxLoader=0, loadMasks = True)
    #data_iter = DataLoader(targ_split, batch_size=b_sz, shuffle=True, num_workers=8)
    targ_split = dataset  #train if params['split'] == 'train' else valid if params['split'] == 'val' else test
    data_iter = np.random.permutation(len(targ_split))

    if params['computeSegAccuracy']:
        gt_mask_data = get_dataset('',
                                   '',
                                   params['mask_size'],
                                   params['mask_size'],
                                   params['dataset'],
                                   params['split'],
                                   select_attrs=configs[0]['selected_attrs'],
                                   bboxLoader=0,
                                   loadMasks=True)
        commonIds = set(gt_mask_data.valid_ids).intersection(
            set(dataset.valid_ids))
        commonIndexes = [
            i for i in xrange(len(dataset.valid_ids))
            if dataset.valid_ids[i] in commonIds
        ]
        data_iter = commonIndexes

    if params['withExtMask'] and (params['extmask_type'] == 'mask'):
        ext_mask_data = get_dataset(
            '',
            '',
            params['mask_size'],
            params['mask_size'],
            params['dataset']
            if params['extMask_source'] == 'gt' else params['extMask_source'],
            params['split'],
            select_attrs=configs[0]['selected_attrs'],
            bboxLoader=0,
            loadMasks=True)
        curr_valid_ids = [dataset.valid_ids[i] for i in data_iter]
        commonIds = set(ext_mask_data.valid_ids).intersection(
            set(curr_valid_ids))
        commonIndexes = [
            i for i in xrange(len(dataset.valid_ids))
            if dataset.valid_ids[i] in commonIds
        ]
        data_iter = commonIndexes

    if params['nImages'] > -1:
        data_iter = data_iter[:params['nImages']]

    print('-----------------------------------------')
    print('%s' % (' | '.join(targ_split.selected_attrs)))
    print('-----------------------------------------')

    flatten = lambda l: [item for sublist in l for item in sublist]
    selected_attrs = configs[0]['selected_attrs']

    if params['showreconst'] and len(params['names']) > 0:
        params['names'] = flatten([[nm, nm + '-R'] for nm in params['names']])

    #discriminator.load_state_dict(cv['discriminator_state_dict'])
    c_idx = 0
    np.set_printoptions(precision=2)
    padimg = np.zeros((params['image_size'], 5, 3), dtype=np.uint8)
    padimg[:, :, :] = 128
    vggLoss = VGGLoss(network='squeeze')
    cimg_cnt = 0

    perclass_removeSucc = np.zeros((len(selected_attrs)))
    perclass_confusion = np.zeros((len(selected_attrs), len(selected_attrs)))
    perclass_classScoreDrop = np.zeros(
        (len(selected_attrs), len(selected_attrs)))
    perclass_cooccurence = np.zeros(
        (len(selected_attrs), len(selected_attrs))) + 1e-6
    perclass_vgg = np.zeros((len(selected_attrs)))
    perclass_ssim = np.zeros((len(selected_attrs)))
    perclass_psnr = np.zeros((len(selected_attrs)))
    perclass_tp = np.zeros((len(selected_attrs)))
    perclass_fp = np.zeros((len(selected_attrs)))
    perclass_fn = np.zeros((len(selected_attrs)))
    perclass_acc = np.zeros((len(selected_attrs)))
    perclass_counts = np.zeros((len(selected_attrs))) + 1e-6
    perclass_int = np.zeros((len(selected_attrs)))
    perclass_union = np.zeros((len(selected_attrs)))
    perclass_gtsize = np.zeros((len(selected_attrs)))
    perclass_predsize = np.zeros((len(selected_attrs)))
    perclass_segacc = np.zeros((len(selected_attrs)))
    perclass_msz = np.zeros((len(selected_attrs)))
    #perclass_th = Variable(torch.FloatTensor(np.array([0., 0.5380775, -0.49303985, -0.48941165, 2.8394265, -0.37880898, 1.0709367, 1.6613332, -1.5602279, 1.2631614, 2.4104881, -0.29175103, -0.6607682, -0.2128999, -1.286599, -2.24577, -0.4130093, -1.0535073, 0.038890466, -0.6808476]))).cuda()

    perclass_th = Variable(torch.FloatTensor(np.zeros(
        (len(selected_attrs))))).cuda()

    perImageRes = {'images': {}, 'overall': {}}
    total_count = 0.
    if params['computeAP']:
        allScores = []
        allGT = []
        allEditedSc = []
    if params['dilateMask']:
        dilateWeight = torch.ones(
            (1, 1, params['dilateMask'], params['dilateMask']))
        dilateWeight = Variable(dilateWeight, requires_grad=False).cuda()
    else:
        dilateWeight = None

    all_masks = []
    all_imgidAndCls = []

    for i in tqdm(xrange(len(data_iter))):
        #for i in tqdm(xrange(2)):
        idx = data_iter[i]
        x, real_label, boxImg, boxlabel, mask, bbox, curCls = targ_split[idx]
        cocoid = targ_split.getcocoid(idx)
        nnz_cls = real_label.nonzero()
        z_cls = (1 - real_label).nonzero()

        z_cls = z_cls[:, 0] if len(z_cls.size()) > 1 else z_cls
        x = x[None, ::]
        boxImg = boxImg[None, ::]
        mask = mask[None, ::]
        boxlabel = boxlabel[None, ::]
        real_label = real_label[None, ::]

        x, boxImg, mask, boxlabel = solvers[0].to_var(
            x, volatile=True), solvers[0].to_var(
                boxImg, volatile=True), solvers[0].to_var(
                    mask, volatile=True), solvers[0].to_var(boxlabel,
                                                            volatile=True)
        real_label = solvers[0].to_var(real_label, volatile=True)
        _, out_cls_real = solvers[0].classify(x)
        out_cls_real = out_cls_real[0]  # Remove the singleton dimension
        pred_real_label = (out_cls_real > perclass_th)
        total_count += 1
        #;import ipdb; ipdb.set_trace()

        if params['computeAP']:
            allScores.append(out_cls_real[None, :])
            allGT.append(real_label)
            removeScores = out_cls_real.clone()

        perclass_acc[(pred_real_label.float() == real_label
                      )[0, :].data.cpu().numpy().astype(np.bool)] += 1.
        if len(z_cls):
            perclass_fp[z_cls.numpy()] += pred_real_label.data.cpu()[z_cls]
        if len(nnz_cls):
            nnz_cls = nnz_cls[:, 0]
            perclass_tp[nnz_cls.numpy()] += pred_real_label.data.cpu()[nnz_cls]
            perclass_fn[
                nnz_cls.numpy()] += 1 - pred_real_label.data.cpu()[nnz_cls]

            perImageRes['images'][cocoid] = {'perclass': {}}
            if params['dump_cls_results']:
                perImageRes['images'][cocoid]['real_label'] = nnz_cls.tolist()
                perImageRes['images'][cocoid][
                    'real_scores'] = out_cls_real.data.cpu().tolist()
            if not params['eval_only_discr']:
                for cid in nnz_cls:
                    if configs[0]['use_gtmask_inp']:
                        mask = solvers[0].to_var(targ_split.getGTMaskInp(
                            idx, configs[0]['selected_attrs'][cid])[None, ::],
                                                 volatile=True)
                    if params['withExtMask']:
                        if params['extmask_type'] == 'mask':
                            mask = solvers[0].to_var(
                                ext_mask_data.getbyIdAndclass(
                                    cocoid, configs[0]['selected_attrs'][cid])[
                                        None, ::],
                                volatile=True)
                        elif params['extmask_type'] == 'box':
                            mask = solvers[0].to_var(dataset.getGTMaskInp(
                                idx,
                                configs[0]['selected_attrs'][cid],
                                mask_type=2)[None, ::],
                                                     volatile=True)
                        elif params['extmask_type'] == 'randbox':
                            # Nothing to do here, mask is already set to random boxes
                            None
                    if params['computeSegAccuracy']:
                        gtMask = gt_mask_data.getbyIdAndclass(
                            cocoid, configs[0]['selected_attrs'][cid]).cuda()
                    mask_target = torch.zeros_like(real_label)
                    fake_label = real_label.clone()
                    fake_label[0, cid] = 0.
                    mask_target[0, cid] = 1
                    fake_x, mask_out = solvers[0].forward_generator(
                        x,
                        imagelabel=mask_target,
                        mask_threshold=params['mask_threshold'],
                        onlyMasks=False,
                        mask=mask,
                        withGTMask=params['withExtMask'],
                        dilate=dilateWeight)
                    _, out_cls_fake = solvers[0].classify(fake_x)
                    out_cls_fake = out_cls_fake[
                        0]  # Remove the singleton dimension
                    mask_out = mask_out.data[0, ::]

                    if params['dump_mask']:
                        all_masks.append(mask_out.cpu().numpy())
                        all_imgidAndCls.append((cocoid, selected_attrs[cid]))

                    perImageRes['images'][cocoid]['perclass'][
                        selected_attrs[cid]] = {}
                    if params['computeSegAccuracy']:
                        union = torch.clamp((gtMask + mask_out), max=1.0).sum()
                        intersection = (gtMask * mask_out).sum()
                        img_iou = (intersection / (union + 1e-6))
                        img_acc = (gtMask == mask_out).float().mean()
                        img_recall = ((intersection / (gtMask.sum() + 1e-6)))
                        img_precision = (intersection /
                                         (mask_out.sum() + 1e-6))
                        perImageRes['images'][cocoid]['perclass'][
                            selected_attrs[cid]].update({
                                'iou': img_iou,
                                'rec': img_recall,
                                'prec': img_precision,
                                'acc': img_acc
                            })
                        perImageRes['images'][cocoid]['perclass'][
                            selected_attrs[cid]]['gtSize'] = gtMask.mean()
                        perImageRes['images'][cocoid]['perclass'][
                            selected_attrs[cid]]['predSize'] = mask_out.mean()

                        # Compute metrics now
                        perclass_counts[cid] += 1
                        perclass_int[cid] += intersection
                        perclass_union[cid] += union
                        perclass_gtsize[cid] += gtMask.sum()
                        perclass_predsize[cid] += mask_out.sum()
                        perclass_segacc[cid] += img_acc
                    if params['dump_cls_results']:
                        perImageRes['images'][cocoid]['perclass'][
                            selected_attrs[
                                cid]]['remove_scores'] = out_cls_fake.data.cpu(
                                ).tolist()
                    perImageRes['images'][cocoid]['perclass'][
                        selected_attrs[cid]]['diff'] = out_cls_real.data[
                            cid] - out_cls_fake.data[cid]

                    remove_succ = float(
                        (out_cls_fake.data[cid] <
                         perclass_th[cid]))  # and (out_cls_real[cid]>0.))
                    perclass_removeSucc[cid] += remove_succ
                    vL = vggLoss(fake_x, x).data[0]
                    perclass_vgg[cid] += 100. * vL

                    fake_x_sk = get_sk_image(fake_x)
                    x_sk = get_sk_image(x)
                    pSNR = compare_psnr(fake_x_sk, x_sk, data_range=255.)
                    ssim = compare_ssim(fake_x_sk,
                                        x_sk,
                                        data_range=255.,
                                        multichannel=True)
                    msz = mask_out.mean()
                    if msz > 0.:
                        perclass_ssim[cid] += ssim
                        perclass_psnr[cid] += pSNR

                    if params['computeAP']:
                        removeScores[cid] = out_cls_fake[cid]

                    #---------------------------------------------------------------
                    # These are classes not trying to be removed;
                    # correctly detect on real image and not detected on fake image
                    # This are collateral damage. Count these
                    #---------------------------------------------------------------
                    false_remove = fake_label.byte() * (
                        out_cls_fake < perclass_th) * (out_cls_real >
                                                       perclass_th)
                    perclass_cooccurence[cid, nnz_cls.numpy()] += 1.
                    perclass_confusion[cid,
                                       false_remove.data.cpu().numpy().
                                       astype(np.bool)[0, :]] += 1

                    perImageRes['images'][cocoid]['perclass'][
                        selected_attrs[cid]].update({
                            'remove_succ':
                            remove_succ,
                            'false_remove':
                            float(
                                false_remove.data.cpu().float().numpy().sum()),
                            'perceptual':
                            100. * vL
                        })
                if params['computeAP']:
                    allEditedSc.append(removeScores[None, :])

                perImageRes['images'][cocoid]['overall'] = {}
                perImageRes['images'][cocoid]['overall'][
                    'remove_succ'] = np.mean([
                        perImageRes['images'][cocoid]['perclass'][cls]
                        ['remove_succ']
                        for cls in perImageRes['images'][cocoid]['perclass']
                    ])
                perImageRes['images'][cocoid]['overall'][
                    'false_remove'] = np.mean([
                        perImageRes['images'][cocoid]['perclass'][cls]
                        ['false_remove']
                        for cls in perImageRes['images'][cocoid]['perclass']
                    ])
                perImageRes['images'][cocoid]['overall'][
                    'perceptual'] = np.mean([
                        perImageRes['images'][cocoid]['perclass'][cls]
                        ['perceptual']
                        for cls in perImageRes['images'][cocoid]['perclass']
                    ])
                perImageRes['images'][cocoid]['overall']['diff'] = np.mean([
                    perImageRes['images'][cocoid]['perclass'][cls]['diff']
                    for cls in perImageRes['images'][cocoid]['perclass']
                ])
                if params['computeSegAccuracy']:
                    perImageRes['images'][cocoid]['overall']['iou'] = np.mean([
                        perImageRes['images'][cocoid]['perclass'][cls]['iou']
                        for cls in perImageRes['images'][cocoid]['perclass']
                    ])
                    perImageRes['images'][cocoid]['overall']['acc'] = np.mean([
                        perImageRes['images'][cocoid]['perclass'][cls]['acc']
                        for cls in perImageRes['images'][cocoid]['perclass']
                    ])
                    perImageRes['images'][cocoid]['overall']['prec'] = np.mean(
                        [
                            perImageRes['images'][cocoid]['perclass'][cls]
                            ['prec'] for cls in perImageRes['images'][cocoid]
                            ['perclass']
                        ])
                    perImageRes['images'][cocoid]['overall']['rec'] = np.mean([
                        perImageRes['images'][cocoid]['perclass'][cls]['rec']
                        for cls in perImageRes['images'][cocoid]['perclass']
                    ])

        elif params['dump_cls_results']:
            perImageRes['images'][cocoid] = {'perclass': {}}
            perImageRes['images'][cocoid]['real_label'] = nnz_cls.tolist()
            perImageRes['images'][cocoid][
                'real_scores'] = out_cls_real.data.cpu().tolist()

    if params['dump_mask']:
        np.savez('allMasks.npz',
                 masks=np.concatenate(all_masks).astype(np.uint8),
                 idAndClass=np.stack(all_imgidAndCls))

    if params['computeAP']:
        allScores = torch.cat(allScores, dim=0).data.cpu().numpy()
        allGT = torch.cat(allGT, dim=0).data.cpu().numpy()
        apR = computeAP(allScores, allGT)
        if not params['eval_only_discr']:
            allEditedSc = torch.cat(allEditedSc, dim=0).data.cpu().numpy()
            apEdited = computeAP(allEditedSc, allGT)
    #for i in xrange(len(selected_attrs)):
    #    pr,rec,th = precision_recall_curve(allGTArr[:,i],allPredArr[:,i]);
    #    f1s = 2*(pr*rec)/(pr+rec); mf1idx = np.argmax(f1s);
    #    #print 'Max f1 = %.2f, th =%.2f'%(f1s[mf1idx], th[mf1idx]);
    #    allMf1s.append(f1s[mf1idx])
    #    allTh.append(th[mf1idx])
    recall = perclass_tp / (perclass_tp + perclass_fn + 1e-6)
    precision = perclass_tp / (perclass_tp + perclass_fp + 1e-6)
    f1_score = 2.0 * (recall * precision) / (recall + precision + 1e-6)

    present_classes = (perclass_tp + perclass_fn) > 0.
    perclass_gt_counts = (perclass_tp + perclass_fn)
    apROverall = (perclass_gt_counts * apR).sum() / (perclass_gt_counts.sum())
    apR = apR[present_classes]
    recall = recall[present_classes]
    f1_score = f1_score[present_classes]
    precision = precision[present_classes]
    perclass_acc = perclass_acc[present_classes]
    present_attrs = [
        att for i, att in enumerate(targ_split.selected_attrs)
        if present_classes[i]
    ]

    rec_overall = perclass_tp.sum() / (perclass_tp.sum() + perclass_fn.sum() +
                                       1e-6)
    prec_overall = perclass_tp.sum() / (perclass_tp.sum() + perclass_fp.sum() +
                                        1e-6)
    f1_score_overall = 2.0 * (rec_overall * prec_overall) / (
        rec_overall + prec_overall + 1e-6)
    print '------------------------------------------------------------'
    print '                Metrics have been computed                  '
    print '------------------------------------------------------------'
    print('Score: || %s |' % (' | '.join(
        ['%6s' % att[:6] for att in ['Overall', 'OverCls'] + present_attrs])))
    print('Acc  : || %s |' % (' | '.join([
        '  %.2f' % sc for sc in [(perclass_acc / total_count).mean()] +
        [(perclass_acc / total_count).mean()] +
        list(perclass_acc / total_count)
    ])))
    print('F1-sc: || %s |' % (' | '.join([
        '  %.2f' % sc
        for sc in [f1_score_overall] + [f1_score.mean()] + list(f1_score)
    ])))
    print('recal: || %s |' % (' | '.join([
        '  %.2f' % sc for sc in [rec_overall] + [recall.mean()] + list(recall)
    ])))
    print('prec : || %s |' % (' | '.join([
        '  %.2f' % sc
        for sc in [prec_overall] + [precision.mean()] + list(precision)
    ])))
    if params['computeAP']:
        print('AP   : || %s |' % (' | '.join(
            ['  %.2f' % sc
             for sc in [apROverall] + [apR.mean()] + list(apR)])))
    print('Count: || %s |' % (' | '.join([
        '  %4.0f' % sc for sc in [perclass_gt_counts.mean()] +
        [perclass_gt_counts.mean()] + list(perclass_gt_counts[present_classes])
    ])))
    if not params['eval_only_discr']:
        print('R-suc: || %s |' % (' | '.join([
            '  %.2f' % sc for sc in [(perclass_removeSucc.sum() /
                                      perclass_cooccurence.diagonal().sum())] +
            [(perclass_removeSucc / perclass_cooccurence.diagonal()).mean()] +
            list(perclass_removeSucc / perclass_cooccurence.diagonal())
        ])))
        print('R-fal: || %s |' % (' | '.join([
            '  %.2f' % sc
            for sc in [(perclass_confusion.sum() /
                        (perclass_cooccurence.sum() -
                         perclass_cooccurence.diagonal().sum()))] +
            [(perclass_confusion.sum(axis=1) /
              (perclass_cooccurence.sum(axis=1) -
               perclass_cooccurence.diagonal())).mean()] +
            list((perclass_confusion / perclass_cooccurence).sum(axis=1) /
                 (perclass_cooccurence.shape[0] - 1))
        ])))
        print('Percp: || %s |' % (' | '.join([
            '  %.2f' % sc for sc in [(perclass_vgg.sum() /
                                      perclass_cooccurence.diagonal().sum())] +
            [(perclass_vgg / perclass_cooccurence.diagonal()).mean()] +
            list(perclass_vgg / perclass_cooccurence.diagonal())
        ])))
        print('pSNR : || %s |' % (' | '.join([
            ' %.2f' % sc for sc in [(perclass_psnr.sum() /
                                     perclass_cooccurence.diagonal().sum())] +
            [(perclass_psnr / perclass_cooccurence.diagonal()).mean()] +
            list(perclass_psnr / perclass_cooccurence.diagonal())
        ])))
        print('ssim : || %s |' % (' | '.join([
            ' %.3f' % sc for sc in [(perclass_ssim.sum() /
                                     perclass_cooccurence.diagonal().sum())] +
            [(perclass_ssim / perclass_cooccurence.diagonal()).mean()] +
            list(perclass_ssim / perclass_cooccurence.diagonal())
        ])))
        if params['computeAP']:
            print('R-AP : || %s |' % (' | '.join([
                '  %.2f' % sc for sc in [apEdited.mean()] + [apEdited.mean()] +
                list(apEdited)
            ])))

        if params['computeSegAccuracy']:
            print('mIou : || %s |' % (' | '.join([
                '  %.2f' % sc for sc in [(perclass_int.sum() /
                                          (perclass_union + 1e-6).sum())] +
                [(perclass_int / (perclass_union + 1e-6)).mean()] +
                list(perclass_int / (perclass_union + 1e-6))
            ])))
            print('mRec : || %s |' % (' | '.join([
                '  %.2f' % sc for sc in [(perclass_int.sum() /
                                          (perclass_gtsize + 1e-6).sum())] +
                [(perclass_int / (perclass_gtsize + 1e-6)).mean()] +
                list(perclass_int / (perclass_gtsize + 1e-6))
            ])))
            print('mPrc : || %s |' % (' | '.join([
                '  %.2f' % sc
                for sc in [(perclass_int.sum() / (perclass_predsize.sum()))] +
                [(perclass_int / (perclass_predsize + 1e-6)).mean()] +
                list(perclass_int / (perclass_predsize + 1e-6))
            ])))
            print('mSzR : || %s |' % (' | '.join([
                '  %.2f' % sc for sc in [(perclass_predsize.sum() /
                                          (perclass_gtsize.sum()))] +
                [(perclass_predsize / (perclass_gtsize + 1e-6)).mean()] +
                list(perclass_predsize / (perclass_gtsize + 1e-6))
            ])))
            print('Acc  : || %s |' % (' | '.join([
                '  %.2f' % sc
                for sc in [(perclass_segacc.sum() / (perclass_counts.sum()))] +
                [(perclass_segacc / (perclass_counts + 1e-6)).mean()] +
                list(perclass_segacc / (perclass_counts + 1e-6))
            ])))
            print('mSz  : || %s |' % (' | '.join([
                '  %.1f' % sc
                for sc in [(100. *
                            (perclass_predsize.sum() /
                             (params['mask_size'] * params['mask_size'] *
                              perclass_counts).sum()))] +
                [(100. * (perclass_predsize /
                          (params['mask_size'] * params['mask_size'] *
                           perclass_counts + 1e-6))).mean()] +
                list((100. * perclass_predsize) /
                     (params['mask_size'] * params['mask_size'] *
                      perclass_counts + 1e-6))
            ])))

        perImageRes['overall'] = {'iou': 0., 'rec': 0., 'prec': 0., 'acc': 0.}
        perImageRes['overall']['remove_succ'] = (
            perclass_removeSucc / perclass_cooccurence.diagonal()).mean()
        perImageRes['overall']['false_remove'] = (perclass_confusion /
                                                  perclass_cooccurence).mean()
        perImageRes['overall']['perceptual'] = (
            perclass_vgg / perclass_cooccurence.diagonal()).mean()
        if params['computeSegAccuracy']:
            perImageRes['overall']['iou'] = (perclass_int /
                                             (perclass_union + 1e-6)).mean()
            perImageRes['overall']['acc'] = (perclass_segacc /
                                             (perclass_counts + 1e-6)).mean()
            perImageRes['overall']['prec'] = (
                perclass_int / (perclass_predsize + 1e-6)).mean()
            perImageRes['overall']['psize'] = (perclass_predsize).mean()
            perImageRes['overall']['psize_rel'] = (
                perclass_predsize / (perclass_gtsize + 1e-6)).mean()
            perImageRes['overall']['rec'] = (perclass_int /
                                             (perclass_gtsize + 1e-6)).mean()
        if params['computeAP']:
            perImageRes['overall']['ap-orig'] = list(apR)
            perImageRes['overall']['ap-edit'] = list(apEdited)

    if params['dump_perimage_res']:
        json.dump(
            perImageRes,
            open(
                join(
                    params['dump_perimage_res'], params['split'] + '_' +
                    basename(params['model'][0]).split('.')[0]), 'w'))
Exemplo n.º 2
0
def gen_samples(params):
    # For fast training
    #cudnn.benchmark = True
    gpu_id = 0
    use_cuda = params['cuda']
    b_sz = params['batch_size']

    g_conv_dim = 64
    d_conv_dim = 64
    c_dim = 5
    c2_dim = 8
    g_repeat_num = 6
    d_repeat_num = 6
    select_attrs = []

    if params['use_same_g']:
        if len(params['use_same_g']) == 1:
            gCV = torch.load(params['use_same_g'][0])
    solvers = []
    configs = []
    for i, mfile in enumerate(params['model']):
        model = torch.load(mfile)
        configs.append(model['arch'])
        configs[-1]['pretrained_model'] = mfile
        configs[-1]['load_encoder'] = 1
        configs[-1]['load_discriminator'] = 0
        configs[-1]['image_size'] = params['image_size']
        if 'g_downsamp_layers' not in configs[-1]:
            configs[-1]['g_downsamp_layers'] = 2
        if 'g_dil_start' not in configs[-1]:
            configs[-1]['g_dil_start'] = 0
            configs[-1]['e_norm_type'] = 'drop'
            configs[-1]['e_ksize'] = 4
        if len(params['withExtMask']) and params['mask_size'] != 32:
            if params['withExtMask'][i]:
                configs[-1]['lowres_mask'] = 0
                configs[-1]['load_encoder'] = 0

        solvers.append(
            Solver(None,
                   None,
                   ParamObject(configs[-1]),
                   mode='test',
                   pretrainedcv=model))
        solvers[-1].G.eval()
        #solvers[-1].D.eval()
        if configs[-1]['train_boxreconst'] > 0 and solvers[-1].E is not None:
            solvers[-1].E.eval()
        if params['use_same_g']:
            solvers[-1].load_pretrained_generator(gCV)

    if len(params['dilateMask']):
        assert (len(params['model']) == len(params['dilateMask']))
        dilateWeightAll = []
        for di in range(len(params['dilateMask'])):
            if params['dilateMask'][di] > 0:
                dilateWeight = torch.ones(
                    (1, 1, params['dilateMask'][di], params['dilateMask'][di]))
                dilateWeight = Variable(dilateWeight,
                                        requires_grad=False).cuda()
            else:
                dilateWeight = None
            dilateWeightAll.append(dilateWeight)
    else:
        dilateWeightAll = [None for i in range(len(params['model']))]

    dataset = get_dataset('',
                          '',
                          params['image_size'],
                          params['image_size'],
                          params['dataset'],
                          params['split'],
                          select_attrs=configs[0]['selected_attrs'],
                          datafile=params['datafile'],
                          bboxLoader=1,
                          bbox_size=params['box_size'],
                          randomrotate=params['randomrotate'],
                          randomscale=params['randomscale'],
                          max_object_size=params['max_object_size'],
                          use_gt_mask=configs[0]['use_gtmask_inp'],
                          n_boxes=params['n_boxes'],
                          onlyrandBoxes=(params['extmask_type'] == 'randbox'))
    #data_iter = DataLoader(targ_split, batch_size=b_sz, shuffle=True, num_workers=8)
    targ_split = dataset  #train if params['split'] == 'train' else valid if params['split'] == 'val' else test
    data_iter = np.random.permutation(len(targ_split))

    if len(params['withExtMask']) and (params['extmask_type'] == 'mask'):
        gt_mask_data = get_dataset(
            '',
            '',
            params['mask_size'],
            params['mask_size'],
            params['dataset']
            if params['extMask_source'] == 'gt' else params['extMask_source'],
            params['split'],
            select_attrs=configs[0]['selected_attrs'],
            bboxLoader=0,
            loadMasks=True)
    if len(params['sort_by']):
        resFiles = [json.load(open(fil, 'r')) for fil in params['sort_by']]
        for i in range(len(resFiles)):
            #if params['sort_score'] not in resFiles[i]['images'][resFiles[i]['images'].keys()[0]]['overall']:
            for k in resFiles[i]['images']:
                img = resFiles[i]['images'][k]
                if 'overall' in resFiles[i]['images'][k]:
                    resFiles[i]['images'][k]['overall'][
                        params['sort_score']] = np.mean([
                            img['perclass'][cls][params['sort_score']]
                            for cls in img['perclass']
                        ])
                else:
                    resFiles[i]['images'][k]['overall'] = {}
                    resFiles[i]['images'][k]['overall'][
                        params['sort_score']] = np.mean([
                            img['perclass'][cls][params['sort_score']]
                            for cls in img['perclass']
                        ])
        idToScore = {
            int(k): resFiles[0]['images'][k]['overall'][params['sort_score']]
            for k in resFiles[0]['images']
        }
        idToScore = OrderedDict(
            reversed(sorted(list(idToScore.items()), key=lambda t: t[1])))
        cocoIdToindex = {v: i for i, v in enumerate(dataset.valid_ids)}
        data_iter = [cocoIdToindex[k] for k in idToScore]
        dataIt2id = {cocoIdToindex[k]: str(k) for k in idToScore}

    if len(params['show_ids']) > 0:
        cocoIdToindex = {v: i for i, v in enumerate(dataset.valid_ids)}
        data_iter = [cocoIdToindex[k] for k in params['show_ids']]

    print(len(data_iter))

    print('-----------------------------------------')
    print(('%s' % (' | '.join(targ_split.selected_attrs))))
    print('-----------------------------------------')

    flatten = lambda l: [item for sublist in l for item in sublist]

    if params['showreconst'] and len(params['names']) > 0:
        params['names'] = flatten([[nm, nm + '-R'] for nm in params['names']])

    #discriminator.load_state_dict(cv['discriminator_state_dict'])
    c_idx = 0
    np.set_printoptions(precision=2)
    padimg = np.zeros((params['image_size'], 5, 3), dtype=np.uint8)
    padimg[:, :, :] = 128
    if params['showperceptionloss']:
        vggLoss = VGGLoss(network='squeeze')
    cimg_cnt = 0
    mean_hist = [[], [], []]
    max_hist = [[], [], []]
    lengths_hist = [[], [], []]
    if len(params['n_iter']) == 0:
        params['n_iter'] = [0] * len(params['model'])
    while True:
        cimg_cnt += 1
        #import ipdb; ipdb.set_trace()
        idx = data_iter[c_idx]
        x, real_label, boxImg, boxlabel, mask, bbox, curCls = targ_split[
            data_iter[c_idx]]
        fp = [targ_split.getfilename(data_iter[c_idx])]

        #if configs[0]['use_gtmask_inp']:
        #    mask = mask[1:,::]

        x = x[None, ::]
        boxImg = boxImg[None, ::]
        mask = mask[None, ::]
        boxlabel = boxlabel[None, ::]
        real_label = real_label[None, ::]

        x, boxImg, mask, boxlabel = solvers[0].to_var(
            x, volatile=True), solvers[0].to_var(
                boxImg, volatile=True), solvers[0].to_var(
                    mask, volatile=True), solvers[0].to_var(boxlabel,
                                                            volatile=True)
        real_label = solvers[0].to_var(real_label, volatile=True)

        fake_image_list = [x]
        if params['showmask']:
            mask_image_list = [x - x]
        else:
            fake_image_list.append(x * (1 - mask) + mask)

        deformList = [[], []]
        if len(real_label[0, :].nonzero()):
            #rand_idx = random.choice(real_label[0,:].nonzero()).data[0]
            rand_idx = curCls[0]
            print(configs[0]['selected_attrs'][rand_idx])
            if len(params['withExtMask']):
                cocoid = targ_split.getcocoid(idx)
                if params['extmask_type'] == 'mask':
                    mask = solvers[0].to_var(gt_mask_data.getbyIdAndclass(
                        cocoid,
                        configs[0]['selected_attrs'][rand_idx])[None, ::],
                                             volatile=True)
                elif params['extmask_type'] == 'box':
                    mask = solvers[0].to_var(dataset.getGTMaskInp(
                        idx,
                        configs[0]['selected_attrs'][rand_idx],
                        mask_type=2)[None, ::],
                                             volatile=True)
                elif params['extmask_type'] == 'randbox':
                    # Nothing to do here, mask is already set to random boxes
                    None
        else:
            rand_idx = curCls[0]
        if params['showdiff']:
            diff_image_list = [x - x] if params['showmask'] else [x - x, x - x]
        for i in range(len(params['model'])):
            if configs[i]['use_gtmask_inp']:
                mask = solvers[0].to_var(targ_split.getGTMaskInp(
                    idx,
                    configs[0]['selected_attrs'][rand_idx],
                    mask_type=configs[i]['use_gtmask_inp'])[None, ::],
                                         volatile=True)
            if len(params['withExtMask']) or params['no_maskgen']:
                withGTMask = True if params['no_maskgen'] else params[
                    'withExtMask'][i]
            else:
                withGTMask = False

            if configs[i]['train_boxreconst'] == 3:
                mask_target = torch.zeros_like(real_label)
                if len(real_label[0, :].nonzero()):
                    mask_target[0, rand_idx] = 1
                # This variable informs to the mask generator, which class to generate for
                boxlabelInp = boxlabel

            elif configs[i]['train_boxreconst'] == 2:
                boxlabelfake = torch.zeros_like(boxlabel)
                if configs[i]['use_box_label'] == 2:
                    boxlabelInp = torch.cat([boxlabel, boxlabelfake], dim=1)
                    if params['showreconst']:
                        boxlabelInpRec = torch.cat([boxlabelfake, boxlabel],
                                                   dim=1)
                mask_target = real_label
            else:
                boxlabelInp = boxlabel
                mask_target = real_label
            if params['showdeform']:
                img, maskOut, deform = solvers[i].forward_generator(
                    x,
                    boxImg=boxImg,
                    mask=mask,
                    imagelabel=mask_target,
                    boxlabel=boxlabelInp,
                    get_feat=True,
                    mask_threshold=params['mask_threshold'],
                    withGTMask=withGTMask,
                    dilate=dilateWeightAll[i],
                    n_iter=params['n_iter'][i])
                fake_image_list.append(img)
                deformList.append(deform)
            else:
                img, maskOut = solvers[i].forward_generator(
                    x,
                    boxImg=boxImg,
                    mask=mask,
                    imagelabel=mask_target,
                    boxlabel=boxlabelInp,
                    mask_threshold=params['mask_threshold'],
                    withGTMask=withGTMask,
                    dilate=dilateWeightAll[i],
                    n_iter=params['n_iter'][i])
                fake_image_list.append(img)
            if params['showmask']:
                mask_image_list.append(
                    solvers[i].getImageSizeMask(maskOut)[:, [0, 0, 0], ::])
            if params['showdiff']:
                diff_image_list.append(x - fake_image_list[-1])
            if params['showreconst']:
                if params['showdeform']:
                    img, maskOut, deform = solvers[i].forward_generator(
                        fake_image_list[-1],
                        boxImg=boxImg,
                        mask=mask,
                        imagelabel=mask_target,
                        boxlabel=boxlabelInp,
                        get_feat=True,
                        mask_threshold=params['mask_threshold'],
                        withGTMask=withGTMask,
                        dilate=dilateWeightAll[i],
                        n_iter=params['n_iter'][i])
                    fake_image_list.append(img)
                    deformList.append(deform)
                else:
                    img, maskOut = solvers[i].forward_generator(
                        fake_image_list[-1],
                        boxImg=boxImg,
                        mask=mask,
                        imagelabel=mask_target,
                        boxlabel=boxlabelInp,
                        mask_threshold=params['mask_threshold'],
                        withGTMask=withGTMask,
                        dilate=dilateWeightAll[i],
                        n_iter=params['n_iter'][i])
                    fake_image_list.append(img)
                if params['showdiff']:
                    diff_image_list.append(x - fake_image_list[-1])

        if not params['compute_deform_stats']:
            img = make_image(fake_image_list, padimg)
            if params['showdeform']:
                defImg = make_image_with_deform(
                    fake_image_list, deformList,
                    np.vstack([padimg, padimg, padimg]))
                img = np.vstack([img, defImg])
            if params['showmask']:
                imgmask = make_image(mask_image_list, padimg)
                img = np.vstack([img, imgmask])
            if params['showdiff']:
                imgdiff = make_image(diff_image_list, padimg)
                img = np.vstack([img, imgdiff])
            if len(params['names']) > 0:
                nameList = ['Input'
                            ] + params['names'] if params['showmask'] else [
                                'Input', 'Masked Input'
                            ] + params['names']
                imgNames = np.hstack(
                    flatten([[
                        make_image_with_text((32, x.size(3), 3), nm),
                        padimg[:32, :, :].astype(np.uint8)
                    ] for nm in nameList]))
                img = np.vstack([imgNames, img])
            if len(params['sort_by']):
                clsname = configs[0]['selected_attrs'][rand_idx]
                cocoid = dataIt2id[data_iter[c_idx]]
                curr_class_iou = [
                    resFiles[i]['images'][cocoid]['real_scores'][rand_idx]
                ] + [
                    resFiles[i]['images'][cocoid]['perclass'][clsname]
                    [params['sort_score']] for i in range(len(params['model']))
                ]
                if params['showperceptionloss']:
                    textToPrint = [
                        'P:%.2f, S:%.1f' % (vggLoss(
                            fake_image_list[0],
                            fake_image_list[i]).data[0], curr_class_iou[i])
                        for i in range(len(fake_image_list))
                    ]
                else:
                    textToPrint = [
                        'S:%.1f' % (curr_class_iou[i])
                        for i in range(len(fake_image_list))
                    ]
                if len(params['show_also']):
                    # Additional data to print
                    for val in params['show_also']:
                        curval = [0.] + [
                            resFiles[i]['images'][cocoid]['perclass'][clsname]
                            [val][rand_idx]
                            for i in range(len(params['model']))
                        ]
                        textToPrint = [
                            txt + ' %s:%.1f' % (val[0], curval[i])
                            for i, txt in enumerate(textToPrint)
                        ]

                imgScore = np.hstack(
                    flatten([[
                        make_image_with_text((32, x.size(3), 3),
                                             textToPrint[i]),
                        padimg[:32, :, :].astype(np.uint8)
                    ] for i in range(len(fake_image_list))]))
                img = np.vstack([img, imgScore])
            elif params['showperceptionloss']:
                imgScore = np.hstack(
                    flatten([[
                        make_image_with_text(
                            (32, x.size(3), 3),
                            '%.2f' % vggLoss(fake_image_list[0],
                                             fake_image_list[i]).data[0]),
                        padimg[:32, :, :].astype(np.uint8)
                    ] for i in range(len(fake_image_list))]))
                img = np.vstack([img, imgScore])

            #if params['showmask']:
            #    imgmask = make_image(mask_list)
            #    img = np.vstack([img, imgmask])
            #if params['compmodel']:
            #    imgcomp = make_image(fake_image_list_comp)
            #    img = np.vstack([img, imgcomp])
            #    if params['showdiff']:
            #        imgdiffcomp = make_image([fimg - fake_image_list_comp[0] for fimg in fake_image_list_comp])
            #        img = np.vstack([img, imgdiffcomp])
            cv2.imshow(
                'frame', img if params['scaleDisp'] == 0 else cv2.resize(
                    img, None, fx=params['scaleDisp'], fy=params['scaleDisp']))
            keyInp = cv2.waitKey(0)

            if keyInp & 0xFF == ord('q'):
                break
            elif keyInp & 0xFF == ord('b'):
                #print keyInp & 0xFF
                c_idx = c_idx - 1
            elif (keyInp & 0xFF == ord('s')):
                #sample_dir = join(params['sample_dump_dir'], basename(params['model'][0]).split('.')[0])
                sample_dir = join(
                    params['sample_dump_dir'],
                    '_'.join([params['split']] + params['names']))
                if not exists(sample_dir):
                    makedirs(sample_dir)
                fnames = ['%s.png' % splitext(basename(f))[0] for f in fp]
                fpaths = [join(sample_dir, f) for f in fnames]
                imgSaveName = fpaths[0]
                if params['savesepimages']:
                    saveIndividImages(fake_image_list, mask_image_list,
                                      nameList, sample_dir, fp,
                                      configs[0]['selected_attrs'][rand_idx])
                else:
                    print('Saving into file: ' + imgSaveName)
                    cv2.imwrite(imgSaveName, img)
                c_idx += 1
            else:
                c_idx += 1
        else:
            for di in range(len(deformList)):
                if len(deformList[di]) > 0 and len(deformList[di][0]) > 0:
                    for dLidx, d in enumerate(deformList[di]):
                        lengths, mean, maxl = compute_deform_statistics(
                            d[1], d[0])
                        mean_hist[dLidx].append(mean)
                        max_hist[dLidx].append(maxl)
                        lengthsH = np.histogram(lengths,
                                                bins=np.arange(0, 128, 0.5))[0]
                        if lengths_hist[dLidx] == []:
                            lengths_hist[dLidx] = lengthsH
                        else:
                            lengths_hist[dLidx] += lengthsH

        if params['compute_deform_stats'] and (cimg_cnt <
                                               params['compute_deform_stats']):
            print(np.mean(mean_hist[0]))
            print(np.mean(mean_hist[1]))
            print(np.mean(mean_hist[2]))
            print(np.mean(max_hist[0]))
            print(np.mean(max_hist[1]))
            print(np.mean(max_hist[2]))

            print(lengths_hist[0])
            print(lengths_hist[1])
            print(lengths_hist[2])
            break
Exemplo n.º 3
0
def gen_samples(params):
    # For fast training
    #cudnn.benchmark = True
    gpu_id = 0
    use_cuda = params['cuda']
    b_sz = params['batch_size']

    solvers = []
    configs = []
    for i, mfile in enumerate(params['model']):
        model = torch.load(mfile)
        configs.append(model['arch'])
        configs[-1]['pretrained_model'] = mfile
        configs[-1]['load_encoder'] = 1
        configs[-1]['load_discriminator'] = 0
        configs[-1]['image_size'] = params['image_size']
        if i == 0:
            configs[i]['onlypretrained_discr'] = params['evaluating_discr']
        else:
            configs[i]['onlypretrained_discr'] = None

        if params['withExtMask'] and params['mask_size'] != 32:
            configs[-1]['lowres_mask'] = 0
            configs[-1]['load_encoder'] = 0

        solvers.append(
            Solver(None,
                   None,
                   ParamObject(configs[-1]),
                   mode='test' if i > 0 else 'eval',
                   pretrainedcv=model))
        solvers[-1].G.eval()
        if configs[-1]['train_boxreconst'] > 0:
            solvers[-1].E.eval()

    solvers[0].D.eval()
    solvers[0].D_cls.eval()

    dataset = get_dataset(
        '',
        '',
        params['image_size'],
        params['image_size'],
        params['dataset'],
        params['split'],
        select_attrs=configs[0]['selected_attrs'],
        datafile=params['datafile'],
        bboxLoader=1,
        bbox_size=params['box_size'],
        randomrotate=params['randomrotate'],
        randomscale=params['randomscale'],
        max_object_size=params['max_object_size'],
        use_gt_mask=0,
        n_boxes=params['n_boxes']
    )  #configs[0]['use_gtmask_inp'])#, imagenet_norm=(configs[0]['use_imagenet_pretrained'] is not None))

    #gt_mask_data = get_dataset('','', params['mask_size'], params['mask_size'], params['dataset'], params['split'],
    #                         select_attrs=configs[0]['selected_attrs'], bboxLoader=0, loadMasks = True)
    #data_iter = DataLoader(targ_split, batch_size=b_sz, shuffle=True, num_workers=8)
    targ_split = dataset  #train if params['split'] == 'train' else valid if params['split'] == 'val' else test
    data_iter = np.random.permutation(
        len(targ_split) if params['nImages'] == -1 else params['nImages'])

    if params['withExtMask'] or params['computeSegAccuracy']:
        gt_mask_data = get_dataset(
            '',
            '',
            params['mask_size'],
            params['mask_size'],
            params['dataset']
            if params['extMask_source'] == 'gt' else params['extMask_source'],
            params['split'],
            select_attrs=configs[0]['selected_attrs'],
            bboxLoader=0,
            loadMasks=True)
        commonIds = set(gt_mask_data.valid_ids).intersection(
            set(dataset.valid_ids))
        commonIndexes = [
            i for i in xrange(len(dataset.valid_ids))
            if dataset.valid_ids[i] in commonIds
        ]
        data_iter = commonIndexes if params[
            'nImages'] == -1 else commonIndexes[:params['nImages']]

    print('-----------------------------------------')
    print('%s' % (' | '.join(targ_split.selected_attrs)))
    print('-----------------------------------------')

    flatten = lambda l: [item for sublist in l for item in sublist]
    selected_attrs = configs[0]['selected_attrs']

    if params['showreconst'] and len(params['names']) > 0:
        params['names'] = flatten([[nm, nm + '-R'] for nm in params['names']])

    #discriminator.load_state_dict(cv['discriminator_state_dict'])
    c_idx = 0
    np.set_printoptions(precision=2)
    padimg = np.zeros((params['image_size'], 5, 3), dtype=np.uint8)
    padimg[:, :, :] = 128
    vggLoss = VGGLoss(network='squeeze')
    cimg_cnt = 0

    mask_bin_size = 0.1
    n_bins = int(1.0 / mask_bin_size)
    vLTotal = np.zeros((n_bins, ))
    pSNRTotal = np.zeros((n_bins, ))
    ssimTotal = np.zeros((n_bins, ))
    total_count = np.zeros((n_bins, )) + 1e-8

    perImageRes = {'images': {}, 'overall': {}}
    if params['dilateMask']:
        dilateWeight = torch.ones(
            (1, 1, params['dilateMask'], params['dilateMask']))
        dilateWeight = Variable(dilateWeight, requires_grad=False).cuda()
    else:
        dilateWeight = None

    for i in tqdm(xrange(len(data_iter))):
        #for i in tqdm(xrange(2)):
        idx = data_iter[i]
        x, real_label, boxImg, boxlabel, mask, bbox, curCls = targ_split[idx]
        cocoid = targ_split.getcocoid(idx)
        nnz_cls = real_label.nonzero()
        z_cls = (1 - real_label).nonzero()

        z_cls = z_cls[:, 0]
        x = x[None, ::]
        boxImg = boxImg[None, ::]
        mask = mask[None, ::]
        boxlabel = boxlabel[None, ::]
        real_label = real_label[None, ::]

        x, boxImg, mask, boxlabel = solvers[0].to_var(
            x, volatile=True), solvers[0].to_var(
                boxImg, volatile=True), solvers[0].to_var(
                    mask, volatile=True), solvers[0].to_var(boxlabel,
                                                            volatile=True)
        real_label = solvers[0].to_var(real_label, volatile=True)
        fake_x, mask_out = solvers[0].forward_generator(
            x,
            imagelabel=real_label,
            mask_threshold=params['mask_threshold'],
            onlyMasks=False,
            mask=mask,
            withGTMask=False,
            dilate=dilateWeight)
        vL = vggLoss(fake_x, x).data[0]
        # Change the image range to 0, 255
        fake_x_sk = get_sk_image(fake_x)
        x_sk = get_sk_image(x)
        a = x.data.cpu().numpy()
        print(a.shape)
        plt.subplot(121)
        plt.imshow(np.rollaxis(x.data.cpu().numpy().squeeze(), 0, start=3))
        plt.subplot(122)
        plt.imshow(np.rollaxis(fake_x.data.cpu().numpy().squeeze(), 0,
                               start=3))
        plt.savefig('%d.png' % (i, ))
        if i == 9:
            break
        pSNR = compare_psnr(fake_x_sk, x_sk, data_range=255.)
        ssim = compare_ssim(fake_x_sk,
                            x_sk,
                            data_range=255.,
                            multichannel=True)
        msz = mask.data.cpu().numpy().mean()
        if msz > 0.:
            msz_bin = int((msz - 1e-8) / mask_bin_size)

            perImageRes['images'][cocoid] = {'overall': {}}
            perImageRes['images'][cocoid]['overall']['perceptual'] = float(vL)
            perImageRes['images'][cocoid]['overall']['pSNR'] = float(pSNR)
            perImageRes['images'][cocoid]['overall']['ssim'] = float(ssim)
            perImageRes['images'][cocoid]['overall']['mask_size'] = float(msz)
            vLTotal[msz_bin] += vL
            pSNRTotal[msz_bin] += pSNR
            ssimTotal[msz_bin] += ssim
            total_count[msz_bin] += 1

    print '------------------------------------------------------------'
    print '                Metrics have been computed                  '
    print '------------------------------------------------------------'
    print('Percp: || %s |' % (' | '.join([
        '  %.3f' % sc for sc in [vLTotal.sum() / total_count.sum()] +
        list(vLTotal / total_count)
    ])))
    print('pSNR : || %s |' % (' | '.join([
        '  %.3f' % sc for sc in [pSNRTotal.sum() / total_count.sum()] +
        list(pSNRTotal / total_count)
    ])))
    print('ssim : || %s |' % (' | '.join([
        '  %.3f' % sc for sc in [ssimTotal.sum() / total_count.sum()] +
        list(ssimTotal / total_count)
    ])))
    if params['dump_perimage_res']:
        json.dump(
            perImageRes,
            open(
                join(
                    params['dump_perimage_res'], params['split'] + '_' +
                    basename(params['model'][0]).split('.')[0]), 'w'))
ipdb> plt.imshow(((fake_x.data.cpu().numpy()[2,[0,1,2],:,:].transpose(1,2,0)+1.0)*255./2.0).astype(np.uint8)); plt.show()
<matplotlib.image.AxesImage object at 0x7fb64908acd0>
ipdb> plt.imshow(((diffimg.data.cpu().numpy()[2,[0,1,2],:,:].transpose(1,2,0)+1.0)*255./2.0).astype(np.uint8)); plt.show()
<matplotlib.image.AxesImage object at 0x7fb648fcccd0>
ipdb> plt.imshow(((1-mask.data.cpu().numpy()[2,[0,1,2],:,:].transpose(1,2,0))*255.).astype(np.uint8)); plt.show()
<matplotlib.image.AxesImage object at 0x7fb648f1e410>




#
from utils.data_loader_stargan import get_dataset
import matplotlib.pyplot as plt
import numpy as np

dataset = get_dataset(config.celebA_image_path, config.metadata_path, config.celebA_crop_size, config.image_size, config.dataset, config.mode, select_attrs=config.selected_attrs, datafile=config.datafile,bboxLoader=config.train_boxreconst)
img, imgLab, boxImg, boxLab, mask = dataset[0]

plt.figure();plt.imshow(((img.numpy()[[0,1,2],:,:].transpose(1,2,0)+1.0)*255./2.0).astype(np.uint8)); plt.figure();plt.imshow(((boxImg.numpy()[[0,1,2],:,:].transpose(1,2,0)+1.0)*255./2.0).astype(np.uint8));plt.figure(); plt.imshow(((mask.numpy()[[0,0,0],:,:].transpose(1,2,0)+1.0)*255./2.0).astype(np.uint8));plt.figure();plt.imshow((((img*mask).numpy()[[0,1,2],:,:].transpose(1,2,0)+1.0)*255./2.0).astype(np.uint8));plt.show()




###----------------------------------------------------------------------------------------------------------------

import numpy as np
from collections import defaultdict
from tqdm import tqdm


pwd