예제 #1
0
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name.")
예제 #2
0
def get_named_ground_truth_data(name,
                                corr_type='plane',
                                corr_indices=[3, 4],
                                col=None):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "correlated_dsprites_full":
        return dsprites.CorrelatedDSprites([1, 2, 3, 4, 5], corr_indices,
                                           corr_type)
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "correlated_dsprites_noshape":
        return dsprites.CorrelatedDSprites([2, 3, 4, 5], corr_indices,
                                           corr_type)
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "backgroundcolor_dsprites":
        return dsprites.BackgroundColorDSprites([1, 2, 3, 4, 5, 6], col)
    elif name == "correlated_backgroundcolor_dsprites":
        return dsprites.CorrelatedBackgroundColorDSprites([1, 2, 3, 4, 5, 6],
                                                          corr_indices,
                                                          corr_type)
    elif name == "correlated_color_dsprites":
        return dsprites.CorrelatedColorDSprites([1, 2, 3, 4, 5], corr_indices,
                                                corr_type)
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "correlated_noisy_dsprites":
        return dsprites.CorrelatedNoisyDSprites([1, 2, 3, 4, 5], corr_indices,
                                                corr_type)
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "correlated_scream_dsprites":
        return dsprites.CorrelatedScreamDSprites([1, 2, 3, 4, 5], corr_indices,
                                                 corr_type)
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name: " + name + ".")
예제 #3
0
def get_dlib_data(task):
  ut.log("Loading {}".format(task))
  if task == "dsprites":
    # 5 factors
    return dsprites.DSprites(list(range(1, 6)))
  elif task == "shapes3d":
    # 6 factors
    return shapes3d.Shapes3D()
  elif task == "norb":
    # 4 factors + 1 nuisance (which we'll handle via n_dim=2)
    return norb.SmallNORB()
  elif task == "cars3d":
    # 3 factors
    return cars3d.Cars3D()
  elif task == "scream":
    # 5 factors + 2 nuisance (handled as n_dim=2)
    return dsprites.ScreamDSprites(list(range(1, 6)))
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "mpi3d_toy":
        return mpi3d.MPI3D(mode="mpi3d_toy")
    elif name == "mpi3d_realistic":
        return mpi3d.MPI3D(mode="mpi3d_realistic")
    elif name == "mpi3d_real":
        return mpi3d.MPI3D(mode="mpi3d_real")
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    elif name == "modelnet":
        return DisLibGroundTruthData(**MODELNET_PARAMS)
    elif name == "arrow":
        return DisLibGroundTruthData(**ARROW_PARAMS)
    elif name == "pixel4":
        return DisLibGroundTruthData(**WRAPPED_PIXEL4_PARAMS)
    elif name == "pixel8":
        return DisLibGroundTruthData(**WRAPPED_PIXEL8_PARAMS)
    else:
        raise ValueError("Invalid data set name.")
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "threeDotsCache":
        # a large random sample from ThreeDots
        return threeDots.ThreeDotsTrainingCache()
    elif name == "threeDots":
        return threeDots.ThreeDots()
    elif name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "mpi3d_toy":
        return mpi3d.MPI3D(mode="mpi3d_toy")
    elif name == "mpi3d_realistic":
        return mpi3d.MPI3D(mode="mpi3d_realistic")
    elif name == "mpi3d_real":
        return mpi3d.MPI3D(mode="mpi3d_real")
    elif name == "mpi3d_multi_real":
        return mpi3d_multi.MPI3DMulti(mode="mpi3d_real")
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name.")
예제 #6
0
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """
    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "mpi3d_toy":
        return mpi3d.MPI3D(mode="mpi3d_toy")
    elif name == "mpi3d_realistic":
        return mpi3d.MPI3D(mode="mpi3d_realistic")
    elif name == "mpi3d_real":
        return mpi3d.MPI3D(mode="mpi3d_real")
    elif name == "3dshapes":
        return shapes3d.Shapes3D()
    elif name == "3dshapes_holdout" or name == "3dshapes_pca_holdout_s5000":
        return shapes3d_partial.Shapes3DPartial(name)
    elif name == "3dshapes_model_all":
        return shapes3d_partial.Shapes3DPartial(name), None
    elif name[:8] == "3dshapes":
        return shapes3d_partial.Shapes3DPartial(
            name + '_train'), shapes3d_partial.Shapes3DPartial(name + '_valid')
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name.")
예제 #7
0
def create_data(factors):
    """
    :param factors: underlying factors of variation
    :return: data: obervational data from underlying factors
    """
    if FLAGS.data_type == "dsprites":
        dsp = dsprites.DSprites()
    elif FLAGS.data_type == "smallnorb":
        snb = norb.SmallNORB()
    elif FLAGS.data_type == "cars3d":
        cars = cars3d.Cars3D()
    elif FLAGS.data_type == "shapes3d":
        shp = shapes3d.Shapes3D()

    random_state = np.random.RandomState(FLAGS.seed)

    factors_train = np.transpose(factors['factors_train'], (0, 2, 1))
    factors_test = np.transpose(factors['factors_test'], (0, 2, 1))

    N_train = factors_train.shape[0]
    N_test = factors_test.shape[0]
    time_len = factors_train.shape[1]

    if FLAGS.data_type in ["dsprites", "smallnorb"]:
        data_train = np.zeros([N_train, time_len, 64 * 64])
        data_test = np.zeros([N_test, time_len, 64 * 64])
    elif FLAGS.data_type in ["cars3d", "shapes3d"]:
        data_train = np.zeros([N_train, time_len, 64 * 64 * 3])
        data_test = np.zeros([N_test, time_len, 64 * 64 * 3])

    # Training data
    for i in range(N_train):
        if FLAGS.data_type == "dsprites":
            data_point_train = np.squeeze(
                dsp.sample_observations_from_factors_no_color(
                    factors=factors_train[i, :, :], random_state=random_state))
            data_train_reshape = data_point_train.reshape(
                data_point_train.shape[0], 64 * 64)
        elif FLAGS.data_type == "smallnorb":
            data_point_train = np.squeeze(
                snb.sample_observations_from_factors(
                    factors=factors_train[i, :, :], random_state=random_state))
            data_train_reshape = data_point_train.reshape(
                data_point_train.shape[0], 64 * 64)
        elif FLAGS.data_type == "cars3d":
            data_point_train = cars.sample_observations_from_factors(
                factors=factors_train[i, :, :], random_state=random_state)
            data_train_reshape = data_point_train.reshape(
                data_point_train.shape[0], 64 * 64 * 3)
        elif FLAGS.data_type == "shapes3d":
            data_point_train = shp.sample_observations_from_factors(
                factors=factors_train[i, :, :], random_state=random_state)
            data_train_reshape = data_point_train.reshape(
                data_point_train.shape[0], 64 * 64 * 3)
        data_train[i, :, :] = data_train_reshape

    # Test data
    for i in range(N_test):
        if FLAGS.data_type == "dsprites":
            data_point_test = np.squeeze(
                dsp.sample_observations_from_factors_no_color(
                    factors=factors_test[i, :, :], random_state=random_state))
            data_test_reshape = data_point_test.reshape(
                data_point_test.shape[0], 64 * 64)
        elif FLAGS.data_type == "smallnorb":
            data_point_test = np.squeeze(
                snb.sample_observations_from_factors(
                    factors=factors_test[i, :, :], random_state=random_state))
            data_test_reshape = data_point_test.reshape(
                data_point_test.shape[0], 64 * 64)
        elif FLAGS.data_type == "cars3d":
            data_point_test = cars.sample_observations_from_factors(
                factors=factors_test[i, :, :], random_state=random_state)
            data_test_reshape = data_point_test.reshape(
                data_point_test.shape[0], 64 * 64 * 3)
        elif FLAGS.data_type == "shapes3d":
            data_point_test = shp.sample_observations_from_factors(
                factors=factors_test[i, :, :], random_state=random_state)
            data_test_reshape = data_point_test.reshape(
                data_point_test.shape[0], 64 * 64 * 3)
        data_test[i, :, :] = data_test_reshape

    return data_train.astype('float32'), data_test.astype('float32')
예제 #8
0
else:
    pass

if settings.showplot or settings.saveplot:
    import matplotlib

    if not settings.showplot:
        matplotlib.use('Agg')
    import matplotlib.pyplot as plt

# get ground truth factors
if args.dataset == 'cars3d':
    dta = cars3d.Cars3D()
    factors = get_cars_factors()
elif args.dataset == 'norb':
    dta = norb.SmallNORB()
    factors = get_norb_factors()
else:
    raise Error

with hub.eval_function_for_module(dlib_model_path) as f:
    # Save reconstructions.
    inputs = dta.images
    if inputs.ndim < 4:
        inputs = np.expand_dims(inputs, 3)

    targets = f(dict(images=inputs), signature="reconstructions",
                as_dict=True)["images"]

if settings.dataname == 'faces' or settings.dataname == 'faces2' or settings.dataname == 'planes' or settings.dataname == 'cars' or settings.dataname == 'chairs' or settings.dataname == 'dlib_cars3d' or settings.dataname == 'dlib_faces3d':
    input_shape = [64, 64, 3]
예제 #9
0
def get_dataset(dataname, dataroot, numsamples):
    '''
    inputs: x
    targets: c or y (y if using dlib)
    '''

    if numsamples == -1:
        inds = slice(0, None, 1)
    else:
        inds = slice(0, numsamples, 1)

    if dataname == 'dSprites':
        # Dataroot is the directory containing the dsprites .npz file
        dataset_zip = np.load(os.path.join(
            dataroot, 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'),
                              encoding='bytes')
        inputs = np.expand_dims(dataset_zip['imgs'], axis=4)[inds]
        targets = dataset_zip['latents_values'][inds]
    elif dataname == 'faces':
        # Dataroot is the directory containing the images
        filenames = np.array(
            json.load(open(os.path.join(dataroot, 'img_store')))[inds])
        inputs = np.array([
            os.path.join(dataroot, '/'.join(filenames[i].split('/')[1:]))
            for i in range(len(filenames))
        ])
        targets = np.load(os.path.join(dataroot, 'z_store'))[inds]
    elif dataname == 'faces2':
        # Dataroot is the directory containing the images
        filenames = np.array(
            json.load(open(os.path.join(dataroot, 'img_store')))[inds])
        inputs = np.array([
            os.path.join(dataroot, '/'.join(filenames[i].split('/')[1:]))
            for i in range(len(filenames))
        ])
        targets = np.load(os.path.join(dataroot, 'z_store_full'))[inds]
    elif dataname == 'planes':
        # Dataroot is the directory containing the images
        filenames = np.array(
            json.load(open(os.path.join(dataroot, 'img_store_planes')))[inds])
        inputs = np.array([
            os.path.join(dataroot, 'planes',
                         '/'.join(filenames[i].split('/')[-2:]))
            for i in range(len(filenames))
        ])
        targets = np.load(os.path.join(dataroot, 'z_store_trans'))[inds]
    elif dataname == 'cars':
        # Dataroot is the directory containing the images
        filenames = np.array(
            json.load(open(os.path.join(dataroot, 'img_store_cars')))[inds])
        inputs = np.array([
            os.path.join(dataroot, 'cars_for_transfer',
                         '/'.join(filenames[i].split('/')[-2:]))
            for i in range(len(filenames))
        ])
        targets = np.load(os.path.join(dataroot, 'z_store_new_cars'))[inds]
    elif dataname == 'chairs':
        # Dataroot is the directory containing the images
        filenames = np.array(
            json.load(open(os.path.join(dataroot, 'img_store_chairs')))[inds])
        inputs = np.array([
            os.path.join(dataroot, 'chairs_for_transfer',
                         '/'.join(filenames[i].split('/')[-2:]))
            for i in range(len(filenames))
        ])
        targets = np.load(os.path.join(dataroot, 'z_store_new_chairs'))[inds]
    elif dataname == 'dlib_cars3d':
        ## dataroot: path to the tensorflow_hub directory

        dta = cars3d.Cars3D()
        with hub.eval_function_for_module(dataroot) as f:
            # Save reconstructions.
            inputs = dta.images
            targets = f(dict(images=inputs),
                        signature="reconstructions",
                        as_dict=True)["images"]
    elif dataname == 'dlib_smallnorb':
        ## dataroot: path to the tensorflow_hub directory

        dta = norb.SmallNORB()
        with hub.eval_function_for_module(dataroot) as f:
            # Save reconstructions.
            inputs1 = np.expand_dims(dta.images, 3)[:25000]
            targets1 = f(dict(images=inputs1),
                         signature="reconstructions",
                         as_dict=True)["images"]
            inputs2 = np.expand_dims(dta.images, 3)[25000:]
            targets2 = f(dict(images=inputs2),
                         signature="reconstructions",
                         as_dict=True)["images"]
            inputs = np.concatenate((inputs1, inputs2), axis=0)
            targets = np.concatenate((targets1, targets2), axis=0)
        print(inputs.shape)
        print(targets.shape)
    elif dataname == 'dlib_faces3d':
        ## dataroot: path to the tensorflow_hub directory

        dta = faces3d.Faces3D()
        with hub.eval_function_for_module(dataroot) as f:
            # Save reconstructions.
            inputs = dta.images
            targets = f(dict(images=inputs),
                        signature="reconstructions",
                        as_dict=True)["images"]
        print(dataname)
        print(inputs.shape)
        print(targets.shape)
    else:
        raise NotImplementedError
    return (inputs, targets)