def sample_sequence(model, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device='cpu', tokenizer=None): context = torch.tensor(context, dtype=torch.long, device=device) context = context.unsqueeze(0).repeat(num_samples, 1) generated = context.clone().detach() prev_generated = generated past = None with torch.no_grad(): while True: output, past = model(context, past=past) next_token_logits = output[:, -1, :] / (temperature if temperature > 0 else 1.) # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) for i in range(num_samples): for _ in set(generated[i].tolist()): next_token_logits[i, _] /= repetition_penalty filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) if temperature == 0: # greedy sampling: next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1) else: next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) context = next_token generated = torch.cat((generated, next_token), dim=1) eos = False for o in next_token.tolist(): text = tokenizer.decode(o, clean_up_tokenization_spaces=True) print(text, end="", flush=True) if '.' in text: eos = True while eos: print() raw_text = input('> ') if raw_text == 'quit': return if raw_text == 'revert': generated = prev_generated context = generated past = None continue prev_generated = generated eos = False if raw_text != '': next_input = tokenizer.encode(' ' + raw_text, add_special_tokens=False) next_input = torch.tensor(next_input, dtype=torch.long, device=device) next_input = next_input.unsqueeze(0).repeat(num_samples, 1) generated = torch.cat((generated, next_input), dim=1) context = generated past = None if past and past[0].size()[3] > MAX_PAST: past = None context_len = MAX_PAST - BUFFER_SIZE context_start = generated.size()[1] - context_len context = torch.narrow(generated, 1, context_start, context_len)
def forward(self, x): a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) b = torch.narrow(a, 0, 0, 1) return b + x
def forward(self, input): return torch.narrow(input, 0, 0, 2)
def seg(self, inputs: List[str]): tokenizerd = self.tokenizer.batch_encode_plus(inputs, return_tensors='pt', padding=True) input_ids = tokenizerd['input_ids'].to(self.device) attention_mask = tokenizerd['attention_mask'].to(self.device) token_type_ids = tokenizerd['token_type_ids'].to(self.device) length = torch.sum(attention_mask, dim=-1) - 2 pretrained_output, *_ = self.model.pretrained( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) # remove [CLS] [SEP] word_cls = pretrained_output[:, :1] char_input = torch.narrow(pretrained_output, 1, 1, pretrained_output.size(1) - 2) segment_output = torch.argmax(self.model.seg_decoder(char_input), dim=-1).cpu().numpy() segment_output = self._convert_idx_to_name(segment_output, length, self.seg_vocab) # todo: performance -- maybe cython / c++ / rust sentences = [] word_idx = [] word_length = [] for source_text, encoding, sentence_seg_tag in zip( inputs, tokenizerd.encodings, segment_output): text = [ source_text[start:end] for start, end in encoding.offsets[1:-1] if end != 0 ] last_word = 0 for idx, word in enumerate(encoding.words[1:-1]): if word is None or is_chinese_char(text[idx][-1]): continue if word != last_word: text[idx] = ' ' + text[idx] last_word = word else: sentence_seg_tag[idx] = WORD_MIDDLE entities = get_entities(sentence_seg_tag) word_length.append(len(entities)) sentences.append([ ''.join(text[entity[1]:entity[2] + 1]).strip() for entity in entities ]) word_idx.append( torch.as_tensor([entity[1] for entity in entities], device=self.device)) word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True) word_idx = word_idx.unsqueeze(-1).expand(-1, -1, char_input.shape[-1]) word_input = torch.gather(char_input, dim=1, index=word_idx) word_cls_input = torch.cat([word_cls, word_input], dim=1) word_cls_mask = length_to_mask( torch.as_tensor(word_length, device=self.device) + 1) word_cls_mask[:, 0] = False # ignore the first token of each sentence return sentences, { 'word_cls': word_cls, 'word_input': word_input, 'word_length': word_length, 'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask }
def forward(self, p_vects, q_vects, p_frames_mask, q_frames_mask, num_phones_mask): ''' p/q_vects = [num_speakers X num_feats X max_num_mfcc_frames x mfcc_dim] p/q_lengths = [num_speakers X num_feats] -> stores the number of observed frames associated with the corresponding phone p/q_frames_mask = [num_speakers X num_feats X max_num_mfcc_frames x mfcc_dim] -> The associated 0s and 1s mask of p/q_lengths num_phones_mask = [num_speakers X num_feats], with a 0 corresponding to position that should be -1 (no phones observed) and a 1 everywhere else. n.b. mfcc_dim = 13 usually (using c0 for energy instead of log-energy) num_feats = 46*47*0.5 = 1128 usually max_num_mfcc_frames = the maximum number of frames associated with a particular phone for any speaker -> often set to 4000 ''' # Apply the attack noise = torch.exp(self.noise_root) # Need to add spectral noise # Pad to spectral dimension padding = torch.zeros(p_vects.size(0), p_vects.size(1), p_vects.size(2), self.spectral_dim - self.mfcc_dim).to(self.device) padded_p_vects = torch.cat((p_vects, padding), 3) padded_q_vects = torch.cat((q_vects, padding), 3) # Apply inverse dct log_spectral_p = dct.idct(padded_p_vects) log_spectral_q = dct.idct(padded_q_vects) # Apply inverse log spectral_p = torch.exp(log_spectral_p) spectral_q = torch.exp(log_spectral_q) # Add the adversarial attack noise attacked_spectral_p = spectral_p + noise attacked_spectral_q = spectral_q + noise # Apply the log attacked_log_spectral_p = torch.log(attacked_spectral_p) attacked_log_spectral_q = torch.log(attacked_spectral_q) # Apply the dct attacked_padded_p = dct.dct(attacked_log_spectral_p) attacked_padded_q = dct.dct(attacked_log_spectral_q) # Truncate to mfcc dimension p_vects_attacked = torch.narrow(attacked_padded_p, 3, 0, self.mfcc_dim) q_vects_attacked = torch.narrow(attacked_padded_q, 3, 0, self.mfcc_dim) # Apply mask of zeros/ones, to ensure spectral noise only applied up to p/q lengths p_vects_masked = p_vects_attacked * p_frames_mask q_vects_masked = q_vects_attacked * q_frames_mask # Compute the p/q_means tensor and covariance tensor p_means, p_covariances, q_means, q_covariances = self.get_pq_means_covs( p_vects_masked, q_vects_masked, p_frames_mask, q_frames_mask, num_phones_mask) # add small noise to all covariance matrices to ensure they are non-singular p_covariances_noised = p_covariances + (1e-2 * torch.eye(13).to(self.device)) q_covariances_noised = q_covariances + (1e-2 * torch.eye(13).to(self.device)) # print(p_covariances_noised[0,3,:,:]) # print(q_covariances_noised[1,4,:,:]) # Pass through trained model trained_model = torch.load(self.trained_model_path) trained_model.to(self.device) trained_model.eval() y = trained_model(p_means, p_covariances_noised, q_means, q_covariances_noised, num_phones_mask) return y
def forward(ctx, input, input_mask, self, grads, layer_id, attn_qkvw, attn_qkvb, attn_ow, attn_ob, attn_nw, attn_nb, inter_w, inter_b, output_w, output_b, norm_w, norm_b, config): cuda_module = stochastic_transformer_cuda_module if config.stochastic_mode else transformer_cuda_module forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32 inp_size = input.size() if inp_size[1] % 16 != 0: input = torch.cat( (input, torch.randn( (inp_size[0], (16 - (inp_size[1] % 16)), inp_size[2]), device=input.device, dtype=input.dtype)), 1) input_mask = torch.cat((input_mask, torch.ones((inp_size[0], input_mask.shape[1], input_mask.shape[2], \ (16 - (inp_size[1] % 16))), device=input_mask.device, dtype=input_mask.dtype) * -10000), 3) (output, inp_norm, qkv_tf, soft_inp, ctx_bufB, attn_o_inp, add_res, ff1_inp, gelu_inp, ff2_inp, attn_prob_dropout_mask, attn_output_dropout_mask, layer_output_dropout_mask, attn_layer_norm_var, attn_layer_norm_mean, layer_norm_var, layer_norm_mean) = forward_func( config.layer_id, input, input_mask, attn_qkvw, attn_qkvb, attn_ow, attn_ob, attn_nw, attn_nb, inter_w, inter_b, output_w, output_b, norm_w, norm_b, config.training, config.pre_layer_norm, config.attn_dropout_checkpoint, config.normalize_invertible, config.gelu_checkpoint) # For testing only. if grads is not None: for i in [2]: attn_qkvw.register_hook( lambda x, i=i, self=self: grads.append([ x[i * attn_ow.size(0):(i + 1) * attn_ow.size(0)], ("Q_W" if i == 0 else "K_W" if i == 1 else "V_W") ])) for i in [2]: attn_qkvb.register_hook( lambda x, i=i, self=self: grads.append([ x[i * attn_ow.size(0):(i + 1) * attn_ow.size(0)], ("Q_B" if i == 0 else "K_B" if i == 1 else "V_B") ])) attn_ow.register_hook( lambda x, self=self: grads.append([x, "O_W"])) attn_ob.register_hook( lambda x, self=self: grads.append([x, "O_B"])) attn_nw.register_hook( lambda x, self=self: grads.append([x, "N2_W"])) attn_nb.register_hook( lambda x, self=self: grads.append([x, "N2_B"])) inter_w.register_hook( lambda x, self=self: grads.append([x, "int_W"])) inter_b.register_hook( lambda x, self=self: grads.append([x, "int_B"])) output_w.register_hook( lambda x, self=self: grads.append([x, "out_W"])) output_b.register_hook( lambda x, self=self: grads.append([x, "out_B"])) norm_w.register_hook( lambda x, self=self: grads.append([x, "norm_W"])) norm_b.register_hook( lambda x, self=self: grads.append([x, "norm_B"])) if config.is_grad_enabled and config.training: if (config.pre_layer_norm and config.normalize_invertible): ctx.save_for_backward(input_mask, attn_qkvw, attn_qkvb, attn_ow, attn_ob, attn_nw, attn_nb, inter_w, inter_b, output_w, output_b, norm_w, norm_b) else: ctx.save_for_backward(output, input, input_mask, attn_qkvw, attn_qkvb, attn_ow, attn_ob, attn_nw, attn_nb, inter_w, inter_b, output_w, output_b, norm_w, norm_b) ctx.config = config if (config.pre_layer_norm or not config.normalize_invertible): ctx.inp_norm = inp_norm ctx.qkv_tf = qkv_tf ctx.soft_inp = soft_inp if not config.attn_dropout_checkpoint: ctx.ctx_bufB = ctx_bufB ctx.attn_o_inp = attn_o_inp if not config.normalize_invertible: ctx.add_res = add_res ctx.attn_layer_norm_mean = attn_layer_norm_mean ctx.layer_norm_mean = layer_norm_mean ctx.ff1_inp = ff1_inp if not config.gelu_checkpoint: ctx.gelu_inp = gelu_inp ctx.ff2_inp = ff2_inp ctx.attn_prob_dropout_mask = attn_prob_dropout_mask ctx.attn_output_dropout_mask = attn_output_dropout_mask ctx.layer_output_dropout_mask = layer_output_dropout_mask ctx.attn_layer_norm_var = attn_layer_norm_var ctx.layer_norm_var = layer_norm_var if inp_size[1] % 16 != 0: output = torch.narrow(output, 1, 0, inp_size[1]) if config.huggingface: return (output, ) # outputs -> (output) : outputs[0] = output else: return output
def forward(self, input_volume, last_s=None, input_action=None, input_motion=None, next_mask=False, no_warp=False): B, _, S1, S2, S3 = input_volume.size() K = self.K device = input_volume.device output = {} input = torch.cat( (input_volume, self.coord_feature.expand(B, -1, -1, -1, -1).to(device)), dim=1) input = torch.cat((input, last_s), dim=1) # aggregate history volume_embedding, cache = self.volume_encoder(input) mask_feature = self.feature_decoder(volume_embedding, cache) if self.motion_type == 'conv': motion = self.motion_decoder(mask_feature, input_action) output['motion'] = motion return output assert (self.motion_type == 'se3') logit, mask = self.mask_decoder(mask_feature) output['init_logit'] = logit transform_param = self.transform_decoder(mask_feature, input_action) # trans, pivot: [B, K-1, 3] # rot_matrix: [B, K-1, 3, 3] trans_vec, rot_mat = self.se3(transform_param) mask_object = torch.narrow(mask, 1, 0, K - 1) sum_mask = torch.sum(mask_object, dim=(2, 3, 4)) heatmap = torch.unsqueeze(mask_object, dim=2) * self.grids.to(device) pivot_vec = torch.sum(heatmap, dim=(3, 4, 5)) / torch.unsqueeze( sum_mask, dim=2) # [Important] The last one is the background! trans_vec = torch.cat( [trans_vec, self.zero_vec.expand(B, -1, -1).to(device)], dim=1).unsqueeze(-1) rot_mat = torch.cat( [rot_mat, self.eye_mat.expand(B, 1, -1, -1).to(device)], dim=1) pivot_vec = torch.cat( [pivot_vec, self.zero_vec.expand(B, -1, -1).to(device)], dim=1).unsqueeze(-1) grids_flat = self.grids_flat.to(device) grids_after_flat = rot_mat @ (grids_flat - pivot_vec) + pivot_vec + trans_vec motion = (grids_after_flat - grids_flat).view([B, K, 3, S1, S2, S3]) motion = torch.sum(motion * torch.unsqueeze(mask, 2), 1) output['motion'] = motion if no_warp: output['s'] = mask_feature elif input_motion is not None: mask_feature_warp = self.forward_warp( mask_feature, input_motion, torch.sum(mask[:, :-1, ], dim=1)) output['s'] = mask_feature_warp else: mask_feature_warp = self.forward_warp( mask_feature, motion, torch.sum(mask[:, :-1, ], dim=1)) output['s'] = mask_feature_warp if next_mask: mask_warp = self.forward_warp(mask, motion, torch.sum(mask[:, :-1, ], dim=1)) output['next_mask'] = mask_warp return output
def one_label_loss(gt_percent, predict, moe, batch_node_num): """ Proposed Loss Function Our proposed Loss Functions calculates cost of training batch using -GCN's output graphs and weak image level annotations. For more information, please refer to our paper. Keyword arguments: gt_percent --Ground-Trueth percent, a weak image-level annotation predict --GCN module output, gradient required moe --Margin of Error, a weak image-level annotation batch_node_num --integer list of node numbers per image in batch """ curr_index = 0 batch_top_k_loss = [] batch_bottom_k_loss = [] batch_pairwise_loss = [] positive_num = 0.00000001 negative_num = 0.00000001 for i in range(len(gt_percent)): total_length = batch_node_num[i] #one graph length predict_slice = torch.narrow(input = predict, dim = 0, start = curr_index, length = total_length) curr_index += total_length one_gt_percent = gt_percent[i] one_moe = moe[i] select = torch.tensor([0]) if use_cuda: select = select.to('cuda') threshold_ceil = int(total_length * (one_gt_percent - one_moe)) #100 * (0.8 - 0.1) = top 70 % if threshold_ceil < 0: threshold_ceil = 0 threshold_floor = int(total_length * (1.0 - one_gt_percent - one_moe)) #100 * (1 - 0.8 - 0.1) = bottom 10 % if threshold_floor < 0: threshold_floor = 0 top_k, _ = torch.topk(input = predict_slice, k = threshold_ceil, dim = 0, largest = True, sorted = False) bottom_k, _ = torch.topk(input = predict_slice, k = threshold_floor, dim = 0, largest = False, sorted = False) top_k_mean = torch.mean(top_k,dim=0) bottom_k_mean = torch.mean(bottom_k,dim=0) predict_slice = None top_k = None select = None bottom_k = None loss_fn = nn.SmoothL1Loss() if use_cuda: temp_ones = torch.ones(1, dtype = torch.float).to('cuda') temp_zeros = torch.tensor([-1], dtype = torch.float).to('cuda') temp_ground = torch.zeros(1, dtype = torch.float).to('cuda') if threshold_ceil > 0: #top_k_loss = F.l1_loss(top_k_mean, temp_ones) top_k_loss = loss_fn(top_k_mean, temp_ones) positive_num += top_k_loss.detach().cpu().numpy() else: top_k_loss = None if threshold_floor > 0: #bottom_k_loss = F.l1_loss(bottom_k_mean, temp_zeros) bottom_k_loss = loss_fn(bottom_k_mean, temp_zeros) negative_num += bottom_k_loss.detach().cpu().numpy() else: bottom_k_loss = None temp_ones = None temp_zeors = None else: if threshold_ceil > 0: #top_k_loss = F.l1_loss(top_k_mean, torch.ones(1, dtype = torch.float)) top_k_loss = loss_fn(top_k_mean, torch.ones(1, dtype = torch.float)) positive_num += 1.0 else: top_k_loss = None if threshold_floor > 0: #bottom_k_loss = F.l1_loss(bottom_k_mean, torch.zeros(1, dtype = torch.float)) bottom_k_loss = loss_fn(bottom_k_mean, torch.zeros(1, dtype = torch.float)) negative_num += 1.0 else: bottom_k_loss = None batch_top_k_loss.append(top_k_loss) batch_bottom_k_loss.append(bottom_k_loss) top_k_loss = None bottom_k_loss = None pairwise_loss = None print("-------------------------------------------------------------------------------") print("Targeted Regions Losses Per Image") print([round(float(x.data.cpu().detach().numpy()),2) if x is not None else -1.00 for x in batch_top_k_loss]) print("Background Regions Losses Per Image") print([round(float(x.data.cpu().detach().numpy()),2) if x is not None else -1.00 for x in batch_bottom_k_loss]) print("-------------------------------------------------------------------------------") for t, b, g, a in zip(batch_top_k_loss, batch_bottom_k_loss, gt_percent, moe): if top_k_loss is None and t is not None: top_k_loss = (g - a) * t elif t is not None: top_k_loss += (g - a) * t if bottom_k_loss is None and b is not None: bottom_k_loss = (1.0 - g - a) * b elif b is not None: bottom_k_loss += (1.0 - g - a) * b return top_k_loss, bottom_k_loss
def train_text2mel(load_trained): # create log dir logdir = os.path.join(Hyper.logdir, "text2mel") if not os.path.exists(logdir): os.makedirs(logdir) if not os.path.exists(os.path.join(logdir, "pkg")): os.mkdir(os.path.join(logdir, "pkg")) # device device = Hyper.device_text2mel graph = Text2Mel().to(device) # set the training flag graph.train() # load data and get batch maker names, lengths, texts = load_data() batch_maker = BatchMaker(Hyper.batch_size, names, lengths, texts) criterion_mels = nn.L1Loss().to(device) criterion_bd1 = nn.BCEWithLogitsLoss().to(device) criterion_atten = nn.L1Loss().to(device) optimizer = torch.optim.Adam(graph.parameters(), lr=Hyper.adam_alpha, betas=Hyper.adam_betas, eps=Hyper.adam_eps) lossplot_mels = LogHelper("mel_l1", logdir) lossplot_bd1 = LogHelper("mel_BCE", logdir) lossplot_atten = LogHelper("atten", logdir) dynamic_guide = float(Hyper.guide_weight) global_step = 0 # check if load if load_trained > 0: print("load model trained for {}k batches".format(load_trained)) global_step = load( os.path.join(logdir, "pkg/save_{}k.pkg".format(load_trained)), graph, { "mels": criterion_mels, "bd1": criterion_bd1, "atten": criterion_atten }, optimizer) dynamic_guide *= Hyper.guide_decay**(load_trained * 1000) evaluator = Evaluator() for loop_cnt in range( int(Hyper.num_batches / batch_maker.num_batches() + 0.5)): print("loop", loop_cnt) bar = PrettyBar(batch_maker.num_batches()) bar.set_description("training...") loss_str0 = MovingAverage() loss_str1 = MovingAverage() loss_str2 = MovingAverage() for bi in bar: batch = batch_maker.next_batch() # make batch texts = torch.LongTensor(batch["texts"]).to(device) # shift mel shift_mels = torch.FloatTensor( np.concatenate((np.zeros( (batch["mels"].shape[0], batch["mels"].shape[1], 1)), batch["mels"][:, :, :-1]), axis=2)).to(device) # ground truth mels = torch.FloatTensor(batch["mels"]).to(device) # forward pred_logits, pred_mels = graph(texts, shift_mels) # loss if False: loss_mels = sum( criterion_mels( torch.narrow(pred_mels[i], -1, 0, batch["mel_lengths"] [i]), torch.narrow(mels[i], -1, 0, batch["mel_lengths"][i])) for i in range(batch_maker.batch_size())) / float( batch_maker.batch_size()) loss_bd1 = sum( criterion_bd1( torch.narrow(pred_logits[i], -1, 0, batch["mel_lengths"][i]), torch.narrow(mels[i], -1, 0, batch["mel_lengths"][i])) for i in range(batch_maker.batch_size())) / float( batch_maker.batch_size()) else: loss_mels = criterion_mels(pred_mels, mels) loss_bd1 = criterion_bd1(pred_logits, mels) # guide attention atten_guide = torch.FloatTensor(batch["atten_guides"]).to(device) atten_mask = torch.FloatTensor(batch["atten_masks"]).to(device) atten_mask = torch.ones_like(graph.attention) loss_atten = criterion_atten( atten_guide * graph.attention * atten_mask, torch.zeros_like(graph.attention)) * dynamic_guide loss = loss_mels + loss_bd1 + loss_atten # backward graph.zero_grad() optimizer.zero_grad() loss.backward() # clip grad nn.utils.clip_grad_value_(graph.parameters(), 1) optimizer.step() # log loss_str0.add(loss_mels.cpu().data.mean()) loss_str1.add(loss_bd1.cpu().data.mean()) loss_str2.add(loss_atten.cpu().data.mean()) lossplot_mels.add(loss_str0(), global_step) lossplot_bd1.add(loss_str1(), global_step) lossplot_atten.add(loss_str2(), global_step) # adjust dynamic_guide # dynamic_guide = float((loss_mels + loss_bd1).cpu().data.mean() / loss_atten.cpu().data.mean()) dynamic_guide *= Hyper.guide_decay if dynamic_guide < Hyper.guide_lowbound: dynamic_guide = Hyper.guide_lowbound bar.set_description( "gs: {}, mels: {}, bd1: {}, atten: {}, scale: {}".format( global_step, loss_str0(), loss_str1(), loss_str2(), "%4f" % dynamic_guide)) if global_step % Hyper.synth_freq == 0: evaluator.evaluate(loop_cnt) evaluator.export() # plot if global_step % 100 == 0: gs = 0 plot_spectrum(mels[0].cpu().data, "mel_true", gs, dir=logdir) plot_spectrum(shift_mels[0].cpu().data, "mel_input", gs, dir=logdir) plot_spectrum(pred_mels[0].cpu().data, "mel_pred", gs, dir=logdir) plot_spectrum(graph.query[0].cpu().data, "query", gs, dir=logdir) plot_attention(graph.attention[0].cpu().data, "atten", gs, True, dir=logdir) plot_attention((atten_guide)[0].cpu().data, "atten_guide", gs, True, dir=logdir) if global_step % 500 == 0: lossplot_mels.plot() lossplot_bd1.plot() lossplot_atten.plot() if global_step % 10000 == 0: save( os.path.join(logdir, "pkg/save_{}k.pkg").format( global_step // 1000), graph, { "mels": criterion_mels, "bd1": criterion_bd1, "atten": criterion_atten }, optimizer, global_step, True) # increase global step global_step += 1
def tb(a): return torch.narrow(a, 3, 1, N - k)
def forward(ctx, world_size, start_pos, chunk_size, weight, pg, bias): ctx.weight = weight ctx.pg = pg ctx.world_size = world_size return torch.narrow(bias, 0, start_pos, chunk_size)
def tf(a): return torch.narrow(a, 3, 0, N - k)
def _dp(self, arc_scores, lengths=None, force_grad=False): semiring = self.semiring arc_scores = _convert(arc_scores) arc_scores, batch, N, lengths = self._check_potentials( arc_scores, lengths) DIRS = 2 alpha = [ self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad) for _ in range(2) ] def stack(a, b): return torch.stack([a, b], dim=1) def sstack(a): return torch.stack([a, a], dim=1) arcs = [ self._make_chart(1, (DIRS, batch, N - k), arc_scores, force_grad)[0] for k in range(N) ] # Inside step. assumes first token is root symbol semiring.one_(alpha[A][C][:, :, :, :, 0].data) semiring.one_(alpha[B][C][:, :, :, :, -1].data) k = 0 AIR = alpha[A][I][:, R, :, :N - k, 1:k] BIL = alpha[B][I][:, L, :, k:N, N - k:N - 1] k = 1 AC2 = alpha[A][C][:, :, :, :N - k, :k] BC2 = alpha[B][C][:, :, :, k:, N - k:] AC, BC, AC_next = None, None, None ends = [None] for k in range(1, N): def tf(a): return torch.narrow(a, 3, 0, N - k) def tb(a): return torch.narrow(a, 3, 1, N - k) f = torch.arange(N - k), torch.arange(k, N) if k > 1: AC2 = torch.cat([tf(AC), tf(AC_next).unsqueeze(-1)], dim=4) if k > 1: BC2 = torch.cat([tb(AC_next).unsqueeze(-1), tb(BC)], dim=4) ACL, ACR = AC2.unbind(dim=1) BCL, BCR = BC2.unbind(dim=1) start = semiring.dot(BCL, ACR) # if k == 1: arcs[k] = stack( semiring.times(start, arc_scores[:, :, f[1], f[0]]), semiring.times(start, arc_scores[:, :, f[0], f[1]]), ) arcsL, arcR = arcs[k].unbind(dim=1) # else: # arcs[k] = stack(semiring.times(start), #, arc_scores[:, f[1], f[0]]), # semiring.times(start)) #, arc_scores[:, f[0], f[1]])) AIR2 = torch.cat( [torch.narrow(AIR, 2, 0, N - k), arcR.unsqueeze(-1)], dim=3) BIL2 = torch.cat( [arcsL.unsqueeze(-1), torch.narrow(BIL, 2, 1, N - k)], dim=3) AC_next = stack(semiring.dot(ACL, BIL2), semiring.dot(AIR2, BCR)) ends.append(AC_next[:, R, :, 0]) AC = AC2 BC = BC2 AIR = AIR2 BIL = BIL2 v = torch.stack([ends[l][:, i] for i, l in enumerate(lengths)], dim=1) # v = torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)]) return (semiring.unconvert(v), arcs[1:], alpha)
def narrow(self, dim, start, length): tensor = torch.narrow(self, dim, start, length) tensor.dtype = self.dtype return tensor
def showtensor(tensor): x = torch.narrow(tensor, 0, 0, 1) plt.figure() plt.imshow(x.squeeze().numpy()) plt.show()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, example_idx=None, extend=True): group_ids = None output_length = len(input_ids) if input_ids is not None else len(inputs_embeds) if extend and not self.training: # use the hook input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, example_idx, group_ids, group_sizes =\ self.extend_batch_examples_eval( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, example_idx=example_idx) if not extend and input_ids is not None and inputs_embeds is not None: inputs_embeds=None result_logits = None for i in range(0,attention_mask.shape[0], self.hparams.batch_size): bs = min(self.hparams.batch_size, attention_mask.shape[0]-i) if input_ids is not None: batch_input_ids=input_ids.narrow(0,i,bs) else: batch_input_ids = None if inputs_embeds is not None: batch_inputs_embeds=inputs_embeds.narrow(0,i,bs) else: batch_inputs_embeds = None batch_attention_mask=attention_mask.narrow(0,i,bs) batch_token_type_ids=token_type_ids.narrow(0,i,bs) if position_ids is not None: batch_position_ids=position_ids.narrow(0,i,bs) else: batch_position_ids = None if head_mask is not None: batch_head_mask = head_mask.narrow(0,i,bs) else: batch_head_mask = None logits = LightningBertForSequenceClassification.forward(self, input_ids=batch_input_ids, attention_mask=batch_attention_mask, token_type_ids=batch_token_type_ids, position_ids=batch_position_ids, head_mask=batch_head_mask, inputs_embeds=batch_inputs_embeds) if result_logits is None: result_logits = logits else: result_logits = torch.cat((result_logits, logits), dim=0) logits = result_logits # # time to vote # Makes big empty tensor for all groups # uses torch.view to update individual group # if group_ids is not None: # prepare a couple of output tensors of the right dimensions avg_logits = torch.zeros(output_length, self.num_labels).to(self.device) counted_logits = torch.zeros(output_length, self.num_labels).to(self.device) original_logits = logits[:output_length] # now go through the whole extended batch for i, (logit, group_id) in enumerate(zip(logits, group_ids)): # first, tally logits by averaging across replacement groups current_group_logits = torch.narrow(avg_logits, 0, group_id, 1) torch.add(current_group_logits, torch.div(logit, group_sizes[group_id]), out=current_group_logits) # but also, record the individual VOTES (argmax) current_vote = torch.argmax(logit).item() counted_logits[group_id, current_vote] += 1/group_sizes[group_id] # let us know what is happening here self._debug_print_vote(i, group_id, logit, example_idx) # print the results for this batch self._debug_print_votes(original_logits, avg_logits, counted_logits) if self.hparams.vote_avg_logits: logits = avg_logits else: logits = counted_logits # # return whatever we have at this point # return logits
def diff(a, dim=0, out=None, func=torch.not_equal): sz = a.size(dim) - 1 if out is None: out = torch.empty(sz, dtype=torch.bool, device=a.device) return func(torch.narrow(a, dim, 1, sz), torch.narrow(a, dim, 0, sz), out=out)
def forward( self, x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: encoded output of shape `(batch, src_len, embed_dim)` """ # for layer in self.layers: # x, attn = layer( # x, # encoder_out['encoder_out'] if encoder_out is not None else None, # encoder_out['encoder_padding_mask'] if encoder_out is not None else None, # incremental_state, # self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, # ) residual = x x = self.maybe_layer_norm(self.layer_norm_self_attn, x, before=True) ins = self.fc1(x) q1 = torch.narrow(ins, -1, 0, self.embed_dim) k1 = torch.narrow(ins, -1, self.embed_dim, self.embed_dim) v1 = torch.narrow(ins, -1, 2 * self.embed_dim, self.embed_dim) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.self_attn( q=q1, k=k1, v=v1, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.layer_norm_self_attn, x, after=True) if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.layer_norm_context_attn, x, before=True) q2 = self.fc3(x) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( q=q2, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = self.fc4(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.layer_norm_context_attn, x, after=True) residual = x x = self.maybe_layer_norm(self.layer_norm_ffn, x, before=True) x = self.ffn(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.layer_norm_ffn, x, after=True) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) self_attn_state = saved_state["prev_key"], saved_state[ "prev_value"] return x, attn, self_attn_state return x, attn
def backward(ctx, grad_output): bsz = grad_output.shape[0] grad_output_shape = grad_output.size() if grad_output_shape[1] % 16 != 0: grad_output = torch.cat((grad_output, torch.zeros((bsz, (16 - (grad_output_shape[1] % 16)), \ grad_output_shape[2]), device=grad_output.device, dtype=grad_output.dtype)), 1) assert ctx.config.training if (ctx.config.pre_layer_norm and ctx.config.normalize_invertible): (input_mask, attn_qkvw, attn_qkvb, attn_ow, attn_ob, attn_nw, attn_nb, inter_w, inter_b, output_w, output_b, norm_w, norm_b) = ctx.saved_tensors else: (output, input, input_mask, attn_qkvw, attn_qkvb, attn_ow, attn_ob, attn_nw, attn_nb, inter_w, inter_b, output_w, output_b, norm_w, norm_b) = ctx.saved_tensors cuda_module = stochastic_transformer_cuda_module if ctx.config.stochastic_mode else transformer_cuda_module backward_func = cuda_module.backward_fp16 if ctx.config.fp16 else cuda_module.backward_fp32 (grad_input, grad_attn_qkvw, grad_attn_qkvb, grad_attn_ow, grad_attn_ob, grad_attn_nw, grad_attn_nb, grad_inter_w, grad_inter_b, grad_output_w, grad_output_b, grad_norm_w, grad_norm_b) = backward_func( ctx.config.layer_id, grad_output, (ctx.inp_norm if (ctx.config.pre_layer_norm and ctx.config.normalize_invertible) else output), (ctx.inp_norm if (ctx.config.pre_layer_norm or not ctx.config.normalize_invertible) else input), ctx.qkv_tf, ctx.soft_inp, (ctx.soft_inp if ctx.config.attn_dropout_checkpoint else ctx.ctx_bufB), ctx.attn_o_inp, (ctx.ff1_inp if ctx.config.normalize_invertible else ctx.add_res), ctx.ff1_inp, (ctx.ff2_inp if ctx.config.gelu_checkpoint else ctx.gelu_inp), ctx.ff2_inp, ctx.attn_prob_dropout_mask, ctx.attn_output_dropout_mask, ctx.layer_output_dropout_mask, ctx.attn_layer_norm_var, ctx.attn_layer_norm_mean, ctx.layer_norm_var, ctx.layer_norm_mean, (ctx.inp_norm if (ctx.config.pre_layer_norm and ctx.config.normalize_invertible) else input), input_mask, attn_qkvw, attn_qkvb, attn_ow, attn_ob, attn_nw, attn_nb, inter_w, inter_b, output_w, output_b, norm_w, norm_b) # This appears to be an effective way to release context memory ctx.qkv_tf = None ctx.soft_inp = None ctx.ctx_bufB = None ctx.gelu_inp = None ctx.ff2_inp = None ctx.attn_o_inp = None ctx.ff1_inp = None ctx.add_res = None ctx.inp_norm = None ctx.config = None ctx.attn_layer_norm_mean = None ctx.layer_norm_mean = None ctx.attn_prob_dropout_mask = None ctx.attn_output_dropout_mask = None ctx.layer_output_dropout_mask = None ctx.attn_layer_norm_var = None ctx.layer_norm_var = None if grad_output_shape[1] % 16 != 0: grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1]) return (grad_input, None, None, None, None, grad_attn_qkvw, grad_attn_qkvb, grad_attn_ow, grad_attn_ob, grad_attn_nw, grad_attn_nb, grad_inter_w, grad_inter_b, grad_output_w, grad_output_b, grad_norm_w, grad_norm_b, None)
def forward(self, query, key, value, key_padding_mask=None): """Input shape: Time x Batch x Channel Self-attention can be implemented by passing in the same arguments for query, key and value. Timesteps can be masked by supplying a T x T mask in the `attn_mask` argument. Padding elements can be excluded from the key by passing a binary ByteTensor (`key_padding_mask`) with shape: batch x src_len, where padding elements are indicated by 1s. """ tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] ins = self.fc1(query) q = torch.narrow(ins, -1, 0, embed_dim) k = torch.narrow(ins, -1, embed_dim, embed_dim) v = torch.narrow(ins, -1, 2 * embed_dim, embed_dim) q = q * self.scaling q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) src_len = k.size(1) # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.shape == torch.Size( []): key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len attn_weights = torch.bmm(q, k.transpose(1, 2)) assert list( attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) if self.onnx_trace: attn_weights = torch.where( key_padding_mask.unsqueeze(1).unsqueeze(2), torch.Tensor([float("-Inf")]), attn_weights.float()).type_as(attn_weights) else: attn_weights = attn_weights.float().masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ).type_as(attn_weights) # FP16 support: cast to float and back attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = utils.softmax( attn_weights, dim=-1, onnx_trace=self.onnx_trace, ).type_as(attn_weights) attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn = torch.bmm(attn_weights, v) assert list( attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if (self.onnx_trace and attn.size(1) == 1): # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.fc2(attn) attn_weights = None return attn, attn_weights
def _setup_for_real_optimizer(self): dp_world_size = dist.get_world_size(group=self.dp_process_group) self.partition_count = [ dp_world_size for i in range(len(self.optimizer.param_groups)) ] for i, param_group in enumerate(self.optimizer.param_groups): see_memory_usage(f'before initializing group {i}', force=True) partition_id = dist.get_rank(group=self.real_dp_process_group[i]) # grab the original list self.bf16_groups.append(param_group['params']) # create flat bf16 params self.bf16_groups_flat.append( self._flatten_dense_tensors_aligned( self.bf16_groups[i], self.nccl_start_alignment_factor * dp_world_size)) # Make bf16 params point to flat tensor storage self._update_storage_to_flattened_tensor( tensor_list=self.bf16_groups[i], flat_tensor=self.bf16_groups_flat[i]) # divide flat weights into equal sized partitions partition_size = self.bf16_groups_flat[i].numel() // dp_world_size bf16_dp_partitions = [ self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size) for dp_index in range(dp_world_size) ] self.bf16_partitioned_groups.append(bf16_dp_partitions) # create fp32 params partition self.fp32_groups_flat_partition.append( bf16_dp_partitions[partition_id].clone().float().detach()) self.fp32_groups_flat_partition[i].requires_grad = True num_elem_list = [t.numel() for t in self.bf16_groups[i]] # create fp32 gradients self.fp32_groups_gradients_flat.append( torch.zeros_like(self.bf16_groups_flat[i], dtype=torch.float32)) # track individual fp32 gradients for entire model fp32_gradients = self._split_flat_tensor( flat_tensor=self.fp32_groups_gradients_flat[i], num_elem_list=num_elem_list) self.fp32_groups_gradients.append(fp32_gradients) # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding) length_without_padding = sum(num_elem_list) self.fp32_groups_actual_gradients_flat.append( torch.narrow(self.fp32_groups_gradients_flat[i], 0, 0, length_without_padding)) # flat tensor corresponding to gradient partition self.fp32_groups_gradient_flat_partition.append( torch.narrow(self.fp32_groups_gradients_flat[i], 0, partition_id * partition_size, partition_size)) # track fp32 gradient updates self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i])) # Record padding required for alignment if partition_id == dist.get_world_size( group=self.real_dp_process_group[i]) - 1: padding = self.bf16_groups_flat[i].numel( ) - length_without_padding else: padding = 0 self.group_paddings.append(padding) # update optimizer param groups to reference fp32 params partition param_group['params'] = [self.fp32_groups_flat_partition[i]] see_memory_usage(f'after initializing group {i}', force=True) see_memory_usage('before initialize_optimizer', force=True) self.initialize_optimizer_states() see_memory_usage('end initialize_optimizer', force=True) # Need optimizer states initialized before linking lp to optimizer state self._link_all_hp_params() self._param_slice_mappings = self._create_param_mapping()
def forward(self, q, key, value, key_padding_mask=None, incremental_state=None, need_weights=True, static_kv=False, attn_mask=None): """Input shape: Time x Batch x Channel Self-attention can be implemented by passing in the same arguments for query, key and value. Timesteps can be masked by supplying a T x T mask in the `attn_mask` argument. Padding elements can be excluded from the key by passing a binary ByteTensor (`key_padding_mask`) with shape: batch x src_len, where padding elements are indicated by 1s. """ qkv_same = False kv_same = True tgt_len, bsz, embed_dim = q.size() assert embed_dim == self.embed_dim assert list(q.size()) == [tgt_len, bsz, embed_dim] if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if 'prev_key' in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert kv_same and not qkv_same key = value = None else: saved_state = None # encoder-decoder attention # q = self.in_proj_q(query) if key is None: assert value is None k = v = None else: kv = self.kv_fc(key) k = torch.narrow(kv, -1, 0, self.embed_dim) v = torch.narrow(kv, -1, self.embed_dim, self.embed_dim) q *= self.scaling q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) if k is not None: k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) if v is not None: v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if 'prev_key' in saved_state: prev_key = saved_state['prev_key'].view( bsz * self.num_heads, -1, self.head_dim) if static_kv: k = prev_key else: k = torch.cat((prev_key, k), dim=1) if 'prev_value' in saved_state: prev_value = saved_state['prev_value'].view( bsz * self.num_heads, -1, self.head_dim) if static_kv: v = prev_value else: v = torch.cat((prev_value, v), dim=1) saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) self._set_input_buffer(incremental_state, saved_state) src_len = k.size(1) # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.shape == torch.Size( []): key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len attn_weights = torch.bmm(q, k.transpose(1, 2)) assert list( attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) if self.onnx_trace: attn_weights = torch.where( key_padding_mask.unsqueeze(1).unsqueeze(2), torch.Tensor([float("-Inf")]), attn_weights.float()).type_as(attn_weights) else: attn_weights = attn_weights.float().masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ).type_as(attn_weights) # FP16 support: cast to float and back attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = utils.softmax( attn_weights, dim=-1, onnx_trace=self.onnx_trace, ).type_as(attn_weights) attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn = torch.bmm(attn_weights, v) assert list( attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if (self.onnx_trace and attn.size(1) == 1): # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) if need_weights: # average attention weights over heads attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.sum(dim=1) / self.num_heads else: attn_weights = None return attn, attn_weights
def test_narrow(self): x = torch.randn(3, 3, requires_grad=True) self.assertONNX(lambda x: torch.narrow(x, 0, 0, 2), x)
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): """ Entry-point function to handle the logic of row-wise sharding of weight for Linear. (Detailed explanations of the logic can be found in the comment for sharded_linear.) Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. rank: # of cuda process. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: final result of linear operation. """ # alltoall to gather all the appropriate inputs. input_t = input.t().contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[List[int]] = [[0]] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = max(input_split_sizes) for idx, placement in enumerate(weight._sharding_spec.placements): split_size = input_split_sizes[placement.rank()] offset_start_idx = idx * sharded_dim_size_max indices[placement.rank()] = list(range(offset_start_idx, offset_start_idx + split_size)) indices_flatten = list(idx for indice in indices for idx in indice) input_t = input_t.index_select(0, torch.tensor(indices_flatten, device=input_t.device)) gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device) # Perform alltoall dist.all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.t() # Perform local matmuls for all shards shard_size = local_shard_t.size()[0] results = [] for r in range(world_size): inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size) results.append(inp.matmul(local_shard_t)) # Gather all the results appropriately. local_result = torch.empty_like(results[rank]) dist.reduce_scatter(local_result, results, group=pg) # Return the appropriate local result. return local_result + bias
def slice_axis(data, axis, begin, end): return th.narrow(data, axis, begin, end - begin)
def backward(ctx, grad_output): slice_size = grad_output.size(ctx.dim) // ctx.world_size return torch.narrow(grad_output.clone(), ctx.dim, ctx.ordinal * slice_size, slice_size), None
def forward(self, adv_patch, lab_batch, img_size, do_rotate=True, rand_loc=True): #adv_patch = F.conv2d(adv_patch.unsqueeze(0),self.kernel,padding=(2,2)) adv_patch = self.medianpooler(adv_patch.unsqueeze(0)) #print('lab_batch---------------------------: ',lab_batch) # Determine size of padding pad = (img_size - adv_patch.size(-1)) / 2 # Make a batch of patches adv_patch = adv_patch.unsqueeze(0)#.unsqueeze(0) adv_batch = adv_patch.expand(lab_batch.size(0), lab_batch.size(1), -1, -1, -1) batch_size = torch.Size((lab_batch.size(0), lab_batch.size(1))) #print('--========+++++======---adv_patch/adv_batch-----',adv_patch.shape,adv_batch.shape) #torch.Size([1, 1, 3, 300, 300]) torch.Size([8, 14, 3, 300, 300]) # Contrast, brightness and noise transforms # Create random contrast tensor contrast = torch.cuda.FloatTensor(batch_size).uniform_(self.min_contrast, self.max_contrast) contrast = contrast.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) contrast = contrast.expand(-1, -1, adv_batch.size(-3), adv_batch.size(-2), adv_batch.size(-1)) contrast = contrast.cuda() # Create random brightness tensor brightness = torch.cuda.FloatTensor(batch_size).uniform_(self.min_brightness, self.max_brightness) brightness = brightness.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) brightness = brightness.expand(-1, -1, adv_batch.size(-3), adv_batch.size(-2), adv_batch.size(-1)) brightness = brightness.cuda() # Create random noise tensor noise = torch.cuda.FloatTensor(adv_batch.size()).uniform_(-1, 1) * self.noise_factor # Apply contrast/brightness/noise, clamp adv_batch = adv_batch * contrast + brightness + noise adv_batch = torch.clamp(adv_batch, 0.000001, 0.99999) #人为指定maxlab数量最大为14,不够的其值全部填充1,所以这里以示区别 # Where the label class_id is 1 we don't want a patch (padding) --> fill mask with zero's #class_id = 1 人为填充的1 #lab_batch.size=torch.Size([8, 14, 5]),batch=8,maxlab=14,5表示标签+坐标 cls_ids = torch.narrow(lab_batch, 2, 0, 1) #取lab_batch得第二个维度即5上的0-1得索引值即第一个id值 cls_mask = cls_ids.expand(-1, -1, 3) cls_mask = cls_mask.unsqueeze(-1) cls_mask = cls_mask.expand(-1, -1, -1, adv_batch.size(3)) cls_mask = cls_mask.unsqueeze(-1) cls_mask = cls_mask.expand(-1, -1, -1, -1, adv_batch.size(4)) msk_batch = torch.cuda.FloatTensor(cls_mask.size()).fill_(1) - cls_mask #print('++++++++++====msk_batch=========',msk_batch.shape,msk_batch)#torch.Size([8, 14, 3, 300, 300]) # Pad patch and mask to image dimensions mypad = nn.ConstantPad2d((int(pad + 0.5), int(pad), int(pad + 0.5), int(pad)), 0) #左右上下四个维度分别按指定int大小填充相应个数0,使图像块填充后大小和原图像大小相同 #填充的值为0,在两者融合时过滤掉0值 #分别对应同等大小的图像块和label_id(扩维后得),将两者相乘表示以id过滤图像块 adv_batch = mypad(adv_batch) msk_batch = mypad(msk_batch) # Rotation and rescaling transforms,根据真实label的大小、方向进行图像块的填充 anglesize = (lab_batch.size(0) * lab_batch.size(1)) if do_rotate: angle = torch.cuda.FloatTensor(anglesize).uniform_(self.minangle, self.maxangle) else: angle = torch.cuda.FloatTensor(anglesize).fill_(0) # Resizes and rotates current_patch_size = adv_patch.size(-1) lab_batch_scaled = torch.cuda.FloatTensor(lab_batch.size()).fill_(0) #根据label坐标获取真实标注框大小:x\y\w\h lab_batch_scaled[:, :, 1] = lab_batch[:, :, 1] * img_size lab_batch_scaled[:, :, 2] = lab_batch[:, :, 2] * img_size lab_batch_scaled[:, :, 3] = lab_batch[:, :, 3] * img_size lab_batch_scaled[:, :, 4] = lab_batch[:, :, 4] * img_size #图像块大小 target_size = torch.sqrt(((lab_batch_scaled[:, :, 3].mul(0.2)) ** 2) + ((lab_batch_scaled[:, :, 4].mul(0.2)) ** 2)) target_x = lab_batch[:, :, 1].view(np.prod(batch_size)) target_y = lab_batch[:, :, 2].view(np.prod(batch_size)) targetoff_x = lab_batch[:, :, 3].view(np.prod(batch_size)) targetoff_y = lab_batch[:, :, 4].view(np.prod(batch_size)) if(rand_loc): off_x = targetoff_x*(torch.cuda.FloatTensor(targetoff_x.size()).uniform_(-0.4,0.4)) target_x = target_x + off_x off_y = targetoff_y*(torch.cuda.FloatTensor(targetoff_y.size()).uniform_(-0.4,0.4)) target_y = target_y + off_y target_y = target_y - 0.05 scale = target_size / current_patch_size scale = scale.view(anglesize) s = adv_batch.size() adv_batch = adv_batch.view(s[0] * s[1], s[2], s[3], s[4]) msk_batch = msk_batch.view(s[0] * s[1], s[2], s[3], s[4]) tx = (-target_x+0.5)*2 ty = (-target_y+0.5)*2 sin = torch.sin(angle) cos = torch.cos(angle) # Theta = rotation,rescale matrix theta = torch.cuda.FloatTensor(anglesize, 2, 3).fill_(0) theta[:, 0, 0] = cos/scale theta[:, 0, 1] = sin/scale theta[:, 0, 2] = tx*cos/scale+ty*sin/scale theta[:, 1, 0] = -sin/scale theta[:, 1, 1] = cos/scale theta[:, 1, 2] = -tx*sin/scale+ty*cos/scale b_sh = adv_batch.shape #仿射变换:进行相应旋转平移缩放等,最终输出大小为图片大小416 grid = F.affine_grid(theta, adv_batch.shape) adv_batch_t = F.grid_sample(adv_batch, grid) msk_batch_t = F.grid_sample(msk_batch, grid) #print('-_______________adv_batch_t/mas________-----',adv_batch_t.shape,msk_batch_t.shape) #torch.Size([112, 3, 416, 416]) torch.Size([112, 3, 416, 416]) ''' # Theta2 = translation matrix theta2 = torch.cuda.FloatTensor(anglesize, 2, 3).fill_(0) theta2[:, 0, 0] = 1 theta2[:, 0, 1] = 0 theta2[:, 0, 2] = (-target_x + 0.5) * 2 theta2[:, 1, 0] = 0 theta2[:, 1, 1] = 1 theta2[:, 1, 2] = (-target_y + 0.5) * 2 grid2 = F.affine_grid(theta2, adv_batch.shape) adv_batch_t = F.grid_sample(adv_batch_t, grid2) msk_batch_t = F.grid_sample(msk_batch_t, grid2) ''' adv_batch_t = adv_batch_t.view(s[0], s[1], s[2], s[3], s[4]) msk_batch_t = msk_batch_t.view(s[0], s[1], s[2], s[3], s[4]) adv_batch_t = torch.clamp(adv_batch_t, 0.000001, 0.999999) #img = msk_batch_t[0, 0, :, :, :].detach().cpu() #img = transforms.ToPILImage()(img) #img.show() #exit() return adv_batch_t * msk_batch_t
def d_slice(dist, i, j, length): return torch.narrow(torch.narrow(dist, 0, i, length), 1, j, length)
def _forward_biobert( self, tokens: List[List[str]] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Return BioBERT Hidden state for the tokenized documents. Documents with different lengths will be accepted. list(list(str)) -> tuple(torch.tensor, torch.tensor) """ # Convert each token of each document into a list of subwords. # e.g., # [['Admission', 'Date', ...], ['Service', ':', ...]] # | # V # [[['Ad', '##mission'], ['Date'], ...], [['Service'], [':'], ...]] subwords_unchained = [ [self.tokenizer.tokenize(tok) for tok in doc] for doc in tokens ] # Simply replace each token of each document with corresponding subwords. # e.g., # [['Admission', 'Date', ...], ['Service', ':', ...]] # | # V # [['Ad', '##mission', 'Date', ...], ['Service', ':', ...]] subwords = [ list(itertools.chain(*[self.tokenizer.tokenize(tok) for tok in doc])) for doc in tokens ] # Memorize (i) header place of each token and (ii) how many subwords each token gave birth. # e.g., # For document ['Admission', 'Date'] -> ['Ad', '##mission', 'Date'], # subword_info will be {'start':[0,2], 'length':[2,1]}. subword_info = [] for doc in subwords_unchained: word_lengths = [len(word) for word in doc] word_head_ix = [0] for i in range(len(word_lengths) - 1): word_head_ix.append(word_head_ix[-1] + word_lengths[i]) assert len(word_lengths) == len(word_head_ix) subword_info.append({"start": word_head_ix, "length": word_lengths}) assert [len(info["start"]) for info in subword_info] == [ len(doc) for doc in tokens ] # Split each document into chunks shorter than max_length. # Here, each document will be simply split at every 510 tokens. max_length = min( self.bertconfig.max_position_embeddings, self.hparams.max_length ) longest_length = max([len(doc) for doc in subwords]) n_chunks = (longest_length - 1) // (max_length - 2) + 1 chunks = [] for n in range(n_chunks): chunk_of_all_documents = [] for document in subwords: chunk_of_single_document = document[ (max_length - 2) * n : (max_length - 2) * (n + 1) ] if chunk_of_single_document == []: chunk_of_all_documents.append([""]) else: chunk_of_all_documents.append(chunk_of_single_document) chunks.append(chunk_of_all_documents) # Convert chunks into BERT input form. inputs = [] for chunk in chunks: if type(chunk) is str: unsqueezed_chunk = [[chunk]] elif type(chunk) is list: if type(chunk[0]) is str: unsqueezed_chunk = [chunk] elif type(chunk[0]) is list: unsqueezed_chunk = chunk inputs.append( self.tokenizer.batch_encode_plus( unsqueezed_chunk, pad_to_max_length=True, is_pretokenized=True, ) ) # Get BioBERT hidden states. hidden_states = [] for inpt in inputs: inpt_tensors = { k: torch.tensor(v).to(self.get_device()) for k, v in inpt.items() } hidden_state = self.biobert(**inpt_tensors)[0][:, 1:-1, :] hidden_states.append(hidden_state) # Concatenate hidden states from each chunk. hidden_states_cat = torch.cat(hidden_states, dim=1) # If a word was tokenized into multiple subwords, take average of them. # e.g. Hidden state for "Admission" equals average of hidden states for "Ad" and "##mission" hidden_states_shrunk = torch.zeros_like(hidden_states_cat) for n in range(hidden_states_cat.size()[0]): hidden_state_shrunk = torch.stack( [ torch.narrow(hidden_states_cat[n], dim=0, start=s, length=l).mean( dim=0 ) for s, l in zip(subword_info[n]["start"], subword_info[n]["length"]) ] ) hidden_states_shrunk[ n, : hidden_state_shrunk.size()[0], : ] = hidden_state_shrunk # Truncate lengthy tail that will not be used. hidden_states_shrunk = hidden_states_shrunk[ :, : max([len(doc) for doc in tokens]), : ] # Create mask for CRF. crf_mask = torch.zeros(hidden_states_shrunk.size()[:2]).to(torch.uint8) for i, length in enumerate([len(doc) for doc in tokens]): crf_mask[i, :length] = 1 crf_mask = crf_mask > 0 crf_mask = crf_mask.to(self.get_device()) return (hidden_states_shrunk, crf_mask)
def extract_feature_matrix(self): # define generator generator = self.generator(self.paths) # load extractor extractor = self.load_vgg19(self.layer) # initialize sketch and label matrices features = [] paths = [] n = 0 quit = False # generate batches of sketches and labels if generator: while True: batch_size = self.batch_size img_batch = torch.zeros(batch_size, 3, self.imsize, self.imsize) paths_batch = [] if self.use_cuda: img_batch = img_batch.to(self.cuda_device) if (n + 1) % 5 == 0: print('Batch {}'.format(n + 1)) for b in range(batch_size): try: img, path = next(generator) img_batch[b] = img paths_batch.append(path) except StopIteration: quit = True print('stopped!') break if n == self.num_images // self.batch_size: print('b', b) print(img_batch.size()) img_batch = torch.narrow(img_batch, 0, 0, b) print(img_batch.size()) paths_batch = paths_batch[:b + 1] # extract features from batch n += 1 feats_batch = extractor(img_batch) feats_batch = [feat.cpu().data.numpy() for feat in feats_batch] feats_batch = np.squeeze(np.array(feats_batch), axis=0) # feats_batch = feats_batch.cpu().data.numpy() # print('features shape', features.shape) if len(features) == 0: features = feats_batch else: features = np.vstack((features, feats_batch)) paths.append(paths_batch) if n == self.num_images // batch_size + 1: break return features, paths