Exemple #1
0
def main():
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.savemodel == "":
        timestr = datetime.now().isoformat().replace(':',
                                                     '-').replace('.', 'MS')
        args.savemodel = timestr
    savepath = os.path.join(args.savedir, args.savemodel)
    if not os.path.exists(savepath):
        os.makedirs(savepath)
    log_file = os.path.join(savepath, 'run.log')

    logger = logging.getLogger('FS')
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(filename)s - %(lineno)s: %(message)s')
    fh = logging.StreamHandler(sys.stderr)
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    for k, v in sorted(vars(args).items()):
        logger.info('%s - %s' % (k, v))

    if args.seed != 0:
        torch.manual_seed(args.seed)
        if args.cuda:
            torch.cuda.manual_seed(args.seed)

    if args.dataset == "flow":
        trli, trri, trld, teli, teri, teld = lt.list_flow_file(args.datapath)
    elif args.dataset == "kitti":
        trli, trri, trld, teli, teri, teld = lt.list_kitti_file(
            args.datapath, args.date)
    elif args.dataset == "middlebury":
        trli, trri, trld, teli, teri, teld = lt.list_middlebury_file(
            args.datapath)

    TimeImgLoader = torch.utils.data.DataLoader(DA.ImageFloder(teli * 16,
                                                               teri * 16,
                                                               teld * 16,
                                                               training=False,
                                                               args=args),
                                                batch_size=args.batch_size,
                                                shuffle=False,
                                                num_workers=5,
                                                drop_last=False)

    model = get_model(args)

    if args.cuda:
        model.cuda()

    if args.loadmodel is not None:
        state_dict = torch.load(
            os.path.join(args.savedir, args.loadmodel, "max_loss.tar"))
        model.load_state_dict(state_dict['state_dict'])

    logger.info('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    ## Timing ##
    start_time = time.time()
    total_time = 0.
    for batch_idx, (imgL, imgR, disp_L) in enumerate(TimeImgLoader):
        per_time = runtime(model, args, imgL, imgR, disp_L)
        total_time += per_time
    logger.info(
        'total test time = %.5f, per example time = %.5f' %
        (time.time() - start_time, total_time / len(TimeImgLoader.dataset)))
Exemple #2
0
def main():
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.savemodel == "":
        timestr =  datetime.now().isoformat().replace(':','-').replace('.','MS')
        args.savemodel = timestr
    savepath = os.path.join(args.savedir, args.savemodel)
    if not os.path.exists(savepath):
        os.makedirs(savepath)
    log_file = os.path.join(savepath, 'run.log')

    logger = logging.getLogger('FS')
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(filename)s - %(lineno)s: %(message)s')
    fh = logging.StreamHandler(sys.stderr)
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    for k,v in sorted(vars(args).items()):
        logger.info('%s - %s' % (k, v))

    if args.time_out:
        logger.info('start sleeping')
        time.sleep(3600)
        logger.info('end sleeping')

    if args.seed != 0:
        torch.manual_seed(args.seed)
        if args.cuda:
            torch.cuda.manual_seed(args.seed)

    if args.dataset == "flow":
        trli, trri, trld, teli, teri, teld = lt.list_flow_file(args.datapath)
    elif args.dataset == "kitti":
        if args.all_train:
            trli, trri, trld, teli, teri, teld = lt.list_kitti_file(os.path.join(args.datapath, 'kitti2015','training'), '2015')
            trli2, trri2, trld2, teli2, teri2, teld2 = lt.list_kitti_file(os.path.join(args.datapath, 'kitti2012','train_194'), '2012')
            trli.extend(trli2)
            trri.extend(trri2)
            trld.extend(trld2)
            teli.extend(teli2)
            teri.extend(teri2)
            teld.extend(teld2)
        else:
            trli, trri, trld, teli, teri, teld = lt.list_kitti_file(args.datapath, args.date)
    elif args.dataset == "middlebury":
        trli, trri, trld, teli, teri, teld = lt.list_middlebury_file(args.datapath)
    if args.all_train:
        trli.extend(teli)
        trri.extend(teri)
        trld.extend(teld)

    TrainImgLoader = torch.utils.data.DataLoader(
        DA.ImageFloder(trli, trri, trld, training=args.no_train_aug, args=args), 
        batch_size=args.batch_size, shuffle=True, num_workers=5, drop_last=True)

    TestImgLoader = torch.utils.data.DataLoader(
        DA.ImageFloder(teli, teri, teld, training=False, args=args), 
        batch_size=args.batch_size, shuffle=False, num_workers=5, drop_last=False)

    model = get_model(args)

    if args.cuda:
        model = nn.DataParallel(model)
        model.cuda()

    if args.loadmodel is not None:
        state_dict = torch.load(os.path.join(args.savedir, args.loadmodel, "max_loss.tar"))
        model.load_state_dict(state_dict['state_dict'])

    logger.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    elif args.optimizer == 'mom':
        optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_epochs)

    max_loss=1e10
    max_epo=0
    loss_avg = 0.
    for epoch in range(1, args.epochs+1):
        logger.info('This is %d-th epoch' %(epoch))
        total_train_loss = 0
        scheduler.step()

        ## training ##
        start_time = time.time()
        for batch_idx, (imgL_crop, imgR_crop, disp_crop_L) in enumerate(TrainImgLoader):
            loss = train(model, optimizer, args, imgL_crop,imgR_crop, disp_crop_L, epoch=epoch)
            loss_avg = 0.99 * loss_avg + 0.01 * loss
            if (batch_idx + 1) % args.log_steps == 0:
                logger.info('Iter %d training loss = %.3f , time = %.2f' %(
                    batch_idx + 1, loss_avg, time.time() - start_time))
                start_time = time.time()
            total_train_loss += loss
        logger.info('epoch %d total training loss = %.3f' %(epoch, total_train_loss/len(TrainImgLoader)))

        ## TEST ##
        start_time = time.time()
        total_test_loss = 0.
        total_test_num = 0.
        for batch_idx, (imgL, imgR, disp_L) in enumerate(TestImgLoader):
            loss, num = test(model, args, imgL,imgR, disp_L)
            if (batch_idx + 1) % args.log_steps == 0:
                logger.info('Iter %d test loss = %.5f , time = %.2f' %(
                    batch_idx + 1, loss/num, time.time() - start_time))
                start_time = time.time()
            total_test_loss += loss
            total_test_num += num

        if args.dataset == "kitti" or args.dataset == "middlebury":
            total_test_loss = (1 - total_test_loss / total_test_num) * 100.
        else:
            total_test_loss = total_test_loss / total_test_num
        logger.info('total test loss = %.5f' % total_test_loss)

        if total_test_loss < max_loss:
            max_loss = total_test_loss
            max_epo = epoch

            savefilename = os.path.join(savepath, 'max_loss.tar')
            torch.save({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'test_loss': total_test_loss,
            }, savefilename)
        logger.info('MAX epoch %d total test error = %.5f' %(max_epo, max_loss))