def train_func(config):
    batch_size = config.get("batch_size", 32)
    hidden_size = config.get("hidden_size", 1)
    lr = config.get("lr", 1e-2)
    epochs = config.get("epochs", 3)

    train_dataset_pipeline_shard = sgd.get_dataset_shard("train")
    validation_dataset_pipeline_shard = sgd.get_dataset_shard("validation")

    device = torch.device(
        f"cuda:{sgd.local_rank()}" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        torch.cuda.set_device(device)

    model = nn.Linear(1, hidden_size)
    model = model.to(device)
    model = DistributedDataParallel(
        model,
        device_ids=[sgd.local_rank()] if torch.cuda.is_available() else None)

    loss_fn = nn.MSELoss()

    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    results = []

    train_dataset_iterator = train_dataset_pipeline_shard.iter_datasets()
    validation_dataset_iterator = \
        validation_dataset_pipeline_shard.iter_datasets()

    for _ in range(epochs):
        train_dataset = next(train_dataset_iterator)
        validation_dataset = next(validation_dataset_iterator)

        train_torch_dataset = train_dataset.to_torch(
            label_column="y",
            feature_columns=["x"],
            label_column_dtype=torch.float,
            feature_column_dtypes=[torch.float],
            batch_size=batch_size,
        )
        validation_torch_dataset = validation_dataset.to_torch(
            label_column="y",
            feature_columns=["x"],
            label_column_dtype=torch.float,
            feature_column_dtypes=[torch.float],
            batch_size=batch_size)

        train(train_torch_dataset, model, loss_fn, optimizer, device)
        result = validate(validation_torch_dataset, model, loss_fn, device)
        sgd.report(**result)
        results.append(result)

    return results
Beispiel #2
0
 def train_actor_failure():
     import sys
     sys.exit(0)
     return sgd.local_rank()
Beispiel #3
0
 def train():
     return sgd.local_rank()