예제 #1
0
def main():
    global n_iter
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / (args.exp + '_' + save_path)
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    for i in range(3):
        output_writers.append(SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               ttype=args.ttype,
                               dataset=args.dataset)
    val_set = SequenceFolder(args.data,
                             transform=valid_transform,
                             seed=args.seed,
                             ttype=args.ttype2,
                             dataset=args.dataset)

    train_set.samples = train_set.samples[:len(train_set) -
                                          len(train_set) % args.batch_size]

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    mvdnet = MVDNet(args.nlabel, args.mindepth, no_pool=args.no_pool).cuda()
    mvdnet.init_weights()
    if args.pretrained_mvdn:
        print("=> using pre-trained weights for MVDNet")
        weights = torch.load(args.pretrained_mvdn)
        mvdnet.load_state_dict(weights['state_dict'])

    depth_cons = DepthCons().cuda()
    depth_cons.init_weights()

    if args.pretrained_cons:
        print("=> using pre-trained weights for ConsNet")
        weights = torch.load(args.pretrained_cons)
        depth_cons.load_state_dict(weights['state_dict'])

    cons_loss_ = ConsLoss().cuda()
    print('=> setting adam solver')

    if args.train_cons:
        optimizer = torch.optim.Adam(depth_cons.parameters(),
                                     args.lr,
                                     betas=(args.momentum, args.beta),
                                     weight_decay=args.weight_decay)
        mvdnet.eval()
    else:
        optimizer = torch.optim.Adam(mvdnet.parameters(),
                                     args.lr,
                                     betas=(args.momentum, args.beta),
                                     weight_decay=args.weight_decay)

    cudnn.benchmark = True
    mvdnet = torch.nn.DataParallel(mvdnet)
    depth_cons = torch.nn.DataParallel(depth_cons)

    print(' ==> setting log files')
    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'validation_abs_rel', 'validation_abs_diff',
            'validation_sq_rel', 'validation_a1', 'validation_a2',
            'validation_a3', 'mean_angle_error'
        ])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss'])

    print(' ==> main Loop')
    for epoch in range(args.epochs):
        adjust_learning_rate(args, optimizer, epoch)

        # train for one epoch
        if args.evaluate:
            train_loss = 0
        else:
            train_loss = train(args, train_loader, mvdnet, depth_cons,
                               cons_loss_, optimizer, args.epoch_size,
                               training_writer, epoch)
        if not args.evaluate and (args.skip_v):
            error_names = [
                'abs_rel', 'abs_diff', 'sq_rel', 'a1', 'a2', 'a3', 'angle'
            ]
            errors = [0] * 7
        else:
            errors, error_names = validate_with_gt(args, val_loader, mvdnet,
                                                   depth_cons, epoch,
                                                   output_writers)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[0]
        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                train_loss, decisive_error, errors[1], errors[2], errors[3],
                errors[4], errors[5], errors[6]
            ])
        if args.evaluate:
            break
        if args.train_cons:
            save_checkpoint(args.save_path, {
                'epoch': epoch + 1,
                'state_dict': depth_cons.module.state_dict()
            },
                            epoch,
                            file_prefixes=['cons'])
        else:
            save_checkpoint(args.save_path, {
                'epoch': epoch + 1,
                'state_dict': mvdnet.module.state_dict()
            },
                            epoch,
                            file_prefixes=['mvdnet'])
def main():
    global args, best_error, n_iter, device
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints_shifted' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = ShiftedSequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length,
        target_displacement=args.target_displacement)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    adjust_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )  # workers is set to 0 to avoid multiple instances to be modified at the same time
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    train.args = args
    # create model
    print("=> creating model")

    disp_net = models.DispNetS().cuda()
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net,
                           optimizer, args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        if (epoch + 1) % 5 == 0:
            train_set.adjust = True
            logger.reset_train_bar(len(adjust_loader))
            average_shifts = adjust_shifts(args, train_set, adjust_loader,
                                           pose_exp_net, epoch, logger,
                                           training_writer)
            shifts_string = ' '.join(
                ['{:.3f}'.format(s) for s in average_shifts])
            logger.train_writer.write(
                ' * adjusted shifts, average shifts are now : {}'.format(
                    shifts_string))
            for i, shift in enumerate(average_shifts):
                training_writer.add_scalar('shifts{}'.format(i), shift, epoch)
            train_set.adjust = False

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger,
                                                      output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
예제 #3
0
def prepare_environment():
    env = {}
    args = parser.parse_args()
    if args.dataset_format == 'KITTI':
        from datasets.shifted_sequence_folders import ShiftedSequenceFolder
    elif args.dataset_format == 'StillBox':
        from datasets.shifted_sequence_folders import StillBox as ShiftedSequenceFolder
    elif args.dataset_format == 'TUM':
        from datasets.shifted_sequence_folders import TUM as ShiftedSequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    args.test_batch_size = 4 * args.batch_size
    if args.evaluate:
        args.epochs = 0

    env['training_writer'] = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))
    env['output_writers'] = output_writers

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        # custom_transforms.RandomHorizontalFlip(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = ShiftedSequenceFolder(args.data,
                                      transform=train_transform,
                                      seed=args.seed,
                                      train=True,
                                      with_depth_gt=False,
                                      with_pose_gt=args.supervise_pose,
                                      sequence_length=args.sequence_length)
    val_set = ShiftedSequenceFolder(args.data,
                                    transform=valid_transform,
                                    seed=args.seed,
                                    train=False,
                                    sequence_length=args.sequence_length,
                                    with_depth_gt=args.with_gt,
                                    with_pose_gt=args.with_gt)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=4 * args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    env['train_set'] = train_set
    env['val_set'] = val_set
    env['train_loader'] = train_loader
    env['val_loader'] = val_loader

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")
    pose_net = models.PoseNet(seq_length=args.sequence_length,
                              batch_norm=args.bn in ['pose',
                                                     'both']).to(device)

    if args.pretrained_pose:
        print("=> using pre-trained weights for pose net")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    depth_net = models.DepthNet(depth_activation="elu",
                                batch_norm=args.bn in ['depth',
                                                       'both']).to(device)

    if args.pretrained_depth:
        print("=> using pre-trained DepthNet model")
        data = torch.load(args.pretrained_depth)
        depth_net.load_state_dict(data['state_dict'])

    cudnn.benchmark = True
    depth_net = torch.nn.DataParallel(depth_net)
    pose_net = torch.nn.DataParallel(pose_net)

    env['depth_net'] = depth_net
    env['pose_net'] = pose_net

    print('=> setting adam solver')

    optim_params = [{
        'params': depth_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_net.parameters(),
        'lr': args.lr
    }]
    # parameters = chain(depth_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                args.lr_decay_frequency,
                                                gamma=0.5)
    env['optimizer'] = optimizer
    env['scheduler'] = scheduler

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()
    env['logger'] = logger

    env['args'] = args

    return env
예제 #4
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0

    tb_writer = SummaryWriter(args.save_path)
    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    disp_net = models.DispNetS().to(device)
    seg_net = DeepLab(num_classes=args.nclass,
                      backbone=args.backbone,
                      output_stride=args.out_stride,
                      sync_bn=args.sync_bn,
                      freeze_bn=args.freeze_bn).to(device)
    if args.pretrained_seg:
        print("=> using pre-trained weights for seg net")
        weights = torch.load(args.pretrained_seg)
        seg_net.load_state_dict(weights, strict=False)
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)
    seg_net = torch.nn.DataParallel(seg_net)

    print('=> setting adam solver')

    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    if args.pretrained_disp or args.evaluate:
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   0, logger, tb_writer)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      0, logger, tb_writer)
        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(name, error, 0)
        error_string = ', '.join(
            '{} : {:.3f}'.format(name, error)
            for name, error in zip(error_names[2:9], errors[2:9]))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net, seg_net,
                           optimizer, args.epoch_size, logger, tb_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   seg_net, epoch, logger,
                                                   tb_writer)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger, tb_writer)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
예제 #5
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder, StereoSequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints'/save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(SummaryWriter(args.save_path/'valid'/str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = StereoSequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length
    )

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(
            args.data,
            transform=valid_transform
        )
    else:
        val_set = StereoSequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # 没有epoch_size的时候(=0),每个epoch训练train_set中所有的samples
    # 有epoch_size的时候,每个epoch只训练一部分train_set
    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    # 初始化网络结构
    print("=> creating model")

    # disp_net = models.DispNetS().to(device)
    disp_net = models.DispResNet(3).to(device)
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    # 如果有mask loss,PoseExpNet 要输出mask和pose estimation,因为两个输出共享encoder网络
    # pose_exp_net = PoseExpNet(nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device)
    pose_exp_net = models.PoseExpNet(nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    # 并行化
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    # 训练方式:Adam
    print('=> setting adam solver')
    # 两个网络一起
    optim_params = [
        {'params': disp_net.parameters(), 'lr': args.lr},
        {'params': pose_exp_net.parameters(), 'lr': args.lr}
    ]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path/args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path/args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    # 对pretrained模型先做评估
    if args.pretrained_disp or args.evaluate:
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net, 0, output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, 0, output_writers)
        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, 0)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names[2:9], errors[2:9]))

    # 正式训练
    for epoch in range(args.epochs):

        # train for one epoch 训练一个周期
        print('\n')
        train_loss = train(args, train_loader, disp_net, pose_exp_net, optimizer, args.epoch_size, training_writer, epoch)

        # evaluate on validation set
        print('\n')
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net, epoch, output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        # 验证输出四个loss:总体final loss,warping loss以及mask正则化loss
        # 可自选以哪一种loss作为best model的标准
        decisive_error = errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        # 保存validation最佳model
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(
            args.save_path, {
                'epoch': epoch + 1,
                'state_dict': disp_net.module.state_dict()
            }, {
                'epoch': epoch + 1,
                'state_dict': pose_exp_net.module.state_dict()
            },
            is_best)

        with open(args.save_path/args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
예제 #6
0
파일: train.py 프로젝트: zuru/DeepSFM
def main():
    global n_iter
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        ttype=args.ttype,
        add_geo=args.geo,
        depth_source=args.depth_init,
        pose_source='%s_poses.txt' %
        args.pose_init if args.pose_init else 'poses.txt',
        scale=False)
    val_set = SequenceFolder(args.data,
                             transform=valid_transform,
                             seed=args.seed,
                             ttype=args.ttype2,
                             add_geo=args.geo,
                             depth_source=args.depth_init,
                             pose_source='%s_poses.txt' %
                             args.pose_init if args.pose_init else 'poses.txt',
                             scale=False)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    depth_net = PSNet(args.nlabel,
                      args.mindepth,
                      add_geo_cost=args.geo,
                      depth_augment=False).cuda()

    if args.pretrained_dps:
        # for param in depth_net.feature_extraction.parameters():
        #     param.requires_grad = False

        print("=> using pre-trained weights for DPSNet")
        model_dict = depth_net.state_dict()
        weights = torch.load(args.pretrained_dps)['state_dict']
        pretrained_dict = {k: v for k, v in weights.items() if k in model_dict}

        model_dict.update(pretrained_dict)

        depth_net.load_state_dict(model_dict)

    else:
        depth_net.init_weights()

    cudnn.benchmark = True
    depth_net = torch.nn.DataParallel(depth_net)

    print('=> setting adam solver')

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        depth_net.parameters()),
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss'])

    for epoch in range(args.epochs):
        adjust_learning_rate(args, optimizer, epoch)

        # train for one epoch
        train_loss = train(args, train_loader, depth_net, optimizer,
                           args.epoch_size, training_writer)

        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': depth_net.module.state_dict()
        }, epoch)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss])
예제 #7
0
def main():
    global args, best_error, n_iter, device
    args = parser.parse_args()
    from dataset_loader import SequenceFolder

    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    train_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
    ])

    valid_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
    ])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    val_set = SequenceFolder(
        args.data,
        transform=valid_transform,
        seed=args.seed,
        train=False,
        sequence_length=args.sequence_length,
    )
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    encoder = enCoder()
    decoder = deCoder()
    gplayer = GPlayer()

    if args.pretrained_dict:
        print("=> using pre-trained weights")
        weights = torch.load(args.pretrained_dict)
        pretrained_dict = weights['state_dict']

        encoder_dict = encoder.state_dict()
        pretrained_dict_encoder = {
            k: v
            for k, v in pretrained_dict.items() if k in encoder_dict
        }
        encoder_dict.update(pretrained_dict_encoder)
        encoder.load_state_dict(pretrained_dict_encoder)

        decoder_dict = decoder.state_dict()
        pretrained_dict_decoder = {
            k: v
            for k, v in pretrained_dict.items() if k in decoder_dict
        }
        decoder_dict.update(pretrained_dict_decoder)
        decoder.load_state_dict(pretrained_dict_decoder)

    encoder = encoder.to(device)
    decoder = decoder.to(device)

    cudnn.benchmark = True
    encoder = torch.nn.DataParallel(encoder)
    decoder = torch.nn.DataParallel(decoder)

    parameters = chain(encoder.parameters(), gplayer.parameters(),
                       decoder.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))

    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()

        train_loss = train(train_loader, encoder, gplayer, decoder, optimizer,
                           args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()

        errors, error_names = validate(val_loader, encoder, gplayer, decoder,
                                       epoch, logger, output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        decisive_error = errors[-1]
        if best_error < 0:
            best_error = decisive_error

        # save best checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': encoder.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': gplayer.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': decoder.state_dict()
        }, is_best)

    logger.epoch_bar.finish()
예제 #8
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()  #如果没有,则建立,有则啥都不干 in Path.py小工具
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0
#tensorboard SummaryWriter
    training_writer = SummaryWriter(args.save_path)  #for tensorboard

    output_writers = []  #list
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))
# Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])
    '''transform'''
    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(
        args.data,  #processed_data_train_sets
        transform=train_transform,  #把几种变换函数输入进去
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length)
    # if no Groundtruth is avalaible, Validation set is
    # the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(
        len(train_set), len(train_set.scenes)))  #训练集都是序列,不用左右
    print('{} samples found in {} valid scenes'.format(
        len(val_set), len(val_set.scenes)))  #测试集也是序列,不需要左右
    train_loader = torch.utils.data.DataLoader(  #data(list): [tensor(B,3,H,W),list(B),(B,H,W),(b,h,w)]
        dataset=train_set,  #sequenceFolder
        batch_size=args.batch_size,
        shuffle=True,  #打乱
        num_workers=args.workers,  #多线程读取数据
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        dataset=val_set,
        batch_size=args.batch_size,
        shuffle=False,  #不打乱
        num_workers=args.workers,
        pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

# create model
    print("=> creating model")
    #disp
    disp_net = models.DispNetS().to(device)
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    #pose
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    #init posenet
    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    #init dispNet

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')
    #可以看到两个一起训练
    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)
    #训练结果写入csv
    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    n_epochs = args.epochs
    train_size = min(len(train_loader), args.epoch_size)
    valid_size = len(val_loader)
    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    if args.pretrained_disp or args.evaluate:
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   0, logger, output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      0, logger,
                                                      output_writers)

        for error, name in zip(
                errors, error_names
        ):  #validation时,对['Total loss', 'Photo loss', 'Exp loss']三个 epoch-record 指标添加记录值
            training_writer.add_scalar(name, error, 0)
        error_string = ', '.join(
            '{} : {:.3f}'.format(name, error)
            for name, error in zip(error_names[2:9], errors[2:9]))
        logger.valid_writer.write(' * Avg {}'.format(error_string))


#main cycle
    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        logger.reset_train_bar()
        #1. train for one epoch
        train_loss = train(args, train_loader, disp_net, pose_exp_net,
                           optimizer, args.epoch_size, logger, training_writer)
        #其他参数都好解释, logger: SelfDefined class,

        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        logger.reset_valid_bar()

        # 2. validate on validation set
        if args.with_gt:  #<class 'list'>: ['Total loss', 'Photo loss', 'Exp loss']
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger,
                                                      output_writers)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error,
                                       epoch)  #损失函数中记录epoch-record指标

        # Up to you to chose the most relevant error to measure
        # your model's performance, careful some measures are to maximize (such as a1,a2,a3)

        # 3. remember lowest error and save checkpoint
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)

        #模型保存
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary,
                  'a') as csvfile:  #每个epoch留下结果
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss,
                             decisive_error])  #第二个就是validataion 中的epoch-record
            # loss<class 'list'>: ['Total loss', 'Photo loss', 'Exp loss']
    logger.epoch_bar.finish()
예제 #9
0
def main():
    global n_iter
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               ttype=args.ttype)
    val_set = SequenceFolder(args.data,
                             transform=valid_transform,
                             seed=args.seed,
                             ttype=args.ttype2)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    dpsnet = PSNet(args.nlabel, args.mindepth).cuda()

    if args.pretrained_dps:
        print("=> using pre-trained weights for DPSNet")
        weights = torch.load(args.pretrained_dps)
        dpsnet.load_state_dict(weights['state_dict'])
    else:
        dpsnet.init_weights()

    cudnn.benchmark = True
    dpsnet = torch.nn.DataParallel(dpsnet)

    print('=> setting adam solver')

    parameters = chain(dpsnet.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss'])

    for epoch in range(args.epochs):
        adjust_learning_rate(args, optimizer, epoch)

        # train for one epoch
        train_loss = train(args, train_loader, dpsnet, optimizer,
                           args.epoch_size, training_writer)
        errors, error_names = validate_with_gt(args, val_loader, dpsnet, epoch,
                                               output_writers)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[0]
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': dpsnet.module.state_dict()
        }, epoch)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
예제 #10
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints'/save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(SummaryWriter(args.save_path/'valid'/str(i)))

    # Data loading code
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor()
    ])

    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = KITTIDataset(
        root_dir,
        sequences,
        max_distance=args.max_distance,
        transform=None
    )

    val_set = KITTIDataset(
        root_dir,
        sequences,
        max_distance=args.max_distance,
        transform=None
    )
    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set)))
    print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set)))

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    config_file = "./configs/e2e_mask_rcnn_R_50_FPN_1x.yaml"
    cfg.merge_from_file(config_file)
    cfg.freeze()
    pretrained_model_path = "./e2e_mask_rcnn_R_50_FPN_1x.pth"
    disvo = DISVO(cfg, pretrained_model_path).cuda()

    if args.pretrained_disvo:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disvo)
        disvo.load_state_dict(weights['state_dict'])
    else:
        disvo.init_weights()

    cudnn.benchmark = True

    print('=> setting adam solver')

    optim_params = [
        {'params': disvo.parameters(), 'lr': args.lr}
    ]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path/args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path/args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss'])

    logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader))
    logger.epoch_bar.start()

    if args.pretrained_disvo or args.evaluate:
        logger.reset_valid_bar()
        errors, error_names = validate(args, val_loader, disvo, 0, logger, output_writers)
        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, 0)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names[2:9], errors[2:9]))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disvo, optimizer, args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        errors, error_names = validate(args, val_loader, disvo, 0, logger, output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(
            args.save_path, {
                'epoch': epoch + 1,
                'state_dict': disvo.module.state_dict()
            },
            is_best)

        with open(args.save_path/args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
예제 #11
0
def main():

    global best_error, n_iter

    args = parser.parse_args()

    n_save_model = 1
    degradation_mode = 0
    fusion_mode = 3
    # 0: vision only 1: direct 2: soft 3: hard

    # set saving path
    abs_path = ''
    save_path = save_path_formatter(args, parser)
    args.save_path = abs_path + 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    if not os.path.exists(args.save_path + '/imgs/'):
        os.makedirs(args.save_path + '/imgs/')
    if not os.path.exists(args.save_path + '/models/'):
        os.makedirs(args.save_path + '/models/')
    torch.manual_seed(args.seed)

    # image transform
    normalize = custom_transforms.Normalize(mean=[0, 0, 0],
                                            std=[255, 255, 255])
    normalize2 = custom_transforms.Normalize(mean=[0.411, 0.432, 0.45],
                                             std=[1, 1, 1])
    input_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize, normalize2])

    # Data loading code
    print("=> fetching scenes in '{}'".format(args.data))
    train_set = KITTI_Loader(args.data,
                             transform=input_transform,
                             seed=args.seed,
                             train=0,
                             sequence_length=args.sequence_length,
                             data_degradation=degradation_mode,
                             data_random=True)

    val_set = KITTI_Loader(args.data,
                           transform=input_transform,
                           seed=args.seed,
                           train=1,
                           sequence_length=args.sequence_length,
                           data_degradation=degradation_mode,
                           data_random=True)

    test_set = KITTI_Loader(args.data,
                            transform=input_transform,
                            seed=args.seed,
                            train=2,
                            sequence_length=args.sequence_length,
                            data_degradation=degradation_mode,
                            data_random=False)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create pose model
    print("=> creating pose model")

    feature_dim = 256

    if fusion_mode == 0:
        feature_ext = FlowNet(args.batch_size).cuda()
        fc_flownet = Fc_Flownet(32 * 1024, feature_dim * 2).cuda()
        rec_feat = RecFeat(feature_dim * 2, feature_dim * 2, args.batch_size,
                           2).cuda()
        rec_imu = RecImu(6, int(feature_dim / 2), args.batch_size, 2,
                         feature_dim).cuda()
        selectfusion = Hard_Mask(feature_dim * 2, feature_dim * 2).cuda()
        pose_net = PoseRegressor(feature_dim * 2).cuda()

    if fusion_mode == 1:
        feature_ext = FlowNet(args.batch_size).cuda()
        fc_flownet = Fc_Flownet(32 * 1024, feature_dim).cuda()
        rec_feat = RecFeat(feature_dim * 2, feature_dim * 2, args.batch_size,
                           2).cuda()
        rec_imu = RecImu(6, int(feature_dim / 2), args.batch_size, 2,
                         feature_dim).cuda()
        selectfusion = Hard_Mask(feature_dim * 2, feature_dim * 2).cuda()
        pose_net = PoseRegressor(feature_dim * 2).cuda()

    if fusion_mode == 2:
        feature_ext = FlowNet(args.batch_size).cuda()
        fc_flownet = Fc_Flownet(32 * 1024, feature_dim).cuda()
        rec_feat = RecFeat(feature_dim * 2, feature_dim * 2, args.batch_size,
                           2).cuda()
        rec_imu = RecImu(6, int(feature_dim / 2), args.batch_size, 2,
                         feature_dim).cuda()
        selectfusion = Soft_Mask(feature_dim * 2, feature_dim * 2).cuda()
        pose_net = PoseRegressor(feature_dim * 2).cuda()

    if fusion_mode == 3:
        feature_ext = FlowNet(args.batch_size).cuda()
        fc_flownet = Fc_Flownet(32 * 1024, feature_dim).cuda()
        rec_feat = RecFeat(feature_dim * 2, feature_dim * 2, args.batch_size,
                           2).cuda()
        rec_imu = RecImu(6, int(feature_dim / 2), args.batch_size, 2,
                         feature_dim).cuda()
        selectfusion = Hard_Mask(feature_dim * 2, feature_dim * 2).cuda()
        pose_net = PoseRegressor(feature_dim * 2).cuda()

    pose_net.init_weights()

    flownet_model_path = abs_path + '../../../pretrain/flownets_EPE1.951.pth'
    pretrained_flownet = True
    if pretrained_flownet:
        weights = torch.load(flownet_model_path)
        model_dict = feature_ext.state_dict()
        update_dict = {
            k: v
            for k, v in weights['state_dict'].items() if k in model_dict
        }
        model_dict.update(update_dict)
        feature_ext.load_state_dict(model_dict)
        print('restrore depth model from ' + flownet_model_path)

    cudnn.benchmark = True
    feature_ext = torch.nn.DataParallel(feature_ext)
    rec_feat = torch.nn.DataParallel(rec_feat)
    pose_net = torch.nn.DataParallel(pose_net)
    rec_imu = torch.nn.DataParallel(rec_imu)
    fc_flownet = torch.nn.DataParallel(fc_flownet)
    selectfusion = torch.nn.DataParallel(selectfusion)

    print('=> setting adam solver')

    parameters = chain(rec_feat.parameters(), rec_imu.parameters(),
                       fc_flownet.parameters(), pose_net.parameters(),
                       selectfusion.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'train_pose', 'train_euler', 'validation_loss',
            'val_pose', 'val_euler'
        ])

    # start training loop
    print('=> training pose model')

    best_val = 100

    best_tra = 10000.0
    best_ori = 10000.0

    for epoch in range(args.epochs):

        # train for one epoch
        train_loss, pose_loss, euler_loss, temp = train(
            args, train_loader, feature_ext, rec_feat, rec_imu, pose_net,
            fc_flownet, selectfusion, optimizer, epoch, fusion_mode)

        temp = 0.5

        # evaluate on validation set
        val_loss, val_pose_loss, val_euler_loss =\
            validate(args, val_loader, feature_ext, rec_feat, rec_imu, pose_net, fc_flownet, selectfusion, temp, epoch,
                     fusion_mode)

        # evaluate on validation set
        test(args, test_loader, feature_ext, rec_feat, rec_imu, pose_net,
             fc_flownet, selectfusion, temp, epoch, fusion_mode)

        if val_pose_loss < best_tra:
            best_tra = val_pose_loss

        if val_euler_loss < best_ori:
            best_ori = val_euler_loss

        print('Best: {}, Best Translation {:.5} Best Orientation {:.5}'.format(
            epoch + 1, best_tra, best_ori))

        # save checkpoint
        if (epoch % n_save_model == 0) or (val_loss < best_val):

            best_val = val_loss

            fn = args.save_path + '/models/rec_' + str(epoch) + '.pth'
            torch.save(rec_feat.module.state_dict(), fn)

            fn = args.save_path + '/models/pose_' + str(epoch) + '.pth'
            torch.save(pose_net.module.state_dict(), fn)

            fn = args.save_path + '/models/fc_flownet_' + str(epoch) + '.pth'
            torch.save(fc_flownet.module.state_dict(), fn)

            fn = args.save_path + '/models/rec_imu_' + str(epoch) + '.pth'
            torch.save(rec_imu.module.state_dict(), fn)

            fn = args.save_path + '/models/selectfusion_' + str(epoch) + '.pth'
            torch.save(selectfusion.module.state_dict(), fn)
            print('Model has been saved')

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                train_loss, pose_loss, euler_loss, val_loss, val_pose_loss,
                val_euler_loss
            ])