コード例 #1
0
ファイル: norb.py プロジェクト: ThomasMrY/dis_lib
 def __init__(self):
     self.images, features = _load_small_norb_chunks(
         SMALLNORB_TEMPLATE, SMALLNORB_CHUNKS)
     self.image_leng = self.images.shape[0]
     self.factor_sizes = [5, 10, 9, 18, 6]
     # Instances are not part of the latent space.
     self.latent_factor_indices = [0, 2, 3, 4]
     self.num_total_factors = features.shape[1]
     self.index = util.StateSpaceAtomIndex(self.factor_sizes, features)
     self.state_space = util.SplitDiscreteStateSpace(
         self.factor_sizes, self.latent_factor_indices)
コード例 #2
0
 def __init__(self):
   self.factor_sizes = [4, 24, 183]
   features = cartesian([np.array(list(range(i))) for i in self.factor_sizes])
   self.latent_factor_indices = [0, 1, 2]
   self.num_total_factors = features.shape[1]
   self.index = util.StateSpaceAtomIndex(self.factor_sizes, features)
   self.factor_bases = self.index.factor_bases
   self.state_space = util.SplitDiscreteStateSpace(self.factor_sizes,
                                                   self.latent_factor_indices)
   self.data_shape = [64, 64, 3]
   self.images = self._load_data()
コード例 #3
0
    def __init__(self, mode="mpi3d_toy"):
        if mode == "mpi3d_toy":
            mpi3d_path = os.path.join(
                os.environ.get("DISENTANGLEMENT_LIB_DATA", "."), "mpi3d_toy",
                "mpi3d_toy.npz")
            if not tf.io.gfile.exists(mpi3d_path):
                raise ValueError(
                    "Dataset '{}' not found. Make sure the dataset is publicly available and downloaded correctly."
                    .format(mode))
            else:
                with tf.io.gfile.GFile(mpi3d_path, "rb") as f:
                    data = np.load(f)
            self.factor_sizes = [4, 4, 2, 3, 3, 40, 40]
        elif mode == "mpi3d_realistic":
            mpi3d_path = os.path.join(
                os.environ.get("DISENTANGLEMENT_LIB_DATA", "."),
                "mpi3d_realistic", "mpi3d_realistic.npz")
            if not tf.io.gfile.exists(mpi3d_path):
                raise ValueError(
                    "Dataset '{}' not found. Make sure the dataset is publicly available and downloaded correctly."
                    .format(mode))
            else:
                with tf.io.gfile.GFile(mpi3d_path, "rb") as f:
                    data = np.load(f)
            self.factor_sizes = [4, 4, 2, 3, 3, 40, 40]
        elif mode == "mpi3d_real":
            mpi3d_path = os.path.join(
                os.environ.get("DISENTANGLEMENT_LIB_DATA", "."), "mpi3d_real",
                "mpi3d_real.npz")
            if not tf.io.gfile.exists(mpi3d_path):
                raise ValueError(
                    "Dataset '{}' not found. Make sure the dataset is publicly available and downloaded correctly."
                    .format(mode))
            else:
                with tf.io.gfile.GFile(mpi3d_path, "rb") as f:
                    data = np.load(f)
            self.factor_sizes = [6, 6, 2, 3, 3, 40, 40]
        else:
            raise ValueError("Unknown mode provided.")

        self.images = data["images"]
        self.image_leng = self.images.shape[0]
        self.latent_factor_indices = [0, 1, 2, 3, 4, 5, 6]
        self.num_total_factors = 7
        self.state_space = util.SplitDiscreteStateSpace(
            self.factor_sizes, self.latent_factor_indices)
        self.factor_bases = np.prod(self.factor_sizes) / np.cumprod(
            self.factor_sizes)
コード例 #4
0
 def __init__(self):
     data = h5py.File(SHAPES3D_PATH, 'r')
     #     images = data["images"].value
     #     labels = data["labels"].value
     images = data['images'][()]
     labels = data["labels"][()]
     n_samples = 480000
     self.images = (
         images.reshape([n_samples, 64, 64, 3]).astype(np.float32) / 255.)
     features = labels.reshape([n_samples, 6])
     self.factor_sizes = [10, 10, 10, 8, 4, 15]
     self.latent_factor_indices = list(range(6))
     self.num_total_factors = features.shape[1]
     self.state_space = util.SplitDiscreteStateSpace(
         self.factor_sizes, self.latent_factor_indices)
     self.factor_bases = np.prod(self.factor_sizes) / np.cumprod(
         self.factor_sizes)
コード例 #5
0
ファイル: shapes3d.py プロジェクト: ThomasMrY/dis_lib
 def __init__(self):
   with tf.gfile.GFile(SHAPES3D_PATH, "rb") as f:
     # Data was saved originally using python2, so we need to set the encoding.
     data = np.load(f, encoding="latin1")
   images = data["images"]
   labels = data["labels"]
   n_samples = np.prod(images.shape[0:6])
   self.images = (
       images.reshape([n_samples, 64, 64, 3]).astype(np.float32) / 255.)
   self.image_leng = self.images.shape[0]
   features = labels.reshape([n_samples, 6])
   self.factor_sizes = [10, 10, 10, 8, 4, 15]
   self.latent_factor_indices = list(range(6))
   self.num_total_factors = features.shape[1]
   self.state_space = util.SplitDiscreteStateSpace(self.factor_sizes,
                                                   self.latent_factor_indices)
   self.factor_bases = np.prod(self.factor_sizes) / np.cumprod(
       self.factor_sizes)
コード例 #6
0
ファイル: dsprites.py プロジェクト: ThomasMrY/DisCo
 def __init__(self, latent_factor_indices=None):
     # By default, all factors (including shape) are considered ground truth
     # factors.
     if latent_factor_indices is None:
         latent_factor_indices = list(range(6))
     self.latent_factor_indices = latent_factor_indices
     self.data_shape = [64, 64, 1]
     # Load the data so that we can sample from it.
     with gfile.Open(DSPRITES_PATH, "rb") as data_file:
         # Data was saved originally using python2, so we need to set the encoding.
         data = np.load(data_file, encoding="latin1", allow_pickle=True)
         self.images = np.array(data["imgs"])
         self.factor_sizes = np.array(data["metadata"][()]["latents_sizes"],
                                      dtype=np.int64)
     self.full_factor_sizes = [1, 3, 6, 40, 32, 32]
     self.factor_bases = np.prod(self.factor_sizes) / np.cumprod(
         self.factor_sizes)
     self.state_space = util.SplitDiscreteStateSpace(
         self.factor_sizes, self.latent_factor_indices)