def load_checkpoint(args, model): folder_path = args.checkpoint_path if not os.path.exists(folder_path): print_log(f"Checkpoint path [{folder_path}] does not exist.") return 0 print_log(f"Checkpoint path [{folder_path}] does exist.") files = [f for f in os.listdir(folder_path) if ".pt" in f] if len(files) == 0: print_log("No .pt files found in checkpoint path.") return 0 latest_model = sorted(files)[-1] file_path = "{}/{}".format(folder_path.rstrip("/"), latest_model) if not os.path.exists(file_path): print_log(f"File [{file_path}] not found.") return 0 model.load_state_dict( torch.load(file_path, map_location=lambda storage, loc: storage)) if args.cuda: model.cuda(torch.cuda.current_device()) print_log("Loaded model from {}".format(file_path)) return int(latest_model.replace("model_", "").replace(".pt", "")) + 1
def print_results(args, items, epoch_number, iteration, data_len, training=True): msg = "[{}] Epoch {}/{} | Iter {}/{} | ".format("T" if training else "V", epoch_number, args.train_epochs, iteration, data_len) msg += "".join("{} {:.4E} | ".format(k, v) for k, v in items) print_log(msg)
def save_checkpoint(args, model, optimizer, lr_scheduler, epoch): # Create folder if not already created folder_path = args.checkpoint_path folders = folder_path.split("/") for i in range(len(folders)): if folders[i] == "": continue intermediate_path = "/".join(folders[:i + 1]) if not os.path.exists(intermediate_path): os.mkdir(intermediate_path) final_path = "{}/model_{:03d}.pt".format(folder_path.rstrip("/"), epoch) if os.path.exists(final_path): os.remove(final_path) torch.save(model.state_dict(), final_path) print_log("Saved model at {}".format(final_path))
def main(args): print_log("Setting seed.") set_random_seed(args) print_log("Setting up dataloaders.") train_dataloader, valid_dataloader, test_dataloader = get_data(args) print_log("Setting up model, optimizer, and learning rate scheduler.") model, optimizer, lr_scheduler = setup_model_and_optim( args, len(train_dataloader)) report_model_stats(model) if args.finetune: epoch = load_checkpoint(args, model) else: epoch = 0 original_epoch = epoch print_log("Starting training.") results = {"valid": [], "train": [], "test": []} last_valid_ll = -float('inf') epsilon = 0.03 while epoch < args.train_epochs or args.early_stop: results["train"].append( train_epoch(args, model, optimizer, lr_scheduler, train_dataloader, epoch + 1)) if args.do_valid and ((epoch + 1) % args.valid_epochs == 0): new_valid = eval_epoch(args, model, valid_dataloader, train_dataloader, epoch + 1) results["valid"].append(new_valid) if args.early_stop: if new_valid["log_likelihood"] - last_valid_ll < epsilon: break last_valid_ll = new_valid["log_likelihood"] if ((epoch + 1) % args.save_epochs == 0): save_checkpoint(args, model, optimizer, lr_scheduler, epoch) epoch += 1 if args.save_epochs > 0 and original_epoch != epoch: save_checkpoint(args, model, optimizer, lr_scheduler, epoch) if args.do_valid: overall_test_results = {} reps = 5 for _ in range(reps): test_results = eval_epoch(args, model, test_dataloader, train_dataloader, epoch + 1, num_samples=500) # for k,v in test_results.items(): # if k not in overall_test_results: # overall_test_results[k] = v / reps # else: # overall_test_results[k] += v / reps results["test"].append(test_results) #overall_test_results) del model del optimizer del lr_scheduler del train_dataloader del valid_dataloader del test_dataloader torch.cuda.empty_cache() return results
def report_model_stats(model): encoder_parameter_count = 0 aggregator_parameter_count = 0 decoder_parameter_count = 0 total = 0 for name, param in model.named_parameters(): if name.startswith("encoder"): encoder_parameter_count += param.numel() elif name.startswith("aggregator"): aggregator_parameter_count += param.numel() else: decoder_parameter_count += param.numel() total += param.numel() print_log() print_log("<Parameter Counts>") print_log("Encoder........{}".format(encoder_parameter_count)) print_log("Aggregator.....{}".format(aggregator_parameter_count)) print_log("Decoder........{}".format(decoder_parameter_count)) print_log("---Total.......{}".format(total)) print_log()
def get_data(args): train_dataset = PointPatternDataset( file_path=args.train_data_path, args=args, keep_pct=args.train_data_percentage, set_dominating_rate=args.sample_generations, is_test=False, ) args.num_channels = train_dataset.vocab_size train_dataloader = DataLoader( dataset=train_dataset, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=args.num_workers, collate_fn=lambda x: pad_and_combine_instances( x, train_dataset.max_period), drop_last=True, ) args.max_period = train_dataset.get_max_T() / 2.0 print_log("Loaded {} / {} training examples / batches from {}".format( len(train_dataset), len(train_dataloader), args.train_data_path)) if args.do_valid: valid_dataset = PointPatternDataset( file_path=args.valid_data_path, args=args, keep_pct=args.valid_to_test_pct, set_dominating_rate=False, is_test=False, ) valid_dataloader = DataLoader( dataset=valid_dataset, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=args.num_workers, collate_fn=lambda x: pad_and_combine_instances( x, valid_dataset.max_period), drop_last=True, ) print_log( "Loaded {} / {} validation examples / batches from {}".format( len(valid_dataset), len(valid_dataloader), args.valid_data_path)) test_dataset = PointPatternDataset( file_path=args.valid_data_path, args=args, keep_pct=args. valid_to_test_pct, # object accounts for the test set having (1 - valid_to_test_pct) amount set_dominating_rate=False, is_test=True, ) test_dataloader = DataLoader( dataset=test_dataset, batch_size=args.batch_size // 4, shuffle=args.shuffle, num_workers=args.num_workers, collate_fn=lambda x: pad_and_combine_instances( x, test_dataset.max_period), drop_last=True, pin_memory=args.pin_test_memory, ) print_log("Loaded {} / {} test examples / batches from {}".format( len(test_dataset), len(test_dataloader), args.valid_data_path)) else: valid_dataloader = None test_dataloader = None return train_dataloader, valid_dataloader, test_dataloader
def eval_epoch(args, model, eval_dataloader, train_dataloader, epoch_number, num_samples=150): model.eval() with torch.no_grad(): total_losses = defaultdict(lambda: 0.0) data_len = len(eval_dataloader) valid_latents, valid_labels = [], [] for i, batch in enumerate(eval_dataloader): batch_loss, results = eval_step(args, model, batch, num_samples) if args.classify_latents: valid_latents.append( results["latent_state_dict"]["latent_state"]) valid_labels.append(batch["pp_id"]) for k, v in batch_loss.items(): total_losses[k] += v.item() print_results(args, [(k, v / data_len) for k, v in total_losses.items()], epoch_number, i + 1, data_len, False) if args.classify_latents: with torch.no_grad(): train_latents, train_labels = [], [] for batch in train_dataloader: _, results = eval_step(args, model, batch) train_latents.append( results["latent_state_dict"]["latent_state"]) train_labels.append(batch["pp_id"]) train_latents = torch.cat(train_latents, dim=0).squeeze().numpy() train_labels = torch.cat(train_labels, dim=0).squeeze().numpy() valid_latents = torch.cat(valid_latents, dim=0).squeeze().numpy() valid_labels = torch.cat(valid_labels, dim=0).squeeze().numpy() clf = LogisticRegression( random_state=args.seed, solver="liblinear", multi_class="auto", ).fit(train_latents, train_labels) train_acc, valid_acc = clf.score(train_latents, train_labels), clf.score( valid_latents, valid_labels) t_vals, t_counts = np.unique(train_labels, return_counts=True) t_most_freq_val, t_most_freq_count = t_vals[ t_counts.argmax()], t_counts.max() naive_train_acc = t_most_freq_count / len(train_labels) v_vals, v_counts = np.unique(valid_labels, return_counts=True) v_most_freq_count = v_counts[np.where(v_vals == t_most_freq_val)[0][0]] naive_valid_acc = v_most_freq_count / len(valid_labels) print_log( "[C] Epoch {}/{} | Train Acc {:.4E} | Valid Acc {:.4E} | (N) Train Acc {:.4E} | (N) Valid Acc {:.4E}" .format( epoch_number, args.train_epochs, train_acc, valid_acc, naive_train_acc, naive_valid_acc, )) return {k: v / data_len for k, v in total_losses.items()}
reps = 5 for _ in range(reps): test_results = eval_epoch(args, model, test_dataloader, train_dataloader, epoch + 1, num_samples=500) # for k,v in test_results.items(): # if k not in overall_test_results: # overall_test_results[k] = v / reps # else: # overall_test_results[k] += v / reps results["test"].append(test_results) #overall_test_results) del model del optimizer del lr_scheduler del train_dataloader del valid_dataloader del test_dataloader torch.cuda.empty_cache() return results if __name__ == "__main__": print_log("Getting arguments.") args = get_args() main(args)
def train_step(args, model, optimizer, lr_scheduler, batch): loss_results, forward_results = forward_pass(args, batch, model) if backward_pass(args, loss_results["loss"], model, optimizer): optimizer.step() lr_scheduler.step() else: print_log('======= NAN-Loss =======') print_log( "Loss Results:", { k: (torch.isnan(v).any().item(), v.min().item(), v.max().item()) for k, v in loss_results.items() if isinstance(v, torch.Tensor) }) print_log("Loss Results:", loss_results) print_log("") print_log( "Batch:", { k: (torch.isnan(v).any().item(), v.min().item(), v.max().item()) for k, v in batch.items() if isinstance(v, torch.Tensor) }) print_log("Batch:", batch) print_log("") print_log( "Results:", { k: (torch.isnan(v).any().item(), v.min().item(), v.max().item()) for k, v in forward_results["state_dict"].items() }) print_log( "Results:", { k: (torch.isnan(v).any().item(), v.min().item(), v.max().item()) for k, v in forward_results["tgt_intensities"].items() }) print_log( "Results:", { k: (torch.isnan(v).any().item(), v.min().item(), v.max().item()) for k, v in forward_results["sample_intensities"].items() }) print_log("Results:", forward_results) print_log("========================") input() return loss_results
def anomaly_detection(args, model): model.eval() num_samples = 10000 lengths_to_test = [1, 5, 10, 20, 50, None] labels = [f"cond_{i if i is not None else 'all'}" for i in lengths_to_test] all_results = {} for length, label in zip(lengths_to_test, labels): results = [] dataset = AnomalyDetectionDataset( file_path=args.valid_data_path, args=args, max_tgt_seq_len=length, num_total_pairs=10000, test=True, ) dataloader = DataLoader( dataset=dataset, batch_size=args.batch_size // 8, shuffle=False, num_workers=0, collate_fn=lambda x: pad_and_combine_instances( x, dataset.max_period), drop_last=False, pin_memory=args.pin_test_memory, ) for i, batch in enumerate(dataloader): if ((i + 1) % 10 == 0) or (i < 20): print_log( f"Batch {i+1}/{len(dataloader)} for conditioning on {length} events processed." ) if args.cuda: batch = { k: v.cuda(torch.cuda.current_device()) for k, v in batch.items() } ref_marks, ref_timestamps, context_lengths, padding_mask \ = batch["ref_marks"], batch["ref_times"], batch["context_lengths"], batch["padding_mask"] ref_marks_backwards, ref_timestamps_backwards = batch[ "ref_marks_backwards"], batch["ref_times_backwards"] tgt_marks, tgt_timestamps = batch["tgt_marks"], batch["tgt_times"] pp_id = batch["pp_id"] T = batch["T"] sample_timestamps = torch.rand( tgt_timestamps.shape[0], num_samples, dtype=tgt_timestamps.dtype, device=tgt_timestamps.device).clamp(min=1e-8) * T # ~ U(0, T) model_res = model( ref_marks=ref_marks, ref_timestamps=ref_timestamps, ref_marks_bwd=ref_marks_backwards, ref_timestamps_bwd=ref_timestamps_backwards, tgt_marks=tgt_marks, tgt_timestamps=tgt_timestamps, context_lengths=context_lengths, sample_timestamps=sample_timestamps, pp_id=pp_id, ) ll_results = model.log_likelihood( return_dict=model_res, right_window=T, left_window=0.0, mask=padding_mask, reduce=False, ) for same_source, ll in zip( batch["same_source"].tolist(), ll_results["batch_log_likelihood"].tolist()): results.append((same_source[0], ll)) sorted_results = sorted(results, key=lambda x: -x[1]) most_likely = sorted_results[:(len(sorted_results) // 2)] correctly_ranked = [same_source for same_source, ll in most_likely] proportion_ranked = sum(correctly_ranked) / len(correctly_ranked) all_results[label] = { "raw": results, "agg": proportion_ranked, } print_log( f"Finished anomaly detection for {length} length target sequences." ) print_log( f"Final proportion of correctly ranked pairs: {proportion_ranked}." ) print_log( f"Results up to now: { {k:v for k,v in all_results.items() if k == 'agg'} }" ) print_log("") res_path = "{}/anomaly_detection_results_{}_{}.pickle".format( args.checkpoint_path.rstrip("/"), "diff_refs" if args.anomaly_same_tgt_diff_refs else "diff_tgt", "trunc_tgt" if args.anomaly_truncate_tgts else "trunc_refs") print_log("Saving intermittent results to", res_path) pickle.dump(all_results, open(res_path, "wb"))
def baseline_anomaly_detection(args, train_dataloader): train_dataset = train_dataloader.dataset # find mle from training data max_T = train_dataset.max_period mle_counts = defaultdict(int) mle_props = defaultdict(float) total_obs = 0 mle_path = "{}/anomaly_detection_baseline_mle.pickle".format( args.checkpoint_path.rstrip("/")) if os.path.exists(mle_path): mle_counts, total_obs = pickle.load(open(mle_path, "rb")) else: print_log("Finding mle from training data") for i, obs in enumerate(train_dataset): if i % 50000 == 0: print_log(f"\tProgress {i} / {len(train_dataset)}") obs = {k: v.numpy() for k, v in obs.items()} if abs(obs["T"].item() - max_T) > 1e-2: continue total_obs += 1 for mark in obs["tgt_marks"]: mle_counts[mark] += 1 pickle.dump((mle_counts, total_obs), open(mle_path, "wb")) # tune variance on valid data prior_alpha = mle_counts prior_beta = {k: total_obs for k in mle_counts} prior_mu, prior_var = _gamma_ab_to_ms(prior_alpha, prior_beta) var_scales = [10**(-i) for i in range(12)] valid_dataset = AnomalyDetectionDataset( file_path=args.valid_data_path, args=args, max_tgt_seq_len=None, num_total_pairs=1000, test=False, ) acc_results = {} print_log(f"Testing different variance scales {var_scales}") for var_scale in var_scales: print_log(f"Trying {var_scale}") results = [] #adj_prior_alpha, adj_prior_beta = _gamma_ms_to_ab(prior_mu, {k:v*var_scale for k,v in prior_var.items()}) # adj_prior_alpha, adj_prior_beta = _gamma_adjust_priors(prior_mu, {k:v*var_scale for k,v in prior_var.items()}) # _, good_prior = _gamma_mean(adj_prior_alpha, adj_prior_beta) # if not good_prior: # print_log("Not a good prior") # continue for i, obs in enumerate(valid_dataset): if i % 100 == 0: print_log(f"\t Progress {i} / {len(valid_dataset)}") obs = {k: v.numpy() for k, v in obs.items()} # post_alpha, post_beta = _gamma_post(obs, adj_prior_alpha, adj_prior_beta, var_scale) # lambda_modes, _ = _gamma_mode(post_alpha, post_beta) #ll = _pois_lik(obs, lambda_modes, max_T) lambda_means = _gamma_post_means(obs, prior_alpha, prior_beta, var_scale) ll = _pois_lik(obs, lambda_means, max_T) results.append((obs["same_source"].item(), ll)) sorted_results = sorted(results, key=lambda x: -x[1]) #print_log(sorted_results[:100], sorted_results[-100:]) most_likely = sorted_results[:(len(sorted_results) // 2)] correctly_ranked = [same_source for same_source, ll in most_likely] proportion_ranked = sum(correctly_ranked) / len(correctly_ranked) acc_results[var_scale] = proportion_ranked print_log( f"Var Scale {var_scale} used, resulted in {proportion_ranked} acc") acc_results = sorted(acc_results.items(), key=lambda x: -x[1]) var_scale, best_valid_acc = acc_results[0] print_log( f"Var Scale chosen: {var_scale} w/ valid accuracy of {best_valid_acc}") lengths_to_test = [1, 5, 10, 20, 50, None][::-1] labels = [f"cond_{i if i is not None else 'all'}" for i in lengths_to_test] all_results = {} for length, label in zip(lengths_to_test, labels): print_log("Performing test on", label) test_dataset = AnomalyDetectionDataset( file_path=args.valid_data_path, args=args, max_tgt_seq_len=length, num_total_pairs=10000, test=True, ) # get ranking test_results = [] print_log("Adjusting priors") adj_prior_alpha, adj_prior_beta = _gamma_ms_to_ab( prior_mu, {k: v * var_scale for k, v in prior_var.items()}) for i, obs in enumerate(test_dataset): if i % 100 == 0: print_log(f"\t Progress {i} / {len(test_dataset)}") obs = {k: v.numpy() for k, v in obs.items()} #post_alpha, post_beta = _gamma_post(obs, adj_prior_alpha, adj_prior_beta) #lambda_modes,_ = _gamma_mode(post_alpha, post_beta) #ll = _pois_lik(obs, lambda_modes, max_T) lambda_means = _gamma_post_means(obs, prior_alpha, prior_beta, var_scale) ll = _pois_lik(obs, lambda_means, max_T) test_results.append((obs["same_source"].item(), ll)) sorted_results = sorted(test_results, key=lambda x: -x[1]) most_likely = sorted_results[:(len(sorted_results) // 2)] correctly_ranked = [same_source for same_source, ll in most_likely] proportion_ranked = sum(correctly_ranked) / len(correctly_ranked) all_results[label] = { "raw": test_results, "agg": proportion_ranked, "var_scale": var_scale } print_log( f"Finished anomaly detection for {length} length target sequences." ) print_log( f"Final proportion of correctly ranked pairs: {proportion_ranked}." ) print_log( f"Results up to now: { {k:v for k,v in all_results.items() if k == 'agg'} }" ) print_log("") res_path = "{}/anomaly_detection_baseline_results_{}_{}.pickle".format( args.checkpoint_path.rstrip("/"), "diff_refs" if args.anomaly_same_tgt_diff_refs else "diff_tgt", "trunc_tgt" if args.anomaly_truncate_tgts else "trunc_refs") print_log("Saving intermittent results to", res_path) pickle.dump(all_results, open(res_path, "wb"))
def next_event_prediction(args, model, dataloader): model.eval() samples_per_time = args.samples_per_sequence div = 4 num_batches = args.num_samples // (args.batch_size // div) num_samples = 10000 num_samples_iterated = torch.arange(start=1, end=num_samples + 1) base_linspace = torch.linspace(1e-10, 1.0, num_samples + 1).unsqueeze(0) if args.cuda: num_samples_iterated = num_samples_iterated.cuda( torch.cuda.current_device()) base_linspace = base_linspace.cuda(torch.cuda.current_device()) select_condition_amounts = False # TODO: Make this an option in the args if select_condition_amounts: events_to_cond = [2, 5, 10, 20, 50] #, 0.05, 0.0] else: # args.max_seq_len is set in the data events_to_cond = list(range(1, min(args.max_seq_len, 50))) all_results = {} mean_results = {} if select_condition_amounts: all_res_path = "{}/pred_task_all_results.pickle".format( args.checkpoint_path.rstrip("/")) mean_res_path = "{}/pred_task_mean_results.pickle".format( args.checkpoint_path.rstrip("/")) if os.path.exists(all_res_path) and os.path.extists(mean_res_path): all_results = pickle.load(open(all_results, "rb")) mean_results = pickle.load(open(mean_res_path, "rb")) events_to_cond = [ i for i in events_to_cond if i not in mean_results ] else: all_mean_res_path = "{}/all_pred_task_mean_results.pickle".format( args.checkpoint_path.rstrip("/")) if os.path.exists(all_mean_res_path): mean_results = pickle.load(open(all_mean_res_path, "rb")) events_to_cond = [ i for i in events_to_cond if i not in mean_results ] print_log( f"Next event prediction with {(args.batch_size // div) * len(dataloader)} predictions for {events_to_cond} different condition lengths." ) print_log( f"Batch size of {args.batch_size // div} with {len(dataloader)} total batches. {num_samples} samples per prediction." ) for cond_num in events_to_cond: print_log( f"Starting prediction tasks where we condition on {cond_num} events prior to prediction." ) _, _, dataloader = get_data(args) #data_iter = iter(dataloader) results = {k: [] for k in all_metrics.keys()} results["pred_time"] = [] results["true_time"] = [] results["last_time"] = [] #while i < num_batches: for i, batch in enumerate(dataloader): if ((i + 1) % 10 == 0) or (i < 20): print_log( f"Batch {i+1}/{len(dataloader)} for conditioning on {cond_num} events processed." ) #batch = next(data_iter) invalid_examples = batch["padding_mask"].sum(dim=-1) < (cond_num + 1) if ((1.0 * invalid_examples).mean().item() == 1.0) or (batch["tgt_times"].shape[-1] < (cond_num + 1)): print_log(f"Skipped batch at i={i-1}") continue else: batch = { k: v[~invalid_examples, ...] for k, v in batch.items() } #if i % (len(dataloader) // 10) == 0: if args.cuda: batch = { k: v.cuda(torch.cuda.current_device()) for k, v in batch.items() } ref_marks, ref_timestamps, context_lengths, padding_mask \ = batch["ref_marks"], batch["ref_times"], batch["context_lengths"], batch["padding_mask"] ref_marks_backwards, ref_timestamps_backwards = batch[ "ref_marks_backwards"], batch["ref_times_backwards"] tgt_marks, tgt_timestamps = batch["tgt_marks"], batch["tgt_times"] pp_id = batch["pp_id"] T = batch["T"] # truncate inputs true_times, true_events = tgt_timestamps[..., cond_num], tgt_marks[ ..., cond_num] tgt_timestamps = tgt_timestamps[..., :cond_num] tgt_marks = tgt_marks[..., :cond_num] padding_mask = padding_mask[..., :cond_num] last_times = tgt_timestamps[..., -1].unsqueeze( -1 ) ## commented code below assumes there is no `unsqueeze(-1)` operation # get output intensity values # sample_timestamps = torch.rand( # tgt_timestamps.shape[0], # num_samples, # dtype=tgt_timestamps.dtype, # device=tgt_timestamps.device # ).clamp(min=1e-8) # ~ U(0,1) # sample_timestamps = sample_timestamps * (T.squeeze(-1) - last_times).unsqueeze(-1) + last_times.unsqueeze(-1) # ~ U(t_{i-1}, T) # sample_timestamps = [] # for i in range(last_times.shape[0]): # sample_timestamps.append(torch.linspace(last_times[i]+1e-9, T[i,0], num_samples+1)) # sample_timestamps = torch.stack(sample_timestamps, dim=0) sample_timestamps = base_linspace * (T - last_times) + last_times timestep = (T - last_times) / num_samples model_res = model( ref_marks=ref_marks, ref_timestamps=ref_timestamps, ref_marks_bwd=ref_marks_backwards, ref_timestamps_bwd=ref_timestamps_backwards, tgt_marks=tgt_marks, tgt_timestamps=tgt_timestamps, context_lengths=context_lengths, sample_timestamps=sample_timestamps, pp_id=pp_id, ) sample_intensities = model_res["sample_intensities"] log_mark_intensity = sample_intensities["all_log_mark_intensities"] total_intensity = sample_intensities["total_intensity"] mark_prob = log_mark_intensity.exp() / total_intensity.unsqueeze( -1) #log_total_intensity = total_intensity.clamp(0.0001, None).log() #log_mark_prob = log_mark_intensity - log_total_intensity.unsqueeze(-1) #mark_prob = log_mark_prob.exp() intensity_integral = torch.cumsum(timestep * total_intensity, dim=-1) t_density = total_intensity * torch.exp(-intensity_integral) t_pit = sample_timestamps * t_density # integrand for time estimator pm_pit = mark_prob * t_density.unsqueeze( -1) # integrand for mark estimator # use the trapeze method of integration pred_times = (timestep * 0.5 * (t_pit[..., 1:] + t_pit[..., :-1])).sum( dim=-1) # sum over sample timestep dimension pred_dists = (timestep.unsqueeze(-1) * 0.5 * (pm_pit[..., 1:, :] + pm_pit[..., :-1, :])).sum( dim=-2) # sum over sample timestep dimension # MC estimate probability distributions # sample_intensities = model_res["sample_intensities"] # log_mark_intensity = sample_intensities["all_log_mark_intensities"] # total_intensity = sample_intensities["total_intensity"] # log_total_intensity = total_intensity.clamp(0.0001, None).log() # log_mark_prob = log_mark_intensity - log_total_intensity.unsqueeze(-1) # #mark_prob = log_mark_prob.exp() # ## p(t_i=t) = \lambda(t) exp(-\int_{t_{i-1}}^t \lambda(s) ds) # ## \int_{t_{i-1}}^t \lambda(s) ds \approx (t - t_{i-1}) * 1/N * \sum_{i=1}^N \lambda(s_i) # ## for s_i \sim U(t_{i-1}, t] # cum_hazard = total_intensity.cumsum(dim=-1) # cum_hazard = cum_hazard * (sample_timestamps - last_times.unsqueeze(-1)) # cum_hazard = -cum_hazard / num_samples_iterated # p_t = total_intensity * cum_hazard.exp() # log_p_t = log_total_intensity + cum_hazard # ## \hat{t_i} = \int_{t_{i-1}}^T tp(t_i=t) dt # pred_times = (T.squeeze() - last_times) / num_samples * (sample_timestamps * p_t).sum(dim=-1) # ## p(k_i=k) \propto \int_{t_{i-1}}^T \lambda_k(t) / \lambda(t) * P(t_i=t) dt # ## since we only care about rankings, we will compute the following instead # ## p(k_i=k) \propto \int_{t_{i-1}}^T log \lambda_k(t) - log\lambda(t) + log P(t_i=t) dt # ## log_mark_prob is size (batch, num_samples, total_marks) # pred_dists = (log_mark_prob + log_p_t.unsqueeze(-1)).sum(dim=-2) # sum over sample dim # pred_dists = pred_dists * (T.squeeze() - last_times).unsqueeze(-1) / num_samples # evaluate metrics r = _rank(pred_dists, true_events ) # compute this so we only rank them once per batch batch_res = { k: metric( pred_times=pred_times, pred_dists=pred_dists, true_times=true_times, true_events=true_events, r=r, ) for k, metric in all_metrics.items() } for t, k in zip([pred_times, true_times, last_times], ["pred_time", "true_time", "last_time"]): _t = t.squeeze().tolist() if not isinstance(_t, list): _t = [_t] batch_res[k] = _t # import readline # optional, will allow Up/Down/History in the console # import code # variables = globals().copy() # variables.update(locals()) # shell = code.InteractiveConsole(variables) # shell.interact() # print("DONE") # input() # store results for k, b_res in batch_res.items(): if k not in results: results[k] = [] results[k].extend(b_res) ## this was debugging for lastfm predictions ## makes no sense for other datasets # if any(x > 30 for x in batch_res["time_l1"]): # print_log("BAD BATCH DETECTED") # print_log("BAD BATCH DETECTED") # print_log("BAD BATCH DETECTED") # if "bad_batches" not in results: # results["bad_batches"] = [] # results["bad_batches"].append((batch_res, {k:v.tolist() for k,v in batch.items()})) # add to overall results if select_condition_amounts: all_results[cond_num] = results mean_res = {} bad_indices = set() for k, v in results.items(): if k != "bad_batches": bad_indices = bad_indices.union( set(i for i, el in enumerate(v) if el != el)) # filter out nan's for k, v in results.items(): if k != "bad_batches": filtered_v = [ el for i, el in enumerate(v) if i not in bad_indices ] if len(filtered_v) > 0: mean_res[k] = sum(filtered_v) / len(filtered_v) else: mean_res[k] = -1 num_seqs = len(filtered_v) mean_res["num_predictions"] = num_seqs #mean_results[cond_num] = {k:((sum(v) / len(v)) if len(v) > 0 else None) for k,v in results.items() if k != "bad_batches"} mean_results[cond_num] = mean_res # save results to file if select_condition_amounts: mean_res_path = "{}/pred_task_mean_results.pickle".format( args.checkpoint_path.rstrip("/")) all_res_path = "{}/pred_task_all_results.pickle".format( args.checkpoint_path.rstrip("/")) print_log("Saving intermittent results to", mean_res_path, all_res_path) pickle.dump(all_results, open(all_res_path, "wb")) pickle.dump(mean_results, open(mean_res_path, "wb")) else: mean_res_path = "{}/all_pred_task_mean_results.pickle".format( args.checkpoint_path.rstrip("/")) print_log("Saving intermittent results to", mean_res_path) pickle.dump(mean_results, open(mean_res_path, "wb"))
def save_latents(args, model, dataloader): num_samples = len(dataloader) model.eval() latents = [] for i, batch in enumerate(dataloader): if args.cuda: batch = { k: v.cuda(torch.cuda.current_device()) for k, v in batch.items() } if i % (num_samples // 10) == 0: print_log("{} Latent state batches extracted".format(i)) if i > num_samples: break ref_marks, ref_timestamps, context_lengths, padding_mask \ = batch["ref_marks"], batch["ref_times"], batch["context_lengths"], batch["padding_mask"] ref_marks_backwards, ref_timestamps_backwards = batch[ "ref_marks_backwards"], batch["ref_times_backwards"] pp_id = batch["pp_id"] with torch.no_grad(): latent = model.get_latent( ref_marks_fwd=ref_marks, ref_timestamps_fwd=ref_timestamps, ref_marks_bwd=ref_marks_backwards, ref_timestamps_bwd=ref_timestamps_backwards, context_lengths=context_lengths, pp_id=pp_id, ) mean = latent["latent_state"] sigma = latent["q_z_x"] if sigma is None: sigma = torch.zeros_like(mean) else: sigma = sigma.scale for ls, sm, cl, m, t, pp in zip(mean.tolist(), sigma.tolist(), context_lengths.squeeze().tolist(), ref_marks.tolist(), ref_timestamps.tolist(), pp_id.squeeze().tolist()): m, t = m[:cl + 1], t[:cl + 1] mark_counts = {} t_delta = [t1 - t0 for t1, t0 in zip(t[1:], t[:-1])] if len(t_delta) == 0: continue for k in m: if k not in mark_counts: mark_counts[k] = 1 else: mark_counts[k] += 1 latents.append({ "latent_mu": ls, "latent_sigma": sm, "mark_counts": mark_counts, "total_events": len(m), "mean_inter_event_time": sum(t_delta) / len(t_delta), "median_inter_event_time": sorted(t_delta)[len(t_delta) // 2], "user_id": pp, }) pickle.dump( latents, open( "{}/extracted_latents.pickle".format( args.checkpoint_path.rstrip("/")), "wb"))
def sample_generations(args, model, dataloader): model.eval() samples_per_time = args.samples_per_sequence users_sampled = args.num_samples T_pcts = [0.5, 0.3, 0.1] #, 0.05, 0.0] all_samples = [] data_iter = iter(dataloader) i = 0 while i < users_sampled: # for i, batch in enumerate(dataloader): # if i >= users_sampled: # break try: batch = next(data_iter) print_log("New user {}".format(i)) if args.cuda: batch = { k: v.cuda(torch.cuda.current_device()) for k, v in batch.items() } ref_marks, ref_timestamps, context_lengths, padding_mask \ = batch["ref_marks"], batch["ref_times"], batch["context_lengths"], batch["padding_mask"] ref_marks_backwards, ref_timestamps_backwards = batch[ "ref_marks_backwards"], batch["ref_times_backwards"] tgt_marks, tgt_timestamps = batch["tgt_marks"], batch["tgt_times"] pp_id = batch["pp_id"] tgt_timestamps = tgt_timestamps[ ..., :padding_mask.cumsum(-1).max().item()] tgt_marks = tgt_marks[..., :padding_mask.cumsum(-1).max().item()] T = batch["T"] user_samples = { "original_times": tgt_timestamps.squeeze().tolist(), "original_marks": tgt_marks.squeeze().tolist(), "original_T": T.squeeze().tolist(), "samples": {} } for pct in T_pcts: print_log("New pct {}".format(pct)) user_samples["samples"][pct] = [] if pct == 0.0: new_tgt_timestamps = tgt_timestamps[..., :1] * 10000 new_tgt_marks = tgt_marks[..., :1] left_window = 0.0 else: new_tgt_timestamps = tgt_timestamps[ ..., :math.floor(pct * tgt_timestamps.shape[-1]) + 1] #torch.where(good_times, tgt_timestamps, torch.ones_like(tgt_timestamps) * 10000) new_tgt_marks = tgt_marks[ ..., :math.floor(pct * tgt_timestamps.shape[-1]) + 1] left_window = new_tgt_timestamps[..., -1].squeeze().item() for j in range(samples_per_time): print("New sample {}".format(j)) samples = None m = 1.0 while samples is None: if m >= 10.0: break samples = model.sample_points( ref_marks=ref_marks, ref_timestamps=ref_timestamps, ref_marks_bwd=ref_marks_backwards, ref_timestamps_bwd=ref_timestamps_backwards, tgt_marks=new_tgt_marks, tgt_timestamps=new_tgt_timestamps, context_lengths=context_lengths, dominating_rate=args.dominating_rate * m, T=T, left_window=left_window, top_k=args.top_k, top_p=args.top_p, ) m *= 1.5 if samples is None: print("No good sample found. Skipping") continue sampled_times, sampled_marks = samples held_out_marks = set( tgt_marks[..., math.floor(pct * tgt_timestamps.shape[-1] ):].squeeze().tolist()) print( "Pct: {} | Left Window: {} |Num Original: {} | Num Conditioned: {} | Num Sampled Alone: {} | Unique Marks on Held Out: {} | Unique Marks Sampled: {} | Common Marks: {}" .format( pct, left_window, tgt_timestamps.squeeze().shape[0], math.floor(pct * tgt_timestamps.shape[-1]), len(sampled_times), len(held_out_marks), len(set(sampled_marks)), len(held_out_marks.intersection( set(sampled_marks))), )) assert (len(sampled_times) == 0 or left_window <= min(sampled_times)) user_samples["samples"][pct].append( (sampled_times, sampled_marks)) all_samples.append(user_samples) i += 1 except StopIteration: break # ran out of data except: continue # data processing error pickle.dump( all_samples, open( "{}/scaling_samples_top_p_{}_top_k_{}.pickle".format( args.checkpoint_path.rstrip("/"), args.top_p, args.top_k), "wb"))
def likelihood_over_time(args, model, dataloader): lik_total_contributions = {} pos_total_contributions = {} neg_total_contributions = {} ce_total_contributions = {} overall_freq = {} lik_diff_contributions = {} pos_diff_contributions = {} neg_diff_contributions = {} ce_diff_contributions = {} all_contributions = { "lik_total": lik_total_contributions, "pos_total": pos_total_contributions, "neg_total": neg_total_contributions, "ce_total": ce_total_contributions, } res = args.likelihood_resolution model.eval() for i, batch in enumerate(dataloader): if i % 20 == 0: print_log("Progress: {} / {}".format(i, len(dataloader))) with torch.no_grad(): ll_results, sample_timestamps, tgt_timestamps = forward_pass( args, batch, model, sample_timestamps=None, num_samples=1000, get_raw_likelihoods=True) pos_cont, neg_cont, ce = ll_results[ "positive_contribution"], ll_results[ "negative_contribution"], ll_results["cross_entropy"] prev_lik, prev_pos, prev_neg, prev_count, prev_ce = 0, 0, 0, 0, 0 for T in np.arange(res, batch["T"].max().item() + res, res): partial_pos_sum, partial_pos_mean, partial_pos_mean_scaled = partial_neg_contributions( pos_cont, tgt_timestamps, T) partial_ce_sum, partial_ce_mean, partial_ce_mean_scaled = partial_neg_contributions( ce, tgt_timestamps, T) partial_neg_sum, partial_neg_mean, partial_neg_mean_scaled = partial_neg_contributions( neg_cont, sample_timestamps, T) partial_lik_cont = partial_pos_sum - partial_neg_mean_scaled new_conts = { "lik_total": partial_lik_cont, "pos_total": partial_ce_mean + partial_pos_mean, "neg_total": partial_neg_mean, "ce_total": partial_ce_mean, } for key, new_cont in new_conts.items(): add_contribution(all_contributions[key], new_cont, T, batch["T"]) mean_contributions = {} lower_ci_contributions = {} upper_ci_contributions = {} for key, total_contributions in all_contributions.items(): if "ce_" in key: mean_contributions[key] = sorted([ (t, sum(ls) / sum(1 for x in ls if (x != 0) and (x != 0.0))) for t, ls in total_contributions.items() ]) elif "pos_" in key: mean_contributions[key] = sorted([ (t, sum(ls) / sum(1 for x in all_contributions["ce_total"][t] if (x != 0) and (x != 0.0))) for t, ls in total_contributions.items() ]) else: mean_contributions[key] = sorted([ (t, sum(ls) / len(ls)) for t, ls in total_contributions.items() ]) pickle.dump( {"mean": mean_contributions}, open( "{}/likelihood_data.pickle".format( args.checkpoint_path.rstrip("/")), "wb"), )
def main(): print_log("Getting arguments.") args = get_args() # args.anomaly_detection = True # args.anomaly_same_tgt_diff_refs = True # default is True, True, False # args.anomaly_truncate_tgts = False # args.anomaly_truncate_refs = True args.sample_generations = True args.top_k = 0 args.top_p = 0 if args.visualize or args.sample_generations: args.batch_size = 4 if args.get_latents: args.shuffle = False args.same_tgt_and_ref = True else: args.shuffle = False if not (args.next_event_prediction or args.anomaly_detection): args.train_data_path = [ fp.replace("train", "vis" if args.visualize else "valid") for fp in args.train_data_path ] print_log("Setting seed.") set_random_seed(args) print_log("Setting up dataloaders.") args.pin_test_memory = True # train_dataloader contains the right data for most tasks train_dataloader, valid_dataloader, test_dataloader = get_data(args) print_log("Setting up model, optimizer, and learning rate scheduler.") model, _, _ = setup_model_and_optim(args, len(train_dataloader)) report_model_stats(model) load_result = load_checkpoint(args, model) if load_result == 0: old_path = args.checkpoint_path args.checkpoint_path = old_path.rstrip("/") + "/data_ablation/" print_log(f"Model not found in {old_path}.") print_log(f"Trying to load model instead from {args.checkpoint_path}.") load_checkpoint(args, model) args.checkpoint_path = old_path if args.visualize: print_log("Starting visualization.") save_and_vis_intensities(args, model, train_dataloader) elif args.sample_generations: print_log("Sampling generations.") sample_generations(args, model, test_dataloader) # train_dataloader) elif args.likelihood_over_time: print_log("Starting likelihood over time analysis.") if "amazon" in args.checkpoint_path: args.likelihood_resolution = args.likelihood_resolution / 4.0 # 1/4 day resolution elif "lastfm" in args.checkpoint_path: args.likelihood_resolution = args.likelihood_resolution / 6.0 # 10 minute resolution # else: 1 hour resolution over 1 week = 168 bins likelihood_over_time(args, model, test_dataloader) # train_dataloader) elif args.get_latents: print_log("Extracting latent states.") save_latents(args, model, train_dataloader) elif args.anomaly_detection: print_log("Starting anomaly detection experiments.") with torch.no_grad(): anomaly_detection(args, model) if "rmtpp" in args.checkpoint_path: baseline_anomaly_detection(args, train_dataloader) elif args.next_event_prediction: print_log("Performing next event prediction experiments.") #args.num_workers = 0 args.num_samples = (len(test_dataloader) - 1) * args.batch_size with torch.no_grad(): next_event_prediction(args, model, test_dataloader)