cmd + config["decoding"]["decoding_script_folder"] + "/" + config["decoding"]["decoding_script"] + " " + os.path.abspath(config_dec_file) + " " + out_dec_folder + ' "' + files_dec + '"' ) run_shell(cmd_decode, log_file) # remove ark files if needed if not forward_save_files[k]: list_rem = glob.glob(files_dec) for rem_ark in list_rem: os.remove(rem_ark) # Print WER results and write info file cmd_res = "./check_res_dec.sh " + out_dec_folder wers = run_shell(cmd_res, log_file).decode("utf-8") res_file = open(res_file_path, "a") res_file.write("%s\n" % wers) print(wers) # Saving Loss and Err as .txt and plotting curves if not is_production: create_curves(out_folder, N_ep, valid_data_lst)
def run(config): # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[ config['dataset']] * config['cluster_per_class'] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] if config['is_encoder']: config['E_fp16'] = float(config['D_fp16']) config['num_E_accumulations'] = int(config['num_D_accumulations']) config['dataset_channel'] = utils.channel_dict[config['dataset']] config['lambda_encoder'] = config['resolution']**2 * config[ 'dataset_channel'] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) # Next, build the model G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) if config['is_encoder']: E = model.Encoder(**{**config, 'D': D}).to(device) Prior = layers.Prior(**config).to(device) # If using EMA, prepare it if config['ema']: print('Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = model.Generator(**{ **config, 'skip_init': True, 'no_optim': True }).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None # FP16? if config['G_fp16']: print('Casting G to float16...') G = G.half() if config['ema']: G_ema = G_ema.half() if not config['prior_type'] == 'default': Prior = Prior.half() if config['D_fp16']: print('Casting D to fp16...') D = D.half() # Consider automatically reducing SN_eps? if config['is_encoder'] and config['E_fp16']: print('Casting E to fp16...') E = E.half() print(G) print(D) if config['is_encoder']: print(E) print(Prior) if not config['is_encoder']: GD = model.G_D(G, D) print('Number of params in G: {} D: {}'.format(*[ sum([p.data.nelement() for p in net.parameters()]) for net in [G, D] ])) else: GD = model.G_D(G, D, E, Prior) GE = model.G_E(G, E, Prior) print('Number of params in G: {} D: {} E: {}'.format(*[ sum([p.data.nelement() for p in net.parameters()]) for net in [G, D, E] ])) # Prepare state dict, which holds things like epoch # and itr # # ¡¡¡¡¡¡¡¡¡ Put rec error, discriminator loss and generator loss !!!!!!!!!!!????????? state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'best_error_rec': 99999, 'config': config } # If loading from a pre-trained model, load weights if config['resume']: print('Loading weights...') utils.load_weights( G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None, E=None if not config['is_encoder'] else E, Prior=Prior if not config['prior_type'] == 'default' else None) # If parallel, parallelize the GD module if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) # If parallel, parallelize the GD module #if config['parallel'] and config['is_encoder']: # GE = nn.DataParallel(GE) # if config['cross_replica']: # patch_replication_callback(GE) # Prepare loggers for stats; metrics holds test metrics, # lmetrics holds any desired training metrics. test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) print('Inception Metrics will be saved to {}'.format(test_metrics_fname)) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) print('Training Metrics will be saved to {}'.format(train_metrics_fname)) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) # Write metadata utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) # Prepare data; the Discriminator's batch size is all that needs to be passed # to the dataloader, as G doesn't require dataloading. # Note that at every loader iteration we pass in enough data to complete # a full D iteration (regardless of number of D steps and accumulations) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) loaders = utils.get_data_loaders(**{ **config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'] }) if config['is_encoder']: config_aux = config.copy() config_aux['augment'] = False dataloader_noaug = utils.get_data_loaders( **{ **config_aux, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'] }) # Prepare inception metrics: FID and IS if (config['dataset'] in ['C10']): get_inception_metrics = inception_utils.prepare_inception_metrics( config['dataset'], config['parallel'], config['no_fid']) else: get_inception_metrics = None # Loaders are loaded, prepare the training function if config['which_train_fn'] == 'GAN': train = train_fns.GAN_training_function( G, D, GD, Prior, ema, state_dict, config, losses.Loss_obj(**config), None if not config['is_encoder'] else E) # Else, assume debugging and use the dummy train fn else: train = train_fns.dummy_training_function() # Prepare Sample function for use with inception metrics sample = functools.partial( utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), Prior=Prior, config=config) # Create fixed fixed_z, fixed_y = Prior.sample_noise_and_y() fixed_z, fixed_y = fixed_z.clone(), fixed_y.clone() iter_num = 0 print('Beginning training at epoch %d...' % state_dict['epoch']) # Train for specified number of epochs, although we mostly track G iterations. for epoch in range(state_dict['epoch'], config['num_epochs']): # Which progressbar to use? TQDM or my own? if config['pbar'] == 'mine': pbar = utils.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(loaders[0]) for i, (x, y) in enumerate(pbar): # Increment the iteration counter state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter much. G.train() D.train() if config['is_encoder']: E.train() if not config['prior_type'] == 'default': Prior.train() if config['ema']: G_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) metrics = train(x, y, iter_num) train_log.log(itr=int(state_dict['itr']), **metrics) # Every sv_log_interval, log singular values if (config['sv_log_interval'] > 0) and ( not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{ **utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D') }) if config['is_encoder']: train_log.log(itr=int(state_dict['itr']), **{**utils.get_SVs(E, 'E')}) # If using my progbar, print metrics. if config['pbar'] == 'mine': print(', '.join( ['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') # Save weights and copies as configured at specified interval if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() if not config['prior_type'] == 'default': Prior.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample( G, D, G_ema, Prior, fixed_z, fixed_y, state_dict, config, experiment_name, None if not config['is_encoder'] else E) if not (state_dict['itr'] % config['test_every']) and config['is_encoder']: if not config['prior_type'] == 'default': test_acc, test_acc_iter, error_rec = train_fns.test_accuracy( GE, dataloader_noaug, device, config['D_fp16'], config) p_mse, p_lik = train_fns.test_p_acc(GE, device, config) if config['n_classes'] == 10: utils.reconstruction_sheet( GE, classes_per_sheet=utils.classes_per_sheet_dict[ config['dataset']], num_classes=config['n_classes'], samples_per_class=20, parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=state_dict['itr'], dataloader=dataloader_noaug, device=device, D_fp16=config['D_fp16'], config=config) # Test every specified interval if not (state_dict['itr'] % config['test_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') if not config['prior_type'] == 'default': Prior.eval() G.eval() train_fns.test( G, D, G_ema, Prior, state_dict, config, sample, get_inception_metrics, experiment_name, test_log, None if not config['is_encoder'] else E, None if config['prior_type'] == 'default' else (test_acc, test_acc_iter, error_rec, p_mse, p_lik)) if not (state_dict['itr'] % config['test_every']): utils.create_curves(train_metrics_fname, plot_sv=False, prior_type=config['prior_type'], is_E=config['is_encoder']) utils.plot_IS_FID(train_metrics_fname) # Increment epoch counter at end of epoch iter_num += 1 state_dict['epoch'] += 1