Ejemplo n.º 1
0
def train_and_evaluate(lr=0.001, weight_decay=2, batch_size=BATCH_SIZE):
    hvd.init()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = FF_NN(num_features=num_features,
                  num_classes=2,
                  drop_prob=drop_prob,
                  embedding_table_shapes=embeddings,
                  num_continuous=num_continuous,
                  emb_dropout=emb_dropout)
    criterion = torch.nn.CrossEntropyLoss()

    # Only parameters of final layer are being optimized.
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=0.9,
                                weight_decay=weight_decay)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=7,
                                                       gamma=0.1)

    with BatchedDataLoader(make_batch_reader(dataset_url_or_urls='file:///dbfs/tmp/assembled_t',
                                      num_epochs=None,
                                      transform_spec=None, shuffle_row_groups=False, workers_count=8,
                                            cur_shard=hvd.rank(), shard_count=hvd.size()), batch_size=BATCH_SIZE) as train_dataloader, \
         BatchedDataLoader(make_batch_reader(dataset_url_or_urls='file:///dbfs/tmp/assembled_v',
                                      num_epochs=None, transform_spec=None, shuffle_row_groups=False, workers_count=8,
                                            cur_shard=hvd.rank(), shard_count=hvd.size()), batch_size=BATCH_SIZE) as val_dataloader:

        train_dataloader_iter = iter(train_dataloader)
        steps_per_epoch = train_df_size // BATCH_SIZE

        val_dataloader_iter = iter(val_dataloader)
        validation_steps = max(1, val_df_size // (BATCH_SIZE))

        for epoch in range(NUM_EPOCHS):
            print('Epoch {}/{}'.format(epoch + 1, NUM_EPOCHS))
            print('-' * 10)

            train_loss, train_acc = train_one_epoch(model, optimizer,
                                                    exp_lr_scheduler,
                                                    train_dataloader_iter,
                                                    steps_per_epoch, epoch,
                                                    device)
            val_loss, val_acc, val_f1 = evaluate(model, val_dataloader_iter,
                                                 validation_steps, device)

    return val_loss
Ejemplo n.º 2
0
    def _iterate(self):
        # Reset the reader if needed.
        if self.reader.last_row_consumed:
            self._print_verbose(f"[{self.name}]: Resetting Petastorm reader for {self.reader.dataset.paths}")
            self.reader.reset()

        # Re-create the data loader for each iteration. This is needed becasue there may be
        # some left-over data from last epoch which can cause petastorm's BatchedDataLoader
        # fail to start new iteration. To workaround the issue, we have to re-create the data
        # loader at each new iterration starts.
        data_loader = BatchedDataLoader(
            self.reader,
            batch_size=self.batch_size,
            shuffling_queue_capacity=self.shuffling_queue_capacity,
        )

        num_steps = 0

        self._print_verbose(f"[{self.name}]: Start to generate batch data. limit_step_per_epoch={self.limit_step_per_epoch}")

        for batch in data_loader:
            if num_steps == self.limit_step_per_epoch:
                self._print_verbose(f"[{self.name}]: Reach limit_step_per_epoch. Stop at step {num_steps}.")
                break

            num_steps += 1
            yield batch
Ejemplo n.º 3
0
def test_mem_cache_num_epochs_without_mem_cache_error(
        two_columns_non_petastorm_dataset):
    error_string = "num_epochs should not be specified when inmemory_cache_all is not enabled."
    with make_batch_reader(two_columns_non_petastorm_dataset.url,
                           num_epochs=1) as reader:
        with pytest.raises(ValueError, match=error_string):
            BatchedDataLoader(reader, num_epochs=2)
Ejemplo n.º 4
0
def test_simple_read_batched(synthetic_dataset, reader_factory):
    with BatchedDataLoader(
            reader_factory(
                synthetic_dataset.url,
                schema_fields=TORCH_BATCHABLE_FIELDS,
                transform_spec=TransformSpec(_sensor_name_to_int))) as loader:
        _check_simple_reader(loader, synthetic_dataset.data,
                             TORCH_BATCHABLE_FIELDS - {TestSchema.sensor_name})
Ejemplo n.º 5
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        if self.reader.num_epochs is not None:
            raise ValueError("Need to set num_epochs as None in reader.")

        self.data_loader = BatchedDataLoader(
            self.reader,
            batch_size=self.batch_size,
            shuffling_queue_capacity=self.shuffling_queue_capacity)
        self.iterator = iter(self.data_loader)
Ejemplo n.º 6
0
def train_and_evaluate(lr=0.016):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = FF_NN(num_features=num_features,
                  num_classes=2,
                  drop_prob=drop_prob,
                  embedding_table_shapes=embeddings,
                  num_continuous=num_continuous,
                  emb_dropout=emb_dropout)
    model.load_state_dict(
        torch.load(
            '/dbfs/ml/horovod_pytorch/take2/PetaFlights/checkpoint-2.pth.tar')
        ['model'])

    # Only parameters of final layer are being optimized.
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=0.9,
                                weight_decay=1)

    # Decay LR by a factor of 0.1 every 3 epochs
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=5,
                                                       gamma=0.1)

    with BatchedDataLoader(make_batch_reader(
            dataset_url_or_urls='file:///dbfs/ml/tmp/assembled_test',
            num_epochs=None,
            transform_spec=None,
            shuffle_row_groups=False),
                           batch_size=BATCH_SIZE) as val_dataloader:

        val_dataloader_iter = iter(val_dataloader)
        validation_steps = val_df_size // BATCH_SIZE

        for epoch in range(NUM_EPOCHS):
            print('Epoch {}/{}'.format(epoch + 1, NUM_EPOCHS))
            print('-' * 10)

            val_loss, val_acc, val_f1 = evaluate(model, val_dataloader_iter,
                                                 validation_steps, device)

    return val_loss, val_acc, val_f1
Ejemplo n.º 7
0
def main(device='cpu', batch=1000, dim=64):
    print("Testing DataLoader on cpu")
    reader = DummyReader(int(batch), int(dim))

    for batch_size in [10, 100, 1000]:
        iterations = 100
        loader = DataLoader(reader, shuffling_queue_capacity=batch_size * 10, batch_size=batch_size)
        it = iter(loader)

        # Warmup
        for _ in range(iterations):
            next(it)
        print("Done warming up")

        tstart = time.time()
        for _ in range(iterations):
            next(it)
        print("Samples per second for batch {}: {:.4g}".format(
            batch_size, (iterations * batch_size) / (time.time() - tstart)))

    print("Testing BatchedDataLoader on", device)
    for batch_size in [10, 100, 1000, 100000]:
        iterations = 100
        loader = BatchedDataLoader(reader, shuffling_queue_capacity=batch_size * 10, batch_size=batch_size,
                                   transform_fn=partial(torch.as_tensor, device=device))
        it = iter(loader)

        # Warmup
        for _ in range(iterations):
            next(it)
        print("Done warming up")

        tstart = time.time()
        for _ in range(iterations):
            next(it)
        print("Samples per second for batch {}: {:.4g}".format(
            batch_size, (iterations * batch_size) / (time.time() - tstart)))
Ejemplo n.º 8
0
    def train(serialized_model, optimizer_cls, model_opt_state_serialized,
              train_rows, val_rows, avg_row_size):
        from petastorm import TransformSpec, make_reader, make_batch_reader
        from petastorm.pytorch import BatchedDataLoader, InMemBatchedDataLoader
        import torch
        import horovod.torch as hvd

        # Deserializing objects
        model_opt_state = torch.load(model_opt_state_serialized)
        model = deserialize(serialized_model)

        if loss_fns_pre_train:
            loss_fns = loss_fns_pre_train
        if loss_constructors:
            local_vars = locals()
            loss_fns = [
                loss_constructor(**local_vars)
                for loss_constructor in loss_constructors
            ]

        # Horovod: initialize library.
        hvd.init()

        if not user_shuffle_buffer_size:
            shuffle_buffer_size = \
                calculate_shuffle_buffer_size(hvd, avg_row_size, train_rows / hvd.size())
        else:
            shuffle_buffer_size = user_shuffle_buffer_size

        cuda_available = torch.cuda.is_available()
        if cuda_available:
            # Horovod: pin GPU to local rank or the assigned GPU from spark.
            torch.cuda.set_device(
                _get_assigned_gpu_or_default(default=hvd.local_rank()))
            # Move model to GPU.
            model.cuda()

        # Optimizer object needs to be re-instantiated. Internally, it uses memory addresses of
        # objects as their identity and therefore it cannot be serialized and then
        # deserialized. The deserialized optimizer object stores the names of the parameters
        # with their old memory addresses but in reality those are different than the
        # reconstructed deserialized object and that creates problem.
        # Learning rate is a required parameters in SGD optimizer. It will be overridden with
        # load_state_dict.
        optimizer = optimizer_cls(model.parameters(), lr=1)
        optimizer_state = model_opt_state['optimizer']

        if last_checkpoint_state is not None:
            model.load_state_dict(last_checkpoint_state['model'])
            optimizer.load_state_dict(last_checkpoint_state['optimizer'])
        else:
            # scale the learning rate with the number of horovod workers
            for i in range(len(optimizer_state['param_groups'])):
                optimizer_state['param_groups'][i]['lr'] = \
                    optimizer_state['param_groups'][i]['lr'] * hvd.size()

            optimizer.load_state_dict(optimizer_state)

        # Horovod: broadcast parameters & optimizer state.
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)

        for group in optimizer.param_groups:
            for p in group['params']:
                if id(p) not in optimizer.state_dict()['state']:
                    p.grad = p.data.new(p.size()).zero_()
        optimizer.step()
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)

        dist_optimizer_args = dict(optimizer=optimizer,
                                   named_parameters=model.named_parameters())
        if gradient_compression:
            # Pass the compression arg only if it is specified by the user.
            dist_optimizer_args['compression'] = gradient_compression
        # Horovod: wrap optimizer with DistributedOptimizer.
        optimizer = hvd.DistributedOptimizer(**dist_optimizer_args)

        # This function takes the current optimizer and constructs a new optimizer with the
        # same state except with learning rate scaled down with the number of horovod workers.
        # This is important the retraining of the model. User may retrain the model with
        # different number of workers and we need the raw learning rate to adjust with the
        # new number of workers.

        transform_spec = None
        if transformation:
            transform_spec = TransformSpec(transformation)

        schema_fields = feature_columns + label_columns
        if sample_weight_col:
            schema_fields.append(sample_weight_col)

        if train_steps_per_epoch is None:
            steps_per_epoch = int(
                math.floor(float(train_rows) / batch_size / hvd.size()))
        else:
            steps_per_epoch = train_steps_per_epoch

        with remote_store.get_local_output_dir() as run_output_dir:
            logs_dir = os.path.join(run_output_dir, remote_store.logs_subdir)
            log_writer = SummaryWriter(logs_dir) if hvd.rank() == 0 else None
            ckpt_file = os.path.join(run_output_dir,
                                     remote_store.checkpoint_filename)

            def save_checkpoint():
                model.cpu()
                optimizer_with_scaled_down_lr = \
                    get_optimizer_with_unscaled_lr(hvd, optimizer, optimizer_cls, model)
                state = {
                    'model': model.state_dict(),
                    'optimizer': optimizer_with_scaled_down_lr.state_dict(),
                }
                torch.save(state, ckpt_file)
                if cuda_available:
                    model.cuda()

            # In general, make_batch_reader is faster than make_reader for reading the dataset.
            # However, we found out that make_reader performs data transformations much faster than
            # make_batch_reader with parallel worker processes. Therefore, the default reader
            # we choose is make_batch_reader unless there are data transformations.
            reader_factory = None
            reader_factory_kwargs = dict()
            if transform_spec:
                reader_factory = make_reader
                reader_factory_kwargs['pyarrow_serialize'] = True
            else:
                reader_factory = make_batch_reader

            # Petastorm: read data from the store with the correct shard for this rank
            # setting num_epochs=None will cause an infinite iterator
            # and enables ranks to perform training and validation with
            # unequal number of samples
            with reader_factory(remote_store.train_data_path,
                                num_epochs=None,
                                cur_shard=hvd.rank(),
                                reader_pool_type=reader_pool_type,
                                workers_count=train_reader_worker_count,
                                shard_count=hvd.size(),
                                hdfs_driver=PETASTORM_HDFS_DRIVER,
                                schema_fields=schema_fields,
                                transform_spec=transform_spec,
                                **reader_factory_kwargs) as train_reader:
                with reader_factory(remote_store.val_data_path,
                                    num_epochs=None,
                                    cur_shard=hvd.rank(),
                                    reader_pool_type=reader_pool_type,
                                    workers_count=val_reader_worker_count,
                                    shard_count=hvd.size(),
                                    hdfs_driver=PETASTORM_HDFS_DRIVER,
                                    schema_fields=schema_fields,
                                    transform_spec=transform_spec,
                                    **reader_factory_kwargs) \
                    if should_validate else empty_batch_reader() as val_reader:

                    if inmemory_cache_all:
                        # Petastorm introduced InMemBatchedDataLoader class in v0.11.0
                        train_loader = InMemBatchedDataLoader(
                            train_reader,
                            batch_size=batch_size,
                            num_epochs=epochs,
                            rows_capacity=steps_per_epoch * batch_size,
                            shuffle=True)
                    else:
                        train_loader = BatchedDataLoader(
                            train_reader,
                            batch_size=batch_size,
                            shuffling_queue_capacity=shuffle_buffer_size)
                    train_loader_iter = iter(train_loader)

                    def prepare_batch(row):
                        inputs = [
                            prepare_np_data(row[col].float(), col,
                                            metadata).reshape(shape) for col,
                            shape in zip(feature_columns, input_shapes)
                        ]
                        labels = [
                            prepare_np_data(row[col].float(), col, metadata)
                            for col in label_columns
                        ]

                        sample_weights = row.get(sample_weight_col, None)
                        if sample_weights is not None:
                            sample_weights = sample_weights.float()
                        if cuda_available:
                            inputs = [input.cuda() for input in inputs]
                            labels = [label.cuda() for label in labels]
                            if sample_weights is not None:
                                sample_weights = sample_weights.cuda()
                        return inputs, labels, sample_weights

                    def transform_outputs(outputs, labels):
                        if not isinstance(outputs, tuple) and not isinstance(
                                outputs, list):
                            outputs = [outputs]

                        # reshape labels to match the output shape of the model
                        if hasattr(outputs[0], 'shape'):
                            if label_shapes:
                                labels = [
                                    label.reshape(label_shape)
                                    for label, label_shape in zip(
                                        labels, label_shapes)
                                ]
                            else:
                                # If label_shapes parameter is not provided, reshape the label
                                # columns data to match the shape of the model output
                                labels = [
                                    label.reshape(output.shape)
                                    if output.shape.numel()
                                    == label.shape.numel() else label
                                    for label, output in zip(labels, outputs)
                                ]

                        return outputs, labels

                    def aggregate_metrics(stage, epoch, loss,
                                          metric_value_groups):
                        all_metric_groups_values = get_metric_avgs(
                            metric_value_groups)
                        if remote_store.saving_runs:
                            write_metrics_summary(stage, epoch, loss,
                                                  all_metric_groups_values,
                                                  log_writer)
                        return {
                            loss.name: loss.avg.item(),
                            'all_metrics': all_metric_groups_values
                        }

                    def loss_fn(outputs, labels, sample_weights):
                        loss = calculate_loss(outputs, labels, loss_weights,
                                              loss_fns, sample_weights)
                        return loss

                    def print_metrics(batch_idx, loss, metric_value_groups,
                                      phase):
                        if user_verbose > 0 and hvd.rank() == 0 and \
                                batch_idx % METRIC_PRINT_FREQUENCY == 0:
                            print(
                                "{phase}\tepoch:\t{epoch}\tstep\t{batch_idx}:\t{metrics}"
                                .format(phase=phase,
                                        epoch=epoch,
                                        batch_idx=batch_idx,
                                        metrics=aggregate_metrics(
                                            phase, epoch, loss,
                                            metric_value_groups)))

                    def _train(epoch):
                        model.train()
                        train_loss = metric_cls('loss', hvd)
                        metric_value_groups = construct_metric_value_holders(
                            metric_cls, metric_fn_groups, label_columns, hvd)

                        # iterate on one epoch
                        for batch_idx in range(steps_per_epoch):
                            row = next(train_loader_iter)
                            inputs, labels, sample_weights = prepare_batch(row)
                            outputs, loss = train_minibatch(
                                model, optimizer, transform_outputs, loss_fn,
                                inputs, labels, sample_weights)
                            update_metrics(metric_value_groups, outputs,
                                           labels)
                            train_loss.update(loss)
                            print_metrics(batch_idx, train_loss,
                                          metric_value_groups, 'train')

                        return aggregate_metrics('train', epoch, train_loss,
                                                 metric_value_groups)

                    if should_validate:
                        if validation_steps_per_epoch is None:
                            validation_steps = int(
                                math.ceil(
                                    float(val_rows) / val_batch_size /
                                    hvd.size()))
                        else:
                            validation_steps = validation_steps_per_epoch

                        if inmemory_cache_all:
                            # Petastorm introduced InMemBatchedDataLoader class in v0.11.0
                            val_loader = InMemBatchedDataLoader(
                                val_reader,
                                batch_size=val_batch_size,
                                num_epochs=epochs,
                                rows_capacity=validation_steps *
                                val_batch_size,
                                shuffle=False)
                        else:
                            val_loader = BatchedDataLoader(
                                val_reader,
                                batch_size=val_batch_size,
                                shuffling_queue_capacity=0)
                        val_loader_iter = iter(val_loader)

                        def _validate(epoch):
                            model.eval()
                            val_loss = metric_cls('loss', hvd)

                            metric_value_groups = construct_metric_value_holders(
                                metric_cls, metric_fn_groups, label_columns,
                                hvd)

                            # iterate on one epoch
                            for batch_idx in range(validation_steps):
                                row = next(val_loader_iter)
                                inputs, labels, sample_weights = prepare_batch(
                                    row)

                                outputs = model(*inputs)
                                outputs, labels = transform_outputs(
                                    outputs, labels)

                                loss = calculate_loss(outputs, labels,
                                                      loss_weights, loss_fns,
                                                      sample_weights)
                                val_loss.update(loss)
                                update_metrics(metric_value_groups, outputs,
                                               labels)
                                print_metrics(batch_idx, val_loss,
                                              metric_value_groups, 'val')
                            return aggregate_metrics('val', epoch, val_loss,
                                                     metric_value_groups)

                    history = []
                    for epoch in range(epochs):
                        epoch_metrics = {
                            'epoch': epoch,
                            'train': _train(epoch)
                        }

                        if should_validate:
                            epoch_metrics['validation'] = _validate(epoch)

                        if user_verbose > 0:
                            pdt_dt = datetime.now(timezone.utc)
                            pdt_time_str = pdt_dt.strftime(
                                "%Y-%b-%d %H:%M:%S UTC")
                            print(pdt_time_str, epoch_metrics)

                        history.append(epoch_metrics)
                        if hvd.rank() == 0:
                            # Save model after every epoch
                            save_checkpoint()
                            if remote_store.saving_runs:
                                remote_store.sync(run_output_dir)

            if hvd.rank() == 0:
                best_checkpoint = torch.load(ckpt_file)
                serialized_checkpoint = io.BytesIO()
                torch.save(best_checkpoint, serialized_checkpoint)
                serialized_checkpoint.seek(0)
                return history, serialized_checkpoint
Ejemplo n.º 9
0
def train_and_evaluate_hvd(lr=0.016):
    hvd.init()  # Initialize Horovod.

    # Horovod: pin GPU to local rank.
    if torch.cuda.is_available():
        torch.cuda.set_device(hvd.local_rank())
        device = torch.cuda.current_device()
    else:
        device = torch.device("cpu")

    model = FF_NN(num_features=num_features,
                  num_classes=2,
                  drop_prob=drop_prob,
                  embedding_table_shapes=embeddings,
                  num_continuous=num_continuous,
                  emb_dropout=emb_dropout)

    # Effective batch size in synchronous distributed training is scaled by the number of workers.
    # An increase in learning rate compensates for the increased batch size.
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr * hvd.size(),
                                momentum=0.9)

    # Broadcast initial parameters so all workers start with the same parameters.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # Wrap the optimizer with Horovod's DistributedOptimizer.
    optimizer_hvd = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())

    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_hvd,
                                                       step_size=5,
                                                       gamma=0.1)

    with BatchedDataLoader(make_batch_reader(dataset_url_or_urls='file:///dbfs/ml/tmp/assembled_t',
                                      num_epochs=None, cur_shard=hvd.rank(), shard_count=hvd.size(),
                                      transform_spec=None, shuffle_row_groups=False, workers_count=8), batch_size=BATCH_SIZE) as train_dataloader, \
         BatchedDataLoader(make_batch_reader(dataset_url_or_urls='file:///dbfs/ml/tmp/assembled_v',
                                      num_epochs=None, cur_shard=hvd.rank(), shard_count=hvd.size(),
                                      transform_spec=None, shuffle_row_groups=False, workers_count=8), batch_size=BATCH_SIZE) as val_dataloader:

        train_dataloader_iter = iter(train_dataloader)
        steps_per_epoch = train_df_size // (BATCH_SIZE * hvd.size())

        val_dataloader_iter = iter(val_dataloader)
        validation_steps = val_df_size // (BATCH_SIZE * hvd.size())

        for epoch in range(NUM_EPOCHS):
            print('Epoch {}/{}'.format(epoch + 1, NUM_EPOCHS))
            print('-' * 10)

            train_loss, train_acc = train_one_epoch(model, optimizer_hvd,
                                                    exp_lr_scheduler,
                                                    train_dataloader_iter,
                                                    steps_per_epoch, epoch,
                                                    device)

            # save checkpoint
            if hvd.rank() == 0: save_checkpoint(model, optimizer_hvd, epoch)

            val_loss, val_acc = evaluate(model,
                                         val_dataloader_iter,
                                         validation_steps,
                                         device,
                                         metric_agg_fn=metric_average)

    return val_loss, val_acc
Ejemplo n.º 10
0
                   label_cols=['DEP_DEL15']):
    x_cat, x_cont, y = None, None, None
    x_cat = [batch[col].type(torch.LongTensor) for col in cat_cols]
    x_cat = torch.stack(x_cat, 1)
    x_cont = batch['scaledFeatures']
    y = batch['DEP_DEL15']
    return x_cat.to(device), x_cont.to(device), y.to(device)


# COMMAND ----------

# DBTITLE 1,Check Data
train_loader = BatchedDataLoader(make_batch_reader(
    dataset_url_or_urls='file:///dbfs/ml/tmp/assembled_t',
    num_epochs=None,
    transform_spec=None,
    shuffle_row_groups=False,
    workers_count=8),
                                 batch_size=4)
x_cat, x_cont, y = _transform_row(next(iter(train_loader)))
print(x_cat, x_cont.squeeze(1), y)

# COMMAND ----------


# DBTITLE 1,One Epoch Loop
def train_one_epoch(model, optimizer, scheduler, train_dataloader_iter,
                    steps_per_epoch, epoch, device):
    model.train()  # Set model to training mode

    # statistics
Ejemplo n.º 11
0
def test_batched_data_loader_with_in_memory_cache(
        two_columns_non_petastorm_dataset, shuffling_queue_capacity,
        reader_factory, num_epochs):
    batch_size = 10
    extra_loader_params = dict(
        inmemory_cache_all=True,
        num_epochs=num_epochs,
        shuffling_queue_capacity=shuffling_queue_capacity)
    extra_reader_params = dict(num_epochs=1)

    with reader_factory(two_columns_non_petastorm_dataset.url,
                        cur_shard=0,
                        shard_count=1,
                        reader_pool_type='thread',
                        workers_count=2,
                        hdfs_driver='libhdfs',
                        schema_fields=['col_0'],
                        **extra_reader_params) as reader:

        loader = BatchedDataLoader(reader,
                                   batch_size=batch_size,
                                   transform_fn=partial(torch.as_tensor,
                                                        device='cpu'),
                                   **extra_loader_params)

        it = iter(loader)
        retrieved_so_far = None
        for idx in range(5):
            batch = next(it)
            if idx == 0:
                first_buffer = loader._shuffling_buffer

            this_batch = batch['col_0'].clone()
            assert list(this_batch.shape)[0] == batch_size

            if retrieved_so_far is None:
                retrieved_so_far = this_batch
            else:
                intersect = set(retrieved_so_far.tolist()).intersection(
                    set(this_batch.tolist()))
                assert not intersect
                retrieved_so_far = torch.cat([retrieved_so_far, this_batch], 0)

        retrieved_in_first_epoch = retrieved_so_far.clone()

        assert len(set(retrieved_so_far.tolist())) == 50

        if num_epochs == 1:
            with pytest.raises(StopIteration):
                next(it)

        if num_epochs in [2, 3, None]:
            for idx in range(5):
                batch = next(it)
                # Assert that a new buffer is created inside the loader
                assert loader._shuffling_buffer != first_buffer
                this_batch = batch['col_0'].clone()
                assert list(this_batch.shape)[0] == batch_size
                intersection = set(this_batch.tolist()).intersection(
                    set(retrieved_in_first_epoch.tolist()))
                assert len(intersection) == batch_size
                retrieved_so_far = torch.cat([retrieved_so_far, this_batch], 0)

        if num_epochs == 2:
            with pytest.raises(StopIteration):
                next(it)

        if num_epochs in [3, None]:
            for idx in range(5):
                batch = next(it)
                this_batch = batch['col_0'].clone()
                assert list(this_batch.shape)[0] == batch_size
                retrieved_so_far = torch.cat([retrieved_so_far, this_batch], 0)

        if num_epochs == 3:
            with pytest.raises(StopIteration):
                next(it)

        if num_epochs is None:
            for idx in range(5):
                batch = next(it)
                this_batch = batch['col_0'].clone()
                assert list(this_batch.shape)[0] == batch_size
                retrieved_so_far = torch.cat([retrieved_so_far, this_batch], 0)
Ejemplo n.º 12
0
def test_mem_cache_reader_num_epochs_error(two_columns_non_petastorm_dataset):
    error_string = "reader.num_epochs is currently 2. When cache in memory is "
    with make_batch_reader(two_columns_non_petastorm_dataset.url,
                           num_epochs=2) as reader:
        with pytest.raises(ValueError, match=error_string):
            BatchedDataLoader(reader, inmemory_cache_all=True)