def test_adjoint_marginal(equation, plates): inputs, output = equation.split('->') inputs = inputs.split(',') operands = [torch.randn(torch.Size((2,) * len(input_))) for input_ in inputs] for input_, x in zip(inputs, operands): x._pyro_dims = input_ # check forward pass for x in operands: require_backward(x) actual, = ubersum(equation, *operands, plates=plates, modulo_total=True, backend='pyro.ops.einsum.torch_marginal') expected, = ubersum(equation, *operands, plates=plates, modulo_total=True, backend='pyro.ops.einsum.torch_log') assert_equal(expected, actual) # check backward pass actual._pyro_backward() for input_, operand in zip(inputs, operands): marginal_equation = ','.join(inputs) + '->' + input_ expected, = ubersum(marginal_equation, *operands, plates=plates, modulo_total=True, backend='pyro.ops.einsum.torch_log') actual = operand._pyro_backward_result assert_equal(expected, actual)
def test_einsum_linear(equation, plates): inputs, outputs, log_operands, sizes = make_example(equation) operands = [x.exp() for x in log_operands] try: log_expected = ubersum(equation, *log_operands, plates=plates, modulo_total=True) expected = [x.exp() for x in log_expected] except NotImplementedError: pytest.skip() # einsum() is in linear space whereas ubersum() is in log space. actual = einsum(equation, *operands, plates=plates, modulo_total=True) assert isinstance(actual, tuple) assert len(actual) == len(outputs) for output, expected_part, actual_part in zip(outputs, expected, actual): assert_equal( expected_part.log(), actual_part.log(), msg="For output '{}':\nExpected:\n{}\nActual:\n{}".format( output, expected_part.detach().cpu(), actual_part.detach().cpu()), )
def test_adjoint_shape(backend, equation, plates): backend = "pyro.ops.einsum.torch_{}".format(backend) inputs, output = equation.split("->") inputs = inputs.split(",") operands = [ torch.randn(torch.Size((2, ) * len(input_))) for input_ in inputs ] for input_, x in zip(inputs, operands): x._pyro_dims = input_ # run forward-backward algorithm for x in operands: require_backward(x) (result, ) = ubersum(equation, *operands, plates=plates, modulo_total=True, backend=backend) result._pyro_backward() for input_, x in zip(inputs, operands): backward_result = x._pyro_backward_result contract_dims = set(input_) - set(output) - set(plates) if contract_dims: assert backward_result is not None else: assert backward_result is None
def test_ubersum_total(equation, plates): inputs, outputs, operands, sizes = make_example(equation, fill=1, sizes=(2,)) output = outputs[0] expected = naive_ubersum(equation, *operands, plates=plates)[0] actual = ubersum(equation, *operands, plates=plates, modulo_total=True)[0] expected = _normalize(expected, output, plates) actual = _normalize(actual, output, plates) assert_equal(expected, actual, msg=u"Expected:\n{}\nActual:\n{}".format( expected.detach().cpu(), actual.detach().cpu()))
def forward(self, x, lens, k, kx): # model takes as input the text, aspect, and location # runs BLSTM over text using embedding(location, aspect) as # the initial hidden state, as opposed to a different lstm for every pair??? # output sentiment # DBG words = x emb = self.drop(self.lut(x)) p_emb = pack(emb, lens, True) l, a = k N = x.shape[0] T = x.shape[1] y_idx = l * len(self.A) + a if self.L is not None else a s = (self.lut_la(y_idx).view(N, 2, 2 * self.nlayers, self.rnn_sz).permute(1, 2, 0, 3).contiguous()) state = (s[0], s[1]) x, (h, c) = self.rnn(p_emb, state) # h: L * D x N x H x = unpack(x, True)[0] proj_s = self.proj_s[y_idx.squeeze(-1)] phi_s = torch.einsum("nsh,nth->nts", [proj_s, emb]) idxs = torch.arange(0, max(lens)).to(lens.device) # mask: N x R x 1 mask = (idxs.repeat(len(lens), 1) >= lens.unsqueeze(-1)) phi_y = torch.zeros(N, len(self.S)).to(self.lut.weight.device) psi_ys = self.proj_ys(x).view(N, T, len(self.S) - 1, len(self.S) - 1) left = torch.zeros(N, T, 1, len(self.S) - 1).to(psi_ys) top = torch.cat( [self.psi_none, torch.zeros(len(self.S) - 1).to(psi_ys)], 0, ).view(1, 1, len(self.S), 1).expand(N, T, len(self.S), 1) psi_ys = torch.cat([top, torch.cat([left, psi_ys], -2)], -1) # mask phi_s, psi_ys...actually these are mostly unnecessary phi_s.masked_fill_(mask.unsqueeze(-1), 0) psi_ys = psi_ys.masked_fill_(mask.view(N, T, 1, 1), 0) Z, hy = ubersum("nts,ntys,ny->n,ny", phi_s, psi_ys, phi_y, batch_dims="t", modulo_total=True) return hy
def forward(self, x, lens, a, l): # model takes as input the text, aspect, and location # runs BLSTM over text using embedding(location, aspect) as # the initial hidden state, as opposed to a different lstm for every pair??? # output sentiment # DBG N, T = x.shape words = x emb = self.drop(self.lut(x)) p_emb = pack(emb, lens, True) state = None if self.outer_plate: y_idx = l * len(self.A) + a if self.L is not None else a s = (self.lut_la(y_idx).view(N, 2, 2 * self.nlayers, self.rnn_sz).permute(1, 2, 0, 3).contiguous()) state = (s[0], s[1]) x, (h, c) = self.rnn(p_emb, state) # h: L * D x N x H x = unpack(x, True)[0] phi_s = self.proj_s(x) idxs = torch.arange(0, max(lens)).to(lens.device) # mask: N x R x 1 mask = (idxs.repeat(len(lens), 1) >= lens.unsqueeze(-1)) phi_s[:, :, -1].masked_fill_(1 - mask, float("-inf")) phi_s[:, :, :3].masked_fill_(mask.unsqueeze(-1), float("-inf")) phi_y = torch.zeros(N, len(self.S)).to(self.psi_ys.device) psi_ys = torch.cat( [ torch.diag(self.psi_ys), torch.zeros(len(self.S), 1).to(self.psi_ys) ], dim=-1, ).expand(T, len(self.S), len(self.S) + 1) Z, hy = ubersum("nts,tys,ny->n,ny", phi_s, psi_ys, phi_y, batch_dims="t", modulo_total=True) return hy
def test_ubersum(equation, plates): inputs, outputs, operands, sizes = make_example(equation) try: actual = ubersum(equation, *operands, plates=plates, modulo_total=True) except NotImplementedError: pytest.skip() assert isinstance(actual, tuple) assert len(actual) == len(outputs) expected = naive_ubersum(equation, *operands, plates=plates) for output, expected_part, actual_part in zip(outputs, expected, actual): actual_part = _normalize(actual_part, output, plates) expected_part = _normalize(expected_part, output, plates) assert_equal(expected_part, actual_part, msg=u"For output '{}':\nExpected:\n{}\nActual:\n{}".format( output, expected_part.detach().cpu(), actual_part.detach().cpu()))
def test_ubersum_jit(equation, plates): inputs, outputs, operands, sizes = make_example(equation) try: expected = ubersum(equation, *operands, plates=plates, modulo_total=True) except NotImplementedError: pytest.skip() @pyro.ops.jit.trace def jit_ubersum(*operands): return ubersum(equation, *operands, plates=plates, modulo_total=True) actual = jit_ubersum(*operands) if not isinstance(actual, tuple): pytest.xfail(reason="https://github.com/pytorch/pytorch/issues/14875") assert len(expected) == len(actual) for e, a in zip(expected, actual): assert_equal(e, a)
def jit_ubersum(*operands): return ubersum(equation, *operands, plates=plates, modulo_total=True)
def forward(self, x, lens, a, l): words = x emb = self.drop(self.lut(x)) p_emb = pack(emb, lens, True) N = x.shape[0] T = x.shape[1] state = None if self.outer_plate: y_idx = l * len(self.A) + a if self.L is not None else a s = (self.lut_la(y_idx).view(N, 2, 2 * self.nlayers, self.rnn_sz).permute(1, 2, 0, 3).contiguous()) state = (s[0], s[1]) x, (h, c) = self.rnn(p_emb, state) # h: L * D x N x H x = unpack(x, True)[0] #import pdb; pdb.set_trace() phi_s, phi_neg = None, None if self.outer_plate: proj_s = self.proj_s[y_idx.squeeze(-1)] phi_s = torch.einsum("nsh,nth->nts", [proj_s, emb]) proj_neg = self.proj_neg[y_idx.squeeze(-1)] phi_neg = torch.einsum("nbh,nth->ntb", [proj_neg, x]) else: phi_s = self.proj_s(emb) phi_neg = self.proj_neg(x) # CONV if self.outer_plate: c = (self.conv(emb.transpose(-1, -2)).transpose(-1, -2).view( N, T, -1, 2 * self.rnn_sz)) #[:,:,y_idx.squeeze(-1),:] cy = c.gather(2, y_idx.view(N, 1, 1, 1).expand(N, T, 1, 100)).squeeze(-2) #phi_neg = torch.einsum("nbh,nth->ntb", [proj_neg, cy]) # /CONV # add prior phi_neg = phi_neg + self.phi_b.view(1, 1, 2) phi_y = torch.zeros(N, len(self.S)).to(self.lut.weight.device) psi_ybs0 = torch.diag(self.psi_ys) psi_ybs1 = psi_ybs0 @ self.flip psi_ybs = ( torch.stack([psi_ybs0, psi_ybs1], 1) #psi_ybs = (torch.stack([psi_ybs0 @ self.fm1, psi_ybs0 @ self.fm2], 1) .view(1, 1, len(self.S), 2, len(self.S)).repeat(N, T, 1, 1, 1)) idxs = torch.arange(0, max(lens)).to(lens.device) # mask: N x R mask = (idxs.repeat(len(lens), 1) >= lens.unsqueeze(-1)) phi_s.masked_fill_(mask.unsqueeze(-1), 0) phi_neg.masked_fill_(mask.unsqueeze(-1), 0) psi_ybs.masked_fill_(mask.view(N, T, 1, 1, 1).expand_as(psi_ybs), 0) Z, hy = ubersum("nts,ntb,ntybs,ny->n,ny", phi_s, phi_neg, psi_ybs, phi_y, batch_dims="t", modulo_total=True) if self.training: self._N += 1 #if self._N > 1000 and self.training: """ if self._N > 10 and self.training: Zt, hx, hb = ubersum( "nts,ntb,ntybs,ny->nt,nts,ntb", phi_s, phi_neg, psi_ybs, phi_y, batch_dims="t", modulo_total=True) xp = (hx - Zt.unsqueeze(-1)).exp() bp = (hb - Zt.unsqueeze(-1)).exp() yp = (hy - Z.unsqueeze(-1)).exp() def stuff(i): #loc = self.L.itos[l[i]] asp = self.A and self.A.itos[a[i]] return self.tostr(words[i]), None, asp, xp[i], yp[i], bp[i] if bp.max(-1)[1].sum() > 0: import pdb; pdb.set_trace() # wordsi, loc, asp, xpi, ypi, bpi = stuff(10) """ return hy
def jit_ubersum(*operands): return ubersum(equation, *operands, batch_dims=batch_dims, modulo_total=True)
def forward(self, x, lens, k, kx): # model takes as input the text, aspect, and location # runs BLSTM over text using embedding(location, aspect) as # the initial hidden state, as opposed to a different lstm for every pair??? # output sentiment # DBG words = x emb = self.drop(self.lut(x)) p_emb = pack(emb, lens, True) l, a = k N = l.shape[0] T = x.shape[1] # factor this out, for sure. POSSIBLE BUGS y_idx = l * len(self.A) + a s = (self.lut_la(y_idx) .view(N, 2, 2 * self.nlayers, self.rnn_sz) .permute(1, 2, 0, 3) .contiguous()) state = (s[0], s[1]) x, (h, c) = self.rnn(p_emb, state) # h: L * D x N x H x = unpack(x, True)[0] # Get the last hidden states for both directions, POSSIBLE BUGS phi_s = self.proj_s(x) #""" idxs = torch.arange(0, max(lens)).to(lens.device) # mask: N x R x 1 mask = (idxs.repeat(len(lens), 1) >= lens.unsqueeze(-1)) phi_s[:,:,-1].masked_fill_(1-mask, float("-inf")) phi_s[:,:,:3].masked_fill_(mask.unsqueeze(-1), float("-inf")) #""" """ h = (h .view(self.nlayers, 2, -1, self.rnn_sz)[-1] .permute(1, 0, 2) .contiguous() .view(-1, 2 * self.rnn_sz)) phi_y = self.proj_y(h) """ phi_y = torch.zeros(N, len(self.S)).to(self.psi_ys.device) psi_ys = torch.cat( [torch.diag(self.psi_ys), torch.zeros(len(self.S), 1).to(self.psi_ys)], dim=-1, ).expand(T, len(self.S), len(self.S)+1) #psi_ys = torch.diag(self.psi_ys).repeat(T, 1, 1) # Z is really weird here Z, hy = ubersum("nts,tys,ny->n,ny", phi_s, psi_ys, phi_y, batch_dims="t", modulo_total=True) #Z, hy = ubersum("nts,tys,ny->n,ny", phi_s, psi_ys, phi_y, batch_dims="t", modulo_total=True) def stuff(i): loc = self.L.itos[l[i]] asp = self.A.itos[a[i]] return self.tostr(words[i]), loc, asp, xp[i], yp[i] if self.training: self._N += 1 if self._N > 100 and self.training: Zx, hx = ubersum("nts,tys->nt,nts", phi_s, psi_ys, batch_dims="t", modulo_total=True) xp = (hx - Zx.unsqueeze(-1)).exp() yp = (hy - Z.unsqueeze(-1)).exp() #Zx, hx = ubersum("nts,ys->nt,nts", phi_s, self.psi_ys, batch_dims="t") import pdb; pdb.set_trace() pass # text, loc, asp, xpi, ypi = stuff(10) #import pdb; pdb.set_trace() return hy# - Z.unsqueeze(-1)