Ejemplo n.º 1
0
def real_image_test():
    imgs = sorted([
        os.path.join("D:\\DataSets\\MyFishData\\0119_image\\", img)
        for img in os.listdir("D:\\DataSets\\MyFishData\\0119_image\\")
    ])
    resnet = resnet18(pretrained=True).to(torch.device("cuda"))
    model = SwiftNet(resnet, num_classes=21)
    model = model.to(torch.device("cuda"))
    checkpoint = torch.load("checkpoints/CKPT/16.pth")
    # print("Load",Config.ckpt_path)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    model = model.to(torch.device("cuda"))
    for img in imgs:
        image, label = run_image(img, model)
        cv2.imshow("image", image)
        cv2.imshow("label", label)
        esc_flag = False
        while True:
            key = cv2.waitKey(20) & 0xFF
            if key == 27:
                esc_flag = True
                return
            elif key == ord("p"):
                break
Ejemplo n.º 2
0
def all_eval():
    resnet = resnet18(pretrained=True)
    model = SwiftNet(resnet, num_classes=21)
    ckpt_index = [x for x in range(22, 26)]
    # ckpt_index = [1, 4, 5, 6, 8, 9, 10, 11]
    ckpt = ["checkpoints/CKPT/" + str(x) + ".pth" for x in ckpt_index]
    # valid_path = ['\\val_400f']
    # valid_path = ['\\val_200f','\\val_250f','\\val_300f','\\val_350f','\\val_400f','\\val_7DOF', '\\val_rotate10']
    valid_path = [
        "\\val_200f", "\\val_250f", "\\val_300f", "\\val_350f", "\\val_400f"
    ]
    for i in ckpt:
        for j in valid_path:
            one_eval(model, i, j, j + "_annot", is_distortion=True)
Ejemplo n.º 3
0
def final_eval():
    valid_path = [
        "\\val_200f", "\\val_250f", "\\val_300f", "\\val_350f", "\\val_400f"
    ]
    # valid_path = ['\\val_200f','\\val_250f','\\val_300f','\\val_350f','\\val_400f']
    # valid_path = ['\\val_rotate10']

    valid_annot = [x + "_annot" for x in valid_path]
    # model = ERFPSPNet(shapeHW=[640, 640], num_classes=21)
    resnet = resnet18(pretrained=True)
    model = SwiftNet(resnet, num_classes=21)
    model.to(MyDevice)
    checkpoint = torch.load(Config.ckpt_path)
    print("Load", Config.ckpt_path)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    for i in range(len(valid_path)):
        validation_set = CityScape(Config.data_dir + valid_path[i],
                                   Config.data_dir + valid_annot[i])
        validation_loader = DataLoader(validation_set,
                                       batch_size=Config.val_batch_size,
                                       shuffle=False)
        print("\n", valid_path[i])
        val_distortion(model, validation_loader, is_print=True)
Ejemplo n.º 4
0
def train():
    train_transform = MyTransform(Config.f, Config.fish_size)
    train_transform.set_ext_params(Config.ext_param)
    train_transform.set_ext_param_range(Config.ext_range)
    if Config.rand_f:
        train_transform.rand_f(f_range=Config.f_range)
    if Config.rand_ext:
        train_transform.rand_ext_params()
    train_transform.set_bkg(bkg_label=20, bkg_color=[0, 0, 0])
    train_transform.set_crop(rand=Config.crop, rate=Config.crop_rate)

    # train_transform = RandOneTransform(Config.f, Config.fish_size)
    # train_transform.set_ext_params(Config.ext_param)
    # train_transform.set_ext_param_range(Config.ext_range)
    # train_transform.set_f_range(Config.f_range)
    # train_transform.set_bkg(bkg_label=20, bkg_color=[0, 0, 0])
    # train_transform.set_crop(rand=Config.crop, rate=Config.crop_rate)

    train_set = CityScape(Config.train_img_dir,
                          Config.train_annot_dir,
                          transform=train_transform)
    train_loader = DataLoader(
        train_set,
        batch_size=Config.batch_size,
        shuffle=True,
        num_workers=Config.dataloader_num_worker,
    )

    validation_set = CityScape(Config.valid_img_dir, Config.valid_annot_dir)
    validation_loader = DataLoader(validation_set,
                                   batch_size=Config.val_batch_size,
                                   shuffle=False)

    # model = ERFPSPNet(shapeHW=[640, 640], num_classes=21)
    resnet = resnet18(pretrained=True)
    model = SwiftNet(resnet, num_classes=21)
    model.to(MyDevice)

    class_weights = torch.tensor([
        8.6979065,
        8.497886,
        8.741297,
        5.983605,
        8.662319,
        8.681756,
        8.683093,
        8.763641,
        8.576978,
        2.7114885,
        6.237076,
        3.582358,
        8.439253,
        8.316548,
        8.129169,
        4.312109,
        8.170293,
        6.91469,
        8.135018,
        0.0,
        3.6,
    ]).cuda()

    # criterion = CrossEntropyLoss2d(weight=class_weights)
    criterion = FocalLoss2d(weight=class_weights)

    lr = Config.learning_rate

    # Pretrained SwiftNet optimizer
    optimizer = torch.optim.Adam(
        [
            {
                "params": model.random_init_params()
            },
            {
                "params": model.fine_tune_params(),
                "lr": 1e-4,
                "weight_decay": 2.5e-5
            },
        ],
        lr=4e-4,
        weight_decay=1e-4,
    )

    # ERFNetPSP optimizer
    # optimizer = torch.optim.Adam(model.parameters(),
    #                              lr=1e-3,
    #                              betas=(0.9, 0.999),
    #                              eps=1e-08,
    #                              weight_decay=2e-4)

    # scheduler = torch.optim.lr_scheduler.StepLR(
    #     optimizer, step_size=90, gamma=0.1)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 200, 1e-6)

    start_epoch = 0
    step_per_epoch = math.ceil(2975 / Config.batch_size)
    writer = SummaryWriter(Config.logdir)
    # writer.add_graph(model)

    if Config.train_with_ckpt:
        checkpoint = torch.load(Config.ckpt_path)
        print("Load", Config.ckpt_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        val(model, validation_loader, is_print=True)
        # loss = checkpoint['loss']
        model.train()

    start_time = None
    for epoch in range(start_epoch, Config.max_epoch):

        for i, (image, annot) in enumerate(train_loader):
            if start_time is None:
                start_time = time.time()
            input = image.to(MyDevice)
            target = annot.to(MyDevice, dtype=torch.long)
            model.train()
            optimizer.zero_grad()
            score = model(input)
            # predict = torch.argmax(score, 1)
            loss = criterion(score, target)
            loss.backward()
            optimizer.step()

            global_step = step_per_epoch * epoch + i

            if i % 20 == 0:
                predict = torch.argmax(score, 1).to(MyCPU, dtype=torch.uint8)
                writer.add_image("Images/original_image",
                                 image[0],
                                 global_step=global_step)
                writer.add_image(
                    "Images/segmentation_output",
                    predict[0].view(1, 640, 640) * 10,
                    global_step=global_step,
                )
                writer.add_image(
                    "Images/segmentation_ground_truth",
                    annot[0].view(1, 640, 640) * 10,
                    global_step=global_step,
                )

            if i % 20 == 0 and global_step > 0:
                writer.add_scalar("Monitor/Loss",
                                  loss.item(),
                                  global_step=global_step)

            time_elapsed = time.time() - start_time
            start_time = time.time()
            print(
                f"{epoch}/{Config.max_epoch-1} epoch, {i}/{step_per_epoch} step, loss:{loss.item()}, "
                f"{time_elapsed} sec/step; global step={global_step}")

        scheduler.step()
        if epoch > 20:
            (
                mean_precision,
                mean_recall,
                mean_iou,
                m_precision_19,
                m_racall_19,
                m_iou_19,
            ) = val(model, validation_loader, is_print=True)

            writer.add_scalar("Monitor/precision20",
                              mean_precision,
                              global_step=epoch)
            writer.add_scalar("Monitor/recall20",
                              mean_recall,
                              global_step=epoch)
            writer.add_scalar("Monitor/mIOU20", mean_iou, global_step=epoch)
            writer.add_scalar("Monitor1/precision19",
                              m_precision_19,
                              global_step=epoch)
            writer.add_scalar("Monitor1/recall19",
                              m_racall_19,
                              global_step=epoch)
            writer.add_scalar("Monitor1/mIOU19", m_iou_19, global_step=epoch)

            print(epoch, "/", Config.max_epoch, " loss:", loss.item())
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "loss": loss.item(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                Config.ckpt_name + "_" + str(epoch) + ".pth",
            )
            print("model saved!")

    val(model, validation_loader, is_print=True)
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": loss,
        },
        Config.model_path,
    )
    print("Save model to disk!")
    writer.close()