def iw_log_p_x_dataset(data_loader, model=None, path=None, n_samples=600, n_chunks=3, verbose=False, ddp=False, device_name="cuda:0", max_batches=-1): if model is None and path is None: print( "Either provide a model, or a checkpoint path. Not neither. Aborting." ) quit() if path is not None: print("Loading a model because provided a path {}.".format(path)) model = load_from_checkpoint(path, world_master=True, ddp=ddp, device_name=device_name, evaluation=True, return_loss_term_manager=False) N = len(data_loader) if max_batches < 0 else max_batches print("N", N) log_p_xs = [] sent_lens = [] # handy for perplexity for batch_i, batch in enumerate(data_loader): if verbose is True: print("*" * 40) print(f"{batch_i + 1:3d}/{N}") print("*" * 40) batch = transfer_batch_to_device(batch, device_name=device_name) with torch.no_grad(): log_p_x = iw_log_p_x(model, batch, verbose=verbose, n_chunks=n_chunks, n_samples=n_samples) sent_lens.append(batch["attention_mask"].sum(dim=1)) log_p_xs.append(log_p_x) if batch_i + 1 == N: break log_likelihood = torch.cat(log_p_xs, dim=0).cpu() sent_lens = torch.cat(sent_lens, dim=0).cpu() log_likelihood_p_w = log_likelihood / sent_lens return log_likelihood, log_likelihood_p_w, sent_lens
def evaluation_function(device_rank, run_name, model_path, max_batches, result_dir_path, batch_size, dataset_name, objective, world_size, num_workers): # Prepare some variables & result directory device_name = f"cuda:{device_rank}" latent_size = 32 if "latent32" in model_path else 64 result_dir = Path(result_dir_path) os.makedirs(result_dir, exist_ok=True) result_file = result_dir / f"{device_name}_{run_name}_max_batches_{max_batches}.pickle" if os.path.isfile(result_file): print('_' * 80) print('_' * 80) print("Have done this one already!") print('_' * 80) print('_' * 80) else: print("-" * 30) print("run_name:", run_name) print("batch size:", batch_size) print("max_batches:", max_batches) print("latent size:", latent_size) print("device name:", device_name) print("-" * 30) # Get model # vae_model = #(path=model_path, device_name=device_name) vae_model = load_from_checkpoint( path=model_path, device_name=device_name, latent_size=latent_size, do_tie_embedding_spaces=True, add_decoder_output_embedding_bias=False, do_tie_weights=True, add_latent_via_embeddings=False, add_latent_via_memory=True, objective=objective, evaluation=True) vae_model = vae_model.to(device_name) # Get distributed validation data loader of PTB data set loader = get_dist_validation_loader(batch_size=batch_size, num_workers=num_workers, max_seq_len=64, world_size=world_size, dataset_name=dataset_name, tokenizer_name="roberta", device_name=device_name, gpu_rank=device_rank) dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=device_rank) # Seed everything seed_everything(0) print(f"Len data loader on device {device_name}: {len(loader)}") N = max_batches if max_batches > 0 else len(loader) results = {} for batch_i, batch in enumerate(loader): print(f"{batch_i:3d}/{N} - {device_name}") batch = transfer_batch_to_device(batch, device_name=device_name) with torch.no_grad(): out = vae_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], auto_regressive=False, return_latents=False, return_mu_logvar=False, return_exact_match=True, return_cross_entropy=True, return_reconstruction_loss=True, return_posterior_stats=True, reduce_seq_dim_ce="mean", reduce_seq_dim_exact_match="mean", reduce_batch_dim_exact_match="mean", reduce_batch_dim_ce="mean") for k, v in out.items(): if k not in results: if torch.is_tensor(v): results[k] = [v.item()] else: results[k] = [v] else: if torch.is_tensor(v): results[k].append(v.item()) else: results[k].append(v) if batch_i + 1 == max_batches: break # Dump the results for this device pickle.dump(results, open(result_file, "wb"))
def summary_statistics(path, run_name, data_loader, max_batches=-1, device="cuda:0", result_folder="result-files", result_file=None): os.makedirs(f"{result_folder}/{run_name}", exist_ok=True) if result_file is None: result_file = f"{result_folder}/{run_name}/sum_stats_{run_name}.pth" # Make a loss term manager from checkpoint (includes the model) loss_term_manager = load_from_checkpoint(path, world_master=True, ddp=False, dataset_size=len(data_loader), device_name=device, evaluation=True, return_loss_term_manager=True) # Set to VAE standard objective loss_term_manager.objective = "vae" results = {} N = max_batches if max_batches > 0 else len(data_loader) for batch_i, batch in enumerate(data_loader): print("Batch {:3d}/{:3d}".format(batch_i + 1, N), end="\r") with torch.no_grad(): batch = transfer_batch_to_device(batch, device) if "decoderOnly" in path: decoder_only = True else: decoder_only = False out = loss_term_manager(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_exact_match=True, return_reconstruction_loss=True, decoder_only=decoder_only, return_posterior_stats=True, device_name=device, return_cross_entropy=False, reduce_seq_dim_ce="mean", reduce_batch_dim_ce="mean", reduce_seq_dim_exact_match="mean", reduce_batch_dim_exact_match="mean", train=False) for k, v in out.items(): if torch.is_tensor(v) and v.dim() == 0: x = v.item() else: x = v if k in results: results[k].append(x) else: results[k] = [x] if batch_i + 1 == max_batches: break results_cat = {} for k, v in results.items(): if torch.is_tensor(v[0]): results_cat[k] = torch.cat(v, dim=0) else: results_cat[k] = v dump_pickle(results_cat, result_file)
def iw_log_p_x_generated(model=None, path=None, n_batches=10, batch_size=64, n_samples=600, n_chunks=3, verbose=False, ddp=False, device_name="cuda:0", max_seq_len_gen=64): if model is None and path is None: print( "Either provide a model, or a checkpoint path. Not neither. Aborting." ) quit() if path is not None: print("Loading a model because provided a path {}.".format(path)) model = load_from_checkpoint(path, world_master=True, ddp=ddp, device_name=device_name, evaluation=True, return_loss_term_manager=False) log_p_xs, log_p_x_ws = [], [] for batch_i in range(n_batches): if verbose: print(f"Batch {batch_i}/{n_batches}") with torch.no_grad(): # Sample from the model by decoding from prior auto-regressively with sampling out = model( return_reconstruction_loss=False, return_posterior_stats=False, auto_regressive=True, max_seq_len=max_seq_len_gen, return_predictions=True, nucleus_sampling=True, top_k=0, # no filtering top_p=1.0, # no filtering decode_sample_from_prior=True, n_prior_samples=batch_size, device_name=device_name) padded_predictions, mask, lens = make_batch_from_model_samples( out["predictions"]) batch = dict(input_ids=padded_predictions.to(device_name), attention_mask=mask.to(device_name)) log_p_x = iw_log_p_x(model, batch, n_samples=n_samples, n_chunks=n_chunks, verbose=True).cpu() log_p_x_w = log_p_x / lens log_p_xs.append(log_p_x) log_p_x_ws.append(log_p_x_w) log_p_xs = torch.cat(log_p_xs) log_p_x_ws = torch.cat(log_p_x_ws) return log_p_xs, log_p_x_ws, lens
def acc_drop_over_relative_seq_len(data_loader, model=None, path=None, device="cuda:0", max_batches=-1, N_bins=30): N = max_batches if max_batches > 0 else len(data_loader) assert not (model is None and path is None), "Either supply model or a path. Aborting." if path is not None and model is None: model = load_from_checkpoint(path, world_master=True, ddp=False, device_name=device, evaluation=True) prior_accs = [] post_accs = [] masks = [] for batch_i, batch in enumerate(data_loader): print("Batch {:3d}/{:3d}".format(batch_i + 1, N), end="\r") # save mask labels = batch["input_ids"][:, 1:].contiguous() # skip <s> token label_mask = (labels != 1).float() # pad token is int 1 masks.append(label_mask) # transfer batch to device batch = transfer_batch_to_device(batch, device) # save acc stats of experiments for batch for decode_prior_samples in [True, False]: with torch.no_grad(): preds = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], auto_regressive=False, max_seq_len=64, return_exact_match=True, return_cross_entropy=False, return_reconstruction_loss=False, return_posterior_stats=False, reduce_seq_dim_ce="mean", reduce_seq_dim_exact_match="none", reduce_batch_dim_exact_match="none", reduce_batch_dim_ce="none", nucleus_sampling=False, top_k=0, top_p=1.0, decode_sample_from_prior=decode_prior_samples, n_prior_samples=batch["input_ids"].shape[0], device_name=device) if decode_prior_samples is True: prior_accs.append(preds["exact_match"].cpu()) else: post_accs.append(preds["exact_match"].cpu()) if (batch_i + 1) == max_batches: break prior_accs = cat_pad_uneven(prior_accs, pad_value=0) post_accs = cat_pad_uneven(post_accs, pad_value=0) masks = cat_pad_uneven(masks, pad_value=0) seq_lens = masks.sum(dim=1) n_samples, max_len = prior_accs.shape positions = torch.arange(1, max_len + 1).unsqueeze(0).repeat(n_samples, 1) relative_positions = positions / seq_lens.unsqueeze(1) prior_accs_masked = torch.masked_select(prior_accs, masks == 1.0) post_accs_masked = torch.masked_select(post_accs, masks == 1.0) acc_drops = post_accs_masked - prior_accs_masked relative_positions_masked = torch.masked_select(relative_positions, masks == 1.0) bin_means, bin_edges, bin_ids = stats.binned_statistic( relative_positions_masked.tolist(), acc_drops.tolist(), statistic='mean', bins=N_bins) return_dict = dict(bin_means=bin_means, bin_edges=bin_edges, acc_drops=acc_drops, prior_accs=prior_accs_masked, posterior_accs=post_accs_masked) return return_dict
def dist_iw_log_likelihood_x_obs_x_gen(device_rank, run_name, model_path, max_batches, result_dir_path, batch_size, dataset_name, world_size, num_workers, n_samples, n_chunks, max_seq_len_gen): # Prepare some variables & result directory device_name = f"cuda:{device_rank}" result_dir = Path(result_dir_path) / run_name os.makedirs(result_dir, exist_ok=True) # single GPU file result_file = result_dir / f"{device_name}_{run_name}_max_batches_{max_batches}.pickle" # combined for all GPUs full_result_file = result_dir / f"{run_name}_world_size_{world_size}_max_batches_{max_batches}_" \ f"batch_size_{batch_size}_n_samples_{n_samples}.pickle" if os.path.isfile(result_file) or os.path.isfile(full_result_file): print('_' * 80) print('_' * 80) print("Have done this one already!") print('_' * 80) print('_' * 80) else: print("-" * 30) print("run_name:", run_name) print("batch size:", batch_size) print("max_batches:", max_batches) print("device name:", device_name) print("-" * 30) # Get model vae_model = load_from_checkpoint(model_path, world_master=True, ddp=False, device_name=device_name, evaluation=True, return_loss_term_manager=False) # Get distributed validation data loader of PTB data set valid_loader = get_dist_validation_loader( batch_size=batch_size, num_workers=num_workers, max_seq_len=64, world_size=world_size, dataset_name=dataset_name, tokenizer_name="roberta", device_name=device_name, gpu_rank=device_rank, train_validation="validation") train_loader = get_dist_validation_loader(batch_size=batch_size, num_workers=num_workers, max_seq_len=64, world_size=world_size, dataset_name=dataset_name, tokenizer_name="roberta", device_name=device_name, gpu_rank=device_rank, train_validation="train") dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=device_rank) # Seed everything seed_everything(0) N_valid = max_batches if max_batches > 0 else len(valid_loader) N_train = max_batches if max_batches > 0 else len(train_loader) N = max([N_valid, N_train]) print(f"N_valid {N_valid} N_train {N_train} N {N}") with torch.no_grad(): log_p_x_obs_valid, log_p_x_w_obs_valid, lens_obs_valid = iw_log_p_x_dataset( valid_loader, model=vae_model, path=None, n_samples=n_samples, n_chunks=n_chunks, verbose=True, ddp=False, device_name=device_name, max_batches=N_valid) log_p_x_obs_train, log_p_x_w_obs_train, lens_obs_train = iw_log_p_x_dataset( train_loader, model=vae_model, path=None, n_samples=n_samples, n_chunks=n_chunks, verbose=True, ddp=False, device_name=device_name, max_batches=N_train) log_p_x_gen, log_p_x_w_gen, lens_gen = iw_log_p_x_generated( model=vae_model, path=None, n_batches=N, batch_size=batch_size, n_samples=n_samples, n_chunks=n_chunks, verbose=True, ddp=False, device_name=device_name, max_seq_len_gen=max_seq_len_gen) results = dict(log_p_x_obs_valid=log_p_x_obs_valid.cpu(), log_p_x_w_obs_valid=log_p_x_w_obs_valid.cpu(), lens_obs_valid=lens_obs_valid.cpu(), log_p_x_obs_train=log_p_x_obs_train.cpu(), log_p_x_w_obs_train=log_p_x_w_obs_train.cpu(), lens_obs_train=lens_obs_train.cpu(), log_p_x_gen=log_p_x_gen.cpu(), log_p_x_w_gen=log_p_x_w_gen.cpu(), lens_gen=lens_gen.cpu()) # Dump the results for this device pickle.dump(results, open(result_file, "wb"))
# Calculate MI bounds for these models mutual_information_results = {} for name, path in run_names_paths_to_evaluate: vae_model = get_model_on_device(device_name=DEVICE_NAME, latent_size=768, gradient_checkpointing=False, add_latent_via_memory=True, add_latent_via_embeddings=True, do_tie_weights=True, world_master=True) _, _, vae_model, _, _, _, _ = load_from_checkpoint(vae_model, path, world_master=True, ddp=False, use_amp=False) mi_results = calc_all_mi_bounds(vae_model, VALID_LOADER, device_name=DEVICE_NAME, max_batches=MAX_BATCHES, batch_size=BATCH_SIZE) mutual_information_results[name] = mi_results prefix = "/home/cbarkhof/code-thesis/NewsVAE/evaluation/29DEC/" pickle_filename = "29DEC-mutual-information-results.p" pickle_path = prefix + pickle_filename
def train(device_rank, config, run_name): print("**** DEVICE: ", device_rank) # Device device_name = utils_train.set_device(device_rank) # Determine world size and whether this device is world master world_master, world_size = utils_train.get_world_specs(config.n_gpus, config.n_nodes, device_name) # Determine the maximum number of steps for this device global_max_steps, global_max_grad_steps = utils_train.determine_global_max_steps(config.max_global_train_steps, config.batch_size, world_size, config.accumulate_n_batches_grad) # Initiate process group and specify backend configurations if config.ddp: if world_master: print("Init process group...") if world_master: print(f"--> CPU count {multiprocessing.cpu_count()}") dist.init_process_group(backend='nccl', init_method='env://', world_size=int(config.n_gpus * config.n_nodes), rank=device_rank) # Seed everything seed_everything(config.seed) # Data loaders / data set / samplers (if ddp) data_loaders, data, samplers = utils_train.get_dataloader(["train", "validation"], ddp=config.ddp, batch_size=config.batch_size, num_workers=config.num_workers, max_seq_len=config.max_seq_len, world_size=world_size, dataset_name=config.dataset_name, tokenizer_name=config.tokenizer_name, device_name=device_name, world_master=world_master, gpu_rank=device_rank) # These are actual steps, not gradient steps, so they work in combination with global step max_train_steps_epoch_per_rank, max_valid_steps_epoch_per_rank = utils_train.determine_max_epoch_steps_per_rank( config.max_train_steps_epoch_per_rank, config.max_valid_steps_epoch_per_rank, data.datasets, config.batch_size, world_size=world_size, world_master=world_master) max_epochs = config.max_epochs if config.max_epochs > 0 else 100 config.max_train_steps_epoch_per_rank = max_train_steps_epoch_per_rank # overwrite this config.max_valid_steps_epoch_per_rank = max_valid_steps_epoch_per_rank # overwrite this print("*"*80) print("config.max_train_steps_epoch_per_rank", config.max_train_steps_epoch_per_rank) print("*" * 80) # Get model and loss term manager dataset_size = data.datasets['train'].shape[0] if config.load_from_checkpoint: assert os.path.isfile(config.checkpoint_file), f"checkpoint file does not exists: {config.checkpoint_file}" loss_term_manager = utils_train.load_from_checkpoint(config.checkpoint_file, world_master=world_master, ddp=config.ddp, device_name=device_name, evaluation=False, return_loss_term_manager=True, loss_term_manager_config=config) else: loss_term_manager = vae.get_loss_term_manager_with_model(config, world_master=world_master, dataset_size=dataset_size, device_name=device_name) autoencoder = False if config.objective == "beta-vae" and config.b_vae_beta_constant_linear_lagrangian == "constant" and config.b_vae_beta == 0.0: print("** AUTO ENCODER OBJECTIVE!!") autoencoder = True # Initialise logging if config.logging and world_master: utils_train.init_logging(loss_term_manager.vae_model, run_name, config.code_dir_path, config.wandb_project, config, config.run_dir_name) # Set-up DDP if config.ddp: # Wrap both the model and constraints etc in a loss_term_manager nn.Module as suggested here: # https://discuss.pytorch.org/t/multiple-modules-with-distributed-data-parallel/115621 loss_term_manager = torch.nn.parallel.DistributedDataParallel(loss_term_manager, device_ids=[device_rank], find_unused_parameters=False) # not needed to check print(f"-> Turned on DDP for device rank {device_rank}") # Zero grads TODO: fix this # loss_term_manager.zero_grad() # Initialise the stats to keep track of stats = utils_train.make_nested_dict() finished_training = False epoch, global_step, global_grad_step, not_improved_epochs = 0, 0, 0, 0 # NB, I am not using D_ks for pareto checkpointing anymore. epoch_pareto_effiency_dict = utils_train.prepare_pareto_dict(config=config) current_efficient_epochs = [] if world_master: print("Start or resume training!") # ---------------------------------------------------------------------------------------------------- # TRAINING! # ---------------------------------------------------------------------------------------------------- while not finished_training: print("finished_training", finished_training) # TRAIN, VALID for phase in data_loaders.keys(): if finished_training: break if config.ddp: print(f"-> Setting epoch explicitly to {epoch} on device {device_name}") samplers[phase].set_epoch(epoch) # needed to explicitly shuffle max_steps = max_train_steps_epoch_per_rank if phase == 'train' else max_valid_steps_epoch_per_rank atts_to_latent, masks, = [], [] # latents = [] for batch_i, batch in enumerate(data_loaders[phase]): # ---------------------------------------------------------------------------------------------------- # TRAIN / VALIDATION STEPS # ---------------------------------------------------------------------------------------------------- # SET DEVICE batch = utils_train.transfer_batch_to_device(batch, device_name) # PERFORM TRAIN / VALIDATION STEP if phase == 'train': loss_term_manager, losses = do_train_step( loss_term_manager, batch, global_step, use_amp=config.use_amp, accumulate_n_batches_grad=config.accumulate_n_batches_grad, device_name=device_name, gradient_clipping=config.gradient_clipping, decoder_only=config.decoder_only, ddp=config.ddp) else: # save_latents happens now outside the train loop losses = do_valid_step(loss_term_manager, batch, device_name=device_name, ddp=config.ddp, decoder_only=config.decoder_only, iw_ll_n_samples=config.iw_ll_n_samples, eval_iw_ll_x_gen=config.eval_iw_ll_x_gen, max_seq_len_x_gen=config.max_seq_len_x_gen, save_latents=False) # if "latent_z" in losses: # latents.append(losses["latent_z"]) # del losses["latent_z"] if "attention_to_latent" in losses: atts_to_latent.append(losses["attention_to_latent"].cpu()) masks.append(batch["attention_mask"][:, 1:].cpu()) del losses["attention_to_latent"] # ---------------------------------------------------------------------------------------------------- # INSERT STATISTICS, PRINT, LOG, CHECKPOINT # ---------------------------------------------------------------------------------------------------- # INSERT STATISTICS stats = utils_train.insert_stats(stats, losses, epoch, phase) # PRINT if world_master and global_step % config.print_every_n_steps == 0 and config.print_stats: utils_train.print_stats(stats, epoch, phase, global_step, global_max_steps, global_grad_step, global_max_grad_steps, batch_i, max_steps, config.objective) # LOG STEP (only if world master) if batch_i % config.log_every_n_steps == 0 and config.logging and world_master and phase == 'train': if config.add_latent_w_matrix_influence: utils_train.add_matrix_influence_weight_to_loss(loss_term_manager, global_step, global_grad_step, ddp=config.ddp) utils_train.log_losses_step(losses, phase, epoch, config.log_every_n_steps, global_step, global_grad_step) # Analyse and save latents for runs with save_latents == True if global_step % config.save_latents_every_x_steps == 0 and config.save_latents: utils_train.analyse_save_latents(data_loaders["validation"], loss_term_manager.vae_model, stats, config.code_dir_path, config.run_dir_name, run_name, global_step, epoch, device_name=device_name) # ---------------------------------------------------------------------------------------------------- # KEEP TRACK OF STEPS (IN PHASE AND GLOBALLY) # ---------------------------------------------------------------------------------------------------- # ADVANCE STEP if in train mode if phase == "train": global_step += 1 if global_step % config.accumulate_n_batches_grad == 0: global_grad_step += 1 # CHECK IF EPOCH PHASE IS OVER (after advancing one) if batch_i >= max_steps: break if global_step >= global_max_steps or epoch >= max_epochs: finished_training = True; break # ---------------------------------------------------------------------------------------------------- # END OF TRAIN / VALID PHASE # ---------------------------------------------------------------------------------------------------- # BEST MODEL CHECKPOINT if phase == 'validation' and world_master: val_epoch_stats = stats[epoch]["validation"] # if len(latents) > 0: # utils_train.save_latents(latents, global_step, epoch, run_name, # config.code_dir_path, config.run_dir_name) # Update the epoch_pareto_effiency_dict and determine efficient_epochs epoch_pareto_effiency_dict, efficient_epochs = utils_train.determine_pareto_checkpoint( val_epoch_stats, epoch_pareto_effiency_dict, epoch, logging=config.logging, decoder_only=config.decoder_only or autoencoder) # if AE, also evaluate based on -D # Check if anything changed, if not keep count of not improved epochs if efficient_epochs == current_efficient_epochs: not_improved_epochs += 1 else: not_improved_epochs = 0 current_efficient_epochs = efficient_epochs # Early stopping if (not_improved_epochs >= config.early_stop_epochs) and config.early_stopping: print("*" * 50) print("EARLY STOPPING!") print("*" * 50) finished_training = True # Checkpoint according to efficient_epochs, save the data if config.checkpoint: vae_model = loss_term_manager.vae_model if config.ddp is False else loss_term_manager.module.vae_model utils_train.save_checkpoint_model(vae_model, run_name, config.code_dir_path, global_step, epoch, config, efficient_epochs, epoch_pareto_effiency_dict, config.run_dir_name) # ---------------------------------------------------------------------------------------------------- # END OF EPOCH # ---------------------------------------------------------------------------------------------------- # LOG EPOCH STATS (if world master) if config.logging and world_master: print("LOG EPOCH STATS") utils_train.log_stats_epoch(stats, epoch, global_step, global_grad_step, atts_to_latent, masks) epoch += 1 # Dump train stats and pareto stats path = config.code_dir_path + "/" + run_name pickle.dump(stats, open(path + "/stats.pickle", "wb")) pickle.dump(epoch_pareto_effiency_dict, open(path + "/pareto_dict.pickle", "wb"))