def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-g', '--gpu', type=str, required=True)
    parser.add_argument('-c',
                        '--config',
                        type=int,
                        default=1,
                        choices=configurations.keys())
    parser.add_argument('--resume', help='Checkpoint path')
    args = parser.parse_args()

    gpu = args.gpu
    cfg = configurations[args.config]
    out = get_log_dir('fcn8s-atonce', args.config, cfg)
    resume = args.resume

    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

    if torch.cuda.device_count() == 1:
        batch_size = 1
    else:
        batch_size = 2 * torch.cuda.device_count()

    # 1. dataset

    root = osp.expanduser('~/data/datasets')
    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
    mix_loader = torch.utils.data.DataLoader(
        torchfcn.datasets.CityScapesClassSeg(
            root,
            split=['train', 'val'],
            transform=True,
            preprocess=False,
        ),
        batch_size=batch_size,
        shuffle=True,
        **kwargs)
    train_loader = torch.utils.data.DataLoader(
        torchfcn.datasets.CityScapesClassSeg(
            root,
            split=['train'],
            transform=True,
            preprocess=False,
        ),
        batch_size=batch_size,
        shuffle=True,
        **kwargs)
    val_loader = torch.utils.data.DataLoader(
        torchfcn.datasets.CityScapesClassSeg(
            root,
            split=['val'],
            transform=True,
            preprocess=False,
        ),
        batch_size=batch_size,
        shuffle=False,
        **kwargs)

    # train_loader = torch.utils.data.DataLoader(
    #     torchfcn.datasets.SBDClassSeg(root, split='train', transform=True),
    #     batch_size=1, shuffle=True, **kwargs)
    # val_loader = torch.utils.data.DataLoader(
    #     torchfcn.datasets.VOC2011ClassSeg(
    #         root, split='seg11valid', transform=True),
    #     batch_size=1, shuffle=False, **kwargs)

    # 2. model

    model = torchfcn.models.FCN8sAtOnce(n_class=20)

    start_epoch = 0
    start_iteration = 0
    if resume:
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        start_iteration = checkpoint['iteration']
    else:
        vgg16 = torchfcn.models.VGG16(pretrained=True)
        model.copy_params_from_vgg16(vgg16)
    if cuda:
        if torch.cuda.device_count() == 1:
            model = model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # 3. optimizer

    optim = torch.optim.Adam(
        [
            {
                'params': get_parameters(model, bias=False)
            },
            {
                'params': get_parameters(model, bias=True),
                'lr': cfg['lr'] * 2,
                'weight_decay': 0
            },
        ],
        lr=cfg['lr'],
        # momentum=cfg['momentum'],
        weight_decay=cfg['weight_decay'])
    if resume:
        optim.load_state_dict(checkpoint['optim_state_dict'])

    trainer = torchfcn.Trainer(
        cuda=cuda,
        model=model,
        optimizer=optim,
        train_loader=train_loader,
        val_loader=val_loader,
        mix_loader=mix_loader,
        out=out,
        nEpochs=5,
        max_iter=cfg['max_iteration'],
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()
Example #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-g', '--gpu', type=int, required=True)
    parser.add_argument('-c',
                        '--config',
                        type=int,
                        default=1,
                        choices=configurations.keys())
    parser.add_argument('--resume', help='Checkpoint path')
    args = parser.parse_args()

    gpu = args.gpu
    cfg = configurations[args.config]
    out = get_log_dir('fcn16s', args.config, cfg)
    resume = args.resume

    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

    # 1. dataset

    root = osp.expanduser('~/data/datasets')
    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
    train_loader = torch.utils.data.DataLoader(torchfcn.datasets.SBDClassSeg(
        root, split='train', transform=True),
                                               batch_size=1,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(torchfcn.datasets.VOC2011ClassSeg(
        root, split='seg11valid', transform=True),
                                             batch_size=1,
                                             shuffle=False,
                                             **kwargs)

    # 2. model

    model = torchfcn.models.FCN16s(n_class=21)
    start_epoch = 0
    start_iteration = 0
    if resume:
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        start_iteration = checkpoint['iteration']
    else:
        fcn32s = torchfcn.models.FCN32s()
        fcn32s.load_state_dict(torch.load(cfg['fcn32s_pretrained_model']))
        model.copy_params_from_fcn32s(fcn32s)
    if cuda:
        model = model.cuda()

    # 3. optimizer

    optim = torch.optim.SGD([
        {
            'params': get_parameters(model, bias=False)
        },
        {
            'params': get_parameters(model, bias=True),
            'lr': cfg['lr'] * 2,
            'weight_decay': 0
        },
    ],
                            lr=cfg['lr'],
                            momentum=cfg['momentum'],
                            weight_decay=cfg['weight_decay'])
    if resume:
        optim.load_state_dict(checkpoint['optim_state_dict'])

    trainer = torchfcn.Trainer(
        cuda=cuda,
        model=model,
        optimizer=optim,
        train_loader=train_loader,
        val_loader=val_loader,
        out=out,
        max_iter=cfg['max_iteration'],
        interval_validate=cfg.get('interval_validate', len(train_loader)),
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()
Example #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-g', '--gpu', type=int, default=0)
    parser.add_argument('-c',
                        '--config',
                        type=int,
                        default=1,
                        choices=configurations.keys())
    parser.add_argument('--resume', help='Checkpoint path')
    args = parser.parse_args()

    gpu = args.gpu
    cfg = configurations[args.config]
    out = get_log_dir('fcn8s-atonce', args.config, cfg)
    resume = args.resume

    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

    # 1. dataset

    root = osp.expanduser(
        '/media/zhi/Drive3/KITTI/rwth_kitti_semantics_dataset')
    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
    train_loader = torch.utils.data.DataLoader(torchfcn.datasets.MRIClassSeg(
        root, split='train', transform=True),
                                               batch_size=3,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(
        torchfcn.datasets.MRIClassSegValidate(root,
                                              split='validation',
                                              transform=True),
        batch_size=3,
        shuffle=False,
        **kwargs)

    # 2. model

    model = torchfcn.models.FCN8sAtOnce(n_class=9)
    start_epoch = 0
    start_iteration = 0
    if resume:
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        start_iteration = checkpoint['iteration']
    # else:
    #     vgg16 = torchfcn.models.VGG16(pretrained=True)
    #     model.copy_params_from_vgg16(vgg16)
    if cuda:
        model = model.cuda()

    # 3. optimizer

    optim = torch.optim.SGD([
        {
            'params': get_parameters(model, bias=False)
        },
        {
            'params': get_parameters(model, bias=True),
            'lr': cfg['lr'] * 2,
            'weight_decay': 0
        },
    ],
                            lr=cfg['lr'],
                            momentum=cfg['momentum'],
                            weight_decay=cfg['weight_decay'])
    if resume:
        optim.load_state_dict(checkpoint['optim_state_dict'])

    trainer = torchfcn.Trainer(
        cuda=cuda,
        model=model,
        optimizer=optim,
        train_loader=train_loader,
        val_loader=val_loader,
        out=out,
        max_iter=cfg['max_iteration'],
        interval_validate=100  #cfg.get('interval_validate', len(train_loader)),
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()