def KL(nn_state, target_psi, space, bases=None, **kwargs): r"""A function for calculating the total KL divergence. :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: WaveFunction :param target_psi: The true wavefunction of the system. Can be a dictionary with each value being the wavefunction represented in a different basis. :type target_psi: torch.Tensor or dict(str, torch.Tensor) :param space: The hilbert space of the system. :type space: torch.Tensor :param bases: An array of unique bases. :type bases: np.array(dtype=str) :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The KL divergence. :rtype: float """ psi_r = torch.zeros(2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device) KL = 0.0 if isinstance(target_psi, dict): target_psi = {k: v.to(nn_state.device) for k, v in target_psi.items()} if bases is None: bases = list(target_psi.keys()) else: assert set(bases) == set(target_psi.keys( )), "Given bases must match the keys of the target_psi dictionary." else: target_psi = target_psi.to(nn_state.device) Z = nn_state.compute_normalization(space) if bases is None: target_probs = cplx.absolute_value(target_psi)**2 nn_probs = nn_state.probability(space, Z) KL += torch.sum(target_probs * probs_to_logits(target_probs)) KL -= torch.sum(target_probs * probs_to_logits(nn_probs)) else: unitary_dict = nn_state.unitary_dict for basis in bases: psi_r = rotate_psi(nn_state, basis, space, unitary_dict) if isinstance(target_psi, dict): target_psi_r = target_psi[basis] else: target_psi_r = rotate_psi(nn_state, basis, space, unitary_dict, target_psi) probs_r = (cplx.absolute_value(psi_r)**2) / Z target_probs_r = cplx.absolute_value(target_psi_r)**2 KL += torch.sum(target_probs_r * probs_to_logits(target_probs_r)) KL -= torch.sum(target_probs_r * probs_to_logits(probs_r)) KL /= float(len(bases)) return KL.item()
def NLL(nn_state, samples, space, train_bases=None, **kwargs): r"""A function for calculating the negative log-likelihood. :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: WaveFunction :param samples: Samples to compute the NLL on. :type samples: torch.Tensor :param space: The hilbert space of the system. :type space: torch.Tensor :param train_bases: An array of bases where measurements were taken. :type train_bases: np.array(dtype=str) :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The Negative Log-Likelihood. :rtype: float """ psi_r = torch.zeros(2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device) NLL = 0.0 Z = nn_state.compute_normalization(space) if train_bases is None: nn_probs = nn_state.probability(samples, Z) NLL = torch.sum(probs_to_logits(nn_probs)) else: unitary_dict = nn_state.unitary_dict # print(train_bases) for i in range(len(samples)): # Check whether the sample was measured the reference basis is_reference_basis = True for j in range(nn_state.num_visible): if train_bases[i][j] != "Z": is_reference_basis = False break if is_reference_basis is True: nn_probs = nn_state.probability(samples[i], Z) NLL += torch.sum(probs_to_logits(nn_probs)) else: psi_r = rotate_psi(nn_state, train_bases[i], space, unitary_dict) # Get the index value of the sample state ind = 0 for j in range(nn_state.num_visible): if samples[i, nn_state.num_visible - j - 1] == 1: ind += pow(2, j) probs_r = cplx.norm_sqr(psi_r[:, ind]) / Z NLL -= probs_to_logits(probs_r).item() return (NLL / float(len(samples))).item()
def NLL(nn_state, samples, space=None, sample_bases=None, **kwargs): r"""A function for calculating the negative log-likelihood (NLL). :param nn_state: The neural network state. :type nn_state: qucumber.nn_states.NeuralStateBase :param samples: Samples to compute the NLL on. :type samples: torch.Tensor :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`. If `None`, will generate them using the provided `nn_state`. :type space: torch.Tensor :param sample_bases: An array of bases where measurements were taken. :type sample_bases: numpy.ndarray :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The Negative Log-Likelihood. :rtype: float """ space = space if space is not None else nn_state.generate_hilbert_space() Z = nn_state.normalization(space) if sample_bases is None: nn_probs = nn_state.probability(samples, Z) NLL_ = -torch.mean(probs_to_logits(nn_probs)).item() return NLL_ else: NLL_ = 0.0 unique_bases, indices = np.unique(sample_bases, axis=0, return_inverse=True) indices = torch.Tensor(indices).to(samples) for i in range(unique_bases.shape[0]): basis = unique_bases[i, :] rot_sites = np.where(basis != "Z")[0] if rot_sites.size != 0: if isinstance(nn_state, WaveFunctionBase): Upsi = rotate_psi_inner_prod(nn_state, basis, samples[indices == i, :]) nn_probs = (cplx.absolute_value(Upsi)**2) / Z else: nn_probs = (rotate_rho_probs(nn_state, basis, samples[indices == i, :]) / Z) else: nn_probs = nn_state.probability(samples[indices == i, :], Z) NLL_ -= torch.sum(probs_to_logits(nn_probs)) return NLL_ / float(len(samples))
def __init__( self, total_count: Optional[torch.Tensor] = None, probs: Optional[torch.Tensor] = None, logits: Optional[torch.Tensor] = None, mu: Optional[torch.Tensor] = None, theta: Optional[torch.Tensor] = None, validate_args: bool = False, ): self._eps = 1e-8 if (mu is None) == (total_count is None): raise ValueError( "Please use one of the two possible parameterizations. Refer to the documentation for more information." ) using_param_1 = total_count is not None and ( logits is not None or probs is not None ) if using_param_1: logits = logits if logits is not None else probs_to_logits(probs) total_count = total_count.type_as(logits) total_count, logits = broadcast_all(total_count, logits) mu, theta = _convert_counts_logits_to_mean_disp(total_count, logits) else: mu, theta = broadcast_all(mu, theta) self.mu = mu self.theta = theta super().__init__(validate_args=validate_args)
def __init__(self, cont, logits0=None, probs0=None, validate_args=None): """ - with probability p_0 = sigmoid(logits0) this returns 0 - with probability 1 - p_0 this returns a sample in the open interval (0, 1) logits0: logits for p_0 cont: a (properly normalised) distribution over (0, 1) e.g. RightTruncatedExponential """ if logits0 is None and probs0 is None: raise ValueError("You must specify either logits0 or probs0") if logits0 is not None and probs0 is not None: raise ValueError("You cannot specify both logits0 and probs0") shape = cont.batch_shape super(MixtureD0C01, self).__init__(batch_shape=shape, validate_args=validate_args) if logits0 is None: self.logits = probs_to_logits(probs0, is_binary=True) else: self.logits = logits0 self.cont = cont self.p0, self.pc = bernoulli_probs_from_logit(self.logits) self.log_p0, self.log_pc = bernoulli_log_probs_from_logit(self.logits) self.uniform = Uniform( torch.zeros(shape).to(self.logits.device), torch.ones(shape).to(self.logits.device))
def __init__(self, cont, logits=None, probs=None, validate_args=None): """ cont: a (properly normalised) distribution over (0, 1) e.g. RightTruncatedExponential, Uniform(0, 1) logits: [..., 3] probs: [..., 3] """ if logits is None and probs is None: raise ValueError("You must specify either logits or probs") if logits is not None and probs is not None: raise ValueError("You cannot specify both logits and probs") shape = cont.batch_shape super(MixtureD01C01, self).__init__(batch_shape=shape, validate_args=validate_args) if logits is None: self.logits = probs_to_logits(probs, is_binary=False) self.probs = probs else: self.logits = logits self.probs = logits_to_probs(logits, is_binary=False) self.logprobs = F.log_softmax(self.logits, dim=-1) self.cont = cont self.p0, self.p1, self.pc = [ t.squeeze(-1) for t in torch.split(self.probs, 1, dim=-1) ] self.log_p0, self.log_p1, self.log_pc = [ t.squeeze(-1) for t in torch.split(self.logprobs, 1, dim=-1) ] self.uniform = Uniform( torch.zeros(shape).to(self.logits.device), torch.ones(shape).to(self.logits.device))
def aggregate_predictions(self, predictions, dim=0): probs = dist_utils.logits_to_probs( predictions, is_binary=self.is_binary ) if self.logit_predictions else predictions avg_probs = probs.mean(dim) return dist_utils.probs_to_logits( avg_probs, is_binary=self.is_binary) if self.logit_predictions else avg_probs
def KL_div(target_state, nn_probs): """ target_state: torch.Tensor the state to be reconstructed wavefunction: PositiveWaveFunction RNN reconstruction of wavefunction nn_probs: torch.Tensor probabilities of each basis state, predicted by NN returns: float the KL divergence of distributions """ targ_probs = torch.pow(torch.abs(target_state), 2) div = torch.sum(targ_probs * probs_to_logits(targ_probs)) - torch.sum( targ_probs * probs_to_logits(nn_probs)) return div.item()
def density_matrix_KL(nn_state, target, bases, v_space, a_space): """Computes the KL divergence between the current and target density matrix :param target: The target density matrix :type target: torch.Tensor :param bases: The bases in which measurement is made :type bases: numpy.ndarray :param v_space: The space of the visible states :type v_space: torch.Tensor :param a_space: The space of the auxiliary states :type a_space: torch.Tensor :returns: The KL divergence :rtype: float """ Z = nn_state.rbm_am.partition(v_space, a_space) unitary_dict = nn_state.unitary_dict rho_r_diag = torch.zeros(2, 2**nn_state.num_visible, dtype=torch.double) target_rho_r_diag = torch.zeros_like(rho_r_diag) KL = 0.0 for basis in bases: rho_r = nn_state.rotate_rho(basis, v_space, Z, unitary_dict) target_rho_r = nn_state.rotate_rho(basis, v_space, Z, unitary_dict, rho=target) rho_r_diag[0] = torch.diagonal(rho_r[0]) rho_r_diag[1] = torch.diagonal(rho_r[1]) target_rho_r_diag[0] = torch.diagonal(target_rho_r[0]) target_rho_r_diag[1] = torch.diagonal(target_rho_r[1]) KL += torch.sum(target_rho_r_diag[0] * probs_to_logits(target_rho_r_diag[0])) KL -= torch.sum(target_rho_r_diag[0] * probs_to_logits(rho_r_diag[0])) KL /= float(len(bases)) return KL.item()
def forward(self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs): assert not self.decoder.src_embedding_copy, "do not support embedding copy." # encoding encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) # length prediction length_out = self.decoder.forward_length(normalize=False, encoder_out=encoder_out) length_tgt = self.decoder.forward_length_prediction( length_out, encoder_out, tgt_tokens) # posterior & prior prior_tokens = self.initialize_prior_input(prev_output_tokens) prior_out = self.prior(encoder_out, prior_tokens) posterior_out = self.posterior(encoder_out, prev_output_tokens) # decoding word_ins_out = self.decoder( normalize=False, prev_output_tokens=prev_output_tokens, # prev_output_tokens, (B, T) prev_output_embeds=posterior_out. out, # need to be same as pos emb (B, Ty, dim) encoder_out=encoder_out) word_ins_mask = tgt_tokens.ne(self.pad) return { "word_ins": { "out": word_ins_out, "tgt": tgt_tokens, "mask": word_ins_mask, "ls": self.args.label_smoothing, "nll_loss": True }, "length": { "out": length_out, "tgt": length_tgt, "factor": self.decoder.length_loss_factor }, # this will leave nat_loss to compute kl_div for you. "kl_div": { "out": probs_to_logits(prior_out.attn), "tgt": posterior_out.attn.detach( ), # cannot let gradient go through your target! "mask": tgt_tokens.ne(self.pad), "factor": self.kl_div_loss_factor } }
def log_bernoulli(probs, target): """ Args: logit: [B, X] target: [B, X] Returns: output: [B] """ logit = probs_to_logits(probs, is_binary=True) loss = -F.relu(logit) + torch.mul( target, logit) - torch.log(1. + torch.exp(-logit.abs())) loss = torch.sum(loss, 1) return loss
def gumbel_softmax(probs, temperature, hard=False): """ input: [*, n_class] return: [*, n_class] an one-hot vector """ logits = probs_to_logits(probs) y = gumbel_softmax_sample(logits, temperature) shape = y.size() if hard: _, ind = y.max(dim=-1) y_hard = torch.zeros_like(y).view(-1, shape[-1]) y_hard.scatter_(1, ind.view(-1, 1), 1) y_hard = y_hard.view(*shape) return (y_hard - y).detach() + y else: return y
def select_action(self, state, rand_flag=False, eps_flag=False, eps_value=1.0, train_flag=False): def get_reverse_prob(probs): # assume probs.size() is size([1, action_size]) rev_idxs = torch.arange(probs.size(-1) - 1, -1, -1, device=probs.device).long() with torch.no_grad(): rev_probs = torch.index_select(probs, -1, rev_idxs) return rev_probs if train_flag: self.policy_net.train() action_probs = self.policy_net(state) # size([1, action_size]) else: self.policy_net.eval() with torch.no_grad(): action_probs = self.policy_net(state) # size([1, action_size]) action_logps = probs_to_logits(action_probs) # size([1, action_size]) if rand_flag: if eps_flag and random.random() < eps_value: # print('use epsilon random policy') action_rev_probs = get_reverse_prob(action_probs) m = dist.Categorical(probs=action_rev_probs) else: m = dist.Categorical(probs=action_probs) action = m.sample() # size([1]) # action_logp = m.log_prob(action) # size([1]) else: action = torch.argmax(action_probs, dim=-1) # size([1]) assert action.requires_grad is False action_logp = action_logps.gather(-1, action.unsqueeze(0)).squeeze( -1) # size([1]) action = action.item() self.episode_actions.append(action) self.episode_action_logps.append(action_logp) self.episode_action_probs.append(action_probs) return action
def _single_basis_KL(target_probs, nn_probs): return torch.sum(target_probs * probs_to_logits(target_probs)) - torch.sum( target_probs * probs_to_logits(nn_probs))
def _logitsfn(self): return lambda conds: tcdu.probs_to_logits(self._probsfn(conds))
def gate_logits(self): return probs_to_logits(self.gate)
def logits(self): return probs_to_logits(self.probs, is_binary=True)
def zi_logits(self) -> torch.Tensor: return probs_to_logits(self.zi_probs, is_binary=True)
def _logitsfn(self): return lambda conds: tcdu.probs_to_logits(self._probsfn(conds), is_binary=True)
def logits(self): return probs_to_logits(self.weights)
def KL(nn_state, target_psi, space, bases=None, **kwargs): r"""A function for calculating the total KL divergence. .. math:: KL(P_{target} \vert P_{RBM}) = \sum_{x \in \mathcal{H}} P_{target}(x)\log(\frac{P_{RBM}(x)}{P_{target}(x)}) :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: qucumber.nn_states.WaveFunctionBase :param target_psi: The true wavefunction of the system. Can be a dictionary with each value being the wavefunction represented in a different basis, and the key identifying the basis. :type target_psi: torch.Tensor or dict(str, torch.Tensor) :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`. The ordering of the basis elements must match with the ordering of the coefficients given in `target_psi`. :type space: torch.Tensor :param bases: An array of unique bases. If given, the KL divergence will be computed for each basis and the average will be returned. :type bases: np.array(dtype=str) :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The KL divergence. :rtype: float """ psi_r = torch.zeros(2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device) KL = 0.0 if isinstance(target_psi, dict): target_psi = {k: v.to(nn_state.device) for k, v in target_psi.items()} if bases is None: bases = list(target_psi.keys()) else: assert set(bases) == set(target_psi.keys( )), "Given bases must match the keys of the target_psi dictionary." else: target_psi = target_psi.to(nn_state.device) Z = nn_state.compute_normalization(space) if bases is None: target_probs = cplx.absolute_value(target_psi)**2 nn_probs = nn_state.probability(space, Z) KL += torch.sum(target_probs * probs_to_logits(target_probs)) KL -= torch.sum(target_probs * probs_to_logits(nn_probs)) else: unitary_dict = nn_state.unitary_dict for basis in bases: psi_r = rotate_psi(nn_state, basis, space, unitary_dict) if isinstance(target_psi, dict): target_psi_r = target_psi[basis] else: target_psi_r = rotate_psi(nn_state, basis, space, unitary_dict, target_psi) probs_r = (cplx.absolute_value(psi_r)**2) / Z target_probs_r = cplx.absolute_value(target_psi_r)**2 KL += torch.sum(target_probs_r * probs_to_logits(target_probs_r)) KL -= torch.sum(target_probs_r * probs_to_logits(probs_r)) KL /= float(len(bases)) return KL.item()
def forward(self, outputs, context=None): inputs = probs_to_logits( outputs, is_binary=True) # stable implementation of inverse sigmoid log_p, log_q = -F.softplus(-inputs), -F.softplus(inputs) return inputs, -log_p - log_q
def compute_log_pdf_bernoulli(x, p): logits = probs_to_logits(p, is_binary=True) logits, x = broadcast_all(logits, x) return -F.binary_cross_entropy_with_logits(logits, x, reduction='sum')
def __init__(self, p=0.5): super(CDropout, self).__init__() self.p_logit = nn.Parameter( probs_to_logits(torch.as_tensor(p), is_binary=True))
def logits(self): return probs_to_logits(self.probs)
def __init__(self, output_dim, grid_z, grid_c, grid_cz, mapping_z=None, mapping_c=None, mapping_cz=None, has_feature_level_sparsity=True, penalty_type="fixed", lambda0=1.0, likelihood="Gaussian", p1=0.2, p2=0.2, p3=0.2, device="cpu"): """ NN mapping with constraints to be used as the decoder in a CVAE. Performs Neural Decomposition. :param output_dim: data dimensionality :param grid_z: grid for quadrature (estimation of integral for f(z)) :param grid_c: grid for quadrature (estimation of integral for f(c)) :param grid_cz: grid for quadrature (estimation of integral for f(z, c)) :param mapping_z: neural net mapping z to data :param mapping_c: neural net mapping c to data :param mapping_cz: neural net mapping (z, c) to data :param has_feature_level_sparsity: whether to use (Relaxed) Bernoulli feature-level sparsity :param penalty_type: which penalty to apply :param lambda0: initialisation for both fixed penalty $c$ as well as $lambda$ values :param likelihood: Gaussian or Bernoulli :param p1: Bernoulli prior for sparsity on mapping_z :param p2: Bernoulli prior for sparsity on mapping_c :param p3: Bernoulli prior for sparsity on mapping_zc :param device: cpu or cuda """ super().__init__() self.output_dim = output_dim self.likelihood = likelihood self.has_feature_level_sparsity = has_feature_level_sparsity self.penalty_type = penalty_type self.grid_z = grid_z.to(device) self.grid_c = grid_c.to(device) self.grid_cz = grid_cz.to(device) self.n_grid_z = grid_z.shape[0] self.n_grid_c = grid_c.shape[0] self.n_grid_cz = grid_cz.shape[0] # input -> output self.mapping_z = mapping_z self.mapping_c = mapping_c self.mapping_cz = mapping_cz if self.likelihood == "Gaussian": # feature-specific variances (for Gaussian likelihood) self.noise_sd = torch.nn.Parameter(-1.0 * torch.ones(1, output_dim)) self.intercept = torch.nn.Parameter(torch.zeros(1, output_dim)) self.Lambda_z = Variable(lambda0*torch.ones(1, output_dim, device=device), requires_grad=True) self.Lambda_c = Variable(lambda0*torch.ones(1, output_dim, device=device), requires_grad=True) self.Lambda_cz_1 = Variable(lambda0*torch.ones(self.n_grid_z, output_dim, device=device), requires_grad=True) self.Lambda_cz_2 = Variable(lambda0*torch.ones(self.n_grid_c, output_dim, device=device), requires_grad=True) self.lambda0 = lambda0 self.device = device # RelaxedBernoulli self.temperature = 1.0 * torch.ones(1, device=device) if self.has_feature_level_sparsity: # for the prior RelaxedBernoulli(logits) self.logits_z = probs_to_logits(p1 * torch.ones(1, output_dim).to(device), is_binary=True) self.logits_c = probs_to_logits(p2 * torch.ones(1, output_dim).to(device), is_binary=True) self.logits_cz = probs_to_logits(p3 * torch.ones(1, output_dim).to(device), is_binary=True) # for the approx posterior self.qlogits_z = torch.nn.Parameter(3.0 * torch.ones(1, output_dim).to(device)) self.qlogits_c = torch.nn.Parameter(3.0 * torch.ones(1, output_dim).to(device)) self.qlogits_cz = torch.nn.Parameter(2.0 * torch.ones(1, output_dim).to(device))
def forward(self, img_enc, alpha, tau, t, gen_pres_probs=None, gen_depth_mean=None, gen_depth_std=None, gen_where_mean=None, gen_where_std=None): """ :param x: (bs, dim, img_h, img_w) :param propagate_encode: (bs, propagate_encode_dim) :param tau: :return: """ bs = img_enc.size(0) if self.args.phase_generate and t >= self.args.observe_frames: gen_pres_logits = probs_to_logits(gen_pres_probs, is_binary=True).view(1, 1, 1, 1). \ expand(bs, -1, self.args.num_cell_h, self.args.num_cell_w) z_pres_logits = gen_pres_logits z_depth_mean, z_depth_std = gen_depth_mean.view(1, -1, 1, 1).expand(bs, -1, self.args.num_cell_h, self.args.num_cell_w), \ gen_depth_std.view(1, -1, 1, 1).expand(bs, -1, self.args.num_cell_h, self.args.num_cell_w) z_where_mean, z_where_std = gen_where_mean.view(1, -1, 1, 1).expand(bs, -1, self.args.num_cell_h, self.args.num_cell_w), \ gen_where_std.view(1, -1, 1, 1).expand(bs, -1, self.args.num_cell_h, self.args.num_cell_w) else: mask_enc = self.mask_enc_net(alpha) x_alpha_enc = torch.cat((img_enc, mask_enc), dim=1) cat_enc = self.img_mask_cat_enc(x_alpha_enc) # (bs, 1, 8, 8) z_pres_logits = pres_logit_factor * torch.tanh( self.z_pres_net(cat_enc) + self.z_pres_bias) # (bs, dim, 8, 8) z_depth_mean, z_depth_std = self.z_depth_net(cat_enc).chunk(2, 1) z_depth_std = F.softplus(z_depth_std) # (bs, 4 + 4, 8, 8) z_where_mean, z_where_std = self.z_where_net(cat_enc).chunk(2, 1) z_where_std = F.softplus(z_where_std) q_z_pres = NumericalRelaxedBernoulli(logits=z_pres_logits, temperature=tau) z_pres_y = q_z_pres.rsample() z_pres = torch.sigmoid(z_pres_y) q_z_depth = Normal(z_depth_mean, z_depth_std) z_depth = q_z_depth.rsample() q_z_where = Normal(z_where_mean, z_where_std) z_where = q_z_where.rsample() # (bs, dim, 8, 8) z_where_origin = z_where.clone() scale, ratio = z_where[:, :2].tanh().chunk(2, 1) scale = self.args.size_anc + self.args.var_s * scale ratio = self.args.ratio_anc + self.args.var_anc * ratio ratio_sqrt = ratio.sqrt() z_where[:, 0:1] = scale / ratio_sqrt z_where[:, 1:2] = scale * ratio_sqrt z_where[:, 2:] = 2. / self.args.num_cell_h * (self.offset + 0.5 + z_where[:, 2:].tanh()) - 1 z_where = z_where.permute(0, 2, 3, 1).reshape(-1, 4) return z_where, z_pres, z_depth, z_where_mean, z_where_std, \ z_depth_mean, z_depth_std, z_pres_logits, z_pres_y, z_where_origin
def batch_interact_with(self, env, sample_cnt, fix_example=None, train_flag=False, rand_flag=False, eps_flag=False, eps_value=0.1): # get next example batch env.next_example(example=fix_example) # Size(seq_len, batch_size, action_size) seq_action_probs = self.batch_calcu_action_prob(env.input_emb_seq, env.input_seq_lens, train_flag=train_flag) seq_action_logps = probs_to_logits(seq_action_probs) batch_seq_lens = env.input_seq_lens.tolist() batch_size = len(batch_seq_lens) batch_mean_rewards = torch.zeros_like( env.input_seq_lens).float() # Size(batch_size) mask_list = [] reward_list = [] for sample_idx in range(sample_cnt): # reset episode variables of env and agent env.reset_state() self.reset_episode_info() # Size([seq_len, batch_size, action_size]), dtype=torch.float seq_action_masks = self.batch_select_action(seq_action_probs, batch_seq_lens, rand_flag=rand_flag, eps_flag=eps_flag, eps_value=eps_value) batch_rewards = env.batch_transition_with( seq_action_masks) # Size([batch_size]) batch_mean_rewards += batch_rewards # Size([batch_size]) # TODO: whether to add reward decay strategy seq_action_rewards = batch_rewards.unsqueeze(0).unsqueeze( -1).expand_as(seq_action_masks) # for torch 1.1.0 seq_action_rewards = seq_action_rewards.contiguous() assert seq_action_masks.requires_grad is False assert seq_action_rewards.requires_grad is False mask_list.append(seq_action_masks) reward_list.append(seq_action_rewards) # reduce reward variance if sample_cnt > 1 or batch_size > 1: batch_mean_rewards /= sample_cnt # Size([batch_size]) base_mean_rewards = batch_mean_rewards.unsqueeze(0).unsqueeze( -1).expand_as(seq_action_probs) # cur_baseline = torch.sum(batch_mean_rewards) / batch_size # Size([]) for sample_idx in range(sample_cnt): # reward_list[sample_idx] -= cur_baseline # Size([batch_size]) reward_list[ sample_idx] -= base_mean_rewards # Size([batch_size]) return seq_action_probs, seq_action_logps, mask_list, reward_list, batch_mean_rewards
def __init__(self, output_dim, n_covariates, grid_z, grid_c, mapping_z, mappings_c, mappings_cz, has_feature_level_sparsity=True, penalty_type="fixed", lambda0=1.0, likelihood="Gaussian", p1=0.2, p2=0.2, p3=0.2, device="cpu"): """ Decoder for multiple covariates (i.e. multivariate c) """ super().__init__() self.output_dim = output_dim self.likelihood = likelihood self.has_feature_level_sparsity = has_feature_level_sparsity self.penalty_type = penalty_type self.n_covariates = n_covariates assert isinstance(grid_c, list), "grid_c must be a list" assert len(grid_c) == n_covariates self.grid_z = grid_z self.grid_c = grid_c self.grid_cz = [ torch.cat(expand_grid(self.grid_z, self.grid_c[j]), dim=1) for j in range(self.n_covariates) ] self.grid_z = grid_z.to(device) self.grid_c = [c.to(device) for c in grid_c] self.grid_cz = [cz.to(device) for cz in self.grid_cz] self.n_grid_z = grid_z.shape[0] self.n_grid_c = [c.shape[0] for c in grid_c] self.n_grid_cz = [cz.shape[0] for cz in self.grid_cz] # input -> output self.mapping_z = mapping_z self.mappings_c = mappings_c self.mappings_cz = mappings_cz if self.likelihood == "Gaussian": # feature-specific variances (for Gaussian likelihood) self.noise_sd = torch.nn.Parameter(-1.0 * torch.ones(1, output_dim)) self.intercept = torch.nn.Parameter(torch.zeros(1, output_dim)) self.Lambda_z = Variable(torch.ones(1, output_dim, device=device), requires_grad=True) self.Lambda_c = [ Variable(torch.ones(n_covariates, 1, output_dim, device=device), requires_grad=True) for _ in range(self.n_covariates) ] self.Lambda_cz_1 = [ Variable(torch.ones(self.n_grid_z, output_dim, device=device), requires_grad=True) for _ in range(self.n_covariates) ] self.Lambda_cz_2 = [ Variable(torch.ones(self.n_grid_c[j], output_dim, device=device), requires_grad=True) for j in range(self.n_covariates) ] self.lambda0 = lambda0 self.device = device # RelaxedBernoulli self.temperature = 1.0 * torch.ones(1, device=device) if self.has_feature_level_sparsity: # for the prior RelaxedBernoulli(logits) self.logits_z = probs_to_logits( p1 * torch.ones(1, output_dim).to(device), is_binary=True) self.logits_c = probs_to_logits( p2 * torch.ones(n_covariates, output_dim).to(device), is_binary=True) self.logits_cz = probs_to_logits( p3 * torch.ones(n_covariates, output_dim).to(device), is_binary=True) # for the approx posterior self.qlogits_z = torch.nn.Parameter( 3.0 * torch.ones(1, output_dim).to(device)) self.qlogits_c = torch.nn.Parameter( 3.0 * torch.ones(n_covariates, output_dim).to(device)) self.qlogits_cz = torch.nn.Parameter( 2.0 * torch.ones(n_covariates, output_dim).to(device))