def _select_action(self, logit, ended, is_prob=False, fix_action_ended=True): logit_cpu = logit.clone().cpu() if is_prob: probs = logit_cpu else: probs = F.softmax(logit_cpu, 1) if self.feedback == 'argmax': _, action = probs.max(1) # student forcing - argmax action = action.detach() elif self.feedback == 'sample': # sampling an action from model m = D.Categorical(probs) action = m.sample() else: raise ValueError('Invalid feedback option: {}'.format( self.feedback)) # set action to 0 if already ended if fix_action_ended: for i, _ended in enumerate(ended): if _ended: action[i] = 0 return action
def sample(lnprobs, temperature=1.0): if temperature == 0.0: return lnprobs.argmax() prob = F.softmax(lnprobs / temperature, dim=0) cdf = dist.Categorical(prob) return cdf.sample()
def _sample_posterior(self, x, num_samples, context=None): log_weights = torch.log(self.module.soft_max(self.module.soft_weights)) T = self.module.covars[None, :, :, :] + x[1][:, None, :, :] p_weights = log_weights + dist.MultivariateNormal( loc=self.module.means, covariance_matrix=T ).log_prob(x[0][:, None, :]) p_weights -= torch.logsumexp(p_weights, axis=1)[:, None] L_t = torch.cholesky(T) T_inv = torch.cholesky_solve( torch.eye(self.d, device=self.device), L_t) diff = x[0][:, None, :] - self.module.means T_prod = torch.matmul(T_inv, diff[:, :, :, None]) p_means = self.module.means + torch.matmul( self.module.covars, T_prod ).squeeze() p_covars = self.module.covars - torch.matmul( self.module.covars, torch.matmul(T_inv, self.module.covars) ) idx = dist.Categorical(logits=p_weights).sample([num_samples]) samples = dist.MultivariateNormal( loc=p_means, covariance_matrix=p_covars).sample([num_samples]) return samples.transpose(0, 1)[ torch.arange(len(x), device=self.device)[:, None, None, None], torch.arange(num_samples, device=self.device)[None, :, None, None], idx.T[:, :, None, None], torch.arange(self.d, device=self.device)[None, None, None, :] ].squeeze()
def compose_losses(outputs, log_selected_policies, total_advantages, targets, batch, args): """Caluculate loss value Returns: tuple: losses and statistic values and the number of training data """ tmasks = batch['turn_mask'] omasks = batch['observation_mask'] losses = {} dcnt = tmasks.sum().item() turn_advantages = total_advantages.mul(tmasks).sum(2, keepdim=True) losses['p'] = (-log_selected_policies * turn_advantages).sum() if 'value' in outputs: losses['v'] = ( (outputs['value'] - targets['value'])**2).mul(omasks).sum() / 2 if 'return' in outputs: losses['r'] = F.smooth_l1_loss(outputs['return'], targets['return'], reduction='none').mul(omasks).sum() entropy = dist.Categorical(logits=outputs['policy']).entropy().mul( tmasks.sum(-1)) losses['ent'] = entropy.sum() base_loss = losses['p'] + losses.get('v', 0) + losses.get('r', 0) entropy_loss = entropy.mul(1 - batch['progress'] * (1 - args['entropy_regularization_decay'])).sum( ) * -args['entropy_regularization'] losses['total'] = base_loss + entropy_loss return losses, dcnt
def test_words(net, chars, setence_len=50, iscuda=False): ''' Given a network, valid characters in trained book, let the network generate a sentence. This is used for training rnn model. ''' # create hidden state ho = net.init_hidden() # create random word index x_in = torch.LongTensor([random.randint(0, len(chars) - 1)]) # create output index output = [int(x_in)] if iscuda: ho = ho.cuda() x_in = x_in.cuda() # now we iterate through our setence, pasing x_in to get y_out, and setting y_out as x_in for the next time step for i in range(setence_len): y_out, ho = net.forward(x_in, ho) dist = distributions.Categorical(probs=y_out.exp()) # get max val and index sample = dist.sample() output.append(int(sample)) x_in = sample # now we print our words words = '' for item in output: words += chars[item] return words
def __init__(self, latent_dim=2, num_classes=10, distribution=None, categorical=None): """ Initializes a new dataset where noise and a label is sampled from the given distribution. If no distribution is given, noise is sampled from a multivariate normal distribution with a certain latent dimension and the label is sampled from a categorical distribution. Parameters ---------- latent_dim: int The latent dimension for the Normal Distribution the noise is sampled from. num_classes: int Number of classes for the Categorical Distribution the label is sampled from. distribution: torch.distributions.Distribution The noise type to use. Overrides setting of latent_dim if specified. categorical: torch.distributions.Distribution The distribution to sample labels from. Overrides setting of num_classes if specified. """ super().__init__(latent_dim=latent_dim, distribution=distribution) if categorical is None: self.categorical = D.Categorical( torch.Tensor([1.0 / num_classes] * num_classes)) else: self.categorical = categorical
def _goal_likelihood(self, y: torch.Tensor, goal: torch.Tensor, **hyperparams) -> torch.Tensor: """Returns the goal-likelihood of a plan `y`, given `goal`. Args: y: A plan under evaluation, with shape `[B, T, 2]`. goal: The goal locations, with shape `[B, K, 2]`. hyperparams: (keyword arguments) The goal-likelihood hyperparameters. Returns: The log-likelihodd of the plan `y` under the `goal` distribution. """ # Parses tensor dimensions. B, K, _ = goal.shape # Fetches goal-likelihood hyperparameters. epsilon = hyperparams.get("epsilon", 1.0) # TODO(filangel): implement other goal likelihoods from the DIM paper # Initializes the goal distribution. goal_distribution = D.MixtureSameFamily( mixture_distribution=D.Categorical( probs=torch.ones((B, K)).to(goal.device)), # pylint: disable=no-member component_distribution=D.Independent( D.Normal(loc=goal, scale=torch.ones_like(goal) * epsilon), # pylint: disable=no-member reinterpreted_batch_ndims=1, )) return torch.mean(goal_distribution.log_prob(y[:, -1, :]), dim=0) # pylint: disable=no-member
def get_dist(self): n = len(self.mean) mix = D.Categorical(torch.ones(n, )) comp = D.Independent(D.Normal(self.mean, self.var * torch.ones(n, 2)), 1) return D.MixtureSameFamily(mix, comp)
def predict(self, x, deterministic=True): out = self.actor(x) if deterministic: out = torch.max(out, dim=1)[1] else: out = distributions.Categorical(probs=out).sample() return out.cpu().numpy()
def act(self, s, epsilon): '''epsilon greedy action selection Arguments: s {np array} -- state selection epsilon {float} -- epsilon value Returns: action -- action index ''' # get action logits action_logits = self.brain(s) # create a categorical distribution from logits categorical_distribution = distributions.Categorical( logits=action_logits) # sample actions according to the distribution actions = categorical_distribution.sample() # print(actions.shape) # collect relevant log probabilities relevant_log_probs = categorical_distribution.log_prob(actions) # print(relevant_log_probs.shape) return actions[0].item(), relevant_log_probs
def act_intrinsic(self, obs): assert self.intrinsic # Only usable with random network distillation obs = torch.FloatTensor(obs) if self.action_type == "Discrete": logits, state_values, int_state_values = self.net(obs) state_values = state_values.squeeze() int_state_values = int_state_values.squeeze() dist = distributions.Categorical(F.softmax(logits, dim=-1)) actions = dist.sample().squeeze() action_log_probs = dist.log_prob(actions).squeeze() elif self.action_type == "Box": logits, sd, state_values, int_state_values = self.net.forward_continuous( obs) state_values = state_values.squeeze() int_state_values = int_state_values.squeeze() dist = distributions.Normal(logits, torch.exp(sd)) actions = dist.sample() action_log_probs = dist.log_prob(actions) dist_entropy = dist.entropy() return actions, state_values, int_state_values, action_log_probs
def test_mutual_info_penalty(self): real_loss_mean = 2.600133 real_loss_sum = 5.200266 real_losses = [0.7086121, 4.491654] mean = torch.Tensor([[1.3, 4.6, 7.1], [0.2, 11.4, 1.0]]) std = torch.Tensor([[1.0, 0.5, 3.1], [0.2, 3.5, 4.9]]) logits = torch.Tensor([[0.5, 0.5], [0.75, 0.25]]) c_dis = torch.Tensor([[0, 1], [1, 0]]) c_cont = torch.Tensor([[1.4, 4.0, 5.0], [-1.0, 7.0, 2.0]]) q_cont = ds.Normal(loc=mean, scale=std) q_cat = ds.Categorical(logits=logits) mutualinfo = MutualInformationPenalty() loss_mean = mutualinfo(c_dis, c_cont, q_cat, q_cont) self.assertAlmostEqual(loss_mean.item(), real_loss_mean, 5) mutualinfo.reduction = "sum" loss_sum = mutualinfo(c_dis, c_cont, q_cat, q_cont) self.assertAlmostEqual(loss_sum.item(), real_loss_sum, 5) mutualinfo.reduction = "none" loss = mutualinfo(c_dis, c_cont, q_cat, q_cont) for i in range(2): self.assertAlmostEqual(loss[i].item(), real_losses[i], 5)
def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() value = self.critic(obs, rnncs=self.rnncs) # [B, 1] if self.is_continuous: mu, log_std = output # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] log_prob = dist.log_prob(action).unsqueeze(-1) # [B, 1] else: logits = output # [B, A] logp_all = logits.log_softmax(-1) # [B, A] norm_dist = td.Categorical(logits=logp_all) action = norm_dist.sample() # [B,] log_prob = norm_dist.log_prob(action).unsqueeze(-1) # [B, 1] acts_info = Data(action=action, value=value, log_prob=log_prob + th.finfo().eps) if self.use_rnn: acts_info.update(rnncs=self.rnncs) if self.is_continuous: acts_info.update(mu=mu, log_std=log_std) else: acts_info.update(logp_all=logp_all) return action, acts_info
def decoder(self, z, encoded_history, current_state, y_e=None, train=False): pass bs = encoded_history.shape[0] a_0 = F.dropout(self.action(current_state.reshape(bs, -1)), self.dropout_p) state = F.dropout(self.state(encoded_history.reshape(bs, -1)), self.dropout_p) current_state = current_state.unsqueeze(1) gauses = [] inp = F.dropout( torch.cat((encoded_history.reshape(bs, -1), a_0), dim=-1), self.dropout_p) for i in range(12): h_state = self.gru(inp.reshape(bs, -1), state) _, deltas, log_sigmas, corrs = self.project_to_GMM_params(h_state) deltas = torch.clamp(deltas, max=1.5, min=-1.5) deltas = deltas.reshape(bs, -1, 2) log_sigmas = log_sigmas.reshape(bs, -1, 2) corrs = corrs.reshape(bs, -1, 1) mus = deltas + current_state current_state = mus variance = torch.clamp(torch.exp(log_sigmas).unsqueeze(2)**2, max=1e3) m_diag = variance * torch.eye(2).to(variance.device) sigma_xy = torch.clamp(torch.prod(torch.exp(log_sigmas), dim=-1), min=1e-8, max=1e3) if train: # log_pis = z.reshape(bs, 1) * torch.ones(bs, self.num_modes).cuda() log_pis = to_one_hot(z, n_dims=self.num_modes).cuda() else: log_pis = to_one_hot(z, n_dims=self.num_modes).cuda() log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True) mix = D.Categorical(logits=log_pis) comp = D.MultivariateNormal(mus, m_diag) gmm = D.MixtureSameFamily(mix, comp) t = (sigma_xy * corrs.squeeze()).reshape(-1, 1, 1) cov_matrix = m_diag # + anti_diag gauses.append(gmm) a_t = gmm.sample() # possible grad problems? a_tt = F.dropout(self.action(a_t.reshape(bs, -1)), self.dropout_p) state = h_state inp = F.dropout( torch.cat((encoded_history.reshape(bs, -1), a_tt), dim=-1), self.dropout_p) return gauses
def forward(self, output_sizes, hold_seed=None, hold_initial_set=False): """ Sample from prior :param output_sizes: Tensor([B,]) :param hold_seed :param hold_initial_set :return: Tensor([B, N, D]) """ bsize = output_sizes.shape[0] if hold_initial_set: # [B, N] x_mask = get_mask(output_sizes, self.max_outputs) else: x_mask = sample_mask(output_sizes, self.max_outputs) if hold_seed is not None: # [B, N, Ds] torch.random.manual_seed(hold_seed) eps = torch.randn([1, self.max_outputs, self.dim_seed ]).to(x_mask.device).repeat(bsize, 1, 1) else: eps = torch.randn([bsize, self.max_outputs, self.dim_seed]).to(x_mask.device) if self.n_mixtures == 1: x = self.mu + torch.exp(self.logvar / 2.) * eps else: if self.train_gmm: if hold_seed is not None: torch.random.manual_seed(hold_seed) logits = self.logits.reshape([1, 1, self.n_mixtures]).repeat( 1, self.max_outputs, 1) # [1, N, M] onehot = F.gumbel_softmax( logits, tau=self.tau, hard=True).repeat(bsize, 1, 1).unsqueeze(-1) # [B, N, M, 1] else: logits = self.logits.reshape([1, 1, self.n_mixtures]).repeat( bsize, self.max_outputs, 1) # [B, N, M] onehot = F.gumbel_softmax(logits, tau=self.tau, hard=True).unsqueeze( -1) # [B, N, M, 1] mu = self.mu.reshape([1, 1, self.n_mixtures, self.dim_seed]) # [1, 1, M, D] sig = self.sig.reshape([1, 1, self.n_mixtures, self.dim_seed]) # [1, 1, M, D] mu = (mu * onehot).sum(2) # [B, N, D] sig = (sig * onehot).sum(2) # [B, N, D] x = mu + sig * eps else: mix = D.Categorical(self.logits) comp = D.Independent(D.Normal(self.mu, self.sig.abs()), 1) mixture = D.MixtureSameFamily(mix, comp) x = mixture.sample((output_sizes.size(0), self.max_outputs)) x = self.output(x) # [B, N, D] return x, x_mask
def generate(self, bs): a = torch.zeros( (bs, self.Number_qubits)).type(torch.LongTensor).to(args.device) hidden = self.init_hidden.repeat(1, bs, 1) # BOS input beginning = self.BOS.view(1, 1, -1) beginning = beginning.repeat(1, bs, 1) output, hidden = self.gru(beginning, hidden) output = self.logsoftmax(self.out(output[0])) sampled_op = dist.Categorical(output.squeeze(0).exp()).sample() a[:, 0] = sampled_op for i in range(0, self.Number_qubits - 1): output, hidden = self.forward( a[:, i], hidden) #output: [1,bs,charset_length] sampled_op = dist.Categorical(output.squeeze(0).exp()).sample() a[:, i + 1] = sampled_op return a
def label(self, i: int, j: int) -> dist.Distribution: """ Observed label distribution for each item (i) and label (j). """ labeler = self.labelers[i, j].item() return dist.Categorical( self.confusion_matrix(labeler, self.true_label(i).item()))
def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.q_net.get_rnncs() logits = ((q_values - self._get_v(q_values)) / self.alpha).exp() # > 0 # [B, A] logits /= logits.sum(-1, keepdim=True) # [B, A] cate_dist = td.Categorical(logits=logits) actions = cate_dist.sample() # [B,] return actions, Data(action=actions)
def forward(self, x, z_prev): """ x: shape=(BS,N) z_prev: shape=(BS,N) """ logits = self.transform_x(x_n) + self.P[z_prev] #shape=(BS,N,B) dist_z = dist.Categorical(logits=logits) return dist_z.sample()
def select_action(self, state, rand_flag=False, eps_flag=False, eps_value=1.0, train_flag=False): def get_reverse_prob(probs): # assume probs.size() is size([1, action_size]) rev_idxs = torch.arange(probs.size(-1) - 1, -1, -1, device=probs.device).long() with torch.no_grad(): rev_probs = torch.index_select(probs, -1, rev_idxs) return rev_probs if train_flag: self.policy_net.train() action_probs = self.policy_net(state) # size([1, action_size]) else: self.policy_net.eval() with torch.no_grad(): action_probs = self.policy_net(state) # size([1, action_size]) action_logps = probs_to_logits(action_probs) # size([1, action_size]) if rand_flag: if eps_flag and random.random() < eps_value: # print('use epsilon random policy') action_rev_probs = get_reverse_prob(action_probs) m = dist.Categorical(probs=action_rev_probs) else: m = dist.Categorical(probs=action_probs) action = m.sample() # size([1]) # action_logp = m.log_prob(action) # size([1]) else: action = torch.argmax(action_probs, dim=-1) # size([1]) assert action.requires_grad is False action_logp = action_logps.gather(-1, action.unsqueeze(0)).squeeze( -1) # size([1]) action = action.item() self.episode_actions.append(action) self.episode_action_logps.append(action_logp) self.episode_action_probs.append(action_probs) return action
def select_action(env, model, side, hidden, config: Config): x = get_model_input(env, side).to(config.device).float() output, value, hidden = model.train().to(config.device)(x, hidden) distribution = dist.Categorical(F.softmax(output, dim=-1)) action = distribution.sample() log_prob = distribution.log_prob(action) entropy = -(log_prob * output).sum(-1) return log_prob, action.item() + 1, value, hidden, entropy
def forward(self, x, **kwargs): p = self.p.expand(x.shape[0], self.p.shape[-1]) if isinstance(self.action_space, spaces.Discrete): dist = distributions.Categorical(probs=F.softmax(p, dim=1)) elif isinstance(self.action_space, spaces.Box): p = torch.chunk(p, 2, dim=1) dist = distributions.Normal(loc=p[0], scale=p[1]) return dist, torch.ones_like(x)[:, :1]
def encode(self, x): feats = self.pointnet(x) log_prob_y = self.cat_encoder(feats) prob_y = torch.exp(log_prob_y) y_dis = distrib.Categorical(probs=prob_y).sample() y = one_hot(y_dis, self.clusters).to(x.device) z, _, _ = self.encode_z(y, feats) return y, z
def forward(self, observation: torch.FloatTensor) -> Any: if self.discrete: action_probs = self.logits_na(observation) return distributions.Categorical(action_probs) else: mean = self.mean_net(observation) dist = distributions.Normal(loc=mean, scale=torch.exp(self.logstd)) return dist
def _sample(dist, sampling_mode='greedy'): if sampling_mode == 'greedy': _, sample = torch.topk(dist, 1, dim=-1) elif sampling_mode == 'random': p = F.softmax(dist, dim=-1) sample = dis.Categorical(p).sample() sample = sample.squeeze() return sample
def __init__(self, lib): assert isinstance(lib, Library) super(CharacterTypeDist, self).__init__(lib) # override part type dist self.pdist = StrokeTypeDist(lib) # distribution of 'k' (number of strokes) assert len(lib.pkappa.shape) == 1 self.kappa = dist.Categorical(probs=lib.pkappa)
def select_action(self, state): action_prob, value_pred = self.policy(state) dist = distributions.Categorical(action_prob) action = dist.sample() log_prob_action = dist.log_prob(action) self.log_prob_actions.append(log_prob_action) self.values.append(value_pred) return action
def ppo_update(config, f_actor, diff_actor_opt, critic, critic_opt, memory_cache, update_type='meta'): # Actor is functional in meta, and normal in rl. summed_policy_loss = torch.zeros(1) summed_value_loss = torch.zeros(1) states, next_states, actions_init, rewards, dones, log_prob_actions_init = get_shaped_memory_sample(config, memory_cache) # Using critic to predict last reward. Just as a placeholder in case the trajectory is incomplete in the batch-mode. final_predicted_reward = 0. if dones[-1] == 0.: # Then last step is not done. Last value has to be predicted. final_state = next_states[-1] with torch.no_grad(): final_predicted_reward = critic(final_state).detach().item() returns = calculate_returns(config, rewards, dones, predicted_end_reward=final_predicted_reward) #Returns(samples,1) # At this point, they should always be tensors and output a tensor based solution. values_init = critic(states) advantages = returns - values_init if config.normalize_rewards_and_advantages: advantages = (advantages - advantages.mean()) / advantages.std() advantages = advantages.detach() # Necessary to keep the advantages from have a connection to the value model. # Now the actor makes steps and recalculates actions and log_probs based on the current values for k epochs. for ppo_step in range(config.num_ppo_steps): action_prob = f_actor(states) # print('action_prob', type(action_prob), action_prob.shape, action_prob) values_pred = critic(states) if config.env_config.action_space_type == 'discrete': dist = distributions.Categorical(action_prob) ## Stupido actions_init = actions_init.squeeze(-1) new_log_prob_actions = dist.log_prob(actions_init) new_log_prob_actions = new_log_prob_actions.view(-1, 1) elif config.env_config.action_space_type == 'continuous': action_mean_vector = action_prob * f_actor.action_upper_limit # Direct code from actor get_action, refer there dist = distributions.MultivariateNormal(action_mean_vector, f_actor.covariance_matrix) actions_init = actions_init.view(-1, config.action_dim) new_log_prob_actions = dist.log_prob(actions_init) new_log_prob_actions = new_log_prob_actions.view(-1, 1) policy_ratio = (new_log_prob_actions - log_prob_actions_init).exp() policy_loss_1 = policy_ratio * advantages policy_loss_2 = torch.clamp(policy_ratio, min=1.0 - config.ppo_clip, max=1.0 + config.ppo_clip) * advantages if config.include_entropy_in_ppo: inner_policy_loss = ( -torch.min(policy_loss_1, policy_loss_2) - config.entropy_coefficient * dist.entropy()).sum() else: inner_policy_loss = -torch.min(policy_loss_1, policy_loss_2).sum() if update_type == 'meta': diff_actor_opt.step(inner_policy_loss) else: # In this case, it's normal RL, and so there is no updating that happens outside in the main function. diff_actor_opt.zero_grad() inner_policy_loss.backward() diff_actor_opt.step() inner_value_loss = F.smooth_l1_loss(values_pred, returns).sum() inner_value_loss.backward() critic_opt.step() summed_policy_loss += inner_policy_loss summed_value_loss += inner_value_loss return summed_policy_loss, summed_value_loss.item()
def generate_rollout( world: GridWorld, agent: nn.Module, grammar_goal: str, critic: nn.Module, task_idx: int, deterministic: bool = False, ) -> Trajectory: samples: Sequence[Sample] = [] # Perform typical RL loop. agent.reset(grammar_goal, device=args.device) obs_raw: Observation = world.reset() while True: obs = encode_observation(obs_raw).to(args.device) # size[D] primitive_idx = None if args.agent_type == "ppg": agent_state, action_probs = agent(obs.unsqueeze(0)) # size[1, *] elif args.agent_type == "sketch": agent_state, action_probs, primitive_idx = agent(obs.unsqueeze(0)) state_value = critic(agent_state)[:, task_idx] action_probs = action_probs.squeeze(0) state_value = state_value.squeeze(0) action_dist = dist.Categorical(action_probs) action = action_dist.sample( ) if not deterministic else action_probs.argmax() log_prob = action_dist.log_prob(action) action_raw: Action = Action(action.item()) obs_raw, reward_raw, done, info = world.step(action_raw) reward = torch.tensor(float(reward_raw)).to(args.device) # Must detach results computed by neural networks from computational graph as these values # are just used to compute gradients for the model. We don't actually want the gradients to # be propagating through them. samples.append( Sample( obs=obs, action=action, reward=reward, log_prob=log_prob.detach(), state_value=state_value.detach(), ret=None, advantage=None, primitive_idx=primitive_idx, )) if done: break samples = compute_returns(samples, discount=args.discount_factor, device=args.device) samples = compute_advantages(samples) return Trajectory(samples)
def step(self, t, state, prev_output, detections, seq, *args, mode='teacher_forcing'): assert (mode in ['teacher_forcing', 'feedback']) device = detections.device b_s = detections.size(0) bos_idx = self.bos_idx state_1, state_2 = state[:2], state[2:] detections_mask = (torch.sum(detections, -1, keepdim=True) != 0).float() detections_mean = torch.sum(detections, 1) / torch.sum(detections_mask, 1) if mode == 'teacher_forcing': if self.training and t > 0 and self.ss_prob > .0: # Scheduled sampling coin = detections.data.new(b_s).uniform_(0, 1) coin = (coin < self.ss_prob).long() distr = distributions.Categorical(logits=prev_output) action = distr.sample() it = coin * action.data + (1 - coin) * seq[:, t - 1].data it = it.to(device) else: it = seq[:, t] elif mode == 'feedback': # test if t == 0: it = detections.data.new_full((b_s,), bos_idx).long() else: it = prev_output xt = self.embed(it) if self.with_relu: xt = F.relu(xt) input_1 = torch.cat([state_2[0], detections_mean, xt], 1) if self.with_visual_sentinel: g_t = torch.sigmoid(self.W_sx(input_1) + self.W_sh(state_1[0])) state_1 = self.lstm_cell_1(input_1, state_1) att_weights = torch.tanh(self.att_va(detections) + self.att_ha(state_1[0]).unsqueeze(1)) att_weights = self.att_a(att_weights) if self.with_visual_sentinel: s_t = g_t * torch.tanh(state_1[1]) fc_sentinel = self.fc_sentinel(s_t).unsqueeze(1) if self.with_relu: fc_sentinel = F.relu(fc_sentinel) detections = torch.cat([fc_sentinel, detections], 1) detections_mask = (torch.sum(detections, -1, keepdim=True) != 0).float() sent_att_weights = torch.tanh(self.W_sas(s_t) + self.att_ha(state_1[0])).unsqueeze(1) sent_att_weights = self.W_sa(sent_att_weights) att_weights = torch.cat([sent_att_weights, att_weights], 1) att_weights = F.softmax(att_weights, 1) att_weights = detections_mask * att_weights att_weights = att_weights / torch.sum(att_weights, 1, keepdim=True) att_detections = torch.sum(detections * att_weights, 1) input_2 = torch.cat([state_1[0], att_detections], 1) state_2 = self.lstm_cell_2(input_2, state_2) out = F.log_softmax(self.out_fc(state_2[0]), dim=-1) return out, (state_1[0], state_1[1], state_2[0], state_2[1])