Пример #1
0
def KL(nn_state, target_psi, space, bases=None, **kwargs):
    r"""A function for calculating the total KL divergence.

    :param nn_state: The neural network state (i.e. complex wavefunction or
                     positive wavefunction).
    :type nn_state: WaveFunction
    :param target_psi: The true wavefunction of the system. Can be a dictionary
                       with each value being the wavefunction represented in a
                       different basis.
    :type target_psi: torch.Tensor or dict(str, torch.Tensor)
    :param space: The hilbert space of the system.
    :type space: torch.Tensor
    :param bases: An array of unique bases.
    :type bases: np.array(dtype=str)
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The KL divergence.
    :rtype: float
    """
    psi_r = torch.zeros(2,
                        1 << nn_state.num_visible,
                        dtype=torch.double,
                        device=nn_state.device)
    KL = 0.0

    if isinstance(target_psi, dict):
        target_psi = {k: v.to(nn_state.device) for k, v in target_psi.items()}
        if bases is None:
            bases = list(target_psi.keys())
        else:
            assert set(bases) == set(target_psi.keys(
            )), "Given bases must match the keys of the target_psi dictionary."
    else:
        target_psi = target_psi.to(nn_state.device)

    Z = nn_state.compute_normalization(space)
    if bases is None:
        target_probs = cplx.absolute_value(target_psi)**2
        nn_probs = nn_state.probability(space, Z)
        KL += torch.sum(target_probs * probs_to_logits(target_probs))
        KL -= torch.sum(target_probs * probs_to_logits(nn_probs))
    else:
        unitary_dict = nn_state.unitary_dict
        for basis in bases:
            psi_r = rotate_psi(nn_state, basis, space, unitary_dict)
            if isinstance(target_psi, dict):
                target_psi_r = target_psi[basis]
            else:
                target_psi_r = rotate_psi(nn_state, basis, space, unitary_dict,
                                          target_psi)

            probs_r = (cplx.absolute_value(psi_r)**2) / Z
            target_probs_r = cplx.absolute_value(target_psi_r)**2

            KL += torch.sum(target_probs_r * probs_to_logits(target_probs_r))
            KL -= torch.sum(target_probs_r * probs_to_logits(probs_r))
        KL /= float(len(bases))

    return KL.item()
Пример #2
0
def NLL(nn_state, samples, space, train_bases=None, **kwargs):
    r"""A function for calculating the negative log-likelihood.

    :param nn_state: The neural network state (i.e. complex wavefunction or
                     positive wavefunction).
    :type nn_state: WaveFunction
    :param samples: Samples to compute the NLL on.
    :type samples: torch.Tensor
    :param space: The hilbert space of the system.
    :type space: torch.Tensor
    :param train_bases: An array of bases where measurements were taken.
    :type train_bases: np.array(dtype=str)
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The Negative Log-Likelihood.
    :rtype: float
    """
    psi_r = torch.zeros(2,
                        1 << nn_state.num_visible,
                        dtype=torch.double,
                        device=nn_state.device)
    NLL = 0.0
    Z = nn_state.compute_normalization(space)
    if train_bases is None:
        nn_probs = nn_state.probability(samples, Z)
        NLL = torch.sum(probs_to_logits(nn_probs))
    else:
        unitary_dict = nn_state.unitary_dict
        # print(train_bases)
        for i in range(len(samples)):
            # Check whether the sample was measured the reference basis
            is_reference_basis = True
            for j in range(nn_state.num_visible):
                if train_bases[i][j] != "Z":
                    is_reference_basis = False
                    break
            if is_reference_basis is True:
                nn_probs = nn_state.probability(samples[i], Z)
                NLL += torch.sum(probs_to_logits(nn_probs))
            else:
                psi_r = rotate_psi(nn_state, train_bases[i], space,
                                   unitary_dict)
                # Get the index value of the sample state
                ind = 0
                for j in range(nn_state.num_visible):
                    if samples[i, nn_state.num_visible - j - 1] == 1:
                        ind += pow(2, j)
                probs_r = cplx.norm_sqr(psi_r[:, ind]) / Z
                NLL -= probs_to_logits(probs_r).item()
    return (NLL / float(len(samples))).item()
Пример #3
0
def NLL(nn_state, samples, space=None, sample_bases=None, **kwargs):
    r"""A function for calculating the negative log-likelihood (NLL).

    :param nn_state: The neural network state.
    :type nn_state: qucumber.nn_states.NeuralStateBase
    :param samples: Samples to compute the NLL on.
    :type samples: torch.Tensor
    :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
                  If `None`, will generate them using the provided `nn_state`.
    :type space: torch.Tensor
    :param sample_bases: An array of bases where measurements were taken.
    :type sample_bases: numpy.ndarray
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The Negative Log-Likelihood.
    :rtype: float
    """
    space = space if space is not None else nn_state.generate_hilbert_space()
    Z = nn_state.normalization(space)

    if sample_bases is None:
        nn_probs = nn_state.probability(samples, Z)
        NLL_ = -torch.mean(probs_to_logits(nn_probs)).item()
        return NLL_
    else:
        NLL_ = 0.0

        unique_bases, indices = np.unique(sample_bases,
                                          axis=0,
                                          return_inverse=True)
        indices = torch.Tensor(indices).to(samples)

        for i in range(unique_bases.shape[0]):
            basis = unique_bases[i, :]
            rot_sites = np.where(basis != "Z")[0]

            if rot_sites.size != 0:
                if isinstance(nn_state, WaveFunctionBase):
                    Upsi = rotate_psi_inner_prod(nn_state, basis,
                                                 samples[indices == i, :])
                    nn_probs = (cplx.absolute_value(Upsi)**2) / Z
                else:
                    nn_probs = (rotate_rho_probs(nn_state, basis,
                                                 samples[indices == i, :]) / Z)
            else:
                nn_probs = nn_state.probability(samples[indices == i, :], Z)

            NLL_ -= torch.sum(probs_to_logits(nn_probs))

        return NLL_ / float(len(samples))
Пример #4
0
    def __init__(
        self,
        total_count: Optional[torch.Tensor] = None,
        probs: Optional[torch.Tensor] = None,
        logits: Optional[torch.Tensor] = None,
        mu: Optional[torch.Tensor] = None,
        theta: Optional[torch.Tensor] = None,
        validate_args: bool = False,
    ):
        self._eps = 1e-8
        if (mu is None) == (total_count is None):
            raise ValueError(
                "Please use one of the two possible parameterizations. Refer to the documentation for more information."
            )

        using_param_1 = total_count is not None and (
            logits is not None or probs is not None
        )
        if using_param_1:
            logits = logits if logits is not None else probs_to_logits(probs)
            total_count = total_count.type_as(logits)
            total_count, logits = broadcast_all(total_count, logits)
            mu, theta = _convert_counts_logits_to_mean_disp(total_count, logits)
        else:
            mu, theta = broadcast_all(mu, theta)
        self.mu = mu
        self.theta = theta
        super().__init__(validate_args=validate_args)
Пример #5
0
 def __init__(self, cont, logits0=None, probs0=None, validate_args=None):
     """
     - with probability p_0 = sigmoid(logits0) this returns 0
     - with probability 1 - p_0 this returns a sample in the open interval (0, 1)
     
     logits0: logits for p_0
     cont: a (properly normalised) distribution over (0, 1)
         e.g. RightTruncatedExponential
     """
     if logits0 is None and probs0 is None:
         raise ValueError("You must specify either logits0 or probs0")
     if logits0 is not None and probs0 is not None:
         raise ValueError("You cannot specify both logits0 and probs0")
     shape = cont.batch_shape
     super(MixtureD0C01, self).__init__(batch_shape=shape,
                                        validate_args=validate_args)
     if logits0 is None:
         self.logits = probs_to_logits(probs0, is_binary=True)
     else:
         self.logits = logits0
     self.cont = cont
     self.p0, self.pc = bernoulli_probs_from_logit(self.logits)
     self.log_p0, self.log_pc = bernoulli_log_probs_from_logit(self.logits)
     self.uniform = Uniform(
         torch.zeros(shape).to(self.logits.device),
         torch.ones(shape).to(self.logits.device))
Пример #6
0
    def __init__(self, cont, logits=None, probs=None, validate_args=None):
        """
        cont: a (properly normalised) distribution over (0, 1)
            e.g. RightTruncatedExponential, Uniform(0, 1)
        logits: [..., 3] 
        probs: [..., 3]
        """
        if logits is None and probs is None:
            raise ValueError("You must specify either logits or probs")
        if logits is not None and probs is not None:
            raise ValueError("You cannot specify both logits and probs")
        shape = cont.batch_shape
        super(MixtureD01C01, self).__init__(batch_shape=shape,
                                            validate_args=validate_args)
        if logits is None:
            self.logits = probs_to_logits(probs, is_binary=False)
            self.probs = probs
        else:
            self.logits = logits
            self.probs = logits_to_probs(logits, is_binary=False)

        self.logprobs = F.log_softmax(self.logits, dim=-1)
        self.cont = cont
        self.p0, self.p1, self.pc = [
            t.squeeze(-1) for t in torch.split(self.probs, 1, dim=-1)
        ]
        self.log_p0, self.log_p1, self.log_pc = [
            t.squeeze(-1) for t in torch.split(self.logprobs, 1, dim=-1)
        ]
        self.uniform = Uniform(
            torch.zeros(shape).to(self.logits.device),
            torch.ones(shape).to(self.logits.device))
Пример #7
0
 def aggregate_predictions(self, predictions, dim=0):
     probs = dist_utils.logits_to_probs(
         predictions, is_binary=self.is_binary
     ) if self.logit_predictions else predictions
     avg_probs = probs.mean(dim)
     return dist_utils.probs_to_logits(
         avg_probs,
         is_binary=self.is_binary) if self.logit_predictions else avg_probs
def KL_div(target_state, nn_probs):
    """
    target_state:   torch.Tensor
                    the state to be reconstructed
    wavefunction:   PositiveWaveFunction
                    RNN reconstruction of wavefunction
    nn_probs:       torch.Tensor
                    probabilities of each basis state, predicted by NN

    returns:        float
                    the KL divergence of distributions
    """
    targ_probs = torch.pow(torch.abs(target_state), 2)

    div = torch.sum(targ_probs * probs_to_logits(targ_probs)) - torch.sum(
        targ_probs * probs_to_logits(nn_probs))

    return div.item()
Пример #9
0
def density_matrix_KL(nn_state, target, bases, v_space, a_space):
    """Computes the KL divergence between the current and target density matrix

    :param target: The target density matrix
    :type target: torch.Tensor
    :param bases: The bases in which measurement is made
    :type bases: numpy.ndarray
    :param v_space: The space of the visible states
    :type v_space: torch.Tensor
    :param a_space: The space of the auxiliary states
    :type a_space: torch.Tensor
    :returns: The KL divergence
    :rtype: float
    """
    Z = nn_state.rbm_am.partition(v_space, a_space)
    unitary_dict = nn_state.unitary_dict
    rho_r_diag = torch.zeros(2, 2**nn_state.num_visible, dtype=torch.double)
    target_rho_r_diag = torch.zeros_like(rho_r_diag)

    KL = 0.0

    for basis in bases:
        rho_r = nn_state.rotate_rho(basis, v_space, Z, unitary_dict)
        target_rho_r = nn_state.rotate_rho(basis,
                                           v_space,
                                           Z,
                                           unitary_dict,
                                           rho=target)

        rho_r_diag[0] = torch.diagonal(rho_r[0])
        rho_r_diag[1] = torch.diagonal(rho_r[1])
        target_rho_r_diag[0] = torch.diagonal(target_rho_r[0])
        target_rho_r_diag[1] = torch.diagonal(target_rho_r[1])

        KL += torch.sum(target_rho_r_diag[0] *
                        probs_to_logits(target_rho_r_diag[0]))
        KL -= torch.sum(target_rho_r_diag[0] * probs_to_logits(rho_r_diag[0]))

    KL /= float(len(bases))

    return KL.item()
Пример #10
0
    def forward(self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens,
                **kwargs):
        assert not self.decoder.src_embedding_copy, "do not support embedding copy."

        # encoding
        encoder_out = self.encoder(src_tokens,
                                   src_lengths=src_lengths,
                                   **kwargs)

        # length prediction
        length_out = self.decoder.forward_length(normalize=False,
                                                 encoder_out=encoder_out)
        length_tgt = self.decoder.forward_length_prediction(
            length_out, encoder_out, tgt_tokens)

        # posterior & prior
        prior_tokens = self.initialize_prior_input(prev_output_tokens)
        prior_out = self.prior(encoder_out, prior_tokens)
        posterior_out = self.posterior(encoder_out, prev_output_tokens)

        # decoding
        word_ins_out = self.decoder(
            normalize=False,
            prev_output_tokens=prev_output_tokens,  # prev_output_tokens, (B, T)
            prev_output_embeds=posterior_out.
            out,  # need to be same as pos emb (B, Ty, dim)
            encoder_out=encoder_out)

        word_ins_mask = tgt_tokens.ne(self.pad)

        return {
            "word_ins": {
                "out": word_ins_out,
                "tgt": tgt_tokens,
                "mask": word_ins_mask,
                "ls": self.args.label_smoothing,
                "nll_loss": True
            },
            "length": {
                "out": length_out,
                "tgt": length_tgt,
                "factor": self.decoder.length_loss_factor
            },
            # this will leave nat_loss to compute kl_div for you.
            "kl_div": {
                "out": probs_to_logits(prior_out.attn),
                "tgt": posterior_out.attn.detach(
                ),  # cannot let gradient go through your target!
                "mask": tgt_tokens.ne(self.pad),
                "factor": self.kl_div_loss_factor
            }
        }
Пример #11
0
def log_bernoulli(probs, target):
    """
    Args:
        logit:  [B, X]
        target: [B, X]
    
    Returns:
        output:      [B]
    """
    logit = probs_to_logits(probs, is_binary=True)
    loss = -F.relu(logit) + torch.mul(
        target, logit) - torch.log(1. + torch.exp(-logit.abs()))
    loss = torch.sum(loss, 1)
    return loss
Пример #12
0
def gumbel_softmax(probs, temperature, hard=False):
    """
    input: [*, n_class]
    return: [*, n_class] an one-hot vector
    """
    logits = probs_to_logits(probs)
    y = gumbel_softmax_sample(logits, temperature)
    shape = y.size()
    if hard:
        _, ind = y.max(dim=-1)
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(*shape)
        return (y_hard - y).detach() + y
    else:
        return y
Пример #13
0
    def select_action(self,
                      state,
                      rand_flag=False,
                      eps_flag=False,
                      eps_value=1.0,
                      train_flag=False):
        def get_reverse_prob(probs):
            # assume probs.size() is size([1, action_size])
            rev_idxs = torch.arange(probs.size(-1) - 1,
                                    -1,
                                    -1,
                                    device=probs.device).long()
            with torch.no_grad():
                rev_probs = torch.index_select(probs, -1, rev_idxs)
            return rev_probs

        if train_flag:
            self.policy_net.train()
            action_probs = self.policy_net(state)  # size([1, action_size])
        else:
            self.policy_net.eval()
            with torch.no_grad():
                action_probs = self.policy_net(state)  # size([1, action_size])
        action_logps = probs_to_logits(action_probs)  # size([1, action_size])

        if rand_flag:
            if eps_flag and random.random() < eps_value:
                # print('use epsilon random policy')
                action_rev_probs = get_reverse_prob(action_probs)
                m = dist.Categorical(probs=action_rev_probs)
            else:
                m = dist.Categorical(probs=action_probs)
            action = m.sample()  # size([1])
            # action_logp = m.log_prob(action)  # size([1])
        else:
            action = torch.argmax(action_probs, dim=-1)  # size([1])

        assert action.requires_grad is False
        action_logp = action_logps.gather(-1, action.unsqueeze(0)).squeeze(
            -1)  # size([1])
        action = action.item()
        self.episode_actions.append(action)
        self.episode_action_logps.append(action_logp)
        self.episode_action_probs.append(action_probs)

        return action
Пример #14
0
def _single_basis_KL(target_probs, nn_probs):
    return torch.sum(target_probs * probs_to_logits(target_probs)) - torch.sum(
        target_probs * probs_to_logits(nn_probs))
Пример #15
0
 def _logitsfn(self):
     return lambda conds: tcdu.probs_to_logits(self._probsfn(conds))
Пример #16
0
 def gate_logits(self):
     return probs_to_logits(self.gate)
Пример #17
0
 def logits(self):
     return probs_to_logits(self.probs, is_binary=True)
Пример #18
0
 def zi_logits(self) -> torch.Tensor:
     return probs_to_logits(self.zi_probs, is_binary=True)
Пример #19
0
 def _logitsfn(self):
     return lambda conds: tcdu.probs_to_logits(self._probsfn(conds),
                                               is_binary=True)
Пример #20
0
 def logits(self):
     return probs_to_logits(self.weights)
Пример #21
0
def KL(nn_state, target_psi, space, bases=None, **kwargs):
    r"""A function for calculating the total KL divergence.

    .. math:: KL(P_{target} \vert P_{RBM}) = \sum_{x \in \mathcal{H}} P_{target}(x)\log(\frac{P_{RBM}(x)}{P_{target}(x)})

    :param nn_state: The neural network state (i.e. complex wavefunction or
                     positive wavefunction).
    :type nn_state: qucumber.nn_states.WaveFunctionBase
    :param target_psi: The true wavefunction of the system. Can be a dictionary
                       with each value being the wavefunction represented in a
                       different basis, and the key identifying the basis.
    :type target_psi: torch.Tensor or dict(str, torch.Tensor)
    :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
                  The ordering of the basis elements must match with the ordering of the
                  coefficients given in `target_psi`.
    :type space: torch.Tensor
    :param bases: An array of unique bases. If given, the KL divergence will be
                  computed for each basis and the average will be returned.
    :type bases: np.array(dtype=str)
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The KL divergence.
    :rtype: float
    """
    psi_r = torch.zeros(2,
                        1 << nn_state.num_visible,
                        dtype=torch.double,
                        device=nn_state.device)
    KL = 0.0

    if isinstance(target_psi, dict):
        target_psi = {k: v.to(nn_state.device) for k, v in target_psi.items()}
        if bases is None:
            bases = list(target_psi.keys())
        else:
            assert set(bases) == set(target_psi.keys(
            )), "Given bases must match the keys of the target_psi dictionary."
    else:
        target_psi = target_psi.to(nn_state.device)

    Z = nn_state.compute_normalization(space)
    if bases is None:
        target_probs = cplx.absolute_value(target_psi)**2
        nn_probs = nn_state.probability(space, Z)
        KL += torch.sum(target_probs * probs_to_logits(target_probs))
        KL -= torch.sum(target_probs * probs_to_logits(nn_probs))
    else:
        unitary_dict = nn_state.unitary_dict
        for basis in bases:
            psi_r = rotate_psi(nn_state, basis, space, unitary_dict)
            if isinstance(target_psi, dict):
                target_psi_r = target_psi[basis]
            else:
                target_psi_r = rotate_psi(nn_state, basis, space, unitary_dict,
                                          target_psi)

            probs_r = (cplx.absolute_value(psi_r)**2) / Z
            target_probs_r = cplx.absolute_value(target_psi_r)**2

            KL += torch.sum(target_probs_r * probs_to_logits(target_probs_r))
            KL -= torch.sum(target_probs_r * probs_to_logits(probs_r))
        KL /= float(len(bases))

    return KL.item()
Пример #22
0
 def forward(self, outputs, context=None):
     inputs = probs_to_logits(
         outputs,
         is_binary=True)  # stable implementation of inverse sigmoid
     log_p, log_q = -F.softplus(-inputs), -F.softplus(inputs)
     return inputs, -log_p - log_q
Пример #23
0
def compute_log_pdf_bernoulli(x, p):

    logits = probs_to_logits(p, is_binary=True)
    logits, x = broadcast_all(logits, x)

    return -F.binary_cross_entropy_with_logits(logits, x, reduction='sum')
Пример #24
0
 def __init__(self, p=0.5):
     super(CDropout, self).__init__()
     self.p_logit = nn.Parameter(
         probs_to_logits(torch.as_tensor(p), is_binary=True))
Пример #25
0
 def logits(self):
     return probs_to_logits(self.probs)
Пример #26
0
    def __init__(self, output_dim,
                 grid_z, grid_c, grid_cz,
                 mapping_z=None, mapping_c=None, mapping_cz=None,
                 has_feature_level_sparsity=True,
                 penalty_type="fixed", lambda0=1.0,
                 likelihood="Gaussian",
                 p1=0.2, p2=0.2, p3=0.2, device="cpu"):
        """
        NN mapping with constraints to be used as the decoder in a CVAE. Performs Neural Decomposition.
        :param output_dim: data dimensionality
        :param grid_z: grid for quadrature (estimation of integral for f(z))
        :param grid_c: grid for quadrature (estimation of integral for f(c))
        :param grid_cz: grid for quadrature (estimation of integral for f(z, c))
        :param mapping_z: neural net mapping z to data
        :param mapping_c: neural net mapping c to data
        :param mapping_cz: neural net mapping (z, c) to data
        :param has_feature_level_sparsity: whether to use (Relaxed) Bernoulli feature-level sparsity
        :param penalty_type: which penalty to apply
        :param lambda0: initialisation for both fixed penalty $c$ as well as $lambda$ values
        :param likelihood: Gaussian or Bernoulli
        :param p1: Bernoulli prior for sparsity on mapping_z
        :param p2: Bernoulli prior for sparsity on mapping_c
        :param p3: Bernoulli prior for sparsity on mapping_zc
        :param device: cpu or cuda
        """
        super().__init__()

        self.output_dim = output_dim
        self.likelihood = likelihood
        self.has_feature_level_sparsity = has_feature_level_sparsity
        self.penalty_type = penalty_type

        self.grid_z = grid_z.to(device)
        self.grid_c = grid_c.to(device)
        self.grid_cz = grid_cz.to(device)

        self.n_grid_z = grid_z.shape[0]
        self.n_grid_c = grid_c.shape[0]
        self.n_grid_cz = grid_cz.shape[0]

        # input -> output
        self.mapping_z = mapping_z
        self.mapping_c = mapping_c
        self.mapping_cz = mapping_cz

        if self.likelihood == "Gaussian":
            # feature-specific variances (for Gaussian likelihood)
            self.noise_sd = torch.nn.Parameter(-1.0 * torch.ones(1, output_dim))

        self.intercept = torch.nn.Parameter(torch.zeros(1, output_dim))

        self.Lambda_z = Variable(lambda0*torch.ones(1, output_dim, device=device), requires_grad=True)

        self.Lambda_c = Variable(lambda0*torch.ones(1, output_dim, device=device), requires_grad=True)

        self.Lambda_cz_1 = Variable(lambda0*torch.ones(self.n_grid_z, output_dim, device=device), requires_grad=True)

        self.Lambda_cz_2 = Variable(lambda0*torch.ones(self.n_grid_c, output_dim, device=device), requires_grad=True)

        self.lambda0 = lambda0

        self.device = device

        # RelaxedBernoulli
        self.temperature = 1.0 * torch.ones(1, device=device)

        if self.has_feature_level_sparsity:

            # for the prior RelaxedBernoulli(logits)
            self.logits_z = probs_to_logits(p1 * torch.ones(1, output_dim).to(device), is_binary=True)
            self.logits_c = probs_to_logits(p2 * torch.ones(1, output_dim).to(device), is_binary=True)
            self.logits_cz = probs_to_logits(p3 * torch.ones(1, output_dim).to(device), is_binary=True)

            # for the approx posterior
            self.qlogits_z = torch.nn.Parameter(3.0 * torch.ones(1, output_dim).to(device))
            self.qlogits_c = torch.nn.Parameter(3.0 * torch.ones(1, output_dim).to(device))
            self.qlogits_cz = torch.nn.Parameter(2.0 * torch.ones(1, output_dim).to(device))
Пример #27
0
 def logits(self):
     return probs_to_logits(self.probs)
Пример #28
0
    def forward(self,
                img_enc,
                alpha,
                tau,
                t,
                gen_pres_probs=None,
                gen_depth_mean=None,
                gen_depth_std=None,
                gen_where_mean=None,
                gen_where_std=None):
        """

        :param x: (bs, dim, img_h, img_w)
        :param propagate_encode: (bs, propagate_encode_dim)
        :param tau:
        :return:
        """

        bs = img_enc.size(0)

        if self.args.phase_generate and t >= self.args.observe_frames:
            gen_pres_logits = probs_to_logits(gen_pres_probs, is_binary=True).view(1, 1, 1, 1). \
                expand(bs, -1, self.args.num_cell_h, self.args.num_cell_w)
            z_pres_logits = gen_pres_logits
            z_depth_mean, z_depth_std = gen_depth_mean.view(1, -1, 1, 1).expand(bs, -1, self.args.num_cell_h,
                                                                                self.args.num_cell_w), \
                                        gen_depth_std.view(1, -1, 1, 1).expand(bs, -1, self.args.num_cell_h,
                                                                               self.args.num_cell_w)
            z_where_mean, z_where_std = gen_where_mean.view(1, -1, 1, 1).expand(bs, -1, self.args.num_cell_h,
                                                                                self.args.num_cell_w), \
                                        gen_where_std.view(1, -1, 1, 1).expand(bs, -1, self.args.num_cell_h,
                                                                               self.args.num_cell_w)
        else:
            mask_enc = self.mask_enc_net(alpha)

            x_alpha_enc = torch.cat((img_enc, mask_enc), dim=1)

            cat_enc = self.img_mask_cat_enc(x_alpha_enc)

            # (bs, 1, 8, 8)
            z_pres_logits = pres_logit_factor * torch.tanh(
                self.z_pres_net(cat_enc) + self.z_pres_bias)

            # (bs, dim, 8, 8)
            z_depth_mean, z_depth_std = self.z_depth_net(cat_enc).chunk(2, 1)
            z_depth_std = F.softplus(z_depth_std)
            # (bs, 4 + 4, 8, 8)
            z_where_mean, z_where_std = self.z_where_net(cat_enc).chunk(2, 1)
            z_where_std = F.softplus(z_where_std)

        q_z_pres = NumericalRelaxedBernoulli(logits=z_pres_logits,
                                             temperature=tau)
        z_pres_y = q_z_pres.rsample()

        z_pres = torch.sigmoid(z_pres_y)

        q_z_depth = Normal(z_depth_mean, z_depth_std)

        z_depth = q_z_depth.rsample()

        q_z_where = Normal(z_where_mean, z_where_std)

        z_where = q_z_where.rsample()

        # (bs, dim, 8, 8)
        z_where_origin = z_where.clone()

        scale, ratio = z_where[:, :2].tanh().chunk(2, 1)
        scale = self.args.size_anc + self.args.var_s * scale
        ratio = self.args.ratio_anc + self.args.var_anc * ratio
        ratio_sqrt = ratio.sqrt()
        z_where[:, 0:1] = scale / ratio_sqrt
        z_where[:, 1:2] = scale * ratio_sqrt
        z_where[:,
                2:] = 2. / self.args.num_cell_h * (self.offset + 0.5 +
                                                   z_where[:, 2:].tanh()) - 1

        z_where = z_where.permute(0, 2, 3, 1).reshape(-1, 4)

        return z_where, z_pres, z_depth, z_where_mean, z_where_std, \
               z_depth_mean, z_depth_std, z_pres_logits, z_pres_y, z_where_origin
Пример #29
0
    def batch_interact_with(self,
                            env,
                            sample_cnt,
                            fix_example=None,
                            train_flag=False,
                            rand_flag=False,
                            eps_flag=False,
                            eps_value=0.1):
        # get next example batch
        env.next_example(example=fix_example)

        # Size(seq_len, batch_size, action_size)
        seq_action_probs = self.batch_calcu_action_prob(env.input_emb_seq,
                                                        env.input_seq_lens,
                                                        train_flag=train_flag)
        seq_action_logps = probs_to_logits(seq_action_probs)

        batch_seq_lens = env.input_seq_lens.tolist()
        batch_size = len(batch_seq_lens)
        batch_mean_rewards = torch.zeros_like(
            env.input_seq_lens).float()  # Size(batch_size)

        mask_list = []
        reward_list = []
        for sample_idx in range(sample_cnt):
            # reset episode variables of env and agent
            env.reset_state()
            self.reset_episode_info()

            # Size([seq_len, batch_size, action_size]), dtype=torch.float
            seq_action_masks = self.batch_select_action(seq_action_probs,
                                                        batch_seq_lens,
                                                        rand_flag=rand_flag,
                                                        eps_flag=eps_flag,
                                                        eps_value=eps_value)

            batch_rewards = env.batch_transition_with(
                seq_action_masks)  # Size([batch_size])
            batch_mean_rewards += batch_rewards  # Size([batch_size])
            # TODO: whether to add reward decay strategy
            seq_action_rewards = batch_rewards.unsqueeze(0).unsqueeze(
                -1).expand_as(seq_action_masks)
            # for torch 1.1.0
            seq_action_rewards = seq_action_rewards.contiguous()

            assert seq_action_masks.requires_grad is False
            assert seq_action_rewards.requires_grad is False
            mask_list.append(seq_action_masks)
            reward_list.append(seq_action_rewards)

        # reduce reward variance
        if sample_cnt > 1 or batch_size > 1:
            batch_mean_rewards /= sample_cnt  # Size([batch_size])
            base_mean_rewards = batch_mean_rewards.unsqueeze(0).unsqueeze(
                -1).expand_as(seq_action_probs)
            # cur_baseline = torch.sum(batch_mean_rewards) / batch_size  # Size([])
            for sample_idx in range(sample_cnt):
                # reward_list[sample_idx] -= cur_baseline  # Size([batch_size])
                reward_list[
                    sample_idx] -= base_mean_rewards  # Size([batch_size])

        return seq_action_probs, seq_action_logps, mask_list, reward_list, batch_mean_rewards
Пример #30
0
 def logits(self):
     return probs_to_logits(self.probs, is_binary=True)
Пример #31
0
    def __init__(self,
                 output_dim,
                 n_covariates,
                 grid_z,
                 grid_c,
                 mapping_z,
                 mappings_c,
                 mappings_cz,
                 has_feature_level_sparsity=True,
                 penalty_type="fixed",
                 lambda0=1.0,
                 likelihood="Gaussian",
                 p1=0.2,
                 p2=0.2,
                 p3=0.2,
                 device="cpu"):
        """
        Decoder for multiple covariates (i.e. multivariate c)
        """
        super().__init__()

        self.output_dim = output_dim
        self.likelihood = likelihood
        self.has_feature_level_sparsity = has_feature_level_sparsity
        self.penalty_type = penalty_type
        self.n_covariates = n_covariates

        assert isinstance(grid_c, list), "grid_c must be a list"
        assert len(grid_c) == n_covariates

        self.grid_z = grid_z
        self.grid_c = grid_c
        self.grid_cz = [
            torch.cat(expand_grid(self.grid_z, self.grid_c[j]), dim=1)
            for j in range(self.n_covariates)
        ]

        self.grid_z = grid_z.to(device)
        self.grid_c = [c.to(device) for c in grid_c]
        self.grid_cz = [cz.to(device) for cz in self.grid_cz]

        self.n_grid_z = grid_z.shape[0]
        self.n_grid_c = [c.shape[0] for c in grid_c]
        self.n_grid_cz = [cz.shape[0] for cz in self.grid_cz]

        # input -> output
        self.mapping_z = mapping_z
        self.mappings_c = mappings_c
        self.mappings_cz = mappings_cz

        if self.likelihood == "Gaussian":
            # feature-specific variances (for Gaussian likelihood)
            self.noise_sd = torch.nn.Parameter(-1.0 *
                                               torch.ones(1, output_dim))

        self.intercept = torch.nn.Parameter(torch.zeros(1, output_dim))

        self.Lambda_z = Variable(torch.ones(1, output_dim, device=device),
                                 requires_grad=True)

        self.Lambda_c = [
            Variable(torch.ones(n_covariates, 1, output_dim, device=device),
                     requires_grad=True) for _ in range(self.n_covariates)
        ]

        self.Lambda_cz_1 = [
            Variable(torch.ones(self.n_grid_z, output_dim, device=device),
                     requires_grad=True) for _ in range(self.n_covariates)
        ]

        self.Lambda_cz_2 = [
            Variable(torch.ones(self.n_grid_c[j], output_dim, device=device),
                     requires_grad=True) for j in range(self.n_covariates)
        ]

        self.lambda0 = lambda0

        self.device = device

        # RelaxedBernoulli
        self.temperature = 1.0 * torch.ones(1, device=device)

        if self.has_feature_level_sparsity:

            # for the prior RelaxedBernoulli(logits)
            self.logits_z = probs_to_logits(
                p1 * torch.ones(1, output_dim).to(device), is_binary=True)
            self.logits_c = probs_to_logits(
                p2 * torch.ones(n_covariates, output_dim).to(device),
                is_binary=True)
            self.logits_cz = probs_to_logits(
                p3 * torch.ones(n_covariates, output_dim).to(device),
                is_binary=True)

            # for the approx posterior
            self.qlogits_z = torch.nn.Parameter(
                3.0 * torch.ones(1, output_dim).to(device))
            self.qlogits_c = torch.nn.Parameter(
                3.0 * torch.ones(n_covariates, output_dim).to(device))
            self.qlogits_cz = torch.nn.Parameter(
                2.0 * torch.ones(n_covariates, output_dim).to(device))