def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(weights=args.weights_encoder)
    net_decoder_1 = builder.build_decoder(weights=args.weights_decoder_1)
    net_decoder_2 = builder.build_decoder(arch='c1',weights=args.weights_decoder_2)

    if args.weighted_class:
        crit = nn.NLLLoss(ignore_index=-1, weight=args.class_weight)
    else:
        crit = nn.NLLLoss(ignore_index=-1)

    # Dataset and Loader
    dataset_train = GTA(root=args.root_gta, cropSize=args.imgSize, is_train=1)
    dataset_val = CityScapes('val', root=args.root_cityscapes, cropSize=args.imgSize,
                             max_sample=args.num_val, is_train=0)
    dataset_val_2 = BDD('val', root=args.root_bdd, cropSize=args.imgSize,
                        max_sample=args.num_val, is_train=0)

    loader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=args.batch_size_eval,
        shuffle=False,
        num_workers=int(args.workers),
        drop_last=True)
    loader_val_2 = torch.utils.data.DataLoader(
        dataset_val_2,
        batch_size=args.batch_size_eval,
        shuffle=False,
        num_workers=int(args.workers),
        drop_last=True)
    args.epoch_iters = int(len(dataset_train) / args.batch_size)
    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # load nets into gpu
    if args.num_gpus > 1:
        net_encoder = nn.DataParallel(net_encoder,
                                      device_ids=range(args.num_gpus))
        net_decoder_1 = nn.DataParallel(net_decoder_1,
                                        device_ids=range(args.num_gpus))
        net_decoder_2 = nn.DataParallel(net_decoder_2,
                                        device_ids=range(args.num_gpus))

    nets = (net_encoder, net_decoder_1, net_decoder_2, crit)
    for net in nets:
        net.cuda()

    history = {split: {'epoch': [], 'err': [], 'acc': [], 'mIoU': []}
               for split in ('train', 'val', 'val_2')}

    # eval
    evaluate(nets, loader_val, loader_val_2, history, 0, args)
    print('Evaluation Done!')
Beispiel #2
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                        fc_dim=args.fc_dim,
                                        weights=args.weights_encoder)
    net_decoder_1 = builder.build_decoder(arch=args.arch_decoder,
                                          fc_dim=args.fc_dim,
                                          num_class=args.num_class,
                                          weights=args.weights_decoder_1)
    net_decoder_2 = builder.build_decoder(arch=args.arch_decoder,
                                          fc_dim=args.fc_dim,
                                          num_class=args.num_class,
                                          weights=args.weights_decoder_2)
    net_syn = builder.build_syn(weights=args.weights_syn)

    crit = nn.NLLLoss2d(ignore_index=-1)

    # Dataset and Loader
    # dataset_val = CityScapes('val', root=args.root_cityscapes, max_sample=args.num_val, is_train=0)
    dataset_val = BDD('val',
                      root=args.root_unlabeled,
                      cropSize=args.imgSize,
                      max_sample=args.num_val,
                      is_train=0)
    loader_val = torch.utils.data.DataLoader(dataset_val,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=2,
                                             drop_last=True)

    nets = (net_encoder, net_decoder_1, net_decoder_2, net_syn, crit)
    for net in nets:
        net.cuda()

    # Main loop
    evaluate(nets, loader_val, args)

    print('Evaluation Done!')
Beispiel #3
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(weights=args.weights_encoder)
    net_decoder_1 = builder.build_decoder(weights=args.weights_decoder_1)
    net_decoder_2 = builder.build_decoder(arch='c1',
                                          weights=args.weights_decoder_2)

    if args.weighted_class:
        crit = nn.NLLLoss(ignore_index=-1, weight=args.class_weight)
    else:
        crit = nn.NLLLoss(ignore_index=-1)

    # Dataset and Loader
    dataset_train = GTA(root=args.root_gta,
                        cropSize=args.imgSize,
                        is_train=1,
                        random_mask=args.mask)
    dataset_val = CityScapes('val',
                             root=args.root_cityscapes,
                             cropSize=args.imgSize,
                             max_sample=args.num_val,
                             is_train=0)
    dataset_val_2 = BDD('val',
                        root=args.root_bdd,
                        cropSize=args.imgSize,
                        max_sample=args.num_val,
                        is_train=0)

    loader_train = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=int(args.workers),
                                               drop_last=True)
    loader_val = torch.utils.data.DataLoader(dataset_val,
                                             batch_size=args.batch_size_eval,
                                             shuffle=False,
                                             num_workers=int(args.workers),
                                             drop_last=True)
    loader_val_2 = torch.utils.data.DataLoader(dataset_val_2,
                                               batch_size=args.batch_size_eval,
                                               shuffle=False,
                                               num_workers=int(args.workers),
                                               drop_last=True)
    args.epoch_iters = int(len(dataset_train) / args.batch_size)
    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # load nets into gpu
    if args.num_gpus > 1:
        net_encoder = nn.DataParallel(net_encoder,
                                      device_ids=range(args.num_gpus))
        net_decoder_1 = nn.DataParallel(net_decoder_1,
                                        device_ids=range(args.num_gpus))
        net_decoder_2 = nn.DataParallel(net_decoder_2,
                                        device_ids=range(args.num_gpus))

    nets = (net_encoder, net_decoder_1, net_decoder_2, crit)
    for net in nets:
        net.cuda()

    # Set up optimizers
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {
        split: {
            'epoch': [],
            'err': [],
            'acc': [],
            'mIoU': []
        }
        for split in ('train', 'val', 'val_2')
    }

    # optional initial eval
    # evaluate(nets, loader_val, loader_val_2, history, 0, args)
    for epoch in range(1, args.num_epoch + 1):
        train(nets, loader_train, optimizers, history, epoch, args)

        # Evaluation
        if epoch % args.eval_epoch == 0:
            evaluate(nets, loader_val, loader_val_2, history, epoch, args)

        # checkpointing
        checkpoint(nets, history, args)

        # adjust learning rate
        adjust_learning_rate(optimizers, epoch, args)

    print('Training Done!')
Beispiel #4
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(arch=args.arch_encoder,
                                        fc_dim=args.fc_dim,
                                        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(arch=args.arch_decoder,
                                        fc_dim=args.fc_dim,
                                        num_class=args.num_class,
                                        weights=args.weights_decoder)

    crit = nn.NLLLoss2d(ignore_index=-1)

    # Dataset and Loader
    #dataset_train = GTA(cropSize=args.imgSize, root=args.root_labeled)
    #dataset_train =  CityScapes('train', root=args.root_unlabeled, cropSize=args.imgSize, is_train=1)
    dataset_train = BDD('train',
                        root=args.root_labeled,
                        cropSize=args.imgSize,
                        is_train=1)
    #dataset_val = CityScapes('val', root=args.root_unlabeled, cropSize=args.imgSize, max_sample=args.num_val, is_train=0)
    dataset_val = BDD('val',
                      root=args.root_unlabeled,
                      cropSize=args.imgSize,
                      max_sample=args.num_val,
                      is_train=0)
    loader_train = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=int(args.workers),
                                               drop_last=True)
    loader_val = torch.utils.data.DataLoader(dataset_val,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=2,
                                             drop_last=True)
    args.epoch_iters = int(len(dataset_train) / args.batch_size)
    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # load nets into gpu
    if args.num_gpus > 1:
        net_encoder = nn.DataParallel(net_encoder,
                                      device_ids=range(args.num_gpus))
        net_decoder = nn.DataParallel(net_decoder,
                                      device_ids=range(args.num_gpus))
    nets = (net_encoder, net_decoder, crit)
    for net in nets:
        net.cuda()

    # Set up optimizers
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {
        split: {
            'epoch': [],
            'err': [],
            'acc': []
        }
        for split in ('train', 'val')
    }
    # optional initial eval
    evaluate(nets, loader_val, history, 0, args)
    for epoch in range(1, args.num_epoch + 1):
        train(nets, loader_train, optimizers, history, epoch, args)

        # checkpointing
        checkpoint(nets, history, args)

        # adjust learning rate
        adjust_learning_rate(optimizers, epoch, args)

        # Evaluation and visualization
        if epoch % args.eval_epoch == 0:
            evaluate(nets, loader_val, history, epoch, args)

    print('Training Done!')