def forward(self, var_input, var_pred, var_true=None): # MMD Loss if var_true is None: X = th.cat([var_input, var_pred], 0) else: X = th.cat([th.cat([var_input, var_pred], 1), th.cat([var_input, var_true], 1)], 0) # dot product between all combinations of rows in 'X' XX = X.mm(X.t()) # dot product of rows with themselves X2 = (X.mul(X)).sum(dim=1) # exponent entries of the RBF kernel (without the sigma) for each # combination of the rows in 'X' # -0.5 * (x^Tx - 2*x^Ty + y^Ty) exponent = XX.sub((X2.mul(0.5)).expand_as(XX)) - \ (((X2.t()).mul(0.5)).expand_as(XX)) if self.cuda: lossMMD = Variable(th.cuda.FloatTensor([0])) else: lossMMD = Variable(th.zeros(1)) for i in range(len(self.bandwiths)): kernel_val = exponent.mul(1. / self.bandwiths[i]).exp() lossMMD.add_((self.S.mul(kernel_val)).sum()) return lossMMD.sqrt()
def update(self): self.learning_rate *= 0.99 assert len(self.b_logprobs) == len(self.b_rewards) rewards = self.discount_rewards(self.b_rewards) g = Variable(torch.zeros(1)) for i, logprobs in enumerate(self.b_logprobs): action = self.b_actions[i] rews = torch.zeros(logprobs.size()) rews[0][action] = rewards[i] rews = Variable(rews) #print(logprobs, rews) g.add_(torch.dot(logprobs, rews)) #print(g) g.backward() for f in self.nn.parameters(): f.data.add_(f.grad.data * self.learning_rate) self.nn.zero_grad() self.b_logprobs = [] self.b_rewards = [] self.b_actions = []
def forward(self, input, target): # truncate to the same size # input (batch_size * (seq_length + 2) * (vocab_size + 1)) # target (batch_size * (seq_length)) batch_size, L, Mp1 = input.size(0), input.size(1), input.size(2) seq_length = target.size(1) cumreward = Variable(torch.FloatTensor(1).zero_(), requires_grad=True).cuda() for tt in xrange(seq_length): # reward = self.reward_func(input, target, tt) cumreward.add_(reward) # num_samples self.num_samples = self.reward_func.num_samples(input, target) # normalizing_coeff self.normalizing_coeff = self.weight / (self.sizeAverage and self.num_samples or 1) # here there is a '-' because we minimize self.output = -cumreward * self.normalizing_coeff # cumreward return self.output, self.num_samples
def test_nn(self): n_samples = 100 n_features = 3 n_out = 1 X = Variable(torch.randn(n_samples, n_features).type(FloatTensor), requires_grad=False) y = Variable(torch.randn(n_samples, n_out).type(FloatTensor), requires_grad=False) y.add_(X[:, 0] * 2) y.add_(X[:, 1] * 3) torch_nn = TorchNN(learning_rate=1e-3, n_hidden=4).fit(X, y, iters=100) self.assertGreater(torch_nn.loss_path[0], torch_nn.loss_path[-1])
def test_inplace(self): x = Variable(torch.ones(5, 5), requires_grad=True) y = Variable(torch.ones(5, 5) * 4, requires_grad=True) z = x * y q = z + y w = z * y z.add_(2) # Add doesn't need it's inputs to do backward, so it shouldn't raise q.backward(torch.ones(5, 5), retain_variables=True) # Mul saves both inputs in forward, so it should raise self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) z = x * y q = z * y r = z + y w = z.add_(y) # w is a the last expression, so this should succeed w.backward(torch.ones(5, 5), retain_variables=True) # r doesn't use the modified value in backward, so it should succeed r.backward(torch.ones(5, 5), retain_variables=True) # q uses dirty z, so it should raise self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5))) x.grad.data.zero_() m = x / 2 z = m + y / 8 q = z * y r = z + y prev_version = z._version w = z.exp_() self.assertNotEqual(z._version, prev_version) r.backward(torch.ones(5, 5), retain_variables=True) self.assertEqual(x.grad.data, torch.ones(5, 5) / 2) w.backward(torch.ones(5, 5), retain_variables=True) self.assertEqual(x.grad.data, torch.Tensor(5, 5).fill_((1 + math.e) / 2)) self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5))) leaf = Variable(torch.ones(5, 5), requires_grad=True) x = leaf.clone() x.add_(10) self.assertEqual(x.data, torch.ones(5, 5) * 11) # x should be still usable y = x + 2 y.backward(torch.ones(5, 5)) self.assertEqual(leaf.grad.data, torch.ones(5, 5)) z = x * y x.add_(2) self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
def forward(self, pred, target): """Compute the loss model. :param pred: predicted Variable :param target: Target Variable :return: Loss """ loss = Variable(th.FloatTensor([0])) for i in range(1, self.moments): mk_pred = th.mean(th.pow(pred, i), 0) mk_tar = th.mean(th.pow(target, i), 0) loss.add_(th.mean((mk_pred - mk_tar) ** 2)) # L2 return loss
def rec(node, mfunc): msg = Variable(torch.zeros(node.hsize, )).to(Tree.device) for child in node.children: msg.add_(mfunc(node.h_v, child.h_v)) if node.parent is not None: msg.add_(mfunc(node.h_v, node.parent.h_v)) child_msgs = [] for child in node.children: child_msgs.append(rec(child, mfunc)) return { 'msg': msg, 'child_msgs': child_msgs, }
def add_jitter(mat): """ Adds "jitter" to the diagonal of a matrix. This ensures that a matrix that *should* be positive definite *is* positive definate. Args: - mat (matrix nxn) - Positive definite matrxi Returns: (matrix nxn) """ if isinstance(mat, LazyVariable): return mat.add_jitter() elif isinstance(mat, Variable): diag = Variable(mat.data.new(mat.size(-1)).fill_(1e-3).diag()) if mat.ndimension() == 3: return mat + diag.unsqueeze(0).expand(mat.size(0), mat.size(1), mat.size(2)) else: return mat + diag else: diag = mat.new(mat.size(-1)).fill_(1e-3).diag() if mat.ndimension() == 3: return mat.add_( diag.unsqueeze(0).expand(mat.size(0), mat.size(1), mat.size(2))) else: return diag.add_(mat)
def forward(self, x=None, pos=None): """Compute positional embeddings. Args: x: Tensor of size [batch_size, max_len] Returns: emb: Tensor of size [batch_size, max_len, d_word_vec]. """ d_word_vec = self.hparams.d_word_vec if pos is not None: batch_size, max_len = pos.size() pos = Variable(pos) else: batch_size, max_len = x.size() pos = Variable(torch.arange(0, max_len)) if self.hparams.cuda: pos = pos.cuda() if self.hparams.pos_emb_size is not None: pos = pos.add_(1).long().unsqueeze(0).expand_as(x).contiguous() emb = self.emb(pos) else: emb = pos.float().unsqueeze(-1) * self.freq.unsqueeze(0) sin = torch.sin(emb).mul_(self.emb_scale).unsqueeze(-1) cos = torch.cos(emb).mul_(self.emb_scale).unsqueeze(-1) #emb = pos.float().unsqueeze(-1) / self.freq.unsqueeze(0) #sin = torch.sin(emb).unsqueeze(-1) #cos = torch.cos(emb).unsqueeze(-1) emb = torch.cat([sin, cos], dim=-1).contiguous().view(max_len, d_word_vec) emb = emb.unsqueeze(0).expand(batch_size, -1, -1) return emb
def test_log_reg(self): n_samples = 100 n_features = 3 n_out = 1 X = Variable(torch.randn(n_samples, n_features).type(FloatTensor), requires_grad=False) y = Variable(torch.randn(n_samples, n_out).type(FloatTensor), requires_grad=False) y.add_(X[:, 0] * 2) y.add_(X[:, 1] * 3) y = y.ge(0).type(FloatTensor) torch_log_reg = TorchLogreg(learning_rate=1e-2, verbose=True).fit(X, y, iters=100) self.assertGreater(torch_log_reg.loss_path[0], torch_log_reg.loss_path[-1]) self.check_convergence(torch_log_reg, ideal_W=[.4, 8, 0])
def test_torch_reg(self): n_samples = 100 n_features = 3 n_out = 1 X = Variable(torch.randn(n_samples, n_features).type(FloatTensor), requires_grad=False) y = Variable(torch.randn(n_samples, n_out).type(FloatTensor), requires_grad=False) y.add_(X[:, 0] * 2) y.add_(X[:, 1] * 3) torch_reg = TorchReg(learning_rate=1e-3).fit(X, y, iters=100) coeffs = torch_reg.W.data.numpy()[:, 0] print(coeffs) self.assertGreater(coeffs[0], 1.8) self.assertGreater(coeffs[1], 2.8) self.check_convergence(torch_reg, ideal_W=[2, 3, 0])
def test_shared_storage(self): x = Variable(torch.ones(5, 5)) y = x.t() z = x[1] self.assertRaises(RuntimeError, lambda: x.add_(2)) self.assertRaises(RuntimeError, lambda: y.add_(2)) self.assertRaises(RuntimeError, lambda: z.add_(2))
def test_remote_var_binary_methods(self): hook = TorchHook() local = hook.local_worker remote = VirtualWorker(hook, 0) local.add_worker(remote) x = Var(torch.FloatTensor([1, 2, 3, 4, 5])).send(remote) y = Var(torch.FloatTensor([1, 2, 3, 4, 5])).send(remote) assert torch.equal(x.add_(y).get(), Var(torch.FloatTensor([2,4,6,8,10])))
class NeuralDict(nn.Module): def __init__(self, key_size, dict_size): super().__init__() self.dict_size = dict_size self.key_size = key_size self.fill_count = 0 self.stale_ind = None self.keys = Variable(torch.zeros(self.dict_size, self.key_size).cuda()) self.values = Variable(torch.zeros(self.dict_size, 1).cuda()) self.recency_map = Variable(torch.zeros(self.dict_size, 1).cuda()) def forward(self, _input): return def get_knn(self, query, k): """ get K nearest neighbors by cosine distance for query """ if self.fill_count < self.dict_size: norm_keys = F.normalize(self.keys[:self.fill_count, :], dim=1) else: norm_keys = F.normalize(self.keys, dim=1) norm_query = F.normalize(query, dim=1) cosine_dist = torch.mm(norm_keys, norm_query.t()) return torch.topk(cosine_dist, k, dim=1) def add_key(self, key, value): if self.fill_count >= self.dict_size: self.keys[self.stale_ind] = key self.values[self.stale_ind] = value else: self.keys[self.fill_count] = key self.values[self.fill_count] = value self.fill_count += 1 def update_recency_map(self, nn_indices): mask = Variable(torch.zero(*nn_indices.size()).cuda().fill_(1)) self.recency_map.scatter_add(0, nn_indices.view(-1), mask) self.recency_map.add_(-1).clamp_(0, 100) _, self.stale_ind = self.recency_map.min(0)
def gumbel_sample(input, temperature=1.0, avg=False, N=10000): # more accurate version of gumbel estimator as described in https://arxiv.org/abs/1706.04161 # averages N gumbel distributions and subtracts out Euler's constant if avg: noise = to_gpu(torch.rand([input.size()[-1] * N])) noise.add_(1e-9).log_().neg_() noise.add_(1e-9).log_().neg_() noise.add_(-EULER) noise = Variable(noise.view(N, input.size(-1))) x = (input.expand_as(noise) + noise) x = torch.mean(x, 0) / temperature else: noise = to_gpu(torch.rand(input.size())) noise.add_(1e-9).log_().neg_() noise.add_(1e-9).log_().neg_() noise = Variable(noise) x = (input + noise) / temperature x = F.softmax(x.view(input.size(0), -1)) return x.view_as(input)
def backward(ctx, grad_output): """ X : Distance between similar samples Y : Distance between dissimilar samples Gradients: dLoss/dX = 1 + 1 / (exp(-X) + exp(-Y)) * -1 * exp(-X) = 1 - exp(-X) / (exp(-X) + exp(-Y)) dLoss/dY = 0 + 1 / (exp(-X) + exp(-Y)) * -1 * exp(-Y) = -exp(-Y) / (exp(-X) + exp(-Y)) """ input1, input2, y = ctx.saved_variables grad_input1 = Variable(input1.data.new(input1.size()).zero_()) grad_input2 = Variable(input1.data.new(input1.size()).zero_()) grad_input1.add_(-1, torch.exp(-input1.clone())) grad_input2.add_(-1, torch.exp(-input2.clone())) dist = input1.clone().mul(-1).exp() + input2.clone().mul(-1).exp() grad_input1.div_(dist) grad_input2.div_(dist) grad_input1.add_(1) #grad_input1[y == 1].add_(1) # supporting switched samples #grad_input2[y == -1].add_(1) # supporting switched samples return grad_input1 * grad_output, grad_input2 * grad_output, None, None
def smoothness(grid): """ Given a variable of dimensions (N, X, Y, [Z], C), computes the sum of the differences between adjacent points in the grid formed by the dimensions X, Y, and (optionally) Z. Returns a tensor of dimension N. """ num_dims = len(grid.size()) - 2 batch_size = grid.size()[0] norm = Variable( torch.zeros(batch_size, dtype=grid.data.dtype, device=grid.data.device)) for dim in range(num_dims): slice_before = (slice(None), ) * (dim + 1) slice_after = (slice(None), ) * (num_dims - dim) shifted_grids = [ # left torch.cat([ grid[slice_before + (slice(1, None), ) + slice_after], grid[slice_before + (slice(-1, None), ) + slice_after], ], dim + 1), # right torch.cat([ grid[slice_before + (slice(None, 1), ) + slice_after], grid[slice_before + (slice(None, -1), ) + slice_after], ], dim + 1) ] for shifted_grid in shifted_grids: delta = shifted_grid - grid norm_components = (delta.pow(2).sum(-1) + 1e-10).pow(0.5) norm.add_( norm_components.sum( tuple(range(1, len(norm_components.size()))))) return norm
def add_jitter(mat): """ Adds "jitter" to the diagonal of a matrix. This ensures that a matrix that *should* be positive definite *is* positive definate. Args: - mat (matrix nxn) - Positive definite matrxi Returns: (matrix nxn) """ if isinstance(mat, LazyVariable): return mat.add_jitter() elif isinstance(mat, Variable): diag = Variable(mat.data.new(len(mat)).fill_(1e-3).diag()) return mat + diag else: diag = mat.new(len(mat)).fill_(1e-3).diag() return diag.add_(mat)
w = torch.FloatTensor(4, 2) init.xavier_normal(w) w = Variable(w, requires_grad=True) optimizer = optim.Adam([w], lr=1e-1) BATCH_SIZE = 30 st = torch.FloatTensor(BATCH_SIZE, 4) dreward = torch.Tensor([-1]).expand(BATCH_SIZE) for ep in range(1000): st.uniform_(-0.05, 0.05) vst = Variable(st) reward = Variable(torch.zeros(BATCH_SIZE)) multiplier = 1 notdone = Variable(torch.ones(BATCH_SIZE)) for i in range(800): logits = vst @ w y = gumbel_softmax(logits, tau=1, hard=True) vst, r, d = step_cartpole(vst, y) tester = notdone * (1 - d) * r * multiplier reward.add_(tester) notdone = (d == 0).float() multiplier *= 0.99 if (d.data > 0).all(): break if ep % 5 == 0: print("Num steps", i) optimizer.zero_grad() reward.backward(dreward) optimizer.step()
def ais_trajectory(model, loader, mode='forward', schedule=np.linspace(0., 1., 500), n_sample=100): """Compute annealed importance sampling trajectories for a batch of data. Could be used for *both* forward and reverse chain in bidirectional Monte Carlo (default: forward chain with linear schedule). Args: model (vae.VAE): VAE model loader (iterator): iterator that returns pairs, with first component being `x`, second would be `z` or label (will not be used) mode (string): indicate forward/backward chain; must be either `forward` or 'backward' schedule (list or 1D np.ndarray): temperature schedule, i.e. `p(z)p(x|z)^t`; foward chain has increasing values, whereas backward has decreasing values n_sample (int): number of importance samples (i.e. number of parallel chains for each datapoint) Returns: A list where each element is a torch.autograd.Variable that contains the log importance weights for a single batch of data """ assert mode == 'forward' or mode == 'backward', 'Should have forward/backward mode' def log_f_i(z, data, t, log_likelihood_fn=log_bernoulli): """Unnormalized density for intermediate distribution `f_i`: f_i = p(z)^(1-t) p(x,z)^(t) = p(z) p(x|z)^t => log f_i = log p(z) + t * log p(x|z) """ zeros = Variable(torch.zeros(B, z_size).type(mdtype)) log_prior = log_normal(z, zeros, zeros) log_likelihood = log_likelihood_fn(model.decode(z), data) return log_prior + log_likelihood.mul_(t) model.eval() # shorter aliases z_size = model.z_size mdtype = model.dtype _time = time.time() logws = [] # for output print('In %s mode' % mode) for i, (batch, post_z) in enumerate(loader): B = batch.size(0) * n_sample batch = Variable(batch.type(mdtype)) batch = safe_repeat(batch, n_sample) # batch of step sizes, one for each chain epsilon = Variable(torch.ones(B).type(model.dtype)).mul_(0.01) # accept/reject history for tuning step size accept_hist = Variable(torch.zeros(B).type(model.dtype)) # record log importance weight; volatile=True reduces memory greatly logw = Variable(torch.zeros(B).type(mdtype), volatile=True) # initial sample of z if mode == 'forward': current_z = Variable(torch.randn(B, z_size).type(mdtype), requires_grad=True) else: current_z = Variable(safe_repeat(post_z, n_sample).type(mdtype), requires_grad=True) for j, (t0, t1) in tqdm(enumerate(zip(schedule[:-1], schedule[1:]), 1)): # update log importance weight log_int_1 = log_f_i(current_z, batch, t0) log_int_2 = log_f_i(current_z, batch, t1) logw.add_(log_int_2 - log_int_1) # resample speed current_v = Variable(torch.randn(current_z.size()).type(mdtype)) def U(z): return -log_f_i(z, batch, t1) def grad_U(z): # grad w.r.t. outputs; mandatory in this case grad_outputs = torch.ones(B).type(mdtype) # torch.autograd.grad default returns volatile grad = torchgrad(U(z), z, grad_outputs=grad_outputs)[0] # avoid humongous gradients grad = torch.clamp(grad, -10000, 10000) # needs variable wrapper to make differentiable grad = Variable(grad.data, requires_grad=True) return grad def normalized_kinetic(v): zeros = Variable(torch.zeros(B, z_size).type(mdtype)) # this is superior to the unnormalized version return -log_normal(v, zeros, zeros) z, v = hmc_trajectory(current_z, current_v, U, grad_U, epsilon) # accept-reject step current_z, epsilon, accept_hist = accept_reject( current_z, current_v, z, v, epsilon, accept_hist, j, U, K=normalized_kinetic) # IWAE lower bound logw = log_mean_exp(logw.view(n_sample, -1).transpose(0, 1)) if mode == 'backward': logw = -logw logws.append(logw.data) print ('Time elapse %.4f, last batch stats %.4f' % \ (time.time()-_time, logw.mean().cpu().data.numpy())) _time = time.time() sys.stdout.flush() # for debugging return logws
class MANNCell(nn.Module): def __init__(self, options, input_size, hidden_size, use_bias=True): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.use_bias = use_bias self.time_fac = options['time_fac'] self.weight_ih = nn.Parameter( torch.FloatTensor(2 * input_size, 3 * hidden_size)) self.weight_hh = nn.Parameter( torch.FloatTensor(2 * hidden_size, 3 * hidden_size)) self.bias_ih = nn.Parameter(torch.FloatTensor(1, 3 * hidden_size)) self.bias_hh = nn.Parameter(torch.FloatTensor(1, 3 * hidden_size)) self.memcnt = 0 self.memcap = options['mem_cap'] self.head_size = options['head_size'] mode = options['mode'] if mode == 'train': batch_size = options['batch_size'] elif mode == 'val': batch_size = options['eval_batch_size'] elif mode == 'test': batch_size = options['test_batch_size'] self.batch_size = batch_size self.auxcell = GRUCell(input_size + hidden_size, 2 * (self.head_size + 1)) self.tau = 1. self.i_fc = nn.Sequential( nn.Linear(input_size, self.head_size // 2), nn.ReLU(), nn.Linear(self.head_size // 2, self.head_size), nn.Sigmoid()) self.h_fc = nn.Sequential( nn.Linear(hidden_size, self.head_size // 2), nn.ReLU(), nn.Linear(self.head_size // 2, self.head_size), nn.Sigmoid()) self.last_usage = None self.mem = None self.reset_parameters() def _reset_mem(self): self.memcnt = 0 self.imem = Variable(torch.zeros(self.batch_size, self.memcap, self.input_size), requires_grad=True).cuda() self.hmem = Variable(torch.zeros(self.batch_size, self.memcap, self.hidden_size), requires_grad=True).cuda() self.i_last_use = Variable( torch.ones(self.batch_size, self.memcap) * -9999999.).cuda() self.h_last_use = Variable( torch.ones(self.batch_size, self.memcap) * -9999999.).cuda() def __repr__(self): s = '{name}({input_size}, {hidden_size})' return s.format(name=self.__class__.__name__, **self.__dict__) def set_tau(self, num): self.tau = num def reset_parameters(self): """ Initialize parameters following the way proposed in the paper. """ stdv = 1.0 / math.sqrt(self.hidden_size) for n, p in self.named_parameters(): if 'weight' in n: nn.init.orthogonal_(p) if 'bias' in n: nn.init.constant_(p.data, val=0) def forward(self, input_, h_0, aux_h_0): i = input_ h = h_0.detach() read_head = self.auxcell(torch.cat([i, h], dim=1), aux_h_0) i_read_head, h_read_head = torch.split(read_head, self.head_size + 1, dim=1) i_head_vecs = torch.cat([ self.i_fc(self.imem.detach()), self.time_fac * torch.sigmoid(self.i_last_use).detach().unsqueeze(2) ], dim=2) h_head_vecs = torch.cat([ self.h_fc(self.hmem.detach()), self.time_fac * torch.sigmoid(self.h_last_use).detach().unsqueeze(2) ], dim=2) i_read_head = 1 / torch.sqrt( (1e-6 + (i_read_head.unsqueeze(1) - i_head_vecs)**2).sum(dim=2)) h_read_head = 1 / torch.sqrt( (1e-6 + (h_read_head.unsqueeze(1) - h_head_vecs)**2).sum(dim=2)) i_entry, i_read_index, h_entry, h_read_index = self.read( i_read_head, h_read_head, self.tau) self.i_last_use.add_(-1).add_(-self.i_last_use * i_read_index) self.h_last_use.add_(-1).add_(-self.h_last_use * h_read_index) new_i = torch.cat([input_, i_entry], dim=1) new_h0 = torch.cat([h_0, h_entry], dim=1) wi_b = torch.addmm(self.bias_ih, new_i, self.weight_ih) wh_b = torch.addmm(self.bias_hh, new_h0, self.weight_hh) ri, zi, ni = torch.split(wi_b, self.hidden_size, dim=1) rh, zh, nh = torch.split(wh_b, self.hidden_size, dim=1) r = torch.sigmoid(ri + rh) z = torch.sigmoid(zi + zh) n = torch.tanh(ni + r * nh) h_1 = (1 - z) * n + z * h_0 if self.memcnt < self.memcap: h_write_index = i_write_index = Variable( torch.cat([ torch.zeros(self.memcnt), torch.ones(1), torch.zeros(self.memcap - 1 - self.memcnt) ]).unsqueeze(0)).cuda() self.memcnt += 1 else: h_write_index = h_read_index i_write_index = i_read_index self.write(input_, i_write_index, h_0, h_write_index) return h_1, read_head def write(self, i, i_index, h, h_index): i_ones = i_index.unsqueeze(2) h_ones = h_index.unsqueeze(2) self.imem = i.unsqueeze(1) * i_ones + self.imem * (1. - i_ones) self.hmem = h.unsqueeze(1) * h_ones + self.hmem * (1. - h_ones) def read(self, i_read_head, h_read_head, tau): i_index, _ = self.gumbel_softmax(i_read_head, tau) h_index, _ = self.gumbel_softmax(h_read_head, tau) i_entry = i_index.unsqueeze(2) * self.imem h_entry = h_index.unsqueeze(2) * self.hmem i_entry = i_entry.sum(dim=1) h_entry = h_entry.sum(dim=1) return i_entry, i_index, h_entry, h_index def gumbel_softmax(self, input, tau): gumbel = Variable(-torch.log( 1e-20 - torch.log(1e-20 + torch.rand(*input.shape)))).cuda() y = torch.nn.functional.softmax( (torch.log(1e-20 + input) + gumbel) * tau, dim=1) ymax, pos = y.max(dim=1) hard_y = torch.eq(y, ymax.unsqueeze(1)).float() y = (hard_y - y).detach() + y return y, pos def gumbel_sigmoid(self, input, tau): gumbel = Variable(-torch.log( 1e-20 - torch.log(1e-20 + torch.rand(*input.shape)))).cuda() y = torch.sigmoid((input + gumbel) * tau) #hard_y=torch.eq(y,ymax.unsqueeze(1)).float() #y=(hard_y-y).detach()+y return y
def stAdv_norm(self): """ Computes the norm used in "Spatially Transformed Adversarial Examples" """ # ONLY WORKS FOR SQUARE MATRICES dtype = self.xform_params.data.type() num_examples, height, width = tuple(self.xform_params.shape[0:3]) assert height == width ###################################################################### # Build permutation matrices # ###################################################################### def id_builder(): x = torch.zeros(height, width).type(dtype) for i in range(height): x[i, i] = 1 return x col_permuts = [] row_permuts = [] # torch.matmul(foo, col_permut) for col in ['left', 'right']: col_val = {'left': -1, 'right': 1}[col] idx = ((torch.arange(width) - col_val) % width) idx = idx.type(dtype).type(torch.LongTensor) if self.xform_params.is_cuda: idx = idx.cuda() col_permut = torch.zeros(height, width).index_copy_( 1, idx.cpu(), id_builder().cpu()) col_permut = col_permut.type(dtype) if col == 'left': col_permut[-1][0] = 0 col_permut[0][0] = 1 else: col_permut[0][-1] = 0 col_permut[-1][-1] = 1 col_permut = Variable(col_permut) col_permuts.append(col_permut) row_permuts.append(col_permut.transpose(0, 1)) ###################################################################### # Build delta_u, delta_v grids # ###################################################################### id_params = Variable(self.identity_params(self.img_shape)) delta_grids = self.xform_params - id_params delta_grids = delta_grids.permute(0, 3, 1, 2) ###################################################################### # Compute the norm # ###################################################################### output = Variable(torch.zeros(num_examples).type(dtype)) for row_or_col, permutes in zip(['row', 'col'], [row_permuts, col_permuts]): for permute in permutes: if row_or_col == 'row': temp = delta_grids - torch.matmul(permute, delta_grids) else: temp = delta_grids - torch.matmul(delta_grids, permute) temp = temp.pow(2) temp = temp.sum(1) temp = (temp + 1e-10).pow(0.5) output.add_(temp.sum((1, 2))) return output
#To compute all the gradients after performing all the forward propagation, .backward() can be used to automatically compute the gradients #autograd.Variable # extracting the tensor value form the variable - autograd.Variable.data # to extract the gradients, .grad will of the help #-------------------------------- # Functions in autograd - #Variables and functions do make up the acyclic graphs.. # Every variable that was created due to some tranformations and operations can refer to the function that created that variable using .grad_fn() from torch.autograd import Variable import numpy x = Variable(torch.ones(2,2)) # probably can also create a variable fircetly using the numpy array too!! x = Variable(torch.Tensor(numpy.linspace(1,10,30)), requires_grad = True) # Performing an operation on the varaible x.add_(2) # Now the function that has updated the variable x will be updated, pointed to by the parameter, grad_fn print(x.grad_fn) # Some more computations z = y*y*3 out = z.mean() prrint(z, out) # gradients in autograd out.backward() # performs the different the variable "out" wrt x (the variable we started the tranformation from) print(out) #---------------------- x = torch.randn(3) x = Variable(x, requires_grad = True) y = x*2 while(y.data.norm() < 1000):
def main(): global opt, model, HEIGHT, WIDTH, SCALE opt = parser.parse_args() print(opt) test_image = None if opt.testing: opt.batchSize = 1 img = imread(opt.test_image) HEIGHT, WIDTH = img.shape[0], img.shape[1] test_image = Image.fromarray(np.uint8(img)) test_image = np.asarray(test_image) if test_image.ndim == 3: if test_image.shape[2] != 3: test_image = test_image[:, :, 0:3] test_image = torch.ByteTensor( torch.ByteStorage.from_buffer( test_image.transpose(2, 0, 1).tobytes())).float().div(255).view( -1, 3, HEIGHT, WIDTH) else: print('not good... we do not upscale non color images yet') return cuda = opt.cuda if cuda and not torch.cuda.is_available(): raise Exception('No GPU found, please run without --cuda') opt.seed = random.randint(1, 10000) print('Random Seed: ', opt.seed) torch.manual_seed(opt.seed) if cuda: torch.cuda.manual_seed(opt.seed) model = SRResNet() # clean this mess up! if opt.testing: model.eval() mean = torch.zeros(opt.batchSize, 3, HEIGHT * SCALE, WIDTH * SCALE) mean[:, 0, :, :] = 0.485 mean[:, 1, :, :] = 0.456 mean[:, 2, :, :] = 0.406 std = torch.zeros(opt.batchSize, 3, HEIGHT * SCALE, WIDTH * SCALE) std[:, 0, :, :] = 0.229 std[:, 1, :, :] = 0.224 std[:, 2, :, :] = 0.225 tmean = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH) tmean[:, 0, :, :] = 0.485 tmean[:, 1, :, :] = 0.456 tmean[:, 2, :, :] = 0.406 tstd = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH) tstd[:, 0, :, :] = 0.229 tstd[:, 1, :, :] = 0.224 tstd[:, 2, :, :] = 0.225 else: model.train() mean = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH) mean[:, 0, :, :] = 0.485 mean[:, 1, :, :] = 0.456 mean[:, 2, :, :] = 0.406 std = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH) std[:, 0, :, :] = 0.229 std[:, 1, :, :] = 0.224 std[:, 2, :, :] = 0.225 tmean = torch.zeros(opt.batchSize, 3, HEIGHT // SCALE, WIDTH // SCALE) tmean[:, 0, :, :] = 0.485 tmean[:, 1, :, :] = 0.456 tmean[:, 2, :, :] = 0.406 tstd = torch.zeros(opt.batchSize, 3, HEIGHT // SCALE, WIDTH // SCALE) tstd[:, 0, :, :] = 0.229 tstd[:, 1, :, :] = 0.224 tstd[:, 2, :, :] = 0.225 if not opt.pretraining and not opt.testing: percep_model = models.__dict__['vgg19'](pretrained=True) percep_model.features = nn.Sequential( *list(percep_model.features.children())[:-14]) percep_model.eval() criterion = nn.MSELoss(size_average=False) lr = opt.lr if cuda: model = torch.nn.DataParallel(model).cuda() criterion = criterion.cuda() if not opt.pretraining and not opt.testing: percep_model = percep_model.cuda() mean = Variable(mean).cuda() std = Variable(std).cuda() tmean = Variable(tmean).cuda() tstd = Variable(tstd).cuda() if opt.pretrained: if os.path.isfile(opt.pretrained): print('=> loading model {}'.format(opt.pretrained)) weights = torch.load(opt.pretrained) model.load_state_dict(weights['model'].state_dict()) else: print('=> no model found at {}'.format(opt.pretrained)) if opt.testing: test_image = Variable(test_image) if cuda: test_image = test_image.cuda() test_image = test_image.sub(tmean).div(tstd) gen = model(test_image) gened = torch.clamp(gen.mul(std).add(mean).mul(255.0), min=0., max=255.0).byte()[0].data.cpu().numpy().transpose( 1, 2, 0) gened = Image.fromarray(gened) gened.save('testing-sr.jpg') else: train_set = dataprovider.DatasetFromDir(opt.image_dir, samples=opt.images, width=WIDTH, height=HEIGHT) training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) optimizer = optim.Adam(model.parameters(), lr=lr) counter = 0 for epoch in range(opt.nEpochs): loss_sum = Variable(torch.zeros(1), requires_grad=False) if cuda: loss_sum = loss_sum.cuda() for iteration, batch in enumerate(training_data_loader, 1): counter = counter + 1 input, target = (Variable(batch[0]), Variable(batch[1], requires_grad=False)) if cuda: input = input.cuda() target = target.cuda() input = input.sub(tmean).div(tstd) target = target.sub(mean).div(std) gen = model(input) optimizer.zero_grad() loss = criterion(gen, target) if not opt.pretraining: out_percep = percep_model.features(gen) out_percep_real = Variable( percep_model.features(target).data, requires_grad=False) percep_loss = criterion(out_percep, out_percep_real) # loss_relation = percep_loss.div(loss) loss = loss.add(percep_loss.mul( opt.percep_scale)) # loss_relation)) loss.backward() nn.utils.clip_grad_norm(model.parameters(), opt.clip) loss_sum.add_(loss) optimizer.step() if counter % 400 == 0: print('sum_of_loss = {}'.format(loss_sum.data.select(0, 0))) loss_sum = Variable(torch.zeros(1), requires_grad=False) if cuda: loss_sum = loss_sum.cuda() save_checkpoint(model, epoch) input = torch.clamp( input.mul(tstd).add(tmean).mul(255.0), min=0., max=255.0).byte()[0].data.cpu().numpy().transpose( 1, 2, 0) inp = Image.fromarray(input) label = torch.clamp( target.mul(std).add(mean).mul(255.0), min=0., max=255.0).byte()[0].data.cpu().numpy().transpose( 1, 2, 0) lab = Image.fromarray(label) gened = torch.clamp( gen.mul(std).add(mean).mul(255.0), min=0., max=255.0).byte()[0].data.cpu().numpy().transpose( 1, 2, 0) gened = Image.fromarray(gened) inp.save('input.jpg') lab.save('gt.jpg') gened.save('sr.jpg')
def test_local_var_binary_methods(self): x = Var(torch.FloatTensor([1, 2, 3, 4, 5])) y = Var(torch.FloatTensor([1, 2, 3, 4, 5])) assert torch.equal(x.add_(y), Var(torch.FloatTensor([2,4,6,8,10])))
def meta_learn(model, x_train, y_train, ystatus_train, x_val, y_val, ystatus_val, iterations, lr_inner, lr_outer, n_inner, batch_n, reg_scale, shots_n): optimizer = torch.optim.Adam(model.parameters(), lr=lr_outer) train_metalosses = [] test_metalosses = [] inner_optimizer_state = None for t in range(iterations): start = time.time() #Average gradient of a batch of tasks ind = random.sample(range(x_train.shape[0]), shots_n) x_batch = x_train[ind, ] ystatus_batch = ystatus_train[ind, ] y_batch = y_train[ind, ] R_matrix_batch = np.zeros([y_batch.shape[0], y_batch.shape[0]], dtype=int) for i in range(y_batch.shape[0]): for j in range(y_batch.shape[0]): R_matrix_batch[i, j] = y_batch[j] >= y_batch[i] new_model = do_base_learning(model, x_batch, R_matrix_batch, ystatus_batch, lr_inner, n_inner, reg_scale) diff = list() for p, new_p in zip(model.parameters(), new_model.parameters()): temp = Variable(torch.zeros(p.size())) temp.add_(p.data - new_p.data) diff.append(temp) for j in range(batch_n - 1): ind = random.sample(range(x_train.shape[0]), shots_n) x_batch = x_train[ind, ] ystatus_batch = ystatus_train[ind, ] y_batch = y_train[ind, ] R_matrix_batch = np.zeros([y_batch.shape[0], y_batch.shape[0]], dtype=int) for i in range(y_batch.shape[0]): for j in range(y_batch.shape[0]): R_matrix_batch[i, j] = y_batch[j] >= y_batch[i] new_model = do_base_learning(model, x_batch, R_matrix_batch, ystatus_batch, lr_inner, n_inner, reg_scale) diff_next = list() for p, new_p in zip(model.parameters(), new_model.parameters()): temp = Variable(torch.zeros(p.size())) temp.add_(p.data - new_p.data) diff_next.append(temp) diff = list(map(add, diff, diff_next)) diff_ave = [x / batch_n for x in diff] ind_k = 0 for p in model.parameters(): if p.grad is None: p.grad = Variable(torch.zeros(p.size())) p.grad.data.add_(diff_ave[ind_k]) ind_k = ind_k + 1 # Update meta-parameters optimizer.step() optimizer.zero_grad() val_metaloss, val_cind = do_base_eval(model, x_val, y_val, ystatus_val) end = time.time() print("1 iteration time:", end - start) print('Iteration', t)
def rec(node, rfunc): rsum = Variable(torch.zeros(node.hsize, )).to(Tree.device) for child in node.children: rsum.add_(rec(child, rfunc)) return rfunc(node.h_v, rsum)
def test_local_var_binary_methods(self): x = Var(torch.FloatTensor([1, 2, 3, 4, 5])) y = Var(torch.FloatTensor([1, 2, 3, 4, 5])) assert torch.equal(x.add_(y), Var(torch.FloatTensor([2, 4, 6, 8, 10])))
def Train(path, batch, shuffle, hiddenSize, latentSize, symSize, boxSize, catSize): matplotlib.use('Qt5Agg') plt.ion() ax = plt.gca() ax.set_autoscale_on(True) hl, = plt.plot([], []) data = get_loader(path, batch, shuffle) VAE = myVAE(hiddenSize, latentSize, symSize, boxSize, catSize) # VAE.load_state_dict(torch.load('VAE.pkl')) # optimization = torch.optim.SGD( # [{'params': VAE.encoder.BoxEnco.parameters()}, # {'params': VAE.decoder.BoxDeco.parameters()}, # {'params': VAE.encoder.symEnco1.parameters()}, # {'params': VAE.encoder.symEnco2.parameters()}, # {'params': VAE.decoder.symDeco1.parameters()}, # {'params': VAE.decoder.symDeco2.parameters()}, # {'params': VAE.encoder.AdjEnco1.parameters()}, # {'params': VAE.encoder.AdjEnco2.parameters()}, # {'params': VAE.decoder.AdjDeco1.parameters()}, # {'params': VAE.decoder.AdjDeco2.parameters()}, # {'params': VAE.decoder.NClr1.parameters(), 'lr':0.5/20}, # {'params': VAE.decoder.NClr2.parameters(), 'lr':0.5/20}, # {'params': VAE.ranen1.parameters()}, # {'params': VAE.ranen2.parameters()}, # {'params': VAE.rande1.parameters()}, # {'params': VAE.rande2.parameters()}], # lr=0.2/20) optimization = torch.optim.SGD(VAE.parameters(), lr=0.2 / 20) histLoss = list() for i in range(1500): batchLoss = list() for j, d in enumerate(data): boxes = d['boxes'] symshapes = Variable(d['symshapes'].float()) treekids = d['treekids'] symparams = Variable(d['symparams']) #calculate Output!!! myOut, paramOut, paramGT, NClrOut, NClrGT, mu, logvar = VAE( symshapes, treekids, symparams) #calculate KL Loss KLD_element = mu.pow(2).add_( logvar.exp()).mul_(-1).add_(1).add_(logvar) KLloss = torch.sum(KLD_element).mul(-0.5) # calculate Node Classification Loss finalNCl = torch.nn.functional.cross_entropy( NClrOut, NClrGT).mul_(0.2).mul_(NClrGT.size(0)) # finalNCl = Variable(torch.FloatTensor([0])) # for ii in range(np.shape(NClrGT)[0]): # finalNCl.add_(torch.nn.functional.cross_entropy(NClrOut[ii], NClrGT[ii]).mul_(0.2)) # #calculate Sym parameters Loss finalPL = torch.nn.functional.mse_loss(paramOut, paramGT).mul_( paramGT.size(0)) # finalPL = Variable(torch.FloatTensor([0])) # for ii in range(np.shape(paramGT)[0]): # finalPL.add_(torch.nn.functional.mse_loss(paramOut[ii], paramGT[ii])) #calculate reconstruction Loss symshapes = torch.t(symshapes.squeeze()) finalRL = torch.nn.functional.mse_loss(myOut, symshapes).mul_(0.8).mul_( myOut.size(0)) # finalRL = Variable(torch.FloatTensor([0])) # for ii in range(np.shape(myOut)[0]): # finalRL.add_(torch.nn.functional.mse_loss(myOut[ii], symshapes[:, :, ii]).mul_(0.8)) #final Loss FinalLoss = Variable(torch.FloatTensor([0])) FinalLoss = FinalLoss.add_(KLloss).add_(finalNCl).add_( finalPL).add_(finalRL) batchLoss.append(FinalLoss.data.numpy()) # leafcount = VAE.decoder.gLeafcount # gAssemcount = VAE.decoder.gAssemcount # gSymcount = VAE.decoder.gSymcount # optimization = torch.optim.SGD([{'params':VAE.encoder.parameters(), 'lr':0.2}, {'params':VAE.decoder.parameters(),'lr':0.2}, ], lr=0.2) # for k in range(len(optimization.param_groups)): # if(k <= 1): # optimization.param_groups[k]['lr'] = optimization.param_groups[k]['lr']/leafcount # elif(k <= 5): # optimization.param_groups[k]['lr'] = optimization.param_groups[k]['lr']/gSymcount # elif(k <= 9): # optimization.param_groups[k]['lr'] = optimization.param_groups[k]['lr']/gAssemcount # elif(k <= 11): # optimization.param_groups[k]['lr'] = optimization.param_groups[k]['lr']/treekids.size(1) if (j % 20 == 0 and j != 0): optimization.step() optimization.zero_grad() tmp = sum(batchLoss) / 20 histLoss.append(tmp) print('Epoch [%d/%d], Iter [%d/%d], Loss: %.4f, ' % (i, 1500, j, len(data), tmp)) del batchLoss[:] batchLoss = list() FinalLoss.backward() update_line(hl, ax, histLoss) if (i % 10 == 0 and i != 0): torch.save(VAE.state_dict(), 'VAE.pkl')
def forward(self, inputs, gold=None, inference=None): token_ids = inputs['token_ids'] token_mask = inputs['token_mask'] entity_ids = inputs['entity_ids'] entity_mask = inputs['entity_mask'] p_e_m = inputs['p_e_m'] p_e_ent_net = inputs['p_e_ent_net'] n_negs = inputs['n_negs'] n_ments, n_cands = entity_ids.size() n_rels = self.n_rels if self.mode == 'ment-norm' and self.first_head_uniform: self.ew_embs.data[0] = 0 if not self.oracle: gold = None if self.use_local: local_ent_scores = super(MulRelRanker, self).forward(token_ids, token_mask, entity_ids, entity_mask, p_e_m=None) ent_vecs = self._entity_vecs else: ent_vecs = self.entity_embeddings(entity_ids) local_ent_scores = Variable(torch.zeros(n_ments, n_cands).cuda(), requires_grad=False) # compute context vectors ltok_vecs = self.snd_word_embeddings(inputs['s_ltoken_ids']) * inputs['s_ltoken_mask'].view(n_ments, -1, 1) local_lctx_vecs = torch.sum(ltok_vecs, dim=1) / torch.sum(inputs['s_ltoken_mask'], dim=1, keepdim=True).add_(1e-5) rtok_vecs = self.snd_word_embeddings(inputs['s_rtoken_ids']) * inputs['s_rtoken_mask'].view(n_ments, -1, 1) local_rctx_vecs = torch.sum(rtok_vecs, dim=1) / torch.sum(inputs['s_rtoken_mask'], dim=1, keepdim=True).add_(1e-5) mtok_vecs = self.snd_word_embeddings(inputs['s_mtoken_ids']) * inputs['s_mtoken_mask'].view(n_ments, -1, 1) ment_vecs = torch.sum(mtok_vecs, dim=1) / torch.sum(inputs['s_mtoken_mask'], dim=1, keepdim=True).add_(1e-5) bow_ctx_vecs = torch.cat([local_lctx_vecs, ment_vecs, local_rctx_vecs], dim=1) if self.use_pad_ent: ent_vecs = torch.cat([ent_vecs, self.pad_ent_emb.view(1, 1, -1).repeat(1, n_cands, 1)], dim=0) tmp = torch.zeros(1, n_cands) tmp[0, 0] = 1 tmp = Variable(tmp.cuda()) entity_mask = torch.cat([entity_mask, tmp], dim=0) p_e_m = torch.cat([p_e_m, tmp], dim=0) local_ent_scores = torch.cat([local_ent_scores, Variable(torch.zeros(1, n_cands).cuda(), requires_grad=False)], dim=0) n_ments += 1 if self.oracle: tmp = Variable(torch.zeros(1, 1).cuda().long()) gold = torch.cat([gold, tmp], dim=0) if self.use_local_only: inputs = torch.cat([local_ent_scores.view(n_ments * n_cands, -1), torch.log(p_e_m + 1e-20).view(n_ments * n_cands, -1)], dim=1) scores = self.score_combine(inputs).view(n_ments, n_cands) if self.use_pad_ent: scores = scores[:-1] self.layer_scores = [scores] * self.n_layers return scores if n_ments == 1: ent_scores = local_ent_scores else: # distance - to consider only neighbor mentions ment_pos = torch.arange(0, n_ments).long().cuda() dist = (ment_pos.view(n_ments, 1) - ment_pos.view(1, n_ments)).abs() dist.masked_fill_(dist == 1, -1) dist.masked_fill_((dist > 1) & (dist <= self.max_dist), -1) dist.masked_fill_(dist > self.max_dist, 0) dist.mul_(-1) if self.uniform_att: rel_ctx_ctx_scores = Variable(torch.zeros(n_rels, n_ments, n_ments).cuda()) else: ctx_vecs = self.ctx_layer(bow_ctx_vecs) if self.use_pad_ent: ctx_vecs = torch.cat([ctx_vecs, self.pad_ctx_vec], dim=0) m1_ctx_vecs, m2_ctx_vecs = ctx_vecs, ctx_vecs rel_ctx_vecs = m1_ctx_vecs.view(1, n_ments, -1) * self.ew_embs.view(n_rels, 1, -1) rel_ctx_ctx_scores = torch.matmul(rel_ctx_vecs, m2_ctx_vecs.view(1, n_ments, -1).permute(0, 2, 1)) # n_rels x n_ments x n_ments rel_ctx_ctx_scores = rel_ctx_ctx_scores.add_((1 - Variable(dist.float().cuda())).mul_(-1e10)) eye = Variable(torch.eye(n_ments).cuda()).view(1, n_ments, n_ments) rel_ctx_ctx_scores.add_(eye.mul_(-1e10)) rel_ctx_ctx_scores.mul_(1 / np.sqrt(self.ew_hid_dims)) # scaling proposed by "attention is all you need" if self.mode == 'ment-norm': rel_ctx_ctx_probs = F.softmax(rel_ctx_ctx_scores, dim=2) rel_ctx_ctx_weights = rel_ctx_ctx_probs + rel_ctx_ctx_probs.permute(0, 2, 1) self._rel_ctx_ctx_weights = rel_ctx_ctx_probs elif self.mode == 'rel-norm': ctx_ctx_rel_scores = rel_ctx_ctx_scores.permute(1, 2, 0).contiguous() if not self.use_stargmax: ctx_ctx_rel_probs = F.softmax(ctx_ctx_rel_scores, dim=2) else: ctx_ctx_rel_probs = STArgmax.apply(ctx_ctx_rel_scores) self._rel_ctx_ctx_weights = ctx_ctx_rel_probs.permute(2, 0, 1).contiguous() # compute phi(ei, ej) if self.mode == 'ment-norm': if self.ent_ent_comp == 'bilinear': if self.ent_ent_comp == 'bilinear': rel_ent_vecs = ent_vecs.view(1, n_ments, n_cands, -1) * self.rel_embs.view(n_rels, 1, 1, -1) elif self.ent_ent_comp == 'trans_e': rel_ent_vecs = ent_vecs.view(1, n_ments, n_cands, -1) - self.rel_embs.view(n_rels, 1, 1, -1) else: raise Exception("unknown ent_ent_comp") rel_ent_ent_scores = torch.matmul(rel_ent_vecs.view(n_rels, n_ments, 1, n_cands, -1), ent_vecs[:, n_negs:, :].contiguous().view(1, 1, n_ments, n_cands - n_negs, -1).permute(0, 1, 2, 4, 3)) # n_rels x n_ments x n_ments x n_cands x (n_cands - n_negs) rel_ent_ent_scores = rel_ent_ent_scores.permute(0, 1, 3, 2, 4) # n_rel x n_ments x n_cands x n_ments x (n_cands - n_negs) rel_ent_ent_scores = (rel_ent_ent_scores * entity_mask[:, n_negs:]).add_((entity_mask[:, n_negs:] - 1).mul_(1e10)) rel_ent_ent_weighted_scores = rel_ent_ent_scores * rel_ctx_ctx_weights.view(n_rels, n_ments, 1, n_ments, 1) ent_ent_scores = torch.sum(rel_ent_ent_weighted_scores, dim=0).mul(1. / n_rels) # n_ments x n_cands x n_ments x (n_cands - n_negs) elif self.mode == 'rel-norm': raise Exception('not implement for multi-layer yet') # TODO: check it again rel_vecs = torch.matmul(ctx_ctx_rel_probs.view(n_ments, n_ments, 1, n_rels), self.rel_embs.view(1, 1, n_rels, -1))\ .view(n_ments, n_ments, -1) ent_rel_vecs = ent_vecs.view(n_ments, 1, n_cands, -1) * rel_vecs.view(n_ments, n_ments, 1, -1) # n_ments x n_ments x n_cands x dims ent_ent_scores = torch.matmul( ent_rel_vecs, ent_vecs[:, n_negs:, :].contiguous().view(1, n_ments, n_cands - n_negs, -1).permute(0, 1, 3, 2))\ .permute(0, 2, 1, 3) # n_ments x n_cands x n_ments x (n_cands - n_negs) if gold is None: if inference == 'LBP' or (inference is None and self.inference == 'LBP'): prev_msgs = Variable(torch.zeros(n_ments, n_cands, n_ments).cuda()) for _ in range(self.n_loops): mask = 1 - Variable(torch.eye(n_ments).cuda()) ent_ent_votes = ent_ent_scores + local_ent_scores[:, n_negs:] * 1 + \ torch.sum(prev_msgs.view(1, n_ments, n_cands, n_ments) * mask.view(n_ments, 1, 1, n_ments), dim=3)\ .view(n_ments, 1, n_ments, n_cands) msgs, _ = torch.max(ent_ent_votes, dim=3) # msgs = utils.log_sum_exp(ent_ent_votes, dim=3) msgs = (F.softmax(msgs, dim=1).mul(self.df) + prev_msgs.exp().mul(1 - self.df)).log() prev_msgs = msgs # compute marginal belief mask = 1 - Variable(torch.eye(n_ments).cuda()) ent_scores = local_ent_scores * 1 + torch.sum(msgs * mask.view(n_ments, 1, n_ments), dim=2) ent_scores = F.softmax(ent_scores, dim=1) elif inference == 'star' or (inference is None and self.inference == 'star'): comp = 'max' ent_weights = Variable(torch.ones(n_ments, n_cands - n_negs).cuda()) self_mask = Variable(1 - torch.eye(n_ments).cuda()).view(n_ments, 1, n_ments) ent_ent_scores += local_ent_scores[:, n_negs:] ent_ent_weighted_scores = ent_ent_scores * ent_weights ent_ent_weighted_scores = (ent_ent_weighted_scores * entity_mask[:, n_negs:]).add_((entity_mask[:, n_negs:] - 1).mul_(1e10)) if comp == 'weighted_sum': ent_ent_att_probs = F.softmax(ent_ent_weighted_scores, dim=3) # n_ments x n_cands x n_ments x (n_cands - n_negs) ent_ent_weighted_scores = ent_ent_weighted_scores * ent_ent_att_probs ent_ment_scores = torch.sum(ent_ent_weighted_scores, dim=3) # n_ments x n_cands x n_ments elif comp == 'softmax': ent_ment_scores = utils.log_sum_exp(ent_ent_weighted_scores, dim=3) elif comp == 'max': ent_ment_scores, _ = torch.max(ent_ent_weighted_scores, dim=3) # n_ments x n_cands x n_ments ent_ment_scores = (ent_ment_scores * self_mask).add_((self_mask - 1).mul_(1e10)) # set self-self scores to -inf # compute entity scores, using hard attention on neighbour mentions (eq 4) ent_top_ment_scores, _ = torch.topk(ent_ment_scores, dim=2, k=max(1, min(self.ent_top_n, n_ments - 2))) ent_scores = local_ent_scores + torch.sum(ent_top_ment_scores, dim=2) ent_scores = (ent_scores * entity_mask).add_((entity_mask - 1).mul_(1e10)) else: onehot_gold = Variable(torch.zeros(n_ments, n_cands).cuda()).scatter_(1, gold, 1) ent_scores = torch.sum(torch.sum(ent_ent_scores * onehot_gold, dim=3), dim=2) # combine with p_e_m p_e_m_mul = 1 inputs = torch.cat([ent_scores.view(n_ments * n_cands, -1), p_e_m_mul * torch.log(p_e_m + 1e-20).view(n_ments * n_cands, -1)], dim=1) scores = self.score_combine(inputs).view(n_ments, n_cands) # scores = ent_scores if self.use_pad_ent: scores = scores[:-1] return scores
def test_local_add(self): x = Var(torch.FloatTensor([1, 2, 3, 4, 5])) y = Var(torch.FloatTensor([1, 2, 3, 4, 5])) assert torch.equal(x.add_(y), Var(torch.FloatTensor([2, 4, 6, 8, 10])))