def select_action(self, state, deterministic, reparameterize=False): alpha, beta = self.forward(state) dist = Beta(concentration1=alpha, concentration0=beta) if reparameterize: action = dist.rsample() # (bsize, action_dim) else: action = dist.sample() # (bsize, action_dim) return action, dist
def get_loss(enc, dec_ecg, dec_demo, x_batch_ecg, x_batch_other, y_batch, full=True): zmu, zstd, pi_alpha, pi_beta = enc.forward(x_batch_ecg, x_batch_other) z_R_obj = LogNormal(zmu[:, 0], zstd[:, 0]) z_C_obj = LogNormal(zmu[:, 1], zstd[:, 1]) z_Ts_obj = LogNormal(zmu[:, 2], zstd[:, 2]) z_Td_obj = LogNormal(zmu[:, 3], zstd[:, 3]) z_CO_obj = LogNormal(zmu[:, 4], zstd[:, 4]) pi_obj = Beta(pi_alpha, pi_beta) z_R_samp = z_R_obj.rsample() z_C_samp = z_C_obj.rsample() z_Ts_samp = z_Ts_obj.rsample() z_Td_samp = z_Td_obj.rsample() z_CO_samp = z_CO_obj.rsample() pi_samp = pi_obj.rsample() z_samp = torch.cat([ z_R_samp.view(-1, 1), z_C_samp.view(-1, 1), z_Ts_samp.view(-1, 1), z_Td_samp.view(-1, 1), z_CO_samp.view(-1, 1) ], dim=1) log_p_phi = get_log_p_phi() log_p_zpi, zpi_ld = get_log_p_zpi(z_samp, pi_samp, all_dist_params) log_p_xz, xz_ld = get_log_p_xz(x_batch_ecg, x_batch_other, z_samp, dec_ecg, dec_demo) log_p_ypi = get_log_p_ypi(pi_samp, y_batch) log_qzpi = get_log_qzpi(zmu, zstd, z_samp, pi_alpha, pi_beta, pi_samp) log_ppi = get_log_ppi(pi_samp) lossdict = { 'phi_logprob': log_p_phi.item(), 'ypi_logprob': log_p_ypi.item(), 'qzpi_logprob': log_qzpi.item(), 'ppi_logprob': log_ppi.item() } lossdict.update(zpi_ld) lossdict.update(xz_ld) if full: loss = -1 * (log_ppi + log_p_phi + log_p_zpi + log_p_xz + log_p_ypi - log_qzpi) else: loss = -1 * (log_ppi + log_p_phi + log_p_zpi + log_p_ypi - log_qzpi) lossdict['loss'] = loss.item() lossdict = {k: v / len(x_batch_ecg) for (k, v) in lossdict.items()} if torch.isnan(loss): print("got a nan") return loss, lossdict