def train(self): for curr_epoch in range(self.start_epoch, self.end_epoch): train_loss_record = AvgMeter() self._train_per_epoch(curr_epoch, train_loss_record) # 根据周期修改学习率 if not self.arg_dict["sche_usebatch"]: self.sche.step() # 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数 save_checkpoint( model=self.net, optimizer=self.opti, scheduler=self.sche, amp=self.amp, exp_name=self.exp_name, current_epoch=curr_epoch + 1, full_net_path=self.path_dict["final_full_net"], state_net_path=self.path_dict["final_state_net"], ) # 保存参数 if self.arg_dict["use_amp"]: # https://github.com/NVIDIA/apex/issues/567 with self.amp.disable_casts(): construct_print( "When evaluating, we wish to evaluate in pure fp32.") self.test() else: self.test()
def train( model, start_epoch, end_epoch, tr_loader, optimizer, scheduler, loss_funcs, train_sampler, local_rank, ): for curr_epoch in range(start_epoch, end_epoch): if user_config["is_distributed"]: train_sampler.set_epoch(curr_epoch) if not user_config["sche_usebatch"]: scheduler.step(optimizer=optimizer, curr_epoch=curr_epoch) train_epoch_prefetch_generator( curr_epoch, end_epoch, loss_funcs, model, optimizer, scheduler, tr_loader, local_rank, ) if local_rank == 0: # note: to varify the correctness of the modl and training process if (user_config["val_freq"] > 0) and (curr_epoch + 1) % user_config["val_freq"] == 0: _ = test(model, mode="val", save_pre=False) if ((user_config["save_freq"] > 0) and (curr_epoch + 1) % user_config["save_freq"] == 0) or (curr_epoch == end_epoch - 1): save_checkpoint( model=model, optimizer=optimizer, amp=amp if user_config["use_amp"] else None, exp_name=exp_name, current_epoch=curr_epoch + 1, full_net_path=path_config["final_full_net"], state_net_path=path_config["final_state_net"], ) if local_rank == 0: total_results = test(model, mode="test", save_pre=user_config["save_pre"]) xlsx_recorder.write_xlsx(exp_name, total_results)
def train(self): for curr_epoch in range(self.start_epoch, self.end_epoch): train_loss_record = AvgMeter() for train_batch_id, train_data in enumerate(self.tr_loader): curr_iter = curr_epoch * len(self.tr_loader) + train_batch_id self.opti.zero_grad() train_inputs, train_masks, *train_other_data = train_data train_inputs = train_inputs.to(self.dev, non_blocking=True) train_masks = train_masks.to(self.dev, non_blocking=True) train_preds = self.net(train_inputs) train_loss, loss_item_list = get_total_loss( train_preds, train_masks, self.loss_funcs) train_loss.backward() self.opti.step() if self.args["sche_usebatch"]: self.sche.step() # 仅在累计的时候使用item()获取数据 train_iter_loss = train_loss.item() train_batch_size = train_inputs.size(0) train_loss_record.update(train_iter_loss, train_batch_size) # 显示tensorboard if self.args["tb_update"] > 0 and ( curr_iter + 1) % self.args["tb_update"] == 0: self.tb.add_scalar("data/trloss_avg", train_loss_record.avg, curr_iter) self.tb.add_scalar("data/trloss_iter", train_iter_loss, curr_iter) for idx, param_groups in enumerate(self.opti.param_groups): self.tb.add_scalar(f"data/lr_{idx}", param_groups["lr"], curr_iter) tr_tb_mask = make_grid(train_masks, nrow=train_batch_size, padding=5) self.tb.add_image("trmasks", tr_tb_mask, curr_iter) tr_tb_out_1 = make_grid(train_preds, nrow=train_batch_size, padding=5) self.tb.add_image("trsodout", tr_tb_out_1, curr_iter) # 记录每一次迭代的数据 if self.args["print_freq"] > 0 and ( curr_iter + 1) % self.args["print_freq"] == 0: lr_str = ",".join([ f"{param_groups['lr']:.7f}" for param_groups in self.opti.param_groups ]) log = ( f"[I:{curr_iter}/{self.iter_num}][E:{curr_epoch}:{self.end_epoch}]>" f"[{self.exp_name}]" f"[Lr:{lr_str}]" f"[Avg:{train_loss_record.avg:.5f}|Cur:{train_iter_loss:.5f}|" f"{loss_item_list}]") print(log) make_log(self.path["tr_log"], log) # 根据周期修改学习率 if not self.args["sche_usebatch"]: self.sche.step() # 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数 save_checkpoint( model=self.net, optimizer=self.opti, scheduler=self.sche, exp_name=self.exp_name, current_epoch=curr_epoch + 1, full_net_path=self.path["final_full_net"], state_net_path=self.path["final_state_net"], ) # 保存参数 total_results = self.test() # save result into xlsx file. write_xlsx(self.exp_name, total_results)