示例#1
0
def main(config, model_path, output_path, input_shape=(320, 320)):
    logger = Logger(-1, config.save_dir, False)
    model = build_model(config.model)
    checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
    load_model_weight(model, checkpoint, logger)
    dummy_input = torch.autograd.Variable(torch.randn(1, 3, input_shape[0], input_shape[1]))
    torch.onnx.export(model, dummy_input, output_path, verbose=True, keep_initializers_as_inputs=True, opset_version=11)
    print('finished exporting onnx ')
示例#2
0
 def __init__(self, cfg, model_path, logger, device='cuda:0'):
     self.cfg = cfg
     self.device = device
     model = build_model(cfg.model)
     ckpt = torch.load(model_path,
                       map_location=lambda storage, loc: storage)
     load_model_weight(model, ckpt, logger)
     self.model = model.to(device).eval()
     self.pipeline = Pipeline(cfg.data.val.pipeline,
                              cfg.data.val.keep_ratio)
示例#3
0
def main(config, input_shape=(3, 320, 320)):
    model = build_model(config.model)
    #flops, params = get_model_complexity_info(model, input_shape)

    macs, params = get_model_complexity_info(model,
                                             input_shape,
                                             as_strings=True,
                                             print_per_layer_stat=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))
示例#4
0
def main(args):
    load_config(cfg, args.config)
    local_rank = -1
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    cfg.defrost()
    timestr = datetime.datetime.now().__format__('%Y%m%d%H%M%S')
    cfg.save_dir = os.path.join(cfg.save_dir, timestr)
    cfg.freeze()
    mkdir(local_rank, cfg.save_dir)
    logger = Logger(local_rank, cfg.save_dir)

    logger.log('Creating model...')
    model = build_model(cfg.model)

    logger.log('Setting up data...')
    val_dataset = build_dataset(cfg.data.val, args.task)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=1,
                                                 pin_memory=True,
                                                 collate_fn=collate_function,
                                                 drop_last=True)
    trainer = build_trainer(local_rank, cfg, model, logger)
    if 'load_model' in cfg.schedule:
        trainer.load_model(cfg)
    evaluator = build_evaluator(cfg, val_dataset)
    logger.log('Starting testing...')
    with torch.no_grad():
        results, val_loss_dict, _ = trainer.run_epoch(0,
                                                      val_dataloader,
                                                      mode=args.task)
    if args.task == 'test':
        res_json = evaluator.results2json(results)
        json_path = os.path.join(cfg.save_dir,
                                 'results{}.json'.format(timestr))
        json.dump(res_json, open(json_path, 'w'))
    elif args.task == 'val':
        eval_results = evaluator.evaluate(results,
                                          cfg.save_dir,
                                          0,
                                          logger,
                                          rank=local_rank)
        if args.save_result:
            txt_path = os.path.join(cfg.save_dir,
                                    "eval_results{}.txt".format(timestr))
            with open(txt_path, "a") as f:
                for k, v in eval_results.items():
                    f.write("{}: {}\n".format(k, v))
示例#5
0
def main(args):
    load_config(cfg, args.config)
    local_rank = int(args.local_rank)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    mkdir(local_rank, cfg.save_dir)
    logger = Logger(local_rank, cfg.save_dir)
    if args.seed is not None:
        logger.log('Set random seed to {}'.format(args.seed))
        init_seeds(args.seed)

    logger.log('Creating model...')
    model = build_model(cfg.model)

    print("model:", model)

    # pre_dict = model.state_dict() #按键值对将模型参数加载到pre_dict
    # for k, v in pre_dict.items(): # 打印模型参数
    # for k, v in pre_dict.items(): #打印模型每层命名
    #     print ('%-50s%s' %(k,v.shape))

    #summary(model, (3, 320, 320))

    logger.log('Setting up data...')
    train_dataset = build_dataset(cfg.data.train, 'train')
    val_dataset = build_dataset(cfg.data.val, 'test')

    if len(cfg.device.gpu_ids) > 1:
        print('rank = ', local_rank)
        num_gpus = torch.cuda.device_count()
        torch.cuda.set_device(local_rank % num_gpus)
        dist.init_process_group(backend='nccl')
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=cfg.device.batchsize_per_gpu,
            num_workers=cfg.device.workers_per_gpu,
            pin_memory=True,
            collate_fn=custom_collate_function,
            sampler=train_sampler,
            drop_last=True)
    else:

        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=cfg.device.batchsize_per_gpu,
            shuffle=True,
            collate_fn=custom_collate_function,
            pin_memory=True,
            drop_last=True)

    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        pin_memory=True,
        collate_fn=custom_collate_function,
        drop_last=True)

    trainer = build_trainer(local_rank, cfg, model, logger)

    if cfg.schedule.resume:
        trainer.resume(cfg)
        if 'load_model' in cfg.schedule:
            trainer.load_model(cfg)

    evaluator = build_evaluator(cfg, val_dataset)

    logger.log('Starting training...')
    trainer.run(train_dataloader, val_dataloader, evaluator)