예제 #1
0
def main():
    global args
    checkpoint = None
    #is_eval = False
    is_eval = True  # 我加的,用来测试,2020/02/26
    if args.evaluate:
        args_new = args
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}' ... ".format(args.evaluate),
                  end='')
            checkpoint = torch.load(args.evaluate, map_location=device)
            args = checkpoint['args']
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            is_eval = True
            print("Completed.")
        else:
            print("No model found at '{}'".format(args.evaluate))
            return

    print("=> creating model and optimizer ... ", end='')
    model = DepthCompletionNet(args).to(device)
    model_named_params = [
        p for _, p in model.named_parameters() if p.requires_grad
    ]
    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("completed.")
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    model = torch.nn.DataParallel(model)

    # Data loading code
    print("=> creating data loaders ... ")

    val_dataset = KittiDepth('test_completion', args)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    # create backups and results folder
    logger = helper.logger(args)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
    print("=> logger created.")

    if is_eval:
        print("=> starting model test ...")
        result, is_best = iterate("test_completion", args, val_loader, model,
                                  None, logger, checkpoint['epoch'])
        return
예제 #2
0
def get_kitti_dataloader(mode, dataset_name, setname, args):
    """
    Get kitti dataset and dataloader according mode and setname
    :param mode: use this dataset for train or eval, possible value: train or eval
    :param dataset_name: kitti, ours, vkitti, by default, it use kitti
    :param setname: train, val, selval, test
    :param args: related arguments
    :return: dataset, dataloader
    """
    dataset_dir = get_dataset_dir(dataset_name)

    if dataset_name == 'ours':
        dataset = OurDataset(base_dir=dataset_dir,
                             mode=mode,
                             setname="f_c_1216_352",
                             args=args)
    elif dataset_name == 'ours_20190318':
        dataset = OurDataset(base_dir=dataset_dir,
                             mode=mode,
                             setname="f_c_1216_352_20190318",
                             args=args)
    elif dataset_name == 'vkitti':
        dataset = VKittiDataset(base_dir=dataset_dir,
                                mode=mode,
                                setname=setname,
                                args=args)
    elif dataset_name == 'nuscenes':
        dataset = NuScenesDataset(base_dir=dataset_dir,
                                  mode=mode,
                                  setname="f_c_1216_352",
                                  args=args)
    elif dataset_name == 'kitti':
        dataset = KittiDataset(base_dir=dataset_dir,
                               mode=mode,
                               setname=setname,
                               args=args)
    else:
        dataset = KittiDepth(setname, args)

    if mode == 'train':
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 sampler=None)
    elif mode == 'eval':
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=1,
            shuffle=False,
            num_workers=2,
            pin_memory=True)  # set batch size to be 1 for validation
    else:
        raise ValueError("Unrecognized mode " + str(mode))

    return dataset, dataloader
예제 #3
0
def main():
    global args
    if args.partial_train == 'yes':  # train on a part of the whole train set
        print(
            "Can't use partial train here. It is used only for test check. Exit..."
        )
        return

    if args.test != "yes":
        print(
            "This main should use only for testing, but test=yes wat not given. Exit..."
        )
        return

    print("Evaluating test set with main_test:")
    whole_ts = time.time()
    checkpoint = None
    is_eval = False
    if args.evaluate:  # test a finished model
        args_new = args  # copies
        if os.path.isfile(args.evaluate):  # path is an existing regular file
            print("=> loading finished model from '{}' ... ".format(
                args.evaluate),
                  end='')  # "end=''" disables the newline
            checkpoint = torch.load(args.evaluate, map_location=device)
            args = checkpoint['args']
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            args.save_images = args_new.save_images
            args.result = args_new.result
            is_eval = True
            print("Completed.")
        else:
            print("No model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint from '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer ... ", end='')
    model = DepthCompletionNet(args).to(device)
    model_named_params = [
        p for _, p in model.named_parameters(
        )  # "_, p" is a direct analogy to an assignment statement k, _ = (0, 1). Unpack a tuple object
        if p.requires_grad
    ]
    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("completed.")
    [f'{k:<20}: {v}' for k, v in model.__dict__.items()]

    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    model = torch.nn.DataParallel(
        model
    )  # make the model run parallelly: splits your data automatically and sends job orders to multiple models on several GPUs.
    # After each model finishes their job, DataParallel collects and merges the results before returning it to you

    # data loading code
    print("=> creating data loaders ... ")
    if not is_eval:  # we're not evaluating
        train_dataset = KittiDepth('train',
                                   args)  # get the paths for the files
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)  # load them
        print("\t==> train_loader size:{}".format(len(train_loader)))

    if args_new.test == "yes":  # will take the data from the "test" folders
        val_dataset = KittiDepth('test', args)
        is_test = 'yes'
    else:
        val_dataset = KittiDepth('val', args)
        is_test = 'no'
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    # create backups and results folder
    logger = helper.logger(args, is_test)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
    print("=> logger created.")  # logger records sequential data to a log file

    # main code - run the NN
    if is_eval:
        print("=> starting model evaluation ...")
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  checkpoint['epoch'])
        return

    print("=> starting model training ...")
    for epoch in range(args.start_epoch, args.epochs):
        print("=> start training epoch {}".format(epoch) +
              "/{}..".format(args.epochs))
        train_ts = time.time()
        iterate("train", args, train_loader, model, optimizer, logger,
                epoch)  # train for one epoch
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  epoch)  # evaluate on validation set
        helper.save_checkpoint({  # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer': optimizer.state_dict(),
            'args': args,
        }, is_best, epoch, logger.output_directory)
        print("finish training epoch {}, time elapsed {:.2f} hours, \n".format(
            epoch, (time.time() - train_ts) / 3600))
    last_checkpoint = os.path.join(
        logger.output_directory, 'checkpoint-' + str(epoch) + '.pth.tar'
    )  # delete last checkpoint because we have the best_model and we dont need it
    os.remove(last_checkpoint)
    print("finished model training, time elapsed {0:.2f} hours, \n".format(
        (time.time() - whole_ts) / 3600))
예제 #4
0
                            set_num, i + 1))
                        input_t = input_type
                        data_in = '../data_new/phase_' + str(
                            phase) + '/mini_set_' + str(set_num)
                        pred_dir = '../data_new/phase_' + str(
                            phase + 1) + '/mini_set_' + str(
                                set_num) + '/predictions_tmp/NN' + str(i + 1)

                        NN_arguments[i].data_folder = data_in
                        NN_arguments[i].pred_dir = pred_dir
                        NN_arguments[i].val = 'full'
                        NN_arguments[i].use_d = 'd' in input_t
                        NN_arguments[i].batch_size = predict_batch_size

                        train_dataset = KittiDepth(
                            'val', NN_arguments[i]
                        )  # we adjusted 'val-full' option for predicting on the train data
                        train_loader = torch.utils.data.DataLoader(
                            train_dataset,
                            batch_size=NN_arguments[i].batch_size,
                            shuffle=False,
                            num_workers=2,
                            pin_memory=True)
                        print("\t==> train_loader size:{}".format(
                            len(train_loader)))
                        print("=> starting predictions with args:\n {}".format(
                            NN_arguments[i]))
                        predict(NN_arguments[i], train_loader, models[i])
                        print("finished predictions\n")

                print(
예제 #5
0
def create_data_loaders(data_path,
                        data_type='visim',
                        loader_type='val',
                        arch='',
                        sparsifier_type='uar',
                        num_samples=500,
                        modality='rgb-fd',
                        depth_divisor=1,
                        max_depth=-1,
                        max_gt_depth=-1,
                        batch_size=8,
                        workers=8):
    # Data loading code
    print("=> creating data loaders ...")

    #legacy compatibility with sparse-to-dense data folder
    subfolder = os.path.join(data_path, loader_type)
    # if os.path.exists(subfolder):
    #     data_path = subfolder

    if not os.path.exists(data_path):
        raise RuntimeError('Data source does not exit:{}'.format(data_path))

    loader = None
    dataset = None
    max_depth = max_depth if max_depth >= 0.0 else np.inf
    max_gt_depth = max_gt_depth if max_gt_depth >= 0.0 else np.inf

    # sparsifier is a class for generating random sparse depth input from the ground truth
    sparsifier = None

    if sparsifier_type == UniformSampling.name:  #uar
        sparsifier = UniformSampling(num_samples=num_samples,
                                     max_depth=max_depth)
    elif sparsifier_type == SimulatedStereo.name:  #sim_stereo
        sparsifier = SimulatedStereo(num_samples=num_samples,
                                     max_depth=max_depth)

    if data_type == 'kitti':
        from dataloaders.kitti_loader import KittiDepth

        dataset = KittiDepth(data_path,
                             split=loader_type,
                             depth_divisor=depth_divisor)

    elif data_type == 'visim':
        from dataloaders.visim_dataloader import VISIMDataset

        dataset = VISIMDataset(data_path,
                               type=loader_type,
                               modality=modality,
                               sparsifier=sparsifier,
                               depth_divider=depth_divisor,
                               is_resnet=('resnet' in arch),
                               max_gt_depth=max_gt_depth)

    elif data_type == 'visim_seq':
        from dataloaders.visim_dataloader import VISIMSeqDataset
        dataset = VISIMSeqDataset(data_path,
                                  type=loader_type,
                                  modality=modality,
                                  sparsifier=sparsifier,
                                  depth_divider=depth_divisor,
                                  is_resnet=('resnet' in arch),
                                  max_gt_depth=max_gt_depth)
    else:
        raise RuntimeError(
            'data type not found.' +
            'The dataset must be either of kitti, visim or visim_seq.')

    if loader_type == 'val':
        # set batch size to be 1 for validation
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=workers,
                                             pin_memory=True)
        print("=> Val loader:{}".format(len(dataset)))
    elif loader_type == 'train':
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=workers,
            pin_memory=True,
            sampler=None,
            worker_init_fn=lambda work_id: np.random.seed(work_id))
        print("=> Train loader:{}".format(len(dataset)))
        # worker_init_fn ensures different sampling patterns for each data loading thread

    print("=> data loaders created.")
    return loader, dataset
def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        args_new = args
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}' ... ".format(args.evaluate),
                  end='')
            checkpoint = torch.load(args.evaluate, map_location=device)
            args = checkpoint['args']
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            args.result = args_new.result
            is_eval = True
            print("Completed.")
        else:
            print("No model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # optionally resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            args.result = args_new.result
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer ... ", end='')
    model = DepthCompletionNet(args).to(device)
    model_named_params = [
        p for _, p in model.named_parameters() if p.requires_grad
    ]
    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("completed.")
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    model = torch.nn.DataParallel(model)

    # Data loading code
    print("=> creating data loaders ... ")
    if not is_eval:
        train_dataset = KittiDepth('train', args)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)
        print("\t==> train_loader size:{}".format(len(train_loader)))
    val_dataset = KittiDepth('val', args)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    # create backups and results folder
    logger = helper.logger(args)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
    print("=> logger created.")

    if is_eval:
        print("=> starting model evaluation ...")
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  checkpoint['epoch'])
        return

    # main loop
    print("=> starting main loop ...")
    for epoch in range(args.start_epoch, args.epochs):
        print("=> starting training epoch {} ..".format(epoch))
        iterate("train", args, train_loader, model, optimizer, logger,
                epoch)  # train for one epoch
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  epoch)  # evaluate on validation set
        helper.save_checkpoint({ # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer' : optimizer.state_dict(),
            'args' : args,
        }, is_best, epoch, logger.output_directory)
def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        args_new = args
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}' ... ".format(args.evaluate),
                  end='')
            checkpoint = torch.load(args.evaluate, map_location=device)
            args = checkpoint['args']
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            is_eval = True
            print("Completed.")
        else:
            print("No model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # optionally resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    ################# model

    print("=> creating model and optimizer ... ", end='')
    parameters_to_train = []
    encoder = networks.ResnetEncoder(num_layers=18)
    encoder.to(device)
    parameters_to_train += list(encoder.parameters())
    decoder = networks.DepthDecoder(encoder.num_ch_enc)
    decoder.to(device)
    parameters_to_train += list(decoder.parameters())
    # encoder_named_params = [
    #     p for _, p in encoder.named_parameters() if p.requires_grad
    # ]
    optimizer = torch.optim.Adam(parameters_to_train,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    encoder = torch.nn.DataParallel(encoder)
    decoder = torch.nn.DataParallel(decoder)
    model = [encoder, decoder]
    print("completed.")
    # if checkpoint is not None:
    #     model.load_state_dict(checkpoint['model'])
    #     optimizer.load_state_dict(checkpoint['optimizer'])
    #     print("=> checkpoint state loaded.")

    # Data loading code
    print("=> creating data loaders ... ")
    if not is_eval:
        train_dataset = KittiDepth('train', args)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)
        print("\t==> train_loader size:{}".format(len(train_loader)))
    val_dataset = KittiDepth('val', args)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=12,  #1
        shuffle=False,
        num_workers=2,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    ##############################################################

    # create backups and results folder
    logger = helper.logger(args)
    # if checkpoint is not None:
    #     logger.best_result = checkpoint['best_result']
    print("=> logger created.")

    if is_eval:
        print("=> starting model evaluation ...")
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  checkpoint['epoch'])
        return

    # main loop
    print("=> starting main loop ...")
    for epoch in range(args.start_epoch, args.epochs):
        print("=> starting training epoch {} ..".format(epoch))
        iterate("train", args, train_loader, model, optimizer, logger,
                epoch)  # train for one epoch
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  epoch)  # evaluate on validation set
예제 #8
0
def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        args_new = args
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}' ... ".format(args.evaluate),
                  end='')
            checkpoint = torch.load(args.evaluate, map_location=device)
            #args = checkpoint['args']
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            is_eval = True

            print("Completed.")
        else:
            is_eval = True
            print("No model found at '{}'".format(args.evaluate))
            #return

    elif args.resume:  # optionally resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)

            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer ... ", end='')
    model = None
    penet_accelerated = False
    if (args.network_model == 'e'):
        model = ENet(args).to(device)
    elif (is_eval == False):
        if (args.dilation_rate == 1):
            model = PENet_C1_train(args).to(device)
        elif (args.dilation_rate == 2):
            model = PENet_C2_train(args).to(device)
        elif (args.dilation_rate == 4):
            model = PENet_C4(args).to(device)
            penet_accelerated = True
    else:
        if (args.dilation_rate == 1):
            model = PENet_C1(args).to(device)
            penet_accelerated = True
        elif (args.dilation_rate == 2):
            model = PENet_C2(args).to(device)
            penet_accelerated = True
        elif (args.dilation_rate == 4):
            model = PENet_C4(args).to(device)
            penet_accelerated = True

    if (penet_accelerated == True):
        model.encoder3.requires_grad = False
        model.encoder5.requires_grad = False
        model.encoder7.requires_grad = False

    model_named_params = None
    model_bone_params = None
    model_new_params = None
    optimizer = None

    if checkpoint is not None:
        #print(checkpoint.keys())
        if (args.freeze_backbone == True):
            model.backbone.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint['model'], strict=False)
        #optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    logger = helper.logger(args)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
        del checkpoint
    print("=> logger created.")

    test_dataset = None
    test_loader = None
    if (args.test):
        test_dataset = KittiDepth('test_completion', args)
        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=1,
                                                  pin_memory=True)
        iterate("test_completion", args, test_loader, model, None, logger, 0)
        return

    val_dataset = KittiDepth('val', args)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    if is_eval == True:
        for p in model.parameters():
            p.requires_grad = False

        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  args.start_epoch - 1)
        return

    if (args.freeze_backbone == True):
        for p in model.backbone.parameters():
            p.requires_grad = False
        model_named_params = [
            p for _, p in model.named_parameters() if p.requires_grad
        ]
        optimizer = torch.optim.Adam(model_named_params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay,
                                     betas=(0.9, 0.99))
    elif (args.network_model == 'pe'):
        model_bone_params = [
            p for _, p in model.backbone.named_parameters() if p.requires_grad
        ]
        model_new_params = [
            p for _, p in model.named_parameters() if p.requires_grad
        ]
        model_new_params = list(set(model_new_params) - set(model_bone_params))
        optimizer = torch.optim.Adam([{
            'params': model_bone_params,
            'lr': args.lr / 10
        }, {
            'params': model_new_params
        }],
                                     lr=args.lr,
                                     weight_decay=args.weight_decay,
                                     betas=(0.9, 0.99))
    else:
        model_named_params = [
            p for _, p in model.named_parameters() if p.requires_grad
        ]
        optimizer = torch.optim.Adam(model_named_params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay,
                                     betas=(0.9, 0.99))
    print("completed.")

    model = torch.nn.DataParallel(model)

    # Data loading code
    print("=> creating data loaders ... ")
    if not is_eval:
        train_dataset = KittiDepth('train', args)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)
        print("\t==> train_loader size:{}".format(len(train_loader)))

    print("=> starting main loop ...")
    for epoch in range(args.start_epoch, args.epochs):
        print("=> starting training epoch {} ..".format(epoch))
        iterate("train", args, train_loader, model, optimizer, logger,
                epoch)  # train for one epoch

        # validation memory reset
        for p in model.parameters():
            p.requires_grad = False
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  epoch)  # evaluate on validation set

        for p in model.parameters():
            p.requires_grad = True
        if (args.freeze_backbone == True):
            for p in model.module.backbone.parameters():
                p.requires_grad = False
        if (penet_accelerated == True):
            model.module.encoder3.requires_grad = False
            model.module.encoder5.requires_grad = False
            model.module.encoder7.requires_grad = False

        helper.save_checkpoint({ # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer' : optimizer.state_dict(),
            'args' : args,
        }, is_best, epoch, logger.output_directory)
예제 #9
0
def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}'".format(args.evaluate))
            checkpoint = torch.load(args.evaluate)
            args = checkpoint['args']
            is_eval = True
            print("=> checkpoint loaded.")
        else:
            print("=> no model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # optionally resume from a checkpoint
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer...")
    model = DepthCompletionNet(args).cuda()
    model_named_params = [
        p for _, p in model.named_parameters() if p.requires_grad
    ]
    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("=> model and optimizer created.")
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    model = torch.nn.DataParallel(model)
    print("=> model transferred to multi-GPU.")

    # Data loading code
    print("=> creating data loaders ...")
    if not is_eval:
        train_dataset = KittiDepth('train', args)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)
    val_dataset = KittiDepth('val', args)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=False)  # set batch size to be 1 for validation
    print("=> data loaders created.")

    # create backups and results folde
    logger = helper.logger(args)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
    print("=> logger created.")

    if is_eval:
        result, result_intensity, is_best = iterate("val", args, val_loader,
                                                    model, None, logger,
                                                    checkpoint['epoch'])
        return

    # main loop

    for epoch in range(args.start_epoch, args.epochs):
        print("=> starting training epoch {} ..".format(epoch))
        iterate("train", args, train_loader, model, optimizer, logger,
                epoch)  # train for one epoch
        result, result_intensity, is_best = iterate(
            "val", args, val_loader, model, None, logger,
            epoch)  # evaluate on validation set
        helper.save_checkpoint({ # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer' : optimizer.state_dict(),
            'args' : args,
        }, is_best, epoch, logger.output_directory)

        logger.writer.add_scalar('eval/rmse_depth', result.rmse, epoch)
        logger.writer.add_scalar('eval/rmse_intensity', result_intensity.rmse,
                                 epoch)
        logger.writer.add_scalar('eval/mae_depth', result.mae, epoch)
        logger.writer.add_scalar('eval/mae_intensity', result_intensity.mae,
                                 epoch)
        # logger.writer.add_scalar('eval/irmse_depth', result.irmse, epoch)
        # logger.writer.add_scalar('eval/irmse_intensity', result_intensity.irmse, epoch)
        logger.writer.add_scalar('eval/rmse_total',
                                 result.rmse + args.wi * result_intensity.rmse,
                                 epoch)