Beispiel #1
0
def test_moving(opt, log_dir, generator=None):
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    if generator == None:
        generator = UNet(opt.sample_num, opt.channels, opt.batch_size,
                         opt.alpha)

        checkpoint = torch.load(opt.load_model, map_location=device)
        generator.load_state_dict(checkpoint['g_state_dict'])
        del checkpoint
        torch.cuda.empty_cache()

    generator.to(device)
    generator.eval()

    dataloader = torch.utils.data.DataLoader(
        MyDataset_test_moving(opt),
        opt.batch_size,
        shuffle=True,
        num_workers=opt.num_workers_dataloader)

    for i, (imgs, filename) in enumerate(dataloader):
        with torch.no_grad():
            filename = filename[0].split('/')[-1]
            for k in range(len(imgs)):
                test_img = generator(imgs[k].to(device))
                folder_path = os.path.join(log_dir, "test/%s" % filename)
                os.makedirs(folder_path, exist_ok=True)
                filename_ = filename + '_' + str(k) + '.png'
                test_img = convert_im(test_img,
                                      os.path.join(folder_path, filename_),
                                      nrow=5,
                                      normalize=True,
                                      save_im=True)
Beispiel #2
0
def run_pipeline(root_dir, model_path, img_size, batch_size, use_gpu):
    images_path = os.path.join(root_dir, "images")
    masks_path = os.path.join(root_dir, "masks")
    outputs_path = os.path.join(root_dir, "outputs")

    sizes = []
    file_names = []
    for f in os.listdir(images_path):
        im = Image.open(os.path.join(images_path, f))
        sizes.append(im.size)
        file_names.append(f)

    model = UNet(num_channels=1, num_classes=2)
    use_gpu = use_gpu and torch.cuda.is_available()
    if use_gpu:
        model.cuda()
    model.load_state_dict(torch.load(model_path, map_location='cpu'))

    test_dataset = TestDataset(images_path,
                               im_size=[img_size, img_size],
                               transform=tr.ToTensor())
    test_loader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=1)
    print("Test Data : ", len(test_loader.dataset))
    start = timer()
    predict(model, test_loader, batch_size, sizes, file_names, masks_path,
            use_gpu)
    end = timer()
    print("Prediction completed in {:0.2f}s".format(end - start))
    generate_bbox(images_path, masks_path, outputs_path)
    end2 = timer()
    print("Bbox generation completed in {:0.2f}s".format(end2 - end))
Beispiel #3
0
def start():
    model = UNet(num_channels=1, num_classes=2)
    criterion = DICELossMultiClass()

    dir_path = './Data/images/'
    sizes = []
    file_names = []
    for f in os.listdir(dir_path):
        im = Image.open(os.path.join(dir_path, f))
        print(im.size)
        sizes.append(im.size)
        file_names.append(f)
    print(sizes)
    print(file_names)

    test_dataset = TestDataset('./Data/',
                               im_size=[256, 256],
                               transform=tr.ToTensor())
    test_loader = DataLoader(test_dataset,
                             batch_size=4,
                             shuffle=False,
                             num_workers=1)
    print("Test Data : ", len(test_loader.dataset))
    model.load_state_dict(
        torch.load('unet-model-16-100-0.001', map_location='cpu'))
    test_only(model, test_loader, criterion, sizes, file_names)
Beispiel #4
0
def train(device, model_path, dataset_path):
    """
    Trains the network according on the dataset_path
    """
    network = UNet(1, 3).to(device)
    optimizer = torch.optim.Adam(network.parameters())
    criteria = torch.nn.MSELoss()

    dataset = GrayColorDataset(dataset_path, transform=train_transform)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=16,
                                         shuffle=True,
                                         num_workers=cpu_count())

    if os.path.exists(model_path):
        network.load_state_dict(torch.load(model_path))
    for _ in tqdm.trange(10, desc="Epoch"):
        network.train()
        for gray, color in tqdm.tqdm(loader, desc="Training", leave=False):
            gray, color = gray.to(device), color.to(device)
            optimizer.zero_grad()
            pred_color = network(gray)
            loss = criteria(pred_color, color)
            loss.backward()
            optimizer.step()
        torch.save(network.state_dict(), model_path)
    def __init__(self, proxy_map, opts,startup_check=False):
        super(SpecificWorker, self).__init__(proxy_map)
        self.Period = 50
        if startup_check:
            self.startup_check()
        else:
            self.timer.timeout.connect(self.compute)
            self.timer.start(self.Period)

        self.opts = opts # Copy user options
        self.wait_frames = 5 # Wait for these many frames to get stored before calling gaitset (see gaitSet/pretreatment.py)
        # Dictonary to store mapping between tracking id and segmentation mask
        self.id2mask = {}


        # Segmentation module
        segmentation_model = UNet(backbone="resnet18", num_classes=2)
        segmentation_model.load_state_dict(torch.load("./src/PretrainedModels/UNet_ResNet18.pth", map_location="cpu")['state_dict'], strict=False)
        self.segmentation_inference = ImageListInference(segmentation_model,opts)

        # Gait feature extraction
        self.gait_model = SetNet(256)
        self.gait_model.load_state_dict(torch.load("./src/PretrainedModels/GaitSet_CASIA-B_73_False_256_0.2_128_full_30-80000-encoder.ptm",map_location="cpu"),strict=False)
        if self.opts.gpu >= 0:
            self.gait_model.cuda()
        self.gait_model.eval()

        # Variables to store data for computation
        self.lock = False # Lock from taking any more input until features are calculated
        self.input_image_list = [] # Store images in this list before performing segmentation
        self.input_tracking_id = [] # Store respective tracking id here
Beispiel #6
0
def get_model(model_path, model_type):
    """

    :param model_path:
    :param model_type: 'UNet', 'UNet11', 'UNet16', 'AlbuNet34'
    :return:
    """

    num_classes = 1

    if model_type == 'UNet11':
        model = UNet11(num_classes=num_classes)
    elif model_type == 'UNet16':
        model = UNet16(num_classes=num_classes)
    elif model_type == 'AlbuNet34':
        model = AlbuNet34(num_classes=num_classes)
    elif model_type == 'UNet':
        model = UNet(num_classes=num_classes)
    else:
        model = UNet(num_classes=num_classes)

    state = torch.load(str(model_path))
    state = {
        key.replace('module.', ''): value
        for key, value in state['model'].items()
    }
    model.load_state_dict(state)

    if torch.cuda.is_available():
        return model.cuda()

    model.eval()

    return model
Beispiel #7
0
def get_model(model_path, model_type, num_classes):
    """

    :param model_path:
    :param model_type: 'UNet', 'UNet16', 'UNet11', 'LinkNet34',
    :param problem_type: 'binary', 'parts', 'instruments'
    :return:
    """

    if model_type == 'UNet':
        model = UNet(num_classes=num_classes)
    else:
        model_name = model_list[model_type]
        model = model_name(num_classes=num_classes)


#    print(model)
    state = torch.load(str(model_path))
    state = {
        key.replace('module.', ''): value
        for key, value in state['model'].items()
    }
    model.load_state_dict(state)

    if torch.cuda.is_available():
        return model.cuda()

    model.eval()

    return model
Beispiel #8
0
def test(opt, log_dir, generator=None):
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    if generator == None:
        generator = UNet(opt.sample_num, opt.channels, opt.batch_size,
                         opt.alpha)

        checkpoint = torch.load(opt.load_model, map_location=device)
        generator.load_state_dict(checkpoint['g_state_dict'])
        del checkpoint
        torch.cuda.empty_cache()

    generator.to(device)
    generator.eval()

    dataloader = torch.utils.data.DataLoader(MyDataset_test(opt),
                                             opt.batch_size,
                                             shuffle=True,
                                             num_workers=0)

    for i, (imgs, filename) in enumerate(dataloader):
        with torch.no_grad():
            test_img = generator(imgs.to(device))
            filename = filename[0].split('/')[-1]
            filename = "test/" + filename + '.png'
            test_img = convert_im(test_img,
                                  os.path.join(log_dir, filename),
                                  nrow=5,
                                  normalize=True,
                                  save_im=True)
Beispiel #9
0
def export2caffe(weights, num_classes, img_size):
    model = UNet(num_classes)
    weights = torch.load(weights, map_location='cpu')
    model.load_state_dict(weights['model'])
    model.eval()
    fuse(model)
    name = 'DeepLabV3Plus'
    dummy_input = torch.ones([1, 3, img_size[1], img_size[0]])
    pytorch2caffe.trans_net(model, dummy_input, name)
    pytorch2caffe.save_prototxt('{}.prototxt'.format(name))
    pytorch2caffe.save_caffemodel('{}.caffemodel'.format(name))
def setup(opts):
    model = UNet(backbone=opts["backbone"], num_classes=2)

    if torch.cuda.is_available():
        print("Using CUDA")
        trained_dict = torch.load(opts["checkpoint"])['state_dict']
        model.load_state_dict(trained_dict, strict=False)
        model.cuda()
    else:
        print("Using CPU")
        trained_dict = torch.load(opts["checkpoint"], map_location="cpu")['state_dict']
        model.load_state_dict(trained_dict, strict=False)

    return model
Beispiel #11
0
def denoise_3d(mat, model=None, filepath='/home/jupyter/notebooks/checkpoints/3d_denoise.pt', return_detached=True, 
               batch_size=5000, device=torch.device('cuda')):
    if model is None:
        model = UNet(in_channels=1, num_classes=1, out_channels=[4, 8, 16], num_conv=2, n_dim=3, 
                     kernel_size=[3, 3, 3], same_shape=True).to(device)
        model.load_state_dict(torch.load(filepath))
    with torch.no_grad():
        num_batches = (mat.size(0) + batch_size - 1)//batch_size
        mat = torch.cat([model(mat[batch_size*i:batch_size*(i+1)]) for i in range(num_batches)], dim=0)
    if return_detached:
        mat = mat.detach()
    for k in [k for k in locals().keys() if k!='mat']:
        del locals()[k]
    torch.cuda.empty_cache()
    return mat
Beispiel #12
0
def denoise_trace(trace, model=None, filepath='/home/jupyter/notebooks/checkpoints/denoise_trace.pt', return_detached=True, 
                  device=torch.device('cuda')):
    if model is None:
        model = UNet(in_channels=1, num_classes=1, out_channels=[8, 16, 32], num_conv=2, 
                     n_dim=1, kernel_size=3).to(device)
        model.load_state_dict(torch.load(filepath))
    with torch.no_grad():
        mean = trace.mean()
        std = trace.std()
        pred = model((trace-mean)/std)
        pred = model(pred)
        pred = pred * std + mean
    if return_detached:
        pred = pred.detach()
    for k in [k for k in locals().keys() if k!='pred']:
        del locals()[k]
    torch.cuda.empty_cache()
    return pred
Beispiel #13
0
def attention_map(mat, model=None, filepath='/home/jupyter/notebooks/checkpoints/segmentation_count_hardmask.pt', 
                  batch_size=5000, return_detached=True, device=torch.device('cuda')):
    if model is None:
        model = UNet(in_channels=1, num_classes=1, out_channels=[4, 8, 16], num_conv=2, n_dim=3, 
                     kernel_size=[3, 3, 3], same_shape=True).to(device)
        model.load_state_dict(torch.load(filepath))
    nrow, ncol = mat.shape[1:]
    if batch_size*nrow*ncol > 1e7:
        batch_size = int(1e7 / (nrow*ncol))
    with torch.no_grad():
        num_batches = (mat.size(0) + batch_size - 1)//batch_size
        mat = torch.cat([model(mat[batch_size*i:batch_size*(i+1)]) for i in range(num_batches)], dim=0).mean(0)
    if return_detached:
        mat = mat.detach()
    for k in [k for k in locals().keys() if k!='mat']:
        del locals()[k]
    torch.cuda.empty_cache()
    return mat
Beispiel #14
0
def test(weights_path):

    # Get all images in train set
    image_names = os.listdir('dataset/train/images/')
    image_names = [name for name in image_names if name.endswith(('.jpg', '.JPG', '.png'))]

    # Initialize model and transfer to device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = UNet()
    model = model.to(device)
    model.eval()

    # Load weights
    model.load_state_dict(torch.load(weights_path, map_location=device))

    # Misc info
    img_size = 512

    # Predict on images
    for image_name in tqdm(image_names):

        # Load image, prepare for inference
        img = cv2.imread(os.path.join('dataset/train/images/', image_name))

        img_torch = prepare_image(img, img_size)

        with torch.no_grad():

            # Get predictions for image
            pred_egg_mask, pred_pan_mask = model(img_torch)

            # Threshold by 0.5
            pred_egg_mask = (torch.sigmoid(pred_egg_mask) >= 0.5).type(pred_egg_mask.dtype)
            pred_pan_mask = (torch.sigmoid(pred_pan_mask) >= 0.5).type(pred_pan_mask.dtype)

            pred_egg_mask, pred_pan_mask = pred_egg_mask.cpu().detach().numpy(), pred_pan_mask.cpu().detach().numpy()

        # Resize masks back to original shape
        pred_egg_mask, pred_pan_mask = pred_egg_mask[0][0] * 256, pred_pan_mask[0][0] * 256
        pred_egg_mask, pred_pan_mask = postprocess_masks(img, pred_egg_mask, pred_pan_mask)

        cv2.imwrite('test_vis/' + image_name[:-4] + '_egg' + image_name[-4:], pred_egg_mask)
        cv2.imwrite('test_vis/' + image_name[:-4] + '_pan' + image_name[-4:], pred_pan_mask)
        cv2.imwrite('test_vis/' + image_name, img)
Beispiel #15
0
def test(device, model_path, dataset_path, out_path):
    """
    Tests the network on the dataset_path
    """
    network = UNet(1, 3).to(device)
    if os.path.exists(model_path):
        network.load_state_dict(torch.load(model_path))
    dataset = GrayDataset(dataset_path, transform=val_transform)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=cpu_count())
    with torch.no_grad():
        network.eval()
        for i, gray in enumerate(tqdm.tqdm(loader, desc="Testing",
                                           leave=False)):
            gray = gray.to(device)
            pred_color = network(gray)
            result = F.to_pil_image((pred_color.cpu().squeeze() * 0.5) + 0.5)
            result.save(os.path.join(out_path, "{:06d}.png".format(i)))
Beispiel #16
0
def test(device, gen_model, fake_dataset_path, out_dir):
    """tests a gan"""
    print("Test a gan")
    val_transform = tv.transforms.Compose([
        tv.transforms.Resize((224, 224)),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, ), (0.5, ))
    ])
    fakedataset = GrayDataset(fake_dataset_path, transform=val_transform)
    fakeloader = torch.utils.data.DataLoader(fakedataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=cpu_count())
    generator = UNet(1, 3).to(device)
    if os.path.exists(gen_model):
        generator.load_state_dict(torch.load(gen_model))
    with torch.no_grad():
        generator.eval()
        for i, fake_data in enumerate(tqdm.tqdm(fakeloader)):
            fake_data = fake_data.to(device)
            fake = generator(fake_data)
            fake_im = fake.squeeze().cpu() * 0.5 + 0.5
            fake_im = tv.transforms.functional.to_pil_image(fake_im)
            fake_im.save(os.path.join(out_dir, "{:06d}.png".format(i)))
Beispiel #17
0
def inference(img_dir='data/samples',
              img_size=256,
              output_dir='outputs',
              weights='weights/best_miou.pt',
              unet=False):
    os.makedirs(output_dir, exist_ok=True)
    if unet:
        model = UNet(30)
    else:
        model = DeepLabV3Plus(30)
    model = model.to(device)
    state_dict = torch.load(weights, map_location=device)
    model.load_state_dict(state_dict['model'])
    model.eval()
    names = [
        n for n in os.listdir(img_dir)
        if os.path.splitext(n)[1] in ['.jpg', '.jpeg', '.png', '.tiff']
    ]
    with torch.no_grad():
        for name in tqdm(names):
            path = os.path.join(img_dir, name)
            img = cv2.imread(path)
            img_shape = img.shape
            h = (img.shape[0] / max(img.shape[:2]) * img_size) // 32
            w = (img.shape[1] / max(img.shape[:2]) * img_size) // 32
            img = cv2.resize(img, (int(w * 32), int(h * 32)))
            img = img[:, :, ::-1]
            img = img.transpose(2, 0, 1)
            img = torch.FloatTensor([img], device=device) / 255.
            output = model(img)[0].cpu().numpy().transpose(1, 2, 0)
            output = cv2.resize(output, (img_shape[1], img_shape[0]))
            output = output.argmax(2)
            seg = np.zeros(img_shape, dtype=np.uint8)
            for ci, color in enumerate(VOC_COLORMAP):
                seg[output == ci] = color
            cv2.imwrite(os.path.join(output_dir, name), seg)
def get_model(model_path, model_type='unet11', problem_type='binary'):
    """

    :param model_path:
    :param model_type: 'UNet', 'UNet16', 'UNet11', 'LinkNet34'
    :param problem_type: 'binary', 'parts', 'instruments'
    :return:
    """
    if problem_type == 'binary':
        num_classes = 1
    elif problem_type == 'parts':
        num_classes = 4
    elif problem_type == 'instruments':
        num_classes = 8

    if model_type == 'UNet16':
        model = UNet16(num_classes=num_classes)
    elif model_type == 'UNet11':
        model = UNet11(num_classes=num_classes)
    elif model_type == 'LinkNet34':
        model = LinkNet34(num_classes=num_classes)
    elif model_type == 'UNet':
        model = UNet(num_classes=num_classes)
    elif model_type == 'DLinkNet':
        model = D_LinkNet34(num_classes=num_classes, pretrained=True)

    state = torch.load(str(model_path))
    state = {key.replace('module.', ''): value for key, value in state['model'].items()}
    model.load_state_dict(state)

    if torch.cuda.is_available():
        return model.cuda()

    model.eval()

    return model
Beispiel #19
0
def predict():
    net = UNet(n_channels=1, n_classes=1)
    net.eval()
    # 将多GPU模型加载为CPU模型
    if opt.load_model_path:
        checkpoint = t.load(opt.load_model_path)
        state_dict = checkpoint['net']
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        net.load_state_dict(new_state_dict)  # 加载模型
        print('加载预训练模型{}'.format(opt.load_model_path))
    if opt.use_gpu:
        net.cuda()

    test_data = NodeDataSet(test=True)

    test_dataloader = DataLoader(test_data,
                                 opt.test_batch_size,
                                 shuffle=False,
                                 num_workers=opt.num_workers)
    for ii, full_img in enumerate(test_dataloader):
        img_test = full_img[0][0].unsqueeze(
            0)  # 第一个[0]  取 原图像的一个batch,第二个[0]指batch为1
        if opt.use_gpu:
            img_test = img_test.cuda()

        with t.no_grad():  # pytorch0.4版本写法
            output = net(img_test)
            probs = t.sigmoid(output).squeeze(0)

        full_mask = probs.squeeze().cpu().numpy()
        # ===========================================下面方法可能未考虑 一通道图像
        # if opt.use_dense_crf:
        #     full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
        mask = full_mask > opt.out_threshold  # 预测mask值都太小,最大0.01

        # # 可视化1
        # plt.imsave(opt.save_test_dir+str(10000+ii)+'full_img.jpg', full_img[0][0].squeeze(0),cmap = cm.gray)  #保存原图
        # plt.imsave(opt.save_test_dir+str(10000+ii)+'mask.jpg', mask,cmap = cm.gray) #保存mask
        # plt.imsave(opt.save_test_dir+str(10000+ii)+'full_mask.jpg', full_img[0][0].squeeze(0).squeeze(0).numpy() * mask,cmap = cm.gray)  #保存mask之后的原图

        # 可视化2
        # # 多子图显示原图和mask
        # plt.subplot(1,3,1)
        # plt.title('origin')
        # plt.imshow(full_img[0][0].squeeze(0),cmap='Greys_r')
        #
        # plt.subplot(1, 3, 2)
        # plt.title('mask')
        # plt.imshow(mask,cmap='Greys_r')
        #
        # plt.subplot(1, 3, 3)
        # plt.title('origin_after_mask')
        # plt.imshow( full_img[0][0].squeeze(0).squeeze(0).numpy() * mask,cmap='Greys_r')
        #
        # plt.show()

        # 保存mask为npy
        np.save('/home/bobo/data/test/test8/' + str(10000 + ii) + '_mask.npy',
                mask)

    print('测试完毕')
Beispiel #20
0
def train():
    t.cuda.set_device(1)

    # n_channels:医学影像为一通道灰度图    n_classes:二分类
    net = UNet(n_channels=1, n_classes=1)
    optimizer = t.optim.SGD(net.parameters(),
                            lr=opt.learning_rate,
                            momentum=0.9,
                            weight_decay=0.0005)
    criterion = t.nn.BCELoss()  # 二进制交叉熵(适合mask占据图像面积较大的场景)

    start_epoch = 0
    if opt.load_model_path:
        checkpoint = t.load(opt.load_model_path)

        # 加载多GPU模型参数到 单模型上
        state_dict = checkpoint['net']
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        net.load_state_dict(new_state_dict)  # 加载模型
        optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器
        start_epoch = checkpoint['epoch']  # 加载训练批次

    # 学习率每当到达milestones值则更新参数
    if start_epoch == 0:
        scheduler = t.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=opt.milestones,
                                                     gamma=0.1,
                                                     last_epoch=-1)  # 默认为-1
        print('从头训练 ,学习率为{}'.format(optimizer.param_groups[0]['lr']))
    else:
        scheduler = t.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=opt.milestones,
                                                     gamma=0.1,
                                                     last_epoch=start_epoch)
        print('加载预训练模型{}并从{}轮开始训练,学习率为{}'.format(
            opt.load_model_path, start_epoch, optimizer.param_groups[0]['lr']))

    # 网络转移到GPU上
    if opt.use_gpu:
        net = t.nn.DataParallel(net, device_ids=opt.device_ids)  # 模型转为GPU并行
        net.cuda()
        cudnn.benchmark = True

    # 定义可视化对象
    vis = Visualizer(opt.env)

    train_data = NodeDataSet(train=True)
    val_data = NodeDataSet(val=True)
    test_data = NodeDataSet(test=True)

    # 数据集加载器
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=True,
                                num_workers=opt.num_workers)
    test_dataloader = DataLoader(test_data,
                                 opt.test_batch_size,
                                 shuffle=False,
                                 num_workers=opt.num_workers)
    for epoch in range(opt.max_epoch - start_epoch):
        print('开始 epoch {}/{}.'.format(start_epoch + epoch + 1, opt.max_epoch))
        epoch_loss = 0

        # 每轮判断是否更新学习率
        scheduler.step()

        # 迭代数据集加载器
        for ii, (img, mask) in enumerate(
                train_dataloader):  # pytorch0.4写法,不再将tensor封装为Variable
            # 将数据转到GPU
            if opt.use_gpu:
                img = img.cuda()
                true_masks = mask.cuda()
            masks_pred = net(img)

            # 经过sigmoid
            masks_probs = t.sigmoid(masks_pred)

            # 损失 = 二进制交叉熵损失 + dice损失
            loss = criterion(masks_probs.view(-1), true_masks.view(-1))

            # 加入dice损失
            if opt.use_dice_loss:
                loss += dice_loss(masks_probs, true_masks)

            epoch_loss += loss.item()

            if ii % 2 == 0:
                vis.plot('训练集loss', loss.item())

            # 优化器梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 更新参数
            optimizer.step()

        # 当前时刻的一些信息
        vis.log("epoch:{epoch},lr:{lr},loss:{loss}".format(
            epoch=epoch, loss=loss.item(), lr=optimizer.param_groups[0]['lr']))

        vis.plot('每轮epoch的loss均值', epoch_loss / ii)
        # 保存模型、优化器、当前轮次等
        state = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch
        }
        t.save(state, opt.checkpoint_root + '{}_unet.pth'.format(epoch))

        # ============验证===================

        net.eval()
        # 评价函数:Dice系数    Dice距离用于度量两个集合的相似性
        tot = 0
        for jj, (img_val, mask_val) in enumerate(val_dataloader):
            img_val = img_val
            true_mask_val = mask_val
            if opt.use_gpu:
                img_val = img_val.cuda()
                true_mask_val = true_mask_val.cuda()

            mask_pred = net(img_val)
            mask_pred = (t.sigmoid(mask_pred) > 0.5).float()  # 阈值为0.5
            # 评价函数:Dice系数   Dice距离用于度量两个集合的相似性
            tot += dice_loss(mask_pred, true_mask_val).item()
        val_dice = tot / jj
        vis.plot('验证集 Dice损失', val_dice)

        # ============验证召回率===================
        # 每10轮验证一次测试集召回率
        if epoch % 10 == 0:
            result_test = []
            for kk, (img_test, mask_test) in enumerate(test_dataloader):
                # 测试 unet分割能力,故 不使用真值mask
                if opt.use_gpu:
                    img_test = img_test.cuda()
                mask_pred_test = net(img_test)  # [1,1,512,512]

                probs = t.sigmoid(mask_pred_test).squeeze().squeeze().cpu(
                ).detach().numpy()  # [512,512]
                mask = probs > opt.out_threshold
                result_test.append(mask)

            # 得到 测试集所有预测掩码,计算二维召回率
            vis.plot('测试集二维召回率', getRecall(result_test).getResult())
        net.train()
Beispiel #21
0
    loss_list = []
    for i in tqdm(range(args.epochs)):
        train(i, exp_lr_scheduler, loss_list)
        test()

    plt.plot(loss_list)
    plt.title("UNet bs={}, ep={}, lr={}".format(args.batch_size, args.epochs,
                                                args.lr))
    plt.xlabel("Number of iterations")
    plt.ylabel("Average DICE loss per batch")
    plt.savefig("plots/{}-UNet_Loss_bs={}_ep={}_lr={}.png".format(
        args.save, args.batch_size, args.epochs, args.lr))

    np.save(
        'npy-files/loss-files/{}-UNet_Loss_bs={}_ep={}_lr={}.npy'.format(
            args.save, args.batch_size, args.epochs, args.lr),
        np.asarray(loss_list))

    torch.save(
        model.state_dict(),
        '{}unetsmall-final-{}-{}-{}'.format(SAVE_MODEL_NAME, args.batch_size,
                                            args.epochs, args.lr))

# elif args.pred:
#     predict()

elif args.load is not None:
    model.load_state_dict(torch.load(args.load))
    #test()
    predict()
Beispiel #22
0
    device = torch.device(CUDA_SELECT if torch.cuda.is_available() else "cpu")
    model = UNet(num_classes=2, input_channels=1)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device)

    # To handle epoch start number and pretrained weight
    epoch_start = '0'
    if (use_pretrained):
        print("Loading Model {}".format(
            os.path.basename(pretrained_model_path)))
        model.load_state_dict(torch.load(pretrained_model_path))
        epoch_start = os.path.basename(pretrained_model_path).split('.')[0]
        print(epoch_start)

    trainLoader = DataLoader(DatasetImageMaskLocal(train_file_names,
                                                   object_type,
                                                   mode='train'),
                             batch_size=batch_size)
    devLoader = DataLoader(
        DatasetImageMaskLocal(val_file_names, object_type, mode='valid'))
    displayLoader = DataLoader(DatasetImageMaskLocal(val_file_names,
                                                     object_type,
                                                     mode='valid'),
                               batch_size=val_batch_size)

    optimizer = Adam(model.parameters(), lr=1e-4)
Beispiel #23
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--input', required=True)
    parser.add_argument('--output', default='predicted.jpg')
    parser.add_argument('--output-dir', default='logs')

    opt = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = UNet(3, 1)
    if device == torch.device('cuda'):
        net = nn.DataParallel(net)
    net.to(device=device)
    
    net.load_state_dict(torch.load(os.path.expanduser(opt.model)))
    net.eval()

    # image
    image = Image.open(os.path.expanduser(opt.input))
    image = CarvanaDatasetTransforms([256, 512]).transform(image)
    image.to(device)

    mask = net(image.unsqueeze(0))[0]
    mask = torch.sigmoid(mask)
    mask = mask.squeeze(0).cpu().detach().numpy()
    mask = mask > 0.5

    mask = Image.fromarray((mask * 255).astype(np.uint8))

    mask.save(os.path.join(opt.output_dir, opt.output))
Beispiel #24
0
    test_file = config_params['test_images_txt']
    input_size = config_params['input_size']
    num_channels = config_params['num_channels']
    n_classes = config_params['n_classes']
    bilinear = config_params['bilinear']

    # torch.manual_seed(config_params['seed'])
    test_set = NucleusTestDataset(test_file, input_size)
    test_loader = DataLoader(test_set,
                             batch_size=1,
                             sampler=RandomSampler(test_set))

    # Inference device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Load Model
    model = UNet(n_channels=num_channels,
                 n_classes=n_classes,
                 bilinear=bilinear).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    # Threshold for prediction
    threshold = float(args.threshold)

    # Get test image
    img, idx = next(iter(test_loader))
    pred = predict_mask(model, img, threshold, device)

    # Visualise Prediction
    visualize(img, pred)
Beispiel #25
0
def train(device, gen_model, disc_model, real_dataset_path, epochs):
    """trains a gan"""
    train_transform = tv.transforms.Compose([
        tv.transforms.Resize((224, 224)),
        tv.transforms.RandomHorizontalFlip(0.5),
        tv.transforms.RandomVerticalFlip(0.5),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, ), (0.5, ))
    ])

    realdataset = ColorDataset(real_dataset_path, transform=train_transform)
    realloader = torch.utils.data.DataLoader(realdataset,
                                             batch_size=20,
                                             shuffle=True,
                                             num_workers=cpu_count(),
                                             drop_last=True)
    realiter = iter(realloader)

    discriminator = discriminator_model(3, 1024).to(device)
    disc_optimizer = torch.optim.Adam(discriminator.parameters(),
                                      lr=0.0001,
                                      betas=(0, 0.9))
    if os.path.exists(disc_model):
        discriminator.load_state_dict(torch.load(disc_model))

    generator = UNet(1, 3).to(device)
    gen_optimizer = torch.optim.Adam(generator.parameters(),
                                     lr=0.0001,
                                     betas=(0, 0.9))
    if os.path.exists(gen_model):
        generator.load_state_dict(torch.load(gen_model))

    one = torch.FloatTensor([1])
    mone = one * -1
    one = one.to(device).squeeze()
    mone = mone.to(device).squeeze()

    n_critic = 5
    lam = 10
    for _ in tqdm.trange(epochs, desc="Epochs"):
        for param in discriminator.parameters():
            param.requires_grad = True

        for _ in range(n_critic):
            real_data, realiter = try_iter(realiter, realloader)
            real_data = real_data.to(device)

            disc_optimizer.zero_grad()

            disc_real = discriminator(real_data)
            real_cost = torch.mean(disc_real)
            real_cost.backward(mone)

            # fake_data, fakeiter = try_iter(fakeiter, fakeloader)
            fake_data = torch.randn(real_data.shape[0], 1, 224, 224)
            fake_data = fake_data.to(device)
            disc_fake = discriminator(generator(fake_data))
            fake_cost = torch.mean(disc_fake)
            fake_cost.backward(one)

            gradient_penalty = calc_gp(device, discriminator, real_data,
                                       fake_data, lam)
            gradient_penalty.backward()

            disc_optimizer.step()
        for param in discriminator.parameters():
            param.requires_grad = False
        gen_optimizer.zero_grad()

        # fake_data, fakeiter = try_iter(fakeiter, fakeloader)
        fake_data = torch.randn(real_data.shape[0], 1, 224, 224)
        fake_data = fake_data.to(device)
        disc_g = discriminator(generator(fake_data)).mean()
        disc_g.backward(mone)
        gen_optimizer.step()

        torch.save(generator.state_dict(), gen_model)
        torch.save(discriminator.state_dict(), disc_model)
Beispiel #26
0
def start():
    parser = argparse.ArgumentParser(
        description='UNet + BDCLSTM for BraTS Dataset')
    parser.add_argument('--batch-size',
                        type=int,
                        default=4,
                        metavar='N',
                        help='input batch size for training (default: 4)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=4,
                        metavar='N',
                        help='input batch size for testing (default: 4)')
    parser.add_argument('--train',
                        action='store_true',
                        default=False,
                        help='Argument to train model (default: False)')
    parser.add_argument('--epochs',
                        type=int,
                        default=2,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='enables CUDA training (default: False)')
    parser.add_argument('--log-interval',
                        type=int,
                        default=1,
                        metavar='N',
                        help='batches to wait before logging training status')
    parser.add_argument('--size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='imsize')
    parser.add_argument('--load',
                        type=str,
                        default=None,
                        metavar='str',
                        help='weight file to load (default: None)')
    parser.add_argument('--data',
                        type=str,
                        default='./Data/',
                        metavar='str',
                        help='folder that contains data')
    parser.add_argument('--save',
                        type=str,
                        default='OutMasks',
                        metavar='str',
                        help='Identifier to save npy arrays with')
    parser.add_argument('--modality',
                        type=str,
                        default='flair',
                        metavar='str',
                        help='Modality to use for training (default: flair)')
    parser.add_argument('--optimizer',
                        type=str,
                        default='SGD',
                        metavar='str',
                        help='Optimizer (default: SGD)')

    args = parser.parse_args()
    args.cuda = args.cuda and torch.cuda.is_available()

    DATA_FOLDER = args.data

    # %% Loading in the model
    # Binary
    # model = UNet(num_channels=1, num_classes=2)
    # Multiclass
    model = UNet(num_channels=1, num_classes=3)

    if args.cuda:
        model.cuda()

    if args.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.99)
    if args.optimizer == 'ADAM':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))

    # Defining Loss Function
    criterion = DICELossMultiClass()

    if args.train:
        # %% Loading in the Dataset
        full_dataset = BraTSDatasetUnet(DATA_FOLDER,
                                        im_size=[args.size, args.size],
                                        transform=tr.ToTensor())
        #dset_test = BraTSDatasetUnet(DATA_FOLDER, train=False,
        # keywords=[args.modality], im_size=[args.size,args.size], transform=tr.ToTensor())

        train_size = int(0.9 * len(full_dataset))
        test_size = len(full_dataset) - train_size
        train_dataset, validation_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, test_size])

        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=1)
        validation_loader = DataLoader(validation_dataset,
                                       batch_size=args.test_batch_size,
                                       shuffle=False,
                                       num_workers=1)
        #test_loader = DataLoader(full_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=1)

        print("Training Data : ", len(train_loader.dataset))
        print("Validaion Data : ", len(validation_loader.dataset))
        #print("Test Data : ", len(test_loader.dataset))

        loss_list = []
        start = timer()
        for i in tqdm(range(args.epochs)):
            train(model, i, loss_list, train_loader, optimizer, criterion,
                  args)
            test(model, validation_loader, criterion, args, validation=True)
        end = timer()
        print("Training completed in {:0.2f}s".format(end - start))

        plt.plot(loss_list)
        plt.title("UNet bs={}, ep={}, lr={}".format(args.batch_size,
                                                    args.epochs, args.lr))
        plt.xlabel("Number of iterations")
        plt.ylabel("Average DICE loss per batch")
        plt.savefig("./plots/{}-UNet_Loss_bs={}_ep={}_lr={}.png".format(
            args.save, args.batch_size, args.epochs, args.lr))

        np.save(
            './npy-files/loss-files/{}-UNet_Loss_bs={}_ep={}_lr={}.npy'.format(
                args.save, args.batch_size, args.epochs, args.lr),
            np.asarray(loss_list))
        print("Testing Validation")
        test(model, validation_loader, criterion, args, save_output=True)
        torch.save(
            model.state_dict(),
            'unet-multiclass-model-{}-{}-{}'.format(args.batch_size,
                                                    args.epochs, args.lr))

        print("Testing PDF images")
        test_dataset = TestDataset('./pdf_data/',
                                   im_size=[args.size, args.size],
                                   transform=tr.ToTensor())
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.test_batch_size,
                                 shuffle=False,
                                 num_workers=1)
        print("Test Data : ", len(test_loader.dataset))
        test_only(model, test_loader, criterion, args)

    elif args.load is not None:
        test_dataset = TestDataset(DATA_FOLDER,
                                   im_size=[args.size, args.size],
                                   transform=tr.ToTensor())
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.test_batch_size,
                                 shuffle=False,
                                 num_workers=1)
        print("Test Data : ", len(test_loader.dataset))
        model.load_state_dict(torch.load(args.load))
        test_only(model, test_loader, criterion, args)
Beispiel #27
0
def main(cfg: DictConfig):

    # This is here to collapse the code in VS Code
    if True:

        # Setup
        print = logging.getLogger(__name__).info
        print(OmegaConf.to_yaml(cfg))
        pl.seed_everything(cfg.seed)

        # Create validation and test segmentation datasets
        # NOTE: The batch size must be 1 for test because the masks are different sizes,
        # and evaluation should be done using the mask at the original resolution.
        val_dataloaders = []
        test_dataloaders = []
        for _cfg in cfg.data_seg.data:
            kwargs = dict(images_dir=_cfg.images_dir,
                          labels_dir=_cfg.labels_dir,
                          image_size=cfg.data_seg.image_size)
            val_dataset = SegmentationDataset(**kwargs, crop=True)
            test_dataset = SegmentationDataset(**kwargs,
                                               crop=_cfg.crop,
                                               resize_mask=False)
            val_dataloaders.append(DataLoader(val_dataset, **cfg.dataloader))
            test_dataloaders.append(
                DataLoader(test_dataset, **{
                    **cfg.dataloader, 'batch_size': 1
                }))

    # Evaluate only
    if not cfg.train:
        assert cfg.eval_checkpoint is not None

        # Print dataset info
        for i, dataloader in enumerate(test_dataloaders):
            dataset = dataloader.dataset
            print(
                f'Test dataset / dataloader size [{i}]: {len(dataset)} / {len(dataset)}'
            )

        # Create trainer
        trainer = pl.Trainer(**cfg.trainer)

        # Load checkpoint(s)
        net = UNet().eval()
        checkpoint = torch.load(cfg.eval_checkpoint, map_location='cpu')
        state_dict = {
            k.replace('net.', ''): v
            for k, v in checkpoint["state_dict"].items()
        }
        net.load_state_dict(state_dict)
        print(f'Loaded checkpoint from {cfg.eval_checkpoint}')

        # Create module
        module = SementationModule(net, cfg).eval()

        # Compute test results
        trainer.test(module, test_dataloaders=test_dataloaders)

        # Pretty print results
        table = utils.get_metrics_as_table(trainer.callback_metrics)
        print('\n' + str(table.round(decimals=3)))

    # Train
    else:

        # Generated images: load from disk
        if cfg.data_gen.load_from_disk:
            print('Loading images from disk')

            # Transforms
            train_transform = val_transform = A.Compose([
                A.Resize(cfg.data_gen.image_size, cfg.data_gen.image_size),
                A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                ToTensorV2()
            ])

            # Loaders
            gan_train_dataloader, gan_val_dataloader = create_train_and_val_dataloaders(
                cfg,
                train_transform=train_transform,
                val_transform=val_transform)

        # Generated images: generate on the fly
        else:
            print('Loading images on-the-fly')

            # Create GAN dataset
            gan_train_dataset = create_gan_dataset(cfg.data_gen)

            # GAN training dataloader
            # NOTE: Only 1 process (num_workers=0) supported
            gan_train_dataloader = DataLoader(gan_train_dataset, batch_size=1)

            # Load or create GAN validation batches
            print('Creating new GAN validation set.')
            num_batches = max(
                1, cfg.data_gen.val_images // cfg.data_gen.kwargs.batch_size)
            gan_val_batches = utils.get_subset_of_dataset(
                dataset=gan_train_dataset, num_batches=num_batches)
            gan_val_dataset = TensorDataset(*gan_val_batches)

            # Save example images from GAN validation dataset
            fname = 'generated-val-examples.png'
            utils.save_overlayed_images(gan_val_batches,
                                        filename=fname,
                                        is_mask=True)
            print(f'Saved visualization images to {fname}')

            # Validation dataloader
            gan_val_dataloader = DataLoader(gan_val_dataset, **cfg.dataloader)

        # Summary of dataset/dataloader sizes
        print(f'Generated train {utils.get_dl_size(gan_train_dataloader)}')
        print(f'Generated val {utils.get_dl_size(gan_val_dataloader)}')
        for i, dl in enumerate(val_dataloaders):
            print(f'Seg val [{i}] {utils.get_dl_size(dl)}')

        # Validation dataloaders
        val_dataloaders = [gan_val_dataloader, *val_dataloaders]

        # Checkpointer
        callbacks = [
            pl.callbacks.ModelCheckpoint(monitor='train_loss',
                                         save_top_k=20,
                                         save_last=True,
                                         verbose=True),
            pl.callbacks.LearningRateMonitor('step')
        ]

        # Logging
        logger = pl.loggers.WandbLogger(name=cfg.name) if cfg.wandb else True

        # Trainer
        trainer = pl.Trainer(logger=logger, callbacks=callbacks, **cfg.trainer)

        # Lightning
        net = UNet().train()
        module = SementationModule(net, cfg)

        # Train
        trainer.fit(module,
                    train_dataloader=gan_train_dataloader,
                    val_dataloaders=val_dataloaders)

        # Test
        trainer.test(module, test_dataloaders=test_dataloaders)

        # Pretty print results
        table = utils.get_metrics_as_table(trainer.callback_metrics)
        print('\n' + str(table.round(decimals=3)))
Beispiel #28
0
	COLOR1 = [255, 0, 0]
	COLOR2 = [0, 0, 255]


#------------------------------------------------------------------------------
#	Create model and load weights
#------------------------------------------------------------------------------
model = UNet(
    backbone="mobilenetv2",
    num_classes=2,
	pretrained_backbone=None
)
if args.use_cuda:
	model = model.cuda()
trained_dict = torch.load(args.checkpoint, map_location="cpu")['state_dict']
model.load_state_dict(trained_dict, strict=False)
model.eval()


#------------------------------------------------------------------------------
#   Predict frames
#------------------------------------------------------------------------------
i = 0
while(cap.isOpened()):
	# Read frame from camera
	start_time = time()
	_, frame = cap.read()
	image = cv2.transpose(frame[...,::-1])
	h, w = image.shape[:2]
	read_cam_time = time()
Beispiel #29
0
    model2 = torchvision.models.segmentation.deeplabv3_resnet101(
        pretrained=False, num_classes=1)
    model2 = model2.to(device)
    modelName2 = model2.__class__.__name__
    model3 = UNet(1, 1).to(device)
    modelName3 = model3.__class__.__name__

    model1_checkpoint = torch.load(
        train().checkpointsPath + '/' + modelName1 + '/' +
        '2019-08-30 13:21:52.559302_epoch-5_dice-0.4743926368317377.pth')
    model2_checkpoint = torch.load(
        train().checkpointsPath + '/' + modelName2 + '/' +
        '2019-08-22 08:37:06.839794_epoch-1_dice-0.4479589270841744.pth')
    model3_checkpoint = torch.load(
        train().checkpointsPath + '/' + modelName3 + '/' +
        '2019-09-03 03:21:05.647040_epoch-253_dice-0.46157537277322264.pth')

    model1.load_state_dict(model1_checkpoint['model_state_dict'])
    model2.load_state_dict(model2_checkpoint['model_state_dict'])
    model3.load_state_dict(model3_checkpoint['model_state_dict'])

    try:
        # Create model Directory
        train().main(model1, model2, model3, device)
    except KeyboardInterrupt:
        print('Interrupted')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
def main():
    net = UNet(num_classes=num_classes).cuda()
    if len(train_args['snapshot']) == 0:
        curr_epoch = 0
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1])
        train_record['best_val_loss'] = float(split_snapshot[3])
        train_record['corr_mean_iu'] = float(split_snapshot[6])
        train_record['corr_epoch'] = curr_epoch

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.RandomCrop(train_args['input_size']),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.CenterCrop(train_args['input_size'])
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        expanded_transforms.MaskToTensor(),
        expanded_transforms.ChangeLabel(ignored_label, num_classes - 1)
    ])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    train_set = CityScapes('train', simul_transform=train_simul_transform, transform=img_transform,
                           target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=train_args['batch_size'], num_workers=16, shuffle=True)
    val_set = CityScapes('val', simul_transform=val_simul_transform, transform=img_transform,
                         target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=val_args['batch_size'], num_workers=16, shuffle=False)

    weight = torch.ones(num_classes)
    weight[num_classes - 1] = 0
    criterion = CrossEntropyLoss2d(weight).cuda()

    # don't use weight_decay for bias
    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], momentum=0.9, nesterov=True)

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    if not os.path.exists(ckpt_path):
        os.mkdir(ckpt_path)
    if not os.path.exists(os.path.join(ckpt_path, exp_name)):
        os.mkdir(os.path.join(ckpt_path, exp_name))

    for epoch in range(curr_epoch, train_args['epoch_num']):
        train(train_loader, net, criterion, optimizer, epoch)
        validate(val_loader, net, criterion, optimizer, epoch, restore_transform)