Esempio n. 1
0
 def select_action(self, state, deterministic, reparameterize=False):
     alpha, beta = self.forward(state)
     dist = Beta(concentration1=alpha, concentration0=beta)
     if reparameterize:
         action = dist.rsample()  # (bsize, action_dim)
     else:
         action = dist.sample()  # (bsize, action_dim)
     return action, dist
Esempio n. 2
0
def get_loss(enc,
             dec_ecg,
             dec_demo,
             x_batch_ecg,
             x_batch_other,
             y_batch,
             full=True):
    zmu, zstd, pi_alpha, pi_beta = enc.forward(x_batch_ecg, x_batch_other)

    z_R_obj = LogNormal(zmu[:, 0], zstd[:, 0])
    z_C_obj = LogNormal(zmu[:, 1], zstd[:, 1])
    z_Ts_obj = LogNormal(zmu[:, 2], zstd[:, 2])
    z_Td_obj = LogNormal(zmu[:, 3], zstd[:, 3])
    z_CO_obj = LogNormal(zmu[:, 4], zstd[:, 4])

    pi_obj = Beta(pi_alpha, pi_beta)

    z_R_samp = z_R_obj.rsample()
    z_C_samp = z_C_obj.rsample()
    z_Ts_samp = z_Ts_obj.rsample()
    z_Td_samp = z_Td_obj.rsample()
    z_CO_samp = z_CO_obj.rsample()
    pi_samp = pi_obj.rsample()

    z_samp = torch.cat([
        z_R_samp.view(-1, 1),
        z_C_samp.view(-1, 1),
        z_Ts_samp.view(-1, 1),
        z_Td_samp.view(-1, 1),
        z_CO_samp.view(-1, 1)
    ],
                       dim=1)

    log_p_phi = get_log_p_phi()
    log_p_zpi, zpi_ld = get_log_p_zpi(z_samp, pi_samp, all_dist_params)
    log_p_xz, xz_ld = get_log_p_xz(x_batch_ecg, x_batch_other, z_samp, dec_ecg,
                                   dec_demo)
    log_p_ypi = get_log_p_ypi(pi_samp, y_batch)
    log_qzpi = get_log_qzpi(zmu, zstd, z_samp, pi_alpha, pi_beta, pi_samp)
    log_ppi = get_log_ppi(pi_samp)

    lossdict = {
        'phi_logprob': log_p_phi.item(),
        'ypi_logprob': log_p_ypi.item(),
        'qzpi_logprob': log_qzpi.item(),
        'ppi_logprob': log_ppi.item()
    }

    lossdict.update(zpi_ld)
    lossdict.update(xz_ld)

    if full:
        loss = -1 * (log_ppi + log_p_phi + log_p_zpi + log_p_xz + log_p_ypi -
                     log_qzpi)
    else:
        loss = -1 * (log_ppi + log_p_phi + log_p_zpi + log_p_ypi - log_qzpi)

    lossdict['loss'] = loss.item()

    lossdict = {k: v / len(x_batch_ecg) for (k, v) in lossdict.items()}

    if torch.isnan(loss):
        print("got a nan")

    return loss, lossdict