示例#1
0
def train_experiment(device):
    with TemporaryDirectory() as logdir:
        # sample data
        num_users, num_features, num_items = int(1e4), int(1e1), 10
        X = torch.rand(num_users, num_features)
        y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_items)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        class CustomRunner(dl.Runner):
            def handle_batch(self, batch):
                x, y = batch
                logits = self.model(x)
                self.batch = {
                    "features": x,
                    "logits": logits,
                    "scores": torch.sigmoid(logits),
                    "targets": y,
                }

        # model training
        runner = CustomRunner()
        runner.train(
            engine=dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            num_epochs=1,
            verbose=False,
            callbacks=[
                dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
                dl.AUCCallback(input_key="scores", target_key="targets"),
                dl.HitrateCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
                dl.MRRCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
                dl.MAPCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
                dl.NDCGCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
                dl.OptimizerCallback(metric_key="loss"),
                dl.SchedulerCallback(),
                dl.CheckpointCallback(
                    logdir=logdir, loader_key="valid", metric_key="map01", minimize=False
                ),
            ],
        )
示例#2
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        # sample data
        num_users, num_features, num_items = int(1e4), int(1e1), 10
        X = torch.rand(num_users, num_features)
        y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_items)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        callbacks = [
            dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
            dl.AUCCallback(input_key="scores", target_key="targets"),
            dl.HitrateCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
            dl.MRRCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
            dl.MAPCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
            dl.NDCGCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
            dl.OptimizerCallback(metric_key="loss"),
            dl.SchedulerCallback(),
            dl.CheckpointCallback(
                logdir=logdir, loader_key="valid", metric_key="map01", minimize=False
            ),
        ]
        if engine is None or not isinstance(
            engine, (dl.AMPEngine, dl.DataParallelAMPEngine, dl.DistributedDataParallelAMPEngine)
        ):
            callbacks.append(dl.AUCCallback(input_key="logits", target_key="targets"))

        # model training
        runner = CustomRunner()
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            num_epochs=1,
            verbose=False,
            callbacks=callbacks,
        )
示例#3
0
    }

    item_num = len(train_dataset[0])
    model = MultiVAE([200, 600, item_num], dropout=0.5)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = StepLR(optimizer, step_size=20, gamma=0.1)
    engine = dl.Engine()
    hparams = {
        "anneal_cap": 0.2,
        "total_anneal_steps": 6000,
    }
    callbacks = [
        dl.NDCGCallback("logits", "targets", [20, 50, 100]),
        dl.MAPCallback("logits", "targets", [20, 50, 100]),
        dl.MRRCallback("logits", "targets", [20, 50, 100]),
        dl.HitrateCallback("logits", "targets", [20, 50, 100]),
        dl.BackwardCallback("loss"),
        dl.OptimizerCallback("loss", accumulation_steps=1),
        dl.SchedulerCallback(),
    ]

    runner = RecSysRunner()
    runner.train(
        model=model,
        optimizer=optimizer,
        engine=engine,
        hparams=hparams,
        scheduler=lr_scheduler,
        loaders=loaders,
        num_epochs=100,
        verbose=True,
示例#4
0
def train_experiment(engine=None):
    with TemporaryDirectory() as logdir:
        # sample data
        num_users, num_features, num_items = int(1e4), int(1e1), 10
        X = torch.rand(num_users, num_features)
        y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_items)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        callbacks = [
            dl.BatchTransformCallback(
                input_key="logits",
                output_key="scores",
                transform=torch.sigmoid,
                scope="on_batch_end",
            ),
            dl.CriterionCallback(input_key="logits",
                                 target_key="targets",
                                 metric_key="loss"),
            dl.HitrateCallback(input_key="scores",
                               target_key="targets",
                               topk=(1, 3, 5)),
            dl.MRRCallback(input_key="scores",
                           target_key="targets",
                           topk=(1, 3, 5)),
            dl.MAPCallback(input_key="scores",
                           target_key="targets",
                           topk=(1, 3, 5)),
            dl.NDCGCallback(input_key="scores",
                            target_key="targets",
                            topk=(1, 3)),
            dl.BackwardCallback(metric_key="loss"),
            dl.OptimizerCallback(metric_key="loss"),
            dl.SchedulerCallback(),
            dl.CheckpointCallback(logdir=logdir,
                                  loader_key="valid",
                                  metric_key="map01",
                                  minimize=False),
        ]
        if isinstance(engine, dl.CPUEngine):
            callbacks.append(
                dl.AUCCallback(input_key="logits", target_key="targets"))

        # model training
        runner = dl.SupervisedRunner(
            input_key="features",
            output_key="logits",
            target_key="targets",
            loss_key="loss",
        )
        runner.train(
            engine=engine,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            num_epochs=1,
            verbose=False,
            callbacks=callbacks,
        )