def forward(self, var_input, var_pred, var_true=None):

            # MMD Loss
            if var_true is None:
                X = th.cat([var_input, var_pred], 0)
            else:
                X = th.cat([th.cat([var_input, var_pred], 1),
                            th.cat([var_input, var_true], 1)], 0)
            # dot product between all combinations of rows in 'X'
            XX = X.mm(X.t())

            # dot product of rows with themselves
            X2 = (X.mul(X)).sum(dim=1)

            # exponent entries of the RBF kernel (without the sigma) for each
            # combination of the rows in 'X'
            # -0.5 * (x^Tx - 2*x^Ty + y^Ty)
            exponent = XX.sub((X2.mul(0.5)).expand_as(XX)) - \
                (((X2.t()).mul(0.5)).expand_as(XX))

            if self.cuda:
                lossMMD = Variable(th.cuda.FloatTensor([0]))
            else:
                lossMMD = Variable(th.zeros(1))
            for i in range(len(self.bandwiths)):
                kernel_val = exponent.mul(1. / self.bandwiths[i]).exp()
                lossMMD.add_((self.S.mul(kernel_val)).sum())

            return lossMMD.sqrt()
예제 #2
0
파일: vpg.py 프로젝트: mbalunovic/rl
    def update(self):
        self.learning_rate *= 0.99
        assert len(self.b_logprobs) == len(self.b_rewards)
        
        rewards = self.discount_rewards(self.b_rewards)
        g = Variable(torch.zeros(1))
        
        for i, logprobs in enumerate(self.b_logprobs):
            action = self.b_actions[i]

            rews = torch.zeros(logprobs.size())
            rews[0][action] = rewards[i]
            rews = Variable(rews)

            #print(logprobs, rews)
            
            g.add_(torch.dot(logprobs, rews))

        #print(g)
        g.backward()
        
        for f in self.nn.parameters():
            f.data.add_(f.grad.data * self.learning_rate)
            
        self.nn.zero_grad()
        self.b_logprobs = []
        self.b_rewards = []
        self.b_actions = []
    def forward(self, input, target):
        # truncate to the same size
        # input (batch_size * (seq_length + 2) * (vocab_size + 1))
        # target (batch_size * (seq_length))
        batch_size, L, Mp1 = input.size(0), input.size(1), input.size(2)
        seq_length = target.size(1)

        cumreward = Variable(torch.FloatTensor(1).zero_(), requires_grad=True).cuda()

        for tt in xrange(seq_length):
            #
            reward = self.reward_func(input, target, tt)

            cumreward.add_(reward)

        # num_samples
        self.num_samples = self.reward_func.num_samples(input, target)

        # normalizing_coeff
        self.normalizing_coeff = self.weight / (self.sizeAverage and self.num_samples or 1)

        # here there is a '-' because we minimize
        self.output = -cumreward * self.normalizing_coeff

        # cumreward
        return self.output, self.num_samples
예제 #4
0
 def test_nn(self):
     n_samples = 100
     n_features = 3
     n_out = 1
     X = Variable(torch.randn(n_samples, n_features).type(FloatTensor),
                  requires_grad=False)
     y = Variable(torch.randn(n_samples, n_out).type(FloatTensor),
                  requires_grad=False)
     y.add_(X[:, 0] * 2)
     y.add_(X[:, 1] * 3)
     torch_nn = TorchNN(learning_rate=1e-3, n_hidden=4).fit(X, y, iters=100)
     self.assertGreater(torch_nn.loss_path[0], torch_nn.loss_path[-1])
예제 #5
0
    def test_inplace(self):
        x = Variable(torch.ones(5, 5), requires_grad=True)
        y = Variable(torch.ones(5, 5) * 4, requires_grad=True)

        z = x * y
        q = z + y
        w = z * y
        z.add_(2)
        # Add doesn't need it's inputs to do backward, so it shouldn't raise
        q.backward(torch.ones(5, 5), retain_variables=True)
        # Mul saves both inputs in forward, so it should raise
        self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))

        z = x * y
        q = z * y
        r = z + y
        w = z.add_(y)
        # w is a the last expression, so this should succeed
        w.backward(torch.ones(5, 5), retain_variables=True)
        # r doesn't use the modified value in backward, so it should succeed
        r.backward(torch.ones(5, 5), retain_variables=True)
        # q uses dirty z, so it should raise
        self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5)))

        x.grad.data.zero_()
        m = x / 2
        z = m + y / 8
        q = z * y
        r = z + y
        prev_version = z._version
        w = z.exp_()
        self.assertNotEqual(z._version, prev_version)
        r.backward(torch.ones(5, 5), retain_variables=True)
        self.assertEqual(x.grad.data, torch.ones(5, 5) / 2)
        w.backward(torch.ones(5, 5), retain_variables=True)
        self.assertEqual(x.grad.data,
                         torch.Tensor(5, 5).fill_((1 + math.e) / 2))
        self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5)))

        leaf = Variable(torch.ones(5, 5), requires_grad=True)
        x = leaf.clone()
        x.add_(10)
        self.assertEqual(x.data, torch.ones(5, 5) * 11)
        # x should be still usable
        y = x + 2
        y.backward(torch.ones(5, 5))
        self.assertEqual(leaf.grad.data, torch.ones(5, 5))
        z = x * y
        x.add_(2)
        self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
예제 #6
0
    def forward(self, pred, target):
        """Compute the loss model.

        :param pred: predicted Variable
        :param target: Target Variable
        :return: Loss
        """
        loss = Variable(th.FloatTensor([0]))
        for i in range(1, self.moments):
            mk_pred = th.mean(th.pow(pred, i), 0)
            mk_tar = th.mean(th.pow(target, i), 0)

            loss.add_(th.mean((mk_pred - mk_tar) ** 2))  # L2

        return loss
예제 #7
0
파일: structs.py 프로젝트: ulzee/mpgan
        def rec(node, mfunc):
            msg = Variable(torch.zeros(node.hsize, )).to(Tree.device)
            for child in node.children:
                msg.add_(mfunc(node.h_v, child.h_v))
            if node.parent is not None:
                msg.add_(mfunc(node.h_v, node.parent.h_v))

            child_msgs = []
            for child in node.children:
                child_msgs.append(rec(child, mfunc))

            return {
                'msg': msg,
                'child_msgs': child_msgs,
            }
예제 #8
0
def add_jitter(mat):
    """
    Adds "jitter" to the diagonal of a matrix.
    This ensures that a matrix that *should* be positive definite *is* positive definate.

    Args:
        - mat (matrix nxn) - Positive definite matrxi
    Returns: (matrix nxn)
    """
    if isinstance(mat, LazyVariable):
        return mat.add_jitter()
    elif isinstance(mat, Variable):
        diag = Variable(mat.data.new(mat.size(-1)).fill_(1e-3).diag())
        if mat.ndimension() == 3:
            return mat + diag.unsqueeze(0).expand(mat.size(0), mat.size(1),
                                                  mat.size(2))
        else:
            return mat + diag
    else:
        diag = mat.new(mat.size(-1)).fill_(1e-3).diag()
        if mat.ndimension() == 3:
            return mat.add_(
                diag.unsqueeze(0).expand(mat.size(0), mat.size(1),
                                         mat.size(2)))
        else:
            return diag.add_(mat)
예제 #9
0
  def forward(self, x=None, pos=None):
    """Compute positional embeddings.

    Args:
      x: Tensor of size [batch_size, max_len]

    Returns:
      emb: Tensor of size [batch_size, max_len, d_word_vec].
    """

    d_word_vec = self.hparams.d_word_vec
    if pos is not None:
      batch_size, max_len = pos.size()
      pos = Variable(pos)
    else:
      batch_size, max_len = x.size()
      pos = Variable(torch.arange(0, max_len))
    if self.hparams.cuda:
      pos = pos.cuda()
    if self.hparams.pos_emb_size is not None:
      pos = pos.add_(1).long().unsqueeze(0).expand_as(x).contiguous()
      emb = self.emb(pos)
    else:
      emb = pos.float().unsqueeze(-1) * self.freq.unsqueeze(0)
      sin = torch.sin(emb).mul_(self.emb_scale).unsqueeze(-1)
      cos = torch.cos(emb).mul_(self.emb_scale).unsqueeze(-1)
      #emb = pos.float().unsqueeze(-1) / self.freq.unsqueeze(0)
      #sin = torch.sin(emb).unsqueeze(-1)
      #cos = torch.cos(emb).unsqueeze(-1)
      emb = torch.cat([sin, cos], dim=-1).contiguous().view(max_len, d_word_vec)
      emb = emb.unsqueeze(0).expand(batch_size, -1, -1)

    return emb
예제 #10
0
 def test_log_reg(self):
     n_samples = 100
     n_features = 3
     n_out = 1
     X = Variable(torch.randn(n_samples, n_features).type(FloatTensor),
                  requires_grad=False)
     y = Variable(torch.randn(n_samples, n_out).type(FloatTensor),
                  requires_grad=False)
     y.add_(X[:, 0] * 2)
     y.add_(X[:, 1] * 3)
     y = y.ge(0).type(FloatTensor)
     torch_log_reg = TorchLogreg(learning_rate=1e-2,
                                 verbose=True).fit(X, y, iters=100)
     self.assertGreater(torch_log_reg.loss_path[0],
                        torch_log_reg.loss_path[-1])
     self.check_convergence(torch_log_reg, ideal_W=[.4, 8, 0])
예제 #11
0
 def test_torch_reg(self):
     n_samples = 100
     n_features = 3
     n_out = 1
     X = Variable(torch.randn(n_samples, n_features).type(FloatTensor),
                  requires_grad=False)
     y = Variable(torch.randn(n_samples, n_out).type(FloatTensor),
                  requires_grad=False)
     y.add_(X[:, 0] * 2)
     y.add_(X[:, 1] * 3)
     torch_reg = TorchReg(learning_rate=1e-3).fit(X, y, iters=100)
     coeffs = torch_reg.W.data.numpy()[:, 0]
     print(coeffs)
     self.assertGreater(coeffs[0], 1.8)
     self.assertGreater(coeffs[1], 2.8)
     self.check_convergence(torch_reg, ideal_W=[2, 3, 0])
예제 #12
0
 def test_shared_storage(self):
     x = Variable(torch.ones(5, 5))
     y = x.t()
     z = x[1]
     self.assertRaises(RuntimeError, lambda: x.add_(2))
     self.assertRaises(RuntimeError, lambda: y.add_(2))
     self.assertRaises(RuntimeError, lambda: z.add_(2))
예제 #13
0
    def test_remote_var_binary_methods(self):

        hook = TorchHook()
        local = hook.local_worker
        remote = VirtualWorker(hook, 0)
        local.add_worker(remote)

        x = Var(torch.FloatTensor([1, 2, 3, 4, 5])).send(remote)
        y = Var(torch.FloatTensor([1, 2, 3, 4, 5])).send(remote)
        assert torch.equal(x.add_(y).get(),  Var(torch.FloatTensor([2,4,6,8,10])))
예제 #14
0
    def test_remote_var_binary_methods(self):

        hook = TorchHook()
        local = hook.local_worker
        remote = VirtualWorker(hook, 0)
        local.add_worker(remote)

        x = Var(torch.FloatTensor([1, 2, 3, 4, 5])).send(remote)
        y = Var(torch.FloatTensor([1, 2, 3, 4, 5])).send(remote)
        assert torch.equal(x.add_(y).get(),  Var(torch.FloatTensor([2,4,6,8,10])))
예제 #15
0
class NeuralDict(nn.Module):
    def __init__(self, key_size, dict_size):
        super().__init__()
        self.dict_size = dict_size
        self.key_size = key_size
        self.fill_count = 0
        self.stale_ind = None

        self.keys = Variable(torch.zeros(self.dict_size, self.key_size).cuda())
        self.values = Variable(torch.zeros(self.dict_size, 1).cuda())
        self.recency_map = Variable(torch.zeros(self.dict_size, 1).cuda())

    def forward(self, _input):
        return

    def get_knn(self, query, k):
        """
        get K nearest neighbors by cosine distance for query
        """
        if self.fill_count < self.dict_size:
            norm_keys = F.normalize(self.keys[:self.fill_count, :], dim=1)
        else:
            norm_keys = F.normalize(self.keys, dim=1)

        norm_query = F.normalize(query, dim=1)
        cosine_dist = torch.mm(norm_keys, norm_query.t())
        return torch.topk(cosine_dist, k, dim=1)

    def add_key(self, key, value):
        if self.fill_count >= self.dict_size:
            self.keys[self.stale_ind] = key
            self.values[self.stale_ind] = value
        else:
            self.keys[self.fill_count] = key
            self.values[self.fill_count] = value
            self.fill_count += 1

    def update_recency_map(self, nn_indices):
        mask = Variable(torch.zero(*nn_indices.size()).cuda().fill_(1))
        self.recency_map.scatter_add(0, nn_indices.view(-1), mask)
        self.recency_map.add_(-1).clamp_(0, 100)
        _, self.stale_ind = self.recency_map.min(0)
예제 #16
0
def gumbel_sample(input, temperature=1.0, avg=False, N=10000):

    # more accurate version of gumbel estimator as described in https://arxiv.org/abs/1706.04161
    # averages N gumbel distributions and subtracts out Euler's constant
    if avg:
        noise = to_gpu(torch.rand([input.size()[-1] * N]))
        noise.add_(1e-9).log_().neg_()
        noise.add_(1e-9).log_().neg_()
        noise.add_(-EULER)
        noise = Variable(noise.view(N, input.size(-1)))
        x = (input.expand_as(noise) + noise)
        x = torch.mean(x, 0) / temperature
    else:
        noise = to_gpu(torch.rand(input.size()))
        noise.add_(1e-9).log_().neg_()
        noise.add_(1e-9).log_().neg_()
        noise = Variable(noise)
        x = (input + noise) / temperature
    x = F.softmax(x.view(input.size(0), -1))
    return x.view_as(input)
예제 #17
0
    def backward(ctx, grad_output):
        """
        X : Distance between similar samples
        Y : Distance between dissimilar samples

        Gradients:
           dLoss/dX = 1 + 1 / (exp(-X) + exp(-Y)) * -1 * exp(-X)
                    = 1 - exp(-X) / (exp(-X) + exp(-Y))

           dLoss/dY = 0 + 1 / (exp(-X) + exp(-Y)) * -1 * exp(-Y)
                    = -exp(-Y) / (exp(-X) + exp(-Y))
        """
        input1, input2, y = ctx.saved_variables
        grad_input1 = Variable(input1.data.new(input1.size()).zero_())
        grad_input2 = Variable(input1.data.new(input1.size()).zero_())

        grad_input1.add_(-1, torch.exp(-input1.clone()))
        grad_input2.add_(-1, torch.exp(-input2.clone()))

        dist = input1.clone().mul(-1).exp() + input2.clone().mul(-1).exp()
        grad_input1.div_(dist)
        grad_input2.div_(dist)

        grad_input1.add_(1)
        #grad_input1[y == 1].add_(1)  # supporting switched samples
        #grad_input2[y == -1].add_(1)  # supporting switched samples

        return grad_input1 * grad_output, grad_input2 * grad_output, None, None
예제 #18
0
def smoothness(grid):
    """
    Given a variable of dimensions (N, X, Y, [Z], C), computes the sum of
    the differences between adjacent points in the grid formed by the
    dimensions X, Y, and (optionally) Z. Returns a tensor of dimension N.
    """

    num_dims = len(grid.size()) - 2
    batch_size = grid.size()[0]
    norm = Variable(
        torch.zeros(batch_size, dtype=grid.data.dtype,
                    device=grid.data.device))

    for dim in range(num_dims):
        slice_before = (slice(None), ) * (dim + 1)
        slice_after = (slice(None), ) * (num_dims - dim)
        shifted_grids = [
            # left
            torch.cat([
                grid[slice_before + (slice(1, None), ) + slice_after],
                grid[slice_before + (slice(-1, None), ) + slice_after],
            ], dim + 1),
            # right
            torch.cat([
                grid[slice_before + (slice(None, 1), ) + slice_after],
                grid[slice_before + (slice(None, -1), ) + slice_after],
            ], dim + 1)
        ]
        for shifted_grid in shifted_grids:
            delta = shifted_grid - grid
            norm_components = (delta.pow(2).sum(-1) + 1e-10).pow(0.5)
            norm.add_(
                norm_components.sum(
                    tuple(range(1, len(norm_components.size())))))

    return norm
예제 #19
0
def add_jitter(mat):
    """
    Adds "jitter" to the diagonal of a matrix.
    This ensures that a matrix that *should* be positive definite *is* positive definate.

    Args:
        - mat (matrix nxn) - Positive definite matrxi
    Returns: (matrix nxn)
    """
    if isinstance(mat, LazyVariable):
        return mat.add_jitter()
    elif isinstance(mat, Variable):
        diag = Variable(mat.data.new(len(mat)).fill_(1e-3).diag())
        return mat + diag
    else:
        diag = mat.new(len(mat)).fill_(1e-3).diag()
        return diag.add_(mat)
예제 #20
0
w = torch.FloatTensor(4, 2)
init.xavier_normal(w)
w = Variable(w, requires_grad=True)

optimizer = optim.Adam([w], lr=1e-1)

BATCH_SIZE = 30

st = torch.FloatTensor(BATCH_SIZE, 4)
dreward = torch.Tensor([-1]).expand(BATCH_SIZE)
for ep in range(1000):
    st.uniform_(-0.05, 0.05)
    vst = Variable(st)
    reward = Variable(torch.zeros(BATCH_SIZE))
    multiplier = 1
    notdone = Variable(torch.ones(BATCH_SIZE))
    for i in range(800):
        logits = vst @ w
        y = gumbel_softmax(logits, tau=1, hard=True)
        vst, r, d = step_cartpole(vst, y)
        tester = notdone * (1 - d) * r * multiplier
        reward.add_(tester)
        notdone = (d == 0).float()
        multiplier *= 0.99
        if (d.data > 0).all(): break
    if ep % 5 == 0:
        print("Num steps", i)
        optimizer.zero_grad()
    reward.backward(dreward)
    optimizer.step()
예제 #21
0
def ais_trajectory(model,
                   loader,
                   mode='forward',
                   schedule=np.linspace(0., 1., 500),
                   n_sample=100):
    """Compute annealed importance sampling trajectories for a batch of data. 
    Could be used for *both* forward and reverse chain in bidirectional Monte Carlo
    (default: forward chain with linear schedule).

    Args:
        model (vae.VAE): VAE model
        loader (iterator): iterator that returns pairs, with first component being `x`,
            second would be `z` or label (will not be used)
        mode (string): indicate forward/backward chain; must be either `forward` or 
            'backward' schedule (list or 1D np.ndarray): temperature schedule,
            i.e. `p(z)p(x|z)^t`; foward chain has increasing values, whereas
            backward has decreasing values
        n_sample (int): number of importance samples (i.e. number of parallel chains 
            for each datapoint)

    Returns:
        A list where each element is a torch.autograd.Variable that contains the 
        log importance weights for a single batch of data
    """

    assert mode == 'forward' or mode == 'backward', 'Should have forward/backward mode'

    def log_f_i(z, data, t, log_likelihood_fn=log_bernoulli):
        """Unnormalized density for intermediate distribution `f_i`:
            f_i = p(z)^(1-t) p(x,z)^(t) = p(z) p(x|z)^t
        =>  log f_i = log p(z) + t * log p(x|z)
        """
        zeros = Variable(torch.zeros(B, z_size).type(mdtype))
        log_prior = log_normal(z, zeros, zeros)
        log_likelihood = log_likelihood_fn(model.decode(z), data)

        return log_prior + log_likelihood.mul_(t)

    model.eval()

    # shorter aliases
    z_size = model.z_size
    mdtype = model.dtype

    _time = time.time()
    logws = []  # for output

    print('In %s mode' % mode)

    for i, (batch, post_z) in enumerate(loader):

        B = batch.size(0) * n_sample
        batch = Variable(batch.type(mdtype))
        batch = safe_repeat(batch, n_sample)

        # batch of step sizes, one for each chain
        epsilon = Variable(torch.ones(B).type(model.dtype)).mul_(0.01)
        # accept/reject history for tuning step size
        accept_hist = Variable(torch.zeros(B).type(model.dtype))
        # record log importance weight; volatile=True reduces memory greatly
        logw = Variable(torch.zeros(B).type(mdtype), volatile=True)

        # initial sample of z
        if mode == 'forward':
            current_z = Variable(torch.randn(B, z_size).type(mdtype),
                                 requires_grad=True)
        else:
            current_z = Variable(safe_repeat(post_z, n_sample).type(mdtype),
                                 requires_grad=True)

        for j, (t0, t1) in tqdm(enumerate(zip(schedule[:-1], schedule[1:]),
                                          1)):
            # update log importance weight
            log_int_1 = log_f_i(current_z, batch, t0)
            log_int_2 = log_f_i(current_z, batch, t1)
            logw.add_(log_int_2 - log_int_1)

            # resample speed
            current_v = Variable(torch.randn(current_z.size()).type(mdtype))

            def U(z):
                return -log_f_i(z, batch, t1)

            def grad_U(z):
                # grad w.r.t. outputs; mandatory in this case
                grad_outputs = torch.ones(B).type(mdtype)
                # torch.autograd.grad default returns volatile
                grad = torchgrad(U(z), z, grad_outputs=grad_outputs)[0]
                # avoid humongous gradients
                grad = torch.clamp(grad, -10000, 10000)
                # needs variable wrapper to make differentiable
                grad = Variable(grad.data, requires_grad=True)
                return grad

            def normalized_kinetic(v):
                zeros = Variable(torch.zeros(B, z_size).type(mdtype))
                # this is superior to the unnormalized version
                return -log_normal(v, zeros, zeros)

            z, v = hmc_trajectory(current_z, current_v, U, grad_U, epsilon)

            # accept-reject step
            current_z, epsilon, accept_hist = accept_reject(
                current_z,
                current_v,
                z,
                v,
                epsilon,
                accept_hist,
                j,
                U,
                K=normalized_kinetic)

        # IWAE lower bound
        logw = log_mean_exp(logw.view(n_sample, -1).transpose(0, 1))
        if mode == 'backward':
            logw = -logw
        logws.append(logw.data)

        print ('Time elapse %.4f, last batch stats %.4f' % \
            (time.time()-_time, logw.mean().cpu().data.numpy()))

        _time = time.time()
        sys.stdout.flush()  # for debugging

    return logws
예제 #22
0
class MANNCell(nn.Module):
    def __init__(self, options, input_size, hidden_size, use_bias=True):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.use_bias = use_bias
        self.time_fac = options['time_fac']
        self.weight_ih = nn.Parameter(
            torch.FloatTensor(2 * input_size, 3 * hidden_size))
        self.weight_hh = nn.Parameter(
            torch.FloatTensor(2 * hidden_size, 3 * hidden_size))
        self.bias_ih = nn.Parameter(torch.FloatTensor(1, 3 * hidden_size))
        self.bias_hh = nn.Parameter(torch.FloatTensor(1, 3 * hidden_size))
        self.memcnt = 0
        self.memcap = options['mem_cap']
        self.head_size = options['head_size']
        mode = options['mode']
        if mode == 'train':
            batch_size = options['batch_size']
        elif mode == 'val':
            batch_size = options['eval_batch_size']
        elif mode == 'test':
            batch_size = options['test_batch_size']
        self.batch_size = batch_size
        self.auxcell = GRUCell(input_size + hidden_size,
                               2 * (self.head_size + 1))
        self.tau = 1.
        self.i_fc = nn.Sequential(
            nn.Linear(input_size, self.head_size // 2), nn.ReLU(),
            nn.Linear(self.head_size // 2, self.head_size), nn.Sigmoid())
        self.h_fc = nn.Sequential(
            nn.Linear(hidden_size, self.head_size // 2), nn.ReLU(),
            nn.Linear(self.head_size // 2, self.head_size), nn.Sigmoid())

        self.last_usage = None
        self.mem = None

        self.reset_parameters()

    def _reset_mem(self):
        self.memcnt = 0
        self.imem = Variable(torch.zeros(self.batch_size, self.memcap,
                                         self.input_size),
                             requires_grad=True).cuda()
        self.hmem = Variable(torch.zeros(self.batch_size, self.memcap,
                                         self.hidden_size),
                             requires_grad=True).cuda()
        self.i_last_use = Variable(
            torch.ones(self.batch_size, self.memcap) * -9999999.).cuda()
        self.h_last_use = Variable(
            torch.ones(self.batch_size, self.memcap) * -9999999.).cuda()

    def __repr__(self):
        s = '{name}({input_size}, {hidden_size})'
        return s.format(name=self.__class__.__name__, **self.__dict__)

    def set_tau(self, num):
        self.tau = num

    def reset_parameters(self):
        """
        Initialize parameters following the way proposed in the paper.
        """
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for n, p in self.named_parameters():
            if 'weight' in n:
                nn.init.orthogonal_(p)
            if 'bias' in n:
                nn.init.constant_(p.data, val=0)

    def forward(self, input_, h_0, aux_h_0):

        i = input_
        h = h_0.detach()
        read_head = self.auxcell(torch.cat([i, h], dim=1), aux_h_0)
        i_read_head, h_read_head = torch.split(read_head,
                                               self.head_size + 1,
                                               dim=1)
        i_head_vecs = torch.cat([
            self.i_fc(self.imem.detach()), self.time_fac *
            torch.sigmoid(self.i_last_use).detach().unsqueeze(2)
        ],
                                dim=2)
        h_head_vecs = torch.cat([
            self.h_fc(self.hmem.detach()), self.time_fac *
            torch.sigmoid(self.h_last_use).detach().unsqueeze(2)
        ],
                                dim=2)
        i_read_head = 1 / torch.sqrt(
            (1e-6 + (i_read_head.unsqueeze(1) - i_head_vecs)**2).sum(dim=2))
        h_read_head = 1 / torch.sqrt(
            (1e-6 + (h_read_head.unsqueeze(1) - h_head_vecs)**2).sum(dim=2))
        i_entry, i_read_index, h_entry, h_read_index = self.read(
            i_read_head, h_read_head, self.tau)
        self.i_last_use.add_(-1).add_(-self.i_last_use * i_read_index)
        self.h_last_use.add_(-1).add_(-self.h_last_use * h_read_index)

        new_i = torch.cat([input_, i_entry], dim=1)
        new_h0 = torch.cat([h_0, h_entry], dim=1)
        wi_b = torch.addmm(self.bias_ih, new_i, self.weight_ih)
        wh_b = torch.addmm(self.bias_hh, new_h0, self.weight_hh)
        ri, zi, ni = torch.split(wi_b, self.hidden_size, dim=1)
        rh, zh, nh = torch.split(wh_b, self.hidden_size, dim=1)
        r = torch.sigmoid(ri + rh)
        z = torch.sigmoid(zi + zh)
        n = torch.tanh(ni + r * nh)
        h_1 = (1 - z) * n + z * h_0

        if self.memcnt < self.memcap:
            h_write_index = i_write_index = Variable(
                torch.cat([
                    torch.zeros(self.memcnt),
                    torch.ones(1),
                    torch.zeros(self.memcap - 1 - self.memcnt)
                ]).unsqueeze(0)).cuda()
            self.memcnt += 1
        else:
            h_write_index = h_read_index
            i_write_index = i_read_index
        self.write(input_, i_write_index, h_0, h_write_index)

        return h_1, read_head

    def write(self, i, i_index, h, h_index):
        i_ones = i_index.unsqueeze(2)
        h_ones = h_index.unsqueeze(2)
        self.imem = i.unsqueeze(1) * i_ones + self.imem * (1. - i_ones)
        self.hmem = h.unsqueeze(1) * h_ones + self.hmem * (1. - h_ones)

    def read(self, i_read_head, h_read_head, tau):
        i_index, _ = self.gumbel_softmax(i_read_head, tau)
        h_index, _ = self.gumbel_softmax(h_read_head, tau)
        i_entry = i_index.unsqueeze(2) * self.imem
        h_entry = h_index.unsqueeze(2) * self.hmem
        i_entry = i_entry.sum(dim=1)
        h_entry = h_entry.sum(dim=1)
        return i_entry, i_index, h_entry, h_index

    def gumbel_softmax(self, input, tau):
        gumbel = Variable(-torch.log(
            1e-20 - torch.log(1e-20 + torch.rand(*input.shape)))).cuda()
        y = torch.nn.functional.softmax(
            (torch.log(1e-20 + input) + gumbel) * tau, dim=1)
        ymax, pos = y.max(dim=1)
        hard_y = torch.eq(y, ymax.unsqueeze(1)).float()
        y = (hard_y - y).detach() + y
        return y, pos

    def gumbel_sigmoid(self, input, tau):
        gumbel = Variable(-torch.log(
            1e-20 - torch.log(1e-20 + torch.rand(*input.shape)))).cuda()
        y = torch.sigmoid((input + gumbel) * tau)
        #hard_y=torch.eq(y,ymax.unsqueeze(1)).float()
        #y=(hard_y-y).detach()+y
        return y
예제 #23
0
    def stAdv_norm(self):
        """ Computes the norm used in
           "Spatially Transformed Adversarial Examples"
        """

        # ONLY WORKS FOR SQUARE MATRICES
        dtype = self.xform_params.data.type()
        num_examples, height, width = tuple(self.xform_params.shape[0:3])
        assert height == width

        ######################################################################
        #   Build permutation matrices                                       #
        ######################################################################

        def id_builder():
            x = torch.zeros(height, width).type(dtype)
            for i in range(height):
                x[i, i] = 1
            return x

        col_permuts = []
        row_permuts = []
        # torch.matmul(foo, col_permut)
        for col in ['left', 'right']:
            col_val = {'left': -1, 'right': 1}[col]
            idx = ((torch.arange(width) - col_val) % width)
            idx = idx.type(dtype).type(torch.LongTensor)
            if self.xform_params.is_cuda:
                idx = idx.cuda()

            col_permut = torch.zeros(height, width).index_copy_(
                1, idx.cpu(),
                id_builder().cpu())
            col_permut = col_permut.type(dtype)

            if col == 'left':
                col_permut[-1][0] = 0
                col_permut[0][0] = 1
            else:
                col_permut[0][-1] = 0
                col_permut[-1][-1] = 1
            col_permut = Variable(col_permut)
            col_permuts.append(col_permut)
            row_permuts.append(col_permut.transpose(0, 1))

        ######################################################################
        #   Build delta_u, delta_v grids                                     #
        ######################################################################
        id_params = Variable(self.identity_params(self.img_shape))
        delta_grids = self.xform_params - id_params
        delta_grids = delta_grids.permute(0, 3, 1, 2)

        ######################################################################
        #   Compute the norm                                                 #
        ######################################################################
        output = Variable(torch.zeros(num_examples).type(dtype))

        for row_or_col, permutes in zip(['row', 'col'],
                                        [row_permuts, col_permuts]):
            for permute in permutes:
                if row_or_col == 'row':
                    temp = delta_grids - torch.matmul(permute, delta_grids)
                else:
                    temp = delta_grids - torch.matmul(delta_grids, permute)
                temp = temp.pow(2)
                temp = temp.sum(1)
                temp = (temp + 1e-10).pow(0.5)
                output.add_(temp.sum((1, 2)))
        return output
예제 #24
0
#To compute all the gradients after performing all the forward propagation, .backward() can be used to automatically compute the gradients
#autograd.Variable
# extracting the tensor value form the variable - autograd.Variable.data
# to extract the gradients, .grad will of the help

#--------------------------------
# Functions in autograd -
#Variables and functions do make up the acyclic graphs..
# Every variable that was created due to some tranformations and operations can refer to the function that created that variable using .grad_fn()
from torch.autograd import Variable
import numpy
x = Variable(torch.ones(2,2))
# probably can also create a variable fircetly using the numpy array too!!
x = Variable(torch.Tensor(numpy.linspace(1,10,30)), requires_grad = True)
# Performing an operation on the varaible
x.add_(2)
# Now the function that has updated the variable x will be updated, pointed to by the parameter, grad_fn
print(x.grad_fn)
# Some more computations
z = y*y*3
out = z.mean()
prrint(z, out)
# gradients in autograd
out.backward() # performs the different the variable "out" wrt x (the variable we started the tranformation from)
print(out)
#----------------------

x = torch.randn(3)
x = Variable(x, requires_grad = True)
y = x*2
while(y.data.norm() < 1000):
예제 #25
0
def main():

    global opt, model, HEIGHT, WIDTH, SCALE
    opt = parser.parse_args()
    print(opt)
    test_image = None
    if opt.testing:
        opt.batchSize = 1
        img = imread(opt.test_image)
        HEIGHT, WIDTH = img.shape[0], img.shape[1]
        test_image = Image.fromarray(np.uint8(img))
        test_image = np.asarray(test_image)

        if test_image.ndim == 3:
            if test_image.shape[2] != 3:
                test_image = test_image[:, :, 0:3]

            test_image = torch.ByteTensor(
                torch.ByteStorage.from_buffer(
                    test_image.transpose(2, 0,
                                         1).tobytes())).float().div(255).view(
                                             -1, 3, HEIGHT, WIDTH)
        else:
            print('not good... we do not upscale non color images yet')
            return

    cuda = opt.cuda
    if cuda and not torch.cuda.is_available():
        raise Exception('No GPU found, please run without --cuda')

    opt.seed = random.randint(1, 10000)
    print('Random Seed: ', opt.seed)
    torch.manual_seed(opt.seed)
    if cuda:
        torch.cuda.manual_seed(opt.seed)

    model = SRResNet()

    #  clean this mess up!
    if opt.testing:
        model.eval()
        mean = torch.zeros(opt.batchSize, 3, HEIGHT * SCALE, WIDTH * SCALE)
        mean[:, 0, :, :] = 0.485
        mean[:, 1, :, :] = 0.456
        mean[:, 2, :, :] = 0.406

        std = torch.zeros(opt.batchSize, 3, HEIGHT * SCALE, WIDTH * SCALE)
        std[:, 0, :, :] = 0.229
        std[:, 1, :, :] = 0.224
        std[:, 2, :, :] = 0.225

        tmean = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH)
        tmean[:, 0, :, :] = 0.485
        tmean[:, 1, :, :] = 0.456
        tmean[:, 2, :, :] = 0.406

        tstd = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH)
        tstd[:, 0, :, :] = 0.229
        tstd[:, 1, :, :] = 0.224
        tstd[:, 2, :, :] = 0.225

    else:
        model.train()
        mean = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH)
        mean[:, 0, :, :] = 0.485
        mean[:, 1, :, :] = 0.456
        mean[:, 2, :, :] = 0.406

        std = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH)
        std[:, 0, :, :] = 0.229
        std[:, 1, :, :] = 0.224
        std[:, 2, :, :] = 0.225

        tmean = torch.zeros(opt.batchSize, 3, HEIGHT // SCALE, WIDTH // SCALE)
        tmean[:, 0, :, :] = 0.485
        tmean[:, 1, :, :] = 0.456
        tmean[:, 2, :, :] = 0.406

        tstd = torch.zeros(opt.batchSize, 3, HEIGHT // SCALE, WIDTH // SCALE)
        tstd[:, 0, :, :] = 0.229
        tstd[:, 1, :, :] = 0.224
        tstd[:, 2, :, :] = 0.225

    if not opt.pretraining and not opt.testing:
        percep_model = models.__dict__['vgg19'](pretrained=True)
        percep_model.features = nn.Sequential(
            *list(percep_model.features.children())[:-14])
        percep_model.eval()

    criterion = nn.MSELoss(size_average=False)
    lr = opt.lr

    if cuda:
        model = torch.nn.DataParallel(model).cuda()
        criterion = criterion.cuda()
        if not opt.pretraining and not opt.testing:
            percep_model = percep_model.cuda()
        mean = Variable(mean).cuda()
        std = Variable(std).cuda()
        tmean = Variable(tmean).cuda()
        tstd = Variable(tstd).cuda()

    if opt.pretrained:
        if os.path.isfile(opt.pretrained):
            print('=> loading model {}'.format(opt.pretrained))
            weights = torch.load(opt.pretrained)
            model.load_state_dict(weights['model'].state_dict())
        else:
            print('=> no model found at {}'.format(opt.pretrained))

    if opt.testing:
        test_image = Variable(test_image)
        if cuda:
            test_image = test_image.cuda()

        test_image = test_image.sub(tmean).div(tstd)
        gen = model(test_image)
        gened = torch.clamp(gen.mul(std).add(mean).mul(255.0),
                            min=0.,
                            max=255.0).byte()[0].data.cpu().numpy().transpose(
                                1, 2, 0)
        gened = Image.fromarray(gened)
        gened.save('testing-sr.jpg')

    else:
        train_set = dataprovider.DatasetFromDir(opt.image_dir,
                                                samples=opt.images,
                                                width=WIDTH,
                                                height=HEIGHT)

        training_data_loader = DataLoader(dataset=train_set,
                                          num_workers=opt.threads,
                                          batch_size=opt.batchSize,
                                          shuffle=True)

        optimizer = optim.Adam(model.parameters(), lr=lr)

        counter = 0
        for epoch in range(opt.nEpochs):

            loss_sum = Variable(torch.zeros(1), requires_grad=False)
            if cuda:
                loss_sum = loss_sum.cuda()

            for iteration, batch in enumerate(training_data_loader, 1):
                counter = counter + 1
                input, target = (Variable(batch[0]),
                                 Variable(batch[1], requires_grad=False))

                if cuda:
                    input = input.cuda()
                    target = target.cuda()

                input = input.sub(tmean).div(tstd)
                target = target.sub(mean).div(std)

                gen = model(input)
                optimizer.zero_grad()
                loss = criterion(gen, target)

                if not opt.pretraining:
                    out_percep = percep_model.features(gen)
                    out_percep_real = Variable(
                        percep_model.features(target).data,
                        requires_grad=False)
                    percep_loss = criterion(out_percep, out_percep_real)
                    #                    loss_relation = percep_loss.div(loss)

                    loss = loss.add(percep_loss.mul(
                        opt.percep_scale))  # loss_relation))

                loss.backward()
                nn.utils.clip_grad_norm(model.parameters(), opt.clip)
                loss_sum.add_(loss)
                optimizer.step()

                if counter % 400 == 0:
                    print('sum_of_loss = {}'.format(loss_sum.data.select(0,
                                                                         0)))
                    loss_sum = Variable(torch.zeros(1), requires_grad=False)
                    if cuda:
                        loss_sum = loss_sum.cuda()

                    save_checkpoint(model, epoch)
                    input = torch.clamp(
                        input.mul(tstd).add(tmean).mul(255.0),
                        min=0.,
                        max=255.0).byte()[0].data.cpu().numpy().transpose(
                            1, 2, 0)
                    inp = Image.fromarray(input)
                    label = torch.clamp(
                        target.mul(std).add(mean).mul(255.0),
                        min=0.,
                        max=255.0).byte()[0].data.cpu().numpy().transpose(
                            1, 2, 0)
                    lab = Image.fromarray(label)
                    gened = torch.clamp(
                        gen.mul(std).add(mean).mul(255.0), min=0.,
                        max=255.0).byte()[0].data.cpu().numpy().transpose(
                            1, 2, 0)
                    gened = Image.fromarray(gened)
                    inp.save('input.jpg')
                    lab.save('gt.jpg')
                    gened.save('sr.jpg')
예제 #26
0
 def test_local_var_binary_methods(self):
     
     x = Var(torch.FloatTensor([1, 2, 3, 4, 5]))
     y = Var(torch.FloatTensor([1, 2, 3, 4, 5]))
     assert  torch.equal(x.add_(y), Var(torch.FloatTensor([2,4,6,8,10])))
def meta_learn(model, x_train, y_train, ystatus_train, x_val, y_val,
               ystatus_val, iterations, lr_inner, lr_outer, n_inner, batch_n,
               reg_scale, shots_n):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr_outer)

    train_metalosses = []
    test_metalosses = []

    inner_optimizer_state = None

    for t in range(iterations):

        start = time.time()
        #Average gradient of a batch of tasks
        ind = random.sample(range(x_train.shape[0]), shots_n)
        x_batch = x_train[ind, ]
        ystatus_batch = ystatus_train[ind, ]
        y_batch = y_train[ind, ]
        R_matrix_batch = np.zeros([y_batch.shape[0], y_batch.shape[0]],
                                  dtype=int)
        for i in range(y_batch.shape[0]):
            for j in range(y_batch.shape[0]):
                R_matrix_batch[i, j] = y_batch[j] >= y_batch[i]
        new_model = do_base_learning(model, x_batch, R_matrix_batch,
                                     ystatus_batch, lr_inner, n_inner,
                                     reg_scale)

        diff = list()
        for p, new_p in zip(model.parameters(), new_model.parameters()):
            temp = Variable(torch.zeros(p.size()))
            temp.add_(p.data - new_p.data)
            diff.append(temp)

        for j in range(batch_n - 1):

            ind = random.sample(range(x_train.shape[0]), shots_n)
            x_batch = x_train[ind, ]
            ystatus_batch = ystatus_train[ind, ]
            y_batch = y_train[ind, ]
            R_matrix_batch = np.zeros([y_batch.shape[0], y_batch.shape[0]],
                                      dtype=int)
            for i in range(y_batch.shape[0]):
                for j in range(y_batch.shape[0]):
                    R_matrix_batch[i, j] = y_batch[j] >= y_batch[i]
            new_model = do_base_learning(model, x_batch, R_matrix_batch,
                                         ystatus_batch, lr_inner, n_inner,
                                         reg_scale)

            diff_next = list()
            for p, new_p in zip(model.parameters(), new_model.parameters()):
                temp = Variable(torch.zeros(p.size()))
                temp.add_(p.data - new_p.data)
                diff_next.append(temp)

            diff = list(map(add, diff, diff_next))

        diff_ave = [x / batch_n for x in diff]

        ind_k = 0
        for p in model.parameters():
            if p.grad is None:
                p.grad = Variable(torch.zeros(p.size()))
            p.grad.data.add_(diff_ave[ind_k])
            ind_k = ind_k + 1

        # Update meta-parameters
        optimizer.step()
        optimizer.zero_grad()

        val_metaloss, val_cind = do_base_eval(model, x_val, y_val, ystatus_val)

        end = time.time()
        print("1 iteration time:", end - start)
        print('Iteration', t)
예제 #28
0
파일: structs.py 프로젝트: ulzee/mpgan
        def rec(node, rfunc):
            rsum = Variable(torch.zeros(node.hsize, )).to(Tree.device)
            for child in node.children:
                rsum.add_(rec(child, rfunc))

            return rfunc(node.h_v, rsum)
예제 #29
0
    def test_local_var_binary_methods(self):

        x = Var(torch.FloatTensor([1, 2, 3, 4, 5]))
        y = Var(torch.FloatTensor([1, 2, 3, 4, 5]))
        assert torch.equal(x.add_(y), Var(torch.FloatTensor([2, 4, 6, 8, 10])))
예제 #30
0
def Train(path, batch, shuffle, hiddenSize, latentSize, symSize, boxSize,
          catSize):
    matplotlib.use('Qt5Agg')
    plt.ion()
    ax = plt.gca()
    ax.set_autoscale_on(True)
    hl, = plt.plot([], [])

    data = get_loader(path, batch, shuffle)

    VAE = myVAE(hiddenSize, latentSize, symSize, boxSize, catSize)
    # VAE.load_state_dict(torch.load('VAE.pkl'))
    # optimization = torch.optim.SGD(
    #     [{'params': VAE.encoder.BoxEnco.parameters()},
    #      {'params': VAE.decoder.BoxDeco.parameters()},
    #      {'params': VAE.encoder.symEnco1.parameters()},
    #      {'params': VAE.encoder.symEnco2.parameters()},
    #      {'params': VAE.decoder.symDeco1.parameters()},
    #      {'params': VAE.decoder.symDeco2.parameters()},
    #      {'params': VAE.encoder.AdjEnco1.parameters()},
    #      {'params': VAE.encoder.AdjEnco2.parameters()},
    #      {'params': VAE.decoder.AdjDeco1.parameters()},
    #      {'params': VAE.decoder.AdjDeco2.parameters()},
    #      {'params': VAE.decoder.NClr1.parameters(), 'lr':0.5/20},
    #      {'params': VAE.decoder.NClr2.parameters(), 'lr':0.5/20},
    #      {'params': VAE.ranen1.parameters()},
    #      {'params': VAE.ranen2.parameters()},
    #      {'params': VAE.rande1.parameters()},
    #      {'params': VAE.rande2.parameters()}],
    #      lr=0.2/20)
    optimization = torch.optim.SGD(VAE.parameters(), lr=0.2 / 20)

    histLoss = list()

    for i in range(1500):
        batchLoss = list()
        for j, d in enumerate(data):
            boxes = d['boxes']
            symshapes = Variable(d['symshapes'].float())
            treekids = d['treekids']
            symparams = Variable(d['symparams'])

            #calculate Output!!!
            myOut, paramOut, paramGT, NClrOut, NClrGT, mu, logvar = VAE(
                symshapes, treekids, symparams)

            #calculate KL Loss
            KLD_element = mu.pow(2).add_(
                logvar.exp()).mul_(-1).add_(1).add_(logvar)
            KLloss = torch.sum(KLD_element).mul(-0.5)

            # calculate Node Classification Loss
            finalNCl = torch.nn.functional.cross_entropy(
                NClrOut, NClrGT).mul_(0.2).mul_(NClrGT.size(0))

            # finalNCl = Variable(torch.FloatTensor([0]))
            # for ii in range(np.shape(NClrGT)[0]):
            #     finalNCl.add_(torch.nn.functional.cross_entropy(NClrOut[ii], NClrGT[ii]).mul_(0.2))

            # #calculate Sym parameters Loss
            finalPL = torch.nn.functional.mse_loss(paramOut, paramGT).mul_(
                paramGT.size(0))

            # finalPL = Variable(torch.FloatTensor([0]))
            # for ii in range(np.shape(paramGT)[0]):
            #     finalPL.add_(torch.nn.functional.mse_loss(paramOut[ii], paramGT[ii]))

            #calculate reconstruction Loss
            symshapes = torch.t(symshapes.squeeze())
            finalRL = torch.nn.functional.mse_loss(myOut,
                                                   symshapes).mul_(0.8).mul_(
                                                       myOut.size(0))

            # finalRL = Variable(torch.FloatTensor([0]))
            # for ii in range(np.shape(myOut)[0]):
            #     finalRL.add_(torch.nn.functional.mse_loss(myOut[ii], symshapes[:, :, ii]).mul_(0.8))

            #final Loss
            FinalLoss = Variable(torch.FloatTensor([0]))
            FinalLoss = FinalLoss.add_(KLloss).add_(finalNCl).add_(
                finalPL).add_(finalRL)
            batchLoss.append(FinalLoss.data.numpy())
            # leafcount = VAE.decoder.gLeafcount
            # gAssemcount = VAE.decoder.gAssemcount
            # gSymcount = VAE.decoder.gSymcount
            # optimization = torch.optim.SGD([{'params':VAE.encoder.parameters(), 'lr':0.2}, {'params':VAE.decoder.parameters(),'lr':0.2}, ], lr=0.2)

            # for k in range(len(optimization.param_groups)):
            #     if(k <= 1):
            #         optimization.param_groups[k]['lr'] = optimization.param_groups[k]['lr']/leafcount
            #     elif(k <= 5):
            #         optimization.param_groups[k]['lr'] = optimization.param_groups[k]['lr']/gSymcount
            #     elif(k <= 9):
            #         optimization.param_groups[k]['lr'] = optimization.param_groups[k]['lr']/gAssemcount
            #     elif(k <= 11):
            #         optimization.param_groups[k]['lr'] = optimization.param_groups[k]['lr']/treekids.size(1)

            if (j % 20 == 0 and j != 0):
                optimization.step()
                optimization.zero_grad()

                tmp = sum(batchLoss) / 20
                histLoss.append(tmp)
                print('Epoch [%d/%d],  Iter [%d/%d], Loss: %.4f, ' %
                      (i, 1500, j, len(data), tmp))
                del batchLoss[:]
                batchLoss = list()

            FinalLoss.backward()

        update_line(hl, ax, histLoss)

        if (i % 10 == 0 and i != 0):
            torch.save(VAE.state_dict(), 'VAE.pkl')
예제 #31
0
    def forward(self, inputs, gold=None, inference=None):
        token_ids = inputs['token_ids']
        token_mask = inputs['token_mask']
        entity_ids = inputs['entity_ids']
        entity_mask = inputs['entity_mask']
        p_e_m = inputs['p_e_m']
        p_e_ent_net = inputs['p_e_ent_net']
        n_negs = inputs['n_negs']

        n_ments, n_cands = entity_ids.size()
        n_rels = self.n_rels

        if self.mode == 'ment-norm' and self.first_head_uniform:
            self.ew_embs.data[0] = 0

        if not self.oracle:
            gold = None

        if self.use_local:
            local_ent_scores = super(MulRelRanker, self).forward(token_ids, token_mask,
                                                                 entity_ids, entity_mask,
                                                                 p_e_m=None)
            ent_vecs = self._entity_vecs
        else:
            ent_vecs = self.entity_embeddings(entity_ids)
            local_ent_scores = Variable(torch.zeros(n_ments, n_cands).cuda(), requires_grad=False)

        # compute context vectors
        ltok_vecs = self.snd_word_embeddings(inputs['s_ltoken_ids']) * inputs['s_ltoken_mask'].view(n_ments, -1, 1)
        local_lctx_vecs = torch.sum(ltok_vecs, dim=1) / torch.sum(inputs['s_ltoken_mask'], dim=1, keepdim=True).add_(1e-5)
        rtok_vecs = self.snd_word_embeddings(inputs['s_rtoken_ids']) * inputs['s_rtoken_mask'].view(n_ments, -1, 1)
        local_rctx_vecs = torch.sum(rtok_vecs, dim=1) / torch.sum(inputs['s_rtoken_mask'], dim=1, keepdim=True).add_(1e-5)
        mtok_vecs = self.snd_word_embeddings(inputs['s_mtoken_ids']) * inputs['s_mtoken_mask'].view(n_ments, -1, 1)
        ment_vecs = torch.sum(mtok_vecs, dim=1) / torch.sum(inputs['s_mtoken_mask'], dim=1, keepdim=True).add_(1e-5)
        bow_ctx_vecs = torch.cat([local_lctx_vecs, ment_vecs, local_rctx_vecs], dim=1)

        if self.use_pad_ent:
            ent_vecs = torch.cat([ent_vecs, self.pad_ent_emb.view(1, 1, -1).repeat(1, n_cands, 1)], dim=0)
            tmp = torch.zeros(1, n_cands)
            tmp[0, 0] = 1
            tmp = Variable(tmp.cuda())
            entity_mask = torch.cat([entity_mask, tmp], dim=0)
            p_e_m = torch.cat([p_e_m, tmp], dim=0)
            local_ent_scores = torch.cat([local_ent_scores,
                                          Variable(torch.zeros(1, n_cands).cuda(), requires_grad=False)],
                                         dim=0)
            n_ments += 1

            if self.oracle:
                tmp = Variable(torch.zeros(1, 1).cuda().long())
                gold = torch.cat([gold, tmp], dim=0)

        if self.use_local_only:
            inputs = torch.cat([local_ent_scores.view(n_ments * n_cands, -1),
                                torch.log(p_e_m + 1e-20).view(n_ments * n_cands, -1)], dim=1)
            scores = self.score_combine(inputs).view(n_ments, n_cands)
            if self.use_pad_ent:
                scores = scores[:-1]

            self.layer_scores = [scores] * self.n_layers
            return scores

        if n_ments == 1:
            ent_scores = local_ent_scores

        else:
            # distance - to consider only neighbor mentions
            ment_pos = torch.arange(0, n_ments).long().cuda()
            dist = (ment_pos.view(n_ments, 1) - ment_pos.view(1, n_ments)).abs()
            dist.masked_fill_(dist == 1, -1)
            dist.masked_fill_((dist > 1) & (dist <= self.max_dist), -1)
            dist.masked_fill_(dist > self.max_dist, 0)
            dist.mul_(-1)

            if self.uniform_att:
                rel_ctx_ctx_scores = Variable(torch.zeros(n_rels, n_ments, n_ments).cuda())

            else:
                ctx_vecs = self.ctx_layer(bow_ctx_vecs)
                if self.use_pad_ent:
                    ctx_vecs = torch.cat([ctx_vecs, self.pad_ctx_vec], dim=0)

                m1_ctx_vecs, m2_ctx_vecs = ctx_vecs, ctx_vecs
                rel_ctx_vecs = m1_ctx_vecs.view(1, n_ments, -1) * self.ew_embs.view(n_rels, 1, -1)
                rel_ctx_ctx_scores = torch.matmul(rel_ctx_vecs, m2_ctx_vecs.view(1, n_ments, -1).permute(0, 2, 1))  # n_rels x n_ments x n_ments

                rel_ctx_ctx_scores = rel_ctx_ctx_scores.add_((1 - Variable(dist.float().cuda())).mul_(-1e10))
                eye = Variable(torch.eye(n_ments).cuda()).view(1, n_ments, n_ments)
                rel_ctx_ctx_scores.add_(eye.mul_(-1e10))
                rel_ctx_ctx_scores.mul_(1 / np.sqrt(self.ew_hid_dims))  # scaling proposed by "attention is all you need"


            if self.mode == 'ment-norm':
                rel_ctx_ctx_probs = F.softmax(rel_ctx_ctx_scores, dim=2)
                rel_ctx_ctx_weights = rel_ctx_ctx_probs + rel_ctx_ctx_probs.permute(0, 2, 1)
                self._rel_ctx_ctx_weights = rel_ctx_ctx_probs
            elif self.mode == 'rel-norm':
                ctx_ctx_rel_scores = rel_ctx_ctx_scores.permute(1, 2, 0).contiguous()
                if not self.use_stargmax:
                    ctx_ctx_rel_probs = F.softmax(ctx_ctx_rel_scores, dim=2)
                else:
                    ctx_ctx_rel_probs = STArgmax.apply(ctx_ctx_rel_scores)
                self._rel_ctx_ctx_weights = ctx_ctx_rel_probs.permute(2, 0, 1).contiguous()

            # compute phi(ei, ej)
            if self.mode == 'ment-norm':
                if self.ent_ent_comp == 'bilinear':
                    if self.ent_ent_comp == 'bilinear':
                        rel_ent_vecs = ent_vecs.view(1, n_ments, n_cands, -1) * self.rel_embs.view(n_rels, 1, 1, -1)
                    elif self.ent_ent_comp == 'trans_e':
                        rel_ent_vecs = ent_vecs.view(1, n_ments, n_cands, -1) - self.rel_embs.view(n_rels, 1, 1, -1)
                    else:
                        raise Exception("unknown ent_ent_comp")

                    rel_ent_ent_scores = torch.matmul(rel_ent_vecs.view(n_rels, n_ments, 1, n_cands, -1),
                                                      ent_vecs[:, n_negs:, :].contiguous().view(1, 1, n_ments, n_cands - n_negs, -1).permute(0, 1, 2, 4, 3))
                    # n_rels x n_ments x n_ments x n_cands x (n_cands - n_negs)

                rel_ent_ent_scores = rel_ent_ent_scores.permute(0, 1, 3, 2, 4)  # n_rel x n_ments x n_cands x n_ments x (n_cands - n_negs)
                rel_ent_ent_scores = (rel_ent_ent_scores * entity_mask[:, n_negs:]).add_((entity_mask[:, n_negs:] - 1).mul_(1e10))
                rel_ent_ent_weighted_scores = rel_ent_ent_scores * rel_ctx_ctx_weights.view(n_rels, n_ments, 1, n_ments, 1)
                ent_ent_scores = torch.sum(rel_ent_ent_weighted_scores, dim=0).mul(1. / n_rels)  # n_ments x n_cands x n_ments x (n_cands - n_negs)

            elif self.mode == 'rel-norm':
                raise Exception('not implement for multi-layer yet')

                # TODO: check it again
                rel_vecs = torch.matmul(ctx_ctx_rel_probs.view(n_ments, n_ments, 1, n_rels),
                                        self.rel_embs.view(1, 1, n_rels, -1))\
                           .view(n_ments, n_ments, -1)
                ent_rel_vecs = ent_vecs.view(n_ments, 1, n_cands, -1) * rel_vecs.view(n_ments, n_ments, 1, -1)  # n_ments x n_ments x n_cands x dims
                ent_ent_scores = torch.matmul(
                        ent_rel_vecs,
                        ent_vecs[:, n_negs:, :].contiguous().view(1, n_ments, n_cands - n_negs, -1).permute(0, 1, 3, 2))\
                .permute(0, 2, 1, 3)  # n_ments x n_cands x n_ments x (n_cands - n_negs)

            if gold is None:
                if inference == 'LBP' or (inference is None and self.inference == 'LBP'):
                    prev_msgs = Variable(torch.zeros(n_ments, n_cands, n_ments).cuda())

                    for _ in range(self.n_loops):
                        mask = 1 - Variable(torch.eye(n_ments).cuda())
                        ent_ent_votes = ent_ent_scores + local_ent_scores[:, n_negs:] * 1 + \
                                        torch.sum(prev_msgs.view(1, n_ments, n_cands, n_ments) *
                                                  mask.view(n_ments, 1, 1, n_ments), dim=3)\
                                        .view(n_ments, 1, n_ments, n_cands)
                        msgs, _ = torch.max(ent_ent_votes, dim=3)
                        # msgs = utils.log_sum_exp(ent_ent_votes, dim=3)
                        msgs = (F.softmax(msgs, dim=1).mul(self.df) +
                                prev_msgs.exp().mul(1 - self.df)).log()
                        prev_msgs = msgs

                    # compute marginal belief
                    mask = 1 - Variable(torch.eye(n_ments).cuda())
                    ent_scores = local_ent_scores * 1 + torch.sum(msgs * mask.view(n_ments, 1, n_ments), dim=2)
                    ent_scores = F.softmax(ent_scores, dim=1)

                elif inference == 'star' or (inference is None and self.inference == 'star'):
                    comp = 'max'
                    ent_weights = Variable(torch.ones(n_ments, n_cands - n_negs).cuda())
                    self_mask = Variable(1 - torch.eye(n_ments).cuda()).view(n_ments, 1, n_ments)

                    ent_ent_scores += local_ent_scores[:, n_negs:]
                    ent_ent_weighted_scores = ent_ent_scores * ent_weights
                    ent_ent_weighted_scores = (ent_ent_weighted_scores * entity_mask[:, n_negs:]).add_((entity_mask[:, n_negs:] - 1).mul_(1e10))

                    if comp == 'weighted_sum':
                        ent_ent_att_probs = F.softmax(ent_ent_weighted_scores, dim=3)  # n_ments x n_cands x n_ments x (n_cands - n_negs)
                        ent_ent_weighted_scores = ent_ent_weighted_scores * ent_ent_att_probs
                        ent_ment_scores = torch.sum(ent_ent_weighted_scores, dim=3)  # n_ments x n_cands x n_ments
                    elif comp == 'softmax':
                        ent_ment_scores = utils.log_sum_exp(ent_ent_weighted_scores, dim=3)
                    elif comp == 'max':
                        ent_ment_scores, _ = torch.max(ent_ent_weighted_scores, dim=3)  # n_ments x n_cands x n_ments

                    ent_ment_scores = (ent_ment_scores * self_mask).add_((self_mask - 1).mul_(1e10))  # set self-self scores to -inf

                    # compute entity scores, using hard attention on neighbour mentions (eq 4)
                    ent_top_ment_scores, _ = torch.topk(ent_ment_scores, dim=2, k=max(1, min(self.ent_top_n, n_ments - 2)))
                    ent_scores = local_ent_scores + torch.sum(ent_top_ment_scores, dim=2)
                    ent_scores = (ent_scores * entity_mask).add_((entity_mask - 1).mul_(1e10))

            else:
                onehot_gold = Variable(torch.zeros(n_ments, n_cands).cuda()).scatter_(1, gold, 1)
                ent_scores = torch.sum(torch.sum(ent_ent_scores * onehot_gold, dim=3), dim=2)

        # combine with p_e_m
        p_e_m_mul = 1
        inputs = torch.cat([ent_scores.view(n_ments * n_cands, -1),
                            p_e_m_mul * torch.log(p_e_m + 1e-20).view(n_ments * n_cands, -1)], dim=1)
        scores = self.score_combine(inputs).view(n_ments, n_cands)
        # scores = ent_scores
        if self.use_pad_ent:
            scores = scores[:-1]
        return scores
예제 #32
0
    def test_local_add(self):

        x = Var(torch.FloatTensor([1, 2, 3, 4, 5]))
        y = Var(torch.FloatTensor([1, 2, 3, 4, 5]))
        assert torch.equal(x.add_(y), Var(torch.FloatTensor([2, 4, 6, 8, 10])))