def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None): r"""Eliminates all but the first element from every consecutive group of equivalent elements. See :func:`torch.unique_consecutive` """ relevant_args = (self,) from torch.overrides import has_torch_function, handle_torch_function if type(self) is not Tensor and has_torch_function(relevant_args): return handle_torch_function( Tensor.unique_consecutive, relevant_args, self, return_inverse=return_inverse, return_counts=return_counts, dim=dim ) return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
def forward(self, log_h, y): log_h = log_h.flatten() durations, events = y.T # sort input durations, idx = durations.sort(descending=True) log_h = log_h[idx] events = events[idx] event_ind = events.nonzero().flatten() # numerator log_num = log_h[event_ind].mean() # logcumsumexp of events event_lcse = torch.logcumsumexp(log_h, dim=0)[event_ind] # number of events for each unique risk set _, tie_inverses, tie_count = torch.unique_consecutive( durations[event_ind], return_counts=True, return_inverse=True) # position of last event (lowest duration) of each unique risk set tie_pos = tie_count.cumsum(axis=0) - 1 # logcumsumexp by tie for each event event_tie_lcse = event_lcse[tie_pos][tie_inverses] if self.method == "breslow": log_den = event_tie_lcse.mean() elif self.method == "efron": # based on https://bydmitry.github.io/efron-tensorflow.html # logsumexp of ties, duplicated within tie set tie_lse = scatter_logsumexp(log_h[event_ind], tie_inverses, dim=0)[tie_inverses] # multiply (add in log space) with corrective factor aux = torch.ones_like(tie_inverses) aux[tie_pos[:-1] + 1] -= tie_count[:-1] event_id_in_tie = torch.cumsum(aux, dim=0) - 1 discounted_tie_lse = (tie_lse + torch.log(event_id_in_tie) - torch.log(tie_count[tie_inverses])) # denominator log_den = log_substract(event_tie_lcse, discounted_tie_lse).mean() # loss is negative log likelihood return log_den - log_num
def forward(self, x_all, segment_key): x_all_new = [] for segment_val in torch.unique_consecutive(segment_key): segment_bin = segment_key == segment_val x = x_all[segment_bin, :] x = x.unsqueeze(0).transpose(1, 2) x = self.conv(x) x = x.squeeze(0).transpose(0, 1) x_all_new.append(x) # s = raw_input() x_all_new = torch.cat(x_all_new, axis=0) x_all_new = self.aft_conv(x_all_new) return x_all_new
def _get_to_orthogonalize(matrix: torch.Tensor, filt_len: int) -> torch.Tensor: """Find matrix rows with fewer entries than filt_len. The returned rows will need to be orthogonalized. Args: matrix (torch.Tensor): The wavelet matrix under consideration. filt_len (int): The number of entries we would expect per row. Returns: torch.Tensor: The row indices with too few entries. """ unique, count = torch.unique_consecutive(matrix.coalesce().indices()[0, :], return_counts=True) return unique[count != filt_len]
def forward_pain(self, input_dict): input = input_dict['img_crop'] segment_key = input_dict['segment_key'] batch_size = input_dict['img_crop'].size()[0] device = input.device out_enc_conv = self.encoder(input) center_flat = out_enc_conv.view(batch_size, -1) latent_3d = self.to_3d(center_flat) h_n_all = [] segment_key_new = [] for segment_val in torch.unique_consecutive(segment_key): segment_bin = segment_key == segment_val rel_latent = latent_3d[segment_bin, :] init_size = rel_latent.size(0) rel_latent = self.pad_input(rel_latent) out, _ = self.lstm(rel_latent) out = out.view(out.size(0) * out.size(1), -1) out = out[:init_size, :] h_n_all.append(out) # segment_key_rel = segment_key[segment_bin][:h_n.size(0)] # segment_key_new.append(segment_key_rel) h_n = torch.cat(h_n_all, axis=0) output_pain = self.to_pain[1](h_n) # segment_key_new = torch.cat(segment_key_new, axis=0) pain_pred = torch.nn.functional.softmax(output_pain, dim=1) ############################################### # Select the right output output_dict_all = { 'pain': output_pain, 'pain_pred': pain_pred, 'segment_key': segment_key } output_dict = {} # print (self.output_types) for key in self.output_types: output_dict[key] = output_dict_all[key] return output_dict
def invert_indices(qs, n_qs): """Creates a :class:`Ragged` from a list of values, by interpreting the list as a index-to-value mapping. The ragged is then the mapping from value-to-indices-with-that-value.""" sort = torch.sort(qs) # As of Pytorch 1.4, unique_consecutive breaks on empty inputs. if len(qs) > 0: unique, counts = torch.unique_consecutive(sort.values, return_counts=True) else: unique, counts = torch.zeros_like(sort.values), torch.zeros_like( sort.values) cardinalities = qs.new_zeros(n_qs) cardinalities[unique] = counts return Ragged(sort.indices, cardinalities)
def forward(self, inputs): if not isinstance(inputs, list): inputs = [inputs] idx_crops = torch.cumsum(torch.unique_consecutive( torch.tensor([inp.shape[-1] for inp in inputs]), return_counts=True, )[1], 0) start_idx = 0 for end_idx in idx_crops: _out = self.forward_backbone(torch.cat(inputs[start_idx: end_idx]).cuda(non_blocking=True)) if start_idx == 0: output = _out else: output = torch.cat((output, _out)) start_idx = end_idx return self.forward_head(output)
def from_pairs(pairs, n_ps, n_qs): qs, ps = pairs.T sort = torch.sort(qs) image = ps[sort.indices] # As of Pytorch 1.4, unique_consecutive breaks on empty inputs. if len(pairs) > 0: unique, counts = torch.unique_consecutive(sort.values, return_counts=True) else: unique, counts = torch.zeros_like(sort.values), torch.zeros_like( sort.values) cardinalities = qs.new_zeros(n_qs) cardinalities[unique] = counts return Ragged(image, cardinalities)
def greedy_assignment(self, scores, k=1): token_to_workers = torch.topk(scores, dim=1, k=k, largest=True).indices.view(-1) token_to_workers, sort_ordering = torch.sort(token_to_workers) worker2token = sort_ordering // k # Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers) output_splits = torch.zeros((self.num_workers, ), dtype=torch.long, device=scores.device) workers, counts = torch.unique_consecutive(token_to_workers, return_counts=True) output_splits[workers] = counts # Tell other workers how many tokens to expect from us input_splits = All2All.apply(output_splits) return worker2token, input_splits.tolist(), output_splits.tolist()
def piecewise_arange(piecewise_idxer): """ count repeated indices Example: [0, 0, 0, 3, 3, 3, 3, 1, 1, 2] -> [0, 1, 2, 0, 1, 2, 3, 0, 1, 0] """ dv = piecewise_idxer.device # print(piecewise_idxer) uni, counts = torch.unique_consecutive(piecewise_idxer, return_counts=True) # print(counts) maxcnt = torch.max(counts).item() numuni = uni.shape[0] tmp = torch.zeros(size=(numuni, maxcnt), device=dv).bool() ranges = torch.arange(maxcnt, device=dv).unsqueeze(0).expand(numuni, -1) tmp[ranges < counts.unsqueeze(-1)] = True return ranges[tmp]
def forward(self, logits: torch.Tensor) -> str: """Given a sequence logits over labels, get the best path string Args: logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`. Returns: str: The resulting transcript """ best_path = torch.argmax(logits, dim=-1) # [num_seq,] best_path = torch.unique_consecutive(best_path, dim=-1) hypothesis = [] for i in best_path: if i != self.blank: hypothesis.append(self.labels[i]) return ''.join(hypothesis)
def confidence_based_inlier_selection(residuals, ransidx, rdims, idxoffsets, dv, min_confidence): numransacs = rdims.shape[0] numiters = residuals.shape[0] sorted_res, sorting_idxes = stable_sort_residuals(residuals, ransidx) sorted_res_sqr = sorted_res**2 too_perfect_fits = sorted_res_sqr <= 1e-8 end_rans_indexing = torch.cumsum(rdims, dim=0) - 1 _, inv_indices, res_dup_counts = torch.unique_consecutive( sorted_res_sqr.half().float(), dim=1, return_counts=True, return_inverse=True) duplicates_per_sample = res_dup_counts[inv_indices] inlier_weights = (1. / duplicates_per_sample).repeat(numiters, 1) inlier_weights[too_perfect_fits] = 0. balanced_rdims, weights_cumsums = group_sum_and_cumsum( inlier_weights, end_rans_indexing, ransidx) progressive_inl_rates = weights_cumsums / ( balanced_rdims.repeat_interleave(rdims, dim=1)).float() good_inl_mask = (sorted_res_sqr * min_confidence <= progressive_inl_rates) | too_perfect_fits inlier_weights[~good_inl_mask] = 0. inlier_counts_matrix, _ = group_sum_and_cumsum(inlier_weights, end_rans_indexing) inl_counts, inl_iters = torch.max(inlier_counts_matrix, dim=0) relative_inl_idxes = arange_sequence(inl_counts) inl_ransidx = torch.arange(numransacs, device=dv).repeat_interleave(inl_counts) inl_sampleidx = sorting_idxes[inl_iters.repeat_interleave(inl_counts), idxoffsets[inl_ransidx] + relative_inl_idxes] highest_accepted_sqr_residuals = sorted_res_sqr[inl_iters, idxoffsets + inl_counts - 1] expected_extra_inl = balanced_rdims[ inl_iters, torch.arange(numransacs, device=dv)].float( ) * highest_accepted_sqr_residuals return inl_ransidx, inl_sampleidx, inl_counts, inl_iters, 1. - expected_extra_inl / inl_counts.float( )
def sample(self, xs, batch_size=None, **kwargs): xs = xs.to(self.device) # Fix since x is unnecessarily repeated unique_xs, counts = torch.unique_consecutive(xs, return_counts=True, dim=0) n_samples = counts[0].item() # Assume all counts the same assert n_samples < self.imp_samples, ( """Trying to draw more or same amount of samples from DCTD model as is used for importance sampling. Increase amount of importance samples used.""" ) # Batch over x:s if batch_size: batches = torch.split(unique_xs, batch_size, dim=0) else: batches = (unique_xs, ) sample_list = [] for x_batch in batches: proposal_means = self.find_modes(x_batch) _, imp_weights, imp_samples = self.estimate_log_z( x_batch, proposal_means, return_weights=True) # imp_samples shape: (n_batch, n_imp_samples, dim_y) # Sample index from categorical distribution parametrized # by importance weights sample_indexes = torch.multinomial(imp_weights, n_samples, replacement=True) # shape: (n_batch, n_samples) indexes = sample_indexes.unsqueeze(2).repeat(1, 1, self.y_dim) sample_batch = torch.gather(imp_samples, dim=1, index=indexes) # Shape (n_batch, n_samples, dim_y) sample_batch_flat = torch.flatten(sample_batch, start_dim=0, end_dim=1) # Shape (n_batch*n_samples, dim_y) sample_list.append(sample_batch_flat) samples = torch.cat(sample_list, dim=0) return samples
def from_pairs(pairs, n_ps, n_qs): """Creates a :class:`Ragged` from a list of index-value pairs. A single index can occur many times.""" qs, ps = pairs.T sort = torch.sort(qs) image = ps[sort.indices] # As of Pytorch 1.4, unique_consecutive breaks on empty inputs. if len(pairs) > 0: unique, counts = torch.unique_consecutive(sort.values, return_counts=True) else: unique, counts = torch.zeros_like(sort.values), torch.zeros_like( sort.values) cardinalities = qs.new_zeros(n_qs) cardinalities[unique] = counts return Ragged(image, cardinalities)
def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """Input: indices for pairs of atoms that are close to each other. each pair only appear once, i.e. only one of the pairs (1, 2) and (2, 1) exists. Output: indices for all central atoms and it pairs of neighbors. For example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4) """ # convert representation from pair to central-others ai1 = atom_index12.view(-1) sorted_ai1, rev_indices = ai1.sort() # sort and compute unique key unique_results = torch.unique_consecutive(sorted_ai1, return_inverse=True, return_counts=True) uniqued_central_atom_index = unique_results[0] counts = unique_results[-1] # compute central_atom_index pair_sizes = counts * (counts - 1) // 2 pair_indices = torch.repeat_interleave(pair_sizes) central_atom_index = uniqued_central_atom_index.index_select( 0, pair_indices) # do local combinations within unique key, assuming sorted m = counts.max().item() if counts.numel() > 0 else 0 n = pair_sizes.shape[0] intra_pair_indices = torch.tril_indices( m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1) mask = (torch.arange(intra_pair_indices.shape[2], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten() sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask] sorted_local_index12 += cumsum_from_zero(counts).index_select( 0, pair_indices) # unsort result from last part local_index12 = rev_indices[sorted_local_index12] # compute mapping between representation of central-other to pair n = atom_index12.shape[1] sign12 = ((local_index12 < n).to(torch.int8) * 2) - 1 return central_atom_index, local_index12 % n, sign12
def check_repeat(features, indices, features_add=None, sort_first=True, flip_first=True): """ Check that whether there are replicate indices in the sparse features, remove the replicate features if any. """ if sort_first: features, indices, features_add = sort_by_indices( features, indices, features_add) if flip_first: features, indices = features.flip([0]), indices.flip([0]) if not features_add is None: features_add = features_add.flip([0]) idx = indices[:, 1:].int() idx_sum = torch.add( torch.add( idx.select(1, 0) * idx[:, 1].max() * idx[:, 2].max(), idx.select(1, 1) * idx[:, 2].max()), idx.select(1, 2)) _unique, inverse, counts = torch.unique_consecutive(idx_sum, return_inverse=True, return_counts=True, dim=0) if _unique.shape[0] < indices.shape[0]: perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device) features_new = torch.zeros((_unique.shape[0], features.shape[-1]), device=features.device) features_new.index_add_(0, inverse.long(), features) features = features_new perm_ = inverse.new_empty(_unique.size(0)).scatter_(0, inverse, perm) indices = indices[perm_].int() if not features_add is None: features_add_new = torch.zeros((_unique.shape[0], ), device=features_add.device) features_add_new.index_add_(0, inverse.long(), features_add) features_add = features_add_new / counts return features, indices, features_add
def triple_by_molecule(atom_index1, atom_index2): """Input: indices for pairs of atoms that are close to each other. each pair only appear once, i.e. only one of the pairs (1, 2) and (2, 1) exists. Output: indices for all central atoms and it pairs of neighbors. For example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4) """ # convert representation from pair to central-others n = atom_index1.shape[0] ai1 = torch.cat([atom_index1, atom_index2]) sorted_ai1, rev_indices = ai1.sort() # sort and compute unique key uniqued_central_atom_index, counts = torch.unique_consecutive( sorted_ai1, return_counts=True) # do local combinations within unique key, assuming sorted pair_sizes = counts * (counts - 1) // 2 total_size = pair_sizes.sum() pair_indices = torch.repeat_interleave(pair_sizes) central_atom_index = uniqued_central_atom_index.index_select( 0, pair_indices) cumsum = cumsum_from_zero(pair_sizes) cumsum = cumsum.index_select(0, pair_indices) sorted_local_pair_index = torch.arange(total_size, device=cumsum.device) - cumsum sorted_local_index1, sorted_local_index2 = convert_pair_index( sorted_local_pair_index) cumsum = cumsum_from_zero(counts) cumsum = cumsum.index_select(0, pair_indices) sorted_local_index1 += cumsum sorted_local_index2 += cumsum # unsort result from last part local_index1 = rev_indices[sorted_local_index1] local_index2 = rev_indices[sorted_local_index2] # compute mapping between representation of central-other to pair sign1 = ((local_index1 < n) * 2).to(torch.long) - 1 sign2 = ((local_index2 < n) * 2).to(torch.long) - 1 return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2
def _project2D_edges_init(self, rot_mat, edge_index, edge_distance_vec): torch.set_printoptions(sci_mode=False) length = len(edge_distance_vec) device = edge_distance_vec.device # Assuming the edges are consecutive based on the target index target_node_index, neigh_count = torch.unique_consecutive( edge_index[1], return_counts=True) max_neighbors = torch.max(neigh_count) target_neigh_count = torch.zeros(self.num_atoms, device=device).long() target_neigh_count.index_copy_(0, target_node_index.long(), neigh_count) index_offset = (torch.cumsum(target_neigh_count, dim=0) - target_neigh_count) neigh_index = torch.arange(length, device=device) neigh_index = neigh_index - index_offset[edge_index[1]] edge_map_index = edge_index[1] * max_neighbors + neigh_index target_lookup = ( torch.zeros(self.num_atoms * max_neighbors, device=device) - 1).long() target_lookup.index_copy_( 0, edge_map_index.long(), torch.arange(length, device=device).long(), ) target_lookup = target_lookup.view(self.num_atoms, max_neighbors) # target_lookup - For each target node, a list of edge indices # target_neigh_count - number of neighbors for each target node source_edge = target_lookup[edge_index[0]] target_edge = (torch.arange(length, device=device).long().view(-1, 1).repeat( 1, max_neighbors)) source_edge = source_edge.view(-1) target_edge = target_edge.view(-1) mask_unused = source_edge.ge(0) source_edge = torch.masked_select(source_edge, mask_unused) target_edge = torch.masked_select(target_edge, mask_unused) return self._project2D_init(source_edge, target_edge, rot_mat, edge_distance_vec)
def forward(self, x): # convert to list if not isinstance(x, list): x = [x] idx_crops = torch.cumsum(torch.unique_consecutive( torch.tensor([inp.shape[-1] for inp in x]), return_counts=True, )[1], 0) start_idx = 0 for end_idx in idx_crops: _out = self.backbone(torch.cat(x[start_idx: end_idx])) if start_idx == 0: output = _out else: output = torch.cat((output, _out)) start_idx = end_idx # Run the head forward on the concatenated features. return self.head(output)
def forward(self, x: Union[List[Tensor], Tensor]) -> Tensor: if isinstance(x, Tensor): return self.head(self.backbone(x)) idx_crops = torch.cumsum( torch.unique_consecutive( torch.tensor([inp.shape[-1] for inp in x]), return_counts=True, )[1], 0, ) start_idx = 0 for end_idx in idx_crops: _out = self.backbone(torch.cat(x[start_idx:end_idx])) output: Tensor = _out if start_idx == 0 else torch.cat((output, _out)) # type: ignore start_idx = end_idx # Run the head forward on the concatenated features. return self.head(output) # type: ignore
def forward(self, token_seq): mask = torch.ne(token_seq[:, :, 1], self.bert_tokenizer.pad_token_id) bert_output = self.bert(token_seq[:, :, 1], attention_mask=mask) bert_emb_tokens = bert_output.last_hidden_state emb_tokens = [] for i in range(len(token_seq)): # # groupby token_id # mask = torch.ne(input_xtokens[i, :, 1], 0) idxs, vals = torch.unique_consecutive(token_seq[i, :, 0][mask[i]], return_counts=True) token_emb_xtoken_split = torch.split_with_sizes( bert_emb_tokens[i][mask[i]], tuple(vals)) # token_xcontext = {k.item(): v for k, v in zip(idxs, [torch.mean(t, dim=0) for t in token_emb_xtokens])} emb_tokens.append( torch.stack( [torch.mean(t, dim=0) for t in token_emb_xtoken_split], dim=0)) return emb_tokens
def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None): r"""Eliminates all but the first element from every consecutive group of equivalent elements. See :func:`torch.unique_consecutive` """ if has_torch_function_unary(self): return handle_torch_function(Tensor.unique_consecutive, (self, ), self, return_inverse=return_inverse, return_counts=return_counts, dim=dim) return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
def forward_pain(self, input_dict): input = input_dict['img_crop'] segment_key = input_dict['segment_key'] batch_size = input_dict['img_crop'].size()[0] device = input.device out_enc_conv = self.encoder(input) center_flat = out_enc_conv.view(batch_size, -1) latent_3d = self.to_3d(center_flat) latent_3d = self.to_pain[0](latent_3d) latent_pooled_all = [] for segment_val in torch.unique_consecutive(segment_key): segment_bin = segment_key == segment_val latent_pooled = latent_3d[segment_bin, :] init_size = latent_pooled.size(0) latent_pooled = latent_pooled.unsqueeze(0).transpose(1, 2) latent_pooled = self.to_pain[1](latent_pooled) latent_pooled = latent_pooled.squeeze(0).transpose(0, 1) assert latent_pooled.size(0) == init_size latent_pooled_all.append(latent_pooled) latent_pooled_all = torch.cat(latent_pooled_all, axis=0) output_pain = self.to_pain[2](latent_pooled_all) pain_pred = torch.nn.functional.softmax(output_pain, dim=1) ############################################### # Select the right output output_dict_all = { 'pain': output_pain, 'pain_pred': pain_pred, 'segment_key': segment_key } output_dict = {} # print (self.output_types) for key in self.output_types: output_dict[key] = output_dict_all[key] return output_dict
def reduction_ops(self): a = torch.randn(4) b = torch.randn(4) c = torch.tensor(0.5) return len( torch.argmax(a), torch.argmin(a), torch.amax(a), torch.amin(a), torch.aminmax(a), torch.all(a), torch.any(a), torch.max(a), a.max(a), torch.max(a, 0), torch.min(a), a.min(a), torch.min(a, 0), torch.dist(a, b), torch.logsumexp(a, 0), torch.mean(a), torch.mean(a, 0), torch.nanmean(a), torch.median(a), torch.nanmedian(a), torch.mode(a), torch.norm(a), a.norm(2), torch.norm(a, dim=0), torch.norm(c, torch.tensor(2)), torch.nansum(a), torch.prod(a), torch.quantile(a, torch.tensor([0.25, 0.5, 0.75])), torch.quantile(a, 0.5), torch.nanquantile(a, torch.tensor([0.25, 0.5, 0.75])), torch.std(a), torch.std_mean(a), torch.sum(a), torch.unique(a), torch.unique_consecutive(a), torch.var(a), torch.var_mean(a), torch.count_nonzero(a), )
def forward(self, inputs, return_before_head=False): if not isinstance(inputs, list): inputs = [inputs] idx_crops = torch.cumsum(torch.unique_consecutive( torch.tensor([inp.shape[-1] for inp in inputs]), return_counts=True, )[1], 0) start_idx = 0 for end_idx in idx_crops: _h = self._forward_backbone(torch.cat(inputs[start_idx:end_idx])) _z = self._forward_head(_h) if start_idx == 0: h, z = _h, _z else: h, z = torch.cat((h, _h)), torch.cat((z, _z)) start_idx = end_idx if return_before_head: return h, z return z
def __next__(self): batch = self._next_indices() if self.mapping_keys is None: return batch # convert the type-ID pairs to dictionary type_ids = batch[:, 0] indices = batch[:, 1] type_ids_sortidx = torch.argsort(type_ids) type_ids = type_ids[type_ids_sortidx] indices = indices[type_ids_sortidx] type_id_uniq, type_id_count = torch.unique_consecutive(type_ids, return_counts=True) type_id_uniq = type_id_uniq.tolist() type_id_offset = type_id_count.cumsum(0).tolist() type_id_offset.insert(0, 0) id_dict = { self.mapping_keys[type_id_uniq[i]]: indices[type_id_offset[i]:type_id_offset[i+1]] for i in range(len(type_id_uniq))} return id_dict
def test_on_task_switch_is_called(): setting = TraditionalRLSetting( dataset="CartPole-v0", nb_tasks=5, # steps_per_task=100, train_max_steps=500, test_max_steps=500, ) assert setting.stationary_context method = DummyMethod() _ = setting.apply(method) # assert setting.task_labels_at_test_time # assert False, method.observation_task_labels assert method.n_fit_calls == 1 import numpy as np import torch assert torch.unique_consecutive( torch.as_tensor(method.observation_task_labels)).tolist() != list( range(setting.nb_tasks))
def __call__(self, x: Any, segmentation: bool = False) -> Dict[str, List]: x, xs = transform_batch(x) x = x.detach() x = [x[:xs[i], i, :] for i in range(len(xs))] x = [x_n.max(dim=1) for x_n in x] out = {} if segmentation: out["prob"] = [x_n.values.exp() for x_n in x] out["segmentation"] = [ CTCGreedyDecoder.compute_segmentation(x_n.indices.tolist()) for x_n in x ] x = [x_n.indices for x_n in x] # Remove repeated symbols x = [torch.unique_consecutive(x_n) for x_n in x] # Remove CTC blank symbol x = [x_n[torch.nonzero(x_n, as_tuple=True)] for x_n in x] out["hyp"] = [x_n.tolist() for x_n in x] return out
def get_formula(atomic_numbers: Tensor) -> str: """Helper function to get reduced formula.""" # If n atoms > 30; then use the reduced formula if len(atomic_numbers) > 30: return ''.join([ f'{chemical_symbols[z]}{n}' if n != 1 else f'{chemical_symbols[z]}' for z, n in zip(*atomic_numbers.unique(return_counts=True)) if z != 0 ]) # <- Ignore zeros (padding) # Otherwise list the elements in the order they were specified else: return ''.join([ f'{chemical_symbols[int(z)]}{int(n)}' if n != 1 else f'{chemical_symbols[z]}' for z, n in zip(*torch.unique_consecutive( atomic_numbers, return_counts=True)) if z != 0 ])
def call(self, gt_boxes, gt_labels, pred_boxes, iou_matrices=None): """Calculate box recall for class-agnostic task.""" assert len(gt_boxes) == len(pred_boxes) if iou_matrices is None: iou_matrices = [ box_iou(gt_boxes_i, pred_boxes_i) for gt_boxes_i, pred_boxes_i in zip(gt_boxes, pred_boxes) ] return self.call(gt_boxes, gt_labels, pred_boxes, iou_matrices=iou_matrices) best_overlaps = [] for iou_matrix, gt_labels_i in zip(iou_matrices, gt_labels): # In case there is no predicted box if iou_matrix.numel() == 0: best_overlap = torch.zeros( (gt_labels_i.shape[0], )).to(iou_matrix) else: best_overlap, _ = iou_matrix.max(dim=-1) best_overlaps.append(best_overlap) best_overlaps = torch.cat(best_overlaps) gt_labels = torch.cat(gt_labels).to(torch.int16) assert len(best_overlaps) == len(gt_labels) # Sort gt_labels, sort_idxs = torch.sort(gt_labels) best_overlaps = best_overlaps[sort_idxs] # Since tensors are sorted, we can use `unique_consecutive` here _, counts = torch.unique_consecutive(gt_labels, return_counts=True) start = 0 mabo = [] for c in counts: end = start + c abo = best_overlaps[start:end].mean() mabo.append(abo) start = end return sum(mabo) / len(mabo)