def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
        sequence_output = all_encoder_layers[-1]
        pooled_output = self.pooler(sequence_output)
        return all_encoder_layers, pooled_output
    def sample_conditional_a(self, resid_image, var_so_far, pixel_1d):

        is_on = (pixel_1d < (self.n_discrete_latent - 1)).float()

        # pass through galaxy encoder
        pixel_2d = self.one_galaxy_vae.pixel_1d_to_2d(pixel_1d)
        z_mean, z_var = self.one_galaxy_vae.enc(resid_image, pixel_2d)

        # sample z
        q_z = Normal(z_mean, z_var.sqrt())
        z_sample = q_z.rsample()

        # kl term for continuous latent vars
        log_q_z = q_z.log_prob(z_sample).sum(1)
        p_z = Normal(torch.zeros_like(z_sample), torch.ones_like(z_sample))
        log_p_z = p_z.log_prob(z_sample).sum(1)
        kl_z = is_on * (log_q_z - log_p_z)

        # run through decoder
        recon_mean, recon_var = self.one_galaxy_vae.dec(is_on, pixel_2d, z_sample)

        # NOTE: we will have to the recon means once we do more detections
        # recon_means = recon_mean + image_so_far
        # recon_vars = recon_var + var_so_far

        return recon_mean, recon_var, is_on, kl_z
Exemple #3
0
    def forward(self, X: Tensor) -> Tensor:
        r"""Evaluate Expected Improvement on the candidate set X.

        Args:
            X: A `b1 x ... bk x 1 x d`-dim batched tensor of `d`-dim design points.
                Expected Improvement is computed for each point individually,
                i.e., what is considered are the marginal posteriors, not the
                joint.

        Returns:
            A `b1 x ... bk`-dim tensor of Expected Improvement values at the
            given design points `X`.
        """
        self.best_f = self.best_f.to(X)
        posterior = self.model.posterior(X)
        self._validate_single_output_posterior(posterior)
        mean = posterior.mean
        # deal with batch evaluation and broadcasting
        view_shape = mean.shape[:-2] if mean.dim() >= X.dim() else X.shape[:-2]
        mean = mean.view(view_shape)
        sigma = posterior.variance.clamp_min(1e-9).sqrt().view(view_shape)
        u = (mean - self.best_f.expand_as(mean)) / sigma
        if not self.maximize:
            u = -u
        normal = Normal(torch.zeros_like(u), torch.ones_like(u))
        ucdf = normal.cdf(u)
        updf = torch.exp(normal.log_prob(u))
        ei = sigma * (updf + u * ucdf)
        return ei
    def test_index_setitem_bools_slices(self):
        true = torch.tensor(1, dtype=torch.uint8)
        false = torch.tensor(0, dtype=torch.uint8)

        tensors = [Variable(torch.randn(2, 3)), torch.tensor(3)]

        for a in tensors:
            # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
            # (some of these ops already prefix a 1 to the size)
            neg_ones = torch.ones_like(a) * -1
            neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
            a[True] = neg_ones_expanded
            self.assertEqual(a, neg_ones)
            a[False] = 5
            self.assertEqual(a, neg_ones)
            a[true] = neg_ones_expanded * 2
            self.assertEqual(a, neg_ones * 2)
            a[false] = 5
            self.assertEqual(a, neg_ones * 2)
            a[None] = neg_ones_expanded * 3
            self.assertEqual(a, neg_ones * 3)
            a[...] = neg_ones_expanded * 4
            self.assertEqual(a, neg_ones * 4)
            if a.dim() == 0:
                with self.assertRaises(RuntimeError):
                    a[:] = neg_ones_expanded * 5
Exemple #5
0
 def mse(self, prediction, target):
     if not hasattr(target, '__len__'):
         target = torch.ones_like(prediction)*target
         if prediction.is_cuda:
             target = target.cuda()
         target = Variable(target)
     return torch.nn.MSELoss()(prediction, target)
 def forward(self, inputs, targets):
     """
     Args:
     - inputs: feature matrix with shape (batch_size, feat_dim)
     - targets: ground truth labels with shape (num_classes)
     """
     n = inputs.size(0)
     
     # Compute pairwise distance, replace by the official when merged
     dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
     dist = dist + dist.t()
     dist.addmm_(1, -2, inputs, inputs.t())
     dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
     
     # For each anchor, find the hardest positive and negative
     mask = targets.expand(n, n).eq(targets.expand(n, n).t())
     dist_ap, dist_an = [], []
     for i in range(n):
         dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
         dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
     dist_ap = torch.cat(dist_ap)
     dist_an = torch.cat(dist_an)
     
     # Compute ranking hinge loss
     y = torch.ones_like(dist_an)
     loss = self.ranking_loss(dist_an, dist_ap, y)
     return loss
Exemple #7
0
    def test_probability_of_improvement(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            mean = torch.tensor([0.0], device=device, dtype=dtype).view(1, 1)
            variance = torch.ones(1, 1, device=device, dtype=dtype)
            mm = MockModel(MockPosterior(mean=mean, variance=variance))

            module = ProbabilityOfImprovement(model=mm, best_f=1.96)
            X = torch.zeros(1, 1, device=device, dtype=dtype)
            pi = module(X)
            pi_expected = torch.tensor(0.0250, device=device, dtype=dtype)
            self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

            module = ProbabilityOfImprovement(model=mm, best_f=1.96, maximize=False)
            X = torch.zeros(1, 1, device=device, dtype=dtype)
            pi = module(X)
            pi_expected = torch.tensor(0.9750, device=device, dtype=dtype)
            self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

            # check for proper error if multi-output model
            mean2 = torch.rand(1, 2, device=device, dtype=dtype)
            variance2 = torch.ones_like(mean2)
            mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2))
            module2 = ProbabilityOfImprovement(model=mm2, best_f=0.0)
            with self.assertRaises(UnsupportedError):
                module2(X)
    def test_index_setitem_bools_slices(self):
        true = variable(1).byte()
        false = variable(0).byte()

        tensors = [Variable(torch.randn(2, 3))]
        if torch._C._with_scalars():
            tensors.append(variable(3))

        for a in tensors:
            a_clone = a.clone()
            # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
            # (some of these ops already prefix a 1 to the size)
            neg_ones = torch.ones_like(a) * -1
            neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
            a[True] = neg_ones_expanded
            self.assertEqual(a, neg_ones)
            a[False] = 5
            self.assertEqual(a, neg_ones)
            if torch._C._with_scalars():
                a[true] = neg_ones_expanded * 2
                self.assertEqual(a, neg_ones * 2)
                a[false] = 5
                self.assertEqual(a, neg_ones * 2)
            a[None] = neg_ones_expanded * 3
            self.assertEqual(a, neg_ones * 3)
            a[...] = neg_ones_expanded * 4
            self.assertEqual(a, neg_ones * 4)
            if a.dim() == 0:
                with self.assertRaises(RuntimeError):
                    a[:] = neg_ones_expanded * 5
Exemple #9
0
def dummy_mask(seq):
    '''
    create dummy mask (all 1)
    '''
    if isinstance(seq, tuple):
        seq = seq[0]
    assert len(seq.size()) == 1 or (len(seq.size()) == 2 and seq.size(1) == 1)
    return torch.ones_like(seq, dtype=torch.float)
Exemple #10
0
 def decode_step(self, enc_hs, enc_mask, input_, hidden):
     trans, emiss, hidden = super().decode_step(enc_hs, enc_mask, input_,
                                                hidden)
     trans_mask = torch.ones_like(trans[0]).triu().unsqueeze(0)
     trans_mask = (trans_mask - 1) * -np.log(EPSILON)
     trans = trans + trans_mask
     trans = trans - trans.logsumexp(-1, keepdim=True)
     return trans, emiss, hidden
Exemple #11
0
def topk_demo(wl_dist: Tensor, ncubes: int):
    """ torch.topk() only returns the chosen index, sometimes I want to split the tensor,
        i.e., derive the other indices as well.
    """
    batch_dist, topk_idxs = wl_dist.topk(ncubes, sorted=False)  # topk_idxs: size <K>
    other_idxs = torch.arange(len(wl_dist))  # <Batch>
    other_idxs = torch.ones_like(other_idxs).byte().scatter_(-1, topk_idxs, 0)  # topk_idxs are 0, others are 1
    other_idxs = other_idxs.nonzero().squeeze(dim=-1)  # <Batch-K>
    return topk_idxs, other_idxs
Exemple #12
0
 def propose_log_prob(self, value):
     v = value / self._d
     result = -self._d.log()
     y = v.pow(1 / 3)
     result -= torch.log(3 * y ** 2)
     x = (y - 1) / self._c
     result -= self._c.log()
     result += Normal(torch.zeros_like(self.concentration), torch.ones_like(self.concentration)).log_prob(x)
     return result
Exemple #13
0
 def penalty(self, dis, real_data, fake_data):
     probe = self.get_probe(real_data.detach(), fake_data.detach())
     probe.requires_grad = True
     probe_logit, _ = dis(probe)
     gradients = autograd.grad(outputs=F.sigmoid(probe_logit),
                               inputs=probe,
                               grad_outputs=torch.ones_like(probe_logit))[0]
     grad_norm = gradients.view(gradients.shape[0], -1).norm(2, dim=1)
     penalty = ((grad_norm - self.target) ** 2).mean()
     return self.weight * penalty, grad_norm.mean()
Exemple #14
0
def bce(prediction, target):
    if not hasattr(target, '__len__'):
        target = torch.ones_like(prediction)*target
        if prediction.is_cuda:
            target = target.cuda()
        target = Variable(target)
    loss = torch.nn.BCELoss()
    if prediction.is_cuda:
        loss = loss.cuda()
    return loss(prediction, target)
    def test_cuda_extension(self):
        import torch_test_cuda_extension as cuda_extension

        x = torch.FloatTensor(100).zero_().cuda()
        y = torch.FloatTensor(100).zero_().cuda()

        z = cuda_extension.sigmoid_add(x, y).cpu()

        # 2 * sigmoid(0) = 2 * 0.5 = 1
        self.assertEqual(z, torch.ones_like(z))
Exemple #16
0
    def forward(self, prob, targets, infos, wt=None):
        prob = prob.clamp(min=1e-7, max=1-1e-7)
        if wt is None:
            wt1 = torch.ones_like(prob)
        if config.TRAIN.CE_LOSS_WEIGHTED and self.pos_wt is not None:
            wt1 = wt * (targets.detach() * self.pos_wt + (1-targets.detach()) * self.neg_wt)

        loss = -torch.mean(wt1 * (torch.log(prob) * targets + torch.log(1-prob) * (1-targets)))

        return loss
    def test_cuda_extension(self):
        import torch_test_cpp_extension.cuda as cuda_extension

        x = torch.zeros(100, device='cuda', dtype=torch.float32)
        y = torch.zeros(100, device='cuda', dtype=torch.float32)

        z = cuda_extension.sigmoid_add(x, y).cpu()

        # 2 * sigmoid(0) = 2 * 0.5 = 1
        self.assertEqual(z, torch.ones_like(z))
Exemple #18
0
 def sample(self, sample_shape=torch.Size()):
     sample_shape = torch.Size(sample_shape)
     samples = self._categorical.sample(torch.Size((self.total_count,)) + sample_shape)
     # samples.shape is (total_count, sample_shape, batch_shape), need to change it to
     # (sample_shape, batch_shape, total_count)
     shifted_idx = list(range(samples.dim()))
     shifted_idx.append(shifted_idx.pop(0))
     samples = samples.permute(*shifted_idx)
     counts = samples.new(self._extended_shape(sample_shape)).zero_()
     counts.scatter_add_(-1, samples, torch.ones_like(samples))
     return counts.type_as(self.probs)
Exemple #19
0
def valid_lb_ub(lb: Tensor, ub: Tensor) -> bool:
    """ To be valid:
        (1) Size ==
        (2) LB <= UB
    """
    if lb.size() != ub.size():
        return False

    # '<=' will return a uint8 tensor of 1 or 0 for each element, it should have all 1s.
    rel = lb <= ub
    return torch.equal(rel, torch.ones_like(rel))
Exemple #20
0
def torch_ones_like(x):
    """
    Polyfill for `torch.ones_like()`.
    """
    # Work around https://github.com/pytorch/pytorch/issues/2906
    if isinstance(x, Variable):
        return Variable(torch_ones_like(x.data))
    # Support Pytorch before https://github.com/pytorch/pytorch/pull/2489
    try:
        return torch.ones_like(x)
    except AttributeError:
        return torch.ones(x.size()).type_as(x)
Exemple #21
0
 def __init__(self, loc, scale, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     finfo = _finfo(self.loc)
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
         base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
     else:
         batch_shape = self.scale.size()
         base_dist = Uniform(self.loc.new(self.loc.size()).fill_(finfo.tiny), 1 - finfo.eps)
     transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
                   ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
     super(Gumbel, self).__init__(base_dist, transforms, validate_args=validate_args)
Exemple #22
0
 def calc_gen_loss(self, input_fake):
     # calculate the loss to train G
     outs0 = self.forward(input_fake)
     loss = 0
     for it, (out0) in enumerate(outs0):
         if self.gan_type == 'lsgan':
             loss += torch.mean((out0 - 1)**2) # LSGAN
         elif self.gan_type == 'nsgan':
             all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False)
             loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1))
         else:
             assert 0, "Unsupported GAN type: {}".format(self.gan_type)
     return loss
Exemple #23
0
    def __call__(self,
                 predictions: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: Optional[torch.Tensor] = None):
        """
        Parameters
        ----------
        predictions : ``torch.Tensor``, required.
            A tensor of predictions of shape (batch_size, ..., num_classes).
        gold_labels : ``torch.Tensor``, required.
            A tensor of integer class label of shape (batch_size, ...). It must be the same
            shape as the ``predictions`` tensor without the ``num_classes`` dimension.
        mask: ``torch.Tensor``, optional (default = None).
            A masking tensor the same size as ``gold_labels``.
        """
        predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask)

        num_classes = predictions.size(-1)
        if (gold_labels >= num_classes).any():
            raise ConfigurationError("A gold label passed to F1Measure contains an id >= {}, "
                                     "the number of classes.".format(num_classes))
        if mask is None:
            mask = torch.ones_like(gold_labels)
        mask = mask.float()
        gold_labels = gold_labels.float()
        positive_label_mask = gold_labels.eq(self._positive_label).float()
        negative_label_mask = 1.0 - positive_label_mask

        argmax_predictions = predictions.max(-1)[1].float().squeeze(-1)

        # True Negatives: correct non-positive predictions.
        correct_null_predictions = (argmax_predictions !=
                                    self._positive_label).float() * negative_label_mask
        self._true_negatives += (correct_null_predictions.float() * mask).sum()

        # True Positives: correct positively labeled predictions.
        correct_non_null_predictions = (argmax_predictions ==
                                        self._positive_label).float() * positive_label_mask
        self._true_positives += (correct_non_null_predictions * mask).sum()

        # False Negatives: incorrect negatively labeled predictions.
        incorrect_null_predictions = (argmax_predictions !=
                                      self._positive_label).float() * positive_label_mask
        self._false_negatives += (incorrect_null_predictions * mask).sum()

        # False Positives: incorrect positively labeled predictions
        incorrect_non_null_predictions = (argmax_predictions ==
                                          self._positive_label).float() * negative_label_mask
        self._false_positives += (incorrect_non_null_predictions * mask).sum()
Exemple #24
0
def _make_grads(outputs, grads):
    new_grads = []
    for out, grad in zip(outputs, grads):
        if isinstance(grad, torch.Tensor):
            new_grads.append(grad)
        elif grad is None:
            if out.requires_grad:
                if out.numel() != 1:
                    raise RuntimeError("grad can be implicitly created only for scalar outputs")
                new_grads.append(torch.ones_like(out))
            else:
                new_grads.append(None)
        else:
            raise TypeError("gradients can be either Tensors or None, but got " +
                            type(grad).__name__)
    return tuple(new_grads)
Exemple #25
0
    def calc_dis_loss(self, input_fake, input_real):
        # calculate the loss to train D
        outs0 = self.forward(input_fake)
        outs1 = self.forward(input_real)
        loss = 0

        for it, (out0, out1) in enumerate(zip(outs0, outs1)):
            if self.gan_type == 'lsgan':
                loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2)
            elif self.gan_type == 'nsgan':
                all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False)
                all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False)
                loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) +
                                   F.binary_cross_entropy(F.sigmoid(out1), all1))
            else:
                assert 0, "Unsupported GAN type: {}".format(self.gan_type)
        return loss
Exemple #26
0
    def forward(self, input, adj):
        h = torch.mm(input, self.W)
        N = h.size()[0]

        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime
Exemple #27
0
def build_sequences(sequences, nenvs, nsteps, depth, return_mask=False, offset=0):
    # sequences are bs x size, containing e.g. rewards, actions, state reps
    # returns bs x depth x size processed sequences with a sliding window set by 'depth', padded with 0's
    # if return_mask=True also returns a mask showing where the sequences were padded
    # This can be used to produce targets for tree outputs, from the true observed sequences
    sequences = [s.view(nenvs, nsteps, -1) for s in sequences]
    if return_mask:
        mask = torch.ones_like(sequences[0]).float()
        sequences.append(mask)
    sequences = [F.pad(s, (0, 0, 0, depth+offset, 0, 0), mode="constant", value=0).data for s in sequences]
    proc_sequences = []
    for seq in sequences:
        proc_seq = []
        for env in range(seq.shape[0]):
            for t in range(nsteps):
                proc_seq.append(seq[env, t+offset:t+offset+depth, :])
        proc_sequences.append(torch.stack(proc_seq))
    return proc_sequences
    def test_jit_cuda_extension(self):
        # NOTE: The name of the extension must equal the name of the module.
        module = torch.utils.cpp_extension.load(
            name='torch_test_cuda_extension',
            sources=[
                'cpp_extensions/cuda_extension.cpp',
                'cpp_extensions/cuda_extension.cu'
            ],
            extra_cuda_cflags=['-O2'],
            verbose=True)

        x = torch.FloatTensor(100).zero_().cuda()
        y = torch.FloatTensor(100).zero_().cuda()

        z = module.sigmoid_add(x, y).cpu()

        # 2 * sigmoid(0) = 2 * 0.5 = 1
        self.assertEqual(z, torch.ones_like(z))
    def _get_checklist_info(self,
                            agenda: torch.LongTensor,
                            all_actions: List[ProductionRuleArray]) -> Tuple[torch.Tensor,
                                                                             torch.Tensor,
                                                                             torch.Tensor]:
        """
        Takes an agenda and a list of all actions and returns a target checklist against which the
        checklist at each state will be compared to compute a loss, indices of ``terminal_actions``,
        and a ``checklist_mask`` that indicates which of the terminal actions are relevant for
        checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``,
        ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to
        ``False``, indices of all terminals that are not in the agenda will be masked.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRuleArray]``
            All actions for one instance.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRuleArray, a tuple where the first item is the production
            # rule string.
            if action[0] in self._terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        # (num_terminals, 1)
        terminal_actions = agenda.new_tensor(terminal_indices)
        # (num_terminals, 1)
        target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
        if self._penalize_non_agenda_actions:
            # All terminal actions are relevant
            checklist_mask = torch.ones_like(target_checklist)
        else:
            checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask
Exemple #30
0
    def forward(self, prob, targets, infos, wt=None):
        if wt is None:
            wt = torch.ones_like(prob)
        prob = prob.clamp(min=1e-7, max=1-1e-7)
        with torch.no_grad():
            prob_diff_wt = torch.abs((prob - targets) * wt) ** config.TRAIN.RHEM_POWER
            idx = torch.multinomial(prob_diff_wt.view(-1), config.TRAIN.RHEM_BATCH_SIZE, replacement=True)
            # hist = np.histogram(idx.cpu().numpy(), np.arange(torch.numel(prob)+1))[0]
            # hist = np.reshape(hist, prob.shape)
            # pos = np.where(hist == np.max(hist))
            # row = pos[0][0]
            # col = pos[1][0]
            # print np.max(hist), prob[row, col].item(), targets[row, col].item(), \
            #     default.term_list[col], int(self.pos_wt[col].item()), infos[row][0]#, prob_diff_wt.mean(0)[col].item()

        targets = targets.view(-1)[idx]
        prob = prob.view(-1)[idx]
        loss_per_smp = - (torch.log(prob) * targets + torch.log(1-prob) * (1-targets))
        loss = loss_per_smp.mean()

        return loss
Exemple #31
0
def atom_distances(
    positions,
    neighbors,
    cell=None,
    cell_offsets=None,
    return_vecs=False,
    return_directions=False,
    neighbor_mask=None,
):
    """
    Use advanced torch indexing to compute differentiable distances
    of every central atom to its relevant neighbors. Indices of the
    neighbors to consider are stored in neighbors.

    Args:
        positions (torch.Tensor): Atomic positions, differentiable torch
            Variable (B x N_at x 3)
        neighbors (torch.Tensor): Indices of neighboring
            atoms (B x N_at x N_nbh)
        cell (torch.Tensor): cell for periodic systems (B x 3 x 3)
        cell_offsets (torch.Tensor): offset of atom in cell
            coordinates (B x N_at x N_nbh x 3)
        return_directions (bool): If true, also return direction cosines.
        neighbor_mask (torch.Tensor, optional): Boolean mask for neighbor
            positions. Required for the stable computation of forces in
            molecules with different sizes.

    Returns:
        torch.Tensor: Distances of every atom to its
            neighbors (B x N_at x N_nbh)
        torch.Tensor: Direction cosines of every atom to its
            neighbors (B x N_at x N_nbh x 3) (optional)
    """

    # Construct auxiliary index vector
    n_batch = positions.size()[0]
    idx_m = torch.arange(n_batch, device=positions.device, dtype=torch.long)[
        :, None, None
    ]
    # Get atomic positions of all neighboring indices
    pos_xyz = positions[idx_m, neighbors[:, :, :], :]

    # Subtract positions of central atoms to get distance vectors
    dist_vec = pos_xyz - positions[:, :, None, :]

    # add cell offset
    if cell is not None:
        B, A, N, D = cell_offsets.size()
        cell_offsets = cell_offsets.view(B, A * N, D)
        offsets = cell_offsets.bmm(cell)
        offsets = offsets.view(B, A, N, D)
        dist_vec += offsets

    # Compute vector lengths
    distances = torch.norm(dist_vec, 2, 3)

    if neighbor_mask is not None:
        # Avoid problems with zero distances in forces (instability of square
        # root derivative at 0) This way is neccessary, as gradients do not
        # work with inplace operations, such as e.g.
        # -> distances[mask==0] = 0.0
        tmp_distances = torch.zeros_like(distances)
        tmp_distances[neighbor_mask != 0] = distances[neighbor_mask != 0]
        distances = tmp_distances

    if return_directions or return_vecs:
        tmp_distances = torch.ones_like(distances)
        tmp_distances[neighbor_mask != 0] = distances[neighbor_mask != 0]

        if return_directions:
            dist_vec = dist_vec / tmp_distances[:, :, :, None]
        return distances, dist_vec
    return distances
Exemple #32
0
'''
Numpy Bridge
'''
###把Torch tensor转成Numpy array
a  = torch.ones(5)
print(a)
b = a.numpy()
print(b)
a.add_(1)
print(a)
print(b)

###把Numpy array转成Torch tensor
a = np.ones(5)
b = torch.as_tensor(a)
c = torch.from_numpy(a)# 两种方法都可以
np.add(a,1,out=a)
print(a)
print(b)# 两个都改变

'''
把tensor搬到确定的设备(device:cpu or gpu)
'''
if torch.cuda.is_available():
    device = torch.device("cuda")
    x = rand(4,4)
    y = torch.ones_like(x,device=device) #直接在gpu上创建一个tensor
    x = x.to(device) #使用 .to() 函数把数据搬到gpu
    z = x+y
    print(z)
    print(z.to("cpu",torch.double)) #同样还可以使用 .to() 把数据搬回cpu
Exemple #33
0
    def forward(
        self,
        input,
        mask=None,
        kv=None,
        past_key_value_state=None,
        query_length=None,
        use_cache=False,
        relpos=None,
    ):
        """
        Self-attention (if kv is None) or attention over source sentence (provided by kv).
        """
        # Input is (bs, qlen, dim)
        # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
        # past_key_value_state[0] is (bs, n_heads, q_len - 1, dim_per_head)
        bs, qlen, dim = input.size()

        if past_key_value_state is not None:
            assert self.is_decoder is True, "Encoder cannot cache past key value states"
            assert (
                len(past_key_value_state) == 2
            ), "past_key_value_state should have 2 past states: keys and values. Got {} past states".format(
                len(past_key_value_state))
            real_qlen = qlen + past_key_value_state[0].shape[
                2] if query_length is None else query_length
        else:
            real_qlen = qlen

        if kv is None:
            klen = real_qlen
        else:
            klen = kv.size(1)

        def shape(x):
            """  projection """
            return x.view(bs, -1, self.n_heads, self.d_kv).transpose(1, 2)

        def unshape(x):
            """  compute context """
            return x.transpose(1, 2).contiguous().view(bs, -1, self.inner_dim)

        q = shape(self.q(input))  # (bs, n_heads, qlen, dim_per_head)

        if kv is None:
            k = shape(self.k(input))  # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v(input))  # (bs, n_heads, qlen, dim_per_head)
        elif past_key_value_state is None:
            k = v = kv
            k = shape(self.k(k))  # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v(v))  # (bs, n_heads, qlen, dim_per_head)

        if past_key_value_state is not None:
            if kv is None:
                k_, v_ = past_key_value_state
                k = torch.cat([k_, k],
                              dim=2)  # (bs, n_heads, klen, dim_per_head)
                v = torch.cat([v_, v],
                              dim=2)  # (bs, n_heads, klen, dim_per_head)
            else:
                k, v = past_key_value_state

        if self.is_decoder and use_cache is True:
            present_key_value_state = ((k, v), )
        else:
            present_key_value_state = (None, )

        scores = torch.einsum("bnqd,bnkd->bnqk", q,
                              k)  # (bs, n_heads, qlen, klen)
        scores = scores / math.sqrt(self.d_kv)

        if relpos is not None:
            assert self.rel_emb is not None, "can't process relpos because rel_emb is not initialized"
            relpos_scores = self.rel_emb.compute_scores(
                q, relpos
            )  # (bs, n_heads, qlen, dim)x(bs, qlen, klen)->(bs, n_heads, qlen, klen)
            scores = scores + relpos_scores

        if mask is not None:
            if not self.cross_attention and self.is_decoder:
                maskslots = (
                    (torch.arange(mask.size(3), device=mask.device) + 1) %
                    2).float()[None, :]
                maskslots = maskslots + torch.eye(mask.size(3),
                                                  device=mask.device)
                maskslots = torch.where(maskslots > 0,
                                        torch.ones_like(maskslots) * 0,
                                        torch.ones_like(maskslots) * -1e6)
                maskslots = maskslots[-mask.size(2):, :]
                maskslots = maskslots[None, None, :]
                mask = mask + maskslots
            scores = scores + mask

        weights = F.softmax(scores.float(), dim=-1).type_as(
            scores)  # (bs, n_heads, qlen, klen)
        weights = self.dropout(weights)  # (bs, n_heads, qlen, klen)

        context = torch.matmul(weights, v)  # (bs, n_heads, qlen, dim_per_head)

        if relpos is not None:
            context_rel = self.rel_emb.compute_context(
                weights, relpos
            )  # (bs, n_heads, qlen, klen)x(bs, qlen, klen) -> (bs, n_heads, qlen, dim)
            context = context + context_rel

        context = unshape(context)  # (bs, qlen, dim)

        _context = context
        context = self.o(context)

        if self.config.vib_att:
            _context = torch.relu(context) * torch.sigmoid(
                self.o_gate(_context))
            _context = self.o_ln(_context)
            mu, logvar = self.o_mu(_context), self.o_logvar(_context)

            if self.training:
                ret = mu + torch.exp(0.5 * logvar) * torch.randn_like(mu)
            else:
                ret = mu
            context = ret

            priorkl = torch.zeros(ret.size(0), ret.size(1), device=ret.device)
            if self.training:
                priorkl = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp(),
                                           dim=-1)  # (batsize, seqlen)
                # priorkls = priorkls * mask.float()        # TOD: mask !!!
                # priorkl = priorkls.sum(-1)

            outputs = (context, ) + present_key_value_state + (priorkl, )
        else:
            outputs = (context, ) + present_key_value_state

        if self.output_attentions:
            outputs = outputs + (weights, )
        return outputs
z = x.view(-1, 8)  # the size -1 is inferred from other dimensions
print(x.size(), y.size(), z.size())

#to get python as a value nummber
x = torch.randn(1)
print(x)
print(x.item())



#Tensor has operations, including transposing, indexing, slicing, mathematical operations, linear algebra, random numbers
# let us run this cell only if CUDA is available
# We will use ``torch.device`` objects to move tensors in and out of GPU
if torch.cuda.is_available():
    device = torch.device("cuda")          # a CUDA device object
    y = torch.ones_like(x, device=device)  # directly create a tensor on GPU
    x = x.to(device)                       # or just use strings ``.to("cuda")``
    z = x + y
    print(z)
    print(z.to("cpu", torch.double))       # ``.to`` can also change dtype together!
    b = a.numpy()
print(b)

#output after changing to array
a.add_(1)
print(a)
print(b)

#convering numpy array to tensor torch
import numpy as np
a = np.ones(5)
Exemple #35
0
def hard_sigmoid(x):
    return torch.min(torch.max(x, torch.zeros_like(x)), torch.ones_like(x))
Exemple #36
0
    def model_fit(self,epoch):

        
        best_acc = 0

        lamb =0.1

        if ((epoch+1)%self.down_period) ==0 and (self.lr>1e-4) :
            self.lr = self.lr*self.lr_decay_rate

        if epoch>80:
            self.lr = self.lr*0.99
        all_parameters_h = sum([list(h.parameters()) for h in self.hypothesis], [])

        self.optimizer = optim.Adam(list(self.FE.parameters()) + list(all_parameters_h),
                                       lr=self.lr,  weight_decay = 1e-5)


        sem_mtrx = np.zeros((self.num_tsk, self.num_tsk))

        loss_mtrx_hypo_vlue = np.zeros((self.num_tsk, self.num_tsk))
        weigh_loss_hypo_vlue, correct_hypo = np.zeros(self.num_tsk), np.zeros(self.num_tsk)
        Total_loss = 0
        n_batch = 0

        # set train mode
        self.FE.train()
        for t in range(self.num_tsk):
            self.hypothesis[t].train()


        for tasks_batch in zip(*self.train_loader ):
            Loss_1, Loss_2 = 0, 0
            semantic_loss = 0
            n_batch += 1
            # data = (x,y)
            inputs = torch.cat([batch[0] for batch in tasks_batch])


            btch_sz = len(tasks_batch[0][0])
            targets = torch.cat([batch[1] for batch in tasks_batch])

            # inputs = (x1,...,xT)  targets = (y1,...,yT)
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            features = self.FE(inputs)
            features = features.view(features.size(0), -1)

            
            for t in range(self.num_tsk):
                w = torch.tensor([np.tile(self.alpha[t, i], reps=len(data[0])) for i, data in enumerate(tasks_batch)],
                                 dtype=torch.float).view(-1)
                w = w.to(self.device)

                label_prob, fea = self.hypothesis[t](features)

                pred = label_prob[t * (btch_sz):(t + 1) * btch_sz].argmax(dim=1, keepdim=True)
                correct_hypo[t] += (
                (pred.eq(targets[t * btch_sz:(t + 1) * btch_sz].view_as(pred)).sum().item()) / btch_sz)

                hypo_loss = torch.mean(w * F.cross_entropy(label_prob, targets, reduction='none'))


                Loss_1 += hypo_loss
                weigh_loss_hypo_vlue[t] += hypo_loss.item()



                for k in range(t + 1, self.num_tsk):
                    alpha_domain = torch.tensor(self.alpha[t, k] + self.alpha[k, t], dtype=torch.float)
                    alpha_domain = alpha_domain.to(self.device)
                    

                    sem_fea_t = fea[t * btch_sz:(t + 1) * btch_sz]
                    sem_fea_k = fea[k * btch_sz:(k + 1) * btch_sz]


                        
                    labels_t = targets[t * btch_sz:(t + 1) * btch_sz]
                    labels_k = targets[k * btch_sz:(k + 1) * btch_sz]

                        


                    _,d = sem_fea_t.shape


                    # image number in each class
                    ones = torch.ones_like(labels_t, dtype=torch.float)
                    zeros = torch.zeros(self.n_class)
                    if self.cudable:
                        zeros = zeros.cuda()
                    # smaples per class
                    t_n_classes = zeros.scatter_add(0, labels_t, ones)
                    k_n_classes = zeros.scatter_add(0, labels_k, ones)

                    # image number cannot be 0, when calculating centroids
                    ones = torch.ones_like(t_n_classes)
                    t_n_classes = torch.max(t_n_classes, ones)
                    k_n_classes = torch.max(k_n_classes, ones)

                    # calculating centroids, sum and divide
                    zeros = torch.zeros(self.n_class, d)
                    if self.cudable:
                        zeros = zeros.cuda()
                    t_sum_feature = zeros.scatter_add(0, torch.transpose(labels_t.repeat(d, 1), 1, 0), sem_fea_t)
                    k_sum_feature = zeros.scatter_add(0, torch.transpose(labels_k.repeat(d, 1), 1, 0), sem_fea_k)
                    current_t_centroid = torch.div(t_sum_feature, t_n_classes.view(self.n_class, 1))
                    current_k_centroid = torch.div(k_sum_feature, k_n_classes.view(self.n_class, 1))

                    # Moving Centroid
                    decay = self.decay
                    t_centroid = (1-decay) * self.centroids[t] + decay * current_t_centroid
                    k_centroid = (1-decay) * self.centroids[k] + decay * current_k_centroid
                        
                    s_loss = self.MSEloss(t_centroid, k_centroid)
                    semantic_loss += torch.mean(s_loss)



                    self.centroids[t] = t_centroid.detach()
                    self.centroids[k]= k_centroid.detach()



            Loss_2 = torch.mean(alpha_domain* semantic_loss) 
            Loss =  torch.mean(Loss_1)+ lamb * Loss_2* (1.0 / self.num_tsk) 

            self.optimizer.zero_grad()
            Loss.backward(retain_graph=True)
            self.optimizer.step()

        if epoch>0:
            c_2, c_3 = 1 * np.ones(self.num_tsk), self.c3_value * np.ones(self.num_tsk)
            self.alpha = min_alphacvx(self.alpha.T, c_2, c_3, loss_mtrx_hypo_vlue.T, sem_mtrx.T)

            self.alpha = self.alpha.T


        Total_loss += Loss.item()

        return  Total_loss
 def __init__(self, shape_like):
     super().__init__()
     if isinstance(shape_like, torch.Tensor):
         self.value = torch.ones_like(shape_like)
     else:
         self.value = torch.ones_like(torch.tensor([shape_like]))
Exemple #38
0
    def forward(self, data):
        batchSize = data['image'].shape[0]
        width = data['image'].shape[3]
        height = data['image'].shape[2]

        pixelFeatureMap = self.mainModel(data['image'])

        # Compute the embedding based on the symbols provided. symbols are usually traces of which lines of code got executed.
        symbolEmbeddings = self.symbolEmbedding(data['symbolIndexes'], data['symbolOffsets'], per_sample_weights=data['symbolWeights'])

        # Concatenate the step number with the rest of the additional features
        additionalFeaturesWithStep = torch.cat([torch.log10(data['stepNumber'] + torch.ones_like(data['stepNumber'])).reshape([-1, 1]), self.stampProjection(symbolEmbeddings)], dim=1)

        # Append the stamp layer along side the pixel-by-pixel features
        stamp = additionalFeaturesWithStep.reshape([-1, self.config['additional_features_stamp_depth_size'],
                                                    self.config['additional_features_stamp_edge_size'],
                                                    self.config['additional_features_stamp_edge_size']])

        featureMapHeight = pixelFeatureMap.shape[2]
        featureMapWidth = pixelFeatureMap.shape[3]
        stampTiler = stamp.repeat([1, 1, int(featureMapHeight / self.config['additional_features_stamp_edge_size']) + 1, int(featureMapWidth / self.config['additional_features_stamp_edge_size']) + 1])
        stampLayer = stampTiler[:, :, :featureMapHeight, :featureMapWidth].reshape([-1, self.config['additional_features_stamp_depth_size'], featureMapHeight, featureMapWidth])

        mergedPixelFeatureMap = torch.cat([stampLayer, pixelFeatureMap], dim=1)

        outputDict = {}

        if data['computeRewards']:
            presentRewards = self.presentRewardConvolution(mergedPixelFeatureMap) * data['pixelActionMaps'] + (1.0 - data['pixelActionMaps']) * self.config['reward_impossible_action']
            discountFutureRewards = self.discountedFutureRewardConvolution(mergedPixelFeatureMap) * data['pixelActionMaps'] + (1.0 - data['pixelActionMaps']) * self.config['reward_impossible_action']

            totalReward = (presentRewards + discountFutureRewards)

            outputDict['presentRewards'] = presentRewards
            outputDict['discountFutureRewards'] = discountFutureRewards
        else:
            totalReward = None

        if data["outputStamp"]:
            outputDict["stamp"] = stamp.detach()

        if data['outputFutureSymbolEmbedding']:
            # Compute the embedding based on the symbols provided for the future execution trace
            decayingFutureSymbolEmbedding = self.symbolEmbedding(data['decayingFutureSymbolIndexes'],
                                                    data['decayingFutureSymbolOffsets'],
                                                    per_sample_weights=data['decayingFutureSymbolWeights'])


            outputDict['decayingFutureSymbolEmbedding'] = decayingFutureSymbolEmbedding

        if data["computeActionProbabilities"]:
            actorLogProbs = self.actorConvolution(mergedPixelFeatureMap)

            actorProbExp = torch.exp(actorLogProbs) * data['pixelActionMaps']
            actorProbSums = torch.sum(actorProbExp.reshape(shape=[-1, width * height * self.numActions]), dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
            actorProbSums = torch.max((actorProbSums == 0) * 1e-8, actorProbSums)
            actorActionProbs = actorProbExp / actorProbSums
            actorActionProbs = actorActionProbs.reshape([-1, self.numActions, height, width])

            outputDict["actionProbabilities"] = actorActionProbs

        if data['computeStateValues']:
            shrunkTrainingWidth = int(self.config['training_crop_width'] / 8)
            shrunkTrainingHeight = int(self.config['training_crop_width'] / 8)

            centerCropLeft = torch.div((torch.div(width, 8) - shrunkTrainingWidth), 2)
            centerCropTop = torch.div((torch.div(height, 8) - shrunkTrainingHeight), 2)

            centerCropRight = torch.add(centerCropLeft, shrunkTrainingWidth)
            centerCropBottom = torch.add(centerCropTop, shrunkTrainingHeight)

            croppedPixelFeatureMap = pixelFeatureMap[:, :, centerCropTop:centerCropBottom, centerCropLeft:centerCropRight]

            flatFeatureMap = torch.cat([additionalFeaturesWithStep.reshape(shape=[batchSize, -1]), croppedPixelFeatureMap.reshape(shape=[batchSize, -1])], dim=1)

            stateValuePredictions = self.stateValueLinear(flatFeatureMap)

            outputDict['stateValues'] = stateValuePredictions

        if data['computeAdvantageValues']:
            advantageValues = self.advantageConvolution(mergedPixelFeatureMap) * data['pixelActionMaps'] + (1.0 - data['pixelActionMaps']) * self.config['reward_impossible_action']
            outputDict['advantage'] = advantageValues

        if data['computeExtras']:
            if 'action_type' in data:
                action_types = data['action_type']
                action_xs = data['action_x']
                action_ys = data['action_y']
            else:
                action_types = []
                action_xs = []
                action_ys = []
                for sampleReward in totalReward:
                    action_type = sampleReward.reshape([self.numActions, width * height]).max(dim=1)[0].argmax(0)
                    action_types.append(action_type)

                    action_y = sampleReward[action_type].max(dim=1)[0].argmax(0)
                    action_ys.append(action_y)

                    action_x = sampleReward[action_type, action_y].argmax(0)
                    action_xs.append(action_x)

            forwardFeaturesForAuxillaryLosses = []
            for sampleIndex, action_type, action_x, action_y in zip(range(len(action_types)), action_types, action_xs, action_ys):
                featuresForAuxillaryLosses = mergedPixelFeatureMap[sampleIndex, :, int(action_y / 8), int(action_x / 8)].unsqueeze(0)
                forwardFeaturesForAuxillaryLosses.append(featuresForAuxillaryLosses)

            joinedFeatures = torch.cat(forwardFeaturesForAuxillaryLosses, dim=0)

            if self.config['enable_trace_prediction_loss']:
                outputDict['predictedTraces'] = self.predictedExecutionTraceLinear(joinedFeatures)
            if self.config['enable_execution_feature_prediction_loss']:
                outputDict['predictedExecutionFeatures'] = self.predictedExecutionFeaturesLinear(joinedFeatures)
            if self.config['enable_cursor_prediction_loss']:
                outputDict['predictedCursor'] = self.predictedCursorLinear(joinedFeatures)

        return outputDict
Exemple #39
0
    def forward(self, q_pts, s_pts, neighb_inds, x):

        ###################
        # Offset generation
        ###################

        if self.deformable:

            # Get offsets with a KPConv that only takes part of the features
            self.offset_features = self.offset_conv(q_pts, s_pts, neighb_inds,
                                                    x) + self.offset_bias

            if self.modulated:

                # Get offset (in normalized scale) from features
                unscaled_offsets = self.offset_features[:, :self.p_dim *
                                                        self.K]
                unscaled_offsets = unscaled_offsets.view(
                    -1, self.K, self.p_dim)

                # Get modulations
                modulations = 2 * torch.sigmoid(
                    self.offset_features[:, self.p_dim * self.K:])

            else:

                # Get offset (in normalized scale) from features
                unscaled_offsets = self.offset_features.view(
                    -1, self.K, self.p_dim)

                # No modulations
                modulations = None

            # Rescale offset for this layer
            offsets = unscaled_offsets * self.KP_extent

        else:
            offsets = None
            modulations = None

        ######################
        # Deformed convolution
        ######################

        # Add a fake point in the last row for shadow neighbors
        s_pts = torch.cat((s_pts, torch.zeros_like(s_pts[:1, :]) + 1e6), 0)

        # Get neighbor points [n_points, n_neighbors, dim]
        neighbors = s_pts[neighb_inds, :]

        # Center every neighborhood
        neighbors = neighbors - q_pts.unsqueeze(1)

        # Apply offsets to kernel points [n_points, n_kpoints, dim]
        if self.deformable:
            self.deformed_KP = offsets + self.kernel_points
            deformed_K_points = self.deformed_KP.unsqueeze(1)
        else:
            deformed_K_points = self.kernel_points

        # Get all difference matrices [n_points, n_neighbors, n_kpoints, dim]
        neighbors.unsqueeze_(2)
        differences = neighbors - deformed_K_points

        # Get the square distances [n_points, n_neighbors, n_kpoints]
        sq_distances = torch.sum(differences**2, dim=3)

        # Optimization by ignoring points outside a deformed KP range
        if self.deformable:

            # Save distances for loss
            self.min_d2, _ = torch.min(sq_distances, dim=1)

            # Boolean of the neighbors in range of a kernel point [n_points, n_neighbors]
            in_range = torch.any(sq_distances < self.KP_extent**2,
                                 dim=2).type(torch.int32)

            # New value of max neighbors
            new_max_neighb = torch.max(torch.sum(in_range, dim=1))

            # For each row of neighbors, indices of the ones that are in range [n_points, new_max_neighb]
            neighb_row_bool, neighb_row_inds = torch.topk(
                in_range, new_max_neighb.item(), dim=1)

            # Gather new neighbor indices [n_points, new_max_neighb]
            new_neighb_inds = neighb_inds.gather(1,
                                                 neighb_row_inds,
                                                 sparse_grad=False)

            # Gather new distances to KP [n_points, new_max_neighb, n_kpoints]
            neighb_row_inds.unsqueeze_(2)
            neighb_row_inds = neighb_row_inds.expand(-1, -1, self.K)
            sq_distances = sq_distances.gather(1,
                                               neighb_row_inds,
                                               sparse_grad=False)

            # New shadow neighbors have to point to the last shadow point
            new_neighb_inds *= neighb_row_bool
            new_neighb_inds -= (neighb_row_bool.type(torch.int64) -
                                1) * int(s_pts.shape[0] - 1)
        else:
            new_neighb_inds = neighb_inds

        # Get Kernel point influences [n_points, n_kpoints, n_neighbors]
        if self.KP_influence == 'constant':
            # Every point get an influence of 1.
            all_weights = torch.ones_like(sq_distances)
            all_weights = torch.transpose(all_weights, 1, 2)

        elif self.KP_influence == 'linear':
            # Influence decrease linearly with the distance, and get to zero when d = KP_extent.
            all_weights = torch.clamp(
                1 - torch.sqrt(sq_distances) / self.KP_extent, min=0.0)
            all_weights = torch.transpose(all_weights, 1, 2)

        elif self.KP_influence == 'gaussian':
            # Influence in gaussian of the distance.
            sigma = self.KP_extent * 0.3
            all_weights = radius_gaussian(sq_distances, sigma)
            all_weights = torch.transpose(all_weights, 1, 2)
        else:
            raise ValueError(
                'Unknown influence function type (config.KP_influence)')

        # In case of closest mode, only the closest KP can influence each point
        if self.aggregation_mode == 'closest':
            neighbors_1nn = torch.argmin(sq_distances, dim=2)
            all_weights *= torch.transpose(
                nn.functional.one_hot(neighbors_1nn, self.K), 1, 2)

        elif self.aggregation_mode != 'sum':
            raise ValueError(
                "Unknown convolution mode. Should be 'closest' or 'sum'")

        # Add a zero feature for shadow neighbors
        x = torch.cat((x, torch.zeros_like(x[:1, :])), 0)

        # Get the features of each neighborhood [n_points, n_neighbors, in_fdim]
        neighb_x = gather(x, new_neighb_inds)

        # Apply distance weights [n_points, n_kpoints, in_fdim]
        weighted_features = torch.matmul(all_weights, neighb_x)

        # Apply modulations
        if self.deformable and self.modulated:
            weighted_features *= modulations.unsqueeze(2)

        # Apply network weights [n_kpoints, n_points, out_fdim]
        weighted_features = weighted_features.permute((1, 0, 2))
        kernel_outputs = torch.matmul(weighted_features, self.weights)

        # Convolution sum [n_points, out_fdim]
        return torch.sum(kernel_outputs, dim=0)
Exemple #40
0
    def apply_rule(self, node: AlignedActionTree, rule: Union[str, int]):
        # if node.label() not in self.query_encoder.grammar.rules_by_type \
        #         or rule not in self.query_encoder.grammar.rules_by_type[node.label()]:
        #     raise Exception("something wrong")
        #     return
        self.nn_states["prev_action"] = torch.ones_like(self.nn_states["prev_action"]) \
                                        * self.query_encoder.vocab_actions[rule]
        self.out_rules.append(rule)
        assert (node == self.open_nodes[0])
        if isinstance(rule, str):
            ruleid = self.query_encoder.vocab_actions[rule]
            rulestr = rule
        elif isinstance(rule, int):
            ruleid = rule
            rulestr = self.query_encoder.vocab_actions(rule)

        head, body = rulestr.split(" -> ")
        func_splits = body.split(" :: ")
        sibl_splits = body.split(" -- ")

        if len(sibl_splits) > 1:
            raise Exception("sibling rules no longer supported")

        self.open_nodes.pop(0)

        if node.label(
        )[-1] in "*+" and body != f"{head}:END@":  # variable number of children
            # create new sibling node
            parent = node.parent()
            i = len(parent)

            new_sibl_node = AlignedActionTree(node.label(), [])
            parent.append(new_sibl_node)

            # manage open nodes
            self.open_nodes = ([new_sibl_node]
                               if (new_sibl_node.label() in self.query_encoder.grammar.rules_by_type
                                   or new_sibl_node.label()[:-1] in self.query_encoder.grammar.rules_by_type)
                               else []) \
                              + self.open_nodes

            if self.use_gold:
                gold_child = parent._align[i]
                new_sibl_node._align = gold_child

        if len(func_splits) > 1:
            rule_arg, rule_inptypes = func_splits
            rule_inptypes = rule_inptypes.split(" ")

            # replace label of tree
            node.set_label(rule_arg)
            node.set_action(rule)

            # align to gold
            if self.use_gold:
                gold_children = node._align[:]

            # create children nodes as open non-terminals
            for i, child in enumerate(rule_inptypes):
                child_node = AlignedActionTree(child, [])
                node.append(child_node)

                if self.use_gold:
                    child_node._align = gold_children[i]

            # manage open nodes
            self.open_nodes = [child_node for child_node in node if child_node.label() in self.query_encoder.grammar.rules_by_type]\
                              + self.open_nodes
        else:  # terminal
            node.set_label(body)
            node.set_action(rule)
Exemple #41
0
    def parameterize(self, u, x, t, *additional_tensors):
        r"""Re-parameterizes outputs such that the initial and boundary conditions are satisfied.

        The Initial condition is always :math:`u(x,t_0)=u_0(x)`. There are four boundary conditions that are
        currently implemented:

        - For Dirichlet-Dirichlet boundary condition :math:`u(x_0,t)=g(t)` and :math:`u(x_1,t)=h(t)`:

          The re-parameterization is
          :math:`\displaystyle u(x,t)=A(x,t)+\tilde{x}\big(1-\tilde{x}\big)\Big(1-e^{-\tilde{t}}\Big)\mathrm{ANN}(x,t)`,
          where :math:`\displaystyle A(x,t)=u_0(x)+
          \tilde{x}\big(h(t)-h(t_0)\big)+\big(1-\tilde{x}\big)\big(g(t)-g(t_0)\big)`.

        - For Dirichlet-Neumann boundary condition :math:`u(x_0,t)=g(t)` and :math:`u'_x(x_1, t)=q(t)`:

          The re-parameterization is
          :math:`\displaystyle u(x,t)=A(x,t)+\tilde{x}\Big(1-e^{-\tilde{t}}\Big)
          \Big(\mathrm{ANN}(x,t)-\big(x_1-x_0\big)\mathrm{ANN}'_x(x_1,t)-\mathrm{ANN}(x_1,t)\Big)`,
          where :math:`\displaystyle A(x,t)=u_0(x)+\big(x-x_0\big)\big(q(t)-q(t_0)\big)+\big(g(t)-g(t_0)\big)`.

        - For Neumann-Dirichlet boundary condition :math:`u'_x(x_0,t)=p(t)` and :math:`u(x_1, t)=h(t)`:

          The re-parameterization is
          :math:`\displaystyle u(x,t)=A(x,t)+\big(1-\tilde{x}\big)\Big(1-e^{-\tilde{t}}\Big)
          \Big(\mathrm{ANN}(x,t)-\big(x_1-x_0\big)\mathrm{ANN}'_x(x_0,t)-\mathrm{ANN}(x_0,t)\Big)`,
          where :math:`\displaystyle A(x,t)=u_0(x)+\big(x_1-x\big)\big(p(t)-p(t_0)\big)+\big(h(t)-h(t_0)\big)`.

        - For Neumann-Neumann boundary condition :math:`u'_x(x_0,t)=p(t)` and :math:`u'_x(x_1, t)=q(t)`

          The re-parameterization is
          :math:`\displaystyle u(x,t)=A(x,t)+\left(1-e^{-\tilde{t}}\right)
          \Big(
          \mathrm{ANN}(x,t)-\big(x-x_0\big)\mathrm{ANN}'_x(x_0,t)
          +\frac{1}{2}\tilde{x}^2\big(x_1-x_0\big)
          \big(\mathrm{ANN}'_x(x_0,t)-\mathrm{ANN}'_x(x_1,t)\big)
          \Big)`,
          where :math:`\displaystyle A(x,t)=u_0(x)
          -\frac{1}{2}\big(1-\tilde{x}\big)^2\big(x_1-x_0\big)\big(p(t)-p(t_0)\big)
          +\frac{1}{2}\tilde{x}^2\big(x_1-x_0\big)\big(q(t)-q(t_0)\big)`.

        Notations:

        - :math:`\displaystyle\tilde{t}=\frac{t-t_0}{t_1-t_0}`,
        - :math:`\displaystyle\tilde{x}=\frac{x-x_0}{x_1-x_0}`,
        - :math:`\displaystyle\mathrm{ANN}` is the neural network,
        - and :math:`\displaystyle\mathrm{ANN}'_x=\frac{\partial ANN}{\partial x}`.

        :param output_tensor: Output of the neural network.
        :type output_tensor: `torch.Tensor`
        :param x: The :math:`x`-coordinates of the samples; i.e., the spatial coordinates.
        :type x: `torch.Tensor`
        :param t: The :math:`t`-coordinates of the samples; i.e., the temporal coordinates.
        :type t: `torch.Tensor`
        :param additional_tensors: additional tensors that will be passed by ``enforce``
        :type additional_tensors: `torch.Tensor`
        :return: The re-parameterized output of the network.
        :rtype: `torch.Tensor`
        """

        t0 = self.t_min * torch.ones_like(t, requires_grad=True)
        x_tilde = (x - self.x_min) / (self.x_max - self.x_min)
        t_tilde = t - self.t_min

        if self.x_min_val and self.x_max_val:
            return self._parameterize_dd(u, x, t, x_tilde, t_tilde, t0)
        elif self.x_min_val and self.x_max_prime:
            return self._parameterize_dn(u, x, t, x_tilde, t_tilde, t0, *additional_tensors)
        elif self.x_min_prime and self.x_max_val:
            return self._parameterize_nd(u, x, t, x_tilde, t_tilde, t0, *additional_tensors)
        elif self.x_min_prime and self.x_max_prime:
            return self._parameterize_nn(u, x, t, x_tilde, t_tilde, t0, *additional_tensors)
        else:
            raise NotImplementedError('Sorry, this boundary condition is not implemented.')
    def forward(self,
                query_title: torch.LongTensor,
                query_emb: torch.Tensor,
                pos_author_match: torch.Tensor,
                pos_citation_overlap: torch.Tensor,
                pos_reference_overlap: torch.Tensor,
                pos_cites_query: torch.Tensor,
                query_cites_pos: torch.Tensor,
                pos_oldness: torch.Tensor,
                pos_relative_oldness: torch.Tensor,
                pos_number_citations: torch.Tensor,
                pos_position: torch.Tensor,
                pos_title: Dict[str, torch.Tensor],
                pos_title_match: torch.Tensor,
                pos_emb: torch.Tensor,
                neg_author_match: torch.LongTensor = None,
                neg_citation_overlap: torch.Tensor = None,
                neg_reference_overlap: torch.Tensor = None,
                neg_cites_query: torch.Tensor = None,
                query_cites_neg: torch.Tensor = None,
                neg_oldness: torch.Tensor = None,
                neg_relative_oldness: torch.Tensor = None,
                neg_number_citations: torch.Tensor = None,
                neg_position: torch.LongTensor = None,
                neg_title: Dict[str, torch.Tensor] = None,
                neg_title_match: torch.Tensor = None,
                neg_emb: torch.Tensor = None):
        # query_title["tokens"] is (batch size x num tokens in title)
        batch_size = query_title["tokens"].size(0)
        if self.text_encoder and self.encode_title:
            query_paper_encoding = self._paper_to_vec(query_title["tokens"])
            check_dimensions_match(query_paper_encoding.size(),
                                   (batch_size, self.total_paper_output_size),
                                   "Query paper encoding size",
                                   "Expected paper encoding size")

            pos_paper_encoding = self._paper_to_vec(pos_title["tokens"])
            check_dimensions_match(pos_paper_encoding.size(),
                                   (batch_size, self.total_paper_output_size),
                                   "Positive paper encoding size",
                                   "Expected paper encoding size")
            if neg_title:
                # neg_paper_encoding is (batch size x size of embedding)
                neg_paper_encoding = self._paper_to_vec(neg_title["tokens"])
                check_dimensions_match(
                    neg_paper_encoding.size(),
                    (batch_size, self.total_paper_output_size),
                    "Negative paper encoding size",
                    "Expected paper encoding size")
        #pos_features holds additional features about this instance, is (batch size x num_extra_numeric_features)
        if self.project_query:
            proj_query_emb = self.query_projection(query_emb)
        else:
            proj_query_emb = query_emb

        pos_emb_sim = torch.nn.functional.cosine_similarity(proj_query_emb,
                                                            pos_emb,
                                                            dim=1)
        pos_emb_sim = pos_emb_sim.view(-1, 1)

        pos_features = torch.cat([
            pos_author_match, pos_position, pos_title_match, pos_emb_sim,
            pos_citation_overlap, pos_reference_overlap, pos_cites_query,
            query_cites_pos, pos_oldness, pos_relative_oldness,
            pos_number_citations
        ],
                                 dim=1)
        check_dimensions_match(pos_features.size(),
                               (batch_size, self.num_extra_numeric_features),
                               "Positive features size",
                               "Expected positive features size")

        # positive_paper_score is (batch size x 1)
        if self.text_encoder and self.encode_title:
            positive_paper_score = self.ff_positive_score(
                torch.cat(
                    [query_paper_encoding, pos_paper_encoding, pos_features],
                    dim=1))
        else:
            positive_paper_score = self.ff_positive_score(pos_features)
        check_dimensions_match(positive_paper_score.size(), (batch_size, 1),
                               "Positive score size",
                               "Expected positive scoresize")
        if neg_title:
            # negative_paper_score is (batch size x 1)
            neg_emb_sim = torch.nn.functional.cosine_similarity(proj_query_emb,
                                                                neg_emb,
                                                                dim=1)
            neg_emb_sim = neg_emb_sim.view(-1, 1)

            neg_features = torch.cat([
                neg_author_match, neg_position, neg_title_match, neg_emb_sim,
                neg_citation_overlap, neg_reference_overlap, neg_cites_query,
                query_cites_neg, neg_oldness, neg_relative_oldness,
                neg_number_citations
            ],
                                     dim=1)
            if self.text_encoder and self.encode_title:
                negative_paper_score = self.ff_positive_score(
                    torch.cat([
                        query_paper_encoding, neg_paper_encoding, neg_features
                    ],
                              dim=1))
            else:
                negative_paper_score = self.ff_positive_score(neg_features)
            check_dimensions_match(negative_paper_score.size(),
                                   (batch_size, 1), "negative score size",
                                   "Expected negative score size")
            long_pos_position = torch.round(pos_position * 10 - 1).long()
            propensity_score = self.adj_click_distribution[long_pos_position]
            # loss is a batch size x 1 vector
            loss = self.loss(positive_paper_score, negative_paper_score,
                             torch.ones_like(positive_paper_score))
            loss = loss / propensity_score
            loss = torch.mean(loss)
            check_dimensions_match(loss.dim(), 0, "Loss size",
                                   "Expected loss size")

        output = {}
        output['pos_score'] = positive_paper_score
        if neg_title:
            self.accuracy(
                torch.cat([positive_paper_score, negative_paper_score], dim=1),
                torch.zeros(len(positive_paper_score)))
            self.saved_loss(
                loss.item()
            )  #NOTE averages across batches, which is a bit wrong unless total examples is divisible by batch size
            output['neg_score'] = negative_paper_score
            output['loss'] = loss
        return output
    def __init__(self, pos, mass, max_levels=100, device='cpu'):
        super().__init__()
        self.device = device

        self.num_levels = 0
        self.max_levels = max_levels

        self.num_dim = pos.shape[1]
        self.num_o = 2**self.num_dim

        min_val = torch.min(pos) - 1e-4
        max_val = torch.max(pos) + 1e-4
        self.size = max_val - min_val

        norm_pos = (pos - min_val.unsqueeze(0)) / self.size.unsqueeze(
            0)  # normalized position of all points

        # level-wise tree parameters (list index is the corresponding level)
        self.node_mass = []
        self.center_of_mass = []
        self.is_end_node = []
        self.node_indexing = []

        point_nodes = torch.zeros(
            pos.shape[0], dtype=torch.long, device=self.device
        )  # node in which each point falls on the current level
        num_nodes = 1

        while True:
            self.num_levels += 1
            num_divisions = 2**self.num_levels

            # calculate the orthant in which each point falls
            point_orthant = torch.floor(norm_pos * num_divisions).long()
            point_orthant = (point_orthant % 2) * (2**torch.arange(
                self.num_dim, device=self.device).unsqueeze(0))
            point_orthant = torch.sum(point_orthant, dim=1)

            # calculate node indices from point orthants
            point_nodes *= self.num_o
            point_nodes += point_orthant

            # calculate total mass of each section
            node_mass = torch.zeros(num_nodes * self.num_o, device=self.device)
            node_mass.scatter_add_(0, point_nodes, mass)

            # calculate center of mass of each node
            node_com = torch.zeros(num_nodes * self.num_o,
                                   self.num_dim,
                                   device=self.device)
            for d in range(self.num_dim):
                node_com[:, d].scatter_add_(0, point_nodes, pos[:, d] * mass)
            node_com /= node_mass.unsqueeze(1)

            # determine if node is end node
            point_is_continued = node_mass[
                point_nodes] > mass  # only points that are not the only ones in their node are passed on to the next level
            end_nodes = point_nodes[
                point_is_continued ==
                0]  # nodes with only one point are end nodes
            is_end_node = torch.zeros(num_nodes * self.num_o,
                                      device=self.device,
                                      dtype=torch.bool)
            is_end_node[end_nodes] = 1

            node_is_continued = node_mass > 0.
            non_empty_nodes = node_is_continued.nonzero().squeeze(
                1)  # indices of non-empty nodes
            num_nodes = non_empty_nodes.shape[0]

            # create new node indexing: only non-empty nodes have positive indices, end nodes have the index -1
            node_indexing = self.create_non_empty_node_indexing(
                non_empty_nodes, node_mass.shape[0], self.num_o)

            # only pass on nodes that are continued
            is_end_node = is_end_node[node_is_continued]
            node_mass = node_mass[node_is_continued]
            node_com = node_com[node_is_continued, :]

            self.node_mass.append(node_mass)
            self.center_of_mass.append(node_com)
            self.node_indexing.append(node_indexing)
            self.is_end_node.append(is_end_node)

            # update the node index of each point
            point_nodes = node_indexing[point_nodes / self.num_o,
                                        point_nodes % self.num_o]

            # discard points in end nodes
            pos = pos[point_is_continued]
            mass = mass[point_is_continued]
            point_nodes = point_nodes[point_is_continued]
            norm_pos = norm_pos[point_is_continued]

            if torch.sum(point_is_continued) < 1:
                break
            if self.num_levels >= self.max_levels:
                num_points_in_nodes = torch.zeros_like(node_mass,
                                                       dtype=torch.long)
                num_points_in_nodes.scatter_add_(
                    0, point_nodes, torch.ones_like(mass, dtype=torch.long))
                max_points_in_node = torch.max(num_points_in_nodes)
                non_empty_nodes, index_of_point = torch.unique(
                    point_nodes, return_inverse=True)
                node_index_of_point = non_empty_nodes[index_of_point]
                scatter_indices = torch.arange(
                    node_index_of_point.shape[0],
                    device=self.device) % max_points_in_node
                point_order = torch.argsort(node_index_of_point)
                node_indexing = torch.zeros(num_nodes,
                                            max_points_in_node,
                                            dtype=torch.long,
                                            device=self.device) - 1
                node_indexing[node_index_of_point[point_order],
                              scatter_indices] = torch.arange(
                                  node_index_of_point.shape[0],
                                  device=self.device)[point_order]
                self.node_mass.append(mass)
                self.center_of_mass.append(pos)
                self.node_indexing.append(node_indexing)
                self.is_end_node.append(torch.ones_like(mass,
                                                        dtype=torch.bool))
                #print("too many levels!")
                break
Exemple #44
0
def train(args):
    dataset = SimpsonDataset(args.image_path)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    latent_space = Z_Generator()
    generator = Generator()
    discriminator = Discriminator()

    # Initalize model's parameters
    generator.apply(weights_init)
    discriminator.apply(weights_init)

    # Load models
    if args.model_load_flag:
        generator.load_state_dict(
            torch.load(os.path.join(args.model_path,
                                    args.generator_load_name)))
        discriminator.load_state_dict(
            torch.load(
                os.path.join(args.model_path, args.discriminator_load_name)))

    # Use GPU, if it's available
    if _CUDA_FLAG:
        generator.cuda()
        discriminator.cuda()

    # Loss function
    g_criterion = torch.nn.BCELoss()
    d_criterion = torch.nn.BCELoss()

    # Optimizer
    g_optimizer = torch.optim.Adam(generator.parameters(),
                                   lr=args.learning_rate,
                                   betas=args.betas)
    d_optimizer = torch.optim.Adam(discriminator.parameters(),
                                   lr=args.learning_rate,
                                   betas=args.betas)

    for cur_epoch in range(args.epoch):
        # Train
        generator.train()
        discriminator.train()
        for cur_batch_num, images in enumerate(dataloader):
            if _CUDA_FLAG: images = images.cuda()
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            # Make label for discriminator about real images
            d_toutput = discriminator(images).view(-1)
            d_tlabel = torch.ones_like(d_toutput)
            if _CUDA_FLAG: d_tlabel = d_tlabel.cuda()

            # Calculate loss about real iamges
            d_tloss = d_criterion(d_toutput, d_tlabel)
            d_tloss.backward()

            # Generate fake images from latent space
            latent_vectors = latent_space(len(images))
            if _CUDA_FLAG: latent_vectors = latent_vectors.cuda()
            fake_images = generator(latent_vectors)

            # Make label for discriminator about fake images
            d_foutput = discriminator(fake_images.detach()).view(-1)
            d_flabel = torch.zeros_like(d_foutput)
            if _CUDA_FLAG: d_flabel = d_flabel.cuda()

            # Calculate loss about fake iamges
            d_floss = d_criterion(d_foutput, d_flabel)
            d_floss.backward()

            # Update discriminator's parameters
            d_total_loss = (d_tloss + d_floss) / 2
            if cur_epoch < 50:
                if cur_batch_num % 2 == 0: d_optimizer.step()
            else:
                d_optimizer.step()

            # Make label for generator
            g_output = discriminator(fake_images).view(-1)
            g_label = torch.ones_like(g_output)
            if _CUDA_FLAG: g_label = g_label.cuda()

            # Update generator's parameters
            g_loss = g_criterion(g_output, g_label)
            g_loss.backward()
            g_optimizer.step()

            print("EPOCH {}/{} Iter {}/{} D TLoss {:.6f} FLoss {:.6f} TotalLoss {:.6f} G TotalLoss {:.6f}".format(\
                cur_epoch, args.epoch, cur_batch_num+1, len(dataloader), d_tloss, d_floss, d_total_loss, g_loss))
        if cur_epoch % 30 == 29:
            with torch.no_grad():
                generator.eval()
                # Save several images which are generated from generator model
                generator.cpu()
                latent_vectors = latent_space(20)
                test_images = generator(latent_vectors)
                save_images(test_images.numpy(), args.image_save_path,
                            cur_epoch)

                # Save model's parameters
                generator_save_name = "generator_{}_checkpoint.pth".format(
                    cur_epoch)
                discriminator_save_name = "discriminator_{}_checkpoint.pth".format(
                    cur_epoch)
                torch.save(generator.state_dict(),
                           os.path.join(args.model_path, generator_save_name))
                torch.save(
                    discriminator.state_dict(),
                    os.path.join(args.model_path, discriminator_save_name))
                generator.cuda()
Exemple #45
0
def reset_bn(module):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        module.running_mean = torch.zeros_like(module.running_mean)
        module.running_var = torch.ones_like(module.running_var)
    def get_online_pyramid_level(self, cls_scores_img, bbox_preds_img,
                                 gt_bbox_obj_xyxy, gt_label_obj):
        """
        :param cls_scores_img: 这个图片对应的 cls_scores 信息(也是多个 stage 的信息)
        :param bbox_preds_img: 这个图片对应的 bbox_preds 信息(也是多个 stage 的信息)
        :param gt_bbox_obj_xyxy: 要进行层次选择的 gt bbox 的信息 (x1, y1, x2, y2) 格式
        :param gt_label_obj: 要进行层次选择的 gt bbox 的 label(类别) 信息
        :return:
        """
        device = cls_scores_img[0].device
        # 获取 stage 数量
        num_levels = len(cls_scores_img)
        # 为每个层次的 loss 创建一下统计求和的 placeholder
        level_losses = torch.zeros(num_levels)
        # 开始对每个层次进行统计
        for level in range(num_levels):
            # 获取 feature map 的 长宽
            H, W = cls_scores_img[level].shape[1:]
            # 获取在 feature map 上的 grid 位置
            b_p_xyxy = gt_bbox_obj_xyxy / self.feat_strides[level]
            # 获取 effective 的范围
            b_e_xyxy = self.get_prop_xyxy(b_p_xyxy, self.eps_e, W, H)

            # Eqn-(1)
            # 获取 effective 区域的计数个数
            N = (b_e_xyxy[3] - b_e_xyxy[1] + 1) * (b_e_xyxy[2] - b_e_xyxy[0] +
                                                   1)

            # cls loss; FL
            # 开始计算 focal loss -> classification score loss
            # 首先获取到 effective 区域的 cls_score 的值
            score = cls_scores_img[level][gt_label_obj,
                                          b_e_xyxy[1]:b_e_xyxy[3] + 1,
                                          b_e_xyxy[0]:b_e_xyxy[2] + 1]
            # 把score reshape -> (1, N)
            score = score.contiguous().view(-1).unsqueeze(1)
            # 设置 label 为与 score 相同形状 (1, N) 的全 1 Tensor
            label = torch.ones_like(score).long()
            label = label.contiguous().view(-1)
            # 计算 sigmoid 之后的 focal_loss
            # label 是 weight
            loss_cls = sigmoid_focal_loss(score,
                                          label,
                                          gamma=self.FL_gamma,
                                          alpha=self.FL_alpha,
                                          reduction='mean')
            # 因为已经在 loss 函数中有 "mean" 了... 所以不用除以 N 了
            # loss_cls /= N

            # 开始计算 Bbox regression loss
            # reg loss; IoU
            # 获取到区域内的 bbox 信息
            offsets = bbox_preds_img[level][:, b_e_xyxy[1]:b_e_xyxy[3] + 1,
                                            b_e_xyxy[0]:b_e_xyxy[2] + 1]
            # 调整维度顺序, 从 (channel * height * width) -> (height * width * channel)
            offsets = offsets.contiguous().permute(1, 2, 0)  # (b_e_H,b_e_W,4)
            # reshape -> (height * width *  channel)
            offsets = offsets.reshape(-1, 4)  # (#pix-e,4)
            # PS: 上面拿到的 offsets 中每个 offset 实际上是预测的是到 上下左右四条边 的距离

            # predicted bbox
            # 首先产生 feature map 指定区域的 网格位置信息 -> 要生成目标的 "anchor" 区域的
            y, x = torch.meshgrid([
                torch.arange(b_e_xyxy[1], b_e_xyxy[3] + 1),
                torch.arange(b_e_xyxy[0], b_e_xyxy[2] + 1)
            ])
            # 获取到这个位置上的中心点坐标 (相对于input的原图大小)
            y = (y.float() + 0.5) * self.feat_strides[level]
            x = (x.float() + 0.5) * self.feat_strides[level]
            # 拼接中心位置
            xy = torch.cat([x.unsqueeze(2), y.unsqueeze(2)],
                           dim=2).float().to(device)
            xy = xy.reshape(-1, 2)

            # 将 offsets -> 原图的尺度上
            dist_pred = offsets * self.feat_strides[level]
            # 进行变换获取到 bboxes 信息
            bboxes = self.dist2bbox(xy, dist_pred, self.bbox_offset_norm)
            # loss_reg 通过 iou_loss 进行计算
            loss_reg = iou_loss(bboxes,
                                gt_bbox_obj_xyxy.unsqueeze(0).repeat(N, 1),
                                reduction='mean')
            # 同样是因为 reduction = mean 所以不需要 /= N
            # PS: /= 操作好像是不行的... pytorch 要 loss_reg = loss_reg / N 这样的操作
            # loss_reg /= N
            # 计算当前的 stage 的 loss
            loss = loss_cls + loss_reg

            level_losses[level] = loss
        # 找到最小的 loss 的区域, 然后返回这个维度的 loss 信息
        min_level_idx = torch.argmin(level_losses)
        # print(level_losses, min_level_idx)
        return min_level_idx
 def exp_multi_wd(outputs, targets, weights, wd):
     neg_one = -torch.ones_like(
         outputs, device=outputs.device, dtype=outputs.dtype)
     return torch.exp(-outputs * neg_one.scatter_(1, targets.view(-1, 1), 1)).sum(dim=1).mean() + \
            sum([(torch.exp(wd * w) - 1).sum() + (torch.exp(-wd * w) - 1).sum() for w in weights]) / 2
Exemple #48
0
                loss = (trH + .5 * norm_s).mean()

            elif args.mode == "nce":
                noise_dist = distributions.Normal(init_mu, init_std)
                x_fake = noise_dist.sample_n(x.size(0))

                pos_logits = modelICA(
                    x) + approx_normalizing_const - noise_dist.log_prob(x).sum(
                        1)
                neg_logits = modelICA(
                    x_fake) + approx_normalizing_const - noise_dist.log_prob(
                        x_fake).sum(1)

                pos_loss = nn.BCEWithLogitsLoss()(pos_logits,
                                                  torch.ones_like(pos_logits))
                neg_loss = nn.BCEWithLogitsLoss()(neg_logits,
                                                  torch.zeros_like(neg_logits))
                loss = pos_loss + neg_loss

            elif args.mode == "mle":
                loss = -modelICA.log_prob(x).mean()

            elif args.mode.startswith("cnce-"):
                eps = float(args.mode.split('-')[1])
                x_pert = x + torch.randn_like(x) * eps
                logits = modelICA(x) - modelICA(x_pert)
                loss = nn.BCEWithLogitsLoss()(logits, torch.ones_like(logits))

            else:
                assert False
 def gen_backward(self):
     ones = One(torch.ones_like(self.input))
     return [ones]
Exemple #50
0
    def reverse(self, input, hidden, buf, slice_dim=0, saved_hidden=None, mask=None):
        if buf is None:
            return saved_hidden.clone() 

        buf_h1, buf_h2, buf_c1, buf_c2 = buf
        group_size = self.h_size // 2
        h1 = hidden[:, :group_size]
        h2 = hidden[:, group_size:2*group_size]
        c1 = hidden[:, 2*group_size:3*group_size]
        c2 = hidden[:, 3*group_size:]
        mask = torch.ones_like(h1) if mask is None else mask[:, None].expand_as(h1)

        # Compute concatenated gates used to update h2, c2.
        h1_fl = ConvertToFloat.apply(h1, hidden_radix)
        zgfop2 = self.ih1_to_zgfop2(torch.cat([input, h1_fl], dim=1))

        # Compute gates used to update h2
        o2 = F.sigmoid(zgfop2[:, 3*group_size:4*group_size])
        p2 = F.sigmoid(zgfop2[:, 4*group_size:])
        p2 = self.max_forget * p2 + (1 - self.max_forget)

        # Reverse update/forgetting for h2.
        c2_fl = ConvertToFloat.apply(c2, hidden_radix)
        update_h2 = ConvertToFixed.apply(o2 * F.tanh(c2_fl), hidden_radix)
        h2 = h2 - update_h2 * mask
        p2_fix = ConvertToFixed.apply(p2, forget_radix)
        p2_fix = EnsureNotForgetAll.apply(p2_fix, self.max_forget)
        h2 = FixedDivide.apply(h2, p2_fix, buf_h2, mask)

        # Compute gates used to update c2.
        z2 = F.sigmoid(zgfop2[:, :group_size])
        z2 = self.max_forget * z2 + (1 - self.max_forget)
        g2 = F.tanh(zgfop2[:, group_size:2*group_size])
        f2 = F.sigmoid(zgfop2[:, 2*group_size:3*group_size])

        # Reverse update/forgetting for c2.
        update_c2 = ConvertToFixed.apply(f2 * g2, hidden_radix)
        c2 = c2 - update_c2 * mask
        z2_fix = ConvertToFixed.apply(z2, forget_radix)
        z2_fix = EnsureNotForgetAll.apply(z2_fix, self.max_forget)
        c2 = FixedDivide.apply(c2, z2_fix, buf_c2, mask)

        # Compute concatenated gates used to update h1, c1.
        h2_fl = ConvertToFloat.apply(h2, hidden_radix)
        zgfop1 = self.ih2_to_zgfop1(torch.cat([input, h2_fl], dim=1))

        # Compute gates used to update h1.
        o1 = F.sigmoid(zgfop1[:, 3*group_size:4*group_size])
        p1 = F.sigmoid(zgfop1[:, 4*group_size:])
        p1 = self.max_forget * p1 + (1 - self.max_forget)

        # Reverse update/forgetting for h1.
        c1_fl = ConvertToFloat.apply(c1, hidden_radix)
        update_h1 = ConvertToFixed.apply(o1 * F.tanh(c1_fl), hidden_radix)
        h1 = h1 - update_h1 * mask
        p1_fix = ConvertToFixed.apply(p1, forget_radix)
        p1_fix = EnsureNotForgetAll.apply(p1_fix, self.max_forget)
        h1 = FixedDivide.apply(h1, p1_fix, buf_h1, mask, slice_dim)
        if slice_dim > 0:
            h1[:, :slice_dim] = saved_hidden

        # Compute gates used to update c1.
        z1 = F.sigmoid(zgfop1[:, :group_size])
        z1 = self.max_forget * z1 + (1 - self.max_forget)
        g1 = F.tanh(zgfop1[:, group_size:2*group_size])
        f1 = F.sigmoid(zgfop1[:, 2*group_size:3*group_size])

        # Apply update/forgetting for c1.
        update_c1 = ConvertToFixed.apply(f1 * g1, hidden_radix)
        c1 = c1 - update_c1 * mask
        z1_fix = ConvertToFixed.apply(z1, forget_radix) 
        z1_fix = EnsureNotForgetAll.apply(z1_fix, self.max_forget)
        c1 = FixedDivide.apply(c1, z1_fix, buf_c1, mask)

        hidden = torch.cat([h1, h2, c1, c2], dim=1)
        return hidden
Exemple #51
0
 def forward(self, a):
     return torch.ones_like(a, dtype=torch.float64, pin_memory=False)
Exemple #52
0
 def forward(self, a):
     return torch.ones_like(a, dtype=torch.float32)
Exemple #53
0
import torch, torchvision
import numpy as np

data = [[1, 2], [3, 4]]
x_data = torch.tensor(data)
np_array = np.array(data)
x_np = torch.from_numpy(np_array)

x_ones = torch.ones_like(x_data)
x_rand = torch.rand_like(x_data, dtype=torch.float)

shape = (
    2,
    3,
)
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)

tensor = torch.rand(3, 4)
print(f"Shape of tensor: {tensor.shape}")
print(f"Datatype of tensor: {tensor.dtype}")
print(f"Device tensor is stored on: {tensor.device}")

if torch.cuda.is_available():
    tensor = tensor.to('cuda')

tensor = torch.ones(4, 4)
tensor[:, 1] = 0
print(tensor)
    def sample(self, datas):
        mean, std = datas

        distribution    = Normal(torch.zeros_like(mean), torch.ones_like(std))
        rand            = distribution.sample().float().to(set_device(self.use_gpu))
        return (mean + std.squeeze() * rand).squeeze(0)
    def fsaf_target(self,
                    cls_scores,
                    bbox_preds,
                    gt_bboxes,
                    gt_labels,
                    img_metas,
                    cfg,
                    gt_bboxes_ignore_list=None):
        """
        :param cls_scores: 多个 stage 中 bbox forward 得到的分类的分数, 每个entry对应的 shape : classNum * height * width
        :param bbox_preds: 多个 stage 中 bbox forward 得到的bbox regression的结果, 每个entry对应的 shape : 4 * height * width
        :param gt_bboxes: GT bboxes 的信息
        :param gt_labels: GT label 的信息
        # 下面的参数都没有用到... 不做说明了
        :param img_metas:
        :param cfg:
        :param gt_bboxes_ignore_list:
        :return:
        """
        # 首先获取设备类型a, 方便后续进行计算的时候设备类型的统一
        device = cls_scores[0].device

        # target placeholder(记录target的两个list)
        num_levels = len(cls_scores)
        cls_targets_list = []
        reg_targets_list = []
        # 首先进行初始化 对每个阶段都进行一下初始化, 每个阶段都用其形状的全 0/1 Tensor 进行初始化
        for level in range(num_levels):
            cls_targets_list.append(
                torch.zeros_like(cls_scores[level]).long())  # 0 init
            reg_targets_list.append(torch.ones_like(bbox_preds[level]) *
                                    -1)  # -1 init

        # detached network prediction for online GT generation
        # 获取 图片长度(每个img 的 bboxes 对应 gt_bboxes 中的一个维度)
        num_imgs = len(gt_bboxes)
        # 获取 进行了 detach 的 cls_scores 与 bbox_preds
        cls_scores_list = []
        bbox_preds_list = []
        for img in range(num_imgs):
            # detached prediction for online pyramid level selection
            cls_scores_list.append([lvl[img].detach() for lvl in cls_scores])
            bbox_preds_list.append([lvl[img].detach() for lvl in bbox_preds])
        # 开始进行 GT 匹配
        # generate online GT
        num_imgs = len(gt_bboxes)
        for img in range(num_imgs):
            # sort objects according to their size
            # 取出来所有的 gt_bboxes 信息, 这个时候取出来的是 (x1, y1, x2, y2) 的形式
            gt_bboxes_img_xyxy = gt_bboxes[img]
            # 将 (x1, y1, x2, y2) 形式转化为 (x_center, y_center, width, height) 的形式 em... 直接给变成 int 类型的结果了...
            # 同时这步转换会将 gt_bboxes_img_xywh 转变为 Tensor 类型
            gt_bboxes_img_xywh = self.xyxy2xywh(gt_bboxes_img_xyxy)
            # 获取 GT Bbox 的 size
            gt_bboxes_img_size = gt_bboxes_img_xywh[:,
                                                    -2] * gt_bboxes_img_xywh[:,
                                                                             -1]
            # 获取 gt_bboxes_img_size 的排序后顺序对应的 index, 从大到小进行排序
            _, gt_bboxes_img_idx = gt_bboxes_img_size.sort(descending=True)
            # 对每个 GT bbox 进行一定操作
            for obj_idx in gt_bboxes_img_idx:
                # 因为看了配置文件, 选择使用 sigmoid 对 cls 的预测分数进行激活, 是不要 bg 这个类别的(猜测是完全依靠阈值进行排除,
                # 因为没有 softmax 更改了其预测的比例)
                label = gt_labels[img][obj_idx] - 1
                # 获取这个 gt bbox 的形状信息
                gt_bbox_obj_xyxy = gt_bboxes_img_xyxy[obj_idx]
                # get optimal online pyramid level for each object
                # 获取这个 bbox 应该被分配的 stage level
                # 这个传递进去的 cls_scores_list[img], bbox_preds_list[img] 是用来读取信息的... 这种传递引用过去的如果进行修改了那就真的修改了...
                opt_level = self.get_online_pyramid_level(
                    cls_scores_list[img], bbox_preds_list[img],
                    gt_bbox_obj_xyxy, label)
                # 获取到分类的 effective 和 ignore 区域
                # get the effective/ignore area
                # 获取当前 stage 的 height 和 width
                H, W = cls_scores[opt_level].shape[2:]
                # 获取这个 gt anchor 在分配到的 stage level 的实际大小(相对于 feature map 的 grid 的大小)
                # 因为 gt bbox 信息都是基于原图大小的信息, 所以我们获取之后要在特定层次上放缩到对应的比例
                b_p_xyxy = gt_bbox_obj_xyxy / self.feat_strides[opt_level]
                # 使用 get_spatial_idx 对 b_p_xyxy 进行处理 获取到空间上 effective 和 ignore 的空间区域 的 mask
                e_spatial_idx, i_spatial_idx = self.get_spatial_idx(
                    b_p_xyxy, W, H, device)

                # cls-GT
                # fill prob= 1 for the effective area
                # cls 的 effective 区域进行赋值为 1
                cls_targets_list[opt_level][img, label, e_spatial_idx] = 1

                # fill prob=-1 for the ignoring area
                # 对 cls 的 ignore 区域进行赋值为 -1
                # 这个步骤是为了防止 ignore 直接将重叠了的 gt 区域覆盖为 ignore 区域了而设置的操作
                _i_spatial_idx = cls_targets_list[opt_level][
                    img, label] * i_spatial_idx.long()
                i_spatial_idx = i_spatial_idx - (_i_spatial_idx == 1).type(
                    torch.float32)
                i_spatial_idx = i_spatial_idx.long()
                cls_targets_list[opt_level][img, label, i_spatial_idx] = -1

                # fill prob=-1 for the adjacent ignoring area; lower
                # 向下进行邻近层次的 ignoring 区域的填充
                if opt_level != 0:
                    H_l, W_l = cls_scores[opt_level - 1].shape[2:]
                    b_p_xyxy_l = gt_bbox_obj_xyxy / self.feat_strides[opt_level
                                                                      - 1]
                    _, i_spatial_idx_l = self.get_spatial_idx(
                        b_p_xyxy_l, W_l, H_l, device)
                    # preserve cls-gt that is already filled as effective area
                    _i_spatial_idx_l = cls_targets_list[opt_level - 1][
                        img, label] * i_spatial_idx_l.long()
                    i_spatial_idx_l = i_spatial_idx_l - (
                        _i_spatial_idx_l == 1).type(torch.float32)
                    i_spatial_idx_l = i_spatial_idx_l.long()
                    cls_targets_list[opt_level -
                                     1][img, label][i_spatial_idx_l] = -1

                # fill prob=-1 for the adjacent ignoring area; upper
                # 向上进行临近层次的 ignoring 区域的填充
                if opt_level != num_levels - 1:
                    H_u, W_u = cls_scores[opt_level + 1].shape[2:]
                    b_p_xyxy_u = gt_bbox_obj_xyxy / self.feat_strides[opt_level
                                                                      + 1]
                    _, i_spatial_idx_u = self.get_spatial_idx(
                        b_p_xyxy_u, W_u, H_u, device)
                    # preserve cls-gt that is already filled as effective area
                    _i_spatial_idx_u = cls_targets_list[opt_level + 1][
                        img, label] * i_spatial_idx_u.long()
                    i_spatial_idx_u = i_spatial_idx_u - (
                        _i_spatial_idx_u == 1).type(torch.float32)
                    i_spatial_idx_u = i_spatial_idx_u.long()
                    cls_targets_list[opt_level +
                                     1][img, label][i_spatial_idx_u] = -1

                # reg-GT
                reg_targets_list[opt_level][
                    img, :, e_spatial_idx] = gt_bbox_obj_xyxy.unsqueeze(1)
        return cls_targets_list, reg_targets_list
Exemple #56
0
 def forward_predict(self, x, Nsamples=0):
     """This function is different from forward to compactly represent eval functions"""
     mu = self.forward(x)
     return mu, torch.ones_like(mu) * 0 # TODO: torch.zeros_like?
Exemple #57
0
    def forward(self, input, hidden, buf=None, slice_dim=0, mask=None):
        """
        Arguments:
            input (FloatTensor): Of size (batch_size, in_size)
            hidden (IntTensor): Of size (batch_size, 2 * h_size)
        """
        if buf is not None:
            buf_h1, buf_h2, buf_c1, buf_c2 = buf
        else:
            buf_h1 = buf_h2 = buf_c1 = buf_c2 = None

        group_size = self.h_size // 2
        h1 = hidden[:, :group_size]
        h2 = hidden[:, group_size:2*group_size]
        c1 = hidden[:, 2*group_size:3*group_size]
        c2 = hidden[:, 3*group_size:]
        mask = torch.ones_like(h1) if mask is None else mask[:, None].expand_as(h1)

        # Compute concatenated gates required to update h1, c1.
        h2_fl = ConvertToFloat.apply(h2, hidden_radix)
        zgfop1 = self.ih2_to_zgfop1(torch.cat([input, h2_fl], dim=1))

        # Compute gates necessary to update c1.
        z1 = F.sigmoid(zgfop1[:, :group_size])
        z1 = self.max_forget * z1 + (1 - self.max_forget)
        g1 = F.tanh(zgfop1[:, group_size:2*group_size])
        f1 = F.sigmoid(zgfop1[:, 2*group_size:3*group_size])

        # Apply update/forgetting for c1.
        z1_fix = ConvertToFixed.apply(z1, forget_radix) 
        z1_fix = EnsureNotForgetAll.apply(z1_fix, self.max_forget)
        c1 = FixedMultiply.apply(c1, z1_fix, buf_c1, mask)
        update_c1 = ConvertToFixed.apply(f1 * g1, hidden_radix)
        c1 = c1 + MaskFixedMultiply.apply(update_c1, mask)

        # Compute gates necessary to update h1.
        o1 = F.sigmoid(zgfop1[:, 3*group_size:4*group_size])
        p1 = F.sigmoid(zgfop1[:, 4*group_size:])
        p1 = self.max_forget * p1 + (1 - self.max_forget)

        # Apply update/forgetting for h1.
        c1_fl = ConvertToFloat.apply(c1, hidden_radix)
        p1_fix = ConvertToFixed.apply(p1, forget_radix)
        p1_fix = EnsureNotForgetAll.apply(p1_fix, self.max_forget)
        h1 = FixedMultiply.apply(h1, p1_fix, buf_h1, mask, slice_dim)
        update_h1 = ConvertToFixed.apply(o1 * F.tanh(c1_fl), hidden_radix)
        h1 = h1 + MaskFixedMultiply.apply(update_h1, mask)

        # Compute concatenated gates required to update h2, c2.
        h1_fl = ConvertToFloat.apply(h1, hidden_radix)
        zgfop2 = self.ih1_to_zgfop2(torch.cat([input, h1_fl], dim=1))

        # Compute gates necessary to update c2.
        z2 = F.sigmoid(zgfop2[:, :group_size])
        z2 = self.max_forget * z2 + (1 - self.max_forget)
        g2 = F.tanh(zgfop2[:, group_size:2*group_size])
        f2 = F.sigmoid(zgfop2[:, 2*group_size:3*group_size])

        # Apply update/forgetting for c2.
        z2_fix = ConvertToFixed.apply(z2, forget_radix)
        z2_fix = EnsureNotForgetAll.apply(z2_fix, self.max_forget)
        c2 = FixedMultiply.apply(c2, z2_fix, buf_c2, mask)
        update_c2 = ConvertToFixed.apply(f2 * g2, hidden_radix)
        c2 = c2 + MaskFixedMultiply.apply(update_c2, mask)

        # Compute gates necessary to update h2
        o2 = F.sigmoid(zgfop2[:, 3*group_size:4*group_size])
        p2 = F.sigmoid(zgfop2[:, 4*group_size:])
        p2 = self.max_forget * p2 + (1 - self.max_forget)

        # Apply update/forgetting for h2.
        c2_fl = ConvertToFloat.apply(c2, hidden_radix)
        p2_fix = ConvertToFixed.apply(p2, forget_radix)
        p2_fix = EnsureNotForgetAll.apply(p2_fix, self.max_forget)
        h2 = FixedMultiply.apply(h2, p2_fix, buf_h2, mask)
        update_h2 = ConvertToFixed.apply(o2 * F.tanh(c2_fl), hidden_radix)
        h2 = h2 + MaskFixedMultiply.apply(update_h2, mask)

        recurrent_hidden = torch.cat([h1, h2, c1, c2], dim=1)
        output_hidden = ConvertToFloat.apply(torch.cat([h1, h2], dim=1), hidden_radix)
        hidden_dict = {"recurrent_hidden": recurrent_hidden, "output_hidden": output_hidden}

        nonattn_p1 = p1[:, slice_dim:]
        optimal_bits1 = torch.sum(-torch.log(z1.data) / ln_2) + torch.sum(-torch.log(nonattn_p1.data) / ln_2)
        optimal_bits2 = torch.sum(-torch.log(z2.data) / ln_2) + torch.sum(-torch.log(p2.data) / ln_2)
        stats = {"optimal_bits": optimal_bits1 + optimal_bits2 + 32*slice_dim*input.size(0)}

        return hidden_dict, stats
Exemple #58
0
    def step(self, closure):
        """Performs a single optimization step.
        Arguments:
            closure (callable): A closure that reevaluates the model
                and returns the loss.
        """
        if closure is None:
            raise RuntimeError(
                'For now, Vadam only supports that the model/loss can be reevaluated inside the step function'
            )

        grads = []
        grads2 = []
        for group in self.param_groups:
            for p in group['params']:
                grads.append([])
                grads2.append([])

        # Compute grads and grads2 using num_samples MC samples
        for s in range(self.num_samples):

            # Sample noise for each parameter
            pid = 0
            original_values = {}
            for group in self.param_groups:
                for p in group['params']:

                    original_values.setdefault(pid, p.detach().clone())
                    state = self.state[p]
                    # State initialization
                    if len(state) == 0:
                        state['step'] = 0
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(p.data)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.ones_like(p.data) * (
                            group['prec_init'] -
                            group['prior_prec']) / self.train_set_size

                    # A noisy sample
                    raw_noise = torch.normal(mean=torch.zeros_like(p.data),
                                             std=1.0)
                    p.data.addcdiv_(
                        1., raw_noise,
                        torch.sqrt(self.train_set_size * state['exp_avg_sq'] +
                                   group['prior_prec']))

                    pid = pid + 1

            # Call the loss function and do BP to compute gradient
            loss = closure()

            # Replace original values and store gradients
            pid = 0
            for group in self.param_groups:
                for p in group['params']:

                    # Restore original parameters
                    p.data = original_values[pid]

                    if p.grad is None:
                        continue

                    if p.grad.is_sparse:
                        raise RuntimeError(
                            'Vadam does not support sparse gradients')

                    # Aggregate gradients
                    g = p.grad.detach().clone()
                    if s == 0:
                        grads[pid] = g
                        grads2[pid] = g**2
                    else:
                        grads[pid] += g
                        grads2[pid] += g**2

                    pid = pid + 1

        # Update parameters and states
        pid = 0
        for group in self.param_groups:
            for p in group['params']:

                if grads[pid] is None:
                    continue

                # Compute MC estimate of g and g2
                grad = grads[pid].div(self.num_samples)
                grad2 = grads2[pid].div(self.num_samples)

                tlambda = group['prior_prec'] / self.train_set_size

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                beta1, beta2 = group['betas']

                state['step'] += 1

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1,
                                         grad + tlambda * original_values[pid])
                exp_avg_sq.mul_(beta2).add_(1 - beta2, grad2)

                bias_correction1 = 1 - beta1**state['step']
                bias_correction2 = 1 - beta2**state['step']

                numerator = exp_avg.div(bias_correction1)
                denominator = exp_avg_sq.div(bias_correction2).sqrt().add(
                    tlambda)

                # Update parameters
                p.data.addcdiv_(-group['lr'], numerator, denominator)

                pid = pid + 1

        return loss
    def forward(self, source_ids, target_ids, num_source_tokens, num_target_tokens,pseudo_ids=None, target_span_ids=None):
        # note that here the source ids must not include any pseudo labels
        relevant_scores, _, relevant_doc_features = self.retrieval(
            input_ids = source_ids
        )
        # reconstruct source ids and target_ids
        # print(source_ids.shape)
        # print(target_ids.shape)
        # print(len(num_source_tokens))
        # print(len(num_target_tokens))
        # print(len(relevant_doc_features))
        source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens = self.concator.concate(source_ids, relevant_doc_features, target_ids, num_source_tokens, num_target_tokens)
        # print(source_ids)
        # print(pseudo_ids)
        source_ids = torch.tensor(source_ids, dtype=torch.long)
        target_ids = torch.tensor(target_ids, dtype=torch.long)
        pseudo_ids = torch.tensor(pseudo_ids, dtype=torch.long)
        num_source_tokens = torch.tensor(num_source_tokens,dtype=torch.long)
        num_target_tokens = torch.tensor(num_target_tokens, dtype=torch.long)


        source_len = source_ids.size(1)
        target_len = target_ids.size(1)
        pseudo_len = pseudo_ids.size(1)
        assert target_len == pseudo_len
        assert source_len > 0 and target_len > 0
        split_lengths = (source_len, target_len, pseudo_len)
        input_ids = torch.cat((source_ids, target_ids, pseudo_ids), dim=1)
        token_type_ids = torch.cat(
            (torch.ones_like(source_ids) * self.source_type_id,
             torch.ones_like(target_ids) * self.target_type_id,
             torch.ones_like(pseudo_ids) * self.target_type_id), dim=1)

        source_mask, source_position_ids = \
            self.create_mask_and_position_ids(num_source_tokens, source_len)
        target_mask, target_position_ids = \
            self.create_mask_and_position_ids(num_target_tokens, target_len, offset=num_source_tokens)
        position_ids = torch.cat((source_position_ids, target_position_ids, target_position_ids), dim=1)
        if target_span_ids is None:
            target_span_ids = target_position_ids
        attention_mask = self.create_attention_mask(source_mask, target_mask, source_position_ids, target_span_ids)
        device = torch.device("cuda")
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        position_ids = position_ids.to(device)
        target_ids = target_ids.to(device)
        # split_lengths = (x.to(device) for x in split_lengths)

        outputs = self.bert(
            input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
            position_ids=position_ids, split_lengths=split_lengths)
        sequence_output = outputs[0] # [batch_size * top_k ,sequence_length, embeddings]
        pseudo_sequence_output = sequence_output[:, source_len + target_len:, ] #[batch_size * top_k, sequence_length, embeddings]
        # print(pseudo_sequence_output.shape)
        def loss_mask_and_normalize(loss, mask):
            mask = mask.type_as(loss)
            loss = loss * mask
            denominator = torch.sum(mask) + 1e-5
            return (loss / denominator).sum()

        # make prediction based on sofmax relevant scores
        # relevant_scores.to(self.bert.device)
        pseudo_sequence_output = pseudo_sequence_output.view(relevant_scores.size(0), self.top_k, -1)
        #print(relevant_scores.shape)
        #print(pseudo_sequence_output.shape)
        pseudo_sequence_output = torch.bmm(relevant_scores.unsqueeze(1), pseudo_sequence_output)

        pseudo_sequence_output = pseudo_sequence_output.view(relevant_scores.size(0), -1 , self.config.hidden_size)
        #print(pseudo_sequence_output.shape)
        #print(target_ids.shape)
        target_ids = target_ids[:(relevant_scores.size(0))]
        prediction_scores_masked = self.cls(pseudo_sequence_output)
        #print(prediction_scores_masked.shape)
        #print(target_ids.shape)
        # print(target_mask)
        # print(target_mask.shape)
        target_mask = target_mask[:(relevant_scores.size(0))]
        if self.crit_mask_lm_smoothed:
            masked_lm_loss = self.crit_mask_lm_smoothed(
                F.log_softmax(prediction_scores_masked.float(), dim=-1), target_ids)
        else:
            masked_lm_loss = self.crit_mask_lm(
                prediction_scores_masked.transpose(1, 2).float(), target_ids)
        pseudo_lm_loss = loss_mask_and_normalize(
            masked_lm_loss.float(), target_mask)

        return pseudo_lm_loss
Exemple #60
0
 def forward(self, a):
     return torch.ones_like(a)