示例#1
0
def prepare_environment():
    env = {}
    args = parser.parse_args()
    if args.dataset_format == 'KITTI':
        from datasets.shifted_sequence_folders import ShiftedSequenceFolder
    elif args.dataset_format == 'StillBox':
        from datasets.shifted_sequence_folders import StillBox as ShiftedSequenceFolder
    elif args.dataset_format == 'TUM':
        from datasets.shifted_sequence_folders import TUM as ShiftedSequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    args.test_batch_size = 4 * args.batch_size
    if args.evaluate:
        args.epochs = 0

    env['training_writer'] = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))
    env['output_writers'] = output_writers

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        # custom_transforms.RandomHorizontalFlip(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = ShiftedSequenceFolder(args.data,
                                      transform=train_transform,
                                      seed=args.seed,
                                      train=True,
                                      with_depth_gt=False,
                                      with_pose_gt=args.supervise_pose,
                                      sequence_length=args.sequence_length)
    val_set = ShiftedSequenceFolder(args.data,
                                    transform=valid_transform,
                                    seed=args.seed,
                                    train=False,
                                    sequence_length=args.sequence_length,
                                    with_depth_gt=args.with_gt,
                                    with_pose_gt=args.with_gt)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=4 * args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    env['train_set'] = train_set
    env['val_set'] = val_set
    env['train_loader'] = train_loader
    env['val_loader'] = val_loader

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")
    pose_net = models.PoseNet(seq_length=args.sequence_length,
                              batch_norm=args.bn in ['pose',
                                                     'both']).to(device)

    if args.pretrained_pose:
        print("=> using pre-trained weights for pose net")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    depth_net = models.DepthNet(depth_activation="elu",
                                batch_norm=args.bn in ['depth',
                                                       'both']).to(device)

    if args.pretrained_depth:
        print("=> using pre-trained DepthNet model")
        data = torch.load(args.pretrained_depth)
        depth_net.load_state_dict(data['state_dict'])

    cudnn.benchmark = True
    depth_net = torch.nn.DataParallel(depth_net)
    pose_net = torch.nn.DataParallel(pose_net)

    env['depth_net'] = depth_net
    env['pose_net'] = pose_net

    print('=> setting adam solver')

    optim_params = [{
        'params': depth_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_net.parameters(),
        'lr': args.lr
    }]
    # parameters = chain(depth_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                args.lr_decay_frequency,
                                                gamma=0.5)
    env['optimizer'] = optimizer
    env['scheduler'] = scheduler

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()
    env['logger'] = logger

    env['args'] = args

    return env
示例#2
0
文件: train.py 项目: ml-lab/DepthNet
def main():
    global args, best_error, viz
    args = util.set_params(parser)

    train_writer = SummaryWriter(args.save_path / 'train')
    val_writer = SummaryWriter(args.save_path / 'val')
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'val' / str(i)))
    torch.manual_seed(args.seed)

    # Data loading code
    mean = [0.5, 0.5, 0.5]
    std = [0.2, 0.2, 0.2]
    normalize = transforms.Normalize(mean=mean, std=std)
    input_transform = transforms.Compose([
        co_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]), normalize
    ])
    target_transform = transforms.Compose(
        [co_transforms.Clip(0, 100),
         co_transforms.ArrayToTensor()])
    co_transform = co_transforms.Compose([
        co_transforms.RandomVerticalFlip(),
        co_transforms.RandomHorizontalFlip()
    ])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set, val_set = datasets.still_box(args.data,
                                            transform=input_transform,
                                            target_transform=target_transform,
                                            co_transform=co_transform,
                                            split=args.split,
                                            seed=args.seed)
    print(
        '{} samples found, {} train scenes and {} validation samples '.format(
            len(val_set) + len(train_set), len(train_set), len(val_set)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)
    # create model
    if args.pretrained:
        data = torch.load(args.pretrained)
        assert (not data['with_confidence'])
        print("=> using pre-trained model '{}'".format(data['arch']))
        model = models.DepthNet(batch_norm=data['bn'],
                                clamp=args.clamp,
                                depth_activation=args.activation_function)
        model.load_state_dict(data['state_dict'])
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.DepthNet(batch_norm=args.bn,
                                clamp=args.clamp,
                                depth_activation=args.activation_function)

    model = model.cuda()
    cudnn.benchmark = True

    assert (args.solver in ['adam', 'sgd'])
    print('=> setting {} solver'.format(args.solver))
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     betas=(args.momentum, args.beta),
                                     weight_decay=args.weight_decay)
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    dampening=args.momentum)

    with open(os.path.join(args.save_path, args.log_summary), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'train_depth_error', 'normalized_train_depth_error',
            'depth_error', 'normalized_depth_error'
        ])

    with open(os.path.join(args.save_path, args.log_full), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'train_depth_error'])

    term_logger = TermLogger(n_epochs=args.epochs,
                             train_size=min(len(train_loader),
                                            args.epoch_size),
                             test_size=len(val_loader))
    term_logger.epoch_bar.start()

    if args.evaluate:
        depth_error, normalized = validate(val_loader, model, 0, term_logger,
                                           output_writers)
        term_logger.test_writer.write(
            ' * Depth error : {:.3f}, normalized : {:.3f}'.format(
                depth_error, normalized))
        return

    for epoch in range(args.epochs):
        term_logger.epoch_bar.update(epoch)
        util.adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        term_logger.reset_train_bar()
        term_logger.train_bar.start()
        train_loss, train_error, train_normalized_error = train(
            train_loader, model, optimizer, args.epoch_size, term_logger,
            train_writer)
        term_logger.train_writer.write(
            ' * Avg Loss : {:.3f}, Avg Depth error : {:.3f}, normalized : {:.3f}'
            .format(train_loss, train_error, train_normalized_error))
        train_writer.add_scalar('metric_error', train_error, epoch)
        train_writer.add_scalar('metric_normalized_error',
                                train_normalized_error, epoch)

        # evaluate on validation set
        term_logger.reset_test_bar()
        term_logger.test_bar.start()
        depth_error, normalized = validate(val_loader, model, epoch,
                                           term_logger, output_writers)
        term_logger.test_writer.write(
            ' * Depth error : {:.3f}, normalized : {:.3f}'.format(
                depth_error, normalized))
        val_writer.add_scalar('metric_error', depth_error, epoch)
        val_writer.add_scalar('metric_normalized_error', normalized, epoch)

        if best_error < 0:
            best_error = depth_error

        # remember lowest error and save checkpoint
        is_best = depth_error < best_error
        best_error = min(depth_error, best_error)
        util.save_checkpoint(
            args.save_path, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_EPE': best_error,
                'bn': args.bn,
                'with_confidence': False,
                'activation_function': args.activation_function,
                'clamp': args.clamp,
                'mean': mean,
                'std': std
            }, is_best)

        with open(os.path.join(args.save_path, args.log_summary),
                  'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, train_error, depth_error])
    term_logger.epoch_bar.finish()
示例#3
0
def main():
    global args, best_error, viz
    args = util.set_params(parser)
    logging.info("[starting]" * 10)
    train_writer = SummaryWriter(args.save_path / 'train')
    val_writer = SummaryWriter(args.save_path / 'val')
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'val' / str(i)))
    torch.manual_seed(args.seed)

    # Data loading code
    mean = [0.5, 0.5, 0.5]
    std = [0.2, 0.2, 0.2]
    normalize = transforms.Normalize(mean=mean, std=std)
    input_transform = transforms.Compose([
        co_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]), normalize
    ])
    target_transform = transforms.Compose(
        [co_transforms.Clip(0, 100),
         co_transforms.ArrayToTensor()])
    co_transform = co_transforms.Compose([
        co_transforms.RandomVerticalFlip(),
        co_transforms.RandomHorizontalFlip()
    ])

    logging.info("=> fetching scenes in '{}'".format(args.data))
    train_set, val_set = datasets.still_box(args.data,
                                            transform=input_transform,
                                            target_transform=target_transform,
                                            co_transform=co_transform,
                                            split=args.split,
                                            seed=args.seed)
    logging.info(
        '{} samples found, {} train scenes and {} validation samples '.format(
            len(val_set) + len(train_set), len(train_set), len(val_set)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)
    # create model
    if args.pretrained:
        data = torch.load(args.pretrained)
        assert (not data['with_confidence'])
        print("=> using pre-trained model '{}'".format(data['arch']))
        model = models.DepthNet(batch_norm=data['bn'],
                                clamp=args.clamp,
                                depth_activation=args.activation_function)
        model.load_state_dict(data['state_dict'])
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.DepthNet(batch_norm=args.bn,
                                clamp=args.clamp,
                                depth_activation=args.activation_function)
    model = model.to(device)
    logging.info("Model created")
    # if torch.cuda.device_count() > 1:
    # print("%"*100)
    # print("Let's use", torch.cuda.device_count(), "GPUs!")
    # # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    # model = torch.nn.DataParallel(model, device_ids=device_ids)

    # if torch.cuda.is_available():
    # print("&"*100)
    # model.cuda()

    #model = torch.nn.DataParallel(model.cuda(1), device_ids=device_ids)
    cudnn.benchmark = True

    assert (args.solver in ['adam', 'sgd'])
    print('=> setting {} solver'.format(args.solver))
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     betas=(args.momentum, args.beta),
                                     weight_decay=args.weight_decay)
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    dampening=args.momentum)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[19, 30, 44, 53], gamma=0.3)
    logging.info("Optimizer created")

    with open(os.path.join(args.save_path, args.log_summary), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'train_depth_error', 'normalized_train_depth_error',
            'depth_error', 'normalized_depth_error'
        ])

    with open(os.path.join(args.save_path, args.log_full), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'train_depth_error'])

    term_logger = TermLogger(n_epochs=args.epochs,
                             train_size=min(len(train_loader),
                                            args.epoch_size),
                             test_size=len(val_loader))
    term_logger.epoch_bar.start()
    logging.info("Validate")
    if args.evaluate:
        depth_error, normalized = validate(val_loader, model, 0, term_logger,
                                           output_writers)
        term_logger.test_writer.write(
            ' * Depth error : {:.3f}, normalized : {:.3f}'.format(
                depth_error, normalized))
        return
    logging.info("epoch loop for %d time" % args.epochs)
    for epoch in range(args.epochs):
        logging.info("<epoch>=%d :start" % epoch)
        term_logger.epoch_bar.update(epoch)
        #scheduler.module.step()
        scheduler.step()

        # train for one epoch
        logging.info("train for one epoch: start       ")
        term_logger.reset_train_bar()
        term_logger.train_bar.start()
        logging.info("it might take more than 3min     ")
        train_loss, train_error, train_normalized_error = train(
            train_loader, model, optimizer, args.epoch_size, term_logger,
            train_writer)
        logging.info("train for one epoch: done         ")

        term_logger.train_writer.write(
            ' * Avg Loss : {:.3f}, Avg Depth error : {:.3f}, normalized : {:.3f}'
            .format(train_loss, train_error, train_normalized_error))
        train_writer.add_scalar('metric_error', train_error, epoch)
        train_writer.add_scalar('metric_normalized_error',
                                train_normalized_error, epoch)

        # evaluate on validation set
        logging.info("evaluate on validation set")
        term_logger.reset_test_bar()
        term_logger.test_bar.start()
        depth_error, normalized = validate(val_loader, model, epoch,
                                           term_logger, output_writers)
        term_logger.test_writer.write(
            ' * Depth error : {:.3f}, normalized : {:.3f}'.format(
                depth_error, normalized))
        val_writer.add_scalar('metric_error', depth_error, epoch)
        val_writer.add_scalar('metric_normalized_error', normalized, epoch)

        if best_error < 0:
            best_error = depth_error

        # remember lowest error and save checkpoint
        logging.info("remember lowest error and save checkpoint")
        is_best = depth_error < best_error
        best_error = min(depth_error, best_error)
        util.save_checkpoint(
            args.save_path, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_error': best_error,
                'bn': args.bn,
                'with_confidence': False,
                'activation_function': args.activation_function,
                'clamp': args.clamp,
                'mean': mean,
                'std': std
            }, is_best)

        with open(os.path.join(args.save_path, args.log_summary),
                  'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, train_error, depth_error])
        logging.info("epoch=%d done" % epoch)
    term_logger.epoch_bar.finish()
示例#4
0
def main():
    global best_a3, n_iter, device
    args = parser.parse_args()
    torch.autograd.set_detect_anomaly(True)  # 启动梯度侦测,用于查找梯度终断
    """====== step 1 : 根据使用的数据类型加载相应的数据流水线  ======"""
    if args.dataset_format == 'stacked':
        from DataFlow.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from DataFlow.sequence_folders import SequenceFolder
    """====== step 2 : 准备存储目录 ======"""
    save_path = save_path_formatter(args, parser)
    if sys.platform is 'win32':
        args.save_path = '.\checkpoints' / save_path
    else:  # linux
        args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    tb_writer = SummaryWriter(args.save_path)  # tensorboardx writer
    """====== step 3 : 指定随机数种子以便于实验复现 ======"""
    torch.manual_seed(args.seed)
    """========= step 4 : 数据准备 =========="""
    # 数据扩增
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])
    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])
    # 训练集
    print("=> fetching data from '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.SEQ_LENGTH)
    # 验证集
    val_set = ValidationSet(args.data, transform=valid_transform)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)
    val_loader = DataLoader(val_set,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)
    """========= step 5 : 加载模型 =========="""
    print("=> creating models")

    depth_net = models.DepthNet().to(device)
    motion_net = models.MotionNet(intrinsic_pred=args.intri_pred).to(device)

    if args.pretrained_depth:
        print("=> using pre-trained weights for DepthNet")
        weights = torch.load(args.pretrained_depth)
        depth_net.load_state_dict(weights['state_dict'], strict=False)

    if args.pretrained_motion:
        print("=> using pre-trained weights for MotionNet")
        weights = torch.load(args.pretrained_motion)
        motion_net.load_state_dict(weights['state_dict'])

    cudnn.benchmark = True
    depth_net = torch.nn.DataParallel(depth_net)
    motion_net = torch.nn.DataParallel(motion_net)
    """========= step 6 : 设置求解器 =========="""
    print('=> setting adam solver')

    optim_params = [{
        'params': depth_net.parameters(),
        'lr': args.lr
    }, {
        'params': motion_net.parameters(),
        'lr': args.lr
    }]

    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)
    """====== step 7 : 初始化损失函数计算器======="""
    total_loss_calculator = LossFactory(
        SEQ_LENGTH=args.SEQ_LENGTH,
        rgb_weight=args.rgb_weight,
        depth_smoothing_weight=args.depth_smoothing_weight,
        ssim_weight=args.ssim_weight,
        motion_smoothing_weight=args.motion_smoothing_weight,
        rotation_consistency_weight=args.rotation_consistency_weight,
        translation_consistency_weight=args.translation_consistency_weight,
        depth_consistency_loss_weight=args.depth_consistency_loss_weight)
    """========= step 8 : 训练循环 =========="""
    if args.epoch_size == 0:
        args.epoch_size = len(
            train_loader)  # 如果不指定epoch_size,那么每一个epoch就把全部的训练数据过一遍
    for epoch in range(args.epochs):
        tqdm.write("\n===========TRAIN EPOCH [{}/{}]===========".format(
            epoch + 1, args.epochs))
        """====== step 8.1 : 训练一个epoch ======"""
        train_loss = train(args, train_loader, depth_net, motion_net,
                           optimizer, args.epoch_size, total_loss_calculator,
                           tb_writer)
        tqdm.write('* Avg Loss : {:.3f}'.format(train_loss))
        """======= step 8.2 : 验证 ========"""
        # 验证时要输出 : 深度指标abs_diff, abs_rel, sq_rel, a1, a2, a3
        errors, error_names = validate_with_gt(args, val_loader, depth_net,
                                               motion_net, epoch, tb_writer)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        tqdm.write(error_string)
        # TODO:输出验证集上的轨迹指标

        # abs_rel, sq_rel, rms, log_rms, a1, a2, a3
        tb_writer.add_scalar("Relative Errors/abs_rel", errors[0], epoch)
        tb_writer.add_scalar("Relative Errors/sq_rel", errors[1], epoch)
        tb_writer.add_scalar("Root Mean Squared Error/rms", errors[2], epoch)
        tb_writer.add_scalar("Root Mean Squared Error/log_rms", errors[3],
                             epoch)
        tb_writer.add_scalar("Thresholding accuracy/a1", errors[4], epoch)
        tb_writer.add_scalar("Thresholding accuracy/a2", errors[5], epoch)
        tb_writer.add_scalar("Thresholding accuracy/a3", errors[6], epoch)
        """======= step 8.3 : 保存验证效果最佳的模型状态 =========="""
        decisive_a3 = errors[6]  # 选取a3为关键指标
        is_best = decisive_a3 > best_a3  # 如果当前的a3比之前记录的a3更大,那么模型的最佳状态就是现在的状态
        best_a3 = max(best_a3, decisive_a3)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': depth_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': motion_net.module.state_dict()
        }, is_best)
    pass  # end of main