def real_image_test(): imgs = sorted([ os.path.join("D:\\DataSets\\MyFishData\\0119_image\\", img) for img in os.listdir("D:\\DataSets\\MyFishData\\0119_image\\") ]) resnet = resnet18(pretrained=True).to(torch.device("cuda")) model = SwiftNet(resnet, num_classes=21) model = model.to(torch.device("cuda")) checkpoint = torch.load("checkpoints/CKPT/16.pth") # print("Load",Config.ckpt_path) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() model = model.to(torch.device("cuda")) for img in imgs: image, label = run_image(img, model) cv2.imshow("image", image) cv2.imshow("label", label) esc_flag = False while True: key = cv2.waitKey(20) & 0xFF if key == 27: esc_flag = True return elif key == ord("p"): break
def all_eval(): resnet = resnet18(pretrained=True) model = SwiftNet(resnet, num_classes=21) ckpt_index = [x for x in range(22, 26)] # ckpt_index = [1, 4, 5, 6, 8, 9, 10, 11] ckpt = ["checkpoints/CKPT/" + str(x) + ".pth" for x in ckpt_index] # valid_path = ['\\val_400f'] # valid_path = ['\\val_200f','\\val_250f','\\val_300f','\\val_350f','\\val_400f','\\val_7DOF', '\\val_rotate10'] valid_path = [ "\\val_200f", "\\val_250f", "\\val_300f", "\\val_350f", "\\val_400f" ] for i in ckpt: for j in valid_path: one_eval(model, i, j, j + "_annot", is_distortion=True)
def final_eval(): valid_path = [ "\\val_200f", "\\val_250f", "\\val_300f", "\\val_350f", "\\val_400f" ] # valid_path = ['\\val_200f','\\val_250f','\\val_300f','\\val_350f','\\val_400f'] # valid_path = ['\\val_rotate10'] valid_annot = [x + "_annot" for x in valid_path] # model = ERFPSPNet(shapeHW=[640, 640], num_classes=21) resnet = resnet18(pretrained=True) model = SwiftNet(resnet, num_classes=21) model.to(MyDevice) checkpoint = torch.load(Config.ckpt_path) print("Load", Config.ckpt_path) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() for i in range(len(valid_path)): validation_set = CityScape(Config.data_dir + valid_path[i], Config.data_dir + valid_annot[i]) validation_loader = DataLoader(validation_set, batch_size=Config.val_batch_size, shuffle=False) print("\n", valid_path[i]) val_distortion(model, validation_loader, is_print=True)
def train(): train_transform = MyTransform(Config.f, Config.fish_size) train_transform.set_ext_params(Config.ext_param) train_transform.set_ext_param_range(Config.ext_range) if Config.rand_f: train_transform.rand_f(f_range=Config.f_range) if Config.rand_ext: train_transform.rand_ext_params() train_transform.set_bkg(bkg_label=20, bkg_color=[0, 0, 0]) train_transform.set_crop(rand=Config.crop, rate=Config.crop_rate) # train_transform = RandOneTransform(Config.f, Config.fish_size) # train_transform.set_ext_params(Config.ext_param) # train_transform.set_ext_param_range(Config.ext_range) # train_transform.set_f_range(Config.f_range) # train_transform.set_bkg(bkg_label=20, bkg_color=[0, 0, 0]) # train_transform.set_crop(rand=Config.crop, rate=Config.crop_rate) train_set = CityScape(Config.train_img_dir, Config.train_annot_dir, transform=train_transform) train_loader = DataLoader( train_set, batch_size=Config.batch_size, shuffle=True, num_workers=Config.dataloader_num_worker, ) validation_set = CityScape(Config.valid_img_dir, Config.valid_annot_dir) validation_loader = DataLoader(validation_set, batch_size=Config.val_batch_size, shuffle=False) # model = ERFPSPNet(shapeHW=[640, 640], num_classes=21) resnet = resnet18(pretrained=True) model = SwiftNet(resnet, num_classes=21) model.to(MyDevice) class_weights = torch.tensor([ 8.6979065, 8.497886, 8.741297, 5.983605, 8.662319, 8.681756, 8.683093, 8.763641, 8.576978, 2.7114885, 6.237076, 3.582358, 8.439253, 8.316548, 8.129169, 4.312109, 8.170293, 6.91469, 8.135018, 0.0, 3.6, ]).cuda() # criterion = CrossEntropyLoss2d(weight=class_weights) criterion = FocalLoss2d(weight=class_weights) lr = Config.learning_rate # Pretrained SwiftNet optimizer optimizer = torch.optim.Adam( [ { "params": model.random_init_params() }, { "params": model.fine_tune_params(), "lr": 1e-4, "weight_decay": 2.5e-5 }, ], lr=4e-4, weight_decay=1e-4, ) # ERFNetPSP optimizer # optimizer = torch.optim.Adam(model.parameters(), # lr=1e-3, # betas=(0.9, 0.999), # eps=1e-08, # weight_decay=2e-4) # scheduler = torch.optim.lr_scheduler.StepLR( # optimizer, step_size=90, gamma=0.1) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, 200, 1e-6) start_epoch = 0 step_per_epoch = math.ceil(2975 / Config.batch_size) writer = SummaryWriter(Config.logdir) # writer.add_graph(model) if Config.train_with_ckpt: checkpoint = torch.load(Config.ckpt_path) print("Load", Config.ckpt_path) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_epoch = checkpoint["epoch"] + 1 val(model, validation_loader, is_print=True) # loss = checkpoint['loss'] model.train() start_time = None for epoch in range(start_epoch, Config.max_epoch): for i, (image, annot) in enumerate(train_loader): if start_time is None: start_time = time.time() input = image.to(MyDevice) target = annot.to(MyDevice, dtype=torch.long) model.train() optimizer.zero_grad() score = model(input) # predict = torch.argmax(score, 1) loss = criterion(score, target) loss.backward() optimizer.step() global_step = step_per_epoch * epoch + i if i % 20 == 0: predict = torch.argmax(score, 1).to(MyCPU, dtype=torch.uint8) writer.add_image("Images/original_image", image[0], global_step=global_step) writer.add_image( "Images/segmentation_output", predict[0].view(1, 640, 640) * 10, global_step=global_step, ) writer.add_image( "Images/segmentation_ground_truth", annot[0].view(1, 640, 640) * 10, global_step=global_step, ) if i % 20 == 0 and global_step > 0: writer.add_scalar("Monitor/Loss", loss.item(), global_step=global_step) time_elapsed = time.time() - start_time start_time = time.time() print( f"{epoch}/{Config.max_epoch-1} epoch, {i}/{step_per_epoch} step, loss:{loss.item()}, " f"{time_elapsed} sec/step; global step={global_step}") scheduler.step() if epoch > 20: ( mean_precision, mean_recall, mean_iou, m_precision_19, m_racall_19, m_iou_19, ) = val(model, validation_loader, is_print=True) writer.add_scalar("Monitor/precision20", mean_precision, global_step=epoch) writer.add_scalar("Monitor/recall20", mean_recall, global_step=epoch) writer.add_scalar("Monitor/mIOU20", mean_iou, global_step=epoch) writer.add_scalar("Monitor1/precision19", m_precision_19, global_step=epoch) writer.add_scalar("Monitor1/recall19", m_racall_19, global_step=epoch) writer.add_scalar("Monitor1/mIOU19", m_iou_19, global_step=epoch) print(epoch, "/", Config.max_epoch, " loss:", loss.item()) torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss.item(), "optimizer_state_dict": optimizer.state_dict(), }, Config.ckpt_name + "_" + str(epoch) + ".pth", ) print("model saved!") val(model, validation_loader, is_print=True) torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, }, Config.model_path, ) print("Save model to disk!") writer.close()