def train(cfg, network):
#     cfg.train.num_workers = 0
    if cfg.train.dataset[:4] != 'City':
        torch.multiprocessing.set_sharing_strategy('file_system')
        
    train_loader = make_data_loader(cfg, is_train=True, max_iter=cfg.ep_iter)
    val_loader = make_data_loader(cfg, is_train=False)
    # train_loader = make_data_loader(cfg, is_train=True, max_iter=100)
    
    trainer = make_trainer(cfg, network)
    optimizer = make_optimizer(cfg, network)
    scheduler = make_lr_scheduler(cfg, optimizer)
    recorder = make_recorder(cfg)
    evaluator = make_evaluator(cfg)

    begin_epoch = load_model(network, optimizer, scheduler, recorder, cfg.model_dir, resume=cfg.resume)
    # set_lr_scheduler(cfg, scheduler)

    for epoch in range(begin_epoch, cfg.train.epoch):
        recorder.epoch = epoch
        trainer.train(epoch, train_loader, optimizer, recorder)
        scheduler.step()

        if (epoch + 1) % cfg.save_ep == 0:
            save_model(network, optimizer, scheduler, recorder, epoch, cfg.model_dir)

        if (epoch + 1) % cfg.eval_ep == 0:
            trainer.val(epoch, val_loader, evaluator, recorder)

    return network
Exemple #2
0
def train(cfg, network):
    trainer = make_trainer(cfg, network)
    optimizer = make_optimizer(cfg, network)
    scheduler = make_lr_scheduler(cfg, optimizer)
    recorder = make_recorder(cfg)
    evaluator = make_evaluator(cfg)

    begin_epoch = load_model(network, optimizer, scheduler, recorder, cfg.model_dir, resume=cfg.resume)
    # set_lr_scheduler(cfg, scheduler)

    train_loader = make_data_loader(cfg, is_train=True)
    val_loader = make_data_loader(cfg, is_train=False)

    for epoch in range(begin_epoch, cfg.train.epoch):
        recorder.epoch = epoch
        trainer.train(epoch, train_loader, optimizer, recorder)
        scheduler.step()

        if (epoch + 1) % cfg.save_ep == 0:
            save_model(network, optimizer, scheduler, recorder, epoch, cfg.model_dir)

        if (epoch + 1) % cfg.eval_ep == 0:
            trainer.val(epoch, val_loader, evaluator, recorder)

    return network
Exemple #3
0
def train(cfg, network):
    if cfg.train.dataset[:4] != 'City':
        torch.multiprocessing.set_sharing_strategy('file_system')
    trainer = make_trainer(cfg, network)
    optimizer = make_optimizer(cfg, network)
    scheduler = make_lr_scheduler(cfg, optimizer)
    recorder = make_recorder(cfg)
    if 'Coco' not in cfg.train.dataset:
        evaluator = make_evaluator(cfg)

    begin_epoch = load_model(network,
                             optimizer,
                             scheduler,
                             recorder,
                             cfg.model_dir,
                             resume=cfg.resume)
    # set_lr_scheduler(cfg, scheduler)

    train_loader = make_data_loader(cfg, is_train=True)
    val_loader = make_data_loader(cfg, is_train=False)
    # train_loader = make_data_loader(cfg, is_train=True, max_iter=100)

    global_steps = None
    if cfg.neptune:
        global_steps = {
            'train_global_steps': 0,
            'valid_global_steps': 0,
        }

        neptune.init('hccccccccc/clean-pvnet')
        neptune.create_experiment(cfg.model_dir.split('/')[-1])
        neptune.append_tag('pose')

    for epoch in range(begin_epoch, cfg.train.epoch):
        recorder.epoch = epoch
        trainer.train(epoch, train_loader, optimizer, recorder, global_steps)
        scheduler.step()

        if (epoch + 1) % cfg.save_ep == 0:
            save_model(network, optimizer, scheduler, recorder, epoch,
                       cfg.model_dir)

        if (epoch + 1) % cfg.eval_ep == 0:
            if 'Coco' in cfg.train.dataset:
                trainer.val_coco(val_loader, global_steps)
            else:
                trainer.val(epoch, val_loader, evaluator, recorder)

    if cfg.neptune:
        neptune.stop()

    return network
Exemple #4
0
def train(cfg, network):
    trainer = make_trainer(cfg, network)
    optimizer = make_optimizer(cfg, network)
    scheduler = make_lr_scheduler(cfg, optimizer)
    recorder = make_recorder(cfg)
    evaluator = make_evaluator(cfg)

    begin_epoch = load_model(network,
                             optimizer,
                             scheduler,
                             recorder,
                             cfg.model_dir,
                             resume=cfg.resume)
    # begin_epoch = 0  #如果要继续训练那么请注释这一行
    # set_lr_scheduler(cfg, scheduler)
    # print("before train loader")
    train_loader = make_data_loader(cfg, is_train=True)  #到这里才读取的数据
    # print("under train loader")
    val_loader = make_data_loader(cfg, is_train=False)

    # #这里是查看train_loader的相关参数个结构
    # tmp_file = open('/home/tianhao.lu/code/Deep_snake/snake/Result/Contour/contour.log', 'w')
    # tmp_file.writelines("train_loader type:" + str(type(train_loader)) + "\n")
    # tmp_file.writelines("train_loader len:" + str(len(train_loader)) + "\n")
    # tmp_file.writelines("train_loader data:" + str(train_loader) + "\n")
    # for tmp_data in train_loader:
    #     tmp_file.writelines("one train_loader data type:" + str(type(tmp_data)) + "\n")
    #     for key in tmp_data:
    #         tmp_file.writelines("one train_loader data key:" + str(key) + "\n")
    #         tmp_file.writelines("one train_loader data len:" + str(len(tmp_data[key])) + "\n")
    #     # tmp_file.writelines("one train_loader data:" + str(tmp_data) + "\n")
    #     break
    # tmp_file.writelines(str("*************************************************************** \n"))
    # tmp_file.close()

    for epoch in range(begin_epoch, cfg.train.epoch):
        recorder.epoch = epoch
        trainer.train(epoch, train_loader, optimizer, recorder)
        scheduler.step()  #optimizer.step()模型才会更新,scheduler.step()是用来调整lr的

        if (epoch + 1) % cfg.save_ep == 0:
            save_model(network, optimizer, scheduler, recorder, epoch,
                       cfg.model_dir)

        if (epoch + 1) % cfg.eval_ep == 0:
            trainer.val(epoch, val_loader, evaluator, recorder)

    return network
Exemple #5
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    from lib.visualizers import make_visualizer

    network = make_network(cfg).cuda()
    load_network(network,
                 cfg.model_dir,
                 resume=cfg.resume,
                 epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    visualizer = make_visualizer(cfg)
    if args.vis_out:
        os.mkdir(args.vis_out)
    idx = 0
    #start = timeit.default_timer()
    for batch in tqdm.tqdm(data_loader):
        idx += 1
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        if args.vis_out:
            visualizer.visualize(
                output, batch,
                os.path.join(args.vis_out, "{:04d}_.png".format(idx)))
        else:
            visualizer.visualize(output, batch)
Exemple #6
0
def run_evaluate():
    from lib.datasets import make_data_loader
    from lib.evaluators import make_evaluator
    import tqdm
    import torch
    from lib.networks import make_network
    from lib.utils.net_utils import load_network
    from lib.train import make_trainer

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
    trainer = make_trainer(cfg, network)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    if 'Coco' in cfg.train.dataset:
        trainer.val_coco(data_loader)
    else:
        evaluator = make_evaluator(cfg)
        for batch in tqdm.tqdm(data_loader):
            inp = batch['inp'].cuda()
            with torch.no_grad():
                output = network(inp)
            evaluator.evaluate(output, batch)
        evaluator.summarize()
Exemple #7
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    from lib.visualizers import make_visualizer

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    visualizer = make_visualizer(cfg)
    num = -1
    for batch in tqdm.tqdm(data_loader):
        num = num + 1
        name = '%06d.jpg' % num
        #high_resolution = '/mnt/SSD/jzwang/code/clean-pvnet/demo28/' + name
        #high_resolution = Image.open(high_resolution)
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        #visualizer.visualize(output, batch, name, high_resolution)
        visualizer.visualize(output, batch, name)
Exemple #8
0
def run_network():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    import time

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    total_time = 0
    for batch in tqdm.tqdm(data_loader):
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            torch.cuda.synchronize()
            start = time.time()
            network(batch['inp'])
            torch.cuda.synchronize()
            total_time += time.time() - start
    print(total_time / len(data_loader))
Exemple #9
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    from lib.visualizers import make_visualizer

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    visualizer = make_visualizer(cfg)
    index = 0
    for batch in tqdm.tqdm(data_loader):
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        # pdb.set_trace()
        visualizer.visualize(output, batch, index)

        index = index + 1
Exemple #10
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    from lib.visualizers import make_visualizer

    if DEBUG:
        print(
            '===================================Visualizing================================='
        )
        import pprint
        print('args:', args)
        print('cfg:')
        pprint.pprint(cfg)

    network = make_network(cfg).cuda()
    load_network(network,
                 cfg.model_dir,
                 resume=cfg.resume,
                 epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    visualizer = make_visualizer(cfg)
    for batch in tqdm.tqdm(data_loader):
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        visualizer.visualize(output, batch)
Exemple #11
0
def run_evaluate():
    from lib.datasets import make_data_loader
    from lib.evaluators import make_evaluator
    import tqdm
    import torch
    from lib.networks import make_network
    from lib.utils.net_utils import load_network

    if DEBUG:
        print(
            '-------------------------------Evaluating---------------------------------'
        )

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    evaluator = make_evaluator(cfg)
    for batch in tqdm.tqdm(data_loader):
        inp = batch['inp'].cuda()
        with torch.no_grad():
            output = network(inp)
        evaluator.evaluate(output, batch)
    evaluator.summarize()
Exemple #12
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    from lib.utils import net_utils
    import tqdm
    import torch
    from lib.visualizers import make_visualizer
    from lib.networks.renderer import make_renderer

    cfg.perturb = 0

    network = make_network(cfg).cuda()
    load_network(network,
                 cfg.trained_model_dir,
                 resume=cfg.resume,
                 epoch=cfg.test.epoch)
    network.train()

    data_loader = make_data_loader(cfg, is_train=False)
    renderer = make_renderer(cfg, network)
    visualizer = make_visualizer(cfg)
    for batch in tqdm.tqdm(data_loader):
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = renderer.render(batch)
            visualizer.visualize(output, batch)
Exemple #13
0
def run_dataset():
    from lib.datasets import make_data_loader
    import tqdm

    cfg.train.num_workers = 0
    data_loader = make_data_loader(cfg, is_train=False)
    for batch in tqdm.tqdm(data_loader):
        pass
Exemple #14
0
def train(cfg, network):
    trainer = make_trainer(cfg, network)
    optimizer = make_optimizer(cfg, network)
    scheduler = make_lr_scheduler(cfg, optimizer)
    recorder = make_recorder(cfg)
    evaluator = make_evaluator(cfg)

    begin_epoch = load_model(network,
                             optimizer,
                             scheduler,
                             recorder,
                             cfg.trained_model_dir,
                             resume=cfg.resume)
    set_lr_scheduler(cfg, scheduler)

    train_loader = make_data_loader(cfg,
                                    is_train=True,
                                    is_distributed=cfg.distributed,
                                    max_iter=cfg.ep_iter)
    val_loader = make_data_loader(cfg, is_train=False)

    for epoch in range(begin_epoch, cfg.train.epoch):
        recorder.epoch = epoch
        if cfg.distributed:
            train_loader.batch_sampler.sampler.set_epoch(epoch)

        trainer.train(epoch, train_loader, optimizer, recorder)
        scheduler.step()

        if (epoch + 1) % cfg.save_ep == 0 and cfg.local_rank == 0:
            save_model(network, optimizer, scheduler, recorder,
                       cfg.trained_model_dir, epoch)

        if (epoch + 1) % cfg.save_latest_ep == 0 and cfg.local_rank == 0:
            save_model(network,
                       optimizer,
                       scheduler,
                       recorder,
                       cfg.trained_model_dir,
                       epoch,
                       last=True)

        if (epoch + 1) % cfg.eval_ep == 0:
            trainer.val(epoch, val_loader, evaluator, recorder)

    return network
def test(cfg, network):
    trainer = make_trainer(cfg, network)
    val_loader = make_data_loader(cfg, is_train=False)
    evaluator = make_evaluator(cfg)
    epoch = load_network(network,
                         cfg.model_dir,
                         resume=cfg.resume,
                         epoch=cfg.test.epoch)
    trainer.val(epoch, val_loader, evaluator)
Exemple #16
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    from lib.visualizers import make_visualizer

    network = make_network(cfg).cuda()
    print("model dir:{}".format(cfg.model_dir))
    load_network(network,
                 cfg.model_dir,
                 resume=cfg.resume,
                 epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    visualizer = make_visualizer(cfg)
    ann_file = 'data/NICE1/NICE1/coco_int/test/annotations/NICE_test.json'  # 这里是Test的json,因为这里使用的就是test的集合
    coco = COCO(ann_file)
    for batch in tqdm.tqdm(data_loader):  #tqdm是个进度条
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        # print("batch  ->{}".format(batch)) #batch是对的,里面还有img_id等相关信息
        # print("batch['meta']  ->{}".format(batch['meta'])) #batch['meta']是对的
        img_info = batch['meta']
        img_id = img_info['img_id']  #这个img_id用来根据id读取img名称从而确定输出图片名称的
        img_scale = img_info['scale']  # 这个img_scale是用来读取尺寸从而改变图像尺寸的
        # print("img id ->{}".format(img_id))
        # print("img scale ->{}".format(img_scale))
        img_name = coco.loadImgs(int(img_id))[0]['file_name']
        img_name, _ = os.path.splitext(img_name)
        # ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=0)  #两种方法,一种获得anno,一种直接loadimg获得filename
        # anno = coco.loadAnns(ann_ids)
        visualizer.visualize(output, batch, img_name)
        tmp_file = open(
            '/home/tianhao.lu/code/Deep_snake/snake/Result/Contour/Output_check.log',
            'w',
            encoding='utf8')
        tmp_file.writelines("Output -> :" + str(output) + "\n")
        tmp_file.writelines("batch -> :" + str(batch) + "\n")
        # for tmp_data in train_loader:
        #     tmp_file.writelines("one train_loader data type:" + str(type(tmp_data)) + "\n")
        #     for key in tmp_data:
        #         tmp_file.writelines("one train_loader data key:" + str(key) + "\n")
        #         tmp_file.writelines("one train_loader data len:" + str(len(tmp_data[key])) + "\n")
        #     # tmp_file.writelines("one train_loader data:" + str(tmp_data) + "\n")
        #     break
        tmp_file.writelines(
            str("*************************************************************** \n"
                ))
        tmp_file.close()
Exemple #17
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    from lib.utils import net_utils
    import tqdm
    import torch
    from lib.visualizers import make_visualizer
    from lib.networks.renderer import make_renderer
    import os

    cfg.perturb = 0
    cfg.render_name = args.render_name

    network = make_network(cfg)#.cuda()
    load_network(network,
                 cfg.trained_model_dir,
                 resume=cfg.resume,
                 epoch=cfg.test.epoch)
    network.train()

    data_loader = make_data_loader(cfg, is_train=False)
    n_devices = 2

    expand_k = ['index', 'bounds', 'R', 'Th',
                'center', 'rot', 'trans', 'i',
                'cam_ind', 'feature', 'coord', 'out_sh']
    renderer = torch.nn.DataParallel(make_renderer(cfg, network)).cuda()
    #renderer.set_save_mem(True)
    visualizer = make_visualizer(cfg)
    for batch in tqdm.tqdm(data_loader):
        batch_cuda = {}
        for k in batch:
            if k != 'meta':
                # TODO: ugly hack...
                if k not in expand_k:
                    sh = batch[k].shape
                    if sh[-1] == 3:
                        batch_cuda[k] = batch[k].view(-1, 3)
                    else:
                        batch_cuda[k] = batch[k].view(-1)
                else:
                    sh = batch[k].shape
                    batch_cuda[k] = batch[k].expand(n_devices, *sh[1:])

                batch_cuda[k] = batch_cuda[k].cuda()
            else:
                batch_cuda[k] = batch[k]


        with torch.no_grad():
            output = renderer(batch_cuda)
            visualizer.visualize(output, batch)
Exemple #18
0
def run_evaluate_nv():
    from lib.datasets import make_data_loader
    from lib.evaluators import make_evaluator
    import tqdm
    from lib.utils import net_utils

    data_loader = make_data_loader(cfg, is_train=False)
    evaluator = make_evaluator(cfg)
    for batch in tqdm.tqdm(data_loader):
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        evaluator.evaluate(batch)
    evaluator.summarize()
Exemple #19
0
def run_evaluate():
    from lib.datasets import make_data_loader
    from lib.evaluators import make_evaluator
    import tqdm
    import torch
    import pickle
    import lzma
    from lib.networks import make_network
    from lib.utils.net_utils import load_network
    from lib.evaluators.custom.monitor import MetricMonitor
    from lib.visualizers import make_visualizer

    torch.manual_seed(0)

    monitor = MetricMonitor()
    network = make_network(cfg).cuda()
    epoch = load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
    network.eval()
    print("Trainable parameters: {}".format(
        sum(p.numel() for p in network.parameters())))

    data_loader = make_data_loader(cfg, is_train=False)
    evaluator = make_evaluator(cfg)
    visualizer = make_visualizer(cfg)
    idx = 0

    if args.vis_out:
        os.mkdir(args.vis_out)

    for batch in tqdm.tqdm(data_loader):
        idx += 1
        inp = batch['inp'].cuda()
        with torch.no_grad():
            output = network(inp)
        evaluator.evaluate(output, batch)

        if args.vis_out:
            err = evaluator.data["obj_drilltip_trans_3d"][-1]
            visualizer.visualize(
                output, batch,
                os.path.join(
                    args.vis_out,
                    "tiperr{:.4f}_idx{:04d}.png".format(err.item(), idx)))
    result = evaluator.summarize()
    monitor.add('val', epoch, result)
    monitor.save_metrics("metrics.pkl")
    monitor.plot_histogram("evaluation.html", plotly=True)
Exemple #20
0
def train(cfg, network):
    trainer = make_trainer(cfg, network)
    optimizer = make_optimizer(cfg, network)
    scheduler = make_lr_scheduler(cfg, optimizer)
    recorder = make_recorder(cfg)
    #evaluator = make_evaluator(cfg)

    if cfg.network_full_init:
        begin_epoch = load_network(network,
                                   cfg.model_dir,
                                   resume=cfg.resume,
                                   epoch=cfg.test.epoch)
        begin_epoch = 0
    else:
        begin_epoch = load_model(network,
                                 optimizer,
                                 scheduler,
                                 recorder,
                                 cfg.model_dir,
                                 resume=cfg.resume)

    # set_lr_scheduler(cfg, scheduler)
    if DEBUG:
        print('------------------Loading training set-------------------')
    train_loader = make_data_loader(cfg, is_train=True)
    if DEBUG:
        print('Loading training set done...')
        print('---------------------------------------------------------')
    #val_loader = make_data_loader(cfg, is_train=False)

    for epoch in range(begin_epoch, cfg.train.epoch):
        recorder.epoch = epoch
        trainer.train(epoch, train_loader, optimizer, recorder)
        scheduler.step()

        if (epoch + 1) % cfg.save_ep == 0:
            save_model(network, optimizer, scheduler, recorder, epoch,
                       cfg.model_dir)

        #if (epoch + 1) % cfg.eval_ep == 0:
        #    trainer.val(epoch, val_loader, evaluator, recorder)

    return network
Exemple #21
0
def run_detector_pvnet():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    from lib.visualizers import make_visualizer

    network = make_network(cfg).cuda()
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    visualizer = make_visualizer(cfg)
    for batch in tqdm.tqdm(data_loader):
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        visualizer.visualize(output, batch)
Exemple #22
0
def run_evaluate():
    from lib.datasets import make_data_loader
    from lib.evaluators import make_evaluator
    import tqdm
    import torch
    from lib.networks import make_network
    from lib.utils.net_utils import load_network

    import numpy as np
    np.random.seed(1000)
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    evaluator = make_evaluator(cfg)

    count = 1200
    i = 0
    for batch in tqdm.tqdm(data_loader):
        if i == count:
            break
        inp = batch['inp'].cuda()
        # save input
        #         print(batch['img_id'])
        #         import pickle
        #         with open('/mbrdi/sqnap1_colomirror/gupansh/input_cat.pkl','wb') as fp:
        #             pickle.dump(batch['inp'], fp)
        #         input()
        seg_gt = batch['mask'].cuda()
        with torch.no_grad():
            output = network(inp)
        evaluator.evaluate(output, batch)

        i += 1
    evaluator.summarize()
Exemple #23
0
def run_analyze():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    from lib.analyzers import make_analyzer

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
    network.eval()

    cfg.train.num_workers = 0
    data_loader = make_data_loader(cfg, is_train=False)
    analyzer = make_analyzer(cfg)
    for batch in tqdm.tqdm(data_loader):
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        analyzer.analyze(output, batch)