Ejemplo n.º 1
0
def load(cfg, **unused_kwargs):
    """
    Args:
        cfg (obj): Forge config
    Returns:
        (DataLoader, DataLoader, DataLoader):
            Tuple of data loaders for train, val, test
    """
    del unused_kwargs
    if not os.path.exists(cfg.data_folder):
        raise Exception("Data folder does not exist.")
    print(f"Using {cfg.num_workers} data workers.")

    if not hasattr(cfg, 'unique_colours'):
        cfg.unique_colours = False

    # Paths
    if cfg.unique_colours:
        train_path = 'training_images_rand4_unique.npy'
        val_path = 'validation_images_rand4_unique.npy'
        test_path = 'test_images_rand4_unique.npy'
    else:
        train_path = 'training_images_rand4.npy'
        val_path = 'validation_images_rand4.npy'
        test_path = 'test_images_rand4.npy'

    # Training
    train_dataset = dSpritesDataset(os.path.join(cfg.data_folder, train_path),
                                    cfg.load_instances, cfg.img_size,
                                    cfg.mem_map)
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.batch_size,
                              shuffle=True,
                              num_workers=cfg.num_workers)
    # Validation
    val_dataset = dSpritesDataset(os.path.join(cfg.data_folder,
                                               val_path), cfg.load_instances,
                                  cfg.img_size, cfg.mem_map)
    val_loader = DataLoader(val_dataset,
                            batch_size=cfg.batch_size,
                            shuffle=False,
                            num_workers=cfg.num_workers)
    # Test
    test_dataset = dSpritesDataset(os.path.join(cfg.data_folder,
                                                test_path), cfg.load_instances,
                                   cfg.img_size, cfg.mem_map)
    test_loader = DataLoader(test_dataset,
                             batch_size=cfg.batch_size,
                             shuffle=True,
                             num_workers=1)

    # Throughput stats
    if not cfg.debug:
        loader_throughput(train_loader)

    return (train_loader, val_loader, test_loader)
Ejemplo n.º 2
0
def load(cfg, **unused_kwargs):
    del unused_kwargs
    if not os.path.exists(cfg.data_folder):
        raise Exception("Data folder does not exist.")
    print(f"Using {cfg.num_workers} data workers.")

    # Copy all images and splits to /tmp
    if cfg.copy_to_tmp:
        for directory in ['/recordings', '/splits']:
            src = cfg.data_folder + directory
            dst = '/tmp' + directory
            fprint(f"Copying dataset from {src} to {dst}.")
            copytree(src, dst)
        cfg.data_folder = '/tmp'

    # Training
    tng_set = ShapeStacksDataset(cfg.data_folder,
                                 cfg.split_name,
                                 'train',
                                 cfg.img_size)
    tng_loader = DataLoader(tng_set,
                            batch_size=cfg.batch_size,
                            shuffle=True,
                            num_workers=cfg.num_workers)
    # Validation
    val_set = ShapeStacksDataset(cfg.data_folder,
                                 cfg.split_name,
                                 'eval',
                                 cfg.img_size)
    val_loader = DataLoader(val_set,
                            batch_size=cfg.batch_size,
                            shuffle=False,
                            num_workers=cfg.num_workers)
    # Test
    tst_set = ShapeStacksDataset(cfg.data_folder,
                                 cfg.split_name,
                                 'test',
                                 cfg.img_size,
                                 shuffle_files=cfg.shuffle_test)
    tst_loader = DataLoader(tst_set,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    # Throughput stats
    loader_throughput(tng_loader)

    return (tng_loader, val_loader, tst_loader)
Ejemplo n.º 3
0
def load(cfg, **unused_kwargs):
    # Fix TensorFlow seed
    global SEED
    SEED = cfg.seed
    tf.set_random_seed(SEED)

    if cfg.num_workers == 0:
        fprint("Need to use at least one worker for loading tfrecords.")
        cfg.num_workers = 1

    del unused_kwargs
    if not os.path.exists(cfg.data_folder):
        raise Exception("Data folder does not exist.")
    print(f"Using {cfg.num_workers} data workers.")
    # Create data iterators
    train_loader = GQNLoader(data_folder=cfg.data_folder,
                             mode='devel_train',
                             img_size=cfg.img_size,
                             val_frac=cfg.val_frac,
                             batch_size=cfg.batch_size,
                             num_workers=cfg.num_workers,
                             buffer_size=cfg.buffer_size)
    val_loader = GQNLoader(data_folder=cfg.data_folder,
                           mode='devel_val',
                           img_size=cfg.img_size,
                           val_frac=cfg.val_frac,
                           batch_size=cfg.batch_size,
                           num_workers=cfg.num_workers,
                           buffer_size=cfg.buffer_size)
    test_loader = GQNLoader(data_folder=cfg.data_folder,
                            mode='test',
                            img_size=cfg.img_size,
                            val_frac=cfg.val_frac,
                            batch_size=1,
                            num_workers=1,
                            buffer_size=cfg.buffer_size)
    # Create session to be used by loaders
    sess = tf.InteractiveSession()
    train_loader.sess = sess
    val_loader.sess = sess
    test_loader.sess = sess

    # Throughput stats
    if not cfg.debug:
        loader_throughput(train_loader)

    return (train_loader, val_loader, test_loader)
Ejemplo n.º 4
0
def load(cfg, **unused_kwargs):
    # Fix TensorFlow seed
    global SEED
    SEED = cfg.seed

    if cfg.num_workers == 0:
        fprint("Need to use at least one worker for loading.")
        cfg.num_workers = 1

    del unused_kwargs
    print(f"Using {cfg.num_workers} data workers.")
    # Create data iterators
    train_loader = MineRLLoader(
        mode="devel_train",
        img_size=cfg.img_size,
        val_frac=cfg.val_frac,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        buffer_size=cfg.buffer_size,
    )
    val_loader = MineRLLoader(
        mode="devel_val",
        img_size=cfg.img_size,
        val_frac=cfg.val_frac,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        buffer_size=cfg.buffer_size,
    )
    test_loader = MineRLLoader(
        mode="test",
        img_size=cfg.img_size,
        val_frac=cfg.val_frac,
        batch_size=1,
        num_workers=1,
        buffer_size=cfg.buffer_size,
    )

    # Throughput stats
    loader_throughput(train_loader)

    return (train_loader, val_loader, test_loader)
Ejemplo n.º 5
0
def load(cfg, **unused_kwargs):
    # Fix TensorFlow seed
    global SEED
    SEED = cfg.seed
    tf.set_random_seed(SEED)

    del unused_kwargs
    fprint(f"Using {cfg.num_workers} data workers.")

    sess = tf.InteractiveSession()

    if cfg.dataset == 'multi_dsprites':
        cfg.img_size = 64 if cfg.img_size < 0 else cfg.img_size
        cfg.K_steps = 5 if cfg.K_steps < 0 else cfg.K_steps
        background_entities = 1
        max_frames = 60000
        raw_dataset = multi_dsprites.dataset(cfg.data_folder + MULTI_DSPRITES,
                                             'colored_on_colored',
                                             map_parallel_calls=cfg.num_workers
                                             if cfg.num_workers > 0 else None)
    elif cfg.dataset == 'objects_room':
        cfg.img_size = 64 if cfg.img_size < 0 else cfg.img_size
        cfg.K_steps = 7 if cfg.K_steps < 0 else cfg.K_steps
        background_entities = 4
        max_frames = 1000000
        raw_dataset = objects_room.dataset(cfg.data_folder + OBJECTS_ROOM,
                                           'train',
                                           map_parallel_calls=cfg.num_workers
                                           if cfg.num_workers > 0 else None)
    elif cfg.dataset == 'clevr':
        cfg.img_size = 128 if cfg.img_size < 0 else cfg.img_size
        cfg.K_steps = 11 if cfg.K_steps < 0 else cfg.K_steps
        background_entities = 1
        max_frames = 70000
        raw_dataset = clevr_with_masks.dataset(
            cfg.data_folder + CLEVR,
            map_parallel_calls=cfg.num_workers
            if cfg.num_workers > 0 else None)
    elif cfg.dataset == 'tetrominoes':
        cfg.img_size = 32 if cfg.img_size < 0 else cfg.img_size
        cfg.K_steps = 4 if cfg.K_steps < 0 else cfg.K_steps
        background_entities = 1
        max_frames = 60000
        raw_dataset = tetrominoes.dataset(cfg.data_folder + TETROMINOS,
                                          map_parallel_calls=cfg.num_workers
                                          if cfg.num_workers > 0 else None)
    else:
        raise NotImplementedError(f"{cfg.dataset} not a valid dataset.")

    # Split into train / val / test
    if cfg.dataset_size > max_frames:
        fprint(f"WARNING: {cfg.dataset_size} frames requested, "\
                "but only {max_frames} available.")
        cfg.dataset_size = max_frames
    if cfg.dataset_size > 0:
        total_sz = cfg.dataset_size
        raw_dataset = raw_dataset.take(total_sz)
    else:
        total_sz = max_frames
    if total_sz < 0:
        fprint("Determining size of dataset...")
        total_sz = len_tfrecords(raw_dataset, sess)
    fprint(f"Dataset has {total_sz} frames")

    val_sz = 10000
    tst_sz = 10000
    tng_sz = total_sz - val_sz - tst_sz
    assert tng_sz > 0
    fprint(f"Splitting into {tng_sz}/{val_sz}/{tst_sz} for tng/val/tst")
    tst_dataset = raw_dataset.take(tst_sz)
    val_dataset = raw_dataset.skip(tst_sz).take(val_sz)
    tng_dataset = raw_dataset.skip(tst_sz + val_sz)

    tng_loader = MultiOjectLoader(sess, tng_dataset, background_entities,
                                  tng_sz, cfg.batch_size, cfg.img_size,
                                  cfg.buffer_size)
    val_loader = MultiOjectLoader(sess, val_dataset, background_entities,
                                  val_sz, cfg.batch_size, cfg.img_size,
                                  cfg.buffer_size)
    tst_loader = MultiOjectLoader(sess, tst_dataset, background_entities,
                                  tst_sz, cfg.batch_size, cfg.img_size,
                                  cfg.buffer_size)

    # Throughput stats
    if not cfg.debug:
        loader_throughput(tng_loader)

    return (tng_loader, val_loader, tst_loader)