def loss_function(self, forward_ret, labels=None): (x, t2, qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar, qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2, pt_z2_z1_mu, pt_z2_z1_logvar, pd_x2_z2) = forward_ret # replicate x multiple times x = x[None, ...].expand(self.flags.samples_per_seq, -1, -1, -1) # size: copy, bs, time, dim x2 = torch.gather(x, 2, t2[..., None, None].expand(-1, -1, -1, x.size(3))).view(-1, x.size(3)) batch_size = x2.size(0) kl_div_qs_pb = ops.kl_div_gaussian(qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar).mean() kl_shift_qb_pt = ( ops.gaussian_log_prob(qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2) - ops.gaussian_log_prob(pt_z2_z1_mu, pt_z2_z1_logvar, qb_z2_b2)).mean() bce = F.binary_cross_entropy(pd_x2_z2, x2, reduction='sum') / batch_size bce_optimal = F.binary_cross_entropy( x2, x2, reduction='sum').detach() / batch_size bce_diff = bce - bce_optimal loss = bce_diff + kl_div_qs_pb + kl_shift_qb_pt return loss, bce_diff, kl_div_qs_pb, kl_shift_qb_pt, bce_optimal
def loss_function(self, forward_ret, labels=None): (x, qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2, pd_x2_z2) = forward_ret # replicate x multiple times x_flat = x.flatten(2, -1) x_flat = x_flat.expand(self.flags.samples_per_seq, -1, -1, -1) # size: copy, bs, time, dim batch_size = x.size(0) if self.adversarial and self.model.training: r_in = x.view(x.shape[0], x.shape[2], x.shape[3], x.shape[4]) f_in = pd_x2_z2.view(x.shape[0], x.shape[2], x.shape[3], x.shape[4]) for _ in range(self.d_steps): d_loss, g_loss, hidden_loss = self.dnet.get_loss(r_in, f_in) d_loss.backward(retain_graph=True) # print(d_loss, g_loss) self.adversarial_optim.step() self.adversarial_optim.zero_grad() else: g_loss = 0 hidden_loss = 0 eye = torch.ones(qb_z2_b2.size(-1)).to( qb_z2_b2.device)[None, None, :].expand(-1, qb_z2_b2.size(-2), -1) kl_div_qs_pb = ops.kl_div_gaussian(qb_z2_b2_mu, qb_z2_b2_logvar, 0, eye).mean() target = x.flatten() pred = pd_x2_z2.flatten() bce = F.binary_cross_entropy(pred, target, reduction='sum') / batch_size bce_optimal = F.binary_cross_entropy(target, target, reduction='sum') / batch_size bce_diff = bce - bce_optimal if self.adversarial and self.is_training(): r_in = x.view(x.shape[0], x.shape[2], x.shape[3], x.shape[4]) f_in = pd_x2_z2.view(x.shape[0], x.shape[2], x.shape[3], x.shape[4]) for _ in range(self.d_steps): d_loss, g_loss, hidden_loss = self.dnet.get_loss(r_in, f_in) d_loss.backward(retain_graph=True) # print(d_loss, g_loss) self.adversarial_optim.step() self.adversarial_optim.zero_grad() bce_diff = hidden_loss # XXX bce_diff added twice to loss? else: g_loss = 0 hidden_loss = 0 loss = bce_diff + hidden_loss + self.d_weight * g_loss + self.beta * kl_div_qs_pb return loss, bce_diff, kl_div_qs_pb, 0, bce_optimal
def loss_function(self, forward_ret, labels=None, loss=F.binary_cross_entropy): (x_orig, actions, rewards, done, t1, t2, qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar, qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2, pt_z2_z1_mu, pt_z2_z1_logvar, pd_x2_z2, pd_g2_z2_mu, q1, q2) = forward_ret # replicate x multiple times x = x_orig.flatten(2, -1) x = x[None, ...].expand(self.flags.samples_per_seq, -1, -1, -1) # size: copy, bs, time, dim x2 = torch.gather(x, 2, t2[..., None, None].expand(-1, -1, -1, x.size(3))).view(-1, x.size(3)) kl_div_qs_pb = ops.kl_div_gaussian(qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar) kl_shift_qb_pt = ( ops.gaussian_log_prob(qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2) - ops.gaussian_log_prob(pt_z2_z1_mu, pt_z2_z1_logvar, qb_z2_b2)) pd_x2_z2 = pd_x2_z2.flatten(1, -1) bce = loss(pd_x2_z2, x2, reduction='none').sum(dim=1) bce_optimal = loss(x2, x2, reduction='none').sum(dim=1) bce_diff = bce - bce_optimal if self.rl: # Note: x[t], rewards[t] is a result of actions[t] # Q(s[t], a[t+1]) = r[t+1] + γ max_a Q(s[t+1], a) returns, is_weight = labels # use pd_g2_z2_mu for returns modeling returns_loss = (pd_g2_z2_mu.squeeze(1) - (10.0 * returns))**2 # reward clipping for Atari clipped_rewards = rewards.clamp(-1.0, 1.0) t1_next = t1 + 1 t2_next = t2 + 1 with torch.no_grad(): # size: bs, action_space q1_next_target, q2_next_target = self.target_net.q_and_z_b( x_orig, actions, rewards, done, t1_next, t2_next)[:2] q1_next_index, q2_next_index = self.model.q_and_z_b( x_orig, actions, rewards, done, t1_next, t2_next)[:2] q1_next_index = torch.argmax(q1_next_index, dim=1, keepdim=True) q2_next_index = torch.argmax(q2_next_index, dim=1, keepdim=True) done = done[None, ...].expand(self.flags.samples_per_seq, -1, -1) # size: copy, bs, time done1_next = torch.gather(done, 2, t1_next[..., None]).view(-1) # size: bs done2_next = torch.gather(done, 2, t2_next[..., None]).view(-1) # size: bs # size: copy, bs, time clipped_rewards = clipped_rewards[None, ...].expand( self.flags.samples_per_seq, -1, -1) r1_next = torch.gather(clipped_rewards, 2, t1_next[..., None]).view(-1) # size: bs r2_next = torch.gather(clipped_rewards, 2, t2_next[..., None]).view(-1) # size: bs actions = actions[None, ...].expand(self.flags.samples_per_seq, -1, -1) # size: copy, bs, time a1_next = torch.gather(actions, 2, t1_next[..., None]).view(-1) # size: bs a2_next = torch.gather(actions, 2, t2_next[..., None]).view(-1) # size: bs pred_q1 = torch.gather(q1, 1, a1_next[..., None]).view(-1) pred_q2 = torch.gather(q2, 1, a2_next[..., None]).view(-1) q1_next = torch.gather(q1_next_target, 1, q1_next_index).view(-1) q2_next = torch.gather(q2_next_target, 1, q2_next_index).view(-1) target_q1 = r1_next + self.flags.discount_factor * ( 1.0 - done1_next) * q1_next target_q2 = r2_next + self.flags.discount_factor * ( 1.0 - done2_next) * q2_next rl_loss = 0.5 * ( F.smooth_l1_loss(pred_q1, target_q1, reduction='none') + F.smooth_l1_loss(pred_q2, target_q2, reduction='none')) # errors for prioritized experience replay rl_errors = 0.5 * (torch.abs(pred_q1 - target_q1) + torch.abs(pred_q2 - target_q2)).detach() else: returns_loss = 0.0 rl_loss = 0.0 is_weight = 1.0 rl_errors = 0.0 # multiply is_weight separately for ease of reporting returns_loss = is_weight * returns_loss bce_optimal = is_weight * bce_optimal bce_diff = is_weight * bce_diff kl_div_qs_pb = is_weight * kl_div_qs_pb kl_shift_qb_pt = is_weight * kl_shift_qb_pt rl_loss = is_weight * rl_loss beta = self.beta_decay.get_y(self.get_train_steps()) tdvae_loss = bce_diff + returns_loss + beta * (kl_div_qs_pb + kl_shift_qb_pt) loss = self.flags.tdvae_weight * tdvae_loss + self.flags.rl_weight * rl_loss if self.rl: # workaround to work with non-RL setting rl_loss = rl_loss.mean() returns_loss = returns_loss.mean() return collections.OrderedDict([('loss', loss.mean()), ('bce_diff', bce_diff.mean()), ('returns_loss', returns_loss), ('kl_div_qs_pb', kl_div_qs_pb.mean()), ('kl_shift_qb_pt', kl_shift_qb_pt.mean()), ('rl_loss', rl_loss), ('bce_optimal', bce_optimal.mean()), ('rl_errors', rl_errors)])
def loss_function(self, forward_ret, labels=None, loss=F.binary_cross_entropy): (x_orig, actions, options, rewards, done, t1, t2, t_encodings, qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar, pb_z1_b1, qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2, pt_z2_z1_mu, pt_z2_z1_logvar, pd_x2_z2, pd_g2_z2_mu, q1, q2, option_recon_loss, o_mean, o_logvar, option) = forward_ret # replicate x multiple times x = x_orig.flatten(2, -1) x = x[None, ...].expand(self.flags.samples_per_seq, -1, -1, -1) # size: copy, bs, time, dim x2 = torch.gather(x, 2, t2[..., None, None].expand(-1, -1, -1, x.size(3))).view(-1, x.size(3)) kl_div_qs_pb = ops.kl_div_gaussian(qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar) # kl_div_option = ops.kl_div_gaussian(o_mean, o_logvar) kl_div_option = 0.5 * torch.sum(o_mean**2 + o_logvar.exp() - o_logvar - 1) kl_shift_qb_pt = ( ops.gaussian_log_prob(qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2) - ops.gaussian_log_prob(pt_z2_z1_mu, pt_z2_z1_logvar, qb_z2_b2)) pd_x2_z2 = pd_x2_z2.flatten(1, -1) bce = loss(pd_x2_z2, x2, reduction='none').sum(dim=1) bce_optimal = loss(x2, x2, reduction='none').sum(dim=1) bce_diff = bce - bce_optimal if self.adversarial and self.is_training(): r_in = x2.view(x2.shape[0], x.shape[2], x.shape[3], x.shape[4]) f_in = pd_x2_z2.view(x2.shape[0], x.shape[2], x.shape[3], x.shape[4]) for _ in range(self.d_steps): d_loss, g_loss, hidden_loss = self.dnet.get_loss(r_in, f_in) d_loss.backward(retain_graph=True) self.adversarial_optim.step() self.adversarial_optim.zero_grad() bce_diff = hidden_loss # XXX bce_diff added twice to loss? else: g_loss = 0 hidden_loss = 0 if self.model_based: # pred_z2, pred_g = self.model.predict_forward(pb_z1_b1, options, t_encodings) # with torch.no_grad(): # # size: bs, action_space # t1_next = t1 + 1 # t2_next = t2 + 1 # _, pred_values = self.target_net.q_and_z_b(x_orig, actions, rewards, done, t1_next, # t2_next)[:2] # # target_q2 = r2_next + self.flags.discount_factor * (1.0 - done2_next) * q2_next # Note: x[t], rewards[t] is a result of actions[t] # Q(s[t], a[t+1]) = r[t+1] + γ max_a Q(s[t+1], a) returns, is_weight = labels # use pd_g2_z2_mu for returns modeling returns_loss = (pd_g2_z2_mu.squeeze(1) - (10.0 * returns))**2 # XXX reward clipping hardcoded for Seaquest clipped_rewards = (rewards / 10.0).clamp(0.0, 2.0) t1_next = t1 + 1 t2_next = t2 + 1 with torch.no_grad(): # size: bs, action_space q1_next_target, q2_next_target = self.target_net.q_and_z_b( x_orig, actions, rewards, done, t1_next, t2_next)[:2] q1_next_index, q2_next_index = self.model.q_and_z_b( x_orig, actions, rewards, done, t1_next, t2_next)[:2] q1_next_index = torch.argmax(q1_next_index, dim=1, keepdim=True) q2_next_index = torch.argmax(q2_next_index, dim=1, keepdim=True) done = done[None, ...].expand(self.flags.samples_per_seq, -1, -1) # size: copy, bs, time done1_next = torch.gather(done, 2, t1_next[..., None]).view(-1) # size: bs done2_next = torch.gather(done, 2, t2_next[..., None]).view(-1) # size: bs # size: copy, bs, time clipped_rewards = clipped_rewards[None, ...].expand( self.flags.samples_per_seq, -1, -1) r1_next = torch.gather(clipped_rewards, 2, t1_next[..., None]).view(-1) # size: bs r2_next = torch.gather(clipped_rewards, 2, t2_next[..., None]).view(-1) # size: bs # actions = actions[None, ...].expand(self.flags.samples_per_seq, -1, -1) # size: copy, bs, time # a1_next = torch.gather(actions, 2, t1_next[..., None]).view(-1) # size: bs # a2_next = torch.gather(actions, 2, t2_next[..., None]).view(-1) # size: bs # # pred_q1 = torch.gather(q1, 1, a1_next[..., None]).view(-1) # pred_q2 = torch.gather(q2, 1, a2_next[..., None]).view(-1) q1 = q1.squeeze(-1) q2 = q2.squeeze(-1) q1_next = torch.gather(q1_next_target, 1, q1_next_index).view(-1) q2_next = torch.gather(q2_next_target, 1, q2_next_index).view(-1) target_q1 = r1_next + self.flags.discount_factor * ( 1.0 - done1_next) * q1_next target_q2 = r2_next + self.flags.discount_factor * ( 1.0 - done2_next) * q2_next rl_loss = 0.5 * ( F.smooth_l1_loss(q1, target_q1, reduction='none') + F.smooth_l1_loss(q2, target_q2, reduction='none')) # errors for prioritized experience replay rl_errors = 0.5 * (torch.abs(q1 - target_q1) + torch.abs(q2 - target_q2)).detach() else: returns_loss = 0.0 rl_loss = 0.0 is_weight = 1.0 rl_errors = 0.0 # multiply is_weight separately for ease of reporting is_weight = is_weight.float() returns_loss = is_weight * returns_loss bce_optimal = is_weight * bce_optimal bce_diff = is_weight * bce_diff hidden_loss = is_weight * hidden_loss g_loss = is_weight * g_loss kl_div_qs_pb = is_weight * kl_div_qs_pb kl_shift_qb_pt = is_weight * kl_shift_qb_pt rl_loss = is_weight * rl_loss beta = self.beta_decay.get_y(self.get_train_steps()) tdvae_loss = bce_diff + returns_loss + hidden_loss + self.d_weight * g_loss + beta * ( kl_div_qs_pb + kl_shift_qb_pt) option_loss = option_recon_loss + beta * kl_div_option * 0.001 loss = self.flags.tdvae_weight * tdvae_loss + self.flags.rl_weight * rl_loss + option_loss if self.rl: # workaround to work with non-RL setting rl_loss = rl_loss.mean() returns_loss = returns_loss.mean() return collections.OrderedDict([ ('loss', loss.mean()), ('bce_diff', bce_diff.mean()), ('returns_loss', returns_loss), ('kl_div_qs_pb', kl_div_qs_pb.mean()), ('kl_shift_qb_pt', kl_shift_qb_pt.mean()), ('kl_div_option', kl_div_option.mean()), ('reconstruction_option', option_recon_loss.mean()), ('rl_loss', rl_loss), ('bce_optimal', bce_optimal.mean()), ('rl_errors', rl_errors) ])
def loss_function(self, forward_ret, labels=None): (x_orig, actions, rewards, done, t1, t2, qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar, qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2, pt_z2_z1_mu, pt_z2_z1_logvar, pd_x2_z2, pd_g2_z2_mu, q1, q2) = forward_ret # replicate x multiple times x = x_orig.flatten(3, -1) x = x[None, ...].expand(self.flags.samples_per_seq, -1, -1, -1, -1) # size: copy, bs, time, dim x2 = torch.gather(x, 2, t2[..., None, None, None].expand(-1, -1, -1, x.size(3), x.size(4))) x2 = x2.long().view(-1, x.size(3), x.size(4)) kl_div_qs_pb = ops.kl_div_gaussian(qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar) kl_shift_qb_pt = (ops.gaussian_log_prob(qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2) - ops.gaussian_log_prob(pt_z2_z1_mu, pt_z2_z1_logvar, qb_z2_b2)) ce_1 = F.cross_entropy(pd_x2_z2[0], x2[:, 0]) ce_2 = F.cross_entropy(pd_x2_z2[1], x2[:, 1]) ce_3 = F.cross_entropy(pd_x2_z2[2], x2[:, 2]) obs_ce = F.cross_entropy(pd_x2_z2[3], x2[:, 3])/(x_orig.shape[1]) total_ce = ce_1 + ce_2 + ce_3 + obs_ce if self.adversarial and self.is_training(): r_in = x2.view(x2.shape[0], x.shape[2], x.shape[3], x.shape[4]) f_in = pd_x2_z2.view(x2.shape[0], x.shape[2], x.shape[3], x.shape[4]) for _ in range(self.d_steps): d_loss, g_loss, hidden_loss = self.dnet.get_loss(r_in, f_in) d_loss.backward(retain_graph=True) # print(d_loss, g_loss) self.adversarial_optim.step() self.adversarial_optim.zero_grad() bce_diff = hidden_loss # XXX bce_diff added twice to loss? else: g_loss = 0 hidden_loss = 0 if self.rl: # Note: x[t], rewards[t] is a result of actions[t] # Q(s[t], a[t+1]) = r[t+1] + γ max_a Q(s[t+1], a) returns, is_weight = labels # use pd_g2_z2_mu for returns modeling returns_loss = (pd_g2_z2_mu.squeeze(1) - (10.0 * returns)) ** 2 # reward clipping for Atari clipped_rewards = rewards.clamp(-1.0, 1.0) t1_next = t1 + 1 t2_next = t2 + 1 with torch.no_grad(): # size: bs, action_space q1_next_target, q2_next_target = self.target_net.q_and_z_b(x_orig, actions, rewards, done, t1_next, t2_next)[:2] q1_next_index, q2_next_index = self.model.q_and_z_b(x_orig, actions, rewards, done, t1_next, t2_next)[:2] q1_next_index = torch.argmax(q1_next_index, dim=1, keepdim=True) q2_next_index = torch.argmax(q2_next_index, dim=1, keepdim=True) done = done[None, ...].expand(self.flags.samples_per_seq, -1, -1) # size: copy, bs, time done1_next = torch.gather(done, 2, t1_next[..., None]).view(-1) # size: bs done2_next = torch.gather(done, 2, t2_next[..., None]).view(-1) # size: bs # size: copy, bs, time clipped_rewards = clipped_rewards[None, ...].expand(self.flags.samples_per_seq, -1, -1) r1_next = torch.gather(clipped_rewards, 2, t1_next[..., None]).view(-1) # size: bs r2_next = torch.gather(clipped_rewards, 2, t2_next[..., None]).view(-1) # size: bs actions = actions[None, ...].expand(self.flags.samples_per_seq, -1, -1) # size: copy, bs, time a1_next = torch.gather(actions, 2, t1_next[..., None]).view(-1) # size: bs a2_next = torch.gather(actions, 2, t2_next[..., None]).view(-1) # size: bs pred_q1 = torch.gather(q1, 1, a1_next[..., None]).view(-1) pred_q2 = torch.gather(q2, 1, a2_next[..., None]).view(-1) q1_next = torch.gather(q1_next_target, 1, q1_next_index).view(-1) q2_next = torch.gather(q2_next_target, 1, q2_next_index).view(-1) target_q1 = r1_next + self.flags.discount_factor * (1.0 - done1_next) * q1_next target_q2 = r2_next + self.flags.discount_factor * (1.0 - done2_next) * q2_next rl_loss = 0.5 * (F.smooth_l1_loss(pred_q1, target_q1, reduction='none') + F.smooth_l1_loss(pred_q2, target_q2, reduction='none')) # errors for prioritized experience replay rl_errors = 0.5 * (torch.abs(pred_q1 - target_q1) + torch.abs(pred_q2 - target_q2)).detach() else: returns_loss = 0.0 rl_loss = 0.0 is_weight = 1.0 rl_errors = 0.0 # multiply is_weight separately for ease of reporting returns_loss = is_weight * returns_loss total_ce = is_weight * total_ce hidden_loss = is_weight * hidden_loss g_loss = is_weight * g_loss kl_div_qs_pb = is_weight * kl_div_qs_pb kl_shift_qb_pt = is_weight * kl_shift_qb_pt rl_loss = is_weight * rl_loss beta = self.beta_decay.get_y(self.get_train_steps()) tdvae_loss = total_ce + returns_loss + hidden_loss + self.d_weight * g_loss + beta * (kl_div_qs_pb + kl_shift_qb_pt) loss = self.flags.tdvae_weight * tdvae_loss + self.flags.rl_weight * rl_loss if self.rl: # workaround to work with non-RL setting rl_loss = rl_loss.mean() returns_loss = returns_loss.mean() return collections.OrderedDict([('loss', loss.mean()), ('total_ce', total_ce.mean()), ('returns_loss', returns_loss), ('kl_div_qs_pb', kl_div_qs_pb.mean()), ('kl_shift_qb_pt', kl_shift_qb_pt.mean()), ('rl_loss', rl_loss), # ('bce_optimal', bce_optimal.mean()), ('rl_errors', rl_errors)])