def __init__(self): self.images, features = _load_small_norb_chunks( SMALLNORB_TEMPLATE, SMALLNORB_CHUNKS) self.image_leng = self.images.shape[0] self.factor_sizes = [5, 10, 9, 18, 6] # Instances are not part of the latent space. self.latent_factor_indices = [0, 2, 3, 4] self.num_total_factors = features.shape[1] self.index = util.StateSpaceAtomIndex(self.factor_sizes, features) self.state_space = util.SplitDiscreteStateSpace( self.factor_sizes, self.latent_factor_indices)
def __init__(self): self.factor_sizes = [4, 24, 183] features = cartesian([np.array(list(range(i))) for i in self.factor_sizes]) self.latent_factor_indices = [0, 1, 2] self.num_total_factors = features.shape[1] self.index = util.StateSpaceAtomIndex(self.factor_sizes, features) self.factor_bases = self.index.factor_bases self.state_space = util.SplitDiscreteStateSpace(self.factor_sizes, self.latent_factor_indices) self.data_shape = [64, 64, 3] self.images = self._load_data()
def __init__(self, mode="mpi3d_toy"): if mode == "mpi3d_toy": mpi3d_path = os.path.join( os.environ.get("DISENTANGLEMENT_LIB_DATA", "."), "mpi3d_toy", "mpi3d_toy.npz") if not tf.io.gfile.exists(mpi3d_path): raise ValueError( "Dataset '{}' not found. Make sure the dataset is publicly available and downloaded correctly." .format(mode)) else: with tf.io.gfile.GFile(mpi3d_path, "rb") as f: data = np.load(f) self.factor_sizes = [4, 4, 2, 3, 3, 40, 40] elif mode == "mpi3d_realistic": mpi3d_path = os.path.join( os.environ.get("DISENTANGLEMENT_LIB_DATA", "."), "mpi3d_realistic", "mpi3d_realistic.npz") if not tf.io.gfile.exists(mpi3d_path): raise ValueError( "Dataset '{}' not found. Make sure the dataset is publicly available and downloaded correctly." .format(mode)) else: with tf.io.gfile.GFile(mpi3d_path, "rb") as f: data = np.load(f) self.factor_sizes = [4, 4, 2, 3, 3, 40, 40] elif mode == "mpi3d_real": mpi3d_path = os.path.join( os.environ.get("DISENTANGLEMENT_LIB_DATA", "."), "mpi3d_real", "mpi3d_real.npz") if not tf.io.gfile.exists(mpi3d_path): raise ValueError( "Dataset '{}' not found. Make sure the dataset is publicly available and downloaded correctly." .format(mode)) else: with tf.io.gfile.GFile(mpi3d_path, "rb") as f: data = np.load(f) self.factor_sizes = [6, 6, 2, 3, 3, 40, 40] else: raise ValueError("Unknown mode provided.") self.images = data["images"] self.image_leng = self.images.shape[0] self.latent_factor_indices = [0, 1, 2, 3, 4, 5, 6] self.num_total_factors = 7 self.state_space = util.SplitDiscreteStateSpace( self.factor_sizes, self.latent_factor_indices) self.factor_bases = np.prod(self.factor_sizes) / np.cumprod( self.factor_sizes)
def __init__(self): data = h5py.File(SHAPES3D_PATH, 'r') # images = data["images"].value # labels = data["labels"].value images = data['images'][()] labels = data["labels"][()] n_samples = 480000 self.images = ( images.reshape([n_samples, 64, 64, 3]).astype(np.float32) / 255.) features = labels.reshape([n_samples, 6]) self.factor_sizes = [10, 10, 10, 8, 4, 15] self.latent_factor_indices = list(range(6)) self.num_total_factors = features.shape[1] self.state_space = util.SplitDiscreteStateSpace( self.factor_sizes, self.latent_factor_indices) self.factor_bases = np.prod(self.factor_sizes) / np.cumprod( self.factor_sizes)
def __init__(self): with tf.gfile.GFile(SHAPES3D_PATH, "rb") as f: # Data was saved originally using python2, so we need to set the encoding. data = np.load(f, encoding="latin1") images = data["images"] labels = data["labels"] n_samples = np.prod(images.shape[0:6]) self.images = ( images.reshape([n_samples, 64, 64, 3]).astype(np.float32) / 255.) self.image_leng = self.images.shape[0] features = labels.reshape([n_samples, 6]) self.factor_sizes = [10, 10, 10, 8, 4, 15] self.latent_factor_indices = list(range(6)) self.num_total_factors = features.shape[1] self.state_space = util.SplitDiscreteStateSpace(self.factor_sizes, self.latent_factor_indices) self.factor_bases = np.prod(self.factor_sizes) / np.cumprod( self.factor_sizes)
def __init__(self, latent_factor_indices=None): # By default, all factors (including shape) are considered ground truth # factors. if latent_factor_indices is None: latent_factor_indices = list(range(6)) self.latent_factor_indices = latent_factor_indices self.data_shape = [64, 64, 1] # Load the data so that we can sample from it. with gfile.Open(DSPRITES_PATH, "rb") as data_file: # Data was saved originally using python2, so we need to set the encoding. data = np.load(data_file, encoding="latin1", allow_pickle=True) self.images = np.array(data["imgs"]) self.factor_sizes = np.array(data["metadata"][()]["latents_sizes"], dtype=np.int64) self.full_factor_sizes = [1, 3, 6, 40, 32, 32] self.factor_bases = np.prod(self.factor_sizes) / np.cumprod( self.factor_sizes) self.state_space = util.SplitDiscreteStateSpace( self.factor_sizes, self.latent_factor_indices)