def eval(self, epoch=0, model_file_name=None, result_path=None):
        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.save_result_txt)
            self.load_model(model_file_name)
            pass

        avg_meter = AverageMeter()
        ss_meter = StreamSegMetrics(self.config.num_classes + 1)
        self.net.eval()
        with torch.no_grad():
            for i, (inputs, masks, labels,
                    image_info_list) in tqdm(enumerate(self.data_loader_val),
                                             total=len(self.data_loader_val)):
                inputs = inputs.float().cuda()
                masks, labels = masks.numpy(), labels.numpy()

                result = self.net.module.forward_inference(
                    inputs,
                    has_class=self.config.has_class,
                    has_cam=self.config.has_cam,
                    has_ss=self.config.has_ss)

                # SS
                if self.config.has_ss:
                    ss_out = result["ss"]["out_up"].detach().max(
                        dim=1)[1].cpu().numpy()
                    ss_meter.update(masks, ss_out)

                    if result_path is not None:
                        for image_info_one, ss_out_one, mask_one in zip(
                                image_info_list, ss_out, masks):
                            result_file = Tools.new_dir(
                                os.path.join(result_path,
                                             os.path.basename(image_info_one)))
                            Image.open(image_info_one).save(result_file)
                            DataUtil.gray_to_color(
                                np.asarray(ss_out_one, dtype=np.uint8)).save(
                                    result_file.replace(".jpg", "_p.png"))
                            DataUtil.gray_to_color(
                                np.asarray(mask_one, dtype=np.uint8)).save(
                                    result_file.replace(".jpg", "_l.png"))
                            pass
                        pass

                    pass

                # Class
                class_out = torch.sigmoid(
                    result["class_logits"]).detach().cpu().numpy()
                one, zero = labels == 1, labels != 1
                avg_meter.update(
                    "mae", (np.abs(class_out[one] - labels[one]).mean() +
                            np.abs(class_out[zero] - labels[zero]).mean()) / 2)
                avg_meter.update(
                    "f1",
                    metrics.f1_score(y_true=labels,
                                     y_pred=class_out > 0.5,
                                     average='micro'))
                avg_meter.update("acc",
                                 self._acc(net_out=class_out, labels=labels))
                pass
            pass

        Tools.print("[E:{:3d}] val mae:{:.4f} f1:{:.4f} acc:{:.4f}".format(
            epoch, avg_meter.get_results("mae"), avg_meter.get_results("f1"),
            avg_meter.get_results("acc")),
                    txt_path=self.config.save_result_txt)
        if self.config.has_ss:
            Tools.print("[E:{:3d}] ss {}".format(
                epoch, ss_meter.to_str(ss_meter.get_results())))
            pass
        pass
    def train(self, start_epoch=0, model_file_name=None):

        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.save_result_txt)
            self.load_model(model_file_name)
            pass

        for epoch in range(start_epoch, self.config.epoch_num):
            Tools.print()
            self._adjust_learning_rate(self.optimizer,
                                       epoch,
                                       lr=self.config.lr,
                                       change_epoch=self.config.change_epoch)
            Tools.print('Epoch:{:03d}, lr={:.6f}'.format(
                epoch, self.optimizer.param_groups[0]['lr']),
                        txt_path=self.config.save_result_txt)

            ###########################################################################
            # 1 训练模型
            avg_meter = AverageMeter()
            self.net.train()
            for i, (pair_labels, inputs, masks,
                    labels) in tqdm(enumerate(self.data_loader_train),
                                    total=len(self.data_loader_train)):
                pair_labels = pair_labels.long().cuda()
                x1, x2 = inputs[0].float().cuda(), inputs[1].float().cuda()
                # mask1, mask2 = masks[0].cuda(), masks[1].cuda()
                label1, label2 = labels[0].cuda(), labels[1].cuda()
                self.optimizer.zero_grad()

                result = self.net(x1,
                                  x2,
                                  pair_labels,
                                  label1,
                                  label2,
                                  has_class=self.config.has_class,
                                  has_cam=self.config.has_cam,
                                  has_ss=self.config.has_ss)

                # 分类损失
                class_logits = result["class_logits"]
                loss_class = 10 * (
                    self.bce_with_logits_loss(class_logits["x1"], label1) +
                    self.bce_with_logits_loss(class_logits["x2"], label2))
                loss = loss_class
                avg_meter.update("loss_class", loss_class.item())

                if self.config.has_ss:
                    ####################################################################################################
                    # CAM最大值掩码, 最小值掩码
                    ss_where_cam_mask_large_1 = torch.squeeze(
                        result["our"]["cam_mask_large_1"], dim=1) > 0.5
                    ss_where_cam_mask_large_2 = torch.squeeze(
                        result["our"]["cam_mask_large_2"], dim=1) > 0.5
                    ss_where_cam_mask_min_large_1 = torch.squeeze(
                        result["our"]["cam_mask_min_large_1"], dim=1) > 0.5
                    ss_where_cam_mask_min_large_2 = torch.squeeze(
                        result["our"]["cam_mask_min_large_2"], dim=1) > 0.5
                    ss_value_cam_mask_large_1 = result["our"][
                        "d5_mask_2_to_1"][ss_where_cam_mask_large_1]  # 1
                    ss_value_cam_mask_large_2 = result["our"][
                        "d5_mask_1_to_2"][ss_where_cam_mask_large_2]  # 1
                    ss_value_cam_mask_min_large_12 = result["our"][
                        "d5_mask_neg_2_to_1"][ss_where_cam_mask_large_1]  # 0
                    ss_value_cam_mask_min_large_22 = result["our"][
                        "d5_mask_neg_1_to_2"][ss_where_cam_mask_large_2]  # 0
                    ss_value_cam_mask_large_12 = result["our"][
                        "d5_mask_2_to_1"][ss_where_cam_mask_min_large_1]  # 0
                    ss_value_cam_mask_large_22 = result["our"][
                        "d5_mask_1_to_2"][ss_where_cam_mask_min_large_2]  # 0

                    # 特征相似度损失
                    loss_ss = 0
                    #########################################
                    if len(ss_value_cam_mask_large_1) > 0:
                        loss_ss = self.bce_loss(
                            ss_value_cam_mask_large_1,
                            torch.ones_like(ss_value_cam_mask_large_1))
                    if len(ss_value_cam_mask_large_2) > 0:
                        loss_ss += self.bce_loss(
                            ss_value_cam_mask_large_2,
                            torch.ones_like(ss_value_cam_mask_large_2))
                    if len(ss_value_cam_mask_min_large_12) > 0:
                        loss_ss += self.bce_loss(
                            ss_value_cam_mask_min_large_12,
                            torch.zeros_like(ss_value_cam_mask_min_large_12))
                    if len(ss_value_cam_mask_min_large_22) > 0:
                        loss_ss += self.bce_loss(
                            ss_value_cam_mask_min_large_22,
                            torch.zeros_like(ss_value_cam_mask_min_large_22))
                    if len(ss_value_cam_mask_large_12) > 0:
                        loss_ss += self.bce_loss(
                            ss_value_cam_mask_large_12,
                            torch.zeros_like(ss_value_cam_mask_large_12))
                    if len(ss_value_cam_mask_large_22) > 0:
                        loss_ss += self.bce_loss(
                            ss_value_cam_mask_large_22,
                            torch.zeros_like(ss_value_cam_mask_large_22))
                    #########################################
                    if loss_ss > 0:
                        loss = loss + loss_ss
                        avg_meter.update("loss_ss", loss_ss.item())
                        pass
                    ####################################################################################################

                    ####################################################################################################
                    # 输出的正标签
                    ce_where_cam_mask_large_1 = ss_where_cam_mask_large_1
                    ce_mask_large_1 = torch.ones_like(
                        ce_where_cam_mask_large_1).long() * 255
                    now_pair_labels_1 = (pair_labels + 1).view(
                        -1, 1, 1).expand_as(ce_mask_large_1)
                    ce_mask_large_1[
                        ce_where_cam_mask_large_1] = now_pair_labels_1[
                            ce_where_cam_mask_large_1]
                    ce_where_cam_mask_large_2 = ss_where_cam_mask_large_2
                    ce_mask_large_2 = torch.ones_like(
                        ce_where_cam_mask_large_2).long() * 255
                    now_pair_labels_2 = (pair_labels + 1).view(
                        -1, 1, 1).expand_as(ce_mask_large_2)
                    ce_mask_large_2[
                        ce_where_cam_mask_large_2] = now_pair_labels_2[
                            ce_where_cam_mask_large_2]

                    # 输出的负标签
                    ce_where_cam_mask_min_large_1 = ss_where_cam_mask_min_large_1
                    ce_mask_min_large_1 = torch.ones_like(
                        ce_where_cam_mask_min_large_1).long() * 255
                    ce_mask_min_large_1[ce_where_cam_mask_min_large_1] = 0
                    ce_where_cam_mask_min_large_2 = ss_where_cam_mask_min_large_2
                    ce_mask_min_large_2 = torch.ones_like(
                        ce_where_cam_mask_min_large_2).long() * 255
                    ce_mask_min_large_2[ce_where_cam_mask_min_large_2] = 0

                    # 预测损失
                    loss_ce = self.ce_loss(result["ss"]["out_1"], ce_mask_large_1) + \
                              self.ce_loss(result["ss"]["out_2"], ce_mask_large_2) + \
                              self.ce_loss(result["ss"]["out_1"], ce_mask_min_large_1) + \
                              self.ce_loss(result["ss"]["out_2"], ce_mask_min_large_2)
                    loss = loss + loss_ce
                    avg_meter.update("loss_ce", loss_ce.item())
                    ####################################################################################################
                    pass

                loss.backward()
                self.optimizer.step()
                avg_meter.update("loss", loss.item())
                pass
            ###########################################################################

            Tools.print(
                "[E:{:3d}/{:3d}] loss:{:.4f} class:{:.4f} ss:{:.4f} ce:{:.4f}".
                format(
                    epoch, self.config.epoch_num,
                    avg_meter.get_results("loss"),
                    avg_meter.get_results("loss_class"),
                    avg_meter.get_results("loss_ss")
                    if self.config.has_ss else 0.0,
                    avg_meter.get_results("loss_ce")
                    if self.config.has_ss else 0.0),
                txt_path=self.config.save_result_txt)

            ###########################################################################
            # 2 保存模型
            if epoch % self.config.save_epoch_freq == 0:
                Tools.print()
                save_file_name = Tools.new_dir(
                    os.path.join(self.config.model_dir,
                                 "{}.pth".format(epoch)))
                torch.save(self.net.state_dict(), save_file_name)
                Tools.print("Save Model to {}".format(save_file_name),
                            txt_path=self.config.save_result_txt)
                Tools.print()
                pass
            ###########################################################################

            ###########################################################################
            # 3 评估模型
            if epoch % self.config.eval_epoch_freq == 0:
                self.eval(epoch=epoch)
                pass
            ###########################################################################

            pass

        # Final Save
        Tools.print()
        save_file_name = Tools.new_dir(
            os.path.join(self.config.model_dir,
                         "final_{}.pth".format(self.config.epoch_num)))
        torch.save(self.net.state_dict(), save_file_name)
        Tools.print("Save Model to {}".format(save_file_name),
                    txt_path=self.config.save_result_txt)
        Tools.print()

        self.eval(epoch=self.config.epoch_num)
        pass
示例#3
0
    def train(self, start_epoch=0, model_file_name=None):

        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.save_result_txt)
            self.load_model(model_file_name)
            pass

        for epoch in range(start_epoch, self.config.epoch_num):
            Tools.print()
            Tools.print('Epoch:{:2d}, lr={:.6f} lr2={:.6f}'.format(
                epoch, self.optimizer.param_groups[0]['lr'],
                self.optimizer.param_groups[1]['lr']),
                        txt_path=self.config.save_result_txt)

            ###########################################################################
            # 1 训练模型
            self.net.train()
            avg_meter = AverageMeter()
            for i, (pair_labels, inputs, masks,
                    labels) in tqdm(enumerate(self.data_loader_train),
                                    total=len(self.data_loader_train)):
                pair_labels = pair_labels.long().cuda()
                x1, x2 = inputs[0].float().cuda(), inputs[1].float().cuda()
                mask1, mask2 = masks[0].long().cuda(), masks[1].long().cuda()
                label1, label2 = labels[0].cuda(), labels[1].cuda()
                self.optimizer.zero_grad()

                result = self.net(x1,
                                  x2,
                                  pair_labels,
                                  has_class=self.config.has_class,
                                  has_cam=self.config.has_cam,
                                  has_ss=self.config.has_ss)

                loss = 0

                ####################################################################################################
                # 分类损失
                if self.config.has_class:
                    class_logits = result["class_logits"]
                    loss_class = 5 * (
                        self.bce_loss(class_logits["x1"], label1) +
                        self.bce_loss(class_logits["x2"], label2))
                    loss = loss + loss_class
                    avg_meter.update("loss_class", loss_class.item())
                    pass
                ####################################################################################################

                ####################################################################################################
                # 激活图损失, 特征相似度损失
                if self.config.has_cam:
                    where_1 = torch.squeeze(
                        (result["cam"]["cam_norm_aff_1"].detach() < 1e-6) |
                        (result["cam"]["cam_norm_aff_1"].detach() >
                         self.config.a),
                        dim=1)
                    where_2 = torch.squeeze(
                        (result["cam"]["cam_norm_aff_2"].detach() < 1e-6) |
                        (result["cam"]["cam_norm_aff_2"].detach() >
                         self.config.a),
                        dim=1)
                    mask_where_1 = torch.squeeze(
                        result["cam"]["cam_norm_aff_1"].detach() >
                        self.config.b,
                        dim=1)
                    mask_where_2 = torch.squeeze(
                        result["cam"]["cam_norm_aff_2"].detach() >
                        self.config.b,
                        dim=1)

                    cam_mask_large_1 = torch.zeros_like(
                        result["our"]["d5_mask_2_to_1"])
                    cam_mask_large_1[mask_where_1] = 1
                    cam_mask_large_2 = torch.zeros_like(
                        result["our"]["d5_mask_1_to_2"])
                    cam_mask_large_2[mask_where_2] = 1

                    # 激活图损失
                    loss_cam = self.mse_loss(torch.squeeze(result["cam"]["cam_norm_1"], dim=1)[where_1], cam_mask_large_1[where_1]) + \
                               self.mse_loss(torch.squeeze(result["cam"]["cam_norm_2"], dim=1)[where_2], cam_mask_large_2[where_2])
                    loss = loss + loss_cam
                    avg_meter.update("loss_cam", loss_cam.item())
                    ##################################################
                    # 特征相似度损失
                    loss_ss = self.mse_loss(result["our"]["d5_mask_2_to_1"][where_1], cam_mask_large_1[where_1]) + \
                              self.mse_loss(result["our"]["d5_mask_1_to_2"][where_2], cam_mask_large_2[where_2])
                    loss = loss + loss_ss
                    avg_meter.update("loss_ss", loss_ss.item())
                    pass
                ####################################################################################################

                ####################################################################################################
                # 预测损失
                if self.config.has_ss:
                    final_mask1 = torch.zeros_like(mask1) + 255
                    black1 = torch.squeeze(
                        (result["cam"]["cam_norm_aff_1"].detach() < 1e-6),
                        dim=1)
                    white1 = torch.squeeze(
                        (result["cam"]["cam_norm_aff_1"].detach() >
                         self.config.a),
                        dim=1)
                    final_mask1[black1] = 0
                    final_mask1[white1] = 1
                    final_mask1 = final_mask1 * (pair_labels + 1).view(
                        -1, 1, 1).expand_as(final_mask1)
                    final_mask1[final_mask1 >= 255] = 255

                    final_mask2 = torch.zeros_like(mask2) + 255
                    black2 = torch.squeeze(
                        (result["cam"]["cam_norm_aff_2"].detach() < 1e-6),
                        dim=1)
                    white2 = torch.squeeze(
                        (result["cam"]["cam_norm_aff_2"].detach() >
                         self.config.a),
                        dim=1)
                    final_mask2[black2] = 0
                    final_mask2[white2] = 1
                    final_mask2 = final_mask2 * (pair_labels + 1).view(
                        -1, 1, 1).expand_as(final_mask2)
                    final_mask2[final_mask2 >= 255] = 255

                    loss_ce = self.ce_loss(result["ss"]["out_up_1"], final_mask1) + \
                              self.ce_loss(result["ss"]["out_up_2"], final_mask2)
                    loss = loss + loss_ce
                    avg_meter.update("loss_ce", loss_ce.item())
                    pass
                ####################################################################################################

                loss.backward()
                self.optimizer.step()
                avg_meter.update("loss", loss.item())
                pass
            self.scheduler.step()
            ###########################################################################

            Tools.print(
                "[E:{:3d}/{:3d}] loss:{:.4f} class:{:.4f} ss:{:.4f} ce:{:.4f} cam:{:.4f}"
                .format(
                    epoch, self.config.epoch_num,
                    avg_meter.get_results("loss"),
                    avg_meter.get_results("loss_class")
                    if self.config.has_class else 0.0,
                    avg_meter.get_results("loss_ss")
                    if self.config.has_cam else 0.0,
                    avg_meter.get_results("loss_ce")
                    if self.config.has_ss else 0.0,
                    avg_meter.get_results("loss_cam")
                    if self.config.has_cam else 0.0),
                txt_path=self.config.save_result_txt)

            ###########################################################################
            # 2 保存模型
            if epoch % self.config.save_epoch_freq == 0:
                Tools.print()
                save_file_name = Tools.new_dir(
                    os.path.join(self.config.model_dir,
                                 "{}.pth".format(epoch)))
                torch.save(self.net.state_dict(), save_file_name)
                Tools.print("Save Model to {}".format(save_file_name),
                            txt_path=self.config.save_result_txt)
                Tools.print()
                pass
            ###########################################################################

            ###########################################################################
            # 3 评估模型
            if epoch % self.config.eval_epoch_freq == 0:
                self.eval(epoch=epoch)
                pass
            ###########################################################################

            pass

        # Final Save
        Tools.print()
        save_file_name = Tools.new_dir(
            os.path.join(self.config.model_dir,
                         "final_{}.pth".format(self.config.epoch_num)))
        torch.save(self.net.state_dict(), save_file_name)
        Tools.print("Save Model to {}".format(save_file_name),
                    txt_path=self.config.save_result_txt)
        Tools.print()

        self.eval(epoch=self.config.epoch_num)
        pass