Пример #1
0
def train_nn_models(model, event_seqs, args):

    train_dataloader = DataLoader(
        EventSeqDataset(event_seqs), **dataloader_args
    )

    train_dataloader, valid_dataloader = split_dataloader(
        train_dataloader, 8 / 9
    )
    if "bucket_seqs" in args and args.bucket_seqs:
        train_dataloader = convert_to_bucketed_dataloader(
            train_dataloader, key_fn=len
        )
    valid_dataloader = convert_to_bucketed_dataloader(
        valid_dataloader, key_fn=len, shuffle_same_key=False
    )

    optimizer = getattr(torch.optim, args.optimizer)(
        model.parameters(), lr=args.lr
    )

    model.train()
    best_metric = float("nan")

    for epoch in range(args.epochs):
        train_metrics, valid_metrics = model.train_epoch(
            train_dataloader,
            optimizer,
            valid_dataloader,
            device=device,
            **vars(args),
        )

        msg = f"[Training] Epoch={epoch}"
        for k, v in train_metrics.items():
            msg += f", {k}={v.avg:.4f}"
        logger.info(msg)
        msg = f"[Validation] Epoch={epoch}"
        for k, v in valid_metrics.items():
            msg += f", {k}={v.avg:.4f}"
        logger.info(msg)

        if compare_metric_value(
            valid_metrics[args.tune_metric].avg, best_metric, args.tune_metric
        ):
            if epoch > args.epochs // 2:
                logger.info(f"Found a better model at epoch {epoch}.")
            best_metric = valid_metrics[args.tune_metric].avg
            torch.save(model.state_dict(), osp.join(output_path, "model.pt"))

    model.load_state_dict(torch.load(osp.join(output_path, "model.pt")))

    return model
Пример #2
0
def get_infectivity_matrix(model, event_seqs, args):

    if args.model in ["RME", "ERPP", "RPPN"]:
        _dataloader_args = dataloader_args.copy()
        if "attr_batch_size" in args and args.attr_batch_size:
            _dataloader_args.update(batch_size=args.attr_batch_size)

        dataloader = DataLoader(EventSeqDataset(event_seqs),
                                **_dataloader_args)
        dataloader = convert_to_bucketed_dataloader(dataloader, key_fn=len)
        infectivity = model.get_infectivity(dataloader, device, **vars(args))
    else:
        infectivity = model.get_kernel_norms()

    return infectivity
Пример #3
0
def predict_next_event(model, event_seqs, args):
    if args.model in ["ERPP", "RPPN"]:
        dataloader = DataLoader(EventSeqDataset(event_seqs),
                                shuffle=False,
                                **dataloader_args)
        event_seqs_pred = model.predict_next_event(dataloader, device=device)
    elif args.model == "HExp":
        from pkg.utils.pp import predict_next_event_hawkes_exp_kern

        event_seqs_pred = predict_next_event_hawkes_exp_kern(event_seqs,
                                                             model,
                                                             verbose=True)
    else:
        print("Predicting next event is not supported for "
              f"model={args.model} yet.")
        event_seqs_pred = None

    return event_seqs_pred
Пример #4
0
def eval_nll(model, event_seqs, args):
    if args.model in ["RME", "ERPP", "RPPN"]:

        dataloader = DataLoader(EventSeqDataset(event_seqs),
                                shuffle=False,
                                **dataloader_args)

        metrics = model.evaluate(dataloader, device=device)
        logger.info("[Test]" + ", ".join(f"{k}={v.avg:.4f}"
                                         for k, v in metrics.items()))
        nll = metrics["nll"].avg.item()

    elif args.model == "HSG":
        nll = eval_nll_hawkes_sum_gaussians(event_seqs, model, verbose=True)

    elif args.model == "HExp":
        nll = eval_nll_hawkes_exp_kern(event_seqs, model, verbose=True)
    else:
        nll = float("nan")
        print("not supported yet")

    return nll
Пример #5
0
    model.load_state_dict(
        torch.load(
            osp.join(args.output_dir, args.dataset, args.model, "model.pt")))

    # NOTE: Currently pytorch doesn't support gradient evaluation with cuda
    # backend for RNN; thus here we set the model to training model globally
    # first and then manually set those submodules that behaves differently
    # between training and test models (e.g, Dropout, BatchNorm) to evaluation
    # model.
    model = model.to(device)
    model.train()
    set_eval_mode(model)
    # freeze the model parameters to reduce unnecessary backpropogation.
    for param in model.parameters():
        param.requires_grad_(False)

    dataloader = DataLoader(
        EventSeqDataset(event_seqs),
        batch_size=args.batch_size,
        collate_fn=EventSeqDataset.collate_fn,
        num_workers=args.num_workers,
    )

    infectivity = model.get_infectivity(dataloader, device, steps=args.steps)
    print(infectivity)

    np.savetxt(
        osp.join(args.output_dir, args.dataset, args.model, "infectivity.txt"),
        infectivity,
    )
# freeze the model parameters to reduce unnecessary backpropogation.
for param in model.parameters():
    param.requires_grad_(False)

# %%
df = pd.DataFrame(
    columns=["scheme", "batch_size", "seq_length", "time_per_seq"])

seq_lengths = [10, 50, 100, 150, 200]
batch_sizes = [1, 2, 4, 8, 16, 32]
n_seqs = batch_sizes[-1]

for seq_length in seq_lengths:
    for batch_size in batch_sizes:
        dataloader = DataLoader(
            EventSeqDataset(event_seqs[i][:seq_length] for i in range(n_seqs)),
            batch_size=batch_size,
            collate_fn=EventSeqDataset.collate_fn,
            shuffle=False,
        )

        t_start = time.time()
        model.get_infectivity(dataloader, device, **vars(args))

        df.loc[len(df)] = (
            "batch",
            batch_size,
            seq_length,
            (time.time() - t_start) / n_seqs,
        )