def test_non_mean_field_bern_normal_elbo_gradient(enumerate1, pi1, pi2, pi3, include_z=True): pyro.clear_param_store() num_particles = 10000 def model(): with pyro.iarange("particles", num_particles): q3 = pyro.param("q3", torch.tensor(pi3, requires_grad=True)) y = pyro.sample("y", dist.Bernoulli(q3).expand_by([num_particles])) if include_z: pyro.sample("z", dist.Normal(0.55 * y + q3, 1.0)) def guide(): q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True)) q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True)) with pyro.iarange("particles", num_particles): y = pyro.sample("y", dist.Bernoulli(q1).expand_by([num_particles]), infer={"enumerate": enumerate1}) if include_z: pyro.sample("z", dist.Normal(q2 * y + 0.10, 1.0)) logger.info("Computing gradients using surrogate loss") elbo = TraceEnum_ELBO(max_iarange_nesting=1, strict_enumeration_warning=any([enumerate1])) elbo.loss_and_grads(model, guide) actual_grad_q1 = pyro.param('q1').grad / num_particles if include_z: actual_grad_q2 = pyro.param('q2').grad / num_particles actual_grad_q3 = pyro.param('q3').grad / num_particles logger.info("Computing analytic gradients") q1 = torch.tensor(pi1, requires_grad=True) q2 = torch.tensor(pi2, requires_grad=True) q3 = torch.tensor(pi3, requires_grad=True) elbo = kl_divergence(dist.Bernoulli(q1), dist.Bernoulli(q3)) if include_z: elbo = elbo + q1 * kl_divergence(dist.Normal(q2 + 0.10, 1.0), dist.Normal(q3 + 0.55, 1.0)) elbo = elbo + (1.0 - q1) * kl_divergence(dist.Normal(0.10, 1.0), dist.Normal(q3, 1.0)) expected_grad_q1, expected_grad_q2, expected_grad_q3 = grad(elbo, [q1, q2, q3]) else: expected_grad_q1, expected_grad_q3 = grad(elbo, [q1, q3]) prec = 0.04 if enumerate1 is None else 0.02 assert_equal(actual_grad_q1, expected_grad_q1, prec=prec, msg="".join([ "\nq1 expected = {}".format(expected_grad_q1.data.cpu().numpy()), "\nq1 actual = {}".format(actual_grad_q1.data.cpu().numpy()), ])) if include_z: assert_equal(actual_grad_q2, expected_grad_q2, prec=prec, msg="".join([ "\nq2 expected = {}".format(expected_grad_q2.data.cpu().numpy()), "\nq2 actual = {}".format(actual_grad_q2.data.cpu().numpy()), ])) assert_equal(actual_grad_q3, expected_grad_q3, prec=prec, msg="".join([ "\nq3 expected = {}".format(expected_grad_q3.data.cpu().numpy()), "\nq3 actual = {}".format(actual_grad_q3.data.cpu().numpy()), ]))
def test_elbo_bern(quantity, enumerate1): pyro.clear_param_store() num_particles = 1 if enumerate1 else 10000 prec = 0.001 if enumerate1 else 0.1 q = pyro.param("q", torch.tensor(0.5, requires_grad=True)) kl = kl_divergence(dist.Bernoulli(q), dist.Bernoulli(0.25)) def model(): with pyro.iarange("particles", num_particles): pyro.sample("z", dist.Bernoulli(0.25).expand_by([num_particles])) @config_enumerate(default=enumerate1) def guide(): q = pyro.param("q") with pyro.iarange("particles", num_particles): pyro.sample("z", dist.Bernoulli(q).expand_by([num_particles])) elbo = TraceEnum_ELBO(max_iarange_nesting=1, strict_enumeration_warning=any([enumerate1])) if quantity == "loss": actual = elbo.loss(model, guide) / num_particles expected = kl.item() assert_equal(actual, expected, prec=prec, msg="".join([ "\nexpected = {}".format(expected), "\n actual = {}".format(actual), ])) else: elbo.loss_and_grads(model, guide) actual = q.grad / num_particles expected = grad(kl, [q])[0] assert_equal(actual, expected, prec=prec, msg="".join([ "\nexpected = {}".format(expected.detach().cpu().numpy()), "\n actual = {}".format(actual.detach().cpu().numpy()), ]))
def test_non_mean_field_bern_bern_elbo_gradient(enumerate1, pi1, pi2): pyro.clear_param_store() num_particles = 1 if enumerate1 else 20000 def model(): with pyro.iarange("particles", num_particles): y = pyro.sample("y", dist.Bernoulli(0.33).expand_by([num_particles])) pyro.sample("z", dist.Bernoulli(0.55 * y + 0.10)) def guide(): q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True)) q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True)) with pyro.iarange("particles", num_particles): y = pyro.sample("y", dist.Bernoulli(q1).expand_by([num_particles])) pyro.sample("z", dist.Bernoulli(q2 * y + 0.10)) logger.info("Computing gradients using surrogate loss") elbo = TraceEnum_ELBO(max_iarange_nesting=1, strict_enumeration_warning=any([enumerate1])) elbo.loss_and_grads(model, config_enumerate(guide, default=enumerate1)) actual_grad_q1 = pyro.param('q1').grad / num_particles actual_grad_q2 = pyro.param('q2').grad / num_particles logger.info("Computing analytic gradients") q1 = torch.tensor(pi1, requires_grad=True) q2 = torch.tensor(pi2, requires_grad=True) elbo = kl_divergence(dist.Bernoulli(q1), dist.Bernoulli(0.33)) elbo = elbo + q1 * kl_divergence(dist.Bernoulli(q2 + 0.10), dist.Bernoulli(0.65)) elbo = elbo + (1.0 - q1) * kl_divergence(dist.Bernoulli(0.10), dist.Bernoulli(0.10)) expected_grad_q1, expected_grad_q2 = grad(elbo, [q1, q2]) prec = 0.03 if enumerate1 is None else 0.001 assert_equal(actual_grad_q1, expected_grad_q1, prec=prec, msg="".join([ "\nq1 expected = {}".format(expected_grad_q1.data.cpu().numpy()), "\nq1 actual = {}".format(actual_grad_q1.data.cpu().numpy()), ])) assert_equal(actual_grad_q2, expected_grad_q2, prec=prec, msg="".join([ "\nq2 expected = {}".format(expected_grad_q2.data.cpu().numpy()), "\nq2 actual = {}".format(actual_grad_q2.data.cpu().numpy()), ]))
def test_elbo_categoricals(enumerate1, enumerate2, enumerate3, max_iarange_nesting): pyro.clear_param_store() p1 = torch.tensor([0.6, 0.4]) p2 = torch.tensor([0.3, 0.3, 0.4]) p3 = torch.tensor([0.1, 0.2, 0.3, 0.4]) q1 = pyro.param("q1", torch.tensor([0.4, 0.6], requires_grad=True)) q2 = pyro.param("q2", torch.tensor([0.4, 0.3, 0.3], requires_grad=True)) q3 = pyro.param("q3", torch.tensor([0.4, 0.3, 0.2, 0.1], requires_grad=True)) def model(): pyro.sample("x1", dist.Categorical(p1)) pyro.sample("x2", dist.Categorical(p2)) pyro.sample("x3", dist.Categorical(p3)) def guide(): pyro.sample("x1", dist.Categorical(pyro.param("q1")), infer={"enumerate": enumerate1}) pyro.sample("x2", dist.Categorical(pyro.param("q2")), infer={"enumerate": enumerate2}) pyro.sample("x3", dist.Categorical(pyro.param("q3")), infer={"enumerate": enumerate3}) kl = (kl_divergence(dist.Categorical(q1), dist.Categorical(p1)) + kl_divergence(dist.Categorical(q2), dist.Categorical(p2)) + kl_divergence(dist.Categorical(q3), dist.Categorical(p3))) expected_loss = kl.item() expected_grads = grad(kl, [q1, q2, q3]) elbo = TraceEnum_ELBO(max_iarange_nesting=max_iarange_nesting, strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3])) actual_loss = elbo.loss_and_grads(model, guide) actual_grads = [q1.grad, q2.grad, q3.grad] assert_equal(actual_loss, expected_loss, prec=0.001, msg="".join([ "\nexpected loss = {}".format(expected_loss), "\n actual loss = {}".format(actual_loss), ])) for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_equal(actual_grad, expected_grad, prec=0.001, msg="".join([ "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), ]))
def test_elbo_rsvi(enumerate1): pyro.clear_param_store() num_particles = 40000 prec = 0.01 if enumerate1 else 0.02 q = pyro.param("q", torch.tensor(0.5, requires_grad=True)) a = pyro.param("a", torch.tensor(1.5, requires_grad=True)) kl1 = kl_divergence(dist.Bernoulli(q), dist.Bernoulli(0.25)) kl2 = kl_divergence(dist.Gamma(a, 1.0), dist.Gamma(0.5, 1.0)) def model(): with pyro.iarange("particles", num_particles): pyro.sample("z", dist.Bernoulli(0.25).expand_by([num_particles])) pyro.sample("y", dist.Gamma(0.50, 1.0).expand_by([num_particles])) @config_enumerate(default=enumerate1) def guide(): q = pyro.param("q") a = pyro.param("a") with pyro.iarange("particles", num_particles): pyro.sample("z", dist.Bernoulli(q).expand_by([num_particles])) pyro.sample("y", ShapeAugmentedGamma(a, torch.tensor(1.0)).expand_by([num_particles])) elbo = TraceEnum_ELBO(max_iarange_nesting=1, strict_enumeration_warning=any([enumerate1])) elbo.loss_and_grads(model, guide) actual_q = q.grad / num_particles expected_q = grad(kl1, [q])[0] assert_equal(actual_q, expected_q, prec=prec, msg="".join([ "\nexpected q.grad = {}".format(expected_q.detach().cpu().numpy()), "\n actual q.grad = {}".format(actual_q.detach().cpu().numpy()), ])) actual_a = a.grad / num_particles expected_a = grad(kl2, [a])[0] assert_equal(actual_a, expected_a, prec=prec, msg="".join([ "\nexpected a.grad= {}".format(expected_a.detach().cpu().numpy()), "\n actual a.grad = {}".format(actual_a.detach().cpu().numpy()), ]))
def test_elbo_iarange_iarange(outer_dim, inner_dim, enumerate1, enumerate2, enumerate3, enumerate4): pyro.clear_param_store() num_particles = 1 if all([enumerate1, enumerate2, enumerate3, enumerate4]) else 100000 q = pyro.param("q", torch.tensor(0.75, requires_grad=True)) p = 0.2693204236205713 # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5 def model(): d = dist.Bernoulli(p) with pyro.iarange("particles", num_particles): context1 = pyro.iarange("outer", outer_dim, dim=-2) context2 = pyro.iarange("inner", inner_dim, dim=-3) pyro.sample("w", d.expand_by([num_particles])) with context1: pyro.sample("x", d.expand_by([outer_dim, num_particles])) with context2: pyro.sample("y", d.expand_by([inner_dim, 1, num_particles])) with context1, context2: pyro.sample("z", d.expand_by([inner_dim, outer_dim, num_particles])) def guide(): d = dist.Bernoulli(pyro.param("q")) with pyro.iarange("particles", num_particles): context1 = pyro.iarange("outer", outer_dim, dim=-2) context2 = pyro.iarange("inner", inner_dim, dim=-3) pyro.sample("w", d.expand_by([num_particles]), infer={"enumerate": enumerate1}) with context1: pyro.sample("x", d.expand_by([outer_dim, num_particles]), infer={"enumerate": enumerate2}) with context2: pyro.sample("y", d.expand_by([inner_dim, 1, num_particles]), infer={"enumerate": enumerate3}) with context1, context2: pyro.sample("z", d.expand_by([inner_dim, outer_dim, num_particles]), infer={"enumerate": enumerate4}) kl_node = kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node expected_loss = kl.item() expected_grad = grad(kl, [q])[0] elbo = TraceEnum_ELBO(max_iarange_nesting=3, strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3])) actual_loss = elbo.loss_and_grads(model, guide) / num_particles actual_grad = pyro.param('q').grad / num_particles assert_equal(actual_loss, expected_loss, prec=0.1, msg="".join([ "\nexpected loss = {}".format(expected_loss), "\n actual loss = {}".format(actual_loss), ])) assert_equal(actual_grad, expected_grad, prec=0.1, msg="".join([ "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), ]))
def test_elbo_irange_irange(outer_dim, inner_dim, enumerate1, enumerate2, enumerate3): pyro.clear_param_store() num_particles = 1 if all([enumerate1, enumerate2, enumerate3]) else 50000 q = pyro.param("q", torch.tensor(0.75, requires_grad=True)) p = 0.2693204236205713 # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5 def model(): with pyro.iarange("particles", num_particles): pyro.sample("x", dist.Bernoulli(p).expand_by([num_particles])) inner_irange = pyro.irange("inner", outer_dim) for i in pyro.irange("outer", inner_dim): pyro.sample("y_{}".format(i), dist.Bernoulli(p).expand_by([num_particles])) for j in inner_irange: pyro.sample("z_{}_{}".format(i, j), dist.Bernoulli(p).expand_by([num_particles])) def guide(): q = pyro.param("q") with pyro.iarange("particles", num_particles): pyro.sample("x", dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate1}) inner_irange = pyro.irange("inner", inner_dim) for i in pyro.irange("outer", outer_dim): pyro.sample("y_{}".format(i), dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate2}) for j in inner_irange: pyro.sample("z_{}_{}".format(i, j), dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate3}) kl = (1 + outer_dim * (1 + inner_dim)) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) expected_loss = kl.item() expected_grad = grad(kl, [q])[0] elbo = TraceEnum_ELBO(max_iarange_nesting=1, strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3])) actual_loss = elbo.loss_and_grads(model, guide) / num_particles actual_grad = pyro.param('q').grad / num_particles assert_equal(actual_loss, expected_loss, prec=0.1, msg="".join([ "\nexpected loss = {}".format(expected_loss), "\n actual loss = {}".format(actual_loss), ])) assert_equal(actual_grad, expected_grad, prec=0.1, msg="".join([ "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), ]))
def test_svi_enum(Elbo, irange_dim, enumerate1, enumerate2): pyro.clear_param_store() num_particles = 10 q = pyro.param("q", torch.tensor(0.75), constraint=constraints.unit_interval) p = 0.2693204236205713 # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5 def model(): pyro.sample("x", dist.Bernoulli(p)) for i in pyro.irange("irange", irange_dim): pyro.sample("y_{}".format(i), dist.Bernoulli(p)) def guide(): q = pyro.param("q") pyro.sample("x", dist.Bernoulli(q), infer={"enumerate": enumerate1}) for i in pyro.irange("irange", irange_dim): pyro.sample("y_{}".format(i), dist.Bernoulli(q), infer={"enumerate": enumerate2}) kl = (1 + irange_dim) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) expected_loss = kl.item() expected_grad = grad(kl, [q.unconstrained()])[0] inner_particles = 2 outer_particles = num_particles // inner_particles elbo = TraceEnum_ELBO(max_iarange_nesting=0, strict_enumeration_warning=any([enumerate1, enumerate2]), num_particles=inner_particles) actual_loss = sum(elbo.loss_and_grads(model, guide) for i in range(outer_particles)) / outer_particles actual_grad = q.unconstrained().grad / outer_particles assert_equal(actual_loss, expected_loss, prec=0.3, msg="".join([ "\nexpected loss = {}".format(expected_loss), "\n actual loss = {}".format(actual_loss), ])) assert_equal(actual_grad, expected_grad, prec=0.5, msg="".join([ "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()), "\n actual grad = {}".format(actual_grad.detach().cpu().numpy()), ]))
def test_elbo_berns(enumerate1, enumerate2, enumerate3): pyro.clear_param_store() num_particles = 1 if all([enumerate1, enumerate2, enumerate3]) else 10000 prec = 0.001 if all([enumerate1, enumerate2, enumerate3]) else 0.1 q = pyro.param("q", torch.tensor(0.75, requires_grad=True)) def model(): with pyro.iarange("particles", num_particles): pyro.sample("x1", dist.Bernoulli(0.1).expand_by([num_particles])) pyro.sample("x2", dist.Bernoulli(0.2).expand_by([num_particles])) pyro.sample("x3", dist.Bernoulli(0.3).expand_by([num_particles])) def guide(): q = pyro.param("q") with pyro.iarange("particles", num_particles): pyro.sample("x1", dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate1}) pyro.sample("x2", dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate2}) pyro.sample("x3", dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate3}) kl = sum(kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) for p in [0.1, 0.2, 0.3]) expected_loss = kl.item() expected_grad = grad(kl, [q])[0] elbo = TraceEnum_ELBO(max_iarange_nesting=1, strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3])) actual_loss = elbo.loss_and_grads(model, guide) / num_particles actual_grad = q.grad / num_particles assert_equal(actual_loss, expected_loss, prec=prec, msg="".join([ "\nexpected loss = {}".format(expected_loss), "\n actual loss = {}".format(actual_loss), ])) assert_equal(actual_grad, expected_grad, prec=prec, msg="".join([ "\nexpected grads = {}".format(expected_grad.detach().cpu().numpy()), "\n actual grads = {}".format(actual_grad.detach().cpu().numpy()), ]))
def elbo(self, x: torch.Tensor, p_dists: List, q_dists: List, lv_z: List, lv_g: List, lv_bg: List, pa_recon: List) -> Tuple: bs = x.size(0) p_global_all, p_pres_given_g_probs_reshaped, \ p_where_given_g, p_depth_given_g, p_what_given_g, p_bg = p_dists q_global_all, q_pres_given_x_and_g_probs_reshaped, \ q_where_given_x_and_g, q_depth_given_x_and_g, q_what_given_x_and_g, q_bg = q_dists y, y_nobg, alpha_map, bg = pa_recon if self.args.log.phase_nll: # (bs, dim, num_cell, num_cell) z_pres, _, z_depth, z_what, z_where_origin = lv_z # (bs * num_cell * num_cell, dim) z_pres_reshape = z_pres.permute(0, 2, 3, 1).reshape(-1, self.args.z.z_pres_dim) z_depth_reshape = z_depth.permute(0, 2, 3, 1).reshape( -1, self.args.z.z_depth_dim) z_what_reshape = z_what.permute(0, 2, 3, 1).reshape(-1, self.args.z.z_what_dim) z_where_origin_reshape = z_where_origin.permute( 0, 2, 3, 1).reshape(-1, self.args.z.z_where_dim) # (bs, dim, 1, 1) z_bg = lv_bg[0] # (bs, step, dim, 1, 1) z_g = lv_g[0] else: z_pres, _, _, _, z_where_origin = lv_z z_pres_reshape = z_pres.permute(0, 2, 3, 1).reshape(-1, self.args.z.z_pres_dim) if self.args.train.p_pres_anneal_end_step != 0: self.aux_p_pres_probs = linear_schedule_tensor( self.args.train.global_step, self.args.train.p_pres_anneal_start_step, self.args.train.p_pres_anneal_end_step, self.args.train.p_pres_anneal_start_value, self.args.train.p_pres_anneal_end_value, self.aux_p_pres_probs.device) if self.args.train.aux_p_scale_anneal_end_step != 0: aux_p_scale_mean = linear_schedule_tensor( self.args.train.global_step, self.args.train.aux_p_scale_anneal_start_step, self.args.train.aux_p_scale_anneal_end_step, self.args.train.aux_p_scale_anneal_start_value, self.args.train.aux_p_scale_anneal_end_value, self.aux_p_where_mean.device) self.aux_p_where_mean[:, 0] = aux_p_scale_mean auxiliary_prior_z_pres_probs = self.aux_p_pres_probs[None][ None, :].expand(bs * self.args.arch.num_cell**2, -1) aux_kl_pres = kl_divergence_bern_bern( q_pres_given_x_and_g_probs_reshaped, auxiliary_prior_z_pres_probs) aux_kl_where = kl_divergence( q_where_given_x_and_g, self.aux_p_where) * z_pres_reshape.clamp(min=1e-5) aux_kl_depth = kl_divergence( q_depth_given_x_and_g, self.aux_p_depth) * z_pres_reshape.clamp(min=1e-5) aux_kl_what = kl_divergence( q_what_given_x_and_g, self.aux_p_what) * z_pres_reshape.clamp(min=1e-5) kl_pres = kl_divergence_bern_bern(q_pres_given_x_and_g_probs_reshaped, p_pres_given_g_probs_reshaped) kl_where = kl_divergence(q_where_given_x_and_g, p_where_given_g) kl_depth = kl_divergence(q_depth_given_x_and_g, p_depth_given_g) kl_what = kl_divergence(q_what_given_x_and_g, p_what_given_g) kl_global_all = kl_divergence(q_global_all, p_global_all) if self.args.arch.phase_background: kl_bg = kl_divergence(q_bg, p_bg) aux_kl_bg = kl_divergence(q_bg, self.aux_p_bg) else: kl_bg = self.background.new_zeros(bs, 1) aux_kl_bg = self.background.new_zeros(bs, 1) log_like = Normal(y, self.args.const.likelihood_sigma).log_prob(x) log_imp_list = [] if self.args.log.phase_nll: log_pres_prior = z_pres_reshape * torch.log(p_pres_given_g_probs_reshaped + self.args.const.eps) + \ (1 - z_pres_reshape) * torch.log(1 - p_pres_given_g_probs_reshaped + self.args.const.eps) log_pres_pos = z_pres_reshape * torch.log(q_pres_given_x_and_g_probs_reshaped + self.args.const.eps) + \ (1 - z_pres_reshape) * torch.log( 1 - q_pres_given_x_and_g_probs_reshaped + self.args.const.eps) log_imp_pres = log_pres_prior - log_pres_pos log_imp_depth = p_depth_given_g.log_prob(z_depth_reshape) - \ q_depth_given_x_and_g.log_prob(z_depth_reshape) log_imp_what = p_what_given_g.log_prob(z_what_reshape) - \ q_what_given_x_and_g.log_prob(z_what_reshape) log_imp_where = p_where_given_g.log_prob(z_where_origin_reshape) - \ q_where_given_x_and_g.log_prob(z_where_origin_reshape) if self.args.arch.phase_background: log_imp_bg = p_bg.log_prob(z_bg) - q_bg.log_prob(z_bg) else: log_imp_bg = x.new_zeros(bs, 1) log_imp_g = p_global_all.log_prob(z_g) - q_global_all.log_prob(z_g) log_imp_list = [ log_imp_pres.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(1), log_imp_depth.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(1), log_imp_what.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(1), log_imp_where.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(1), log_imp_bg.flatten(start_dim=1).sum(1), log_imp_g.flatten(start_dim=1).sum(1), ] return log_like.flatten(start_dim=1).sum(1), \ [ aux_kl_pres.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum( -1), aux_kl_where.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum( -1), aux_kl_depth.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum( -1), aux_kl_what.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum( -1), aux_kl_bg.flatten(start_dim=1).sum(-1), kl_pres.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(-1), kl_where.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(-1), kl_depth.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(-1), kl_what.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(-1), kl_global_all.flatten(start_dim=2).sum(-1), kl_bg.flatten(start_dim=1).sum(-1) ], log_imp_list
l2_sum = 0 kl_sum = 0 disc_inv_sum = 0 disc_latent_sum = 0 total_sum = 0 for i in range(args.samples_per_epoch): x,target = next(train_loader) x = x.cuda() target = target.cuda() z_inv,mu_logvar,hidden = dvml_model(x) q_dist = utils.get_normal_from_params(mu_logvar) p_dist = dist.Normal(0,1) #KL Loss kl_loss = dist.kl_divergence(q_dist,p_dist).sum(-1).mean() #Discriminate invar disc_inv_loss = triplet_loss(z_inv,target) if phase1: z_var = q_dist.sample() z = z_inv + z_var decoded = decoder_model(z.detach()) #l2_loss l2_loss = F.mse_loss(hidden,decoded) else: z_var = q_dist.sample((20,)) z = z_inv[None] + z_var decoded = decoder_model(z.reshape([-1,args.z_dims])) decoded = decoded.reshape([20,-1,1024])
def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values # shape = (bs, self.n_agents, -1) mac_out = [] mu_out = [] sigma_out = [] logits_out = [] m_sample_out = [] g_out = [] #reconstruct_losses = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): if self.args.comm and self.args.use_IB: # agent_outs, (mu, sigma), logits, m_sample = self.mac.forward(batch, t=t) # mu_out.append(mu) # sigma_out.append(sigma) # logits_out.append(logits) # m_sample_out.append(m_sample) #reconstruct_losses.append((info['reconstruct_loss']**2).sum()) agent_outs, info = self.mac.forward(batch, t=t) mu_out.append(info['mu']) sigma_out.append(info['sigma']) logits_out.append(info['logits']) m_sample_out.append(info['m_sample']) else: agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time if self.args.use_IB: mu_out = th.stack(mu_out, dim=1)[:, :-1] # Concat over time sigma_out = th.stack(sigma_out, dim=1)[:, :-1] # Concat over time logits_out = th.stack(logits_out, dim=1)[:, :-1] m_sample_out = th.stack(m_sample_out, dim=1)[:, :-1] # Pick the Q-Values for the actions taken by each agent chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim # I believe that code up to here is right... # Q values are right, the main issue is to calculate loss for message... # Calculate the Q-Values necessary for the target target_mac_out = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): if self.args.comm and self.args.use_IB: #target_agent_outs, (target_mu, target_sigma), target_logits, target_m_sample = \ # self.target_mac.forward(batch, t=t) target_agent_outs, target_info = self.target_mac.forward(batch, t=t) else: target_agent_outs = self.target_mac.forward(batch, t=t) target_mac_out.append(target_agent_outs) # label label_target_max_out = th.stack(target_mac_out[:-1], dim=1) label_target_max_out[avail_actions[:, :-1] == 0] = -9999999 label_target_actions = label_target_max_out.max(dim=3, keepdim=True)[1] # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time # Mask out unavailable actions target_mac_out[avail_actions[:, 1:] == 0] = -9999999 # Max over target Q-Values if self.args.double_q: # Get actions that maximise live Q (for double q-learning) mac_out_detach = mac_out.clone().detach() mac_out_detach[avail_actions == 0] = -9999999 cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1] target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) else: target_max_qvals = target_mac_out.max(dim=3)[0] # Mix if self.mixer is not None: chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:]) # Calculate 1-step Q-Learning targets targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals # Td-error td_error = (chosen_action_qvals - targets.detach()) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data loss = (masked_td_error ** 2).sum() / mask.sum() if self.args.only_downstream or not self.args.use_IB: expressiveness_loss = th.Tensor([0.]) compactness_loss = th.Tensor([0.]) entropy_loss = th.Tensor([0.]) comm_loss = th.Tensor([0.]) comm_beta = th.Tensor([0.]) comm_entropy_beta = th.Tensor([0.]) #rec_loss = sum(reconstruct_losses) #loss += 0.01*rec_loss else: # ### Optimize message # Message are controlled only by expressiveness and compactness loss. # Compute cross entropy with target q values of the same time step expressiveness_loss = 0 label_prob = th.gather(logits_out, 3, label_target_actions).squeeze(3) expressiveness_loss += (-th.log(label_prob + 1e-6)).sum() / mask.sum() # Compute KL divergence compactness_loss = D.kl_divergence(D.Normal(mu_out, sigma_out), D.Normal(self.s_mu, self.s_sigma)).sum() / \ mask.sum() # Entropy loss entropy_loss = -D.Normal(self.s_mu, self.s_sigma).log_prob(m_sample_out).sum() / mask.sum() # Gate loss gate_loss = 0 # Total loss comm_beta = self.get_comm_beta(t_env) comm_entropy_beta = self.get_comm_entropy_beta(t_env) comm_loss = expressiveness_loss + comm_beta * compactness_loss + comm_entropy_beta * entropy_loss comm_loss *= self.args.c_beta loss += comm_loss comm_beta = th.Tensor([comm_beta]) comm_entropy_beta = th.Tensor([comm_entropy_beta]) # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.optimiser.step() # Update target if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("comm_loss", comm_loss.item(), t_env) self.logger.log_stat("exp_loss", expressiveness_loss.item(), t_env) self.logger.log_stat("comp_loss", compactness_loss.item(), t_env) self.logger.log_stat("comm_beta", comm_beta.item(), t_env) self.logger.log_stat("entropy_loss", entropy_loss.item(), t_env) self.logger.log_stat("comm_beta", comm_beta.item(), t_env) self.logger.log_stat("comm_entropy_beta", comm_entropy_beta.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) #self.logger.log_stat("reconstruct_loss", rec_loss.item(), t_env) self.log_stats_t = t_env
def run_epoch(experiment, network, optimizer, dataloader, config, use_tqdm = False, debug=False, plot=False): cum_loss = 0.0 cum_param_loss = 0.0 cum_position_loss = 0.0 cum_velocity_loss = 0.0 num_samples=0.0 if use_tqdm: t = tqdm(enumerate(dataloader), total=len(dataloader)) else: t = enumerate(dataloader) network.train() # This is important to call before training! dataloaderlen = len(dataloader) dev = next(network.parameters()).device # we are only doing single-device training for now, so this works fine. dtype = next(network.parameters()).dtype # we are only doing single-device training for now, so this works fine. loss_weights = config["loss_weights"] positionerror = loss_functions.SquaredLpNormLoss().type(dtype).to(dev) #_, _, _, _, _, _, sample_session_times,_,_ = dataloader.dataset[0] bezier_order = network.bezier_order d = network.output_dimension for (i, imagedict) in t: track_names = imagedict["track"] input_images = imagedict["images"].type(dtype).to(device=dev) batch_size = input_images.shape[0] session_times = imagedict["session_times"].type(dtype).to(device=dev) ego_positions = imagedict["ego_positions"].type(dtype).to(device=dev) ego_velocities = imagedict["ego_velocities"].type(dtype).to(device=dev) targets = ego_positions dt = session_times[:,-1]-session_times[:,0] s_torch_cur = (session_times - session_times[:,0,None])/dt[:,None] M, controlpoints_fit = deepracing_models.math_utils.bezier.bezierLsqfit(targets, bezier_order, t = s_torch_cur) Msquare = torch.square(M) means, varfactors, covarfactors = network(input_images) scale_trils = torch.diag_embed(varfactors) + torch.diag_embed(covarfactors, offset=-1) covars = torch.matmul(scale_trils, scale_trils.transpose(2,3)) covars_expand = covars.unsqueeze(1).expand(batch_size, Msquare.shape[1], Msquare.shape[2], d, d) poscovar = torch.sum(Msquare[:,:,:,None,None]*covars_expand, dim=2) posmeans = torch.matmul(M, means) initial_points = targets[:,0].unsqueeze(1) final_points = (dt[:,None]*ego_velocities[:,0]).unsqueeze(1) deltas = final_points - initial_points ds = torch.linspace(0.0,1.0,steps=means.shape[1]) straight_lines = torch.cat([initial_points + t.item()*deltas for t in ds], dim=1) priorscaletril = torch.diag_embed(torch.ones_like(straight_lines)) priorcurves = D.MultivariateNormal(controlpoints_fit, scale_tril=priorscaletril, validate_args=False) distcurves = D.MultivariateNormal(means, scale_tril=scale_trils, validate_args=False) distpos = D.MultivariateNormal(posmeans, covariance_matrix=poscovar, validate_args=False) position_error = positionerror(posmeans, targets) log_probs = distpos.log_prob(ego_positions) NLL = torch.mean(-log_probs) kl_divergences = D.kl_divergence(distcurves, priorcurves) mean_kl = torch.mean(kl_divergences) if debug and plot: fig, (ax1, ax2) = plt.subplots(1, 2, sharey=False) print("position_error: %f" % position_error.item() ) images_np = np.round(255.0*input_images[0].detach().cpu().numpy().copy().transpose(0,2,3,1)).astype(np.uint8) #image_np_transpose=skimage.util.img_as_ubyte(images_np[-1].transpose(1,2,0)) # oap = other_agent_positions[other_agent_positions==other_agent_positions].view(1,-1,60,2) # print(oap) ims = [] for i in range(images_np.shape[0]): ims.append([ax1.imshow(images_np[i])]) ani = animation.ArtistAnimation(fig, ims, interval=250, blit=True, repeat=True) fit_points = torch.matmul(M, controlpoints_fit) prior_points = torch.matmul(M, straight_lines) # gt_points_np = ego_positions[0].detach().cpu().numpy().copy() gt_points_np = targets[0].detach().cpu().numpy().copy() pred_points_np = posmeans[0].detach().cpu().numpy().copy() pred_control_points_np = means[0].detach().cpu().numpy().copy() fit_points_np = fit_points[0].cpu().numpy().copy() fit_control_points_np = controlpoints_fit[0].cpu().numpy().copy() prior_points_np = prior_points[0].cpu().numpy().copy() prior_control_points_np = straight_lines[0].cpu().numpy().copy() ymin = np.min(np.hstack([gt_points_np[:,1], pred_points_np[:,1] ])) - 2.5 ymax = np.max(np.hstack([gt_points_np[:,1], pred_points_np[:,1] ])) + 2.5 xmin = np.min(np.hstack([gt_points_np[:,0], fit_points_np[:,0] ])) - 2.5 xmax = np.max(np.hstack([gt_points_np[:,0], fit_points_np[:,0] ])) ax2.set_xlim(xmax,xmin) ax2.set_ylim(ymin,ymax) ax2.plot(gt_points_np[:,0],gt_points_np[:,1],'g+', label="Ground Truth Waypoints") ax2.plot(pred_points_np[:,0],pred_points_np[:,1],'r-', label="Predicted Bézier Curve") ax2.plot(prior_points_np[:,0],prior_points_np[:,1], label="Prior") # ax2.plot(fit_points_np[:,0],fit_points_np[:,1],'b-', label="Best-fit Bézier Curve") #ax2.scatter(fit_control_points_np[1:,0],fit_control_points_np[1:,1],c="b", label="Bézier Curve's Control Points") # ax2.plot(pred_points_np[:,1],pred_points_np[:,0],'r-', label="Predicted Bézier Curve") # ax2.scatter(pred_control_points_np[:,1],pred_control_points_np[:,0], c='r', label="Predicted Bézier Curve's Control Points") plt.legend() plt.show() # loss = position_error loss = loss_weights["position"]*position_error + loss_weights["nll"]*NLL + loss_weights["kl_divergence"]*mean_kl #loss = loss_weights["position"]*position_error optimizer.zero_grad() loss.backward() # Weight and bias updates. optimizer.step() # logging information current_position_loss_float = float(position_error.item()) num_samples += 1.0 if not debug: experiment.log_metric("current_position_loss", current_position_loss_float) experiment.log_metric("logprob", NLL.item()) experiment.log_metric("kl_divergence", mean_kl.item()) if use_tqdm: t.set_postfix({"current_position_loss" : current_position_loss_float})
def _kl_loss(self, prior_dist, post_dist): # 1 return td.kl_divergence(prior_dist, post_dist).clamp(min=self.kl_free_nats).mean()
def _prod_prod(p, q): assert p.n_dists == q.n_dists, "I need the same number of distributions" return torch.cat([ kl_divergence(pi, qi).unsqueeze(-1) for pi, qi in zip(p.distributions, q.distributions) ], -1).sum(-1)
def _kl_independent_independent(p, q): if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims: raise NotImplementedError result = kl_divergence(p.base_dist, q.base_dist) return _sum_rightmost(result, p.reinterpreted_batch_ndims)
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: actor_losses, vf_losses, step_sizes, kls = [], [], [], [] for step in range(repeat): for b in batch.split(batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient dist = self(b).dist # TODO could come from batch ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) actor_loss = -(ratio * b.adv).mean() flat_grads = self._get_flat_grad(actor_loss, self.actor, retain_graph=True).detach() # direction: calculate natural gradient with torch.no_grad(): old_dist = self(b).dist kl = kl_divergence(old_dist, dist).mean() # calculate first order gradient of kl with respect to theta flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) search_direction = -self._conjugate_gradients( flat_grads, flat_kl_grad, nsteps=10) # stepsize: calculate max stepsize constrained by kl bound step_size = torch.sqrt( 2 * self._delta / (search_direction * self._MVP( search_direction, flat_kl_grad)).sum(0, keepdim=True)) # stepsize: linesearch stepsize with torch.no_grad(): flat_params = torch.cat([ param.data.view(-1) for param in self.actor.parameters() ]) for i in range(self._max_backtracks): new_flat_params = flat_params + step_size * search_direction self._set_from_flat_params(self.actor, new_flat_params) # calculate kl and if in bound, loss actually down new_dist = self(b).dist new_dratio = (new_dist.log_prob(b.act) - b.logp_old).exp().float() new_dratio = new_dratio.reshape( new_dratio.size(0), -1).transpose(0, 1) new_actor_loss = -(new_dratio * b.adv).mean() kl = kl_divergence(old_dist, new_dist).mean() if kl < self._delta and new_actor_loss < actor_loss: if i > 0: warnings.warn(f"Backtracking to step {i}.") break elif i < self._max_backtracks - 1: step_size = step_size * self._backtrack_coeff else: self._set_from_flat_params(self.actor, new_flat_params) step_size = torch.tensor([0.0]) warnings.warn( "Line search failed! It seems hyperparamters" " are poor and need to be changed.") # optimize citirc for _ in range(self._optim_critic_iters): value = self.critic(b.obs).flatten() vf_loss = F.mse_loss(b.returns, value) self.optim.zero_grad() vf_loss.backward() self.optim.step() actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) step_sizes.append(step_size.item()) kls.append(kl.item()) # update learning rate if lr_scheduler is given if self.lr_scheduler is not None: self.lr_scheduler.step() return { "loss/actor": actor_losses, "loss/vf": vf_losses, "step_size": step_sizes, "kl": kls, }
old_agent.load_state_dict(agent.state_dict()) aux_inds = np.arange(args.aux_batch_size, ) print("aux phase starts") for auxiliary_update in range(1, args.e_auxiliary + 1): np.random.shuffle(aux_inds) for i, start in enumerate( range(0, args.aux_batch_size, args.aux_minibatch_size)): end = start + args.aux_minibatch_size aux_minibatch_ind = aux_inds[start:end] try: m_aux_obs = aux_obs[aux_minibatch_ind].to(device) m_aux_returns = aux_returns[aux_minibatch_ind].to(device) new_values = agent.get_value(m_aux_obs).view(-1) new_aux_values = agent.get_aux_value(m_aux_obs).view(-1) kl_loss = td.kl_divergence(old_agent.get_pi(m_aux_obs), agent.get_pi(m_aux_obs)).mean() real_value_loss = 0.5 * ( (new_values - m_aux_returns)**2).mean() aux_value_loss = 0.5 * ( (new_aux_values - m_aux_returns)**2).mean() joint_loss = aux_value_loss + args.beta_clone * kl_loss optimizer.zero_grad() loss = (joint_loss + real_value_loss) / args.n_aux_grad_accum loss.backward() if (i + 1) % args.n_aux_grad_accum == 0: nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) optimizer.step() except RuntimeError:
if __name__ == '__main__': from pyro.distributions import Dirichlet, MultivariateNormal from torch.distributions import kl_divergence from distributions.mixture import Mixture B, D1, D2 = 5, 3, 4 N = 1000 dist1 = MultivariateNormal(torch.zeros(D1), torch.eye(D1)).expand((B, )) dist2 = Dirichlet(torch.ones(D2)).expand((B, )) print(dist1.batch_shape, dist1.event_shape) print(dist2.batch_shape, dist2.event_shape) fact = Factorised([dist1, dist2]) print(fact.batch_shape, fact.event_shape) samples = fact.rsample((N, )) print(samples[0]) print(samples.shape) logp = fact.log_prob(samples) print(logp.shape) entropy = fact.entropy() print(entropy.shape) print(entropy, -logp.mean()) print() print(kl_divergence(fact, fact)) mixture = Mixture(torch.ones(B), fact) samples = mixture.rsample((N, )) logp = mixture.log_prob(samples) print(samples.shape) print(logp.shape)
def _kl_factorised_factorised(p: Factorised, q: Factorised): return sum( kl_divergence(p_factor, q_factor) for p_factor, q_factor in zip(p.factors, q.factors))
def total_kld(posterior, prior=None): if prior is None: prior = standard_prior_like(posterior) return torch.sum(kl_divergence(posterior, prior))
def forward(self, x, img_enc, alpha_map_prop, ids_prop, lengths, t, eps=1e-15): """ :param z_what_prop: (bs, max_num_obj, dim) :param z_where_prop: (bs, max_num_obj, 4) :param z_pres_prop: (bs, max_num_obj, 1) :param alpha_map_prop: (bs, 1, img_h, img_w) """ bs = x.size(0) device = x.device alpha_map_prop = alpha_map_prop.detach() max_num_disc_obj = (self.max_num_obj - lengths).long() self.prior_z_pres_prob = linear_annealing( self.args.global_step, self.z_pres_anneal_start_step, self.z_pres_anneal_end_step, self.z_pres_anneal_start_value, self.z_pres_anneal_end_value, device) # z_where: (bs * num_cell_h * num_cell_w, 4) # z_pres, z_depth, z_pres_logits: (bs, dim, num_cell_h, num_cell_w) z_where, z_pres, z_depth, z_where_mean, z_where_std, \ z_depth_mean, z_depth_std, z_pres_logits, z_pres_y, z_where_origin = self.ProposalNet( img_enc, alpha_map_prop, self.args.tau, t, gen_pres_probs=x.new_ones(1) * self.args.gen_disc_pres_probs, gen_depth_mean=self.prior_depth_mean, gen_depth_std=self.prior_depth_std, gen_where_mean=self.prior_where_mean, gen_where_std=self.prior_where_std ) num_cell_h, num_cell_w = z_pres.shape[2], z_pres.shape[3] q_z_where = Normal(z_where_mean, z_where_std) q_z_depth = Normal(z_depth_mean, z_depth_std) z_pres_orgin = z_pres if self.args.phase_generate and t >= self.args.observe_frames: z_what_mean, z_what_std = self.prior_what_mean.view(1, 1).expand(bs * self.args.num_cell_h * self.args.num_cell_w, z_what_dim), \ self.prior_what_std.view(1, 1).expand(bs * self.args.num_cell_h * self.args.num_cell_w, z_what_dim) x_att = x.new_zeros(1) else: # (bs * num_cell_h * num_cell_w, 3, glimpse_size, glimpse_size) x_att = spatial_transform( torch.stack(num_cell_h * num_cell_w * (x, ), dim=1).view(-1, 3, img_h, img_w), z_where, (bs * num_cell_h * num_cell_w, 3, glimpse_size, glimpse_size), inverse=False) # (bs * num_cell_h * num_cell_w, dim) z_what_mean, z_what_std = self.z_what_net(x_att) z_what_std = F.softplus(z_what_std) q_z_what = Normal(z_what_mean, z_what_std) z_what = q_z_what.rsample() # (bs * num_cell_h * num_cell_w, dim, glimpse_size, glimpse_size) o_att, alpha_att = self.glimpse_dec(z_what) # Rejection if phase_rejection and t > 0: alpha_map_raw = spatial_transform( alpha_att, z_where, (bs * num_cell_h * num_cell_w, 1, img_h, img_w), inverse=True) alpha_map_proposed = (alpha_map_raw > 0.3).float() alpha_map_prop = (alpha_map_prop > 0.1).float().view(bs, 1, 1, img_h, img_w) \ .expand(-1, num_cell_h * num_cell_w, -1, -1, -1).contiguous().view(-1, 1, img_h, img_w) alpha_map_intersect = alpha_map_proposed * alpha_map_prop explained_ratio = alpha_map_intersect.view(bs * num_cell_h * num_cell_w, -1).sum(1) / \ (alpha_map_proposed.view(bs * num_cell_h * num_cell_w, -1).sum(1) + eps) pres_mask = (explained_ratio < self.args.explained_ratio_threshold).view( bs, 1, num_cell_h, num_cell_w).float() z_pres = z_pres * pres_mask # The following "if" is useful only if you don't have high-memery GPUs, better to remove it if you do if self.training and phase_obj_num_contrain: z_pres = z_pres.view(bs, -1) z_pres_threshold = z_pres.sort( dim=1, descending=True)[0][torch.arange(bs), max_num_disc_obj] z_pres_mask = (z_pres > z_pres_threshold.view(bs, -1)).float() if self.args.phase_generate and t >= self.args.observe_frames: z_pres_mask = x.new_zeros(z_pres_mask.size()) z_pres = z_pres * z_pres_mask z_pres = z_pres.view(bs, 1, num_cell_h, num_cell_w) alpha_att_hat = alpha_att * z_pres.view(-1, 1, 1, 1) y_att = alpha_att_hat * o_att # (bs * num_cell_h * num_cell_w, 3, img_h, img_w) y_each_cell = spatial_transform( y_att, z_where, (bs * num_cell_h * num_cell_w, 3, img_h, img_w), inverse=True) # (bs * num_cell_h * num_cell_w, 1, glimpse_size, glimpse_size) importance_map = alpha_att_hat * torch.sigmoid(-z_depth).view( -1, 1, 1, 1) # importance_map = -z_depth.view(-1, 1, 1, 1).expand_as(alpha_att_hat) # (bs * num_cell_h * num_cell_w, 1, img_h, img_w) importance_map_full_res = spatial_transform( importance_map, z_where, (bs * num_cell_h * num_cell_w, 1, img_h, img_w), inverse=True) # (bs * num_cell_h * num_cell_w, 1, img_h, img_w) alpha_map = spatial_transform( alpha_att_hat, z_where, (bs * num_cell_h * num_cell_w, 1, img_h, img_w), inverse=True) # (bs * num_cell_h * num_cell_w, z_what_dim) kl_z_what = kl_divergence(q_z_what, self.p_z_what) * z_pres_orgin.view( -1, 1) # (bs, num_cell_h * num_cell_w, z_what_dim) kl_z_what = kl_z_what.view(-1, num_cell_h * num_cell_w, z_what_dim) # (bs * num_cell_h * num_cell_w, z_depth_dim) kl_z_depth = kl_divergence(q_z_depth, self.p_z_depth) * z_pres_orgin # (bs, num_cell_h * num_cell_w, z_depth_dim) kl_z_depth = kl_z_depth.view(-1, num_cell_h * num_cell_w, z_depth_dim) # (bs, dim, num_cell_h, num_cell_w) kl_z_where = kl_divergence(q_z_where, self.p_z_where) * z_pres_orgin if phase_rejection and t > 0: kl_z_pres = calc_kl_z_pres_bernoulli( z_pres_logits, self.prior_z_pres_prob * pres_mask + self.z_pres_masked_prior * (1 - pres_mask)) else: kl_z_pres = calc_kl_z_pres_bernoulli(z_pres_logits, self.prior_z_pres_prob) kl_z_pres = kl_z_pres.view(-1, num_cell_h * num_cell_w, z_pres_dim) ########################################### Compute log importance ############################################ log_imp = x.new_zeros(bs, 1) if not self.training and self.args.phase_nll: z_pres_orgin_binary = (z_pres_orgin > 0.5).float() # (bs * num_cell_h * num_cell_w, dim) log_imp_what = ( self.p_z_what.log_prob(z_what) - q_z_what.log_prob(z_what)) * z_pres_orgin_binary.view(-1, 1) log_imp_what = log_imp_what.view(-1, num_cell_h * num_cell_w, z_what_dim) # (bs, dim, num_cell_h, num_cell_w) log_imp_depth = (self.p_z_depth.log_prob(z_depth) - q_z_depth.log_prob(z_depth)) * z_pres_orgin_binary # (bs, dim, num_cell_h, num_cell_w) log_imp_where = ( self.p_z_where.log_prob(z_where_origin) - q_z_where.log_prob(z_where_origin)) * z_pres_orgin_binary if phase_rejection and t > 0: p_z_pres = self.prior_z_pres_prob * pres_mask + self.z_pres_masked_prior * ( 1 - pres_mask) else: p_z_pres = self.prior_z_pres_prob z_pres_binary = (z_pres > 0.5).float() log_pres_prior = z_pres_binary * torch.log(p_z_pres + eps) + \ (1 - z_pres_binary) * torch.log(1 - p_z_pres + eps) log_pres_pos = z_pres_binary * torch.log(torch.sigmoid(z_pres_logits) + eps) + \ (1 - z_pres_binary) * torch.log(1 - torch.sigmoid(z_pres_logits) + eps) log_imp_pres = log_pres_prior - log_pres_pos log_imp = log_imp_what.flatten(start_dim=1).sum(dim=1) + log_imp_depth.flatten(start_dim=1).sum(1) + \ log_imp_where.flatten(start_dim=1).sum(1) + log_imp_pres.flatten(start_dim=1).sum(1) ######################################## End of Compute log importance ######################################### # (bs, num_cell_h * num_cell_w) ids = torch.arange(num_cell_h * num_cell_w).view(1, -1).expand(bs, -1).to(x.device).float() + \ ids_prop.max(dim=1, keepdim=True)[0] + 1 if self.args.log_phase: self.log = { 'z_what': z_what, 'z_where': z_where, 'z_pres': z_pres, 'z_pres_logits': z_pres_logits, 'z_what_std': q_z_what.stddev, 'z_what_mean': q_z_what.mean, 'z_where_std': q_z_where.stddev, 'z_where_mean': q_z_where.mean, 'x_att': x_att, 'y_att': y_att, 'prior_z_pres_prob': self.prior_z_pres_prob.unsqueeze(0), 'o_att': o_att, 'alpha_att_hat': alpha_att_hat, 'alpha_att': alpha_att, 'y_each_cell': y_each_cell, 'z_depth': z_depth, 'z_depth_std': q_z_depth.stddev, 'z_depth_mean': q_z_depth.mean, # 'importance_map_full_res_norm': importance_map_full_res_norm, 'z_pres_y': z_pres_y, 'ids': ids } else: self.log = {} return y_each_cell.view(bs, num_cell_h * num_cell_w, 3, img_h, img_w), \ alpha_map.view(bs, num_cell_h * num_cell_w, 1, img_h, img_w), \ importance_map_full_res.view(bs, num_cell_h * num_cell_w, 1, img_h, img_w), \ z_what.view(bs, num_cell_h * num_cell_w, -1), z_where.view(bs, num_cell_h * num_cell_w, -1), \ torch.zeros_like(z_where.view(bs, num_cell_h * num_cell_w, -1)), \ z_depth.view(bs, num_cell_h * num_cell_w, -1), z_pres.view(bs, num_cell_h * num_cell_w, -1), ids, \ kl_z_what.flatten(start_dim=1).sum(dim=1), \ kl_z_where.flatten(start_dim=1).sum(dim=1), \ kl_z_pres.flatten(start_dim=1).sum(dim=1), \ kl_z_depth.flatten(start_dim=1).sum(dim=1), \ log_imp, self.log
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] context = batch['context'] if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ dist, p_z, task_z_with_grad = self.agent( obs, context, return_latent_posterior_and_task_z=True, ) task_z_detached = task_z_with_grad.detach() new_obs_actions, log_pi = dist.rsample_and_logprob() log_pi = log_pi.unsqueeze(1) next_dist = self.agent(next_obs, context) if self._debug_ignore_context: task_z_with_grad = task_z_with_grad * 0 # flattens out the task dimension t, b, _ = obs.size() obs = obs.view(t * b, -1) actions = actions.view(t * b, -1) next_obs = next_obs.view(t * b, -1) unscaled_rewards_flat = rewards.view(t * b, 1) rewards_flat = unscaled_rewards_flat * self.reward_scale terms_flat = terminals.view(t * b, 1) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = self.alpha """ QF Loss """ if self.backprop_q_loss_into_encoder: q1_pred = self.qf1(obs, actions, task_z_with_grad) q2_pred = self.qf2(obs, actions, task_z_with_grad) else: q1_pred = self.qf1(obs, actions, task_z_detached) q2_pred = self.qf2(obs, actions, task_z_detached) # Make sure policy accounts for squashing functions like tanh correctly! new_next_actions, new_log_pi = next_dist.rsample_and_logprob() new_log_pi = new_log_pi.unsqueeze(1) with torch.no_grad(): target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions, task_z_detached), self.target_qf2(next_obs, new_next_actions, task_z_detached), ) - alpha * new_log_pi q_target = rewards_flat + ( 1. - terms_flat) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Context Encoder Loss """ if self._debug_use_ground_truth_context: kl_div = kl_loss = ptu.zeros(0) else: kl_div = kl_divergence(p_z, self.agent.latent_prior).mean(dim=0).sum() kl_loss = self.kl_lambda * kl_div if self.train_context_decoder: # TODO: change to use a distribution reward_pred = self.context_decoder(obs, actions, task_z_with_grad) reward_prediction_loss = ((reward_pred - unscaled_rewards_flat)**2).mean() context_loss = kl_loss + reward_prediction_loss else: context_loss = kl_loss reward_prediction_loss = ptu.zeros(1) """ Policy Loss """ qf1_new_actions = self.qf1(obs, new_obs_actions, task_z_detached) qf2_new_actions = self.qf2(obs, new_obs_actions, task_z_detached) q_new_actions = torch.min( qf1_new_actions, qf2_new_actions, ) # Advantage-weighted regression if self.vf_K > 1: vs = [] for i in range(self.vf_K): u = dist.sample() q1 = self.qf1(obs, u, task_z_detached) q2 = self.qf2(obs, u, task_z_detached) v = torch.min(q1, q2) # v = q1 vs.append(v) v_pi = torch.cat(vs, 1).mean(dim=1) else: # v_pi = self.qf1(obs, new_obs_actions) v1_pi = self.qf1(obs, new_obs_actions, task_z_detached) v2_pi = self.qf2(obs, new_obs_actions, task_z_detached) v_pi = torch.min(v1_pi, v2_pi) u = actions if self.awr_min_q: q_adv = torch.min(q1_pred, q2_pred) else: q_adv = q1_pred policy_logpp = dist.log_prob(u) if self.use_automatic_beta_tuning: buffer_dist = self.buffer_policy(obs) beta = self.log_beta.exp() kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) beta_loss = -1 * (beta * (kldiv - self.beta_epsilon).detach()).mean() self.beta_optimizer.zero_grad() beta_loss.backward() self.beta_optimizer.step() else: beta = self.beta_schedule.get_value(self._n_train_steps_total) beta_loss = ptu.zeros(1) score = q_adv - v_pi if self.mask_positive_advantage: score = torch.sign(score) if self.clip_score is not None: score = torch.clamp(score, max=self.clip_score) weights = batch.get('weights', None) if self.weight_loss and weights is None: if self.normalize_over_batch == True: weights = F.softmax(score / beta, dim=0) elif self.normalize_over_batch == "whiten": adv_mean = torch.mean(score) adv_std = torch.std(score) + 1e-5 normalized_score = (score - adv_mean) / adv_std weights = torch.exp(normalized_score / beta) elif self.normalize_over_batch == "exp": weights = torch.exp(score / beta) elif self.normalize_over_batch == "step_fn": weights = (score > 0).float() elif self.normalize_over_batch == False: weights = score elif self.normalize_over_batch == 'uniform': weights = F.softmax(ptu.ones_like(score) / beta, dim=0) else: raise ValueError(self.normalize_over_batch) weights = weights[:, 0] policy_loss = alpha * log_pi.mean() if self.use_awr_update and self.weight_loss: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp * len(weights) * weights.detach()).mean() elif self.use_awr_update: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp).mean() if self.use_reparam_update: policy_loss = policy_loss + self.train_reparam_weight * ( -q_new_actions).mean() policy_loss = self.rl_weight * policy_loss """ Update networks """ if self._n_train_steps_total % self.q_update_period == 0: if self.train_encoder_decoder: self.context_optimizer.zero_grad() if self.train_agent: self.qf1_optimizer.zero_grad() self.qf2_optimizer.zero_grad() context_loss.backward(retain_graph=True) # retain graph because the encoder is trained by both QF losses qf1_loss.backward(retain_graph=True) qf2_loss.backward() if self.train_agent: self.qf1_optimizer.step() self.qf2_optimizer.step() if self.train_encoder_decoder: self.context_optimizer.step() if self.train_agent: if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() self._num_gradient_steps += 1 """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics['task_embedding/kl_divergence'] = ( ptu.get_numpy(kl_div)) self.eval_statistics['task_embedding/kl_loss'] = ( ptu.get_numpy(kl_loss)) self.eval_statistics['task_embedding/reward_prediction_loss'] = ( ptu.get_numpy(reward_prediction_loss)) self.eval_statistics['task_embedding/context_loss'] = ( ptu.get_numpy(context_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), )) self.eval_statistics.update( create_stats_ordered_dict( 'terminals', ptu.get_numpy(terminals), )) policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") self.eval_statistics.update(policy_statistics) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Weights', ptu.get_numpy(weights), )) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Score', ptu.get_numpy(score), )) self.eval_statistics['reparam_weight'] = self.train_reparam_weight self.eval_statistics['num_gradient_steps'] = ( self._num_gradient_steps) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() if self.use_automatic_beta_tuning: self.eval_statistics.update({ "adaptive_beta/beta": ptu.get_numpy(beta.mean()), "adaptive_beta/beta loss": ptu.get_numpy(beta_loss.mean()), }) self._n_train_steps_total += 1
def kl_divergence(self): return kl_divergence(self.W, self.W_prior)
def draw_message_distributions_tracker1_2d(mu, sigma, save_dir, azim): # sigma *= torch.where(sigma<0.01, torch.ones(sigma.shape).cuda(), 10*torch.ones(sigma.shape).cuda()) mu = mu.view(-1, 2) sigma = sigma.view(-1, 2) s_mu = torch.Tensor([0.0, 0.0]) s_sigma = torch.Tensor([1.0, 1.0]) x = y = np.arange(100) t = np.meshgrid(x, y) d = D.Normal(s_mu, s_sigma) d1 = D.Normal(mu[0], sigma[0]) d21 = D.Normal(mu[2], sigma[2]) d22 = D.Normal(mu[3], sigma[3]) d31 = D.Normal(mu[4], sigma[4]) d32 = D.Normal(mu[5], sigma[5]) print('Entropy') print(d1.entropy().detach().cpu().numpy()) print(d21.entropy().detach().cpu().numpy()) print(d22.entropy().detach().cpu().numpy()) print(d31.entropy().detach().cpu().numpy()) print(d32.entropy().detach().cpu().numpy()) print('KL Divergence') for tt_i in range(3): d1 = D.Normal(mu[tt_i * 2 + 0], sigma[tt_i * 2 + 0]) d2 = D.Normal(mu[tt_i * 2 + 1], sigma[tt_i * 2 + 1]) print(tt_i, D.kl_divergence(d1, d2).sum().detach().cpu().numpy(), D.kl_divergence(d1, d).sum().detach().cpu().numpy(), D.kl_divergence(d2, d).sum().detach().cpu().numpy(), sigma[tt_i * 2 + 0].mean().detach().cpu().numpy(), sigma[tt_i * 2 + 1].mean().detach().cpu().numpy()) # Numpy array of mu and sigma s_mu_ = s_mu.detach().cpu().numpy() mu_0 = mu[0].detach().cpu().numpy() mu_2 = mu[2].detach().cpu().numpy() mu_3 = mu[3].detach().cpu().numpy() mu_4 = mu[4].detach().cpu().numpy() mu_5 = mu[5].detach().cpu().numpy() s_sigma_ = s_sigma.detach().cpu().numpy() sigma_0 = sigma[0].detach().cpu().numpy() sigma_2 = sigma[2].detach().cpu().numpy() sigma_3 = sigma[3].detach().cpu().numpy() sigma_4 = sigma[4].detach().cpu().numpy() sigma_5 = sigma[5].detach().cpu().numpy() # Print print('mu and sigma') print(mu_0, sigma_0) print(mu_2, sigma_2) print(mu_3, sigma_3) print(mu_4, sigma_4) print(mu_5, sigma_5) # Create grid x = np.linspace(-5, 5, 5000) y = np.linspace(-5, 5, 5000) X, Y = np.meshgrid(x, y) pos = np.empty(X.shape + (2, )) pos[:, :, 0] = X pos[:, :, 1] = Y rv = multivariate_normal(s_mu_, [[s_sigma[0], 0], [0, s_sigma[1]]]) # Agent 1 # Create multivariate normal rv1 = multivariate_normal(mu_0, [[sigma_0[0], 0], [0, sigma_0[1]]]) # Make a 3D plot fig = plt.figure() ax = fig.gca(projection='3d') ax.plot_surface(X, Y, rv.pdf(pos) + rv1.pdf(pos), cmap='viridis', linewidth=0) ax.set_xlabel('X axis') ax.set_ylabel('Y axis') ax.set_zlabel('message') plt.tight_layout() plt.savefig(save_dir + 'agent1.png') ax.view_init(elev=0., azim=azim) plt.savefig(save_dir + ("agent1_0_%i.png" % int(azim))) ax.view_init(elev=90., azim=0.) plt.savefig(save_dir + "agent1_90_0.png") # plt.show() plt.close() # Agent 2 # Create multivariate normal rv21 = multivariate_normal(mu_2, [[sigma_2[0], 0], [0, sigma_2[1]]]) rv22 = multivariate_normal(mu_3, [[sigma_3[0], 0], [0, sigma_3[1]]]) # Make a 3D plot fig = plt.figure() ax = fig.gca(projection='3d') ax.plot_surface(X, Y, rv.pdf(pos) + rv21.pdf(pos) + rv22.pdf(pos), cmap='viridis', linewidth=0) ax.set_xlabel('X axis') ax.set_ylabel('Y axis') ax.set_zlabel('message') plt.tight_layout() plt.savefig(save_dir + 'agent2.png') ax.view_init(elev=0., azim=azim) plt.savefig(save_dir + ("agent2_0_%i.png" % int(azim))) ax.view_init(elev=90., azim=0.) plt.savefig(save_dir + "agent2_90_0.png") # plt.show() plt.close() # Agent 3 rv31 = multivariate_normal(mu_4, [[sigma_4[0], 0], [0, sigma_4[1]]]) rv32 = multivariate_normal(mu_5, [[sigma_5[0], 0], [0, sigma_5[1]]]) # Make a 3D plot fig = plt.figure() ax = fig.gca(projection='3d') ax.plot_surface(X, Y, rv.pdf(pos) + rv31.pdf(pos) + rv32.pdf(pos), cmap='viridis', linewidth=0) ax.set_xlabel('X axis') ax.set_ylabel('Y axis') ax.set_zlabel('message') plt.tight_layout() plt.savefig(save_dir + 'agent3.png') ax.view_init(elev=0., azim=azim) plt.savefig(save_dir + ("agent3_0_%i.png" % int(azim))) ax.view_init(elev=90., azim=0.) plt.savefig(save_dir + "agent3_90_0.png") # plt.show() plt.close() # Overall # Make a 3D plot fig = plt.figure() ax = fig.gca(projection='3d') ax.plot_surface(X, Y, rv.pdf(pos) + rv1.pdf(pos) + rv21.pdf(pos) + rv22.pdf(pos) + rv31.pdf(pos) + rv32.pdf(pos), cmap='viridis', linewidth=0) ax.set_xlabel('X axis') ax.set_ylabel('Y axis') ax.set_zlabel('message') plt.tight_layout() plt.savefig(save_dir + 'overall.png') ax.view_init(elev=0., azim=azim) plt.savefig(save_dir + ("overall_0_%i.png" % int(azim))) ax.view_init(elev=90., azim=0.) plt.savefig(save_dir + "overall_90_0.png") # plt.show() plt.close()
def div_from_prior(posterior): prior = Normal(torch.zeros_like(posterior.loc), torch.ones_like(posterior.scale)) return kl_divergence(posterior, prior).sum(dim=-1)
def embedding_loss(output, target, epoch_iter, n_iters_start=0, lambd=0.3): # assert len(output) == 3 or len(output) == 5 # TODO: implement KL-div with Pytorch lib # passar distribucio i sample. Distribucio es calcula abans. Use rsample, no te lies. # G**2 es el numero de slots. Seria com el numero de timesteps, suposo. Hauria de deixar el batch i prou. rec = output["rec"] pred = output["pred"] pred_rev = output["pred_rev"] g = output["obs_rec_pred"] posteriors = output["gauss_post"] A = output["A"] B = output["B"] u = output["u"] # Get bs, T, n_obj; and take order into account. bs = rec.shape[0] T = rec.shape[1] device = rec.device std_rec = .15 # local_geo_loss = torch.zeros(1) if epoch_iter[ 0] < n_iters_start: #TODO: Add rec loss in G with spectralnorm. fit_error # lambd_lg = 1 lambd_fit_error = 1 lambd_hrank = 0.05 lambd_rec = 1 lambd_pred = 0.2 prior_pres_prob = 0.05 prior_g_mask_prob = 0.8 lambd_u = 0.1 lambd_I = 1 else: # lambd_lg = 1 lambd_fit_error = 1 lambd_hrank = 0.2 lambd_rec = .7 lambd_pred = 1 prior_pres_prob = 0.1 prior_g_mask_prob = 0.4 lambd_u = 1 lambd_I = 1 '''Rec and pred losses''' # Shape latent: [bs, T-n_timesteps+1, 1, 64, 64] rec_distr = Normal(rec, std_rec) logprob_rec = rec_distr.log_prob(target[:, -rec.shape[1]:])\ .flatten(start_dim=1).sum(1) pred_distr = Normal(pred, std_rec) logprob_pred = pred_distr.log_prob(target[:, -pred.shape[1]:])\ .flatten(start_dim=1).sum(1) # TODO: Check if correct. pred_rev_distr = Normal(pred_rev, std_rec) logprob_pred_rev = pred_rev_distr.log_prob(target[:, :pred_rev.shape[1]]) \ .flatten(start_dim=1).sum(1) # TODO: Check if correct. kl_bern_loss = 0 '''G composition bernoulli KL div loss''' # if "g_bern_logit" == any(output.keys()): if output["g_bern_logit"] is not None: g_mask_logit = output["g_bern_logit"] # TODO: Check shape kl_g_mask_loss = kl_divergence_bern_bern( g_mask_logit, prior_pres_prob=torch.FloatTensor([prior_g_mask_prob ]).to(u_logit.device)).sum() kl_bern_loss = kl_bern_loss + kl_g_mask_loss '''Input bernoulli KL div loss''' # if "u_bern_logit" == any(output.keys()): if output["u_bern_logit"] is not None: u_logit = output["u_bern_logit"] kl_u_loss = kl_divergence_bern_bern(u_logit, prior_pres_prob=torch.FloatTensor([ prior_pres_prob ]).to(u_logit.device)).sum() kl_bern_loss = kl_bern_loss + kl_u_loss l1_u = 0. else: '''Input sparsity loss u: [bs, T, feat_dim] ''' up_bound = 0.3 N_elem = u.shape[0] l1_u_sparse = F.relu(l1_loss(u, torch.zeros_like(u)).mean() - up_bound) # l1_u_diff_sparse = F.relu(l1_loss(u[:, :-1] - u[:, 1:], torch.zeros_like(u[:, 1:])).mean() - up_bound) * u.shape[0] * u.shape[1] l1_u = lambd_u * l1_u_sparse * N_elem # l1_u = lambd_u * l1_loss(u, torch.zeros_like(u)).mean() '''Gaussian vectors KL div loss''' posteriors = posteriors[:-1] # If only rec prior = Normal(torch.zeros(1).to(device), torch.ones(1).to(device)) kl_loss = torch.stack([ kl_divergence(post, prior).flatten(start_dim=1).sum(1) for post in posteriors ]).sum(0) nelbo = ( kl_loss + kl_bern_loss - logprob_rec * lambd_rec - logprob_pred * lambd_pred #TODO: There's a mismatch here? - logprob_pred_rev * lambd_pred).mean() '''Cycle consistency''' if output["Ainv"] is not None: A_inv = output["Ainv"] cyc_conc_loss = cycle_consistency_loss(A, A_inv) '''LSQ fit loss''' # Note: only has correspondence if last frames of pred correspond to last frames of rec g_rec, g_pred = g[:, :T], g[:, T:] T_min = min(T, g_pred.shape[1]) fit_error = lambd_fit_error * mse_loss(g_rec[:, -T_min:], g_pred[:, -T_min:]).sum() '''Low rank G''' # Option 1: # T = g_for_koop.shape[1] # n_timesteps = g_for_koop.shape[-1] # g_for_koop = g_for_koop.permute(0, 2, 1, 3).reshape(-1, T-1, n_timesteps) # h_rank_loss = 0 # reg_mask = torch.zeros_like(g_for_koop[..., :n_timesteps, :]) # ids = torch.arange(0, reg_mask.shape[-1]) # reg_mask[..., ids, ids] = 0.01 # for t in range(T-1-n_timesteps): # logdet_H = torch.slogdet(g_for_koop[..., t:t+n_timesteps, :] + reg_mask) # h_rank_loss = h_rank_loss + .01*(logdet_H[1]).mean() # Option 2: h_rank_loss = ( l1_loss(A, torch.zeros_like(A)).sum(-1).sum(-1).sum() # + l1_loss(B, torch.zeros_like(B)).sum(-1).sum(-1).mean() ) # h_rank_loss = (l1_loss(A, torch.zeros_like(A)).sum(-1).sum(-1).sum() # + l1_loss(B, torch.zeros_like(B)).sum(-1).sum(-1).sum()) h_rank_loss = lambd_hrank * h_rank_loss '''Input KL div.''' # KL loss for gumbel-softmax. We should use increasing temperature for softmax. Check SPACE pres variable. '''Local geometry loss''' # g = g[:, :rec.shape[1]] # local_geo_loss = lambd_lg * local_geo(g, target[:, -g.shape[1]:]) '''Total Loss''' loss = ( nelbo + h_rank_loss + l1_u + fit_error + cyc_cons_loss # + local_geo_loss ) rec_mse = mse_loss(rec, target[:, -rec.shape[1]:]).mean(1).reshape( bs, -1).flatten(start_dim=1).sum(1).mean() pred_mse = mse_loss(pred, target[:, -pred.shape[1]:]).mean(1).reshape( bs, -1).flatten(start_dim=1).sum(1).mean() return loss, { 'Rec mse': rec_mse, 'Pred mse': pred_mse, 'KL Loss': kl_loss.mean(), # 'Rec llik':logprob_rec.mean(), # 'Pred llik':logprob_pred.mean(), 'Cycle consistency Loss': cyc_cons_loss, 'H rank Loss': h_rank_loss, 'Fit error': fit_error, # 'G Pred Loss':fit_error, # 'Local geo Loss':local_geo_loss, # 'l1_u':l1_u, }
def _kl_mv_diag_normal_mv_diag_normal(p, q): return kl_divergence(p.distribution, q.distribution)
def kl_relaxed_one_hot_categorical(p, q): p = Categorical(probs=p.probs) q = Categorical(probs=q.probs) return kl_divergence(p, q)
def forward(self, p_params, q_params=None, forced_latent=None, use_mode=False, force_constant_output=False, analytical_kl=False): assert (forced_latent is None) or (not use_mode) if self.transform_p_params: p_params = self.conv_in_p(p_params) else: assert p_params.size(1) == 2 * self.c_vars if q_params is not None: q_params = self.conv_in_q(q_params) mu_lv = q_params else: mu_lv = p_params if forced_latent is None: if use_mode: z = torch.chunk(mu_lv, 2, dim=1)[0] else: z = normal_rsample(mu_lv) else: z = forced_latent # Copy one sample (and distrib parameters) over the whole batch if force_constant_output: z = z[0:1].expand_as(z).clone() p_params = p_params[0:1].expand_as(p_params).clone() # Output of stochastic layer out = self.conv_out(z) kl_elementwise = kl_samplewise = kl_spatial_analytical = None logprob_q = None # Compute log p(z) p_mu, p_lv = p_params.chunk(2, dim=1) p = Normal(p_mu, (p_lv / 2).exp()) logprob_p = p.log_prob(z).sum((1, 2, 3)) if q_params is not None: # Compute log q(z) q_mu, q_lv = q_params.chunk(2, dim=1) q = Normal(q_mu, (q_lv / 2).exp()) logprob_q = q.log_prob(z).sum((1, 2, 3)) # Compute KL (analytical or MC estimate) if analytical_kl: kl_elementwise = kl_divergence(q, p) else: kl_elementwise = kl_normal_mc(z, p_params, q_params) kl_samplewise = kl_elementwise.sum((1, 2, 3)) # Compute spatial KL analytically (but conditioned on samples from # previous layers) kl_spatial_analytical = -0.5 * (1 + q_lv - q_mu.pow(2) - q_lv.exp()) kl_spatial_analytical = kl_spatial_analytical.sum(1) data = { 'z': z, 'p_params': p_params, 'q_params': q_params, 'logprob_p': logprob_p, 'logprob_q': logprob_q, 'kl_elementwise': kl_elementwise, 'kl_samplewise': kl_samplewise, 'kl_spatial': kl_spatial_analytical, } return out, data
def forward(self, representation, viewpoint_query, image_query): batch_size, _, *spatial_dims = image_query.shape spatial_dims_scaled = tuple(np.array(spatial_dims) // self.scale) kl = 0 # Increase dimensions viewpoint_query = viewpoint_query.view(batch_size, -1, *[ 1, ] * len(spatial_dims)).repeat(1, 1, *spatial_dims_scaled) if representation.shape[2:] != spatial_dims_scaled: representation = representation.view(batch_size, -1, *[ 1, ] * len(spatial_dims)).repeat(1, 1, *spatial_dims_scaled) # Reset hidden state hidden_g = image_query.new_zeros( (batch_size, self.h_channels, *spatial_dims_scaled)) hidden_i = image_query.new_zeros( (batch_size, self.h_channels, *spatial_dims_scaled)) # Reset cell state cell_g = image_query.new_zeros( (batch_size, self.h_channels, *spatial_dims_scaled)) cell_i = image_query.new_zeros( (batch_size, self.h_channels, *spatial_dims_scaled)) # better name for u? u = image_query.new_zeros((batch_size, self.h_channels, *spatial_dims)) for i in range(self.core_repeat): if self.core_shared: current_generator_core = self.generator_core current_inference_core = self.inference_core else: current_generator_core = self.generator_core[i] current_inference_core = self.inference_core[i] # Prior o = self.prior_net(hidden_g) prior_mu, prior_std_pseudo = torch.split(o, self.z_channels, dim=1) prior = Normal(prior_mu, F.softplus(prior_std_pseudo)) # Inference state update cell_i, hidden_i = current_inference_core(image_query, viewpoint_query, representation, cell_i, hidden_i, hidden_g, u) # Posterior o = self.posterior_net(hidden_i) posterior_mu, posterior_std_pseudo = torch.split(o, self.z_channels, dim=1) posterior = Normal(posterior_mu, F.softplus(posterior_std_pseudo)) # Posterior sample if self.training: z = posterior.rsample() else: z = prior.loc # Generator update cell_g, hidden_g, u = current_generator_core( viewpoint_query, representation, z, cell_g, hidden_g, u) # Calculate KL-divergence kl += kl_divergence(posterior, prior) image_prediction = self.output_activation(self.observation_net(u)) return image_prediction, kl
def train(self, env_fn, policy, n_itr, normalize=None, logger=None): if normalize != None: policy.train() else: policy.train(0) env = Vectorize([env_fn]) # this will be useful for parallelism later if normalize is not None: env = normalize(env) mean, std = env.ob_rms.mean, np.sqrt(env.ob_rms.var + 1E-8) policy.obs_mean = torch.Tensor(mean) policy.obs_std = torch.Tensor(std) policy.train(0) env = Vectorize([env_fn]) old_policy = deepcopy(policy) optimizer = optim.Adam(policy.parameters(), lr=self.lr, eps=self.eps) start_time = time.time() for itr in range(n_itr): print("********** Iteration {} ************".format(itr)) sample_t = time.time() if self.n_proc > 1: print("doing multi samp") batch = self.sample_parallel(env_fn, policy, self.num_steps, 300) else: batch = self._sample(env_fn, policy, self.num_steps, 300) #TODO: fix this print("sample time: {:.2f} s".format(time.time() - sample_t)) observations, actions, returns, values = map( torch.Tensor, batch.get()) advantages = returns - values advantages = (advantages - advantages.mean()) / (advantages.std() + self.eps) minibatch_size = self.minibatch_size or advantages.numel() print("timesteps in batch: %i" % advantages.numel()) old_policy.load_state_dict( policy.state_dict()) # WAY faster than deepcopy for _ in range(self.epochs): losses = [] sampler = BatchSampler(SubsetRandomSampler( range(advantages.numel())), minibatch_size, drop_last=True) for indices in sampler: indices = torch.LongTensor(indices) obs_batch = observations[indices] action_batch = actions[indices] return_batch = returns[indices] advantage_batch = advantages[indices] values, pdf = policy.evaluate(obs_batch) # TODO, move this outside loop? with torch.no_grad(): _, old_pdf = old_policy.evaluate(obs_batch) old_log_probs = old_pdf.log_prob(action_batch).sum( -1, keepdim=True) log_probs = pdf.log_prob(action_batch).sum(-1, keepdim=True) ratio = (log_probs - old_log_probs).exp() cpi_loss = ratio * advantage_batch clip_loss = ratio.clamp(1.0 - self.clip, 1.0 + self.clip) * advantage_batch actor_loss = -torch.min(cpi_loss, clip_loss).mean() critic_loss = 0.5 * (return_batch - values).pow(2).mean() entropy_penalty = -self.entropy_coeff * pdf.entropy().mean( ) # TODO: add ability to optimize critic and actor seperately, with different learning rates optimizer.zero_grad() (actor_loss + critic_loss + entropy_penalty).backward() # Clip the gradient norm to prevent "unlucky" minibatches from # causing pathalogical updates torch.nn.utils.clip_grad_norm_(policy.parameters(), self.grad_clip) optimizer.step() losses.append([ actor_loss.item(), pdf.entropy().mean().item(), critic_loss.item(), ratio.mean().item() ]) # TODO: add verbosity arguments to suppress this print(' '.join(["%g" % x for x in np.mean(losses, axis=0)])) # Early stopping if kl_divergence(pdf, old_pdf).mean() > 0.02: print("Max kl reached, stopping optimization early.") break if logger is not None: test = self.sample(env, policy, 800 // self.n_proc, 400, deterministic=True) _, pdf = policy.evaluate(observations) _, old_pdf = old_policy.evaluate(observations) entropy = pdf.entropy().mean().item() kl = kl_divergence(pdf, old_pdf).mean().item() logger.record("Return (test)", np.mean(test.ep_returns)) logger.record("Return (batch)", np.mean(batch.ep_returns)) logger.record("Mean Eplen", np.mean(batch.ep_lens)) logger.record("Mean KL Div", kl) logger.record("Mean Entropy", entropy) logger.dump() # TODO: add option for how often to save model # if itr % 10 == 0: if np.mean(test.ep_returns) > self.max_return: self.max_return = np.mean(test.ep_returns) self.save(policy, env) self.save_optim(optimizer) print("Total time: {:.2f} s".format(time.time() - start_time))
def test_non_empty_bit_vector(batch_shape=tuple(), K=3): assert K<= 10, "I test against explicit enumeration, K>10 might be too slow for that" # Uniform F = NonEmptyBitVector(torch.zeros(batch_shape + (K,))) # Shapes assert F.batch_shape == batch_shape, "NonEmptyBitVector has the wrong batch_shape" assert F.dim == K, "NonEmptyBitVector has the wrong dim" assert F.event_shape == (K,), "NonEmptyBitVector has the wrong event_shape" assert F.scores.shape == batch_shape + (K,), "NonEmptyBitVector.score has the wrong shape" assert F.arc_weight.shape == batch_shape + (K+1,3,3), "NonEmptyBitVector.arc_weight has the wrong shape" assert F.state_value.shape == batch_shape + (K+2,3), "NonEmptyBitVector.state_value has the wrong shape" assert F.state_rvalue.shape == batch_shape + (K+2,3), "NonEmptyBitVector.state_rvalue has the wrong shape" # shape: [num_faces] + batch_shape + [K] support = F.enumerate_support() # test shape of support assert support.shape == (2**K-1,) + batch_shape + (K,), "The support has the wrong shape" assert F.expand((2,3) + batch_shape).batch_shape == (2,3) + batch_shape, "Bad expand batch_shape" assert F.expand((2,3) + batch_shape).event_shape == (K,), "Bad expand event_shape" assert F.expand((2,3) + batch_shape).sample().shape == (2,3) + batch_shape + (K,), "Bad expand single sample" assert F.expand((2,3) + batch_shape).sample((13,)).shape == (13,2,3) + batch_shape + (K,), "Bad expand multiple samples" # Constraints assert (support.sum(-1) > 0).all(), "The support has an empty bit vector" for _ in range(100): # testing one sample at a time assert F.sample().sum(-1).all(), "I found an empty vector" # testing a batch of samples assert F.sample((100,)).sum(-1).all(), "I found an empty vector" # testing a complex batch of samples assert F.sample((2, 100,)).sum(-1).all(), "I found an empty vector" # Distribution # check for uniform probabilities assert torch.isclose(F.log_prob(support).exp(), torch.tensor(1./F.support_size)).all(), "Non-uniform" # check for uniform marginal probabilities assert torch.isclose(F.sample((10000,)).float().mean(0), support.mean(0), atol=1e-1).all(), "Bad MC marginals" assert torch.isclose(F.marginals(), support.mean(0)).all(), "Bad exact marginals" # Entropy # [num_faces, B] log_prob = F.log_prob(support) assert torch.isclose(F.entropy(), (-(log_prob.exp() * log_prob).sum(0)), atol=1e-2).all(), "Problem in the entropy DP" # Non-Uniform # Entropy P = NonEmptyBitVector(td.Normal(torch.zeros(batch_shape + (K,)), torch.ones(batch_shape + (K,))).sample()) log_p = P.log_prob(support) assert torch.isclose(P.entropy(), (-(log_p.exp() * log_p).sum(0)), atol=1e-2).all(), "Problem in the entropy DP" # Cross-Entropy Q = NonEmptyBitVector(td.Normal(torch.zeros(batch_shape + (K,)), torch.ones(batch_shape + (K,))).sample()) log_q = Q.log_prob(support) assert torch.isclose(P.cross_entropy(Q), -(log_p.exp() * log_q).sum(0), atol=1e-2).all(), "Problem in the cross-entropy DP" # KL assert torch.isclose(td.kl_divergence(P, Q), (log_p.exp() * (log_p - log_q)).sum(0), atol=1e-2).all(), "Problem in KL" # Constraints for _ in range(100): # testing one sample at a time assert P.sample().sum(-1).all(), "I found an empty vector" assert Q.sample().sum(-1).all(), "I found an empty vector" # testing a batch of samples assert P.sample((100,)).sum(-1).all(), "I found an empty vector" assert Q.sample((100,)).sum(-1).all(), "I found an empty vector" # testing a complex batch of samples assert P.sample((2, 100,)).sum(-1).all(), "I found an empty vector" assert Q.sample((2, 100,)).sum(-1).all(), "I found an empty vector"
def update(self, policy, old_policy, optimizer, observations, actions, returns, advantages, env_fn): env = env_fn() mirror_observation = env.mirror_observation mirror_action = env.mirror_action minibatch_size = self.minibatch_size or advantages.numel() for _ in range(self.epochs): losses = [] sampler = BatchSampler(SubsetRandomSampler( range(advantages.numel())), minibatch_size, drop_last=True) for indices in sampler: indices = torch.LongTensor(indices) obs_batch = observations[indices] # obs_batch = torch.cat( # [obs_batch, # obs_batch @ torch.Tensor(env.obs_symmetry_matrix)] # ).detach() action_batch = actions[indices] # action_batch = torch.cat( # [action_batch, # action_batch @ torch.Tensor(env.action_symmetry_matrix)] # ).detach() return_batch = returns[indices] # return_batch = torch.cat( # [return_batch, # return_batch] # ).detach() advantage_batch = advantages[indices] # advantage_batch = torch.cat( # [advantage_batch, # advantage_batch] # ).detach() values, pdf = policy.evaluate(obs_batch) # TODO, move this outside loop? with torch.no_grad(): _, old_pdf = old_policy.evaluate(obs_batch) old_log_probs = old_pdf.log_prob(action_batch).sum( -1, keepdim=True) log_probs = pdf.log_prob(action_batch).sum(-1, keepdim=True) ratio = (log_probs - old_log_probs).exp() cpi_loss = ratio * advantage_batch clip_loss = ratio.clamp(1.0 - self.clip, 1.0 + self.clip) * advantage_batch actor_loss = -torch.min(cpi_loss, clip_loss).mean() critic_loss = 0.5 * (return_batch - values).pow(2).mean() # Mirror Symmetry Loss _, deterministic_actions = policy(obs_batch) _, mirror_actions = policy(mirror_observation(obs_batch)) mirror_actions = mirror_action(mirror_actions) mirror_loss = 4 * (deterministic_actions - mirror_actions).pow(2).mean() entropy_penalty = -self.entropy_coeff * pdf.entropy().mean() # TODO: add ability to optimize critic and actor seperately, with different learning rates optimizer.zero_grad() (actor_loss + critic_loss + mirror_loss + entropy_penalty).backward() # Clip the gradient norm to prevent "unlucky" minibatches from # causing pathalogical updates torch.nn.utils.clip_grad_norm_(policy.parameters(), self.grad_clip) optimizer.step() losses.append([ actor_loss.item(), pdf.entropy().mean().item(), critic_loss.item(), ratio.mean().item(), mirror_loss.item() ]) # TODO: add verbosity arguments to suppress this print(' '.join(["%g" % x for x in np.mean(losses, axis=0)])) # Early stopping if kl_divergence(pdf, old_pdf).mean() > 0.02: print("Max kl reached, stopping optimization early.") break
def forward(self, *inputs: Tensor, **kwargs ) -> Tuple[List[MultivariateNormal], Tensor]: """Forward propagate the model. Parameters ---------- inputs: Tensor. output_sequence: Tensor. Tensor of output data [batch_size x sequence_length x dim_outputs]. input_sequence: Tensor. Tensor of input data [batch_size x sequence_length x dim_inputs]. Returns ------- output_distribution: List[Normal]. List of length sequence_length of distributions of size [batch_size x dim_outputs x num_particles] """ output_sequence, input_sequence = inputs num_particles = self.num_particles # dim_states = self.dim_states batch_size, sequence_length, dim_inputs = input_sequence.shape _, _, dim_outputs = output_sequence.shape ################################################################################ # SAMPLE GP # ################################################################################ self.forward_model.resample() self.backward_model.resample() ################################################################################ # PERFORM Backward Pass # ################################################################################ if self.training: output_distribution = self.backward(output_sequence, input_sequence) ################################################################################ # Initial State # ################################################################################ state = self.recognition(output_sequence[:, :self.recognition.length], input_sequence[:, :self.recognition.length], num_particles=num_particles) ################################################################################ # PREDICT Outputs # ################################################################################ outputs = [] y_pred = self.emissions(state) outputs.append(MultivariateNormal(y_pred.loc.detach(), y_pred.covariance_matrix.detach())) ################################################################################ # INITIALIZE losses # ################################################################################ # entropy = torch.tensor(0.) if self.training: output_distribution.pop(0) # entropy += y_tilde.entropy().mean() / sequence_length y = output_sequence[:, 0].expand(num_particles, batch_size, dim_outputs ).permute(1, 2, 0) log_lik = y_pred.log_prob(y).sum(dim=1).mean() # type: torch.Tensor l2 = ((y_pred.loc - y) ** 2).sum(dim=1).mean() # type: torch.Tensor kl_cond = torch.tensor(0.) for t in range(sequence_length - 1): ############################################################################ # PREDICT Next State # ############################################################################ u = input_sequence[:, t].expand(num_particles, batch_size, dim_inputs) u = u.permute(1, 2, 0) # Move last component to end. state_samples = state.rsample() state_input = torch.cat((state_samples, u), dim=1) next_f = self.forward_model(state_input) next_state = self.transitions(next_f) next_state.loc += state_samples if self.independent_particles: next_state = diagonal_covariance(next_state) ############################################################################ # CONDITION Next State # ############################################################################ if self.training: y_tilde = output_distribution.pop(0) p_next_state = next_state next_state = self._condition(next_state, y_tilde) kl_cond += kl_divergence(next_state, p_next_state).mean() ############################################################################ # RESAMPLE State # ############################################################################ state = next_state ############################################################################ # PREDICT Outputs # ############################################################################ y_pred = self.emissions(state) outputs.append(y_pred) ############################################################################ # COMPUTE Losses # ############################################################################ y = output_sequence[:, t + 1].expand( num_particles, batch_size, dim_outputs).permute(1, 2, 0) log_lik += y_pred.log_prob(y).sum(dim=1).mean() l2 += ((y_pred.loc - y) ** 2).sum(dim=1).mean() # entropy += y_tilde.entropy().mean() / sequence_length assert len(outputs) == sequence_length # if self.training: # del output_distribution ################################################################################ # Compute model KL divergences Divergences # ################################################################################ factor = 1 # batch_size / self.dataset_size kl_uf = self.forward_model.kl_divergence() kl_ub = self.backward_model.kl_divergence() if self.forward_model.independent: kl_uf *= sequence_length if self.backward_model.independent: kl_ub *= sequence_length kl_cond = kl_cond * self.loss_factors['kl_conditioning'] * factor kl_ub = kl_ub * self.loss_factors['kl_u'] * factor kl_uf = kl_uf * self.loss_factors['kl_u'] * factor if self.loss_key.lower() == 'loglik': loss = -log_lik elif self.loss_key.lower() == 'elbo': loss = -(log_lik - kl_uf - kl_ub - kl_cond) if kwargs.get('print', False): str_ = 'elbo: {}, log_lik: {}, kluf: {}, klub: {}, klcond: {}' print(str_.format(loss.item(), log_lik.item(), kl_uf.item(), kl_ub.item(), kl_cond.item())) elif self.loss_key.lower() == 'l2': loss = l2 elif self.loss_key.lower() == 'rmse': loss = torch.sqrt(l2) else: raise NotImplementedError("Key {} not implemented".format(self.loss_key)) return outputs, loss
def kld_loss(self, encoded_distribution): prior = Normal(torch.zeros_like(encoded_distribution.mean), torch.ones_like(encoded_distribution.variance)) kld = kl_divergence(encoded_distribution, prior).sum(dim=1) return kld
np.random.shuffle(aux_inds) for i, start in enumerate( range(0, args.aux_batch_size, args.aux_minibatch_size)): end = start + args.aux_minibatch_size aux_minibatch_ind = aux_inds[start:end] try: m_aux_obs = aux_obs[aux_minibatch_ind].to(device) m_aux_returns = aux_returns[aux_minibatch_ind].to(device) new_pi, new_values, new_aux_values = agent.get_pi_value_and_aux_value( m_aux_obs) new_values = new_values.view(-1) new_aux_values = new_aux_values.view(-1) with torch.no_grad(): old_pi = old_agent.get_pi(m_aux_obs) kl_loss = td.kl_divergence(old_pi, new_pi).mean() real_value_loss = 0.5 * ( (new_values - m_aux_returns)**2).mean() aux_value_loss = 0.5 * ( (new_aux_values - m_aux_returns)**2).mean() joint_loss = aux_value_loss + args.beta_clone * kl_loss optimizer.zero_grad() loss = (joint_loss + real_value_loss) / args.n_aux_grad_accum loss.backward() if (i + 1) % args.n_aux_grad_accum == 0: nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) optimizer.step()
def train(self, env_fn, policy, n_itr, logger=None): old_policy = deepcopy(policy) optimizer = optim.Adam(policy.parameters(), lr=self.lr, eps=self.eps) start_time = time.time() for itr in range(n_itr): print("********** Iteration {} ************".format(itr)) sample_start = time.time() batch = self.sample_parallel(env_fn, policy, self.num_steps, self.max_traj_len) print("time elapsed: {:.2f} s".format(time.time() - start_time)) print("sample time elapsed: {:.2f} s".format(time.time() - sample_start)) observations, actions, returns, values = map( torch.Tensor, batch.get()) advantages = returns - values advantages = (advantages - advantages.mean()) / (advantages.std() + self.eps) minibatch_size = self.minibatch_size or advantages.numel() print("timesteps in batch: %i" % advantages.numel()) old_policy.load_state_dict( policy.state_dict()) # WAY faster than deepcopy optimizer_start = time.time() self.update(policy, old_policy, optimizer, observations, actions, returns, advantages, env_fn) print("optimizer time elapsed: {:.2f} s".format(time.time() - optimizer_start)) if logger is not None: evaluate_start = time.time() test = self.sample_parallel(env_fn, policy, 800 // self.n_proc, self.max_traj_len, deterministic=True) print("evaluate time elapsed: {:.2f} s".format(time.time() - evaluate_start)) _, pdf = policy.evaluate(observations) _, old_pdf = old_policy.evaluate(observations) entropy = pdf.entropy().mean().item() kl = kl_divergence(pdf, old_pdf).mean().item() logger.record("Return (test)", np.mean(test.ep_returns)) logger.record("Return (batch)", np.mean(batch.ep_returns)) logger.record("Mean Eplen", np.mean(batch.ep_lens)) logger.record("Mean KL Div", kl) logger.record("Mean Entropy", entropy) logger.dump() # TODO: add option for how often to save model if itr % 10 == 0: self.save(policy)
data_loader = PairedComparison(4, direction=False, dichotomized=dichotomized, ranking=True) for i in tqdm(range(num_tasks)): if dichotomized: model1 = VariationalFirstDiscriminatingCue( data_loader.num_inputs) else: model1 = VariationalFirstCue(data_loader.num_inputs) model2 = VariationalProbitRegression(data_loader.num_inputs) inputs, targets, _, _ = data_loader.get_batch(1, num_steps) predictive_distribution1 = model1.forward(inputs, targets) predictive_distribution2 = model2.forward(inputs, targets) kl[k, i, :] = kl_divergence(predictive_distribution1, predictive_distribution2).squeeze() torch.save(kl, 'data/power_analysis.pth') print(kl.mean(1)) print(kl.mean(1).sum(-1)) else: kl = torch.load('data/power_analysis.pth').div(2.303) # ln to log10 mean = kl.sum(-1).mean(1) variance = kl.sum(-1).std(1).pow(2) num_tasks = torch.arange(1, 31) expected_bf = torch.arange(1, 31).unsqueeze(-1) * mean confidence_bf = (torch.arange(1, 31).unsqueeze(-1) * variance).sqrt() * 1.96 styles = [':', '-'] for i in range(expected_bf.shape[1]):