def main(): global device global data_parallel print("=> will save everthing to {}".format(args.output_dir)) output_dir = Path(args.output_dir) output_dir.makedirs_p() # Data loading code train_transform = pose_transforms.Compose([ pose_transforms.RandomHorizontalFlip(), pose_transforms.ArrayToTensor() ]) valid_transform = pose_transforms.Compose( [pose_transforms.ArrayToTensor()]) print("=> fetching sequences in '{}'".format(args.dataset_dir)) dataset_dir = Path(args.dataset_dir) print("=> preparing train set") train_set = dataset() #transform=train_transform) print("=> preparing val set") val_set = pose_framework_KITTI(dataset_dir, args.test_sequences, transform=valid_transform, seed=args.seed, shuffle=False) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) # create model odometry_net = PoseExpNet().to(device) depth_net = DispNetS().to(device) feat_extractor = FeatExtractor().to(device) # init weights of model if args.odometry is None: odometry_net.init_weights() elif args.odometry: weights = torch.load(args.odometry) odometry_net.load_state_dict(weights) if args.depth is None: depth_net.init_weights() elif args.depth: weights = torch.load(args.depth) depth_net.load_state_dict(weights['state_dict']) feat_extractor.init_weights() cudnn.benchmark = True if args.cuda and args.gpu_id in range(2): os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) elif args.cuda: data_parallel = True odometry_net = torch.nn.DataParallel(odometry_net) depth_net = torch.nn.DataParallel(depth_net) feat_extractor = torch.nn.DataParallel(feat_extractor) optim_params = [{ 'params': odometry_net.parameters(), 'lr': args.lr }, { 'params': depth_net.parameters(), 'lr': args.lr }, { 'params': feat_extractor.parameters(), 'lr': args.lr }] # optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) optimizer = optim.Adam(optim_params, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay) print("=> validating before training") #validate(odometry_net, depth_net, val_loader, 0, output_dir, True) print("=> training & validating") #validate(odometry_net, depth_net, val_loader, 0, output_dir) for epoch in range(1, args.epochs + 1): train(odometry_net, depth_net, feat_extractor, train_loader, epoch, optimizer) validate(odometry_net, depth_net, feat_extractor, val_loader, epoch, output_dir)