Esempio n. 1
0
def main():
    net = U_Net(img_ch=1, num_classes=3).to(device)

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(384),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    center_crop = joint_transforms.CenterCrop(crop_size)
    train_input_transform = extended_transforms.ImgToTensor()

    target_transform = extended_transforms.MaskToTensor()
    make_dataset_fn = bladder.make_dataset_v2
    train_set = bladder.Bladder(data_path,
                                'train',
                                joint_transform=train_joint_transform,
                                center_crop=center_crop,
                                transform=train_input_transform,
                                target_transform=target_transform,
                                make_dataset_fn=make_dataset_fn)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

    if loss_name == 'dice_':
        criterion = SoftDiceLossV2(activation='sigmoid',
                                   num_classes=3).to(device)
    elif loss_name == 'bcew_':
        criterion = nn.BCEWithLogitsLoss().to(device)

    optimizer = optim.Adam(net.parameters(), lr=1e-4)

    train(train_loader, net, criterion, optimizer, n_epoch, 0)
Esempio n. 2
0
def main():
    net = PSPNet(num_classes=voc.num_classes).cuda()

    if len(args['snapshot']) == 0:
        # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth')))
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(args['longer_size']),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    sliding_crop = joint_transforms.SlidingCrop(args['crop_size'],
                                                args['stride_rate'],
                                                voc.ignore_label)
    train_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    visualize = standard_transforms.Compose([
        standard_transforms.Resize(args['val_img_display_size']),
        standard_transforms.ToTensor()
    ])

    train_set = voc.VOC('train',
                        joint_transform=train_joint_transform,
                        sliding_crop=sliding_crop,
                        transform=train_input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=1,
                              shuffle=True,
                              drop_last=True)
    val_set = voc.VOC('val',
                      transform=val_input_transform,
                      sliding_crop=sliding_crop,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=1,
                            shuffle=False,
                            drop_last=True)

    criterion = CrossEntropyLoss2d(size_average=True,
                                   ignore_index=voc.ignore_label).cuda()

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    train(train_loader, net, criterion, optimizer, curr_epoch, args,
          val_loader, visualize)
Esempio n. 3
0
def main():
    net = FCN8s(num_classes=cityscapes.num_classes, caffe=True).cuda()

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                               'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                               'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    short_size = int(min(args['input_size']) / 0.875)
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['input_size']),
        joint_transforms.RandomHorizontallyFlip()
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels()
    ])
    visualize = standard_transforms.ToTensor()

    train_set = cityscapes.CityScapes('fine', 'train', joint_transform=train_joint_transform,
                                      transform=input_transform, target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True)
    val_set = cityscapes.CityScapes('fine', 'val', joint_transform=val_joint_transform, transform=input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=8, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=False, ignore_index=cityscapes.ignore_label).cuda()

    optimizer = optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], betas=(args['momentum'], 0.999))

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10, verbose=True)
    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch, args, restore_transform, visualize)
        scheduler.step(val_loss)
Esempio n. 4
0
def testing(net, train_args, curr_iter):
    net.eval()

    mean_std = ([0.9584, 0.9588, 0.9586], [0.1246, 0.1223, 0.1224])

    test_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    #target_transform = extended_transforms.MaskToTensor()
    scaleTest_transform = simul_transforms.Scale(512)
    test_simul_transform = simul_transforms.Scale(train_args['input_size'])
    #val_input_transform = standard_transforms.Compose([
    #    standard_transforms.ToTensor(),
    #    standard_transforms.Normalize(*mean_std)
    #test_set = doc.DOC('val',Dataroot, transform=val_input_transform,
    #                 target_transform=target_transform)

    # segmentation on test images
    test_set = doc.DOC('test',
                       ckpt_path,
                       joint_transform=test_simul_transform,
                       transform=test_input_transform)
    test_loader = DataLoader(test_set,
                             batch_size=1,
                             num_workers=1,
                             shuffle=False)

    check_mkdir(os.path.join(ckpt_path, exp_name, str(curr_iter)))
    listOfImgs = [
        '2018_101008181.jpg', '2018_101001826.jpg', '2016_101000408.jpg',
        '2009_052067571.jpg', '2010_055401143.jpg'
    ]
    for vi, data in enumerate(test_loader):
        img_name, img = data
        img_name = img_name[0]
        with torch.no_grad():
            img = img.cuda()
            output = net(img)

            prediction = output.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()
            imput_img = img.squeeze_(0).cpu().numpy()
            prediction = doc.colorize_mask_combine(
                prediction, ckpt_path + 'data/img/' + img_name)
            #if img_name in listOfImgs:
            if curr_iter > 1:
                prediction.save(
                    os.path.join(ckpt_path, exp_name, str(curr_iter),
                                 img_name))
    call([
        "rsync", "-avz",
        os.path.join(ckpt_path, exp_name),
        "[email protected]:/home/jobinkv/Documents/r1/19wavc/"
    ])
    #print '%d / %d' % (vi + 1, len(test_loader))

    test_set_eva = doc.DOC('test_eva',
                           ckpt_path,
                           joint_transform=test_simul_transform,
                           transform=test_input_transform,
                           target_transform=target_transform)
    test_loader_eva = DataLoader(test_set_eva,
                                 batch_size=1,
                                 num_workers=1,
                                 shuffle=False)

    gts_all, predictions_all = [], []

    for vi, data in enumerate(test_loader_eva):
        img, gts = data
        with torch.no_grad():
            img = img.cuda()
            gts = gts.cuda()

            output = net(img)
            prediction = output.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()

            gts_all.append(gts.squeeze_(0).cpu().numpy())
            predictions_all.append(prediction)

    acc, acc_cls, mean_iu, fwavacc, sep_iu = evaluate(predictions_all, gts_all,
                                                      doc.num_classes)
    del predictions_all
    print(
        '--------------------------------------------------------------------')
    print(
        '[test acc %.5f], [test acc_cls %.5f], [test mean_iu %.5f], [test fwavacc %.5f]'
        % (acc, acc_cls, mean_iu, fwavacc))
    print(
        '--------------------------------------------------------------------')
    sep_iu = sep_iu.tolist()
    sep_iu.append(mean_iu)
    sep_iu.insert(0, curr_iter)
    sep_iou_test.append(sep_iu)
Esempio n. 5
0
def main(train_args):
    print('No of classes', doc.num_classes)
    if train_args['network'] == 'psp':
        net = PSPNet(num_classes=doc.num_classes,
                     resnet=resnet,
                     res_path=res_path).cuda()
    elif train_args['network'] == 'mfcn':
        net = MFCN(num_classes=doc.num_classes, use_aux=True).cuda()
    elif train_args['network'] == 'psppen':
        net = PSPNet(num_classes=doc.num_classes,
                     resnet=resnet,
                     res_path=res_path).cuda()
    print("number of cuda devices = ", torch.cuda.device_count())
    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, 'model_' + exp_name,
                             train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    net = torch.nn.DataParallel(net,
                                device_ids=list(
                                    range(torch.cuda.device_count())))
    net.train()
    mean_std = ([0.9584, 0.9588, 0.9586], [0.1246, 0.1223, 0.1224])
    weight = torch.FloatTensor(doc.num_classes)
    #weight[0] = 1.0/0.5511 # background
    train_simul_transform = simul_transforms.Compose([
        simul_transforms.RandomSized(train_args['input_size']),
        simul_transforms.RandomRotate(3),
        #simul_transforms.RandomHorizontallyFlip()
        simul_transforms.Scale(train_args['input_size']),
        simul_transforms.CenterCrop(train_args['input_size'])
    ])
    val_simul_transform = simul_transforms.Scale(train_args['input_size'])
    train_input_transform = standard_transforms.Compose([
        extended_transforms.RandomGaussianBlur(),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    train_set = doc.DOC('train',
                        Dataroot,
                        joint_transform=train_simul_transform,
                        transform=train_input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=train_args['train_batch_size'],
                              num_workers=1,
                              shuffle=True,
                              drop_last=True)
    train_loader_temp = DataLoader(train_set,
                                   batch_size=1,
                                   num_workers=1,
                                   shuffle=True,
                                   drop_last=True)
    train_args['No_train_images'] = len(train_loader_temp)
    del train_loader_temp
    if train_args['No_train_images'] == 87677:
        train_args['Type_of_train_image'] = 'All Synthetic slide image'
    elif train_args['No_train_images'] == 150:
        train_args['Type_of_train_image'] = 'Real image'
    elif train_args['No_train_images'] == 84641:
        train_args['Type_of_train_image'] = 'Synthetic image'
    elif train_args['No_train_images'] == 151641:
        train_args['Type_of_train_image'] = 'Real + Synthetic image'
    val_set = doc.DOC('val',
                      Dataroot,
                      joint_transform=val_simul_transform,
                      transform=val_input_transform,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=1,
                            shuffle=False,
                            drop_last=True)
    #criterion = CrossEntropyLoss2d(weight = weight, size_average = True, ignore_index = doc.ignore_label).cuda()
    criterion = CrossEntropyLoss2d(size_average=True,
                                   ignore_index=doc.ignore_label).cuda()
    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * train_args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        train_args['lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                          momentum=train_args['momentum'])
    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, 'model_' + exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']
    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(train_args) + '\n\n')
    train(train_loader, net, criterion, optimizer, curr_epoch, train_args,
          val_loader)
Esempio n. 6
0
def get_transforms(scale_size, input_size, region_size, supervised, test,
                   al_algorithm, full_res, dataset):
    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    if scale_size == 0:
        print('(Data loading) Not scaling the data')
        print('(Data loading) Random crops of ' + str(input_size) +
              ' in training')
        print('(Data loading) No crops in validation')
        if supervised:
            train_joint_transform = joint_transforms.Compose([
                joint_transforms.RandomCrop(input_size),
                joint_transforms.RandomHorizontallyFlip()
            ])
        else:
            train_joint_transform = joint_transforms.ComposeRegion([
                joint_transforms.RandomCropRegion(input_size,
                                                  region_size=region_size),
                joint_transforms.RandomHorizontallyFlip()
            ])
        if (not test and al_algorithm == 'ralis') and not full_res:
            val_joint_transform = joint_transforms.Scale(1024)
        else:
            val_joint_transform = None
        al_train_joint_transform = joint_transforms.ComposeRegion([
            joint_transforms.CropRegion(region_size, region_size=region_size),
            joint_transforms.RandomHorizontallyFlip()
        ])
    else:
        print('(Data loading) Scaling training data: ' + str(scale_size) +
              ' width dimension')
        print('(Data loading) Random crops of ' + str(input_size) +
              ' in training')
        print('(Data loading) No crops nor scale_size in validation')
        if supervised:
            train_joint_transform = joint_transforms.Compose([
                joint_transforms.Scale(scale_size),
                joint_transforms.RandomCrop(input_size),
                joint_transforms.RandomHorizontallyFlip()
            ])
        else:
            train_joint_transform = joint_transforms.ComposeRegion([
                joint_transforms.Scale(scale_size),
                joint_transforms.RandomCropRegion(input_size,
                                                  region_size=region_size),
                joint_transforms.RandomHorizontallyFlip()
            ])
        al_train_joint_transform = joint_transforms.ComposeRegion([
            joint_transforms.Scale(scale_size),
            joint_transforms.CropRegion(region_size, region_size=region_size),
            joint_transforms.RandomHorizontallyFlip()
        ])
        if dataset == 'gta_for_camvid':
            val_joint_transform = joint_transforms.ComposeRegion(
                [joint_transforms.Scale(scale_size)])
        else:
            val_joint_transform = None
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()

    return input_transform, target_transform, train_joint_transform, val_joint_transform, al_train_joint_transform
Esempio n. 7
0
    'val_img_sample_rate':
    0.05  # randomly sample some validation results to display
}

# Paths to trained models & epoch counts

DUCHDC_trainedModelPath = './ducModelFinal.pth'
FCN8_trainedModelPath = './fcnModelFinal.pth'
Unet_trainedModelPath = './unetModelFinal.pth'

# Transforms
mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
short_size = int(min(args['input_size']) / 0.875)

joint_transform = joint_transforms.Compose([
    joint_transforms.Scale(short_size),
    joint_transforms.RandomCrop(args['input_size']),
    joint_transforms.RandomHorizontallyFlip()
])

input_transform = standard_transforms.Compose(
    [standard_transforms.ToTensor(),
     standard_transforms.Normalize(*mean_std)])

target_transform = extended_transforms.MaskToTensor()

restore_transform = standard_transforms.Compose([
    extended_transforms.DeNormalize(*mean_std),
    standard_transforms.ToPILImage()
])
def main(train_args):
    net = PSPNet(num_classes=voc.num_classes).cuda()

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print ('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_simul_transform = simul_transforms.Compose([
        simul_transforms.RandomSized(train_args['input_size']),
        simul_transforms.RandomRotate(10),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Scale(train_args['input_size'])
    train_input_transform = standard_transforms.Compose([
        extended_transforms.RandomGaussianBlur(),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = voc.VOC('train', joint_transform=train_simul_transform, transform=train_input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=train_args['train_batch_size'], num_workers=8, shuffle=True)
    val_set = voc.VOC('val', joint_transform=val_simul_transform, transform=val_input_transform,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=1, num_workers=8, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=True, ignore_index=voc.ignore_label).cuda()

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], momentum=train_args['momentum'])

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n')

    train(train_loader, net, criterion, optimizer, curr_epoch, train_args, val_loader, restore_transform, visualize)
Esempio n. 9
0
def main():
    net = FCN32VGG(num_classes=mapillary.num_classes).cuda()

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    short_size = int(min(args['input_size']) / 0.875)
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['input_size']),
        joint_transforms.RandomHorizontallyFlip()
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])
    visualize = standard_transforms.ToTensor()

    train_set = mapillary.Mapillary('semantic',
                                    'training',
                                    joint_transform=train_joint_transform,
                                    transform=input_transform,
                                    target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True,
                              pin_memory=True)
    val_set = mapillary.Mapillary('semantic',
                                  'validation',
                                  joint_transform=val_joint_transform,
                                  transform=input_transform,
                                  target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=8,
                            shuffle=False,
                            pin_memory=True)

    criterion = CrossEntropyLoss2d(size_average=False).cuda()

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'])

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()).replace(':', '-') + '.txt'),
        'w').write(str(args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  patience=args['lr_patience'],
                                  min_lr=1e-10)
    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch, args,
                            restore_transform, visualize)
        scheduler.step(val_loss)

    torch.save(net.state_dict(), PATH)
def main():

    torch.backends.cudnn.benchmark = True
    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    vgg_model = VGGNet(requires_grad=True, remove_fc=True)
    net = FCN8s(pretrained_net=vgg_model,
                n_class=cityscapes.num_classes,
                dropout_rate=0.4)
    print('load model ' + args['snapshot'])

    vgg_model = vgg_model.to(device)
    net = net.to(device)

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, args['exp_name'],
                                args['snapshot'])))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    short_size = int(min(args['input_size']) / 0.875)
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(*mean_std)])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = transforms.Compose(
        [extended_transforms.DeNormalize(*mean_std),
         transforms.ToPILImage()])

    # test_set = cityscapes.CityScapes('test', transform=test_transform)

    test_set = cityscapes.CityScapes('test',
                                     joint_transform=val_joint_transform,
                                     transform=test_transform,
                                     target_transform=target_transform)

    test_loader = DataLoader(test_set,
                             batch_size=1,
                             num_workers=8,
                             shuffle=False)

    transform = transforms.ToPILImage()

    check_mkdir(os.path.join(ckpt_path, args['exp_name'], 'test'))

    gts_all, predictions_all = [], []
    count = 0
    for vi, data in enumerate(test_loader):
        # img_name, img = data
        img_name, img, gts = data

        img_name = img_name[0]
        # print(img_name)
        img_name = img_name.split('/')[-1]
        # img.save(os.path.join(ckpt_path, args['exp_name'], 'test', img_name))

        img_transform = restore_transform(img[0])
        # img_transform = img_transform.convert('RGB')
        img_transform.save(
            os.path.join(ckpt_path, args['exp_name'], 'test', img_name))
        img_name = img_name.split('_leftImg8bit.png')[0]

        # img = Variable(img, volatile=True).cuda()
        img, gts = img.to(device), gts.to(device)
        output = net(img)

        prediction = output.data.max(1)[1].squeeze_(1).squeeze_(
            0).cpu().numpy()
        prediction_img = cityscapes.colorize_mask(prediction)
        # print(type(prediction_img))
        prediction_img.save(
            os.path.join(ckpt_path, args['exp_name'], 'test',
                         img_name + '.png'))
        # print(ckpt_path, args['exp_name'], 'test', img_name + '.png')

        print('%d / %d' % (vi + 1, len(test_loader)))
        gts_all.append(gts.data.cpu().numpy())
        predictions_all.append(prediction)
        # break

        # if count == 1:
        #     break
        # count += 1
    gts_all = np.concatenate(gts_all)
    predictions_all = np.concatenate(prediction)
    acc, acc_cls, mean_iou, _ = evaluate(predictions_all, gts_all,
                                         cityscapes.num_classes)

    print(
        '-----------------------------------------------------------------------------------------------------------'
    )
    print('[acc %.5f], [acc_cls %.5f], [mean_iu %.5f]' %
          (acc, acc_cls, mean_iu))
def main():
    # args = parse_args()

    torch.backends.cudnn.benchmark = True
    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    # # if args.seed:
    # random.seed(args.seed)
    # np.random.seed(args.seed)
    # torch.manual_seed(args.seed)
    # # if args.gpu:
    # torch.cuda.manual_seed_all(args.seed)
    seed = 63
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # if args.gpu:
    torch.cuda.manual_seed_all(seed)

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    # train_transforms = transforms.Compose([
    # 	transforms.RandomCrop(args['crop_size']),
    # 	transforms.RandomRotation(90),
    # 	transforms.RandomHorizontalFlip(p=0.5),
    # 	transforms.RandomVerticalFlip(p=0.5),

    # 	])
    short_size = int(min(args['input_size']) / 0.875)
    # val_transforms = transforms.Compose([
    # 	transforms.Scale(short_size, interpolation=Image.NEAREST),
    # 	# joint_transforms.Scale(short_size),
    # 	transforms.CenterCrop(args['input_size'])
    # 	])
    train_joint_transform = joint_transforms.Compose([
        # joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['crop_size']),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomRotate(90)
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(*mean_std)])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = transforms.Compose(
        [extended_transforms.DeNormalize(*mean_std),
         transforms.ToPILImage()])
    visualize = transforms.ToTensor()

    train_set = cityscapes.CityScapes('train',
                                      joint_transform=train_joint_transform,
                                      transform=input_transform,
                                      target_transform=target_transform)
    # train_set = cityscapes.CityScapes('train', transform=train_transforms)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True)
    val_set = cityscapes.CityScapes('val',
                                    joint_transform=val_joint_transform,
                                    transform=input_transform,
                                    target_transform=target_transform)
    # val_set = cityscapes.CityScapes('val', transform=val_transforms)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=8,
                            shuffle=True)

    print(len(train_loader), len(val_loader))

    # sdf

    vgg_model = VGGNet(requires_grad=True, remove_fc=True)
    net = FCN8s(pretrained_net=vgg_model,
                n_class=cityscapes.num_classes,
                dropout_rate=0.4)
    # net.apply(init_weights)
    criterion = nn.CrossEntropyLoss(ignore_index=cityscapes.ignore_label)

    optimizer = optim.Adam(net.parameters(), lr=1e-4)

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10)

    vgg_model = vgg_model.to(device)
    net = net.to(device)

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9][:-4])
        }

    criterion.to(device)

    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, device, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, device, criterion, optimizer,
                            epoch, args, restore_transform, visualize)
        scheduler.step(val_loss)
def main():
    args = parse_args()

    torch.backends.cudnn.benchmark = True
    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    # Random seed for reproducibility
    if args.seed:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if args.gpu:
            torch.cuda.manual_seed_all(args.seed)

    # seed = 63
    # random.seed(seed)
    # np.random.seed(seed)
    # torch.manual_seed(seed)
    # # if args.gpu:
    # torch.cuda.manual_seed_all(seed)

    denoramlize_argument = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_transforms = transforms.Compose([
        transforms.RandomCrop((args.crop_size, args.crop_size)),
        transforms.RandomRotation(90),
        transforms.RandomHorizontalFlip(p=0.5),
    ])

    # train_joint_transform = joint_transforms.Compose([
    # 	# joint_transforms.Scale(img_resize_shape),
    # 	joint_transforms.RandomCrop(args['crop_size']),
    # 	joint_transforms.RandomHorizontallyFlip(),
    # 	joint_transforms.RandomRotate(90)
    # ])

    img_resize_shape = int(min(args.input_size) / 0.8)
    # val_transforms = transforms.Compose([
    # 	transforms.Scale(img_resize_shape, interpolation=Image.NEAREST),
    # 	transforms.CenterCrop(args['input_size'])
    # 	])

    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(img_resize_shape),
        joint_transforms.CenterCrop(args.input_size)
    ])
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = transforms.Compose([
        extended_transforms.DeNormalize(*denoramlize_argument),
        transforms.ToPILImage()
    ])
    visualize = transforms.ToTensor()

    # train_set = games_data.CityScapes('train', joint_transform=train_joint_transform,
    # 								  transform=input_transform, target_transform=target_transform)
    train_set = games_data.CityScapes('train', transform=train_transforms)
    # train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True)
    train_loader = DataLoader(train_set,
                              batch_size=args.training_batch_size,
                              num_workers=8,
                              shuffle=True)
    val_set = games_data.CityScapes('val',
                                    joint_transform=val_joint_transform,
                                    transform=input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=args.val_batch_size,
                            num_workers=8,
                            shuffle=True)

    print(len(train_loader), len(val_loader))
    # sdf

    # Load pretrained VGG model
    vgg_model = VGGNet(requires_grad=True, remove_fc=True)

    # FCN architecture load
    model = FCN8s(pretrained_net=vgg_model,
                  n_class=games_data.num_classes,
                  dropout_rate=0.4)

    # Loss function
    criterion = nn.CrossEntropyLoss(ignore_index=games_data.ignore_label)

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Create directory for checkpoints
    exist_directory(ckpt_path)
    exist_directory(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     patience=args.lr_patience,
                                                     min_lr=1e-10)

    # Send model to CUDA device
    vgg_model = vgg_model.to(device)
    model = model.to(device)

    # Use if more than 1 GPU
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    if len(args.snapshot) == 0:
        curr_epoch = 1
        best_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        model.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        best_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9][:-4])
        }

    criterion.to(device)

    for epoch in range(curr_epoch, args.epochs + 1):
        train(train_loader, model, device, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, model, device, criterion, optimizer,
                            epoch, args, restore_transform, visualize)
        scheduler.step(val_loss)
Esempio n. 13
0
def main(train_args):
    net = FCN8s(num_classes=11)

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()
    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_joint_transform = joint_transforms.Compose([
        # joint_transforms.Scale(short_size),
        joint_transforms.Scale(2000),
        # joint_transforms.RandomCrop(args['input_size']),
        joint_transforms.RandomHorizontallyFlip()
    ])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = PRIMA('train', joint_transform=train_joint_transform, transform=input_transform, target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=1, num_workers=0, shuffle=True)

    it = iter(train_loader)
    first = next(it)
    import pdb; pdb.set_trace()

    criterion = CrossEntropyLoss2d(size_average=False).cuda()
    optimizer = optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], betas=(train_args['momentum'], 0.999))

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True)
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, train_args)
def main():
    net = PSPNet(19)
    net.load_pretrained_model(
        model_path='./Caffe-PSPNet/pspnet101_cityscapes.caffemodel')
    for param in net.parameters():
        param.requires_grad = False
    net.cbr_final = conv2DBatchNormRelu(4096, 128, 3, 1, 1, False)
    net.dropout = nn.Dropout2d(p=0.1, inplace=True)
    net.classification = nn.Conv2d(128, kitti_binary.num_classes, 1, 1, 0)

    # Find total parameters and trainable parameters
    total_params = sum(p.numel() for p in net.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(p.numel() for p in net.parameters()
                                 if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')

    if len(args['snapshot']) == 0:
        # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth')))
        args['best_record'] = {
            'epoch': 0,
            'iter': 0,
            'val_loss': 1e10,
            'accu': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'iter': int(split_snapshot[3]),
            'val_loss': float(split_snapshot[5]),
            'accu': float(split_snapshot[7])
        }
    net.cuda(args['gpu']).train()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(args['longer_size']),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    train_input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    train_set = kitti_binary.KITTI(mode='train',
                                   joint_transform=train_joint_transform,
                                   transform=train_input_transform,
                                   target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True)
    val_set = kitti_binary.KITTI(mode='val',
                                 transform=val_input_transform,
                                 target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=8,
                            shuffle=False)

    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.full([1], 1.05)).cuda(
        args['gpu'])

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    train(train_loader, net, criterion, optimizer, args, val_loader)