Example #1
0
        def forward(self, x):
            seq, batch = x.size(0), x.size(1)

            x = x.view(batch * seq, -1)
            params = self.net(x)
            params = params.view(seq, batch)

            sampler = Bernoulli(params)
            pred = sampler.sample()

            logits = sampler.log_prob(pred)
            entropy = sampler.entropy().sum(0)

            return pred, logits, entropy, params
    def forward(self, seq, mask):
        encoded = self.encoder(seq, mask)
        dist_params, actions = self.predictor(encoded)
        dist_params, actions = dist_params.t(), actions.t()
        sampler = Bernoulli(dist_params)
        # Compute LogProba
        log_probas = sampler.log_prob(actions)
        log_probas = apply_mask(log_probas, mask)

        # Compute Entropy
        entropy = sampler.entropy()
        entropy = apply_mask(log_probas, mask)

        return actions, log_probas, entropy, dist_params
Example #3
0
    def loss(self, X, y, n_samples=100):
        p = Normal(self.prior_mean, torch.exp(self.prior_log_sigma))
        q = Normal(self.mean, torch.exp(self.log_sigma))

        weights = q.rsample((n_samples, ))
        theta = Normal(0, 1).cdf(X @ weights / (math.sqrt(2) * self.noise))

        predictive_distribution = Bernoulli(theta)

        nll = -predictive_distribution.log_prob(y).mean()
        kld = kl_divergence(q, p).sum()
        loss = nll.mean() + kld

        return loss
Example #4
0
class BernoulliDistribution(Distribution):
    """
    Bernoulli distribution for MultiBinary action spaces.

    :param action_dim: (int) Number of binary actions
    """

    def __init__(self, action_dims: int):
        super(BernoulliDistribution, self).__init__()
        self.distribution = None
        self.action_dims = action_dims

    def proba_distribution_net(self, latent_dim: int) -> nn.Module:
        """
        Create the layer that represents the distribution:
        it will be the logits of the Bernoulli distribution.

        :param latent_dim: (int) Dimension of the last layer
            of the policy network (before the action layer)
        :return: (nn.Linear)
        """
        action_logits = nn.Linear(latent_dim, self.action_dims)
        return action_logits

    def proba_distribution(self, action_logits: th.Tensor) -> 'BernoulliDistribution':
        self.distribution = Bernoulli(logits=action_logits)
        return self

    def mode(self) -> th.Tensor:
        return th.round(self.distribution.probs)

    def sample(self) -> th.Tensor:
        return self.distribution.sample()

    def entropy(self) -> th.Tensor:
        return self.distribution.entropy().sum(dim=1)

    def actions_from_params(self, action_logits: th.Tensor,
                            deterministic: bool = False) -> th.Tensor:
        # Update the proba distribution
        self.proba_distribution(action_logits)
        return self.get_actions(deterministic=deterministic)

    def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        actions = self.actions_from_params(action_logits)
        log_prob = self.log_prob(actions)
        return actions, log_prob

    def log_prob(self, actions: th.Tensor) -> th.Tensor:
        return self.distribution.log_prob(actions).sum(dim=1)
Example #5
0
 def log_likelihood_vec(self, y, x, z):
     '''
         p(y_t | y_1:t-1, x_1:t, z_1:t)
         y will contain -1's denoting unobserved data
         only compute log probs for observed y's.
         identify indices where y does not equal -1
         compute log probs for those and then sum accordingly.
     '''
     logits = torch.sum(x * z[:, None, :], dim=2)
     train_inds = self.return_train_ind(y)
     logits_train = logits[train_inds]
     # limit logits to observed y's
     obs = Bernoulli(logits=logits_train)
     return torch.sum(obs.log_prob(y[train_inds]))
Example #6
0
 def act_parallel(self, batch_states, batch_z, values=None):
     assert (batch_states.shape[0] == batch_z.shape[0])
     batch_states = torch.from_numpy(batch_states).long()
     batch_conditioned_vec = self.theta + batch_z.to(self.device)
     probs = torch.sigmoid(batch_conditioned_vec).gather(
         1, batch_states.view(-1, 1)).squeeze(1)
     m = Bernoulli(1 - probs)
     actions = m.sample()
     log_probs_actions = m.log_prob(actions)
     if values is not None:
         return actions.numpy().astype(
             int), log_probs_actions, values[batch_states]
     else:
         return actions.numpy().astype(int), log_probs_actions
    def backward(self, states, actions, rewards, dummy_index):
        self.optimizer.zero_grad()
        for i in range(len(states)):
            state = states[i]
            # action = torch.autograd.Variable(torch.FloatTensor([actions[i]]))
            reward = rewards[i]

            probs = self.forward(state)
            m = Bernoulli(probs)
            loss = (-m.log_prob(actions[i]) * reward
                    )  # Negtive score function x reward
            loss.backward()

            self.optimizer.step()
Example #8
0
 def compute_theta_star(self):
     opt = Adam(self.model.parameters(), lr=self.alpha, weight_decay=0)
     scheduler = ReduceLROnPlateau(opt, 'min', factor=0.99, min_lr=1e-10)
     for i in range(self.n_epochs):
         opt.zero_grad()
         output = self.model.forward(self.Xt)
         likelihood = Bernoulli(logits=output.flatten())
         prior = Normal(0, 1 / np.sqrt(self.delta))
         loss = - torch.sum(likelihood.log_prob(self.yt)) - torch.sum(prior.log_prob(self.model.weights))
         loss.backward()
         opt.step()
         scheduler.step(loss.item())
         self.losses.append(loss.item())
     self.theta_star = self.model.weights.detach().numpy()
Example #9
0
    def loss(self, x):
        mu, std = self.encode(x)
        z = self.reparameterize(mu, std)
        recon_x = self.decode(z)
        dist = Bernoulli(recon_x)
        l = torch.sum(dist.log_prob(x.view(-1, 784)), dim=1)
        a = torch.tensor([0.0]).to(device)
        b = torch.tensor([1.0]).to(device)
        p_z = torch.sum(Normal(a, b).log_prob(z), dim=1)
        q_z = torch.sum(Normal(mu, std).log_prob(z), dim=1)

        res = -torch.mean(l + p_z - q_z) * np.log2(np.e) / 784

        return res
 def train_pg(self, state, action, reward):
     """
     Train the policy using a policy gradient approach
     :param state: the input state(s)
     :param action: the input action(s)
     :param reward: the resulting reward
     :return: the loss applied to train the policy
     """
     action = torch.FloatTensor(action)
     reward = torch.FloatTensor(reward)
     probs = self.forward(state)
     m = Bernoulli(probs)
     loss = -m.log_prob(action).sum(dim=-1) * reward  # Negative score function x reward
     self.update(loss)
     return loss
Example #11
0
    def sampled_from_logit_p(self, num_samples):
        expanded_logit_p = self.logit_p.unsqueeze(0).expand(
            num_samples, *self.logit_p.size())

        # Note that p is the dropout probability here
        drop_p = torch.sigmoid(expanded_logit_p)
        m = Bernoulli(1. - drop_p)

        bern_val = m.sample()
        if self.log_prob is not None:
            raise Exception('Log probability should be cleaned up after use')
        self.log_prob = m.log_prob(bern_val)

        self.bern_val = bern_val
        return bern_val
Example #12
0
    def forward(self, h_t, alpha=0.05, eps=1):

        probs = torch.sigmoid(self.fc(
            h_t.detach()))  # Compute halting-probability

        probs = (1 - alpha) * probs + alpha * torch.FloatTensor([eps]).cuda()

        m = Bernoulli(
            probs=probs
        )  # Define bernoulli distribution parameterized with predicted probability

        halt = m.sample()  # Sample action

        log_pi = m.log_prob(halt)  # Compute log probability for optimization

        return halt, log_pi, -torch.log(probs), probs
Example #13
0
    def select_action(self, state):

        # sample an action from stochastic policy
        action_prob = self.actor.forward(state)
        dist = Bernoulli(action_prob)

        sampled_val = dist.sample()
        action_idx = int(sampled_val.item())

        # compute log prob
        # print(sampled_val.item() == 1.0, sampled_val, action_idx)
        action_to_take = ACTIONS[action_idx]

        self.memory['log_probs'].append(dist.log_prob(sampled_val))

        return action_to_take
Example #14
0
    def forward(self, x, mask):
        self.x_sizes = x.size()
        encoded = self.encode(x)
        decoded = self.decode(encoded)
        dist_params, actions = self.predict(decoded)
        self.x_sizes = None
        sampler = Bernoulli(dist_params)
        # Compute LogProba
        log_probas = sampler.log_prob(actions)
        log_probas = apply_mask(log_probas, mask)

        # Compute Entropy
        entropy = sampler.entropy()
        entropy = apply_mask(log_probas, mask)

        return actions, log_probas, entropy, dist_params
def predict(agent, parameters):

    parameters = parameters[0, :]
    # -- velx ¨= 4, pip_y ¨= 512, player_heigth = 512, dist = 288
    y = agent.forward(torch.from_numpy(np.array(parameters)).float())

    prob_dist = Bernoulli(
        y)  # Generate Bernoulli distribution with given probability
    action = prob_dist.sample()  # Sample action from probability distribution
    log_prob = prob_dist.log_prob(action)

    print(log_prob)

    if y > 0.5:
        return 1, log_prob
    else:
        return 0, log_prob
Example #16
0
 def forward(self, model_input):
     history_vector = Variable(
         torch.zeros([model_input.shape[0], self.num_hist])).cuda()
     actions = []
     log_probs = []
     for t in range(self.max_frames):
         video_frames = model_input[:, t, :]
         input = torch.cat([video_frames, history_vector], dim=1)
         fc1 = F.relu(self.fc1(input))
         fc2 = F.relu(self.fc2(fc1))
         dists = F.softmax(self.fc3(fc2), dim=1)[:, 0]  # (batch_size, 1)
         m = Bernoulli(dists)
         action = m.sample()
         actions.append(action.view([-1, 1]))
         log_probs.append(m.log_prob(action).view([-1, 1]))
         history_vector = torch.cat([history_vector[:, 1:], action], dim=1)
     return torch.cat(actions, dim=1), torch.cat(log_probs, dim=1)
Example #17
0
    def learn(self):
        self._adjust_reward()

        # policy gradient
        self.optimizer.zero_grad()
        for i in range(self.steps):
            # all steps in multi games
            state = self.state_pool[i]
            action = torch.FloatTensor([self.action_pool[i]])
            reward = self.reward_pool[i]

            probs = self.act(state)
            m = Bernoulli(probs)
            loss = -m.log_prob(action) * reward
            loss.backward()
        self.optimizer.step()

        self._init_memory()
Example #18
0
    def leaning(self):
        self._adjust_reward()

        # policy gradient
        self.optimizer.zero_grad()
        # -- loss backward start --
        for i in range(self.steps):
            state = self.state_pool[i]
            action = torch.FloatTensor([self.action_pool[i]])
            reward = self.reward_pool[i]

            probs = self.act(state)
            m = Bernoulli(probs)
            loss = -m.log_prob(action) * reward
            loss.backward()
        # -- loss backward end --
        self.optimizer.step()
        self._init_memory()
Example #19
0
    def forward(self, encoder_input, decoder_input, labels):
        discrete_latent_z, encoder_scores, encoder_entropy = \
            self.encoder(encoder_input)
        decoder_output, decoder_log_prob, decoder_entropy = \
            self.decoder(discrete_latent_z, decoder_input)

        argmax = (encoder_scores > 0).to(torch.float)

        loss, logs = self.loss(encoder_input, argmax, decoder_input,
                               decoder_output, labels)

        encoder_bernoull_distr = Bernoulli(logits=encoder_scores)
        encoder_sample_log_probs = \
            encoder_bernoull_distr.log_prob(discrete_latent_z).sum(dim=1)

        if self.encoder.baseline_type == 'runavg':
            baseline = self.mean_baseline
        elif self.encoder.baseline_type == 'sample':
            alt_z_sample = encoder_bernoull_distr.sample().detach()
            decoder_output, _, _ = self.decoder(alt_z_sample, decoder_input)
            baseline, _ = self.loss(encoder_input, alt_z_sample, decoder_input,
                                    decoder_output, labels)

        policy_loss = (loss.detach() - baseline) * encoder_sample_log_probs
        entropy_loss = -encoder_entropy * self.encoder_entropy_coeff

        full_loss = (policy_loss + entropy_loss + loss).mean()

        if self.training and self.encoder.baseline_type == 'runavg':
            self.n_points += 1.0
            self.mean_baseline += (loss.detach().mean() -
                                   self.mean_baseline) / self.n_points

        for k, v in logs.items():
            if hasattr(v, 'mean'):
                logs[k] = v.mean()

        logs['baseline'] = self.mean_baseline
        logs['loss'] = loss.mean()
        logs['encoder_entropy'] = encoder_entropy.mean()
        logs['decoder_entropy'] = decoder_entropy.mean()
        logs['distr'] = encoder_bernoull_distr

        return {'loss': full_loss, 'log': logs}
Example #20
0
    def forward(self, embedded_message, bits, _aux_input=None):
        embedded_bits = self.emb_column(bits.float())

        x = torch.cat([embedded_bits, embedded_message], dim=1)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        x = self.fc2(x)

        probs = x.sigmoid()

        distr = Bernoulli(probs=probs)
        entropy = distr.entropy()

        if self.training:
            sample = distr.sample()
        else:
            sample = (probs > 0.5).float()
        log_prob = distr.log_prob(sample).sum(dim=1)
        return sample, log_prob, entropy
Example #21
0
    def gradient_step(self, trajectories):
        
        
        for i in range(len(trajectories)):

            # Discount rewards        
    
            trajectory = trajectories[i]
            discounted_reward = 0
            for j in reversed(range(len(trajectory))):
                state, action, reward = trajectory[j]
                discounted_reward = discounted_reward * self.gamma + reward
                trajectories[i][j] = (state, action, discounted_reward)

            # Normalize rewards

            rewards = [frame[2] for frame in trajectory]

            reward_mean = np.mean(rewards)
            reward_std = np.std(rewards)

            for j in range(len(trajectory)):
                state, action, reward = trajectory[j]
                normalized_reward = (reward - reward_mean) / reward_std 
                trajectories[i][j] = (state, action, normalized_reward)

        # Calculate gradients

        self.optimizer.zero_grad()
        for i in range(len(trajectories)):
            trajectory = trajectories[i]
            for j in range(len(trajectory)):
                state, action, reward = trajectory[j]

                probs = self.policy_net(state)
                
                m = Bernoulli(probs) if self.binary_action_space else Categorical(probs)

                loss = -m.log_prob(action) * reward
                
                loss.backward()
            
        self.optimizer.step()
Example #22
0
def get_log_pi_gradient(policy, action, state, mode="param"):
    """
    caluculate the gradient of log probability of running a certain action
    on specified state under a specific policy

    input:
      policy: nn.module, the policy specified
      action: int, the action specified
      state: int, the state specified
    return:
      grad_log_pi Torch.cuda.FloatTensor in 1D, the gradient of each variables in the policy
    """

    # clean the grad
    policy.zero_grad()

    #convert state to one-hot
    state = get_one_hot(state, 6)

    # forward pass
    probs = policy(state)

    if mode == "param":
        # by the probablity obtained, create a categorical distribution
        action_prob = torch.clamp(probs, min=0.0, max=1.0)
        c = Bernoulli(action_prob)
    else:
        # by the probablity obtained, create a categorical distribution
        c = Categorical(probs)

    loss = c.log_prob(torch.tensor(action).float().cuda())

    # calculate the gradient
    loss.backward()

    # get the gradient in vector:
    grad_log_pi = torch.cat([
        torch.flatten(grads)
        for grads in [value.grad for name, value in policy.named_parameters()]
    ]).detach()

    return grad_log_pi.cuda()
Example #23
0
class CommonDistribution:
    def __init__(self, intent_probs, slot_sigms):
        self.cd = Categorical(intent_probs)
        self.bd = Bernoulli(slot_sigms)

    def sample(self):
        return self.cd.sample(), self.bd.sample()

    def log_prob(self, intent, slots):
        intent = intent.squeeze()
        cd_log_prob = self.cd.log_prob(intent).unsqueeze(1)
        bd_log_prob = self.bd.log_prob(slots)
        log_prob = torch.sum(torch.cat([cd_log_prob, bd_log_prob], dim=1), dim=1,)
        return log_prob

    def entropy(self):
        bd_entr = self.bd.entropy().mean(dim=1)
        cd_entr = self.cd.entropy()
        entr = bd_entr + cd_entr
        return entr
Example #24
0
    def forward(self, h, eps=0.):
        """Read in hidden state, predict one halting probability per class"""
        # Predict one probability per class
        probs = torch.sigmoid(self.fc(x))

        # Balance between explore/exploit by randomly picking some actions
        probs = (1 -
                 self._epsilon) * probs + self._epsilon * torch.FloatTensor(
                     [0.05])  # Explore/exploit (can't be 0)

        # Parameterize bernoulli distribution with prediced probabilities
        m = Bernoulli(probs=probs)

        # Sample an action and compute the log probability of that action being
        # picked (for use during optimization)
        action = m.sample()  # sample an action
        log_pi = m.log_prob(
            action)  # compute log probability of sampled action

        # We also return the negative log probability of the probabilities themselves
        # if we minimize this, it will maximize the likelihood of halting!
        return action, log_pi, -torch.log(probs)
Example #25
0
def marginal(model, x, z):

    k = z.shape[1]
    mu, logvar = model.encode(x)
    mu = mu.squeeze(0)
    logvar = logvar.squeeze(0)
    std = torch.exp(0.5 * logvar)
    batchsize = x.data.shape[0]
    logsums = torch.empty((batchsize, 1)).to(device)

    z = z.view(k, batchsize, -1)
    q_z = Normal(mu, std)

    zero_mean = torch.zeros(batchsize, model.latent_dim).to(device)
    one_std = torch.ones(batchsize, model.latent_dim).to(device)
    p_z = Normal(
        torch.zeros(batchsize, model.latent_dim).to(device),
        torch.ones(batchsize, model.latent_dim).to(device))

    for i in range(k):
        zs = z[i]
        recon_xs = model.decode(zs)
        p_xz = Bernoulli(recon_xs.view(batchsize, 784))

        xs = x.view(batchsize, 784)
        log_pxs = torch.sum(p_xz.log_prob(xs), dim=1)

        log_prior = calc_normal_log_pdf(zs, zero_mean, one_std).sum(dim=1)
        log_posterior = calc_normal_log_pdf(zs, mu, std).sum(dim=1)

        logsum = log_pxs + log_prior - log_posterior
        logsums = torch.cat((logsums, logsum[:, None]), dim=1)

    logsums = logsums[:, 1:]
    res = torch.logsumexp(logsums, dim=1) - math.log(k)

    return res
Example #26
0
    def update(self, reward_pool, state_pool, action_pool):
        # Discount reward
        running_add = 0  # 就是那个有discount的公式

        for i in reversed(range(len(reward_pool))):  # 倒数
            if reward_pool[i] == 0:
                running_add = 0
            else:
                running_add = running_add * self.gamma + reward_pool[i]
                reward_pool[i] = running_add

        # 得到G

        # Normalize reward
        reward_mean = np.mean(reward_pool)
        reward_std = np.std(reward_pool)
        for i in range(len(reward_pool)):
            reward_pool[i] = (reward_pool[i] - reward_mean) / reward_std
        # 归一化

        # Gradient Desent
        self.optimizer.zero_grad()

        for i in range(len(reward_pool)):  # 从前往后
            state = state_pool[i]
            action = Variable(torch.FloatTensor([action_pool[i]]))
            reward = reward_pool[i]

            state = Variable(torch.from_numpy(state).float())
            probs = self.policy_net(state)
            m = Bernoulli(probs)
            # Negtive score function x reward
            loss = -m.log_prob(action) * reward  # 核心
            # print(loss)
            loss.backward()

        self.optimizer.step()
Example #27
0
    def logjoint(self, x, z, model_params):
        '''
        input: x (observations T x D)
        input: latent_mean
        return logpdf under the model parameters
        '''
        T = x.size(0)
        transition_log_scale = model_params[0]
        self.transition_scale = torch.exp(self.transition_log_scale)
        # init log prior
        init_latent_logpdf = Normal(self.init_latent_loc,
                                    torch.exp(self.init_latent_log_scale))
        # transitions
        transition_logpdf = Normal(z[:-1],
                                   torch.exp(self.transition_log_scale))
        # observations
        obs_logpdf = Bernoulli(self.sigmoid(z))

        # compute log probs
        logprob = init_latent_logpdf.log_prob(z[0])
        logprob += torch.sum(transition_logpdf.log_prob(z[1:]))
        logprob += torch.sum(obs_logpdf.log_prob(x))

        return logprob
Example #28
0
def point_distribution(logits: [..., 'T']) -> ([...], [...], [...]):
    '''
    Implements the categorical proposal -> Bernoulli acceptance sampling
    scheme. Given a tensor of logits, performs samples on the last dimension,
    returning
        a) the proposals
        b) a binary mask indicating which ones were accepted
        c) the logp-probability of (proposal and acceptance decision)
    '''

    proposal_dist = Categorical(logits=logits)
    proposals = proposal_dist.sample()
    proposal_logp = proposal_dist.log_prob(proposals)

    accept_logits = select_on_last(logits, proposals).squeeze(-1)

    accept_dist = Bernoulli(logits=accept_logits)
    accept_samples = accept_dist.sample()
    accept_logp = accept_dist.log_prob(accept_samples)
    accept_mask = accept_samples == 1.

    logp = proposal_logp + accept_logp

    return proposals, accept_mask, logp
Example #29
0
def train(epoch):
    agent.train()
    rnet.train()

    matches, rewards, rewards_baseline, policies = [], [], [], []
    for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(trainloader),
                                                  total=len(trainloader)):

        inputs, targets = Variable(inputs), Variable(targets).cuda(async=True)
        if not args.parallel:
            inputs = inputs.cuda()

        # Get the low resolution agent images
        inputs_agent = inputs.clone()
        inputs_agent = torch.nn.functional.interpolate(
            inputs_agent, (args.lr_size, args.lr_size))
        probs = F.sigmoid(
            agent.forward(inputs_agent,
                          args.model.split('_')[1], 'lr'))
        probs = probs * args.alpha + (1 - probs) * (1 - args.alpha)

        # Sample the policies from the Bernoulli distribution characterized by agent's output
        distr = Bernoulli(probs)
        policy_sample = distr.sample()

        # Test time policy - used as baseline policy in the training step
        policy_map = probs.data.clone()
        policy_map[policy_map < 0.5] = 0.0
        policy_map[policy_map >= 0.5] = 1.0

        # Agent sampled high resolution images
        inputs_map = inputs.clone()
        inputs_sample = inputs.clone()
        inputs_map = utils.agent_chosen_input(inputs_map, policy_map, mappings,
                                              patch_size)
        inputs_sample = utils.agent_chosen_input(inputs_sample,
                                                 policy_sample.int(), mappings,
                                                 patch_size)

        # Get the predictions for baseline and sampled policy
        preds_map = rnet.forward(inputs_map, args.model.split('_')[1], 'hr')
        preds_sample = rnet.forward(inputs_sample,
                                    args.model.split('_')[1], 'hr')

        # Get the rewards for both policies
        reward_map, match = utils.compute_reward(preds_map, targets,
                                                 policy_map.data, args.penalty)
        reward_sample, _ = utils.compute_reward(preds_sample, targets,
                                                policy_sample.data,
                                                args.penalty)

        # Find the joint loss from the classifier and agent
        advantage = reward_sample - reward_map
        loss = -distr.log_prob(policy_sample).sum(
            1, keepdim=True) * Variable(advantage)
        loss = loss.mean()
        loss += F.cross_entropy(preds_sample, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        matches.append(match.cpu())
        rewards.append(reward_sample.cpu())
        rewards_baseline.append(reward_map.cpu())
        policies.append(policy_sample.data.cpu())

    accuracy, reward, sparsity, variance, policy_set = utils.performance_stats(
        policies, rewards, matches)

    print('Train: %d | Acc: %.3f | Rw: %.2E | S: %.3f | V: %.3f | #: %d' %
          (epoch, accuracy, reward, sparsity, variance, len(policy_set)))
    log_value('train_accuracy', accuracy, epoch)
    log_value('train_reward', reward, epoch)
    log_value('train_sparsity', sparsity, epoch)
    log_value('train_variance', variance, epoch)
    log_value('train_baseline_reward',
              torch.cat(rewards_baseline, 0).mean(), epoch)
    log_value('train_unique_policies', len(policy_set), epoch)
Example #30
0
    def forward(self,
                s,
                a,
                use_prior=False,
                train_dynamics=False,
                train_correspondence=False,
                return_loss_breakdown=False):

        # encode the batch of trajectories
        z, g, alpha = self.encode(s, a)

        # harden and save the log probabilities for REINFORCE
        p_alpha = Bernoulli(alpha)
        with torch.no_grad():
            hard_alpha = p_alpha.sample()

        reinforce_log_probs = p_alpha.log_prob(hard_alpha).mean()

        # NOTE: I couldn't think of a way to vectorize this along the
        # batch dimension. I think that's ok because we might always
        # end up training this with a batch size of 1 (where each batch
        # is one trajectory)
        i = 0  # the first trajectory in the batch
        mask = hard_alpha[i, :, 0].bool()
        short_z = z[i, mask]
        short_g = g[i, mask]

        if not self.training:
            return short_z, short_g, hard_alpha

        else:
            bce_loss = nn.BCEWithLogitsLoss(reduce=False)
            # the prior loss is the number of alpha == 1 (to encourage sparsity)
            prior_losses = hard_alpha * use_prior

            # compute losses for dynamics and correspondence functions
            dynamics_loss = torch.Tensor([0])
            correspondence_loss = torch.Tensor([0])

            if train_dynamics:
                short_z_recon = self.f(short_z[:-1], short_g[:-1])
                dynamics_loss += ((short_z_recon - short_z[1:])**2).mean()

            if train_correspondence:
                short_z_recon = self.c(s[i, mask])
                correspondence_loss += ((short_z_recon - short_z)**2).mean()

            # use the alpha mask to extend the high level goals for their duration
            hard_g = self.extend_goals_hard(g, hard_alpha)
            # use the high level goals to reconstruct low level actions
            a_recon = self.decode(s, hard_g)
            # and the quality of the reconstruction
            recon_losses = bce_loss(a_recon, a)
            # recon_loss = recon_losses.mean()

            per_timestep_losses = prior_losses + recon_losses + dynamics_loss + correspondence_loss
            loss = per_timestep_losses.mean()

            # use REINFORCE to estimate the gradients of the alpha parameters
            reinforce_loss = (reinforce_log_probs *
                              per_timestep_losses.detach()).mean()
            loss += reinforce_loss

            # return the latents and the losses
            if not return_loss_breakdown:
                return short_z, short_g, hard_alpha, loss
            else:
                loss_breakdown = [
                    loss.item(),
                    prior_losses.mean().item(),
                    recon_losses.mean().item(),
                    dynamics_loss.item(),
                    reinforce_loss.item()
                ]
                return short_z, short_g, hard_alpha, loss, loss_breakdown