Example #1
0
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name.")
def get_named_ground_truth_data(name,
                                corr_type='plane',
                                corr_indices=[3, 4],
                                col=None):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "correlated_dsprites_full":
        return dsprites.CorrelatedDSprites([1, 2, 3, 4, 5], corr_indices,
                                           corr_type)
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "correlated_dsprites_noshape":
        return dsprites.CorrelatedDSprites([2, 3, 4, 5], corr_indices,
                                           corr_type)
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "backgroundcolor_dsprites":
        return dsprites.BackgroundColorDSprites([1, 2, 3, 4, 5, 6], col)
    elif name == "correlated_backgroundcolor_dsprites":
        return dsprites.CorrelatedBackgroundColorDSprites([1, 2, 3, 4, 5, 6],
                                                          corr_indices,
                                                          corr_type)
    elif name == "correlated_color_dsprites":
        return dsprites.CorrelatedColorDSprites([1, 2, 3, 4, 5], corr_indices,
                                                corr_type)
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "correlated_noisy_dsprites":
        return dsprites.CorrelatedNoisyDSprites([1, 2, 3, 4, 5], corr_indices,
                                                corr_type)
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "correlated_scream_dsprites":
        return dsprites.CorrelatedScreamDSprites([1, 2, 3, 4, 5], corr_indices,
                                                 corr_type)
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name: " + name + ".")
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "mpi3d_toy":
        return mpi3d.MPI3D(mode="mpi3d_toy")
    elif name == "mpi3d_realistic":
        return mpi3d.MPI3D(mode="mpi3d_realistic")
    elif name == "mpi3d_real":
        return mpi3d.MPI3D(mode="mpi3d_real")
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    elif name == "modelnet":
        return DisLibGroundTruthData(**MODELNET_PARAMS)
    elif name == "arrow":
        return DisLibGroundTruthData(**ARROW_PARAMS)
    elif name == "pixel4":
        return DisLibGroundTruthData(**WRAPPED_PIXEL4_PARAMS)
    elif name == "pixel8":
        return DisLibGroundTruthData(**WRAPPED_PIXEL8_PARAMS)
    else:
        raise ValueError("Invalid data set name.")
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """

    if name == "threeDotsCache":
        # a large random sample from ThreeDots
        return threeDots.ThreeDotsTrainingCache()
    elif name == "threeDots":
        return threeDots.ThreeDots()
    elif name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "mpi3d_toy":
        return mpi3d.MPI3D(mode="mpi3d_toy")
    elif name == "mpi3d_realistic":
        return mpi3d.MPI3D(mode="mpi3d_realistic")
    elif name == "mpi3d_real":
        return mpi3d.MPI3D(mode="mpi3d_real")
    elif name == "mpi3d_multi_real":
        return mpi3d_multi.MPI3DMulti(mode="mpi3d_real")
    elif name == "shapes3d":
        return shapes3d.Shapes3D()
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name.")
Example #5
0
def get_named_ground_truth_data(name):
    """Returns ground truth data set based on name.

  Args:
    name: String with the name of the dataset.

  Raises:
    ValueError: if an invalid data set name is provided.
  """
    if name == "dsprites_full":
        return dsprites.DSprites([1, 2, 3, 4, 5])
    elif name == "dsprites_noshape":
        return dsprites.DSprites([2, 3, 4, 5])
    elif name == "color_dsprites":
        return dsprites.ColorDSprites([1, 2, 3, 4, 5])
    elif name == "noisy_dsprites":
        return dsprites.NoisyDSprites([1, 2, 3, 4, 5])
    elif name == "scream_dsprites":
        return dsprites.ScreamDSprites([1, 2, 3, 4, 5])
    elif name == "smallnorb":
        return norb.SmallNORB()
    elif name == "cars3d":
        return cars3d.Cars3D()
    elif name == "mpi3d_toy":
        return mpi3d.MPI3D(mode="mpi3d_toy")
    elif name == "mpi3d_realistic":
        return mpi3d.MPI3D(mode="mpi3d_realistic")
    elif name == "mpi3d_real":
        return mpi3d.MPI3D(mode="mpi3d_real")
    elif name == "3dshapes":
        return shapes3d.Shapes3D()
    elif name == "3dshapes_holdout" or name == "3dshapes_pca_holdout_s5000":
        return shapes3d_partial.Shapes3DPartial(name)
    elif name == "3dshapes_model_all":
        return shapes3d_partial.Shapes3DPartial(name), None
    elif name[:8] == "3dshapes":
        return shapes3d_partial.Shapes3DPartial(
            name + '_train'), shapes3d_partial.Shapes3DPartial(name + '_valid')
    elif name == "dummy_data":
        return dummy_data.DummyData()
    else:
        raise ValueError("Invalid data set name.")