예제 #1
0
    backbone_params = list(model.node_layers.parameters()) + list(
        model.edge_layers.parameters())
    backbone_params += list(model.final_layers.parameters())

    backbone_ids = [id(item) for item in backbone_params]

    new_params = [
        param for param in model.parameters() if id(param) not in backbone_ids
    ]
    opt_params = [
        dict(params=backbone_params, lr=cfg.TRAIN.LR * 0.01),
        dict(params=new_params, lr=cfg.TRAIN.LR),
    ]
    optimizer = optim.Adam(opt_params)

    if not Path(cfg.model_dir).exists():
        Path(cfg.model_dir).mkdir(parents=True)

    num_epochs, _, __ = lr_schedules[cfg.TRAIN.lr_schedule]
    with DupStdoutFileManager(str(Path(cfg.model_dir) /
                                  ("train_log.log"))) as _:
        model, accs = train_eval_model(
            model,
            criterion,
            optimizer,
            dataloader,
            num_epochs=num_epochs,
            resume=cfg.warmstart_path is not None,
            start_epoch=0,
        )
예제 #2
0
    mod = importlib.import_module(cfg.MODULE)
    Net = mod.Net

    torch.manual_seed(cfg.RANDOM_SEED)

    image_dataset = GMDataset(cfg.DATASET_FULL_NAME,
                              sets='test',
                              length=cfg.EVAL.SAMPLES,
                              obj_resize=cfg.PAIR.RESCALE)
    dataloader = get_dataloader(image_dataset)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = Net()
    model = model.cuda()
    model = DataParallel(model, device_ids=range(torch.cuda.device_count()))

    if not Path(cfg.OUTPUT_PATH).exists():
        Path(cfg.OUTPUT_PATH).mkdir(parents=True)
    now_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    with DupStdoutFileManager(
            str(Path(cfg.OUTPUT_PATH) /
                ('eval_log_' + now_time + '.log'))) as _:
        print_easydict(cfg)
        classes = dataloader.dataset.classes
        pcks = eval_model(
            model,
            dataloader,
            eval_epoch=cfg.EVAL.EPOCH if cfg.EVAL.EPOCH != 0 else None,
            verbose=True)
예제 #3
0
    xlsx_data['t_mae'] = []
    xlsx_data['err_r_deg_mean'] = []
    xlsx_data['err_t_mean'] = []
    xlsx_data['CCD'] = []
    xlsx_data['acc'] = []
    xlsx_data['acc'] = []
    return xlsx_data


ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(ROOT_DIR, 'output')
allalgfile = np.sort(os.listdir(output_dir))
xlsx_data = xlsx_init({})

for file in allalgfile:
    with DupStdoutFileManager(output_dir + '/' + file +
                              '/acc_record.log') as _:
        npyfile = np.sort([
            txt for txt in os.listdir(output_dir + '/' + file)
            if txt.endswith('metric.npy')
        ])
        for npyfile_i in npyfile:
            metric_all = np.load(output_dir + '/' + file + '/' + npyfile_i,
                                 allow_pickle=True).item()
            print(file + '/' + npyfile_i)
            summary_metrics = summarize_metrics(metric_all)
            print_metrics(summary_metrics)
            xlsx_data = output2dict(file + '/' + npyfile_i, summary_metrics,
                                    xlsx_data)
            for i in [0.1, 0.5, 1, 5, 10, 20]:
                cur_acc = sum(
                    (metric_all['r_mae'] <= i) *