示例#1
0
# Let us begin by choosing a simple dataset and problem to allow us to focus on how the hybrid
# model is constructed. Our objective is to classify points generated from scikit-learn's
# binary-class
# `make_moons() <https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html>`__ dataset:

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_moons

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

X, y = make_moons(n_samples=200, noise=0.1)
y_ = torch.unsqueeze(torch.tensor(y), 1)  # used for one-hot encoded labels
y_hot = torch.scatter(torch.zeros((200, 2)), 1, y_, 1)

c = ["#1f77b4" if y_ == 0 else "#ff7f0e" for y_ in y]  # colours for each class
plt.axis("off")
plt.scatter(X[:, 0], X[:, 1], c=c)
plt.show()

###############################################################################
# Defining a QNode
# ----------------
#
# Our next step is to define the QNode that we want to interface with ``torch.nn``. Any
# combination of device, operations and measurements that is valid in PennyLane can be used to
# compose the QNode. However, the QNode arguments must satisfy additional :doc:`conditions
# <code/api/pennylane.qnn.TorchLayer>` including having an argument called ``inputs``. All other
# arguments must be arrays or tensors and are treated as trainable weights in the model. We fix a
示例#2
0
    def forward(self,
                xs,
                ilens,
                ys,
                labels,
                olens,
                spembs=None,
                *args,
                **kwargs):
        """Calculate forward propagation.

        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional):
                Batch of speaker embedding vectors (B, spk_embed_dim).

        Returns:
            Tensor: Loss value.

        """
        # remove unnecessary padded part (for multi-gpus)
        max_ilen = max(ilens)
        max_olen = max(olens)
        if max_ilen != xs.shape[1]:
            xs = xs[:, :max_ilen]
        if max_olen != ys.shape[1]:
            ys = ys[:, :max_olen]
            labels = labels[:, :max_olen]

        # forward encoder
        x_masks = self._source_mask(ilens)
        hs, h_masks = self.encoder(xs, x_masks)

        # integrate speaker embedding
        if self.spk_embed_dim is not None:
            hs = self._integrate_with_spk_embed(hs, spembs)

        # thin out frames for reduction factor (B, Lmax, odim) ->  (B, Lmax//r, odim)
        if self.reduction_factor > 1:
            ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor]
            olens_in = olens.new(
                [olen // self.reduction_factor for olen in olens])
        else:
            ys_in, olens_in = ys, olens

        # add first zero frame and remove last frame for auto-regressive
        ys_in = self._add_first_frame_and_remove_last_frame(ys_in)

        # forward decoder
        y_masks = self._target_mask(olens_in)
        zs, _ = self.decoder(ys_in, y_masks, hs, h_masks)
        # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim)
        before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
        # (B, Lmax//r, r) -> (B, Lmax//r * r)
        logits = self.prob_out(zs).view(zs.size(0), -1)

        # postnet -> (B, Lmax//r * r, odim)
        if self.postnet is None:
            after_outs = before_outs
        else:
            after_outs = before_outs + self.postnet(before_outs.transpose(
                1, 2)).transpose(1, 2)

        # modifiy mod part of groundtruth
        if self.reduction_factor > 1:
            assert olens.ge(self.reduction_factor).all(
            ), "Output length must be greater than or equal to reduction factor."
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_olen = max(olens)
            ys = ys[:, :max_olen]
            labels = labels[:, :max_olen]
            labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1),
                                   1.0)  # see #3388

        # calculate loss values
        l1_loss, l2_loss, bce_loss = self.criterion(after_outs, before_outs,
                                                    logits, ys, labels, olens)
        if self.loss_type == "L1":
            loss = l1_loss + bce_loss
        elif self.loss_type == "L2":
            loss = l2_loss + bce_loss
        elif self.loss_type == "L1+L2":
            loss = l1_loss + l2_loss + bce_loss
        else:
            raise ValueError("unknown --loss-type " + self.loss_type)
        report_keys = [
            {
                "l1_loss": l1_loss.item()
            },
            {
                "l2_loss": l2_loss.item()
            },
            {
                "bce_loss": bce_loss.item()
            },
            {
                "loss": loss.item()
            },
        ]

        # calculate guided attention loss
        if self.use_guided_attn_loss:
            # calculate for encoder
            if "encoder" in self.modules_applied_guided_attn:
                att_ws = []
                for idx, layer_idx in enumerate(
                        reversed(range(len(self.encoder.encoders)))):
                    att_ws += [
                        self.encoder.encoders[layer_idx].self_attn.
                        attn[:, :self.num_heads_applied_guided_attn]
                    ]
                    if idx + 1 == self.num_layers_applied_guided_attn:
                        break
                att_ws = torch.cat(att_ws, dim=1)  # (B, H*L, T_in, T_in)
                enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens)
                loss = loss + enc_attn_loss
                report_keys += [{"enc_attn_loss": enc_attn_loss.item()}]
            # calculate for decoder
            if "decoder" in self.modules_applied_guided_attn:
                att_ws = []
                for idx, layer_idx in enumerate(
                        reversed(range(len(self.decoder.decoders)))):
                    att_ws += [
                        self.decoder.decoders[layer_idx].self_attn.
                        attn[:, :self.num_heads_applied_guided_attn]
                    ]
                    if idx + 1 == self.num_layers_applied_guided_attn:
                        break
                att_ws = torch.cat(att_ws, dim=1)  # (B, H*L, T_out, T_out)
                dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in)
                loss = loss + dec_attn_loss
                report_keys += [{"dec_attn_loss": dec_attn_loss.item()}]
            # calculate for encoder-decoder
            if "encoder-decoder" in self.modules_applied_guided_attn:
                att_ws = []
                for idx, layer_idx in enumerate(
                        reversed(range(len(self.decoder.decoders)))):
                    att_ws += [
                        self.decoder.decoders[layer_idx].src_attn.
                        attn[:, :self.num_heads_applied_guided_attn]
                    ]
                    if idx + 1 == self.num_layers_applied_guided_attn:
                        break
                att_ws = torch.cat(att_ws, dim=1)  # (B, H*L, T_out, T_in)
                enc_dec_attn_loss = self.attn_criterion(
                    att_ws, ilens, olens_in)
                loss = loss + enc_dec_attn_loss
                report_keys += [{
                    "enc_dec_attn_loss": enc_dec_attn_loss.item()
                }]

        # report extra information
        if self.use_scaled_pos_enc:
            report_keys += [
                {
                    "encoder_alpha": self.encoder.embed[-1].alpha.data.item()
                },
                {
                    "decoder_alpha": self.decoder.embed[-1].alpha.data.item()
                },
            ]
        self.reporter.report(report_keys)

        return loss
示例#3
0
    def forward(self, batch_dict):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """

        # join features
        features = torch.cat(
            [
                bd["contrastive_projection_norm"].unsqueeze(dim=1)
                for bd in batch_dict
            ],
            dim=1,
        )

        # targets for the batch is the one with highest score
        labels = batch_dict[0]["target"].argmax(dim=-1).view(-1, 1)

        # samples without an answer cannot work as anchor points
        mask_samples = (batch_dict[0]["target"].sum(dim=-1) != 0).int()

        # mask
        pos_mask = None

        device = torch.device("cuda") if features.is_cuda else torch.device(
            "cpu")

        if len(features.shape) < 3:
            raise ValueError("`features` needs to be [bsz, n_views, ...],"
                             "at least 3 dimensions are required")
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and pos_mask is not None:
            raise ValueError("Cannot define both `labels` and `mask`")
        elif labels is None and pos_mask is None:
            pos_mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError(
                    "Num of labels does not match num of features")
            pos_mask = torch.eq(labels, labels.T).float().to(device)
        else:
            pos_mask = pos_mask.float().to(device)

        # remove samples without gt
        pos_mask = pos_mask * mask_samples
        contrast_count = features.shape[1]

        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == "one":
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == "all":
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError("Unknown mode: {}".format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T), self.temperature)

        # for numerical stability, doesn't affect any values ahead
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        pos_mask = pos_mask.repeat(anchor_count, contrast_count)

        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(pos_mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0,
        )
        # This is just an inverted identity matrix
        # assert logits_mask.cpu() == (torch.eye(logits_mask.shape[0]) == 0).int()
        pos_mask = pos_mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask

        if self.formulation == "custom":
            negs_mask = (pos_mask == 0).int() * logits_mask
            negs_sum = (exp_logits * negs_mask).sum(dim=-1, keepdim=True)
            denominator = negs_sum + exp_logits * pos_mask
            log_prob = logits - torch.log(denominator.sum(1, keepdim=True))
        else:
            assert self.formulation == "normal"
            log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # re-scaling rephrasings
        scl_mask_rescale_factor = registry.scl_mask_rescale_factor
        if scl_mask_rescale_factor > 0:
            secondary_mask = (torch.eye(batch_size,
                                        device=pos_mask.device).repeat(
                                            anchor_count,
                                            contrast_count).fill_diagonal_(0))
            secondary_mask = secondary_mask * scl_mask_rescale_factor
            secondary_mask[secondary_mask == 0] = 1
            pos_mask = pos_mask * secondary_mask

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (pos_mask * log_prob).sum(1) / torch.max(
            pos_mask.sum(1),
            torch.ones(1).to(pos_mask.device))

        # loss
        loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss, -1
def sample_sequence(model,
                    length,
                    context,
                    position_ids,
                    num_samples=1,
                    temperature=1,
                    top_k=0,
                    top_p=0.0,
                    repetition_penalty=1.0,
                    is_xlnet=False,
                    is_xlm_mlm=False,
                    xlm_mask_token=None,
                    xlm_lang=None,
                    device='cpu',
                    stop_token_ids=None,
                    pad_token_id=None,
                    supports_past=False,
                    prompt_token_id=None,
                    segment_token_ids=None):
    """
    Generates sequence of tokens for the batch of input contexts.
    Inputs:
        context: a list of token_ids, sorted by length from longest to shortest
        position_ids: a list of indicate that indicates the positional embedding we should use for each token in context
        num_samples: the number of sequences to output for each input context
        length: The maximum length of generation in addition to the original sentence's length
        stop_token_ids: generation of each sequence will stop if we generate any of these tokens
        supports_past: set to True if the model accepts the 'past' input for more efficient generation. For example, GPT-2/Transfo-XL/XLNet/CTRL do
        segment_token_ids: a list of two integers that indicate the tokens we should use for each of the two segments
    """
    max_length = len(
        context[0])  # context is sorted by length from longest to shortest
    min_length = len(context[-1])
    for a in context:
        a.extend([pad_token_id] * (max_length - len(a)))

    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.repeat(num_samples, 1)
    next_index = min_length
    generated = context[:, :next_index]
    should_finish = None
    length = max_length + length
    segment_ids = []
    for p in position_ids:
        segment_ids.append([segment_token_ids[0]] * len(p) +
                           [segment_token_ids[1]] *
                           (length + max_length - len(p)))
        p.extend(range(length + max_length - len(p)))

    position_ids = torch.tensor(position_ids, dtype=torch.long, device=device)
    position_ids = position_ids.repeat(num_samples, 1)
    segment_ids = torch.tensor(segment_ids, dtype=torch.long, device=device)
    segment_ids = segment_ids.repeat(num_samples, 1)

    # print('context = ', context)
    # print('position_ids = ', position_ids)
    # print('segment_ids = ', segment_ids)

    past = None
    next_token = None
    with torch.no_grad():
        # rep_penalty = np.random.random(length) < 0.1
        # original_rep_penalty = repetition_penalty
        # print('rep_penalty = ', rep_penalty)
        for _ in trange(length):
            inputs = {
                'input_ids': generated,
                'position_ids': position_ids[:, :next_index],
                'token_type_ids': segment_ids[:, :next_index]
            }
            if is_xlnet:
                # XLNet is a direct (predict same token, not next token) and bi-directional model by default
                # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
                input_ids = torch.cat(
                    (generated,
                     torch.zeros((1, 1), dtype=torch.long, device=device)),
                    dim=1)
                perm_mask = torch.zeros(
                    (1, input_ids.shape[1], input_ids.shape[1]),
                    dtype=torch.float,
                    device=device)
                perm_mask[:, :,
                          -1] = 1.0  # Previous tokens don't see last token
                target_mapping = torch.zeros((1, 1, input_ids.shape[1]),
                                             dtype=torch.float,
                                             device=device)
                target_mapping[0, 0, -1] = 1.0  # predict last token
                inputs = {
                    'input_ids': input_ids,
                    'perm_mask': perm_mask,
                    'target_mapping': target_mapping
                }

            if is_xlm_mlm and xlm_mask_token:
                # XLM MLM models are direct models (predict same token, not next token)
                # => need one additional dummy token in the input (will be masked and guessed)
                input_ids = torch.cat((generated,
                                       torch.full((1, 1),
                                                  xlm_mask_token,
                                                  dtype=torch.long,
                                                  device=device)),
                                      dim=1)
                inputs = {'input_ids': input_ids}

            if xlm_lang is not None:
                inputs["langs"] = torch.tensor([xlm_lang] *
                                               inputs["input_ids"].shape[1],
                                               device=device).view(1, -1)

            if supports_past:
                inputs['past'] = past
                if past is not None:
                    inputs['input_ids'] = next_token
                    inputs['position_ids'] = position_ids[:, next_index - 1]
                    inputs['token_type_ids'] = segment_ids[:, next_index - 1]

            outputs = model(**inputs)
            next_token_logits = outputs[0][:, -1, :] / (
                temperature if temperature > 0 else 1.)
            past = outputs[1]

            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858), but much faster on GPU
            # for repetition_penalty, we penalize the tokens that appear in the context
            m = torch.scatter(input=torch.zeros_like(next_token_logits),
                              dim=1,
                              index=context,
                              value=1)
            m[:prompt_token_id] = 0
            m[:pad_token_id] = 0
            # print('m = ', m.shape)
            need_change = m * next_token_logits
            need_divide = need_change > 0
            need_multiply = need_change < 0
            next_token_logits = need_divide * next_token_logits / repetition_penalty + need_multiply * next_token_logits * repetition_penalty + (
                1 - m) * next_token_logits

            # Old, slow implementation
            # if repetition_penalty != 1.0:
            # for i in range(context.shape[0]):
            # for _ in set(generated[i].tolist()):
            # if next_token_logits[i, _] > 0:
            # next_token_logits[i, _] /= repetition_penalty
            # else:
            # next_token_logits[i, _] *= repetition_penalty

            filtered_logits = top_k_top_p_filtering(next_token_logits,
                                                    top_k=top_k,
                                                    top_p=top_p)

            if temperature == 0:  # greedy sampling:
                next_token = torch.argmax(filtered_logits,
                                          dim=-1).unsqueeze(-1)
            else:
                next_token = torch.multinomial(F.softmax(filtered_logits,
                                                         dim=-1),
                                               num_samples=1)

            # throw away the tokens that we already have from the context
            if next_index < context.shape[1]:
                m = (context[:, next_index:next_index + 1] !=
                     pad_token_id).long()
                next_token = m * context[:, next_index:next_index +
                                         1] + (1 - m) * next_token
            else:
                m = torch.zeros(1, device=device)

            for stop_token_id in stop_token_ids:
                if should_finish is None:
                    should_finish = ((next_token == stop_token_id) &
                                     (1 - m).bool())
                else:
                    should_finish = should_finish | (
                        (next_token == stop_token_id) & (1 - m).bool())
            next_index += 1
            generated = torch.cat((generated, next_token), dim=1)
            if should_finish.all():
                break
    return generated
示例#5
0
    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        # mask 做什么的? 0,1 排除 i=j 情况
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
            # 对角阵
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            # labels 转化为列向量 为torch.eq做铺垫
            if labels.shape[0] != batch_size:
                raise ValueError(
                    'Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
            """
            仅仅给定标签数据: - 依据标签数据生成 mask
            mask: 二维数组,判断 i,j 是否为同类
            """
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]  # 视图数量
        # ubind 维数-1,变tuple
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)

        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        '''
        ---------------------------------------------------------------
        以上都为初始化操作
        '''
        # compute logits : i-j matrix
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(
            anchor_dot_contrast, dim=1, keepdim=True)  # cos相似度最大值
        # logits = anchor_dot_contrast - logits_max.detach()  # 每个元素 - 列cos_sim最大值
        logits = anchor_dot_contrast
        '''
        logits 每个元素-行最大值(1) 为什么要-1?
        '''

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases 除了自己全都有
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask  # 同类与不同类的映射表
        '''
        logtis_mask 对角为0其余为1 区分自己和别人
        mask 区分同类和异类
        mask * logits_mask 逐元素乘 区分同类为1 异类为0 ii为0
        此时mask 代表 同类为1,自身、异类为0
        '''

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        '''
        if i == j 0
        else exp_logits[i,j] = cossin_sim[i,j] - 1
        '''
        # ''' exp_logits 对比损失中的分母 '''
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
        # ''' exp_logits.sum(1) 行和 q^ {\sum^{i*j}} '''

        # 逐行累加

        # compute mean of log-likelihood over positive 正例均值
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        # 有必要view?
        # loss = loss.view(anchor_count, batch_size).mean()
        loss = loss.mean()

        return loss
示例#6
0
 def _impl(x, dim, index, src):
     dim = dim.item()
     return (torch.scatter(x, dim, index, src), )
示例#7
0
import cv2
import torch

# x = torch.rand(2, 5)
# print(x)
# x=torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
# print(x)

img = cv2.imread('annotations/pixmask/56.jpg')
# print()
# img=cv2.cvtColor(img,cv2.COLOR_GRAY2BGR)
# cv2.imshow('1', img)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
#
data = torch.randint(1, 4, (3, 3))
zeros = torch.zeros(1, 59, 3, 3).permute(0, 2, 3, 1)
label = torch.scatter(dim=1, index=data[:, :], value=1)
print(data)
print(data[:])
print(zeros)
示例#8
0
    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError(
                    'Num of labels does not match num of features')
            mask = torch.eq(labels, labels.t()).float().to(device)
            Negative_mask = torch.ne(labels, labels.t()).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        #Sim 分子计算
        dot_contrast = torch.matmul(anchor_feature, anchor_feature.t())
        dot_contrast_copy = dot_contrast.clone()
        norm_2 = []
        for i in range(batch_size * 2):
            v1 = torch.norm(anchor_feature[i], 2)
            norm_2.append(v1)
        for i in range(batch_size * 2):
            for j in range(batch_size * 2):
                #除去分母
                dot_contrast[i][j] = dot_contrast_copy[i][j] / (norm_2[i] *
                                                                norm_2[j])
        #除去T
        logits = torch.div(dot_contrast, self.temperature)

        mask = mask.repeat(anchor_count, contrast_count)  #自己的class
        Negative_mask = Negative_mask.repeat(anchor_count, contrast_count)
        logits_mask = torch.scatter(
            torch.ones_like(mask), 1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0)
        mask = mask * logits_mask

        # 本ID之外的,也就是负样本
        exp_logits = torch.exp(logits) * Negative_mask

        #分母是再加上自己(不加正样本)
        v1 = exp_logits.sum(1, keepdim=True)
        v2 = v1.repeat(1, batch_size * 2)
        v2 = v2.t()
        v2 = v2 + torch.exp(logits)
        #v3即每一次计算的分母矩阵
        v3 = torch.log(v2)
        #log相减
        log_prob = logits - v3

        #相加
        mean_log_prob_pos = (mask * log_prob).sum(1)

        # loss
        loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss
示例#9
0
def cal_concept_word_probs(logits,
                           final_pool,
                           concept2words_map,
                           softmax,
                           temperature=1.0):
    assert len(logits.size()) == 3
    assert len(final_pool.size()) == 2
    batch_size = logits.size(0)
    output_len = logits.size(1)
    # concept_probs = (jump_probs * hybrid_weights['jump'] + walk_probs * hybrid_weights['walk'])

    concept_word_probs, concept_word_mask = None, None
    if final_pool is not None:
        # [bs, 2680]
        topk_concept2words_map = (final_pool.unsqueeze(-1) *
                                  concept2words_map).view(batch_size, -1)
        # assert topk_concept2words_map.size() == (batch_size, 2680 * 7)

        # map to the word vocab
        idx = topk_concept2words_map.unsqueeze(1).expand(-1, output_len,
                                                         -1).type(torch.int64)
        concept_word_logits_mask = torch.scatter(input=torch.zeros_like(
            logits, dtype=torch.int64),
                                                 index=idx,
                                                 src=torch.ones_like(idx),
                                                 dim=-1)
        concept_word_logits_mask[:, :, 0] = 0
        concept_word_logits = logits * concept_word_logits_mask

        concept_word_logits = torch.where(
            concept_word_logits.eq(0),
            torch.ones_like(concept_word_logits) * -1e10, concept_word_logits)
        concept_word_probs = softmax(concept_word_logits / temperature)

    # topk_concept_idx = concept_probs.topk(topk)[1]
    # topk_concept_probs = concept_probs.topk(topk)[0]
    #
    # #  [bs, topk, 7]
    # topk_concept2words_map = torch.gather(input=concept2words_map.unsqueeze(0).expand(batch_size, -1, -1), dim=1,
    #                                       index=topk_concept_idx.unsqueeze(-1).expand(-1, -1, 7))
    #
    # # topk_concept_probs = torch.gather(input=concept_probs, dim=1, index=topk_concept_idx)
    # topk_concept2words_mask = topk_concept2words_map.ne(0)
    #
    # #  [bs, len, topk, 7]
    # concept_word_logits = torch.gather(lm_logits.unsqueeze(-2).expand(batch_size, output_len, topk, -1), dim=-1,
    #                                    index=topk_concept2words_map.type(torch.int64).unsqueeze(1).expand(
    #                                        batch_size, output_len, topk, -1))
    # concept_word_logits2 = concept_word_logits * topk_concept2words_mask.unsqueeze(1).expand(-1, output_len, -1, -1)

    # if use_lm_logits:
    #     # map to the word vocab
    #     idx = topk_concept2words_map.unsqueeze(1).expand(-1, output_len, -1, -1).view(batch_size, output_len, -1).type(
    #         torch.int64)
    #     src = concept_word_logits2.view(batch_size, output_len, -1)
    #     tgt = torch.zeros_like(lm_logits)
    #     final_logits = tgt.scatter(dim=-1, index=idx, src=src)
    #     final_logits = torch.where(final_logits.eq(0), torch.ones_like(final_logits) * -1e10, final_logits)
    #     final_probs = softmax(final_logits)
    #
    # else:
    #     concept_word_logits3 = torch.where(concept_word_logits2.eq(0), torch.ones_like(concept_word_logits2) * -1e10,
    #                                        concept_word_logits2)
    #     word_probs_given_concept = softmax(concept_word_logits3)
    #     # word_probs_given_concept[:, :, 0:2] = 0
    #
    #     concept_word_probs = word_probs_given_concept * (topk_concept_probs.unsqueeze(-1).unsqueeze(1))
    #
    #     # map to the word vocab
    #     idx = topk_concept2words_map.unsqueeze(1).expand(-1, output_len, -1, -1).view(batch_size, output_len, -1).type(
    #         torch.int64)
    #     src = concept_word_probs.view(batch_size, output_len, -1)
    #     tgt = torch.zeros_like(lm_logits)
    #     final_probs = tgt.scatter(dim=-1, index=idx, src=src)

    return concept_word_probs
 def __init__(self):
     # Fixing the dataset and problem
     self.X, self.y = make_moons(n_samples=200, noise=0.1)
     self.y_ = torch.unsqueeze(torch.tensor(self.y), 1)  # used for one-hot encoded labels
     self.y_hot = torch.scatter(torch.zeros((200, 2)), 1, self.y_, 1)
示例#11
0
def cal_finding_common_ground_score(send_messages_list,
                                    receive_messages_list,
                                    trainer_persona,
                                    partner_persona,
                                    kw_graph_distance_matrix,
                                    device,
                                    r=None):
    # calulate persona ground
    both_persona_str = trainer_persona + ' ' + partner_persona
    persona_concepts = extract_concepts(both_persona_str, 50)
    persona_ground = torch.scatter(
        input=torch.zeros(2680).to(device),
        dim=-1,
        index=torch.tensor(persona_concepts).to(device),
        src=torch.ones_like(
            torch.tensor(persona_concepts, dtype=torch.float).to(device)))
    persona_ground[0] = 0
    # num_persona_ground_concepts = persona_ground.sum().item()

    batch_size = len(send_messages_list[0])
    num_turn = len(send_messages_list)
    # calculate common ground
    # common_grounds = [[[] for _ in range(num_turn)] for _ in range(batch_size)]
    # num_common_ground_concepts = [[0 for _ in range(num_turn)] for _ in range(batch_size)]
    fcg_scores = [[0 for _ in range(num_turn)] for _ in range(batch_size)]
    recall_scores = [[0 for _ in range(num_turn)] for _ in range(batch_size)]
    common_ground_history = [
        torch.zeros(2680).to(device) for _ in range(batch_size)
    ]
    for idx_turn, receive_messages, send_messages in zip(
            reversed(range(num_turn)), reversed(receive_messages_list),
            reversed(send_messages_list)):
        for idx_batch, receive_message, send_message in zip(
                range(batch_size), receive_messages, send_messages):
            concepts = extract_concepts(send_message + ' ' + receive_message,
                                        50)
            common_ground_current = torch.scatter(
                input=torch.zeros(2680).to(device),
                dim=-1,
                index=torch.tensor(concepts).to(device),
                src=torch.ones_like(
                    torch.tensor(concepts, dtype=torch.float).to(device)))
            if have_concepts_in(common_ground_current):
                common_ground_current[0] = 0
            common_ground = (common_ground_current +
                             common_ground_history[idx_batch]).clamp(0, 1)
            common_ground_history[idx_batch] = common_ground
            # if no concept, then the common_ground_one_turn[0] will be scattered by 1.

            # num_common_ground_concepts[idx_batch][idx_turn] += common_ground.sum().item()
            precision_score = fcg_precision_score(persona_ground,
                                                  common_ground,
                                                  kw_graph_distance_matrix)
            fcg_scores[idx_batch][idx_turn] += precision_score

            recall_score = fcg_recall_score(persona_ground, common_ground,
                                            kw_graph_distance_matrix, r)
            recall_scores[idx_batch][idx_turn] += recall_score / (num_turn -
                                                                  idx_turn + 1)
            # common_grounds[idx_batch][idx_turn] += common_ground.tolist()

    # common_grounds = torch.tensor(common_grounds, dtype=torch.bool).to(device)
    # num_common_ground_concepts = torch.tensor(num_common_ground_concepts).to(device)
    # concepts2persona_ground = (kw_graph_distance_matrix * persona_ground).sum(-1) / num_persona_ground_concepts
    # fcg_precision = (common_grounds * concepts2persona_ground).sum(-1) / num_common_ground_concepts
    return fcg_scores, recall_scores
示例#12
0
 def create_mask(l_probs, lbl):
     with torch.no_grad():
         mask = torch.zeros_like(l_probs)
         mask = torch.scatter(mask, 2, lbl.unsqueeze(2), 1.)
     return mask
示例#13
0
    def forward(self,
                xs,
                ilens,
                ys,
                labels,
                olens,
                spembs=None,
                spcs=None,
                *args,
                **kwargs):
        """Calculate forward propagation.

        Args:
            xs (Tensor): Batch of padded acoustic features (B, Tmax, idim).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional):
                Batch of speaker embedding vectors (B, spk_embed_dim).
            spcs (Tensor, optional):
                Batch of groundtruth spectrograms (B, Lmax, spc_dim).

        Returns:
            Tensor: Loss value.

        """
        # remove unnecessary padded part (for multi-gpus)
        max_in = max(ilens)
        max_out = max(olens)
        if max_in != xs.shape[1]:
            xs = xs[:, :max_in]
        if max_out != ys.shape[1]:
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]

        # thin out input frames for reduction factor
        # (B, Lmax, idim) ->  (B, Lmax // r, idim * r)
        if self.encoder_reduction_factor > 1:
            B, Lmax, idim = xs.shape
            if Lmax % self.encoder_reduction_factor != 0:
                xs = xs[:, :-(Lmax % self.encoder_reduction_factor), :]
            xs_ds = xs.contiguous().view(
                B,
                int(Lmax / self.encoder_reduction_factor),
                idim * self.encoder_reduction_factor,
            )
            ilens_ds = ilens.new(
                [ilen // self.encoder_reduction_factor for ilen in ilens])
        else:
            xs_ds, ilens_ds = xs, ilens

        # calculate tacotron2 outputs
        hs, hlens = self.enc(xs_ds, ilens_ds)
        if self.spk_embed_dim is not None:
            spembs = F.normalize(spembs).unsqueeze(1).expand(
                -1, hs.size(1), -1)
            hs = torch.cat([hs, spembs], dim=-1)
        after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys)

        # caluculate src reconstruction
        if self.src_reconstruction_loss_lambda > 0:
            B, _in_length, _adim = hs.shape
            xt, xtlens = self.src_reconstructor(hs, hlens)
            xt = self.src_reconstructor_linear(xt)
            if self.encoder_reduction_factor > 1:
                xt = xt.view(B, -1, self.idim)

        # caluculate trg reconstruction
        if self.trg_reconstruction_loss_lambda > 0:
            olens_trg_cp = olens.new(
                sorted([olen // self.reduction_factor for olen in olens],
                       reverse=True))
            B, _in_length, _adim = hs.shape
            _, _out_length, _ = att_ws.shape
            # att_R should be [B, out_length / r_d, adim]
            att_R = torch.sum(
                hs.view(B, 1, _in_length, _adim) *
                att_ws.view(B, _out_length, _in_length, 1),
                dim=2,
            )
            yt, ytlens = self.trg_reconstructor(
                att_R, olens_trg_cp)  # is using olens correct?
            yt = self.trg_reconstructor_linear(yt)
            if self.reduction_factor > 1:
                yt = yt.view(
                    B, -1,
                    self.odim)  # now att_R should be [B, out_length, adim]

        # modifiy mod part of groundtruth
        if self.reduction_factor > 1:
            assert olens.ge(self.reduction_factor).all(
            ), "Output length must be greater than or equal to reduction factor."
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_out = max(olens)
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]
            labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1),
                                   1.0)  # see #3388
        if self.encoder_reduction_factor > 1:
            ilens = ilens.new([
                ilen - ilen % self.encoder_reduction_factor for ilen in ilens
            ])
            max_in = max(ilens)
            xs = xs[:, :max_in]

        # caluculate taco2 loss
        l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs,
                                                      logits, ys, labels,
                                                      olens)
        loss = l1_loss + mse_loss + bce_loss
        report_keys = [
            {
                "l1_loss": l1_loss.item()
            },
            {
                "mse_loss": mse_loss.item()
            },
            {
                "bce_loss": bce_loss.item()
            },
        ]

        # caluculate context_perservation loss
        if self.src_reconstruction_loss_lambda > 0:
            src_recon_l1_loss, src_recon_mse_loss = self.src_reconstruction_loss(
                xt, xs, ilens)
            loss = loss + src_recon_l1_loss
            report_keys += [
                {
                    "src_recon_l1_loss": src_recon_l1_loss.item()
                },
                {
                    "src_recon_mse_loss": src_recon_mse_loss.item()
                },
            ]
        if self.trg_reconstruction_loss_lambda > 0:
            trg_recon_l1_loss, trg_recon_mse_loss = self.trg_reconstruction_loss(
                yt, ys, olens)
            loss = loss + trg_recon_l1_loss
            report_keys += [
                {
                    "trg_recon_l1_loss": trg_recon_l1_loss.item()
                },
                {
                    "trg_recon_mse_loss": trg_recon_mse_loss.item()
                },
            ]

        # caluculate attention loss
        if self.use_guided_attn_loss:
            # NOTE(kan-bayashi): length of output for auto-regressive input
            #   will be changed when r > 1
            if self.encoder_reduction_factor > 1:
                ilens_in = ilens.new(
                    [ilen // self.encoder_reduction_factor for ilen in ilens])
            else:
                ilens_in = ilens
            if self.reduction_factor > 1:
                olens_in = olens.new(
                    [olen // self.reduction_factor for olen in olens])
            else:
                olens_in = olens
            attn_loss = self.attn_loss(att_ws, ilens_in, olens_in)
            loss = loss + attn_loss
            report_keys += [
                {
                    "attn_loss": attn_loss.item()
                },
            ]

        # caluculate cbhg loss
        if self.use_cbhg:
            # remove unnecessary padded part (for multi-gpus)
            if max_out != spcs.shape[1]:
                spcs = spcs[:, :max_out]

            # caluculate cbhg outputs & loss and report them
            cbhg_outs, _ = self.cbhg(after_outs, olens)
            cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(
                cbhg_outs, spcs, olens)
            loss = loss + cbhg_l1_loss + cbhg_mse_loss
            report_keys += [
                {
                    "cbhg_l1_loss": cbhg_l1_loss.item()
                },
                {
                    "cbhg_mse_loss": cbhg_mse_loss.item()
                },
            ]

        report_keys += [{"loss": loss.item()}]
        self.reporter.report(report_keys)

        return loss
示例#14
0
 def forward(self, data: torch.Tensor, indices: torch.Tensor,
             updates: torch.Tensor):
     return torch.scatter(data, self.dim, indices, updates)
示例#15
0
neis_d = x1.coordinate_manager.get_kernel_map(
        x1.coordinate_map_key,
        x1.coordinate_map_key,
        kernel_size=2,
        stride=1,
        )
k = 8
N, dim = x1.F.shape
out = torch.zeros([N, dim], device=x1.device)
for k_ in range(k):

    if not k_ in neis_d.keys():
        continue
    neis_ = torch.gather(x1.F, dim=0, index=neis_d[k_][0].reshape(-1,1).repeat(1,dim).long())
    neis = torch.zeros([N, dim], device=x1.device)
    neis = torch.scatter(neis, dim=0, index=neis_d[k_][1].reshape(-1,1).repeat(1,dim).long(),src=neis_)
    out += neis

out_ = debug_channel_conv(x1)
out_conv = debug_conv(x1)

print(out, '\n',out_.F)
import ipdb; ipdb.set_trace()

'''
check clone 2 convs
'''
# conv2 = ME.MinkowskiConvolution(1,1,kernel_size=1,dimension=3)
# conv2.kernel = nn.Parameter(conv1.kernel.clone())
# # out1 = conv1(x1)
# out2 = conv2(x2)
示例#16
0
    def forward(self,
                xs,
                ilens,
                ys,
                labels,
                olens,
                spembs=None,
                extras=None,
                *args,
                **kwargs):
        """Calculate forward propagation.

        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional):
                Batch of speaker embedding vectors (B, spk_embed_dim).
            extras (Tensor, optional):
                Batch of groundtruth spectrograms (B, Lmax, spc_dim).

        Returns:
            Tensor: Loss value.

        """
        # remove unnecessary padded part (for multi-gpus)
        max_in = max(ilens)
        max_out = max(olens)
        if max_in != xs.shape[1]:
            xs = xs[:, :max_in]
        if max_out != ys.shape[1]:
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]

        # calculate tacotron2 outputs
        hs, hlens = self.enc(xs, ilens)
        if self.spk_embed_dim is not None:
            spembs = F.normalize(spembs).unsqueeze(1).expand(
                -1, hs.size(1), -1)
            hs = torch.cat([hs, spembs], dim=-1)
        after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys)

        # modifiy mod part of groundtruth
        if self.reduction_factor > 1:
            assert olens.ge(self.reduction_factor).all(
            ), "Output length must be greater than or equal to reduction factor."
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_out = max(olens)
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]
            labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1),
                                   1.0)  # see #3388

        # caluculate taco2 loss
        l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs,
                                                      logits, ys, labels,
                                                      olens)
        loss = l1_loss + mse_loss + bce_loss
        report_keys = [
            {
                "l1_loss": l1_loss.item()
            },
            {
                "mse_loss": mse_loss.item()
            },
            {
                "bce_loss": bce_loss.item()
            },
        ]

        # caluculate attention loss
        if self.use_guided_attn_loss:
            # NOTE(kan-bayashi):
            # length of output for auto-regressive input will be changed when r > 1
            if self.reduction_factor > 1:
                olens_in = olens.new(
                    [olen // self.reduction_factor for olen in olens])
            else:
                olens_in = olens
            attn_loss = self.attn_loss(att_ws, ilens, olens_in)
            loss = loss + attn_loss
            report_keys += [
                {
                    "attn_loss": attn_loss.item()
                },
            ]

        # caluculate cbhg loss
        if self.use_cbhg:
            # remove unnecessary padded part (for multi-gpus)
            if max_out != extras.shape[1]:
                extras = extras[:, :max_out]

            # caluculate cbhg outputs & loss and report them
            cbhg_outs, _ = self.cbhg(after_outs, olens)
            cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(
                cbhg_outs, extras, olens)
            loss = loss + cbhg_l1_loss + cbhg_mse_loss
            report_keys += [
                {
                    "cbhg_l1_loss": cbhg_l1_loss.item()
                },
                {
                    "cbhg_mse_loss": cbhg_mse_loss.item()
                },
            ]

        report_keys += [{"loss": loss.item()}]
        self.reporter.report(report_keys)

        return loss
示例#17
0
    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        torch.set_printoptions(threshold=10000)
        device = (torch.device('cuda')
                  if features.is_cuda else torch.device('cpu'))
        print("features:", features.size())  # torch.Size([bs, 2, 128])

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            print("labels", labels.size(), labels)
            if labels.shape[0] != batch_size:
                raise ValueError(
                    'Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
            #print("mask", mask.size(), mask) # [bsz, bsz]
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]  # 2
        #print("contrast_count(should be 2):", contrast_count)
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        #print("contrast_feature shape(should be [2*bsz, 128]:", contrast_feature.size())
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            #print("contrast_mode is all")
            anchor_feature = contrast_feature  # [2*bsz, 128]
            anchor_count = contrast_count  #2
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T), self.temperature)
        #print("anchor_dot_contrast shape(should be [32, 32]):", anchor_dot_contrast.size())
        ### for numerical stability ###
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        #print("logits shape(should be [32, 32]):", logits.size())
        #print("trace of logits(should be 0):", torch.trace(logits))
        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        #print("mask repeat:", mask) # [32, 32] or [2*bsz, 2*bsz]
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask), 1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0)
        #print("mask", mask)
        #print("logits_mask", logits_mask)
        mask = mask * logits_mask
        #print("mask", mask)
        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1,
                                                     keepdim=True))  #[32, 32]

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)  # [32,]

        # loss
        loss_inter = -(self.temperature /
                       self.base_temperature) * mean_log_prob_pos  # [32,]

        loss_inter = loss_inter.view(anchor_count, batch_size).mean()

        exp_logits_intra = torch.exp(logits) * mask
        print("check mask one's amount:", mask.sum(1))
        log_prob_intra = logits - torch.log(
            exp_logits_intra.sum(1, keepdim=True))  #[32, 32]
        intra_mask = torch.eye(batch_size).repeat(
            anchor_count, contrast_count).to(device) - torch.eye(
                batch_size * anchor_count).to(device)
        mean_log_prob_pos_intra = (
            intra_mask * log_prob_intra).sum(1) / intra_mask.sum(1)  # [32,]
        #print("intra mask:", intra_mask)
        #print("check intra mask one's amount(must be 1 for each row):", intra_mask.sum(1))
        loss_intra = -(self.temperature / self.base_temperature
                       ) * mean_log_prob_pos_intra  # [32,]

        loss_intra = loss_intra.view(anchor_count, batch_size).mean()
        loss = loss_inter + loss_intra
        return loss
示例#18
0
def train(config):
    assert config.model_type in ('RNN', 'LSTM')

    # Initialize the device which to run the model on
    device = torch.device(config.device)

    # Initialize params for models
    seq_length = config.input_length
    input_dim = config.input_dim
    num_hidden = config.num_hidden
    num_classes = config.num_classes

    print(seq_length, input_dim, num_classes, num_hidden)

    # Testing for convergence
    epsilon = 5e-4
    # minimal steps the model definitely trains, LSTM trains slower so needs more interations
    if seq_length < 30:
        if config.model_type == 'RNN':
            min_steps = 3000 if seq_length > 15 else 1000
        else:
            min_steps = 5000 if seq_length > 15 else 1500
    else:
        min_steps = 6500

    # Initialize the model that we are going to use
    if config.model_type == 'RNN':
        model = VanillaRNN(seq_length, input_dim, num_hidden, num_classes,
                           device)
    else:
        model = LSTM(seq_length, input_dim, num_hidden, num_classes, device)

    model.to(device)

    # Initialize the dataset and data loader (note the +1)
    dataset = PalindromeDataset(config.input_length + 1)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Setup the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.RMSprop(model.parameters(), lr=config.learning_rate)

    # Train losses and accuracies for debugging purposes
    accuracies, losses = [], []

    for step, (batch_inputs, batch_targets) in enumerate(data_loader):

        # Only for time measurement of step through network
        t1 = time.time()

        # convert to one-hot representation
        batch_inputs = torch.scatter(
            torch.zeros(*batch_inputs.size(), num_classes), 2,
            batch_inputs[..., None].to(torch.int64), 1).to(device)

        batch_targets = batch_targets.to(device)

        train_output = model.forward(batch_inputs)
        loss = criterion(train_output, batch_targets)

        ############################################################################
        # QUESTION: what happens here and why?
        ############################################################################
        # Clip exploding gradients
        torch.nn.utils.clip_grad_norm(model.parameters(),
                                      max_norm=config.max_norm)
        ############################################################################

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        accuracy = torch.sum(
            torch.eq(torch.argmax(train_output, dim=1),
                     batch_targets)).item() / train_output.size(0)
        accuracies.append(accuracy)
        losses.append(loss.item())

        # Just for time measurement
        t2 = time.time()
        examples_per_second = config.batch_size / float(t2 - t1)

        if step % 100 == 0:

            print(
                "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                "Accuracy = {:.2f}, Loss = {:.3f}".format(
                    datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                    config.train_steps, config.batch_size, examples_per_second,
                    accuracy, loss))

            if step > min_steps and (
                    np.absolute(np.mean(losses[-102:-2]) - losses[-1]) <
                    epsilon):
                print("Convergence reached after {} steps".format(step))
                break

        if step == config.train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break

    print('Done training.')
    return model
示例#19
0
    def forward(self, anchor_points_list, gt_bboxes, labels,
                inside_gt_bbox_mask):
        """Get the center prior of each point on the feature map for each
        instance.

        Args:
            anchor_points_list (list[Tensor]): list of coordinate
                of points on feature map. Each with shape
                (num_points, 2).
            gt_bboxes (Tensor): The gt_bboxes with shape of
                (num_gt, 4).
            labels (Tensor): The gt_labels with shape of (num_gt).
            inside_gt_bbox_mask (Tensor): Tensor of bool type,
                with shape of (num_points, num_gt), each
                value is used to mark whether this point falls
                within a certain gt.

        Returns:
            tuple(Tensor):

                - center_prior_weights(Tensor): Float tensor with shape \
                    of (num_points, num_gt). Each value represents \
                    the center weighting coefficient.
                - inside_gt_bbox_mask (Tensor): Tensor of bool type, \
                    with shape of (num_points, num_gt), each \
                    value is used to mark whether this point falls \
                    within a certain gt or is the topk nearest points for \
                    a specific gt_bbox.
        """
        inside_gt_bbox_mask = inside_gt_bbox_mask.clone()
        num_gts = len(labels)
        num_points = sum([len(item) for item in anchor_points_list])
        if num_gts == 0:
            return gt_bboxes.new_zeros(num_points,
                                       num_gts), inside_gt_bbox_mask
        center_prior_list = []
        for slvl_points, stride in zip(anchor_points_list, self.strides):
            # slvl_points: points from single level in FPN, has shape (h*w, 2)
            # single_level_points has shape (h*w, num_gt, 2)
            single_level_points = slvl_points[:, None, :].expand(
                (slvl_points.size(0), len(gt_bboxes), 2))
            gt_center_x = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2)
            gt_center_y = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2)
            gt_center = torch.stack((gt_center_x, gt_center_y), dim=1)
            gt_center = gt_center[None]
            # instance_center has shape (1, num_gt, 2)
            instance_center = self.mean[labels][None]
            # instance_sigma has shape (1, num_gt, 2)
            instance_sigma = self.sigma[labels][None]
            # distance has shape (num_points, num_gt, 2)
            distance = (((single_level_points - gt_center) / float(stride) -
                         instance_center)**2)
            center_prior = torch.exp(-distance /
                                     (2 * instance_sigma**2)).prod(dim=-1)
            center_prior_list.append(center_prior)
        center_prior_weights = torch.cat(center_prior_list, dim=0)

        if self.force_topk:
            gt_inds_no_points_inside = torch.nonzero(
                inside_gt_bbox_mask.sum(0) == 0).reshape(-1)
            if gt_inds_no_points_inside.numel():
                topk_center_index = \
                    center_prior_weights[:, gt_inds_no_points_inside].topk(
                                                             self.topk,
                                                             dim=0)[1]
                temp_mask = inside_gt_bbox_mask[:, gt_inds_no_points_inside]
                inside_gt_bbox_mask[:, gt_inds_no_points_inside] = \
                    torch.scatter(temp_mask,
                                  dim=0,
                                  index=topk_center_index,
                                  src=torch.ones_like(
                                    topk_center_index,
                                    dtype=torch.bool))

        center_prior_weights[~inside_gt_bbox_mask] = 0
        return center_prior_weights, inside_gt_bbox_mask
示例#20
0
    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]

        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError(
                    'Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)

        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T), self.temperature)

        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask), 1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0)
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss
示例#21
0
def compute_contrastive_loss(posterior,
                             aug_posterior,
                             contrastive_fc,
                             labels=None,
                             mask=None,
                             contrast_mode='all',
                             temperature=0.07,
                             base_temperature=0.07,
                             loss_weight=1.0):
    '''Compute contrastive loss'''

    zis = contrastive_fc(posterior)
    # normalize
    zis = torch.nn.functional.normalize(zis, dim=1)

    zjs = contrastive_fc(aug_posterior)
    # normalize
    zjs = torch.nn.functional.normalize(zjs, dim=1)

    features = torch.stack([zis, zjs], dim=1)

    if labels is not None:
        labels = torch.argmax(labels, dim=-1)

    device = (torch.device('cuda')
              if features.is_cuda else torch.device('cpu'))

    features = features.view(features.shape[0], features.shape[1], -1)

    batch_size = features.shape[0]
    if labels is not None and mask is not None:
        raise ValueError('Cannot define both `labels` and `mask`')
    elif labels is None and mask is None:
        mask = torch.eye(batch_size, dtype=torch.float32).to(device)
    elif labels is not None:
        labels = labels.contiguous().view(-1, 1)
        if labels.shape[0] != batch_size:
            raise ValueError('Num of labels does not match num of features')
        mask = torch.eq(labels, labels.permute(1, 0)).float().to(device)
    else:
        mask = mask.float().to(device)

    contrast_count = features.shape[1]
    contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
    if contrast_mode == 'one':
        anchor_feature = features[:, 0]
        anchor_count = 1
    elif contrast_mode == 'all':
        anchor_feature = contrast_feature
        anchor_count = contrast_count
    else:
        raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

    # compute logits
    anchor_dot_contrast = torch.div(
        torch.matmul(anchor_feature, contrast_feature.permute(1, 0)),
        temperature)
    # for numerical stability
    logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
    logits = anchor_dot_contrast - logits_max.detach()

    # tile mask
    mask = mask.repeat(anchor_count, contrast_count)
    # mask-out self-contrast cases
    logits_mask = torch.scatter(
        torch.ones_like(mask), 1,
        torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0)
    mask = mask * logits_mask

    # compute log_prob
    exp_logits = torch.exp(logits) * logits_mask
    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

    # compute mean of log-likelihood over positives
    mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

    # loss
    loss = -(temperature / base_temperature) * mean_log_prob_pos
    loss = loss.view(anchor_count, batch_size).mean()

    return loss * loss_weight
示例#22
0
    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        feats: torch.Tensor,
        feats_lengths: torch.Tensor,
        spembs: Optional[torch.Tensor] = None,
        sids: Optional[torch.Tensor] = None,
        lids: Optional[torch.Tensor] = None,
        joint_training: bool = False,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Calculate forward propagation.

        Args:
            text (LongTensor): Batch of padded character ids (B, T_text).
            text_lengths (LongTensor): Batch of lengths of each input batch (B,).
            feats (Tensor): Batch of padded target features (B, T_feats, odim).
            feats_lengths (LongTensor): Batch of the lengths of each target (B,).
            spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim).
            sids (Optional[Tensor]): Batch of speaker IDs (B, 1).
            lids (Optional[Tensor]): Batch of language IDs (B, 1).
            joint_training (bool): Whether to perform joint training with vocoder.

        Returns:
            Tensor: Loss scalar value.
            Dict: Statistics to be monitored.
            Tensor: Weight value if not joint training else model outputs.

        """
        text = text[:, :text_lengths.max()]  # for data-parallel
        feats = feats[:, :feats_lengths.max()]  # for data-parallel

        batch_size = text.size(0)

        # Add eos at the last of sequence
        xs = F.pad(text, [0, 1], "constant", self.padding_idx)
        for i, l in enumerate(text_lengths):
            xs[i, l] = self.eos
        ilens = text_lengths + 1

        ys = feats
        olens = feats_lengths

        # make labels for stop prediction
        labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype)
        labels = F.pad(labels, [0, 1], "constant", 1.0)

        # calculate tacotron2 outputs
        after_outs, before_outs, logits, att_ws = self._forward(
            xs=xs,
            ilens=ilens,
            ys=ys,
            olens=olens,
            spembs=spembs,
            sids=sids,
            lids=lids,
        )

        # modify mod part of groundtruth
        if self.reduction_factor > 1:
            assert olens.ge(self.reduction_factor).all(
            ), "Output length must be greater than or equal to reduction factor."
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_out = max(olens)
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]
            labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1),
                                   1.0)  # see #3388

        # calculate taco2 loss
        l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs,
                                                      logits, ys, labels,
                                                      olens)
        if self.loss_type == "L1+L2":
            loss = l1_loss + mse_loss + bce_loss
        elif self.loss_type == "L1":
            loss = l1_loss + bce_loss
        elif self.loss_type == "L2":
            loss = mse_loss + bce_loss
        else:
            raise ValueError(f"unknown --loss-type {self.loss_type}")

        stats = dict(
            l1_loss=l1_loss.item(),
            mse_loss=mse_loss.item(),
            bce_loss=bce_loss.item(),
        )

        # calculate attention loss
        if self.use_guided_attn_loss:
            # NOTE(kan-bayashi): length of output for auto-regressive
            # input will be changed when r > 1
            if self.reduction_factor > 1:
                olens_in = olens.new(
                    [olen // self.reduction_factor for olen in olens])
            else:
                olens_in = olens
            attn_loss = self.attn_loss(att_ws, ilens, olens_in)
            loss = loss + attn_loss
            stats.update(attn_loss=attn_loss.item())

        if not joint_training:
            stats.update(loss=loss.item())
            loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                                   loss.device)
            return loss, stats, weight
        else:
            return loss, stats, after_outs
示例#23
0
def analyze_grads_over_time(config, pretrain_model=False):
    device = torch.device(config.device)
    config.input_length = 150

    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)

    total_norms = []

    for m in ["RNN", "LSTM"]:

        # pretrain model
        if pretrain_model:
            model = train(config)
        else:
            # Initialize params for models
            seq_length = config.input_length
            input_dim = config.input_dim
            num_hidden = config.num_hidden
            num_classes = config.num_classes

            # Initialize the model that we are going to use
            if m == 'RNN':
                model = VanillaRNN(seq_length, input_dim, num_hidden,
                                   num_classes, device)
            else:
                model = LSTM(seq_length, input_dim, num_hidden, num_classes,
                             device)

            model.to(device)

        # Initialize the dataset and data loader (note the +1)
        dataset = PalindromeDataset(config.input_length + 1)
        # data_loader = DataLoader(dataset, batch_size=1, num_workers=1)
        data_loader = DataLoader(dataset,
                                 batch_size=config.batch_size,
                                 num_workers=1)

        # Setup the loss and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.RMSprop(model.parameters(), lr=config.learning_rate)

        # Get single batch from dataloader
        batch_inputs, batch_targets, = next(iter(data_loader))

        # convert to one-hot
        batch_inputs = torch.scatter(
            torch.zeros(*batch_inputs.size(), config.num_classes), 2,
            batch_inputs[..., None].to(torch.int64), 1).to(device)
        batch_targets = batch_targets.to(device)

        train_output = model.analyze_hs_gradients(batch_inputs)
        loss = criterion(train_output, batch_targets)

        optimizer.zero_grad()
        loss.backward()

        gradient_norms = []
        for i, (t, h) in enumerate(reversed(model.h_states)):
            _grad = h.grad  # (batch_size x hidden_dim)
            average_grads = torch.mean(
                _grad, dim=0
            )  # Calculate average gradient to get more stable estimate
            grad_l2_norm = average_grads.norm(2).item()
            gradient_norms.append(grad_l2_norm)

        print(len(gradient_norms))
        total_norms.append(gradient_norms)

    time_steps = np.arange(150)
    print(time_steps)

    fig = plt.figure(figsize=(15, 10), dpi=150)
    # fig.suptitle('L2-norm of Gradients across Time Steps (LSTM $b_f = 2$)', fontsize=32)
    fig.suptitle('L2-norm of Gradients across Time Steps', fontsize=36)
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(total_norms[0], linewidth=2, color="tomato", label="RNN")
    ax.plot(total_norms[1], linewidth=2, color="darkblue", label="LSTM")
    ax.tick_params(labelsize=16)
    ax.set_xticks(time_steps[::10])
    ax.set_xticklabels(time_steps[::10])

    ax.set_xlabel('Backpropagation Step', fontsize=24)
    ax.set_ylabel('Gradient Norm (L2)', fontsize=24)
    ax.legend(prop={'size': 16})

    if not os.path.exists('part1/figures/'):
        os.makedirs('part1/figures/')

    plt.savefig("part1/figures/Analyze_gradients_pt_{}.png".format(
        str(pretrain_model)))
    # plt.savefig("part1/figures/Analyze_gradients_pt_{}_bias_2.png".format(str(pretrain_model)))
    plt.show()