コード例 #1
0
                    type=int,
                    default=1,
                    metavar='S',
                    help='random seed (default: 1)')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = kl2012.dataloader(
    args.datapath)

TrainImgLoader = torch.utils.data.DataLoader(kl.myImageFloder(
    all_left_img, all_right_img, all_left_disp, True),
                                             batch_size=3,
                                             shuffle=True,
                                             num_workers=8,
                                             drop_last=False)

TestImgLoader = torch.utils.data.DataLoader(kl.myImageFloder(
    test_left_img, test_right_img, test_left_disp, False),
                                            batch_size=3,
                                            shuffle=False,
                                            num_workers=4,
                                            drop_last=False)

if args.model == 'stackhourglass':
    model = stackhourglass(args.maxdisp)
elif args.model == 'basic':
コード例 #2
0
    from dataloader import KITTIloader_VirtualKT2 as ls
    train_file_list = args.vkt2_train_list
    val_file_list = args.vkt2_val_list
    virtual_kitti2 = True

else:
    raise Exception("No suitable KITTI found ...")

print('[??] args.datapath = ', args.datapath)
all_left_img, all_right_img, all_left_disp, test_left_img, \
        test_right_img, test_left_disp = ls.dataloader(
                args.datapath, train_file_list, val_file_list)

TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
    all_left_img,
    all_right_img,
    all_left_disp,
    training=True,
    virtual_kitti2=virtual_kitti2),
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=8,
                                             drop_last=False)

TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
    test_left_img,
    test_right_img,
    test_left_disp,
    training=False,
    virtual_kitti2=virtual_kitti2),
                                            batch_size=args.batch_size,
                                            shuffle=False,
コード例 #3
0
ファイル: finetune.py プロジェクト: tareeqav/AnyNet
def main():
    global args
    log = logger.setup_logger(args.save_path + '/training.log')

    train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader(
        args.datapath, log)

    TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
        train_left_img, train_right_img, train_left_disp, True),
                                                 batch_size=args.train_bsize,
                                                 shuffle=True,
                                                 num_workers=4,
                                                 drop_last=False)

    TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
        test_left_img, test_right_img, test_left_disp, False),
                                                batch_size=args.test_bsize,
                                                shuffle=False,
                                                num_workers=4,
                                                drop_last=False)

    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    for key, value in sorted(vars(args).items()):
        log.info(str(key) + ': ' + str(value))

    model = models.anynet.AnyNet(args)
    model = nn.DataParallel(model).cuda()
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
    log.info('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    if args.pretrained:
        if os.path.isfile(args.pretrained):
            checkpoint = torch.load(args.pretrained)
            model.load_state_dict(checkpoint['state_dict'])
            log.info("=> loaded pretrained model '{}'".format(args.pretrained))
        else:
            log.info("=> no pretrained model found at '{}'".format(
                args.pretrained))
            log.info("=> Will start from scratch.")
    args.start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            log.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            log.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            log.info("=> no checkpoint found at '{}'".format(args.resume))
            log.info("=> Will start from scratch.")
    else:
        log.info('Not Resume')
    cudnn.benchmark = True
    start_full_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        log.info('This is {}-th epoch'.format(epoch))
        adjust_learning_rate(optimizer, epoch)

        train(TrainImgLoader, model, optimizer, log, epoch)

        savefilename = args.save_path + '/checkpoint.tar'
        torch.save(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, savefilename)

        if epoch % 1 == 0:
            test(TestImgLoader, model, log)

    test(TestImgLoader, model, log)
    log.info('full training time = {:.2f} Hours'.format(
        (time.time() - start_full_time) / 3600))
コード例 #4
0
def main():
    epoch_start = 0
    max_acc = 0
    max_epo = 0

    # 数据集部分
    all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader(
        args.datapath)

    TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
        all_left_img, all_right_img, all_left_disp, True),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 num_workers=12,
                                                 drop_last=True)

    TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
        test_left_img, test_right_img, test_left_disp, False),
                                                batch_size=args.batch_size,
                                                shuffle=False,
                                                num_workers=4,
                                                drop_last=False)

    start_full_time = time.time()
    for epoch in range(epoch_start, args.epochs + 1):
        print("epoch:", epoch)

        # training
        total_train_loss = 0
        for batch_idx, (imgL_crop, imgR_crop,
                        disp_crop_L) in enumerate(TrainImgLoader):

            loss = train(imgL_crop, imgR_crop, disp_crop_L)
            print(loss)
            exit()
            total_train_loss += loss
            print('epoch:{}, step:{}, loss:{}'.format(epoch, batch_idx, loss))

        print('epoch %d average training loss = %.3f' %
              (epoch, total_train_loss / len(TrainImgLoader)))

        # test
        total_test_three_pixel_error_rate = 0

        for batch_idx, (imgL, imgR, disp_L) in enumerate(TestImgLoader):
            test_three_pixel_error_rate = test(imgL, imgR, disp_L)
            total_test_three_pixel_error_rate += test_three_pixel_error_rate

        print('epoch %d total 3-px error in val = %.3f' %
              (epoch,
               total_test_three_pixel_error_rate / len(TestImgLoader) * 100))

        acc = (1 -
               total_test_three_pixel_error_rate / len(TestImgLoader)) * 100
        if acc > max_acc:
            max_acc = acc
            max_epo = epoch
            savefilename = './kitti15.tar'
            #
            # savefilename = root_path + '/checkpoints/checkpoint_finetune_kitti15.tar'
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'total_train_loss': total_train_loss,
                    'epoch': epoch + 1,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'max_acc': max_acc,
                    'max_epoch': max_epo
                }, savefilename)
            print("-- max acc checkpoint saved --")
        print('MAX epoch %d test 3 pixel correct rate = %.3f' %
              (max_epo, max_acc))

    print('full finetune time = %.2f HR' %
          ((time.time() - start_full_time) / 3600))
    print(max_epo)
    print(max_acc)
コード例 #5
0
parser.add_argument('--seed',
                    type=int,
                    default=1,
                    metavar='S',
                    help='random seed (default: 1)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed_all(args.seed)

test_left_img1, test_right_img1, test_left_disp1, test_left_img, test_right_img, test_left_disp = ls.dataloader(
    args.datapath)

TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
    test_left_img, test_right_img, test_left_disp, False),
                                            batch_size=1,
                                            shuffle=False,
                                            num_workers=4,
                                            drop_last=False)

if args.model == 'ShuffleStereo8':
    model = MABNet_origin(args.maxdisp)
elif args.model == 'ShuffleStereo16':
    model = ShuffleStereo16(args.maxdisp)
else:
    print('no model')

if args.cuda:
    model = nn.DataParallel(model)
    model.cuda()
コード例 #6
0
def main():
    global args
    log = logger.setup_logger(args.save_path + '/training.log')

    # train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp, test_fn = ls.dataloader(
    #     args.datapath,log, args.split_file)
    #
    # TrainImgLoader = torch.utils.data.DataLoader(
    #     DA.myImageFloder(train_left_img, train_right_img, train_left_disp, True),
    #     batch_size=args.train_bsize, shuffle=True, num_workers=4, drop_last=False)
    #
    # TestImgLoader = torch.utils.data.DataLoader(
    #     DA.myImageFloder(test_left_img, test_right_img, test_left_disp, False, test_fn),
    #     batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False)

    train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, left_val_disp, val_fn, left_train_semantic, left_val_semantic = ls.dataloader(
        args.datapath, log, args.split_file)

    TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
        train_left_img, train_right_img, train_left_disp, left_train_semantic,
        True),
                                                 batch_size=args.train_bsize,
                                                 shuffle=True,
                                                 num_workers=4,
                                                 drop_last=False)

    TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
        test_left_img, test_right_img, left_val_disp, left_val_semantic, False,
        val_fn),
                                                batch_size=args.test_bsize,
                                                shuffle=False,
                                                num_workers=4,
                                                drop_last=False)

    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    for key, value in sorted(vars(args).items()):
        log.info(str(key) + ': ' + str(value))

    model = models.anynet.AnyNet(args)
    model = nn.DataParallel(model).cuda()
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
    log.info('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    if args.pretrained:
        if os.path.isfile(args.pretrained):
            checkpoint = torch.load(args.pretrained)
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            log.info("=> loaded pretrained model '{}'".format(args.pretrained))
        else:
            log.info("=> no pretrained model found at '{}'".format(
                args.pretrained))
            log.info("=> Will start from scratch.")
    args.start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            log.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            log.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            log.info("=> no checkpoint found at '{}'".format(args.resume))
            log.info("=> Will start from scratch.")
    else:
        log.info('Not Resume')
    cudnn.benchmark = True

    test(TestImgLoader, model, log)
    return
コード例 #7
0
def main(log):
    global args

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    ## init dist ##
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.world_size = 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    if args.model_types == "PSMNet":
        model = PSMNet(args)
        args.loss_weights = [0.5, 0.7, 1.]

    elif args.model_types == "PSMNet_DSM":
        model = PSMNet_DSM(args)
        args.loss_weights = [0.5, 0.7, 1.]

    elif args.model_types == "Hybrid_Net_DSM" or "Hybrid_Net":
        model = Hybrid_Net(args)
        args.loss_weights = [0.5, 0.7, 1., 1., 1.]

    else:
        AssertionError("model error")

    if args.datatype == '2015':
        all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader2015(
            args.datapath2015, split=args.split_for_val)

    elif args.datatype == '2012':
        all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader2012(
            args.datapath2012, split=False)

    elif args.datatype == 'mix':
        all_left_img_2015, all_right_img_2015, all_left_disp_2015, test_left_img_2015, test_right_img_2015, test_left_disp_2015 = ls.dataloader2015(
            args.datapath2015, split=False)
        all_left_img_2012, all_right_img_2012, all_left_disp_2012, test_left_img_2012, test_right_img_2012, test_left_disp_2012 = ls.dataloader2012(
            args.datapath2012, split=False)
        all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = \
            all_left_img_2015 + all_left_img_2012, all_right_img_2015 + all_right_img_2012, \
            all_left_disp_2015 + all_left_disp_2012, test_left_img_2015 + test_left_img_2012, \
            test_right_img_2015 + test_right_img_2012, test_left_disp_2015 + test_left_disp_2012
    else:

        AssertionError("please define the finetune dataset")

    train_set = DA.myImageFloder(all_left_img, all_right_img, all_left_disp,
                                 True)
    val_set = DA.myImageFloder(test_left_img, test_right_img, test_left_disp,
                               False)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_set)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_set)

    else:
        train_sampler = None
        val_sampler = None

    TrainImgLoader = torch.utils.data.DataLoader(train_set,
                                                 batch_size=args.train_bsize,
                                                 shuffle=False,
                                                 num_workers=4,
                                                 pin_memory=True,
                                                 sampler=train_sampler,
                                                 drop_last=False)

    TestImgLoader = torch.utils.data.DataLoader(val_set,
                                                batch_size=args.test_bsize,
                                                shuffle=False,
                                                num_workers=4,
                                                pin_memory=True,
                                                sampler=None,
                                                drop_last=False)

    num_train = len(TrainImgLoader)
    num_test = len(TestImgLoader)

    if args.local_rank == 0:

        for key, value in sorted(vars(args).items()):

            log.info(str(key) + ': ' + str(value))

    stages = len(args.loss_weights)

    # note
    if args.sync_bn:
        import apex
        model = apex.parallel.convert_syncbn_model(model)
        if args.local_rank == 0:
            log.info(
                "using apex synced BN-----------------------------------------------------"
            )

    model = model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999))

    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level=args.opt_level,
        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
        loss_scale=args.loss_scale)

    if args.distributed:
        model = DDP(model, delay_allreduce=True)
        if args.local_rank == 0:
            log.info(
                "using distributed-----------------------------------------------------"
            )

    if args.pretrained:
        if os.path.isfile(args.pretrained):
            checkpoint = torch.load(args.pretrained, map_location='cpu')

            model.load_state_dict(checkpoint['state_dict'], strict=True)

            if args.local_rank == 0:
                log.info("=> loaded pretrained model '{}'".format(
                    args.pretrained))

        else:
            if args.local_rank == 0:
                log.info("=> no pretrained model found at '{}'".format(
                    args.pretrained))
                log.info("=> Will start from scratch.")
    args.start_epoch = 0

    if args.resume:
        if os.path.isfile(args.resume):
            if args.local_rank == 0:
                log.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            model.load_state_dict(checkpoint['state_dict'])
            args.start_epoch = checkpoint['epoch'] + 1
            if args.local_rank == 0:
                log.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            if args.local_rank == 0:
                log.info("=> no checkpoint found at '{}'".format(args.resume))
                log.info("=> Will start from scratch.")
    else:
        if args.local_rank == 0:
            log.info('Not Resume')

    if args.local_rank == 0:
        log.info('Number of model parameters: {}'.format(
            sum([p.data.nelement() for p in model.parameters()])))

    start_full_time = time.time()

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        total_train_loss = 0
        total_d1 = 0
        total_epe = 0
        adjust_learning_rate(optimizer, epoch)

        losses = [AverageMeter() for _ in range(stages)]

        ## training ##
        for batch_idx, (imgL_crop, imgR_crop,
                        disp_crop_L) in enumerate(TrainImgLoader):

            loss = train(imgL_crop, imgR_crop, disp_crop_L, model, optimizer)

            for idx in range(stages):
                losses[idx].update(loss[idx].item() / args.loss_weights[idx])

            # # record loss
            info_str = [
                'Stage {} = {:.2f}({:.2f})'.format(x, losses[x].val,
                                                   losses[x].avg)
                for x in range(stages)
            ]
            info_str = '\t'.join(info_str)
            if args.local_rank == 0:
                log.info('losses  Epoch{} [{}/{}] {}'.format(
                    epoch, batch_idx, len(TrainImgLoader), info_str))

            total_train_loss += loss[-1]

        if args.local_rank == 0:
            log.info('epoch %d total training loss = %.3f' %
                     (epoch, total_train_loss / num_train))

        if epoch % 50 == 0:
            ## Test ##
            inference_time = 0

            for batch_idx, (imgL, imgR, disp_L) in enumerate(TestImgLoader):

                epe, d1, single_inference_time = test(imgL, imgR, disp_L,
                                                      model)

                inference_time += single_inference_time

                total_d1 += d1
                total_epe += epe

                if args.distributed:
                    total_epe = reduce_tensor(total_epe.data)

                    total_d1 = reduce_tensor(total_d1.data)

                else:

                    total_epe = total_epe
                    total_d1 = total_d1

            if args.local_rank == 0:
                log.info('epoch %d avg_3-px error in val = %.3f' %
                         (epoch, total_d1 / num_test * 100))
                log.info('epoch %d avg_epe  in val = %.3f' %
                         (epoch, total_epe / num_test))
                log.info(('=> Mean inference time for %d images: %.3fs' %
                          (num_test, inference_time / num_test)))

            if args.local_rank == 0:
                if epoch % 100 == 0:
                    savefilename = args.save_path + '/finetune_' + str(
                        epoch) + '.tar'
                    torch.save(
                        {
                            'epoch': epoch,
                            'state_dict': model.state_dict(),
                            'train_loss':
                            total_train_loss / len(TrainImgLoader),
                            'test_loss': total_d1 / len(TestImgLoader) * 100,
                        }, savefilename)

    if args.local_rank == 0:
        log.info('full finetune time = %.2f HR' %
                 ((time.time() - start_full_time) / 3600))
コード例 #8
0
def main():
    global args
    log = logger.setup_logger(args.save_path + '/training.log')
    train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader(
        args.datapath, log, args.split_file)
    n_train = int(len(train_left_img))
    TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
        train_left_img, train_right_img, train_left_disp, True),
                                                 batch_size=args.train_bsize,
                                                 shuffle=True,
                                                 num_workers=4,
                                                 drop_last=False)

    TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
        test_left_img, test_right_img, test_left_disp, False),
                                                batch_size=args.test_bsize,
                                                shuffle=False,
                                                num_workers=4,
                                                drop_last=False)

    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    for key, value in sorted(vars(args).items()):
        log.info(str(key) + ': ' + str(value))

    model = models.anynet.AnyNet(args)
    torch.save(model, './model_para.pth')

    model = nn.DataParallel(model).cuda()

    torch.manual_seed(2.0)
    left = torch.randn(1, 3, 256, 512)
    right = torch.randn(1, 3, 256, 512)

    with SummaryWriter(comment='AnyNet_model_stracture') as w:
        w.add_graph(model, (
            left,
            right,
        ))

    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
    log.info('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    args.start_epoch = 0
    if args.resume:  #训练中断后,继续加载
        if os.path.isfile(args.resume):
            log.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])  #模型加载,
            optimizer.load_state_dict(checkpoint['optimizer'])  #优化器加载
            log.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            log.info("=> no checkpoint found at '{}'".format(args.resume))
            log.info("=> Will start from scratch.")
    else:
        log.info('Not Resume')
    cudnn.benchmark = True
    start_full_time = time.time()

    train(TrainImgLoader, model, optimizer, log, n_train,
          TestImgLoader)  #开始进行模型训练
    test(TestImgLoader, model, log)
    log.info('full training time = {:.2f} Hours'.format(
        (time.time() - start_full_time) / 3600))
コード例 #9
0
def main():
    global args
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader(
        args.datapath)

    TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
        train_left_img, train_right_img, train_left_disp, True),
                                                 batch_size=args.train_bsize,
                                                 shuffle=True,
                                                 num_workers=1,
                                                 drop_last=False)

    TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
        test_left_img, test_right_img, test_left_disp, False),
                                                batch_size=args.test_bsize,
                                                shuffle=False,
                                                num_workers=4,
                                                drop_last=False)

    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = logger.setup_logger(args.save_path + 'training.log')
    for key, value in sorted(vars(args).items()):
        log.info(str(key) + ':' + str(value))

    model = StereoNet(maxdisp=args.maxdisp)
    model = nn.DataParallel(model).cuda()
    model.apply(weights_init)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[200],
                                         gamma=args.gamma)

    log.info('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    args.start_epoch = 0

    if args.resume:
        if os.path.isfile(args.resume):
            log.info("=> loading checkpoint '{}'".format((args.resume)))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            log.info("=> loaded checkpoint '{}' (epeoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            log.info("=> no checkpoint found at '{}'".format(args.resume))
            log.info("=> will start from scratch.")
    else:
        log.info("Not Resume")

    min_erro = 100000
    max_epo = 0
    start_full_time = time.time()
    for epoch in range(args.start_epoch, args.epoch):
        log.info('This is {}-th epoch'.format(epoch))

        train(TrainImgLoader, model, optimizer, log, epoch)
        scheduler.step()

        erro = test(TestImgLoader, model, log)
        if erro < min_erro:
            max_epo = epoch
            min_erro = erro
            savefilename = args.save_path + 'finetune_checkpoint_{}.pth'.format(
                max_epo)
            torch.save({
                'epoch': epoch,
                'state_dict': model.state_dict()
            }, savefilename)
        log.info('MIN epoch %d total test erro = %.3f' % (max_epo, min_erro))
    log.info('full training time = {: 2f} Hours'.format(
        (time.time() - start_full_time) / 3600))
コード例 #10
0
                    help='random seed (default: 1)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

if args.datatype == '2015':
   from dataloader import KITTIloader2015 as ls
elif args.datatype == '2012':
   from dataloader import KITTIloader2012 as ls

all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader(args.datapath)

TrainImgLoader = torch.utils.data.DataLoader(
         DA.myImageFloder(all_left_img,all_right_img,all_left_disp, True),
         batch_size= 2, shuffle= True, num_workers= 4, drop_last=False)

TestImgLoader = torch.utils.data.DataLoader(
         DA.myImageFloder(test_left_img,test_right_img,test_left_disp, False),
         batch_size= 2, shuffle= False, num_workers= 4, drop_last=False)

if args.model == 'stackhourglass':
    model = stackhourglass(args.maxdisp)
elif args.model == 'basic':
    model = basic(args.maxdisp)
else:
    print('no model')

if args.cuda:
    model = nn.DataParallel(model)
コード例 #11
0
ファイル: main.py プロジェクト: SHILAIFU/PSMNet-Tensorflow
if args.datatype == '2015':
    from dataloader import KITTIloader2015 as ls
elif args.datatype == '2012':
    from dataloader import KITTIloader2012 as ls

#读取数据路径(默认读取2012的数据,返回图片的路径)
all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader(
    args.datapath)

#读数据,并且讲数据以12个batch进行封装,并且图片全部shuffle
#TrainImgLoader是一个list包括了imgL_crop, imgR_crop, disp_crop_L,list[i]表示一个batch_size

#kitti2012_data = DA.myImageFloder(all_left_img,all_right_img,all_left_disp, True)

TrainImgLoader = DA.ImgLoader(DA.myImageFloder(all_left_img, all_right_img,
                                               all_left_disp, True),
                              BATCH_SIZE=1)

#TestImgLoader = DA.ImgLoader(
#         DA.myImageFloder(test_left_img,test_right_img,test_left_disp, False),
#         BATCH_SIZE= 8)
"""
#读取已有的模型(后续在写)
if args.loadmodel is not None:
    state_dict = torch.load(args.loadmodel)
    model.load_state_dict(state_dict['state_dict'])

print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
"""
"""
def test(imgL,imgR,disp_true):
コード例 #12
0
    from dataloader import KITTILoader as DA
elif args.datatype == 'sceneflow':
    from dataloader import listflowfile as ls
    from dataloader import SecenFlowLoader as DA
elif args.datatype == 'kitti_object':
    from dataloader.KITTIObjectLoader import KITTIObjectLoader
else:
    print('unknown datatype: ', args.datatype)
    sys.exit()

if args.datatype == 'kitti_object':
    dataloader = KITTIObjectLoader(args.datapath, 'trainval')
else:
    all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader(
        args.datapath)
    dataloader = DA.myImageFloder(test_left_img, test_right_img,
                                  test_left_disp, False)

if args.model == 'stackhourglass':
    model = stackhourglass(args.maxdisp)
elif args.model == 'basic':
    model = basic(args.maxdisp)
else:
    print('no model')

refine_model = unet_refine.resnet34(pretrained=True, rgbd=True)
#refine_model = unet_refine.resnet18(pretrained=True)

model = nn.DataParallel(model, device_ids=[0])
model.cuda()
refine_model = nn.DataParallel(refine_model)
refine_model.cuda()
コード例 #13
0
ファイル: finetune.py プロジェクト: wpfhtl/PSMNet
                    help='load image as RGB or gray')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

if args.datatype == '2015':
   from dataloader import KITTIloader2015 as ls
elif args.datatype == '2012':
   from dataloader import KITTIloader2012 as ls

all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader(args.datapath)

TrainImgLoader = torch.utils.data.DataLoader(
         DA.myImageFloder(all_left_img,all_right_img,all_left_disp, True, colormode=args.colormode),
         batch_size= 12, shuffle= True, num_workers= 8, drop_last=False)

TestImgLoader = torch.utils.data.DataLoader(
         DA.myImageFloder(test_left_img,test_right_img,test_left_disp, False, colormode=args.colormode),
         batch_size= 3, shuffle= False, num_workers= 4, drop_last=False)

if args.model == 'stackhourglass':
    model = stackhourglass(args.maxdisp, colormode=args.colormode)
elif args.model == 'basic':
    model = basic(args.maxdisp, colormode=args.colormode)
else:
    print('no model')

if args.cuda:
    model = nn.DataParallel(model)
コード例 #14
0
args = parser.parse_args()
if args.KITTI == '2015':
    from dataloader import KITTI_submission_loader as DA
else:
   from dataloader import KITTI_submission_loader2012 as DA  

test_left_img, test_right_img = DA.dataloader(args.datapath)


from dataloader import KITTIloader2015 as lt
from dataloader import KITTILoader as DA_tmp
all_left_img, all_right_img, all_left_disp, test_left_img_tmp, test_right_img_tmp, test_left_disp = lt.dataloader(args.datapath)

TrainImgLoader = torch.utils.data.DataLoader(
         DA_tmp.myImageFloder(all_left_img,all_right_img,all_left_disp, True), 
         batch_size= 1, shuffle= True, num_workers= 8, drop_last=False)


args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

if args.model == 'stackhourglass':
    model = psm_net.PSMNet(args.maxdisp)
elif args.model == 'basic':
    model = basic_net.PSMNet(args.maxdisp)
elif args.model == 'concatNet':
    model = concatNet.PSMNet(args.maxdisp)