def _batch_forward(self, batch, train=True): x, _ = batch # there is problem with y, so discard and construct ourselves. b, way_shot_query, c, h, w = x.size() n_queries = self.hparams.n_queries if train else 1 x_ = x.view(b, self.hparams.n_ways, (self.hparams.n_shots + n_queries), c, h, w).contiguous() # construct y lbls = torch.arange(self.hparams.n_ways).to(x.device).view(1, self.hparams.n_ways, 1).contiguous() y_ = lbls.repeat(b, 1, (self.hparams.n_shots + n_queries)).contiguous() x_support, x_queries = torch.split_with_sizes(x_, split_sizes=[self.hparams.n_shots, n_queries], dim=2) y_support, y_queries = torch.split_with_sizes(y_, split_sizes=[self.hparams.n_shots, n_queries], dim=2) rep_s = self(x_support.contiguous().view(b, self.hparams.n_ways * self.hparams.n_shots, c, h, w)) rep_q = self(x_queries.contiguous().view(b, self.hparams.n_ways * n_queries, c, h, w)) q = rep_q.view(b, self.hparams.n_ways * n_queries, self.proj_dim) # centroid of same way/class s = rep_s.view(b, self.hparams.n_ways, self.hparams.n_shots, self.proj_dim).mean(dim=2) s = s.clone().permute(0, 2, 1).contiguous() cosine_scores = q @ s # batch matrix multiplication logits = cosine_scores.view(-1, self.hparams.n_ways) / 0.1 labels = y_queries.contiguous().view(-1) loss = F.cross_entropy(logits, labels) acc = (logits.argmax(dim=1) == labels).float().mean() return loss, acc
def forward(self, input: torch.Tensor) -> torch.Tensor: batch, seq_len, input_size = input.size() out = torch.zeros(batch, seq_len, self.output_size, device=input.device) for t in range(seq_len): h = input[:, t] h_layers = torch.split_with_sizes(h, self.input_sizes, dim=1) s = torch.zeros(batch, self.embedding_size, device=h.device) for l, hl in enumerate(h_layers): #sg = h @ self.w[l] sg = self.linear_gates[l](h) g = torch.sigmoid(sg) s = s + self.embeddings[l](hl * g) s = self.ln_embeddings(s) he = self.activation(s) fnn = self.fnn(he) out[:, t] = self.output(fnn) return out
def forward(self, x): # x: [batch, n_frames, h, w] lens = [len(_x) for _x in x] xs = torch.cat(x, dim=0).unsqueeze(1) # [batch*n_frames, 1, h, w] xs = self.features(xs) # [batch*n_frames, features] xs = torch.split_with_sizes(xs, lens, dim=0) # [batch, n_frames, features] xs = torch.nn.utils.rnn.pack_sequence( xs, enforce_sorted=False) # [n_frames, batch, features] x, _ = self.lstm(xs) # [seq, batch, features] x, l = torch.nn.utils.rnn.pad_packed_sequence(x) l = l.cuda() mask = torch.arange( x.size(0)).cuda().unsqueeze(-1).unsqueeze(-1).expand(x.size()) l_exp = l.unsqueeze(0).unsqueeze(-1).expand(x.size()) mask = (mask < l_exp) x_sum = x.sum(0) / l[:, None] # [batch, features] x[~mask] = float('-inf') x_max = x.max(0)[0] # [batch, features] x = torch.cat([ x_sum, x_max, ], dim=-1) return self.classifier(x)
def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]): """ Infinite loop: receive results from runtime and dispatch them to task Futures """ try: while True: logger.debug(f"{self.uid} waiting for results from runtime") payload = self.outputs_receiver.recv() if isinstance(payload, BaseException): raise payload else: batch_index, batch_outputs = payload logger.debug(f"{self.uid}, batch {batch_index}: got results") # split batch into partitions for individual tasks batch_tasks = pending_batches.pop(batch_index) task_sizes = [self.get_task_size(task) for task in batch_tasks] outputs_per_task = zip( *(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs)) logger.debug( f"{self.uid}, batch {batch_index}: sending outputs to handlers" ) # dispatch results to futures for task, task_outputs in zip(batch_tasks, outputs_per_task): task.future.set_result(tuple(task_outputs)) except KeyboardInterrupt: logger.debug(f"Caught KeyboardInterrupt, shutting down")
def decode(self, scores) -> torch.Tensor: decoded_labels = torch.argmax(scores, dim=-1) if self.crf is not None: crf_scores, crf_tags, token_masks = crf_prepare( scores, decoded_labels) crf_masks = torch.ne(crf_tags, 0).bool() crf_decoded_labels = self.crf.viterbi_tags(logits=crf_scores, mask=crf_masks) for labels, crf_labels, token_mask in zip(decoded_labels, crf_decoded_labels, token_masks): idxs_vals = [ torch.unique_consecutive(mask, return_counts=True) for mask in token_mask ] idxs = torch.cat([idx for idx, _ in idxs_vals]) vals = torch.cat([val for _, val in idxs_vals]) decoded_token_tags = torch.split_with_sizes( torch.tensor(crf_labels[0]), tuple(vals[idxs])) # TODO: this doesn't do the right thing if a token is decoded as all pads (0) # In such a case the first mask is all False and the above split doesn't indicate that this token # should be skipped for idx, token_tags in enumerate(decoded_token_tags): labels[idx, :len(token_tags)] = token_tags return decoded_labels
def parse_dynamic_params(params, channels, weight_nums, bias_nums): assert params.dim() == 2 assert len(weight_nums) == len(bias_nums) assert params.size(1) == sum(weight_nums) + sum(bias_nums) num_insts = params.size(0) num_layers = len(weight_nums) params_splits = list( torch.split_with_sizes(params, weight_nums + bias_nums, dim=1)) weight_splits = params_splits[:num_layers] bias_splits = params_splits[num_layers:] for l in range(num_layers): if l < num_layers - 1: # out_channels x in_channels x 1 x 1 weight_splits[l] = weight_splits[l].reshape( num_insts * channels, -1, 1, 1) bias_splits[l] = bias_splits[l].reshape(num_insts * channels) else: # out_channels x in_channels x 1 x 1 weight_splits[l] = weight_splits[l].reshape( num_insts * 1, -1, 1, 1) bias_splits[l] = bias_splits[l].reshape(num_insts) return weight_splits, bias_splits
def forward(self, x): # x: [batch, n_frames, h, w] lens = [len(_x) for _x in x] xs = torch.cat(x, dim=0).unsqueeze(1) xs = self.features(xs) # [batch, n_frames, features] xs = torch.split_with_sizes(xs, lens, dim=0) xs = torch.nn.utils.rnn.pack_sequence(xs, enforce_sorted=False) x, l = torch.nn.utils.rnn.pad_packed_sequence(xs) l = l.cuda() mask = torch.arange( x.size(0)).cuda().unsqueeze(-1).unsqueeze(-1).expand(x.size()) l_exp = l.unsqueeze(0).unsqueeze(-1).expand(x.size()) mask = (mask < l_exp) # attn_mask = _generate_square_subsequent_mask(len(x)).cuda() # x = x*math.sqrt(self.n_f) # x = self.pos(x) # TODO add src_key_padding_mask x = self.lstm(x) #, attn_mask [seq, batch, features] x_sum = x.sum(0) / l[:, None] # [batch, features] x[~mask] = float('-inf') x_max = x.max(0)[0] # [batch, features] x = torch.cat([ x_sum, x_max, ], dim=-1) return self.classifier(x)
def groupby_apply( keys: torch.Tensor, values: torch.Tensor, bins: int = 95, reduction: str = "mean", return_histogram: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Groupby apply for torch tensors Args: keys: tensor of groups (``0`` to ``bins``) values: values to aggregate - same size as keys bins: total number of groups reduction: either "mean" or "sum" return_histogram: if to return histogram on top Returns: tensor of size ``bins`` with aggregated values and optionally with counts of values """ if reduction == "mean": reduce = torch.mean elif reduction == "sum": reduce = torch.sum else: raise ValueError(f"Unknown reduction '{reduction}'") uniques, counts = keys.unique(return_counts=True) groups = torch.stack([reduce(item) for item in torch.split_with_sizes(values, tuple(counts))]) reduced = torch.zeros(bins, dtype=values.dtype, device=values.device).scatter(dim=0, index=uniques, src=groups) if return_histogram: hist = torch.zeros(bins, dtype=torch.long, device=values.device).scatter(dim=0, index=uniques, src=counts) return reduced, hist else: return reduced
def as_tuple(self) -> Tuple[torch.Tensor]: """Convenience method to get a tuple of non-aggregated edge features. Better than building a tuple from the iterator: `tuple(batch.edge_features_by_graph)`""" return torch.split_with_sizes(self._batch.edge_features, self._batch.num_edges_by_graph.tolist(), dim=0)
def crossentropy_minimize(self, u_logits, u_images, l_images, l_labels, u_labels=None): """Cross-entropy optimization step implementation for TPU.""" batch_size = self.params.batch_size guessed_label = self.guess_label(u_logits) self.guessed_label = guessed_label guessed_label = torch.reshape(guessed_label.detach(), shape=(-1, self.params.num_classes)) l_labels = torch.reshape(onehot(l_labels, self.params.num_classes), shape=(-1, self.params.num_classes)) augment_images, augment_labels = self.augment( l_images, u_images, l_labels, guessed_label * self.params.nu, self.params.beta) logit = self.net(augment_images) zbs = batch_size * 2 halfzbs = batch_size split_pos = [l_images.shape[0], halfzbs, halfzbs] logit = [ logit_norm(lgt) for lgt in torch.split_with_sizes(logit, split_pos) ] u_logit = torch.cat(logit[1:], dim=0) split_pos = [l_images.shape[0], zbs] l_augment_labels, u_augment_labels = torch.split_with_sizes( augment_labels, split_pos) u_loss = tf.losses.softmax_cross_entropy(u_augment_labels, u_logit) l_loss = tf.losses.softmax_cross_entropy(l_augment_labels, logit[0]) loss = tf.math.add(l_loss, u_loss * FLAGS.ce_factor, name='crossentropy_minimization_loss') return loss
def forward(self, z, x, g): input = torch.cat((z, x, g), dim=-1) for i in range(len(self.model)): input = self.model[i](input) actions, states = torch.split_with_sizes(self.out_layer(input), [self.ac_dim, self.state_dim], dim=-1) return torch.tanh(actions), states, None
def restore_from_parts( chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]: """ restores the original tensor shapes from chunks obtained by split_into_chunks """ flat_tensor = torch.cat(tuple(chunks)) result_sizes = tuple(map(torch.Size.numel, shapes)) flat_original_tensors = torch.split_with_sizes(flat_tensor, result_sizes) return tuple(map(torch.Tensor.reshape, flat_original_tensors, shapes))
def _batches(self) -> Iterator[List[int]]: total_samples = len(self.dataset) batches = torch.split_with_sizes(torch.arange(total_samples), self._get_lengths(total_samples)) sort_keys = torch.randperm(len(batches)).tolist() # here we ensure that the shortest batch is last yield from sorted([batches[i].tolist() for i in sort_keys], key=len, reverse=True)
def forward(self, x, train=True, mode='meta'): b, way_shot_query, c, h, w = x.size() n_queries = self.n_queries if train else 1 x_ = x.view(b, self.n_ways, (self.n_shots + n_queries), c, h, w).contiguous() # construct y lbls = torch.arange(self.n_ways).to(x.device).view(1, self.n_ways, 1).contiguous() y_ = lbls.repeat(b, 1, (self.n_shots + n_queries)).contiguous() x_support, x_queries = torch.split_with_sizes( x_, split_sizes=[self.n_shots, n_queries], dim=2) y_support, y_queries = torch.split_with_sizes( y_, split_sizes=[self.n_shots, n_queries], dim=2) rep_s = self.x_forward(x_support.contiguous().view( b, self.n_ways * self.n_shots, c, h, w)) rep_q = self.x_forward(x_queries.contiguous().view( b, self.n_ways * n_queries, c, h, w)) q = rep_q.view(b, self.n_ways * n_queries, self.proj_dim) # centroid of same way/class s = rep_s.view(b, self.n_ways, self.n_shots, self.proj_dim).mean(dim=2) s = s.clone().permute(0, 2, 1).contiguous() cosine_scores = q @ s # batch matrix multiplication logits = cosine_scores.view(-1, self.n_ways) labels = y_queries.contiguous().view(-1) if mode == 'meta': logits = logits / 0.1 # scale with temperature=0.1 elif mode == 'margin': margin = 1.0 masked_margin = margin * torch.ones_like(cosine_scores).scatter_( dim=1, index=labels.unsqueeze(dim=1), value=0.) logits = logits + masked_margin else: raise Exception('score mode {} not available'.format(mode)) loss = F.cross_entropy(logits, labels) acc = (logits.argmax(dim=1) == labels).float().mean() return loss, acc
def splitby(data_tensor: torch.Tensor, group_indices: torch.Tensor, split_dim=0) -> List[torch.Tensor]: # https://twitter.com/jeremyphoward/status/1185062637341593600 idxs, vals = torch.unique(group_indices, return_counts=True) split_arrays = torch.split_with_sizes(data_tensor, tuple(vals), dim=split_dim) doc_tensors = [] for idx, split_array in sorted(zip(idxs, split_arrays), key=lambda t: t[0]): doc_tensors.append(split_array) return doc_tensors
def aggTensorBy(tensor, by, fun): """ Group by analogue for pytorch tensor :param tensor: tensor to aggregate by :param by: 1d tensor with sorted (!) indexes to aggregate the tensor by :param fun: aggregation function :return: tuple (unique indexes, aggregated tensor by by) """ idxs, vals = torch.unique(by, return_counts=True) vs = torch.split_with_sizes(tensor, tuple(vals)) return idxs, torch.stack([fun(v) for v in vs])
def groupby(data_tensor: torch.Tensor, doc_inds: torch.Tensor, split_dim=0) -> List[torch.Tensor]: # https://twitter.com/jeremyphoward/status/1185062637341593600 idxs, vals = torch.unique(doc_inds, return_counts=True) split_arrays = torch.split_with_sizes(data_tensor, tuple(vals), dim=split_dim) doc_arrays = [None] * max(idxs) for idx, split_array in zip(idxs, split_arrays): doc_arrays[idx.item()] = split_array doc_arrays = [e for e in doc_arrays if e is not None] return doc_arrays
def plot_h(h: torch.Tensor, layer_sizes: List[int], data: List = None) -> None: h = h.detach() h_layers = torch.split_with_sizes(h, layer_sizes, dim=2) h = [torch.norm(hl, dim=2) for hl in h_layers] h = torch.stack(h, dim=2) h = h.numpy() h = np.flip(h, axis=2) for n in range(h.shape[0]): _plot_h(h[n].T, data)
def parse_dynamic_params(params, channels, weight_nums, bias_nums, inds, concat=False): assert params.dim() == 2 assert len(weight_nums) == len(bias_nums) assert params.size(1) == sum(weight_nums) + sum(bias_nums) num_insts = params.size(0) num_layers = len(weight_nums) params_splits = list(torch.split_with_sizes( params, weight_nums + bias_nums, dim=1 )) weight_splits = params_splits[:num_layers] bias_splits = params_splits[num_layers:] multi_weight_splits = [[] for _ in inds] multi_bias_splits = [[] for _ in inds] for l in range(num_layers): if l < num_layers - 1: # out_channels x in_channels x 1 x 1 weight_splits[l] = weight_splits[l].reshape(num_insts, channels, -1, 1, 1) bias_splits[l] = bias_splits[l].reshape(num_insts, channels) for idx, ind in enumerate(inds): weight_splits_per_ind = weight_splits[l][ind] bias_splits_per_ind = bias_splits[l][ind] n, c, _, _, _ = weight_splits_per_ind.shape if n > 0: if concat and idx: multi_weight_splits[idx].append(weight_splits_per_ind) multi_bias_splits[idx].append(bias_splits_per_ind) else: multi_weight_splits[idx].append(weight_splits_per_ind.reshape(n * c, -1, 1, 1)) multi_bias_splits[idx].append(bias_splits_per_ind.reshape(n * c)) else: multi_weight_splits[idx].append([]) multi_bias_splits[idx].append([]) else: # out_channels x in_channels x 1 x 1 weight_splits[l] = weight_splits[l].reshape(num_insts, -1, 1, 1) bias_splits[l] = bias_splits[l].reshape(num_insts) for idx, ind in enumerate(inds): weight_splits_per_ind = weight_splits[l][ind] bias_splits_per_ind = bias_splits[l][ind] n, _, _, _ = weight_splits_per_ind.shape if n > 0: multi_weight_splits[idx].append(weight_splits_per_ind) multi_bias_splits[idx].append(bias_splits_per_ind) else: multi_weight_splits[idx].append([]) multi_bias_splits[idx].append([]) return multi_weight_splits, multi_bias_splits
def parse_dynamic_params(params, channels, weight_nums, bias_nums): assert params.dim() == 2 assert len(weight_nums) == len(bias_nums) assert params.size(1) == sum(weight_nums) + sum(bias_nums) num_insts = params.size(0) num_layers = len(weight_nums) """ in size: (10, 169) out size: torch.Size([10, 80]) torch.Size([10, 64]) torch.Size([10, 8]) torch.Size([10, 8]) torch.Size([10, 8]) torch.Size([10, 1]) """ params_splits = list( torch.split_with_sizes(params, weight_nums + bias_nums, dim=1)) weight_splits = params_splits[:num_layers] bias_splits = params_splits[num_layers:] for l in range(num_layers): if l < num_layers - 1: # out_channels x in_channels x 1 x 1 weight_splits[l] = weight_splits[l].contiguous().view( num_insts * channels, -1, 1, 1) bias_splits[l] = bias_splits[l].contiguous().view( num_insts * channels) else: # out_channels x in_channels x 1 x 1 weight_splits[l] = weight_splits[l].contiguous().view( num_insts * 1, -1, 1, 1) bias_splits[l] = bias_splits[l].contiguous().view( num_insts) """ out size: given num_insts = 10 weight_splits -> torch.Size([80, 10, 1, 1]) torch.Size([80, 8, 1, 1]) torch.Size([10, 8, 1, 1]) bias_splits -> torch.Size([80]) torch.Size([80]) torch.Size([10]) """ return weight_splits, bias_splits
def plot_zh(z: torch.Tensor, h: torch.Tensor, layer_sizes: List[int], data: List = None) -> None: z = z.detach().numpy() h = h.detach() h_layers = torch.split_with_sizes(h, layer_sizes, dim=2) h = [torch.norm(hl, dim=2) for hl in h_layers] h = torch.stack(h, dim=2) h = h.numpy() h = np.flip(h, axis=2) _, S, L = z.shape for n in range(h.shape[0]): zh = np.dstack((z[n], h[n])).reshape((S, 2 * L)) _plot_zh(zh.T, data)
def _parse_params( pred_params, in_channels, channels, num_classes, num_weight_params, num_bias_params, ): assert pred_params.dim() == 2 assert len(num_weight_params) == len(num_bias_params) assert pred_params.size( 1) == sum(num_weight_params) + sum(num_bias_params) num_instances = pred_params.size(0) num_layers = len(num_weight_params) params_splits = list( torch.split_with_sizes(pred_params, num_weight_params + num_bias_params, dim=1)) weight_splits = params_splits[:num_layers] bias_splits = params_splits[num_layers:] for l in range(num_layers): if l == 0: # input layer weight_splits[l] = weight_splits[l].reshape( num_instances, channels, in_channels) bias_splits[l] = bias_splits[l].reshape( num_instances, channels, 1) elif l < num_layers - 1: # intermediate layer weight_splits[l] = weight_splits[l].reshape( num_instances, channels, channels) bias_splits[l] = bias_splits[l].reshape( num_instances, channels, 1) else: # output layer weight_splits[l] = weight_splits[l].reshape( num_instances, num_classes, channels) bias_splits[l] = bias_splits[l].reshape( num_instances, num_classes, 1) return weight_splits, bias_splits
def get_subnetworks_params(attns, num_bases, channels): assert attns.dim() == 2 n_inst = attns.size(0) w0, b0, w1, b1, w2, b2 = torch.split_with_sizes( attns, [(2 + num_bases) * channels, channels, channels * channels, channels, channels * 17, 17], dim=1) # out_channels x in_channels x 1 x 1 w0 = w0.reshape(n_inst * channels, 2 + num_bases, 1, 1) b0 = b0.reshape(n_inst * channels) w1 = w1.reshape(n_inst * channels, channels, 1, 1) b1 = b1.reshape(n_inst * channels) w2 = w2.reshape(n_inst * 17, channels, 1, 1) b2 = b2.reshape(n_inst * 17) return [w0, w1, w2], [b0, b1, b2]
def forward(self, token_seq): mask = torch.ne(token_seq[:, :, 1], self.bert_tokenizer.pad_token_id) bert_output = self.bert(token_seq[:, :, 1], attention_mask=mask) bert_emb_tokens = bert_output.last_hidden_state emb_tokens = [] for i in range(len(token_seq)): # # groupby token_id # mask = torch.ne(input_xtokens[i, :, 1], 0) idxs, vals = torch.unique_consecutive(token_seq[i, :, 0][mask[i]], return_counts=True) token_emb_xtoken_split = torch.split_with_sizes( bert_emb_tokens[i][mask[i]], tuple(vals)) # token_xcontext = {k.item(): v for k, v in zip(idxs, [torch.mean(t, dim=0) for t in token_emb_xtokens])} emb_tokens.append( torch.stack( [torch.mean(t, dim=0) for t in token_emb_xtoken_split], dim=0)) return emb_tokens
def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]): """ Infinite loop: receive results from runtime and dispatch them to task Futures """ while True: payload = self.outputs_receiver.recv() if isinstance(payload, BaseException): raise payload else: batch_index, batch_outputs = payload # split batch into partitions for individual tasks batch_tasks = pending_batches.pop(batch_index) task_sizes = [self.get_task_size(task) for task in batch_tasks] outputs_per_task = zip(*(torch.split_with_sizes(array, task_sizes, dim=0) for array in batch_outputs)) # dispatch results to futures for task, task_outputs in zip(batch_tasks, outputs_per_task): task.future.set_result(tuple(task_outputs))
def forward(self, x): # x: [batch, n_frames, h, w] lens = [len(_x) for _x in x] xs = torch.cat(x, dim=0).unsqueeze(1) xs = self.features(xs) # [batch, n_frames, features] xs = torch.split_with_sizes(xs, lens, dim=0) xs = torch.nn.utils.rnn.pack_sequence(xs, enforce_sorted=False) x, l = torch.nn.utils.rnn.pad_packed_sequence(xs) x = x.permute(1, 2, 0) # [batch, features, seq] x = self.lstm(x) # [batch, features, seq] x = torch.cat([ x.mean(-1), x.max(-1)[0], ], dim=-1) return self.classifier(x)
def forward(self, xtoken_seq, char_seq, special_symbols, num_tokens, max_form_len, max_num_labels, target_chars=None): morph_scores, morph_states, _ = super().forward( xtoken_seq, char_seq, special_symbols, num_tokens, max_form_len, max_num_labels, target_chars) if target_chars is not None: morph_chars = target_chars else: morph_chars, _ = self.decode(morph_scores, []) morph_chars = morph_chars.squeeze(0) eos, sep = special_symbols['</s>'], special_symbols['<sep>'] eos_mask = torch.eq(morph_chars[:num_tokens], eos) eos_mask[:, -1] = True eos_mask = torch.bitwise_and( torch.eq(torch.cumsum(eos_mask, dim=1), 1), eos_mask) sep_mask = torch.eq(morph_chars[:num_tokens], sep) sep_mask = torch.bitwise_and( torch.eq(torch.cumsum(eos_mask, dim=1), 0), sep_mask) seg_state_mask = torch.bitwise_or(eos_mask, sep_mask) seg_states = morph_states[seg_state_mask] enc_seg_scores, _ = self.encoder(seg_states.unsqueeze(dim=1)) enc_seg_scores = self.seg_dropout(enc_seg_scores) label_scores = [] seg_sizes = torch.sum(seg_state_mask, dim=1) for classifier in self.classifiers: scores = classifier(enc_seg_scores) scores = torch.split_with_sizes(scores.squeeze(dim=1), tuple(seg_sizes)) scores = nn.utils.rnn.pad_sequence(scores, batch_first=True) fill_len = max_num_labels - scores.shape[1] label_scores.append(F.pad(scores, (0, 0, 0, fill_len))) return morph_scores, morph_states, label_scores
def update(self, output): relations = output[0] targets = output[1] sizes = relations.n_edges.tolist() for subjs, preds, objs, rel_scores in zip( torch.split_with_sizes( relations.object_classes[relations.relation_indexes[0]], sizes), torch.split_with_sizes(relations.predicate_classes, sizes), torch.split_with_sizes( relations.object_classes[relations.relation_indexes[1]], sizes), torch.split_with_sizes(relations.relation_scores, sizes), ): graph_hois = {} for subj, pred, obj, hoi_score in zip(subjs, preds, objs, rel_scores): if subj.item() != 0: continue hoi = (pred.item(), obj.item()) if hoi_score.item() > graph_hois.get(hoi, -1): graph_hois[hoi] = hoi_score.item() self.pred.append(graph_hois) sizes = targets.n_edges.tolist() for subjs, preds, objs in zip( torch.split_with_sizes( targets.object_classes[targets.relation_indexes[0]], sizes), torch.split_with_sizes(targets.predicate_classes, sizes), torch.split_with_sizes( targets.object_classes[targets.relation_indexes[1]], sizes), ): graph_hois = {} for subj, pred, obj in zip(subjs, preds, objs): if subj.item() != 0: continue hoi = (pred.item(), obj.item()) graph_hois[hoi] = True self.gt.append(graph_hois)
def parse_dynamic_params(self, params): """parse per-instances weights and biases Args: params (Tensor): per-location conv weights and biases, shape like (num_insts, sum(weight_nums)+sum(bias_nums)) Returns: weight_splits (List[Tensor]): contains per-layer conv weights shape like (num_insts * output_channels, input_channels_per_inst , 1, 1) bias_splits (List[Tensor]): contains per-layer conv biases shape like (num_insts * output_channels, input_channels_per_inst , 1, 1) """ assert params.dim() == 2 assert params.shape[1] == sum(self.weight_nums) + sum(self.bias_nums) num_insts = params.shape[0] params_splits = list( torch.split_with_sizes(params, self.weight_nums + self.bias_nums, dim=1)) weight_splits = params_splits[:self.num_layers] bias_splits = params_splits[self.num_layers:] for layer_ind in range(self.num_layers): if layer_ind < self.num_layers - 1: weight_splits[layer_ind] = weight_splits[layer_ind].reshape( num_insts * self.channels, -1, 1, 1) bias_splits[layer_ind] = bias_splits[layer_ind].reshape( num_insts * self.channels) else: weight_splits[layer_ind] = weight_splits[layer_ind].reshape( num_insts * 1, -1, 1, 1) bias_splits[layer_ind] = bias_splits[layer_ind].reshape( num_insts) return weight_splits, bias_splits
def parse_dynamic_params(params, channels, weight_nums, bias_nums): # params (n, 169) assert params.dim() == 2 assert len(weight_nums) == len(bias_nums) assert params.size(1) == sum(weight_nums) + sum(bias_nums) num_insts = params.size(0) num_layers = len(weight_nums) # weight: [80, 64, 8] # bias: [8, 8, 1] # 152 + 17 = 169 params_splits = list(torch.split_with_sizes( params, weight_nums + bias_nums, dim=1 )) # torch.Size([n, 169])[88, 72, 9] # params_splits [(n, 88), (n, 72), (n, 9)] # [torch.Size([421, 80]), torch.Size([421, 64]), torch.Size([421, 8]), # torch.Size([421, 8]), torch.Size([421, 8]), torch.Size([421, 1])] # [torch.Size([421, 80]), torch.Size([421, 64]), torch.Size([421, 8])] # [torch.Size([421, 8]), torch.Size([421, 8]), torch.Size([421, 1])] weight_splits = params_splits[:num_layers] bias_splits = params_splits[num_layers:] for l in range(num_layers): if l < num_layers - 1: # out_channels x in_channels x 1 x 1 weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1) bias_splits[l] = bias_splits[l].reshape(num_insts * channels) else: # out_channels x in_channels x 1 x 1 weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1) bias_splits[l] = bias_splits[l].reshape(num_insts) # [torch.Size([3368, 10, 1, 1]), torch.Size([3368, 8, 1, 1]), torch.Size([421, 8, 1, 1])] # [torch.Size([3368]), torch.Size([3368]), torch.Size([421])] return weight_splits, bias_splits