Exemplo n.º 1
0
def test_adjoint_marginal(equation, plates):
    inputs, output = equation.split('->')
    inputs = inputs.split(',')
    operands = [torch.randn(torch.Size((2,) * len(input_)))
                for input_ in inputs]
    for input_, x in zip(inputs, operands):
        x._pyro_dims = input_

    # check forward pass
    for x in operands:
        require_backward(x)
    actual, = ubersum(equation, *operands, plates=plates, modulo_total=True,
                      backend='pyro.ops.einsum.torch_marginal')
    expected, = ubersum(equation, *operands, plates=plates, modulo_total=True,
                        backend='pyro.ops.einsum.torch_log')
    assert_equal(expected, actual)

    # check backward pass
    actual._pyro_backward()
    for input_, operand in zip(inputs, operands):
        marginal_equation = ','.join(inputs) + '->' + input_
        expected, = ubersum(marginal_equation, *operands, plates=plates, modulo_total=True,
                            backend='pyro.ops.einsum.torch_log')
        actual = operand._pyro_backward_result
        assert_equal(expected, actual)
Exemplo n.º 2
0
def test_einsum_linear(equation, plates):
    inputs, outputs, log_operands, sizes = make_example(equation)
    operands = [x.exp() for x in log_operands]

    try:
        log_expected = ubersum(equation,
                               *log_operands,
                               plates=plates,
                               modulo_total=True)
        expected = [x.exp() for x in log_expected]
    except NotImplementedError:
        pytest.skip()

    # einsum() is in linear space whereas ubersum() is in log space.
    actual = einsum(equation, *operands, plates=plates, modulo_total=True)
    assert isinstance(actual, tuple)
    assert len(actual) == len(outputs)
    for output, expected_part, actual_part in zip(outputs, expected, actual):
        assert_equal(
            expected_part.log(),
            actual_part.log(),
            msg="For output '{}':\nExpected:\n{}\nActual:\n{}".format(
                output,
                expected_part.detach().cpu(),
                actual_part.detach().cpu()),
        )
Exemplo n.º 3
0
def test_adjoint_shape(backend, equation, plates):
    backend = "pyro.ops.einsum.torch_{}".format(backend)
    inputs, output = equation.split("->")
    inputs = inputs.split(",")
    operands = [
        torch.randn(torch.Size((2, ) * len(input_))) for input_ in inputs
    ]
    for input_, x in zip(inputs, operands):
        x._pyro_dims = input_

    # run forward-backward algorithm
    for x in operands:
        require_backward(x)
    (result, ) = ubersum(equation,
                         *operands,
                         plates=plates,
                         modulo_total=True,
                         backend=backend)
    result._pyro_backward()

    for input_, x in zip(inputs, operands):
        backward_result = x._pyro_backward_result
        contract_dims = set(input_) - set(output) - set(plates)
        if contract_dims:
            assert backward_result is not None
        else:
            assert backward_result is None
Exemplo n.º 4
0
def test_ubersum_total(equation, plates):
    inputs, outputs, operands, sizes = make_example(equation, fill=1, sizes=(2,))
    output = outputs[0]

    expected = naive_ubersum(equation, *operands, plates=plates)[0]
    actual = ubersum(equation, *operands, plates=plates, modulo_total=True)[0]
    expected = _normalize(expected, output, plates)
    actual = _normalize(actual, output, plates)
    assert_equal(expected, actual,
                 msg=u"Expected:\n{}\nActual:\n{}".format(
                     expected.detach().cpu(), actual.detach().cpu()))
Exemplo n.º 5
0
    def forward(self, x, lens, k, kx):
        # model takes as input the text, aspect, and location
        # runs BLSTM over text using embedding(location, aspect) as
        # the initial hidden state, as opposed to a different lstm for every pair???
        # output sentiment

        # DBG
        words = x

        emb = self.drop(self.lut(x))
        p_emb = pack(emb, lens, True)

        l, a = k
        N = x.shape[0]
        T = x.shape[1]

        y_idx = l * len(self.A) + a if self.L is not None else a
        s = (self.lut_la(y_idx).view(N, 2, 2 * self.nlayers,
                                     self.rnn_sz).permute(1, 2, 0,
                                                          3).contiguous())
        state = (s[0], s[1])
        x, (h, c) = self.rnn(p_emb, state)
        # h: L * D x N x H
        x = unpack(x, True)[0]
        proj_s = self.proj_s[y_idx.squeeze(-1)]
        phi_s = torch.einsum("nsh,nth->nts", [proj_s, emb])

        idxs = torch.arange(0, max(lens)).to(lens.device)
        # mask: N x R x 1
        mask = (idxs.repeat(len(lens), 1) >= lens.unsqueeze(-1))
        phi_y = torch.zeros(N, len(self.S)).to(self.lut.weight.device)
        psi_ys = self.proj_ys(x).view(N, T, len(self.S) - 1, len(self.S) - 1)
        left = torch.zeros(N, T, 1, len(self.S) - 1).to(psi_ys)
        top = torch.cat(
            [self.psi_none,
             torch.zeros(len(self.S) - 1).to(psi_ys)],
            0,
        ).view(1, 1, len(self.S), 1).expand(N, T, len(self.S), 1)
        psi_ys = torch.cat([top, torch.cat([left, psi_ys], -2)], -1)
        # mask phi_s, psi_ys...actually these are mostly unnecessary
        phi_s.masked_fill_(mask.unsqueeze(-1), 0)
        psi_ys = psi_ys.masked_fill_(mask.view(N, T, 1, 1), 0)
        Z, hy = ubersum("nts,ntys,ny->n,ny",
                        phi_s,
                        psi_ys,
                        phi_y,
                        batch_dims="t",
                        modulo_total=True)
        return hy
Exemplo n.º 6
0
    def forward(self, x, lens, a, l):
        # model takes as input the text, aspect, and location
        # runs BLSTM over text using embedding(location, aspect) as
        # the initial hidden state, as opposed to a different lstm for every pair???
        # output sentiment

        # DBG
        N, T = x.shape
        words = x

        emb = self.drop(self.lut(x))
        p_emb = pack(emb, lens, True)

        state = None
        if self.outer_plate:
            y_idx = l * len(self.A) + a if self.L is not None else a
            s = (self.lut_la(y_idx).view(N, 2, 2 * self.nlayers,
                                         self.rnn_sz).permute(1, 2, 0,
                                                              3).contiguous())
            state = (s[0], s[1])
        x, (h, c) = self.rnn(p_emb, state)
        # h: L * D x N x H
        x = unpack(x, True)[0]
        phi_s = self.proj_s(x)
        idxs = torch.arange(0, max(lens)).to(lens.device)
        # mask: N x R x 1
        mask = (idxs.repeat(len(lens), 1) >= lens.unsqueeze(-1))
        phi_s[:, :, -1].masked_fill_(1 - mask, float("-inf"))
        phi_s[:, :, :3].masked_fill_(mask.unsqueeze(-1), float("-inf"))

        phi_y = torch.zeros(N, len(self.S)).to(self.psi_ys.device)
        psi_ys = torch.cat(
            [
                torch.diag(self.psi_ys),
                torch.zeros(len(self.S), 1).to(self.psi_ys)
            ],
            dim=-1,
        ).expand(T, len(self.S),
                 len(self.S) + 1)
        Z, hy = ubersum("nts,tys,ny->n,ny",
                        phi_s,
                        psi_ys,
                        phi_y,
                        batch_dims="t",
                        modulo_total=True)

        return hy
Exemplo n.º 7
0
def test_ubersum(equation, plates):
    inputs, outputs, operands, sizes = make_example(equation)

    try:
        actual = ubersum(equation, *operands, plates=plates, modulo_total=True)
    except NotImplementedError:
        pytest.skip()

    assert isinstance(actual, tuple)
    assert len(actual) == len(outputs)
    expected = naive_ubersum(equation, *operands, plates=plates)
    for output, expected_part, actual_part in zip(outputs, expected, actual):
        actual_part = _normalize(actual_part, output, plates)
        expected_part = _normalize(expected_part, output, plates)
        assert_equal(expected_part, actual_part,
                     msg=u"For output '{}':\nExpected:\n{}\nActual:\n{}".format(
                         output, expected_part.detach().cpu(), actual_part.detach().cpu()))
Exemplo n.º 8
0
def test_ubersum_jit(equation, plates):
    inputs, outputs, operands, sizes = make_example(equation)

    try:
        expected = ubersum(equation, *operands, plates=plates, modulo_total=True)
    except NotImplementedError:
        pytest.skip()

    @pyro.ops.jit.trace
    def jit_ubersum(*operands):
        return ubersum(equation, *operands, plates=plates, modulo_total=True)

    actual = jit_ubersum(*operands)

    if not isinstance(actual, tuple):
        pytest.xfail(reason="https://github.com/pytorch/pytorch/issues/14875")
    assert len(expected) == len(actual)
    for e, a in zip(expected, actual):
        assert_equal(e, a)
Exemplo n.º 9
0
 def jit_ubersum(*operands):
     return ubersum(equation, *operands, plates=plates, modulo_total=True)
Exemplo n.º 10
0
    def forward(self, x, lens, a, l):
        words = x

        emb = self.drop(self.lut(x))
        p_emb = pack(emb, lens, True)

        N = x.shape[0]
        T = x.shape[1]

        state = None
        if self.outer_plate:
            y_idx = l * len(self.A) + a if self.L is not None else a
            s = (self.lut_la(y_idx).view(N, 2, 2 * self.nlayers,
                                         self.rnn_sz).permute(1, 2, 0,
                                                              3).contiguous())
            state = (s[0], s[1])
        x, (h, c) = self.rnn(p_emb, state)
        # h: L * D x N x H
        x = unpack(x, True)[0]
        #import pdb; pdb.set_trace()

        phi_s, phi_neg = None, None
        if self.outer_plate:
            proj_s = self.proj_s[y_idx.squeeze(-1)]
            phi_s = torch.einsum("nsh,nth->nts", [proj_s, emb])
            proj_neg = self.proj_neg[y_idx.squeeze(-1)]
            phi_neg = torch.einsum("nbh,nth->ntb", [proj_neg, x])
        else:
            phi_s = self.proj_s(emb)
            phi_neg = self.proj_neg(x)
        # CONV
        if self.outer_plate:
            c = (self.conv(emb.transpose(-1, -2)).transpose(-1, -2).view(
                N, T, -1, 2 * self.rnn_sz))  #[:,:,y_idx.squeeze(-1),:]
            cy = c.gather(2,
                          y_idx.view(N, 1, 1, 1).expand(N, T, 1,
                                                        100)).squeeze(-2)
        #phi_neg = torch.einsum("nbh,nth->ntb", [proj_neg, cy])
        # /CONV
        # add prior
        phi_neg = phi_neg + self.phi_b.view(1, 1, 2)

        phi_y = torch.zeros(N, len(self.S)).to(self.lut.weight.device)
        psi_ybs0 = torch.diag(self.psi_ys)
        psi_ybs1 = psi_ybs0 @ self.flip
        psi_ybs = (
            torch.stack([psi_ybs0, psi_ybs1], 1)
            #psi_ybs = (torch.stack([psi_ybs0 @ self.fm1, psi_ybs0 @ self.fm2], 1)
            .view(1, 1, len(self.S), 2, len(self.S)).repeat(N, T, 1, 1, 1))
        idxs = torch.arange(0, max(lens)).to(lens.device)
        # mask: N x R
        mask = (idxs.repeat(len(lens), 1) >= lens.unsqueeze(-1))
        phi_s.masked_fill_(mask.unsqueeze(-1), 0)
        phi_neg.masked_fill_(mask.unsqueeze(-1), 0)
        psi_ybs.masked_fill_(mask.view(N, T, 1, 1, 1).expand_as(psi_ybs), 0)
        Z, hy = ubersum("nts,ntb,ntybs,ny->n,ny",
                        phi_s,
                        phi_neg,
                        psi_ybs,
                        phi_y,
                        batch_dims="t",
                        modulo_total=True)
        if self.training:
            self._N += 1
        #if self._N > 1000 and self.training:
        """
        if self._N > 10 and self.training:
            Zt, hx, hb = ubersum(
                "nts,ntb,ntybs,ny->nt,nts,ntb",
                phi_s, phi_neg, psi_ybs, phi_y, batch_dims="t", modulo_total=True)
            xp = (hx - Zt.unsqueeze(-1)).exp()
            bp = (hb - Zt.unsqueeze(-1)).exp()
            yp = (hy - Z.unsqueeze(-1)).exp()
            def stuff(i):
                #loc = self.L.itos[l[i]]
                asp = self.A and self.A.itos[a[i]]
                return self.tostr(words[i]), None, asp, xp[i], yp[i], bp[i]
            if bp.max(-1)[1].sum() > 0:
                import pdb; pdb.set_trace()
            # wordsi, loc, asp, xpi, ypi, bpi = stuff(10)
        """
        return hy
Exemplo n.º 11
0
 def jit_ubersum(*operands):
     return ubersum(equation, *operands, batch_dims=batch_dims, modulo_total=True)
Exemplo n.º 12
0
    def forward(self, x, lens, k, kx):
        # model takes as input the text, aspect, and location
        # runs BLSTM over text using embedding(location, aspect) as
        # the initial hidden state, as opposed to a different lstm for every pair???
        # output sentiment

        # DBG
        words = x

        emb = self.drop(self.lut(x))
        p_emb = pack(emb, lens, True)

        l, a = k
        N = l.shape[0]
        T = x.shape[1]
        # factor this out, for sure. POSSIBLE BUGS
        y_idx = l * len(self.A) + a
        s = (self.lut_la(y_idx)
            .view(N, 2, 2 * self.nlayers, self.rnn_sz)
            .permute(1, 2, 0, 3)
            .contiguous())
        state = (s[0], s[1])
        x, (h, c) = self.rnn(p_emb, state)
        # h: L * D x N x H
        x = unpack(x, True)[0]
        # Get the last hidden states for both directions, POSSIBLE BUGS
        phi_s = self.proj_s(x)
        #"""
        idxs = torch.arange(0, max(lens)).to(lens.device)
        # mask: N x R x 1
        mask = (idxs.repeat(len(lens), 1) >= lens.unsqueeze(-1))
        phi_s[:,:,-1].masked_fill_(1-mask, float("-inf"))
        phi_s[:,:,:3].masked_fill_(mask.unsqueeze(-1), float("-inf"))
        #"""
        """
        h = (h
            .view(self.nlayers, 2, -1, self.rnn_sz)[-1]
            .permute(1, 0, 2)
            .contiguous()
            .view(-1, 2 * self.rnn_sz))
        phi_y = self.proj_y(h)
        """
        phi_y = torch.zeros(N, len(self.S)).to(self.psi_ys.device)
        psi_ys = torch.cat(
            [torch.diag(self.psi_ys), torch.zeros(len(self.S), 1).to(self.psi_ys)],
            dim=-1,
        ).expand(T, len(self.S), len(self.S)+1)
        #psi_ys = torch.diag(self.psi_ys).repeat(T, 1, 1)
        # Z is really weird here
        Z, hy = ubersum("nts,tys,ny->n,ny", phi_s, psi_ys, phi_y, batch_dims="t", modulo_total=True)
        #Z, hy = ubersum("nts,tys,ny->n,ny", phi_s, psi_ys, phi_y, batch_dims="t", modulo_total=True)
        def stuff(i):
            loc = self.L.itos[l[i]]
            asp = self.A.itos[a[i]]
            return self.tostr(words[i]), loc, asp, xp[i], yp[i]
        if self.training:
            self._N += 1
        if self._N > 100 and self.training:
            Zx, hx = ubersum("nts,tys->nt,nts", phi_s, psi_ys, batch_dims="t", modulo_total=True)
            xp = (hx - Zx.unsqueeze(-1)).exp()
            yp = (hy - Z.unsqueeze(-1)).exp()
            #Zx, hx = ubersum("nts,ys->nt,nts", phi_s, self.psi_ys, batch_dims="t")
            import pdb; pdb.set_trace()
            pass
            # text, loc, asp, xpi, ypi = stuff(10)
        #import pdb; pdb.set_trace()
        return hy# - Z.unsqueeze(-1)