Example #1
0
    def train(self):
        for curr_epoch in range(self.start_epoch, self.end_epoch):
            train_loss_record = AvgMeter()
            self._train_per_epoch(curr_epoch, train_loss_record)

            # 根据周期修改学习率
            if not self.arg_dict["sche_usebatch"]:
                self.sche.step()

            # 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数
            save_checkpoint(
                model=self.net,
                optimizer=self.opti,
                scheduler=self.sche,
                amp=self.amp,
                exp_name=self.exp_name,
                current_epoch=curr_epoch + 1,
                full_net_path=self.path_dict["final_full_net"],
                state_net_path=self.path_dict["final_state_net"],
            )  # 保存参数

        if self.arg_dict["use_amp"]:
            # https://github.com/NVIDIA/apex/issues/567
            with self.amp.disable_casts():
                construct_print(
                    "When evaluating, we wish to evaluate in pure fp32.")
                self.test()
        else:
            self.test()
Example #2
0
def train(
    model,
    start_epoch,
    end_epoch,
    tr_loader,
    optimizer,
    scheduler,
    loss_funcs,
    train_sampler,
    local_rank,
):
    for curr_epoch in range(start_epoch, end_epoch):
        if user_config["is_distributed"]:
            train_sampler.set_epoch(curr_epoch)
        if not user_config["sche_usebatch"]:
            scheduler.step(optimizer=optimizer, curr_epoch=curr_epoch)

            train_epoch_prefetch_generator(
                curr_epoch,
                end_epoch,
                loss_funcs,
                model,
                optimizer,
                scheduler,
                tr_loader,
                local_rank,
            )

        if local_rank == 0:
            # note: to varify the correctness of the modl and training process
            if (user_config["val_freq"] >
                    0) and (curr_epoch + 1) % user_config["val_freq"] == 0:
                _ = test(model, mode="val", save_pre=False)

            if ((user_config["save_freq"] > 0) and
                (curr_epoch + 1) % user_config["save_freq"]
                    == 0) or (curr_epoch == end_epoch - 1):
                save_checkpoint(
                    model=model,
                    optimizer=optimizer,
                    amp=amp if user_config["use_amp"] else None,
                    exp_name=exp_name,
                    current_epoch=curr_epoch + 1,
                    full_net_path=path_config["final_full_net"],
                    state_net_path=path_config["final_state_net"],
                )

    if local_rank == 0:
        total_results = test(model,
                             mode="test",
                             save_pre=user_config["save_pre"])
        xlsx_recorder.write_xlsx(exp_name, total_results)
Example #3
0
    def train(self):
        for curr_epoch in range(self.start_epoch, self.end_epoch):
            train_loss_record = AvgMeter()
            for train_batch_id, train_data in enumerate(self.tr_loader):
                curr_iter = curr_epoch * len(self.tr_loader) + train_batch_id

                self.opti.zero_grad()
                train_inputs, train_masks, *train_other_data = train_data
                train_inputs = train_inputs.to(self.dev, non_blocking=True)
                train_masks = train_masks.to(self.dev, non_blocking=True)
                train_preds = self.net(train_inputs)

                train_loss, loss_item_list = get_total_loss(
                    train_preds, train_masks, self.loss_funcs)
                train_loss.backward()
                self.opti.step()

                if self.args["sche_usebatch"]:
                    self.sche.step()

                # 仅在累计的时候使用item()获取数据
                train_iter_loss = train_loss.item()
                train_batch_size = train_inputs.size(0)
                train_loss_record.update(train_iter_loss, train_batch_size)

                # 显示tensorboard
                if self.args["tb_update"] > 0 and (
                        curr_iter + 1) % self.args["tb_update"] == 0:
                    self.tb.add_scalar("data/trloss_avg",
                                       train_loss_record.avg, curr_iter)
                    self.tb.add_scalar("data/trloss_iter", train_iter_loss,
                                       curr_iter)
                    for idx, param_groups in enumerate(self.opti.param_groups):
                        self.tb.add_scalar(f"data/lr_{idx}",
                                           param_groups["lr"], curr_iter)
                    tr_tb_mask = make_grid(train_masks,
                                           nrow=train_batch_size,
                                           padding=5)
                    self.tb.add_image("trmasks", tr_tb_mask, curr_iter)
                    tr_tb_out_1 = make_grid(train_preds,
                                            nrow=train_batch_size,
                                            padding=5)
                    self.tb.add_image("trsodout", tr_tb_out_1, curr_iter)

                # 记录每一次迭代的数据
                if self.args["print_freq"] > 0 and (
                        curr_iter + 1) % self.args["print_freq"] == 0:
                    lr_str = ",".join([
                        f"{param_groups['lr']:.7f}"
                        for param_groups in self.opti.param_groups
                    ])
                    log = (
                        f"[I:{curr_iter}/{self.iter_num}][E:{curr_epoch}:{self.end_epoch}]>"
                        f"[{self.exp_name}]"
                        f"[Lr:{lr_str}]"
                        f"[Avg:{train_loss_record.avg:.5f}|Cur:{train_iter_loss:.5f}|"
                        f"{loss_item_list}]")
                    print(log)
                    make_log(self.path["tr_log"], log)

            # 根据周期修改学习率
            if not self.args["sche_usebatch"]:
                self.sche.step()

            # 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数
            save_checkpoint(
                model=self.net,
                optimizer=self.opti,
                scheduler=self.sche,
                exp_name=self.exp_name,
                current_epoch=curr_epoch + 1,
                full_net_path=self.path["final_full_net"],
                state_net_path=self.path["final_state_net"],
            )  # 保存参数

        total_results = self.test()
        # save result into xlsx file.
        write_xlsx(self.exp_name, total_results)