Ejemplo n.º 1
0
def sample_sequence(model,
                    context,
                    num_samples=1,
                    temperature=1,
                    top_k=0,
                    top_p=0.0,
                    repetition_penalty=1.0,
                    device='cpu',
                    tokenizer=None):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    generated = context.clone().detach()
    prev_generated = generated
    past = None

    with torch.no_grad():
        while True:

            output, past = model(context, past=past)
            next_token_logits = output[:, -1, :] / (temperature
                                                    if temperature > 0 else 1.)

            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for i in range(num_samples):
                for _ in set(generated[i].tolist()):
                    next_token_logits[i, _] /= repetition_penalty

            filtered_logits = top_k_top_p_filtering(next_token_logits,
                                                    top_k=top_k,
                                                    top_p=top_p)
            if temperature == 0:  # greedy sampling:
                next_token = torch.argmax(filtered_logits,
                                          dim=-1).unsqueeze(-1)
            else:
                next_token = torch.multinomial(F.softmax(filtered_logits,
                                                         dim=-1),
                                               num_samples=1)

            context = next_token
            generated = torch.cat((generated, next_token), dim=1)

            eos = False
            for o in next_token.tolist():
                text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
                print(text, end="", flush=True)
                if '.' in text:
                    eos = True

            while eos:
                print()
                raw_text = input('> ')
                if raw_text == 'quit':
                    return
                if raw_text == 'revert':
                    generated = prev_generated
                    context = generated
                    past = None
                    continue

                prev_generated = generated
                eos = False

                if raw_text != '':
                    next_input = tokenizer.encode(' ' + raw_text,
                                                  add_special_tokens=False)
                    next_input = torch.tensor(next_input,
                                              dtype=torch.long,
                                              device=device)
                    next_input = next_input.unsqueeze(0).repeat(num_samples, 1)
                    generated = torch.cat((generated, next_input), dim=1)
                    context = generated
                    past = None

            if past and past[0].size()[3] > MAX_PAST:
                past = None
                context_len = MAX_PAST - BUFFER_SIZE
                context_start = generated.size()[1] - context_len
                context = torch.narrow(generated, 1, context_start,
                                       context_len)
Ejemplo n.º 2
0
 def forward(self, x):
     a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
     b = torch.narrow(a, 0, 0, 1)
     return b + x
Ejemplo n.º 3
0
 def forward(self, input):
     return torch.narrow(input, 0, 0, 2)
Ejemplo n.º 4
0
    def seg(self, inputs: List[str]):
        tokenizerd = self.tokenizer.batch_encode_plus(inputs,
                                                      return_tensors='pt',
                                                      padding=True)

        input_ids = tokenizerd['input_ids'].to(self.device)
        attention_mask = tokenizerd['attention_mask'].to(self.device)
        token_type_ids = tokenizerd['token_type_ids'].to(self.device)
        length = torch.sum(attention_mask, dim=-1) - 2

        pretrained_output, *_ = self.model.pretrained(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids)

        # remove [CLS] [SEP]
        word_cls = pretrained_output[:, :1]
        char_input = torch.narrow(pretrained_output, 1, 1,
                                  pretrained_output.size(1) - 2)

        segment_output = torch.argmax(self.model.seg_decoder(char_input),
                                      dim=-1).cpu().numpy()
        segment_output = self._convert_idx_to_name(segment_output, length,
                                                   self.seg_vocab)

        # todo: performance -- maybe cython / c++ / rust
        sentences = []
        word_idx = []
        word_length = []
        for source_text, encoding, sentence_seg_tag in zip(
                inputs, tokenizerd.encodings, segment_output):
            text = [
                source_text[start:end] for start, end in encoding.offsets[1:-1]
                if end != 0
            ]

            last_word = 0
            for idx, word in enumerate(encoding.words[1:-1]):
                if word is None or is_chinese_char(text[idx][-1]):
                    continue
                if word != last_word:
                    text[idx] = ' ' + text[idx]
                    last_word = word
                else:
                    sentence_seg_tag[idx] = WORD_MIDDLE

            entities = get_entities(sentence_seg_tag)
            word_length.append(len(entities))

            sentences.append([
                ''.join(text[entity[1]:entity[2] + 1]).strip()
                for entity in entities
            ])
            word_idx.append(
                torch.as_tensor([entity[1] for entity in entities],
                                device=self.device))

        word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True)
        word_idx = word_idx.unsqueeze(-1).expand(-1, -1, char_input.shape[-1])

        word_input = torch.gather(char_input, dim=1, index=word_idx)

        word_cls_input = torch.cat([word_cls, word_input], dim=1)
        word_cls_mask = length_to_mask(
            torch.as_tensor(word_length, device=self.device) + 1)
        word_cls_mask[:, 0] = False  # ignore the first token of each sentence
        return sentences, {
            'word_cls': word_cls,
            'word_input': word_input,
            'word_length': word_length,
            'word_cls_input': word_cls_input,
            'word_cls_mask': word_cls_mask
        }
Ejemplo n.º 5
0
    def forward(self, p_vects, q_vects, p_frames_mask, q_frames_mask,
                num_phones_mask):
        '''
        p/q_vects = [num_speakers X num_feats X max_num_mfcc_frames x mfcc_dim]
        p/q_lengths = [num_speakers X num_feats] -> stores the number of observed
                                                    frames associated
                                                    with the corresponding phone
        p/q_frames_mask = [num_speakers X num_feats X max_num_mfcc_frames x mfcc_dim]
                          -> The associated 0s and 1s mask of p/q_lengths
        num_phones_mask = [num_speakers X num_feats],
        with a 0 corresponding to position that should be -1 (no phones observed)
        and a 1 everywhere else.
        n.b. mfcc_dim = 13 usually (using c0 for energy instead of log-energy)
             num_feats = 46*47*0.5 = 1128 usually
             max_num_mfcc_frames = the maximum number of frames associated
             with a particular phone for any speaker -> often set to 4000
        '''
        # Apply the attack
        noise = torch.exp(self.noise_root)

        # Need to add spectral noise
        # Pad to spectral dimension
        padding = torch.zeros(p_vects.size(0), p_vects.size(1),
                              p_vects.size(2), self.spectral_dim -
                              self.mfcc_dim).to(self.device)
        padded_p_vects = torch.cat((p_vects, padding), 3)
        padded_q_vects = torch.cat((q_vects, padding), 3)

        # Apply inverse dct
        log_spectral_p = dct.idct(padded_p_vects)
        log_spectral_q = dct.idct(padded_q_vects)

        # Apply inverse log
        spectral_p = torch.exp(log_spectral_p)
        spectral_q = torch.exp(log_spectral_q)

        # Add the adversarial attack noise
        attacked_spectral_p = spectral_p + noise
        attacked_spectral_q = spectral_q + noise

        # Apply the log
        attacked_log_spectral_p = torch.log(attacked_spectral_p)
        attacked_log_spectral_q = torch.log(attacked_spectral_q)

        # Apply the dct
        attacked_padded_p = dct.dct(attacked_log_spectral_p)
        attacked_padded_q = dct.dct(attacked_log_spectral_q)

        # Truncate to mfcc dimension
        p_vects_attacked = torch.narrow(attacked_padded_p, 3, 0, self.mfcc_dim)
        q_vects_attacked = torch.narrow(attacked_padded_q, 3, 0, self.mfcc_dim)

        # Apply mask of zeros/ones, to ensure spectral noise only applied up to p/q lengths
        p_vects_masked = p_vects_attacked * p_frames_mask
        q_vects_masked = q_vects_attacked * q_frames_mask

        # Compute the p/q_means tensor and covariance tensor
        p_means, p_covariances, q_means, q_covariances = self.get_pq_means_covs(
            p_vects_masked, q_vects_masked, p_frames_mask, q_frames_mask,
            num_phones_mask)

        # add small noise to all covariance matrices to ensure they are non-singular
        p_covariances_noised = p_covariances + (1e-2 *
                                                torch.eye(13).to(self.device))
        q_covariances_noised = q_covariances + (1e-2 *
                                                torch.eye(13).to(self.device))

        #        print(p_covariances_noised[0,3,:,:])
        #        print(q_covariances_noised[1,4,:,:])

        # Pass through trained model
        trained_model = torch.load(self.trained_model_path)
        trained_model.to(self.device)
        trained_model.eval()
        y = trained_model(p_means, p_covariances_noised, q_means,
                          q_covariances_noised, num_phones_mask)

        return y
Ejemplo n.º 6
0
    def forward(ctx, input, input_mask, self, grads, layer_id, attn_qkvw,
                attn_qkvb, attn_ow, attn_ob, attn_nw, attn_nb, inter_w,
                inter_b, output_w, output_b, norm_w, norm_b, config):

        cuda_module = stochastic_transformer_cuda_module if config.stochastic_mode else transformer_cuda_module
        forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32

        inp_size = input.size()
        if inp_size[1] % 16 != 0:
            input = torch.cat(
                (input,
                 torch.randn(
                     (inp_size[0], (16 - (inp_size[1] % 16)), inp_size[2]),
                     device=input.device,
                     dtype=input.dtype)), 1)
            input_mask = torch.cat((input_mask, torch.ones((inp_size[0], input_mask.shape[1], input_mask.shape[2], \
                                            (16 - (inp_size[1] % 16))), device=input_mask.device, dtype=input_mask.dtype) * -10000), 3)

        (output, inp_norm, qkv_tf, soft_inp, ctx_bufB, attn_o_inp, add_res,
         ff1_inp, gelu_inp, ff2_inp, attn_prob_dropout_mask,
         attn_output_dropout_mask, layer_output_dropout_mask,
         attn_layer_norm_var, attn_layer_norm_mean,
         layer_norm_var, layer_norm_mean) = forward_func(
             config.layer_id, input, input_mask, attn_qkvw, attn_qkvb, attn_ow,
             attn_ob, attn_nw, attn_nb, inter_w, inter_b, output_w, output_b,
             norm_w, norm_b, config.training, config.pre_layer_norm,
             config.attn_dropout_checkpoint, config.normalize_invertible,
             config.gelu_checkpoint)

        # For testing only.
        if grads is not None:
            for i in [2]:
                attn_qkvw.register_hook(
                    lambda x, i=i, self=self: grads.append([
                        x[i * attn_ow.size(0):(i + 1) * attn_ow.size(0)],
                        ("Q_W" if i == 0 else "K_W" if i == 1 else "V_W")
                    ]))
            for i in [2]:
                attn_qkvb.register_hook(
                    lambda x, i=i, self=self: grads.append([
                        x[i * attn_ow.size(0):(i + 1) * attn_ow.size(0)],
                        ("Q_B" if i == 0 else "K_B" if i == 1 else "V_B")
                    ]))

            attn_ow.register_hook(
                lambda x, self=self: grads.append([x, "O_W"]))
            attn_ob.register_hook(
                lambda x, self=self: grads.append([x, "O_B"]))
            attn_nw.register_hook(
                lambda x, self=self: grads.append([x, "N2_W"]))
            attn_nb.register_hook(
                lambda x, self=self: grads.append([x, "N2_B"]))
            inter_w.register_hook(
                lambda x, self=self: grads.append([x, "int_W"]))
            inter_b.register_hook(
                lambda x, self=self: grads.append([x, "int_B"]))
            output_w.register_hook(
                lambda x, self=self: grads.append([x, "out_W"]))
            output_b.register_hook(
                lambda x, self=self: grads.append([x, "out_B"]))
            norm_w.register_hook(
                lambda x, self=self: grads.append([x, "norm_W"]))
            norm_b.register_hook(
                lambda x, self=self: grads.append([x, "norm_B"]))

        if config.is_grad_enabled and config.training:
            if (config.pre_layer_norm and config.normalize_invertible):
                ctx.save_for_backward(input_mask, attn_qkvw, attn_qkvb,
                                      attn_ow, attn_ob, attn_nw, attn_nb,
                                      inter_w, inter_b, output_w, output_b,
                                      norm_w, norm_b)
            else:
                ctx.save_for_backward(output, input, input_mask, attn_qkvw,
                                      attn_qkvb, attn_ow, attn_ob, attn_nw,
                                      attn_nb, inter_w, inter_b, output_w,
                                      output_b, norm_w, norm_b)

            ctx.config = config
            if (config.pre_layer_norm or not config.normalize_invertible):
                ctx.inp_norm = inp_norm

            ctx.qkv_tf = qkv_tf
            ctx.soft_inp = soft_inp
            if not config.attn_dropout_checkpoint:
                ctx.ctx_bufB = ctx_bufB

            ctx.attn_o_inp = attn_o_inp
            if not config.normalize_invertible:
                ctx.add_res = add_res

            ctx.attn_layer_norm_mean = attn_layer_norm_mean
            ctx.layer_norm_mean = layer_norm_mean

            ctx.ff1_inp = ff1_inp
            if not config.gelu_checkpoint:
                ctx.gelu_inp = gelu_inp

            ctx.ff2_inp = ff2_inp
            ctx.attn_prob_dropout_mask = attn_prob_dropout_mask
            ctx.attn_output_dropout_mask = attn_output_dropout_mask
            ctx.layer_output_dropout_mask = layer_output_dropout_mask
            ctx.attn_layer_norm_var = attn_layer_norm_var
            ctx.layer_norm_var = layer_norm_var

        if inp_size[1] % 16 != 0:
            output = torch.narrow(output, 1, 0, inp_size[1])

        if config.huggingface:
            return (output, )  # outputs -> (output) : outputs[0] = output
        else:
            return output
Ejemplo n.º 7
0
    def forward(self,
                input_volume,
                last_s=None,
                input_action=None,
                input_motion=None,
                next_mask=False,
                no_warp=False):
        B, _, S1, S2, S3 = input_volume.size()
        K = self.K
        device = input_volume.device
        output = {}

        input = torch.cat(
            (input_volume, self.coord_feature.expand(B, -1, -1, -1,
                                                     -1).to(device)),
            dim=1)
        input = torch.cat((input, last_s), dim=1)  # aggregate history

        volume_embedding, cache = self.volume_encoder(input)
        mask_feature = self.feature_decoder(volume_embedding, cache)

        if self.motion_type == 'conv':
            motion = self.motion_decoder(mask_feature, input_action)
            output['motion'] = motion

            return output

        assert (self.motion_type == 'se3')
        logit, mask = self.mask_decoder(mask_feature)
        output['init_logit'] = logit
        transform_param = self.transform_decoder(mask_feature, input_action)

        # trans, pivot: [B, K-1, 3]
        # rot_matrix:   [B, K-1, 3, 3]
        trans_vec, rot_mat = self.se3(transform_param)
        mask_object = torch.narrow(mask, 1, 0, K - 1)
        sum_mask = torch.sum(mask_object, dim=(2, 3, 4))
        heatmap = torch.unsqueeze(mask_object, dim=2) * self.grids.to(device)
        pivot_vec = torch.sum(heatmap, dim=(3, 4, 5)) / torch.unsqueeze(
            sum_mask, dim=2)

        # [Important] The last one is the background!
        trans_vec = torch.cat(
            [trans_vec, self.zero_vec.expand(B, -1, -1).to(device)],
            dim=1).unsqueeze(-1)
        rot_mat = torch.cat(
            [rot_mat, self.eye_mat.expand(B, 1, -1, -1).to(device)], dim=1)
        pivot_vec = torch.cat(
            [pivot_vec, self.zero_vec.expand(B, -1, -1).to(device)],
            dim=1).unsqueeze(-1)

        grids_flat = self.grids_flat.to(device)
        grids_after_flat = rot_mat @ (grids_flat -
                                      pivot_vec) + pivot_vec + trans_vec
        motion = (grids_after_flat - grids_flat).view([B, K, 3, S1, S2, S3])

        motion = torch.sum(motion * torch.unsqueeze(mask, 2), 1)

        output['motion'] = motion

        if no_warp:
            output['s'] = mask_feature
        elif input_motion is not None:
            mask_feature_warp = self.forward_warp(
                mask_feature, input_motion, torch.sum(mask[:, :-1, ], dim=1))
            output['s'] = mask_feature_warp
        else:
            mask_feature_warp = self.forward_warp(
                mask_feature, motion, torch.sum(mask[:, :-1, ], dim=1))
            output['s'] = mask_feature_warp

        if next_mask:
            mask_warp = self.forward_warp(mask, motion,
                                          torch.sum(mask[:, :-1, ], dim=1))
            output['next_mask'] = mask_warp

        return output
Ejemplo n.º 8
0
def one_label_loss(gt_percent, predict, moe, batch_node_num):
    """
    Proposed Loss Function
    Our proposed Loss Functions calculates cost of training batch using
    -GCN's output graphs and weak image level annotations.
    For more information, please refer to our paper.


    Keyword arguments:
    gt_percent       --Ground-Trueth percent, a weak image-level annotation
    predict          --GCN module output, gradient required
    moe              --Margin of Error, a weak image-level annotation
    batch_node_num   --integer list of node numbers per image in batch
    """
    curr_index = 0
    batch_top_k_loss = []
    batch_bottom_k_loss = []
    batch_pairwise_loss = []
    positive_num = 0.00000001
    negative_num = 0.00000001
    for i in range(len(gt_percent)):
        total_length = batch_node_num[i] #one graph length
        predict_slice = torch.narrow(input = predict, dim = 0, start = curr_index, length = total_length)
        curr_index += total_length
        one_gt_percent = gt_percent[i]
        one_moe = moe[i]
        select = torch.tensor([0])
        if use_cuda:
            select = select.to('cuda')

        threshold_ceil = int(total_length * (one_gt_percent - one_moe)) #100 * (0.8 - 0.1) = top 70 %
        if threshold_ceil < 0:
            threshold_ceil = 0
        threshold_floor = int(total_length * (1.0 - one_gt_percent - one_moe)) #100 * (1 - 0.8 - 0.1) = bottom 10 %
        if threshold_floor < 0:
            threshold_floor = 0

        top_k, _ = torch.topk(input = predict_slice, k = threshold_ceil, dim = 0, largest = True, sorted = False)
        bottom_k, _ = torch.topk(input = predict_slice, k = threshold_floor, dim = 0, largest = False, sorted = False)

        top_k_mean = torch.mean(top_k,dim=0)
        bottom_k_mean = torch.mean(bottom_k,dim=0)

        predict_slice = None
        top_k = None
        select = None
        bottom_k = None
        loss_fn = nn.SmoothL1Loss()
        if use_cuda:
            temp_ones = torch.ones(1, dtype = torch.float).to('cuda')
            temp_zeros = torch.tensor([-1], dtype = torch.float).to('cuda')
            temp_ground = torch.zeros(1, dtype = torch.float).to('cuda')
            if threshold_ceil > 0:
                #top_k_loss = F.l1_loss(top_k_mean, temp_ones)
                top_k_loss = loss_fn(top_k_mean, temp_ones)
                positive_num += top_k_loss.detach().cpu().numpy()
            else:
                top_k_loss = None

            if threshold_floor > 0:
                #bottom_k_loss = F.l1_loss(bottom_k_mean, temp_zeros)
                bottom_k_loss = loss_fn(bottom_k_mean, temp_zeros)
                negative_num += bottom_k_loss.detach().cpu().numpy()
            else:
                bottom_k_loss = None
            temp_ones = None
            temp_zeors = None
        else:
            if threshold_ceil > 0:
                #top_k_loss = F.l1_loss(top_k_mean, torch.ones(1, dtype = torch.float))
                top_k_loss = loss_fn(top_k_mean, torch.ones(1, dtype = torch.float))
                positive_num += 1.0
            else:
                top_k_loss = None

            if threshold_floor > 0:
                #bottom_k_loss = F.l1_loss(bottom_k_mean, torch.zeros(1, dtype = torch.float))
                bottom_k_loss = loss_fn(bottom_k_mean, torch.zeros(1, dtype = torch.float))
                negative_num += 1.0
            else:
                bottom_k_loss = None
        batch_top_k_loss.append(top_k_loss)
        batch_bottom_k_loss.append(bottom_k_loss)
    top_k_loss = None
    bottom_k_loss = None
    pairwise_loss = None
    print("-------------------------------------------------------------------------------")
    print("Targeted Regions Losses Per Image")
    print([round(float(x.data.cpu().detach().numpy()),2) if x is not None else -1.00 for x in batch_top_k_loss])
    print("Background Regions Losses Per Image")
    print([round(float(x.data.cpu().detach().numpy()),2) if x is not None else -1.00 for x in batch_bottom_k_loss])
    print("-------------------------------------------------------------------------------")

    for t, b, g, a in zip(batch_top_k_loss, batch_bottom_k_loss, gt_percent, moe):
        if top_k_loss is None and t is not None:
            top_k_loss =  (g - a) * t
        elif t is not None:
            top_k_loss +=  (g - a) * t
        if bottom_k_loss is None and b is not None:
            bottom_k_loss = (1.0 - g - a) * b
        elif b is not None:
            bottom_k_loss += (1.0 - g - a) * b
    return top_k_loss, bottom_k_loss
Ejemplo n.º 9
0
def train_text2mel(load_trained):
    # create log dir
    logdir = os.path.join(Hyper.logdir, "text2mel")
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    if not os.path.exists(os.path.join(logdir, "pkg")):
        os.mkdir(os.path.join(logdir, "pkg"))

    # device
    device = Hyper.device_text2mel

    graph = Text2Mel().to(device)
    # set the training flag
    graph.train()
    # load data and get batch maker
    names, lengths, texts = load_data()
    batch_maker = BatchMaker(Hyper.batch_size, names, lengths, texts)

    criterion_mels = nn.L1Loss().to(device)
    criterion_bd1 = nn.BCEWithLogitsLoss().to(device)
    criterion_atten = nn.L1Loss().to(device)
    optimizer = torch.optim.Adam(graph.parameters(),
                                 lr=Hyper.adam_alpha,
                                 betas=Hyper.adam_betas,
                                 eps=Hyper.adam_eps)

    lossplot_mels = LogHelper("mel_l1", logdir)
    lossplot_bd1 = LogHelper("mel_BCE", logdir)
    lossplot_atten = LogHelper("atten", logdir)

    dynamic_guide = float(Hyper.guide_weight)
    global_step = 0

    # check if load
    if load_trained > 0:
        print("load model trained for {}k batches".format(load_trained))
        global_step = load(
            os.path.join(logdir, "pkg/save_{}k.pkg".format(load_trained)),
            graph, {
                "mels": criterion_mels,
                "bd1": criterion_bd1,
                "atten": criterion_atten
            }, optimizer)
        dynamic_guide *= Hyper.guide_decay**(load_trained * 1000)

    evaluator = Evaluator()
    for loop_cnt in range(
            int(Hyper.num_batches / batch_maker.num_batches() + 0.5)):
        print("loop", loop_cnt)
        bar = PrettyBar(batch_maker.num_batches())
        bar.set_description("training...")
        loss_str0 = MovingAverage()
        loss_str1 = MovingAverage()
        loss_str2 = MovingAverage()
        for bi in bar:
            batch = batch_maker.next_batch()
            # make batch
            texts = torch.LongTensor(batch["texts"]).to(device)
            # shift mel
            shift_mels = torch.FloatTensor(
                np.concatenate((np.zeros(
                    (batch["mels"].shape[0], batch["mels"].shape[1], 1)),
                                batch["mels"][:, :, :-1]),
                               axis=2)).to(device)
            # ground truth
            mels = torch.FloatTensor(batch["mels"]).to(device)

            # forward
            pred_logits, pred_mels = graph(texts, shift_mels)
            # loss
            if False:
                loss_mels = sum(
                    criterion_mels(
                        torch.narrow(pred_mels[i], -1, 0, batch["mel_lengths"]
                                     [i]),
                        torch.narrow(mels[i], -1, 0, batch["mel_lengths"][i]))
                    for i in range(batch_maker.batch_size())) / float(
                        batch_maker.batch_size())
                loss_bd1 = sum(
                    criterion_bd1(
                        torch.narrow(pred_logits[i], -1, 0,
                                     batch["mel_lengths"][i]),
                        torch.narrow(mels[i], -1, 0, batch["mel_lengths"][i]))
                    for i in range(batch_maker.batch_size())) / float(
                        batch_maker.batch_size())
            else:
                loss_mels = criterion_mels(pred_mels, mels)
                loss_bd1 = criterion_bd1(pred_logits, mels)
            # guide attention
            atten_guide = torch.FloatTensor(batch["atten_guides"]).to(device)
            atten_mask = torch.FloatTensor(batch["atten_masks"]).to(device)
            atten_mask = torch.ones_like(graph.attention)
            loss_atten = criterion_atten(
                atten_guide * graph.attention * atten_mask,
                torch.zeros_like(graph.attention)) * dynamic_guide
            loss = loss_mels + loss_bd1 + loss_atten

            # backward
            graph.zero_grad()
            optimizer.zero_grad()
            loss.backward()
            # clip grad
            nn.utils.clip_grad_value_(graph.parameters(), 1)
            optimizer.step()
            # log
            loss_str0.add(loss_mels.cpu().data.mean())
            loss_str1.add(loss_bd1.cpu().data.mean())
            loss_str2.add(loss_atten.cpu().data.mean())
            lossplot_mels.add(loss_str0(), global_step)
            lossplot_bd1.add(loss_str1(), global_step)
            lossplot_atten.add(loss_str2(), global_step)

            # adjust dynamic_guide
            # dynamic_guide = float((loss_mels + loss_bd1).cpu().data.mean() / loss_atten.cpu().data.mean())
            dynamic_guide *= Hyper.guide_decay
            if dynamic_guide < Hyper.guide_lowbound:
                dynamic_guide = Hyper.guide_lowbound
            bar.set_description(
                "gs: {}, mels: {}, bd1: {}, atten: {}, scale: {}".format(
                    global_step, loss_str0(), loss_str1(), loss_str2(),
                    "%4f" % dynamic_guide))

            if global_step % Hyper.synth_freq == 0:
                evaluator.evaluate(loop_cnt)
                evaluator.export()

            # plot
            if global_step % 100 == 0:
                gs = 0
                plot_spectrum(mels[0].cpu().data, "mel_true", gs, dir=logdir)
                plot_spectrum(shift_mels[0].cpu().data,
                              "mel_input",
                              gs,
                              dir=logdir)
                plot_spectrum(pred_mels[0].cpu().data,
                              "mel_pred",
                              gs,
                              dir=logdir)
                plot_spectrum(graph.query[0].cpu().data,
                              "query",
                              gs,
                              dir=logdir)
                plot_attention(graph.attention[0].cpu().data,
                               "atten",
                               gs,
                               True,
                               dir=logdir)
                plot_attention((atten_guide)[0].cpu().data,
                               "atten_guide",
                               gs,
                               True,
                               dir=logdir)
                if global_step % 500 == 0:
                    lossplot_mels.plot()
                    lossplot_bd1.plot()
                    lossplot_atten.plot()

                if global_step % 10000 == 0:
                    save(
                        os.path.join(logdir, "pkg/save_{}k.pkg").format(
                            global_step // 1000), graph, {
                                "mels": criterion_mels,
                                "bd1": criterion_bd1,
                                "atten": criterion_atten
                            }, optimizer, global_step, True)

            # increase global step
            global_step += 1
Ejemplo n.º 10
0
 def tb(a):
     return torch.narrow(a, 3, 1, N - k)
Ejemplo n.º 11
0
 def forward(ctx, world_size, start_pos, chunk_size, weight, pg, bias):
     ctx.weight = weight
     ctx.pg = pg
     ctx.world_size = world_size
     return torch.narrow(bias, 0, start_pos, chunk_size)
Ejemplo n.º 12
0
 def tf(a):
     return torch.narrow(a, 3, 0, N - k)
Ejemplo n.º 13
0
    def _dp(self, arc_scores, lengths=None, force_grad=False):
        semiring = self.semiring
        arc_scores = _convert(arc_scores)
        arc_scores, batch, N, lengths = self._check_potentials(
            arc_scores, lengths)
        DIRS = 2
        alpha = [
            self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad)
            for _ in range(2)
        ]

        def stack(a, b):
            return torch.stack([a, b], dim=1)

        def sstack(a):
            return torch.stack([a, a], dim=1)

        arcs = [
            self._make_chart(1, (DIRS, batch, N - k), arc_scores,
                             force_grad)[0] for k in range(N)
        ]

        # Inside step. assumes first token is root symbol
        semiring.one_(alpha[A][C][:, :, :, :, 0].data)
        semiring.one_(alpha[B][C][:, :, :, :, -1].data)
        k = 0

        AIR = alpha[A][I][:, R, :, :N - k, 1:k]
        BIL = alpha[B][I][:, L, :, k:N, N - k:N - 1]
        k = 1
        AC2 = alpha[A][C][:, :, :, :N - k, :k]
        BC2 = alpha[B][C][:, :, :, k:, N - k:]
        AC, BC, AC_next = None, None, None

        ends = [None]
        for k in range(1, N):

            def tf(a):
                return torch.narrow(a, 3, 0, N - k)

            def tb(a):
                return torch.narrow(a, 3, 1, N - k)

            f = torch.arange(N - k), torch.arange(k, N)
            if k > 1:
                AC2 = torch.cat([tf(AC), tf(AC_next).unsqueeze(-1)], dim=4)
            if k > 1:
                BC2 = torch.cat([tb(AC_next).unsqueeze(-1), tb(BC)], dim=4)

            ACL, ACR = AC2.unbind(dim=1)
            BCL, BCR = BC2.unbind(dim=1)
            start = semiring.dot(BCL, ACR)
            # if k == 1:
            arcs[k] = stack(
                semiring.times(start, arc_scores[:, :, f[1], f[0]]),
                semiring.times(start, arc_scores[:, :, f[0], f[1]]),
            )
            arcsL, arcR = arcs[k].unbind(dim=1)
            # else:
            #     arcs[k] = stack(semiring.times(start),   #, arc_scores[:, f[1], f[0]]),
            #                     semiring.times(start)) #, arc_scores[:, f[0], f[1]]))

            AIR2 = torch.cat(
                [torch.narrow(AIR, 2, 0, N - k),
                 arcR.unsqueeze(-1)], dim=3)
            BIL2 = torch.cat(
                [arcsL.unsqueeze(-1),
                 torch.narrow(BIL, 2, 1, N - k)], dim=3)
            AC_next = stack(semiring.dot(ACL, BIL2), semiring.dot(AIR2, BCR))

            ends.append(AC_next[:, R, :, 0])
            AC = AC2
            BC = BC2
            AIR = AIR2
            BIL = BIL2
        v = torch.stack([ends[l][:, i] for i, l in enumerate(lengths)], dim=1)
        # v = torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)])
        return (semiring.unconvert(v), arcs[1:], alpha)
Ejemplo n.º 14
0
 def narrow(self, dim, start, length):
     tensor = torch.narrow(self, dim, start, length)
     tensor.dtype = self.dtype
     return tensor
Ejemplo n.º 15
0
def showtensor(tensor):
    x = torch.narrow(tensor, 0, 0, 1)
    plt.figure()
    plt.imshow(x.squeeze().numpy())
    plt.show()
Ejemplo n.º 16
0
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
                position_ids=None, head_mask=None, inputs_embeds=None, example_idx=None, extend=True):

        group_ids = None
        output_length = len(input_ids) if input_ids is not None else len(inputs_embeds)

        if extend and not self.training:
            # use the hook
            input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, example_idx, group_ids, group_sizes =\
                 self.extend_batch_examples_eval(
                            input_ids=input_ids, 
                            attention_mask=attention_mask, 
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
                            head_mask=head_mask,
                            inputs_embeds=inputs_embeds,
                            example_idx=example_idx)

        if not extend and input_ids is not None and inputs_embeds is not None:
            inputs_embeds=None

        result_logits = None
        for i in range(0,attention_mask.shape[0], self.hparams.batch_size):
            bs = min(self.hparams.batch_size, attention_mask.shape[0]-i)
            if input_ids is not None:
                batch_input_ids=input_ids.narrow(0,i,bs)
            else:
                batch_input_ids = None

            if inputs_embeds is not None:
                batch_inputs_embeds=inputs_embeds.narrow(0,i,bs)
            else:
                batch_inputs_embeds = None

            batch_attention_mask=attention_mask.narrow(0,i,bs)
            batch_token_type_ids=token_type_ids.narrow(0,i,bs)

            if position_ids is not None:
                batch_position_ids=position_ids.narrow(0,i,bs)
            else:
                batch_position_ids = None

            if head_mask is not None:
                batch_head_mask = head_mask.narrow(0,i,bs)
            else:
                batch_head_mask = None

            logits = LightningBertForSequenceClassification.forward(self, 
                            input_ids=batch_input_ids, 
                            attention_mask=batch_attention_mask, 
                            token_type_ids=batch_token_type_ids,
                            position_ids=batch_position_ids,
                            head_mask=batch_head_mask,
                            inputs_embeds=batch_inputs_embeds)

            if result_logits is None:
                result_logits = logits
            else:
                result_logits = torch.cat((result_logits, logits), dim=0)

        logits = result_logits

        #
        # time to vote
        # Makes big empty tensor for all groups
        # uses torch.view to update individual group
        #
        if group_ids is not None:
            # prepare a couple of output tensors of the right dimensions
            avg_logits = torch.zeros(output_length, self.num_labels).to(self.device)
            counted_logits = torch.zeros(output_length, self.num_labels).to(self.device)
            original_logits = logits[:output_length]

            # now go through the whole extended batch
            for i, (logit, group_id) in enumerate(zip(logits, group_ids)):

                # first, tally logits by averaging across replacement groups
                current_group_logits = torch.narrow(avg_logits, 0, group_id, 1)
                torch.add(current_group_logits, torch.div(logit, group_sizes[group_id]), out=current_group_logits)

                # but also, record the individual VOTES (argmax)
                current_vote = torch.argmax(logit).item()
                counted_logits[group_id, current_vote] += 1/group_sizes[group_id]

                # let us know what is happening here
                self._debug_print_vote(i, group_id, logit, example_idx)
    
            # print the results for this batch
            self._debug_print_votes(original_logits, avg_logits, counted_logits)

            if self.hparams.vote_avg_logits:
                logits = avg_logits
            else:
                logits = counted_logits


        #
        # return whatever we have at this point
        #
        return logits
Ejemplo n.º 17
0
def diff(a, dim=0, out=None, func=torch.not_equal):
    sz = a.size(dim) - 1
    if out is None:
        out = torch.empty(sz, dtype=torch.bool, device=a.device)
    return func(torch.narrow(a, dim, 1, sz), torch.narrow(a, dim, 0, sz), out=out)
Ejemplo n.º 18
0
    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        # for layer in self.layers:
        #     x, attn = layer(
        #         x,
        #         encoder_out['encoder_out'] if encoder_out is not None else None,
        #         encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
        #         incremental_state,
        #         self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None,
        #     )
        residual = x
        x = self.maybe_layer_norm(self.layer_norm_self_attn, x, before=True)
        ins = self.fc1(x)
        q1 = torch.narrow(ins, -1, 0, self.embed_dim)
        k1 = torch.narrow(ins, -1, self.embed_dim, self.embed_dim)
        v1 = torch.narrow(ins, -1, 2 * self.embed_dim, self.embed_dim)

        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x, attn = self.self_attn(
            q=q1,
            k=k1,
            v=v1,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.layer_norm_self_attn, x, after=True)

        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.layer_norm_context_attn,
                                      x,
                                      before=True)
            q2 = self.fc3(x)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)
            x, attn = self.encoder_attn(
                q=q2,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = self.fc4(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.layer_norm_context_attn,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_layer_norm(self.layer_norm_ffn, x, before=True)
        x = self.ffn(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.layer_norm_ffn, x, after=True)

        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state[
                "prev_value"]
            return x, attn, self_attn_state
        return x, attn
Ejemplo n.º 19
0
    def backward(ctx, grad_output):
        bsz = grad_output.shape[0]
        grad_output_shape = grad_output.size()
        if grad_output_shape[1] % 16 != 0:
            grad_output = torch.cat((grad_output, torch.zeros((bsz, (16 - (grad_output_shape[1] % 16)), \
                                        grad_output_shape[2]), device=grad_output.device, dtype=grad_output.dtype)), 1)

        assert ctx.config.training

        if (ctx.config.pre_layer_norm and ctx.config.normalize_invertible):
            (input_mask, attn_qkvw, attn_qkvb, attn_ow, attn_ob, attn_nw,
             attn_nb, inter_w, inter_b, output_w, output_b, norm_w,
             norm_b) = ctx.saved_tensors
        else:
            (output, input, input_mask, attn_qkvw, attn_qkvb, attn_ow, attn_ob,
             attn_nw, attn_nb, inter_w, inter_b, output_w, output_b, norm_w,
             norm_b) = ctx.saved_tensors

        cuda_module = stochastic_transformer_cuda_module if ctx.config.stochastic_mode else transformer_cuda_module
        backward_func = cuda_module.backward_fp16 if ctx.config.fp16 else cuda_module.backward_fp32

        (grad_input, grad_attn_qkvw, grad_attn_qkvb, grad_attn_ow,
         grad_attn_ob, grad_attn_nw, grad_attn_nb, grad_inter_w, grad_inter_b,
         grad_output_w, grad_output_b,
         grad_norm_w, grad_norm_b) = backward_func(
             ctx.config.layer_id, grad_output,
             (ctx.inp_norm if
              (ctx.config.pre_layer_norm
               and ctx.config.normalize_invertible) else output),
             (ctx.inp_norm if
              (ctx.config.pre_layer_norm
               or not ctx.config.normalize_invertible) else input), ctx.qkv_tf,
             ctx.soft_inp, (ctx.soft_inp if ctx.config.attn_dropout_checkpoint
                            else ctx.ctx_bufB), ctx.attn_o_inp,
             (ctx.ff1_inp if ctx.config.normalize_invertible else ctx.add_res),
             ctx.ff1_inp,
             (ctx.ff2_inp if ctx.config.gelu_checkpoint else ctx.gelu_inp),
             ctx.ff2_inp, ctx.attn_prob_dropout_mask,
             ctx.attn_output_dropout_mask, ctx.layer_output_dropout_mask,
             ctx.attn_layer_norm_var, ctx.attn_layer_norm_mean,
             ctx.layer_norm_var, ctx.layer_norm_mean,
             (ctx.inp_norm if
              (ctx.config.pre_layer_norm
               and ctx.config.normalize_invertible) else input), input_mask,
             attn_qkvw, attn_qkvb, attn_ow, attn_ob, attn_nw, attn_nb, inter_w,
             inter_b, output_w, output_b, norm_w, norm_b)

        # This appears to be an effective way to release context memory
        ctx.qkv_tf = None
        ctx.soft_inp = None
        ctx.ctx_bufB = None
        ctx.gelu_inp = None
        ctx.ff2_inp = None
        ctx.attn_o_inp = None
        ctx.ff1_inp = None
        ctx.add_res = None
        ctx.inp_norm = None
        ctx.config = None
        ctx.attn_layer_norm_mean = None
        ctx.layer_norm_mean = None
        ctx.attn_prob_dropout_mask = None
        ctx.attn_output_dropout_mask = None
        ctx.layer_output_dropout_mask = None
        ctx.attn_layer_norm_var = None
        ctx.layer_norm_var = None

        if grad_output_shape[1] % 16 != 0:
            grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1])

        return (grad_input, None, None, None, None, grad_attn_qkvw,
                grad_attn_qkvb, grad_attn_ow, grad_attn_ob, grad_attn_nw,
                grad_attn_nb, grad_inter_w, grad_inter_b, grad_output_w,
                grad_output_b, grad_norm_w, grad_norm_b, None)
Ejemplo n.º 20
0
    def forward(self, query, key, value, key_padding_mask=None):
        """Input shape: Time x Batch x Channel

        Self-attention can be implemented by passing in the same arguments for
        query, key and value. Timesteps can be masked by supplying a T x T mask in the
        `attn_mask` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """
        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]
        ins = self.fc1(query)
        q = torch.narrow(ins, -1, 0, embed_dim)
        k = torch.narrow(ins, -1, embed_dim, embed_dim)
        v = torch.narrow(ins, -1, 2 * embed_dim, embed_dim)

        q = q * self.scaling
        q = q.contiguous().view(tgt_len, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)
        k = k.contiguous().view(-1, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)
        v = v.contiguous().view(-1, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)
        src_len = k.size(1)
        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.shape == torch.Size(
            []):
            key_padding_mask = None
        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len
        attn_weights = torch.bmm(q, k.transpose(1, 2))
        assert list(
            attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            if self.onnx_trace:
                attn_weights = torch.where(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    torch.Tensor([float("-Inf")]),
                    attn_weights.float()).type_as(attn_weights)
            else:
                attn_weights = attn_weights.float().masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    float('-inf'),
                ).type_as(attn_weights)  # FP16 support: cast to float and back
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

        attn_weights = utils.softmax(
            attn_weights,
            dim=-1,
            onnx_trace=self.onnx_trace,
        ).type_as(attn_weights)
        attn_weights = F.dropout(attn_weights,
                                 p=self.attention_dropout,
                                 training=self.training)
        attn = torch.bmm(attn_weights, v)
        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if (self.onnx_trace and attn.size(1) == 1):
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0,
                                  1).contiguous().view(tgt_len, bsz, embed_dim)

        attn = self.fc2(attn)
        attn_weights = None
        return attn, attn_weights
Ejemplo n.º 21
0
    def _setup_for_real_optimizer(self):
        dp_world_size = dist.get_world_size(group=self.dp_process_group)
        self.partition_count = [
            dp_world_size for i in range(len(self.optimizer.param_groups))
        ]

        for i, param_group in enumerate(self.optimizer.param_groups):
            see_memory_usage(f'before initializing group {i}', force=True)

            partition_id = dist.get_rank(group=self.real_dp_process_group[i])

            # grab the original list
            self.bf16_groups.append(param_group['params'])

            # create flat bf16 params
            self.bf16_groups_flat.append(
                self._flatten_dense_tensors_aligned(
                    self.bf16_groups[i],
                    self.nccl_start_alignment_factor * dp_world_size))

            # Make bf16 params point to flat tensor storage
            self._update_storage_to_flattened_tensor(
                tensor_list=self.bf16_groups[i],
                flat_tensor=self.bf16_groups_flat[i])

            # divide flat weights into equal sized partitions
            partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
            bf16_dp_partitions = [
                self.bf16_groups_flat[i].narrow(0, dp_index * partition_size,
                                                partition_size)
                for dp_index in range(dp_world_size)
            ]
            self.bf16_partitioned_groups.append(bf16_dp_partitions)

            # create fp32 params partition
            self.fp32_groups_flat_partition.append(
                bf16_dp_partitions[partition_id].clone().float().detach())
            self.fp32_groups_flat_partition[i].requires_grad = True

            num_elem_list = [t.numel() for t in self.bf16_groups[i]]

            # create fp32 gradients
            self.fp32_groups_gradients_flat.append(
                torch.zeros_like(self.bf16_groups_flat[i],
                                 dtype=torch.float32))

            # track individual fp32 gradients for entire model
            fp32_gradients = self._split_flat_tensor(
                flat_tensor=self.fp32_groups_gradients_flat[i],
                num_elem_list=num_elem_list)
            self.fp32_groups_gradients.append(fp32_gradients)

            # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding)
            length_without_padding = sum(num_elem_list)
            self.fp32_groups_actual_gradients_flat.append(
                torch.narrow(self.fp32_groups_gradients_flat[i], 0, 0,
                             length_without_padding))

            # flat tensor corresponding to gradient partition
            self.fp32_groups_gradient_flat_partition.append(
                torch.narrow(self.fp32_groups_gradients_flat[i], 0,
                             partition_id * partition_size, partition_size))

            # track fp32 gradient updates
            self.fp32_groups_has_gradients.append([False] *
                                                  len(self.bf16_groups[i]))

            # Record padding required for alignment
            if partition_id == dist.get_world_size(
                    group=self.real_dp_process_group[i]) - 1:
                padding = self.bf16_groups_flat[i].numel(
                ) - length_without_padding
            else:
                padding = 0

            self.group_paddings.append(padding)

            # update optimizer param groups to reference fp32 params partition
            param_group['params'] = [self.fp32_groups_flat_partition[i]]

            see_memory_usage(f'after initializing group {i}', force=True)

        see_memory_usage('before initialize_optimizer', force=True)
        self.initialize_optimizer_states()
        see_memory_usage('end initialize_optimizer', force=True)

        # Need optimizer states initialized before linking lp to optimizer state
        self._link_all_hp_params()
        self._param_slice_mappings = self._create_param_mapping()
Ejemplo n.º 22
0
    def forward(self,
                q,
                key,
                value,
                key_padding_mask=None,
                incremental_state=None,
                need_weights=True,
                static_kv=False,
                attn_mask=None):
        """Input shape: Time x Batch x Channel

        Self-attention can be implemented by passing in the same arguments for
        query, key and value. Timesteps can be masked by supplying a T x T mask in the
        `attn_mask` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """
        qkv_same = False
        kv_same = True

        tgt_len, bsz, embed_dim = q.size()
        assert embed_dim == self.embed_dim
        assert list(q.size()) == [tgt_len, bsz, embed_dim]

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if 'prev_key' in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert kv_same and not qkv_same
                    key = value = None
        else:
            saved_state = None

        # encoder-decoder attention
        # q = self.in_proj_q(query)
        if key is None:
            assert value is None
            k = v = None
        else:
            kv = self.kv_fc(key)
            k = torch.narrow(kv, -1, 0, self.embed_dim)
            v = torch.narrow(kv, -1, self.embed_dim, self.embed_dim)
        q *= self.scaling
        q = q.contiguous().view(tgt_len, bsz * self.num_heads,
                                self.head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads,
                                    self.head_dim).transpose(0, 1)
        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value'].view(
                    bsz * self.num_heads, -1, self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    v = torch.cat((prev_value, v), dim=1)
            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)
            self._set_input_buffer(incremental_state, saved_state)
        src_len = k.size(1)
        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.shape == torch.Size(
            []):
            key_padding_mask = None

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len
        attn_weights = torch.bmm(q, k.transpose(1, 2))
        assert list(
            attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            if self.onnx_trace:
                attn_weights = torch.where(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    torch.Tensor([float("-Inf")]),
                    attn_weights.float()).type_as(attn_weights)
            else:
                attn_weights = attn_weights.float().masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2),
                    float('-inf'),
                ).type_as(attn_weights)  # FP16 support: cast to float and back
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

        attn_weights = utils.softmax(
            attn_weights,
            dim=-1,
            onnx_trace=self.onnx_trace,
        ).type_as(attn_weights)
        attn_weights = F.dropout(attn_weights,
                                 p=self.dropout,
                                 training=self.training)

        attn = torch.bmm(attn_weights, v)
        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if (self.onnx_trace and attn.size(1) == 1):
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0,
                                  1).contiguous().view(tgt_len, bsz, embed_dim)

        if need_weights:
            # average attention weights over heads
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            attn_weights = attn_weights.sum(dim=1) / self.num_heads
        else:
            attn_weights = None

        return attn, attn_weights
Ejemplo n.º 23
0
 def test_narrow(self):
     x = torch.randn(3, 3, requires_grad=True)
     self.assertONNX(lambda x: torch.narrow(x, 0, 0, 2), x)
Ejemplo n.º 24
0
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for Linear. (Detailed explanations of the logic can be found in the
    comment for sharded_linear.)

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        rank: # of cuda process.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns: final result of linear operation.
    """
    # alltoall to gather all the appropriate inputs.
    input_t = input.t().contiguous()
    input_t_size = input_t.size()

    # Compute expected size
    split_size = get_split_size(input_t_size[0], world_size)
    input_split_sizes = [0] * world_size
    rearrange_rows = False

    for idx, placement in enumerate(weight._sharding_spec.placements):
        sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx)
        input_split_sizes[placement.rank()] = sharded_dim_size
        if placement.rank() != idx:
            rearrange_rows = True

    if rearrange_rows:
        # Need to re-arrange rows of input_t for all2all.
        indices: List[List[int]] = [[0]] * world_size
        # When we do the chunk split, we always ensure the first N - 1 chunks get max out
        # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
        # are not possible. The expected split size will be [4, 4, 4, 1].
        sharded_dim_size_max = max(input_split_sizes)
        for idx, placement in enumerate(weight._sharding_spec.placements):
            split_size = input_split_sizes[placement.rank()]
            offset_start_idx = idx * sharded_dim_size_max
            indices[placement.rank()] = list(range(offset_start_idx, offset_start_idx + split_size))
        indices_flatten = list(idx for indice in indices for idx in indice)

        input_t = input_t.index_select(0, torch.tensor(indices_flatten, device=input_t.device))

    gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device)

    # Perform alltoall
    dist.all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg)
    gathered_input = gathered_input.t()

    # Perform local matmuls for all shards
    shard_size = local_shard_t.size()[0]
    results = []
    for r in range(world_size):
        inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size)
        results.append(inp.matmul(local_shard_t))

    # Gather all the results appropriately.
    local_result = torch.empty_like(results[rank])
    dist.reduce_scatter(local_result, results, group=pg)

    # Return the appropriate local result.
    return local_result + bias
Ejemplo n.º 25
0
Archivo: tensor.py Proyecto: yifeim/dgl
def slice_axis(data, axis, begin, end):
    return th.narrow(data, axis, begin, end - begin)
Ejemplo n.º 26
0
 def backward(ctx, grad_output):
     slice_size = grad_output.size(ctx.dim) // ctx.world_size
     return torch.narrow(grad_output.clone(), ctx.dim,
                         ctx.ordinal * slice_size, slice_size), None
Ejemplo n.º 27
0
    def forward(self, adv_patch, lab_batch, img_size, do_rotate=True, rand_loc=True):
        #adv_patch = F.conv2d(adv_patch.unsqueeze(0),self.kernel,padding=(2,2))
        adv_patch = self.medianpooler(adv_patch.unsqueeze(0))
        #print('lab_batch---------------------------: ',lab_batch)
        # Determine size of padding
        pad = (img_size - adv_patch.size(-1)) / 2
        # Make a batch of patches
        adv_patch = adv_patch.unsqueeze(0)#.unsqueeze(0)
        adv_batch = adv_patch.expand(lab_batch.size(0), lab_batch.size(1), -1, -1, -1)
        batch_size = torch.Size((lab_batch.size(0), lab_batch.size(1)))
        #print('--========+++++======---adv_patch/adv_batch-----',adv_patch.shape,adv_batch.shape)
        #torch.Size([1, 1, 3, 300, 300]) torch.Size([8, 14, 3, 300, 300])
        # Contrast, brightness and noise transforms
        
        # Create random contrast tensor
        contrast = torch.cuda.FloatTensor(batch_size).uniform_(self.min_contrast, self.max_contrast)
        contrast = contrast.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        contrast = contrast.expand(-1, -1, adv_batch.size(-3), adv_batch.size(-2), adv_batch.size(-1))
        contrast = contrast.cuda()


        # Create random brightness tensor
        brightness = torch.cuda.FloatTensor(batch_size).uniform_(self.min_brightness, self.max_brightness)
        brightness = brightness.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        brightness = brightness.expand(-1, -1, adv_batch.size(-3), adv_batch.size(-2), adv_batch.size(-1))
        brightness = brightness.cuda()


        # Create random noise tensor
        noise = torch.cuda.FloatTensor(adv_batch.size()).uniform_(-1, 1) * self.noise_factor


        # Apply contrast/brightness/noise, clamp
        adv_batch = adv_batch * contrast + brightness + noise

        adv_batch = torch.clamp(adv_batch, 0.000001, 0.99999)

        #人为指定maxlab数量最大为14,不够的其值全部填充1,所以这里以示区别
        # Where the label class_id is 1 we don't want a patch (padding) --> fill mask with zero's
        #class_id = 1 人为填充的1
        #lab_batch.size=torch.Size([8, 14, 5]),batch=8,maxlab=14,5表示标签+坐标
        cls_ids = torch.narrow(lab_batch, 2, 0, 1) #取lab_batch得第二个维度即5上的0-1得索引值即第一个id值
        cls_mask = cls_ids.expand(-1, -1, 3)
        cls_mask = cls_mask.unsqueeze(-1)
        cls_mask = cls_mask.expand(-1, -1, -1, adv_batch.size(3))
        cls_mask = cls_mask.unsqueeze(-1)
        cls_mask = cls_mask.expand(-1, -1, -1, -1, adv_batch.size(4))
        msk_batch = torch.cuda.FloatTensor(cls_mask.size()).fill_(1) - cls_mask
        #print('++++++++++====msk_batch=========',msk_batch.shape,msk_batch)#torch.Size([8, 14, 3, 300, 300])
        # Pad patch and mask to image dimensions
        mypad = nn.ConstantPad2d((int(pad + 0.5), int(pad), int(pad + 0.5), int(pad)), 0)
        #左右上下四个维度分别按指定int大小填充相应个数0,使图像块填充后大小和原图像大小相同
        #填充的值为0,在两者融合时过滤掉0值
        #分别对应同等大小的图像块和label_id(扩维后得),将两者相乘表示以id过滤图像块
        adv_batch = mypad(adv_batch)
        msk_batch = mypad(msk_batch)


        # Rotation and rescaling transforms,根据真实label的大小、方向进行图像块的填充
        anglesize = (lab_batch.size(0) * lab_batch.size(1))
        if do_rotate:
            angle = torch.cuda.FloatTensor(anglesize).uniform_(self.minangle, self.maxangle)
        else: 
            angle = torch.cuda.FloatTensor(anglesize).fill_(0)

        # Resizes and rotates
        current_patch_size = adv_patch.size(-1)
        lab_batch_scaled = torch.cuda.FloatTensor(lab_batch.size()).fill_(0)
        #根据label坐标获取真实标注框大小:x\y\w\h
        lab_batch_scaled[:, :, 1] = lab_batch[:, :, 1] * img_size
        lab_batch_scaled[:, :, 2] = lab_batch[:, :, 2] * img_size
        lab_batch_scaled[:, :, 3] = lab_batch[:, :, 3] * img_size
        lab_batch_scaled[:, :, 4] = lab_batch[:, :, 4] * img_size
        #图像块大小
        target_size = torch.sqrt(((lab_batch_scaled[:, :, 3].mul(0.2)) ** 2) + ((lab_batch_scaled[:, :, 4].mul(0.2)) ** 2))
        target_x = lab_batch[:, :, 1].view(np.prod(batch_size))
        target_y = lab_batch[:, :, 2].view(np.prod(batch_size))
        targetoff_x = lab_batch[:, :, 3].view(np.prod(batch_size))
        targetoff_y = lab_batch[:, :, 4].view(np.prod(batch_size))
        if(rand_loc):
            off_x = targetoff_x*(torch.cuda.FloatTensor(targetoff_x.size()).uniform_(-0.4,0.4))
            target_x = target_x + off_x
            off_y = targetoff_y*(torch.cuda.FloatTensor(targetoff_y.size()).uniform_(-0.4,0.4))
            target_y = target_y + off_y
        target_y = target_y - 0.05
        scale = target_size / current_patch_size
        scale = scale.view(anglesize)

        s = adv_batch.size()
        adv_batch = adv_batch.view(s[0] * s[1], s[2], s[3], s[4])
        msk_batch = msk_batch.view(s[0] * s[1], s[2], s[3], s[4])


        tx = (-target_x+0.5)*2
        ty = (-target_y+0.5)*2
        sin = torch.sin(angle)
        cos = torch.cos(angle)        

        # Theta = rotation,rescale matrix
        theta = torch.cuda.FloatTensor(anglesize, 2, 3).fill_(0)
        theta[:, 0, 0] = cos/scale
        theta[:, 0, 1] = sin/scale
        theta[:, 0, 2] = tx*cos/scale+ty*sin/scale
        theta[:, 1, 0] = -sin/scale
        theta[:, 1, 1] = cos/scale
        theta[:, 1, 2] = -tx*sin/scale+ty*cos/scale

        b_sh = adv_batch.shape
        #仿射变换:进行相应旋转平移缩放等,最终输出大小为图片大小416
        grid = F.affine_grid(theta, adv_batch.shape)

        adv_batch_t = F.grid_sample(adv_batch, grid)
        msk_batch_t = F.grid_sample(msk_batch, grid)
        #print('-_______________adv_batch_t/mas________-----',adv_batch_t.shape,msk_batch_t.shape)
        #torch.Size([112, 3, 416, 416]) torch.Size([112, 3, 416, 416])
        '''
        # Theta2 = translation matrix
        theta2 = torch.cuda.FloatTensor(anglesize, 2, 3).fill_(0)
        theta2[:, 0, 0] = 1
        theta2[:, 0, 1] = 0
        theta2[:, 0, 2] = (-target_x + 0.5) * 2
        theta2[:, 1, 0] = 0
        theta2[:, 1, 1] = 1
        theta2[:, 1, 2] = (-target_y + 0.5) * 2

        grid2 = F.affine_grid(theta2, adv_batch.shape)
        adv_batch_t = F.grid_sample(adv_batch_t, grid2)
        msk_batch_t = F.grid_sample(msk_batch_t, grid2)

        '''
        adv_batch_t = adv_batch_t.view(s[0], s[1], s[2], s[3], s[4])
        msk_batch_t = msk_batch_t.view(s[0], s[1], s[2], s[3], s[4])

        adv_batch_t = torch.clamp(adv_batch_t, 0.000001, 0.999999)
        #img = msk_batch_t[0, 0, :, :, :].detach().cpu()
        #img = transforms.ToPILImage()(img)
        #img.show()
        #exit()

        return adv_batch_t * msk_batch_t
def d_slice(dist, i, j, length):
    return torch.narrow(torch.narrow(dist, 0, i, length), 1, j, length)
Ejemplo n.º 29
0
    def _forward_biobert(
        self, tokens: List[List[str]]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Return BioBERT Hidden state for the tokenized documents.
        Documents with different lengths will be accepted.

        list(list(str)) -> tuple(torch.tensor, torch.tensor)
        """
        # Convert each token of each document into a list of subwords.
        # e.g.,
        #   [['Admission', 'Date', ...], ['Service', ':', ...]]
        #       |
        #       V
        #   [[['Ad', '##mission'], ['Date'], ...], [['Service'], [':'], ...]]
        subwords_unchained = [
            [self.tokenizer.tokenize(tok) for tok in doc] for doc in tokens
        ]

        # Simply replace each token of each document with corresponding subwords.
        # e.g.,
        #   [['Admission', 'Date', ...], ['Service', ':', ...]]
        #       |
        #       V
        #   [['Ad', '##mission', 'Date', ...], ['Service', ':', ...]]
        subwords = [
            list(itertools.chain(*[self.tokenizer.tokenize(tok) for tok in doc]))
            for doc in tokens
        ]

        # Memorize (i) header place of each token and (ii) how many subwords each token gave birth.
        # e.g.,
        #   For document ['Admission', 'Date'] -> ['Ad', '##mission', 'Date'],
        #   subword_info will be {'start':[0,2], 'length':[2,1]}.
        subword_info = []
        for doc in subwords_unchained:
            word_lengths = [len(word) for word in doc]
            word_head_ix = [0]
            for i in range(len(word_lengths) - 1):
                word_head_ix.append(word_head_ix[-1] + word_lengths[i])
            assert len(word_lengths) == len(word_head_ix)
            subword_info.append({"start": word_head_ix, "length": word_lengths})

        assert [len(info["start"]) for info in subword_info] == [
            len(doc) for doc in tokens
        ]

        # Split each document into chunks shorter than max_length.
        # Here, each document will be simply split at every 510 tokens.

        max_length = min(
            self.bertconfig.max_position_embeddings, self.hparams.max_length
        )

        longest_length = max([len(doc) for doc in subwords])
        n_chunks = (longest_length - 1) // (max_length - 2) + 1
        chunks = []
        for n in range(n_chunks):
            chunk_of_all_documents = []
            for document in subwords:
                chunk_of_single_document = document[
                    (max_length - 2) * n : (max_length - 2) * (n + 1)
                ]
                if chunk_of_single_document == []:
                    chunk_of_all_documents.append([""])
                else:
                    chunk_of_all_documents.append(chunk_of_single_document)
            chunks.append(chunk_of_all_documents)

        # Convert chunks into BERT input form.
        inputs = []
        for chunk in chunks:
            if type(chunk) is str:
                unsqueezed_chunk = [[chunk]]
            elif type(chunk) is list:
                if type(chunk[0]) is str:
                    unsqueezed_chunk = [chunk]
                elif type(chunk[0]) is list:
                    unsqueezed_chunk = chunk

            inputs.append(
                self.tokenizer.batch_encode_plus(
                    unsqueezed_chunk,
                    pad_to_max_length=True,
                    is_pretokenized=True,
                )
            )

        # Get BioBERT hidden states.
        hidden_states = []
        for inpt in inputs:
            inpt_tensors = {
                k: torch.tensor(v).to(self.get_device()) for k, v in inpt.items()
            }
            hidden_state = self.biobert(**inpt_tensors)[0][:, 1:-1, :]
            hidden_states.append(hidden_state)

        # Concatenate hidden states from each chunk.
        hidden_states_cat = torch.cat(hidden_states, dim=1)

        # If a word was tokenized into multiple subwords, take average of them.
        # e.g. Hidden state for "Admission" equals average of hidden states for "Ad" and "##mission"
        hidden_states_shrunk = torch.zeros_like(hidden_states_cat)
        for n in range(hidden_states_cat.size()[0]):
            hidden_state_shrunk = torch.stack(
                [
                    torch.narrow(hidden_states_cat[n], dim=0, start=s, length=l).mean(
                        dim=0
                    )
                    for s, l in zip(subword_info[n]["start"], subword_info[n]["length"])
                ]
            )
            hidden_states_shrunk[
                n, : hidden_state_shrunk.size()[0], :
            ] = hidden_state_shrunk

        # Truncate lengthy tail that will not be used.
        hidden_states_shrunk = hidden_states_shrunk[
            :, : max([len(doc) for doc in tokens]), :
        ]

        # Create mask for CRF.
        crf_mask = torch.zeros(hidden_states_shrunk.size()[:2]).to(torch.uint8)
        for i, length in enumerate([len(doc) for doc in tokens]):
            crf_mask[i, :length] = 1
        crf_mask = crf_mask > 0
        crf_mask = crf_mask.to(self.get_device())

        return (hidden_states_shrunk, crf_mask)
Ejemplo n.º 30
0
    def extract_feature_matrix(self):

        # define generator
        generator = self.generator(self.paths)

        # load extractor
        extractor = self.load_vgg19(self.layer)

        # initialize sketch and label matrices
        features = []
        paths = []
        n = 0
        quit = False

        # generate batches of sketches and labels
        if generator:
            while True:
                batch_size = self.batch_size
                img_batch = torch.zeros(batch_size, 3, self.imsize,
                                        self.imsize)
                paths_batch = []
                if self.use_cuda:
                    img_batch = img_batch.to(self.cuda_device)

                if (n + 1) % 5 == 0:
                    print('Batch {}'.format(n + 1))

                for b in range(batch_size):
                    try:
                        img, path = next(generator)
                        img_batch[b] = img
                        paths_batch.append(path)

                    except StopIteration:
                        quit = True
                        print('stopped!')
                        break

                if n == self.num_images // self.batch_size:
                    print('b', b)
                    print(img_batch.size())
                    img_batch = torch.narrow(img_batch, 0, 0, b)
                    print(img_batch.size())
                    paths_batch = paths_batch[:b + 1]

                # extract features from batch
                n += 1
                feats_batch = extractor(img_batch)
                feats_batch = [feat.cpu().data.numpy() for feat in feats_batch]
                feats_batch = np.squeeze(np.array(feats_batch), axis=0)
                #                feats_batch = feats_batch.cpu().data.numpy()
                #                print('features shape', features.shape)

                if len(features) == 0:
                    features = feats_batch
                else:
                    features = np.vstack((features, feats_batch))

                paths.append(paths_batch)
                if n == self.num_images // batch_size + 1:
                    break

        return features, paths