Exemplo n.º 1
0
def main():
    net = U_Net(img_ch=1, num_classes=3).to(device)

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(384),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    center_crop = joint_transforms.CenterCrop(crop_size)
    train_input_transform = extended_transforms.ImgToTensor()

    target_transform = extended_transforms.MaskToTensor()
    make_dataset_fn = bladder.make_dataset_v2
    train_set = bladder.Bladder(data_path,
                                'train',
                                joint_transform=train_joint_transform,
                                center_crop=center_crop,
                                transform=train_input_transform,
                                target_transform=target_transform,
                                make_dataset_fn=make_dataset_fn)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

    if loss_name == 'dice_':
        criterion = SoftDiceLossV2(activation='sigmoid',
                                   num_classes=3).to(device)
    elif loss_name == 'bcew_':
        criterion = nn.BCEWithLogitsLoss().to(device)

    optimizer = optim.Adam(net.parameters(), lr=1e-4)

    train(train_loader, net, criterion, optimizer, n_epoch, 0)
Exemplo n.º 2
0
    def transform_image_and_mask_tt(self, image, mask, angle=None, crop_size=None):
        assert self.use_iaa == False
        transforms = [JT.RandomHorizontallyFlip()]
        if crop_size is not None:
            transforms.append(JT.RandomCrop(size=crop_size))
        if angle is not None:
            transforms.append(JT.RandomRotate(degree=angle))
        jt_random = JT.RandomOrderApply(transforms)

        jt_transform = JT.Compose([
            JT.ToPILImage(),
            jt_random,
            JT.ToNumpy(),
        ])

        return jt_transform(image, mask)
Exemplo n.º 3
0
def main():
    net = PSPNet(num_classes=voc.num_classes).cuda()

    if len(args['snapshot']) == 0:
        # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth')))
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(args['longer_size']),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    sliding_crop = joint_transforms.SlidingCrop(args['crop_size'],
                                                args['stride_rate'],
                                                voc.ignore_label)
    train_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    visualize = standard_transforms.Compose([
        standard_transforms.Resize(args['val_img_display_size']),
        standard_transforms.ToTensor()
    ])

    train_set = voc.VOC('train',
                        joint_transform=train_joint_transform,
                        sliding_crop=sliding_crop,
                        transform=train_input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=1,
                              shuffle=True,
                              drop_last=True)
    val_set = voc.VOC('val',
                      transform=val_input_transform,
                      sliding_crop=sliding_crop,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=1,
                            shuffle=False,
                            drop_last=True)

    criterion = CrossEntropyLoss2d(size_average=True,
                                   ignore_index=voc.ignore_label).cuda()

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    train(train_loader, net, criterion, optimizer, curr_epoch, args,
          val_loader, visualize)
Exemplo n.º 4
0
def main(train_args):
    print('No of classes', doc.num_classes)
    if train_args['network'] == 'psp':
        net = PSPNet(num_classes=doc.num_classes,
                     resnet=resnet,
                     res_path=res_path).cuda()
    elif train_args['network'] == 'mfcn':
        net = MFCN(num_classes=doc.num_classes, use_aux=True).cuda()
    elif train_args['network'] == 'psppen':
        net = PSPNet(num_classes=doc.num_classes,
                     resnet=resnet,
                     res_path=res_path).cuda()
    print("number of cuda devices = ", torch.cuda.device_count())
    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, 'model_' + exp_name,
                             train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    net = torch.nn.DataParallel(net,
                                device_ids=list(
                                    range(torch.cuda.device_count())))
    net.train()
    mean_std = ([0.9584, 0.9588, 0.9586], [0.1246, 0.1223, 0.1224])
    weight = torch.FloatTensor(doc.num_classes)
    #weight[0] = 1.0/0.5511 # background
    train_simul_transform = simul_transforms.Compose([
        simul_transforms.RandomSized(train_args['input_size']),
        simul_transforms.RandomRotate(3),
        #simul_transforms.RandomHorizontallyFlip()
        simul_transforms.Scale(train_args['input_size']),
        simul_transforms.CenterCrop(train_args['input_size'])
    ])
    val_simul_transform = simul_transforms.Scale(train_args['input_size'])
    train_input_transform = standard_transforms.Compose([
        extended_transforms.RandomGaussianBlur(),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    train_set = doc.DOC('train',
                        Dataroot,
                        joint_transform=train_simul_transform,
                        transform=train_input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=train_args['train_batch_size'],
                              num_workers=1,
                              shuffle=True,
                              drop_last=True)
    train_loader_temp = DataLoader(train_set,
                                   batch_size=1,
                                   num_workers=1,
                                   shuffle=True,
                                   drop_last=True)
    train_args['No_train_images'] = len(train_loader_temp)
    del train_loader_temp
    if train_args['No_train_images'] == 87677:
        train_args['Type_of_train_image'] = 'All Synthetic slide image'
    elif train_args['No_train_images'] == 150:
        train_args['Type_of_train_image'] = 'Real image'
    elif train_args['No_train_images'] == 84641:
        train_args['Type_of_train_image'] = 'Synthetic image'
    elif train_args['No_train_images'] == 151641:
        train_args['Type_of_train_image'] = 'Real + Synthetic image'
    val_set = doc.DOC('val',
                      Dataroot,
                      joint_transform=val_simul_transform,
                      transform=val_input_transform,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=1,
                            shuffle=False,
                            drop_last=True)
    #criterion = CrossEntropyLoss2d(weight = weight, size_average = True, ignore_index = doc.ignore_label).cuda()
    criterion = CrossEntropyLoss2d(size_average=True,
                                   ignore_index=doc.ignore_label).cuda()
    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * train_args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        train_args['lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                          momentum=train_args['momentum'])
    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, 'model_' + exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']
    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(train_args) + '\n\n')
    train(train_loader, net, criterion, optimizer, curr_epoch, train_args,
          val_loader)
Exemplo n.º 5
0
def train_with_ignite(networks, dataset, data_dir, batch_size, img_size,
                      epochs, lr, momentum, num_workers, optimizer, logger):

    from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
    from ignite.metrics import Loss
    from utils.metrics import MultiThresholdMeasures, Accuracy, IoU, F1score

    # device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # build model
    model = get_network(networks)

    # log model summary
    input_size = (3, img_size, img_size)
    summarize_model(model.to(device), input_size, logger, batch_size, device)

    # build loss
    loss = torch.nn.BCEWithLogitsLoss()

    # build optimizer and scheduler
    model_optimizer = get_optimizer(optimizer, model, lr, momentum)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(model_optimizer)

    # transforms on both image and mask
    train_joint_transforms = jnt_trnsf.Compose([
        jnt_trnsf.RandomCrop(img_size),
        jnt_trnsf.RandomRotate(5),
        jnt_trnsf.RandomHorizontallyFlip()
    ])

    # transforms only on images
    train_image_transforms = std_trnsf.Compose([
        std_trnsf.ColorJitter(0.05, 0.05, 0.05, 0.05),
        std_trnsf.ToTensor(),
        std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_joint_transforms = jnt_trnsf.Compose([jnt_trnsf.Safe32Padding()])

    test_image_transforms = std_trnsf.Compose([
        std_trnsf.ToTensor(),
        std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # transforms only on mask
    mask_transforms = std_trnsf.Compose([std_trnsf.ToTensor()])

    # build train / test loader
    train_loader = get_loader(dataset=dataset,
                              data_dir=data_dir,
                              train=True,
                              joint_transforms=train_joint_transforms,
                              image_transforms=train_image_transforms,
                              mask_transforms=mask_transforms,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=num_workers)

    test_loader = get_loader(dataset=dataset,
                             data_dir=data_dir,
                             train=False,
                             joint_transforms=test_joint_transforms,
                             image_transforms=test_image_transforms,
                             mask_transforms=mask_transforms,
                             batch_size=1,
                             shuffle=False,
                             num_workers=num_workers)

    # build trainer / evaluator with ignite
    trainer = create_supervised_trainer(model,
                                        model_optimizer,
                                        loss,
                                        device=device)
    measure = MultiThresholdMeasures()
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                '': measure,
                                                'pix-acc': Accuracy(measure),
                                                'iou': IoU(measure),
                                                'loss': Loss(loss),
                                                'f1': F1score(measure),
                                            },
                                            device=device)

    # initialize state variable for checkpoint
    state = update_state(model.state_dict(), 0, 0, 0, 0, 0)

    # make ckpt path
    ckpt_root = './ckpt/'
    filename = '{network}_{optimizer}_lr_{lr}_epoch_{epoch}.pth'
    ckpt_path = os.path.join(ckpt_root, filename)

    # execution after every training iteration
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(trainer):
        num_iter = (trainer.state.iteration - 1) % len(train_loader) + 1
        if num_iter % 20 == 0:
            logger.info("Epoch[{}] Iter[{:03d}] Loss: {:.2f}".format(
                trainer.state.epoch, num_iter, trainer.state.output))

    # execution after every training epoch
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(trainer):
        # evaluate on training set
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        logger.info(
            "Training Results - Epoch: {} Avg-loss: {:.3f}\n Pix-acc: {}\n IoU: {}\n F1: {}\n"
            .format(trainer.state.epoch, metrics['loss'],
                    str(metrics['pix-acc']), str(metrics['iou']),
                    str(metrics['f1'])))

        # update state
        update_state(weight=model.state_dict(),
                     train_loss=metrics['loss'],
                     val_loss=state['val_loss'],
                     val_pix_acc=state['val_pix_acc'],
                     val_iou=state['val_iou'],
                     val_f1=state['val_f1'])

    # execution after every epoch
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(trainer):
        # evaluate test(validation) set
        evaluator.run(test_loader)
        metrics = evaluator.state.metrics
        logger.info(
            "Validation Results - Epoch: {} Avg-loss: {:.3f}\n Pix-acc: {}\n IoU: {}\n F1: {}\n"
            .format(trainer.state.epoch, metrics['loss'],
                    str(metrics['pix-acc']), str(metrics['iou']),
                    str(metrics['f1'])))

        # update scheduler
        lr_scheduler.step(metrics['loss'])

        # update and save state
        update_state(weight=model.state_dict(),
                     train_loss=state['train_loss'],
                     val_loss=metrics['loss'],
                     val_pix_acc=metrics['pix-acc'],
                     val_iou=metrics['iou'],
                     val_f1=metrics['f1'])

        path = ckpt_path.format(network=networks,
                                optimizer=optimizer,
                                lr=lr,
                                epoch=trainer.state.epoch)
        save_ckpt_file(path, state)

    trainer.run(train_loader, max_epochs=epochs)
Exemplo n.º 6
0
def train_without_ignite(model,
                         loss,
                         batch_size,
                         img_size,
                         epochs,
                         lr,
                         num_workers,
                         optimizer,
                         logger,
                         gray_image=False,
                         scheduler=None,
                         viz=True):
    import visdom
    from utils.metrics import Accuracy, IoU

    DEFAULT_PORT = 8097
    DEFAULT_HOSTNAME = "http://localhost"

    if viz:
        vis = visdom.Visdom(port=DEFAULT_PORT, server=DEFAULT_HOSTNAME)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    data_loader = {}

    joint_transforms = jnt_trnsf.Compose([
        jnt_trnsf.RandomCrop(img_size),
        jnt_trnsf.RandomRotate(5),
        jnt_trnsf.RandomHorizontallyFlip()
    ])

    train_image_transforms = std_trnsf.Compose([
        std_trnsf.ColorJitter(0.05, 0.05, 0.05, 0.05),
        std_trnsf.ToTensor(),
        std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_joint_transforms = jnt_trnsf.Compose([jnt_trnsf.Safe32Padding()])

    test_image_transforms = std_trnsf.Compose([
        std_trnsf.ToTensor(),
        std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    mask_transforms = std_trnsf.Compose([std_trnsf.ToTensor()])

    data_loader['train'] = get_loader(dataset='figaro',
                                      train=True,
                                      joint_transforms=joint_transforms,
                                      image_transforms=train_image_transforms,
                                      mask_transforms=mask_transforms,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      gray_image=gray_image)

    data_loader['test'] = get_loader(dataset='figaro',
                                     train=False,
                                     joint_transforms=test_joint_transforms,
                                     image_transforms=test_image_transforms,
                                     mask_transforms=mask_transforms,
                                     batch_size=1,
                                     shuffle=True,
                                     num_workers=num_workers,
                                     gray_image=gray_image)

    for epoch in range(epochs):
        for phase in ['train', 'test']:
            if phase == 'train':
                model.train(True)
            else:
                prev_grad_state = torch.is_grad_enabled()
                torch.set_grad_enabled(False)
                model.train(False)

            running_loss = 0.0

            for i, data in enumerate(tqdm(data_loader[phase],
                                          file=sys.stdout)):
                if i == len(data_loader[phase]) - 1: break
                data_ = [
                    t.to(device) if isinstance(t, torch.Tensor) else t
                    for t in data
                ]

                if gray_image:
                    img, mask, gray = data_
                else:
                    img, mask = data_

                model.zero_grad()

                pred_mask = model(img)

                if gray_image:
                    l = loss(pred_mask, mask, gray)
                else:
                    l = loss(pred_mask, mask)

                if phase == 'train':
                    l.backward()
                    optimizer.step()

                running_loss += l.item()

            epoch_loss = running_loss / len(data_loader[phase])

            if phase == 'train':
                logger.info(
                    f"Training Results - Epoch: {epoch} Avg-loss: {epoch_loss:.3f}"
                )
                if viz:
                    vis.images(
                        [
                            np.clip(pred_mask.detach().cpu().numpy()[0], 0, 1),
                            mask.detach().cpu().numpy()[0]
                        ],
                        opts=dict(title=f'pred img for {epoch}-th iter'))

            if phase == 'test':
                if viz:
                    vis.images(
                        [
                            np.clip(pred_mask.detach().cpu().numpy()[0], 0, 1),
                            mask.detach().cpu().numpy()[0]
                        ],
                        opts=dict(title=f'pred img for {epoch}-th iter'))
                logger.info(
                    f"Test Results - Epoch: {epoch} Avg-loss: {epoch_loss:.3f}"
                )

                if scheduler: scheduler.step(epoch_loss)

                torch.set_grad_enabled(prev_grad_state)
Exemplo n.º 7
0
def train_with_correspondences(save_folder, startnet, args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network().to(device)

    if args['snapshot'] == 'latest':
        args['snapshot'] = get_latest_network_name(save_folder)

    if len(args['snapshot']) == 0:  # If start from beginning
        state_dict = torch.load(startnet)
        # needed since we slightly changed the structure of the network in
        # pspnet
        state_dict = rename_keys_to_match(state_dict)
        net.load_state_dict(state_dict)  # load original weights

        start_iter = 0
        args['best_record'] = {
            'iter': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:  # If continue training
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(save_folder,
                                    args['snapshot'])))  # load weights
        split_snapshot = args['snapshot'].split('_')

        start_iter = int(split_snapshot[1])
        with open(os.path.join(save_folder, 'bestval.txt')) as f:
            best_val_dict_str = f.read()
        args['best_record'] = eval(best_val_dict_str.rstrip())

    net.train()
    freeze_bn(net)

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()

    sliding_crop_im = joint_transforms.SlidingCropImageOnly(
        713, args['stride_rate'])

    input_transform = model_config.input_transform
    pre_validation_transform = model_config.pre_validation_transform

    target_transform = extended_transforms.MaskToTensor()

    train_joint_transform_seg = joint_transforms.Compose([
        joint_transforms.Resize(1024),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomCrop(713)
    ])

    train_joint_transform_corr = corr_transforms.Compose([
        corr_transforms.CorrResize(1024),
        corr_transforms.CorrRandomCrop(713)
    ])

    # keep list of segmentation loaders and validators
    seg_loaders = list()
    validators = list()

    # Correspondences
    corr_set = correspondences.Correspondences(
        corr_set_config.correspondence_path,
        corr_set_config.correspondence_im_path,
        input_size=(713, 713),
        mean_std=model_config.mean_std,
        input_transform=input_transform,
        joint_transform=train_joint_transform_corr)
    corr_loader = DataLoader(corr_set,
                             batch_size=args['train_batch_size'],
                             num_workers=args['n_workers'],
                             shuffle=True)

    # Cityscapes Training
    c_config = data_configs.CityscapesConfig()
    seg_set_cs = cityscapes.CityScapes(
        c_config.train_im_folder,
        c_config.train_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    seg_loader_cs = DataLoader(seg_set_cs,
                               batch_size=args['train_batch_size'],
                               num_workers=args['n_workers'],
                               shuffle=True)
    seg_loaders.append(seg_loader_cs)

    # Cityscapes Validation
    val_set_cs = cityscapes.CityScapes(
        c_config.val_im_folder,
        c_config.val_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    val_loader_cs = DataLoader(val_set_cs,
                               batch_size=1,
                               num_workers=args['n_workers'],
                               shuffle=False)
    validator_cs = Validator(val_loader_cs,
                             n_classes=c_config.n_classes,
                             save_snapshot=False,
                             extra_name_str='Cityscapes')
    validators.append(validator_cs)

    # Vistas Training and Validation
    if args['include_vistas']:
        v_config = data_configs.VistasConfig(
            use_subsampled_validation_set=True, use_cityscapes_classes=True)

        seg_set_vis = cityscapes.CityScapes(
            v_config.train_im_folder,
            v_config.train_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            joint_transform=train_joint_transform_seg,
            sliding_crop=None,
            transform=input_transform,
            target_transform=target_transform)
        seg_loader_vis = DataLoader(seg_set_vis,
                                    batch_size=args['train_batch_size'],
                                    num_workers=args['n_workers'],
                                    shuffle=True)
        seg_loaders.append(seg_loader_vis)

        val_set_vis = cityscapes.CityScapes(
            v_config.val_im_folder,
            v_config.val_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            sliding_crop=sliding_crop_im,
            transform=input_transform,
            target_transform=target_transform,
            transform_before_sliding=pre_validation_transform)
        val_loader_vis = DataLoader(val_set_vis,
                                    batch_size=1,
                                    num_workers=args['n_workers'],
                                    shuffle=False)
        validator_vis = Validator(val_loader_vis,
                                  n_classes=v_config.n_classes,
                                  save_snapshot=False,
                                  extra_name_str='Vistas')
        validators.append(validator_vis)
    else:
        seg_loader_vis = None
        map_validator = None

    # Extra Training
    extra_seg_set = cityscapes.CityScapes(
        corr_set_config.train_im_folder,
        corr_set_config.train_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    extra_seg_loader = DataLoader(extra_seg_set,
                                  batch_size=args['train_batch_size'],
                                  num_workers=args['n_workers'],
                                  shuffle=True)
    seg_loaders.append(extra_seg_loader)

    # Extra Validation
    extra_val_set = cityscapes.CityScapes(
        corr_set_config.val_im_folder,
        corr_set_config.val_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    extra_val_loader = DataLoader(extra_val_set,
                                  batch_size=1,
                                  num_workers=args['n_workers'],
                                  shuffle=False)
    extra_validator = Validator(extra_val_loader,
                                n_classes=corr_set_config.n_classes,
                                save_snapshot=True,
                                extra_name_str='Extra')
    validators.append(extra_validator)

    # Loss setup
    if args['corr_loss_type'] == 'class':
        corr_loss_fct = CorrClassLoss(input_size=[713, 713])
    else:
        corr_loss_fct = FeatureLoss(
            input_size=[713, 713],
            loss_type=args['corr_loss_type'],
            feat_dist_threshold_match=args['feat_dist_threshold_match'],
            feat_dist_threshold_nomatch=args['feat_dist_threshold_nomatch'],
            n_not_matching=0)

    seg_loss_fct = torch.nn.CrossEntropyLoss(
        reduction='elementwise_mean',
        ignore_index=cityscapes.ignore_label).to(device)

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(os.path.join(save_folder, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    if len(args['snapshot']) == 0:
        f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)
    else:
        clean_log_before_continuing(os.path.join(save_folder, 'log.log'),
                                    start_iter)
        f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    ##########################################################################
    #
    #       MAIN TRAINING CONSISTS OF ALL SEGMENTATION LOSSES AND A CORRESPONDENCE LOSS
    #
    ##########################################################################
    softm = torch.nn.Softmax2d()

    val_iter = 0
    train_corr_loss = AverageMeter()
    train_seg_cs_loss = AverageMeter()
    train_seg_extra_loss = AverageMeter()
    train_seg_vis_loss = AverageMeter()

    seg_loss_meters = list()
    seg_loss_meters.append(train_seg_cs_loss)
    if args['include_vistas']:
        seg_loss_meters.append(train_seg_vis_loss)
    seg_loss_meters.append(train_seg_extra_loss)

    curr_iter = start_iter

    for i in range(args['max_iter']):
        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        #######################################################################
        #       SEGMENTATION UPDATE STEP
        #######################################################################
        #
        for si, seg_loader in enumerate(seg_loaders):
            # get segmentation training sample
            inputs, gts = next(iter(seg_loader))

            slice_batch_pixel_size = inputs.size(0) * inputs.size(
                2) * inputs.size(3)

            inputs = inputs.to(device)
            gts = gts.to(device)

            optimizer.zero_grad()
            outputs, aux = net(inputs)

            main_loss = args['seg_loss_weight'] * seg_loss_fct(outputs, gts)
            aux_loss = args['seg_loss_weight'] * seg_loss_fct(aux, gts)
            loss = main_loss + 0.4 * aux_loss

            loss.backward()
            optimizer.step()

            seg_loss_meters[si].update(main_loss.item(),
                                       slice_batch_pixel_size)

        #######################################################################
        #       CORRESPONDENCE UPDATE STEP
        #######################################################################
        if args['corr_loss_weight'] > 0 and args[
                'n_iterations_before_corr_loss'] < curr_iter:
            img_ref, img_other, pts_ref, pts_other, weights = next(
                iter(corr_loader))

            # Transfer data to device
            # img_ref is from the "good" sequence with generally better
            # segmentation results
            img_ref = img_ref.to(device)
            img_other = img_other.to(device)
            pts_ref = [p.to(device) for p in pts_ref]
            pts_other = [p.to(device) for p in pts_other]
            weights = [w.to(device) for w in weights]

            # Forward pass
            if args['corr_loss_type'] == 'hingeF':  # Works on features
                net.output_all = True
                with torch.no_grad():
                    output_feat_ref, aux_feat_ref, output_ref, aux_ref = net(
                        img_ref)
                output_feat_other, aux_feat_other, output_other, aux_other = net(
                    img_other
                )  # output1 must be last to backpropagate derivative correctly
                net.output_all = False

            else:  # Works on class probs
                with torch.no_grad():
                    output_ref, aux_ref = net(img_ref)
                    if args['corr_loss_type'] != 'hingeF' and args[
                            'corr_loss_type'] != 'hingeC':
                        output_ref = softm(output_ref)
                        aux_ref = softm(aux_ref)

                # output1 must be last to backpropagate derivative correctly
                output_other, aux_other = net(img_other)
                if args['corr_loss_type'] != 'hingeF' and args[
                        'corr_loss_type'] != 'hingeC':
                    output_other = softm(output_other)
                    aux_other = softm(aux_other)

            # Correspondence filtering
            pts_ref_orig, pts_other_orig, weights_orig, batch_inds_to_keep_orig = correspondences.refine_correspondence_sample(
                output_ref,
                output_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_ref_orig)
                if b.item() > 0
            ]
            pts_other_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_other_orig)
                if b.item() > 0
            ]
            weights_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, weights_orig)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                output_vals_ref = output_feat_ref[batch_inds_to_keep_orig]
                output_vals_other = output_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                output_vals_ref = output_ref[batch_inds_to_keep_orig]
                output_vals_other = output_other[batch_inds_to_keep_orig]

            pts_ref_aux, pts_other_aux, weights_aux, batch_inds_to_keep_aux = correspondences.refine_correspondence_sample(
                aux_ref,
                aux_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_ref_aux)
                if b.item() > 0
            ]
            pts_other_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_other_aux)
                if b.item() > 0
            ]
            weights_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, weights_aux)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                aux_vals_ref = aux_feat_ref[batch_inds_to_keep_orig]
                aux_vals_other = aux_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                aux_vals_ref = aux_ref[batch_inds_to_keep_aux]
                aux_vals_other = aux_other[batch_inds_to_keep_aux]

            optimizer.zero_grad()

            # correspondence loss
            if output_vals_ref.size(0) > 0:
                loss_corr_hr = corr_loss_fct(output_vals_ref,
                                             output_vals_other, pts_ref_orig,
                                             pts_other_orig, weights_orig)
            else:
                loss_corr_hr = 0 * output_vals_other.sum()

            if aux_vals_ref.size(0) > 0:
                loss_corr_aux = corr_loss_fct(
                    aux_vals_ref, aux_vals_other, pts_ref_aux, pts_other_aux,
                    weights_aux)  # use output from img1 as "reference"
            else:
                loss_corr_aux = 0 * aux_vals_other.sum()

            loss_corr = args['corr_loss_weight'] * \
                (loss_corr_hr + 0.4 * loss_corr_aux)
            loss_corr.backward()

            optimizer.step()
            train_corr_loss.update(loss_corr.item())

        #######################################################################
        #       LOGGING ETC
        #######################################################################
        curr_iter += 1
        val_iter += 1

        writer.add_scalar('train_seg_loss_cs', train_seg_cs_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_extra', train_seg_extra_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_vis', train_seg_vis_loss.avg,
                          curr_iter)
        writer.add_scalar('train_corr_loss', train_corr_loss.avg, curr_iter)
        writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter)

        if (i + 1) % args['print_freq'] == 0:
            str2write = '[iter %d / %d], [train corr loss %.5f] , [seg cs loss %.5f], [seg vis loss %.5f], [seg extra loss %.5f]. [lr %.10f]' % (
                curr_iter, len(corr_loader), train_corr_loss.avg,
                train_seg_cs_loss.avg, train_seg_vis_loss.avg,
                train_seg_extra_loss.avg, optimizer.param_groups[1]['lr'])
            print(str2write)
            f_handle.write(str2write + "\n")

        if val_iter >= args['val_interval']:
            val_iter = 0
            for validator in validators:
                validator.run(net,
                              optimizer,
                              args,
                              curr_iter,
                              save_folder,
                              f_handle,
                              writer=writer)

    # Post training
    f_handle.close()
    writer.close()
def main(train_args):
    net = PSPNet(num_classes=voc.num_classes).cuda()

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print ('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_simul_transform = simul_transforms.Compose([
        simul_transforms.RandomSized(train_args['input_size']),
        simul_transforms.RandomRotate(10),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Scale(train_args['input_size'])
    train_input_transform = standard_transforms.Compose([
        extended_transforms.RandomGaussianBlur(),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = voc.VOC('train', joint_transform=train_simul_transform, transform=train_input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=train_args['train_batch_size'], num_workers=8, shuffle=True)
    val_set = voc.VOC('val', joint_transform=val_simul_transform, transform=val_input_transform,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=1, num_workers=8, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=True, ignore_index=voc.ignore_label).cuda()

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], momentum=train_args['momentum'])

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n')

    train(train_loader, net, criterion, optimizer, curr_epoch, train_args, val_loader, restore_transform, visualize)
def main():
    # args = parse_args()

    torch.backends.cudnn.benchmark = True
    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    # # if args.seed:
    # random.seed(args.seed)
    # np.random.seed(args.seed)
    # torch.manual_seed(args.seed)
    # # if args.gpu:
    # torch.cuda.manual_seed_all(args.seed)
    seed = 63
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # if args.gpu:
    torch.cuda.manual_seed_all(seed)

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    # train_transforms = transforms.Compose([
    # 	transforms.RandomCrop(args['crop_size']),
    # 	transforms.RandomRotation(90),
    # 	transforms.RandomHorizontalFlip(p=0.5),
    # 	transforms.RandomVerticalFlip(p=0.5),

    # 	])
    short_size = int(min(args['input_size']) / 0.875)
    # val_transforms = transforms.Compose([
    # 	transforms.Scale(short_size, interpolation=Image.NEAREST),
    # 	# joint_transforms.Scale(short_size),
    # 	transforms.CenterCrop(args['input_size'])
    # 	])
    train_joint_transform = joint_transforms.Compose([
        # joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['crop_size']),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomRotate(90)
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(*mean_std)])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = transforms.Compose(
        [extended_transforms.DeNormalize(*mean_std),
         transforms.ToPILImage()])
    visualize = transforms.ToTensor()

    train_set = cityscapes.CityScapes('train',
                                      joint_transform=train_joint_transform,
                                      transform=input_transform,
                                      target_transform=target_transform)
    # train_set = cityscapes.CityScapes('train', transform=train_transforms)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True)
    val_set = cityscapes.CityScapes('val',
                                    joint_transform=val_joint_transform,
                                    transform=input_transform,
                                    target_transform=target_transform)
    # val_set = cityscapes.CityScapes('val', transform=val_transforms)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=8,
                            shuffle=True)

    print(len(train_loader), len(val_loader))

    # sdf

    vgg_model = VGGNet(requires_grad=True, remove_fc=True)
    net = FCN8s(pretrained_net=vgg_model,
                n_class=cityscapes.num_classes,
                dropout_rate=0.4)
    # net.apply(init_weights)
    criterion = nn.CrossEntropyLoss(ignore_index=cityscapes.ignore_label)

    optimizer = optim.Adam(net.parameters(), lr=1e-4)

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10)

    vgg_model = vgg_model.to(device)
    net = net.to(device)

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9][:-4])
        }

    criterion.to(device)

    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, device, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, device, criterion, optimizer,
                            epoch, args, restore_transform, visualize)
        scheduler.step(val_loss)
Exemplo n.º 10
0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_filters = [16, 32]
batch_size = 32
lr = 1e-4
weight_decay = 1e-5
epochs = 600

# unet setting
num_conv_blocks = 1
# prob unet only
num_convs_per_block = num_conv_blocks
num_convs_fcomb = 1
partial_data = False
latent_dim = 6
beta = 10
# isotropic = False

# kaiming_normal and orthogonal
initializers = {'w': 'kaiming_normal', 'b': 'normal'}
# initializers = {'w':None, 'b':None}

# Transforms
joint_transfm = joint_transforms.Compose([joint_transforms.RandomHorizontallyFlip(),
                                          joint_transforms.RandomSizedCrop(128),
                                          joint_transforms.RandomRotate(60)])
# joint_transfm = None
input_transfm = transforms.Compose([transforms.ToPILImage()])
target_transfm = transforms.Compose([transforms.ToTensor()])
# joint_transfm=None
# input_transfm=None
Exemplo n.º 11
0
# define data loader ===========================================================
mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
normal_mean_std = ([0.5, 0.5, 0.5], [0.25, 0.25, 0.25])
short_size = int(min(inp_size) / 0.875)

if stop_random_rotation:
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomCrop(inp_size)
    ])
else:
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.RandomRotate(rot_angle),
        joint_transforms.RandomCrop(inp_size)
    ])

val_joint_transform = joint_transforms.Compose(
    [joint_transforms.FreeScale(inp_size)])

if stop_jitter:
    train_input_transform = standard_transforms.Compose([
        standard_transforms.ColorJitter(brightness=0.5,
                                        contrast=0.5,
                                        saturation=0.5,
                                        hue=0.5),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
Exemplo n.º 12
0
def main():
    net = PSPNet(19)
    net.load_pretrained_model(
        model_path='./Caffe-PSPNet/pspnet101_cityscapes.caffemodel')
    for param in net.parameters():
        param.requires_grad = False
    net.cbr_final = conv2DBatchNormRelu(4096, 128, 3, 1, 1, False)
    net.dropout = nn.Dropout2d(p=0.1, inplace=True)
    net.classification = nn.Conv2d(128, kitti_binary.num_classes, 1, 1, 0)

    # Find total parameters and trainable parameters
    total_params = sum(p.numel() for p in net.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(p.numel() for p in net.parameters()
                                 if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')

    if len(args['snapshot']) == 0:
        # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth')))
        args['best_record'] = {
            'epoch': 0,
            'iter': 0,
            'val_loss': 1e10,
            'accu': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'iter': int(split_snapshot[3]),
            'val_loss': float(split_snapshot[5]),
            'accu': float(split_snapshot[7])
        }
    net.cuda(args['gpu']).train()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(args['longer_size']),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    train_input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    train_set = kitti_binary.KITTI(mode='train',
                                   joint_transform=train_joint_transform,
                                   transform=train_input_transform,
                                   target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True)
    val_set = kitti_binary.KITTI(mode='val',
                                 transform=val_input_transform,
                                 target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=8,
                            shuffle=False)

    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.full([1], 1.05)).cuda(
        args['gpu'])

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    train(train_loader, net, criterion, optimizer, args, val_loader)