def train_func(config):
    batch_size = config.get("batch_size", 32)
    hidden_size = config.get("hidden_size", 1)
    lr = config.get("lr", 1e-2)
    epochs = config.get("epochs", 3)

    train_dataset_pipeline_shard = sgd.get_dataset_shard("train")
    validation_dataset_pipeline_shard = sgd.get_dataset_shard("validation")

    device = torch.device(
        f"cuda:{sgd.local_rank()}" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        torch.cuda.set_device(device)

    model = nn.Linear(1, hidden_size)
    model = model.to(device)
    model = DistributedDataParallel(
        model,
        device_ids=[sgd.local_rank()] if torch.cuda.is_available() else None)

    loss_fn = nn.MSELoss()

    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    results = []

    train_dataset_iterator = train_dataset_pipeline_shard.iter_datasets()
    validation_dataset_iterator = \
        validation_dataset_pipeline_shard.iter_datasets()

    for _ in range(epochs):
        train_dataset = next(train_dataset_iterator)
        validation_dataset = next(validation_dataset_iterator)

        train_torch_dataset = train_dataset.to_torch(
            label_column="y",
            feature_columns=["x"],
            label_column_dtype=torch.float,
            feature_column_dtypes=[torch.float],
            batch_size=batch_size,
        )
        validation_torch_dataset = validation_dataset.to_torch(
            label_column="y",
            feature_columns=["x"],
            label_column_dtype=torch.float,
            feature_column_dtypes=[torch.float],
            batch_size=batch_size)

        train(train_torch_dataset, model, loss_fn, optimizer, device)
        result = validate(validation_torch_dataset, model, loss_fn, device)
        sgd.report(**result)
        results.append(result)

    return results
def train_func(config):
    batch_size = config.get("batch_size", 64)
    epochs = config.get("epochs", 3)

    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        # Model building/compiling need to be within `strategy.scope()`.
        multi_worker_model = build_and_compile_model(config)

    dataset_pipeline = sgd.get_dataset_shard()
    dataset_iterator = dataset_pipeline.iter_datasets()

    results = []
    for _ in range(epochs):
        dataset = next(dataset_iterator)
        tf_dataset = prepare_dataset_shard(
            dataset.to_tf(label_column="y",
                          output_signature=(tf.TensorSpec(shape=(None, 1),
                                                          dtype=tf.float32),
                                            tf.TensorSpec(shape=(None),
                                                          dtype=tf.float32)),
                          batch_size=batch_size))
        history = multi_worker_model.fit(tf_dataset,
                                         callbacks=[SGDReportCallback()])
        results.append(history.history)
    return results
Example #3
0
    def get_dataset():
        data_train_all_epochs = []
        data_val_all_epochs = []
        for _ in range(2):
            data_this_epoch_train = []
            train_dataset = sgd.get_dataset_shard("train")
            for batch in train_dataset.iter_batches():
                data_this_epoch_train.extend(batch)
            data_train_all_epochs.append(data_this_epoch_train)

            data_this_epoch_val = []
            val_dataset = sgd.get_dataset_shard("val")
            for batch in val_dataset.iter_batches():
                data_this_epoch_val.extend(batch)
            data_val_all_epochs.append(data_this_epoch_val)

        return data_train_all_epochs, data_val_all_epochs
Example #4
0
 def get_dataset():
     data_all_epochs = []
     for _ in range(2):
         data_this_epoch = []
         dataset = sgd.get_dataset_shard()
         for batch in dataset.iter_batches():
             data_this_epoch.extend(batch)
         data_all_epochs.append(data_this_epoch)
     return data_all_epochs
Example #5
0
 def get_dataset():
     pipeline_iterator = sgd.get_dataset_shard().iter_datasets()
     data_all_epochs = []
     for _ in range(num_epochs):
         dataset_this_epoch = next(pipeline_iterator)
         data_this_epoch = []
         for batch in dataset_this_epoch.iter_batches():
             data_this_epoch.extend(batch)
         data_all_epochs.append(data_this_epoch)
     return data_all_epochs
Example #6
0
    def get_dataset():
        pipeline_iterator = sgd.get_dataset_shard().iter_datasets()
        data_all_epochs = []
        for _ in range(2):
            dataset_this_epoch = next(pipeline_iterator)
            data_this_epoch = []
            for batch in dataset_this_epoch.iter_batches():
                data_this_epoch.extend(batch)

            if len(data_all_epochs) > 0:
                # Make sure data is shuffled per epoch.
                assert data_this_epoch != data_all_epochs[-1]

            data_all_epochs.append(data_this_epoch)
        return data_all_epochs