def gen_samples(params): # For fast training #cudnn.benchmark = True gpu_id = 0 use_cuda = params['cuda'] b_sz = params['batch_size'] if params['use_same_g']: if len(params['use_same_g']) == 1: gCV = torch.load(params['use_same_g'][0]) solvers = [] configs = [] for i, mfile in enumerate(params['model']): model = torch.load(mfile) configs.append(model['arch']) configs[-1]['pretrained_model'] = mfile configs[-1]['load_encoder'] = 1 configs[-1]['load_discriminator'] = 0 if params[ 'evaluating_discr'] is not None else 1 if i == 0: configs[i]['onlypretrained_discr'] = params['evaluating_discr'] else: configs[i]['onlypretrained_discr'] = None if params['withExtMask'] and params['mask_size'] != 32: configs[-1]['lowres_mask'] = 0 configs[-1]['load_encoder'] = 0 else: params['mask_size'] = 32 solvers.append( Solver(None, None, ParamObject(configs[-1]), mode='test' if i > 0 else 'eval', pretrainedcv=model)) solvers[-1].G.eval() if configs[-1]['train_boxreconst'] > 0: solvers[-1].E.eval() if params['use_same_g']: solvers[-1].no_inpainter = 0 solvers[-1].load_pretrained_generator(gCV) print 'loaded generator again' solvers[0].D.eval() solvers[0].D_cls.eval() dataset = get_dataset('', '', params['image_size'], params['image_size'], params['dataset'], params['split'], select_attrs=configs[0]['selected_attrs'], datafile=params['datafile'], bboxLoader=1, bbox_size=params['box_size'], randomrotate=params['randomrotate'], randomscale=params['randomscale'], max_object_size=params['max_object_size'], use_gt_mask=configs[0]['use_gtmask_inp'], onlyrandBoxes=params['extmask_type'] == 'randbox', square_resize=configs[0].get('square_resize', 0) if params['square_resize_override'] < 0 else params['square_resize_override'], filter_by_mincooccur=params['filter_by_mincooccur'], only_indiv_occur=params['only_indiv_occur']) #gt_mask_data = get_dataset('','', params['mask_size'], params['mask_size'], params['dataset'], params['split'], # select_attrs=configs[0]['selected_attrs'], bboxLoader=0, loadMasks = True) #data_iter = DataLoader(targ_split, batch_size=b_sz, shuffle=True, num_workers=8) targ_split = dataset #train if params['split'] == 'train' else valid if params['split'] == 'val' else test data_iter = np.random.permutation(len(targ_split)) if params['computeSegAccuracy']: gt_mask_data = get_dataset('', '', params['mask_size'], params['mask_size'], params['dataset'], params['split'], select_attrs=configs[0]['selected_attrs'], bboxLoader=0, loadMasks=True) commonIds = set(gt_mask_data.valid_ids).intersection( set(dataset.valid_ids)) commonIndexes = [ i for i in xrange(len(dataset.valid_ids)) if dataset.valid_ids[i] in commonIds ] data_iter = commonIndexes if params['withExtMask'] and (params['extmask_type'] == 'mask'): ext_mask_data = get_dataset( '', '', params['mask_size'], params['mask_size'], params['dataset'] if params['extMask_source'] == 'gt' else params['extMask_source'], params['split'], select_attrs=configs[0]['selected_attrs'], bboxLoader=0, loadMasks=True) curr_valid_ids = [dataset.valid_ids[i] for i in data_iter] commonIds = set(ext_mask_data.valid_ids).intersection( set(curr_valid_ids)) commonIndexes = [ i for i in xrange(len(dataset.valid_ids)) if dataset.valid_ids[i] in commonIds ] data_iter = commonIndexes if params['nImages'] > -1: data_iter = data_iter[:params['nImages']] print('-----------------------------------------') print('%s' % (' | '.join(targ_split.selected_attrs))) print('-----------------------------------------') flatten = lambda l: [item for sublist in l for item in sublist] selected_attrs = configs[0]['selected_attrs'] if params['showreconst'] and len(params['names']) > 0: params['names'] = flatten([[nm, nm + '-R'] for nm in params['names']]) #discriminator.load_state_dict(cv['discriminator_state_dict']) c_idx = 0 np.set_printoptions(precision=2) padimg = np.zeros((params['image_size'], 5, 3), dtype=np.uint8) padimg[:, :, :] = 128 vggLoss = VGGLoss(network='squeeze') cimg_cnt = 0 perclass_removeSucc = np.zeros((len(selected_attrs))) perclass_confusion = np.zeros((len(selected_attrs), len(selected_attrs))) perclass_classScoreDrop = np.zeros( (len(selected_attrs), len(selected_attrs))) perclass_cooccurence = np.zeros( (len(selected_attrs), len(selected_attrs))) + 1e-6 perclass_vgg = np.zeros((len(selected_attrs))) perclass_ssim = np.zeros((len(selected_attrs))) perclass_psnr = np.zeros((len(selected_attrs))) perclass_tp = np.zeros((len(selected_attrs))) perclass_fp = np.zeros((len(selected_attrs))) perclass_fn = np.zeros((len(selected_attrs))) perclass_acc = np.zeros((len(selected_attrs))) perclass_counts = np.zeros((len(selected_attrs))) + 1e-6 perclass_int = np.zeros((len(selected_attrs))) perclass_union = np.zeros((len(selected_attrs))) perclass_gtsize = np.zeros((len(selected_attrs))) perclass_predsize = np.zeros((len(selected_attrs))) perclass_segacc = np.zeros((len(selected_attrs))) perclass_msz = np.zeros((len(selected_attrs))) #perclass_th = Variable(torch.FloatTensor(np.array([0., 0.5380775, -0.49303985, -0.48941165, 2.8394265, -0.37880898, 1.0709367, 1.6613332, -1.5602279, 1.2631614, 2.4104881, -0.29175103, -0.6607682, -0.2128999, -1.286599, -2.24577, -0.4130093, -1.0535073, 0.038890466, -0.6808476]))).cuda() perclass_th = Variable(torch.FloatTensor(np.zeros( (len(selected_attrs))))).cuda() perImageRes = {'images': {}, 'overall': {}} total_count = 0. if params['computeAP']: allScores = [] allGT = [] allEditedSc = [] if params['dilateMask']: dilateWeight = torch.ones( (1, 1, params['dilateMask'], params['dilateMask'])) dilateWeight = Variable(dilateWeight, requires_grad=False).cuda() else: dilateWeight = None all_masks = [] all_imgidAndCls = [] for i in tqdm(xrange(len(data_iter))): #for i in tqdm(xrange(2)): idx = data_iter[i] x, real_label, boxImg, boxlabel, mask, bbox, curCls = targ_split[idx] cocoid = targ_split.getcocoid(idx) nnz_cls = real_label.nonzero() z_cls = (1 - real_label).nonzero() z_cls = z_cls[:, 0] if len(z_cls.size()) > 1 else z_cls x = x[None, ::] boxImg = boxImg[None, ::] mask = mask[None, ::] boxlabel = boxlabel[None, ::] real_label = real_label[None, ::] x, boxImg, mask, boxlabel = solvers[0].to_var( x, volatile=True), solvers[0].to_var( boxImg, volatile=True), solvers[0].to_var( mask, volatile=True), solvers[0].to_var(boxlabel, volatile=True) real_label = solvers[0].to_var(real_label, volatile=True) _, out_cls_real = solvers[0].classify(x) out_cls_real = out_cls_real[0] # Remove the singleton dimension pred_real_label = (out_cls_real > perclass_th) total_count += 1 #;import ipdb; ipdb.set_trace() if params['computeAP']: allScores.append(out_cls_real[None, :]) allGT.append(real_label) removeScores = out_cls_real.clone() perclass_acc[(pred_real_label.float() == real_label )[0, :].data.cpu().numpy().astype(np.bool)] += 1. if len(z_cls): perclass_fp[z_cls.numpy()] += pred_real_label.data.cpu()[z_cls] if len(nnz_cls): nnz_cls = nnz_cls[:, 0] perclass_tp[nnz_cls.numpy()] += pred_real_label.data.cpu()[nnz_cls] perclass_fn[ nnz_cls.numpy()] += 1 - pred_real_label.data.cpu()[nnz_cls] perImageRes['images'][cocoid] = {'perclass': {}} if params['dump_cls_results']: perImageRes['images'][cocoid]['real_label'] = nnz_cls.tolist() perImageRes['images'][cocoid][ 'real_scores'] = out_cls_real.data.cpu().tolist() if not params['eval_only_discr']: for cid in nnz_cls: if configs[0]['use_gtmask_inp']: mask = solvers[0].to_var(targ_split.getGTMaskInp( idx, configs[0]['selected_attrs'][cid])[None, ::], volatile=True) if params['withExtMask']: if params['extmask_type'] == 'mask': mask = solvers[0].to_var( ext_mask_data.getbyIdAndclass( cocoid, configs[0]['selected_attrs'][cid])[ None, ::], volatile=True) elif params['extmask_type'] == 'box': mask = solvers[0].to_var(dataset.getGTMaskInp( idx, configs[0]['selected_attrs'][cid], mask_type=2)[None, ::], volatile=True) elif params['extmask_type'] == 'randbox': # Nothing to do here, mask is already set to random boxes None if params['computeSegAccuracy']: gtMask = gt_mask_data.getbyIdAndclass( cocoid, configs[0]['selected_attrs'][cid]).cuda() mask_target = torch.zeros_like(real_label) fake_label = real_label.clone() fake_label[0, cid] = 0. mask_target[0, cid] = 1 fake_x, mask_out = solvers[0].forward_generator( x, imagelabel=mask_target, mask_threshold=params['mask_threshold'], onlyMasks=False, mask=mask, withGTMask=params['withExtMask'], dilate=dilateWeight) _, out_cls_fake = solvers[0].classify(fake_x) out_cls_fake = out_cls_fake[ 0] # Remove the singleton dimension mask_out = mask_out.data[0, ::] if params['dump_mask']: all_masks.append(mask_out.cpu().numpy()) all_imgidAndCls.append((cocoid, selected_attrs[cid])) perImageRes['images'][cocoid]['perclass'][ selected_attrs[cid]] = {} if params['computeSegAccuracy']: union = torch.clamp((gtMask + mask_out), max=1.0).sum() intersection = (gtMask * mask_out).sum() img_iou = (intersection / (union + 1e-6)) img_acc = (gtMask == mask_out).float().mean() img_recall = ((intersection / (gtMask.sum() + 1e-6))) img_precision = (intersection / (mask_out.sum() + 1e-6)) perImageRes['images'][cocoid]['perclass'][ selected_attrs[cid]].update({ 'iou': img_iou, 'rec': img_recall, 'prec': img_precision, 'acc': img_acc }) perImageRes['images'][cocoid]['perclass'][ selected_attrs[cid]]['gtSize'] = gtMask.mean() perImageRes['images'][cocoid]['perclass'][ selected_attrs[cid]]['predSize'] = mask_out.mean() # Compute metrics now perclass_counts[cid] += 1 perclass_int[cid] += intersection perclass_union[cid] += union perclass_gtsize[cid] += gtMask.sum() perclass_predsize[cid] += mask_out.sum() perclass_segacc[cid] += img_acc if params['dump_cls_results']: perImageRes['images'][cocoid]['perclass'][ selected_attrs[ cid]]['remove_scores'] = out_cls_fake.data.cpu( ).tolist() perImageRes['images'][cocoid]['perclass'][ selected_attrs[cid]]['diff'] = out_cls_real.data[ cid] - out_cls_fake.data[cid] remove_succ = float( (out_cls_fake.data[cid] < perclass_th[cid])) # and (out_cls_real[cid]>0.)) perclass_removeSucc[cid] += remove_succ vL = vggLoss(fake_x, x).data[0] perclass_vgg[cid] += 100. * vL fake_x_sk = get_sk_image(fake_x) x_sk = get_sk_image(x) pSNR = compare_psnr(fake_x_sk, x_sk, data_range=255.) ssim = compare_ssim(fake_x_sk, x_sk, data_range=255., multichannel=True) msz = mask_out.mean() if msz > 0.: perclass_ssim[cid] += ssim perclass_psnr[cid] += pSNR if params['computeAP']: removeScores[cid] = out_cls_fake[cid] #--------------------------------------------------------------- # These are classes not trying to be removed; # correctly detect on real image and not detected on fake image # This are collateral damage. Count these #--------------------------------------------------------------- false_remove = fake_label.byte() * ( out_cls_fake < perclass_th) * (out_cls_real > perclass_th) perclass_cooccurence[cid, nnz_cls.numpy()] += 1. perclass_confusion[cid, false_remove.data.cpu().numpy(). astype(np.bool)[0, :]] += 1 perImageRes['images'][cocoid]['perclass'][ selected_attrs[cid]].update({ 'remove_succ': remove_succ, 'false_remove': float( false_remove.data.cpu().float().numpy().sum()), 'perceptual': 100. * vL }) if params['computeAP']: allEditedSc.append(removeScores[None, :]) perImageRes['images'][cocoid]['overall'] = {} perImageRes['images'][cocoid]['overall'][ 'remove_succ'] = np.mean([ perImageRes['images'][cocoid]['perclass'][cls] ['remove_succ'] for cls in perImageRes['images'][cocoid]['perclass'] ]) perImageRes['images'][cocoid]['overall'][ 'false_remove'] = np.mean([ perImageRes['images'][cocoid]['perclass'][cls] ['false_remove'] for cls in perImageRes['images'][cocoid]['perclass'] ]) perImageRes['images'][cocoid]['overall'][ 'perceptual'] = np.mean([ perImageRes['images'][cocoid]['perclass'][cls] ['perceptual'] for cls in perImageRes['images'][cocoid]['perclass'] ]) perImageRes['images'][cocoid]['overall']['diff'] = np.mean([ perImageRes['images'][cocoid]['perclass'][cls]['diff'] for cls in perImageRes['images'][cocoid]['perclass'] ]) if params['computeSegAccuracy']: perImageRes['images'][cocoid]['overall']['iou'] = np.mean([ perImageRes['images'][cocoid]['perclass'][cls]['iou'] for cls in perImageRes['images'][cocoid]['perclass'] ]) perImageRes['images'][cocoid]['overall']['acc'] = np.mean([ perImageRes['images'][cocoid]['perclass'][cls]['acc'] for cls in perImageRes['images'][cocoid]['perclass'] ]) perImageRes['images'][cocoid]['overall']['prec'] = np.mean( [ perImageRes['images'][cocoid]['perclass'][cls] ['prec'] for cls in perImageRes['images'][cocoid] ['perclass'] ]) perImageRes['images'][cocoid]['overall']['rec'] = np.mean([ perImageRes['images'][cocoid]['perclass'][cls]['rec'] for cls in perImageRes['images'][cocoid]['perclass'] ]) elif params['dump_cls_results']: perImageRes['images'][cocoid] = {'perclass': {}} perImageRes['images'][cocoid]['real_label'] = nnz_cls.tolist() perImageRes['images'][cocoid][ 'real_scores'] = out_cls_real.data.cpu().tolist() if params['dump_mask']: np.savez('allMasks.npz', masks=np.concatenate(all_masks).astype(np.uint8), idAndClass=np.stack(all_imgidAndCls)) if params['computeAP']: allScores = torch.cat(allScores, dim=0).data.cpu().numpy() allGT = torch.cat(allGT, dim=0).data.cpu().numpy() apR = computeAP(allScores, allGT) if not params['eval_only_discr']: allEditedSc = torch.cat(allEditedSc, dim=0).data.cpu().numpy() apEdited = computeAP(allEditedSc, allGT) #for i in xrange(len(selected_attrs)): # pr,rec,th = precision_recall_curve(allGTArr[:,i],allPredArr[:,i]); # f1s = 2*(pr*rec)/(pr+rec); mf1idx = np.argmax(f1s); # #print 'Max f1 = %.2f, th =%.2f'%(f1s[mf1idx], th[mf1idx]); # allMf1s.append(f1s[mf1idx]) # allTh.append(th[mf1idx]) recall = perclass_tp / (perclass_tp + perclass_fn + 1e-6) precision = perclass_tp / (perclass_tp + perclass_fp + 1e-6) f1_score = 2.0 * (recall * precision) / (recall + precision + 1e-6) present_classes = (perclass_tp + perclass_fn) > 0. perclass_gt_counts = (perclass_tp + perclass_fn) apROverall = (perclass_gt_counts * apR).sum() / (perclass_gt_counts.sum()) apR = apR[present_classes] recall = recall[present_classes] f1_score = f1_score[present_classes] precision = precision[present_classes] perclass_acc = perclass_acc[present_classes] present_attrs = [ att for i, att in enumerate(targ_split.selected_attrs) if present_classes[i] ] rec_overall = perclass_tp.sum() / (perclass_tp.sum() + perclass_fn.sum() + 1e-6) prec_overall = perclass_tp.sum() / (perclass_tp.sum() + perclass_fp.sum() + 1e-6) f1_score_overall = 2.0 * (rec_overall * prec_overall) / ( rec_overall + prec_overall + 1e-6) print '------------------------------------------------------------' print ' Metrics have been computed ' print '------------------------------------------------------------' print('Score: || %s |' % (' | '.join( ['%6s' % att[:6] for att in ['Overall', 'OverCls'] + present_attrs]))) print('Acc : || %s |' % (' | '.join([ ' %.2f' % sc for sc in [(perclass_acc / total_count).mean()] + [(perclass_acc / total_count).mean()] + list(perclass_acc / total_count) ]))) print('F1-sc: || %s |' % (' | '.join([ ' %.2f' % sc for sc in [f1_score_overall] + [f1_score.mean()] + list(f1_score) ]))) print('recal: || %s |' % (' | '.join([ ' %.2f' % sc for sc in [rec_overall] + [recall.mean()] + list(recall) ]))) print('prec : || %s |' % (' | '.join([ ' %.2f' % sc for sc in [prec_overall] + [precision.mean()] + list(precision) ]))) if params['computeAP']: print('AP : || %s |' % (' | '.join( [' %.2f' % sc for sc in [apROverall] + [apR.mean()] + list(apR)]))) print('Count: || %s |' % (' | '.join([ ' %4.0f' % sc for sc in [perclass_gt_counts.mean()] + [perclass_gt_counts.mean()] + list(perclass_gt_counts[present_classes]) ]))) if not params['eval_only_discr']: print('R-suc: || %s |' % (' | '.join([ ' %.2f' % sc for sc in [(perclass_removeSucc.sum() / perclass_cooccurence.diagonal().sum())] + [(perclass_removeSucc / perclass_cooccurence.diagonal()).mean()] + list(perclass_removeSucc / perclass_cooccurence.diagonal()) ]))) print('R-fal: || %s |' % (' | '.join([ ' %.2f' % sc for sc in [(perclass_confusion.sum() / (perclass_cooccurence.sum() - perclass_cooccurence.diagonal().sum()))] + [(perclass_confusion.sum(axis=1) / (perclass_cooccurence.sum(axis=1) - perclass_cooccurence.diagonal())).mean()] + list((perclass_confusion / perclass_cooccurence).sum(axis=1) / (perclass_cooccurence.shape[0] - 1)) ]))) print('Percp: || %s |' % (' | '.join([ ' %.2f' % sc for sc in [(perclass_vgg.sum() / perclass_cooccurence.diagonal().sum())] + [(perclass_vgg / perclass_cooccurence.diagonal()).mean()] + list(perclass_vgg / perclass_cooccurence.diagonal()) ]))) print('pSNR : || %s |' % (' | '.join([ ' %.2f' % sc for sc in [(perclass_psnr.sum() / perclass_cooccurence.diagonal().sum())] + [(perclass_psnr / perclass_cooccurence.diagonal()).mean()] + list(perclass_psnr / perclass_cooccurence.diagonal()) ]))) print('ssim : || %s |' % (' | '.join([ ' %.3f' % sc for sc in [(perclass_ssim.sum() / perclass_cooccurence.diagonal().sum())] + [(perclass_ssim / perclass_cooccurence.diagonal()).mean()] + list(perclass_ssim / perclass_cooccurence.diagonal()) ]))) if params['computeAP']: print('R-AP : || %s |' % (' | '.join([ ' %.2f' % sc for sc in [apEdited.mean()] + [apEdited.mean()] + list(apEdited) ]))) if params['computeSegAccuracy']: print('mIou : || %s |' % (' | '.join([ ' %.2f' % sc for sc in [(perclass_int.sum() / (perclass_union + 1e-6).sum())] + [(perclass_int / (perclass_union + 1e-6)).mean()] + list(perclass_int / (perclass_union + 1e-6)) ]))) print('mRec : || %s |' % (' | '.join([ ' %.2f' % sc for sc in [(perclass_int.sum() / (perclass_gtsize + 1e-6).sum())] + [(perclass_int / (perclass_gtsize + 1e-6)).mean()] + list(perclass_int / (perclass_gtsize + 1e-6)) ]))) print('mPrc : || %s |' % (' | '.join([ ' %.2f' % sc for sc in [(perclass_int.sum() / (perclass_predsize.sum()))] + [(perclass_int / (perclass_predsize + 1e-6)).mean()] + list(perclass_int / (perclass_predsize + 1e-6)) ]))) print('mSzR : || %s |' % (' | '.join([ ' %.2f' % sc for sc in [(perclass_predsize.sum() / (perclass_gtsize.sum()))] + [(perclass_predsize / (perclass_gtsize + 1e-6)).mean()] + list(perclass_predsize / (perclass_gtsize + 1e-6)) ]))) print('Acc : || %s |' % (' | '.join([ ' %.2f' % sc for sc in [(perclass_segacc.sum() / (perclass_counts.sum()))] + [(perclass_segacc / (perclass_counts + 1e-6)).mean()] + list(perclass_segacc / (perclass_counts + 1e-6)) ]))) print('mSz : || %s |' % (' | '.join([ ' %.1f' % sc for sc in [(100. * (perclass_predsize.sum() / (params['mask_size'] * params['mask_size'] * perclass_counts).sum()))] + [(100. * (perclass_predsize / (params['mask_size'] * params['mask_size'] * perclass_counts + 1e-6))).mean()] + list((100. * perclass_predsize) / (params['mask_size'] * params['mask_size'] * perclass_counts + 1e-6)) ]))) perImageRes['overall'] = {'iou': 0., 'rec': 0., 'prec': 0., 'acc': 0.} perImageRes['overall']['remove_succ'] = ( perclass_removeSucc / perclass_cooccurence.diagonal()).mean() perImageRes['overall']['false_remove'] = (perclass_confusion / perclass_cooccurence).mean() perImageRes['overall']['perceptual'] = ( perclass_vgg / perclass_cooccurence.diagonal()).mean() if params['computeSegAccuracy']: perImageRes['overall']['iou'] = (perclass_int / (perclass_union + 1e-6)).mean() perImageRes['overall']['acc'] = (perclass_segacc / (perclass_counts + 1e-6)).mean() perImageRes['overall']['prec'] = ( perclass_int / (perclass_predsize + 1e-6)).mean() perImageRes['overall']['psize'] = (perclass_predsize).mean() perImageRes['overall']['psize_rel'] = ( perclass_predsize / (perclass_gtsize + 1e-6)).mean() perImageRes['overall']['rec'] = (perclass_int / (perclass_gtsize + 1e-6)).mean() if params['computeAP']: perImageRes['overall']['ap-orig'] = list(apR) perImageRes['overall']['ap-edit'] = list(apEdited) if params['dump_perimage_res']: json.dump( perImageRes, open( join( params['dump_perimage_res'], params['split'] + '_' + basename(params['model'][0]).split('.')[0]), 'w'))
def gen_samples(params): # For fast training #cudnn.benchmark = True gpu_id = 0 use_cuda = params['cuda'] b_sz = params['batch_size'] g_conv_dim = 64 d_conv_dim = 64 c_dim = 5 c2_dim = 8 g_repeat_num = 6 d_repeat_num = 6 select_attrs = [] if params['use_same_g']: if len(params['use_same_g']) == 1: gCV = torch.load(params['use_same_g'][0]) solvers = [] configs = [] for i, mfile in enumerate(params['model']): model = torch.load(mfile) configs.append(model['arch']) configs[-1]['pretrained_model'] = mfile configs[-1]['load_encoder'] = 1 configs[-1]['load_discriminator'] = 0 configs[-1]['image_size'] = params['image_size'] if 'g_downsamp_layers' not in configs[-1]: configs[-1]['g_downsamp_layers'] = 2 if 'g_dil_start' not in configs[-1]: configs[-1]['g_dil_start'] = 0 configs[-1]['e_norm_type'] = 'drop' configs[-1]['e_ksize'] = 4 if len(params['withExtMask']) and params['mask_size'] != 32: if params['withExtMask'][i]: configs[-1]['lowres_mask'] = 0 configs[-1]['load_encoder'] = 0 solvers.append( Solver(None, None, ParamObject(configs[-1]), mode='test', pretrainedcv=model)) solvers[-1].G.eval() #solvers[-1].D.eval() if configs[-1]['train_boxreconst'] > 0 and solvers[-1].E is not None: solvers[-1].E.eval() if params['use_same_g']: solvers[-1].load_pretrained_generator(gCV) if len(params['dilateMask']): assert (len(params['model']) == len(params['dilateMask'])) dilateWeightAll = [] for di in range(len(params['dilateMask'])): if params['dilateMask'][di] > 0: dilateWeight = torch.ones( (1, 1, params['dilateMask'][di], params['dilateMask'][di])) dilateWeight = Variable(dilateWeight, requires_grad=False).cuda() else: dilateWeight = None dilateWeightAll.append(dilateWeight) else: dilateWeightAll = [None for i in range(len(params['model']))] dataset = get_dataset('', '', params['image_size'], params['image_size'], params['dataset'], params['split'], select_attrs=configs[0]['selected_attrs'], datafile=params['datafile'], bboxLoader=1, bbox_size=params['box_size'], randomrotate=params['randomrotate'], randomscale=params['randomscale'], max_object_size=params['max_object_size'], use_gt_mask=configs[0]['use_gtmask_inp'], n_boxes=params['n_boxes'], onlyrandBoxes=(params['extmask_type'] == 'randbox')) #data_iter = DataLoader(targ_split, batch_size=b_sz, shuffle=True, num_workers=8) targ_split = dataset #train if params['split'] == 'train' else valid if params['split'] == 'val' else test data_iter = np.random.permutation(len(targ_split)) if len(params['withExtMask']) and (params['extmask_type'] == 'mask'): gt_mask_data = get_dataset( '', '', params['mask_size'], params['mask_size'], params['dataset'] if params['extMask_source'] == 'gt' else params['extMask_source'], params['split'], select_attrs=configs[0]['selected_attrs'], bboxLoader=0, loadMasks=True) if len(params['sort_by']): resFiles = [json.load(open(fil, 'r')) for fil in params['sort_by']] for i in range(len(resFiles)): #if params['sort_score'] not in resFiles[i]['images'][resFiles[i]['images'].keys()[0]]['overall']: for k in resFiles[i]['images']: img = resFiles[i]['images'][k] if 'overall' in resFiles[i]['images'][k]: resFiles[i]['images'][k]['overall'][ params['sort_score']] = np.mean([ img['perclass'][cls][params['sort_score']] for cls in img['perclass'] ]) else: resFiles[i]['images'][k]['overall'] = {} resFiles[i]['images'][k]['overall'][ params['sort_score']] = np.mean([ img['perclass'][cls][params['sort_score']] for cls in img['perclass'] ]) idToScore = { int(k): resFiles[0]['images'][k]['overall'][params['sort_score']] for k in resFiles[0]['images'] } idToScore = OrderedDict( reversed(sorted(list(idToScore.items()), key=lambda t: t[1]))) cocoIdToindex = {v: i for i, v in enumerate(dataset.valid_ids)} data_iter = [cocoIdToindex[k] for k in idToScore] dataIt2id = {cocoIdToindex[k]: str(k) for k in idToScore} if len(params['show_ids']) > 0: cocoIdToindex = {v: i for i, v in enumerate(dataset.valid_ids)} data_iter = [cocoIdToindex[k] for k in params['show_ids']] print(len(data_iter)) print('-----------------------------------------') print(('%s' % (' | '.join(targ_split.selected_attrs)))) print('-----------------------------------------') flatten = lambda l: [item for sublist in l for item in sublist] if params['showreconst'] and len(params['names']) > 0: params['names'] = flatten([[nm, nm + '-R'] for nm in params['names']]) #discriminator.load_state_dict(cv['discriminator_state_dict']) c_idx = 0 np.set_printoptions(precision=2) padimg = np.zeros((params['image_size'], 5, 3), dtype=np.uint8) padimg[:, :, :] = 128 if params['showperceptionloss']: vggLoss = VGGLoss(network='squeeze') cimg_cnt = 0 mean_hist = [[], [], []] max_hist = [[], [], []] lengths_hist = [[], [], []] if len(params['n_iter']) == 0: params['n_iter'] = [0] * len(params['model']) while True: cimg_cnt += 1 #import ipdb; ipdb.set_trace() idx = data_iter[c_idx] x, real_label, boxImg, boxlabel, mask, bbox, curCls = targ_split[ data_iter[c_idx]] fp = [targ_split.getfilename(data_iter[c_idx])] #if configs[0]['use_gtmask_inp']: # mask = mask[1:,::] x = x[None, ::] boxImg = boxImg[None, ::] mask = mask[None, ::] boxlabel = boxlabel[None, ::] real_label = real_label[None, ::] x, boxImg, mask, boxlabel = solvers[0].to_var( x, volatile=True), solvers[0].to_var( boxImg, volatile=True), solvers[0].to_var( mask, volatile=True), solvers[0].to_var(boxlabel, volatile=True) real_label = solvers[0].to_var(real_label, volatile=True) fake_image_list = [x] if params['showmask']: mask_image_list = [x - x] else: fake_image_list.append(x * (1 - mask) + mask) deformList = [[], []] if len(real_label[0, :].nonzero()): #rand_idx = random.choice(real_label[0,:].nonzero()).data[0] rand_idx = curCls[0] print(configs[0]['selected_attrs'][rand_idx]) if len(params['withExtMask']): cocoid = targ_split.getcocoid(idx) if params['extmask_type'] == 'mask': mask = solvers[0].to_var(gt_mask_data.getbyIdAndclass( cocoid, configs[0]['selected_attrs'][rand_idx])[None, ::], volatile=True) elif params['extmask_type'] == 'box': mask = solvers[0].to_var(dataset.getGTMaskInp( idx, configs[0]['selected_attrs'][rand_idx], mask_type=2)[None, ::], volatile=True) elif params['extmask_type'] == 'randbox': # Nothing to do here, mask is already set to random boxes None else: rand_idx = curCls[0] if params['showdiff']: diff_image_list = [x - x] if params['showmask'] else [x - x, x - x] for i in range(len(params['model'])): if configs[i]['use_gtmask_inp']: mask = solvers[0].to_var(targ_split.getGTMaskInp( idx, configs[0]['selected_attrs'][rand_idx], mask_type=configs[i]['use_gtmask_inp'])[None, ::], volatile=True) if len(params['withExtMask']) or params['no_maskgen']: withGTMask = True if params['no_maskgen'] else params[ 'withExtMask'][i] else: withGTMask = False if configs[i]['train_boxreconst'] == 3: mask_target = torch.zeros_like(real_label) if len(real_label[0, :].nonzero()): mask_target[0, rand_idx] = 1 # This variable informs to the mask generator, which class to generate for boxlabelInp = boxlabel elif configs[i]['train_boxreconst'] == 2: boxlabelfake = torch.zeros_like(boxlabel) if configs[i]['use_box_label'] == 2: boxlabelInp = torch.cat([boxlabel, boxlabelfake], dim=1) if params['showreconst']: boxlabelInpRec = torch.cat([boxlabelfake, boxlabel], dim=1) mask_target = real_label else: boxlabelInp = boxlabel mask_target = real_label if params['showdeform']: img, maskOut, deform = solvers[i].forward_generator( x, boxImg=boxImg, mask=mask, imagelabel=mask_target, boxlabel=boxlabelInp, get_feat=True, mask_threshold=params['mask_threshold'], withGTMask=withGTMask, dilate=dilateWeightAll[i], n_iter=params['n_iter'][i]) fake_image_list.append(img) deformList.append(deform) else: img, maskOut = solvers[i].forward_generator( x, boxImg=boxImg, mask=mask, imagelabel=mask_target, boxlabel=boxlabelInp, mask_threshold=params['mask_threshold'], withGTMask=withGTMask, dilate=dilateWeightAll[i], n_iter=params['n_iter'][i]) fake_image_list.append(img) if params['showmask']: mask_image_list.append( solvers[i].getImageSizeMask(maskOut)[:, [0, 0, 0], ::]) if params['showdiff']: diff_image_list.append(x - fake_image_list[-1]) if params['showreconst']: if params['showdeform']: img, maskOut, deform = solvers[i].forward_generator( fake_image_list[-1], boxImg=boxImg, mask=mask, imagelabel=mask_target, boxlabel=boxlabelInp, get_feat=True, mask_threshold=params['mask_threshold'], withGTMask=withGTMask, dilate=dilateWeightAll[i], n_iter=params['n_iter'][i]) fake_image_list.append(img) deformList.append(deform) else: img, maskOut = solvers[i].forward_generator( fake_image_list[-1], boxImg=boxImg, mask=mask, imagelabel=mask_target, boxlabel=boxlabelInp, mask_threshold=params['mask_threshold'], withGTMask=withGTMask, dilate=dilateWeightAll[i], n_iter=params['n_iter'][i]) fake_image_list.append(img) if params['showdiff']: diff_image_list.append(x - fake_image_list[-1]) if not params['compute_deform_stats']: img = make_image(fake_image_list, padimg) if params['showdeform']: defImg = make_image_with_deform( fake_image_list, deformList, np.vstack([padimg, padimg, padimg])) img = np.vstack([img, defImg]) if params['showmask']: imgmask = make_image(mask_image_list, padimg) img = np.vstack([img, imgmask]) if params['showdiff']: imgdiff = make_image(diff_image_list, padimg) img = np.vstack([img, imgdiff]) if len(params['names']) > 0: nameList = ['Input' ] + params['names'] if params['showmask'] else [ 'Input', 'Masked Input' ] + params['names'] imgNames = np.hstack( flatten([[ make_image_with_text((32, x.size(3), 3), nm), padimg[:32, :, :].astype(np.uint8) ] for nm in nameList])) img = np.vstack([imgNames, img]) if len(params['sort_by']): clsname = configs[0]['selected_attrs'][rand_idx] cocoid = dataIt2id[data_iter[c_idx]] curr_class_iou = [ resFiles[i]['images'][cocoid]['real_scores'][rand_idx] ] + [ resFiles[i]['images'][cocoid]['perclass'][clsname] [params['sort_score']] for i in range(len(params['model'])) ] if params['showperceptionloss']: textToPrint = [ 'P:%.2f, S:%.1f' % (vggLoss( fake_image_list[0], fake_image_list[i]).data[0], curr_class_iou[i]) for i in range(len(fake_image_list)) ] else: textToPrint = [ 'S:%.1f' % (curr_class_iou[i]) for i in range(len(fake_image_list)) ] if len(params['show_also']): # Additional data to print for val in params['show_also']: curval = [0.] + [ resFiles[i]['images'][cocoid]['perclass'][clsname] [val][rand_idx] for i in range(len(params['model'])) ] textToPrint = [ txt + ' %s:%.1f' % (val[0], curval[i]) for i, txt in enumerate(textToPrint) ] imgScore = np.hstack( flatten([[ make_image_with_text((32, x.size(3), 3), textToPrint[i]), padimg[:32, :, :].astype(np.uint8) ] for i in range(len(fake_image_list))])) img = np.vstack([img, imgScore]) elif params['showperceptionloss']: imgScore = np.hstack( flatten([[ make_image_with_text( (32, x.size(3), 3), '%.2f' % vggLoss(fake_image_list[0], fake_image_list[i]).data[0]), padimg[:32, :, :].astype(np.uint8) ] for i in range(len(fake_image_list))])) img = np.vstack([img, imgScore]) #if params['showmask']: # imgmask = make_image(mask_list) # img = np.vstack([img, imgmask]) #if params['compmodel']: # imgcomp = make_image(fake_image_list_comp) # img = np.vstack([img, imgcomp]) # if params['showdiff']: # imgdiffcomp = make_image([fimg - fake_image_list_comp[0] for fimg in fake_image_list_comp]) # img = np.vstack([img, imgdiffcomp]) cv2.imshow( 'frame', img if params['scaleDisp'] == 0 else cv2.resize( img, None, fx=params['scaleDisp'], fy=params['scaleDisp'])) keyInp = cv2.waitKey(0) if keyInp & 0xFF == ord('q'): break elif keyInp & 0xFF == ord('b'): #print keyInp & 0xFF c_idx = c_idx - 1 elif (keyInp & 0xFF == ord('s')): #sample_dir = join(params['sample_dump_dir'], basename(params['model'][0]).split('.')[0]) sample_dir = join( params['sample_dump_dir'], '_'.join([params['split']] + params['names'])) if not exists(sample_dir): makedirs(sample_dir) fnames = ['%s.png' % splitext(basename(f))[0] for f in fp] fpaths = [join(sample_dir, f) for f in fnames] imgSaveName = fpaths[0] if params['savesepimages']: saveIndividImages(fake_image_list, mask_image_list, nameList, sample_dir, fp, configs[0]['selected_attrs'][rand_idx]) else: print('Saving into file: ' + imgSaveName) cv2.imwrite(imgSaveName, img) c_idx += 1 else: c_idx += 1 else: for di in range(len(deformList)): if len(deformList[di]) > 0 and len(deformList[di][0]) > 0: for dLidx, d in enumerate(deformList[di]): lengths, mean, maxl = compute_deform_statistics( d[1], d[0]) mean_hist[dLidx].append(mean) max_hist[dLidx].append(maxl) lengthsH = np.histogram(lengths, bins=np.arange(0, 128, 0.5))[0] if lengths_hist[dLidx] == []: lengths_hist[dLidx] = lengthsH else: lengths_hist[dLidx] += lengthsH if params['compute_deform_stats'] and (cimg_cnt < params['compute_deform_stats']): print(np.mean(mean_hist[0])) print(np.mean(mean_hist[1])) print(np.mean(mean_hist[2])) print(np.mean(max_hist[0])) print(np.mean(max_hist[1])) print(np.mean(max_hist[2])) print(lengths_hist[0]) print(lengths_hist[1]) print(lengths_hist[2]) break
def gen_samples(params): # For fast training #cudnn.benchmark = True gpu_id = 0 use_cuda = params['cuda'] b_sz = params['batch_size'] solvers = [] configs = [] for i, mfile in enumerate(params['model']): model = torch.load(mfile) configs.append(model['arch']) configs[-1]['pretrained_model'] = mfile configs[-1]['load_encoder'] = 1 configs[-1]['load_discriminator'] = 0 configs[-1]['image_size'] = params['image_size'] if i == 0: configs[i]['onlypretrained_discr'] = params['evaluating_discr'] else: configs[i]['onlypretrained_discr'] = None if params['withExtMask'] and params['mask_size'] != 32: configs[-1]['lowres_mask'] = 0 configs[-1]['load_encoder'] = 0 solvers.append( Solver(None, None, ParamObject(configs[-1]), mode='test' if i > 0 else 'eval', pretrainedcv=model)) solvers[-1].G.eval() if configs[-1]['train_boxreconst'] > 0: solvers[-1].E.eval() solvers[0].D.eval() solvers[0].D_cls.eval() dataset = get_dataset( '', '', params['image_size'], params['image_size'], params['dataset'], params['split'], select_attrs=configs[0]['selected_attrs'], datafile=params['datafile'], bboxLoader=1, bbox_size=params['box_size'], randomrotate=params['randomrotate'], randomscale=params['randomscale'], max_object_size=params['max_object_size'], use_gt_mask=0, n_boxes=params['n_boxes'] ) #configs[0]['use_gtmask_inp'])#, imagenet_norm=(configs[0]['use_imagenet_pretrained'] is not None)) #gt_mask_data = get_dataset('','', params['mask_size'], params['mask_size'], params['dataset'], params['split'], # select_attrs=configs[0]['selected_attrs'], bboxLoader=0, loadMasks = True) #data_iter = DataLoader(targ_split, batch_size=b_sz, shuffle=True, num_workers=8) targ_split = dataset #train if params['split'] == 'train' else valid if params['split'] == 'val' else test data_iter = np.random.permutation( len(targ_split) if params['nImages'] == -1 else params['nImages']) if params['withExtMask'] or params['computeSegAccuracy']: gt_mask_data = get_dataset( '', '', params['mask_size'], params['mask_size'], params['dataset'] if params['extMask_source'] == 'gt' else params['extMask_source'], params['split'], select_attrs=configs[0]['selected_attrs'], bboxLoader=0, loadMasks=True) commonIds = set(gt_mask_data.valid_ids).intersection( set(dataset.valid_ids)) commonIndexes = [ i for i in xrange(len(dataset.valid_ids)) if dataset.valid_ids[i] in commonIds ] data_iter = commonIndexes if params[ 'nImages'] == -1 else commonIndexes[:params['nImages']] print('-----------------------------------------') print('%s' % (' | '.join(targ_split.selected_attrs))) print('-----------------------------------------') flatten = lambda l: [item for sublist in l for item in sublist] selected_attrs = configs[0]['selected_attrs'] if params['showreconst'] and len(params['names']) > 0: params['names'] = flatten([[nm, nm + '-R'] for nm in params['names']]) #discriminator.load_state_dict(cv['discriminator_state_dict']) c_idx = 0 np.set_printoptions(precision=2) padimg = np.zeros((params['image_size'], 5, 3), dtype=np.uint8) padimg[:, :, :] = 128 vggLoss = VGGLoss(network='squeeze') cimg_cnt = 0 mask_bin_size = 0.1 n_bins = int(1.0 / mask_bin_size) vLTotal = np.zeros((n_bins, )) pSNRTotal = np.zeros((n_bins, )) ssimTotal = np.zeros((n_bins, )) total_count = np.zeros((n_bins, )) + 1e-8 perImageRes = {'images': {}, 'overall': {}} if params['dilateMask']: dilateWeight = torch.ones( (1, 1, params['dilateMask'], params['dilateMask'])) dilateWeight = Variable(dilateWeight, requires_grad=False).cuda() else: dilateWeight = None for i in tqdm(xrange(len(data_iter))): #for i in tqdm(xrange(2)): idx = data_iter[i] x, real_label, boxImg, boxlabel, mask, bbox, curCls = targ_split[idx] cocoid = targ_split.getcocoid(idx) nnz_cls = real_label.nonzero() z_cls = (1 - real_label).nonzero() z_cls = z_cls[:, 0] x = x[None, ::] boxImg = boxImg[None, ::] mask = mask[None, ::] boxlabel = boxlabel[None, ::] real_label = real_label[None, ::] x, boxImg, mask, boxlabel = solvers[0].to_var( x, volatile=True), solvers[0].to_var( boxImg, volatile=True), solvers[0].to_var( mask, volatile=True), solvers[0].to_var(boxlabel, volatile=True) real_label = solvers[0].to_var(real_label, volatile=True) fake_x, mask_out = solvers[0].forward_generator( x, imagelabel=real_label, mask_threshold=params['mask_threshold'], onlyMasks=False, mask=mask, withGTMask=False, dilate=dilateWeight) vL = vggLoss(fake_x, x).data[0] # Change the image range to 0, 255 fake_x_sk = get_sk_image(fake_x) x_sk = get_sk_image(x) a = x.data.cpu().numpy() print(a.shape) plt.subplot(121) plt.imshow(np.rollaxis(x.data.cpu().numpy().squeeze(), 0, start=3)) plt.subplot(122) plt.imshow(np.rollaxis(fake_x.data.cpu().numpy().squeeze(), 0, start=3)) plt.savefig('%d.png' % (i, )) if i == 9: break pSNR = compare_psnr(fake_x_sk, x_sk, data_range=255.) ssim = compare_ssim(fake_x_sk, x_sk, data_range=255., multichannel=True) msz = mask.data.cpu().numpy().mean() if msz > 0.: msz_bin = int((msz - 1e-8) / mask_bin_size) perImageRes['images'][cocoid] = {'overall': {}} perImageRes['images'][cocoid]['overall']['perceptual'] = float(vL) perImageRes['images'][cocoid]['overall']['pSNR'] = float(pSNR) perImageRes['images'][cocoid]['overall']['ssim'] = float(ssim) perImageRes['images'][cocoid]['overall']['mask_size'] = float(msz) vLTotal[msz_bin] += vL pSNRTotal[msz_bin] += pSNR ssimTotal[msz_bin] += ssim total_count[msz_bin] += 1 print '------------------------------------------------------------' print ' Metrics have been computed ' print '------------------------------------------------------------' print('Percp: || %s |' % (' | '.join([ ' %.3f' % sc for sc in [vLTotal.sum() / total_count.sum()] + list(vLTotal / total_count) ]))) print('pSNR : || %s |' % (' | '.join([ ' %.3f' % sc for sc in [pSNRTotal.sum() / total_count.sum()] + list(pSNRTotal / total_count) ]))) print('ssim : || %s |' % (' | '.join([ ' %.3f' % sc for sc in [ssimTotal.sum() / total_count.sum()] + list(ssimTotal / total_count) ]))) if params['dump_perimage_res']: json.dump( perImageRes, open( join( params['dump_perimage_res'], params['split'] + '_' + basename(params['model'][0]).split('.')[0]), 'w'))
ipdb> plt.imshow(((fake_x.data.cpu().numpy()[2,[0,1,2],:,:].transpose(1,2,0)+1.0)*255./2.0).astype(np.uint8)); plt.show() <matplotlib.image.AxesImage object at 0x7fb64908acd0> ipdb> plt.imshow(((diffimg.data.cpu().numpy()[2,[0,1,2],:,:].transpose(1,2,0)+1.0)*255./2.0).astype(np.uint8)); plt.show() <matplotlib.image.AxesImage object at 0x7fb648fcccd0> ipdb> plt.imshow(((1-mask.data.cpu().numpy()[2,[0,1,2],:,:].transpose(1,2,0))*255.).astype(np.uint8)); plt.show() <matplotlib.image.AxesImage object at 0x7fb648f1e410> # from utils.data_loader_stargan import get_dataset import matplotlib.pyplot as plt import numpy as np dataset = get_dataset(config.celebA_image_path, config.metadata_path, config.celebA_crop_size, config.image_size, config.dataset, config.mode, select_attrs=config.selected_attrs, datafile=config.datafile,bboxLoader=config.train_boxreconst) img, imgLab, boxImg, boxLab, mask = dataset[0] plt.figure();plt.imshow(((img.numpy()[[0,1,2],:,:].transpose(1,2,0)+1.0)*255./2.0).astype(np.uint8)); plt.figure();plt.imshow(((boxImg.numpy()[[0,1,2],:,:].transpose(1,2,0)+1.0)*255./2.0).astype(np.uint8));plt.figure(); plt.imshow(((mask.numpy()[[0,0,0],:,:].transpose(1,2,0)+1.0)*255./2.0).astype(np.uint8));plt.figure();plt.imshow((((img*mask).numpy()[[0,1,2],:,:].transpose(1,2,0)+1.0)*255./2.0).astype(np.uint8));plt.show() ###---------------------------------------------------------------------------------------------------------------- import numpy as np from collections import defaultdict from tqdm import tqdm pwd