示例#1
0
def test_non_mean_field_bern_normal_elbo_gradient(enumerate1, pi1, pi2, pi3, include_z=True):
    pyro.clear_param_store()
    num_particles = 10000

    def model():
        with pyro.iarange("particles", num_particles):
            q3 = pyro.param("q3", torch.tensor(pi3, requires_grad=True))
            y = pyro.sample("y", dist.Bernoulli(q3).expand_by([num_particles]))
            if include_z:
                pyro.sample("z", dist.Normal(0.55 * y + q3, 1.0))

    def guide():
        q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True))
        q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True))
        with pyro.iarange("particles", num_particles):
            y = pyro.sample("y", dist.Bernoulli(q1).expand_by([num_particles]), infer={"enumerate": enumerate1})
            if include_z:
                pyro.sample("z", dist.Normal(q2 * y + 0.10, 1.0))

    logger.info("Computing gradients using surrogate loss")
    elbo = TraceEnum_ELBO(max_iarange_nesting=1,
                          strict_enumeration_warning=any([enumerate1]))
    elbo.loss_and_grads(model, guide)
    actual_grad_q1 = pyro.param('q1').grad / num_particles
    if include_z:
        actual_grad_q2 = pyro.param('q2').grad / num_particles
    actual_grad_q3 = pyro.param('q3').grad / num_particles

    logger.info("Computing analytic gradients")
    q1 = torch.tensor(pi1, requires_grad=True)
    q2 = torch.tensor(pi2, requires_grad=True)
    q3 = torch.tensor(pi3, requires_grad=True)
    elbo = kl_divergence(dist.Bernoulli(q1), dist.Bernoulli(q3))
    if include_z:
        elbo = elbo + q1 * kl_divergence(dist.Normal(q2 + 0.10, 1.0), dist.Normal(q3 + 0.55, 1.0))
        elbo = elbo + (1.0 - q1) * kl_divergence(dist.Normal(0.10, 1.0), dist.Normal(q3, 1.0))
        expected_grad_q1, expected_grad_q2, expected_grad_q3 = grad(elbo, [q1, q2, q3])
    else:
        expected_grad_q1, expected_grad_q3 = grad(elbo, [q1, q3])

    prec = 0.04 if enumerate1 is None else 0.02

    assert_equal(actual_grad_q1, expected_grad_q1, prec=prec, msg="".join([
        "\nq1 expected = {}".format(expected_grad_q1.data.cpu().numpy()),
        "\nq1   actual = {}".format(actual_grad_q1.data.cpu().numpy()),
    ]))
    if include_z:
        assert_equal(actual_grad_q2, expected_grad_q2, prec=prec, msg="".join([
            "\nq2 expected = {}".format(expected_grad_q2.data.cpu().numpy()),
            "\nq2   actual = {}".format(actual_grad_q2.data.cpu().numpy()),
        ]))
    assert_equal(actual_grad_q3, expected_grad_q3, prec=prec, msg="".join([
        "\nq3 expected = {}".format(expected_grad_q3.data.cpu().numpy()),
        "\nq3   actual = {}".format(actual_grad_q3.data.cpu().numpy()),
    ]))
示例#2
0
def test_elbo_bern(quantity, enumerate1):
    pyro.clear_param_store()
    num_particles = 1 if enumerate1 else 10000
    prec = 0.001 if enumerate1 else 0.1
    q = pyro.param("q", torch.tensor(0.5, requires_grad=True))
    kl = kl_divergence(dist.Bernoulli(q), dist.Bernoulli(0.25))

    def model():
        with pyro.iarange("particles", num_particles):
            pyro.sample("z", dist.Bernoulli(0.25).expand_by([num_particles]))

    @config_enumerate(default=enumerate1)
    def guide():
        q = pyro.param("q")
        with pyro.iarange("particles", num_particles):
            pyro.sample("z", dist.Bernoulli(q).expand_by([num_particles]))

    elbo = TraceEnum_ELBO(max_iarange_nesting=1,
                          strict_enumeration_warning=any([enumerate1]))

    if quantity == "loss":
        actual = elbo.loss(model, guide) / num_particles
        expected = kl.item()
        assert_equal(actual, expected, prec=prec, msg="".join([
            "\nexpected = {}".format(expected),
            "\n  actual = {}".format(actual),
        ]))
    else:
        elbo.loss_and_grads(model, guide)
        actual = q.grad / num_particles
        expected = grad(kl, [q])[0]
        assert_equal(actual, expected, prec=prec, msg="".join([
            "\nexpected = {}".format(expected.detach().cpu().numpy()),
            "\n  actual = {}".format(actual.detach().cpu().numpy()),
        ]))
示例#3
0
def test_non_mean_field_bern_bern_elbo_gradient(enumerate1, pi1, pi2):
    pyro.clear_param_store()
    num_particles = 1 if enumerate1 else 20000

    def model():
        with pyro.iarange("particles", num_particles):
            y = pyro.sample("y", dist.Bernoulli(0.33).expand_by([num_particles]))
            pyro.sample("z", dist.Bernoulli(0.55 * y + 0.10))

    def guide():
        q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True))
        q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True))
        with pyro.iarange("particles", num_particles):
            y = pyro.sample("y", dist.Bernoulli(q1).expand_by([num_particles]))
            pyro.sample("z", dist.Bernoulli(q2 * y + 0.10))

    logger.info("Computing gradients using surrogate loss")
    elbo = TraceEnum_ELBO(max_iarange_nesting=1,
                          strict_enumeration_warning=any([enumerate1]))
    elbo.loss_and_grads(model, config_enumerate(guide, default=enumerate1))
    actual_grad_q1 = pyro.param('q1').grad / num_particles
    actual_grad_q2 = pyro.param('q2').grad / num_particles

    logger.info("Computing analytic gradients")
    q1 = torch.tensor(pi1, requires_grad=True)
    q2 = torch.tensor(pi2, requires_grad=True)
    elbo = kl_divergence(dist.Bernoulli(q1), dist.Bernoulli(0.33))
    elbo = elbo + q1 * kl_divergence(dist.Bernoulli(q2 + 0.10), dist.Bernoulli(0.65))
    elbo = elbo + (1.0 - q1) * kl_divergence(dist.Bernoulli(0.10), dist.Bernoulli(0.10))
    expected_grad_q1, expected_grad_q2 = grad(elbo, [q1, q2])

    prec = 0.03 if enumerate1 is None else 0.001

    assert_equal(actual_grad_q1, expected_grad_q1, prec=prec, msg="".join([
        "\nq1 expected = {}".format(expected_grad_q1.data.cpu().numpy()),
        "\nq1  actual = {}".format(actual_grad_q1.data.cpu().numpy()),
    ]))
    assert_equal(actual_grad_q2, expected_grad_q2, prec=prec, msg="".join([
        "\nq2 expected = {}".format(expected_grad_q2.data.cpu().numpy()),
        "\nq2   actual = {}".format(actual_grad_q2.data.cpu().numpy()),
    ]))
示例#4
0
def test_elbo_categoricals(enumerate1, enumerate2, enumerate3, max_iarange_nesting):
    pyro.clear_param_store()
    p1 = torch.tensor([0.6, 0.4])
    p2 = torch.tensor([0.3, 0.3, 0.4])
    p3 = torch.tensor([0.1, 0.2, 0.3, 0.4])
    q1 = pyro.param("q1", torch.tensor([0.4, 0.6], requires_grad=True))
    q2 = pyro.param("q2", torch.tensor([0.4, 0.3, 0.3], requires_grad=True))
    q3 = pyro.param("q3", torch.tensor([0.4, 0.3, 0.2, 0.1], requires_grad=True))

    def model():
        pyro.sample("x1", dist.Categorical(p1))
        pyro.sample("x2", dist.Categorical(p2))
        pyro.sample("x3", dist.Categorical(p3))

    def guide():
        pyro.sample("x1", dist.Categorical(pyro.param("q1")), infer={"enumerate": enumerate1})
        pyro.sample("x2", dist.Categorical(pyro.param("q2")), infer={"enumerate": enumerate2})
        pyro.sample("x3", dist.Categorical(pyro.param("q3")), infer={"enumerate": enumerate3})

    kl = (kl_divergence(dist.Categorical(q1), dist.Categorical(p1)) +
          kl_divergence(dist.Categorical(q2), dist.Categorical(p2)) +
          kl_divergence(dist.Categorical(q3), dist.Categorical(p3)))
    expected_loss = kl.item()
    expected_grads = grad(kl, [q1, q2, q3])

    elbo = TraceEnum_ELBO(max_iarange_nesting=max_iarange_nesting,
                          strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3]))
    actual_loss = elbo.loss_and_grads(model, guide)
    actual_grads = [q1.grad, q2.grad, q3.grad]

    assert_equal(actual_loss, expected_loss, prec=0.001, msg="".join([
        "\nexpected loss = {}".format(expected_loss),
        "\n  actual loss = {}".format(actual_loss),
    ]))
    for actual_grad, expected_grad in zip(actual_grads, expected_grads):
        assert_equal(actual_grad, expected_grad, prec=0.001, msg="".join([
            "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()),
            "\n  actual grad = {}".format(actual_grad.detach().cpu().numpy()),
        ]))
示例#5
0
def test_elbo_rsvi(enumerate1):
    pyro.clear_param_store()
    num_particles = 40000
    prec = 0.01 if enumerate1 else 0.02
    q = pyro.param("q", torch.tensor(0.5, requires_grad=True))
    a = pyro.param("a", torch.tensor(1.5, requires_grad=True))
    kl1 = kl_divergence(dist.Bernoulli(q), dist.Bernoulli(0.25))
    kl2 = kl_divergence(dist.Gamma(a, 1.0), dist.Gamma(0.5, 1.0))

    def model():
        with pyro.iarange("particles", num_particles):
            pyro.sample("z", dist.Bernoulli(0.25).expand_by([num_particles]))
            pyro.sample("y", dist.Gamma(0.50, 1.0).expand_by([num_particles]))

    @config_enumerate(default=enumerate1)
    def guide():
        q = pyro.param("q")
        a = pyro.param("a")
        with pyro.iarange("particles", num_particles):
            pyro.sample("z", dist.Bernoulli(q).expand_by([num_particles]))
            pyro.sample("y", ShapeAugmentedGamma(a, torch.tensor(1.0)).expand_by([num_particles]))

    elbo = TraceEnum_ELBO(max_iarange_nesting=1,
                          strict_enumeration_warning=any([enumerate1]))
    elbo.loss_and_grads(model, guide)

    actual_q = q.grad / num_particles
    expected_q = grad(kl1, [q])[0]
    assert_equal(actual_q, expected_q, prec=prec, msg="".join([
        "\nexpected q.grad = {}".format(expected_q.detach().cpu().numpy()),
        "\n  actual q.grad = {}".format(actual_q.detach().cpu().numpy()),
    ]))
    actual_a = a.grad / num_particles
    expected_a = grad(kl2, [a])[0]
    assert_equal(actual_a, expected_a, prec=prec, msg="".join([
        "\nexpected a.grad= {}".format(expected_a.detach().cpu().numpy()),
        "\n  actual a.grad = {}".format(actual_a.detach().cpu().numpy()),
    ]))
示例#6
0
def test_elbo_iarange_iarange(outer_dim, inner_dim, enumerate1, enumerate2, enumerate3, enumerate4):
    pyro.clear_param_store()
    num_particles = 1 if all([enumerate1, enumerate2, enumerate3, enumerate4]) else 100000
    q = pyro.param("q", torch.tensor(0.75, requires_grad=True))
    p = 0.2693204236205713  # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5

    def model():
        d = dist.Bernoulli(p)
        with pyro.iarange("particles", num_particles):
            context1 = pyro.iarange("outer", outer_dim, dim=-2)
            context2 = pyro.iarange("inner", inner_dim, dim=-3)
            pyro.sample("w", d.expand_by([num_particles]))
            with context1:
                pyro.sample("x", d.expand_by([outer_dim, num_particles]))
            with context2:
                pyro.sample("y", d.expand_by([inner_dim, 1, num_particles]))
            with context1, context2:
                pyro.sample("z", d.expand_by([inner_dim, outer_dim, num_particles]))

    def guide():
        d = dist.Bernoulli(pyro.param("q"))
        with pyro.iarange("particles", num_particles):
            context1 = pyro.iarange("outer", outer_dim, dim=-2)
            context2 = pyro.iarange("inner", inner_dim, dim=-3)
            pyro.sample("w", d.expand_by([num_particles]), infer={"enumerate": enumerate1})
            with context1:
                pyro.sample("x", d.expand_by([outer_dim, num_particles]), infer={"enumerate": enumerate2})
            with context2:
                pyro.sample("y", d.expand_by([inner_dim, 1, num_particles]), infer={"enumerate": enumerate3})
            with context1, context2:
                pyro.sample("z", d.expand_by([inner_dim, outer_dim, num_particles]), infer={"enumerate": enumerate4})

    kl_node = kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p))
    kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node
    expected_loss = kl.item()
    expected_grad = grad(kl, [q])[0]

    elbo = TraceEnum_ELBO(max_iarange_nesting=3,
                          strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3]))
    actual_loss = elbo.loss_and_grads(model, guide) / num_particles
    actual_grad = pyro.param('q').grad / num_particles

    assert_equal(actual_loss, expected_loss, prec=0.1, msg="".join([
        "\nexpected loss = {}".format(expected_loss),
        "\n  actual loss = {}".format(actual_loss),
    ]))
    assert_equal(actual_grad, expected_grad, prec=0.1, msg="".join([
        "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()),
        "\n  actual grad = {}".format(actual_grad.detach().cpu().numpy()),
    ]))
示例#7
0
def test_elbo_irange_irange(outer_dim, inner_dim, enumerate1, enumerate2, enumerate3):
    pyro.clear_param_store()
    num_particles = 1 if all([enumerate1, enumerate2, enumerate3]) else 50000
    q = pyro.param("q", torch.tensor(0.75, requires_grad=True))
    p = 0.2693204236205713  # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5

    def model():
        with pyro.iarange("particles", num_particles):
            pyro.sample("x", dist.Bernoulli(p).expand_by([num_particles]))
            inner_irange = pyro.irange("inner", outer_dim)
            for i in pyro.irange("outer", inner_dim):
                pyro.sample("y_{}".format(i), dist.Bernoulli(p).expand_by([num_particles]))
                for j in inner_irange:
                    pyro.sample("z_{}_{}".format(i, j), dist.Bernoulli(p).expand_by([num_particles]))

    def guide():
        q = pyro.param("q")
        with pyro.iarange("particles", num_particles):
            pyro.sample("x", dist.Bernoulli(q).expand_by([num_particles]),
                        infer={"enumerate": enumerate1})
            inner_irange = pyro.irange("inner", inner_dim)
            for i in pyro.irange("outer", outer_dim):
                pyro.sample("y_{}".format(i), dist.Bernoulli(q).expand_by([num_particles]),
                            infer={"enumerate": enumerate2})
                for j in inner_irange:
                    pyro.sample("z_{}_{}".format(i, j), dist.Bernoulli(q).expand_by([num_particles]),
                                infer={"enumerate": enumerate3})

    kl = (1 + outer_dim * (1 + inner_dim)) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p))
    expected_loss = kl.item()
    expected_grad = grad(kl, [q])[0]

    elbo = TraceEnum_ELBO(max_iarange_nesting=1,
                          strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3]))
    actual_loss = elbo.loss_and_grads(model, guide) / num_particles
    actual_grad = pyro.param('q').grad / num_particles

    assert_equal(actual_loss, expected_loss, prec=0.1, msg="".join([
        "\nexpected loss = {}".format(expected_loss),
        "\n  actual loss = {}".format(actual_loss),
    ]))
    assert_equal(actual_grad, expected_grad, prec=0.1, msg="".join([
        "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()),
        "\n  actual grad = {}".format(actual_grad.detach().cpu().numpy()),
    ]))
示例#8
0
文件: test_jit.py 项目: lewisKit/pyro
def test_svi_enum(Elbo, irange_dim, enumerate1, enumerate2):
    pyro.clear_param_store()
    num_particles = 10
    q = pyro.param("q", torch.tensor(0.75), constraint=constraints.unit_interval)
    p = 0.2693204236205713  # for which kl(Bernoulli(q), Bernoulli(p)) = 0.5

    def model():
        pyro.sample("x", dist.Bernoulli(p))
        for i in pyro.irange("irange", irange_dim):
            pyro.sample("y_{}".format(i), dist.Bernoulli(p))

    def guide():
        q = pyro.param("q")
        pyro.sample("x", dist.Bernoulli(q), infer={"enumerate": enumerate1})
        for i in pyro.irange("irange", irange_dim):
            pyro.sample("y_{}".format(i), dist.Bernoulli(q), infer={"enumerate": enumerate2})

    kl = (1 + irange_dim) * kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p))
    expected_loss = kl.item()
    expected_grad = grad(kl, [q.unconstrained()])[0]

    inner_particles = 2
    outer_particles = num_particles // inner_particles
    elbo = TraceEnum_ELBO(max_iarange_nesting=0,
                          strict_enumeration_warning=any([enumerate1, enumerate2]),
                          num_particles=inner_particles)
    actual_loss = sum(elbo.loss_and_grads(model, guide)
                      for i in range(outer_particles)) / outer_particles
    actual_grad = q.unconstrained().grad / outer_particles

    assert_equal(actual_loss, expected_loss, prec=0.3, msg="".join([
        "\nexpected loss = {}".format(expected_loss),
        "\n  actual loss = {}".format(actual_loss),
    ]))
    assert_equal(actual_grad, expected_grad, prec=0.5, msg="".join([
        "\nexpected grad = {}".format(expected_grad.detach().cpu().numpy()),
        "\n  actual grad = {}".format(actual_grad.detach().cpu().numpy()),
    ]))
示例#9
0
def test_elbo_berns(enumerate1, enumerate2, enumerate3):
    pyro.clear_param_store()
    num_particles = 1 if all([enumerate1, enumerate2, enumerate3]) else 10000
    prec = 0.001 if all([enumerate1, enumerate2, enumerate3]) else 0.1
    q = pyro.param("q", torch.tensor(0.75, requires_grad=True))

    def model():
        with pyro.iarange("particles", num_particles):
            pyro.sample("x1", dist.Bernoulli(0.1).expand_by([num_particles]))
            pyro.sample("x2", dist.Bernoulli(0.2).expand_by([num_particles]))
            pyro.sample("x3", dist.Bernoulli(0.3).expand_by([num_particles]))

    def guide():
        q = pyro.param("q")
        with pyro.iarange("particles", num_particles):
            pyro.sample("x1", dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate1})
            pyro.sample("x2", dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate2})
            pyro.sample("x3", dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate3})

    kl = sum(kl_divergence(dist.Bernoulli(q), dist.Bernoulli(p)) for p in [0.1, 0.2, 0.3])
    expected_loss = kl.item()
    expected_grad = grad(kl, [q])[0]

    elbo = TraceEnum_ELBO(max_iarange_nesting=1,
                          strict_enumeration_warning=any([enumerate1, enumerate2, enumerate3]))
    actual_loss = elbo.loss_and_grads(model, guide) / num_particles
    actual_grad = q.grad / num_particles

    assert_equal(actual_loss, expected_loss, prec=prec, msg="".join([
        "\nexpected loss = {}".format(expected_loss),
        "\n  actual loss = {}".format(actual_loss),
    ]))
    assert_equal(actual_grad, expected_grad, prec=prec, msg="".join([
        "\nexpected grads = {}".format(expected_grad.detach().cpu().numpy()),
        "\n  actual grads = {}".format(actual_grad.detach().cpu().numpy()),
    ]))
示例#10
0
文件: model.py 项目: JindongJiang/GNM
    def elbo(self, x: torch.Tensor, p_dists: List, q_dists: List, lv_z: List,
             lv_g: List, lv_bg: List, pa_recon: List) -> Tuple:

        bs = x.size(0)

        p_global_all, p_pres_given_g_probs_reshaped, \
        p_where_given_g, p_depth_given_g, p_what_given_g, p_bg = p_dists

        q_global_all, q_pres_given_x_and_g_probs_reshaped, \
        q_where_given_x_and_g, q_depth_given_x_and_g, q_what_given_x_and_g, q_bg = q_dists

        y, y_nobg, alpha_map, bg = pa_recon

        if self.args.log.phase_nll:
            # (bs, dim, num_cell, num_cell)
            z_pres, _, z_depth, z_what, z_where_origin = lv_z
            # (bs * num_cell * num_cell, dim)
            z_pres_reshape = z_pres.permute(0, 2, 3,
                                            1).reshape(-1,
                                                       self.args.z.z_pres_dim)
            z_depth_reshape = z_depth.permute(0, 2, 3, 1).reshape(
                -1, self.args.z.z_depth_dim)
            z_what_reshape = z_what.permute(0, 2, 3,
                                            1).reshape(-1,
                                                       self.args.z.z_what_dim)
            z_where_origin_reshape = z_where_origin.permute(
                0, 2, 3, 1).reshape(-1, self.args.z.z_where_dim)
            # (bs, dim, 1, 1)
            z_bg = lv_bg[0]
            # (bs, step, dim, 1, 1)
            z_g = lv_g[0]
        else:
            z_pres, _, _, _, z_where_origin = lv_z

            z_pres_reshape = z_pres.permute(0, 2, 3,
                                            1).reshape(-1,
                                                       self.args.z.z_pres_dim)

        if self.args.train.p_pres_anneal_end_step != 0:
            self.aux_p_pres_probs = linear_schedule_tensor(
                self.args.train.global_step,
                self.args.train.p_pres_anneal_start_step,
                self.args.train.p_pres_anneal_end_step,
                self.args.train.p_pres_anneal_start_value,
                self.args.train.p_pres_anneal_end_value,
                self.aux_p_pres_probs.device)

        if self.args.train.aux_p_scale_anneal_end_step != 0:
            aux_p_scale_mean = linear_schedule_tensor(
                self.args.train.global_step,
                self.args.train.aux_p_scale_anneal_start_step,
                self.args.train.aux_p_scale_anneal_end_step,
                self.args.train.aux_p_scale_anneal_start_value,
                self.args.train.aux_p_scale_anneal_end_value,
                self.aux_p_where_mean.device)
            self.aux_p_where_mean[:, 0] = aux_p_scale_mean

        auxiliary_prior_z_pres_probs = self.aux_p_pres_probs[None][
            None, :].expand(bs * self.args.arch.num_cell**2, -1)

        aux_kl_pres = kl_divergence_bern_bern(
            q_pres_given_x_and_g_probs_reshaped, auxiliary_prior_z_pres_probs)
        aux_kl_where = kl_divergence(
            q_where_given_x_and_g,
            self.aux_p_where) * z_pres_reshape.clamp(min=1e-5)
        aux_kl_depth = kl_divergence(
            q_depth_given_x_and_g,
            self.aux_p_depth) * z_pres_reshape.clamp(min=1e-5)
        aux_kl_what = kl_divergence(
            q_what_given_x_and_g,
            self.aux_p_what) * z_pres_reshape.clamp(min=1e-5)

        kl_pres = kl_divergence_bern_bern(q_pres_given_x_and_g_probs_reshaped,
                                          p_pres_given_g_probs_reshaped)

        kl_where = kl_divergence(q_where_given_x_and_g, p_where_given_g)
        kl_depth = kl_divergence(q_depth_given_x_and_g, p_depth_given_g)
        kl_what = kl_divergence(q_what_given_x_and_g, p_what_given_g)

        kl_global_all = kl_divergence(q_global_all, p_global_all)

        if self.args.arch.phase_background:
            kl_bg = kl_divergence(q_bg, p_bg)
            aux_kl_bg = kl_divergence(q_bg, self.aux_p_bg)
        else:
            kl_bg = self.background.new_zeros(bs, 1)
            aux_kl_bg = self.background.new_zeros(bs, 1)

        log_like = Normal(y, self.args.const.likelihood_sigma).log_prob(x)

        log_imp_list = []
        if self.args.log.phase_nll:
            log_pres_prior = z_pres_reshape * torch.log(p_pres_given_g_probs_reshaped + self.args.const.eps) + \
                             (1 - z_pres_reshape) * torch.log(1 - p_pres_given_g_probs_reshaped + self.args.const.eps)
            log_pres_pos = z_pres_reshape * torch.log(q_pres_given_x_and_g_probs_reshaped + self.args.const.eps) + \
                           (1 - z_pres_reshape) * torch.log(
                1 - q_pres_given_x_and_g_probs_reshaped + self.args.const.eps)

            log_imp_pres = log_pres_prior - log_pres_pos

            log_imp_depth = p_depth_given_g.log_prob(z_depth_reshape) - \
                            q_depth_given_x_and_g.log_prob(z_depth_reshape)

            log_imp_what = p_what_given_g.log_prob(z_what_reshape) - \
                           q_what_given_x_and_g.log_prob(z_what_reshape)

            log_imp_where = p_where_given_g.log_prob(z_where_origin_reshape) - \
                            q_where_given_x_and_g.log_prob(z_where_origin_reshape)

            if self.args.arch.phase_background:
                log_imp_bg = p_bg.log_prob(z_bg) - q_bg.log_prob(z_bg)
            else:
                log_imp_bg = x.new_zeros(bs, 1)

            log_imp_g = p_global_all.log_prob(z_g) - q_global_all.log_prob(z_g)

            log_imp_list = [
                log_imp_pres.view(bs, self.args.arch.num_cell,
                                  self.args.arch.num_cell,
                                  -1).flatten(start_dim=1).sum(1),
                log_imp_depth.view(bs, self.args.arch.num_cell,
                                   self.args.arch.num_cell,
                                   -1).flatten(start_dim=1).sum(1),
                log_imp_what.view(bs, self.args.arch.num_cell,
                                  self.args.arch.num_cell,
                                  -1).flatten(start_dim=1).sum(1),
                log_imp_where.view(bs, self.args.arch.num_cell,
                                   self.args.arch.num_cell,
                                   -1).flatten(start_dim=1).sum(1),
                log_imp_bg.flatten(start_dim=1).sum(1),
                log_imp_g.flatten(start_dim=1).sum(1),
            ]

        return log_like.flatten(start_dim=1).sum(1), \
               [
                   aux_kl_pres.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(
                       -1),
                   aux_kl_where.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(
                       -1),
                   aux_kl_depth.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(
                       -1),
                   aux_kl_what.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(
                       -1),
                   aux_kl_bg.flatten(start_dim=1).sum(-1),
                   kl_pres.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(-1),
                   kl_where.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(-1),
                   kl_depth.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(-1),
                   kl_what.view(bs, self.args.arch.num_cell, self.args.arch.num_cell, -1).flatten(start_dim=1).sum(-1),
                   kl_global_all.flatten(start_dim=2).sum(-1),
                   kl_bg.flatten(start_dim=1).sum(-1)
               ], log_imp_list
    l2_sum = 0
    kl_sum = 0
    disc_inv_sum = 0
    disc_latent_sum = 0
    total_sum = 0
    for i in range(args.samples_per_epoch):
        x,target = next(train_loader)
        x = x.cuda()
        target = target.cuda()
        z_inv,mu_logvar,hidden = dvml_model(x)

        q_dist = utils.get_normal_from_params(mu_logvar)
        p_dist = dist.Normal(0,1)

        #KL Loss
        kl_loss = dist.kl_divergence(q_dist,p_dist).sum(-1).mean()

        #Discriminate invar
        disc_inv_loss = triplet_loss(z_inv,target)

        if phase1:
            z_var = q_dist.sample()
            z = z_inv + z_var
            decoded = decoder_model(z.detach())
            #l2_loss
            l2_loss = F.mse_loss(hidden,decoded)
        else:
            z_var = q_dist.sample((20,))
            z = z_inv[None] + z_var
            decoded = decoder_model(z.reshape([-1,args.z_dims]))
            decoded = decoded.reshape([20,-1,1024])
示例#12
0
	def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
		# Get the relevant quantities
		rewards = batch["reward"][:, :-1]
		actions = batch["actions"][:, :-1]
		terminated = batch["terminated"][:, :-1].float()
		mask = batch["filled"][:, :-1].float()
		mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
		avail_actions = batch["avail_actions"]

		# Calculate estimated Q-Values
		# shape = (bs, self.n_agents, -1)
		mac_out = []
		mu_out = []
		sigma_out = []
		logits_out = []
		m_sample_out = []
		g_out = []
		#reconstruct_losses = []
		self.mac.init_hidden(batch.batch_size)
		for t in range(batch.max_seq_length):
			if self.args.comm and self.args.use_IB:
				# agent_outs, (mu, sigma), logits, m_sample = self.mac.forward(batch, t=t)
				# mu_out.append(mu)
				# sigma_out.append(sigma)
				# logits_out.append(logits)
				# m_sample_out.append(m_sample)
				
				#reconstruct_losses.append((info['reconstruct_loss']**2).sum())
				agent_outs, info = self.mac.forward(batch, t=t)
				mu_out.append(info['mu'])
				sigma_out.append(info['sigma'])
				logits_out.append(info['logits'])
				m_sample_out.append(info['m_sample'])
				
			else:
				agent_outs = self.mac.forward(batch, t=t)
			mac_out.append(agent_outs)
		mac_out = th.stack(mac_out, dim=1)  # Concat over time
		if self.args.use_IB:
			mu_out = th.stack(mu_out, dim=1)[:, :-1]  # Concat over time
			sigma_out = th.stack(sigma_out, dim=1)[:, :-1]  # Concat over time
			logits_out = th.stack(logits_out, dim=1)[:, :-1]
			m_sample_out = th.stack(m_sample_out, dim=1)[:, :-1]

		# Pick the Q-Values for the actions taken by each agent
		chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3)  # Remove the last dim
		# I believe that code up to here is right...

		# Q values are right, the main issue is to calculate loss for message...

		# Calculate the Q-Values necessary for the target
		target_mac_out = []
		self.target_mac.init_hidden(batch.batch_size)
		for t in range(batch.max_seq_length):
			if self.args.comm and self.args.use_IB:
				#target_agent_outs, (target_mu, target_sigma), target_logits, target_m_sample = \
				#	self.target_mac.forward(batch, t=t)
				target_agent_outs, target_info = self.target_mac.forward(batch, t=t)
			else:
				target_agent_outs = self.target_mac.forward(batch, t=t)
			target_mac_out.append(target_agent_outs)

		# label
		label_target_max_out = th.stack(target_mac_out[:-1], dim=1)
		label_target_max_out[avail_actions[:, :-1] == 0] = -9999999
		label_target_actions = label_target_max_out.max(dim=3, keepdim=True)[1]

		# We don't need the first timesteps Q-Value estimate for calculating targets
		target_mac_out = th.stack(target_mac_out[1:], dim=1)  # Concat across time

		# Mask out unavailable actions
		target_mac_out[avail_actions[:, 1:] == 0] = -9999999

		# Max over target Q-Values
		if self.args.double_q:
			# Get actions that maximise live Q (for double q-learning)
			mac_out_detach = mac_out.clone().detach()
			mac_out_detach[avail_actions == 0] = -9999999
			cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1]

			target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3)
		else:
			target_max_qvals = target_mac_out.max(dim=3)[0]

		# Mix
		if self.mixer is not None:
			chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1])
			target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:])

		# Calculate 1-step Q-Learning targets
		targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals

		# Td-error
		td_error = (chosen_action_qvals - targets.detach())

		mask = mask.expand_as(td_error)

		# 0-out the targets that came from padded data
		masked_td_error = td_error * mask

		# Normal L2 loss, take mean over actual data
		loss = (masked_td_error ** 2).sum() / mask.sum()

		if self.args.only_downstream or not self.args.use_IB:
			expressiveness_loss = th.Tensor([0.])
			compactness_loss = th.Tensor([0.])
			entropy_loss = th.Tensor([0.])
			comm_loss = th.Tensor([0.])
			comm_beta = th.Tensor([0.])
			comm_entropy_beta = th.Tensor([0.])
			#rec_loss = sum(reconstruct_losses)
			#loss += 0.01*rec_loss
		else:
			# ### Optimize message
			# Message are controlled only by expressiveness and compactness loss.
			# Compute cross entropy with target q values of the same time step
			expressiveness_loss = 0
			label_prob = th.gather(logits_out, 3, label_target_actions).squeeze(3)
			expressiveness_loss += (-th.log(label_prob + 1e-6)).sum() / mask.sum()

			# Compute KL divergence
			compactness_loss = D.kl_divergence(D.Normal(mu_out, sigma_out), D.Normal(self.s_mu, self.s_sigma)).sum() / \
			                   mask.sum()

			# Entropy loss
			entropy_loss = -D.Normal(self.s_mu, self.s_sigma).log_prob(m_sample_out).sum() / mask.sum()

			# Gate loss
			gate_loss = 0

			# Total loss
			comm_beta = self.get_comm_beta(t_env)
			comm_entropy_beta = self.get_comm_entropy_beta(t_env)
			comm_loss = expressiveness_loss + comm_beta * compactness_loss + comm_entropy_beta * entropy_loss
			comm_loss *= self.args.c_beta
			loss += comm_loss
			comm_beta = th.Tensor([comm_beta])
			comm_entropy_beta = th.Tensor([comm_entropy_beta])

		# Optimise
		self.optimiser.zero_grad()
		loss.backward()
		grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
		self.optimiser.step()

		# Update target
		if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0:
			self._update_targets()
			self.last_target_update_episode = episode_num

		if t_env - self.log_stats_t >= self.args.learner_log_interval:
			self.logger.log_stat("loss", loss.item(), t_env)
			self.logger.log_stat("comm_loss", comm_loss.item(), t_env)
			self.logger.log_stat("exp_loss", expressiveness_loss.item(), t_env)
			self.logger.log_stat("comp_loss", compactness_loss.item(), t_env)
			self.logger.log_stat("comm_beta", comm_beta.item(), t_env)
			self.logger.log_stat("entropy_loss", entropy_loss.item(), t_env)
			self.logger.log_stat("comm_beta", comm_beta.item(), t_env)
			self.logger.log_stat("comm_entropy_beta", comm_entropy_beta.item(), t_env)
			self.logger.log_stat("grad_norm", grad_norm, t_env)
			mask_elems = mask.sum().item()
			self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env)
			self.logger.log_stat("q_taken_mean",
			                     (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env)
			self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents),
			                     t_env)
			#self.logger.log_stat("reconstruct_loss", rec_loss.item(), t_env)
			self.log_stats_t = t_env
示例#13
0
def run_epoch(experiment, network, optimizer, dataloader, config, use_tqdm = False, debug=False, plot=False):
    cum_loss = 0.0
    cum_param_loss = 0.0
    cum_position_loss = 0.0
    cum_velocity_loss = 0.0
    num_samples=0.0
    if use_tqdm:
        t = tqdm(enumerate(dataloader), total=len(dataloader))
    else:
        t = enumerate(dataloader)
    network.train()  # This is important to call before training!
    dataloaderlen = len(dataloader)
    
    dev = next(network.parameters()).device  # we are only doing single-device training for now, so this works fine.
    dtype = next(network.parameters()).dtype # we are only doing single-device training for now, so this works fine.
    loss_weights = config["loss_weights"]
    positionerror = loss_functions.SquaredLpNormLoss().type(dtype).to(dev)

    #_, _, _, _, _, _, sample_session_times,_,_ = dataloader.dataset[0]
    bezier_order = network.bezier_order
    d = network.output_dimension
    for (i, imagedict) in t:
        track_names = imagedict["track"]
        input_images = imagedict["images"].type(dtype).to(device=dev)
        batch_size = input_images.shape[0]
        session_times = imagedict["session_times"].type(dtype).to(device=dev)
        ego_positions = imagedict["ego_positions"].type(dtype).to(device=dev)
        ego_velocities = imagedict["ego_velocities"].type(dtype).to(device=dev)
        targets = ego_positions

        
        dt = session_times[:,-1]-session_times[:,0]
        s_torch_cur = (session_times - session_times[:,0,None])/dt[:,None]

        M, controlpoints_fit = deepracing_models.math_utils.bezier.bezierLsqfit(targets, bezier_order, t = s_torch_cur)
        Msquare = torch.square(M)

                
        means, varfactors, covarfactors = network(input_images)
        scale_trils = torch.diag_embed(varfactors) + torch.diag_embed(covarfactors, offset=-1)
        covars = torch.matmul(scale_trils, scale_trils.transpose(2,3))
        covars_expand = covars.unsqueeze(1).expand(batch_size, Msquare.shape[1], Msquare.shape[2], d, d)
        poscovar = torch.sum(Msquare[:,:,:,None,None]*covars_expand, dim=2)
        posmeans = torch.matmul(M, means)
        initial_points = targets[:,0].unsqueeze(1)
        final_points = (dt[:,None]*ego_velocities[:,0]).unsqueeze(1)
        deltas = final_points - initial_points
        ds = torch.linspace(0.0,1.0,steps=means.shape[1])
        straight_lines = torch.cat([initial_points + t.item()*deltas for t in ds], dim=1)
        priorscaletril = torch.diag_embed(torch.ones_like(straight_lines))
        priorcurves = D.MultivariateNormal(controlpoints_fit, scale_tril=priorscaletril, validate_args=False)
        distcurves = D.MultivariateNormal(means, scale_tril=scale_trils, validate_args=False)
        distpos = D.MultivariateNormal(posmeans, covariance_matrix=poscovar, validate_args=False)

        position_error = positionerror(posmeans, targets)
        log_probs = distpos.log_prob(ego_positions)
        NLL = torch.mean(-log_probs)
        kl_divergences = D.kl_divergence(distcurves, priorcurves)
        mean_kl = torch.mean(kl_divergences)

        
        if debug and plot:
            fig, (ax1, ax2) = plt.subplots(1, 2, sharey=False)
            print("position_error: %f" % position_error.item() )
            images_np = np.round(255.0*input_images[0].detach().cpu().numpy().copy().transpose(0,2,3,1)).astype(np.uint8)
            #image_np_transpose=skimage.util.img_as_ubyte(images_np[-1].transpose(1,2,0))
            # oap = other_agent_positions[other_agent_positions==other_agent_positions].view(1,-1,60,2)
            # print(oap)
            ims = []
            for i in range(images_np.shape[0]):
                ims.append([ax1.imshow(images_np[i])])
            ani = animation.ArtistAnimation(fig, ims, interval=250, blit=True, repeat=True)

            fit_points = torch.matmul(M, controlpoints_fit)
            prior_points = torch.matmul(M, straight_lines)

            # gt_points_np = ego_positions[0].detach().cpu().numpy().copy()
            gt_points_np = targets[0].detach().cpu().numpy().copy()
            pred_points_np = posmeans[0].detach().cpu().numpy().copy()
            pred_control_points_np = means[0].detach().cpu().numpy().copy()
            fit_points_np = fit_points[0].cpu().numpy().copy()
            fit_control_points_np = controlpoints_fit[0].cpu().numpy().copy()

            prior_points_np = prior_points[0].cpu().numpy().copy()
            prior_control_points_np = straight_lines[0].cpu().numpy().copy()
            
            ymin = np.min(np.hstack([gt_points_np[:,1], pred_points_np[:,1] ])) - 2.5
            ymax = np.max(np.hstack([gt_points_np[:,1], pred_points_np[:,1] ])) + 2.5
            xmin = np.min(np.hstack([gt_points_np[:,0], fit_points_np[:,0] ])) -  2.5
            xmax = np.max(np.hstack([gt_points_np[:,0], fit_points_np[:,0] ]))
            ax2.set_xlim(xmax,xmin)
            ax2.set_ylim(ymin,ymax)
            ax2.plot(gt_points_np[:,0],gt_points_np[:,1],'g+', label="Ground Truth Waypoints")
            ax2.plot(pred_points_np[:,0],pred_points_np[:,1],'r-', label="Predicted Bézier Curve")
            ax2.plot(prior_points_np[:,0],prior_points_np[:,1], label="Prior")
            # ax2.plot(fit_points_np[:,0],fit_points_np[:,1],'b-', label="Best-fit Bézier Curve")
            #ax2.scatter(fit_control_points_np[1:,0],fit_control_points_np[1:,1],c="b", label="Bézier Curve's Control Points")
       #     ax2.plot(pred_points_np[:,1],pred_points_np[:,0],'r-', label="Predicted Bézier Curve")
          #  ax2.scatter(pred_control_points_np[:,1],pred_control_points_np[:,0], c='r', label="Predicted Bézier Curve's Control Points")
            plt.legend()
            plt.show()

        # loss = position_error
        loss = loss_weights["position"]*position_error + loss_weights["nll"]*NLL + loss_weights["kl_divergence"]*mean_kl
        #loss = loss_weights["position"]*position_error
        optimizer.zero_grad()
        loss.backward() 
        # Weight and bias updates.
        optimizer.step()
        # logging information
        current_position_loss_float = float(position_error.item())
        num_samples += 1.0
        if not debug:
            experiment.log_metric("current_position_loss", current_position_loss_float)
            experiment.log_metric("logprob", NLL.item())
            experiment.log_metric("kl_divergence", mean_kl.item())
        if use_tqdm:
            t.set_postfix({"current_position_loss" : current_position_loss_float})
示例#14
0
 def _kl_loss(self, prior_dist, post_dist):
     # 1
     return td.kl_divergence(prior_dist, post_dist).clamp(min=self.kl_free_nats).mean()
示例#15
0
def _prod_prod(p, q):
    assert p.n_dists == q.n_dists, "I need the same number of distributions"
    return torch.cat([
        kl_divergence(pi, qi).unsqueeze(-1)
        for pi, qi in zip(p.distributions, q.distributions)
    ], -1).sum(-1)
示例#16
0
def _kl_independent_independent(p, q):
    if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
        raise NotImplementedError
    result = kl_divergence(p.base_dist, q.base_dist)
    return _sum_rightmost(result, p.reinterpreted_batch_ndims)
示例#17
0
    def learn(  # type: ignore
            self, batch: Batch, batch_size: int, repeat: int,
            **kwargs: Any) -> Dict[str, List[float]]:
        actor_losses, vf_losses, step_sizes, kls = [], [], [], []
        for step in range(repeat):
            for b in batch.split(batch_size, merge_last=True):
                # optimize actor
                # direction: calculate villia gradient
                dist = self(b).dist  # TODO could come from batch
                ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
                ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
                actor_loss = -(ratio * b.adv).mean()
                flat_grads = self._get_flat_grad(actor_loss,
                                                 self.actor,
                                                 retain_graph=True).detach()

                # direction: calculate natural gradient
                with torch.no_grad():
                    old_dist = self(b).dist

                kl = kl_divergence(old_dist, dist).mean()
                # calculate first order gradient of kl with respect to theta
                flat_kl_grad = self._get_flat_grad(kl,
                                                   self.actor,
                                                   create_graph=True)
                search_direction = -self._conjugate_gradients(
                    flat_grads, flat_kl_grad, nsteps=10)

                # stepsize: calculate max stepsize constrained by kl bound
                step_size = torch.sqrt(
                    2 * self._delta / (search_direction * self._MVP(
                        search_direction, flat_kl_grad)).sum(0, keepdim=True))

                # stepsize: linesearch stepsize
                with torch.no_grad():
                    flat_params = torch.cat([
                        param.data.view(-1)
                        for param in self.actor.parameters()
                    ])
                    for i in range(self._max_backtracks):
                        new_flat_params = flat_params + step_size * search_direction
                        self._set_from_flat_params(self.actor, new_flat_params)
                        # calculate kl and if in bound, loss actually down
                        new_dist = self(b).dist
                        new_dratio = (new_dist.log_prob(b.act) -
                                      b.logp_old).exp().float()
                        new_dratio = new_dratio.reshape(
                            new_dratio.size(0), -1).transpose(0, 1)
                        new_actor_loss = -(new_dratio * b.adv).mean()
                        kl = kl_divergence(old_dist, new_dist).mean()

                        if kl < self._delta and new_actor_loss < actor_loss:
                            if i > 0:
                                warnings.warn(f"Backtracking to step {i}.")
                            break
                        elif i < self._max_backtracks - 1:
                            step_size = step_size * self._backtrack_coeff
                        else:
                            self._set_from_flat_params(self.actor,
                                                       new_flat_params)
                            step_size = torch.tensor([0.0])
                            warnings.warn(
                                "Line search failed! It seems hyperparamters"
                                " are poor and need to be changed.")

                # optimize citirc
                for _ in range(self._optim_critic_iters):
                    value = self.critic(b.obs).flatten()
                    vf_loss = F.mse_loss(b.returns, value)
                    self.optim.zero_grad()
                    vf_loss.backward()
                    self.optim.step()

                actor_losses.append(actor_loss.item())
                vf_losses.append(vf_loss.item())
                step_sizes.append(step_size.item())
                kls.append(kl.item())

        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {
            "loss/actor": actor_losses,
            "loss/vf": vf_losses,
            "step_size": step_sizes,
            "kl": kls,
        }
示例#18
0
    old_agent.load_state_dict(agent.state_dict())
    aux_inds = np.arange(args.aux_batch_size, )
    print("aux phase starts")
    for auxiliary_update in range(1, args.e_auxiliary + 1):
        np.random.shuffle(aux_inds)
        for i, start in enumerate(
                range(0, args.aux_batch_size, args.aux_minibatch_size)):
            end = start + args.aux_minibatch_size
            aux_minibatch_ind = aux_inds[start:end]
            try:
                m_aux_obs = aux_obs[aux_minibatch_ind].to(device)
                m_aux_returns = aux_returns[aux_minibatch_ind].to(device)

                new_values = agent.get_value(m_aux_obs).view(-1)
                new_aux_values = agent.get_aux_value(m_aux_obs).view(-1)
                kl_loss = td.kl_divergence(old_agent.get_pi(m_aux_obs),
                                           agent.get_pi(m_aux_obs)).mean()

                real_value_loss = 0.5 * (
                    (new_values - m_aux_returns)**2).mean()
                aux_value_loss = 0.5 * (
                    (new_aux_values - m_aux_returns)**2).mean()
                joint_loss = aux_value_loss + args.beta_clone * kl_loss

                optimizer.zero_grad()
                loss = (joint_loss + real_value_loss) / args.n_aux_grad_accum
                loss.backward()
                if (i + 1) % args.n_aux_grad_accum == 0:
                    nn.utils.clip_grad_norm_(agent.parameters(),
                                             args.max_grad_norm)
                    optimizer.step()
            except RuntimeError:
示例#19
0
if __name__ == '__main__':
    from pyro.distributions import Dirichlet, MultivariateNormal
    from torch.distributions import kl_divergence
    from distributions.mixture import Mixture

    B, D1, D2 = 5, 3, 4
    N = 1000

    dist1 = MultivariateNormal(torch.zeros(D1), torch.eye(D1)).expand((B, ))
    dist2 = Dirichlet(torch.ones(D2)).expand((B, ))
    print(dist1.batch_shape, dist1.event_shape)
    print(dist2.batch_shape, dist2.event_shape)
    fact = Factorised([dist1, dist2])
    print(fact.batch_shape, fact.event_shape)
    samples = fact.rsample((N, ))
    print(samples[0])
    print(samples.shape)
    logp = fact.log_prob(samples)
    print(logp.shape)
    entropy = fact.entropy()
    print(entropy.shape)
    print(entropy, -logp.mean())
    print()

    print(kl_divergence(fact, fact))
    mixture = Mixture(torch.ones(B), fact)
    samples = mixture.rsample((N, ))
    logp = mixture.log_prob(samples)
    print(samples.shape)
    print(logp.shape)
示例#20
0
def _kl_factorised_factorised(p: Factorised, q: Factorised):
    return sum(
        kl_divergence(p_factor, q_factor)
        for p_factor, q_factor in zip(p.factors, q.factors))
示例#21
0
def total_kld(posterior, prior=None):
    if prior is None:
        prior = standard_prior_like(posterior)
    return torch.sum(kl_divergence(posterior, prior))
    def forward(self,
                x,
                img_enc,
                alpha_map_prop,
                ids_prop,
                lengths,
                t,
                eps=1e-15):
        """
            :param z_what_prop: (bs, max_num_obj, dim)
            :param z_where_prop: (bs, max_num_obj, 4)
            :param z_pres_prop: (bs, max_num_obj, 1)
            :param alpha_map_prop: (bs, 1, img_h, img_w)
        """
        bs = x.size(0)
        device = x.device
        alpha_map_prop = alpha_map_prop.detach()

        max_num_disc_obj = (self.max_num_obj - lengths).long()

        self.prior_z_pres_prob = linear_annealing(
            self.args.global_step, self.z_pres_anneal_start_step,
            self.z_pres_anneal_end_step, self.z_pres_anneal_start_value,
            self.z_pres_anneal_end_value, device)

        # z_where: (bs * num_cell_h * num_cell_w, 4)
        # z_pres, z_depth, z_pres_logits: (bs, dim, num_cell_h, num_cell_w)
        z_where, z_pres, z_depth, z_where_mean, z_where_std, \
        z_depth_mean, z_depth_std, z_pres_logits, z_pres_y, z_where_origin = self.ProposalNet(
            img_enc, alpha_map_prop, self.args.tau, t, gen_pres_probs=x.new_ones(1) * self.args.gen_disc_pres_probs,
            gen_depth_mean=self.prior_depth_mean, gen_depth_std=self.prior_depth_std,
            gen_where_mean=self.prior_where_mean, gen_where_std=self.prior_where_std
        )
        num_cell_h, num_cell_w = z_pres.shape[2], z_pres.shape[3]

        q_z_where = Normal(z_where_mean, z_where_std)
        q_z_depth = Normal(z_depth_mean, z_depth_std)

        z_pres_orgin = z_pres

        if self.args.phase_generate and t >= self.args.observe_frames:
            z_what_mean, z_what_std = self.prior_what_mean.view(1, 1).expand(bs * self.args.num_cell_h *
                                                                             self.args.num_cell_w, z_what_dim), \
                                      self.prior_what_std.view(1, 1).expand(bs * self.args.num_cell_h *
                                                                            self.args.num_cell_w, z_what_dim)
            x_att = x.new_zeros(1)
        else:
            # (bs * num_cell_h * num_cell_w, 3, glimpse_size, glimpse_size)
            x_att = spatial_transform(
                torch.stack(num_cell_h * num_cell_w * (x, ),
                            dim=1).view(-1, 3, img_h, img_w),
                z_where,
                (bs * num_cell_h * num_cell_w, 3, glimpse_size, glimpse_size),
                inverse=False)

            # (bs * num_cell_h * num_cell_w, dim)
            z_what_mean, z_what_std = self.z_what_net(x_att)
            z_what_std = F.softplus(z_what_std)

        q_z_what = Normal(z_what_mean, z_what_std)

        z_what = q_z_what.rsample()

        # (bs * num_cell_h * num_cell_w, dim, glimpse_size, glimpse_size)
        o_att, alpha_att = self.glimpse_dec(z_what)

        # Rejection
        if phase_rejection and t > 0:
            alpha_map_raw = spatial_transform(
                alpha_att,
                z_where, (bs * num_cell_h * num_cell_w, 1, img_h, img_w),
                inverse=True)

            alpha_map_proposed = (alpha_map_raw > 0.3).float()

            alpha_map_prop = (alpha_map_prop > 0.1).float().view(bs, 1, 1, img_h, img_w) \
                .expand(-1, num_cell_h * num_cell_w, -1, -1, -1).contiguous().view(-1, 1, img_h, img_w)

            alpha_map_intersect = alpha_map_proposed * alpha_map_prop

            explained_ratio = alpha_map_intersect.view(bs * num_cell_h * num_cell_w, -1).sum(1) / \
                              (alpha_map_proposed.view(bs * num_cell_h * num_cell_w, -1).sum(1) + eps)

            pres_mask = (explained_ratio <
                         self.args.explained_ratio_threshold).view(
                             bs, 1, num_cell_h, num_cell_w).float()

            z_pres = z_pres * pres_mask

        # The following "if" is useful only if you don't have high-memery GPUs, better to remove it if you do
        if self.training and phase_obj_num_contrain:
            z_pres = z_pres.view(bs, -1)

            z_pres_threshold = z_pres.sort(
                dim=1, descending=True)[0][torch.arange(bs), max_num_disc_obj]

            z_pres_mask = (z_pres > z_pres_threshold.view(bs, -1)).float()

            if self.args.phase_generate and t >= self.args.observe_frames:
                z_pres_mask = x.new_zeros(z_pres_mask.size())

            z_pres = z_pres * z_pres_mask

            z_pres = z_pres.view(bs, 1, num_cell_h, num_cell_w)

        alpha_att_hat = alpha_att * z_pres.view(-1, 1, 1, 1)

        y_att = alpha_att_hat * o_att

        # (bs * num_cell_h * num_cell_w, 3, img_h, img_w)
        y_each_cell = spatial_transform(
            y_att,
            z_where, (bs * num_cell_h * num_cell_w, 3, img_h, img_w),
            inverse=True)

        # (bs * num_cell_h * num_cell_w, 1, glimpse_size, glimpse_size)
        importance_map = alpha_att_hat * torch.sigmoid(-z_depth).view(
            -1, 1, 1, 1)
        # importance_map = -z_depth.view(-1, 1, 1, 1).expand_as(alpha_att_hat)
        # (bs * num_cell_h * num_cell_w, 1, img_h, img_w)
        importance_map_full_res = spatial_transform(
            importance_map,
            z_where, (bs * num_cell_h * num_cell_w, 1, img_h, img_w),
            inverse=True)

        # (bs * num_cell_h * num_cell_w, 1, img_h, img_w)
        alpha_map = spatial_transform(
            alpha_att_hat,
            z_where, (bs * num_cell_h * num_cell_w, 1, img_h, img_w),
            inverse=True)

        # (bs * num_cell_h * num_cell_w, z_what_dim)
        kl_z_what = kl_divergence(q_z_what, self.p_z_what) * z_pres_orgin.view(
            -1, 1)
        # (bs, num_cell_h * num_cell_w, z_what_dim)
        kl_z_what = kl_z_what.view(-1, num_cell_h * num_cell_w, z_what_dim)
        # (bs * num_cell_h * num_cell_w, z_depth_dim)
        kl_z_depth = kl_divergence(q_z_depth, self.p_z_depth) * z_pres_orgin
        # (bs, num_cell_h * num_cell_w, z_depth_dim)
        kl_z_depth = kl_z_depth.view(-1, num_cell_h * num_cell_w, z_depth_dim)
        # (bs, dim, num_cell_h, num_cell_w)
        kl_z_where = kl_divergence(q_z_where, self.p_z_where) * z_pres_orgin
        if phase_rejection and t > 0:
            kl_z_pres = calc_kl_z_pres_bernoulli(
                z_pres_logits, self.prior_z_pres_prob * pres_mask +
                self.z_pres_masked_prior * (1 - pres_mask))
        else:
            kl_z_pres = calc_kl_z_pres_bernoulli(z_pres_logits,
                                                 self.prior_z_pres_prob)

        kl_z_pres = kl_z_pres.view(-1, num_cell_h * num_cell_w, z_pres_dim)

        ########################################### Compute log importance ############################################
        log_imp = x.new_zeros(bs, 1)
        if not self.training and self.args.phase_nll:
            z_pres_orgin_binary = (z_pres_orgin > 0.5).float()
            # (bs * num_cell_h * num_cell_w, dim)
            log_imp_what = (
                self.p_z_what.log_prob(z_what) -
                q_z_what.log_prob(z_what)) * z_pres_orgin_binary.view(-1, 1)
            log_imp_what = log_imp_what.view(-1, num_cell_h * num_cell_w,
                                             z_what_dim)
            # (bs, dim, num_cell_h, num_cell_w)
            log_imp_depth = (self.p_z_depth.log_prob(z_depth) -
                             q_z_depth.log_prob(z_depth)) * z_pres_orgin_binary
            # (bs, dim, num_cell_h, num_cell_w)
            log_imp_where = (
                self.p_z_where.log_prob(z_where_origin) -
                q_z_where.log_prob(z_where_origin)) * z_pres_orgin_binary
            if phase_rejection and t > 0:
                p_z_pres = self.prior_z_pres_prob * pres_mask + self.z_pres_masked_prior * (
                    1 - pres_mask)
            else:
                p_z_pres = self.prior_z_pres_prob

            z_pres_binary = (z_pres > 0.5).float()

            log_pres_prior = z_pres_binary * torch.log(p_z_pres + eps) + \
                             (1 - z_pres_binary) * torch.log(1 - p_z_pres + eps)

            log_pres_pos = z_pres_binary * torch.log(torch.sigmoid(z_pres_logits) + eps) + \
                           (1 - z_pres_binary) * torch.log(1 - torch.sigmoid(z_pres_logits) + eps)

            log_imp_pres = log_pres_prior - log_pres_pos

            log_imp = log_imp_what.flatten(start_dim=1).sum(dim=1) + log_imp_depth.flatten(start_dim=1).sum(1) + \
                      log_imp_where.flatten(start_dim=1).sum(1) + log_imp_pres.flatten(start_dim=1).sum(1)

        ######################################## End of Compute log importance #########################################

        # (bs, num_cell_h * num_cell_w)
        ids = torch.arange(num_cell_h * num_cell_w).view(1, -1).expand(bs, -1).to(x.device).float() + \
              ids_prop.max(dim=1, keepdim=True)[0] + 1

        if self.args.log_phase:
            self.log = {
                'z_what': z_what,
                'z_where': z_where,
                'z_pres': z_pres,
                'z_pres_logits': z_pres_logits,
                'z_what_std': q_z_what.stddev,
                'z_what_mean': q_z_what.mean,
                'z_where_std': q_z_where.stddev,
                'z_where_mean': q_z_where.mean,
                'x_att': x_att,
                'y_att': y_att,
                'prior_z_pres_prob': self.prior_z_pres_prob.unsqueeze(0),
                'o_att': o_att,
                'alpha_att_hat': alpha_att_hat,
                'alpha_att': alpha_att,
                'y_each_cell': y_each_cell,
                'z_depth': z_depth,
                'z_depth_std': q_z_depth.stddev,
                'z_depth_mean': q_z_depth.mean,
                # 'importance_map_full_res_norm': importance_map_full_res_norm,
                'z_pres_y': z_pres_y,
                'ids': ids
            }
        else:
            self.log = {}

        return y_each_cell.view(bs, num_cell_h * num_cell_w, 3, img_h, img_w), \
               alpha_map.view(bs, num_cell_h * num_cell_w, 1, img_h, img_w), \
               importance_map_full_res.view(bs, num_cell_h * num_cell_w, 1, img_h, img_w), \
               z_what.view(bs, num_cell_h * num_cell_w, -1), z_where.view(bs, num_cell_h * num_cell_w, -1), \
               torch.zeros_like(z_where.view(bs, num_cell_h * num_cell_w, -1)), \
               z_depth.view(bs, num_cell_h * num_cell_w, -1), z_pres.view(bs, num_cell_h * num_cell_w, -1), ids, \
               kl_z_what.flatten(start_dim=1).sum(dim=1), \
               kl_z_where.flatten(start_dim=1).sum(dim=1), \
               kl_z_pres.flatten(start_dim=1).sum(dim=1), \
               kl_z_depth.flatten(start_dim=1).sum(dim=1), \
               log_imp, self.log
示例#23
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        context = batch['context']

        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist, p_z, task_z_with_grad = self.agent(
            obs,
            context,
            return_latent_posterior_and_task_z=True,
        )
        task_z_detached = task_z_with_grad.detach()
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        log_pi = log_pi.unsqueeze(1)
        next_dist = self.agent(next_obs, context)

        if self._debug_ignore_context:
            task_z_with_grad = task_z_with_grad * 0

        # flattens out the task dimension
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        actions = actions.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)
        unscaled_rewards_flat = rewards.view(t * b, 1)
        rewards_flat = unscaled_rewards_flat * self.reward_scale
        terms_flat = terminals.view(t * b, 1)

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha
        """
        QF Loss
        """
        if self.backprop_q_loss_into_encoder:
            q1_pred = self.qf1(obs, actions, task_z_with_grad)
            q2_pred = self.qf2(obs, actions, task_z_with_grad)
        else:
            q1_pred = self.qf1(obs, actions, task_z_detached)
            q2_pred = self.qf2(obs, actions, task_z_detached)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        new_log_pi = new_log_pi.unsqueeze(1)
        with torch.no_grad():
            target_q_values = torch.min(
                self.target_qf1(next_obs, new_next_actions, task_z_detached),
                self.target_qf2(next_obs, new_next_actions, task_z_detached),
            ) - alpha * new_log_pi

        q_target = rewards_flat + (
            1. - terms_flat) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Context Encoder Loss
        """
        if self._debug_use_ground_truth_context:
            kl_div = kl_loss = ptu.zeros(0)
        else:
            kl_div = kl_divergence(p_z,
                                   self.agent.latent_prior).mean(dim=0).sum()
            kl_loss = self.kl_lambda * kl_div

        if self.train_context_decoder:
            # TODO: change to use a distribution
            reward_pred = self.context_decoder(obs, actions, task_z_with_grad)
            reward_prediction_loss = ((reward_pred -
                                       unscaled_rewards_flat)**2).mean()
            context_loss = kl_loss + reward_prediction_loss
        else:
            context_loss = kl_loss
            reward_prediction_loss = ptu.zeros(1)
        """
        Policy Loss
        """
        qf1_new_actions = self.qf1(obs, new_obs_actions, task_z_detached)
        qf2_new_actions = self.qf2(obs, new_obs_actions, task_z_detached)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        # Advantage-weighted regression
        if self.vf_K > 1:
            vs = []
            for i in range(self.vf_K):
                u = dist.sample()
                q1 = self.qf1(obs, u, task_z_detached)
                q2 = self.qf2(obs, u, task_z_detached)
                v = torch.min(q1, q2)
                # v = q1
                vs.append(v)
            v_pi = torch.cat(vs, 1).mean(dim=1)
        else:
            # v_pi = self.qf1(obs, new_obs_actions)
            v1_pi = self.qf1(obs, new_obs_actions, task_z_detached)
            v2_pi = self.qf2(obs, new_obs_actions, task_z_detached)
            v_pi = torch.min(v1_pi, v2_pi)

        u = actions
        if self.awr_min_q:
            q_adv = torch.min(q1_pred, q2_pred)
        else:
            q_adv = q1_pred

        policy_logpp = dist.log_prob(u)

        if self.use_automatic_beta_tuning:
            buffer_dist = self.buffer_policy(obs)
            beta = self.log_beta.exp()
            kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)
            beta_loss = -1 * (beta *
                              (kldiv - self.beta_epsilon).detach()).mean()

            self.beta_optimizer.zero_grad()
            beta_loss.backward()
            self.beta_optimizer.step()
        else:
            beta = self.beta_schedule.get_value(self._n_train_steps_total)
            beta_loss = ptu.zeros(1)

        score = q_adv - v_pi
        if self.mask_positive_advantage:
            score = torch.sign(score)

        if self.clip_score is not None:
            score = torch.clamp(score, max=self.clip_score)

        weights = batch.get('weights', None)
        if self.weight_loss and weights is None:
            if self.normalize_over_batch == True:
                weights = F.softmax(score / beta, dim=0)
            elif self.normalize_over_batch == "whiten":
                adv_mean = torch.mean(score)
                adv_std = torch.std(score) + 1e-5
                normalized_score = (score - adv_mean) / adv_std
                weights = torch.exp(normalized_score / beta)
            elif self.normalize_over_batch == "exp":
                weights = torch.exp(score / beta)
            elif self.normalize_over_batch == "step_fn":
                weights = (score > 0).float()
            elif self.normalize_over_batch == False:
                weights = score
            elif self.normalize_over_batch == 'uniform':
                weights = F.softmax(ptu.ones_like(score) / beta, dim=0)
            else:
                raise ValueError(self.normalize_over_batch)
        weights = weights[:, 0]

        policy_loss = alpha * log_pi.mean()

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp * len(weights) * weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.train_reparam_weight * (
                -q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss
        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            if self.train_encoder_decoder:
                self.context_optimizer.zero_grad()
            if self.train_agent:
                self.qf1_optimizer.zero_grad()
                self.qf2_optimizer.zero_grad()
            context_loss.backward(retain_graph=True)
            # retain graph because the encoder is trained by both QF losses
            qf1_loss.backward(retain_graph=True)
            qf2_loss.backward()
            if self.train_agent:
                self.qf1_optimizer.step()
                self.qf2_optimizer.step()
            if self.train_encoder_decoder:
                self.context_optimizer.step()

        if self.train_agent:
            if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy:
                self.policy_optimizer.zero_grad()
                policy_loss.backward()
                self.policy_optimizer.step()
        self._num_gradient_steps += 1
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics['task_embedding/kl_divergence'] = (
                ptu.get_numpy(kl_div))
            self.eval_statistics['task_embedding/kl_loss'] = (
                ptu.get_numpy(kl_loss))
            self.eval_statistics['task_embedding/reward_prediction_loss'] = (
                ptu.get_numpy(reward_prediction_loss))
            self.eval_statistics['task_embedding/context_loss'] = (
                ptu.get_numpy(context_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'terminals',
                    ptu.get_numpy(terminals),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            self.eval_statistics.update(policy_statistics)
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Weights',
                    ptu.get_numpy(weights),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Score',
                    ptu.get_numpy(score),
                ))
            self.eval_statistics['reparam_weight'] = self.train_reparam_weight
            self.eval_statistics['num_gradient_steps'] = (
                self._num_gradient_steps)

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.use_automatic_beta_tuning:
                self.eval_statistics.update({
                    "adaptive_beta/beta":
                    ptu.get_numpy(beta.mean()),
                    "adaptive_beta/beta loss":
                    ptu.get_numpy(beta_loss.mean()),
                })

        self._n_train_steps_total += 1
示例#24
0
 def kl_divergence(self):
     return kl_divergence(self.W, self.W_prior)
示例#25
0
def draw_message_distributions_tracker1_2d(mu, sigma, save_dir, azim):
    # sigma *= torch.where(sigma<0.01, torch.ones(sigma.shape).cuda(), 10*torch.ones(sigma.shape).cuda())
    mu = mu.view(-1, 2)
    sigma = sigma.view(-1, 2)

    s_mu = torch.Tensor([0.0, 0.0])
    s_sigma = torch.Tensor([1.0, 1.0])

    x = y = np.arange(100)
    t = np.meshgrid(x, y)

    d = D.Normal(s_mu, s_sigma)
    d1 = D.Normal(mu[0], sigma[0])
    d21 = D.Normal(mu[2], sigma[2])
    d22 = D.Normal(mu[3], sigma[3])
    d31 = D.Normal(mu[4], sigma[4])
    d32 = D.Normal(mu[5], sigma[5])

    print('Entropy')
    print(d1.entropy().detach().cpu().numpy())
    print(d21.entropy().detach().cpu().numpy())
    print(d22.entropy().detach().cpu().numpy())
    print(d31.entropy().detach().cpu().numpy())
    print(d32.entropy().detach().cpu().numpy())

    print('KL Divergence')

    for tt_i in range(3):
        d1 = D.Normal(mu[tt_i * 2 + 0], sigma[tt_i * 2 + 0])
        d2 = D.Normal(mu[tt_i * 2 + 1], sigma[tt_i * 2 + 1])
        print(tt_i,
              D.kl_divergence(d1, d2).sum().detach().cpu().numpy(),
              D.kl_divergence(d1, d).sum().detach().cpu().numpy(),
              D.kl_divergence(d2, d).sum().detach().cpu().numpy(),
              sigma[tt_i * 2 + 0].mean().detach().cpu().numpy(),
              sigma[tt_i * 2 + 1].mean().detach().cpu().numpy())

    # Numpy array of mu and sigma
    s_mu_ = s_mu.detach().cpu().numpy()
    mu_0 = mu[0].detach().cpu().numpy()
    mu_2 = mu[2].detach().cpu().numpy()
    mu_3 = mu[3].detach().cpu().numpy()
    mu_4 = mu[4].detach().cpu().numpy()
    mu_5 = mu[5].detach().cpu().numpy()

    s_sigma_ = s_sigma.detach().cpu().numpy()
    sigma_0 = sigma[0].detach().cpu().numpy()
    sigma_2 = sigma[2].detach().cpu().numpy()
    sigma_3 = sigma[3].detach().cpu().numpy()
    sigma_4 = sigma[4].detach().cpu().numpy()
    sigma_5 = sigma[5].detach().cpu().numpy()

    # Print
    print('mu and sigma')
    print(mu_0, sigma_0)
    print(mu_2, sigma_2)
    print(mu_3, sigma_3)
    print(mu_4, sigma_4)
    print(mu_5, sigma_5)

    # Create grid
    x = np.linspace(-5, 5, 5000)
    y = np.linspace(-5, 5, 5000)
    X, Y = np.meshgrid(x, y)
    pos = np.empty(X.shape + (2, ))
    pos[:, :, 0] = X
    pos[:, :, 1] = Y

    rv = multivariate_normal(s_mu_, [[s_sigma[0], 0], [0, s_sigma[1]]])

    # Agent 1
    # Create multivariate normal
    rv1 = multivariate_normal(mu_0, [[sigma_0[0], 0], [0, sigma_0[1]]])

    # Make a 3D plot
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.plot_surface(X,
                    Y,
                    rv.pdf(pos) + rv1.pdf(pos),
                    cmap='viridis',
                    linewidth=0)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('message')
    plt.tight_layout()
    plt.savefig(save_dir + 'agent1.png')
    ax.view_init(elev=0., azim=azim)
    plt.savefig(save_dir + ("agent1_0_%i.png" % int(azim)))
    ax.view_init(elev=90., azim=0.)
    plt.savefig(save_dir + "agent1_90_0.png")
    # plt.show()
    plt.close()

    # Agent 2
    # Create multivariate normal
    rv21 = multivariate_normal(mu_2, [[sigma_2[0], 0], [0, sigma_2[1]]])
    rv22 = multivariate_normal(mu_3, [[sigma_3[0], 0], [0, sigma_3[1]]])

    # Make a 3D plot
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.plot_surface(X,
                    Y,
                    rv.pdf(pos) + rv21.pdf(pos) + rv22.pdf(pos),
                    cmap='viridis',
                    linewidth=0)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('message')
    plt.tight_layout()
    plt.savefig(save_dir + 'agent2.png')
    ax.view_init(elev=0., azim=azim)
    plt.savefig(save_dir + ("agent2_0_%i.png" % int(azim)))
    ax.view_init(elev=90., azim=0.)
    plt.savefig(save_dir + "agent2_90_0.png")
    # plt.show()
    plt.close()

    # Agent 3
    rv31 = multivariate_normal(mu_4, [[sigma_4[0], 0], [0, sigma_4[1]]])
    rv32 = multivariate_normal(mu_5, [[sigma_5[0], 0], [0, sigma_5[1]]])

    # Make a 3D plot
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.plot_surface(X,
                    Y,
                    rv.pdf(pos) + rv31.pdf(pos) + rv32.pdf(pos),
                    cmap='viridis',
                    linewidth=0)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('message')
    plt.tight_layout()
    plt.savefig(save_dir + 'agent3.png')
    ax.view_init(elev=0., azim=azim)
    plt.savefig(save_dir + ("agent3_0_%i.png" % int(azim)))
    ax.view_init(elev=90., azim=0.)
    plt.savefig(save_dir + "agent3_90_0.png")
    # plt.show()
    plt.close()

    # Overall
    # Make a 3D plot
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.plot_surface(X,
                    Y,
                    rv.pdf(pos) + rv1.pdf(pos) + rv21.pdf(pos) +
                    rv22.pdf(pos) + rv31.pdf(pos) + rv32.pdf(pos),
                    cmap='viridis',
                    linewidth=0)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('message')
    plt.tight_layout()
    plt.savefig(save_dir + 'overall.png')
    ax.view_init(elev=0., azim=azim)
    plt.savefig(save_dir + ("overall_0_%i.png" % int(azim)))
    ax.view_init(elev=90., azim=0.)
    plt.savefig(save_dir + "overall_90_0.png")
    # plt.show()
    plt.close()
示例#26
0
def div_from_prior(posterior):
    prior = Normal(torch.zeros_like(posterior.loc),
                   torch.ones_like(posterior.scale))
    return kl_divergence(posterior, prior).sum(dim=-1)
示例#27
0
def embedding_loss(output, target, epoch_iter, n_iters_start=0, lambd=0.3):
    # assert len(output) == 3 or len(output) == 5

    # TODO: implement KL-div with Pytorch lib
    #  passar distribucio i sample. Distribucio es calcula abans. Use rsample, no te lies.
    #  G**2 es el numero de slots. Seria com el numero de timesteps, suposo. Hauria de deixar el batch i prou.

    rec = output["rec"]
    pred = output["pred"]
    pred_rev = output["pred_rev"]

    g = output["obs_rec_pred"]
    posteriors = output["gauss_post"]
    A = output["A"]
    B = output["B"]
    u = output["u"]

    # Get bs, T, n_obj; and take order into account.
    bs = rec.shape[0]
    T = rec.shape[1]
    device = rec.device
    std_rec = .15

    # local_geo_loss = torch.zeros(1)
    if epoch_iter[
            0] < n_iters_start:  #TODO: Add rec loss in G with spectralnorm. fit_error
        # lambd_lg = 1
        lambd_fit_error = 1
        lambd_hrank = 0.05
        lambd_rec = 1
        lambd_pred = 0.2
        prior_pres_prob = 0.05
        prior_g_mask_prob = 0.8
        lambd_u = 0.1
        lambd_I = 1
    else:
        # lambd_lg = 1
        lambd_fit_error = 1
        lambd_hrank = 0.2
        lambd_rec = .7
        lambd_pred = 1
        prior_pres_prob = 0.1
        prior_g_mask_prob = 0.4
        lambd_u = 1
        lambd_I = 1
    '''Rec and pred losses'''
    # Shape latent: [bs, T-n_timesteps+1, 1, 64, 64]

    rec_distr = Normal(rec, std_rec)
    logprob_rec = rec_distr.log_prob(target[:, -rec.shape[1]:])\
        .flatten(start_dim=1).sum(1)

    pred_distr = Normal(pred, std_rec)
    logprob_pred = pred_distr.log_prob(target[:, -pred.shape[1]:])\
        .flatten(start_dim=1).sum(1) # TODO: Check if correct.

    pred_rev_distr = Normal(pred_rev, std_rec)
    logprob_pred_rev = pred_rev_distr.log_prob(target[:, :pred_rev.shape[1]]) \
    .flatten(start_dim=1).sum(1) # TODO: Check if correct.

    kl_bern_loss = 0
    '''G composition bernoulli KL div loss'''
    # if "g_bern_logit" == any(output.keys()):
    if output["g_bern_logit"] is not None:
        g_mask_logit = output["g_bern_logit"]  # TODO: Check shape
        kl_g_mask_loss = kl_divergence_bern_bern(
            g_mask_logit,
            prior_pres_prob=torch.FloatTensor([prior_g_mask_prob
                                               ]).to(u_logit.device)).sum()
        kl_bern_loss = kl_bern_loss + kl_g_mask_loss
    '''Input bernoulli KL div loss'''
    # if "u_bern_logit" == any(output.keys()):
    if output["u_bern_logit"] is not None:
        u_logit = output["u_bern_logit"]
        kl_u_loss = kl_divergence_bern_bern(u_logit,
                                            prior_pres_prob=torch.FloatTensor([
                                                prior_pres_prob
                                            ]).to(u_logit.device)).sum()
        kl_bern_loss = kl_bern_loss + kl_u_loss
        l1_u = 0.
    else:
        '''Input sparsity loss
        u: [bs, T, feat_dim]
        '''
        up_bound = 0.3
        N_elem = u.shape[0]
        l1_u_sparse = F.relu(l1_loss(u, torch.zeros_like(u)).mean() - up_bound)
        # l1_u_diff_sparse = F.relu(l1_loss(u[:, :-1] - u[:, 1:], torch.zeros_like(u[:, 1:])).mean() - up_bound) * u.shape[0] * u.shape[1]
        l1_u = lambd_u * l1_u_sparse * N_elem
        # l1_u = lambd_u * l1_loss(u, torch.zeros_like(u)).mean()
    '''Gaussian vectors KL div loss'''
    posteriors = posteriors[:-1]  # If only rec
    prior = Normal(torch.zeros(1).to(device), torch.ones(1).to(device))
    kl_loss = torch.stack([
        kl_divergence(post, prior).flatten(start_dim=1).sum(1)
        for post in posteriors
    ]).sum(0)

    nelbo = (
        kl_loss + kl_bern_loss - logprob_rec * lambd_rec -
        logprob_pred * lambd_pred  #TODO: There's a mismatch here?
        - logprob_pred_rev * lambd_pred).mean()
    '''Cycle consistency'''
    if output["Ainv"] is not None:
        A_inv = output["Ainv"]
        cyc_conc_loss = cycle_consistency_loss(A, A_inv)
    '''LSQ fit loss'''
    # Note: only has correspondence if last frames of pred correspond to last frames of rec
    g_rec, g_pred = g[:, :T], g[:, T:]
    T_min = min(T, g_pred.shape[1])
    fit_error = lambd_fit_error * mse_loss(g_rec[:, -T_min:],
                                           g_pred[:, -T_min:]).sum()
    '''Low rank G'''
    # Option 1:
    # T = g_for_koop.shape[1]
    # n_timesteps = g_for_koop.shape[-1]
    # g_for_koop = g_for_koop.permute(0, 2, 1, 3).reshape(-1, T-1, n_timesteps)
    # h_rank_loss = 0
    # reg_mask = torch.zeros_like(g_for_koop[..., :n_timesteps, :])
    # ids = torch.arange(0, reg_mask.shape[-1])
    # reg_mask[..., ids, ids] = 0.01
    # for t in range(T-1-n_timesteps):
    #     logdet_H = torch.slogdet(g_for_koop[..., t:t+n_timesteps, :] + reg_mask)
    #     h_rank_loss = h_rank_loss + .01*(logdet_H[1]).mean()
    # Option 2:
    h_rank_loss = (
        l1_loss(A, torch.zeros_like(A)).sum(-1).sum(-1).sum()
        # + l1_loss(B, torch.zeros_like(B)).sum(-1).sum(-1).mean()
    )
    # h_rank_loss = (l1_loss(A, torch.zeros_like(A)).sum(-1).sum(-1).sum()
    #                + l1_loss(B, torch.zeros_like(B)).sum(-1).sum(-1).sum())
    h_rank_loss = lambd_hrank * h_rank_loss
    '''Input KL div.'''
    # KL loss for gumbel-softmax. We should use increasing temperature for softmax. Check SPACE pres variable.
    '''Local geometry loss'''
    # g = g[:, :rec.shape[1]]
    # local_geo_loss = lambd_lg * local_geo(g, target[:, -g.shape[1]:])
    '''Total Loss'''
    loss = (
        nelbo + h_rank_loss + l1_u + fit_error + cyc_cons_loss
        # + local_geo_loss
    )

    rec_mse = mse_loss(rec, target[:, -rec.shape[1]:]).mean(1).reshape(
        bs, -1).flatten(start_dim=1).sum(1).mean()
    pred_mse = mse_loss(pred, target[:, -pred.shape[1]:]).mean(1).reshape(
        bs, -1).flatten(start_dim=1).sum(1).mean()

    return loss, {
        'Rec mse': rec_mse,
        'Pred mse': pred_mse,
        'KL Loss': kl_loss.mean(),
        # 'Rec llik':logprob_rec.mean(),
        # 'Pred llik':logprob_pred.mean(),
        'Cycle consistency Loss': cyc_cons_loss,
        'H rank Loss': h_rank_loss,
        'Fit error': fit_error,
        # 'G Pred Loss':fit_error,
        # 'Local geo Loss':local_geo_loss,
        # 'l1_u':l1_u,
    }
示例#28
0
def _kl_mv_diag_normal_mv_diag_normal(p, q):
    return kl_divergence(p.distribution, q.distribution)
示例#29
0
def kl_relaxed_one_hot_categorical(p, q):
    p = Categorical(probs=p.probs)
    q = Categorical(probs=q.probs)
    return kl_divergence(p, q)
    def forward(self, p_params, q_params=None, forced_latent=None,
                use_mode=False, force_constant_output=False,
                analytical_kl=False):

        assert (forced_latent is None) or (not use_mode)

        if self.transform_p_params:
            p_params = self.conv_in_p(p_params)
        else:
            assert p_params.size(1) == 2 * self.c_vars

        if q_params is not None:
            q_params = self.conv_in_q(q_params)
            mu_lv = q_params
        else:
            mu_lv = p_params

        if forced_latent is None:
            if use_mode:
                z = torch.chunk(mu_lv, 2, dim=1)[0]
            else:
                z = normal_rsample(mu_lv)
        else:
            z = forced_latent

        # Copy one sample (and distrib parameters) over the whole batch
        if force_constant_output:
            z = z[0:1].expand_as(z).clone()
            p_params = p_params[0:1].expand_as(p_params).clone()

        # Output of stochastic layer
        out = self.conv_out(z)

        kl_elementwise = kl_samplewise = kl_spatial_analytical = None
        logprob_q = None

        # Compute log p(z)
        p_mu, p_lv = p_params.chunk(2, dim=1)
        p = Normal(p_mu, (p_lv / 2).exp())
        logprob_p = p.log_prob(z).sum((1, 2, 3))

        if q_params is not None:

            # Compute log q(z)
            q_mu, q_lv = q_params.chunk(2, dim=1)
            q = Normal(q_mu, (q_lv / 2).exp())
            logprob_q = q.log_prob(z).sum((1, 2, 3))

            # Compute KL (analytical or MC estimate)
            if analytical_kl:
                kl_elementwise = kl_divergence(q, p)
            else:
                kl_elementwise = kl_normal_mc(z, p_params, q_params)
            kl_samplewise = kl_elementwise.sum((1, 2, 3))

            # Compute spatial KL analytically (but conditioned on samples from
            # previous layers)
            kl_spatial_analytical = -0.5 * (1 + q_lv - q_mu.pow(2) - q_lv.exp())
            kl_spatial_analytical = kl_spatial_analytical.sum(1)

        data = {
            'z': z,
            'p_params': p_params,
            'q_params': q_params,
            'logprob_p': logprob_p,
            'logprob_q': logprob_q,
            'kl_elementwise': kl_elementwise,
            'kl_samplewise': kl_samplewise,
            'kl_spatial': kl_spatial_analytical,
        }
        return out, data
示例#31
0
    def forward(self, representation, viewpoint_query, image_query):

        batch_size, _, *spatial_dims = image_query.shape
        spatial_dims_scaled = tuple(np.array(spatial_dims) // self.scale)
        kl = 0

        # Increase dimensions
        viewpoint_query = viewpoint_query.view(batch_size, -1, *[
            1,
        ] * len(spatial_dims)).repeat(1, 1, *spatial_dims_scaled)
        if representation.shape[2:] != spatial_dims_scaled:
            representation = representation.view(batch_size, -1, *[
                1,
            ] * len(spatial_dims)).repeat(1, 1, *spatial_dims_scaled)

        # Reset hidden state
        hidden_g = image_query.new_zeros(
            (batch_size, self.h_channels, *spatial_dims_scaled))
        hidden_i = image_query.new_zeros(
            (batch_size, self.h_channels, *spatial_dims_scaled))

        # Reset cell state
        cell_g = image_query.new_zeros(
            (batch_size, self.h_channels, *spatial_dims_scaled))
        cell_i = image_query.new_zeros(
            (batch_size, self.h_channels, *spatial_dims_scaled))

        # better name for u?
        u = image_query.new_zeros((batch_size, self.h_channels, *spatial_dims))

        for i in range(self.core_repeat):

            if self.core_shared:
                current_generator_core = self.generator_core
                current_inference_core = self.inference_core
            else:
                current_generator_core = self.generator_core[i]
                current_inference_core = self.inference_core[i]

            # Prior
            o = self.prior_net(hidden_g)
            prior_mu, prior_std_pseudo = torch.split(o, self.z_channels, dim=1)
            prior = Normal(prior_mu, F.softplus(prior_std_pseudo))

            # Inference state update
            cell_i, hidden_i = current_inference_core(image_query,
                                                      viewpoint_query,
                                                      representation, cell_i,
                                                      hidden_i, hidden_g, u)

            # Posterior
            o = self.posterior_net(hidden_i)
            posterior_mu, posterior_std_pseudo = torch.split(o,
                                                             self.z_channels,
                                                             dim=1)
            posterior = Normal(posterior_mu, F.softplus(posterior_std_pseudo))

            # Posterior sample
            if self.training:
                z = posterior.rsample()
            else:
                z = prior.loc

            # Generator update
            cell_g, hidden_g, u = current_generator_core(
                viewpoint_query, representation, z, cell_g, hidden_g, u)

            # Calculate KL-divergence
            kl += kl_divergence(posterior, prior)

        image_prediction = self.output_activation(self.observation_net(u))

        return image_prediction, kl
示例#32
0
    def train(self, env_fn, policy, n_itr, normalize=None, logger=None):

        if normalize != None:
            policy.train()
        else:
            policy.train(0)

        env = Vectorize([env_fn])  # this will be useful for parallelism later

        if normalize is not None:
            env = normalize(env)

            mean, std = env.ob_rms.mean, np.sqrt(env.ob_rms.var + 1E-8)
            policy.obs_mean = torch.Tensor(mean)
            policy.obs_std = torch.Tensor(std)
            policy.train(0)

        env = Vectorize([env_fn])

        old_policy = deepcopy(policy)

        optimizer = optim.Adam(policy.parameters(), lr=self.lr, eps=self.eps)

        start_time = time.time()

        for itr in range(n_itr):
            print("********** Iteration {} ************".format(itr))

            sample_t = time.time()
            if self.n_proc > 1:
                print("doing multi samp")
                batch = self.sample_parallel(env_fn, policy, self.num_steps,
                                             300)
            else:
                batch = self._sample(env_fn, policy, self.num_steps,
                                     300)  #TODO: fix this

            print("sample time: {:.2f} s".format(time.time() - sample_t))

            observations, actions, returns, values = map(
                torch.Tensor, batch.get())

            advantages = returns - values
            advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                             self.eps)

            minibatch_size = self.minibatch_size or advantages.numel()

            print("timesteps in batch: %i" % advantages.numel())

            old_policy.load_state_dict(
                policy.state_dict())  # WAY faster than deepcopy

            for _ in range(self.epochs):
                losses = []
                sampler = BatchSampler(SubsetRandomSampler(
                    range(advantages.numel())),
                                       minibatch_size,
                                       drop_last=True)

                for indices in sampler:
                    indices = torch.LongTensor(indices)

                    obs_batch = observations[indices]
                    action_batch = actions[indices]

                    return_batch = returns[indices]
                    advantage_batch = advantages[indices]

                    values, pdf = policy.evaluate(obs_batch)

                    # TODO, move this outside loop?
                    with torch.no_grad():
                        _, old_pdf = old_policy.evaluate(obs_batch)
                        old_log_probs = old_pdf.log_prob(action_batch).sum(
                            -1, keepdim=True)

                    log_probs = pdf.log_prob(action_batch).sum(-1,
                                                               keepdim=True)

                    ratio = (log_probs - old_log_probs).exp()

                    cpi_loss = ratio * advantage_batch
                    clip_loss = ratio.clamp(1.0 - self.clip,
                                            1.0 + self.clip) * advantage_batch
                    actor_loss = -torch.min(cpi_loss, clip_loss).mean()

                    critic_loss = 0.5 * (return_batch - values).pow(2).mean()

                    entropy_penalty = -self.entropy_coeff * pdf.entropy().mean(
                    )

                    # TODO: add ability to optimize critic and actor seperately, with different learning rates

                    optimizer.zero_grad()
                    (actor_loss + critic_loss + entropy_penalty).backward()

                    # Clip the gradient norm to prevent "unlucky" minibatches from
                    # causing pathalogical updates
                    torch.nn.utils.clip_grad_norm_(policy.parameters(),
                                                   self.grad_clip)
                    optimizer.step()

                    losses.append([
                        actor_loss.item(),
                        pdf.entropy().mean().item(),
                        critic_loss.item(),
                        ratio.mean().item()
                    ])

                # TODO: add verbosity arguments to suppress this
                print(' '.join(["%g" % x for x in np.mean(losses, axis=0)]))

                # Early stopping
                if kl_divergence(pdf, old_pdf).mean() > 0.02:
                    print("Max kl reached, stopping optimization early.")
                    break

            if logger is not None:
                test = self.sample(env,
                                   policy,
                                   800 // self.n_proc,
                                   400,
                                   deterministic=True)
                _, pdf = policy.evaluate(observations)
                _, old_pdf = old_policy.evaluate(observations)

                entropy = pdf.entropy().mean().item()
                kl = kl_divergence(pdf, old_pdf).mean().item()

                logger.record("Return (test)", np.mean(test.ep_returns))
                logger.record("Return (batch)", np.mean(batch.ep_returns))
                logger.record("Mean Eplen", np.mean(batch.ep_lens))

                logger.record("Mean KL Div", kl)
                logger.record("Mean Entropy", entropy)
                logger.dump()

            # TODO: add option for how often to save model
            # if itr % 10 == 0:
            if np.mean(test.ep_returns) > self.max_return:
                self.max_return = np.mean(test.ep_returns)
                self.save(policy, env)
                self.save_optim(optimizer)

            print("Total time: {:.2f} s".format(time.time() - start_time))
示例#33
0
def test_non_empty_bit_vector(batch_shape=tuple(), K=3):
    assert K<= 10, "I test against explicit enumeration, K>10 might be too slow for that"
    
    # Uniform 
    F = NonEmptyBitVector(torch.zeros(batch_shape + (K,)))
    
    # Shapes
    assert F.batch_shape == batch_shape, "NonEmptyBitVector has the wrong batch_shape"
    assert F.dim == K, "NonEmptyBitVector has the wrong dim"
    assert F.event_shape == (K,), "NonEmptyBitVector has the wrong event_shape"
    assert F.scores.shape == batch_shape + (K,), "NonEmptyBitVector.score has the wrong shape"
    assert F.arc_weight.shape == batch_shape + (K+1,3,3), "NonEmptyBitVector.arc_weight has the wrong shape"
    assert F.state_value.shape == batch_shape + (K+2,3), "NonEmptyBitVector.state_value has the wrong shape"
    assert F.state_rvalue.shape == batch_shape + (K+2,3), "NonEmptyBitVector.state_rvalue has the wrong shape"
    # shape: [num_faces] + batch_shape + [K]
    support = F.enumerate_support()    
    # test shape of support
    assert support.shape == (2**K-1,) + batch_shape + (K,), "The support has the wrong shape"

    assert F.expand((2,3) + batch_shape).batch_shape == (2,3) + batch_shape, "Bad expand batch_shape"
    assert F.expand((2,3) + batch_shape).event_shape == (K,), "Bad expand event_shape"
    assert F.expand((2,3) + batch_shape).sample().shape == (2,3) + batch_shape + (K,), "Bad expand single sample"
    assert F.expand((2,3) + batch_shape).sample((13,)).shape == (13,2,3) + batch_shape + (K,), "Bad expand multiple samples"

    # Constraints
    assert (support.sum(-1) > 0).all(), "The support has an empty bit vector"
    for _ in range(100):  # testing one sample at a time
        assert F.sample().sum(-1).all(), "I found an empty vector"
    # testing a batch of samples
    assert F.sample((100,)).sum(-1).all(), "I found an empty vector"
    # testing a complex batch of samples
    assert F.sample((2, 100,)).sum(-1).all(), "I found an empty vector"
    
    # Distribution
    # check for uniform probabilities
    assert torch.isclose(F.log_prob(support).exp(), torch.tensor(1./F.support_size)).all(), "Non-uniform"
    # check for uniform marginal probabilities
    assert torch.isclose(F.sample((10000,)).float().mean(0), support.mean(0), atol=1e-1).all(), "Bad MC marginals"    
    assert torch.isclose(F.marginals(), support.mean(0)).all(), "Bad exact marginals"

    # Entropy

    # [num_faces, B]
    log_prob = F.log_prob(support)
    assert torch.isclose(F.entropy(), (-(log_prob.exp() * log_prob).sum(0)), atol=1e-2).all(), "Problem in the entropy DP"

    # Non-Uniform  

    # Entropy  
    P = NonEmptyBitVector(td.Normal(torch.zeros(batch_shape + (K,)), torch.ones(batch_shape + (K,))).sample())
    log_p = P.log_prob(support)
    assert torch.isclose(P.entropy(), (-(log_p.exp() * log_p).sum(0)), atol=1e-2).all(), "Problem in the entropy DP"
    # Cross-Entropy
    Q = NonEmptyBitVector(td.Normal(torch.zeros(batch_shape + (K,)), torch.ones(batch_shape + (K,))).sample())
    log_q = Q.log_prob(support)
    assert torch.isclose(P.cross_entropy(Q), -(log_p.exp() * log_q).sum(0), atol=1e-2).all(), "Problem in the cross-entropy DP"
    # KL
    assert torch.isclose(td.kl_divergence(P, Q), (log_p.exp() * (log_p - log_q)).sum(0), atol=1e-2).all(), "Problem in KL"

    # Constraints
    for _ in range(100):  # testing one sample at a time
        assert P.sample().sum(-1).all(), "I found an empty vector"
        assert Q.sample().sum(-1).all(), "I found an empty vector"
    # testing a batch of samples
    assert P.sample((100,)).sum(-1).all(), "I found an empty vector"
    assert Q.sample((100,)).sum(-1).all(), "I found an empty vector"
    # testing a complex batch of samples
    assert P.sample((2, 100,)).sum(-1).all(), "I found an empty vector"
    assert Q.sample((2, 100,)).sum(-1).all(), "I found an empty vector"
示例#34
0
    def update(self, policy, old_policy, optimizer, observations, actions,
               returns, advantages, env_fn):
        env = env_fn()
        mirror_observation = env.mirror_observation
        mirror_action = env.mirror_action

        minibatch_size = self.minibatch_size or advantages.numel()

        for _ in range(self.epochs):
            losses = []
            sampler = BatchSampler(SubsetRandomSampler(
                range(advantages.numel())),
                                   minibatch_size,
                                   drop_last=True)
            for indices in sampler:
                indices = torch.LongTensor(indices)

                obs_batch = observations[indices]
                # obs_batch = torch.cat(
                #     [obs_batch,
                #      obs_batch @ torch.Tensor(env.obs_symmetry_matrix)]
                # ).detach()

                action_batch = actions[indices]
                # action_batch = torch.cat(
                #     [action_batch,
                #      action_batch @ torch.Tensor(env.action_symmetry_matrix)]
                # ).detach()

                return_batch = returns[indices]
                # return_batch = torch.cat(
                #     [return_batch,
                #      return_batch]
                # ).detach()

                advantage_batch = advantages[indices]
                # advantage_batch = torch.cat(
                #     [advantage_batch,
                #      advantage_batch]
                # ).detach()

                values, pdf = policy.evaluate(obs_batch)

                # TODO, move this outside loop?
                with torch.no_grad():
                    _, old_pdf = old_policy.evaluate(obs_batch)
                    old_log_probs = old_pdf.log_prob(action_batch).sum(
                        -1, keepdim=True)

                log_probs = pdf.log_prob(action_batch).sum(-1, keepdim=True)

                ratio = (log_probs - old_log_probs).exp()

                cpi_loss = ratio * advantage_batch
                clip_loss = ratio.clamp(1.0 - self.clip,
                                        1.0 + self.clip) * advantage_batch
                actor_loss = -torch.min(cpi_loss, clip_loss).mean()

                critic_loss = 0.5 * (return_batch - values).pow(2).mean()

                # Mirror Symmetry Loss
                _, deterministic_actions = policy(obs_batch)
                _, mirror_actions = policy(mirror_observation(obs_batch))
                mirror_actions = mirror_action(mirror_actions)

                mirror_loss = 4 * (deterministic_actions -
                                   mirror_actions).pow(2).mean()

                entropy_penalty = -self.entropy_coeff * pdf.entropy().mean()

                # TODO: add ability to optimize critic and actor seperately, with different learning rates

                optimizer.zero_grad()
                (actor_loss + critic_loss + mirror_loss +
                 entropy_penalty).backward()

                # Clip the gradient norm to prevent "unlucky" minibatches from
                # causing pathalogical updates
                torch.nn.utils.clip_grad_norm_(policy.parameters(),
                                               self.grad_clip)
                optimizer.step()

                losses.append([
                    actor_loss.item(),
                    pdf.entropy().mean().item(),
                    critic_loss.item(),
                    ratio.mean().item(),
                    mirror_loss.item()
                ])

            # TODO: add verbosity arguments to suppress this
            print(' '.join(["%g" % x for x in np.mean(losses, axis=0)]))

            # Early stopping
            if kl_divergence(pdf, old_pdf).mean() > 0.02:
                print("Max kl reached, stopping optimization early.")
                break
示例#35
0
    def forward(self, *inputs: Tensor, **kwargs
                ) -> Tuple[List[MultivariateNormal], Tensor]:
        """Forward propagate the model.

        Parameters
        ----------
        inputs: Tensor.
            output_sequence: Tensor.
            Tensor of output data [batch_size x sequence_length x dim_outputs].

            input_sequence: Tensor.
            Tensor of input data [batch_size x sequence_length x dim_inputs].

        Returns
        -------
        output_distribution: List[Normal].
            List of length sequence_length of distributions of size
            [batch_size x dim_outputs x num_particles]
        """
        output_sequence, input_sequence = inputs
        num_particles = self.num_particles
        # dim_states = self.dim_states
        batch_size, sequence_length, dim_inputs = input_sequence.shape
        _, _, dim_outputs = output_sequence.shape

        ################################################################################
        # SAMPLE GP #
        ################################################################################
        self.forward_model.resample()
        self.backward_model.resample()

        ################################################################################
        # PERFORM Backward Pass #
        ################################################################################
        if self.training:
            output_distribution = self.backward(output_sequence, input_sequence)

        ################################################################################
        # Initial State #
        ################################################################################
        state = self.recognition(output_sequence[:, :self.recognition.length],
                                 input_sequence[:, :self.recognition.length],
                                 num_particles=num_particles)

        ################################################################################
        # PREDICT Outputs #
        ################################################################################
        outputs = []
        y_pred = self.emissions(state)
        outputs.append(MultivariateNormal(y_pred.loc.detach(),
                                          y_pred.covariance_matrix.detach()))

        ################################################################################
        # INITIALIZE losses #
        ################################################################################

        # entropy = torch.tensor(0.)
        if self.training:
            output_distribution.pop(0)
            # entropy += y_tilde.entropy().mean() / sequence_length

        y = output_sequence[:, 0].expand(num_particles, batch_size, dim_outputs
                                         ).permute(1, 2, 0)
        log_lik = y_pred.log_prob(y).sum(dim=1).mean()  # type: torch.Tensor
        l2 = ((y_pred.loc - y) ** 2).sum(dim=1).mean()  # type: torch.Tensor
        kl_cond = torch.tensor(0.)

        for t in range(sequence_length - 1):
            ############################################################################
            # PREDICT Next State #
            ############################################################################
            u = input_sequence[:, t].expand(num_particles, batch_size, dim_inputs)
            u = u.permute(1, 2, 0)  # Move last component to end.
            state_samples = state.rsample()
            state_input = torch.cat((state_samples, u), dim=1)

            next_f = self.forward_model(state_input)
            next_state = self.transitions(next_f)
            next_state.loc += state_samples

            if self.independent_particles:
                next_state = diagonal_covariance(next_state)
            ############################################################################
            # CONDITION Next State #
            ############################################################################
            if self.training:
                y_tilde = output_distribution.pop(0)
                p_next_state = next_state
                next_state = self._condition(next_state, y_tilde)
                kl_cond += kl_divergence(next_state, p_next_state).mean()
            ############################################################################
            # RESAMPLE State #
            ############################################################################
            state = next_state

            ############################################################################
            # PREDICT Outputs #
            ############################################################################
            y_pred = self.emissions(state)
            outputs.append(y_pred)

            ############################################################################
            # COMPUTE Losses #
            ############################################################################
            y = output_sequence[:, t + 1].expand(
                num_particles, batch_size, dim_outputs).permute(1, 2, 0)
            log_lik += y_pred.log_prob(y).sum(dim=1).mean()
            l2 += ((y_pred.loc - y) ** 2).sum(dim=1).mean()
            # entropy += y_tilde.entropy().mean() / sequence_length

        assert len(outputs) == sequence_length

        # if self.training:
        #     del output_distribution
        ################################################################################
        # Compute model KL divergences Divergences #
        ################################################################################
        factor = 1  # batch_size / self.dataset_size
        kl_uf = self.forward_model.kl_divergence()
        kl_ub = self.backward_model.kl_divergence()

        if self.forward_model.independent:
            kl_uf *= sequence_length
        if self.backward_model.independent:
            kl_ub *= sequence_length

        kl_cond = kl_cond * self.loss_factors['kl_conditioning'] * factor
        kl_ub = kl_ub * self.loss_factors['kl_u'] * factor
        kl_uf = kl_uf * self.loss_factors['kl_u'] * factor

        if self.loss_key.lower() == 'loglik':
            loss = -log_lik
        elif self.loss_key.lower() == 'elbo':
            loss = -(log_lik - kl_uf - kl_ub - kl_cond)
            if kwargs.get('print', False):
                str_ = 'elbo: {}, log_lik: {}, kluf: {}, klub: {}, klcond: {}'
                print(str_.format(loss.item(), log_lik.item(), kl_uf.item(),
                                  kl_ub.item(), kl_cond.item()))
        elif self.loss_key.lower() == 'l2':
            loss = l2
        elif self.loss_key.lower() == 'rmse':
            loss = torch.sqrt(l2)
        else:
            raise NotImplementedError("Key {} not implemented".format(self.loss_key))

        return outputs, loss
示例#36
0
    def kld_loss(self, encoded_distribution):
        prior = Normal(torch.zeros_like(encoded_distribution.mean),
                       torch.ones_like(encoded_distribution.variance))
        kld = kl_divergence(encoded_distribution, prior).sum(dim=1)

        return kld
示例#37
0
        np.random.shuffle(aux_inds)
        for i, start in enumerate(
                range(0, args.aux_batch_size, args.aux_minibatch_size)):
            end = start + args.aux_minibatch_size
            aux_minibatch_ind = aux_inds[start:end]
            try:
                m_aux_obs = aux_obs[aux_minibatch_ind].to(device)
                m_aux_returns = aux_returns[aux_minibatch_ind].to(device)

                new_pi, new_values, new_aux_values = agent.get_pi_value_and_aux_value(
                    m_aux_obs)
                new_values = new_values.view(-1)
                new_aux_values = new_aux_values.view(-1)
                with torch.no_grad():
                    old_pi = old_agent.get_pi(m_aux_obs)
                kl_loss = td.kl_divergence(old_pi, new_pi).mean()

                real_value_loss = 0.5 * (
                    (new_values - m_aux_returns)**2).mean()
                aux_value_loss = 0.5 * (
                    (new_aux_values - m_aux_returns)**2).mean()
                joint_loss = aux_value_loss + args.beta_clone * kl_loss

                optimizer.zero_grad()
                loss = (joint_loss + real_value_loss) / args.n_aux_grad_accum
                loss.backward()

                if (i + 1) % args.n_aux_grad_accum == 0:
                    nn.utils.clip_grad_norm_(agent.parameters(),
                                             args.max_grad_norm)
                    optimizer.step()
示例#38
0
    def train(self, env_fn, policy, n_itr, logger=None):

        old_policy = deepcopy(policy)

        optimizer = optim.Adam(policy.parameters(), lr=self.lr, eps=self.eps)

        start_time = time.time()

        for itr in range(n_itr):
            print("********** Iteration {} ************".format(itr))

            sample_start = time.time()
            batch = self.sample_parallel(env_fn, policy, self.num_steps,
                                         self.max_traj_len)

            print("time elapsed: {:.2f} s".format(time.time() - start_time))
            print("sample time elapsed: {:.2f} s".format(time.time() -
                                                         sample_start))

            observations, actions, returns, values = map(
                torch.Tensor, batch.get())

            advantages = returns - values
            advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                             self.eps)

            minibatch_size = self.minibatch_size or advantages.numel()

            print("timesteps in batch: %i" % advantages.numel())

            old_policy.load_state_dict(
                policy.state_dict())  # WAY faster than deepcopy

            optimizer_start = time.time()

            self.update(policy, old_policy, optimizer, observations, actions,
                        returns, advantages, env_fn)

            print("optimizer time elapsed: {:.2f} s".format(time.time() -
                                                            optimizer_start))

            if logger is not None:
                evaluate_start = time.time()
                test = self.sample_parallel(env_fn,
                                            policy,
                                            800 // self.n_proc,
                                            self.max_traj_len,
                                            deterministic=True)
                print("evaluate time elapsed: {:.2f} s".format(time.time() -
                                                               evaluate_start))

                _, pdf = policy.evaluate(observations)
                _, old_pdf = old_policy.evaluate(observations)

                entropy = pdf.entropy().mean().item()
                kl = kl_divergence(pdf, old_pdf).mean().item()

                logger.record("Return (test)", np.mean(test.ep_returns))
                logger.record("Return (batch)", np.mean(batch.ep_returns))
                logger.record("Mean Eplen", np.mean(batch.ep_lens))

                logger.record("Mean KL Div", kl)
                logger.record("Mean Entropy", entropy)
                logger.dump()

            # TODO: add option for how often to save model
            if itr % 10 == 0:
                self.save(policy)
        data_loader = PairedComparison(4,
                                       direction=False,
                                       dichotomized=dichotomized,
                                       ranking=True)
        for i in tqdm(range(num_tasks)):
            if dichotomized:
                model1 = VariationalFirstDiscriminatingCue(
                    data_loader.num_inputs)
            else:
                model1 = VariationalFirstCue(data_loader.num_inputs)
            model2 = VariationalProbitRegression(data_loader.num_inputs)
            inputs, targets, _, _ = data_loader.get_batch(1, num_steps)

            predictive_distribution1 = model1.forward(inputs, targets)
            predictive_distribution2 = model2.forward(inputs, targets)
            kl[k, i, :] = kl_divergence(predictive_distribution1,
                                        predictive_distribution2).squeeze()

    torch.save(kl, 'data/power_analysis.pth')
    print(kl.mean(1))
    print(kl.mean(1).sum(-1))
else:
    kl = torch.load('data/power_analysis.pth').div(2.303)  # ln to log10
    mean = kl.sum(-1).mean(1)
    variance = kl.sum(-1).std(1).pow(2)

    num_tasks = torch.arange(1, 31)
    expected_bf = torch.arange(1, 31).unsqueeze(-1) * mean
    confidence_bf = (torch.arange(1, 31).unsqueeze(-1) *
                     variance).sqrt() * 1.96
    styles = [':', '-']
    for i in range(expected_bf.shape[1]):