示例#1
0
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name.")
示例#2
0
def get_named_ground_truth_data(name,
                                corr_type='plane',
                                corr_indices=[3, 4],
                                col=None):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "correlated_dsprites_full":
        return dsprites.CorrelatedDSprites([1, 2, 3, 4, 5], corr_indices,
                                           corr_type)
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "correlated_dsprites_noshape":
        return dsprites.CorrelatedDSprites([2, 3, 4, 5], corr_indices,
                                           corr_type)
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "backgroundcolor_dsprites":
        return dsprites.BackgroundColorDSprites([1, 2, 3, 4, 5, 6], col)
    elif name == "correlated_backgroundcolor_dsprites":
        return dsprites.CorrelatedBackgroundColorDSprites([1, 2, 3, 4, 5, 6],
                                                          corr_indices,
                                                          corr_type)
    elif name == "correlated_color_dsprites":
        return dsprites.CorrelatedColorDSprites([1, 2, 3, 4, 5], corr_indices,
                                                corr_type)
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "correlated_noisy_dsprites":
        return dsprites.CorrelatedNoisyDSprites([1, 2, 3, 4, 5], corr_indices,
                                                corr_type)
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "correlated_scream_dsprites":
        return dsprites.CorrelatedScreamDSprites([1, 2, 3, 4, 5], corr_indices,
                                                 corr_type)
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name: " + name + ".")
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "mpi3d_toy":
        return mpi3d.MPI3D(mode="mpi3d_toy")
    elif name == "mpi3d_realistic":
        return mpi3d.MPI3D(mode="mpi3d_realistic")
    elif name == "mpi3d_real":
        return mpi3d.MPI3D(mode="mpi3d_real")
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    elif name == "modelnet":
        return DisLibGroundTruthData(**MODELNET_PARAMS)
    elif name == "arrow":
        return DisLibGroundTruthData(**ARROW_PARAMS)
    elif name == "pixel4":
        return DisLibGroundTruthData(**WRAPPED_PIXEL4_PARAMS)
    elif name == "pixel8":
        return DisLibGroundTruthData(**WRAPPED_PIXEL8_PARAMS)
    else:
        raise ValueError("Invalid data set name.")
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "threeDotsCache":
        # a large random sample from ThreeDots
        return threeDots.ThreeDotsTrainingCache()
    elif name == "threeDots":
        return threeDots.ThreeDots()
    elif name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "mpi3d_toy":
        return mpi3d.MPI3D(mode="mpi3d_toy")
    elif name == "mpi3d_realistic":
        return mpi3d.MPI3D(mode="mpi3d_realistic")
    elif name == "mpi3d_real":
        return mpi3d.MPI3D(mode="mpi3d_real")
    elif name == "mpi3d_multi_real":
        return mpi3d_multi.MPI3DMulti(mode="mpi3d_real")
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name.")
示例#5
0
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """
    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "mpi3d_toy":
        return mpi3d.MPI3D(mode="mpi3d_toy")
    elif name == "mpi3d_realistic":
        return mpi3d.MPI3D(mode="mpi3d_realistic")
    elif name == "mpi3d_real":
        return mpi3d.MPI3D(mode="mpi3d_real")
    elif name == "3dshapes":
        return shapes3d.Shapes3D()
    elif name == "3dshapes_holdout" or name == "3dshapes_pca_holdout_s5000":
        return shapes3d_partial.Shapes3DPartial(name)
    elif name == "3dshapes_model_all":
        return shapes3d_partial.Shapes3DPartial(name), None
    elif name[:8] == "3dshapes":
        return shapes3d_partial.Shapes3DPartial(
            name + '_train'), shapes3d_partial.Shapes3DPartial(name + '_valid')
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name.")
示例#6
0
def get_dlib_data(task):
  ut.log("Loading {}".format(task))
  if task == "dsprites":
    # 5 factors
    return dsprites.DSprites(list(range(1, 6)))
  elif task == "shapes3d":
    # 6 factors
    return shapes3d.Shapes3D()
  elif task == "norb":
    # 4 factors + 1 nuisance (which we'll handle via n_dim=2)
    return norb.SmallNORB()
  elif task == "cars3d":
    # 3 factors
    return cars3d.Cars3D()
  elif task == "scream":
    # 5 factors + 2 nuisance (handled as n_dim=2)
    return dsprites.ScreamDSprites(list(range(1, 6)))
示例#7
0
def create_data(factors):
    """
    :param factors: underlying factors of variation
    :return: data: obervational data from underlying factors
    """
    if FLAGS.data_type == "dsprites":
        dsp = dsprites.DSprites()
    elif FLAGS.data_type == "smallnorb":
        snb = norb.SmallNORB()
    elif FLAGS.data_type == "cars3d":
        cars = cars3d.Cars3D()
    elif FLAGS.data_type == "shapes3d":
        shp = shapes3d.Shapes3D()

    random_state = np.random.RandomState(FLAGS.seed)

    factors_train = np.transpose(factors['factors_train'], (0, 2, 1))
    factors_test = np.transpose(factors['factors_test'], (0, 2, 1))

    N_train = factors_train.shape[0]
    N_test = factors_test.shape[0]
    time_len = factors_train.shape[1]

    if FLAGS.data_type in ["dsprites", "smallnorb"]:
        data_train = np.zeros([N_train, time_len, 64 * 64])
        data_test = np.zeros([N_test, time_len, 64 * 64])
    elif FLAGS.data_type in ["cars3d", "shapes3d"]:
        data_train = np.zeros([N_train, time_len, 64 * 64 * 3])
        data_test = np.zeros([N_test, time_len, 64 * 64 * 3])

    # Training data
    for i in range(N_train):
        if FLAGS.data_type == "dsprites":
            data_point_train = np.squeeze(
                dsp.sample_observations_from_factors_no_color(
                    factors=factors_train[i, :, :], random_state=random_state))
            data_train_reshape = data_point_train.reshape(
                data_point_train.shape[0], 64 * 64)
        elif FLAGS.data_type == "smallnorb":
            data_point_train = np.squeeze(
                snb.sample_observations_from_factors(
                    factors=factors_train[i, :, :], random_state=random_state))
            data_train_reshape = data_point_train.reshape(
                data_point_train.shape[0], 64 * 64)
        elif FLAGS.data_type == "cars3d":
            data_point_train = cars.sample_observations_from_factors(
                factors=factors_train[i, :, :], random_state=random_state)
            data_train_reshape = data_point_train.reshape(
                data_point_train.shape[0], 64 * 64 * 3)
        elif FLAGS.data_type == "shapes3d":
            data_point_train = shp.sample_observations_from_factors(
                factors=factors_train[i, :, :], random_state=random_state)
            data_train_reshape = data_point_train.reshape(
                data_point_train.shape[0], 64 * 64 * 3)
        data_train[i, :, :] = data_train_reshape

    # Test data
    for i in range(N_test):
        if FLAGS.data_type == "dsprites":
            data_point_test = np.squeeze(
                dsp.sample_observations_from_factors_no_color(
                    factors=factors_test[i, :, :], random_state=random_state))
            data_test_reshape = data_point_test.reshape(
                data_point_test.shape[0], 64 * 64)
        elif FLAGS.data_type == "smallnorb":
            data_point_test = np.squeeze(
                snb.sample_observations_from_factors(
                    factors=factors_test[i, :, :], random_state=random_state))
            data_test_reshape = data_point_test.reshape(
                data_point_test.shape[0], 64 * 64)
        elif FLAGS.data_type == "cars3d":
            data_point_test = cars.sample_observations_from_factors(
                factors=factors_test[i, :, :], random_state=random_state)
            data_test_reshape = data_point_test.reshape(
                data_point_test.shape[0], 64 * 64 * 3)
        elif FLAGS.data_type == "shapes3d":
            data_point_test = shp.sample_observations_from_factors(
                factors=factors_test[i, :, :], random_state=random_state)
            data_test_reshape = data_point_test.reshape(
                data_point_test.shape[0], 64 * 64 * 3)
        data_test[i, :, :] = data_test_reshape

    return data_train.astype('float32'), data_test.astype('float32')
示例#8
0
    def __init__(self, dataset_name, batch_size, random_seed=42):
        self.dataset_name = dataset_name
        self.random_state = np.random.RandomState(random_seed)

        self.dataset = dsprites.DSprites(list(range(1, 6)))
        self.transform = self._set_transforms()