def test_full_pytorch_example(large_mock_mnist_data, tmpdir):
    # First, generate mock dataset
    dataset_url = 'file://{}'.format(tmpdir)
    mnist_data_to_petastorm_dataset(tmpdir,
                                    dataset_url,
                                    mnist_data=large_mock_mnist_data,
                                    spark_master='local[1]',
                                    parquet_files_count=1)

    # Next, run a round of training using the pytorce adapting data loader
    from petastorm.pytorch import DataLoader

    torch.manual_seed(1)
    device = torch.device('cpu')
    model = pytorch_example.Net().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    with DataLoader(Reader('{}/train'.format(dataset_url),
                           reader_pool=DummyPool(),
                           num_epochs=1),
                    batch_size=32,
                    transform=pytorch_example._transform_row) as train_loader:
        pytorch_example.train(model, device, train_loader, 10, optimizer, 1)
    with DataLoader(Reader('{}/test'.format(dataset_url),
                           reader_pool=DummyPool(),
                           num_epochs=1),
                    batch_size=100,
                    transform=pytorch_example._transform_row) as test_loader:
        pytorch_example.test(model, device, test_loader)
def test_full_pytorch_example(large_mock_mnist_data, tmpdir):
    # First, generate mock dataset
    dataset_url = 'file://{}'.format(tmpdir)
    mnist_data_to_pycarbon_dataset(tmpdir,
                                   dataset_url,
                                   mnist_data=large_mock_mnist_data,
                                   spark_master='local[1]',
                                   carbon_files_count=1)

    # Next, run a round of training using the pytorce adapting data loader
    from petastorm.pytorch import DataLoader

    torch.manual_seed(1)
    device = torch.device('cpu')
    model = pytorch_example.Net().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    transform = TransformSpec(pytorch_example._transform_row,
                              removed_fields=['idx'])

    with DataLoader(make_carbon_reader('{}/train'.format(dataset_url),
                                       reader_pool_type='dummy',
                                       num_epochs=1,
                                       transform_spec=transform),
                    batch_size=32) as train_loader:
        pytorch_example.train(model, device, train_loader, 10, optimizer, 1)

    with DataLoader(make_carbon_reader('{}/test'.format(dataset_url),
                                       reader_pool_type='dummy',
                                       num_epochs=1,
                                       transform_spec=transform),
                    batch_size=100) as test_loader:
        pytorch_example.evaluation(model, device, test_loader)
示例#3
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Petastorm MNIST Example')
    default_dataset_url = 'file://{}'.format(DEFAULT_MNIST_DATA_PATH)
    parser.add_argument('--dataset-url', type=str,
                        default=default_dataset_url, metavar='S',
                        help='hdfs:// or file:/// URL to the MNIST petastorm dataset '
                             '(default: %s)' % default_dataset_url)
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--all-epochs', action='store_true', default=False,
                        help='train all epochs before testing accuracy/loss')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device('cuda' if use_cuda else 'cpu')

    model = Net().to(device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

    # Configure loop and Reader epoch for illustrative purposes.
    # Typical training usage would use the `all_epochs` approach.
    #
    if args.all_epochs:
        # Run training across all the epochs before testing for accuracy
        loop_epochs = 1
        reader_epochs = args.epochs
    else:
        # Test training accuracy after each epoch
        loop_epochs = args.epochs
        reader_epochs = 1

    # Instantiate each petastorm Reader with a single thread, shuffle enabled, and appropriate epoch setting
    for epoch in range(1, loop_epochs + 1):
        with DataLoader(make_reader('{}/train'.format(args.dataset_url), num_epochs=reader_epochs),
                        batch_size=args.batch_size, transform=_transform_row) as train_loader:
            train(model, device, train_loader, args.log_interval, optimizer, epoch)
        with DataLoader(make_reader('{}/test'.format(args.dataset_url), num_epochs=reader_epochs),
                        batch_size=args.test_batch_size, transform=_transform_row) as test_loader:
            test(model, device, test_loader)
def test_pytorch_dataloader_context(synthetic_dataset):
    reader = make_reader(synthetic_dataset.url,
                         schema_fields=PYTORCH_COMPATIBLE_FIELDS,
                         reader_pool_type='dummy')
    with DataLoader(reader, collate_fn=_noop_collate) as loader:
        for item in loader:
            assert len(item) == 1
示例#5
0
def test_pytorch_dataloader_batched(synthetic_dataset):
    batch_size = 10
    loader = DataLoader(Reader(synthetic_dataset.url, reader_pool=DummyPool()),
                        batch_size=batch_size,
                        collate_fn=_noop_collate)
    for item in loader:
        assert len(item) == batch_size
示例#6
0
    def train_dataloader(self):
        reader = make_reader(Path(self.data_path).absolute().as_uri(), reader_pool_type='process', workers_count=12,
                             pyarrow_serialize=True, shuffle_row_groups=True, shuffle_row_drop_partitions=2,
                             num_epochs=self.hparams.epoch)
        dataloader = DataLoader(reader, batch_size=16, shuffling_queue_capacity=4096)

        return dataloader
def test_pytorch_dataloader_with_transform_function(synthetic_dataset):
    with DataLoader(make_reader(synthetic_dataset.url,
                                schema_fields=ALL_FIELDS - NULLABLE_FIELDS,
                                reader_pool_type='dummy',
                                transform_spec=TransformSpec(_str_to_int)),
                    collate_fn=_noop_collate) as loader:
        for item in loader:
            assert len(item) == 1
def test_simple_read(synthetic_dataset, reader_factory):
    with DataLoader(
            reader_factory(
                synthetic_dataset.url,
                schema_fields=BATCHABLE_FIELDS,
                transform_spec=TransformSpec(_sensor_name_to_int))) as loader:
        _check_simple_reader(loader, synthetic_dataset.data,
                             BATCHABLE_FIELDS - {TestSchema.sensor_name})
示例#9
0
    def __enter__(self):
        from petastorm.pytorch import DataLoader

        _wait_file_available(self.parquet_file_url_list)
        self.reader = make_batch_reader(self.parquet_file_url_list,
                                        **self.petastorm_reader_kwargs)
        self.loader = DataLoader(reader=self.reader, batch_size=self.batch_size)
        return self.loader
示例#10
0
def pytorch_hello_world(dataset_url='file:///tmp/carbon_pycarbon_dataset'):
    with DataLoader(make_reader(dataset_url, is_batch=False)) as train_loader:
        sample = next(iter(train_loader))
        print(sample['id'])

    with make_data_loader(make_reader(dataset_url,
                                      is_batch=False)) as train_loader:
        sample = next(iter(train_loader))
        print(sample['id'])
def test_pytorch_dataloader_batched(synthetic_dataset):
    batch_size = 10
    loader = DataLoader(make_reader(synthetic_dataset.url,
                                    schema_fields=PYTORCH_COMPATIBLE_FIELDS,
                                    reader_pool_type='dummy'),
                        batch_size=batch_size,
                        collate_fn=_noop_collate)
    for item in loader:
        assert len(item) == batch_size
示例#12
0
def get_data_loader(data_path: str = None,
                    num_epochs: int = 1,
                    batch_size: int = 16):
    if not data_path:
        return None

    return DataLoader(make_batch_reader(dataset_url=data_path,
                                        num_epochs=num_epochs),
                      batch_size=batch_size)
示例#13
0
def pytorch_hello_world(dataset_url='file:///tmp/carbon_external_dataset'):
    with DataLoader(make_reader(dataset_url)) as train_loader:
        sample = next(iter(train_loader))
        # Because we are using make_batch_reader(), each read returns a batch of rows instead of a single row
        print("id batch: {0}".format(sample['id']))

    with make_data_loader(make_reader(dataset_url)) as train_loader:
        sample = next(iter(train_loader))
        # Because we are using make_batch_reader(), each read returns a batch of rows instead of a single row
        print("id batch: {0}".format(sample['id']))
def test_no_shuffling(synthetic_dataset, reader_factory):
    with DataLoader(
            reader_factory(synthetic_dataset.url,
                           schema_fields=['^id$'],
                           workers_count=1,
                           shuffle_row_groups=False)) as loader:
        ids = [row['id'][0].numpy() for row in loader]
        # expected_ids would be [0, 1, 2, ...]
        expected_ids = [row['id'] for row in synthetic_dataset.data]
        np.testing.assert_array_equal(expected_ids, ids)
def test_with_torch_api(synthetic_dataset):
    """Verify that WeightedSamplingReader is compatible with petastorm.pytorch.DataLoader"""
    readers = [reader0, reader1]

    with WeightedSamplingReader(readers, [0.5, 0.5]) as mixer:
        assert not mixer.batched_output
        sample = next(mixer)
        assert sample is not None
        with DataLoader(mixer, batch_size=2) as loader:
            for batch in loader:
                assert batch['f1'].shape[0] == 2
                break
def test_with_shuffling_buffer(synthetic_dataset, reader_factory):
    with DataLoader(reader_factory(synthetic_dataset.url,
                                   schema_fields=['^id$'],
                                   workers_count=1,
                                   shuffle_row_groups=False),
                    shuffling_queue_capacity=51) as loader:
        ids = [row['id'][0].numpy() for row in loader]

        assert len(ids) == len(
            synthetic_dataset.data
        ), 'All samples should be returned after reshuffling'

        # diff(ids) would return all-'1' for the seqeunce (note that we used shuffle_row_groups=False)
        # We assume we get less then 10% of consequent elements for the sake of the test (this probability is very
        # close to zero)
        assert sum(np.diff(ids) == 1) < len(synthetic_dataset.data) / 10.0
示例#17
0
def train_net(params):
    trainer = emulator.MAIACTrainer(params)

    # set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    trainer.to(device)
    trainer.load_checkpoint()

    # data transformatoins
    drop_columns = [
        'year', 'dayofyear', 'hour', 'minute', 'fileahi05', 'fileahi12', 'h',
        'v', 'sample_id'
    ]
    transform = TransformSpec(_transform_row('AHI'),
                              removed_fields=drop_columns)
    while 1:
        with DataLoader(make_reader(params['data_url'],
                                    num_epochs=1,
                                    shuffle_row_drop_partitions=5,
                                    transform_spec=transform),
                        batch_size=8) as loader:
            for example in loader:
                #x = example['AHI05'][:,:4].to(device)
                x = example['AHI05'].to(device)
                ahi12 = example['AHI12'].type(torch.FloatTensor)
                mask = (ahi12 != ahi12).type(
                    torch.FloatTensor)  # null values = 1

                mask = mask.to(device)
                ahi12 = ahi12.to(device)
                log = False

                if trainer.global_step % params['log_iter'] == 0:
                    log = True

                loss = trainer.step(x, ahi12, mask, log=log)
                if log:
                    print(
                        f"Step: {trainer.global_step}\tLoss: {loss.item():4.4g}"
                    )

                if trainer.global_step % params['checkpoint_step'] == 1:
                    trainer.save_checkpoint()

                if trainer.global_step >= params['max_iter']:
                    break
示例#18
0
    def __init__(self, data_url, batch_size, num_epochs, workers_count,
                 cur_shard, shard_count, **petastorm_reader_kwargs):
        """
        :param data_url: A string specifying the data URL.
        See `SparkDatasetConverter.make_torch_dataloader()` for the definitions
        of the other parameters.
        """
        from petastorm.pytorch import DataLoader

        petastorm_reader_kwargs["num_epochs"] = num_epochs
        if workers_count is not None:
            petastorm_reader_kwargs["workers_count"] = workers_count
        petastorm_reader_kwargs["cur_shard"] = cur_shard
        petastorm_reader_kwargs["shard_count"] = shard_count

        self.reader = petastorm.make_batch_reader(data_url,
                                                  **petastorm_reader_kwargs)
        self.loader = DataLoader(reader=self.reader, batch_size=batch_size)
示例#19
0
 def get_dataloader(self, dataset: Dataset, identity: str = "Default"):
     batch_preprocessor = self.build_batch_preprocessor()
     reader_options = self.reader_options
     assert reader_options
     data_reader = make_batch_reader(
         # pyre-fixme[16]: `HiveDataSetClass` has no attribute `parquet_url`.
         dataset.parquet_url,
         num_epochs=1,
         reader_pool_type=reader_options.petastorm_reader_pool_type,
     )
     # NOTE: must be wrapped by DataLoaderWrapper to call __exit__() on end of epoch
     dataloader = DataLoader(
         data_reader,
         batch_size=reader_options.minibatch_size,
         collate_fn=collate_and_preprocess(
             batch_preprocessor=batch_preprocessor, use_gpu=False),
     )
     return _closing_iter(dataloader)
示例#20
0
def get_petastorm_dataloader(
    dataset: Dataset,
    batch_size: int,
    batch_preprocessor: BatchPreprocessor,
    use_gpu: bool,
    reader_options: ReaderOptions,
):
    """get petastorm loader for dataset (with preprocessor)"""
    data_reader = make_batch_reader(
        dataset.parquet_url,
        num_epochs=1,
        reader_pool_type=reader_options.petastorm_reader_pool_type,
    )
    return DataLoader(
        data_reader,
        batch_size=batch_size,
        collate_fn=collate_and_preprocess(
            batch_preprocessor=batch_preprocessor, use_gpu=use_gpu),
    )
示例#21
0
def confusion_matrix(data_path, model, num_class):
    data_path = Path(data_path)
    model.eval()

    cm = np.zeros((num_class, num_class), dtype=np.float)

    dataloader = DataLoader(make_reader(str(data_path.absolute().as_uri()),
                                        reader_pool_type='process',
                                        num_epochs=1),
                            batch_size=4096)
    for batch in dataloader:
        x = batch['feature'].float()
        y = batch['label'].long()
        y_hat = torch.argmax(F.log_softmax(model(x), dim=1), dim=1)

        for i in range(len(y)):
            cm[y[i], y_hat[i]] += 1

    return cm
示例#22
0
def main(data_path: str, model_path: str, gpu: bool):
    if gpu:
        gpu = -1
    else:
        gpu = None

    # initialise logger
    logger = logging.getLogger(__file__)
    logger.addHandler(logging.StreamHandler())
    logger.setLevel('INFO')

    logger.info('Initialise data loader...')
    # get number of cores
    num_cores = psutil.cpu_count(logical=True)
    # load data loader
    reader = make_reader(Path(data_path).absolute().as_uri(),
                         schema_fields=['feature'],
                         reader_pool_type='process',
                         workers_count=num_cores,
                         pyarrow_serialize=True,
                         shuffle_row_groups=True,
                         shuffle_row_drop_partitions=2,
                         num_epochs=1)
    dataloader = DataLoader(reader,
                            batch_size=300,
                            shuffling_queue_capacity=4096)

    logger.info('Initialise model...')
    # init model
    model = VAE()

    logger.info('Start Training...')
    # train
    trainer = Trainer(val_check_interval=100, max_epochs=50, gpus=gpu)
    trainer.fit(model, dataloader)

    logger.info('Persisting...')
    # persist model
    Path(model_path).parent.mkdir(parents=True, exist_ok=True)
    trainer.save_checkpoint(model_path)

    logger.info('Done')
def make_loaders(params):
    '''
    Data parameters from training configuration file are used to build
        training generators
    Args:
        params
    Returns:
        
    '''
    loaders = dict()
    datanames = params['data'].keys()
    drop_columns = [
        'year', 'dayofyear', 'hour', 'minute', 'file', 'h', 'v', 'sample_id'
    ]
    for key in datanames:
        url = params['data'][key]['data_url']
        transform = TransformSpec(_transform_row(key),
                                  removed_fields=drop_columns)
        loaders[key] = DataLoader(make_reader(url, transform_spec=transform),
                                  batch_size=params['batch_size'])
    return loaders
def test_with_batch_reader(scalar_dataset, shuffling_queue_capacity):
    """See if we are getting correct batch sizes when using DataLoader with make_batch_reader"""
    pytorch_compatible_fields = [
        k for k, v in scalar_dataset.data[0].items()
        if not isinstance(v, (np.datetime64, np.unicode_))
    ]
    with DataLoader(
            make_batch_reader(scalar_dataset.url,
                              schema_fields=pytorch_compatible_fields),
            batch_size=3,
            shuffling_queue_capacity=shuffling_queue_capacity) as loader:
        batches = list(loader)
        assert len(scalar_dataset.data) == sum(batch['id'].shape[0]
                                               for batch in batches)

        # list types are broken in pyarrow 0.15.0. Don't test list-of-int field
        if pa.__version__ != '0.15.0':
            assert len(scalar_dataset.data) == sum(
                batch['int_fixed_size_list'].shape[0] for batch in batches)
            assert batches[0]['int_fixed_size_list'].shape[1] == len(
                scalar_dataset.data[0]['int_fixed_size_list'])
示例#25
0
文件: timing.py 项目: yariv/petastorm
def main(device):
    print("Testing DataLoader on", device)
    reader = DummyReader()
    for batch_size in [10, 100, 1000, 100000]:
        iterations = 100
        loader = DataLoader(reader,
                            shuffling_queue_capacity=batch_size * 10,
                            batch_size=batch_size,
                            collate_fn=partial(torch.as_tensor, device=device))
        it = iter(loader)

        # Warmup
        for _ in range(iterations):
            next(it)
        print("Done warming up")

        tstart = time.time()
        for _ in range(iterations):
            next(it)
        print("Samples per second for batch {}: {:.4g}".format(
            batch_size, (iterations * batch_size) / (time.time() - tstart)))
示例#26
0
def get_petastorm_dataloader(
    dataset: Dataset,
    batch_size: int,
    batch_preprocessor: BatchPreprocessor,
    use_gpu: bool,
    reader_options: ReaderOptions,
):
    """ get petastorm loader for dataset (with preprocessor) """
    data_reader = make_batch_reader(
        # pyre-fixme[16]: `HiveDataSetClass` has no attribute `parquet_url`.
        # pyre-fixme[16]: `HiveDataSetClass` has no attribute `parquet_url`.
        dataset.parquet_url,
        num_epochs=1,
        # pyre-fixme[16]: `ReaderOptions` has no attribute `petastorm_reader_pool_type`.
        # pyre-fixme[16]: `ReaderOptions` has no attribute `petastorm_reader_pool_type`.
        reader_pool_type=reader_options.petastorm_reader_pool_type,
    )
    # NOTE: must be wrapped by DataLoaderWrapper to call __exit__() on end of epoch
    return DataLoader(
        data_reader,
        batch_size=batch_size,
        collate_fn=collate_and_preprocess(
            batch_preprocessor=batch_preprocessor, use_gpu=use_gpu),
    )
示例#27
0
def pytorch_hello_world(dataset_url='file:///tmp/hello_world_dataset'):
    with DataLoader(Reader(dataset_url)) as train_loader:
        sample = next(iter(train_loader))
        print(sample['id'])
示例#28
0
def test_pytorch_dataloader_context(synthetic_dataset):
    with DataLoader(Reader(synthetic_dataset.url, reader_pool=DummyPool()),
                    collate_fn=_noop_collate) as loader:
        assert len(loader) == len(synthetic_dataset.data)
        for item in loader:
            assert len(item) == 1
示例#29
0
    def train(serialized_model, optimizer_cls, model_opt_state_serialized,
              train_rows, val_rows, avg_row_size):
        from petastorm import make_batch_reader
        from petastorm.pytorch import DataLoader
        import torch
        import horovod.torch as hvd

        # Deserializing objects
        model_opt_state = torch.load(model_opt_state_serialized)
        model = deserialize(serialized_model)

        if loss_fns_pre_train:
            loss_fns = loss_fns_pre_train
        if loss_constructors:
            local_vars = locals()
            loss_fns = [
                loss_constructor(**local_vars)
                for loss_constructor in loss_constructors
            ]

        # Horovod: initialize library.
        hvd.init()

        if not user_shuffle_buffer_size:
            shuffle_buffer_size = \
                calculate_shuffle_buffer_size(hvd, avg_row_size, train_rows / hvd.size())
        else:
            shuffle_buffer_size = user_shuffle_buffer_size

        cuda_available = torch.cuda.is_available()
        if cuda_available:
            # Horovod: pin GPU to local rank.
            torch.cuda.set_device(hvd.local_rank())
            # Move model to GPU.
            model.cuda()

        # Optimizer object needs to be re-instantiated. Internally, it uses memory addresses of
        # objects as their identity and therefore it cannot be serialized and then
        # deserialized. The deserialized optimizer object stores the names of the parameters
        # with their old memory addresses but in reality those are different than the
        # reconstructed deserialized object and that creates problem.
        # Learning rate is a required parameters in SGD optimizer. It will be overridden with
        # load_state_dict.
        optimizer = optimizer_cls(model.parameters(), lr=1)
        optimizer_state = model_opt_state['optimizer']

        if last_checkpoint_state is not None:
            model.load_state_dict(last_checkpoint_state['model'])
            optimizer.load_state_dict(last_checkpoint_state['optimizer'])
        else:
            # scale the learning rate with the number of horovod workers
            for i in range(len(optimizer_state['param_groups'])):
                optimizer_state['param_groups'][i]['lr'] = \
                    optimizer_state['param_groups'][i]['lr'] * hvd.size()

            optimizer.load_state_dict(optimizer_state)

        # Horovod: broadcast parameters & optimizer state.
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)

        for group in optimizer.param_groups:
            for p in group['params']:
                if id(p) not in optimizer.state_dict()['state']:
                    p.grad = p.data.new(p.size()).zero_()
        optimizer.step()
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)

        dist_optimizer_args = dict(optimizer=optimizer,
                                   named_parameters=model.named_parameters())
        if gradient_compression:
            # Pass the compression arg only if it is specified by the user.
            dist_optimizer_args['compression'] = gradient_compression
        # Horovod: wrap optimizer with DistributedOptimizer.
        optimizer = hvd.DistributedOptimizer(**dist_optimizer_args)

        # This function takes the current optimizer and constructs a new optimizer with the
        # same state except with learning rate scaled down with the number of horovod workers.
        # This is important the retraining of the model. User may retrain the model with
        # different number of workers and we need the raw learning rate to adjust with the
        # new number of workers.

        schema_fields = feature_columns + label_columns
        if sample_weight_col:
            schema_fields.append(sample_weight_col)

        if train_steps_per_epoch is None:
            steps_per_epoch = int(
                math.ceil(float(train_rows) / batch_size / hvd.size()))
        else:
            steps_per_epoch = train_steps_per_epoch

        with remote_store.get_local_output_dir() as run_output_dir:
            logs_dir = os.path.join(run_output_dir, remote_store.logs_subdir)
            log_writer = SummaryWriter(logs_dir) if hvd.rank() == 0 else None
            ckpt_file = os.path.join(run_output_dir,
                                     remote_store.checkpoint_filename)

            def save_checkpoint():
                model.cpu()
                optimizer_with_scaled_down_lr = \
                    get_optimizer_with_unscaled_lr(hvd, optimizer, optimizer_cls, model)
                state = {
                    'model': model.state_dict(),
                    'optimizer': optimizer_with_scaled_down_lr.state_dict(),
                }
                torch.save(state, ckpt_file)
                if cuda_available:
                    model.cuda()

            # Petastorm: read data from the store with the correct shard for this rank
            # setting num_epochs=None will cause an infinite iterator
            # and enables ranks to perform training and validation with
            # unequal number of samples
            with make_batch_reader(
                    remote_store.train_data_path,
                    num_epochs=None,
                    cur_shard=hvd.rank(),
                    shard_count=hvd.size(),
                    hdfs_driver=PETASTORM_HDFS_DRIVER,
                    schema_fields=schema_fields) as train_reader:
                with make_batch_reader(remote_store.val_data_path,
                                       num_epochs=None,
                                       cur_shard=hvd.rank(),
                                       shard_count=hvd.size(),
                                       hdfs_driver=PETASTORM_HDFS_DRIVER,
                                       schema_fields=schema_fields) \
                        if should_validate else empty_batch_reader() as val_reader:

                    train_loader = DataLoader(
                        train_reader,
                        batch_size=batch_size,
                        shuffling_queue_capacity=shuffle_buffer_size)
                    train_loader_iter = iter(train_loader)

                    def prepare_batch(row):
                        inputs = [
                            prepare_np_data(row[col].float(), col,
                                            metadata).reshape(shape) for col,
                            shape in zip(feature_columns, input_shapes)
                        ]
                        labels = [
                            prepare_np_data(row[col].float(), col, metadata)
                            for col in label_columns
                        ]

                        sample_weights = row.get(sample_weight_col, None)
                        if cuda_available:
                            inputs = [input.cuda() for input in inputs]
                            labels = [label.cuda() for label in labels]
                        return inputs, labels, sample_weights

                    def transform_outputs(outputs, labels):
                        if type(outputs) != tuple and type(outputs) != list:
                            outputs = [outputs]

                        # reshape labels to match the output shape of the model
                        if hasattr(outputs[0], 'shape'):
                            labels = [
                                label.reshape(output.shape)
                                if output.shape.numel() == label.shape.numel()
                                else label
                                for label, output in zip(labels, outputs)
                            ]
                        return outputs, labels

                    def aggregate_metrics(stage, epoch, loss,
                                          metric_value_groups):
                        all_metric_groups_values = get_metric_avgs(
                            metric_value_groups)
                        if remote_store.saving_runs:
                            write_metrics_summary(stage, epoch, loss,
                                                  all_metric_groups_values,
                                                  log_writer)
                        return {
                            loss.name: loss.avg.item(),
                            'all_metrics': all_metric_groups_values
                        }

                    def loss_fn(outputs, labels, sample_weights):
                        loss = calculate_loss(outputs, labels, loss_weights,
                                              loss_fns, sample_weights)
                        return loss

                    def print_metrics(batch_idx, loss, metric_value_groups,
                                      phase):
                        if user_verbose > 0 and hvd.rank() == 0 and \
                                batch_idx % METRIC_PRINT_FREQUENCY == 0:
                            print(
                                "epoch:\t{epoch}\tstep\t{batch_idx}:\t{metrics}"
                                .format(epoch=epoch,
                                        batch_idx=batch_idx,
                                        metrics=aggregate_metrics(
                                            phase, epoch, loss,
                                            metric_value_groups)))

                    def _train(epoch):
                        model.train()
                        train_loss = metric_cls('loss', hvd)
                        metric_value_groups = construct_metric_value_holders(
                            metric_cls, metric_fn_groups, label_columns, hvd)

                        # iterate on one epoch
                        for batch_idx in range(steps_per_epoch):
                            row = next(train_loader_iter)
                            inputs, labels, sample_weights = prepare_batch(row)
                            outputs, loss = train_minibatch(
                                model, optimizer, transform_outputs, loss_fn,
                                inputs, labels, sample_weights)
                            update_metrics(metric_value_groups, outputs,
                                           labels)
                            train_loss.update(loss)
                            print_metrics(batch_idx, train_loss,
                                          metric_value_groups, 'train')

                        return aggregate_metrics('train', epoch, train_loss,
                                                 metric_value_groups)

                    if should_validate:
                        val_loader = DataLoader(val_reader,
                                                batch_size=batch_size)
                        val_loader_iter = iter(val_loader)
                        if validation_steps_per_epoch is None:
                            validation_steps = int(
                                math.ceil(
                                    float(val_rows) / batch_size / hvd.size()))
                        else:
                            validation_steps = validation_steps_per_epoch

                        def _validate(epoch):
                            model.eval()
                            val_loss = metric_cls('loss', hvd)

                            metric_value_groups = construct_metric_value_holders(
                                metric_cls, metric_fn_groups, label_columns,
                                hvd)

                            # iterate on one epoch
                            for batch_idx in range(validation_steps):
                                row = next(val_loader_iter)
                                inputs, labels, sample_weights = prepare_batch(
                                    row)

                                outputs = model(*inputs)
                                outputs, labels = transform_outputs(
                                    outputs, labels)

                                loss = calculate_loss(outputs, labels,
                                                      loss_weights, loss_fns,
                                                      sample_weights)
                                val_loss.update(loss)
                                update_metrics(metric_value_groups, outputs,
                                               labels)
                                print_metrics(batch_idx, val_loss,
                                              metric_value_groups, 'val')
                            return aggregate_metrics('val', epoch, val_loss,
                                                     metric_value_groups)

                    history = []
                    for epoch in range(epochs):
                        epoch_metrics = {
                            'epoch': epoch,
                            'train': _train(epoch)
                        }

                        if should_validate:
                            epoch_metrics['validation'] = _validate(epoch)

                        if user_verbose > 0:
                            print(epoch_metrics)

                        history.append(epoch_metrics)
                        if hvd.rank() == 0:
                            # Save model after every epoch
                            save_checkpoint()
                            if remote_store.saving_runs:
                                remote_store.sync(run_output_dir)

            if hvd.rank() == 0:
                best_checkpoint = torch.load(ckpt_file)
                serialized_checkpoint = io.BytesIO()
                torch.save(best_checkpoint, serialized_checkpoint)
                serialized_checkpoint.seek(0)
                return history, serialized_checkpoint
示例#30
0
文件: train_vae.py 项目: hyh2010/GEE
def main(data_path: str, model_path: str, model_name: str, gpu: bool, vae: bool):
    if gpu:
        gpu = -1
    else:
        gpu = None

    # initialise logger
    logger = logging.getLogger(__file__)
    logger.addHandler(logging.StreamHandler())
    logger.setLevel('INFO')

    logger.info('Initialise data loader...')

    # create tensorboard logdir
    #logdir = pathlib.Path(tempfile.mkdtemp())/"tensorboard_logs"
    #shutil.rmtree(logdir, ignore_errors=True)
    logdir = pathlib.Path(model_path)/"tensorboard_logs"
    logger.info('Tensorboard dir: %s' % logdir) 

    # get number of cores
    #num_cores = psutil.cpu_count(logical=True)
    num_cores = 8
    print(num_cores)
    # load data loader
    reader = make_reader(
        Path(data_path).absolute().as_uri(), schema_fields=['feature'], reader_pool_type='process',
        workers_count=num_cores, pyarrow_serialize=True, shuffle_row_groups=True, shuffle_row_drop_partitions=2,
        num_epochs=1
    )
    dataloader = DataLoader(reader, batch_size=300, shuffling_queue_capacity=4096)

    logger.info('Initialise model...')
    # init model

    if vae:
        model = VAE()
    else:
        model = AE()

    #data = next(iter(dataloader))


    #writer = SummaryWriter(logdir/'vae')
    #writer.add_graph(model, data['feature'])
    #writer.close()

    logger.info('Start Training...')

    tb_logger = pl_loggers.TensorBoardLogger(str(logdir), name=model_name)

    # train
    trainer = Trainer(val_check_interval=100, max_epochs=50, gpus=gpu, logger=tb_logger)

    trainer.fit(model, dataloader)

    logger.info('Persisting...')
    # persist model
    Path(model_path).mkdir(parents=True, exist_ok=True)
    trainer.save_checkpoint(model_path + '/' + model_name + '.model')

    logger.info('Done')