def save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name): utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, None, G_ema if config['ema'] else None) # Save an additional copy to mitigate accidental corruption if process # is killed during a save (it's happened to me before -.-) if config['num_save_copies'] > 0: utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'copy%d' % state_dict['save_num'], G_ema if config['ema'] else None) state_dict['save_num'] = ( state_dict['save_num'] + 1) % config['num_save_copies'] # Use EMA G for samples or non-EMA? which_G = G_ema if config['ema'] and config['use_ema'] else G # Accumulate standing statistics? if config['accumulate_stats']: utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G, z_, y_, config['n_classes'], config['num_standing_accumulations']) # Save a random sample sheet with fixed z and y with torch.no_grad(): if config['parallel']: fixed_Gz = nn.parallel.data_parallel( which_G, (fixed_z, which_G.shared(fixed_y))) else: fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y)) if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)): os.mkdir('%s/%s' % (config['samples_root'], experiment_name)) image_filename = '%s/%s/fixed_samples%d.jpg' % (config['samples_root'], experiment_name, state_dict['itr']) torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename, nrow=int(fixed_Gz.shape[0] ** 0.5), normalize=True) # For now, every time we save, also save sample sheets utils.sample_sheet(which_G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], num_classes=config['n_classes'], samples_per_class=10, parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=state_dict['itr'], z_=z_) # Also save interp sheets for fix_z, fix_y in zip([False, False, True], [False, True, False]): utils.interp_sheet(which_G, num_per_sheet=16, num_midpoints=8, num_classes=config['n_classes'], parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=state_dict['itr'], sheet_number=0, fix_z=fix_z, fix_y=fix_y, device='cuda')
def save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name): utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, None, G_ema if config['ema'] else None) if config['num_save_copies'] > 0: utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'copy%d' % state_dict['save_num'], G_ema if config['ema'] else None) state_dict['save_num'] = (state_dict['save_num'] + 1) % config['num_save_copies'] which_G = G_ema if config['ema'] and config['use_ema'] else G if config['accumulate_stats']: if not config['conditional']: y_.zero_() utils.accumulate_standing_stats( G_ema if config['ema'] and config['use_ema'] else G, z_, y_, config['n_classes'], config['num_standing_accumulations']) with torch.no_grad(): if config['parallel']: fixed_Gz = nn.parallel.data_parallel( which_G, (fixed_z, which_G.shared(fixed_y))) else: fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y)) if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)): os.mkdir('%s/%s' % (config['samples_root'], experiment_name)) image_filename = '%s/%s/fixed_samples%d.jpg' % ( config['samples_root'], experiment_name, state_dict['itr']) torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename, nrow=int(fixed_Gz.shape[0]**0.5), normalize=True) utils.sample_sheet( which_G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], num_classes=config['n_classes'], samples_per_class=10, parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=state_dict['itr'], z_=z_) for fix_z, fix_y in zip([False, False, True], [False, True, False]): utils.interp_sheet(which_G, num_per_sheet=16, num_midpoints=8, num_classes=config['n_classes'], parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=state_dict['itr'], sheet_number=0, fix_z=fix_z, fix_y=fix_y, device='cuda')
def run(config): # Prepare state dict, which holds things like epoch # and itr # state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } # Optionally, get the configuration from the state dict. This allows for # recovery of the config provided only a state dict and experiment name, # and can be convenient for writing less verbose sample shell scripts. if config['config_from_name']: utils.load_weights(None, None, state_dict, config['weights_root'], config['experiment_name'], config['load_weights'], None, strict=False, load_optim=False) # Ignore items which we might want to overwrite from the command line for item in state_dict['config']: if item not in [ 'z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode' ]: config[item] = state_dict['config'][item] # update config (see train.py for explanation) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] config = utils.update_config_roots(config) config['skip_init'] = True config['no_optim'] = True device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) G = model.Generator(**config).cuda() utils.count_parameters(G) # Load weights print('Loading weights...') # Here is where we deal with the ema--load ema weights or load normal weights utils.load_weights(G if not (config['use_ema']) else None, None, state_dict, config['weights_root'], experiment_name, config['load_weights'], G if config['ema'] and config['use_ema'] else None, strict=False, load_optim=False) # Update batch size setting used for G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'], z_var=config['z_var']) if config['G_eval_mode']: print('Putting G in eval mode..') G.eval() else: print('G is in %s mode...' % ('training' if G.training else 'eval')) # Sample function sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) if config['accumulate_stats']: print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations']) utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], config['num_standing_accumulations']) # Sample a number of images and save them to an NPZ, for use with TF-Inception if config['sample_npz']: # Lists to hold images and labels for images x, y = [], [] print('Sampling %d images and saving them to npz...' % config['sample_num_npz']) for i in trange( int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))): with torch.no_grad(): images, labels = sample() x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)] y += [labels.cpu().numpy()] x = np.concatenate(x, 0)[:config['sample_num_npz']] y = np.concatenate(y, 0)[:config['sample_num_npz']] print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape)) npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name) print('Saving npz to %s...' % npz_filename) np.savez(npz_filename, **{'x': x, 'y': y}) # Prepare sample sheets if config['sample_sheets']: print('Preparing conditional sample sheets...') utils.sample_sheet( G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], num_classes=config['n_classes'], samples_per_class=10, parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=config['sample_sheet_folder_num'], z_=z_, ) # Sample interp sheets if config['sample_interps']: print('Preparing interp sheets...') for fix_z, fix_y in zip([False, False, True], [False, True, False]): utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8, num_classes=config['n_classes'], parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=config['sample_sheet_folder_num'], sheet_number=0, fix_z=fix_z, fix_y=fix_y, device='cuda') # Sample random sheet if config['sample_random']: print('Preparing random sample sheet...') images, labels = sample() torchvision.utils.save_image(images.float(), '%s/%s/random_samples.jpg' % (config['samples_root'], experiment_name), nrow=int(G_batch_size**0.5), normalize=True) # Get Inception Score and FID get_inception_metrics = inception_utils.prepare_inception_metrics( config['dataset'], config['parallel'], config['no_fid']) # Prepare a simple function get metrics that we use for trunc curves def get_metrics(): sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) IS_mean, IS_std, FID = get_inception_metrics( sample, config['num_inception_images'], num_splits=10, prints=False) # Prepare output string outstring = 'Using %s weights ' % ('ema' if config['use_ema'] else 'non-ema') outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else 'training') outstring += 'with noise variance %3.3f, ' % z_.var outstring += 'over %d images, ' % config['num_inception_images'] if config['accumulate_stats'] or not config['G_eval_mode']: outstring += 'with batch size %d, ' % G_batch_size if config['accumulate_stats']: outstring += 'using %d standing stat accumulations, ' % config[ 'num_standing_accumulations'] outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % ( state_dict['itr'], IS_mean, IS_std, FID) print(outstring) if config['sample_inception_metrics']: print('Calculating Inception metrics...') get_metrics() # Sample truncation curve stuff. This is basically the same as the inception metrics code if config['sample_trunc_curves']: start, step, end = [ float(item) for item in config['sample_trunc_curves'].split('_') ] print( 'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...' % (start, step, end)) for var in np.arange(start, end + step, step): z_.var = var # Optionally comment this out if you want to run with standing stats # accumulated at one z variance setting if config['accumulate_stats']: utils.accumulate_standing_stats( G, z_, y_, config['n_classes'], config['num_standing_accumulations']) get_metrics()
def save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name,sample_only=False, use_real = False, real_batch = None, id = "", mixed=False, target_map = None): if not sample_only: utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, None, G_ema if config['ema'] else None) # Save an additional copy to mitigate accidental corruption if process # is killed during a save (it's happened to me before -.-) if config['num_save_copies'] > 0: utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'copy%d' % state_dict['save_num'], G_ema if config['ema'] else None) state_dict['save_num'] = (state_dict['save_num'] + 1 ) % config['num_save_copies'] else: # Use EMA G for samples or non-EMA? which_G = G_ema if config['ema'] and config['use_ema'] else G # Accumulate standing statistics? if config['accumulate_stats']: print("accumulating stats!") utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G, z_, y_, config['n_classes'], config['num_standing_accumulations']) # Save a random sample sheet with fixed z and y with torch.no_grad(): if use_real: fixed_Gz = real_batch experiment_name += "_real" else: if config['parallel']: fixed_Gz = nn.parallel.data_parallel(which_G, (fixed_z, which_G.shared(fixed_y))) else: fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y)) if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)): os.mkdir('%s/%s' % (config['samples_root'], experiment_name)) image_filename = '%s/%s/fixed_samples%d' % (config['samples_root'],experiment_name,state_dict['itr']) image_filename += id + ".jpg" if not (state_dict["itr"]>config["sample_every"] and use_real and not mixed): torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename, nrow=int(fixed_Gz.shape[0] **0.5), normalize=True) with torch.no_grad(): D_map, c = D(fixed_Gz ,fixed_y ) D_map = F.sigmoid(D_map) c = F.sigmoid(c) s = D_map.mean([2,3]) s = s.view(-1) c = c.view(-1) cs = torch.cat((c.view(c.size(0),1) ,s.view(s.size(0),1) ),dim=1) cs = cs.cpu().numpy() cs = cs.round(3) if mixed: s_real = D_map.clone() s_real = s_real*target_map # all fakes are zero now s_real = s_real.sum([2,3])/target_map.sum([2,3]) s_fake = D_map.clone() s_fake = s_fake*(1-target_map) # all real are zero now s_fake = s_fake.sum([2,3])/(1-target_map).sum([2,3]) s_fake = s_fake.view(-1) s_real = s_real.view(-1) cs_real = torch.cat((c.view(c.size(0),1) ,s_real.view(s_real.size(0),1) ),dim=1) cs_real = cs_real.cpu().numpy() cs_real = cs_real.round(3) cs_fake = torch.cat((c.view(c.size(0),1) ,s_fake.view(s_fake.size(0),1) ),dim=1) cs_fake = cs_fake.cpu().numpy() cs_fake = cs_fake.round(3) cs_mix = torch.cat((c.view(c.size(0),1) ,s_real.view(s_real.size(0),1) ,s_fake.view(s_fake.size(0),1) ),dim=1) cs_mix = cs_mix.cpu().numpy() cs_mix = cs_mix.round(3) for i in range(D_map.size(0)): D_map[i] = D_map[i] - D_map[i].min() D_map[i] = D_map[i]/D_map[i].max() image_filename = '%s/%s/fixed_samples_D%d' % (config['samples_root'],experiment_name,state_dict['itr']) image_filename += id + ".jpg" torchvision.utils.save_image(D_map.float().cpu(), image_filename, nrow=int(fixed_Gz.shape[0] **0.5), normalize=False) if config["resolution"]==128: num_per_sheet=16 num_midpoints=8 elif config["resolution"]==256: num_per_sheet=8 num_midpoints=4 elif config["resolution"]==64: num_per_sheet=32 num_midpoints=16 if not use_real: # For now, every time we save, also save sample sheets utils.sample_sheet(which_G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], num_classes=config['n_classes'], samples_per_class=10, parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=state_dict['itr'], z_=z_) # Also save interp sheets if config["dataset"]=="coco": for fix_z, fix_y in zip([False, False, True], [False, True, False]): utils.interp_sheet(which_G, num_per_sheet=num_per_sheet, num_midpoints=num_midpoints, num_classes=config['n_classes'], parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=state_dict['itr'], sheet_number=0, fix_z=fix_z, fix_y=fix_y, device='cuda',config=config)
def run(config): # Prepare state dict, which holds things like epoch # and itr # state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } # Optionally, get the configuration from the state dict. This allows for # recovery of the config provided only a state dict and experiment name, # and can be convenient for writing less verbose sample shell scripts. if config['config_from_name']: utils.load_weights(None, None, state_dict, config['weights_root'], config['experiment_name'], config['load_weights'], None, strict=False, load_optim=False) # Ignore items which we might want to overwrite from the command line for item in state_dict['config']: if item not in [ 'z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode' ]: config[item] = state_dict['config'][item] # update config (see train.py for explanation) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['n_channels'] = utils.nchannels_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] config = utils.update_config_roots(config) config['skip_init'] = True config['no_optim'] = True device = 'cuda' # Seed RNG # utils.seed_rng(config['seed']) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) G = model.Generator(**config).cuda() utils.count_parameters(G) # In some cases we need to load D if True or config['get_test_error'] or config['get_train_error'] or config[ 'get_self_error'] or config['get_generator_error']: disc_config = config.copy() if config['mh_csc_loss'] or config['mh_loss']: disc_config['output_dim'] = disc_config['n_classes'] + 1 D = model.Discriminator(**disc_config).to(device) def get_n_correct_from_D(x, y): """Gets the "classifications" from D. y: the correct labels In the case of projection discrimination we have to pass in all the labels as conditionings to get the class specific affinity. """ x = x.to(device) if config['model'] == 'BigGAN': # projection discrimination case if not config['get_self_error']: y = y.to(device) yhat = D(x, y) for i in range(1, config['n_classes']): yhat_ = D(x, ((y + i) % config['n_classes'])) yhat = torch.cat([yhat, yhat_], 1) preds_ = yhat.data.max(1)[1].cpu() return preds_.eq(0).cpu().sum() else: # the mh gan case if not config['get_self_error']: y = y.to(device) yhat = D(x) preds_ = yhat[:, :config['n_classes']].data.max(1)[1] return preds_.eq(y.data).cpu().sum() # Load weights print('Loading weights...') # Here is where we deal with the ema--load ema weights or load normal weights utils.load_weights(G if not (config['use_ema']) else None, D, state_dict, config['weights_root'], experiment_name, config['load_weights'], G if config['ema'] and config['use_ema'] else None, strict=False, load_optim=False) # Update batch size setting used for G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'], z_var=config['z_var']) if config['G_eval_mode']: print('Putting G in eval mode..') G.eval() else: print('G is in %s mode...' % ('training' if G.training else 'eval')) sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) brief_expt_name = config['experiment_name'][-30:] # load results dict always HIST_FNAME = 'scoring_hist.npy' def load_or_make_hist(d): """make/load history files in each """ if not os.path.isdir(d): raise Exception('%s is not a valid directory' % d) f = os.path.join(d, HIST_FNAME) if os.path.isfile(f): return np.load(f, allow_pickle=True).item() else: return defaultdict(dict) hist_dir = os.path.join(config['weights_root'], config['experiment_name']) hist = load_or_make_hist(hist_dir) if config['get_test_error'] or config['get_train_error']: loaders = utils.get_data_loaders( **{ **config, 'batch_size': config['batch_size'], 'start_itr': state_dict['itr'], 'use_test_set': config['get_test_error'] }) acc_type = 'Test' if config['get_test_error'] else 'Train' pbar = tqdm(loaders[0]) loader_total = len(loaders[0]) * config['batch_size'] sample_todo = min(config['sample_num_error'], loader_total) print('Getting %s error accross %i examples' % (acc_type, sample_todo)) correct = 0 total = 0 with torch.no_grad(): for i, (x, y) in enumerate(pbar): correct += get_n_correct_from_D(x, y) total += config['batch_size'] if loader_total > total and total >= config['sample_num_error']: print('Quitting early...') break accuracy = float(correct) / float(total) hist = load_or_make_hist(hist_dir) hist[state_dict['itr']][acc_type] = accuracy np.save(os.path.join(hist_dir, HIST_FNAME), hist) print('[%s][%06d] %s accuracy: %f.' % (brief_expt_name, state_dict['itr'], acc_type, accuracy * 100)) if config['get_self_error']: n_used_imgs = config['sample_num_error'] correct = 0 imageSize = config['resolution'] x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8) for l in tqdm(range(n_used_imgs // G_batch_size), desc='Generating [%s][%06d]' % (brief_expt_name, state_dict['itr'])): with torch.no_grad(): images, y = sample() correct += get_n_correct_from_D(images, y) accuracy = float(correct) / float(n_used_imgs) print('[%s][%06d] %s accuracy: %f.' % (brief_expt_name, state_dict['itr'], 'Self', accuracy * 100)) hist = load_or_make_hist(hist_dir) hist[state_dict['itr']]['Self'] = accuracy np.save(os.path.join(hist_dir, HIST_FNAME), hist) if config['get_generator_error']: if config['dataset'] == 'C10': from classification.models.densenet import DenseNet121 from torchvision import transforms compnet = DenseNet121() compnet = torch.nn.DataParallel(compnet) #checkpoint = torch.load(os.path.join('/scratch0/ilya/locDoc/classifiers/densenet121','ckpt_47.t7')) checkpoint = torch.load( os.path.join( '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar/densenet121', 'ckpt_47.t7')) compnet.load_state_dict(checkpoint['net']) compnet = compnet.to(device) compnet.eval() minimal_trans = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) elif config['dataset'] == 'C100': from classification.models.densenet import DenseNet121 from torchvision import transforms compnet = DenseNet121(num_classes=100) compnet = torch.nn.DataParallel(compnet) checkpoint = torch.load( os.path.join( '/scratch0/ilya/locDoc/classifiers/cifar100/densenet121', 'ckpt.copy.t7')) #checkpoint = torch.load(os.path.join('/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar100/densenet121','ckpt.copy.t7')) compnet.load_state_dict(checkpoint['net']) compnet = compnet.to(device) compnet.eval() minimal_trans = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)), ]) elif config['dataset'] == 'STL48': from classification.models.wideresnet import WideResNet48 from torchvision import transforms checkpoint = torch.load( os.path.join( '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/stl/mixmatch_48', 'model_best.pth.tar')) compnet = WideResNet48(num_classes=10) compnet = compnet.to(device) for param in compnet.parameters(): param.detach_() compnet.load_state_dict(checkpoint['ema_state_dict']) compnet.eval() minimal_trans = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) else: raise ValueError('Dataset %s has no comparison network.' % config['dataset']) n_used_imgs = 10000 correct = 0 mean_label = np.zeros(config['n_classes']) imageSize = config['resolution'] x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8) for l in tqdm(range(n_used_imgs // G_batch_size), desc='Generating [%s][%06d]' % (brief_expt_name, state_dict['itr'])): with torch.no_grad(): images, y = sample() fake = images.data.cpu().numpy() fake = np.floor((fake + 1) * 255 / 2.0).astype(np.uint8) fake_input = np.zeros(fake.shape) for bi in range(fake.shape[0]): fake_input[bi] = minimal_trans(np.moveaxis( fake[bi], 0, -1)) images.data.copy_(torch.from_numpy(fake_input)) lab = compnet(images).max(1)[1] mean_label += np.bincount(lab.data.cpu(), minlength=config['n_classes']) correct += int((lab == y).sum().cpu()) accuracy = float(correct) / float(n_used_imgs) mean_label_normalized = mean_label / float(n_used_imgs) print( '[%s][%06d] %s accuracy: %f.' % (brief_expt_name, state_dict['itr'], 'Generator', accuracy * 100)) hist = load_or_make_hist(hist_dir) hist[state_dict['itr']]['Generator'] = accuracy hist[state_dict['itr']]['Mean_Label'] = mean_label_normalized np.save(os.path.join(hist_dir, HIST_FNAME), hist) if config['accumulate_stats']: print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations']) utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], config['num_standing_accumulations']) # Sample a number of images and save them to an NPZ, for use with TF-Inception if config['sample_npz']: # Lists to hold images and labels for images x, y = [], [] print('Sampling %d images and saving them to npz...' % config['sample_num_npz']) for i in trange( int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))): with torch.no_grad(): images, labels = sample() x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)] y += [labels.cpu().numpy()] x = np.concatenate(x, 0)[:config['sample_num_npz']] y = np.concatenate(y, 0)[:config['sample_num_npz']] print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape)) npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name) print('Saving npz to %s...' % npz_filename) np.savez(npz_filename, **{'x': x, 'y': y}) if config['official_FID']: f = np.load(config['dataset_is_fid']) # this is for using the downloaded one from # https://github.com/bioinf-jku/TTUR #mdata, sdata = f['mu'][:], f['sigma'][:] # this one is for my format files mdata, sdata = f['mfid'], f['sfid'] # Sample a number of images and stick them in memory, for use with TF-Inception official_IS and official_FID data_gen_necessary = False if config['sample_np_mem']: is_saved = int('IS' in hist[state_dict['itr']]) is_todo = int(config['official_IS']) fid_saved = int('FID' in hist[state_dict['itr']]) fid_todo = int(config['official_FID']) data_gen_necessary = config['overwrite'] or (is_todo > is_saved) or ( fid_todo > fid_saved) if config['sample_np_mem'] and data_gen_necessary: n_used_imgs = 50000 imageSize = config['resolution'] x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8) for l in tqdm(range(n_used_imgs // G_batch_size), desc='Generating [%s][%06d]' % (brief_expt_name, state_dict['itr'])): start = l * G_batch_size end = start + G_batch_size with torch.no_grad(): images, labels = sample() fake = np.uint8(255 * (images.cpu().numpy() + 1) / 2.) x[start:end] = np.moveaxis(fake, 1, -1) #y += [labels.cpu().numpy()] if config['official_IS']: if (not ('IS' in hist[state_dict['itr']])) or config['overwrite']: mis, sis = iscore.get_inception_score(x) print('[%s][%06d] IS mu: %f. IS sigma: %f.' % (brief_expt_name, state_dict['itr'], mis, sis)) hist = load_or_make_hist(hist_dir) hist[state_dict['itr']]['IS'] = [mis, sis] np.save(os.path.join(hist_dir, HIST_FNAME), hist) else: mis, sis = hist[state_dict['itr']]['IS'] print( '[%s][%06d] Already done (skipping...): IS mu: %f. IS sigma: %f.' % (brief_expt_name, state_dict['itr'], mis, sis)) if config['official_FID']: import tensorflow as tf def fid_ms_for_imgs(images, mem_fraction=0.5): gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=mem_fraction) inception_path = fid.check_or_download_inception(None) fid.create_inception_graph( inception_path) # load the graph into the current TF graph with tf.Session(config=tf.ConfigProto( gpu_options=gpu_options)) as sess: sess.run(tf.global_variables_initializer()) mu_gen, sigma_gen = fid.calculate_activation_statistics( images, sess, batch_size=100) return mu_gen, sigma_gen if (not ('FID' in hist[state_dict['itr']])) or config['overwrite']: m1, s1 = fid_ms_for_imgs(x) fid_value = fid.calculate_frechet_distance(m1, s1, mdata, sdata) print('[%s][%06d] FID: %f' % (brief_expt_name, state_dict['itr'], fid_value)) hist = load_or_make_hist(hist_dir) hist[state_dict['itr']]['FID'] = fid_value np.save(os.path.join(hist_dir, HIST_FNAME), hist) else: fid_value = hist[state_dict['itr']]['FID'] print('[%s][%06d] Already done (skipping...): FID: %f' % (brief_expt_name, state_dict['itr'], fid_value)) # Prepare sample sheets if config['sample_sheets']: print('Preparing conditional sample sheets...') folder_number = config['sample_sheet_folder_num'] if folder_number == -1: folder_number = config['load_weights'] utils.sample_sheet( G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], num_classes=config['n_classes'], samples_per_class=10, parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=folder_number, z_=z_, ) # Sample interp sheets if config['sample_interps']: print('Preparing interp sheets...') folder_number = config['sample_sheet_folder_num'] if folder_number == -1: folder_number = config['load_weights'] for fix_z, fix_y in zip([False, False, True], [False, True, False]): utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8, num_classes=config['n_classes'], parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=int(folder_number), sheet_number=0, fix_z=fix_z, fix_y=fix_y, device='cuda') # Sample random sheet if config['sample_random']: print('Preparing random sample sheet...') images, labels = sample() torchvision.utils.save_image( images.float(), '%s/%s/%s.jpg' % (config['samples_root'], experiment_name, config['load_weights']), nrow=int(G_batch_size**0.5), normalize=True) # Prepare a simple function get metrics that we use for trunc curves def get_metrics(): # Get Inception Score and FID get_inception_metrics = inception_utils.prepare_inception_metrics( config['dataset'], config['parallel'], config['no_fid']) sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) IS_mean, IS_std, FID = get_inception_metrics( sample, config['num_inception_images'], num_splits=10, prints=False) # Prepare output string outstring = 'Using %s weights ' % ('ema' if config['use_ema'] else 'non-ema') outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else 'training') outstring += 'with noise variance %3.3f, ' % z_.var outstring += 'over %d images, ' % config['num_inception_images'] if config['accumulate_stats'] or not config['G_eval_mode']: outstring += 'with batch size %d, ' % G_batch_size if config['accumulate_stats']: outstring += 'using %d standing stat accumulations, ' % config[ 'num_standing_accumulations'] outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % ( state_dict['itr'], IS_mean, IS_std, FID) print(outstring) if config['sample_inception_metrics']: print('Calculating Inception metrics...') get_metrics() # Sample truncation curve stuff. This is basically the same as the inception metrics code if config['sample_trunc_curves']: start, step, end = [ float(item) for item in config['sample_trunc_curves'].split('_') ] print( 'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...' % (start, step, end)) for var in np.arange(start, end + step, step): z_.var = var # Optionally comment this out if you want to run with standing stats # accumulated at one z variance setting if config['accumulate_stats']: utils.accumulate_standing_stats( G, z_, y_, config['n_classes'], config['num_standing_accumulations']) get_metrics()
def run(config): # Prepare state dict, which holds things like epoch # and itr # state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } # print(config) # exit() # Optionally, get the configuration from the state dict. This allows for # recovery of the config provided only a state dict and experiment name, # and can be convenient for writing less verbose sample shell scripts. if config['config_from_name']: # print(config['weights_root'],config['experiment_name'], config['load_weights']) utils.load_weights(None, None, state_dict, config['weights_root'], config['experiment_name'], config['load_weights'], None, strict=False, load_optim=False) # Ignore items which we might want to overwrite from the command line for item in state_dict['config']: if item not in [ 'z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode' ]: config[item] = state_dict['config'][item] # update config (see train.py for explanation) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] config = utils.update_config_roots(config) config['skip_init'] = True config['no_optim'] = True device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) G = model.Generator(**config).cuda() # zht: my code # D = model.Discriminator(**config).cuda() from torch.nn import ReLU config_fixed = { 'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 0, 'pin_memory': True, 'shuffle': True, 'load_in_mem': False, 'use_multiepoch_sampler': True, 'model': 'BigGAN', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'G_depth': 1, 'D_depth': 1, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'z_var': 1.0, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': True, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 64, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 100, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'G_eval_mode': True, 'save_every': 1000, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'data_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': '', 'config_from_name': False, 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': False, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'sample_npz': True, 'sample_num_npz': 50000, 'sample_sheets': True, 'sample_interps': True, 'sample_sheet_folder_num': -1, 'sample_random': True, 'sample_trunc_curves': '0.05_0.05_1.0', 'sample_inception_metrics': True, 'resolution': 128, 'n_classes': 1000, 'G_activation': ReLU(inplace=True), 'D_activation': ReLU(inplace=True), 'no_optim': True } # config_fixed = {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 0, 'pin_memory': True, 'shuffle': True, 'load_in_mem': False, 'use_multiepoch_sampler': True, 'model': 'BigGAN', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'G_depth': 1, 'D_depth': 1, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'z_var': 1.0, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': True, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 64, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 100, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'G_eval_mode': True, 'save_every': 1000, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'data_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': '', 'config_from_name': False, 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': False, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'sample_npz': True, 'sample_num_npz': 50000, 'sample_sheets': True, 'sample_interps': True, 'sample_sheet_folder_num': -1, 'sample_random': True, 'sample_trunc_curves': '0.05_0.05_1.0', 'sample_inception_metrics': True, 'resolution': 128, 'n_classes': 1000, 'no_optim': True} D = model.Discriminator(**config_fixed).cuda() utils.load_weights(None, D, state_dict, config['weights_root'], experiment_name, config['load_weights'], None, strict=False, load_optim=False) D.eval() utils.count_parameters(G) # Load weights print('Loading weights...') # Here is where we deal with the ema--load ema weights or load normal weights utils.load_weights(G if not (config['use_ema']) else None, None, state_dict, config['weights_root'], experiment_name, config['load_weights'], G if config['ema'] and config['use_ema'] else None, strict=False, load_optim=False) # Update batch size setting used for G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'], z_var=config['z_var']) if config['G_eval_mode']: print('Putting G in eval mode..') G.eval() else: print('G is in %s mode...' % ('training' if G.training else 'eval')) #Sample function sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) if config['accumulate_stats']: print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations']) utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], config['num_standing_accumulations']) # Sample a number of images and save them to an NPZ, for use with TF-Inception if config['sample_npz']: # Lists to hold images and labels for images x, y = [], [] print('Sampling %d images and saving them to npz...' % config['sample_num_npz']) for i in trange( int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))): with torch.no_grad(): images, labels = sample() # zht: show discriminator results print(images.size(), labels.size()) dis_loss = D(x=images, y=labels) print(dis_loss.size()) print(dis_loss) exit() x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)] y += [labels.cpu().numpy()] plt.imshow(x[0][i, :, :, :].transpose((1, 2, 0))) plt.show() x = np.concatenate(x, 0)[:config['sample_num_npz']] y = np.concatenate(y, 0)[:config['sample_num_npz']] print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape)) npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name) print('Saving npz to %s...' % npz_filename) np.savez(npz_filename, **{'x': x, 'y': y}) # Prepare sample sheets if config['sample_sheets']: print('Preparing conditional sample sheets...') utils.sample_sheet( G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], num_classes=config['n_classes'], samples_per_class=10, parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=config['sample_sheet_folder_num'], z_=z_, ) # Sample interp sheets if config['sample_interps']: print('Preparing interp sheets...') for fix_z, fix_y in zip([False, False, True], [False, True, False]): utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8, num_classes=config['n_classes'], parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=config['sample_sheet_folder_num'], sheet_number=0, fix_z=fix_z, fix_y=fix_y, device='cuda') # Sample random sheet if config['sample_random']: print('Preparing random sample sheet...') images, labels = sample() torchvision.utils.save_image(images.float(), '%s/%s/random_samples.jpg' % (config['samples_root'], experiment_name), nrow=int(G_batch_size**0.5), normalize=True) # Get Inception Score and FID get_inception_metrics = inception_utils.prepare_inception_metrics( config['dataset'], config['parallel'], config['no_fid']) # Prepare a simple function get metrics that we use for trunc curves def get_metrics(): sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) IS_mean, IS_std, FID = get_inception_metrics( sample, config['num_inception_images'], num_splits=10, prints=False) # Prepare output string outstring = 'Using %s weights ' % ('ema' if config['use_ema'] else 'non-ema') outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else 'training') outstring += 'with noise variance %3.3f, ' % z_.var outstring += 'over %d images, ' % config['num_inception_images'] if config['accumulate_stats'] or not config['G_eval_mode']: outstring += 'with batch size %d, ' % G_batch_size if config['accumulate_stats']: outstring += 'using %d standing stat accumulations, ' % config[ 'num_standing_accumulations'] outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % ( state_dict['itr'], IS_mean, IS_std, FID) print(outstring) if config['sample_inception_metrics']: print('Calculating Inception metrics...') get_metrics() # Sample truncation curve stuff. This is basically the same as the inception metrics code if config['sample_trunc_curves']: start, step, end = [ float(item) for item in config['sample_trunc_curves'].split('_') ] print( 'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...' % (start, step, end)) for var in np.arange(start, end + step, step): z_.var = var # Optionally comment this out if you want to run with standing stats # accumulated at one z variance setting if config['accumulate_stats']: utils.accumulate_standing_stats( G, z_, y_, config['n_classes'], config['num_standing_accumulations']) get_metrics()
def run(config): # Prepare state dict, which holds things like epoch # and itr # state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config} # Optionally, get the configuration from the state dict. This allows for # recovery of the config provided only a state dict and experiment name, # and can be convenient for writing less verbose sample shell scripts. if config['config_from_name']: utils.load_weights(None, None, state_dict, config['weights_root'], config['experiment_name'], config['load_weights'], None, strict=False, load_optim=False) # Ignore items which we might want to overwrite from the command line for item in state_dict['config']: if item not in ['z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode']: config[item] = state_dict['config'][item] # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) # Next, build the model G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) utils.count_parameters(G) # Load weights print('Loading weights...') # Here is where we deal with the ema--load ema weights or load normal weights utils.load_weights(G if not (config['use_ema']) else None, None, state_dict, config['weights_root'], experiment_name, config['load_weights'], G if config['ema'] and config['use_ema'] else None, strict=False, load_optim=False) # Update batch size setting used for G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'], z_var=config['z_var']) if config['G_eval_mode']: print('Putting G in eval mode..') G.eval() else: print('G is in %s mode...' % ('training' if G.training else 'eval')) #Sample function sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) if config['accumulate_stats']: print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations']) utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], config['num_standing_accumulations']) # Sample a number of images and save them to an NPZ, for use with TF-Inception if config['sample_npz']: # Lists to hold images and labels for images x, y = [], [] print('Sampling %d images and saving them to npz...' % config['sample_num_npz']) for i in trange(int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))): with torch.no_grad(): images, labels = sample() x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)] y += [labels.cpu().numpy()] x = np.concatenate(x, 0)[:config['sample_num_npz']] y = np.concatenate(y, 0)[:config['sample_num_npz']] print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape)) npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name) print('Saving npz to %s...' % npz_filename) np.savez(npz_filename, **{'x' : x, 'y' : y}) # Prepare sample sheets if config['sample_sheets']: print('Preparing conditional sample sheets...') utils.sample_sheet(G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], num_classes=config['n_classes'], samples_per_class=10, parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=config['sample_sheet_folder_num'], z_=z_,) # Sample interp sheets if config['sample_interps']: print('Preparing interp sheets...') for fix_z, fix_y in zip([False, False, True], [False, True, False]): utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8, num_classes=config['n_classes'], parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=config['sample_sheet_folder_num'], sheet_number=0, fix_z=fix_z, fix_y=fix_y, device='cuda') # Sample random sheet if config['sample_random']: print('Preparing random sample sheet...') images, labels = sample() print("labels size", labels) torchvision.utils.save_image(images.float(), '%s/%s/random_samples.jpg' % (config['samples_root'], experiment_name), nrow=int(G_batch_size**0.5), normalize=True) # If using EMA, prepare it if config['ema']: print('Preparing EMA for G with decay of {}'.format(config['ema_decay'])) G_ema = model.Generator(**{**config, 'skip_init':True, 'no_optim': True}).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None # FP16? if config['G_fp16']: print('Casting G to float16...') G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: print('Casting D to fp16...') D = D.half() # Consider automatically reducing SN_eps? GD = model.G_D(G, D) #print(G) #print(D) print('Number of params in G: {} D: {}'.format( *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]])) # Prepare state dict, which holds things like epoch # and itr # state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config} # If loading from a pre-trained model, load weights if config['resume']: print('Loading weights...') utils.load_weights(G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # If parallel, parallelize the GD module if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) D_fake = D(images[1,:,:,:],labels[0]) print("D_fake ",D_fake)
def activation_extract(G, D, E, G_ema, fixed_x, fixed_y, z_, y_, state_dict, config, experiment_name, save_weights=False): # Use EMA G for samples or non-EMA? which_G = G_ema if config['ema'] and config['use_ema'] else G # Save a random sample sheet with fixed z and y # TODO: change here to encode fixed x into z and feed z with fixed y into G print("check2 ---------------------------") with torch.no_grad(): if config['parallel']: print("fixed_x: ", fixed_x.shape) if config['inference_nosample']: _, fixed_z, _ = nn.parallel.data_parallel(E, fixed_x) else: fixed_z, _, _ = nn.parallel.data_parallel(E, fixed_x) print("fixed_z: ", fixed_z.shape) print("fixed_y: ", fixed_y.shape) fixed_Gz = nn.parallel.data_parallel( which_G, (fixed_z, which_G.shared(fixed_y))) else: # print("fixed_x: ", fixed_x.shape) fixed_z, _, _ = E(fixed_x) # print("fixed_z: ", fixed_z.shape) # print("fixed_y: ", fixed_y.shape) #### TODO: Catch the intermediate output here AND save the results into numpy fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y)) print("check3 -----------------------------") if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)): os.mkdir('%s/%s' % (config['samples_root'], experiment_name)) image_filename = '%s/%s/fixed_samples%d.jpg' % ( config['samples_root'], experiment_name, state_dict['itr']) print("######### save_path #####: ", image_filename) torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename, nrow=int(fixed_Gz.shape[0]**0.5), normalize=True) # For now, every time we save, also save sample sheets utils.sample_sheet( which_G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], num_classes=config['n_classes'], samples_per_class=10, parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=state_dict['itr'], z_=z_) # Also save interp sheets for fix_z, fix_y in zip([False, False, True], [False, True, False]): utils.interp_sheet(which_G, num_per_sheet=16, num_midpoints=8, num_classes=config['n_classes'], parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=state_dict['itr'], sheet_number=0, fix_z=fix_z, fix_y=fix_y, device='cuda')