Exemplo n.º 1
0
    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            transform_spec = TransformSpec(
                self.transform_spec) if self.transform_spec else None
            # In general, make_batch_reader is faster than make_reader for reading the dataset.
            # However, we found out that make_reader performs data transformations much faster than
            # make_batch_reader with parallel worker processes. Therefore, the default reader
            # we choose is make_batch_reader unless there are data transformations.
            if transform_spec:
                reader_factory = make_reader
            else:
                reader_factory = make_batch_reader

            self.train_reader = reader_factory(
                self.train_dir,
                num_epochs=self.num_reader_epochs,
                cur_shard=self.cur_shard,
                shard_count=self.shard_count,
                hdfs_driver=PETASTORM_HDFS_DRIVER,
                schema_fields=self.schema_fields,
                storage_options=self.storage_options,
                # Don't shuffle row groups without shuffling.
                shuffle_row_groups=True if self.shuffle_size > 0 else False)
            if self.has_val:
                self.val_reader = reader_factory(
                    self.val_dir,
                    num_epochs=self.num_reader_epochs,
                    cur_shard=self.cur_shard,
                    shard_count=self.shard_count,
                    hdfs_driver=PETASTORM_HDFS_DRIVER,
                    schema_fields=self.schema_fields,
                    storage_options=self.storage_options,
                    shuffle_row_groups=False)
def test_full_pytorch_example(large_mock_mnist_data, tmpdir):
    # First, generate mock dataset
    dataset_url = 'file://{}'.format(tmpdir)
    mnist_data_to_pycarbon_dataset(tmpdir,
                                   dataset_url,
                                   mnist_data=large_mock_mnist_data,
                                   spark_master='local[1]',
                                   carbon_files_count=1)

    # Next, run a round of training using the pytorce adapting data loader
    from petastorm.pytorch import DataLoader

    torch.manual_seed(1)
    device = torch.device('cpu')
    model = pytorch_example.Net().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    transform = TransformSpec(pytorch_example._transform_row,
                              removed_fields=['idx'])

    with DataLoader(make_carbon_reader('{}/train'.format(dataset_url),
                                       reader_pool_type='dummy',
                                       num_epochs=1,
                                       transform_spec=transform),
                    batch_size=32) as train_loader:
        pytorch_example.train(model, device, train_loader, 10, optimizer, 1)

    with DataLoader(make_carbon_reader('{}/test'.format(dataset_url),
                                       reader_pool_type='dummy',
                                       num_epochs=1,
                                       transform_spec=transform),
                    batch_size=100) as test_loader:
        pytorch_example.evaluation(model, device, test_loader)
def test_pytorch_dataloader_with_transform_function(synthetic_dataset):
    with DataLoader(make_reader(synthetic_dataset.url,
                                schema_fields=ALL_FIELDS - NULLABLE_FIELDS,
                                reader_pool_type='dummy',
                                transform_spec=TransformSpec(_str_to_int)),
                    collate_fn=_noop_collate) as loader:
        for item in loader:
            assert len(item) == 1
Exemplo n.º 4
0
def test_simple_read(synthetic_dataset, reader_factory):
    with DataLoader(
            reader_factory(
                synthetic_dataset.url,
                schema_fields=BATCHABLE_FIELDS,
                transform_spec=TransformSpec(_sensor_name_to_int))) as loader:
        _check_simple_reader(loader, synthetic_dataset.data,
                             BATCHABLE_FIELDS - {TestSchema.sensor_name})
Exemplo n.º 5
0
    def make_petastorm_reader(model,
                              data_path,
                              dataloader_attr,
                              reader_worker_count,
                              reader_pool_type,
                              should_read=True):
        from petastorm import TransformSpec, make_reader, make_batch_reader
        import horovod.torch as hvd

        is_loader_overridden = False
        if LooseVersion(pl.__version__) >= LooseVersion('1.0.0'):
            from pytorch_lightning.utilities.model_helpers import is_overridden
            is_loader_overridden = is_overridden(dataloader_attr, model)

        if not should_read or is_loader_overridden:
            yield
            return

        transform_spec = TransformSpec(
            transformation) if transformation else None

        # In general, make_batch_reader is faster than make_reader for reading the dataset.
        # However, we found out that make_reader performs data transformations much faster than
        # make_batch_reader with parallel worker processes. Therefore, the default reader
        # we choose is make_batch_reader unless there are data transformations.
        reader_factory_kwargs = dict()
        if transform_spec:
            reader_factory = make_reader
            reader_factory_kwargs['pyarrow_serialize'] = True
        else:
            reader_factory = make_batch_reader

        # Petastorm: read data from the store with the correct shard for this rank
        # setting num_epochs=None will cause an infinite iterator
        # and enables ranks to perform training and validation with
        # unequal number of samples
        with reader_factory(data_path,
                            num_epochs=1,
                            cur_shard=hvd.rank(),
                            shard_count=hvd.size(),
                            reader_pool_type=reader_pool_type,
                            workers_count=reader_worker_count,
                            hdfs_driver=PETASTORM_HDFS_DRIVER,
                            schema_fields=schema_fields,
                            transform_spec=transform_spec,
                            **reader_factory_kwargs) as reader:

            def dataloader_fn():
                return dataloader_cls(
                    reader,
                    batch_size=batch_size,
                    shuffling_queue_capacity=calculate_shuffle_buffer_size())

            try:
                setattr(model, dataloader_attr, dataloader_fn)
                yield
            finally:
                setattr(model, dataloader_attr, None)
Exemplo n.º 6
0
def test_transform_remove_field(synthetic_dataset, reader_factory):
    """Make sure we apply transform only after we apply the predicate"""

    with reader_factory(
            synthetic_dataset.url,
            schema_fields=[TestSchema.id, TestSchema.id2],
            transform_spec=TransformSpec(removed_fields=['id2'])) as reader:
        row = next(reader)
        assert 'id2' not in row._fields
        assert 'id' in row._fields
Exemplo n.º 7
0
def test_transform_function_batched(scalar_dataset):
    def double_float64(sample):
        sample['float64'] *= 2
        return sample

    with make_batch_reader(scalar_dataset.url, transform_spec=TransformSpec(double_float64)) as reader:
        actual = next(reader)
        for actual_id, actual_float64 in zip(actual.id, actual.float64):
            original_sample = next(d for d in scalar_dataset.data if d['id'] == actual_id)
            expected_matrix = original_sample['float64'] * 2
            np.testing.assert_equal(expected_matrix, actual_float64)
Exemplo n.º 8
0
def test_transform_function_batched_deleting_column(scalar_dataset):
    def double_float64(sample):
        del sample['float64']
        return sample

    with make_batch_reader(scalar_dataset.url,
                           transform_spec=TransformSpec(double_float64,
                                                        removed_fields=[
                                                            'float64'
                                                        ])) as reader:
        actual = next(reader)
        assert 'float64' not in actual._fields
Exemplo n.º 9
0
def test_transform_function_returns_a_new_dict(synthetic_dataset,
                                               reader_factory):
    """"""
    def double_matrix(sample):
        return {'id': -1}

    with reader_factory(synthetic_dataset.url,
                        schema_fields=[TestSchema.id],
                        transform_spec=TransformSpec(double_matrix)) as reader:
        all_samples = list(reader)
        actual_ids = list(map(lambda x: x.id, all_samples))

        np.testing.assert_equal(actual_ids, [-1] * len(synthetic_dataset.data))
Exemplo n.º 10
0
    def _init_petaloader(self):
        def _transform_row(df_batch):
            return df_batch

        transform = TransformSpec(_transform_row, removed_fields=['cat_id', 'store_id', 'state_id'])
        reader = make_batch_reader(self.filename,
                 schema_fields=['id', 'item_id', 'dept_id', 'cat_id', 'day_id',
               'sales', 'day_date_str', 'month_id', 'date', 'wm_yr_wk',
               'snap_flag', 'sell_price', 'sales_dollars', 'store_id', 'state_id'],
                workers_count=1
                #,transform_spec = transform
        )
        return PetaDataLoader(reader=reader, batch_size=128, shuffling_queue_capacity=100000)
Exemplo n.º 11
0
def test_transform_function(synthetic_dataset, reader_factory):
    """"""

    def double_matrix(sample):
        sample['matrix'] *= 2
        return sample

    with reader_factory(synthetic_dataset.url, schema_fields=[TestSchema.id, TestSchema.matrix],
                        transform_spec=TransformSpec(double_matrix)) as reader:
        actual = next(reader)
        original_sample = next(d for d in synthetic_dataset.data if d['id'] == actual.id)
        expected_matrix = original_sample['matrix'] * 2
        np.testing.assert_equal(expected_matrix, actual.matrix)
Exemplo n.º 12
0
def test_transform_function_with_predicate_batched(scalar_dataset):
    def double_float64(sample):
        assert all(sample['id'] % 2 == 0)
        sample['float64'] *= 2
        return sample

    with make_batch_reader(scalar_dataset.url, transform_spec=TransformSpec(double_float64),
                           predicate=in_lambda(['id'], lambda id: id % 2 == 0)) as reader:
        actual = next(reader)
        for actual_id, actual_float64 in zip(actual.id, actual.float64):
            assert actual_id % 2 == 0
            original_sample = next(d for d in scalar_dataset.data if d['id'] == actual_id)
            expected_matrix = original_sample['float64'] * 2
            np.testing.assert_equal(expected_matrix, actual_float64)
Exemplo n.º 13
0
def test_transform_function_with_predicate(synthetic_dataset, reader_factory):
    """Make sure we apply transform only after we apply the predicate"""

    with reader_factory(
            synthetic_dataset.url,
            schema_fields=[TestSchema.id, TestSchema.id2],
            predicate=in_lambda(['id2'], lambda id2: id2 == 1),
            transform_spec=TransformSpec(removed_fields=['id2'])) as reader:
        rows = list(reader)
        assert 'id2' not in rows[0]._fields
        actual_ids = np.asarray(list(row.id for row in rows))
        assert actual_ids.size > 0
        # In the test data id2 = id % 2, which means we expect only odd ids to remain after
        # we apply lambda id2: id2 == 1 predicate.
        assert np.all(actual_ids % 2 == 1)
Exemplo n.º 14
0
def test_transform_function_new_field(synthetic_dataset, reader_factory):
    """"""

    def double_matrix(sample):
        sample['double_matrix'] = sample['matrix'] * 2
        del sample['matrix']
        return sample

    with reader_factory(synthetic_dataset.url, schema_fields=[TestSchema.id, TestSchema.matrix],
                        transform_spec=TransformSpec(double_matrix,
                                                     [('double_matrix', np.float32, (32, 16, 3), False)],
                                                     ['matrix'])) as reader:
        actual = next(reader)
        original_sample = next(d for d in synthetic_dataset.data if d['id'] == actual.id)
        expected_matrix = original_sample['matrix'] * 2
        np.testing.assert_equal(expected_matrix, actual.double_matrix)
Exemplo n.º 15
0
def train_net(params):
    trainer = emulator.MAIACTrainer(params)

    # set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    trainer.to(device)
    trainer.load_checkpoint()

    # data transformatoins
    drop_columns = [
        'year', 'dayofyear', 'hour', 'minute', 'fileahi05', 'fileahi12', 'h',
        'v', 'sample_id'
    ]
    transform = TransformSpec(_transform_row('AHI'),
                              removed_fields=drop_columns)
    while 1:
        with DataLoader(make_reader(params['data_url'],
                                    num_epochs=1,
                                    shuffle_row_drop_partitions=5,
                                    transform_spec=transform),
                        batch_size=8) as loader:
            for example in loader:
                #x = example['AHI05'][:,:4].to(device)
                x = example['AHI05'].to(device)
                ahi12 = example['AHI12'].type(torch.FloatTensor)
                mask = (ahi12 != ahi12).type(
                    torch.FloatTensor)  # null values = 1

                mask = mask.to(device)
                ahi12 = ahi12.to(device)
                log = False

                if trainer.global_step % params['log_iter'] == 0:
                    log = True

                loss = trainer.step(x, ahi12, mask, log=log)
                if log:
                    print(
                        f"Step: {trainer.global_step}\tLoss: {loss.item():4.4g}"
                    )

                if trainer.global_step % params['checkpoint_step'] == 1:
                    trainer.save_checkpoint()

                if trainer.global_step >= params['max_iter']:
                    break
Exemplo n.º 16
0
def test_transform_function_new_field(synthetic_dataset):
    def double_matrix(sample):
        sample['double_matrix'] = sample['matrix'] * 2
        del sample['matrix']
        return sample

    with make_reader(synthetic_dataset.url, reader_pool_type='dummy', schema_fields=[TestSchema.id, TestSchema.matrix],
                     transform_spec=TransformSpec(double_matrix,
                                                  [('double_matrix', np.float32, (32, 16, 3), False)],
                                                  ['matrix'])) as reader:
        row_tensors = tf_tensors(reader)
        with _tf_session() as sess:
            actual = sess.run(row_tensors)

        original_sample = next(d for d in synthetic_dataset.data if d['id'] == actual.id)
        expected_matrix = original_sample['matrix'] * 2
        np.testing.assert_equal(expected_matrix, actual.double_matrix)
Exemplo n.º 17
0
def test_transform_function_new_field_batched(scalar_dataset):
    def double_float64(sample):
        sample['new_float64'] = sample['float64'] * 2
        del sample['float64']
        return sample

    with make_batch_reader(scalar_dataset.url, reader_pool_type='dummy',
                           transform_spec=TransformSpec(double_float64,
                                                        [('new_float64', np.float64, (), False)],
                                                        ['float64'])) as reader:
        row_tensors = tf_tensors(reader)
        with _tf_session() as sess:
            actual = sess.run(row_tensors)

        for actual_id, actual_float64 in zip(actual.id, actual.new_float64):
            original_sample = next(d for d in scalar_dataset.data if d['id'] == actual_id)
            expected = original_sample['float64'] * 2
            np.testing.assert_equal(expected, actual_float64)
def test_torch_transform_spec(test_ctx):
    df = test_ctx.spark.range(8)
    conv = make_spark_converter(df)

    from torchvision import transforms
    from petastorm import TransformSpec

    def _transform_row(df_row):
        scale_tranform = transforms.Compose([
            transforms.Lambda(lambda x: x * 0.1),
        ])
        return scale_tranform(df_row)

    transform = TransformSpec(_transform_row)
    with conv.make_torch_dataloader(transform_spec=transform,
                                    num_epochs=1) as dataloader:
        for batch in dataloader:
            assert min(batch['id']) >= 0 and max(batch['id']) < 1
Exemplo n.º 19
0
def test_transform_function_returns_a_new_dict_with_predicate(
        synthetic_dataset, reader_factory):
    def transform(sample):
        return {'id': sample['id'], 'id2': -1}

    with reader_factory(
            synthetic_dataset.url,
            schema_fields=[TestSchema.id, TestSchema.id2],
            predicate=in_lambda(['id2'], lambda id2: id2 == 1),
            transform_spec=TransformSpec(func=transform)) as reader:
        rows = list(reader)
        actual_ids = np.asarray(list(row.id for row in rows))
        assert actual_ids.size > 0
        # In the test data id2 = id % 2, which means we expect only odd ids to remain after
        # we apply lambda id2: id2 == 1 predicate.
        assert np.all(actual_ids % 2 == 1)

        transformed_ids = np.asarray(list(row.id2 for row in rows))
        assert np.all(transformed_ids == -1)
def make_loaders(params):
    '''
    Data parameters from training configuration file are used to build
        training generators
    Args:
        params
    Returns:
        
    '''
    loaders = dict()
    datanames = params['data'].keys()
    drop_columns = [
        'year', 'dayofyear', 'hour', 'minute', 'file', 'h', 'v', 'sample_id'
    ]
    for key in datanames:
        url = params['data'][key]['data_url']
        transform = TransformSpec(_transform_row(key),
                                  removed_fields=drop_columns)
        loaders[key] = DataLoader(make_reader(url, transform_spec=transform),
                                  batch_size=params['batch_size'])
    return loaders
def test_advanced_params(test_ctx):
    df = test_ctx.spark.range(8)
    conv = make_spark_converter(df)
    batch_size = 2
    with conv.make_torch_dataloader(batch_size=batch_size,
                                    num_epochs=1) as dataloader:
        for batch in dataloader:
            assert batch_size == batch['id'].shape[0]

    from torchvision import transforms
    from petastorm import TransformSpec

    def _transform_row(df_row):
        scale_tranform = transforms.Compose([
            transforms.Lambda(lambda x: x * 0.1),
        ])
        return scale_tranform(df_row)

    transform = TransformSpec(_transform_row)
    with conv.make_torch_dataloader(transform_spec=transform,
                                    num_epochs=1) as dataloader:
        for batch in dataloader:
            assert min(batch['id']) >= 0 and max(batch['id']) < 1

    with pytest.raises(TypeError, match="unexpected keyword argument 'xyz'"):
        conv.make_torch_dataloader(xyz=1)

    def mock_make_batch_reader(dataset_url,
                               schema_fields=None,
                               reader_pool_type='thread', workers_count=10,
                               shuffle_row_groups=True, shuffle_row_drop_partitions=1,
                               predicate=None,
                               rowgroup_selector=None,
                               num_epochs=1,
                               cur_shard=None, shard_count=None,
                               cache_type='null', cache_location=None, cache_size_limit=None,
                               cache_row_size_estimate=None, cache_extra_settings=None,
                               hdfs_driver='libhdfs3',
                               transform_spec=None):
        return {
            "dataset_url": dataset_url,
            "schema_fields": schema_fields,
            "reader_pool_type": reader_pool_type,
            "workers_count": workers_count,
            "shuffle_row_groups": shuffle_row_groups,
            "shuffle_row_drop_partitions": shuffle_row_drop_partitions,
            "predicate": predicate,
            "rowgroup_selector": rowgroup_selector,
            "num_epochs": num_epochs,
            "cur_shard": cur_shard,
            "shard_count": shard_count,
            "cache_type": cache_type,
            "cache_location": cache_location,
            "cache_size_limit": cache_size_limit,
            "cache_row_size_estimate": cache_row_size_estimate,
            "cache_extra_settings": cache_extra_settings,
            "hdfs_driver": hdfs_driver,
            "transform_spec": transform_spec,
        }

    original_fn = petastorm.make_batch_reader
    petastorm.make_batch_reader = mock_make_batch_reader
    ctm = conv.make_torch_dataloader(schema_fields="schema_1",
                                     reader_pool_type='type_1',
                                     workers_count="count_1",
                                     shuffle_row_groups="row_group_1",
                                     shuffle_row_drop_partitions="drop_1",
                                     predicate="predicate_1",
                                     rowgroup_selector="selector_1",
                                     num_epochs="num_1",
                                     cur_shard="shard_1",
                                     shard_count="total_shard",
                                     cache_type="cache_1",
                                     cache_location="location_1",
                                     cache_size_limit="limit_1",
                                     cache_extra_settings="extra_1",
                                     hdfs_driver="driver_1",
                                     transform_spec="transform_spec_1")
    assert ctm.reader["schema_fields"] == "schema_1"
    assert ctm.reader["reader_pool_type"] == "type_1"
    assert ctm.reader["workers_count"] == "count_1"
    assert ctm.reader["shuffle_row_groups"] == "row_group_1"
    assert ctm.reader["shuffle_row_drop_partitions"] == "drop_1"
    assert ctm.reader["predicate"] == "predicate_1"
    assert ctm.reader["rowgroup_selector"] == "selector_1"
    assert ctm.reader["num_epochs"] == "num_1"
    assert ctm.reader["cur_shard"] == "shard_1"
    assert ctm.reader["shard_count"] == "total_shard"
    assert ctm.reader["cache_type"] == "cache_1"
    assert ctm.reader["cache_location"] == "location_1"
    assert ctm.reader["cache_size_limit"] == "limit_1"
    assert ctm.reader["cache_extra_settings"] == "extra_1"
    assert ctm.reader["hdfs_driver"] == "driver_1"
    assert ctm.reader["transform_spec"] == "transform_spec_1"

    petastorm.make_batch_reader = original_fn
Exemplo n.º 22
0
    def train(serialized_model, train_rows, val_rows, avg_row_size):
        from petastorm import TransformSpec, make_reader, make_batch_reader
        import horovod as _horovod
        k = get_keras()
        k.backend.set_floatx(floatx)

        hvd = get_horovod()
        hvd.init()

        pin_gpu(hvd, tf, k)

        # If user specifies any user_shuffle_buffer_size (even 0), we should honor it.
        if user_shuffle_buffer_size is None:
            shuffle_buffer_size = calculate_shuffle_buffer_size(
                hvd, avg_row_size, train_rows / hvd.size())
        else:
            if user_shuffle_buffer_size < 0:
                raise ValueError(
                    "user_shuffle_buffer_size cannot be negative!")
            shuffle_buffer_size = user_shuffle_buffer_size

        # needs to be deserialized in the with scope
        with k.utils.custom_object_scope(custom_objects):
            model = deserialize_keras_model(serialized_model,
                                            lambda x: hvd.load_model(x))

        # Horovod: adjust learning rate based on number of processes.
        scaled_lr = k.backend.get_value(model.optimizer.lr) * hvd.size()
        k.backend.set_value(model.optimizer.lr, scaled_lr)

        # Verbose mode 1 will print a progress bar
        verbose = user_verbose if hvd.rank() == 0 else 0

        if verbose:
            print(
                f"Shared lib path is pointing to: {_horovod.common.process_sets._basics.MPI_LIB_CTYPES}"
            )

        transform_spec = None
        if transformation:
            transform_spec = TransformSpec(transformation)

        # The inital_lr needs to be set to scaled learning rate in the checkpointing callbacks.
        for callback in user_callbacks:
            if isinstance(
                    callback, _horovod._keras.callbacks.
                    LearningRateScheduleCallbackImpl):
                callback.initial_lr = scaled_lr

        with remote_store.get_local_output_dir() as run_output_dir:
            callbacks = [
                # Horovod: broadcast initial variable states from rank 0 to all other processes.
                # This is necessary to ensure consistent initialization of all workers when
                # training is started with random weights or restored from a checkpoint.
                hvd.callbacks.BroadcastGlobalVariablesCallback(root_rank=0),

                # Horovod: average metrics among workers at the end of every epoch.
                #
                # Note: This callback must be in the list before the ReduceLROnPlateau,
                # TensorBoard, or other metrics-based callbacks.
                hvd.callbacks.MetricAverageCallback(),
            ]

            callbacks += user_callbacks

            # Horovod: save checkpoints only on the first worker to prevent other workers from
            # corrupting them.
            if hvd.rank() == 0:
                ckpt_file = os.path.join(run_output_dir,
                                         remote_store.checkpoint_filename)
                logs_dir = os.path.join(run_output_dir,
                                        remote_store.logs_subdir)

                # This callback checkpoints the model that ultimately is wrapped and returned after
                # Estimator.fit is called.
                _checkpoint_callback = checkpoint_callback
                if _checkpoint_callback:
                    _checkpoint_callback.filepath = ckpt_file
                else:
                    if is_dbfs and LooseVersion(
                            tf.__version__) < LooseVersion("2.0.0"):
                        # Because DBFS local file APIs does not support random write which is
                        # required by h5 format, save_weights_only=True is needed for switching
                        # to the TensorFlow SavedModel format.
                        _checkpoint_callback = k.callbacks.ModelCheckpoint(
                            ckpt_file, save_weights_only=True)
                    else:
                        _checkpoint_callback = k.callbacks.ModelCheckpoint(
                            ckpt_file)
                callbacks.append(_checkpoint_callback)

                if remote_store.saving_runs:
                    tb_callback = None
                    for i, c in enumerate(callbacks):
                        if isinstance(c, k.callbacks.TensorBoard):
                            tb_callback = c
                            print(
                                f"Found TensorBoard callback, updating log_dir to {logs_dir}"
                            )
                            tb_callback.log_dir = logs_dir
                            break
                    if tb_callback:
                        # Rather than a possibly arbitrary order, we always place the TensorBoard
                        # callback right before the SyncCallback
                        callbacks.pop(i)
                    callbacks.append(tb_callback
                                     or k.callbacks.TensorBoard(logs_dir))
                    callbacks.append(
                        SyncCallback(run_output_dir, remote_store.sync, k))

            if train_steps_per_epoch is None:
                steps_per_epoch = int(
                    math.ceil(train_rows / batch_size / hvd.size()))
            else:
                steps_per_epoch = train_steps_per_epoch

            if validation_steps_per_epoch is None:
                # math.ceil because if val_rows is smaller than val_batch_size we still get the at least
                # one step. float(val_rows) because val_rows/val_batch_size evaluates to zero before
                # math.ceil
                validation_steps = int(math.ceil(float(val_rows) / val_batch_size / hvd.size())) \
                    if should_validate else None
            else:
                validation_steps = validation_steps_per_epoch

            schema_fields = feature_columns + label_columns
            if sample_weight_col:
                schema_fields.append(sample_weight_col)

            if verbose:
                print(
                    f"Training parameters: Epochs: {epochs}, Scaled lr: {scaled_lr}, Shuffle size: {shuffle_buffer_size}\n"
                    f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {steps_per_epoch}\n"
                    f"Val rows: {val_rows}, Val batch size: {val_batch_size}, Val_steps_per_epoch: {validation_steps}\n"
                    f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n"
                )
            # In general, make_batch_reader is faster than make_reader for reading the dataset.
            # However, we found out that make_reader performs data transformations much faster than
            # make_batch_reader with parallel worker processes. Therefore, the default reader
            # we choose is make_batch_reader unless there are data transformations.
            reader_factory_kwargs = dict()
            if transform_spec:
                reader_factory = make_reader
                reader_factory_kwargs['pyarrow_serialize'] = True
                is_batch_reader = False
            else:
                reader_factory = make_batch_reader
                is_batch_reader = True

            with reader_factory(
                    remote_store.train_data_path,
                    num_epochs=1,
                    cur_shard=hvd.rank(),
                    reader_pool_type=reader_pool_type,
                    workers_count=train_reader_worker_count,
                    shard_count=hvd.size(),
                    hdfs_driver=PETASTORM_HDFS_DRIVER,
                    schema_fields=schema_fields,
                    transform_spec=transform_spec,
                    storage_options=storage_options,
                    # Don't shuffle row groups if shuffle_buffer_size is 0 (non-shuffle case).
                    shuffle_row_groups=True
                    if shuffle_buffer_size > 0 else False,
                    **reader_factory_kwargs) as train_reader:
                with reader_factory(remote_store.val_data_path,
                                    num_epochs=1,
                                    cur_shard=hvd.rank(),
                                    reader_pool_type=reader_pool_type,
                                    workers_count=val_reader_worker_count,
                                    shard_count=hvd.size(),
                                    hdfs_driver=PETASTORM_HDFS_DRIVER,
                                    schema_fields=schema_fields,
                                    transform_spec=transform_spec,
                                    storage_options=storage_options,
                                    shuffle_row_groups=False,
                                    **reader_factory_kwargs) \
                    if should_validate else empty_batch_reader() as val_reader:

                    train_data = make_dataset(
                        train_reader,
                        batch_size,
                        shuffle_buffer_size,
                        is_batch_reader,
                        shuffle=True if shuffle_buffer_size > 0 else False,
                        cache=inmemory_cache_all)
                    val_data = make_dataset(val_reader, val_batch_size, shuffle_buffer_size,
                                            is_batch_reader, shuffle=False, cache=inmemory_cache_all) \
                        if val_reader else None

                    history = fit(model, train_data, val_data, steps_per_epoch,
                                  validation_steps, callbacks, verbose)

            # Dataset API usage currently displays a wall of errors upon termination.
            # This global model registration ensures clean termination.
            # Tracked in https://github.com/tensorflow/tensorflow/issues/24570
            globals()['_DATASET_FINALIZATION_HACK'] = model

            if hvd.rank() == 0:
                if is_dbfs:
                    if LooseVersion(tf.__version__) < LooseVersion("2.0.0"):
                        model.load_weights(ckpt_file)
                    else:
                        # needs to be deserialized in the with scope
                        with k.utils.custom_object_scope(custom_objects):
                            model = k.models.load_model(ckpt_file)
                    serialized_model = keras_utils.serialize_model(model)
                else:
                    if LooseVersion(tf.__version__) >= LooseVersion("2.0.0"):
                        with k.utils.custom_object_scope(custom_objects):
                            model = k.models.load_model(ckpt_file)
                        serialized_model = keras_utils.serialize_model(model)
                    else:
                        with open(ckpt_file, 'rb') as f:
                            serialized_model = codec.dumps_base64(f.read())

                return history.history, serialized_model, hvd.size()
Exemplo n.º 23
0
def test_transform_function_batched_auto_deleting_column(scalar_dataset):
    with make_batch_reader(scalar_dataset.url,
                           transform_spec=TransformSpec(
                               removed_fields=['float64'])) as reader:
        actual = next(reader)
        assert 'float64' not in actual._fields
Exemplo n.º 24
0
def test_reader_engine_v2_with_transform_is_not_supported(
        synthetic_dataset, reader_factory):
    with pytest.raises(NotImplementedError):
        make_reader(synthetic_dataset.url,
                    reader_engine='experimental_reader_v2',
                    transform_spec=TransformSpec(lambda x: x))
Exemplo n.º 25
0
    def train(serialized_model, train_rows, val_rows, avg_row_size):
        from petastorm import TransformSpec, make_reader, make_batch_reader

        k = get_keras()
        k.backend.set_floatx(floatx)

        hvd = get_horovod()
        hvd.init()
        pin_gpu(hvd, tf, k)

        if not user_shuffle_buffer_size:
            shuffle_buffer_size = calculate_shuffle_buffer_size(
                hvd, avg_row_size, train_rows / hvd.size())
        else:
            shuffle_buffer_size = user_shuffle_buffer_size

        # needs to be deserialized in the with scope
        with k.utils.custom_object_scope(custom_objects):
            model = deserialize_keras_model(serialized_model,
                                            lambda x: hvd.load_model(x))

        # Horovod: adjust learning rate based on number of processes.
        k.backend.set_value(
            model.optimizer.lr,
            k.backend.get_value(model.optimizer.lr) * hvd.size())

        # Verbose mode 1 will print a progress bar
        verbose = user_verbose if hvd.rank() == 0 else 0

        transform_spec = None
        if transformation:
            transform_spec = TransformSpec(transformation)

        with remote_store.get_local_output_dir() as run_output_dir:
            callbacks = [
                # Horovod: broadcast initial variable states from rank 0 to all other processes.
                # This is necessary to ensure consistent initialization of all workers when
                # training is started with random weights or restored from a checkpoint.
                hvd.callbacks.BroadcastGlobalVariablesCallback(root_rank=0),

                # Horovod: average metrics among workers at the end of every epoch.
                #
                # Note: This callback must be in the list before the ReduceLROnPlateau,
                # TensorBoard, or other metrics-based callbacks.
                hvd.callbacks.MetricAverageCallback(),
            ]
            callbacks += user_callbacks

            # Horovod: save checkpoints only on the first worker to prevent other workers from
            # corrupting them.
            if hvd.rank() == 0:
                ckpt_file = os.path.join(run_output_dir,
                                         remote_store.checkpoint_filename)
                logs_dir = os.path.join(run_output_dir,
                                        remote_store.logs_subdir)

                # This callback checkpoints the model that ultimately is wrapped and returned after
                # Estimator.fit is called.
                _checkpoint_callback = checkpoint_callback
                if _checkpoint_callback:
                    _checkpoint_callback.filepath = ckpt_file
                else:
                    if is_dbfs and LooseVersion(
                            tf.__version__) < LooseVersion("2.0.0"):
                        # Because DBFS local file APIs does not support random write which is
                        # required by h5 format, save_weights_only=True is needed for switching
                        # to the TensorFlow SavedModel format.
                        _checkpoint_callback = k.callbacks.ModelCheckpoint(
                            ckpt_file, save_weights_only=True)
                    else:
                        _checkpoint_callback = k.callbacks.ModelCheckpoint(
                            ckpt_file)
                callbacks.append(_checkpoint_callback)

                if remote_store.saving_runs:
                    callbacks.append(k.callbacks.TensorBoard(logs_dir))
                    callbacks.append(
                        SyncCallback(run_output_dir, remote_store.sync, k))

            if train_steps_per_epoch is None:
                steps_per_epoch = int(
                    math.ceil(train_rows / batch_size / hvd.size()))
            else:
                steps_per_epoch = train_steps_per_epoch

            if validation_steps_per_epoch is None:
                # math.ceil because if val_rows is smaller than batch_size we still get the at least
                # one step. float(val_rows) because val_rows/batch_size evaluates to zero before
                # math.ceil
                validation_steps = int(math.ceil(float(val_rows) / batch_size / hvd.size())) \
                    if should_validate else None
            else:
                validation_steps = validation_steps_per_epoch

            schema_fields = feature_columns + label_columns
            if sample_weight_col:
                schema_fields.append(sample_weight_col)

            # In general, make_batch_reader is faster than make_reader for reading the dataset.
            # However, we found out that make_reader performs data transformations much faster than
            # make_batch_reader with parallel worker processes. Therefore, the default reader
            # we choose is make_batch_reader unless there are data transformations.
            reader_factory_kwargs = dict()
            if transform_spec:
                reader_factory = make_reader
                reader_factory_kwargs['pyarrow_serialize'] = True
                is_batch_reader = False
            else:
                reader_factory = make_batch_reader
                is_batch_reader = True

            # Petastorm: read data from the store with the correct shard for this rank
            # setting num_epochs=None will cause an infinite iterator
            # and enables ranks to perform training and validation with
            # unequal number of samples
            with reader_factory(remote_store.train_data_path,
                                num_epochs=None,
                                cur_shard=hvd.rank(),
                                reader_pool_type='process',
                                workers_count=train_reader_worker_count,
                                shard_count=hvd.size(),
                                hdfs_driver=PETASTORM_HDFS_DRIVER,
                                schema_fields=schema_fields,
                                transform_spec=transform_spec,
                                **reader_factory_kwargs) as train_reader:
                with reader_factory(remote_store.val_data_path,
                                    num_epochs=None,
                                    cur_shard=hvd.rank(),
                                    reader_pool_type='process',
                                    workers_count=val_reader_worker_count,
                                    shard_count=hvd.size(),
                                    hdfs_driver=PETASTORM_HDFS_DRIVER,
                                    schema_fields=schema_fields,
                                    transform_spec=transform_spec,
                                    **reader_factory_kwargs) \
                    if should_validate else empty_batch_reader() as val_reader:

                    train_data = make_dataset(train_reader,
                                              shuffle_buffer_size,
                                              is_batch_reader,
                                              shuffle=True)
                    val_data = make_dataset(val_reader, shuffle_buffer_size,
                                            is_batch_reader, shuffle=False) \
                        if val_reader else None

                    history = fit(model, train_data, val_data, steps_per_epoch,
                                  validation_steps, callbacks, verbose)

            # Dataset API usage currently displays a wall of errors upon termination.
            # This global model registration ensures clean termination.
            # Tracked in https://github.com/tensorflow/tensorflow/issues/24570
            globals()['_DATASET_FINALIZATION_HACK'] = model

            if hvd.rank() == 0:
                if is_dbfs:
                    if LooseVersion(tf.__version__) < LooseVersion("2.0.0"):
                        model.load_weights(ckpt_file)
                    else:
                        # needs to be deserialized in the with scope
                        with k.utils.custom_object_scope(custom_objects):
                            model = k.models.load_model(ckpt_file)
                    serialized_model = keras_utils.serialize_model(model)
                else:
                    with open(ckpt_file, 'rb') as f:
                        serialized_model = codec.dumps_base64(f.read())

                return history.history, serialized_model, hvd.size()
Exemplo n.º 26
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Pycarbon MNIST Example')
    default_dataset_url = 'file://{}'.format(DEFAULT_MNIST_DATA_PATH)
    parser.add_argument(
        '--dataset-url',
        type=str,
        default=default_dataset_url,
        metavar='S',
        help='hdfs:// or file:/// URL to the MNIST pycarbon dataset '
        '(default: %s)' % default_dataset_url)
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--all-epochs',
                        action='store_true',
                        default=False,
                        help='train all epochs before testing accuracy/loss')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--carbon-sdk-path',
                        type=str,
                        default=DEFAULT_CARBONSDK_PATH,
                        help='carbon sdk path')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    jnius_config.set_classpath(args.carbon_sdk_path)

    torch.manual_seed(args.seed)

    device = torch.device('cuda' if use_cuda else 'cpu')

    model = Net().to(device)
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum)

    # Configure loop and Reader epoch for illustrative purposes.
    # Typical training usage would use the `all_epochs` approach.
    #
    if args.all_epochs:
        # Run training across all the epochs before testing for accuracy
        loop_epochs = 1
        reader_epochs = args.epochs
    else:
        # Test training accuracy after each epoch
        loop_epochs = args.epochs
        reader_epochs = 1

    transform = TransformSpec(_transform_row, removed_fields=['idx'])

    # Instantiate each pycarbon Reader with a single thread, shuffle enabled, and appropriate epoch setting
    for epoch in range(1, loop_epochs + 1):
        with make_data_loader(make_reader('{}/train'.format(args.dataset_url),
                                          is_batch=False,
                                          num_epochs=reader_epochs,
                                          transform_spec=transform),
                              batch_size=args.batch_size) as train_loader:
            train(model, device, train_loader, args.log_interval, optimizer,
                  epoch)

        with make_data_loader(make_reader('{}/test'.format(args.dataset_url),
                                          is_batch=False,
                                          num_epochs=reader_epochs,
                                          transform_spec=transform),
                              batch_size=args.test_batch_size) as test_loader:
            evaluation(model, device, test_loader)
Exemplo n.º 27
0
    def train(serialized_model, train_rows, val_rows, avg_row_size):
        from petastorm import make_batch_reader, TransformSpec

        k = get_keras()
        k.backend.set_floatx(floatx)

        hvd = get_horovod()
        hvd.init()
        pin_gpu(hvd, tf, k)

        if not user_shuffle_buffer_size:
            shuffle_buffer_size = calculate_shuffle_buffer_size(
                hvd, avg_row_size, train_rows / hvd.size())
        else:
            shuffle_buffer_size = user_shuffle_buffer_size

        # needs to be deserialized in the with scope
        with k.utils.custom_object_scope(custom_objects):
            model = deserialize_keras_model(serialized_model,
                                            lambda x: hvd.load_model(x))

        # Horovod: adjust learning rate based on number of processes.
        k.backend.set_value(
            model.optimizer.lr,
            k.backend.get_value(model.optimizer.lr) * hvd.size())

        # Verbose mode 1 will print a progress bar
        verbose = user_verbose if hvd.rank() == 0 else 0

        transform_spec = None
        if transformation:
            transform_spec = TransformSpec(transformation)

        with remote_store.get_local_output_dir() as run_output_dir:
            callbacks = [
                # Horovod: broadcast initial variable states from rank 0 to all other processes.
                # This is necessary to ensure consistent initialization of all workers when
                # training is started with random weights or restored from a checkpoint.
                hvd.callbacks.BroadcastGlobalVariablesCallback(root_rank=0),

                # Horovod: average metrics among workers at the end of every epoch.
                #
                # Note: This callback must be in the list before the ReduceLROnPlateau,
                # TensorBoard, or other metrics-based callbacks.
                hvd.callbacks.MetricAverageCallback(),
            ]
            callbacks += user_callbacks

            # Horovod: save checkpoints only on the first worker to prevent other workers from
            # corrupting them.
            if hvd.rank() == 0:
                ckpt_file = os.path.join(run_output_dir,
                                         remote_store.checkpoint_filename)
                logs_dir = os.path.join(run_output_dir,
                                        remote_store.logs_subdir)

                callbacks.append(k.callbacks.ModelCheckpoint(ckpt_file))
                if remote_store.saving_runs:
                    callbacks.append(k.callbacks.TensorBoard(logs_dir))
                    callbacks.append(
                        SyncCallback(run_output_dir, remote_store.sync, k))

            if train_steps_per_epoch is None:
                steps_per_epoch = int(
                    math.ceil(train_rows / batch_size / hvd.size()))
            else:
                steps_per_epoch = train_steps_per_epoch

            if validation_steps_per_epoch is None:
                # math.ceil because if val_rows is smaller than batch_size we still get the at least
                # one step. float(val_rows) because val_rows/batch_size evaluates to zero before
                # math.ceil
                validation_steps = int(math.ceil(float(val_rows) / batch_size / hvd.size())) \
                    if should_validate else None
            else:
                validation_steps = validation_steps_per_epoch

            schema_fields = feature_columns + label_columns
            if sample_weight_col:
                schema_fields.append(sample_weight_col)

            # Petastorm: read data from the store with the correct shard for this rank
            # setting num_epochs=None will cause an infinite iterator and enables
            # ranks to perform training and validation with unequal number of
            # samples
            with make_batch_reader(
                    remote_store.train_data_path,
                    shuffle_row_groups=True,
                    num_epochs=None,
                    cur_shard=hvd.rank(),
                    shard_count=hvd.size(),
                    hdfs_driver=PETASTORM_HDFS_DRIVER,
                    schema_fields=schema_fields,
                    transform_spec=transform_spec) as train_reader:
                with make_batch_reader(remote_store.val_data_path,
                                       num_epochs=None,
                                       cur_shard=hvd.rank(),
                                       shard_count=hvd.size(),
                                       hdfs_driver=PETASTORM_HDFS_DRIVER,
                                       schema_fields=schema_fields,
                                       transform_spec=transform_spec) \
                        if should_validate else empty_batch_reader() as val_reader:

                    train_data = make_dataset(train_reader,
                                              shuffle_buffer_size,
                                              shuffle=True)
                    val_data = make_dataset(val_reader, shuffle_buffer_size, shuffle=False) \
                        if val_reader else None

                    history = fit(model, train_data, val_data, steps_per_epoch,
                                  validation_steps, callbacks, verbose)

            # Dataset API usage currently displays a wall of errors upon termination.
            # This global model registration ensures clean termination.
            # Tracked in https://github.com/tensorflow/tensorflow/issues/24570
            globals()['_DATASET_FINALIZATION_HACK'] = model

            if hvd.rank() == 0:
                with open(ckpt_file, 'rb') as f:
                    return history.history, codec.dumps_base64(
                        f.read()), hvd.size()
Exemplo n.º 28
0
    def set_data_loader(model,
                        data_path,
                        dataloader_attr,
                        reader_worker_count,
                        reader_pool_type,
                        shuffling_queue_capacity,
                        should_read=True,
                        name="",
                        limit_step_per_epoch=-1):
        from petastorm import TransformSpec, make_reader, make_batch_reader
        import horovod.torch as hvd

        is_loader_overridden = False
        if LooseVersion(pl.__version__) >= LooseVersion('1.0.0'):
            from pytorch_lightning.utilities.model_helpers import is_overridden
            is_loader_overridden = is_overridden(dataloader_attr, model)

        if not should_read or is_loader_overridden:
            print(f"Will not set data loader: {name}.")
            yield
            return

        print(
            f"Setting data loader {name} with limit_step_per_epoch={limit_step_per_epoch}"
        )

        transform_spec = TransformSpec(
            transformation) if transformation else None

        # In general, make_batch_reader is faster than make_reader for reading the dataset.
        # However, we found out that make_reader performs data transformations much faster than
        # make_batch_reader with parallel worker processes. Therefore, the default reader
        # we choose is make_batch_reader unless there are data transformations.
        reader_factory_kwargs = dict()
        if transform_spec:
            reader_factory = make_reader
            reader_factory_kwargs['pyarrow_serialize'] = True
        else:
            reader_factory = make_batch_reader

        # Petastorm: read data from the store with the correct shard for this rank
        # Setting num_epochs=None will cause an infinite iterator
        # and enables ranks to perform training and validation with
        # unequal number of samples
        # `loader_num_epochs` is None by default.
        # This doesn't apply to inmem dataloader, which loads whole reader into memory.
        with reader_factory(
                data_path,
                num_epochs=1 if inmemory_cache_all else loader_num_epochs,
                cur_shard=hvd.rank(),
                shard_count=hvd.size(),
                reader_pool_type=reader_pool_type,
                workers_count=reader_worker_count,
                hdfs_driver=PETASTORM_HDFS_DRIVER,
                schema_fields=schema_fields,
                transform_spec=transform_spec,
                storage_options=storage_options,
                **reader_factory_kwargs) as reader:

            def dataloader_fn():
                kwargs = dict(reader=reader,
                              batch_size=batch_size,
                              name=name,
                              limit_step_per_epoch=limit_step_per_epoch,
                              verbose=verbose)
                if inmemory_cache_all:
                    # Use inmem dataloader
                    kwargs['shuffle'] = shuffling_queue_capacity > 0
                    kwargs['num_epochs'] = epochs
                else:
                    kwargs[
                        'shuffling_queue_capacity'] = shuffling_queue_capacity
                return data_loader_cls(**kwargs)

            try:
                setattr(model, dataloader_attr, dataloader_fn)
                yield
            finally:
                setattr(model, dataloader_attr, None)
Exemplo n.º 29
0
    def train(serialized_model, optimizer_cls, model_opt_state_serialized,
              train_rows, val_rows, avg_row_size):
        from petastorm import TransformSpec, make_reader, make_batch_reader
        from petastorm.pytorch import BatchedDataLoader, InMemBatchedDataLoader
        import torch
        import horovod.torch as hvd

        # Deserializing objects
        model_opt_state = torch.load(model_opt_state_serialized)
        model = deserialize(serialized_model)

        if loss_fns_pre_train:
            loss_fns = loss_fns_pre_train
        if loss_constructors:
            local_vars = locals()
            loss_fns = [
                loss_constructor(**local_vars)
                for loss_constructor in loss_constructors
            ]

        # Horovod: initialize library.
        hvd.init()

        if not user_shuffle_buffer_size:
            shuffle_buffer_size = \
                calculate_shuffle_buffer_size(hvd, avg_row_size, train_rows / hvd.size())
        else:
            shuffle_buffer_size = user_shuffle_buffer_size

        cuda_available = torch.cuda.is_available()
        if cuda_available:
            # Horovod: pin GPU to local rank or the assigned GPU from spark.
            torch.cuda.set_device(
                _get_assigned_gpu_or_default(default=hvd.local_rank()))
            # Move model to GPU.
            model.cuda()

        # Optimizer object needs to be re-instantiated. Internally, it uses memory addresses of
        # objects as their identity and therefore it cannot be serialized and then
        # deserialized. The deserialized optimizer object stores the names of the parameters
        # with their old memory addresses but in reality those are different than the
        # reconstructed deserialized object and that creates problem.
        # Learning rate is a required parameters in SGD optimizer. It will be overridden with
        # load_state_dict.
        optimizer = optimizer_cls(model.parameters(), lr=1)
        optimizer_state = model_opt_state['optimizer']

        if last_checkpoint_state is not None:
            model.load_state_dict(last_checkpoint_state['model'])
            optimizer.load_state_dict(last_checkpoint_state['optimizer'])
        else:
            # scale the learning rate with the number of horovod workers
            for i in range(len(optimizer_state['param_groups'])):
                optimizer_state['param_groups'][i]['lr'] = \
                    optimizer_state['param_groups'][i]['lr'] * hvd.size()

            optimizer.load_state_dict(optimizer_state)

        # Horovod: broadcast parameters & optimizer state.
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)

        for group in optimizer.param_groups:
            for p in group['params']:
                if id(p) not in optimizer.state_dict()['state']:
                    p.grad = p.data.new(p.size()).zero_()
        optimizer.step()
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)

        dist_optimizer_args = dict(optimizer=optimizer,
                                   named_parameters=model.named_parameters())
        if gradient_compression:
            # Pass the compression arg only if it is specified by the user.
            dist_optimizer_args['compression'] = gradient_compression
        # Horovod: wrap optimizer with DistributedOptimizer.
        optimizer = hvd.DistributedOptimizer(**dist_optimizer_args)

        # This function takes the current optimizer and constructs a new optimizer with the
        # same state except with learning rate scaled down with the number of horovod workers.
        # This is important the retraining of the model. User may retrain the model with
        # different number of workers and we need the raw learning rate to adjust with the
        # new number of workers.

        transform_spec = None
        if transformation:
            transform_spec = TransformSpec(transformation)

        schema_fields = feature_columns + label_columns
        if sample_weight_col:
            schema_fields.append(sample_weight_col)

        if train_steps_per_epoch is None:
            steps_per_epoch = int(
                math.floor(float(train_rows) / batch_size / hvd.size()))
        else:
            steps_per_epoch = train_steps_per_epoch

        with remote_store.get_local_output_dir() as run_output_dir:
            logs_dir = os.path.join(run_output_dir, remote_store.logs_subdir)
            log_writer = SummaryWriter(logs_dir) if hvd.rank() == 0 else None
            ckpt_file = os.path.join(run_output_dir,
                                     remote_store.checkpoint_filename)

            def save_checkpoint():
                model.cpu()
                optimizer_with_scaled_down_lr = \
                    get_optimizer_with_unscaled_lr(hvd, optimizer, optimizer_cls, model)
                state = {
                    'model': model.state_dict(),
                    'optimizer': optimizer_with_scaled_down_lr.state_dict(),
                }
                torch.save(state, ckpt_file)
                if cuda_available:
                    model.cuda()

            # In general, make_batch_reader is faster than make_reader for reading the dataset.
            # However, we found out that make_reader performs data transformations much faster than
            # make_batch_reader with parallel worker processes. Therefore, the default reader
            # we choose is make_batch_reader unless there are data transformations.
            reader_factory = None
            reader_factory_kwargs = dict()
            if transform_spec:
                reader_factory = make_reader
                reader_factory_kwargs['pyarrow_serialize'] = True
            else:
                reader_factory = make_batch_reader

            # Petastorm: read data from the store with the correct shard for this rank
            # setting num_epochs=None will cause an infinite iterator
            # and enables ranks to perform training and validation with
            # unequal number of samples
            with reader_factory(remote_store.train_data_path,
                                num_epochs=None,
                                cur_shard=hvd.rank(),
                                reader_pool_type=reader_pool_type,
                                workers_count=train_reader_worker_count,
                                shard_count=hvd.size(),
                                hdfs_driver=PETASTORM_HDFS_DRIVER,
                                schema_fields=schema_fields,
                                transform_spec=transform_spec,
                                **reader_factory_kwargs) as train_reader:
                with reader_factory(remote_store.val_data_path,
                                    num_epochs=None,
                                    cur_shard=hvd.rank(),
                                    reader_pool_type=reader_pool_type,
                                    workers_count=val_reader_worker_count,
                                    shard_count=hvd.size(),
                                    hdfs_driver=PETASTORM_HDFS_DRIVER,
                                    schema_fields=schema_fields,
                                    transform_spec=transform_spec,
                                    **reader_factory_kwargs) \
                    if should_validate else empty_batch_reader() as val_reader:

                    if inmemory_cache_all:
                        # Petastorm introduced InMemBatchedDataLoader class in v0.11.0
                        train_loader = InMemBatchedDataLoader(
                            train_reader,
                            batch_size=batch_size,
                            num_epochs=epochs,
                            rows_capacity=steps_per_epoch * batch_size,
                            shuffle=True)
                    else:
                        train_loader = BatchedDataLoader(
                            train_reader,
                            batch_size=batch_size,
                            shuffling_queue_capacity=shuffle_buffer_size)
                    train_loader_iter = iter(train_loader)

                    def prepare_batch(row):
                        inputs = [
                            prepare_np_data(row[col].float(), col,
                                            metadata).reshape(shape) for col,
                            shape in zip(feature_columns, input_shapes)
                        ]
                        labels = [
                            prepare_np_data(row[col].float(), col, metadata)
                            for col in label_columns
                        ]

                        sample_weights = row.get(sample_weight_col, None)
                        if sample_weights is not None:
                            sample_weights = sample_weights.float()
                        if cuda_available:
                            inputs = [input.cuda() for input in inputs]
                            labels = [label.cuda() for label in labels]
                            if sample_weights is not None:
                                sample_weights = sample_weights.cuda()
                        return inputs, labels, sample_weights

                    def transform_outputs(outputs, labels):
                        if not isinstance(outputs, tuple) and not isinstance(
                                outputs, list):
                            outputs = [outputs]

                        # reshape labels to match the output shape of the model
                        if hasattr(outputs[0], 'shape'):
                            if label_shapes:
                                labels = [
                                    label.reshape(label_shape)
                                    for label, label_shape in zip(
                                        labels, label_shapes)
                                ]
                            else:
                                # If label_shapes parameter is not provided, reshape the label
                                # columns data to match the shape of the model output
                                labels = [
                                    label.reshape(output.shape)
                                    if output.shape.numel()
                                    == label.shape.numel() else label
                                    for label, output in zip(labels, outputs)
                                ]

                        return outputs, labels

                    def aggregate_metrics(stage, epoch, loss,
                                          metric_value_groups):
                        all_metric_groups_values = get_metric_avgs(
                            metric_value_groups)
                        if remote_store.saving_runs:
                            write_metrics_summary(stage, epoch, loss,
                                                  all_metric_groups_values,
                                                  log_writer)
                        return {
                            loss.name: loss.avg.item(),
                            'all_metrics': all_metric_groups_values
                        }

                    def loss_fn(outputs, labels, sample_weights):
                        loss = calculate_loss(outputs, labels, loss_weights,
                                              loss_fns, sample_weights)
                        return loss

                    def print_metrics(batch_idx, loss, metric_value_groups,
                                      phase):
                        if user_verbose > 0 and hvd.rank() == 0 and \
                                batch_idx % METRIC_PRINT_FREQUENCY == 0:
                            print(
                                "{phase}\tepoch:\t{epoch}\tstep\t{batch_idx}:\t{metrics}"
                                .format(phase=phase,
                                        epoch=epoch,
                                        batch_idx=batch_idx,
                                        metrics=aggregate_metrics(
                                            phase, epoch, loss,
                                            metric_value_groups)))

                    def _train(epoch):
                        model.train()
                        train_loss = metric_cls('loss', hvd)
                        metric_value_groups = construct_metric_value_holders(
                            metric_cls, metric_fn_groups, label_columns, hvd)

                        # iterate on one epoch
                        for batch_idx in range(steps_per_epoch):
                            row = next(train_loader_iter)
                            inputs, labels, sample_weights = prepare_batch(row)
                            outputs, loss = train_minibatch(
                                model, optimizer, transform_outputs, loss_fn,
                                inputs, labels, sample_weights)
                            update_metrics(metric_value_groups, outputs,
                                           labels)
                            train_loss.update(loss)
                            print_metrics(batch_idx, train_loss,
                                          metric_value_groups, 'train')

                        return aggregate_metrics('train', epoch, train_loss,
                                                 metric_value_groups)

                    if should_validate:
                        if validation_steps_per_epoch is None:
                            validation_steps = int(
                                math.ceil(
                                    float(val_rows) / val_batch_size /
                                    hvd.size()))
                        else:
                            validation_steps = validation_steps_per_epoch

                        if inmemory_cache_all:
                            # Petastorm introduced InMemBatchedDataLoader class in v0.11.0
                            val_loader = InMemBatchedDataLoader(
                                val_reader,
                                batch_size=val_batch_size,
                                num_epochs=epochs,
                                rows_capacity=validation_steps *
                                val_batch_size,
                                shuffle=False)
                        else:
                            val_loader = BatchedDataLoader(
                                val_reader,
                                batch_size=val_batch_size,
                                shuffling_queue_capacity=0)
                        val_loader_iter = iter(val_loader)

                        def _validate(epoch):
                            model.eval()
                            val_loss = metric_cls('loss', hvd)

                            metric_value_groups = construct_metric_value_holders(
                                metric_cls, metric_fn_groups, label_columns,
                                hvd)

                            # iterate on one epoch
                            for batch_idx in range(validation_steps):
                                row = next(val_loader_iter)
                                inputs, labels, sample_weights = prepare_batch(
                                    row)

                                outputs = model(*inputs)
                                outputs, labels = transform_outputs(
                                    outputs, labels)

                                loss = calculate_loss(outputs, labels,
                                                      loss_weights, loss_fns,
                                                      sample_weights)
                                val_loss.update(loss)
                                update_metrics(metric_value_groups, outputs,
                                               labels)
                                print_metrics(batch_idx, val_loss,
                                              metric_value_groups, 'val')
                            return aggregate_metrics('val', epoch, val_loss,
                                                     metric_value_groups)

                    history = []
                    for epoch in range(epochs):
                        epoch_metrics = {
                            'epoch': epoch,
                            'train': _train(epoch)
                        }

                        if should_validate:
                            epoch_metrics['validation'] = _validate(epoch)

                        if user_verbose > 0:
                            pdt_dt = datetime.now(timezone.utc)
                            pdt_time_str = pdt_dt.strftime(
                                "%Y-%b-%d %H:%M:%S UTC")
                            print(pdt_time_str, epoch_metrics)

                        history.append(epoch_metrics)
                        if hvd.rank() == 0:
                            # Save model after every epoch
                            save_checkpoint()
                            if remote_store.saving_runs:
                                remote_store.sync(run_output_dir)

            if hvd.rank() == 0:
                best_checkpoint = torch.load(ckpt_file)
                serialized_checkpoint = io.BytesIO()
                torch.save(best_checkpoint, serialized_checkpoint)
                serialized_checkpoint.seek(0)
                return history, serialized_checkpoint