コード例 #1
0
ファイル: train_stage2.py プロジェクト: wymGAKKI/saps
def prepareSave(args, data, pred_c, pred, recorder, log):
    input_var, mask_var = data['img'], data['mask']
    results = [input_var.data, mask_var.data, (data['normal'].data + 1) / 2]
    if args.s1_est_d:
        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data,
                                                       pred_c['dirs'].data,
                                                       args.batch)
        recorder.updateIter('train', l_acc.keys(), l_acc.values())
    if args.s1_est_i:
        int_acc, data['int_err'] = eval_utils.calIntsAcc(
            data['ints'].data, pred_c['intens'].data, args.batch)
        recorder.updateIter('train', int_acc.keys(), int_acc.values())

    if args.s2_est_n:
        acc, error_map = eval_utils.calNormalAcc(data['normal'].data,
                                                 pred['normal'].data,
                                                 mask_var.data)
        pred_n = (pred['normal'].data + 1) / 2
        masked_pred = pred_n * mask_var.data.expand_as(pred['normal'].data)
        res_n = [masked_pred, error_map['angular_map']]
        results += res_n
        recorder.updateIter('train', acc.keys(), acc.values())

    nrow = input_var.shape[0] if input_var.shape[0] <= 32 else 32
    return results, recorder, nrow
コード例 #2
0
def prepareRes(args, data, pred, recorder, log, split):
    data_batch = args.val_batch if split == 'val' else args.test_batch
    iter_res = []
    error = ''
    if args.s1_est_d:
        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data,
                                                       pred['dirs'].data,
                                                       data_batch)
        recorder.updateIter(split, l_acc.keys(), l_acc.values())
        iter_res.append(l_acc['l_err_mean'])
        error += 'D_%.3f-' % (l_acc['l_err_mean'])
    if args.s1_est_i:
        int_acc, data['int_err'] = eval_utils.calIntsAcc(
            data['ints'].data, pred['intens'].data, data_batch)
        recorder.updateIter(split, int_acc.keys(), int_acc.values())
        iter_res.append(int_acc['ints_ratio'])
        error += 'I_%.3f-' % (int_acc['ints_ratio'])

    if args.s1_est_n:
        acc, error_map = eval_utils.calNormalAcc(data['normal'].data,
                                                 pred['normal'].data,
                                                 data['mask'].data)
        recorder.updateIter(split, acc.keys(), acc.values())
        iter_res.append(acc['n_err_mean'])
        error += 'N_%.3f-' % (acc['n_err_mean'])
        data['error_map'] = error_map['angular_map']

    return recorder, iter_res, error
コード例 #3
0
def prepareRes(args, data, pred_c, pred, random_loc, recorder, log, split):
    mask_var = data['m']
    data_batch = args.val_batch if split == 'val' else args.test_batch
    iter_res = []
    error = ''
    if args.s1_est_d:
        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data,
                                                       pred_c['dirs'].data,
                                                       data_batch)
        recorder.updateIter(split, l_acc.keys(), l_acc.values())
        iter_res.append(l_acc['l_err_mean'])
        error += 'D_%.3f-' % (l_acc['l_err_mean'])

    if args.s2_est_n:
        random_x_loc, random_y_loc = random_loc
        n_tar = data['n'][:, :, random_x_loc - 8:random_x_loc + 8,
                          random_y_loc - 8:random_y_loc + 8]
        mask_var = mask_var[:, :, random_x_loc - 8:random_x_loc + 8,
                            random_y_loc - 8:random_y_loc + 8]
        acc, error_map = eval_utils.calNormalAcc(n_tar.data, pred['n'].data,
                                                 mask_var.data)
        recorder.updateIter(split, acc.keys(), acc.values())
        iter_res.append(acc['n_err_mean'])
        error += 'N_%.3f-' % (acc['n_err_mean'])
        data['error_map'] = error_map['angular_map']

    return recorder, iter_res, error
コード例 #4
0
def prepareSave(args, data, pred, recorder, log):
    results = [data['img'].data, data['m'].data, (data['n'].data + 1) / 2]
    if args.s1_est_d:
        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data,
                                                       pred['dirs'].data,
                                                       args.batch)
        recorder.updateIter('train', l_acc.keys(), l_acc.values())

    nrow = data['img'].shape[0] if data['img'].shape[0] <= 32 else 32
    return results, recorder, nrow
コード例 #5
0
def prepareRes(args, data, pred, recorder, log, split):
    data_batch = args.val_batch if split == 'val' else args.test_batch
    iter_res = []
    error = ''
    if args.s1_est_d:
        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data,
                                                       pred['dirs'].data,
                                                       data_batch)
        recorder.updateIter(split, l_acc.keys(), l_acc.values())
        iter_res.append(l_acc['l_err_mean'])
        error += 'D_%.3f-' % (l_acc['l_err_mean'])

    return recorder, iter_res, error