def get_named_ground_truth_data(name): """Returns ground truth data set based on name. Args: name: String with the name of the dataset. Raises: ValueError: if an invalid data set name is provided. """ if name == "dsprites_full": return dsprites.DSprites([1, 2, 3, 4, 5]) elif name == "dsprites_noshape": return dsprites.DSprites([2, 3, 4, 5]) elif name == "color_dsprites": return dsprites.ColorDSprites([1, 2, 3, 4, 5]) elif name == "noisy_dsprites": return dsprites.NoisyDSprites([1, 2, 3, 4, 5]) elif name == "scream_dsprites": return dsprites.ScreamDSprites([1, 2, 3, 4, 5]) elif name == "smallnorb": return norb.SmallNORB() elif name == "cars3d": return cars3d.Cars3D() elif name == "shapes3d": return shapes3d.Shapes3D() elif name == "dummy_data": return dummy_data.DummyData() else: raise ValueError("Invalid data set name.")
def get_named_ground_truth_data(name, corr_type='plane', corr_indices=[3, 4], col=None): """Returns ground truth data set based on name. Args: name: String with the name of the dataset. Raises: ValueError: if an invalid data set name is provided. """ if name == "dsprites_full": return dsprites.DSprites([1, 2, 3, 4, 5]) elif name == "correlated_dsprites_full": return dsprites.CorrelatedDSprites([1, 2, 3, 4, 5], corr_indices, corr_type) elif name == "dsprites_noshape": return dsprites.DSprites([2, 3, 4, 5]) elif name == "correlated_dsprites_noshape": return dsprites.CorrelatedDSprites([2, 3, 4, 5], corr_indices, corr_type) elif name == "color_dsprites": return dsprites.ColorDSprites([1, 2, 3, 4, 5]) elif name == "backgroundcolor_dsprites": return dsprites.BackgroundColorDSprites([1, 2, 3, 4, 5, 6], col) elif name == "correlated_backgroundcolor_dsprites": return dsprites.CorrelatedBackgroundColorDSprites([1, 2, 3, 4, 5, 6], corr_indices, corr_type) elif name == "correlated_color_dsprites": return dsprites.CorrelatedColorDSprites([1, 2, 3, 4, 5], corr_indices, corr_type) elif name == "noisy_dsprites": return dsprites.NoisyDSprites([1, 2, 3, 4, 5]) elif name == "correlated_noisy_dsprites": return dsprites.CorrelatedNoisyDSprites([1, 2, 3, 4, 5], corr_indices, corr_type) elif name == "scream_dsprites": return dsprites.ScreamDSprites([1, 2, 3, 4, 5]) elif name == "correlated_scream_dsprites": return dsprites.CorrelatedScreamDSprites([1, 2, 3, 4, 5], corr_indices, corr_type) elif name == "smallnorb": return norb.SmallNORB() elif name == "cars3d": return cars3d.Cars3D() elif name == "shapes3d": return shapes3d.Shapes3D() elif name == "dummy_data": return dummy_data.DummyData() else: raise ValueError("Invalid data set name: " + name + ".")
def get_dlib_data(task): ut.log("Loading {}".format(task)) if task == "dsprites": # 5 factors return dsprites.DSprites(list(range(1, 6))) elif task == "shapes3d": # 6 factors return shapes3d.Shapes3D() elif task == "norb": # 4 factors + 1 nuisance (which we'll handle via n_dim=2) return norb.SmallNORB() elif task == "cars3d": # 3 factors return cars3d.Cars3D() elif task == "scream": # 5 factors + 2 nuisance (handled as n_dim=2) return dsprites.ScreamDSprites(list(range(1, 6)))
def get_named_ground_truth_data(name): """Returns ground truth data set based on name. Args: name: String with the name of the dataset. Raises: ValueError: if an invalid data set name is provided. """ if name == "dsprites_full": return dsprites.DSprites([1, 2, 3, 4, 5]) elif name == "dsprites_noshape": return dsprites.DSprites([2, 3, 4, 5]) elif name == "color_dsprites": return dsprites.ColorDSprites([1, 2, 3, 4, 5]) elif name == "noisy_dsprites": return dsprites.NoisyDSprites([1, 2, 3, 4, 5]) elif name == "scream_dsprites": return dsprites.ScreamDSprites([1, 2, 3, 4, 5]) elif name == "smallnorb": return norb.SmallNORB() elif name == "cars3d": return cars3d.Cars3D() elif name == "mpi3d_toy": return mpi3d.MPI3D(mode="mpi3d_toy") elif name == "mpi3d_realistic": return mpi3d.MPI3D(mode="mpi3d_realistic") elif name == "mpi3d_real": return mpi3d.MPI3D(mode="mpi3d_real") elif name == "shapes3d": return shapes3d.Shapes3D() elif name == "dummy_data": return dummy_data.DummyData() elif name == "modelnet": return DisLibGroundTruthData(**MODELNET_PARAMS) elif name == "arrow": return DisLibGroundTruthData(**ARROW_PARAMS) elif name == "pixel4": return DisLibGroundTruthData(**WRAPPED_PIXEL4_PARAMS) elif name == "pixel8": return DisLibGroundTruthData(**WRAPPED_PIXEL8_PARAMS) else: raise ValueError("Invalid data set name.")
def get_named_ground_truth_data(name): """Returns ground truth data set based on name. Args: name: String with the name of the dataset. Raises: ValueError: if an invalid data set name is provided. """ if name == "threeDotsCache": # a large random sample from ThreeDots return threeDots.ThreeDotsTrainingCache() elif name == "threeDots": return threeDots.ThreeDots() elif name == "dsprites_full": return dsprites.DSprites([1, 2, 3, 4, 5]) elif name == "dsprites_noshape": return dsprites.DSprites([2, 3, 4, 5]) elif name == "color_dsprites": return dsprites.ColorDSprites([1, 2, 3, 4, 5]) elif name == "noisy_dsprites": return dsprites.NoisyDSprites([1, 2, 3, 4, 5]) elif name == "scream_dsprites": return dsprites.ScreamDSprites([1, 2, 3, 4, 5]) elif name == "smallnorb": return norb.SmallNORB() elif name == "cars3d": return cars3d.Cars3D() elif name == "mpi3d_toy": return mpi3d.MPI3D(mode="mpi3d_toy") elif name == "mpi3d_realistic": return mpi3d.MPI3D(mode="mpi3d_realistic") elif name == "mpi3d_real": return mpi3d.MPI3D(mode="mpi3d_real") elif name == "mpi3d_multi_real": return mpi3d_multi.MPI3DMulti(mode="mpi3d_real") elif name == "shapes3d": return shapes3d.Shapes3D() elif name == "dummy_data": return dummy_data.DummyData() else: raise ValueError("Invalid data set name.")
def get_named_ground_truth_data(name): """Returns ground truth data set based on name. Args: name: String with the name of the dataset. Raises: ValueError: if an invalid data set name is provided. """ if name == "dsprites_full": return dsprites.DSprites([1, 2, 3, 4, 5]) elif name == "dsprites_noshape": return dsprites.DSprites([2, 3, 4, 5]) elif name == "color_dsprites": return dsprites.ColorDSprites([1, 2, 3, 4, 5]) elif name == "noisy_dsprites": return dsprites.NoisyDSprites([1, 2, 3, 4, 5]) elif name == "scream_dsprites": return dsprites.ScreamDSprites([1, 2, 3, 4, 5]) elif name == "smallnorb": return norb.SmallNORB() elif name == "cars3d": return cars3d.Cars3D() elif name == "mpi3d_toy": return mpi3d.MPI3D(mode="mpi3d_toy") elif name == "mpi3d_realistic": return mpi3d.MPI3D(mode="mpi3d_realistic") elif name == "mpi3d_real": return mpi3d.MPI3D(mode="mpi3d_real") elif name == "3dshapes": return shapes3d.Shapes3D() elif name == "3dshapes_holdout" or name == "3dshapes_pca_holdout_s5000": return shapes3d_partial.Shapes3DPartial(name) elif name == "3dshapes_model_all": return shapes3d_partial.Shapes3DPartial(name), None elif name[:8] == "3dshapes": return shapes3d_partial.Shapes3DPartial( name + '_train'), shapes3d_partial.Shapes3DPartial(name + '_valid') elif name == "dummy_data": return dummy_data.DummyData() else: raise ValueError("Invalid data set name.")
def create_data(factors): """ :param factors: underlying factors of variation :return: data: obervational data from underlying factors """ if FLAGS.data_type == "dsprites": dsp = dsprites.DSprites() elif FLAGS.data_type == "smallnorb": snb = norb.SmallNORB() elif FLAGS.data_type == "cars3d": cars = cars3d.Cars3D() elif FLAGS.data_type == "shapes3d": shp = shapes3d.Shapes3D() random_state = np.random.RandomState(FLAGS.seed) factors_train = np.transpose(factors['factors_train'], (0, 2, 1)) factors_test = np.transpose(factors['factors_test'], (0, 2, 1)) N_train = factors_train.shape[0] N_test = factors_test.shape[0] time_len = factors_train.shape[1] if FLAGS.data_type in ["dsprites", "smallnorb"]: data_train = np.zeros([N_train, time_len, 64 * 64]) data_test = np.zeros([N_test, time_len, 64 * 64]) elif FLAGS.data_type in ["cars3d", "shapes3d"]: data_train = np.zeros([N_train, time_len, 64 * 64 * 3]) data_test = np.zeros([N_test, time_len, 64 * 64 * 3]) # Training data for i in range(N_train): if FLAGS.data_type == "dsprites": data_point_train = np.squeeze( dsp.sample_observations_from_factors_no_color( factors=factors_train[i, :, :], random_state=random_state)) data_train_reshape = data_point_train.reshape( data_point_train.shape[0], 64 * 64) elif FLAGS.data_type == "smallnorb": data_point_train = np.squeeze( snb.sample_observations_from_factors( factors=factors_train[i, :, :], random_state=random_state)) data_train_reshape = data_point_train.reshape( data_point_train.shape[0], 64 * 64) elif FLAGS.data_type == "cars3d": data_point_train = cars.sample_observations_from_factors( factors=factors_train[i, :, :], random_state=random_state) data_train_reshape = data_point_train.reshape( data_point_train.shape[0], 64 * 64 * 3) elif FLAGS.data_type == "shapes3d": data_point_train = shp.sample_observations_from_factors( factors=factors_train[i, :, :], random_state=random_state) data_train_reshape = data_point_train.reshape( data_point_train.shape[0], 64 * 64 * 3) data_train[i, :, :] = data_train_reshape # Test data for i in range(N_test): if FLAGS.data_type == "dsprites": data_point_test = np.squeeze( dsp.sample_observations_from_factors_no_color( factors=factors_test[i, :, :], random_state=random_state)) data_test_reshape = data_point_test.reshape( data_point_test.shape[0], 64 * 64) elif FLAGS.data_type == "smallnorb": data_point_test = np.squeeze( snb.sample_observations_from_factors( factors=factors_test[i, :, :], random_state=random_state)) data_test_reshape = data_point_test.reshape( data_point_test.shape[0], 64 * 64) elif FLAGS.data_type == "cars3d": data_point_test = cars.sample_observations_from_factors( factors=factors_test[i, :, :], random_state=random_state) data_test_reshape = data_point_test.reshape( data_point_test.shape[0], 64 * 64 * 3) elif FLAGS.data_type == "shapes3d": data_point_test = shp.sample_observations_from_factors( factors=factors_test[i, :, :], random_state=random_state) data_test_reshape = data_point_test.reshape( data_point_test.shape[0], 64 * 64 * 3) data_test[i, :, :] = data_test_reshape return data_train.astype('float32'), data_test.astype('float32')
else: pass if settings.showplot or settings.saveplot: import matplotlib if not settings.showplot: matplotlib.use('Agg') import matplotlib.pyplot as plt # get ground truth factors if args.dataset == 'cars3d': dta = cars3d.Cars3D() factors = get_cars_factors() elif args.dataset == 'norb': dta = norb.SmallNORB() factors = get_norb_factors() else: raise Error with hub.eval_function_for_module(dlib_model_path) as f: # Save reconstructions. inputs = dta.images if inputs.ndim < 4: inputs = np.expand_dims(inputs, 3) targets = f(dict(images=inputs), signature="reconstructions", as_dict=True)["images"] if settings.dataname == 'faces' or settings.dataname == 'faces2' or settings.dataname == 'planes' or settings.dataname == 'cars' or settings.dataname == 'chairs' or settings.dataname == 'dlib_cars3d' or settings.dataname == 'dlib_faces3d': input_shape = [64, 64, 3]
def get_dataset(dataname, dataroot, numsamples): ''' inputs: x targets: c or y (y if using dlib) ''' if numsamples == -1: inds = slice(0, None, 1) else: inds = slice(0, numsamples, 1) if dataname == 'dSprites': # Dataroot is the directory containing the dsprites .npz file dataset_zip = np.load(os.path.join( dataroot, 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'), encoding='bytes') inputs = np.expand_dims(dataset_zip['imgs'], axis=4)[inds] targets = dataset_zip['latents_values'][inds] elif dataname == 'faces': # Dataroot is the directory containing the images filenames = np.array( json.load(open(os.path.join(dataroot, 'img_store')))[inds]) inputs = np.array([ os.path.join(dataroot, '/'.join(filenames[i].split('/')[1:])) for i in range(len(filenames)) ]) targets = np.load(os.path.join(dataroot, 'z_store'))[inds] elif dataname == 'faces2': # Dataroot is the directory containing the images filenames = np.array( json.load(open(os.path.join(dataroot, 'img_store')))[inds]) inputs = np.array([ os.path.join(dataroot, '/'.join(filenames[i].split('/')[1:])) for i in range(len(filenames)) ]) targets = np.load(os.path.join(dataroot, 'z_store_full'))[inds] elif dataname == 'planes': # Dataroot is the directory containing the images filenames = np.array( json.load(open(os.path.join(dataroot, 'img_store_planes')))[inds]) inputs = np.array([ os.path.join(dataroot, 'planes', '/'.join(filenames[i].split('/')[-2:])) for i in range(len(filenames)) ]) targets = np.load(os.path.join(dataroot, 'z_store_trans'))[inds] elif dataname == 'cars': # Dataroot is the directory containing the images filenames = np.array( json.load(open(os.path.join(dataroot, 'img_store_cars')))[inds]) inputs = np.array([ os.path.join(dataroot, 'cars_for_transfer', '/'.join(filenames[i].split('/')[-2:])) for i in range(len(filenames)) ]) targets = np.load(os.path.join(dataroot, 'z_store_new_cars'))[inds] elif dataname == 'chairs': # Dataroot is the directory containing the images filenames = np.array( json.load(open(os.path.join(dataroot, 'img_store_chairs')))[inds]) inputs = np.array([ os.path.join(dataroot, 'chairs_for_transfer', '/'.join(filenames[i].split('/')[-2:])) for i in range(len(filenames)) ]) targets = np.load(os.path.join(dataroot, 'z_store_new_chairs'))[inds] elif dataname == 'dlib_cars3d': ## dataroot: path to the tensorflow_hub directory dta = cars3d.Cars3D() with hub.eval_function_for_module(dataroot) as f: # Save reconstructions. inputs = dta.images targets = f(dict(images=inputs), signature="reconstructions", as_dict=True)["images"] elif dataname == 'dlib_smallnorb': ## dataroot: path to the tensorflow_hub directory dta = norb.SmallNORB() with hub.eval_function_for_module(dataroot) as f: # Save reconstructions. inputs1 = np.expand_dims(dta.images, 3)[:25000] targets1 = f(dict(images=inputs1), signature="reconstructions", as_dict=True)["images"] inputs2 = np.expand_dims(dta.images, 3)[25000:] targets2 = f(dict(images=inputs2), signature="reconstructions", as_dict=True)["images"] inputs = np.concatenate((inputs1, inputs2), axis=0) targets = np.concatenate((targets1, targets2), axis=0) print(inputs.shape) print(targets.shape) elif dataname == 'dlib_faces3d': ## dataroot: path to the tensorflow_hub directory dta = faces3d.Faces3D() with hub.eval_function_for_module(dataroot) as f: # Save reconstructions. inputs = dta.images targets = f(dict(images=inputs), signature="reconstructions", as_dict=True)["images"] print(dataname) print(inputs.shape) print(targets.shape) else: raise NotImplementedError return (inputs, targets)