Пример #1
0
    def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None):
        r"""Eliminates all but the first element from every consecutive group of equivalent elements.

        See :func:`torch.unique_consecutive`
        """
        relevant_args = (self,)
        from torch.overrides import has_torch_function, handle_torch_function
        if type(self) is not Tensor and has_torch_function(relevant_args):
            return handle_torch_function(
                Tensor.unique_consecutive, relevant_args, self, return_inverse=return_inverse,
                return_counts=return_counts, dim=dim
            )
        return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
Пример #2
0
    def forward(self, log_h, y):
        log_h = log_h.flatten()

        durations, events = y.T

        # sort input
        durations, idx = durations.sort(descending=True)
        log_h = log_h[idx]
        events = events[idx]

        event_ind = events.nonzero().flatten()

        # numerator
        log_num = log_h[event_ind].mean()

        # logcumsumexp of events
        event_lcse = torch.logcumsumexp(log_h, dim=0)[event_ind]

        # number of events for each unique risk set
        _, tie_inverses, tie_count = torch.unique_consecutive(
            durations[event_ind], return_counts=True, return_inverse=True)

        # position of last event (lowest duration) of each unique risk set
        tie_pos = tie_count.cumsum(axis=0) - 1

        # logcumsumexp by tie for each event
        event_tie_lcse = event_lcse[tie_pos][tie_inverses]

        if self.method == "breslow":
            log_den = event_tie_lcse.mean()

        elif self.method == "efron":
            # based on https://bydmitry.github.io/efron-tensorflow.html

            # logsumexp of ties, duplicated within tie set
            tie_lse = scatter_logsumexp(log_h[event_ind], tie_inverses,
                                        dim=0)[tie_inverses]
            # multiply (add in log space) with corrective factor
            aux = torch.ones_like(tie_inverses)
            aux[tie_pos[:-1] + 1] -= tie_count[:-1]
            event_id_in_tie = torch.cumsum(aux, dim=0) - 1
            discounted_tie_lse = (tie_lse + torch.log(event_id_in_tie) -
                                  torch.log(tie_count[tie_inverses]))

            # denominator
            log_den = log_substract(event_tie_lcse, discounted_tie_lse).mean()

        # loss is negative log likelihood
        return log_den - log_num
Пример #3
0
    def forward(self, x_all, segment_key):
        x_all_new = []
        for segment_val in torch.unique_consecutive(segment_key):
            segment_bin = segment_key == segment_val
            x = x_all[segment_bin, :]
            x = x.unsqueeze(0).transpose(1, 2)
            x = self.conv(x)
            x = x.squeeze(0).transpose(0, 1)
            x_all_new.append(x)
            # s = raw_input()

        x_all_new = torch.cat(x_all_new, axis=0)
        x_all_new = self.aft_conv(x_all_new)

        return x_all_new
Пример #4
0
def _get_to_orthogonalize(matrix: torch.Tensor, filt_len: int) -> torch.Tensor:
    """Find matrix rows with fewer entries than filt_len.

    The returned rows will need to be orthogonalized.

    Args:
        matrix (torch.Tensor): The wavelet matrix under consideration.
        filt_len (int): The number of entries we would expect per row.

    Returns:
        torch.Tensor: The row indices with too few entries.
    """
    unique, count = torch.unique_consecutive(matrix.coalesce().indices()[0, :],
                                             return_counts=True)
    return unique[count != filt_len]
Пример #5
0
    def forward_pain(self, input_dict):
        input = input_dict['img_crop']
        segment_key = input_dict['segment_key']
        batch_size = input_dict['img_crop'].size()[0]
        device = input.device

        out_enc_conv = self.encoder(input)
        center_flat = out_enc_conv.view(batch_size, -1)
        latent_3d = self.to_3d(center_flat)

        h_n_all = []
        segment_key_new = []
        for segment_val in torch.unique_consecutive(segment_key):
            segment_bin = segment_key == segment_val
            rel_latent = latent_3d[segment_bin, :]

            init_size = rel_latent.size(0)

            rel_latent = self.pad_input(rel_latent)
            out, _ = self.lstm(rel_latent)

            out = out.view(out.size(0) * out.size(1), -1)

            out = out[:init_size, :]

            h_n_all.append(out)
            # segment_key_rel = segment_key[segment_bin][:h_n.size(0)]
            # segment_key_new.append(segment_key_rel)

        h_n = torch.cat(h_n_all, axis=0)
        output_pain = self.to_pain[1](h_n)
        # segment_key_new = torch.cat(segment_key_new, axis=0)

        pain_pred = torch.nn.functional.softmax(output_pain, dim=1)

        ###############################################
        # Select the right output
        output_dict_all = {
            'pain': output_pain,
            'pain_pred': pain_pred,
            'segment_key': segment_key
        }
        output_dict = {}
        # print (self.output_types)
        for key in self.output_types:
            output_dict[key] = output_dict_all[key]

        return output_dict
Пример #6
0
def invert_indices(qs, n_qs):
    """Creates a :class:`Ragged` from a list of values, by interpreting the list as a index-to-value mapping.
    
    The ragged is then the mapping from value-to-indices-with-that-value."""
    sort = torch.sort(qs)
    # As of Pytorch 1.4, unique_consecutive breaks on empty inputs.
    if len(qs) > 0:
        unique, counts = torch.unique_consecutive(sort.values,
                                                  return_counts=True)
    else:
        unique, counts = torch.zeros_like(sort.values), torch.zeros_like(
            sort.values)

    cardinalities = qs.new_zeros(n_qs)
    cardinalities[unique] = counts
    return Ragged(sort.indices, cardinalities)
Пример #7
0
 def forward(self, inputs):
     if not isinstance(inputs, list):
         inputs = [inputs]
     idx_crops = torch.cumsum(torch.unique_consecutive(
         torch.tensor([inp.shape[-1] for inp in inputs]),
         return_counts=True,
     )[1], 0)
     start_idx = 0
     for end_idx in idx_crops:
         _out = self.forward_backbone(torch.cat(inputs[start_idx: end_idx]).cuda(non_blocking=True))
         if start_idx == 0:
             output = _out
         else:
             output = torch.cat((output, _out))
         start_idx = end_idx
     return self.forward_head(output)
Пример #8
0
def from_pairs(pairs, n_ps, n_qs):
    qs, ps = pairs.T
    sort = torch.sort(qs)
    image = ps[sort.indices]
    # As of Pytorch 1.4, unique_consecutive breaks on empty inputs.
    if len(pairs) > 0:
        unique, counts = torch.unique_consecutive(sort.values,
                                                  return_counts=True)
    else:
        unique, counts = torch.zeros_like(sort.values), torch.zeros_like(
            sort.values)

    cardinalities = qs.new_zeros(n_qs)
    cardinalities[unique] = counts

    return Ragged(image, cardinalities)
Пример #9
0
    def greedy_assignment(self, scores, k=1):
        token_to_workers = torch.topk(scores, dim=1, k=k,
                                      largest=True).indices.view(-1)
        token_to_workers, sort_ordering = torch.sort(token_to_workers)
        worker2token = sort_ordering // k

        # Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers)
        output_splits = torch.zeros((self.num_workers, ),
                                    dtype=torch.long,
                                    device=scores.device)
        workers, counts = torch.unique_consecutive(token_to_workers,
                                                   return_counts=True)
        output_splits[workers] = counts
        # Tell other workers how many tokens to expect from us
        input_splits = All2All.apply(output_splits)
        return worker2token, input_splits.tolist(), output_splits.tolist()
Пример #10
0
def piecewise_arange(piecewise_idxer):
    """
    count repeated indices
    Example:
    [0, 0, 0, 3, 3, 3, 3, 1, 1, 2] -> [0, 1, 2, 0, 1, 2, 3, 0, 1, 0]
    """
    dv = piecewise_idxer.device
    # print(piecewise_idxer)
    uni, counts = torch.unique_consecutive(piecewise_idxer, return_counts=True)
    # print(counts)
    maxcnt = torch.max(counts).item()
    numuni = uni.shape[0]
    tmp = torch.zeros(size=(numuni, maxcnt), device=dv).bool()
    ranges = torch.arange(maxcnt, device=dv).unsqueeze(0).expand(numuni, -1)
    tmp[ranges < counts.unsqueeze(-1)] = True
    return ranges[tmp]
Пример #11
0
    def forward(self, logits: torch.Tensor) -> str:
        """Given a sequence logits over labels, get the best path string

        Args:
            logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
            str: The resulting transcript
        """
        best_path = torch.argmax(logits, dim=-1)  # [num_seq,]
        best_path = torch.unique_consecutive(best_path, dim=-1)
        hypothesis = []
        for i in best_path:
            if i != self.blank:
                hypothesis.append(self.labels[i])
        return ''.join(hypothesis)
Пример #12
0
def confidence_based_inlier_selection(residuals, ransidx, rdims, idxoffsets,
                                      dv, min_confidence):
    numransacs = rdims.shape[0]
    numiters = residuals.shape[0]

    sorted_res, sorting_idxes = stable_sort_residuals(residuals, ransidx)
    sorted_res_sqr = sorted_res**2

    too_perfect_fits = sorted_res_sqr <= 1e-8
    end_rans_indexing = torch.cumsum(rdims, dim=0) - 1

    _, inv_indices, res_dup_counts = torch.unique_consecutive(
        sorted_res_sqr.half().float(),
        dim=1,
        return_counts=True,
        return_inverse=True)

    duplicates_per_sample = res_dup_counts[inv_indices]
    inlier_weights = (1. / duplicates_per_sample).repeat(numiters, 1)
    inlier_weights[too_perfect_fits] = 0.

    balanced_rdims, weights_cumsums = group_sum_and_cumsum(
        inlier_weights, end_rans_indexing, ransidx)
    progressive_inl_rates = weights_cumsums / (
        balanced_rdims.repeat_interleave(rdims, dim=1)).float()

    good_inl_mask = (sorted_res_sqr * min_confidence <=
                     progressive_inl_rates) | too_perfect_fits

    inlier_weights[~good_inl_mask] = 0.
    inlier_counts_matrix, _ = group_sum_and_cumsum(inlier_weights,
                                                   end_rans_indexing)

    inl_counts, inl_iters = torch.max(inlier_counts_matrix, dim=0)

    relative_inl_idxes = arange_sequence(inl_counts)
    inl_ransidx = torch.arange(numransacs,
                               device=dv).repeat_interleave(inl_counts)
    inl_sampleidx = sorting_idxes[inl_iters.repeat_interleave(inl_counts),
                                  idxoffsets[inl_ransidx] + relative_inl_idxes]
    highest_accepted_sqr_residuals = sorted_res_sqr[inl_iters, idxoffsets +
                                                    inl_counts - 1]
    expected_extra_inl = balanced_rdims[
        inl_iters, torch.arange(numransacs, device=dv)].float(
        ) * highest_accepted_sqr_residuals
    return inl_ransidx, inl_sampleidx, inl_counts, inl_iters, 1. - expected_extra_inl / inl_counts.float(
    )
Пример #13
0
    def sample(self, xs, batch_size=None, **kwargs):
        xs = xs.to(self.device)

        # Fix since x is unnecessarily repeated
        unique_xs, counts = torch.unique_consecutive(xs,
                                                     return_counts=True,
                                                     dim=0)
        n_samples = counts[0].item()  # Assume all counts the same

        assert n_samples < self.imp_samples, (
            """Trying to draw more or same amount of samples from DCTD model as is used
          for importance sampling. Increase amount of importance samples used."""
        )

        # Batch over x:s
        if batch_size:
            batches = torch.split(unique_xs, batch_size, dim=0)
        else:
            batches = (unique_xs, )

        sample_list = []
        for x_batch in batches:
            proposal_means = self.find_modes(x_batch)

            _, imp_weights, imp_samples = self.estimate_log_z(
                x_batch, proposal_means, return_weights=True)
            # imp_samples shape: (n_batch, n_imp_samples, dim_y)

            # Sample index from categorical distribution parametrized
            # by importance weights
            sample_indexes = torch.multinomial(imp_weights,
                                               n_samples,
                                               replacement=True)
            # shape: (n_batch, n_samples)

            indexes = sample_indexes.unsqueeze(2).repeat(1, 1, self.y_dim)
            sample_batch = torch.gather(imp_samples, dim=1, index=indexes)
            # Shape (n_batch, n_samples, dim_y)
            sample_batch_flat = torch.flatten(sample_batch,
                                              start_dim=0,
                                              end_dim=1)
            # Shape (n_batch*n_samples, dim_y)

            sample_list.append(sample_batch_flat)

        samples = torch.cat(sample_list, dim=0)
        return samples
Пример #14
0
def from_pairs(pairs, n_ps, n_qs):
    """Creates a :class:`Ragged` from a list of index-value pairs. A single index can occur many times."""
    qs, ps = pairs.T
    sort = torch.sort(qs)
    image = ps[sort.indices]
    # As of Pytorch 1.4, unique_consecutive breaks on empty inputs.
    if len(pairs) > 0:
        unique, counts = torch.unique_consecutive(sort.values,
                                                  return_counts=True)
    else:
        unique, counts = torch.zeros_like(sort.values), torch.zeros_like(
            sort.values)

    cardinalities = qs.new_zeros(n_qs)
    cardinalities[unique] = counts

    return Ragged(image, cardinalities)
Пример #15
0
def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
    """Input: indices for pairs of atoms that are close to each other.
    each pair only appear once, i.e. only one of the pairs (1, 2) and
    (2, 1) exists.

    Output: indices for all central atoms and it pairs of neighbors. For
    example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2),
    (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have
    central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors
    are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
    """
    # convert representation from pair to central-others
    ai1 = atom_index12.view(-1)
    sorted_ai1, rev_indices = ai1.sort()

    # sort and compute unique key
    unique_results = torch.unique_consecutive(sorted_ai1,
                                              return_inverse=True,
                                              return_counts=True)
    uniqued_central_atom_index = unique_results[0]
    counts = unique_results[-1]

    # compute central_atom_index
    pair_sizes = counts * (counts - 1) // 2
    pair_indices = torch.repeat_interleave(pair_sizes)
    central_atom_index = uniqued_central_atom_index.index_select(
        0, pair_indices)

    # do local combinations within unique key, assuming sorted
    m = counts.max().item() if counts.numel() > 0 else 0
    n = pair_sizes.shape[0]
    intra_pair_indices = torch.tril_indices(
        m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1)
    mask = (torch.arange(intra_pair_indices.shape[2], device=ai1.device) <
            pair_sizes.unsqueeze(1)).flatten()
    sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask]
    sorted_local_index12 += cumsum_from_zero(counts).index_select(
        0, pair_indices)

    # unsort result from last part
    local_index12 = rev_indices[sorted_local_index12]

    # compute mapping between representation of central-other to pair
    n = atom_index12.shape[1]
    sign12 = ((local_index12 < n).to(torch.int8) * 2) - 1
    return central_atom_index, local_index12 % n, sign12
Пример #16
0
def check_repeat(features,
                 indices,
                 features_add=None,
                 sort_first=True,
                 flip_first=True):
    """
        Check that whether there are replicate indices in the sparse features, 
        remove the replicate features if any.
    """
    if sort_first:
        features, indices, features_add = sort_by_indices(
            features, indices, features_add)

    if flip_first:
        features, indices = features.flip([0]), indices.flip([0])

    if not features_add is None:
        features_add = features_add.flip([0])

    idx = indices[:, 1:].int()
    idx_sum = torch.add(
        torch.add(
            idx.select(1, 0) * idx[:, 1].max() * idx[:, 2].max(),
            idx.select(1, 1) * idx[:, 2].max()), idx.select(1, 2))
    _unique, inverse, counts = torch.unique_consecutive(idx_sum,
                                                        return_inverse=True,
                                                        return_counts=True,
                                                        dim=0)

    if _unique.shape[0] < indices.shape[0]:
        perm = torch.arange(inverse.size(0),
                            dtype=inverse.dtype,
                            device=inverse.device)
        features_new = torch.zeros((_unique.shape[0], features.shape[-1]),
                                   device=features.device)
        features_new.index_add_(0, inverse.long(), features)
        features = features_new
        perm_ = inverse.new_empty(_unique.size(0)).scatter_(0, inverse, perm)
        indices = indices[perm_].int()

        if not features_add is None:
            features_add_new = torch.zeros((_unique.shape[0], ),
                                           device=features_add.device)
            features_add_new.index_add_(0, inverse.long(), features_add)
            features_add = features_add_new / counts
    return features, indices, features_add
Пример #17
0
def triple_by_molecule(atom_index1, atom_index2):
    """Input: indices for pairs of atoms that are close to each other.
    each pair only appear once, i.e. only one of the pairs (1, 2) and
    (2, 1) exists.

    Output: indices for all central atoms and it pairs of neighbors. For
    example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2),
    (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have
    central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors
    are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
    """
    # convert representation from pair to central-others
    n = atom_index1.shape[0]
    ai1 = torch.cat([atom_index1, atom_index2])
    sorted_ai1, rev_indices = ai1.sort()

    # sort and compute unique key
    uniqued_central_atom_index, counts = torch.unique_consecutive(
        sorted_ai1, return_counts=True)

    # do local combinations within unique key, assuming sorted
    pair_sizes = counts * (counts - 1) // 2
    total_size = pair_sizes.sum()
    pair_indices = torch.repeat_interleave(pair_sizes)
    central_atom_index = uniqued_central_atom_index.index_select(
        0, pair_indices)
    cumsum = cumsum_from_zero(pair_sizes)
    cumsum = cumsum.index_select(0, pair_indices)
    sorted_local_pair_index = torch.arange(total_size,
                                           device=cumsum.device) - cumsum
    sorted_local_index1, sorted_local_index2 = convert_pair_index(
        sorted_local_pair_index)
    cumsum = cumsum_from_zero(counts)
    cumsum = cumsum.index_select(0, pair_indices)
    sorted_local_index1 += cumsum
    sorted_local_index2 += cumsum

    # unsort result from last part
    local_index1 = rev_indices[sorted_local_index1]
    local_index2 = rev_indices[sorted_local_index2]

    # compute mapping between representation of central-other to pair
    sign1 = ((local_index1 < n) * 2).to(torch.long) - 1
    sign2 = ((local_index2 < n) * 2).to(torch.long) - 1
    return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2
Пример #18
0
    def _project2D_edges_init(self, rot_mat, edge_index, edge_distance_vec):
        torch.set_printoptions(sci_mode=False)
        length = len(edge_distance_vec)
        device = edge_distance_vec.device

        # Assuming the edges are consecutive based on the target index
        target_node_index, neigh_count = torch.unique_consecutive(
            edge_index[1], return_counts=True)
        max_neighbors = torch.max(neigh_count)
        target_neigh_count = torch.zeros(self.num_atoms, device=device).long()
        target_neigh_count.index_copy_(0, target_node_index.long(),
                                       neigh_count)

        index_offset = (torch.cumsum(target_neigh_count, dim=0) -
                        target_neigh_count)
        neigh_index = torch.arange(length, device=device)
        neigh_index = neigh_index - index_offset[edge_index[1]]

        edge_map_index = edge_index[1] * max_neighbors + neigh_index
        target_lookup = (
            torch.zeros(self.num_atoms * max_neighbors, device=device) -
            1).long()
        target_lookup.index_copy_(
            0,
            edge_map_index.long(),
            torch.arange(length, device=device).long(),
        )
        target_lookup = target_lookup.view(self.num_atoms, max_neighbors)

        # target_lookup - For each target node, a list of edge indices
        # target_neigh_count - number of neighbors for each target node
        source_edge = target_lookup[edge_index[0]]
        target_edge = (torch.arange(length,
                                    device=device).long().view(-1, 1).repeat(
                                        1, max_neighbors))

        source_edge = source_edge.view(-1)
        target_edge = target_edge.view(-1)

        mask_unused = source_edge.ge(0)
        source_edge = torch.masked_select(source_edge, mask_unused)
        target_edge = torch.masked_select(target_edge, mask_unused)

        return self._project2D_init(source_edge, target_edge, rot_mat,
                                    edge_distance_vec)
Пример #19
0
 def forward(self, x):
     # convert to list
     if not isinstance(x, list):
         x = [x]
     idx_crops = torch.cumsum(torch.unique_consecutive(
         torch.tensor([inp.shape[-1] for inp in x]),
         return_counts=True,
     )[1], 0)
     start_idx = 0
     for end_idx in idx_crops:
         _out = self.backbone(torch.cat(x[start_idx: end_idx]))
         if start_idx == 0:
             output = _out
         else:
             output = torch.cat((output, _out))
         start_idx = end_idx
     # Run the head forward on the concatenated features.
     return self.head(output)
    def forward(self, x: Union[List[Tensor], Tensor]) -> Tensor:
        if isinstance(x, Tensor):
            return self.head(self.backbone(x))

        idx_crops = torch.cumsum(
            torch.unique_consecutive(
                torch.tensor([inp.shape[-1] for inp in x]),
                return_counts=True,
            )[1],
            0,
        )
        start_idx = 0
        for end_idx in idx_crops:
            _out = self.backbone(torch.cat(x[start_idx:end_idx]))
            output: Tensor = _out if start_idx == 0 else torch.cat((output, _out))  # type: ignore
            start_idx = end_idx
        # Run the head forward on the concatenated features.
        return self.head(output)  # type: ignore
Пример #21
0
 def forward(self, token_seq):
     mask = torch.ne(token_seq[:, :, 1], self.bert_tokenizer.pad_token_id)
     bert_output = self.bert(token_seq[:, :, 1], attention_mask=mask)
     bert_emb_tokens = bert_output.last_hidden_state
     emb_tokens = []
     for i in range(len(token_seq)):
         # # groupby token_id
         # mask = torch.ne(input_xtokens[i, :, 1], 0)
         idxs, vals = torch.unique_consecutive(token_seq[i, :, 0][mask[i]],
                                               return_counts=True)
         token_emb_xtoken_split = torch.split_with_sizes(
             bert_emb_tokens[i][mask[i]], tuple(vals))
         # token_xcontext = {k.item(): v for k, v in zip(idxs, [torch.mean(t, dim=0) for t in token_emb_xtokens])}
         emb_tokens.append(
             torch.stack(
                 [torch.mean(t, dim=0) for t in token_emb_xtoken_split],
                 dim=0))
     return emb_tokens
Пример #22
0
    def unique_consecutive(self,
                           return_inverse=False,
                           return_counts=False,
                           dim=None):
        r"""Eliminates all but the first element from every consecutive group of equivalent elements.

        See :func:`torch.unique_consecutive`
        """
        if has_torch_function_unary(self):
            return handle_torch_function(Tensor.unique_consecutive, (self, ),
                                         self,
                                         return_inverse=return_inverse,
                                         return_counts=return_counts,
                                         dim=dim)
        return torch.unique_consecutive(self,
                                        return_inverse=return_inverse,
                                        return_counts=return_counts,
                                        dim=dim)
Пример #23
0
    def forward_pain(self, input_dict):

        input = input_dict['img_crop']
        segment_key = input_dict['segment_key']
        batch_size = input_dict['img_crop'].size()[0]
        device = input.device

        out_enc_conv = self.encoder(input)
        center_flat = out_enc_conv.view(batch_size, -1)
        latent_3d = self.to_3d(center_flat)

        latent_3d = self.to_pain[0](latent_3d)

        latent_pooled_all = []
        for segment_val in torch.unique_consecutive(segment_key):
            segment_bin = segment_key == segment_val
            latent_pooled = latent_3d[segment_bin, :]

            init_size = latent_pooled.size(0)
            latent_pooled = latent_pooled.unsqueeze(0).transpose(1, 2)
            latent_pooled = self.to_pain[1](latent_pooled)
            latent_pooled = latent_pooled.squeeze(0).transpose(0, 1)

            assert latent_pooled.size(0) == init_size
            latent_pooled_all.append(latent_pooled)

        latent_pooled_all = torch.cat(latent_pooled_all, axis=0)

        output_pain = self.to_pain[2](latent_pooled_all)
        pain_pred = torch.nn.functional.softmax(output_pain, dim=1)

        ###############################################
        # Select the right output
        output_dict_all = {
            'pain': output_pain,
            'pain_pred': pain_pred,
            'segment_key': segment_key
        }
        output_dict = {}
        # print (self.output_types)
        for key in self.output_types:
            output_dict[key] = output_dict_all[key]

        return output_dict
Пример #24
0
 def reduction_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     c = torch.tensor(0.5)
     return len(
         torch.argmax(a),
         torch.argmin(a),
         torch.amax(a),
         torch.amin(a),
         torch.aminmax(a),
         torch.all(a),
         torch.any(a),
         torch.max(a),
         a.max(a),
         torch.max(a, 0),
         torch.min(a),
         a.min(a),
         torch.min(a, 0),
         torch.dist(a, b),
         torch.logsumexp(a, 0),
         torch.mean(a),
         torch.mean(a, 0),
         torch.nanmean(a),
         torch.median(a),
         torch.nanmedian(a),
         torch.mode(a),
         torch.norm(a),
         a.norm(2),
         torch.norm(a, dim=0),
         torch.norm(c, torch.tensor(2)),
         torch.nansum(a),
         torch.prod(a),
         torch.quantile(a, torch.tensor([0.25, 0.5, 0.75])),
         torch.quantile(a, 0.5),
         torch.nanquantile(a, torch.tensor([0.25, 0.5, 0.75])),
         torch.std(a),
         torch.std_mean(a),
         torch.sum(a),
         torch.unique(a),
         torch.unique_consecutive(a),
         torch.var(a),
         torch.var_mean(a),
         torch.count_nonzero(a),
     )
Пример #25
0
 def forward(self, inputs, return_before_head=False):
     if not isinstance(inputs, list):
         inputs = [inputs]
     idx_crops = torch.cumsum(torch.unique_consecutive(
         torch.tensor([inp.shape[-1] for inp in inputs]),
         return_counts=True,
     )[1], 0)
     start_idx = 0
     for end_idx in idx_crops:
         _h = self._forward_backbone(torch.cat(inputs[start_idx:end_idx]))
         _z = self._forward_head(_h)
         if start_idx == 0:
             h, z = _h, _z
         else:
             h, z = torch.cat((h, _h)), torch.cat((z, _z))
         start_idx = end_idx
     if return_before_head:
         return h, z
     return z
Пример #26
0
    def __next__(self):
        batch = self._next_indices()
        if self.mapping_keys is None:
            return batch

        # convert the type-ID pairs to dictionary
        type_ids = batch[:, 0]
        indices = batch[:, 1]
        type_ids_sortidx = torch.argsort(type_ids)
        type_ids = type_ids[type_ids_sortidx]
        indices = indices[type_ids_sortidx]
        type_id_uniq, type_id_count = torch.unique_consecutive(type_ids, return_counts=True)
        type_id_uniq = type_id_uniq.tolist()
        type_id_offset = type_id_count.cumsum(0).tolist()
        type_id_offset.insert(0, 0)
        id_dict = {
            self.mapping_keys[type_id_uniq[i]]: indices[type_id_offset[i]:type_id_offset[i+1]]
            for i in range(len(type_id_uniq))}
        return id_dict
Пример #27
0
def test_on_task_switch_is_called():
    setting = TraditionalRLSetting(
        dataset="CartPole-v0",
        nb_tasks=5,
        # steps_per_task=100,
        train_max_steps=500,
        test_max_steps=500,
    )
    assert setting.stationary_context
    method = DummyMethod()
    _ = setting.apply(method)
    # assert setting.task_labels_at_test_time
    # assert False, method.observation_task_labels
    assert method.n_fit_calls == 1
    import numpy as np
    import torch
    assert torch.unique_consecutive(
        torch.as_tensor(method.observation_task_labels)).tolist() != list(
            range(setting.nb_tasks))
Пример #28
0
 def __call__(self, x: Any, segmentation: bool = False) -> Dict[str, List]:
     x, xs = transform_batch(x)
     x = x.detach()
     x = [x[:xs[i], i, :] for i in range(len(xs))]
     x = [x_n.max(dim=1) for x_n in x]
     out = {}
     if segmentation:
         out["prob"] = [x_n.values.exp() for x_n in x]
         out["segmentation"] = [
             CTCGreedyDecoder.compute_segmentation(x_n.indices.tolist())
             for x_n in x
         ]
     x = [x_n.indices for x_n in x]
     # Remove repeated symbols
     x = [torch.unique_consecutive(x_n) for x_n in x]
     # Remove CTC blank symbol
     x = [x_n[torch.nonzero(x_n, as_tuple=True)] for x_n in x]
     out["hyp"] = [x_n.tolist() for x_n in x]
     return out
Пример #29
0
        def get_formula(atomic_numbers: Tensor) -> str:
            """Helper function to get reduced formula."""
            # If n atoms > 30; then use the reduced formula
            if len(atomic_numbers) > 30:
                return ''.join([
                    f'{chemical_symbols[z]}{n}'
                    if n != 1 else f'{chemical_symbols[z]}'
                    for z, n in zip(*atomic_numbers.unique(return_counts=True))
                    if z != 0
                ])  # <- Ignore zeros (padding)

            # Otherwise list the elements in the order they were specified
            else:
                return ''.join([
                    f'{chemical_symbols[int(z)]}{int(n)}'
                    if n != 1 else f'{chemical_symbols[z]}'
                    for z, n in zip(*torch.unique_consecutive(
                        atomic_numbers, return_counts=True)) if z != 0
                ])
Пример #30
0
    def call(self, gt_boxes, gt_labels, pred_boxes, iou_matrices=None):
        """Calculate box recall for class-agnostic task."""
        assert len(gt_boxes) == len(pred_boxes)
        if iou_matrices is None:
            iou_matrices = [
                box_iou(gt_boxes_i, pred_boxes_i)
                for gt_boxes_i, pred_boxes_i in zip(gt_boxes, pred_boxes)
            ]
            return self.call(gt_boxes,
                             gt_labels,
                             pred_boxes,
                             iou_matrices=iou_matrices)

        best_overlaps = []
        for iou_matrix, gt_labels_i in zip(iou_matrices, gt_labels):
            # In case there is no predicted box
            if iou_matrix.numel() == 0:
                best_overlap = torch.zeros(
                    (gt_labels_i.shape[0], )).to(iou_matrix)
            else:
                best_overlap, _ = iou_matrix.max(dim=-1)
            best_overlaps.append(best_overlap)

        best_overlaps = torch.cat(best_overlaps)
        gt_labels = torch.cat(gt_labels).to(torch.int16)
        assert len(best_overlaps) == len(gt_labels)

        # Sort
        gt_labels, sort_idxs = torch.sort(gt_labels)
        best_overlaps = best_overlaps[sort_idxs]

        # Since tensors are sorted, we can use `unique_consecutive` here
        _, counts = torch.unique_consecutive(gt_labels, return_counts=True)
        start = 0
        mabo = []
        for c in counts:
            end = start + c
            abo = best_overlaps[start:end].mean()
            mabo.append(abo)
            start = end

        return sum(mabo) / len(mabo)