예제 #1
0
def make_unmnist(config):
    """Creates the MNIST dataset."""
    def to_float(x):
        return tf.to_float(x) / 255.

    transform = [to_float]

    if config.canvas_size != 28:
        transform.append(
            functools.partial(preprocess.pad_and_shift,
                              output_size=config.canvas_size,
                              shift=None))

    batch_size = config.batch_size
    trainset = image.create('mnist',
                            subset='train',
                            batch_size=batch_size,
                            transforms=transform)
    del trainset["label"]
    validset = image.create('mnist',
                            subset='test',
                            batch_size=batch_size,
                            transforms=transform)
    del validset["label"]

    res = AttrDict(trainset=trainset, validset=validset)
    return res
예제 #2
0
def make_svhn(config):
    """Creates the svhn dataset."""
    def to_float(x):
        return tf.to_float(x) / 255.

    transform = [to_float]

    if config.canvas_size != 32:
        transform.append(
            functools.partial(preprocess.pad_and_shift,
                              output_size=config.canvas_size,
                              shift=None))
    #transform.append(
    #    functools.partial(preprocess.normalized_sobel_edges))

    batch_size = config.batch_size
    res = AttrDict(trainset=image.create('svhn',
                                         subset='train',
                                         batch_size=batch_size,
                                         transforms=transform),
                   validset=image.create('svhn',
                                         subset='test',
                                         batch_size=batch_size,
                                         transforms=transform))

    return res
def make_dataset256(config):
    """Creates the dataset_256 dataset."""

    # data is created online, so there is no point in having
    # a separate dataset for validation

    def to_float(x):
        return tf.to_float(x) / 255.

    transform = [to_float]
    transform.append(
        functools.partial(preprocess.pad_and_shift,
                          output_size=config.canvas_size,
                          shift=None))

    batch_size = config.batch_size

    res = AttrDict(trainset=image.create('dataset256',
                                         subset='train',
                                         batch_size=batch_size,
                                         transforms=transform),
                   validset=image.create('dataset256',
                                         subset='test',
                                         batch_size=batch_size,
                                         transforms=transform))
    return res
예제 #4
0
def make_clevr_veggies(config):
    """Creates the CLEVR Veggies dataset."""
    def to_float(x):
        return tf.to_float(x) / 255.

    transform = [to_float]

    batch_size = config.batch_size
    res = AttrDict(trainset=image.create('clevr_veggies',
                                         subset='train',
                                         batch_size=batch_size,
                                         transforms=transform),
                   validset=image.create('clevr_veggies',
                                         subset='val',
                                         batch_size=batch_size,
                                         transforms=transform))

    return res