Beispiel #1
0
class RegularizedGAN(GANModel):
    def __init__(self, latent_spec, **kwargs):
        """
        Args:
            latent_spec (list): List of latent distributions.
             [(Distribution, bool)]
             The boolean indicates if the distribution should be used for
             regularization.
        """

        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        super(RegularizedGAN, self).__init__(**kwargs)
        d = {
            'latent_code_influence': {
                'tensor': 'get_latent_code_influence_g_input_tensor',
            },
            'linear_interpolation': {
                'tensor': 'get_linear_interpolation_g_input_tensor',
            },
        }
        self.sampling_functions.update(d)

    def __getstate__(self):
        pickling_dict = super(RegularizedGAN, self).__getstate__()
        del pickling_dict['encoder_template']
        return pickling_dict

    def build_g_input(self):
        return self.latent_dist.sample_prior(self.batch_size)

    def get_g_feed_dict(self):
        return None

    def get_reg_dist_info(self, x_var):
        reg_dist_flat = self.encoder_template.construct(input=x_var)
        reg_dist_info = self.reg_latent_dist.activate_dist(reg_dist_flat)
        return reg_dist_info

    def disc_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists,
                               self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(z_i)
        return self.reg_disc_latent_dist.join_vars(ret)

    def cont_reg_z(self, reg_z_var):
        ret = []
        for dist_i, z_i in zip(self.reg_latent_dist.dists,
                               self.reg_latent_dist.split_var(reg_z_var)):
            if isinstance(dist_i, Gaussian):
                ret.append(z_i)
        return self.reg_cont_latent_dist.join_vars(ret)

    def disc_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(
                self.reg_latent_dist.dists,
                self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, (Categorical, Bernoulli)):
                ret.append(dist_info_i)
        return self.reg_disc_latent_dist.join_dist_infos(ret)

    def cont_reg_dist_info(self, reg_dist_info):
        ret = []
        for dist_i, dist_info_i in zip(
                self.reg_latent_dist.dists,
                self.reg_latent_dist.split_dist_info(reg_dist_info)):
            if isinstance(dist_i, Gaussian):
                ret.append(dist_info_i)
        return self.reg_cont_latent_dist.join_dist_infos(ret)

    def reg_z(self, z_var=None):
        """ Return the variables with distribution bool == True (concatenated). """
        ret = []
        if z_var is None:
            z_var = self.g_input
        for (_, reg_i), z_i in zip(self.latent_spec,
                                   self.latent_dist.split_var(z_var)):
            if reg_i:
                ret.append(z_i)
        return self.reg_latent_dist.join_vars(ret)

    def nonreg_z(self, z_var):
        ret = []
        for (_, reg_i), z_i in zip(self.latent_spec,
                                   self.latent_dist.split_var(z_var)):
            if not reg_i:
                ret.append(z_i)
        return self.nonreg_latent_dist.join_vars(ret)

    def reg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(
                self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if reg_i:
                ret.append(dist_info_i)
        return self.reg_latent_dist.join_dist_infos(ret)

    def nonreg_dist_info(self, dist_info):
        ret = []
        for (_, reg_i), dist_info_i in zip(
                self.latent_spec, self.latent_dist.split_dist_info(dist_info)):
            if not reg_i:
                ret.append(dist_info_i)
        return self.nonreg_latent_dist.join_dist_infos(ret)

    def combine_reg_nonreg_z(self, reg_z_var, nonreg_z_var):
        reg_z_vars = self.reg_latent_dist.split_var(reg_z_var)
        reg_idx = 0
        nonreg_z_vars = self.nonreg_latent_dist.split_var(nonreg_z_var)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_z_vars[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_z_vars[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_vars(ret)

    def combine_reg_nonreg_dist_info(self, reg_dist_info, nonreg_dist_info):
        reg_dist_infos = self.reg_latent_dist.split_dist_info(reg_dist_info)
        reg_idx = 0
        nonreg_dist_infos = self.nonreg_latent_dist.split_dist_info(
            nonreg_dist_info)
        nonreg_idx = 0
        ret = []
        for idx, (dist_i, reg_i) in enumerate(self.latent_spec):
            if reg_i:
                ret.append(reg_dist_infos[reg_idx])
                reg_idx += 1
            else:
                ret.append(nonreg_dist_infos[nonreg_idx])
                nonreg_idx += 1
        return self.latent_dist.join_dist_infos(ret)

    ###########################################################################
    # SAMPLING
    ###########################################################################
    def get_train_g_input_tensor(self):
        if len(self.reg_latent_dist.dists) > 0:
            sampling_type = 'latent_code_influence'
        else:
            sampling_type = 'random'
        return make_list(self.get_g_input_tensor(sampling_type))

    def get_g_input_value(self, sampling_type, **kwargs):
        # TODO: self.get_g_input_tensor should return a Tensor and
        # this method should .eval() its result with a tf.Session()
        if sampling_type == 'random':
            return self.get_g_input(sampling_type, 'value', **kwargs)
        return self.get_g_input_tensor(sampling_type, **kwargs)

    def get_linear_interpolation_g_input_tensor(self,
                                                n_samples=10,
                                                n_variations=10):
        """
        Returns:
            (ndarray, str):
        """
        # TODO: return a tensor instead of ndarray
        with tf.Session():
            n = n_samples * n_variations
            all_z_start = self.latent_dist.sample_prior(n_samples).eval()
            all_z_end = self.latent_dist.sample_prior(n_samples).eval()
            coefficients = np.linspace(start=0, stop=1, num=n_variations)
            z_var = []
            for z_start, z_end in zip(all_z_start, all_z_end):
                for coeff in coefficients:
                    z_var.append(coeff * z_start + (1 - coeff) * z_end)
            other = self.latent_dist.sample_prior(self.batch_size - n).eval()
            z_var = np.concatenate([z_var, other], axis=0)
            z_var = np.asarray(z_var, dtype=np.float32)
            return z_var, 'linear_interpolations'

    def get_latent_code_influence_g_input_tensor(self,
                                                 n_samples=10,
                                                 n_variations=10,
                                                 min_continuous=-2.,
                                                 max_continuous=2.):
        """
        Args:
            n_samples (int): The number of different samples (n_columns).
            n_variations (int): The number of variations for each latent code
             (n_rows).
            min_continuous (float): The minimum value for a continuous latent
             code
            max_continuous (float): The maximum value for a continuous latent
             code
        Returns:
            (ndarray, str):
        """
        # TODO: return a tensor instead of ndarray
        if len(self.reg_latent_dist.dists) == 0:
            raise ValueError('The model must have at least one regularization '
                             'latent distribution.')
        with tf.Session():
            # (n, d) with 10 * 10 samples + other samples
            n = n_samples * n_variations
            fixed_noncat = self.nonreg_latent_dist.sample_prior(n_samples)
            fixed_noncat = np.tile(fixed_noncat.eval(), [n_variations, 1])
            other = self.nonreg_latent_dist.sample_prior(self.batch_size - n)
            other = other.eval()
            fixed_noncat = np.concatenate([fixed_noncat, other], axis=0)

            fixed_cat = self.reg_latent_dist.sample_prior(n_samples).eval()
            fixed_cat = np.tile(fixed_cat, [n_variations, 1])
            other = self.reg_latent_dist.sample_prior(self.batch_size - n)
            other = other.eval()
            fixed_cat = np.concatenate([fixed_cat, other], axis=0)

        offset = 0
        z_vars_and_names = []
        for dist_idx, dist in enumerate(self.reg_latent_dist.dists):
            if isinstance(dist, Gaussian):
                assert dist.dim == 1, "Only dim=1 is currently supported"
                vary_cat = dist.varying_values(min_continuous, max_continuous,
                                               n_variations)
                vary_cat = np.repeat(vary_cat, n_samples)
                other = np.zeros(self.batch_size - n)
                vary_cat = np.concatenate((vary_cat, other)).reshape((-1, 1))
                vary_cat = np.asarray(vary_cat, dtype=np.float32)

                cur_cat = np.copy(fixed_cat)
                cur_cat[:, offset:offset + 1] = vary_cat
                offset += 1
            elif isinstance(dist, Categorical):
                lookup = np.eye(dist.dim, dtype=np.float32)
                cat_ids = []
                for idx in xrange(n_variations):
                    cat_ids.extend([idx] * n_samples)
                cat_ids.extend([0] * (self.batch_size - n))
                cur_cat = np.copy(fixed_cat)
                cur_cat[:, offset:offset + dist.dim] = lookup[cat_ids]
                offset += dist.dim
            elif isinstance(dist, Bernoulli):
                assert dist.dim == 1, "Only dim=1 is currently supported"
                lookup = np.eye(dist.dim, dtype=np.float32)
                cat_ids = []
                for idx in xrange(n_variations):
                    cat_ids.extend([int(idx / 5)] * n_samples)
                cat_ids.extend([0] * (self.batch_size - n))
                cur_cat = np.copy(fixed_cat)
                cur_cat[:, offset:offset + dist.dim] = np.expand_dims(
                    np.array(cat_ids), axis=-1)
                # import ipdb; ipdb.set_trace()
                offset += dist.dim
            else:
                raise NotImplementedError
            # (n, d)
            # The 10 first rows have different z and c and are tiled 10 times
            # except for the varying c that is the same by blocks of 10 rows
            # and linearly varies between blocks
            z_var = np.concatenate([fixed_noncat, cur_cat], axis=1)

            # Images where each column had a different fixed z and c
            # The varying c varies along each column
            # (a different value for each row)
            name = 'image_{}_{}'.format(dist_idx, dist.__class__.__name__)
            z_vars_and_names.append((z_var, name))
        return z_vars_and_names