def run(config): # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = vae_utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) # Next, build the model E = Ex.Extractor(**config).to(device) # If using EMA, prepare it if config['ema']: print('Preparing EMA for E with decay of {}'.format(config['ema_decay'])) E_ema = Ex.Extractor(**{**config, 'skip_init':True, 'no_optim': True}).to(device) ema = utils.ema(E, E_ema, config['ema_decay'], config['ema_start']) else: E_ema, ema = None, None print(E) print('Number of params in E: {}'.format( sum([p.data.nelement() for p in E.parameters()]))) # Prepare state dict, which holds things like epoch # and itr # state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config} # If loading from a pre-trained model, load weights if config['resume']: print('Loading weights...') vae_utils.load_weights([E], state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, [E_ema] if config['ema'] else None) # If parallel, parallelize the GD module if config['parallel']: E_parallel = nn.DataParallel(E) if config['cross_replica']: patch_replication_callback(E_parallel) # Prepare loggers for stats; metrics holds test metrics, # lmetrics holds any desired training metrics. test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) print('Inception Metrics will be saved to {}'.format(test_metrics_fname)) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) print('Training Metrics will be saved to {}'.format(train_metrics_fname)) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) # Write metadata utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) # Prepare data; the Discriminator's batch size is all that needs to be passed # to the dataloader, as G doesn't require dataloading. # Note that at every loader iteration we pass in enough data to complete # a full D iteration (regardless of number of D steps and accumulations) D_batch_size = (config['batch_size'] * 8) loaders = mini_datasets.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr']}) # Prepare noise and randomly sampled label arrays # Allow for different batch sizes in G G_batch_size = max(config['G_batch_size'], config['batch_size']) # Loaders are loaded, prepare the training function if config['which_train_fn'] == 'GAN': train = Ex.Extractor_training_function(E, ema, E_parallel, state_dict, config) # Else, assume debugging and use the dummy train fn else: train = train_fns.dummy_training_function() print('Beginning training at epoch %d...' % state_dict['epoch']) # Train for specified number of epochs, although we mostly track G iterations. for epoch in range(state_dict['epoch'], config['num_epochs']): # Which progressbar to use? TQDM or my own? if config['pbar'] == 'mine': pbar = utils.progress(zip(loaders[0], loaders[1]), displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(zip(loaders[0], loaders[1])) for i, (lx, ly, ux, uy) in enumerate(pbar): x = torch.cat([lx, ux], 0) y = torch.cat([ly, uy]) # Increment the iteration counter state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter much. E.train() ## Last night we process here! if config['ema']: E_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) metrics = train(x, y) train_log.log(itr=int(state_dict['itr']), **metrics) # Every sv_log_interval, log singular values if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')}) # If using my progbar, print metrics. if config['pbar'] == 'mine': print(', '.join(['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') # Save weights and copies as configured at specified interval if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) # Test every specified interval if not (state_dict['itr'] % config['test_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics, experiment_name, test_log) # Increment epoch counter at end of epoch state_dict['epoch'] += 1
def run(config): # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = vae_utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) # Next, build the model G = Generator(**config).to(device) # If using EMA, prepare it if config['ema']: print('Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = Generator(**{ **config, 'skip_init': True, 'no_optim': True }).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None print(G) print('Number of params in E: {}'.format( *[sum([p.data.nelement() for p in net.parameters()]) for net in [G]])) # Prepare state dict, which holds things like epoch # and itr # state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config, 'best_precise': 0.0 } # If loading from a pre-trained model, load weights if config['resume']: print('Loading weights...') vae_utils.load_weights( [G], state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, [G_ema] if config['ema'] else [None]) class Wrapper(nn.Module): def __init__(self): super(Wrapper, self).__init__() self.G = G def forward(self, w, y): x = self.G(w, self.G.shared(y)) return x W = Wrapper() # If parallel, parallelize the GD module if config['parallel']: W = nn.DataParallel(W) if config['cross_replica']: patch_replication_callback(W) # Prepare loggers for stats; metrics holds test metrics, # lmetrics holds any desired training metrics. test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) print('Inception Metrics will be saved to {}'.format(test_metrics_fname)) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) print('Training Metrics will be saved to {}'.format(train_metrics_fname)) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) # Write metadata utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) get_inception_metrics = inception_utils.prepare_inception_metrics( config['dataset'], config['parallel'], config['data_root'], config['no_fid']) z_, y_ = utils.prepare_z_y(config['batch_size'], G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_w, fixed_y = utils.prepare_z_y(config['batch_size'], G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_w.sample_() fixed_y.sample_() G_scheduler = torch.optim.lr_scheduler.StepLR(G.optim, step_size=50, gamma=0.1) MSE = torch.nn.MSELoss(reduction='mean') def train(w, img): y_.sample_() G.optim.zero_grad() x = W(w, y_) loss = MSE(x, img) loss.backward() if config['E_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. print('using modified ortho reg in E') utils.ortho(G, config['G_ortho']) G.optim.step() out = {' loss': float(loss.item())} if config['ema']: ema.update(state_dict['itr']) del loss, x return out class Embed(nn.Module): def __init__(self): super(Embed, self).__init__() embed = np.load('/ghome/fengrl/home/FGAN/embed_ema.npy') self.dense = nn.Linear(120, 120, bias=False) self.embed = torch.tensor(embed, requires_grad=False) self.dense.load_state_dict({'weight': self.embed}) for param in self.dense.parameters(): param.requires_grad = False def forward(self, z): z = self.dense(z) return z embedding = Embed().to(device) fixed_w = embedding(fixed_w) sample = functools.partial( sample_with_embed, embed=embedding, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) batch_size = config['batch_size'] * config['num_D_steps'] * config[ 'num_D_accumulations'] loader = sampled_ssgan.get_SSGAN_sample_loader( **{ **config, 'batch_size': batch_size, 'start_itr': state_dict['itr'], 'is_slice': False }) print('Beginning training at epoch %d...' % state_dict['epoch']) # Train for specified number of epochs, although we mostly track G iterations. for epoch in range(state_dict['epoch'], config['num_epochs']): # Which progressbar to use? TQDM or my own? if config['pbar'] == 'mine': pbar = utils.progress(loader, displaytype='eta') else: pbar = tqdm(loader) for i, (img, z, w) in enumerate(pbar): # Increment the iteration counter state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter much. G.train() if config['ema']: G_ema.train() img, w = img.to(device), w.to(device) img = torch.split(img, config['batch_size']) w = torch.split(w, config['batch_size']) counter = 0 metrics = train(w[counter], img[counter]) counter += 1 del img, w train_log.log(itr=int(state_dict['itr']), **metrics) # Every sv_log_interval, log singular values if (config['sv_log_interval'] > 0) and ( not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{**utils.get_SVs(G, 'G')}) # If using my progbar, print metrics. if config['pbar'] == 'mine': print(', '.join( ['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') # Save weights and copies as configured at specified interval if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: print('Switchin e to eval mode...') G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample(G, None, G_ema, z_, y_, fixed_w, fixed_y, state_dict, config, experiment_name) # Test every specified interval if not (state_dict['itr'] % config['test_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() train_fns.test(G, None, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics, experiment_name, test_log) # Increment epoch counter at end of epoch state_dict['epoch'] += 1 G_scheduler.step()
def run(config): # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = vae_utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = import_module('Network.' + config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) # Next, build the model G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) # If using EMA, prepare it if config['ema']: print('Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = model.Generator(**{ **config, 'skip_init': True, 'no_optim': True }).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None # FP16? if config['G_fp16']: print('Casting G to float16...') G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: print('Casting D to fp16...') D = D.half() # Consider automatically reducing SN_eps? GD = model.G_D(G, D) print(G) print(D) print('Number of params in G: {} D: {}'.format( * [sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]])) # Prepare state dict, which holds things like epoch # and itr # state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } # If loading from a pre-trained model, load weights if config['resume']: print('Loading weights...') utils.load_weights( G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # If parallel, parallelize the GD module if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) # Prepare loggers for stats; metrics holds test metrics, # lmetrics holds any desired training metrics. test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) print('Inception Metrics will be saved to {}'.format(test_metrics_fname)) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) print('Training Metrics will be saved to {}'.format(train_metrics_fname)) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) # Write metadata utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) # Prepare data; the Discriminator's batch size is all that needs to be passed # to the dataloader, as G doesn't require dataloading. # Note that at every loader iteration we pass in enough data to complete # a full D iteration (regardless of number of D steps and accumulations) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) loaders = utils.get_data_loaders(**{ **config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'] }) # Prepare inception metrics: FID and IS get_inception_metrics = inception_utils.prepare_inception_metrics( config['dataset'], config['parallel'], config['data_root'], config['no_fid']) # Prepare noise and randomly sampled label arrays # Allow for different batch sizes in G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, _ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) # Single Label! y_ = torch.ones(G_batch_size, dtype=torch.int64, requires_grad=False).to(device) # Prepare a fixed z & y to see individual sample evolution throghout training fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() # Loaders are loaded, prepare the training function if config['which_train_fn'] == 'GAN': train = vae_utils.SL_training_function(G, D, GD, z_, y_, ema, state_dict, config) # Else, assume debugging and use the dummy train fn else: train = train_fns.dummy_training_function() # Prepare Sample function for use with inception metrics sample = functools.partial( vae_utils.sample_for_SL, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) print('Beginning training at epoch %d...' % state_dict['epoch']) # Train for specified number of epochs, although we mostly track G iterations. for epoch in range(state_dict['epoch'], config['num_epochs']): # Which progressbar to use? TQDM or my own? if config['pbar'] == 'mine': pbar = utils.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(loaders[0]) for i, (x, y) in enumerate(pbar): # Increment the iteration counter state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter much. G.train() D.train() if config['ema']: G_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) yy = torch.ones_like(y).to(device) metrics = train(x, yy) train_log.log(itr=int(state_dict['itr']), **metrics) # Every sv_log_interval, log singular values if (config['sv_log_interval'] > 0) and ( not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{ **utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D') }) # If using my progbar, print metrics. if config['pbar'] == 'mine': print(', '.join( ['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') # Save weights and copies as configured at specified interval if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) # Test every specified interval if not (state_dict['itr'] % config['test_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics, experiment_name, test_log) # Increment epoch counter at end of epoch state_dict['epoch'] += 1