Exemple #1
0
    def _batch_forward(self, batch, train=True):
        x, _ = batch  # there is problem with y, so discard and construct ourselves.
        b, way_shot_query, c, h, w = x.size()

        n_queries = self.hparams.n_queries if train else 1
        x_ = x.view(b, self.hparams.n_ways, (self.hparams.n_shots + n_queries), c, h, w).contiguous()

        # construct y
        lbls = torch.arange(self.hparams.n_ways).to(x.device).view(1, self.hparams.n_ways, 1).contiguous()
        y_ = lbls.repeat(b, 1, (self.hparams.n_shots + n_queries)).contiguous()

        x_support, x_queries = torch.split_with_sizes(x_, split_sizes=[self.hparams.n_shots, n_queries], dim=2)
        y_support, y_queries = torch.split_with_sizes(y_, split_sizes=[self.hparams.n_shots, n_queries], dim=2)

        rep_s = self(x_support.contiguous().view(b, self.hparams.n_ways * self.hparams.n_shots, c, h, w))
        rep_q = self(x_queries.contiguous().view(b, self.hparams.n_ways * n_queries, c, h, w))

        q = rep_q.view(b, self.hparams.n_ways * n_queries, self.proj_dim)
        # centroid of same way/class
        s = rep_s.view(b, self.hparams.n_ways, self.hparams.n_shots, self.proj_dim).mean(dim=2)
        s = s.clone().permute(0, 2, 1).contiguous()

        cosine_scores = q @ s  # batch matrix multiplication
        logits = cosine_scores.view(-1, self.hparams.n_ways) / 0.1
        labels = y_queries.contiguous().view(-1)

        loss = F.cross_entropy(logits, labels)
        acc = (logits.argmax(dim=1) == labels).float().mean()
        return loss, acc
Exemple #2
0
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        batch, seq_len, input_size = input.size()

        out = torch.zeros(batch, seq_len, self.output_size, device=input.device)

        for t in range(seq_len):
            h = input[:, t]

            h_layers = torch.split_with_sizes(h, self.input_sizes, dim=1)

            s = torch.zeros(batch, self.embedding_size, device=h.device)
            for l, hl in enumerate(h_layers):
                #sg = h @ self.w[l]
                sg = self.linear_gates[l](h)
                g = torch.sigmoid(sg)

                s = s + self.embeddings[l](hl * g)

            s = self.ln_embeddings(s)
            he = self.activation(s)

            fnn = self.fnn(he)
            out[:, t] = self.output(fnn)

        return out
Exemple #3
0
    def forward(self, x):
        # x: [batch, n_frames, h, w]

        lens = [len(_x) for _x in x]
        xs = torch.cat(x, dim=0).unsqueeze(1)  # [batch*n_frames, 1, h, w]
        xs = self.features(xs)  # [batch*n_frames, features]
        xs = torch.split_with_sizes(xs, lens,
                                    dim=0)  # [batch, n_frames, features]

        xs = torch.nn.utils.rnn.pack_sequence(
            xs, enforce_sorted=False)  # [n_frames, batch, features]
        x, _ = self.lstm(xs)  # [seq, batch, features]

        x, l = torch.nn.utils.rnn.pad_packed_sequence(x)
        l = l.cuda()
        mask = torch.arange(
            x.size(0)).cuda().unsqueeze(-1).unsqueeze(-1).expand(x.size())
        l_exp = l.unsqueeze(0).unsqueeze(-1).expand(x.size())
        mask = (mask < l_exp)

        x_sum = x.sum(0) / l[:, None]  # [batch, features]

        x[~mask] = float('-inf')
        x_max = x.max(0)[0]  # [batch, features]

        x = torch.cat([
            x_sum,
            x_max,
        ], dim=-1)

        return self.classifier(x)
Exemple #4
0
    def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
        """ Infinite loop: receive results from runtime and dispatch them to task Futures """

        try:
            while True:
                logger.debug(f"{self.uid} waiting for results from runtime")
                payload = self.outputs_receiver.recv()
                if isinstance(payload, BaseException):
                    raise payload
                else:
                    batch_index, batch_outputs = payload
                logger.debug(f"{self.uid}, batch {batch_index}: got results")

                # split batch into partitions for individual tasks
                batch_tasks = pending_batches.pop(batch_index)
                task_sizes = [self.get_task_size(task) for task in batch_tasks]
                outputs_per_task = zip(
                    *(torch.split_with_sizes(tensor, task_sizes, dim=0)
                      for tensor in batch_outputs))
                logger.debug(
                    f"{self.uid}, batch {batch_index}: sending outputs to handlers"
                )

                # dispatch results to futures
                for task, task_outputs in zip(batch_tasks, outputs_per_task):
                    task.future.set_result(tuple(task_outputs))
        except KeyboardInterrupt:
            logger.debug(f"Caught KeyboardInterrupt, shutting down")
Exemple #5
0
 def decode(self, scores) -> torch.Tensor:
     decoded_labels = torch.argmax(scores, dim=-1)
     if self.crf is not None:
         crf_scores, crf_tags, token_masks = crf_prepare(
             scores, decoded_labels)
         crf_masks = torch.ne(crf_tags, 0).bool()
         crf_decoded_labels = self.crf.viterbi_tags(logits=crf_scores,
                                                    mask=crf_masks)
         for labels, crf_labels, token_mask in zip(decoded_labels,
                                                   crf_decoded_labels,
                                                   token_masks):
             idxs_vals = [
                 torch.unique_consecutive(mask, return_counts=True)
                 for mask in token_mask
             ]
             idxs = torch.cat([idx for idx, _ in idxs_vals])
             vals = torch.cat([val for _, val in idxs_vals])
             decoded_token_tags = torch.split_with_sizes(
                 torch.tensor(crf_labels[0]), tuple(vals[idxs]))
             # TODO: this doesn't do the right thing if a token is decoded as all pads (0)
             # In such a case the first mask is all False and the above split doesn't indicate that this token
             # should be skipped
             for idx, token_tags in enumerate(decoded_token_tags):
                 labels[idx, :len(token_tags)] = token_tags
     return decoded_labels
Exemple #6
0
def parse_dynamic_params(params, channels, weight_nums, bias_nums):
    assert params.dim() == 2
    assert len(weight_nums) == len(bias_nums)
    assert params.size(1) == sum(weight_nums) + sum(bias_nums)

    num_insts = params.size(0)
    num_layers = len(weight_nums)

    params_splits = list(
        torch.split_with_sizes(params, weight_nums + bias_nums, dim=1))

    weight_splits = params_splits[:num_layers]
    bias_splits = params_splits[num_layers:]

    for l in range(num_layers):
        if l < num_layers - 1:
            # out_channels x in_channels x 1 x 1
            weight_splits[l] = weight_splits[l].reshape(
                num_insts * channels, -1, 1, 1)
            bias_splits[l] = bias_splits[l].reshape(num_insts * channels)
        else:
            # out_channels x in_channels x 1 x 1
            weight_splits[l] = weight_splits[l].reshape(
                num_insts * 1, -1, 1, 1)
            bias_splits[l] = bias_splits[l].reshape(num_insts)

    return weight_splits, bias_splits
Exemple #7
0
    def forward(self, x):
        # x: [batch, n_frames, h, w]

        lens = [len(_x) for _x in x]
        xs = torch.cat(x, dim=0).unsqueeze(1)
        xs = self.features(xs)  # [batch, n_frames, features]
        xs = torch.split_with_sizes(xs, lens, dim=0)

        xs = torch.nn.utils.rnn.pack_sequence(xs, enforce_sorted=False)
        x, l = torch.nn.utils.rnn.pad_packed_sequence(xs)
        l = l.cuda()
        mask = torch.arange(
            x.size(0)).cuda().unsqueeze(-1).unsqueeze(-1).expand(x.size())
        l_exp = l.unsqueeze(0).unsqueeze(-1).expand(x.size())
        mask = (mask < l_exp)

        #         attn_mask = _generate_square_subsequent_mask(len(x)).cuda()
        #         x = x*math.sqrt(self.n_f)
        #         x = self.pos(x)

        # TODO add src_key_padding_mask
        x = self.lstm(x)  #, attn_mask [seq, batch, features]

        x_sum = x.sum(0) / l[:, None]  # [batch, features]

        x[~mask] = float('-inf')
        x_max = x.max(0)[0]  # [batch, features]

        x = torch.cat([
            x_sum,
            x_max,
        ], dim=-1)

        return self.classifier(x)
Exemple #8
0
def groupby_apply(
    keys: torch.Tensor, values: torch.Tensor, bins: int = 95, reduction: str = "mean", return_histogram: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Groupby apply for torch tensors

    Args:
        keys: tensor of groups (``0`` to ``bins``)
        values: values to aggregate - same size as keys
        bins: total number of groups
        reduction: either "mean" or "sum"
        return_histogram: if to return histogram on top

    Returns:
        tensor of size ``bins`` with aggregated values and optionally with counts of values
    """
    if reduction == "mean":
        reduce = torch.mean
    elif reduction == "sum":
        reduce = torch.sum
    else:
        raise ValueError(f"Unknown reduction '{reduction}'")
    uniques, counts = keys.unique(return_counts=True)
    groups = torch.stack([reduce(item) for item in torch.split_with_sizes(values, tuple(counts))])
    reduced = torch.zeros(bins, dtype=values.dtype, device=values.device).scatter(dim=0, index=uniques, src=groups)
    if return_histogram:
        hist = torch.zeros(bins, dtype=torch.long, device=values.device).scatter(dim=0, index=uniques, src=counts)
        return reduced, hist
    else:
        return reduced
Exemple #9
0
    def as_tuple(self) -> Tuple[torch.Tensor]:
        """Convenience method to get a tuple of non-aggregated edge features.

        Better than building a tuple from the iterator: `tuple(batch.edge_features_by_graph)`"""
        return torch.split_with_sizes(self._batch.edge_features,
                                      self._batch.num_edges_by_graph.tolist(),
                                      dim=0)
Exemple #10
0
    def crossentropy_minimize(self,
                              u_logits,
                              u_images,
                              l_images,
                              l_labels,
                              u_labels=None):
        """Cross-entropy optimization step implementation for TPU."""
        batch_size = self.params.batch_size
        guessed_label = self.guess_label(u_logits)
        self.guessed_label = guessed_label

        guessed_label = torch.reshape(guessed_label.detach(),
                                      shape=(-1, self.params.num_classes))

        l_labels = torch.reshape(onehot(l_labels, self.params.num_classes),
                                 shape=(-1, self.params.num_classes))

        augment_images, augment_labels = self.augment(
            l_images, u_images, l_labels, guessed_label * self.params.nu,
            self.params.beta)

        logit = self.net(augment_images)

        zbs = batch_size * 2
        halfzbs = batch_size

        split_pos = [l_images.shape[0], halfzbs, halfzbs]

        logit = [
            logit_norm(lgt)
            for lgt in torch.split_with_sizes(logit, split_pos)
        ]
        u_logit = torch.cat(logit[1:], dim=0)

        split_pos = [l_images.shape[0], zbs]
        l_augment_labels, u_augment_labels = torch.split_with_sizes(
            augment_labels, split_pos)

        u_loss = tf.losses.softmax_cross_entropy(u_augment_labels, u_logit)
        l_loss = tf.losses.softmax_cross_entropy(l_augment_labels, logit[0])

        loss = tf.math.add(l_loss,
                           u_loss * FLAGS.ce_factor,
                           name='crossentropy_minimization_loss')

        return loss
Exemple #11
0
 def forward(self, z, x, g):
     input = torch.cat((z, x, g), dim=-1)
     for i in range(len(self.model)):
         input = self.model[i](input)
     actions, states = torch.split_with_sizes(self.out_layer(input),
                                              [self.ac_dim, self.state_dim],
                                              dim=-1)
     return torch.tanh(actions), states, None
Exemple #12
0
def restore_from_parts(
        chunks: Sequence[torch.Tensor],
        shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:
    """ restores the original tensor shapes from chunks obtained by split_into_chunks """
    flat_tensor = torch.cat(tuple(chunks))
    result_sizes = tuple(map(torch.Size.numel, shapes))
    flat_original_tensors = torch.split_with_sizes(flat_tensor, result_sizes)
    return tuple(map(torch.Tensor.reshape, flat_original_tensors, shapes))
Exemple #13
0
 def _batches(self) -> Iterator[List[int]]:
     total_samples = len(self.dataset)
     batches = torch.split_with_sizes(torch.arange(total_samples),
                                      self._get_lengths(total_samples))
     sort_keys = torch.randperm(len(batches)).tolist()
     # here we ensure that the shortest batch is last
     yield from sorted([batches[i].tolist() for i in sort_keys],
                       key=len,
                       reverse=True)
    def forward(self, x, train=True, mode='meta'):
        b, way_shot_query, c, h, w = x.size()

        n_queries = self.n_queries if train else 1
        x_ = x.view(b, self.n_ways, (self.n_shots + n_queries), c, h,
                    w).contiguous()

        # construct y
        lbls = torch.arange(self.n_ways).to(x.device).view(1, self.n_ways,
                                                           1).contiguous()
        y_ = lbls.repeat(b, 1, (self.n_shots + n_queries)).contiguous()

        x_support, x_queries = torch.split_with_sizes(
            x_, split_sizes=[self.n_shots, n_queries], dim=2)
        y_support, y_queries = torch.split_with_sizes(
            y_, split_sizes=[self.n_shots, n_queries], dim=2)

        rep_s = self.x_forward(x_support.contiguous().view(
            b, self.n_ways * self.n_shots, c, h, w))
        rep_q = self.x_forward(x_queries.contiguous().view(
            b, self.n_ways * n_queries, c, h, w))

        q = rep_q.view(b, self.n_ways * n_queries, self.proj_dim)
        # centroid of same way/class
        s = rep_s.view(b, self.n_ways, self.n_shots, self.proj_dim).mean(dim=2)
        s = s.clone().permute(0, 2, 1).contiguous()

        cosine_scores = q @ s  # batch matrix multiplication
        logits = cosine_scores.view(-1, self.n_ways)
        labels = y_queries.contiguous().view(-1)

        if mode == 'meta':
            logits = logits / 0.1  # scale with temperature=0.1
        elif mode == 'margin':
            margin = 1.0
            masked_margin = margin * torch.ones_like(cosine_scores).scatter_(
                dim=1, index=labels.unsqueeze(dim=1), value=0.)
            logits = logits + masked_margin
        else:
            raise Exception('score mode {} not available'.format(mode))

        loss = F.cross_entropy(logits, labels)
        acc = (logits.argmax(dim=1) == labels).float().mean()
        return loss, acc
Exemple #15
0
    def splitby(data_tensor: torch.Tensor, group_indices: torch.Tensor, split_dim=0) -> List[torch.Tensor]:
        # https://twitter.com/jeremyphoward/status/1185062637341593600
        idxs, vals = torch.unique(group_indices, return_counts=True)
        split_arrays = torch.split_with_sizes(data_tensor, tuple(vals), dim=split_dim)

        doc_tensors = []
        for idx, split_array in sorted(zip(idxs, split_arrays), key=lambda t: t[0]):
            doc_tensors.append(split_array)

        return doc_tensors
def aggTensorBy(tensor, by, fun):
    """
    Group by analogue for pytorch tensor
    :param tensor: tensor to aggregate by
    :param by: 1d tensor with sorted (!) indexes to aggregate the tensor by
    :param fun: aggregation function
    :return: tuple (unique indexes, aggregated tensor by by)
    """
    idxs, vals = torch.unique(by, return_counts=True)
    vs = torch.split_with_sizes(tensor, tuple(vals))
    return idxs, torch.stack([fun(v) for v in vs])
Exemple #17
0
    def groupby(data_tensor: torch.Tensor, doc_inds: torch.Tensor, split_dim=0) -> List[torch.Tensor]:
        # https://twitter.com/jeremyphoward/status/1185062637341593600
        idxs, vals = torch.unique(doc_inds, return_counts=True)
        split_arrays = torch.split_with_sizes(data_tensor, tuple(vals), dim=split_dim)

        doc_arrays = [None] * max(idxs)
        for idx, split_array in zip(idxs, split_arrays):
            doc_arrays[idx.item()] = split_array

        doc_arrays = [e for e in doc_arrays if e is not None]
        return doc_arrays
Exemple #18
0
def plot_h(h: torch.Tensor, layer_sizes: List[int], data: List = None) -> None:
    h = h.detach()

    h_layers = torch.split_with_sizes(h, layer_sizes, dim=2)
    h = [torch.norm(hl, dim=2) for hl in h_layers]
    h = torch.stack(h, dim=2)
    h = h.numpy()
    h = np.flip(h, axis=2)

    for n in range(h.shape[0]):
        _plot_h(h[n].T, data)
Exemple #19
0
def parse_dynamic_params(params, channels, weight_nums, bias_nums, inds, concat=False):
    assert params.dim() == 2
    assert len(weight_nums) == len(bias_nums)
    assert params.size(1) == sum(weight_nums) + sum(bias_nums)

    num_insts = params.size(0)
    num_layers = len(weight_nums)

    params_splits = list(torch.split_with_sizes(
        params, weight_nums + bias_nums, dim=1
    ))

    weight_splits = params_splits[:num_layers]
    bias_splits = params_splits[num_layers:]

    multi_weight_splits = [[] for _ in inds]
    multi_bias_splits = [[] for _ in inds]

    for l in range(num_layers):
        if l < num_layers - 1:
            # out_channels x in_channels x 1 x 1
            weight_splits[l] = weight_splits[l].reshape(num_insts, channels, -1, 1, 1)
            bias_splits[l] = bias_splits[l].reshape(num_insts, channels)
            for idx, ind in enumerate(inds):
                weight_splits_per_ind = weight_splits[l][ind]
                bias_splits_per_ind = bias_splits[l][ind]
                n, c, _, _, _ = weight_splits_per_ind.shape
                if n > 0:
                    if concat and idx:
                        multi_weight_splits[idx].append(weight_splits_per_ind)
                        multi_bias_splits[idx].append(bias_splits_per_ind)
                    else:
                        multi_weight_splits[idx].append(weight_splits_per_ind.reshape(n * c, -1, 1, 1))
                        multi_bias_splits[idx].append(bias_splits_per_ind.reshape(n * c))
                else:
                    multi_weight_splits[idx].append([])
                    multi_bias_splits[idx].append([])
        else:
            # out_channels x in_channels x 1 x 1
            weight_splits[l] = weight_splits[l].reshape(num_insts, -1, 1, 1)
            bias_splits[l] = bias_splits[l].reshape(num_insts)
            for idx, ind in enumerate(inds):
                weight_splits_per_ind = weight_splits[l][ind]
                bias_splits_per_ind = bias_splits[l][ind]
                n, _, _, _ = weight_splits_per_ind.shape
                if n > 0:
                    multi_weight_splits[idx].append(weight_splits_per_ind)
                    multi_bias_splits[idx].append(bias_splits_per_ind)
                else:
                    multi_weight_splits[idx].append([])
                    multi_bias_splits[idx].append([])

    return multi_weight_splits, multi_bias_splits
Exemple #20
0
        def parse_dynamic_params(params, channels, weight_nums, bias_nums):
            assert params.dim() == 2
            assert len(weight_nums) == len(bias_nums)
            assert params.size(1) == sum(weight_nums) + sum(bias_nums)

            num_insts = params.size(0)
            num_layers = len(weight_nums)
            """
            in size: (10, 169)
            out size: 
            torch.Size([10, 80])
            torch.Size([10, 64])
            torch.Size([10, 8])
            torch.Size([10, 8])
            torch.Size([10, 8])
            torch.Size([10, 1])
            """
            params_splits = list(
                torch.split_with_sizes(params, weight_nums + bias_nums, dim=1))

            weight_splits = params_splits[:num_layers]
            bias_splits = params_splits[num_layers:]

            for l in range(num_layers):
                if l < num_layers - 1:
                    # out_channels x in_channels x 1 x 1
                    weight_splits[l] = weight_splits[l].contiguous().view(
                        num_insts * channels, -1, 1, 1)
                    bias_splits[l] = bias_splits[l].contiguous().view(
                        num_insts * channels)
                else:
                    # out_channels x in_channels x 1 x 1
                    weight_splits[l] = weight_splits[l].contiguous().view(
                        num_insts * 1, -1, 1, 1)
                    bias_splits[l] = bias_splits[l].contiguous().view(
                        num_insts)
            """
            out size: given num_insts = 10
            weight_splits ->
            torch.Size([80, 10, 1, 1])
            torch.Size([80, 8, 1, 1])
            torch.Size([10, 8, 1, 1])
            bias_splits ->
            torch.Size([80])
            torch.Size([80])
            torch.Size([10])
            """
            return weight_splits, bias_splits
Exemple #21
0
def plot_zh(z: torch.Tensor,
            h: torch.Tensor,
            layer_sizes: List[int],
            data: List = None) -> None:
    z = z.detach().numpy()

    h = h.detach()
    h_layers = torch.split_with_sizes(h, layer_sizes, dim=2)
    h = [torch.norm(hl, dim=2) for hl in h_layers]
    h = torch.stack(h, dim=2)
    h = h.numpy()
    h = np.flip(h, axis=2)

    _, S, L = z.shape
    for n in range(h.shape[0]):
        zh = np.dstack((z[n], h[n])).reshape((S, 2 * L))
        _plot_zh(zh.T, data)
Exemple #22
0
    def _parse_params(
        pred_params,
        in_channels,
        channels,
        num_classes,
        num_weight_params,
        num_bias_params,
    ):
        assert pred_params.dim() == 2
        assert len(num_weight_params) == len(num_bias_params)
        assert pred_params.size(
            1) == sum(num_weight_params) + sum(num_bias_params)

        num_instances = pred_params.size(0)
        num_layers = len(num_weight_params)

        params_splits = list(
            torch.split_with_sizes(pred_params,
                                   num_weight_params + num_bias_params,
                                   dim=1))

        weight_splits = params_splits[:num_layers]
        bias_splits = params_splits[num_layers:]

        for l in range(num_layers):
            if l == 0:
                # input layer
                weight_splits[l] = weight_splits[l].reshape(
                    num_instances, channels, in_channels)
                bias_splits[l] = bias_splits[l].reshape(
                    num_instances, channels, 1)
            elif l < num_layers - 1:
                # intermediate layer
                weight_splits[l] = weight_splits[l].reshape(
                    num_instances, channels, channels)
                bias_splits[l] = bias_splits[l].reshape(
                    num_instances, channels, 1)
            else:
                # output layer
                weight_splits[l] = weight_splits[l].reshape(
                    num_instances, num_classes, channels)
                bias_splits[l] = bias_splits[l].reshape(
                    num_instances, num_classes, 1)

        return weight_splits, bias_splits
Exemple #23
0
def get_subnetworks_params(attns, num_bases, channels):
    assert attns.dim() == 2
    n_inst = attns.size(0)

    w0, b0, w1, b1, w2, b2 = torch.split_with_sizes(
        attns, [(2 + num_bases) * channels, channels, channels * channels,
                channels, channels * 17, 17],
        dim=1)

    # out_channels x in_channels x 1 x 1
    w0 = w0.reshape(n_inst * channels, 2 + num_bases, 1, 1)
    b0 = b0.reshape(n_inst * channels)
    w1 = w1.reshape(n_inst * channels, channels, 1, 1)
    b1 = b1.reshape(n_inst * channels)
    w2 = w2.reshape(n_inst * 17, channels, 1, 1)
    b2 = b2.reshape(n_inst * 17)

    return [w0, w1, w2], [b0, b1, b2]
Exemple #24
0
 def forward(self, token_seq):
     mask = torch.ne(token_seq[:, :, 1], self.bert_tokenizer.pad_token_id)
     bert_output = self.bert(token_seq[:, :, 1], attention_mask=mask)
     bert_emb_tokens = bert_output.last_hidden_state
     emb_tokens = []
     for i in range(len(token_seq)):
         # # groupby token_id
         # mask = torch.ne(input_xtokens[i, :, 1], 0)
         idxs, vals = torch.unique_consecutive(token_seq[i, :, 0][mask[i]],
                                               return_counts=True)
         token_emb_xtoken_split = torch.split_with_sizes(
             bert_emb_tokens[i][mask[i]], tuple(vals))
         # token_xcontext = {k.item(): v for k, v in zip(idxs, [torch.mean(t, dim=0) for t in token_emb_xtokens])}
         emb_tokens.append(
             torch.stack(
                 [torch.mean(t, dim=0) for t in token_emb_xtoken_split],
                 dim=0))
     return emb_tokens
Exemple #25
0
    def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
        """ Infinite loop: receive results from runtime and dispatch them to task Futures """

        while True:
            payload = self.outputs_receiver.recv()
            if isinstance(payload, BaseException):
                raise payload
            else:
                batch_index, batch_outputs = payload

            # split batch into partitions for individual tasks
            batch_tasks = pending_batches.pop(batch_index)
            task_sizes = [self.get_task_size(task) for task in batch_tasks]
            outputs_per_task = zip(*(torch.split_with_sizes(array, task_sizes, dim=0) for array in batch_outputs))

            # dispatch results to futures
            for task, task_outputs in zip(batch_tasks, outputs_per_task):
                task.future.set_result(tuple(task_outputs))
Exemple #26
0
    def forward(self, x):
        # x: [batch, n_frames, h, w]

        lens = [len(_x) for _x in x]
        xs = torch.cat(x, dim=0).unsqueeze(1)
        xs = self.features(xs)  # [batch, n_frames, features]
        xs = torch.split_with_sizes(xs, lens, dim=0)

        xs = torch.nn.utils.rnn.pack_sequence(xs, enforce_sorted=False)
        x, l = torch.nn.utils.rnn.pad_packed_sequence(xs)

        x = x.permute(1, 2, 0)  # [batch, features, seq]
        x = self.lstm(x)  # [batch, features, seq]

        x = torch.cat([
            x.mean(-1),
            x.max(-1)[0],
        ], dim=-1)

        return self.classifier(x)
Exemple #27
0
    def forward(self,
                xtoken_seq,
                char_seq,
                special_symbols,
                num_tokens,
                max_form_len,
                max_num_labels,
                target_chars=None):
        morph_scores, morph_states, _ = super().forward(
            xtoken_seq, char_seq, special_symbols, num_tokens, max_form_len,
            max_num_labels, target_chars)
        if target_chars is not None:
            morph_chars = target_chars
        else:
            morph_chars, _ = self.decode(morph_scores, [])
            morph_chars = morph_chars.squeeze(0)
        eos, sep = special_symbols['</s>'], special_symbols['<sep>']
        eos_mask = torch.eq(morph_chars[:num_tokens], eos)
        eos_mask[:, -1] = True
        eos_mask = torch.bitwise_and(
            torch.eq(torch.cumsum(eos_mask, dim=1), 1), eos_mask)

        sep_mask = torch.eq(morph_chars[:num_tokens], sep)
        sep_mask = torch.bitwise_and(
            torch.eq(torch.cumsum(eos_mask, dim=1), 0), sep_mask)

        seg_state_mask = torch.bitwise_or(eos_mask, sep_mask)
        seg_states = morph_states[seg_state_mask]
        enc_seg_scores, _ = self.encoder(seg_states.unsqueeze(dim=1))
        enc_seg_scores = self.seg_dropout(enc_seg_scores)
        label_scores = []
        seg_sizes = torch.sum(seg_state_mask, dim=1)
        for classifier in self.classifiers:
            scores = classifier(enc_seg_scores)
            scores = torch.split_with_sizes(scores.squeeze(dim=1),
                                            tuple(seg_sizes))
            scores = nn.utils.rnn.pad_sequence(scores, batch_first=True)
            fill_len = max_num_labels - scores.shape[1]
            label_scores.append(F.pad(scores, (0, 0, 0, fill_len)))
        return morph_scores, morph_states, label_scores
Exemple #28
0
    def update(self, output):
        relations = output[0]
        targets = output[1]

        sizes = relations.n_edges.tolist()
        for subjs, preds, objs, rel_scores in zip(
                torch.split_with_sizes(
                    relations.object_classes[relations.relation_indexes[0]],
                    sizes),
                torch.split_with_sizes(relations.predicate_classes, sizes),
                torch.split_with_sizes(
                    relations.object_classes[relations.relation_indexes[1]],
                    sizes),
                torch.split_with_sizes(relations.relation_scores, sizes),
        ):
            graph_hois = {}
            for subj, pred, obj, hoi_score in zip(subjs, preds, objs,
                                                  rel_scores):
                if subj.item() != 0:
                    continue
                hoi = (pred.item(), obj.item())
                if hoi_score.item() > graph_hois.get(hoi, -1):
                    graph_hois[hoi] = hoi_score.item()
            self.pred.append(graph_hois)

        sizes = targets.n_edges.tolist()
        for subjs, preds, objs in zip(
                torch.split_with_sizes(
                    targets.object_classes[targets.relation_indexes[0]],
                    sizes),
                torch.split_with_sizes(targets.predicate_classes, sizes),
                torch.split_with_sizes(
                    targets.object_classes[targets.relation_indexes[1]],
                    sizes),
        ):
            graph_hois = {}
            for subj, pred, obj in zip(subjs, preds, objs):
                if subj.item() != 0:
                    continue
                hoi = (pred.item(), obj.item())
                graph_hois[hoi] = True
            self.gt.append(graph_hois)
Exemple #29
0
    def parse_dynamic_params(self, params):
        """parse per-instances weights and biases

        Args:
            params (Tensor): per-location conv weights and biases,
                shape like (num_insts, sum(weight_nums)+sum(bias_nums))

        Returns:
            weight_splits (List[Tensor]): contains per-layer conv weights
                shape like (num_insts * output_channels, input_channels_per_inst , 1, 1)
            bias_splits (List[Tensor]): contains per-layer conv biases
                shape like (num_insts * output_channels, input_channels_per_inst , 1, 1)
        """
        assert params.dim() == 2
        assert params.shape[1] == sum(self.weight_nums) + sum(self.bias_nums)

        num_insts = params.shape[0]
        params_splits = list(
            torch.split_with_sizes(params,
                                   self.weight_nums + self.bias_nums,
                                   dim=1))

        weight_splits = params_splits[:self.num_layers]
        bias_splits = params_splits[self.num_layers:]

        for layer_ind in range(self.num_layers):
            if layer_ind < self.num_layers - 1:
                weight_splits[layer_ind] = weight_splits[layer_ind].reshape(
                    num_insts * self.channels, -1, 1, 1)
                bias_splits[layer_ind] = bias_splits[layer_ind].reshape(
                    num_insts * self.channels)
            else:
                weight_splits[layer_ind] = weight_splits[layer_ind].reshape(
                    num_insts * 1, -1, 1, 1)
                bias_splits[layer_ind] = bias_splits[layer_ind].reshape(
                    num_insts)

        return weight_splits, bias_splits
def parse_dynamic_params(params, channels, weight_nums, bias_nums):
    # params (n, 169)
    assert params.dim() == 2
    assert len(weight_nums) == len(bias_nums)
    assert params.size(1) == sum(weight_nums) + sum(bias_nums)

    num_insts = params.size(0)
    num_layers = len(weight_nums)
    # weight: [80, 64, 8]
    # bias: [8, 8, 1]
    # 152 + 17 = 169
    params_splits = list(torch.split_with_sizes(
        params, weight_nums + bias_nums, dim=1
    ))
    # torch.Size([n, 169])[88, 72, 9]
    # params_splits [(n, 88), (n, 72), (n, 9)]
    # [torch.Size([421, 80]), torch.Size([421, 64]), torch.Size([421, 8]),
    #  torch.Size([421, 8]), torch.Size([421, 8]), torch.Size([421, 1])]
    # [torch.Size([421, 80]), torch.Size([421, 64]), torch.Size([421, 8])]
    # [torch.Size([421, 8]), torch.Size([421, 8]), torch.Size([421, 1])]

    weight_splits = params_splits[:num_layers]
    bias_splits = params_splits[num_layers:]

    for l in range(num_layers):
        if l < num_layers - 1:
            # out_channels x in_channels x 1 x 1
            weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1)
            bias_splits[l] = bias_splits[l].reshape(num_insts * channels)
        else:
            # out_channels x in_channels x 1 x 1
            weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1)
            bias_splits[l] = bias_splits[l].reshape(num_insts)
    # [torch.Size([3368, 10, 1, 1]), torch.Size([3368, 8, 1, 1]), torch.Size([421, 8, 1, 1])]
    # [torch.Size([3368]), torch.Size([3368]), torch.Size([421])]

    return weight_splits, bias_splits