コード例 #1
0
ファイル: train_stage2.py プロジェクト: wymGAKKI/saps
def prepareSave(args, data, pred_c, pred, recorder, log):
    input_var, mask_var = data['img'], data['mask']
    results = [input_var.data, mask_var.data, (data['normal'].data + 1) / 2]
    if args.s1_est_d:
        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data,
                                                       pred_c['dirs'].data,
                                                       args.batch)
        recorder.updateIter('train', l_acc.keys(), l_acc.values())
    if args.s1_est_i:
        int_acc, data['int_err'] = eval_utils.calIntsAcc(
            data['ints'].data, pred_c['intens'].data, args.batch)
        recorder.updateIter('train', int_acc.keys(), int_acc.values())

    if args.s2_est_n:
        acc, error_map = eval_utils.calNormalAcc(data['normal'].data,
                                                 pred['normal'].data,
                                                 mask_var.data)
        pred_n = (pred['normal'].data + 1) / 2
        masked_pred = pred_n * mask_var.data.expand_as(pred['normal'].data)
        res_n = [masked_pred, error_map['angular_map']]
        results += res_n
        recorder.updateIter('train', acc.keys(), acc.values())

    nrow = input_var.shape[0] if input_var.shape[0] <= 32 else 32
    return results, recorder, nrow
コード例 #2
0
def prepareRes(args, data, pred, recorder, log, split):
    data_batch = args.val_batch if split == 'val' else args.test_batch
    iter_res = []
    error = ''
    if args.s1_est_d:
        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data,
                                                       pred['dirs'].data,
                                                       data_batch)
        recorder.updateIter(split, l_acc.keys(), l_acc.values())
        iter_res.append(l_acc['l_err_mean'])
        error += 'D_%.3f-' % (l_acc['l_err_mean'])
    if args.s1_est_i:
        int_acc, data['int_err'] = eval_utils.calIntsAcc(
            data['ints'].data, pred['intens'].data, data_batch)
        recorder.updateIter(split, int_acc.keys(), int_acc.values())
        iter_res.append(int_acc['ints_ratio'])
        error += 'I_%.3f-' % (int_acc['ints_ratio'])

    if args.s1_est_n:
        acc, error_map = eval_utils.calNormalAcc(data['normal'].data,
                                                 pred['normal'].data,
                                                 data['mask'].data)
        recorder.updateIter(split, acc.keys(), acc.values())
        iter_res.append(acc['n_err_mean'])
        error += 'N_%.3f-' % (acc['n_err_mean'])
        data['error_map'] = error_map['angular_map']

    return recorder, iter_res, error
コード例 #3
0
def prepareRes(args, data, pred_c, pred, random_loc, recorder, log, split):
    mask_var = data['m']
    data_batch = args.val_batch if split == 'val' else args.test_batch
    iter_res = []
    error = ''
    if args.s1_est_d:
        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data,
                                                       pred_c['dirs'].data,
                                                       data_batch)
        recorder.updateIter(split, l_acc.keys(), l_acc.values())
        iter_res.append(l_acc['l_err_mean'])
        error += 'D_%.3f-' % (l_acc['l_err_mean'])

    if args.s2_est_n:
        random_x_loc, random_y_loc = random_loc
        n_tar = data['n'][:, :, random_x_loc - 8:random_x_loc + 8,
                          random_y_loc - 8:random_y_loc + 8]
        mask_var = mask_var[:, :, random_x_loc - 8:random_x_loc + 8,
                            random_y_loc - 8:random_y_loc + 8]
        acc, error_map = eval_utils.calNormalAcc(n_tar.data, pred['n'].data,
                                                 mask_var.data)
        recorder.updateIter(split, acc.keys(), acc.values())
        iter_res.append(acc['n_err_mean'])
        error += 'N_%.3f-' % (acc['n_err_mean'])
        data['error_map'] = error_map['angular_map']

    return recorder, iter_res, error
コード例 #4
0
ファイル: test_stage4.py プロジェクト: wymGAKKI/saps
def prepareNormalRes(args, data, pred, recorder, log, split):
    data_batch = args.val_batch if split == 'val' else args.test_batch
    iter_res = []
    error = ''
    acc, error_map = eval_utils.calNormalAcc(data['normal'].data, pred['normal'].data, data['mask'].data)
    recorder.updateIter(split, acc.keys(), acc.values())
    iter_res.append(acc['n_err_mean'])
    error += 'N_%.3f-' % (acc['n_err_mean'])
    data['error_map'] = error_map['angular_map']

    return recorder, iter_res, error
コード例 #5
0
ファイル: test_utils.py プロジェクト: rkripa/PS-FCN
def test(args, split, loader, model, log, epoch, recorder, tf_writer):
    model.eval()
    print('---- Start %s Epoch %d: %d batches ----' %
          (split, epoch, len(loader)))
    timer = time_utils.Timer(args.time_sync)

    disp_intv, save_intv = get_itervals(args, split)
    with torch.no_grad():
        for i, sample in enumerate(loader):
            data = model_utils.parseData(args, sample, timer, split)
            input = model_utils.getInput(args, data)

            out_var = model(input)
            timer.updateTime('Forward')
            acc = eval_utils.calNormalAcc(data['tar'].data, out_var.data,
                                          data['m'].data)
            recorder.updateIter(split, acc.keys(), acc.values())

            iters = i + 1
            if iters % disp_intv == 0:
                opt = {
                    'split': split,
                    'epoch': epoch,
                    'iters': iters,
                    'batch': len(loader),
                    'timer': timer,
                    'recorder': recorder
                }
                log.printItersSummary(opt)
                for tag, value in acc.items():
                    tfboard.tensorboard_scalar(tf_writer, tag, value, iters)

            if iters % save_intv == 0:
                pred = (out_var.data + 1) / 2
                masked_pred = pred * data['m'].data.expand_as(out_var.data)
                log.saveNormalResults(masked_pred, split, epoch, iters)

    opt = {'split': split, 'epoch': epoch, 'recorder': recorder}
    log.printEpochSummary(opt)