def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        # sample data
        num_samples, num_features, num_classes = int(1e4), int(1e1), 4
        X = torch.rand(num_samples, num_features)
        y = (torch.rand(num_samples, num_classes) > 0.5).to(torch.float32)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_classes)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        # model training
        runner = dl.SupervisedRunner(input_key="features",
                                     output_key="logits",
                                     target_key="targets",
                                     loss_key="loss")
        callbacks = [
            dl.BatchTransformCallback(
                transform="F.sigmoid",
                scope="on_batch_end",
                input_key="logits",
                output_key="scores",
            ),
            dl.MultilabelAccuracyCallback(input_key="scores",
                                          target_key="targets",
                                          threshold=0.5),
            dl.MultilabelPrecisionRecallF1SupportCallback(
                input_key="scores",
                target_key="targets",
                num_classes=num_classes),
        ]
        if SETTINGS.amp_required and (engine is None or not isinstance(
                engine,
            (dl.AMPEngine, dl.DataParallelAMPEngine,
             dl.DistributedDataParallelAMPEngine),
        )):
            callbacks.append(
                dl.AUCCallback(input_key="scores", target_key="targets"))
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            logdir=logdir,
            num_epochs=1,
            valid_loader="valid",
            valid_metric="accuracy",
            minimize_valid_metric=False,
            verbose=False,
            callbacks=callbacks,
        )
def train_experiment(device):
    with TemporaryDirectory() as logdir:
        # sample data
        num_samples, num_features, num_classes = int(1e4), int(1e1), 4
        X = torch.rand(num_samples, num_features)
        y = (torch.rand(num_samples, num_classes) > 0.5).to(torch.float32)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_classes)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        # model training
        runner = dl.SupervisedRunner(input_key="features",
                                     output_key="logits",
                                     target_key="targets",
                                     loss_key="loss")
        runner.train(
            engine=dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            logdir=logdir,
            num_epochs=1,
            valid_loader="valid",
            valid_metric="accuracy",
            minimize_valid_metric=False,
            verbose=False,
            callbacks=[
                dl.AUCCallback(input_key="logits", target_key="targets"),
                dl.MultilabelAccuracyCallback(input_key="logits",
                                              target_key="targets",
                                              threshold=0.5),
            ],
        )