def validate(self, valid_iter, step=0): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. self.model.eval() if self.args.acc_reporter == 1: stats = acc_reporter.Statistics() else: stats = Statistics() with torch.no_grad(): for batch in valid_iter: src = batch.src labels = batch.src_sent_labels segs = batch.segs clss = batch.clss mask = batch.mask_src mask_cls = batch.mask_cls if self.args.ext_sum_dec: sent_scores, mask = self.model(src, segs, clss, mask, mask_cls, labels) # B, tgt_len custom_num tgt_len = 3 _, labels_id = torch.topk(labels, k=tgt_len) # B, tgt_len labels_id, _ = torch.sort(labels_id) # nsent 100 weight_up 20 weight = torch.linspace(start=1, end=self.args.weight_up, steps=self.args.max_src_nsents).type_as(sent_scores) # self.max_class = max(self.max_class,torch.max(labels_id+1).item()) # weight = weight[:self.max_class] weight = weight[:sent_scores.size(-1)] # weight = torch.ones(self.args.max_src_nsents) loss = F.nll_loss( F.log_softmax( sent_scores.view(-1, sent_scores.size(-1)), dim=-1, dtype=torch.float32, ), labels_id.view(-1), # bsz sent weight=weight, reduction='sum', ignore_index=-1, ) prediction = torch.argmax(sent_scores, dim=-1) if (self.optim._step + 1) % self.args.print_every == 0: logger.info( 'train prediction: %s |label %s ' % (str(prediction), str(labels_id))) accuracy = torch.div(torch.sum(torch.equal(prediction, labels_id).float()), tgt_len) else: sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) # B, custom_N loss = self.loss(sent_scores, labels.float()) loss = (loss * mask.float()).sum() tgt_len = 3 _, labels_id = torch.topk(labels, k=tgt_len) # B, tgt_len labels_id, _ = torch.sort(labels_id) _, prediction = torch.topk(sent_scores, k=tgt_len) prediction,_ = torch.sort(labels_id) if (self.optim._step + 1) % self.args.print_every == 0: logger.info( 'train prediction: %s |label %s ' % (str(prediction), str(labels_id))) accuracy = torch.div(torch.sum(torch.equal(prediction, labels_id).float()), tgt_len) if self.args.acc_reporter == 1: batch_stats = Statistics(float(loss.cpu().data.numpy()),accuracy, len(labels)) else: batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels)) stats.update(batch_stats) self._report_step(0, step, valid_stats=stats) return stats
def _gradient_accumulation(self, true_batchs, normalization, total_stats, report_stats): if self.grad_accum_count > 1: self.model.zero_grad() for batch in true_batchs: if self.grad_accum_count == 1: self.model.zero_grad() # src = torch.tensor(self._pad(pre_src, 0)) # segs = torch.tensor(self._pad(pre_segs, 0)) # mask_src = torch.logical_not(src == 0) # clss = torch.tensor(self._pad(pre_clss, -1)) # src_sent_labels = torch.tensor(self._pad(pre_src_sent_labels, 0)) # mask_cls = torch.logical_not(clss == -1) # clss[clss == -1] = 0 # setattr(self, 'clss' + postfix, clss.to(device)) # setattr(self, 'mask_cls' + postfix, mask_cls.to(device)) # setattr(self, 'src_sent_labels' + postfix, src_sent_labels.to(device)) # setattr(self, 'src' + postfix, src.to(device)) # setattr(self, 'segs' + postfix, segs.to(device)) # setattr(self, 'mask_src' + postfix, mask_src.to(device)) # # 下面都是要预测的给他pad -1, 意思是看到-1 就停止算loss, 不用计算mask ,mask 是作为输入时才要的 # org_sent_labels = torch.tensor(self._pad(org_sent_labels, -1)) # setattr(self, 'org_sent_labels' + postfix, org_sent_labels.to(device)) # poss = torch.tensor(self._pad(poss, -1)) # setattr(self, 'poss' + postfix, poss.to(device)) if self.args.jigsaw == 'jigsaw_lab': # jigsaw_lab 各自预测的那种,失败的尝试 logits = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s)# bsz tgt_len nsent # bsz, sent, max-sent_num # mask = batch.mask_cls_s[:, :, None].float() # loss = self.loss(sent_scores, batch.poss_s.float()) loss = F.nll_loss( F.log_softmax( logits.view(-1, logits.size(-1)), dim=-1, dtype=torch.float32, ), batch.poss_s.view(-1), # bsz sent reduction='sum', ignore_index=-1, ) prediction = torch.argmax(logits, dim=-1) if (self.optim._step + 1) % self.args.print_every == 0: logger.info( 'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s))) accuracy = torch.div(torch.sum(torch.equal(prediction, batch.poss_s) * batch.mask_cls_s), torch.sum(batch.mask_cls_s)) * len(logits) # loss = (loss * batch.mask_cls_s.float()).sum() # print('train prediction: %s |label %s ' % (str(torch.argmax(logits, dim=-1)[0]), str(batch.poss_s[0]))) # logger.info('train prediction: %s |label %s ' % (str(torch.argmax(logits, dim=-1)[0]), str(batch.poss_s[0]))) # (loss / loss.numel()).backward() else: #self.args.jigsaw == 'jigsaw_dec': jigsaw decoder poss_s = batch.poss_s mask_poss = torch.eq(poss_s, -1) poss_s.masked_fill_(mask_poss, 1e4) # poss_s[i] [5,1,4,0,2,3,-1,-1]->[5,1,4,0,2,3,1e4,1e4] dec_labels[i] [3,1,xxx,6,7] dec_labels = torch.argsort(poss_s, dim=1) logits,_ = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s, dec_labels) final_dec_labels = dec_labels.masked_fill(mask_poss, -1) loss = F.nll_loss( F.log_softmax( logits.view(-1, logits.size(-1)), dim=-1, dtype=torch.float32, ), final_dec_labels.view(-1), # bsz sent reduction='sum', ignore_index=-1, ) # loss = (loss * batch.mask_cls_s.float()).sum() # (loss / loss.numel()).backward() prediction = torch.argmax(logits, dim=-1) if (self.optim._step + 1) % self.args.print_every == 0: logger.info( 'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s))) accuracy = torch.div(torch.sum(torch.equal(prediction, batch.poss_s) * batch.mask_cls_s), torch.sum(batch.mask_cls_s)) * len(logits) with amp.scale_loss((loss / loss.numel()), self.optim.optimizer) as scaled_loss: scaled_loss.backward() # loss.div(float(normalization)).backward() if self.args.acc_reporter: batch_stats = acc_reporter.Statistics(float(loss.cpu().data.numpy()), accuracy, normalization) else: batch_stats = Statistics(float(loss.cpu().data.numpy()), normalization) total_stats.update(batch_stats) report_stats.update(batch_stats) # 4. Update the parameters and statistics. if self.grad_accum_count == 1: # Multi GPU gradient gather if self.n_gpu > 1: grads = [p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None] distributed.all_reduce_and_rescale_tensors( grads, float(1)) self.optim.step() # in case of multi step gradient accumulation, # update only after accum batches if self.grad_accum_count > 1: if self.n_gpu > 1: grads = [p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None] distributed.all_reduce_and_rescale_tensors( grads, float(1)) self.optim.step()
def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter_fct(function): same as train_iter_fct, for valid data train_steps(int): valid_steps(int): save_checkpoint_steps(int): Return: None """ logger.info('Start training...') # step = self.optim._step + 1 step = self.optim._step + 1 true_batchs = [] accum = 0 normalization = 0 train_iter = train_iter_fct() if self.args.acc_reporter == 1: total_stats = acc_reporter.Statistics() report_stats = acc_reporter.Statistics() else: total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) while step <= train_steps: reduce_counter = 0 for i, batch in enumerate(train_iter): if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): true_batchs.append(batch) normalization += batch.batch_size accum += 1 if accum == self.grad_accum_count: # 20200318 1703 似乎step就是num_updates reduce_counter += 1 if self.n_gpu > 1: normalization = sum(distributed .all_gather_list (normalization)) self._gradient_accumulation( true_batchs, normalization, total_stats, report_stats) report_stats = self._maybe_report_training( step, train_steps, self.optim.learning_rate, report_stats) true_batchs = [] accum = 0 normalization = 0 if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): self._save(step) step += 1 if step > train_steps: break train_iter = train_iter_fct() return total_stats
def validate(self, valid_iter, step=0): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. self.model.eval() if self.args.acc_reporter: stats = acc_reporter.Statistics() else: stats = Statistics() with torch.no_grad(): for batch in valid_iter: # src = batch.src # labels = batch.src_sent_labels # segs = batch.segs # clss = batch.clss # mask = batch.mask_src # mask_cls = batch.mask_cls # sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) if self.args.jigsaw == 'jigsaw_lab': # jigsaw_lab 3.31 23:38 发现之前忘了改validate, 早上起来再跑一次看看 logits = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s) # bsz, sent, max-sent_num # mask = batch.mask_cls_s[:, :, None].float() # loss = self.loss(sent_scores, batch.poss_s.float()) loss = F.nll_loss( F.log_softmax( logits.view(-1, logits.size(-1)), dim=-1, dtype=torch.float32, ), batch.poss_s.view(-1), # bsz sent reduction='sum', ignore_index=-1, ) prediction = torch.argmax(logits, dim=-1) if (self.optim._step + 1) % self.args.print_every == 0: logger.info( 'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s))) accuracy = torch.div(torch.sum(torch.equal(prediction, batch.poss_s) * batch.mask_cls_s), torch.sum(batch.mask_cls_s)) * len(logits) elif self.args.jigsaw == 'jigsaw_dec': # jigsaw decoder poss_s = batch.poss_s mask_poss = torch.eq(poss_s, -1) poss_s.masked_fill_(mask_poss, 1e4) # poss_s[i] [5,1,4,0,2,3,-1,-1]->[5,1,4,0,2,3,1e4,1e4] dec_labels = torch.argsort(poss_s, dim=1) # dec_labels[i] [3,1,xxx,6,7] logits = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s, dec_labels) final_dec_labels = dec_labels.masked_fill(mask_poss, -1) # final_dec_labels[i] [3,1,xxx,-1,-1] loss = F.nll_loss( F.log_softmax( logits.view(-1, logits.size(-1)), dim=-1, dtype=torch.float32, ), final_dec_labels.view(-1), # bsz sent reduction='sum', ignore_index=-1, ) # loss = (loss * batch.mask_cls_s.float()).sum() prediction = torch.argmax(logits, dim=-1) if (self.optim._step + 1) % self.args.print_every == 0: logger.info( 'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s))) accuracy = torch.div(torch.sum(torch.equal(prediction, batch.final_dec_labels) * batch.mask_cls_s), torch.sum(batch.mask_cls_s)) * len(logits) # loss = self.loss(sent_scores, labels.float()) # loss = (loss * mask.float()).sum() if self.args.acc_reporter: batch_stats = acc_reporter.Statistics(float(loss.cpu().data.numpy()), accuracy, len(batch.poss_s)) else: batch_stats = Statistics(float(loss.cpu().data.numpy()), len(batch.poss_s)) stats.update(batch_stats) self._report_step(0, step, valid_stats=stats) return stats