def run(config): config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] if config['resume']: config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' utils.seed_rng(config['seed']) utils.prepare_root(config) torch.backends.cudnn.benchmark = True model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) G3 = model.Generator(**config).to(device) D3 = model.Discriminator(**config).to(device) if config['ema']: G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None if config['G_fp16']: G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: D = D.half() GD = model.G_D(G, D, config['conditional']) GD3 = model.G_D(G3, D3, config['conditional']) state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config} if config['resume']: utils.load_weights(G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) #utils.load_weights(G, D, state_dict, '../Task3_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None) #utils.load_weights(G, D, state_dict, '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None) #utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'last0', G_ema if config['ema'] else None) #utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None) utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'last0', G_ema if config['ema'] else None) utils.load_weights(G, D, state_dict, '../Task3_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None) if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) # Use: config['abnormal_class'] #print(config['abnormal_class']) abnormal_class = config['abnormal_class'] select_dataset = config['select_dataset'] #print(config['select_dataset']) #print(select_dataset) loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'], 'abnormal_class': abnormal_class, 'select_dataset': select_dataset}) # Usage: --select_dataset cifar10 --abnormal_class 0 --shuffle --batch_size 64 --parallel --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 --dataset C10 --data_root ../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/data/ --G_ortho 0.0 --G_attn 0 --D_attn 0 --G_init N02 --D_init N02 --ema --use_ema --ema_start 1000 --start_eval 50 --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --loss_type kl_5 --seed 2 --which_best FID --model BigGAN --experiment_name C10Ukl5 # Use: --select_dataset mnist --abnormal_class 1 --shuffle --batch_size 64 --parallel --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 --dataset C10 --data_root ../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/data/ --G_ortho 0.0 --G_attn 0 --D_attn 0 --G_init N02 --D_init N02 --ema --use_ema --ema_start 1000 --start_eval 50 --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --loss_type kl_5 --seed 2 --which_best FID --model BigGAN --experiment_name C10Ukl5 G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() if not config['conditional']: fixed_y.zero_() y_.zero_() if config['which_train_fn'] == 'GAN': train = train_fns.GAN_training_function(G3, D3, GD3, G3, D3, GD3, G, D, GD, z_, y_, ema, state_dict, config) else: train = train_fns.dummy_training_function() sample = functools.partial(utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) if config['dataset'] == 'C10U' or config['dataset'] == 'C10': data_moments = 'fid_stats_cifar10_train.npz' #'../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz' #data_moments = '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz' else: print("Cannot find the data set.") sys.exit() for epoch in range(state_dict['epoch'], config['num_epochs']): if config['pbar'] == 'mine': pbar = utils.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(loaders[0]) for i, (x, y) in enumerate(pbar): state_dict['itr'] += 1 G.eval() D.train() if config['ema']: G_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) print('') # Random seed #print(config['seed']) if epoch==0 and i==0: print(config['seed']) metrics = train(x, y) # We double the learning rate if we double the batch size. train_log.log(itr=int(state_dict['itr']), **metrics) if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')}) if config['pbar'] == 'mine': print(', '.join(['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) if (not (state_dict['itr'] % config['test_every'])) and (epoch >= config['start_eval']): if config['G_eval_mode']: G.eval() if config['ema']: G_ema.eval() utils.sample_inception( G_ema if config['ema'] and config['use_ema'] else G, config, str(epoch)) folder_number = str(epoch) sample_moments = '%s/%s/%s/samples.npz' % (config['samples_root'], experiment_name, folder_number) FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048) train_fns.update_FID(G, D, G_ema, state_dict, config, FID, experiment_name, test_log) state_dict['epoch'] += 1 #utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'be01Bes01Best%d' % state_dict['save_best_num'], G_ema if config['ema'] else None) utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'last%d' % 0, G_ema if config['ema'] else None)
def run(config): config['resolution'] = imsize_dict[config['dataset']] config['n_classes'] = nclass_dict[config['dataset']] config['G_activation'] = activation_dict[config['G_nl']] config['D_activation'] = activation_dict[config['D_nl']] if config['resume']: config['skip_init'] = True config = update_config_roots(config) device = 'cuda' utils_Task1_KLWGAN_Simulation_Experiment.seed_rng(config['seed']) utils_Task1_KLWGAN_Simulation_Experiment.prepare_root(config) torch.backends.cudnn.benchmark = True model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils_Task1_KLWGAN_Simulation_Experiment.name_from_config(config)) G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) if config['ema']: G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device) ema = utils_Task1_KLWGAN_Simulation_Experiment.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None if config['G_fp16']: G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: D = D.half() GD = model.G_D(G, D, config['conditional']) state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config} if config['resume']: utils_Task1_KLWGAN_Simulation_Experiment.load_weights(G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) test_log = utils_Task1_KLWGAN_Simulation_Experiment.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) train_log = utils_Task1_KLWGAN_Simulation_Experiment.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) utils_Task1_KLWGAN_Simulation_Experiment.write_metadata(config['logs_root'], experiment_name, config, state_dict) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) # Use: config['abnormal_class'] #print(config['abnormal_class']) abnormal_class = config['abnormal_class'] select_dataset = config['select_dataset'] #print(config['select_dataset']) #print(select_dataset) loaders = utils_Task1_KLWGAN_Simulation_Experiment.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'], 'abnormal_class': abnormal_class, 'select_dataset': select_dataset}) G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils_Task1_KLWGAN_Simulation_Experiment.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z, fixed_y = utils_Task1_KLWGAN_Simulation_Experiment.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() if not config['conditional']: fixed_y.zero_() y_.zero_() if config['which_train_fn'] == 'GAN': train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config) else: train = train_fns.dummy_training_function() sample = functools.partial(utils_Task1_KLWGAN_Simulation_Experiment.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) if config['dataset'] == 'C10U' or config['dataset'] == 'C10': data_moments = 'fid_stats_cifar10_train.npz' #'../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz' #data_moments = '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz' else: print("Cannot find the dataset.") sys.exit() for epoch in range(state_dict['epoch'], config['num_epochs']): if config['pbar'] == 'mine': pbar = utils_Task1_KLWGAN_Simulation_Experiment.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(loaders[0]) for i, (x, y) in enumerate(pbar): state_dict['itr'] += 1 G.train() D.train() if config['ema']: G_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) #metrics = train(x, y) print('') # Random seed #print(config['seed']) if epoch==0 and i==0: print(config['seed']) metrics = train(x, y) # We double the learning rate if we double the batch size. train_log.log(itr=int(state_dict['itr']), **metrics) if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{**utils_Task1_KLWGAN_Simulation_Experiment.get_SVs(G, 'G'), **utils_Task1_KLWGAN_Simulation_Experiment.get_SVs(D, 'D')}) if config['pbar'] == 'mine': print(', '.join(['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils_Task1_KLWGAN_Simulation_Experiment.name_from_config(config)) if (not (state_dict['itr'] % config['test_every'])) and (epoch >= config['start_eval']): if config['G_eval_mode']: G.eval() if config['ema']: G_ema.eval() utils_Task1_KLWGAN_Simulation_Experiment.sample_inception(G_ema if config['ema'] and config['use_ema'] else G, config, str(epoch)) folder_number = str(epoch) sample_moments = '%s/%s/%s/samples.npz' % (config['samples_root'], experiment_name, folder_number) #FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048) #train_fns.update_FID(G, D, G_ema, state_dict, config, FID, experiment_name, test_log) # Use the files train_fns.py and utils_Task1_KLWGAN_Simulation_Experiment.py # Use the functions update_FID() and save_weights() # Save the lowest FID score FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048) train_fns.update_FID(G, D, G_ema, state_dict, config, FID, experiment_name, test_log) # FID also from: https://github.com/DarthSid95/RumiGANs/blob/main/gan_metrics.py # Implicit generative models and GANs generate sharp, low-FID, realistic, and high-quality images. # We use implicit generative models and GANs for the challenging task of anomaly detection in high-dimensional spaces. state_dict['epoch'] += 1 # Save the last model utils_Task1_KLWGAN_Simulation_Experiment.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'last%d' % 0, G_ema if config['ema'] else None)
def run(config): # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) # Next, build the model G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) # If using EMA, prepare it if config['ema']: print('Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None # FP16? if config['G_fp16']: print('Casting G to float16...') G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: print('Casting D to fp16...') D = D.half() # Consider automatically reducing SN_eps? GD = model.G_D(G, D, config['conditional']) # setting conditional to false print(G) print(D) print('Number of params in G: {} D: {}'.format( *[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]])) # Prepare state dict, which holds things like epoch # and itr # state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config} # If loading from a pre-trained model, load weights if config['resume']: print('Loading weights...') utils.load_weights(G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # If parallel, parallelize the GD module if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) # Prepare loggers for stats; metrics holds test metrics, # lmetrics holds any desired training metrics. test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) print('Inception Metrics will be saved to {}'.format(test_metrics_fname)) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) print('Training Metrics will be saved to {}'.format(train_metrics_fname)) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) # Write metadata utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) # Prepare data; the Discriminator's batch size is all that needs to be passed # to the dataloader, as G doesn't require dataloading. # Note that at every loader iteration we pass in enough data to complete # a full D iteration (regardless of number of D steps and accumulations) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr']}) # Prepare noise and randomly sampled label arrays # Allow for different batch sizes in G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) # Prepare a fixed z & y to see individual sample evolution throghout training fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() if not config['conditional']: fixed_y.zero_() y_.zero_() # Loaders are loaded, prepare the training function if config['which_train_fn'] == 'GAN': train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config) # Else, assume debugging and use the dummy train fn else: train = train_fns.dummy_training_function() # Prepare Sample function for use with inception metrics sample = functools.partial(utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) print('Beginning training at epoch %d...' % state_dict['epoch']) print('Total training epochs ', config['num_epochs']) print("the dataset is ", config['dataset'], ) if config['dataset'] == 'C10U' or config['dataset'] == 'C10': data_moments = 'fid_stats_cifar10_train.npz' else: print("cannot find the dataset") sys.exit() print("the data moments is ", data_moments) # Train for specified number of epochs, although we mostly track G iterations. for epoch in range(state_dict['epoch'], config['num_epochs']): # Which progressbar to use? TQDM or my own? if config['pbar'] == 'mine': pbar = utils.progress( loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(loaders[0]) for i, (x, y) in enumerate(pbar): # Increment the iteration counter state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter much. G.train() D.train() if config['ema']: G_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) metrics = train(x, y) train_log.log(itr=int(state_dict['itr']), **metrics) # Every sv_log_interval, log singular values if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')}) # If using my progbar, print metrics. if config['pbar'] == 'mine': print(', '.join(['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') # Save weights and copies as configured at specified interval if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) # Test every specified interval # First load celeba moments experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) if (not (state_dict['itr'] % config['test_every'])) and (epoch >= config['start_eval']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() if config['ema']: G_ema.eval() # sampling images and saving to samples/experiments/epoch utils.sample_inception( G_ema if config['ema'] and config['use_ema'] else G, config, str(epoch)) # Get saved sample path folder_number = str(epoch) sample_moments = '%s/%s/%s/samples.npz' % ( config['samples_root'], experiment_name, folder_number) # Calculate FID FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048) print("FID calculated") train_fns.update_FID( G, D, G_ema, state_dict, config, FID, experiment_name, test_log) # train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample, # get_inception_metrics, experiment_name, test_log) # Increment epoch counter at end of epoch state_dict['epoch'] += 1