def comparison_ops(self): a = torch.randn(4) b = torch.randn(4) return ( torch.allclose(a, b), torch.argsort(a), torch.eq(a, b), torch.equal(a, b), torch.ge(a, b), torch.greater_equal(a, b), torch.gt(a, b), torch.greater(a, b), torch.isclose(a, b), torch.isfinite(a), torch.isin(a, b), torch.isinf(a), torch.isposinf(a), torch.isneginf(a), torch.isnan(a), torch.isreal(a), torch.kthvalue(a, 1), torch.le(a, b), torch.less_equal(a, b), torch.lt(a, b), torch.less(a, b), torch.maximum(a, b), torch.minimum(a, b), torch.fmax(a, b), torch.fmin(a, b), torch.ne(a, b), torch.not_equal(a, b), torch.sort(a), torch.topk(a, 1), torch.msort(a), )
def cf_score(self, pos_items, users, r_test, t_test): """ predict user-item preference :param pos_items: :param users: :param r_test: relations :param t_test: tail entities :return: """ r_emb = self.Relation(r_test) t_emb = self.Entity(t_test) # mask pos_num_r = torch.not_equal(r_test, self.n_relations).float() # add this line ↓ r_emb = torch.einsum("ab, abc->abc", pos_num_r, r_emb) t_emb = torch.einsum("ab, abc->abc", pos_num_r, t_emb) # attention weight att_weight = self.cal_att_weight(r_emb, t_emb, pos_num_r) # Equ (10) item_emb_nv = torch.sum(torch.mul(att_weight, t_emb), dim=1) item_emb = self.Item(pos_items) item_emb = item_emb + item_emb_nv user_emb = self.User(users) dot = torch.einsum("ac, bc->abc", user_emb, item_emb) pre = torch.einsum("ajk, kl->ajl", dot, self.pre_vec) return pre
def __call__(self, batch): input_ids, permuted = zip(*batch) permuted = torch.tensor(permuted, dtype=torch.float) input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self._pad_token_id) mask = torch.logical_and( torch.less(torch.rand(input_ids.shape), self._mask_prob), torch.not_equal(input_ids, self._pad_token_id)) truly_mask = torch.less(torch.rand(input_ids.shape), 1 - self._random_prob) random_mask = torch.less(torch.rand(input_ids.shape), 0.5) labels = torch.where(mask, input_ids, TARGET_IDX) # masking some of the tokens input_ids = torch.where(torch.logical_and(mask, truly_mask), self._mask_token_id, input_ids) # randomly changing other tokens input_ids = torch.where( torch.logical_and( mask, torch.logical_and(torch.logical_not(truly_mask), random_mask)), torch.randint_like(input_ids, low=5, high=self._vocab_size), input_ids) return input_ids, labels, permuted
def forward(self, inputs, mask=None): assert isinstance(inputs, torch.Tensor) assert (inputs.dtype == torch.int8 or inputs.dtype == torch.int16 or inputs.dtype == torch.int32 or inputs.dtype == torch.int64) assert len(inputs.shape) == 2 if self.pad_idx is not None: inputs_mask = torch.not_equal(inputs, self.pad_idx) else: inputs_mask = torch.not_equal(inputs, 0) inputs_exp = inputs.type(torch.int32) lengths = torch.sum(inputs_mask.type(torch.int32), dim=1) encoder_output = self.encoder_obj(inputs_exp, mask=inputs_mask) encoder_output[LENGTHS] = lengths return encoder_output
def index_nonzero(tensor, mask): assert tensor.shape[:mask.dim()] == mask.shape if mask.dim() == 0: if mask.item() != 0: if tensor.dim() == 0: yield tensor else: yield tensor[0] else: yield from tensor[torch.not_equal(mask, 0)]
def forward(self, inputs: torch.Tensor, mask=None): assert isinstance(inputs, torch.Tensor) assert inputs.dtype in [torch.int8, inputs.dtype, torch.int16, torch.int32, torch.int64] assert len(inputs.shape) == 2 inputs_exp = inputs.type(torch.int32) inputs_mask = torch.not_equal(inputs, 0) lengths = torch.sum(inputs_mask.type(torch.int32), dim=1) encoder_output = self.encoder_obj(inputs_exp, mask=inputs_mask) encoder_output[LENGTHS] = lengths return encoder_output
def forward(self, inputs, mask=None): assert isinstance(inputs, torch.Tensor) assert (inputs.dtype == torch.int8 or inputs.dtype == torch.int16 or inputs.dtype == torch.int32 or inputs.dtype == torch.int64) assert len(inputs.shape) == 2 inputs_mask = torch.not_equal(inputs, SpecialSymbol.PADDING.value) inputs_exp = inputs.type(torch.int32) lengths = torch.sum(inputs_mask.type(torch.int32), dim=1) encoder_output = self.encoder_obj(inputs_exp, mask=inputs_mask) encoder_output[LENGTHS] = lengths return encoder_output
def forward(self, x): cfg = self.cfg y = torch.zeros(torch.int_shape(x) + (cfg.d_model, )) bs = (cfg.brackets or []) + [cfg.s_vocab] b = 0 for i, e in enumerate(bs): m = (x >= (b or 1)) & (x < e) u = torch.boolean_mask(x, m) u = self.lookup(u - b, i) y = torch.tensor_scatter_nd_add(y, torch.where(m), u) b = e y *= y.shape[-1]**0.5 y.mask = torch.not_equal(x, cfg.PAD) return y
def vote_targets_torch(self, vote_base, gt_boxes_3d): """ Generating vote_targets for each vote_base point vote_base: [bs, points_num, 3] gt_boxes_3d: [bs, gt_num, 7] Return: vote_mask: [bs, points_num] vote_target: [bs, points_num, 3] """ bs, points_num, _ = vote_base.shape vote_mask = torch.zeros((bs, points_num)).float().to(vote_base.device) vote_target = torch.zeros( (bs, points_num, 3)).float().to(vote_base.device) for i in range(bs): cur_vote_base = vote_base[i] cur_gt_boxes_3d = gt_boxes_3d[i] filter_idx = torch.where( torch.any(torch.not_equal(cur_gt_boxes_3d, 0), dim=-1))[0].to(vote_base.device) cur_gt_boxes_3d = cur_gt_boxes_3d[filter_idx] cur_vote_base_numpy = cur_vote_base.cpu().detach().numpy() cur_expand_boxes_3d_numpy = cur_gt_boxes_3d.cpu().detach().numpy() cur_expand_boxes_3d_numpy[:, 3: -1] += cfg.TRAIN.AUGMENTATIONS.EXPAND_DIMS_LENGTH cur_points_mask = check_inside_points( cur_vote_base_numpy, cur_expand_boxes_3d_numpy) # [pts_num, gt_num] cur_vote_mask = np.max(cur_points_mask, axis=1).astype(np.float32) vote_mask[i] = torch.from_numpy(cur_vote_mask).float().to( vote_base.device) cur_vote_target_idx = np.argmax(cur_points_mask, axis=1) # [pts_num] cur_vote_target_idx = torch.from_numpy( cur_vote_target_idx).long().to(vote_base.device) cur_vote_target = cur_gt_boxes_3d[cur_vote_target_idx] cur_vote_target[:, 1] = cur_vote_target[:, 1] - cur_vote_target[:, 4] / 2. cur_vote_target = cur_vote_target[:, :3] - cur_vote_base vote_target[i] = cur_vote_target return vote_mask, vote_target
def forward(self, keys, query): # query(B, 1, emb_len), keys(B, len, emb_len) # print(keys.shape) # print(query.shape) query = query.unsqueeze(1) key_mask = torch.not_equal(keys[:, :, 0], 0) # (B, len) # print(key_mask[0, :].data) # mask通过验证,没问题 attention_score = self.local_activation_unit(keys, query) # (B, len) paddings = torch.zeros_like(attention_score) outputs = torch.where(key_mask, attention_score, paddings) # (B, len) outputs = outputs.unsqueeze(dim=1) # (B, 1, len) outputs = torch.matmul(outputs, keys) # (B, 1, emb_len) outputs = outputs.squeeze(dim=1) # (B, emb_len) # print(outputs) return outputs
def forward(self): a = torch.tensor(0) b = torch.tensor(1) return len( torch.allclose(a, b), torch.argsort(a), torch.eq(a, b), torch.eq(a, 1), torch.equal(a, b), torch.ge(a, b), torch.ge(a, 1), torch.greater_equal(a, b), torch.greater_equal(a, 1), torch.gt(a, b), torch.gt(a, 1), torch.greater(a, b), torch.isclose(a, b), torch.isfinite(a), torch.isin(a, b), torch.isinf(a), torch.isposinf(a), torch.isneginf(a), torch.isnan(a), torch.isreal(a), torch.kthvalue(a, 1), torch.le(a, b), torch.le(a, 1), torch.less_equal(a, b), torch.lt(a, b), torch.lt(a, 1), torch.less(a, b), torch.maximum(a, b), torch.minimum(a, b), torch.fmax(a, b), torch.fmin(a, b), torch.ne(a, b), torch.ne(a, 1), torch.not_equal(a, b), torch.sort(a), torch.topk(a, 1), torch.msort(a), )
def error_analyze(model, snaps, labels, mask): model.eval() # What fraction of all snapshots have rearrangement errors? rearrange_err = torch.count_nonzero(torch.logical_not(mask)) / len(mask) # Of the snapshots the model is incorrect on, what fraction have rearrangement error? batch_size = 256 num_batches = int(np.ceil(len(snaps) / 256)) incorrect_snaps = [] incorrect_labels = [] # Label OF the correct one, not the incorrect label incorrect_scores = [] incorrect_mask = [] with torch.no_grad(): for i in range(num_batches): batch = snaps[batch_size * i:batch_size * (i + 1)].to(device='cuda') batch_mask = mask[batch_size * i:batch_size * (i + 1)] batch_labels = labels[batch_size * i:batch_size * (i + 1)] preds = model(batch) scores = torch.nn.functional.softmax(preds, dim=-1) pred_classes = scores.argmax(dim=1).cpu() incorrect = torch.not_equal(pred_classes, batch_labels) incorrect_snaps.append(batch[incorrect]) incorrect_labels.append(batch_labels[incorrect]) incorrect_scores.append(scores[incorrect]) incorrect_mask.append(batch_mask[incorrect]) incorrect_snaps = torch.cat(incorrect_snaps) incorrect_labels = torch.cat(incorrect_labels) incorrect_scores = torch.cat(incorrect_scores) incorrect_mask = torch.cat(incorrect_mask) incorrect_rearrange_err = torch.count_nonzero( torch.logical_not(incorrect_mask)) / len(incorrect_mask) print("Orig Rearrange Err:", rearrange_err, ", Inc Rearrange Err:", incorrect_rearrange_err) incorrect_slideshow(incorrect_snaps, incorrect_labels, incorrect_scores, incorrect_mask)
def get_reward_sum_gae(self, buf_len, ten_reward, ten_mask, ten_value) -> (torch.Tensor, torch.Tensor): """ Calculate the **reward-to-go** and **advantage estimation** using GAE. :param buf_len: the length of the ``ReplayBuffer``. :param buf_reward: a list of rewards for the state-action pairs. :param buf_mask: a list of masks computed by the product of done signal and discount factor. :param buf_value: a list of state values estimiated by the ``Critic`` network. :return: the reward-to-go and advantage estimation. """ buf_r_sum = torch.empty(buf_len, dtype=torch.float32, device=self.device) # old policy value buf_adv_v = torch.empty(buf_len, dtype=torch.float32, device=self.device) # advantage value pre_r_sum = 0 pre_adv_v = 0 # advantage value of previous step ten_bool = torch.not_equal(ten_mask, 0).float() for i in range(buf_len - 1, -1, -1): buf_r_sum[i] = ten_reward[i] + ten_mask[i] * pre_r_sum pre_r_sum = buf_r_sum[i] buf_adv_v[i] = ten_reward[i] + ten_bool[i] * (pre_adv_v - ten_value[i]) pre_adv_v = ten_value[i] + buf_adv_v[i] * self.lambda_gae_adv return buf_r_sum, buf_adv_v
def loss_function_(real, pred, pad_index): ''' :param real: shape (batch_size, sen_len - 1, vocab_size) :param pred: shape (batch_size, sen_len - 1) :param pad_index: :return: ''' loss_object = nn.CrossEntropyLoss(reduction='none') mask = torch.not_equal(real, pad_index) # 类型转换 real = real.type(torch.long) p = torch.argmax(pred, dim=-1) # 巨坑无比!!!!!!pytorch的CrossEntropyLoss的输入不需要经过softmax # pytorch的CrossEntropyLoss在input是三维的时候,要求的shape是(batch_size, C, K) # 即input的最后一个维度和target的最后一个维度要相同 # C是分类的数量,K是网络的维度,即sen_len。 # 参考官方文档 pred = pred.transpose(1, 2) # (batch_size, vocab_size, sen_len - 1) loss_ = loss_object(pred, real) loss_ *= mask return torch.mean(loss_)
def infer(self, input_i, input_iu, input_hr, input_ht): """ :param input_i: items :param input_iu: :param input_hr: :param input_ht: :return: total loss """ # item_emb = self.Item(input_i).squeeze_() item_emb = self.Item(input_i) item_emb = torch.reshape(item_emb, [-1, self.emb_size]) # weights c = self.negative_c[input_i] ck = self.negative_ck[input_i] # Dropout item_emb_kg = self.dropout_kg(item_emb) # >>> knowledge, cal g_{hrt} # relations, tail entities r_emb = self.Relation(input_hr) t_emb = self.Entity(input_ht) # mask, for useless values 0.0, others 1.0 pos_num_r = torch.not_equal(input_hr, self.n_relations).float() pos_r_emb = torch.einsum("ab, abc->abc", pos_num_r, r_emb) pos_t_emb = torch.einsum("ab, abc->abc", pos_num_r, t_emb) # Equ (6) pos_rt = pos_r_emb * pos_t_emb # pos_hrt is g_{hrt}^ in paper pos_hrt = torch.einsum("ac, abc->ab", item_emb_kg, pos_rt) pos_hrt = torch.reshape(pos_hrt, [-1, self.max_i_r]) # <<< knowledge # >> CF, cal y_{uv} # >>> cal items' representation q_v att_weight = self.cal_att_weight(pos_r_emb, pos_t_emb, pos_num_r) # Equ (10), e_{N_v} item_emb_nv = torch.sum(torch.mul(att_weight, pos_t_emb), dim=1) item_emb_nv_drop = self.dropout_kg(item_emb_nv) # Equ (10), e_v item_emb_cf = self.dropout_cf(item_emb) # Equ (10) item_emb_qv = item_emb_cf + item_emb_nv_drop # <<< cal items' representation q_v user_emb = self.User(input_iu) # mask pos_num_u = torch.not_equal(input_iu, self.n_users).float() user_emb = torch.einsum("ab, abc->abc", pos_num_u, user_emb) # Equ (9), predict y_uv pos_iu = torch.einsum("ac, abc->abc", item_emb_qv, user_emb) pos_iu = torch.einsum("ajk, kl->ajl", pos_iu, self.pre_vec) pos_iu = torch.reshape(pos_iu, [-1, self.max_i_u]) # << CF, cal y_{uv} tot_loss = self.cal_loss(g_hrt=pos_hrt, item_batch_ev=item_emb_cf, item_batch_qv=item_emb_qv, y_uv=pos_iu, c=c, ck=ck) return tot_loss
def __mask_assign_targets_anchors_torch( self, batch_points, batch_anchors_3d, batch_gt_boxes_3d, batch_gt_labels, minibatch_size, positive_rate, pos_iou, neg_iou, effective_sample_range, valid_mask): """ Mask assign targets function batch_points: [bs, points_num, 3] batch_anchors_3d: [bs, points_num, cls_num, 7] batch_gt_boxes_3d: [bs, gt_num, 7] batch_gt_labels: [bs, gt_num] valid_mask: [bs, points_num, cls_num] return: assigned_idx: [bs, points_num, cls_num], int32, the index of groundtruth assigned_pmask: [bs, points_num, cls_num], float32 assigned_nmask: [bs, points_num, cls_num], float32 """ bs, pts_num, cls_num, _ = batch_anchors_3d.shape positive_size = int(minibatch_size * positive_rate) batch_assigned_idx = torch.zeros([bs, pts_num, cls_num ]).long().to(batch_points.device) batch_assigned_pmask = torch.zeros([bs, pts_num, cls_num ]).float().to(batch_points.device) batch_assigned_nmask = torch.zeros([bs, pts_num, cls_num ]).float().to(batch_points.device) for i in range(bs): cur_points = batch_points[i] cur_anchors_3d = batch_anchors_3d[i] # [pts_num, cls_num, 3/7] cur_valid_mask = valid_mask[i] # [pts_num, cls_num] # gt_num cur_gt_labels = batch_gt_labels[i] # [gt_num] cur_gt_boxes_3d = batch_gt_boxes_3d[i] # [gt_num, 7] # first filter gt_boxes filter_idx = torch.where( torch.any(torch.not_equal(cur_gt_boxes_3d, 0), dim=-1))[0].to(cur_gt_labels.device) cur_gt_labels = cur_gt_labels[filter_idx] cur_gt_boxes_3d = cur_gt_boxes_3d[filter_idx] cur_points_numpy = cur_points.cpu().detach().numpy() cur_gt_boxes_3d_numpy = cur_gt_boxes_3d.cpu().detach().numpy() points_mask_numpy = check_inside_points( cur_points_numpy, cur_gt_boxes_3d_numpy) # [pts_num, gt_num] points_mask = torch.from_numpy(points_mask_numpy).int().to( cur_points.device) sampled_gt_idx_numpy = np.argmax(points_mask_numpy, axis=-1) sampled_gt_idx = torch.from_numpy(sampled_gt_idx_numpy).long().to( cur_points.device) # [pts_num] # used for label_mask assigned_gt_label = cur_gt_labels[sampled_gt_idx] # [pts_num] assigned_gt_label = assigned_gt_label - 1 # 1... -> 0... # used for dist_mask assigned_gt_boxes = cur_gt_boxes_3d[sampled_gt_idx] # [pts_num, 7] # then calc the distance between anchors and assigned_boxes # dist = cur_anchors_3d[:, :, :3] - assigned_gt_boxes[:, 0:3].unsqueeze(dim=1).repeat((1, cur_anchors_3d.shape[1], 1)) # dist = torch.sqrt(torch.sum(dist * dist, dim=-1)) dist = torch.linalg.norm( cur_anchors_3d[:, :, :3] - assigned_gt_boxes[:, 0:3].unsqueeze(dim=1).repeat( (1, cur_anchors_3d.shape[1], 1)), dim=-1) filtered_assigned_idx = filter_idx[sampled_gt_idx] # [pts_num] filtered_assigned_idx = filtered_assigned_idx.view(pts_num, 1).repeat( (1, cls_num)) batch_assigned_idx[i] = filtered_assigned_idx # then we generate pos/neg mask if cls_num == 1: # anchor_free label_mask = torch.ones( (pts_num, cls_num)).float().to(points_mask.device) else: # multiple anchors label_mask = np.tile( np.reshape(np.arange(cls_num), [1, cls_num]), [pts_num, 1]) label_mask = np.equal(label_mask, assigned_gt_label[:, np.newaxis]).astype( np.float32) pmask = torch.max(points_mask, dim=1)[0] > 0 dist_mask = torch.less_equal( dist, effective_sample_range) # pts_num, cls_num pmask = torch.logical_and(pmask.unsqueeze(-1), dist_mask) pmask = pmask.float() * label_mask pmask = pmask * cur_valid_mask nmask = torch.max(points_mask, dim=1)[0] == 0 nmask = nmask.view(pts_num, 1).repeat((1, cls_num)) nmask = nmask.float() * label_mask nmask = nmask * cur_valid_mask # then randomly sample if minibatch_size != -1: pts_pmask = np.any(pmask, axis=1) # pts_num pts_nmask = np.any(nmask, axis=1) # [pts_num] positive_inds = np.where(pts_pmask)[0] cur_positive_num = np.minimum(len(positive_inds), positive_size) if cur_positive_num > 0: positive_inds = np.random.choice(positive_inds, cur_positive_num, replace=False) pts_pmask = np.zeros_like(pts_pmask) pts_pmask[positive_inds] = 1 cur_negative_num = minibatch_size - cur_positive_num negative_inds = np.where(pts_nmask)[0] cur_negative_num = np.minimum(len(negative_inds), cur_negative_num) if cur_negative_num > 0: negative_inds = np.random.choice(negative_inds, cur_negative_num, replace=False) pts_nmask = np.zeros_like(pts_nmask) pts_nmask[negative_inds] = 1 pmask = pmask * pts_pmask[:, np.newaxis] nmask = nmask * pts_nmask[:, np.newaxis] batch_assigned_pmask[i] = pmask batch_assigned_nmask[i] = nmask return batch_assigned_idx, batch_assigned_pmask, batch_assigned_nmask
def _train_model(model: Any, criterion, optimizer, scheduler, dl, img_datasets): since = time.time() stats = [] best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 for epoch in range(EPOCH): stat = [] if verbose: print('Epoch {}/{}'.format(epoch, EPOCH - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 running_corrects = 0 # Positive running_incorrects = 0 # Negative # Iterate over data. for imgs, labels in dl[phase]: # print('Iterating ', labels, '...') torch.cuda.empty_cache() # clean up cache # print(torch.cuda.memory_summary(device=device, abbreviated=False)) imgs = imgs.float().to(device) labels = labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): outputs = model(imgs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) # backward + optimize only if in training phase if phase == 'train': loss.backward() optimizer.step() # statistics running_loss += loss.item() * imgs.size(0) running_corrects += torch.eq(preds, labels.data).sum() running_incorrects += torch.not_equal(preds, labels.data).sum() # print(running_loss, running_corrects) if phase == 'train': scheduler.step() dataset_size = len(img_datasets[phase]) """ running_corrects.double() = True (True positive + True negative) running_incorrects.double() = False (False positive + False negative) """ epoch_loss = running_loss / dataset_size epoch_acc = running_corrects.double() / dataset_size epoch_tn = running_incorrects.double() / dataset_size if verbose: print("Running corrects: %s" % running_corrects.item()) print("Running incorrects: %s" % running_incorrects.item()) print("Dataset size: %s" % dataset_size) print('{} Loss: {:.4f} Acc: {:.2f}%\n'.format( phase, epoch_loss, epoch_acc * 100)) stat.append(epoch_loss) stat.append(epoch_acc.item()) # print('this is stat: ' + str(stat)) # nb_classes = 2 #?? # confusion_matrix = torch.zeros(nb_classes, nb_classes) # deep copy the model if phase == 'val': # print(stats) # inputs = inputs.to(device) # classes = classes.to(device) # outputs = model_ft stats.append(stat) if epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best val Acc: {:2f}'.format(best_acc * 100)) # load best model weights model.load_state_dict(best_model_wts) return model, stats
def __ne__(self, other): x0, x1 = self._to_binary_tensor_args(other) y = torch.not_equal(x0._t, x1._t) s = _ox.not_equal(*_EagerTensor.ox_args([x0, x1])) return self.from_torch(y, s)
def maskedCrossEntropy(logits, targets_sparse, targets_mask): vals = crossEntropy(logits.transpose(1, 2), targets_sparse) target_mask = torch.not_equal(targets_sparse, 0).float() return torch.mean(target_mask * vals)
def mask_tokens_span(self, inputs: torch.Tensor, mask_labels: torch.Tensor, attention_mask) -> Tuple[torch.Tensor, torch.Tensor]: """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref. """ if self.tokenizer.mask_token is None: raise ValueError( "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." ) # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) probability_matrix = mask_labels special_tokens_mask = [ self.tokenizer.get_special_tokens_mask( val, already_has_special_tokens=True) for val in inputs.tolist() ] probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) if self.tokenizer._pad_token is not None: padding_mask = inputs.eq(self.tokenizer.pad_token_id) probability_matrix.masked_fill_(padding_mask, value=0.0) masked_indices = probability_matrix.bool() #@Todo we are now computing loss on all labels # labels[~masked_indices] = self.tokenizer.pad_token_id # We only compute loss on masked tokens # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) indices_replaced = torch.bernoulli(torch.full( inputs.shape, 0.8)).bool() & masked_indices mask_token_id = self.tokenizer.convert_tokens_to_ids( self.tokenizer.mask_token) inputs[indices_replaced] = mask_token_id # 10% of the time, we replace masked input tokens with random word indices_random = torch.bernoulli(torch.full( inputs.shape, 0.5)).bool() & masked_indices & ~indices_replaced random_words = torch.randint(len(self.tokenizer), inputs.shape, dtype=torch.long) inputs[indices_random] = random_words[indices_random] # The rest of the time (10% of the time) we keep the masked input tokens unchanged # return inputs, labels #Remove consecutive duplicate mask tokens. One mask token represents whole span inputs_left_shift = torch.cat( (inputs[:, 1:], torch.zeros(inputs.shape[0], 1)), dim=-1) mask_left_shift = torch.not_equal((inputs - inputs_left_shift), 0) mask = torch.cat((torch.full( (inputs.shape[0], 1), True), mask_left_shift[:, :-1]), dim=-1) | torch.not_equal(inputs, mask_token_id) inputs = [ torch.masked_select(inputs[i, :], mask[i, :]) for i in range(inputs.shape[0]) ] if attention_mask is not None: attention_mask = [ torch.masked_select(attention_mask[i, :], mask[i, :]) for i in range(attention_mask.shape[0]) ] return inputs, attention_mask
def training(retrain=None): print("Ok, we are ready to train. On your go.") breakpoint() if retrain is not None: checkpoint = torch.load(retrain, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["model_state"]) epochs = 100000 reporting = 2 accumulate = 4 version = "DEC212020_1_NODEC" modelID = str(uuid.uuid4())[-5:] initialRuntime = time.time() writer = SummaryWriter(f'./training/movie/logs/{modelID}') # random.shuffle(zipped_dataset) model.train() # duh for epoch in range(epochs): # # if (epoch % 3 == 0) and epoch != 0: # print(f'Taking a 15 min fridge break before starting at {epoch}...') # for _ in tqdm(range(60*15)): # time.sleep(1) # print(f'Fridge break done. Let\'s get cracking on epoch {epoch}') checkpointID = str(uuid.uuid4())[-5:] batch_data_group = list(zip(inputs_batched, outputs_batched)) random.shuffle(batch_data_group) batch_data_feed = tqdm(enumerate(batch_data_group), total=len(inputs_batched)) for batch, (inp, oup) in batch_data_feed: encinp_torch = np2tens(inp) decinp_torch = np2tens(oup) padding_row = torch.zeros(batch_size, 1) oup_torch = (torch.cat((np2tens(oup)[:, 1:], padding_row), dim=1)).long() prediction = model(encinp_torch, decinp_torch, None, int(batch_size)) target_mask = torch.not_equal(oup_torch, 0).float() # loss_matrix = torch.mean((prediction-torch.nn.functional.one_hot(oup_torch, len(vocabulary)))**2, 2) # loss_val = torch.mean(target_mask*loss_matrix) # powered_value = torch.pow(prediction-oup_vector, 2) # loss_val = torch.mean(target_mask.unsqueeze(-1).expand_as(powered_value)*powered_value) loss_val = criterion(prediction, oup_torch, target_mask) # target_mask = torch.not_equal(oup_torch, 0).float() # loss_matrix = torch.mean((prediction-torch.nn.functional.one_hot(oup_torch, len(vocabulary)))**2, 2) # loss_val = torch.mean(target_mask*loss_matrix) loss_val.backward() # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25) prediction_values = np.array(torch.argmax(prediction, 2).cpu())[:1] if ((batch + (epoch * len(inputs_batched))) % accumulate) == 0 and batch != 0: adam.step() adam.zero_grad() # prediction_values = np.array(torch.argmax(prediction,2).cpu())[:1] prediction_sentences = [] for e in prediction_values: prediction_value = [] for i in e: try: prediction_value.append(vocabulary_inversed[i]) except KeyError: prediction_value.append("<err>") prediction_sentences.append(prediction_value) final_sent = "" for word in prediction_sentences[0]: final_sent = final_sent + word + " " writer.add_scalar('Train/loss', loss_val.item(), batch + (epoch * len(inputs_batched))) writer.add_text('Train/sample', final_sent, batch + (epoch * len(inputs_batched))) batch_data_feed.set_description( f'| Model: {modelID}@{checkpointID} | Epoch: {epoch} | Batch: {batch} | Loss: {loss_val:.5f} |' ) #plot_grad_flow(model.named_parameters()) # CheckpointID,ModelID,ModelVersion,Dataset,Initial Runtime,Current Time,Epoch,Loss,Checkpoint Filename initialHumanTime = datetime.fromtimestamp(initialRuntime).strftime( "%m/%d/%Y, %H:%M:%S") nowHumanTime = datetime.now().strftime("%m/%d/%Y, %H:%M:%S") with open("./training/movie/training-log.csv", "a+") as df: csvfile = csv.writer(df) csvfile.writerow([ checkpointID, modelID, version, dataset_name, initialHumanTime, nowHumanTime, epoch, loss_val.item(), f'{modelID}-{checkpointID}.model', f'{retrain}' ]) torch.save( { 'version': version, 'modelID': modelID, 'checkpointID': checkpointID, 'datasetName': dataset_name, 'epoch': epoch, 'loss': loss_val, 'model_state': model.state_dict(), 'optimizer_state': adam.state_dict(), 'lr': scheduler.get_last_lr() }, f'./training/movie/{modelID}-{checkpointID}.model') print(f'| EPOCH DONE | Epoch: {epoch} | Loss: {loss_val} |') scheduler.step() writer.close()
def kspace_mask(image_orig, name="kspace_mask", dtype=None): """Find k-space mask.""" mask_x = torch.not_equal(image_orig, 0) if dtype is not None: mask_x = torch.cast(mask_x, dtype=dtype) return mask_x
<<<<<<< HEAD assert inputs.dtype == torch.int8 or inputs.dtype == torch.int16 or \ inputs.dtype == torch.int32 or inputs.dtype == torch.int64 >>>>>>> upstream/master ======= assert ( inputs.dtype == torch.int8 or inputs.dtype == torch.int16 or inputs.dtype == torch.int32 or inputs.dtype == torch.int64 ) >>>>>>> upstream/master assert len(inputs.shape) == 2 if self.pad_idx is not None: inputs_mask = torch.not_equal(inputs, self.pad_idx) else: inputs_mask = torch.not_equal(inputs, 0) inputs_exp = inputs.type(torch.int32) lengths = torch.sum(inputs_mask.type(torch.int32), dim=1) encoder_output = self.encoder_obj(inputs_exp, mask=inputs_mask) encoder_output[LENGTHS] = lengths <<<<<<< HEAD return encoder_output @property def input_dtype(self): return torch.int32 =======