Ejemplo n.º 1
0
 def sample(self, params, sample_shape=torch.Size()):
     mean, log_std = params['mean'], params['log_std']
     std = torch.exp(log_std)
     ac = Normal(loc=mean, scale=std).rsample(sample_shape)
     return ac
Ejemplo n.º 2
0
 def reparameterize_transformation(self, mu, var):
     untran_z = Normal(mu, var.sqrt()).rsample()
     z = self.z_transformation(untran_z)
     return z, untran_z
Ejemplo n.º 3
0
 def get_distribution(self, obs: torch.Tensor) -> Distribution:
     mu: torch.Tensor = self.mu_model(obs)
     std = torch.exp(self.log_std)
     self.pi = Normal(mu, std)
     return self.pi
Ejemplo n.º 4
0
    def inference(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        batch_index: Optional[torch.Tensor] = None,
        label: Optional[torch.Tensor] = None,
        n_samples=1,
        cont_covs=None,
        cat_covs=None,
    ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
        """
        Internal helper function to compute necessary inference quantities.

        We use the dictionary ``px_`` to contain the parameters of the ZINB/NB for genes.
        The rate refers to the mean of the NB, dropout refers to Bernoulli mixing parameters.
        `scale` refers to the quanity upon which differential expression is performed. For genes,
        this can be viewed as the mean of the underlying gamma distribution.

        We use the dictionary ``py_`` to contain the parameters of the Mixture NB distribution for proteins.
        `rate_fore` refers to foreground mean, while `rate_back` refers to background mean. ``scale`` refers to
        foreground mean adjusted for background probability and scaled to reside in simplex.
        ``back_alpha`` and ``back_beta`` are the posterior parameters for ``rate_back``.  ``fore_scale`` is the scaling
        factor that enforces `rate_fore` > `rate_back`.

        ``px_["r"]`` and ``py_["r"]`` are the inverse dispersion parameters for genes and protein, respectively.

        Parameters
        ----------
        x
            tensor of values with shape ``(batch_size, n_input_genes)``
        y
            tensor of values with shape ``(batch_size, n_input_proteins)``
        batch_index
            array that indicates which batch the cells belong to with shape ``batch_size``
        label
            tensor of cell-types labels with shape (batch_size, n_labels)
        n_samples
            Number of samples to sample from approximate posterior
        cont_covs
            Continuous covariates to condition on
        cat_covs
            Categorical covariates to condition on
        """
        x_ = x
        y_ = y
        if self.use_observed_lib_size:
            library_gene = x.sum(1).unsqueeze(1)
        if self.log_variational:
            x_ = torch.log(1 + x_)
            y_ = torch.log(1 + y_)

        if cont_covs is not None and self.encode_covariates is True:
            encoder_input = torch.cat((x_, y_, cont_covs), dim=-1)
        else:
            encoder_input = torch.cat((x_, y_), dim=-1)
        if cat_covs is not None and self.encode_covariates is True:
            categorical_input = torch.split(cat_covs, 1, dim=1)
        else:
            categorical_input = tuple()
        qz_m, qz_v, ql_m, ql_v, latent, untran_latent = self.encoder(
            encoder_input, batch_index, *categorical_input)
        z = latent["z"]
        untran_z = untran_latent["z"]
        untran_l = untran_latent["l"]
        if not self.use_observed_lib_size:
            library_gene = latent["l"]

        if n_samples > 1:
            qz_m = qz_m.unsqueeze(0).expand(
                (n_samples, qz_m.size(0), qz_m.size(1)))
            qz_v = qz_v.unsqueeze(0).expand(
                (n_samples, qz_v.size(0), qz_v.size(1)))
            untran_z = Normal(qz_m, qz_v.sqrt()).sample()
            z = self.encoder.z_transformation(untran_z)
            ql_m = ql_m.unsqueeze(0).expand(
                (n_samples, ql_m.size(0), ql_m.size(1)))
            ql_v = ql_v.unsqueeze(0).expand(
                (n_samples, ql_v.size(0), ql_v.size(1)))
            untran_l = Normal(ql_m, ql_v.sqrt()).sample()
            if self.use_observed_lib_size:
                library_gene = library_gene.unsqueeze(0).expand(
                    (n_samples, library_gene.size(0), library_gene.size(1)))
            else:
                library_gene = self.encoder.l_transformation(untran_l)

        # Background regularization
        if self.gene_dispersion == "gene-label":
            # px_r gets transposed - last dimension is nb genes
            px_r = F.linear(one_hot(label, self.n_labels), self.px_r)
        elif self.gene_dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.gene_dispersion == "gene":
            px_r = self.px_r
        px_r = torch.exp(px_r)

        if self.protein_dispersion == "protein-label":
            # py_r gets transposed - last dimension is n_proteins
            py_r = F.linear(one_hot(label, self.n_labels), self.py_r)
        elif self.protein_dispersion == "protein-batch":
            py_r = F.linear(one_hot(batch_index, self.n_batch), self.py_r)
        elif self.protein_dispersion == "protein":
            py_r = self.py_r
        py_r = torch.exp(py_r)
        if self.n_batch > 0:
            py_back_alpha_prior = F.linear(one_hot(batch_index, self.n_batch),
                                           self.background_pro_alpha)
            py_back_beta_prior = F.linear(
                one_hot(batch_index, self.n_batch),
                torch.exp(self.background_pro_log_beta),
            )
        else:
            py_back_alpha_prior = self.background_pro_alpha
            py_back_beta_prior = torch.exp(self.background_pro_log_beta)
        self.back_mean_prior = Normal(py_back_alpha_prior, py_back_beta_prior)

        return dict(
            qz_m=qz_m,
            qz_v=qz_v,
            z=z,
            untran_z=untran_z,
            ql_m=ql_m,
            ql_v=ql_v,
            library_gene=library_gene,
            untran_l=untran_l,
        )
Ejemplo n.º 5
0
def reparameterize_gaussian(mu, var):
    return Normal(mu, var.sqrt()).rsample()
Ejemplo n.º 6
0
    def loss(self,
             tensors,
             inference_outputs,
             generative_outputs,
             kl_weight: float = 1.0):
        # Get the data
        x = tensors[REGISTRY_KEYS.X_KEY]

        x_rna = x[:, :self.n_input_genes]
        x_chr = x[:, self.n_input_genes:]

        mask_expr = x_rna.sum(dim=1) > 0
        mask_acc = x_chr.sum(dim=1) > 0

        # Compute Accessibility loss
        x_accessibility = x[:, self.n_input_genes:]
        p = generative_outputs["p"]
        libsize_acc = inference_outputs["libsize_acc"]
        rl_accessibility = self.get_reconstruction_loss_accessibility(
            x_accessibility, p, libsize_acc)

        # Compute Expression loss
        px_rate = generative_outputs["px_rate"]
        px_r = generative_outputs["px_r"]
        px_dropout = generative_outputs["px_dropout"]
        x_expression = x[:, :self.n_input_genes]
        rl_expression = self.get_reconstruction_loss_expression(
            x_expression, px_rate, px_r, px_dropout)

        # mix losses to get the correct loss for each cell
        recon_loss = self._mix_modalities(
            rl_accessibility + rl_expression,  # paired
            rl_expression,  # expression
            rl_accessibility,  # accessibility
            mask_expr,
            mask_acc,
        )

        # Compute KLD between Z and N(0,I)
        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]
        kl_div_z = kld(
            Normal(qz_m, torch.sqrt(qz_v)),
            Normal(0, 1),
        ).sum(dim=1)

        # Compute KLD between distributions for paired data
        qzm_expr = inference_outputs["qzm_expr"]
        qzv_expr = inference_outputs["qzv_expr"]
        qzm_acc = inference_outputs["qzm_acc"]
        qzv_acc = inference_outputs["qzv_acc"]
        kld_paired = kld(Normal(qzm_expr, torch.sqrt(qzv_expr)),
                         Normal(qzm_acc, torch.sqrt(qzv_acc))) + kld(
                             Normal(qzm_acc, torch.sqrt(qzv_acc)),
                             Normal(qzm_expr, torch.sqrt(qzv_expr)))
        kld_paired = torch.where(
            torch.logical_and(mask_acc, mask_expr),
            kld_paired.T,
            torch.zeros_like(kld_paired).T,
        ).sum(dim=0)

        # KL WARMUP
        kl_local_for_warmup = kl_div_z
        weighted_kl_local = kl_weight * kl_local_for_warmup

        # PENALTY
        # distance_penalty = kl_weight * torch.pow(z_acc - z_expr, 2).sum(dim=1)

        # TOTAL LOSS
        loss = torch.mean(recon_loss + weighted_kl_local + kld_paired)

        kl_local = dict(kl_divergence_z=kl_div_z)
        kl_global = torch.tensor(0.0)
        return LossRecorder(loss, recon_loss, kl_local, kl_global)
Ejemplo n.º 7
0
 def __init__(self, fraction_on):
     super().__init__()
     self.fraction_on = fraction_on
     self.mix_off = Normal(-1, 0.1)
     self.mix_on = Normal(1, 2)
Ejemplo n.º 8
0
    def forward(self, state, encoding=None, device='cpu'):
        """
            - At the first time step, pass in the encoding vector from Encoder with shape (batch_size, hidden_size)
                using the optional argument encoding= . h_list and c_list will be reset to 0s
            - At the following time steps, DO NOT pass in any value to the optional argument encoding=
        """

        # TODO: Test the dimensions of this multilayer LSTM policy net

        # If encoding is not None, reset lists of hidden states and cell states
        if encoding is not None:
            self.h_list = [
                torch.zeros(
                    (self.batch_size, self.hidden_size), device=device) *
                self.num_layers
            ]
            self.c_list = [
                torch.zeros(
                    (self.batch_size, self.hidden_size), device=device) *
                self.num_layers
            ]
            self.h_list[0] = encoding

        # Forward propagation
        h1_list = []
        c1_list = []
        # First layer
        h_1, c_1 = self.cell_list[0](state, (self.h_list[0], self.c_list[0]))
        h1_list.append(h_1)
        c1_list.append(c_1)
        # Following layers
        for i in range(1, self.num_layers):
            h_1, c_1 = self.cell_list[i](h_1, (self.h_list[0], self.c_list[0]))
            h1_list.append(h_1)
            c1_list.append(c_1)
        # Store hidden states list and cell state list
        self.h_list = h1_list
        self.c_list = c1_list

        decision_logit = self.FC_decision(h_1)
        values_mean = self.FC_values_mean(h_1)
        values_logstd = self.FC_values_logstd(h_1)

        # Take the exponentials of log standard deviation
        values_std = torch.exp(values_logstd)

        # Create a categorical (multinomial) distribution from which we can sample a decision on the action dimension
        m_decision = OneHotCategorical(logits=decision_logit)

        # Sample a decision and calculate its log probability. decision of shape (num_actions,)
        decision = m_decision.sample()
        decision_log_prob = m_decision.log_prob(decision)

        # Create a list of Normal distributions for sampling actions in each dimension
        # Note: the last action is assumed to be discrete, meaning "doing nothing", so it has a conditional probability
        #       of 1.
        m_values = []
        action_values = None
        actions_log_prob = None
        # All actions except the last one are assumed to have normal distribution
        for i in range(self.num_actions - 1):
            m_values.append(Normal(values_mean[:, i], values_std[:, i]))
            if action_values is None:
                action_values = m_values[-1].sample().unsqueeze(
                    1)  # Unsqueeze to spare the batch dimension
                actions_log_prob = m_values[-1].log_prob(
                    action_values[:, -1]).unsqueeze(1)
            else:
                action_values = torch.cat(
                    [action_values, m_values[-1].sample().unsqueeze(1)], dim=1)
                actions_log_prob = torch.cat([
                    actions_log_prob, m_values[-1].log_prob(
                        action_values[:, -1]).unsqueeze(1)
                ],
                                             dim=1)

        # TODO: Append the last action. The last action has value 0.0 and log probability 0.0.
        action_values = torch.cat(
            [action_values,
             torch.zeros((self.batch_size, 1), device=device)],
            dim=1)
        actions_log_prob = torch.cat([
            actions_log_prob,
            torch.zeros((self.batch_size, 1), device=device)
        ],
                                     dim=1)

        # Filter the final action value in the intended action dimension
        final_action_values = (action_values * decision).sum(dim=1)
        final_action_log_prob = (actions_log_prob * decision).sum(dim=1)

        # Scale the action value by act_lim
        final_action_values = final_action_values * self.act_lim

        # Calculate the final log probability
        #   Pr(action value in the ith dimension) = Pr(action value given the agent chooses the ith dimension)
        #                                           * Pr(the agent chooses the ith dimension
        log_prob = decision_log_prob + final_action_log_prob

        return decision, final_action_values, log_prob
Ejemplo n.º 9
0
    def forward(self, X, A, beta=1, print_output=False):
        """
        Likelihood objective function for a given trajectory (change to batched verision later)
        X: data matrix of shape [seq_length, batch, output_size]]      (we only feed one trajectory here for testing)
        A: data matrix of action [seq_length-1, batch]
        """
        assert X.size(0) == A.size(0) + 1, print(
            'the seq length of X and A are wrong')
        kl_loss = 0  # KL divergence term
        Ell_loss = 0  # expected log likelihood term
        batch_size = X.size(1)

        if len(X.size()) != 3:
            print(
                'The input data matrix should be the shape of [seq_length, batch_size, input_dim]'
            )

        X = X.to(self.device)
        A = A.to(self.device)

        # container
        states = torch.zeros(A.size(0), A.size(1), self.state_size).to(
            self.device)  # [seq-1, batch, state]
        rnn_hiddens = torch.zeros(A.size(0), A.size(1), self.hidden_size).to(
            self.device)  # [seq-1, batch, hidden]

        # initialising state and rnn hidden state
        # state = torch.zeros(X.size(1), self.state_size).to(self.device)
        rnn_hidden = self.init_h(X[0]).to(self.device)  # [batch, hidden]
        if self.mode == 'LSTM':
            rnn_hidden_c = torch.zeros_like(rnn_hidden).to(
                self.device)  # [batch, hidden]

        # temp_prior = self.hidden_prior(rnn_hidden)      #[batch, state]
        temp_prior = rnn_hidden
        prior_mean = self.prior_mean(temp_prior)  # [batch, state]
        prior_sigma = torch.exp(self.prior_sigma(temp_prior))  # [batch, state]
        state = self.reparametrise(prior_mean, prior_sigma)  # [batch, state]

        # rnn_hidden = torch.zeros(X.size(1), self.hidden_size).to(self.device)

        # emission_mean = X[0]
        for t in range(
                1,
                X.size()[0]
        ):  # for each time step, compute the free energy for each batch of data (start from the second hid state)
            if self.mode == 'LSTM':
                next_state_prior_m, next_state_prior_sigma, rnn_hidden, rnn_hidden_c = self.prior(
                    state, A[t - 1].unsqueeze(-1), rnn_hidden, rnn_hidden_c)
            else:
                next_state_prior_m, next_state_prior_sigma, rnn_hidden = self.prior(
                    state, A[t - 1].unsqueeze(-1), rnn_hidden)

            next_state_post_m, next_state_post_sigma = self.posterior(
                rnn_hidden, X[t])
            state = self.reparametrise(
                next_state_post_m,
                next_state_post_sigma)  # [batch, state_size]
            states[t - 1] = state
            rnn_hiddens[t - 1] = rnn_hidden
            next_state_prior = Normal(next_state_prior_m,
                                      next_state_prior_sigma)
            next_state_post = Normal(next_state_post_m, next_state_post_sigma)

            # kl = kl_divergence(next_state_prior, next_state_post).sum(dim=1)        #[batch]
            kl = kl_divergence(next_state_post,
                               next_state_prior).sum(dim=1)  # [batch]

            kl_loss += kl.mean()
        kl_loss /= A.size(0)

        # compute nll

        # flatten state
        flatten_states = states.view(-1, self.state_size)
        flatten_rnn_hiddens = rnn_hiddens.view(-1, self.hidden_size)
        flatten_x_mean, flatten_x_sigma = self.obs_model(
            flatten_states, flatten_rnn_hiddens)

        nll = self.batched_gaussian_ll(
            flatten_x_mean, flatten_x_sigma,
            X[1:, :, :].reshape(-1, self.output_size))
        nll = nll.mean()

        FE = nll - kl_loss

        if print_output:
            # print('ELL loss=', Ell_loss, 'KL loss=', kl_loss)
            print(
                'Free energy of this batch = {}. Nll loss = {}. KL div = {}.'.
                format(float(FE.data), float(nll.data), float(kl_loss.data)))

        return FE, nll, kl_loss
Ejemplo n.º 10
0
 def forward(self, dists, other):
     results, out = other
     return Normal(self.block(results[-1]), self.logvar.exp())
Ejemplo n.º 11
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('run_number',
                        default=999,
                        help="Consecutive number of this run")
    parser.add_argument('gym_id',
                        default='MountainCarContinuous-v0',
                        help="Id of the Gym environment")
    parser.add_argument('-e',
                        '--episodes',
                        type=int,
                        default=1000,
                        help="Number of episodes")
    parser.add_argument('-t',
                        '--timesteps',
                        type=int,
                        default=None,
                        help="Number of timesteps")
    parser.add_argument('-m',
                        '--max-episode-timesteps',
                        type=int,
                        default=None,
                        help="Maximum number of timesteps per episode")
    parser.add_argument('-ns', '--network-size', type=int, default=12)
    parser.add_argument('-cs', '--curious-size', type=int, default=32)
    parser.add_argument('-rs', '--random-seeds', type=int, default=100)
    parser.add_argument('-bt', '--batch-size', type=int, default=1028)
    parser.add_argument('-mc', '--memory-capacity', type=int, default=10000)
    parser.add_argument('-os', '--optimization-steps', type=int, default=10)
    parser.add_argument('-bs', '--baseline-steps', type=int, default=1)
    parser.add_argument('-sf', '--subsampling-fraction', type=int, default=256)
    parser.add_argument('-lr', '--likelihood-ratio', type=float, default=0.1)
    parser.add_argument('-sd',
                        '--seed',
                        type=int,
                        default=None,
                        help='Random seed for this trial')
    parser.add_argument('-gr', '--gamma-reward', type=float, default=0.99)
    parser.add_argument('-gc', '--gamma-curious', type=float, default=0.99)

    args = parser.parse_args()

    # SET BASIC PARAMETERS

    sys.path.append(
        os.path.abspath("C:\\Users\\genia\\Source\\Repos\\Box2dEnv\\Box2dEnv"))

    env = gym.make('MountainCarContinuous-v0')

    load = False

    N_STATES = 2
    N_CURIOUS_STATES = args.random_seeds
    N_ACTIONS = 1
    np.set_printoptions(precision=3,
                        linewidth=200,
                        floatmode='fixed',
                        suppress=True)
    torch.set_printoptions(precision=3)

    # Initialise network and hyper params
    device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
    # torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.set_default_tensor_type('torch.FloatTensor')
    N_HIDDEN = args.network_size
    N_HIDDEN_RND = args.curious_size
    N_CHANNELS_RND = args.curious_size

    random_seed = args.seed
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)

    ac_net_critic = Net_Critic.Net(N_STATES, N_HIDDEN)
    ac_net_actor = Net_Actor.Net(N_STATES, N_ACTIONS, N_HIDDEN)
    ac_net_c_critic = Net_Critic.Net(N_STATES, N_HIDDEN)
    ac_net_rnd = Net_rnd.Net(N_CURIOUS_STATES, N_CHANNELS_RND, N_HIDDEN_RND)
    ac_net_pred = Net_rnd.Net(N_CURIOUS_STATES, N_CHANNELS_RND, N_HIDDEN_RND)

    criterion_val = nn.SmoothL1Loss()
    # optimizer_c = torch.optim.Adam(ac_net_critic.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.00, amsgrad=False)
    # optimizer_cc = torch.optim.Adam(ac_net_c_critic.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.00, amsgrad=False)
    # optimizer_a = torch.optim.Adam(ac_net_actor.parameters(), lr=0.00001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.00, amsgrad=False)

    optimizer_c = torch.optim.SGD(ac_net_critic.parameters(),
                                  lr=0.001,
                                  momentum=0.9,
                                  nesterov=True)
    optimizer_cc = torch.optim.SGD(ac_net_c_critic.parameters(),
                                   lr=0.001,
                                   momentum=0.9,
                                   nesterov=True)
    optimizer_a = torch.optim.SGD(ac_net_actor.parameters(),
                                  lr=0.001,
                                  momentum=0.9,
                                  nesterov=True)
    optimizer_rnd = torch.optim.SGD(ac_net_pred.parameters(),
                                    lr=0.0005,
                                    momentum=0.0,
                                    nesterov=False)

    gamma1 = args.gamma_reward
    gamma2 = args.gamma_curious

    return_time = 1

    N_STEPS = args.memory_capacity
    # N_STEPS = 500
    N_TRAJECTORIES = 12
    K_epochs = args.optimization_steps
    B_epochs = args.baseline_steps
    R_epochs = 1
    N_MINI_BATCH = args.batch_size
    epsilon = args.likelihood_ratio
    N_CURIOUS_BATCH = args.subsampling_fraction

    avg_reward = deque(maxlen=50)
    avg_curious_reward = deque(maxlen=50)
    avg_max_height_q = deque(maxlen=50)
    avg_STD = deque()
    avg_critic_loss = deque()
    avg_reward_STD = deque()
    avg_value_STD = deque()

    p1 = np.random.normal(0, 5, (N_CURIOUS_STATES, 1))
    p2 = np.random.normal(0, 15, (N_CURIOUS_STATES, 1))

    def get_curious_state(curious_state, p1i, p2i):
        curious_state_t_new = np.zeros((len(curious_state), N_CURIOUS_STATES))
        curious_state_1 = curious_state[:, 0]
        curious_state_2 = curious_state[:, 1] / 0.07
        for x, p1x, p2x, p1y, p2y in zip(range(N_CURIOUS_STATES), p1i, p2i,
                                         reversed(p1i), reversed(p2i)):
            curious_state_t_new[:, x] = np.squeeze(
                p1x * np.cos(p2x * (-curious_state_1)) +
                p1y * np.sin(p2y * (-curious_state_1)))
            curious_state_t_new[:, x] += np.squeeze(
                p1x * np.cos(p2x * (-curious_state_2)) +
                p1y * np.sin(p2y * (-curious_state_2)))
        return torch.tensor(curious_state_t_new).float()

    # START TRAINING
    episodes = args.episodes
    env.env.unwrapped.seed(random_seed)
    first_batch = True
    episode_i = 0
    total_i = 0
    curious_reward_std = 0.2
    while episode_i < episodes:  # START MAIN LOOP
        cur_state_q = []
        next_state_q = []
        reward_q = []
        action_log_prob_q = []
        value_q = []
        advantage_q_new = []
        done_q = []
        action_q = []
        avg_reward_batch = []
        avg_curious_reward_batch = []
        curious_reward_q = []
        avg_max_height = []
        i_in_batch = 0
        completed_q = []

        BATCH_REWARD = []
        BATCH_CURIOUS_REWARD = []
        BATCH_MAX_HEIGHT = []

        while i_in_batch < N_STEPS:  # START EPISODE BATCH LOOP
            cur_state = env.reset()
            cur_state_copy = cur_state.copy()
            cur_state_copy[1] = cur_state_copy[1] / 0.035

            done = False
            ret = 0
            curious_ret = 0
            i_in_episode = 0
            episode_distance_q = []
            next_cur_state_episode_q = []

            while not done:  # RUN SINGLE EPISODE
                # Get parameters for distribution and assign action
                # cur_state[1] = cur_state[1]/0.0035
                torch_state = torch.tensor(cur_state).unsqueeze(0).float()
                with torch.no_grad():
                    mu, sd = ac_net_actor(torch_state)
                    # val_out = ac_net_critic(torch_state)
                    # curious_out = ac_net_c_critic(torch_state)
                distribution = Normal(mu[0], sd[0])
                action = distribution.sample()
                if episode_i < 15:
                    clamped_action = torch.clamp(action, -1, 1).data.numpy()
                else:
                    clamped_action = torch.clamp(action, -1, 1).data.numpy()

                episode_distance_q.append(cur_state[0])
                # Step environment
                next_state, reward, done, info = env.step(clamped_action)
                # Append values to queues
                # Append values to queues
                cur_state_q.append(cur_state_copy)

                next_state_copy = next_state.copy()
                next_state_copy[1] = next_state_copy[1] / 0.035
                next_cur_state_episode_q.append(next_state_copy)
                next_state_q.append(next_state_copy)

                reward_i = reward / 20.0
                reward_q.append(float(reward_i))
                # value_q.append(val_out)
                action_q.append(action.data.numpy())
                action_log_prob_q.append(
                    distribution.log_prob(
                        torch.tensor(clamped_action)).data.numpy())
                done_q.append(1 - done)  # Why 1-done?

                ret += reward  # Sum total reward for episode

                # Iterate counters, etc
                cur_state = next_state
                cur_state_copy = next_state_copy
                i_in_episode += 1
                i_in_batch += 1
                total_i += 1
                if i_in_episode % 1 == 0 and episode_i % 500 == 0 and episode_i > 0:
                    env.render()
                # if i_in_episode > 500:
                #     done = True
                if done:
                    break

            # END SINGLE EPISODE

            if ret > 0.01:
                completed_q += np.ones((len(episode_distance_q), 1)).tolist()
            else:
                completed_q += np.zeros((len(episode_distance_q), 1)).tolist()

            next_state_episode = np.asarray(next_cur_state_episode_q)
            # next_curious_state = get_curious_state(next_state_episode, p1, p2)
            #
            # with torch.no_grad():
            #     rnd_val = ac_net_rnd(next_curious_state.unsqueeze(1))
            #     pred_val = ac_net_pred(next_curious_state.unsqueeze(1))
            #
            # curious_reward_episode = torch.pow((rnd_val - pred_val), 2)
            # curious_rewards_episode = (curious_reward_episode.data.numpy())

            curious_rewards_episode = completed_q
            curious_reward_q += curious_rewards_episode  # .tolist()
            curious_ret = np.sum(curious_rewards_episode)
            avg_curious_ret = curious_ret / i_in_episode

            episode_i += 1
            avg_reward.append(ret)
            avg_curious_reward.append(curious_ret)
            avg_reward_batch.append(ret)
            avg_curious_reward_batch.append(curious_ret)
            avg_max_height_q.append(np.max(episode_distance_q))
            avg_max_height.append(np.max(episode_distance_q))
            print("%4d, %6.2f, %6.0f | " %
                  (episode_i, np.max(episode_distance_q), curious_ret))

            BATCH_CURIOUS_REWARD.append(curious_ret)
            BATCH_MAX_HEIGHT.append(np.max(episode_distance_q))
            BATCH_REWARD.append(ret)

        # print("")
        # END EPISODE BATCH LOOP

        max_achieved_height_in_batch = np.max(avg_max_height)

        # NORMALIZE CURIOUS REWARD
        if first_batch:
            curious_reward_std = np.std(np.asarray(curious_reward_q))
            first_batch = False

        # START CUMULATIVE REWARD CALC
        curious_reward_q = curious_reward_q / curious_reward_std
        discounted_reward = []
        discounted_curious_reward = []
        cul_reward = 0
        cul_curious_reward = 0
        for reward, cur_reward, done, in zip(reversed(reward_q),
                                             reversed(curious_reward_q),
                                             reversed(done_q)):
            if done == 1:
                cul_curious_reward = cul_curious_reward * gamma2 + cur_reward
                cul_reward = cul_reward * gamma1 + reward
                discounted_reward.insert(0, cul_reward)
                discounted_curious_reward.insert(0, cul_curious_reward)
            elif done == 0:
                cul_reward = reward
                cul_curious_reward = cur_reward * gamma2 + reward
                discounted_reward.insert(0, cul_reward)
                discounted_curious_reward.insert(0, cul_curious_reward)

        # CALCULATE ADVANTAGE
        # Why is this a loop, dumbass?
        current_state_t = torch.tensor(cur_state_q).float()
        curious_advantage_q_new = []
        advantage_q_new = []
        with torch.no_grad():
            value_t_new = ac_net_critic(current_state_t)
            curious_value_t_new = ac_net_c_critic(current_state_t)

        for reward_i, value_i in zip(np.asarray(discounted_reward),
                                     value_t_new.data.numpy()):
            advantage_q_new.append(reward_i - value_i)
        advantage_q_new = np.asarray(advantage_q_new)
        for reward_i, value_i in zip(np.asarray(discounted_curious_reward),
                                     curious_value_t_new.data.numpy()):
            curious_advantage_q_new.append(reward_i - value_i)
        curious_advantage_q_new = np.asarray(curious_advantage_q_new)

        advantage_q_new = (advantage_q_new - np.mean(advantage_q_new)) / (
            np.std(advantage_q_new)
        )  # Should advantage be recalculated at each optimize step?
        curious_advantage_q_new = (
            curious_advantage_q_new - np.mean(curious_advantage_q_new)) / (
                np.std(curious_advantage_q_new)
            )  # Should advantage be recalculated at each optimize step?
        # curious_advantage_q_new = (np.asarray(discounted_curious_reward) -np.mean(discounted_curious_reward))/(np.std(discounted_curious_reward))  # Should advantage be recalculated at each optimize step?

        max_curious_advantage = np.max(curious_advantage_q_new)
        std_curious_advantage = np.std(curious_advantage_q_new)
        mean_curious_advantage = np.mean(curious_advantage_q_new)

        max_advantage = np.max(advantage_q_new)
        std_advantage = np.std(advantage_q_new)
        mean_advantage = np.mean(advantage_q_new)

        advantage_t = torch.tensor(advantage_q_new).float()
        curious_advantage_t = torch.tensor(curious_advantage_q_new).float()
        completed_t = torch.tensor(np.asarray(completed_q)).float()
        # advantage_t = completed_t * advantage_t
        a_prop = 0.5
        summed_advantage_t = torch.add(torch.mul(advantage_t, 1),
                                       torch.mul(curious_advantage_t, 1))

        # START UPDATING NETWORKS

        batch_length = len(cur_state_q)

        action_log_prob_t = torch.tensor(action_log_prob_q).float()
        action_t = torch.tensor(action_q).float()
        reward_t = torch.tensor(discounted_reward).float()
        curious_reward_t = torch.tensor(discounted_curious_reward).float()
        summed_reward_t = torch.add(curious_reward_t, reward_t)

        # START BASELINE OPTIMIZE
        avg_baseline_loss = []
        for epoch in range(B_epochs):
            # Get random permutation of indexes
            indexes = torch.tensor(np.random.permutation(batch_length)).type(
                torch.LongTensor)
            n_batch = 0
            batch_start = 0
            batch_end = 0
            # Loop over permutation
            avg_baseline_batch_loss = []
            avg_baseline_curious_batch_loss = []
            while batch_end < batch_length:
                # Get batch indexes
                batch_end = batch_start + N_MINI_BATCH
                if batch_end > batch_length:
                    batch_end = batch_length

                batch_idx = indexes[batch_start:batch_end]

                # Gather data from saved tensors

                batch_state_t = torch.index_select(current_state_t, 0,
                                                   batch_idx).float()
                batch_reward_t = torch.index_select(reward_t, 0, batch_idx)
                batch_curious_reward_t = torch.index_select(
                    curious_reward_t, 0, batch_idx)
                batch_summed_reward_t = torch.index_select(
                    summed_reward_t, 0, batch_idx)
                batch_start = batch_end

                n_batch += 1

                # Get new baseline values
                new_val = ac_net_critic(batch_state_t)
                new_curious_val = ac_net_c_critic(batch_state_t)
                # Calculate loss compared with reward and optimize
                # NEEDS TO BE OPTIMIZED WITH CURIOUS VAL AS WELL
                # new_summed_val = new_val + new_curious_val
                critic_loss_batch = criterion_val(new_val,
                                                  batch_reward_t.unsqueeze(1))
                critic_curious_loss_batch = criterion_val(
                    new_curious_val, batch_curious_reward_t.unsqueeze(1))
                # critic_loss_batch = criterion_val(new_summed_val, batch_summed_reward_t.unsqueeze(1))
                # critic_loss_both = critic_curious_loss_batch  # + critic_loss_batch
                optimizer_c.zero_grad()
                optimizer_cc.zero_grad()

                critic_loss_batch.backward()
                critic_curious_loss_batch.backward()
                optimizer_cc.step()
                optimizer_c.step()

                # avg_value_STD.append(critic_loss_batch.item())
                avg_baseline_batch_loss.append(critic_loss_batch.item())
                avg_baseline_curious_batch_loss.append(
                    critic_curious_loss_batch.item())
            # print(np.mean(avg_baseline_batch_loss), np.mean(avg_baseline_curious_batch_loss), " ", end="")
            # avg_baseline_loss.append(np.mean(avg_baseline_batch_loss))

        # print("")
        # END BASELINE OPTIMIZE

        # START POLICY OPTIMIZE
        for epoch in range(K_epochs):
            # Get random permutation of indexes
            indexes = torch.tensor(np.random.permutation(batch_length)).type(
                torch.LongTensor)
            n_batch = 0
            batch_start = 0
            batch_end = 0
            # Loop over permutation
            while batch_end < batch_length:
                # Get batch indexes
                batch_end = batch_start + N_MINI_BATCH
                if batch_end > batch_length:
                    batch_end = batch_length

                batch_idx = indexes[batch_start:batch_end]

                # Gather data from saved tensors
                batch_state_t = torch.index_select(current_state_t, 0,
                                                   batch_idx).float()
                if np.max(reward_q) > 0.01:
                    batch_advantage_t = torch.index_select(
                        advantage_t, 0, batch_idx)
                else:
                    batch_advantage_t = torch.index_select(
                        advantage_t, 0, batch_idx)
                    # batch_advantage_t = torch.index_select(curious_advantage_t, 0, batch_idx)

                # batch_advantage_t = torch.index_select(summed_advantage_t, 0, batch_idx)

                batch_action_log_prob_t = torch.index_select(
                    action_log_prob_t, 0, batch_idx)
                batch_action_t = torch.index_select(action_t, 0, batch_idx)
                # batch_reward_t = torch.index_select(reward_t, 0, batch_idx)

                batch_start = batch_end
                n_batch += 1

                # Get new batch of parameters and action log probs
                mu_batch, sd_batch = ac_net_actor(batch_state_t)
                batch_distribution = Normal(mu_batch, sd_batch)
                exp_probs = batch_distribution.log_prob(batch_action_t).exp()
                old_exp_probs = batch_action_log_prob_t.exp()
                r_theta_i = torch.div(exp_probs, old_exp_probs)

                # Advantage needs to include curious advantage. Should advantage be recalculated each epoch?
                batch_advantage_t4 = batch_advantage_t.expand_as(r_theta_i)

                surrogate1 = r_theta_i * batch_advantage_t4
                surrogate2 = torch.clamp(r_theta_i, 1 - epsilon,
                                         1 + epsilon) * batch_advantage_t4

                r_theta_surrogate_min = torch.min(surrogate1, surrogate2)
                L_clip = -torch.sum(
                    r_theta_surrogate_min) / r_theta_surrogate_min.size()[0]
                optimizer_a.zero_grad()
                L_clip.backward()
                optimizer_a.step()

        # END OPTIMIZE POLICY

        # START OPTIMIZE CURIOUS

        # curious_state_t = get_curious_state(np.asarray(cur_state_q), p1, p2)
        # avg_curious_loss = []
        # curious_batch_length = N_CURIOUS_BATCH
        # for epoch in range(R_epochs):
        #     # Get random permutation of indexes
        #     indexes = torch.tensor(np.random.permutation(batch_length)).type(torch.LongTensor)
        #     n_batch = 0
        #     batch_start = 0
        #     batch_end = 0
        #     # Loop over permutation
        #     # avg_curious_loss = []
        #     while batch_end < curious_batch_length:
        #         # Get batch indexes
        #         batch_end = batch_start + N_CURIOUS_BATCH
        #         if batch_end > curious_batch_length:
        #             batch_end = curious_batch_length
        #
        #         batch_idx = indexes[batch_start:batch_end]
        #
        #         # Gather data from saved tensors
        #         batch_state_t = torch.index_select(curious_state_t, 0, batch_idx).float()
        #         batch_state_t = batch_state_t.unsqueeze(1)
        #         # batch_reward_t = torch.index_select(reward_t, 0, batch_idx)
        #         # batch_summed_reward_t = torch.index_select(summed_reward_t, 0, batch_idx)
        #         batch_start = batch_end
        #         n_batch += 1
        #
        #         with torch.no_grad():
        #             rnd_val = ac_net_rnd(batch_state_t)
        #         pred_val = ac_net_pred(batch_state_t)
        #         # Calculate loss compared with reward and optimize
        #         optimizer_rnd.zero_grad()
        #         pred_loss_batch_curious = criterion_val(pred_val, rnd_val)
        #         pred_loss_batch_curious.backward()
        #         # nn.utils.clip_grad_norm(ac_net_pred.parameters(), 1)
        #         # nn.utils.clip_grad_value_(ac_net_pred.parameters(), 100)
        #         # clip_min_grad_value_(ac_net_pred.parameters(), 0.2)
        #
        #         optimizer_rnd.step()
        #         avg_curious_loss.append(pred_loss_batch_curious.item())

        # print((pred_loss_batch_curious.data.numpy()), " ", end="")
        # print("")
        # print(epoch)
        # print("")

        # Naming variables
        run_number = args.run_number
        nNum = str(run_number).zfill(3)
        nSeed = str(random_seed).zfill(2)

        nName = ("{}-{}-RT".format(nNum, nSeed))

        if episode_i % return_time == 0:
            print(
                "%4d | %6.0d | %6.1f, %6.1f | %6.1f, %6.1f | %6.2f, %6.2f, %6.2f | %6.2f, %6.2f, %6.2f | %6.2f, %6.2f"
                % (episode_i, total_i, np.mean(avg_reward_batch),
                   np.mean(avg_reward), np.mean(avg_curious_reward_batch),
                   np.mean(avg_curious_reward), max_advantage, mean_advantage,
                   std_advantage, max_curious_advantage,
                   mean_curious_advantage, std_curious_advantage,
                   max_achieved_height_in_batch, np.mean(avg_max_height_q)))
            with open(
                    'C:\\Users\\genia\\source\\repos\\Box2dEnv\\Box2dEnv\\saves\\{}.csv'
                    .format(nName), 'a+') as csv:
                for retw, curious_retw, max_heightw in zip(
                        BATCH_REWARD, BATCH_CURIOUS_REWARD, BATCH_MAX_HEIGHT):
                    csv.write("{:2.2f},{:2.2f},{:2.2f}\n".format(
                        retw, curious_retw, max_heightw))
Ejemplo n.º 12
0
    def sample_action(self):
        normal = Normal(0, 1)
        random_action = self.action_range * normal.sample((self._action_dim, ))

        return random_action.cpu().numpy()
Ejemplo n.º 13
0
 def ent(self, params):
     mean = params['mean']
     log_std = params['log_std']
     std = torch.exp(log_std)
     return torch.sum(Normal(loc=mean, scale=std).entropy(), dim=-1)
Ejemplo n.º 14
0
 def llh(self, x, params):
     mean, log_std = params['mean'], params['log_std']
     std = torch.exp(log_std)
     return torch.sum(Normal(loc=mean, scale=std).log_prob(x), dim=-1)
Ejemplo n.º 15
0
 def _sample(self, mean, var, num_sample = 100):
     dist = Normal(mean, var)
     samples = dist.sample([num_sample])
     return samples
Ejemplo n.º 16
0
    def __init__(self, loc, scale, transforms, validate_args=None):

        super(Glow, self).__init__(Normal(loc, scale),
                                   transforms,
                                   validate_args=validate_args)
Ejemplo n.º 17
0
    def inference(
        self,
        x,
        batch_index,
        cont_covs,
        cat_covs,
        n_samples=1,
    ) -> Dict[str, torch.Tensor]:

        # Get Data and Additional Covs
        x_rna = x[:, :self.n_input_genes]
        x_chr = x[:, self.n_input_genes:]

        mask_expr = x_rna.sum(dim=1) > 0
        mask_acc = x_chr.sum(dim=1) > 0

        if cont_covs is not None and self.encode_covariates:
            encoder_input_expression = torch.cat((x_rna, cont_covs), dim=-1)
            encoder_input_accessibility = torch.cat((x_chr, cont_covs), dim=-1)
        else:
            encoder_input_expression = x_rna
            encoder_input_accessibility = x_chr

        if cat_covs is not None and self.encode_covariates:
            categorical_input = torch.split(cat_covs, 1, dim=1)
        else:
            categorical_input = tuple()

        # Z Encoders
        qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility(
            encoder_input_accessibility, batch_index, *categorical_input)
        qzm_expr, qzv_expr, z_expr = self.z_encoder_expression(
            encoder_input_expression, batch_index, *categorical_input)

        # L encoders
        libsize_expr = self.l_encoder_expression(encoder_input_expression,
                                                 batch_index,
                                                 *categorical_input)
        libsize_acc = self.l_encoder_accessibility(encoder_input_accessibility,
                                                   batch_index,
                                                   *categorical_input)

        # ReFormat Outputs
        if n_samples > 1:
            qzm_acc = qzm_acc.unsqueeze(0).expand(
                (n_samples, qzm_acc.size(0), qzm_acc.size(1)))
            qzv_acc = qzv_acc.unsqueeze(0).expand(
                (n_samples, qzv_acc.size(0), qzv_acc.size(1)))
            untran_za = Normal(qzm_acc, qzv_acc.sqrt()).sample()
            z_acc = self.z_encoder_accessibility.z_transformation(untran_za)

            qzm_expr = qzm_expr.unsqueeze(0).expand(
                (n_samples, qzm_expr.size(0), qzm_expr.size(1)))
            qzv_expr = qzv_expr.unsqueeze(0).expand(
                (n_samples, qzv_expr.size(0), qzv_expr.size(1)))
            untran_zr = Normal(qzm_expr, qzv_expr.sqrt()).sample()
            z_expr = self.z_encoder_expression.z_transformation(untran_zr)

            libsize_expr = libsize_expr.unsqueeze(0).expand(
                (n_samples, libsize_expr.size(0), libsize_expr.size(1)))
            libsize_acc = libsize_acc.unsqueeze(0).expand(
                (n_samples, libsize_acc.size(0), libsize_acc.size(1)))

        ## Sample from the average distribution
        qzp_m = (qzm_acc + qzm_expr) / 2
        qzp_v = (qzv_acc + qzv_expr) / (2**0.5)
        zp = Normal(qzp_m, qzp_v.sqrt()).rsample()

        ## choose the correct latent representation based on the modality
        qz_m = self._mix_modalities(qzp_m, qzm_expr, qzm_acc, mask_expr,
                                    mask_acc)
        qz_v = self._mix_modalities(qzp_v, qzv_expr, qzv_acc, mask_expr,
                                    mask_acc)
        z = self._mix_modalities(zp, z_expr, z_acc, mask_expr, mask_acc)

        outputs = dict(
            z=z,
            qz_m=qz_m,
            qz_v=qz_v,
            z_expr=z_expr,
            qzm_expr=qzm_expr,
            qzv_expr=qzv_expr,
            z_acc=z_acc,
            qzm_acc=qzm_acc,
            qzv_acc=qzv_acc,
            libsize_expr=libsize_expr,
            libsize_acc=libsize_acc,
        )
        return outputs
Ejemplo n.º 18
0
def main():
    # Dataset.
    ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_ = make_data()

    # Plotting parameters.
    vis_batch_size = 1024
    ylims = (-1.75, 1.75)
    alphas = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55]
    percentiles = [0.999, 0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
    vis_idx = npr.permutation(vis_batch_size)
    # From https://colorbrewer2.org/.
    if args.color == "blue":
        sample_colors = ('#8c96c6', '#8c6bb1', '#810f7c')
        fill_color = '#9ebcda'
        mean_color = '#4d004b'
        num_samples = len(sample_colors)
    else:
        sample_colors = ('#fc4e2a', '#e31a1c', '#bd0026')
        fill_color = '#fd8d3c'
        mean_color = '#800026'
        num_samples = len(sample_colors)

    # Fix seed for the random draws used in the plots.
    eps = torch.randn(vis_batch_size, 1).to(device)
    bm = BrownianPath(t0=ts_vis[0],
                      w0=torch.zeros(vis_batch_size, 1).to(device))

    # Model.
    model = LatentSDE().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-2)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=.999)
    kl_scheduler = utils.LinearScheduler(iters=args.kl_anneal_iters)

    logp_metric = utils.EMAMetric()
    log_ratio_metric = utils.EMAMetric()
    loss_metric = utils.EMAMetric()

    if args.show_prior:
        with torch.no_grad():
            zs = model.sample_p(ts=ts_vis,
                                batch_size=vis_batch_size,
                                eps=eps,
                                bm=bm).squeeze()
            ts_vis_, zs_ = ts_vis.cpu().numpy(), zs.cpu().numpy()
            zs_ = np.sort(zs_, axis=1)

            img_dir = os.path.join(args.train_dir, 'prior.png')
            plt.subplot(frameon=False)
            for alpha, percentile in zip(alphas, percentiles):
                idx = int((1 - percentile) / 2. * vis_batch_size)
                zs_bot_ = zs_[:, idx]
                zs_top_ = zs_[:, -idx]
                plt.fill_between(ts_vis_,
                                 zs_bot_,
                                 zs_top_,
                                 alpha=alpha,
                                 color=fill_color)

            # `zorder` determines who's on top; the larger the more at the top.
            plt.scatter(ts_, ys_, marker='x', zorder=3, color='k',
                        s=35)  # Data.
            plt.ylim(ylims)
            plt.xlabel('$t$')
            plt.ylabel('$Y_t$')
            plt.tight_layout()
            plt.savefig(img_dir, dpi=args.dpi)
            plt.close()
            logging.info(f'Saved prior figure at: {img_dir}')

    for global_step in tqdm.tqdm(range(args.train_iters)):
        # Plot and save.
        if global_step % args.pause_iters == 0:
            img_path = os.path.join(args.train_dir,
                                    f'global_step_{global_step}.png')

            with torch.no_grad():
                zs = model.sample_q(ts=ts_vis,
                                    batch_size=vis_batch_size,
                                    eps=eps,
                                    bm=bm).squeeze()
                samples = zs[:, vis_idx]
                ts_vis_, zs_, samples_ = ts_vis.cpu().numpy(), zs.cpu().numpy(
                ), samples.cpu().numpy()
                zs_ = np.sort(zs_, axis=1)
                plt.subplot(frameon=False)

                if args.show_percentiles:
                    for alpha, percentile in zip(alphas, percentiles):
                        idx = int((1 - percentile) / 2. * vis_batch_size)
                        zs_bot_, zs_top_ = zs_[:, idx], zs_[:, -idx]
                        plt.fill_between(ts_vis_,
                                         zs_bot_,
                                         zs_top_,
                                         alpha=alpha,
                                         color=fill_color)

                if args.show_mean:
                    plt.plot(ts_vis_, zs_.mean(axis=1), color=mean_color)

                if args.show_samples:
                    for j in range(num_samples):
                        plt.plot(ts_vis_,
                                 samples_[:, j],
                                 color=sample_colors[j],
                                 linewidth=1.0)

                if args.show_arrows:
                    num, dt = 12, 0.12
                    t, y = torch.meshgrid([
                        torch.linspace(0.2, 1.8, num).to(device),
                        torch.linspace(-1.5, 1.5, num).to(device)
                    ])
                    t, y = t.reshape(-1, 1), y.reshape(-1, 1)
                    fty = model.f(t=t, y=y).reshape(num, num)
                    dt = torch.zeros(num, num).fill_(dt).to(device)
                    dy = fty * dt
                    dt_, dy_, t_, y_ = dt.cpu().numpy(), dy.cpu().numpy(
                    ), t.cpu().numpy(), y.cpu().numpy()
                    plt.quiver(t_,
                               y_,
                               dt_,
                               dy_,
                               alpha=0.3,
                               edgecolors='k',
                               width=0.0035,
                               scale=50)

                if args.hide_ticks:
                    plt.xticks([], [])
                    plt.yticks([], [])

                plt.scatter(ts_, ys_, marker='x', zorder=3, color='k',
                            s=35)  # Data.
                plt.ylim(ylims)
                plt.xlabel('$t$')
                plt.ylabel('$Y_t$')
                plt.tight_layout()
                plt.savefig(img_path, dpi=args.dpi)
                plt.close()
                logging.info(f'Saved figure at: {img_path}')

                if args.save_ckpt:
                    torch.save({'model': model.state_dict()},
                               os.path.join(ckpt_dir,
                                            f'global_step_{global_step}.ckpt'))

        # Train.
        optimizer.zero_grad()
        zs, log_ratio = model(ts=ts_ext, batch_size=args.batch_size)
        zs = zs.squeeze()
        zs = zs[
            1:
            -1]  # Drop first and last which are only used to penalize out-of-data region and spread uncertainty.

        likelihood = {
            "laplace": Laplace(loc=zs, scale=args.scale),
            "normal": Normal(loc=zs, scale=args.scale)
        }[args.likelihood]
        logp = likelihood.log_prob(ys).sum(dim=0).mean(dim=0)

        loss = -logp + log_ratio * kl_scheduler()
        loss.backward()
        optimizer.step()
        scheduler.step()
        kl_scheduler.step()

        logp_metric.step(logp)
        log_ratio_metric.step(log_ratio)
        loss_metric.step(loss)

        logging.info(
            f'global_step: {global_step}, '
            f'logp: {logp_metric.val():.3f}, log_ratio: {log_ratio_metric.val():.3f}, loss: {loss_metric.val():.3f}'
        )
Ejemplo n.º 19
0
 def __init__(self, loc, scale):
     self._normal = Independent(Normal(loc, scale), 1)
     super().__init__()
Ejemplo n.º 20
0
 def __init__(self,
              data,
              G,
              F_att,
              F_lat,
              D,
              D_dis,
              params,
              to_gpu,
              logger=None):
     self.data = data
     self.G = G
     self.F_att = F_att
     self.F_lat = F_lat
     self.D = D
     self.D_dis = D_dis
     self.mods = (G, F_att, F_lat, D, D_dis)
     self.mod_names = ('G', 'F_att', 'F_lat', 'D', 'D_dis')
     self.params = params
     self.to_gpu = to_gpu
     if params['pz']:  # True = Uniform
         self.z_rand = Uniform(to_gpu(torch.tensor(0.0)),
                               to_gpu(torch.tensor(1.0)))
     else:  # False = Normal
         self.z_rand = Normal(to_gpu(torch.tensor(0.0)),
                              to_gpu(torch.tensor(1.0)))
     if params['divergence'] == 'JS' or params['divergence'] == 'standard':
         loss = nn.BCEWithLogitsLoss()
         self.criterion = lambda dec, label: -loss(dec, label)
     elif params['divergence'] == 'Wasserstein':
         self.criterion = lambda dec, label: torch.mean(
             dec * (2. * label - 1.)
         )  #loss(dec, label) #torch.sum(dec)  #torch.sum(dec*(2.*label-1.))
     if params['att_type'] == 0:
         ycounts = np.stack(data.ycounts, axis=0)
         wrel = self.to_gpu(
             Variable(
                 torch.from_numpy(ycounts[:, 0] / ycounts[:, 1]).float()))
         # wneg = self.to_gpu(Variable(torch.from_numpy(1./ycounts[:,0]).float()))
         # wpos = self.to_gpu(Variable(torch.from_numpy(1./ycounts[:,1]).float()))
         # self.att_loss = lambda pred, true: torch.mean(-(true*torch.log(pred+1e-10) + (1-true)*torch.log(1-pred+1e-10)))
         # self.att_loss = lambda pred, true: torch.mean(-(wpos*true*torch.log(pred+1e-10) + wneg*(1-true)*torch.log(1-pred+1e-10)))
         # print('ycounts0')
         # print(ycounts[:,0])
         # print('ycounts1')
         # print(ycounts[:,1])
         # print('wrel')
         # print(wrel.cpu().data.numpy())
         self.att_loss = lambda pred, true: torch.mean(
             -(wrel * true * torch.log(pred + 1e-10) +
               (1 - true) * torch.log(1 - pred + 1e-10)))
     elif params['att_type'] == 1:
         m = self.to_gpu(
             Variable(torch.from_numpy(data.powerfits[:, 0]).float()))
         b = self.to_gpu(
             Variable(torch.from_numpy(data.powerfits[:, 1]).float()))
         self.att_loss = lambda pred, true: torch.mean(
             torch.clamp(1e-3 * torch.exp(-m * true - b), 0., 100.) *
             (pred - true)**2.)
     elif params['att_type'] == 2:
         m = self.to_gpu(
             Variable(torch.from_numpy(data.powerfits[:, 0]).float()))
         b = self.to_gpu(
             Variable(torch.from_numpy(data.powerfits[:, 1]).float()))
         self.att_loss = lambda pred, true: torch.mean(
             torch.clamp(1e-3 * torch.exp(-m * true - b), 0., 100.) * torch.
             abs(pred - true))
     else:
         self.att_loss = lambda pred, true: 0 * torch.mean(pred)
     self.logger = logger
     self.feature_mask = self.load_feature_mask()
     self.feature_means = self.load_feature_means()
Ejemplo n.º 21
0
 def density(self, state):
     loc = self.mean(state)
     scale = th.exp(th.clamp(self.sigma, min=math.log(EPSILON)))
     return Normal(loc=loc, scale=scale)
Ejemplo n.º 22
0
# In[27]:

start = time.time()
for i in range(NITER):
    rows, epoch_no, sgd_scale = yield_minibatch_rows(i, N, MINIBATCH)
    if i < 20 or i % 500 == 0:
        print2("[%.2fs] %i. iteration, %i. epoch" %
               (time.time() - start, i, epoch_no))

    #######################################################
    # preparation: selecting minibatch rows

    qz_loc0 = qz_loc[rows, :]
    qz_scale0 = qz_scale[rows, :]
    qw = Normal(qw_loc, softplus(qw_scale))
    qz = Normal(qz_loc0, softplus(qz_scale0))

    h0 = h[rows, :]
    x0 = x[rows, :]
    training_mask0 = training_mask[rows, :]
    utility_term_mask0 = utility_term_mask[rows, :]

    #######################################################
    # optimization

    w = qw.rsample(torch.Size([NSAMPLES]))
    z = qz.rsample(torch.Size([NSAMPLES]))
    elbo = model_log_prob(x0, w, z, training_mask0, sgd_scale).sum(
    ) - qw.log_prob(w).sum() - qz.log_prob(z).sum() * sgd_scale
    elbo = elbo / NSAMPLES
Ejemplo n.º 23
0
    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
        pro_recons_weight=1.0,  # double check these defaults
        kl_weight=1.0,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor,
               torch.FloatTensor]:
        """
        Returns the reconstruction loss and the Kullback divergences.

        Parameters
        ----------
        x
            tensor of values with shape ``(batch_size, n_input_genes)``
        y
            tensor of values with shape ``(batch_size, n_input_proteins)``
        batch_index
            array that indicates which batch the cells belong to with shape ``batch_size``
        label
            tensor of cell-types labels with shape (batch_size, n_labels)

        Returns
        -------
        type
            the reconstruction loss and the Kullback divergences
        """
        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]
        ql_m = inference_outputs["ql_m"]
        ql_v = inference_outputs["ql_v"]
        px_ = generative_outputs["px_"]
        py_ = generative_outputs["py_"]

        x = tensors[REGISTRY_KEYS.X_KEY]
        batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
        y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY]

        if self.protein_batch_mask is not None:
            pro_batch_mask_minibatch = torch.zeros_like(y)
            for b in torch.unique(batch_index):
                b_indices = (batch_index == b).reshape(-1)
                pro_batch_mask_minibatch[b_indices] = torch.tensor(
                    self.protein_batch_mask[b.item()].astype(np.float32),
                    device=y.device,
                )
        else:
            pro_batch_mask_minibatch = None

        reconst_loss_gene, reconst_loss_protein = self.get_reconstruction_loss(
            x, y, px_, py_, pro_batch_mask_minibatch)

        # KL Divergence
        kl_div_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1)).sum(dim=1)
        if not self.use_observed_lib_size:
            n_batch = self.library_log_means.shape[1]
            local_library_log_means = F.linear(one_hot(batch_index, n_batch),
                                               self.library_log_means)
            local_library_log_vars = F.linear(one_hot(batch_index, n_batch),
                                              self.library_log_vars)
            kl_div_l_gene = kl(
                Normal(ql_m, torch.sqrt(ql_v)),
                Normal(local_library_log_means,
                       torch.sqrt(local_library_log_vars)),
            ).sum(dim=1)
        else:
            kl_div_l_gene = 0.0

        kl_div_back_pro_full = kl(Normal(py_["back_alpha"], py_["back_beta"]),
                                  self.back_mean_prior)
        if pro_batch_mask_minibatch is not None:
            kl_div_back_pro = torch.zeros_like(kl_div_back_pro_full)
            kl_div_back_pro.masked_scatter_(pro_batch_mask_minibatch.bool(),
                                            kl_div_back_pro_full)
            kl_div_back_pro = kl_div_back_pro.sum(dim=1)
        else:
            kl_div_back_pro = kl_div_back_pro_full.sum(dim=1)
        loss = torch.mean(reconst_loss_gene +
                          pro_recons_weight * reconst_loss_protein +
                          kl_weight * kl_div_z + kl_div_l_gene +
                          kl_weight * kl_div_back_pro)

        reconst_losses = dict(
            reconst_loss_gene=reconst_loss_gene,
            reconst_loss_protein=reconst_loss_protein,
        )
        kl_local = dict(
            kl_div_z=kl_div_z,
            kl_div_l_gene=kl_div_l_gene,
            kl_div_back_pro=kl_div_back_pro,
        )

        return LossRecorder(loss,
                            reconst_losses,
                            kl_local,
                            kl_global=torch.tensor(0.0))
Ejemplo n.º 24
0
def metrify(obs, steps, wall_start, all_actions, all_pol_stats, all_stds,
            all_means, all_rewards, all_scenario_wins_rewards,
            all_final_reward, all_q_vals, to_plot):

    m1 = (all_means[-1][0], all_means[-1][1])

    s1 = np.eye(2)
    s1[0][0] = all_stds[-1][0]
    s1[1][1] = all_stds[-1][1]

    k1 = multivariate_normal(mean=m1, cov=s1)

    #import pudb; pudb.set_trace()
    # create a grid of (x,y) coordinates at which to evaluate the kernels
    xlim = (-2.5, 2.5)
    ylim = (-2.5, 2.5)
    xres = 100
    yres = 100

    x = np.linspace(xlim[0], xlim[1], xres)
    y = np.linspace(ylim[0], ylim[1], yres)
    xx, yy = np.meshgrid(x, y)

    # evaluate kernels at grid points
    xxyy = np.c_[xx.ravel(), yy.ravel()]
    zz = k1.pdf(xxyy)  #+ k2.pdf(xxyy)

    x, y = np.mgrid[-2.75:2.75:.05, -2.75:2.75:.05]
    pos = np.empty(x.shape + (2, ))
    pos[:, :, 0] = x
    pos[:, :, 1] = y

    z = k1.pdf(pos)
    z_th = np.tanh(z)

    fig, ax = plt.subplots()

    ax.contourf(x, y, z_th)
    ax.plot(all_pol_stats[-1][0],
            all_pol_stats[-1][1],
            marker='o',
            markersize=3,
            color="red")
    ax.set_xlabel('Throttle')
    ax.set_ylabel('Steering')

    pil_plot = fig2img(fig).resize((400, 300), Image.ANTIALIAS)

    plot_img = np.asarray(pil_plot)[..., :3]

    pil_obs = transforms.ToPILImage()(obs[0][0])
    _draw = ImageDraw.Draw(pil_obs)

    _draw.text(
        (5, 10),
        'FPS: %.3f, steps: %.3f' % (steps / (time.time() - wall_start), steps))

    _draw.text((5, 30), 'Steer: %.3f' % all_pol_stats[-1][0])

    _draw.text((5, 50), 'Throttle: %.3f' % all_pol_stats[-1][1])

    #import pudb; pudb.set_trace()
    entropy = Normal(torch.FloatTensor(all_means[-1]),
                     torch.FloatTensor(all_stds[-1])).entropy()
    entropy1, entropy2 = entropy[0].item(), entropy[1].item()
    _draw.text((5, 70), 'Entropy: {:.3f}; {:.3f}'.format(entropy1, entropy2))

    _combined = Image.fromarray(np.hstack((plot_img, np.asarray(pil_obs))))

    cv2.imshow('Sensors', np.asarray(_combined))

    fig2, ax2 = plt.subplots()

    ax2.plot(np.clip(all_rewards[:, [0]], -1, 1.5), label='Speed')
    ax2.plot(np.clip(all_rewards[:, [1]], -1, 1.5), label='Time')
    ax2.plot(np.clip(all_rewards[:, [2]], -1, 1.5), label='Distance')
    ax2.plot(np.clip(all_rewards[:, [3]], -1, 1.5), label='Collision')
    ax2.plot(np.clip(all_rewards[:, [4]], -1, 1.5), label='Lane')

    ax2.plot(np.clip(all_q_vals, -1, 1.5), label='QVal')

    ax2.plot(np.clip(all_final_reward, -1, 1.5), label='Final R')
    #ax2.plot(all_rewards[:, []], label='Lane   ')

    if (to_plot != []):
        np_arr_to_plot = np.asarray(to_plot)
        ax2.plot(np_arr_to_plot[:, [0]],
                 np_arr_to_plot[:, [1]],
                 marker='o',
                 markersize=3,
                 color="red")
        to_plot = []

    plt.legend()

    ax2.set_xlabel('Step')
    ax2.set_ylabel('Reward')

    pil_plot2 = fig2img(fig2)

    plot_img2 = np.asarray(pil_plot2)[..., :3]

    cv2.imshow('Rewards', cv2.cvtColor(np.asarray(plot_img2),
                                       cv2.COLOR_BGR2RGB))

    cv2.waitKey(1)
    #plt.title('Carla Car Control')
    #fig2.savefig('./data/reward/reward_{}.png'.format(total_steps))
    """
    with open('test_all_data_{}.npy'.format(total_steps), 'wb') as f:
        np.save(f, np.asarray(all_actions))
        np.save(f, np_all_rewards)
        np.save(f, np.asarray(all_pol_stats))
        np.save(f, np.asarray(all_stds))
        np.save(f, np.asarray(all_means))
        np.save(f, np.asarray(all_scenario_wins_rewards))
        np.save(f, np.asarray(all_final_rewards))
    """

    fig3, ax3 = plt.subplots()
    ax3.plot(np.cumsum(all_final_reward), label='Cumulative Reward')

    pil_plot3 = fig2img(fig3)

    plot_img3 = np.asarray(pil_plot3)[..., :3]

    cv2.imshow('Cumulative Reward', np.asarray(plot_img3))

    plt.close(fig2)
    plt.close(fig)
    plt.close(fig3)

    if (steps != 0 and steps % 200 == 0):
        all_actions = []
        all_rewards = []
        all_pol_stats = []
        all_stds = []
        all_means = []
        all_scenario_wins_rewards = []
        all_final_reward = []
        all_q_vals = []
Ejemplo n.º 25
0
    def forward(self, z: torch.Tensor, library_gene: torch.Tensor,
                *cat_list: int):
        """
        The forward computation for a single sample.

         #. Decodes the data from the latent space using the decoder network
         #. Returns local parameters for the ZINB distribution for genes
         #. Returns local parameters for the Mixture NB distribution for proteins

         We use the dictionary `px_` to contain the parameters of the ZINB/NB for genes.
         The rate refers to the mean of the NB, dropout refers to Bernoulli mixing parameters.
         `scale` refers to the quanity upon which differential expression is performed. For genes,
         this can be viewed as the mean of the underlying gamma distribution.

         We use the dictionary `py_` to contain the parameters of the Mixture NB distribution for proteins.
         `rate_fore` refers to foreground mean, while `rate_back` refers to background mean. `scale` refers to
         foreground mean adjusted for background probability and scaled to reside in simplex.
         `back_alpha` and `back_beta` are the posterior parameters for `rate_back`.  `fore_scale` is the scaling
         factor that enforces `rate_fore` > `rate_back`.

        Parameters
        ----------
        z
            tensor with shape ``(n_input,)``
        library_gene
            library size
        cat_list
            list of category membership(s) for this sample

        Returns
        -------
        3-tuple (first 2-tuple :py:class:`dict`, last :py:class:`torch.Tensor`)
            parameters for the ZINB distribution of expression

        """
        px_ = {}
        py_ = {}

        px = self.px_decoder(z, *cat_list)
        px_cat_z = torch.cat([px, z], dim=-1)
        unnorm_px_scale = self.px_scale_decoder(px_cat_z, *cat_list)
        if self.use_softmax:
            px_["scale"] = nn.Softmax(dim=-1)(unnorm_px_scale)
        else:
            px_["scale"] = torch.exp(unnorm_px_scale)
        px_["rate"] = library_gene * px_["scale"]

        py_back = self.py_back_decoder(z, *cat_list)
        py_back_cat_z = torch.cat([py_back, z], dim=-1)

        py_["back_alpha"] = self.py_back_mean_log_alpha(
            py_back_cat_z, *cat_list)
        py_["back_beta"] = torch.exp(
            self.py_back_mean_log_beta(py_back_cat_z, *cat_list))
        log_pro_back_mean = Normal(py_["back_alpha"],
                                   py_["back_beta"]).rsample()
        py_["rate_back"] = torch.exp(log_pro_back_mean)

        py_fore = self.py_fore_decoder(z, *cat_list)
        py_fore_cat_z = torch.cat([py_fore, z], dim=-1)
        py_["fore_scale"] = (
            self.py_fore_scale_decoder(py_fore_cat_z, *cat_list) + 1 + 1e-8)
        py_["rate_fore"] = py_["rate_back"] * py_["fore_scale"]

        p_mixing = self.sigmoid_decoder(z, *cat_list)
        p_mixing_cat_z = torch.cat([p_mixing, z], dim=-1)
        px_["dropout"] = self.px_dropout_decoder_gene(p_mixing_cat_z,
                                                      *cat_list)
        py_["mixing"] = self.py_background_decoder(p_mixing_cat_z, *cat_list)

        protein_mixing = 1 / (1 + torch.exp(-py_["mixing"]))
        py_["scale"] = torch.nn.functional.normalize(
            (1 - protein_mixing) * py_["rate_fore"], p=1, dim=-1)

        return (px_, py_, log_pro_back_mean)
Ejemplo n.º 26
0
Archivo: vae.py Proyecto: Juan-JV/sbVAE
    def from_variables_to_densities(
        self,
        x,
        local_l_mean,
        local_l_var,
        px_r,
        px_rate,
        px_dropout,
        z,
        library,
        px_scale=None,
        qz_m=None,
        qz_v=None,
        ql_m=None,
        ql_v=None,
        log_qz_x=None,
        log_ql_x=None,
    ):
        """
        Unifies VAE outputs to construct loss

        :param x:
        :param local_l_mean:
        :param local_l_var:
        :param px_r:
        :param px_rate:
        :param px_dropout:
        :param qz_m:
        :param qz_v:
        :param z:
        :param ql_m:
        :param ql_v:
        :param library:
        :param log_qz_x:
        :param log_ql_x:
        :return:
        """
        log_px_zl = (-1) * self._reconstruction_loss(x, px_rate, px_r,
                                                     px_dropout)
        if log_qz_x is None:
            log_qz_x = Normal(qz_m, torch.sqrt(qz_v)).log_prob(z).sum(dim=-1)
        if log_ql_x is None:
            log_ql_x = Normal(ql_m,
                              torch.sqrt(ql_v)).log_prob(library).sum(dim=-1)
        log_pz = Normal(torch.zeros_like(z),
                        torch.ones_like(z)).log_prob(z).sum(dim=-1)
        log_pl = (Normal(
            local_l_mean,
            torch.sqrt(local_l_var)).log_prob(library).sum(dim=-1))

        a1_issue = torch.isnan(log_px_zl).any() or torch.isinf(log_px_zl).any()
        a2_issue = torch.isnan(log_pl).any() or torch.isinf(log_pl).any()
        a3_issue = torch.isnan(log_pz).any() or torch.isinf(log_pz).any()
        a4_issue = torch.isnan(log_qz_x).any() or torch.isinf(log_qz_x).any()
        a5_issue = torch.isnan(log_ql_x).any() or torch.isinf(log_ql_x).any()

        if a1_issue or a2_issue or a3_issue or a4_issue or a5_issue:
            print("aie")

        return dict(
            log_px_zl=log_px_zl,
            log_pl=log_pl,
            log_pz=log_pz,
            log_qz_x=log_qz_x,
            log_ql_x=log_ql_x,
        )
Ejemplo n.º 27
0
 def forward(self, h):
     return Normal(self.enc_mean(h), self.enc_std(h))
Ejemplo n.º 28
0
Archivo: vae.py Proyecto: Juan-JV/sbVAE
    def __init__(
        self,
        n_input: int,
        n_batch: int = 0,
        n_labels: int = 0,
        n_hidden: int = 128,
        n_latent: int = 10,
        n_layers: int = 1,
        iaf_t: int = 0,
        dropout_rate: float = 0.1,
        dispersion: str = "gene",
        log_variational: bool = True,
        reconstruction_loss: str = "zinb",
        prevent_library_saturation: bool = False,
        prevent_library_saturation2: bool = False,
        multi_encoder_keys=["default"],
    ):
        super().__init__()
        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.reconstruction_loss = reconstruction_loss
        # Automatically deactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels
        self.n_latent_layers = 1  # not sure what this is for, no usages?
        self.multi_encoder_keys = multi_encoder_keys
        do_multi_encoders = len(multi_encoder_keys) >= 2
        z_prior_mean = torch.zeros(n_latent, device="cuda")
        z_prior_std = torch.ones(n_latent, device="cuda")
        self.z_prior = Normal(z_prior_mean, z_prior_std)

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        else:  # gene-cell
            pass

        # z encoder goes from the n_input-dimensional data to an n_latent-d
        # latent space representation
        self.do_iaf = iaf_t > 0
        if not self.do_iaf:
            logger.info("using MF encoder")
            self.z_encoder = nn.ModuleDict({
                key: Encoder(
                    n_input,
                    n_latent,
                    n_layers=n_layers,
                    n_hidden=n_hidden,
                    dropout_rate=dropout_rate,
                )
                for key in self.multi_encoder_keys
            })
        else:
            logger.info("using IAF encoder")
            assert not do_multi_encoders
            self.z_encoder = EncoderIAF(
                n_in=n_input,
                n_latent=n_latent,
                n_cat_list=None,
                n_hidden=n_hidden,
                n_layers=n_layers,
                t=iaf_t,
                dropout_rate=dropout_rate,
                use_batch_norm=True,
                do_h=True,
            )
        # l encoder goes from n_input-dimensional data to 1-d library size
        self.l_encoder = Encoder(
            n_input,
            1,
            n_layers=1,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            prevent_saturation=prevent_library_saturation,
            prevent_saturation2=prevent_library_saturation2,
        )
        # decoder goes from n_latent-dimensional space to n_input-d data
        self.decoder = DecoderSCVI(
            n_latent,
            n_input,
            n_cat_list=[n_batch],
            n_layers=n_layers,
            n_hidden=n_hidden,
        )

        assert not self.do_iaf
Ejemplo n.º 29
0
def sample(mod, y=None):
    # Sample posterior model parameters
    idx = [range(mod.N[i]) for i in range(mod.I)]
    params = mod.sample_params(idx)
    ll = []

    # Detach model parameters
    mu0 = -params['delta0'].cumsum(0).detach()
    mu1 = params['delta1'].cumsum(0).detach()
    eta0 = params['eta0'].detach()
    eta1 = params['eta1'].detach()
    sig = params['sig2'].detach().sqrt()
    H = params['H'].detach()
    v = params['v'].detach()

    if mod.use_stick_break:
        # Z = (v.cumprod(0) > Normal(0, 1).cdf(H)).double()
        Z = (v.cumprod(0) > H).double()
    else:
        # Z = (v > Normal(0, 1).cdf(H)).double()
        Z = (v > H).double()

    W = params['W'].detach()

    for i in range(mod.I):
        if y is None:
            # Used the imputed y[i]
            yi = params['y'][i].detach()
        else:
            # Used the user-provided y[i]
            yi = y[i]

        # compute probs
        d0 = Normal(mu0[None, None, :], sig[i]).log_prob(yi[:, :, None])
        d0 += eta0[i:i + 1, :, :].log()

        d1 = Normal(mu1[None, None, :], sig[i]).log_prob(yi[:, :, None])
        d1 += eta1[i:i + 1, :, :].log()

        # Ni x J
        logmix_L0 = torch.logsumexp(d0, 2)
        logmix_L1 = torch.logsumexp(d1, 2)

        # Ni x J x K
        c = Z[None, :, :] * logmix_L1[:, :, None] + (
            1 - Z[None, :, :]) * logmix_L0[:, :, None]

        # Ni x K
        d = c.sum(1)
        # loglike for lam[i]
        lli = d + W[i][None, :].log()

        # TODO: Implement this
        # if mod.model_noisy:
        #     eps_i = params['eps'][i].detach()
        #     lli_quiet = lli + torch.log1p(-eps_i)
        #     lli_noisy = Normal(0, mod.noisy_sd).log_prob(yi).sum(1) + eps_i.log()
        #     # lli = torch.cat([lli_noisy, lli_quiet], dim=-1)
        #     lli = torch.cat([lli_quiet, lli_noisy], dim=-1)
        # else:
        #     ll.append(lli)
        ll.append(lli)

    return ll
    def forward(self, x, l_t, h_1, c_1, h_2, c_2, first=False, last=False):
        """
        Run the recurrent attention model for 1 timestep
        on the minibatch of images `x`.

        Args
        ----
        - x: a 5D Tensor of shape (B, C, H, W, D). The minibatch
          of images.
        - l_t_prev: a 3D tensor of shape (B, 3). The location vector
          containing the glimpse coordinates [x, y,z] for the previous
          timestep t-1.
        - h_1_prev, c_1_prev: a 2D tensor of shape (B, hidden_size). The 
          lower LSTM hidden state vector for the previous timestep t-1.
        - h_2_prev, c_2_prev: a 2D tensor of shape (B, hidden_size). The 
          upper LSTM hidden state vector for the previous timestep t-1.
        - last: a bool indicating whether this is the last timestep.
          If True, the action network returns an output probability
          vector over the classes and the baseline b_t for the
          current timestep t. 
          
        Returns
        -------
        - h_1_t, c_1_t, h_2_t, c_2_t: hidden LSTM states for current step
        - mu: a 3D tensor of shape (B, 3). The mean that parametrizes
          the Gaussian policy.
        - l_t: a 3D tensor of shape (B, 3). The location vector
          containing the glimpse coordinates [x, y,z] for the
          current timestep t.
        - b_t: a vector of length (B,). The baseline for the
          current time step t.
        - log_probas: a 2D tensor of shape (B, num_classes). The
          output log probability vector over the classes.
        - log_pi: a vector of length (B,).
        """
        if first:  #if t=0, return get first location to attend to using the context
            # h_0 = self.context(x)
            # mu_0, l_0 = self.locator(h_0,)
            # return h_0, l_0
            x = x.unsqueeze(1)
            B, C, H, W, D = x.shape
            location = T.Tensor(B, 3)

            # shape of current images [224, 160 , 256]
            # giving center location
            # point_1 = 112
            # point_2 = 80
            # point_3 = 62

            position_1 = 0
            position_2 = 0
            position_3 = 0

            location[:, 0].fill_(position_1)
            location[:, 1].fill_(position_2)
            location[:, 2].fill_(position_3)
            # print(location)
            # print(location.shape)
            l_0 = location.fill_(0)
            # print("[INFO] first location :", l_0)
            h_1, c_1, h_2, c_2 = (T.randn(1, B, self.hidden_size),
                                  T.randn(1, B, self.hidden_size),
                                  T.randn(1, B, self.hidden_size),
                                  T.randn(1, B, self.hidden_size))
            #mu_0, l_0 = self.locator(h_0,)

            return main_train.get_cuda(h_1), main_train.get_cuda(
                c_1), main_train.get_cuda(h_2), main_train.get_cuda(
                    c_2), main_train.get_cuda(l_0)

        # print("[INFO] location :", l_t)
        x = x.unsqueeze(1)
        # l_t = get_cuda(l_t) #Added by FATIH
        g_t = self.sensor(x, l_t)  #,display,axes,labels,dem
        h_1, c_1, h_2, c_2 = self.rnn(g_t.unsqueeze(0), h_1, c_1, h_2, c_2)
        mu, l_t = self.locator(h_2)
        b_t = self.baseliner(h_2).squeeze()

        log_pi = Normal(mu, self.std).log_prob(l_t)
        log_pi = T.sum(log_pi,
                       dim=1)  #policy probabilities for REINFORCE training

        if last:
            log_probas = self.classifier(
                h_1)  # perform classification and get class probabilities
            return h_1, c_1, h_2, c_2, l_t, b_t, log_pi, log_probas

        return h_1, c_1, h_2, c_2, l_t, b_t, log_pi