Esempio n. 1
0
def numpy_episodes(train_dir,
                   test_dir,
                   shape,
                   reader=None,
                   loader=None,
                   num_chunks=None,
                   preprocess_fn=None):
    """Read sequences stored as compressed Numpy files as a TensorFlow dataset.

  Args:
    train_dir: Directory containing NPZ files of the training dataset.
    test_dir: Directory containing NPZ files of the testing dataset.
    shape: Tuple of batch size and chunk length for the datasets.
    reader: Callable that reads an episode from a NPZ filename.
    loader: Generator that yields episodes.

  Returns:
    Structured data from numpy episodes as Tensors.
  """
    reader = reader or episode_reader
    loader = loader or cache_loader
    try:
        dtypes, shapes = _read_spec(reader, train_dir)
    except ZeroDivisionError:
        dtypes, shapes = _read_spec(reader, test_dir)
    train = tf.data.Dataset.from_generator(
        functools.partial(loader, reader, train_dir, shape[0]), dtypes, shapes)
    test = tf.data.Dataset.from_generator(
        functools.partial(loader, reader, test_dir, shape[0]), dtypes, shapes)
    chunking = lambda x: tf.data.Dataset.from_tensor_slices(
        chunk_sequence.chunk_sequence(x, shape[1], True, num_chunks))

    def sequence_preprocess_fn(sequence):
        if preprocess_fn:
            sequence['image'] = preprocess_fn(sequence['image'])
        return sequence

    train = train.flat_map(chunking)
    train = train.batch(shape[0], drop_remainder=True)
    train = train.map(sequence_preprocess_fn, 10).prefetch(10)
    test = test.flat_map(chunking)
    test = test.batch(shape[0], drop_remainder=True)
    test = test.map(sequence_preprocess_fn, 10).prefetch(10)
    return attr_dict.AttrDict(train=train, test=test)
Esempio n. 2
0
def numpy_episodes(train_dir,
                   test_dir,
                   shape,
                   reader=None,
                   loader=None,
                   num_chunks=None,
                   preprocess_fn=None,
                   aug_fn=None,
                   simclr=False):
    """Read sequences stored as compressed Numpy files as a TensorFlow dataset.

    Args:
      train_dir: Directory containing NPZ files of the training dataset.
      test_dir: Directory containing NPZ files of the testing dataset.
      shape: Tuple of batch size and chunk length for the datasets.
      reader: Callable that reads an episode from a NPZ filename.
      loader: Generator that yields episodes.

    Returns:
      Structured data from numpy episodes as Tensors.
    """
    reader = reader or episode_reader
    loader = loader or cache_loader
    try:
        dtypes, shapes = _read_spec(reader, train_dir)
    except ZeroDivisionError:
        dtypes, shapes = _read_spec(reader, test_dir)
    train = tf.data.Dataset.from_generator(
        functools.partial(loader, reader, train_dir, shape[0]), dtypes, shapes)
    test = tf.data.Dataset.from_generator(
        functools.partial(loader, reader, test_dir, shape[0]), dtypes, shapes)
    chunking = lambda x: tf.data.Dataset.from_tensor_slices(
        chunk_sequence.chunk_sequence(x, shape[1], True, num_chunks))

    def sequence_preprocess_fn(sequence):
        if preprocess_fn:
            sequence['image'], noise = preprocess_fn(sequence['image'],
                                                     return_noise=True)
            if 'ori_img' in sequence.keys():
                sequence['ori_img'] = preprocess_fn(sequence['ori_img'],
                                                    noise=noise)
        if simclr:
            print('ccc ', sequence)
            sequence['image'] = tf.reshape(sequence['image'],
                                           (shape[0], shape[1]) +
                                           tuple(sequence['image'].shape[2:]))
            sequence['action'] = tf.reshape(
                sequence['action'],
                (shape[0], shape[1], sequence['action'].shape[2]))
            sequence['reward'] = tf.reshape(sequence['reward'],
                                            (shape[0], shape[1]))
            sequence['aug'] = tf.reshape(sequence['aug'],
                                         (shape[0], shape[1], 2))
            sequence['length'] = tf.reshape(sequence['length'], [-1])

        print('preprocess ', sequence)
        return sequence

    bs = shape[0] // 2 if simclr else shape[0]
    train = train.flat_map(chunking)
    if aug_fn:
        print('Training set is augmented')
        train = train.map(lambda x: aug_fn(x, phase='train'), 10)
    train = train.batch(bs, drop_remainder=True)
    train = train.map(sequence_preprocess_fn, 10).prefetch(10)

    test = test.flat_map(chunking)
    if aug_fn:
        print('Test set is augmented')
        test = test.map(lambda x: aug_fn(x, phase='test'), 10)
    test = test.batch(bs, drop_remainder=True)
    test = test.map(sequence_preprocess_fn, 10).prefetch(10)
    return attr_dict.AttrDict(train=train, test=test)