示例#1
0
    def test(self):
        self.net.eval()

        total_results = {}
        for data_name, data_path in self.te_data_list.items():
            construct_print(f"Testing with testset: {data_name}")
            self.te_loader = create_loader(
                data_path=data_path,
                mode="test",
                get_length=False,
                prefix=self.args["prefix"],
            )
            self.save_path = os.path.join(self.path["save"], data_name)
            if not os.path.exists(self.save_path):
                construct_print(
                    f"{self.save_path} do not exist. Let's create it.")
                os.makedirs(self.save_path)
            results = self.__test_process(save_pre=self.save_pre)
            msg = f"Results on the testset({data_name}:'{data_path}'): {results}"
            construct_print(msg)
            make_log(self.path["te_log"], msg)

            total_results[data_name.upper()] = results

        self.net.train()
        return total_results
示例#2
0
def train_epoch(data_loader, curr_epoch):
    loss_record = AvgMeter()
    construct_print(f"Exp_Name: {exp_name}")
    for curr_iter_in_epoch, data in enumerate(data_loader):
        num_iter_per_epoch = len(data_loader)
        curr_iter = curr_epoch * num_iter_per_epoch + curr_iter_in_epoch

        curr_jpegs = data["image"].to(DEVICES, non_blocking=True)
        curr_masks = data["mask"].to(DEVICES, non_blocking=True)
        curr_flows = data["flow"].to(DEVICES, non_blocking=True)
        with amp.autocast(enabled=arg_config['use_amp']):
            preds = model(
                data=dict(curr_jpeg=curr_jpegs, curr_flow=curr_flows))
            seg_jpeg_logits = preds["curr_seg"]
            seg_flow_logits = preds["curr_seg_flow"]

            jpeg_loss_info = cal_total_loss(seg_logits=seg_jpeg_logits,
                                            seg_gts=curr_masks)
            flow_loss_info = cal_total_loss(seg_logits=seg_flow_logits,
                                            seg_gts=curr_masks)
            total_loss = jpeg_loss_info['loss'] + flow_loss_info['loss']
            total_loss_str = jpeg_loss_info['loss_str'] + flow_loss_info[
                'loss_str']

        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        iter_loss = total_loss.item()
        batch_size = curr_jpegs.size(0)
        loss_record.update(iter_loss, batch_size)

        if (arg_config["tb_update"] > 0
                and curr_iter % arg_config["tb_update"] == 0):
            tb_recorder.record_curve("loss_avg", loss_record.avg, curr_iter)
            tb_recorder.record_curve("iter_loss", iter_loss, curr_iter)
            tb_recorder.record_curve("lr", optimizer.param_groups, curr_iter)
            tb_recorder.record_image("jpegs", curr_jpegs, curr_iter)
            tb_recorder.record_image("flows", curr_flows, curr_iter)
            tb_recorder.record_image("masks", curr_masks, curr_iter)
            tb_recorder.record_image("segs", preds["curr_seg"].sigmoid(),
                                     curr_iter)

        if (arg_config["print_freq"] > 0
                and curr_iter % arg_config["print_freq"] == 0):
            lr_string = ",".join([f"{x:.7f}" for x in scheduler.get_last_lr()])
            log = (
                f"[{curr_iter_in_epoch}:{num_iter_per_epoch},{curr_iter}:{num_iter},"
                f"{curr_epoch}:{end_epoch}][{list(curr_jpegs.shape)}]"
                f"[Lr:{lr_string}][M:{loss_record.avg:.5f},C:{iter_loss:.5f}]{total_loss_str}"
            )
            print(log)
            make_log(path_dict["tr_log"], log)

        if scheduler_usebatch:
            scheduler.step()
示例#3
0
def train(tr_loader, val_loader=None):
    for curr_epoch in range(start_epoch, end_epoch):
        if val_loader is not None:
            seg_results = test(model=model, data_loader=val_loader)
            msg = f"Epoch: {curr_epoch}, Results on the valsel: {seg_results}"
            print(msg)
            make_log(path_dict["te_log"], msg)

        model.train()
        train_epoch(data_loader=tr_loader, curr_epoch=curr_epoch)

        # 根据周期修改学习率
        if not scheduler_usebatch:
            scheduler.step()

        # 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数
        save_checkpoint(exp_name=exp_name,
                        model=model,
                        current_epoch=curr_epoch + 1,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        scaler=scaler,
                        full_net_path=path_dict["final_full_net"],
                        state_net_path=path_dict["final_state_net"])
示例#4
0
    def train(self):
        for curr_epoch in range(self.start_epoch, self.end_epoch):
            train_loss_record = AvgMeter()
            for train_batch_id, train_data in enumerate(self.tr_loader):
                curr_iter = curr_epoch * len(self.tr_loader) + train_batch_id

                self.opti.zero_grad()
                train_inputs, train_masks, *train_other_data = train_data
                train_inputs = train_inputs.to(self.dev, non_blocking=True)
                train_masks = train_masks.to(self.dev, non_blocking=True)
                train_preds = self.net(train_inputs)

                train_loss, loss_item_list = self.total_loss(
                    train_preds, train_masks)
                train_loss.backward()
                self.opti.step()

                if self.args["sche_usebatch"]:
                    if self.args["lr_type"] == "poly":
                        self.sche.step(curr_iter + 1)
                    else:
                        raise NotImplementedError

                # 仅在累计的时候使用item()获取数据
                train_iter_loss = train_loss.item()
                train_batch_size = train_inputs.size(0)
                train_loss_record.update(train_iter_loss, train_batch_size)

                # 显示tensorboard
                if (self.args["tb_update"] > 0
                        and (curr_iter + 1) % self.args["tb_update"] == 0):
                    self.tb.add_scalar("data/trloss_avg",
                                       train_loss_record.avg, curr_iter)
                    self.tb.add_scalar("data/trloss_iter", train_iter_loss,
                                       curr_iter)
                    self.tb.add_scalar("data/trlr",
                                       self.opti.param_groups[0]["lr"],
                                       curr_iter)
                    tr_tb_mask = make_grid(train_masks,
                                           nrow=train_batch_size,
                                           padding=5)
                    self.tb.add_image("trmasks", tr_tb_mask, curr_iter)
                    tr_tb_out_1 = make_grid(train_preds,
                                            nrow=train_batch_size,
                                            padding=5)
                    self.tb.add_image("trsodout", tr_tb_out_1, curr_iter)

                # 记录每一次迭代的数据
                if (self.args["print_freq"] > 0
                        and (curr_iter + 1) % self.args["print_freq"] == 0):
                    log = (
                        f"[I:{curr_iter}/{self.iter_num}][E:{curr_epoch}:{self.end_epoch}]>"
                        f"[{self.model_name}]"
                        f"[Lr:{self.opti.param_groups[0]['lr']:.7f}]"
                        f"[Avg:{train_loss_record.avg:.5f}|Cur:{train_iter_loss:.5f}|"
                        f"{loss_item_list}]")
                    print(log)
                    make_log(self.path["tr_log"], log)

            # 根据周期修改学习率
            if not self.args["sche_usebatch"]:
                if self.args["lr_type"] == "poly":
                    self.sche.step(curr_epoch + 1)
                else:
                    raise NotImplementedError

            # 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数
            self.save_checkpoint(
                curr_epoch + 1,
                full_net_path=self.path['final_full_net'],
                state_net_path=self.path['final_state_net'])  # 保存参数

        total_results = {}
        for data_name, data_path in self.te_data_list.items():
            construct_print(f"Testing with testset: {data_name}")
            self.te_loader, self.te_length = create_loader(data_path=data_path,
                                                           mode='test',
                                                           get_length=True)
            self.save_path = os.path.join(self.path["save"], data_name)
            if not os.path.exists(self.save_path):
                construct_print(
                    f"{self.save_path} do not exist. Let's create it.")
                os.makedirs(self.save_path)
            results = self.test(save_pre=self.save_pre)
            msg = (
                f"Results on the testset({data_name}:'{data_path}'): {results}"
            )
            construct_print(msg)
            make_log(self.path["te_log"], msg)

            total_results[data_name.upper()] = results
        # save result into xlsx file.
        write_xlsx(self.model_name, total_results)
示例#5
0
    if not arg_config["resume"]:
        resume_checkpoint(exp_name=exp_name,
                          load_path=path_dict["final_state_net"],
                          model=model,
                          mode="onlynet")

if arg_config["has_test"]:
    for te_data_name, te_data_path in arg_config["data"]["te"]:
        construct_print(f"Testing with testset: {te_data_name}")
        te_loader_info = create_loader(
            data_path=te_data_path,
            training=False,
            in_size=arg_config['in_size']['te'],
            batch_size=arg_config['batch_size'],
            num_workers=arg_config['num_workers'],
            shuffle=False,
            drop_last=False,
            pin_memory=True,
            use_custom_worker_init=arg_config['use_custom_worker_init'])
        pred_save_path = os.path.join(path_dict["save"], te_data_name)
        seg_results = test(model=model,
                           data_loader=te_loader_info['loader'],
                           save_path=pred_save_path)
        msg = (
            f"Results on the testset({te_data_name}:'{te_data_path['root']}'):\n{seg_results}"
        )
        print(msg)
        make_log(path_dict["te_log"], msg)

construct_print(f"{datetime.now()}: End training...")
示例#6
0
    def train(self):
        for curr_epoch in range(self.start_epoch, self.end_epoch):
            train_loss_record = AvgMeter()
            for train_batch_id, train_data in enumerate(self.tr_loader):
                curr_iter = curr_epoch * len(self.tr_loader) + train_batch_id

                self.opti.zero_grad()
                train_inputs, train_masks, *train_other_data = train_data
                train_inputs = train_inputs.to(self.dev, non_blocking=True)
                train_masks = train_masks.to(self.dev, non_blocking=True)
                train_preds = self.net(train_inputs)

                train_loss, loss_item_list = get_total_loss(
                    train_preds, train_masks, self.loss_funcs)
                train_loss.backward()
                self.opti.step()

                if self.args["sche_usebatch"]:
                    self.sche.step()

                # 仅在累计的时候使用item()获取数据
                train_iter_loss = train_loss.item()
                train_batch_size = train_inputs.size(0)
                train_loss_record.update(train_iter_loss, train_batch_size)

                # 显示tensorboard
                if self.args["tb_update"] > 0 and (
                        curr_iter + 1) % self.args["tb_update"] == 0:
                    self.tb.add_scalar("data/trloss_avg",
                                       train_loss_record.avg, curr_iter)
                    self.tb.add_scalar("data/trloss_iter", train_iter_loss,
                                       curr_iter)
                    for idx, param_groups in enumerate(self.opti.param_groups):
                        self.tb.add_scalar(f"data/lr_{idx}",
                                           param_groups["lr"], curr_iter)
                    tr_tb_mask = make_grid(train_masks,
                                           nrow=train_batch_size,
                                           padding=5)
                    self.tb.add_image("trmasks", tr_tb_mask, curr_iter)
                    tr_tb_out_1 = make_grid(train_preds,
                                            nrow=train_batch_size,
                                            padding=5)
                    self.tb.add_image("trsodout", tr_tb_out_1, curr_iter)

                # 记录每一次迭代的数据
                if self.args["print_freq"] > 0 and (
                        curr_iter + 1) % self.args["print_freq"] == 0:
                    lr_str = ",".join([
                        f"{param_groups['lr']:.7f}"
                        for param_groups in self.opti.param_groups
                    ])
                    log = (
                        f"[I:{curr_iter}/{self.iter_num}][E:{curr_epoch}:{self.end_epoch}]>"
                        f"[{self.exp_name}]"
                        f"[Lr:{lr_str}]"
                        f"[Avg:{train_loss_record.avg:.5f}|Cur:{train_iter_loss:.5f}|"
                        f"{loss_item_list}]")
                    print(log)
                    make_log(self.path["tr_log"], log)

            # 根据周期修改学习率
            if not self.args["sche_usebatch"]:
                self.sche.step()

            # 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数
            save_checkpoint(
                model=self.net,
                optimizer=self.opti,
                scheduler=self.sche,
                exp_name=self.exp_name,
                current_epoch=curr_epoch + 1,
                full_net_path=self.path["final_full_net"],
                state_net_path=self.path["final_state_net"],
            )  # 保存参数

        total_results = self.test()
        # save result into xlsx file.
        write_xlsx(self.exp_name, total_results)
示例#7
0
    def train(self):
        for curr_epoch in range(self.start_epoch, self.end_epoch):
            train_loss_record = AvgMeter()

            if self.args["lr_type"] == "poly":
                self.change_lr(curr_epoch)
            else:
                raise NotImplementedError

            for train_batch_id, train_data in enumerate(self.tr_loader):
                curr_iter = curr_epoch * len(self.tr_loader) + train_batch_id

                self.opti.zero_grad()
                train_inputs, train_masks, *train_other_data = train_data
                train_inputs = train_inputs.to(self.dev, non_blocking=True)
                train_masks = train_masks.to(self.dev, non_blocking=True)
                if self.data_mode == "RGBD":
                    # train_other_data是一个list
                    train_depths = train_other_data[-1]
                    train_depths = train_depths.to(self.dev, non_blocking=True)
                    train_preds = self.net(train_inputs, train_depths)
                elif self.data_mode == "RGB":
                    train_preds = self.net(train_inputs)
                else:
                    raise NotImplementedError

                train_loss, loss_item_list = self.total_loss(
                    train_preds, train_masks)
                train_loss.backward()
                self.opti.step()

                # 仅在累计的时候使用item()获取数据
                train_iter_loss = train_loss.item()
                train_batch_size = train_inputs.size(0)
                train_loss_record.update(train_iter_loss, train_batch_size)

                # 记录每一次迭代的数据
                if self.args["print_freq"] > 0 and (
                        curr_iter + 1) % self.args["print_freq"] == 0:
                    log = (
                        f"[I:{curr_iter}/{self.iter_num}][E:{curr_epoch}:{self.end_epoch}]>"
                        f"[{self.model_name}]"
                        f"[Lr:{self.opti.param_groups[0]['lr']:.7f}]"
                        f"[Avg:{train_loss_record.avg:.5f}|Cur:{train_iter_loss:.5f}|"
                        f"{loss_item_list}]")
                    print(log)
                    make_log(self.path["tr_log"], log)

            # 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数
            self.save_checkpoint(
                curr_epoch + 1,
                full_net_path=self.path["final_full_net"],
                state_net_path=self.path["final_state_net"],
            )

        # 进行最终的测试,首先输出验证结果
        print(f" ==>> 训练结束 <<== ")

        for data_name, data_path in self.te_data_list.items():
            print(f" ==>> 使用测试集{data_name}测试 <<== ")
            self.te_loader, self.te_length = create_loader(
                data_path=data_path,
                mode="test",
                get_length=True,
                data_mode=self.data_mode,
            )
            self.save_path = os.path.join(self.path["save"], data_name)
            if not os.path.exists(self.save_path):
                print(f" ==>> {self.save_path} 不存在, 这里创建一个 <<==")
                os.makedirs(self.save_path)
            results = self.test(save_pre=self.save_pre)
            fixed_pre_results = {k: f"{v:.3f}" for k, v in results.items()}
            msg = f" ==>> 在{data_name}:'{data_path}'测试集上结果\n >> {fixed_pre_results}"
            print(msg)
            make_log(self.path["te_log"], msg)