Exemplo n.º 1
0
    def decide(self, choices: List[any]) -> int:

        inputs = list(map(lambda choice: torch.FloatTensor(choice), choices))
        enhanced_features = list(
            map(lambda vec: self._base_network.model.forward(vec), inputs))
        action_features = list(
            map(lambda vec: self._policy_gradient.model.forward(vec.detach()),
                enhanced_features))

        # Get move
        probabilities = Function.softmax(torch.cat(list(action_features)))
        distribution = Multinomial(1, probabilities)
        move = distribution.sample()
        _, index_of_move = move.max(0)

        # Expected reward
        expected_reward = self._value_function.model(
            enhanced_features[index_of_move])
        log_probability = distribution.log_prob(move)

        # Record estimate
        self.rounds.append(
            Round(value=expected_reward, log_probability=log_probability))

        # Return
        return index_of_move.item()
Exemplo n.º 2
0
def select_action(state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy(Variable(state))
    m = Multinomial(probs)
    action = m.sample()
    policy.saved_log_probs.append(m.log_prob(action))
    return action.data[0]
Exemplo n.º 3
0
    def _get_samples(self, n, attributes=None, split_id=None):
        if not attributes:
            attributes = []
        samples = []
        sample_attributes = []
        samples_per_concept = Multinomial(n, probs=self.weights).sample()
        samples_per_concept = samples_per_concept.long().tolist()
        for concept, n_samples in zip(self.get_atomic_concepts(),
                                      samples_per_concept):
            if n_samples == 0:
                continue
            c_samples, c_attrs = concept._get_samples(n_samples,
                                                      attributes=attributes,
                                                      split_id=split_id)
            samples.append(c_samples)
            sample_attributes.append(c_attrs)

        if attributes:
            sample_attributes = torch.cat(sample_attributes)
        else:
            sample_attributes = torch.Tensor()

        if torch.is_tensor(samples[0]):
            cat_func = torch.cat
        else:
            cat_func = np.concatenate
        return cat_func(samples), sample_attributes
Exemplo n.º 4
0
    def forward(self, input):
        if self.quant:
            p_a = torch.sigmoid(self.p_a)
            p_b = torch.sigmoid(self.p_b)
            p_w_0 = p_a
            p_w_pos = p_b * (1. - p_w_0)
            p_w_neg = (1. - p_b) * (1. - p_w_0)
            p = torch.stack([p_w_neg, p_w_0, p_w_pos], dim=-1)
            if self.training:
                w_mean = (p * self.w_candidate).sum(dim=-1)
                w_var = (p *
                         self.w_candidate.pow(2)).sum(dim=-1) - w_mean.pow(2)
                act_mean = F.linear(input, w_mean, self.bias)
                act_var = F.linear(input.pow(2), w_var, None)
                var_eps = torch.randn_like(act_mean)
                y = act_mean + var_eps * act_var.add(self.eps).sqrt()
            else:
                m = Multinomial(probs=p)
                indices = m.sample().argmax(dim=-1)
                w = self.w_candidate[indices]
                y = F.linear(input, w, self.bias)
        else:
            y = super().forward(input)

        self.forward(y)
        return y
Exemplo n.º 5
0
def select_action(state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs, state_value = model(Variable(state))
    m = Multinomial(probs)
    action = m.sample()
    model.saved_actions.append(SavedAction(m.log_prob(action), state_value))
    return action.data[0]
Exemplo n.º 6
0
def select_action(state, variance=1, temp=10):
    # this function selects stochastic actions based on the policy probabilities
    state = torch.from_numpy(state).float().unsqueeze(0)
    action_scores = actor(state)
    prob = F.softmax(action_scores / temp, dim=1)  #
    m = Multinomial(vaccine_supply, prob[0])
    action = m.sample()
    log_prob = m.log_prob(action)
    entropy = -(log_prob * prob).sum(1, keepdim=True)
    return action.numpy(), log_prob, entropy
Exemplo n.º 7
0
def sample_with_weights(values: torch.Tensor, weights: torch.Tensor,
                        num_samples: int) -> torch.Tensor:
    # define multinomial with weights as probs
    multi = Multinomial(probs=weights)
    # sample num samples, with replacement
    samples = multi.sample(sample_shape=(num_samples, ))
    # get indices of success trials
    indices = torch.where(samples)[1]
    # return those indices from trace
    return values[indices]
Exemplo n.º 8
0
    def test_multinomial_2d(self):
        probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
        p = Variable(torch.Tensor(probabilities), requires_grad=True)
        self.assertEqual(Multinomial(p).sample().size(), (2, ))
        self._gradcheck_log_prob(Multinomial, (p, ))

        def ref_log_prob(idx, val, log_prob):
            sample_prob = p.data[idx][val] / p.data[idx].sum()
            self.assertEqual(log_prob, math.log(sample_prob))

        self._check_log_prob(Multinomial(p), ref_log_prob)
Exemplo n.º 9
0
    def sample(self, num_samples):
        noise = torch.randn(num_samples, self.num_vars).to(self.means.device)

        comp_sampler = Multinomial(logits=self.log_weight)

        components = comp_sampler.sample_n(num_samples).cuda()

        return (components.unsqueeze(-1) *
                (self.means.unsqueeze(0) +
                 torch.exp(self.log_std / 2).unsqueeze(0) * noise.unsqueeze(1))
                ).sum(dim=1)
Exemplo n.º 10
0
def select_action(state, variance=1, temp=10):
    state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    action_scores = actor(state)
    print(action_scores, file=myLog)
    prob = F.softmax(action_scores / temp, dim=1)  #
    #print('***',prob)
    m = Multinomial(vaccine_supply, prob[0])  #[0]
    action = m.sample()
    #print(action)
    log_prob = m.log_prob(action)
    entropy = -torch.sum(torch.log(prob) * prob, axis=-1)
    return action.numpy(), log_prob, entropy
Exemplo n.º 11
0
def sample(lp:Tensor,axis=1,numsamples=1,MAP=False):
	lastaxis = lp.ndimension() -1
	lpt = lp.transpose(lastaxis,axis)
	M = Multinomial(total_count=numsamples,logits=lpt)
	#D = Dirichlet((lp.exp())*(numsamples.float())/(lp.size(lastaxis)))
	samps = M.sample().detach()
	samps = samps.transpose(lastaxis,axis)/numsamples
	logprob = (lp-(samps.detach()).log())
	logprob[logprob!=logprob] = float('Inf')
	logprob = logprob.min(dim=axis,keepdim=True)[0]

	return None,None
    def sample_from_population_with_weights(
        particles: Tensor, weights: Tensor, num_samples: int = 1
    ) -> Tensor:
        """Return samples from particles sampled with weights."""

        # define multinomial with weights as probs
        multi = Multinomial(probs=weights)
        # sample num samples, with replacement
        samples = multi.sample(sample_shape=(num_samples,))
        # get indices of success trials
        indices = torch.where(samples)[1]
        # return those indices from trace
        return particles[indices]
Exemplo n.º 13
0
def select_action(state, variance=1, temp=10):
    # this function selects stochastic actions based on the policy probabilities
    #state = torch.from_numpy(np.array(state)).float().unsqueeze(0)   #Reza: this might be a bit faster torch.tensor(state,dtype=torch.float32).unsqueeze(0)
    state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)

    action_scores = actor(state)
    print(action_scores, file=myLog)
    prob = F.softmax(action_scores / temp, dim=1)  #
    #print('***',prob)
    m = Multinomial(vaccine_supply, prob[0])  #[0]
    action = m.sample()
    log_prob = m.log_prob(action)
    entropy = -torch.sum(torch.log(prob) * prob, axis=-1)
    return action.numpy(), log_prob, entropy
Exemplo n.º 14
0
def predict(encoded_text: torch.Tensor,
            model: nn.Module,
            k: int = 1,
            device: torch.device = "cpu") -> torch.Tensor:
    model.eval()

    (out) = model(encoded_text.to(device))

    logits = out[0]

    # TODO why?
    logits = logits[:, -1]
    sample = Multinomial(k, logits=logits).sample()
    prediction = sample.argmax().reshape((encoded_text.shape[0], ))

    return prediction, out[1]
Exemplo n.º 15
0
 def get_reconstruction_loss(self, x):
     hx = ilr(self.imputer(x), self.Psi)
     z_mean = self.encoder(hx)
     eta = self.decoder(z_mean)
     logp = self.Psi.t() @ eta.t()
     mult_loss = Multinomial(logits=logp.t()).log_prob(x).mean()
     return -mult_loss
    def select_action(self, state, temp=1):
        # this function selects stochastic actions based on the policy probabilities
        state = torch.tensor(state, dtype=torch.float32,
                             device=self.device).unsqueeze(0)
        logits = self.actor(state)

        # TODO: check this one later
        logits_norm = (logits - torch.mean(logits)) / \
                             (torch.std(logits) + 1e-5)

        m = Multinomial(self.args.vaccine_supply,
                        logits=logits_norm.squeeze() / temp)
        action = m.sample()
        log_prob = m.log_prob(action)
        entropy = -torch.sum(m.logits * m.probs)
        return action.to('cpu').numpy(), log_prob, entropy
Exemplo n.º 17
0
def select_action(state, variance=1, temp=1):
    # this function selects stochastic actions based on the policy probabilities    
    # state = torch.from_numpy(np.array(state)).float().unsqueeze(0)   #Reza: this might be a bit faster torch.tensor(state,dtype=torch.float32).unsqueeze(0)
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    action_scores = actor(state)
    action_scores_norm = (action_scores-torch.mean(action_scores))/\
                         (torch.std(action_scores)+1e-5)
    # print(action_scores, file=myLog)
    # prob = F.softmax(action_scores_norm , dim=1)
    # print('***',prob)
    m = Multinomial(vaccine_supply, logits=action_scores_norm.squeeze()/ temp)
    # m = Multinomial(vaccine_supply, prob[0])  # [0]
    action = m.sample()
    log_prob = m.log_prob(action)
    # entropy = - torch.sum(torch.log(prob) * prob, axis=-1)
    entropy = -torch.sum(m.logits* m.probs, axis=-1)
    return action.to('cpu').numpy(), log_prob, entropy
Exemplo n.º 18
0
def sample_liklihood(lp,axis=1,numsamples=1):
	lastaxis = lp.ndimension() - 1
	lporig = lp
	lpunif = torch.zeros_like(lp)
	lpunif = lp.exp() * 0 - (lp.exp() * 0).logsumexp(dim=1, keepdim=True)
	samplinglp = lpunif
	lpt = samplinglp.transpose(lastaxis, axis)
	M = Multinomial(total_count=numsamples, logits=lpt)
	samps = M.sample().detach()
	samps = samps.transpose(lastaxis, axis) / numsamples
	logprob = (lporig - (samps.detach()).log())
	logprob[logprob != logprob] = float('Inf')
	logprob = logprob.min(dim=axis, keepdim=True)[0]

	lpmodel = min_correction(lpunif - lporig, axis)

	return samps.detach(), logprob, lpmodel
Exemplo n.º 19
0
    def evaluate(self, possible_boards):
        # possible_boards -> neural network -> sigmoid -> last_layer_sigmoid
        last_layer_outputs = self.run_through_neural_network(possible_boards)
        # last_layer_sigmoid = list(map(lambda x: x.sigmoid(), last_layer_outputs))

        # Decide move and save log_prob for backward
        # We make sure not to affect the value fn with .detach()

        probs = self.pg_plugin._softmax(last_layer_outputs)
        distribution = Multinomial(1, probs)
        move = distribution.sample()
        self.saved_log_probabilities.append(distribution.log_prob(move))

        _, move = move.max(0)
        # calculate the value estimation and save for backward
        value_estimate = self.pg_plugin.value_model(last_layer_outputs[move])
        self.saved_value_estimations.append(value_estimate)
        return move
Exemplo n.º 20
0
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        is_labelled = False if y is None else True

        x_ = torch.log(1 + x)
        qz1_m, qz1_v, z1 = self.z_encoder(x_)
        ql_m, ql_v, library = self.l_encoder(x_)

        # Enumerate choices of label
        ys, z1s = (broadcast_labels(y, z1, n_broadcast=self.n_labels))
        qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys)
        pz1_m, pz1_v = self.decoder_z1_z2(z2, ys)
        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z1, library, batch_index)

        reconst_loss = self._reconstruction_loss(x, px_rate, px_r, px_dropout,
                                                 batch_index, y)

        # KL Divergence
        mean = torch.zeros_like(qz2_m)
        scale = torch.ones_like(qz2_v)

        kl_divergence_z2 = kl(Normal(qz2_m, torch.sqrt(qz2_v)),
                              Normal(mean, scale)).sum(dim=1)
        loss_z1_unweight = -Normal(pz1_m,
                                   torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1)
        loss_z1_weight = Normal(qz1_m,
                                torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)

        if is_labelled:
            return reconst_loss + loss_z1_weight + loss_z1_unweight, kl_divergence_z2 + kl_divergence_l

        probs = self.classifier(z1)
        reconst_loss += (loss_z1_weight + (
            (loss_z1_unweight).view(self.n_labels, -1).t() * probs).sum(dim=1))

        kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() *
                         probs).sum(dim=1)
        kl_divergence += kl(Multinomial(probs=probs),
                            Multinomial(probs=self.y_prior))

        return reconst_loss, kl_divergence
Exemplo n.º 21
0
 def get_reconstruction_loss(self, x, B):
     hx = ilr(self.imputer(x), self.Psi)
     batch_effects = (self.Psi @ B.t()).t()
     hx -= batch_effects  # Subtract out batch effects
     z_mean = self.encoder(hx)
     eta = self.decoder(z_mean)
     eta += batch_effects  # Add batch effects back in
     logp = self.Psi.t() @ eta.t()
     mult_loss = Multinomial(logits=logp.t()).log_prob(x).mean()
     return -mult_loss
Exemplo n.º 22
0
    def update_environment(self, block, trial, responses):
        """Generate stimuli for the current block and trial and update the state
        """

        # offers in the current trial
        offers = self.offers[block][trial]

        # selected arm types
        arm_types = self.arm_types[offers, responses]

        # each selected arm is associated with specific set of reward probabilities
        probs = self.states['probs'][block, trial, range(self.nsub), arm_types]
        out1 = Multinomial(probs=probs).sample()

        out = {'locations': responses, 'features': out1.argmax(-1)}

        out2 = self.update_states(block,
                                  trial + 1,
                                  responses=responses,
                                  outcomes=out1)

        return [responses, (out, out2)]
Exemplo n.º 23
0
    def sample_fn():
        fig = figure(figsize=(12, 5))
        model.eval()

        if K is None:
            n_samp = 12
            Y = None
        else:
            n_samp = 2 * K
            Y = torch.arange(2 * K, device=device) % K

        X = torch.zeros(n_samp, C, H, W, device=device).long()

        with torch.no_grad():
            for h in range(H):
                for w in range(W):
                    for c in range(C):
                        _, logits = model(X, Y)

                        m = Multinomial(logits=logits[:, :, h, w, c])
                        X_ = m.sample(torch.Size([]))
                        X[:, c, h, w] = torch.argmax(X_, dim=1)

        X = X.cpu().numpy()
        if C > 1:
            X = X.astype('float') / 255.0
            _ = imshow(
                X.reshape(2, n_samp // 2, C, H,
                          W).transpose(0, 3, 1, 4,
                                       2).reshape(2 * H, n_samp // 2 * W, C))
        else:
            _ = imshow(
                X.reshape(2, n_samp // 2, H,
                          W).transpose(0, 2, 1,
                                       3).reshape(2 * H, n_samp // 2 * W))
        colorbar()

        return fig
Exemplo n.º 24
0
def sampleunif(lp:Tensor,axis=1,numsamples=1):
	''' Samples from the random variables uniformly
	A model is given in the probability space with logit vector lp
	The probability that the sample is in the model is calculated.

	'''
	lastaxis = lp.ndimension() -1
	lporig = lp
	lpunif = torch.zeros_like(lp)
	lpunif = lpunif - (lpunif).logsumexp(dim=1,keepdim=True)
	lpt = lpunif.transpose(lastaxis,axis)
	M = Multinomial(total_count=numsamples,logits=lpt)
	samps = M.sample().detach()
	samps = samps.transpose(lastaxis,axis)/numsamples
	logprob = (lporig-(samps.detach()).log())
	logprob[logprob!=logprob] = float('Inf')
	logprob = logprob.min(dim=axis,keepdim=True)[0]
	# lpmodel = (lpunif-lporig).min(dim=axis,keepdim=True)[0]
	# TODO min
	lpmodel = softmin(lpunif-lporig,axis)
	# lpmodel= (lpunif-lporig).min(dim=1,keepdim=True)[0]# -  float(lporig.shape[1])
	# lpmodel = renyi_prob(lpunif,lporig,1)
	inmodel_lprobs = logprob + lpmodel - lpunif.mean(dim=1, keepdim=True)  # - max_correction(-lporig, axis)
	return None, None, None
Exemplo n.º 25
0
 def forward(self, x):
     hx = ilr(self.imputer(x), self.Psi)
     z_mean = self.encoder(hx)
     mu = self.decoder(z_mean)
     W = self.decoder.weight
     # penalties
     D = torch.exp(self.variational_logvars)
     var = torch.exp(self.log_sigma_sq)
     qdist = MultivariateNormalFactorIdentity(mu, var, D, W)
     logp = self.Psi.t() @ self.eta.t()
     prior_loss = Normal(self.zm, self.zI).log_prob(z_mean).mean()
     logit_loss = qdist.log_prob(self.eta).mean()
     mult_loss = Multinomial(logits=logp.t()).log_prob(x).mean()
     loglike = mult_loss + logit_loss + prior_loss
     return -loglike
Exemplo n.º 26
0
 def recon_model_loglik(self, x_in, x_out):
     logp = (self.Psi.t() @ x_out.t()).t()
     if self.distribution == 'multinomial':
         dist_loss = Multinomial(
             logits=logp, validate_args=False  # weird ...
         ).log_prob(x_in).mean()
     elif self.distribution == 'gaussian':
         # MSE loss based out on DeepMicro
         # https://www.nature.com/articles/s41598-020-63159-5
         dist_loss = Normal(
             loc=logp, scale=1, validate_args=False  # weird ...
         ).log_prob(x_in).mean()
     else:
         raise ValueError(
             f'Distribution {self.distribution} is not supported.')
     return dist_loss
Exemplo n.º 27
0
 def recon_model_loglik(self, x, eta):
     # WARNING : the gaussian likelidhood is not supported
     if self.likelihood == 'gaussian':
         x_in = self.Psi.t() @ torch.log(x + 1).t()
         diff = (x - eta)**2
         sigma_sq = torch.exp(self.log_sigma_sq)
         # No dimension constant as we sum after
         return 0.5 * (-diff / sigma_sq - LOG_2_PI - self.log_sigma_sq)
     elif self.likelihood == 'multinomial':
         logp = (self.Psi.t() @ eta.t()).t()
         mult_loss = Multinomial(logits=logp).log_prob(x).mean()
         return mult_loss
     elif self.likelihood == 'lognormal':
         logp = F.logsoftmax((self.Psi.t() @ eta.t()).t(), axis=-1)
         logN = torch.log(x.sum(axis=-1))
         mu = logp + logN
         sigma_sq = torch.exp(self.log_sigma_sq)
         nz = x > 0
         logn_loss = LogNormal(loc=mu[nz], scale=sigma_sq).log_prob(x[nz])
         return logn_loss.mean()
     else:
         raise ValueError(
             f'{self.likelihood} has not be properly specified.')
Exemplo n.º 28
0
    def generate_discrete_network(self, method: str = "sample"):
        """ generates discrete weights from the weights of the layer based on the weight distributions

        :param method: the method to use to generate the discrete weights. Either argmax or sample

        :returns: tuple (sampled_w, sampled_b) where sampled_w and sampled_b are tensors of the shapes
        (output_channels x input_channels x kernel rows x kernel columns) and (output_features x 1). sampled_b is None if the layer has no bias
        """

        probabilities_w = self.generate_weight_probabilities(self.W_logits)
        # logit probabilities must be in inner dimension for torch.distribution.Multinomial
        # stepped transpose bc we need to keep the order of the other dimensions
        probabilities_w = probabilities_w.transpose(0, 1).transpose(
            1, 2).transpose(2, 3).transpose(3, 4)
        if self.b_logits is not None:
            probabilities_b = self.generate_weight_probabilities(self.b_logits)
            probabilities_b = probabilities_b.transpose(0, 1).transpose(1, 2)
        else:
            # layer does not use bias
            probabilities_b = None
        discrete_values_tensor = torch.tensor(
            self.discrete_weight_values).double()
        discrete_values_tensor = discrete_values_tensor.to(
            self.W_logits.device)
        if method == "sample":
            # this is a output_channels x input_channels x kernel rows x kernel columns x discretization_levels mask
            m_w = Multinomial(probs=probabilities_w)
            sampled_w = m_w.sample()
            if torch.all(sampled_w.sum(dim=4) != 1):
                raise ValueError("sampled mask for weights does not sum to 1")

            # need to generate the discrete weights from the masks
            sampled_w = torch.matmul(sampled_w, discrete_values_tensor)

            if probabilities_b is not None:
                # this is a output channels x 1 x discretization levels mask
                m_b = Multinomial(probs=probabilities_b)
                sampled_b = m_b.sample()

                if torch.all(sampled_b.sum(dim=2) != 1):
                    raise ValueError("sampled mask for bias does not sum to 1")
                sampled_b = torch.matmul(sampled_b, discrete_values_tensor)
            else:
                sampled_b = None

        elif method == "argmax":
            # returns a (out_feat x in_feat) matrix where the values correspond to the index of the discretized value
            # with the largest probability
            argmax_w = torch.argmax(probabilities_w, dim=4)
            # creating placeholder for discrete weights
            sampled_w = torch.zeros_like(argmax_w).to("cpu")
            sampled_w[:] = discrete_values_tensor[argmax_w[:]]

            if probabilities_b is not None:
                argmax_b = torch.argmax(probabilities_b, dim=2)
                sampled_b = torch.zeros_like(argmax_b).to("cpu")
                sampled_b[:] = discrete_values_tensor[argmax_b[:]]
            else:
                sampled_b = None
        else:
            raise ValueError(
                f"Invalid method {method} for layer discretization")

        # sanity checks
        if sampled_w.shape != probabilities_w.shape[:-1]:
            raise ValueError(
                "sampled probability mask for weights does not match expected shape"
            )
        if sampled_b:
            if sampled_b.shape != probabilities_b.shape[:-1]:
                raise ValueError(
                    "sampled probability mask for bias does not match expected shape"
                )

        return sampled_w, sampled_b
Exemplo n.º 29
0
    def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        is_labelled = False if y is None else True

        # Prepare for sampling
        xs, ys = (x, y)

        # Enumerate choices of label
        if not is_labelled:
            ys = enumerate_discrete(xs, self.n_labels)
            xs = xs.repeat(self.n_labels, 1)
            if batch_index is not None:
                batch_index = batch_index.repeat(self.n_labels, 1)
            local_l_var = local_l_var.repeat(self.n_labels, 1)
            local_l_mean = local_l_mean.repeat(self.n_labels, 1)
        else:
            ys = one_hot(ys, self.n_labels)

        xs_ = xs
        if self.log_variational:
            xs_ = torch.log(1 + xs_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(xs_, ys)
        ql_m, ql_v, library = self.l_encoder(xs_)

        if self.dispersion == "gene-cell":
            px_scale, self.px_r, px_rate, px_dropout = self.decoder(
                self.dispersion, z, library, batch_index, y=ys)
        elif self.dispersion == "gene":
            px_scale, px_rate, px_dropout = self.decoder(self.dispersion,
                                                         z,
                                                         library,
                                                         batch_index,
                                                         y=ys)

        # Reconstruction Loss
        if self.reconstruction_loss == 'zinb':
            reconst_loss = -log_zinb_positive(xs, px_rate, torch.exp(
                self.px_r), px_dropout)
        elif self.reconstruction_loss == 'nb':
            reconst_loss = -log_nb_positive(xs, px_rate, torch.exp(self.px_r))

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)),
                             Normal(mean, scale)).sum(dim=1)
        kl_divergence_l = kl(Normal(ql_m, torch.sqrt(ql_v)),
                             Normal(local_l_mean,
                                    torch.sqrt(local_l_var))).sum(dim=1)
        kl_divergence = kl_divergence_z + kl_divergence_l

        if is_labelled:
            return reconst_loss, kl_divergence

        reconst_loss = reconst_loss.view(self.n_labels, -1)
        kl_divergence = kl_divergence.view(self.n_labels, -1)

        if self.log_variational:
            x_ = torch.log(1 + x)

        probs = self.classifier(x_)
        reconst_loss = (reconst_loss.t() * probs).sum(dim=1)
        kl_divergence = (kl_divergence.t() * probs).sum(dim=1)
        kl_divergence += kl(Multinomial(probs=probs),
                            Multinomial(probs=self.y_prior))

        return reconst_loss, kl_divergence
Exemplo n.º 30
0
 def test_multinomial_1d(self):
     p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
     # TODO: this should return a 0-dim tensor once we have Scalar support
     self.assertEqual(Multinomial(p).sample().size(), (1, ))
     self._gradcheck_log_prob(Multinomial, (p, ))