Exemplo n.º 1
0
def main():
    global device
    print("=> will save everthing to {}".format(args.output_dir))
    output_dir = Path(args.output_dir)
    output_dir.makedirs_p()

    # Data loading code
    # normalize = pose_transforms.Normalize(mean=[0.5, 0.5, 0.5],
    #                                         std=[0.5, 0.5, 0.5])
    valid_transform = pose_transforms.Compose(
        [pose_transforms.ArrayToTensor()])

    print("=> fetching sequences in '{}'".format(args.dataset_dir))
    dataset_dir = Path(args.dataset_dir)
    print("=> preparing val set")
    val_set = pose_framework_KITTI(
        dataset_dir,
        "/home/gaof/workspace/Depth-VO-Feat/kitti_00.txt",
        transform=valid_transform)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # create model
    model = FixOdometryNet(bit_width=8).to(device)
    if args.checkpoint is None:
        model.init_weights()
    elif args.checkpoint:
        weights = torch.load(args.checkpoint)
        model.load_state_dict(weights)

    cudnn.benchmark = True
    if args.cuda and args.gpu_id in range(4):
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    elif args.cuda:
        model = torch.nn.DataParallel(model)

    inference(model, val_loader, args.output_dir)
Exemplo n.º 2
0
def main():
    global device
    print("=> will save everthing to {}".format(args.output_dir))
    output_dir = Path(args.output_dir)
    output_dir.makedirs_p()

    # Data loading code
    train_transform = pose_transforms.Compose([
        pose_transforms.RandomHorizontalFlip(),
        pose_transforms.ArrayToTensor()
    ])
    valid_transform = pose_transforms.Compose(
        [pose_transforms.ArrayToTensor()])

    print("=> fetching sequences in '{}'".format(args.dataset_dir))
    dataset_dir = Path(args.dataset_dir)
    print("=> preparing train set")
    train_set = dataset()  #transform=train_transform)
    print("=> preparing val set")
    val_set = pose_framework_KITTI(dataset_dir,
                                   args.test_sequences,
                                   transform=valid_transform,
                                   seed=args.seed,
                                   shuffle=False)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # create model
    vo_input_fix, vo_output_fix = False, False
    vo_conv_weight_fix = [False, False, False, False, False, False]
    #vo_conv_weight_fix = [True] * 6
    vo_fc_weight_fix = [False, False, False]
    vo_conv_output_fix = [False, False, False, False, False, False]
    #vo_conv_output_fix = [True] * 6
    vo_fc_output_fix = [False, False, False]

    odometry_net = FixOdometryNet(bit_width=BITWIDTH,
                                  input_fix=vo_input_fix,
                                  output_fix=vo_output_fix,
                                  conv_weight_fix=vo_conv_weight_fix,
                                  fc_weight_fix=vo_fc_weight_fix,
                                  conv_output_fix=vo_conv_output_fix,
                                  fc_output_fix=vo_fc_output_fix).to(device)

    depth_net = DispNetS().to(device)

    # init weights of model
    if args.odometry is None:
        odometry_net.init_weights()
    elif args.odometry:
        weights = torch.load(args.odometry)
        odometry_net.load_state_dict(weights)
    if args.depth is None:
        depth_net.init_weights()
    elif args.depth:
        weights = torch.load(args.depth)
        depth_net.load_state_dict(weights['state_dict'])

    cudnn.benchmark = True
    if args.cuda and args.gpu_id in range(2):
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    elif args.cuda:
        odometry_net = torch.nn.DataParallel(odometry_net)
        depth_net = torch.nn.DataParallel(depth_net)

    optim_params = [{
        'params': odometry_net.parameters(),
        'lr': args.lr
    }, {
        'params': depth_net.parameters(),
        'lr': args.lr
    }]

    # model = model.to(device)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
    optimizer = optim.Adam(optim_params,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=args.weight_decay)
    print("=> validating before training")
    #validate(odometry_net, depth_net, val_loader, 0, output_dir, True)
    print("=> training & validating")
    #validate(odometry_net, depth_net, val_loader, 0, output_dir)
    for epoch in range(1, args.epochs + 1):
        train(odometry_net, depth_net, train_loader, epoch, optimizer)
        validate(odometry_net, depth_net, val_loader, epoch, output_dir)
Exemplo n.º 3
0
def main():
    global n_iter, device
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    if args.gpu_id == 0 or args.gpu_id == 1:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    else:
        args.gpu_id = -1

    print("=> will save everthing to {}".format(args.output_dir))
    output_dir = Path(args.output_dir)
    output_dir.makedirs_p()

    # Data loading code
    normalize = pose_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                          std=[0.5, 0.5, 0.5])
    #train_transform = pose_transforms.Compose([
    #    custom_transforms.RandomHorizontalFlip(),
    #    custom_transforms.RandomScaleCrop(),
    #    custom_transforms.ArrayToTensor(),
    #    normalize
    #])

    valid_transform = pose_transforms.Compose(
        [pose_transforms.ArrayToTensor()])  #, normalize])

    print("=> fetching sequences in '{}'".format(args.dataset_dir))
    dataset_dir = Path(args.dataset_dir)
    print("=> preparing train set")
    train_set = dataset()  #transform=train_transform)
    print("=> preparing val set")
    val_set = pose_framework_KITTI(dataset_dir,
                                   args.test_sequences,
                                   transform=valid_transform,
                                   seed=args.seed,
                                   shuffle=False)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    # the network for single-view depth prediction
    # disp_net = models.DispNetS().to(device)
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    disp_net = DispNetS().to(device)
    pose_exp_net = PoseExpNet().to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    if args.gpu_id < 0:
        disp_net = torch.nn.DataParallel(disp_net)
        pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')

    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    validate(args, pose_exp_net, disp_net, val_loader, 0, output_dir)
    for epoch in range(args.epochs):
        # train for one epoch
        train(args, train_loader, disp_net, pose_exp_net, optimizer, epoch)
        validate(args, pose_exp_net, disp_net, val_loader, epoch, output_dir)