示例#1
0
def main():
    global args
    checkpoint = None
    #is_eval = False
    is_eval = True  # 我加的,用来测试,2020/02/26
    if args.evaluate:
        args_new = args
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}' ... ".format(args.evaluate),
                  end='')
            checkpoint = torch.load(args.evaluate, map_location=device)
            args = checkpoint['args']
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            is_eval = True
            print("Completed.")
        else:
            print("No model found at '{}'".format(args.evaluate))
            return

    print("=> creating model and optimizer ... ", end='')
    model = DepthCompletionNet(args).to(device)
    model_named_params = [
        p for _, p in model.named_parameters() if p.requires_grad
    ]
    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("completed.")
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    model = torch.nn.DataParallel(model)

    # Data loading code
    print("=> creating data loaders ... ")

    val_dataset = KittiDepth('test_completion', args)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    # create backups and results folder
    logger = helper.logger(args)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
    print("=> logger created.")

    if is_eval:
        print("=> starting model test ...")
        result, is_best = iterate("test_completion", args, val_loader, model,
                                  None, logger, checkpoint['epoch'])
        return
示例#2
0
def main():
    global args
    if args.partial_train == 'yes':  # train on a part of the whole train set
        print(
            "Can't use partial train here. It is used only for test check. Exit..."
        )
        return

    if args.test != "yes":
        print(
            "This main should use only for testing, but test=yes wat not given. Exit..."
        )
        return

    print("Evaluating test set with main_test:")
    whole_ts = time.time()
    checkpoint = None
    is_eval = False
    if args.evaluate:  # test a finished model
        args_new = args  # copies
        if os.path.isfile(args.evaluate):  # path is an existing regular file
            print("=> loading finished model from '{}' ... ".format(
                args.evaluate),
                  end='')  # "end=''" disables the newline
            checkpoint = torch.load(args.evaluate, map_location=device)
            args = checkpoint['args']
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            args.save_images = args_new.save_images
            args.result = args_new.result
            is_eval = True
            print("Completed.")
        else:
            print("No model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint from '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer ... ", end='')
    model = DepthCompletionNet(args).to(device)
    model_named_params = [
        p for _, p in model.named_parameters(
        )  # "_, p" is a direct analogy to an assignment statement k, _ = (0, 1). Unpack a tuple object
        if p.requires_grad
    ]
    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("completed.")
    [f'{k:<20}: {v}' for k, v in model.__dict__.items()]

    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    model = torch.nn.DataParallel(
        model
    )  # make the model run parallelly: splits your data automatically and sends job orders to multiple models on several GPUs.
    # After each model finishes their job, DataParallel collects and merges the results before returning it to you

    # data loading code
    print("=> creating data loaders ... ")
    if not is_eval:  # we're not evaluating
        train_dataset = KittiDepth('train',
                                   args)  # get the paths for the files
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)  # load them
        print("\t==> train_loader size:{}".format(len(train_loader)))

    if args_new.test == "yes":  # will take the data from the "test" folders
        val_dataset = KittiDepth('test', args)
        is_test = 'yes'
    else:
        val_dataset = KittiDepth('val', args)
        is_test = 'no'
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    # create backups and results folder
    logger = helper.logger(args, is_test)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
    print("=> logger created.")  # logger records sequential data to a log file

    # main code - run the NN
    if is_eval:
        print("=> starting model evaluation ...")
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  checkpoint['epoch'])
        return

    print("=> starting model training ...")
    for epoch in range(args.start_epoch, args.epochs):
        print("=> start training epoch {}".format(epoch) +
              "/{}..".format(args.epochs))
        train_ts = time.time()
        iterate("train", args, train_loader, model, optimizer, logger,
                epoch)  # train for one epoch
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  epoch)  # evaluate on validation set
        helper.save_checkpoint({  # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer': optimizer.state_dict(),
            'args': args,
        }, is_best, epoch, logger.output_directory)
        print("finish training epoch {}, time elapsed {:.2f} hours, \n".format(
            epoch, (time.time() - train_ts) / 3600))
    last_checkpoint = os.path.join(
        logger.output_directory, 'checkpoint-' + str(epoch) + '.pth.tar'
    )  # delete last checkpoint because we have the best_model and we dont need it
    os.remove(last_checkpoint)
    print("finished model training, time elapsed {0:.2f} hours, \n".format(
        (time.time() - whole_ts) / 3600))
def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        args_new = args
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}' ... ".format(args.evaluate),
                  end='')
            checkpoint = torch.load(args.evaluate, map_location=device)
            args = checkpoint['args']
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            args.result = args_new.result
            is_eval = True
            print("Completed.")
        else:
            print("No model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # optionally resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            args.result = args_new.result
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer ... ", end='')
    model = DepthCompletionNet(args).to(device)
    model_named_params = [
        p for _, p in model.named_parameters() if p.requires_grad
    ]
    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("completed.")
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    model = torch.nn.DataParallel(model)

    # Data loading code
    print("=> creating data loaders ... ")
    if not is_eval:
        train_dataset = KittiDepth('train', args)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)
        print("\t==> train_loader size:{}".format(len(train_loader)))
    val_dataset = KittiDepth('val', args)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    # create backups and results folder
    logger = helper.logger(args)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
    print("=> logger created.")

    if is_eval:
        print("=> starting model evaluation ...")
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  checkpoint['epoch'])
        return

    # main loop
    print("=> starting main loop ...")
    for epoch in range(args.start_epoch, args.epochs):
        print("=> starting training epoch {} ..".format(epoch))
        iterate("train", args, train_loader, model, optimizer, logger,
                epoch)  # train for one epoch
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  epoch)  # evaluate on validation set
        helper.save_checkpoint({ # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer' : optimizer.state_dict(),
            'args' : args,
        }, is_best, epoch, logger.output_directory)
示例#4
0
def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        args_new = args
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}' ... ".format(args.evaluate),
                  end='')
            checkpoint = torch.load(args.evaluate, map_location=device)
            args = checkpoint['args']
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            args.every = args_new.every
            args.evaluate = args_new.evaluate
            args.type_feature = args_new.type_feature
            args.instancewise = args_new.instancewise
            args.sparse_depth_source = args_new.sparse_depth_source
            is_eval = True
            print("Completed.")
        else:
            print("No model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # optionally resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.every = args_new.every
            args.sparse_depth_source = args_new.sparse_depth_source
            args.val = args_new.val
            args.save_checkpoint_path = args_new.save_checkpoint_path
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer ... ", end='')

    # model
    if args.type_feature == "sq":
        if args.instancewise:
            model = DepthCompletionNetQSquareNet(args).to(device)
        else:
            model = DepthCompletionNetQSquare(args).to(device)
    elif args.type_feature == "lines":
        if args.instancewise:
            model = DepthCompletionNetQLinesNet(args).to(device)
        else:
            model = DepthCompletionNetQLines(args).to(device)

    model_named_params = [
        p for _, p in model.named_parameters() if p.requires_grad
    ]

    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("completed.")
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'], strict=False)
        #optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")
    model = torch.nn.DataParallel(model)

    # Data loading code
    def split_dataset(dataset, num):
        subloaders = []
        dataset_let = len(dataset)

        chunk = len(dataset) // num
        chunk_remainder = len(dataset) % num

        for i in range(num):
            if i < num:
                dataset_sub = torch.utils.data.Subset(
                    dataset, torch.arange(i * chunk, (i + 1) * chunk))
            elif i == num:
                dataset_sub = torch.utils.data.Subset(
                    dataset, torch.arange((num - 1) * chunk, dataset_let))

            sub_train_loader = torch.utils.data.DataLoader(
                dataset_sub,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True,
                sampler=None)

            subloaders.append(sub_train_loader)

        return subloaders

    print("=> creating data loaders ... ")
    if not is_eval:
        train_dataset = KittiDepth('train', args)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)
        print("\t==> train_loader size:{}".format(len(train_loader)))
    val_dataset = KittiDepth('val', args)
    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,
    #     batch_size=1,
    #     shuffle=False,
    #     num_workers=2,
    #     pin_memory=True)  # set batch size to be 1 for validation
    # print("\t==> val_loader size:{}".format(len(val_loader)))
    val_dataset_sub = torch.utils.data.Subset(val_dataset,
                                              torch.arange(1000))  #1000
    val_loader = torch.utils.data.DataLoader(
        val_dataset_sub,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    # create backups and results folder
    logger = helper.logger(args)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
    print("=> logger created.")

    if is_eval:
        print("=> starting model evaluation ...")
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  checkpoint['epoch'])
        return

    # for name, param in model.named_parameters():
    # #for name, param in model.state_dict().items():
    #     #print(name, param.shape)
    #     if "parameter" not in name:
    #     #if 1:
    #         h = param.register_hook(lambda grad: grad * 0)  # double the gradient

    # main loop
    print("=> starting main loop ...")
    for epoch in range(args.start_epoch, args.epochs):
        print(f"\n\n=> starting {bif_mode} training epoch {epoch} .. \n\n")
        splits_total = 500  #30
        for split_it, subdatloader in enumerate(
                split_dataset(train_dataset, splits_total)):
            print("subdataloader: ", split_it)
            is_eval = False
            iterate("train", args, subdatloader, model, optimizer, logger,
                    epoch, splits_total, split_it)  # train for one epoch
            if args.instancewise:
                result, is_best = iterate("val", args, val_loader, model, None,
                                          logger,
                                          epoch)  # evaluate on validation set
def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        args_new = args
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}' ... ".format(args.evaluate),
                  end='')
            checkpoint = torch.load(args.evaluate, map_location=device)
            args = checkpoint['args']
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            args.every = args_new.every
            args.evaluate = args_new.evaluate
            is_eval = True
            print("Completed.")
        else:
            print("No model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # optionally resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.every = args_new.every
            args.sparse_depth_source = args_new.sparse_depth_source
            args.val = args_new.val
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer ... ", end='')

    # model
    if args.type_feature == "sq":
        if args.instancewise:
            model = DepthCompletionNetQSquareNet(args).to(device)
        else:
            model = DepthCompletionNetQSquare(args).to(device)
    elif args.type_feature == "lines":
        model = DepthCompletionNetQ(args).to(device)
    model_named_params = [
        p for _, p in model.named_parameters() if p.requires_grad
    ]
    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("completed.")
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'], strict=False)
        #optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")
    model = torch.nn.DataParallel(model)

    # Data loading code
    print("=> creating data loaders ... ")
    if not is_eval:
        train_dataset = KittiDepth('train', args)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)
        print("\t==> train_loader size:{}".format(len(train_loader)))
    val_dataset = KittiDepth('val', args)
    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,
    #     batch_size=1,
    #     shuffle=False,
    #     num_workers=2,
    #     pin_memory=True)  # set batch size to be 1 for validation
    # print("\t==> val_loader size:{}".format(len(val_loader)))
    val_dataset_sub = torch.utils.data.Subset(val_dataset, torch.arange(1000))
    val_loader = torch.utils.data.DataLoader(
        val_dataset_sub,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    # create backups and results folder
    logger = helper.logger(args)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
    print("=> logger created.")

    if is_eval:
        print("=> starting model evaluation ...")
        result, is_best = iterate("val", args, val_loader, model, None, logger, checkpoint['epoch'])
        return

    # for name, param in model.named_parameters():
    # #for name, param in model.state_dict().items():
    #     #print(name, param.shape)
    #     if "parameter" not in name:
    #     #if 1:
    #         h = param.register_hook(lambda grad: grad * 0)  # double the gradient

    # main loop
    print("=> starting main loop ...")
    for epoch in range(args.start_epoch, args.epochs):
        print("\n\n=> starting training epoch {} .. \n\n".format(epoch))
        iterate("train", args, train_loader, model, optimizer, logger,epoch)  # train for one epoch
        result, is_best = iterate("val", args, val_loader, model, None, logger, epoch)  # evaluate on validation set
        helper.save_checkpoint({ # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer' : optimizer.state_dict(),
            'args' : args,
        }, is_best, epoch, logger.output_directory, args.type_feature)
示例#6
0
 def post(self):
     comp_id = request.form['ID']
     logger(str(request.environ['REMOTE_ADDR']), comp_id)
     return database_call(comp_id)
示例#7
0
def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}'".format(args.evaluate))
            checkpoint = torch.load(args.evaluate)
            args = checkpoint['args']
            is_eval = True
            print("=> checkpoint loaded.")
        else:
            print("=> no model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # optionally resume from a checkpoint
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer...")
    model = DepthCompletionNet(args).cuda()
    model_named_params = [
        p for _, p in model.named_parameters() if p.requires_grad
    ]
    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("=> model and optimizer created.")
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    model = torch.nn.DataParallel(model)
    print("=> model transferred to multi-GPU.")

    # Data loading code
    print("=> creating data loaders ...")
    if not is_eval:
        train_dataset, train_loader = get_kitti_dataloader(
            mode='train',
            dataset_name=dataset_name,
            setname='train',
            args=args)
        # train_dataset = KittiDepth('train', args)
        # train_loader = torch.utils.data.DataLoader(
        #     train_dataset, batch_size=args.batch_size, shuffle=True,
        #     num_workers=args.workers, pin_memory=True, sampler=None)

    val_dataset, val_loader = get_kitti_dataloader(mode='eval',
                                                   dataset_name=dataset_name,
                                                   setname='test',
                                                   args=args)

    # change dataset here:
    # val_dataset = KittiDepth('val', args)
    # val_dataset = KittiDataset(base_dir="./data/kitti/", setname="selval")
    # val_dataset = vKittiDataset(base_dir="./data/vkitti/", setname="test")
    # val_dataset = OurDataset(base_dir="/home/bird/data2/dataset/our_lidar/20190315/f_c_1216_352", setname="f_c_1216_352")
    # val_dataset = OurDataset(base_dir="/home/bird/data2/dataset/our_lidar/20190318/f_c_1216_352", setname="f_c_1216_352_20190318")
    # val_dataset = NuScenesDataset(base_dir="/home/bird/data2/dataset/nuscenes/projected", setname="f_c_1216_352")
    # val_loader = torch.utils.data.DataLoader(val_dataset,
    #     batch_size=1, shuffle=False, num_workers=2, pin_memory=True)  # set batch size to be 1 for validation

    print("=> data loaders created.")

    # create backups and results folder
    logger = helper.logger(args)
    # if checkpoint is not None:
    #     logger.best_result = checkpoint['best_result']
    print("=> logger created.")

    if is_eval:
        result, is_best = iterate("eval", args, val_loader, model, None,
                                  logger, checkpoint['epoch'], val_dataset)
        print(result)
        print(is_best)
        return

    # main loop
    for epoch in range(args.start_epoch, args.epochs):
        print("=> starting training epoch {} ..".format(epoch))
        iterate("train", args, train_loader, model, optimizer, logger, epoch,
                train_dataset)  # train for one epoch
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  epoch,
                                  val_dataset)  # evaluate on validation set
        helper.save_checkpoint({ # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer' : optimizer.state_dict(),
            'args' : args,
        }, is_best, epoch, logger.output_directory)
def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        args_new = args
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}' ... ".format(args.evaluate),
                  end='')
            checkpoint = torch.load(args.evaluate, map_location=device)
            args = checkpoint['args']
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            is_eval = True
            print("Completed.")
        else:
            print("No model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # optionally resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    ################# model

    print("=> creating model and optimizer ... ", end='')
    parameters_to_train = []
    encoder = networks.ResnetEncoder(num_layers=18)
    encoder.to(device)
    parameters_to_train += list(encoder.parameters())
    decoder = networks.DepthDecoder(encoder.num_ch_enc)
    decoder.to(device)
    parameters_to_train += list(decoder.parameters())
    # encoder_named_params = [
    #     p for _, p in encoder.named_parameters() if p.requires_grad
    # ]
    optimizer = torch.optim.Adam(parameters_to_train,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    encoder = torch.nn.DataParallel(encoder)
    decoder = torch.nn.DataParallel(decoder)
    model = [encoder, decoder]
    print("completed.")
    # if checkpoint is not None:
    #     model.load_state_dict(checkpoint['model'])
    #     optimizer.load_state_dict(checkpoint['optimizer'])
    #     print("=> checkpoint state loaded.")

    # Data loading code
    print("=> creating data loaders ... ")
    if not is_eval:
        train_dataset = KittiDepth('train', args)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)
        print("\t==> train_loader size:{}".format(len(train_loader)))
    val_dataset = KittiDepth('val', args)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=12,  #1
        shuffle=False,
        num_workers=2,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    ##############################################################

    # create backups and results folder
    logger = helper.logger(args)
    # if checkpoint is not None:
    #     logger.best_result = checkpoint['best_result']
    print("=> logger created.")

    if is_eval:
        print("=> starting model evaluation ...")
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  checkpoint['epoch'])
        return

    # main loop
    print("=> starting main loop ...")
    for epoch in range(args.start_epoch, args.epochs):
        print("=> starting training epoch {} ..".format(epoch))
        iterate("train", args, train_loader, model, optimizer, logger,
                epoch)  # train for one epoch
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  epoch)  # evaluate on validation set
示例#9
0
def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        args_new = args
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}' ... ".format(args.evaluate),
                  end='')
            checkpoint = torch.load(args.evaluate, map_location=device)
            #args = checkpoint['args']
            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            is_eval = True

            print("Completed.")
        else:
            is_eval = True
            print("No model found at '{}'".format(args.evaluate))
            #return

    elif args.resume:  # optionally resume from a checkpoint
        args_new = args
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}' ... ".format(args.resume),
                  end='')
            checkpoint = torch.load(args.resume, map_location=device)

            args.start_epoch = checkpoint['epoch'] + 1
            args.data_folder = args_new.data_folder
            args.val = args_new.val
            print("Completed. Resuming from epoch {}.".format(
                checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer ... ", end='')
    model = None
    penet_accelerated = False
    if (args.network_model == 'e'):
        model = ENet(args).to(device)
    elif (is_eval == False):
        if (args.dilation_rate == 1):
            model = PENet_C1_train(args).to(device)
        elif (args.dilation_rate == 2):
            model = PENet_C2_train(args).to(device)
        elif (args.dilation_rate == 4):
            model = PENet_C4(args).to(device)
            penet_accelerated = True
    else:
        if (args.dilation_rate == 1):
            model = PENet_C1(args).to(device)
            penet_accelerated = True
        elif (args.dilation_rate == 2):
            model = PENet_C2(args).to(device)
            penet_accelerated = True
        elif (args.dilation_rate == 4):
            model = PENet_C4(args).to(device)
            penet_accelerated = True

    if (penet_accelerated == True):
        model.encoder3.requires_grad = False
        model.encoder5.requires_grad = False
        model.encoder7.requires_grad = False

    model_named_params = None
    model_bone_params = None
    model_new_params = None
    optimizer = None

    if checkpoint is not None:
        #print(checkpoint.keys())
        if (args.freeze_backbone == True):
            model.backbone.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint['model'], strict=False)
        #optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    logger = helper.logger(args)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
        del checkpoint
    print("=> logger created.")

    test_dataset = None
    test_loader = None
    if (args.test):
        test_dataset = KittiDepth('test_completion', args)
        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=1,
                                                  pin_memory=True)
        iterate("test_completion", args, test_loader, model, None, logger, 0)
        return

    val_dataset = KittiDepth('val', args)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True)  # set batch size to be 1 for validation
    print("\t==> val_loader size:{}".format(len(val_loader)))

    if is_eval == True:
        for p in model.parameters():
            p.requires_grad = False

        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  args.start_epoch - 1)
        return

    if (args.freeze_backbone == True):
        for p in model.backbone.parameters():
            p.requires_grad = False
        model_named_params = [
            p for _, p in model.named_parameters() if p.requires_grad
        ]
        optimizer = torch.optim.Adam(model_named_params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay,
                                     betas=(0.9, 0.99))
    elif (args.network_model == 'pe'):
        model_bone_params = [
            p for _, p in model.backbone.named_parameters() if p.requires_grad
        ]
        model_new_params = [
            p for _, p in model.named_parameters() if p.requires_grad
        ]
        model_new_params = list(set(model_new_params) - set(model_bone_params))
        optimizer = torch.optim.Adam([{
            'params': model_bone_params,
            'lr': args.lr / 10
        }, {
            'params': model_new_params
        }],
                                     lr=args.lr,
                                     weight_decay=args.weight_decay,
                                     betas=(0.9, 0.99))
    else:
        model_named_params = [
            p for _, p in model.named_parameters() if p.requires_grad
        ]
        optimizer = torch.optim.Adam(model_named_params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay,
                                     betas=(0.9, 0.99))
    print("completed.")

    model = torch.nn.DataParallel(model)

    # Data loading code
    print("=> creating data loaders ... ")
    if not is_eval:
        train_dataset = KittiDepth('train', args)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)
        print("\t==> train_loader size:{}".format(len(train_loader)))

    print("=> starting main loop ...")
    for epoch in range(args.start_epoch, args.epochs):
        print("=> starting training epoch {} ..".format(epoch))
        iterate("train", args, train_loader, model, optimizer, logger,
                epoch)  # train for one epoch

        # validation memory reset
        for p in model.parameters():
            p.requires_grad = False
        result, is_best = iterate("val", args, val_loader, model, None, logger,
                                  epoch)  # evaluate on validation set

        for p in model.parameters():
            p.requires_grad = True
        if (args.freeze_backbone == True):
            for p in model.module.backbone.parameters():
                p.requires_grad = False
        if (penet_accelerated == True):
            model.module.encoder3.requires_grad = False
            model.module.encoder5.requires_grad = False
            model.module.encoder7.requires_grad = False

        helper.save_checkpoint({ # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer' : optimizer.state_dict(),
            'args' : args,
        }, is_best, epoch, logger.output_directory)
示例#10
0
def main():
    global args
    checkpoint = None
    is_eval = False
    if args.evaluate:
        if os.path.isfile(args.evaluate):
            print("=> loading checkpoint '{}'".format(args.evaluate))
            checkpoint = torch.load(args.evaluate)
            args = checkpoint['args']
            is_eval = True
            print("=> checkpoint loaded.")
        else:
            print("=> no model found at '{}'".format(args.evaluate))
            return
    elif args.resume:  # optionally resume from a checkpoint
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            return

    print("=> creating model and optimizer...")
    model = DepthCompletionNet(args).cuda()
    model_named_params = [
        p for _, p in model.named_parameters() if p.requires_grad
    ]
    optimizer = torch.optim.Adam(model_named_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    print("=> model and optimizer created.")
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> checkpoint state loaded.")

    model = torch.nn.DataParallel(model)
    print("=> model transferred to multi-GPU.")

    # Data loading code
    print("=> creating data loaders ...")
    if not is_eval:
        train_dataset = KittiDepth('train', args)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=None)
    val_dataset = KittiDepth('val', args)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=False)  # set batch size to be 1 for validation
    print("=> data loaders created.")

    # create backups and results folde
    logger = helper.logger(args)
    if checkpoint is not None:
        logger.best_result = checkpoint['best_result']
    print("=> logger created.")

    if is_eval:
        result, result_intensity, is_best = iterate("val", args, val_loader,
                                                    model, None, logger,
                                                    checkpoint['epoch'])
        return

    # main loop

    for epoch in range(args.start_epoch, args.epochs):
        print("=> starting training epoch {} ..".format(epoch))
        iterate("train", args, train_loader, model, optimizer, logger,
                epoch)  # train for one epoch
        result, result_intensity, is_best = iterate(
            "val", args, val_loader, model, None, logger,
            epoch)  # evaluate on validation set
        helper.save_checkpoint({ # save checkpoint
            'epoch': epoch,
            'model': model.module.state_dict(),
            'best_result': logger.best_result,
            'optimizer' : optimizer.state_dict(),
            'args' : args,
        }, is_best, epoch, logger.output_directory)

        logger.writer.add_scalar('eval/rmse_depth', result.rmse, epoch)
        logger.writer.add_scalar('eval/rmse_intensity', result_intensity.rmse,
                                 epoch)
        logger.writer.add_scalar('eval/mae_depth', result.mae, epoch)
        logger.writer.add_scalar('eval/mae_intensity', result_intensity.mae,
                                 epoch)
        # logger.writer.add_scalar('eval/irmse_depth', result.irmse, epoch)
        # logger.writer.add_scalar('eval/irmse_intensity', result_intensity.irmse, epoch)
        logger.writer.add_scalar('eval/rmse_total',
                                 result.rmse + args.wi * result_intensity.rmse,
                                 epoch)