示例#1
0
    def prior(self, batch_size, **kwargs):
        if 'soft_prior' in kwargs and kwargs['soft_prior'] is True:
            return self._soft_prior(batch_size)  # return a softmax prior

        uniform_probs = float_type(self.config['cuda'])(
            1, self.output_size).zero_()
        uniform_probs += 1.0 / self.output_size
        cat = torch.distributions.Categorical(probs=uniform_probs)
        sample = cat.sample((batch_size, ))
        return Variable(
            one_hot(self.output_size, sample,
                    use_cuda=self.config['cuda'])).type(
                        float_type(self.config['cuda']))
示例#2
0
    def generate_synthetic_sequential_samples(self, model, num_rows=8):
        assert model.has_discrete()

        # create a grid of one-hot vectors for displaying in visdom
        # uses one row for original dimension of discrete component
        discrete_indices = np.array([
            np.random.randint(begin, end, size=num_rows) for begin, end in zip(
                range(0, model.reparameterizer.config['discrete_size'],
                      self.config['discrete_size']),
                range(self.config['discrete_size'],
                      model.reparameterizer.config['discrete_size'] +
                      1, self.config['discrete_size']))
        ])
        discrete_indices = discrete_indices.reshape(-1)
        with torch.no_grad():
            z_samples = Variable(
                torch.from_numpy(
                    one_hot_np(model.reparameterizer.config['discrete_size'],
                               discrete_indices)))
            z_samples = z_samples.type(float_type(self.config['cuda']))

            if self.config['reparam_type'] == 'mixture' and self.config[
                    'vae_type'] != 'sequential':
                ''' add in the gaussian prior '''
                z_gauss = model.reparameterizer.gaussian.prior(
                    z_samples.size(0))
                z_samples = torch.cat([z_gauss, z_samples], dim=-1)

            return model.nll_activation(model.generate(z_samples))
示例#3
0
文件: gumbel.py 项目: mdiephuis/vae-1
    def prior(self, batch_size, **kwargs):
        """ Sample the prior for batch_size samples.

        :param batch_size: number of prior samples.
        :returns: prior
        :rtype: torch.Tensor

        """
        uniform_probs = float_type(self.config['cuda'])(
            1, self.output_size).zero_()
        uniform_probs += 1.0 / self.output_size
        cat = torch.distributions.Categorical(uniform_probs)
        sample = cat.sample((batch_size, ))
        return Variable(
            one_hot(self.output_size, sample,
                    use_cuda=self.config['cuda'])).type(
                        float_type(self.config['cuda']))
示例#4
0
    def _reparametrize_gaussian(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = float_type(self.config['cuda'])(std.size()).normal_()
            eps = Variable(eps)
            z = eps.mul(std).add_(mu)
            return z, {'z': z, 'mu': mu, 'logvar': logvar}

        return mu, {'z': mu, 'mu': mu, 'logvar': logvar}
示例#5
0
    def sample_gumbel(x, tau, hard=False, use_cuda=True):
        y = GumbelSoftmax._gumbel_softmax(x, tau, use_cuda=use_cuda)

        if hard:
            y_max, _ = torch.max(y, dim=y.dim() - 1, keepdim=True)
            y_hard = Variable(
                torch.eq(y_max.data, y.data).type(float_type(use_cuda)))
            y_hard_diff = y_hard - y
            y_hard = y_hard_diff.detach() + y
            return y.view_as(x), y_hard.view_as(x)

        return y.view_as(x), None
示例#6
0
    def mut_info(self, dist_params, batch_size):
        """ Returns mutual information between z <-> x

        :param dist_params: the distribution dict
        :returns: tensor of dimension batch_size
        :rtype: torch.Tensor

        """
        mut_info = float_type(self.config['cuda'])(batch_size).zero_()

        # only grab the mut-info if the scalars above are set
        if (self.config['continuous_mut_info'] > 0
                or self.config['discrete_mut_info'] > 0):
            mut_info = self._clamp_mut_info(
                self.reparameterizer.mutual_info(dist_params))

        return mut_info
示例#7
0
    def fork(self):
        # copy the old student into the teacher
        # dont increase discrete dim for ewc
        config_copy = deepcopy(self.student.config)
        config_copy['discrete_size'] += 0 if self.config[
            'ewc_gamma'] > 0 else self.config['discrete_size']
        self.teacher = deepcopy(self.student)
        del self.student

        # create a new student
        if self.config['vae_type'] == 'sequential':
            self.student = SequentiallyReparameterizedVAE(
                input_shape=self.teacher.input_shape,
                num_current_model=self.current_model + 1,
                reparameterizer_strs=self.teacher.reparameterizer_strs,
                **{'kwargs': config_copy})
        elif self.config['vae_type'] == 'parallel':
            self.student = ParallellyReparameterizedVAE(
                input_shape=self.teacher.input_shape,
                num_current_model=self.current_model + 1,
                **{'kwargs': config_copy})
        else:
            raise Exception("unknown vae type requested")

        # forward pass once to build lazy modules
        data = float_type(
            self.config['cuda'])(self.student.config['batch_size'],
                                 *self.student.input_shape).normal_()
        self.student(Variable(data))

        # copy teacher params into student while
        # omitting the projection weights
        self.teacher, self.student \
            = self.copy_model(self.teacher, self.student, disable_dst_grads=False)

        # update the current model's ratio
        self.current_model += 1
        self.ratio = self.current_model / (self.current_model + 1.0)
        num_teacher_samples = int(self.config['batch_size'] * self.ratio)
        num_student_samples = max(
            self.config['batch_size'] - num_teacher_samples, 1)
        print("#teacher_samples: ", num_teacher_samples,
              " | #student_samples: ", num_student_samples)
示例#8
0
文件: gumbel.py 项目: mdiephuis/vae-1
    def sample_gumbel(x, tau, hard=False, dim=-1, use_cuda=False):
        """ Sample from the gumbel distribution and return hard and soft versions.

        :param x: the input tensor
        :param tau: temperature
        :param hard: whether to generate hard version (argmax)
        :param dim: dimension to operate over
        :param use_cuda: whether or not to use cuda
        :returns: soft, hard or soft, None
        :rtype: torch.Tensor, Optional(torch.Tensor, None)

        """
        y = GumbelSoftmax._gumbel_softmax(x, tau, dim=dim, use_cuda=use_cuda)

        if hard:
            y_max, _ = torch.max(y, dim=dim, keepdim=True)
            y_hard = Variable(
                torch.eq(y_max.data, y.data).type(float_type(use_cuda)))
            y_hard_diff = y_hard - y
            y_hard = y_hard_diff.detach() + y
            return y.view_as(x), y_hard.view_as(x)

        return y.view_as(x), None
    def loss_function(self, recon_x, x, params, mut_info=None):
        # tf: elbo = -log_likelihood + latent_kl
        # tf: cost = elbo + consistency_kl - self.mutual_info_reg * mutual_info_regularizer
        nll = nll_fn(x, recon_x, self.config['nll_type'])
        kld = self.config['kl_reg'] * self.kld(params)
        elbo = nll + kld

        # handle the mutual information term
        if mut_info is None:
            mut_info = Variable(
                float_type(self.config['cuda'])(x.size(0)).zero_())
        else:
            # Clamping strategies
            mut_clamp_strategy_map = {
                'none':
                lambda mut_info: mut_info,
                'norm':
                lambda mut_info: mut_info / torch.norm(mut_info, p=2),
                'clamp':
                lambda mut_info: torch.clamp(
                    mut_info,
                    min=-self.config['mut_clamp_value'],
                    max=self.config['mut_clamp_value'])
            }
            mut_info = mut_clamp_strategy_map[
                self.config['mut_clamp_strategy'].strip().lower()](mut_info)

        loss = elbo - mut_info
        return {
            'loss': loss,
            'loss_mean': torch.mean(loss),
            'elbo_mean': torch.mean(elbo),
            'nll_mean': torch.mean(nll),
            'kld_mean': torch.mean(kld),
            'mut_info_mean': torch.mean(mut_info)
        }
示例#10
0
    def generate_synthetic_sequential_samples(self,
                                              num_original_discrete,
                                              num_rows=8):
        """ Iterates over all discrete positions and generates samples (for mix or disc only).

        :param num_original_discrete: The original discrete size (useful for LLVAE).
        :param num_rows: for visdom
        :returns: decoded logits
        :rtype: torch.Tensor

        """
        assert self.has_discrete()

        # create a grid of one-hot vectors for displaying in visdom
        # uses one row for original dimension of discrete component
        discrete_indices = np.array([
            np.random.randint(begin, end, size=num_rows) for begin, end in zip(
                range(0, self.reparameterizer.config['discrete_size'],
                      num_original_discrete),
                range(num_original_discrete,
                      self.reparameterizer.config['discrete_size'] +
                      1, num_original_discrete))
        ])
        discrete_indices = discrete_indices.reshape(-1)

        self.eval()  # lock BN / Dropout, etc
        with torch.no_grad():
            z_samples = Variable(
                torch.from_numpy(
                    one_hot_np(self.reparameterizer.config['discrete_size'],
                               discrete_indices)))
            z_samples = z_samples.type(float_type(self.config['cuda']))

            if self.config['reparam_type'] == 'mixture' and self.config[
                    'vae_type'] != 'sequential':
                ''' add in the gaussian prior '''
                z_cont = self.reparameterizer.continuous.prior(
                    z_samples.size(0))
                z_samples = torch.cat([z_cont, z_samples], dim=-1)

            # the below is to handle the issues with BN
            # pad the z to be full batch size
            number_to_return = z_samples.shape[0]  # original generate number
            number_batches_z = int(
                max(
                    1,
                    np.ceil(
                        float(self.config['batch_size']) /
                        float(number_to_return))))
            z_padded = torch.cat([z_samples for _ in range(number_batches_z)],
                                 0)[0:self.config['batch_size']]

            # generate and return the requested number
            number_batches_to_generate = int(
                max(
                    1,
                    np.ceil(
                        float(number_to_return) /
                        float(self.config['batch_size']))))
            generated = torch.cat([
                self.generate_synthetic_samples(self.config['batch_size'],
                                                z_samples=z_padded)
                for _ in range(number_batches_to_generate)
            ], 0)
            return generated[0:number_to_return]  # only return num_requested
示例#11
0
def lazy_generate_modules(model, img_shp):
    ''' Super hax, but needed for building lazy modules '''
    model.eval()
    data = float_type(args.cuda)(args.batch_size, *img_shp).normal_()
    model(Variable(data))
示例#12
0
 def prior(self, batch_size, **kwargs):
     scale_var = 1.0 if 'scale_var' not in kwargs else kwargs['scale_var']
     return Variable(
         float_type(self.config['cuda'])(batch_size, self.output_size).normal_(mean=0, std=scale_var)
     )
示例#13
0
 def _prior_distribution(self, batch_size):
     uniform_probs = float_type(self.config['cuda'])(batch_size, self.output_size).zero_()
     #uniform_probs += 1.0 / self.output_size
     uniform_probs += 0.5
     return torch.distributions.Bernoulli(probs=uniform_probs)