def train_nn_models(model, event_seqs, args): train_dataloader = DataLoader( EventSeqDataset(event_seqs), **dataloader_args ) train_dataloader, valid_dataloader = split_dataloader( train_dataloader, 8 / 9 ) if "bucket_seqs" in args and args.bucket_seqs: train_dataloader = convert_to_bucketed_dataloader( train_dataloader, key_fn=len ) valid_dataloader = convert_to_bucketed_dataloader( valid_dataloader, key_fn=len, shuffle_same_key=False ) optimizer = getattr(torch.optim, args.optimizer)( model.parameters(), lr=args.lr ) model.train() best_metric = float("nan") for epoch in range(args.epochs): train_metrics, valid_metrics = model.train_epoch( train_dataloader, optimizer, valid_dataloader, device=device, **vars(args), ) msg = f"[Training] Epoch={epoch}" for k, v in train_metrics.items(): msg += f", {k}={v.avg:.4f}" logger.info(msg) msg = f"[Validation] Epoch={epoch}" for k, v in valid_metrics.items(): msg += f", {k}={v.avg:.4f}" logger.info(msg) if compare_metric_value( valid_metrics[args.tune_metric].avg, best_metric, args.tune_metric ): if epoch > args.epochs // 2: logger.info(f"Found a better model at epoch {epoch}.") best_metric = valid_metrics[args.tune_metric].avg torch.save(model.state_dict(), osp.join(output_path, "model.pt")) model.load_state_dict(torch.load(osp.join(output_path, "model.pt"))) return model
def get_infectivity_matrix(model, event_seqs, args): if args.model in ["RME", "ERPP", "RPPN"]: _dataloader_args = dataloader_args.copy() if "attr_batch_size" in args and args.attr_batch_size: _dataloader_args.update(batch_size=args.attr_batch_size) dataloader = DataLoader(EventSeqDataset(event_seqs), **_dataloader_args) dataloader = convert_to_bucketed_dataloader(dataloader, key_fn=len) infectivity = model.get_infectivity(dataloader, device, **vars(args)) else: infectivity = model.get_kernel_norms() return infectivity
def predict_next_event(model, event_seqs, args): if args.model in ["ERPP", "RPPN"]: dataloader = DataLoader(EventSeqDataset(event_seqs), shuffle=False, **dataloader_args) event_seqs_pred = model.predict_next_event(dataloader, device=device) elif args.model == "HExp": from pkg.utils.pp import predict_next_event_hawkes_exp_kern event_seqs_pred = predict_next_event_hawkes_exp_kern(event_seqs, model, verbose=True) else: print("Predicting next event is not supported for " f"model={args.model} yet.") event_seqs_pred = None return event_seqs_pred
def eval_nll(model, event_seqs, args): if args.model in ["RME", "ERPP", "RPPN"]: dataloader = DataLoader(EventSeqDataset(event_seqs), shuffle=False, **dataloader_args) metrics = model.evaluate(dataloader, device=device) logger.info("[Test]" + ", ".join(f"{k}={v.avg:.4f}" for k, v in metrics.items())) nll = metrics["nll"].avg.item() elif args.model == "HSG": nll = eval_nll_hawkes_sum_gaussians(event_seqs, model, verbose=True) elif args.model == "HExp": nll = eval_nll_hawkes_exp_kern(event_seqs, model, verbose=True) else: nll = float("nan") print("not supported yet") return nll
model.load_state_dict( torch.load( osp.join(args.output_dir, args.dataset, args.model, "model.pt"))) # NOTE: Currently pytorch doesn't support gradient evaluation with cuda # backend for RNN; thus here we set the model to training model globally # first and then manually set those submodules that behaves differently # between training and test models (e.g, Dropout, BatchNorm) to evaluation # model. model = model.to(device) model.train() set_eval_mode(model) # freeze the model parameters to reduce unnecessary backpropogation. for param in model.parameters(): param.requires_grad_(False) dataloader = DataLoader( EventSeqDataset(event_seqs), batch_size=args.batch_size, collate_fn=EventSeqDataset.collate_fn, num_workers=args.num_workers, ) infectivity = model.get_infectivity(dataloader, device, steps=args.steps) print(infectivity) np.savetxt( osp.join(args.output_dir, args.dataset, args.model, "infectivity.txt"), infectivity, )
# freeze the model parameters to reduce unnecessary backpropogation. for param in model.parameters(): param.requires_grad_(False) # %% df = pd.DataFrame( columns=["scheme", "batch_size", "seq_length", "time_per_seq"]) seq_lengths = [10, 50, 100, 150, 200] batch_sizes = [1, 2, 4, 8, 16, 32] n_seqs = batch_sizes[-1] for seq_length in seq_lengths: for batch_size in batch_sizes: dataloader = DataLoader( EventSeqDataset(event_seqs[i][:seq_length] for i in range(n_seqs)), batch_size=batch_size, collate_fn=EventSeqDataset.collate_fn, shuffle=False, ) t_start = time.time() model.get_infectivity(dataloader, device, **vars(args)) df.loc[len(df)] = ( "batch", batch_size, seq_length, (time.time() - t_start) / n_seqs, )