示例#1
0
def inference(model_outputs, activities, sequence_ids, args):
    model_output_probs = torch.nn.Softmax(dim=-1)(model_outputs)
    model_output_probs = model_output_probs.data.cpu().numpy()
    batch_earley_pred_labels = list()
    batch_tokens = list()
    batch_seg_pos = list()
    for batch_i in range(model_outputs.size()[1]):
        grammar_file = os.path.join(args.paths.grammar_root,
                                    activities[batch_i] + '.pcfg')
        grammar = grammarutils.read_grammar(grammar_file,
                                            index=True,
                                            mapping=args.metadata.action_index)
        gen_earley_parser = GEP.GeneralizedEarley(grammar, args.prior)
        best_string, prob = gen_earley_parser.parse(
            model_output_probs[:, batch_i, :])
        # print([int(s) for s in best_string.split()], "{:.2e}".format(decimal.Decimal(prob)))

        # Back trace to get labels of the entire sequence
        earley_pred_labels, tokens, seg_pos = gen_earley_parser.compute_labels(
        )
        batch_earley_pred_labels.append(earley_pred_labels)
        batch_tokens.append(tokens)
        batch_seg_pos.append(seg_pos)

    _, nn_pred_labels = torch.max(model_outputs, dim=2)

    return nn_pred_labels, batch_earley_pred_labels, batch_tokens, batch_seg_pos
示例#2
0
def inference(prob_mat, activity, sequence_id, args):
    grammar_file = os.path.join(args.paths.grammar_root, activity + '.pcfg')
    grammar = grammarutils.read_grammar(
        grammar_file, index=True, mapping=args.metadata.subactivity_index)
    gen_earley_parser = GEP.GeneralizedEarley(grammar)
    best_string, prob = gen_earley_parser.parse(prob_mat)
    # print([int(s) for s in best_string.split()], "{:.2e}".format(decimal.Decimal(prob)))

    # Back trace to get labels of the entire sequence
    earley_pred_labels, tokens, seg_pos = gen_earley_parser.compute_labels()
    nn_pred_labels = np.argmax(prob_mat, axis=1)
    return nn_pred_labels, earley_pred_labels, tokens, seg_pos
示例#3
0
文件: test.py 项目: tf369/GEP_PAMI
def visualize_grammar():
    paths = config.Paths()
    dataset_name = 'wnp'
    for pcfg in os.listdir(os.path.join(paths.tmp_root, 'grammar', dataset_name)):
        if not pcfg.endswith('.pcfg'):
            continue
        grammar_file = os.path.join(paths.tmp_root, 'grammar', dataset_name, pcfg)
        grammar = grammarutils.read_grammar(grammar_file, insert=False)
        dot_filename = os.path.join(paths.tmp_root, 'visualize', 'grammar', dataset_name, pcfg.replace('.pcfg', '.dot'))
        pdf_filename = os.path.join(paths.tmp_root, 'visualize', 'grammar', dataset_name, pcfg.replace('.pcfg', '.pdf'))
        grammarutils.grammar_to_dot(grammar, dot_filename)
        os.system('dot -Tpdf {} -o {}'.format(dot_filename, pdf_filename))
示例#4
0
文件: test.py 项目: tf369/GEP_PAMI
def test_time():
    paths = config.Paths()
    start_time = time.time()
    np.random.seed(int(start_time))
    classifier_output = np.random.rand(100000, 10)
    classifier_output = classifier_output / np.sum(classifier_output, axis=1)[:, None]  # Normalize to probability
    for pcfg in os.listdir(os.path.join(paths.tmp_root, 'grammar', 'cad')):
        if not pcfg.endswith('.pcfg'):
            continue
        grammar_file = os.path.join(paths.tmp_root, 'grammar', 'cad', pcfg)
        grammar = grammarutils.read_grammar(grammar_file, index=True, mapping=datasets.cad_metadata.subactivity_index)
        test_generalized_earley(grammar, classifier_output)
    print('Time elapsed: {}s'.format(time.time() - start_time))
示例#5
0
文件: test.py 项目: tf369/GEP_PAMI
def test_grammar():
    paths = config.Paths()
    for pcfg in os.listdir(os.path.join(paths.tmp_root, 'grammar', 'cad')):
        if not pcfg.endswith('.pcfg'):
            continue
        grammar_file = os.path.join(paths.tmp_root, 'grammar', 'cad', pcfg)
        grammar = grammarutils.read_grammar(grammar_file, index=True, mapping=datasets.cad_metadata.subactivity_index)
        corpus_file = os.path.join(paths.tmp_root, 'corpus', 'cad', pcfg.replace('pcfg', 'txt'))
        with open(corpus_file, 'r') as f:
            for line in f:
                tokens = [str(datasets.cad_metadata.subactivity_index[token]) for token in line.strip(' *#\n').split(' ')]
                earley_parser = nltk.EarleyChartParser(grammar, trace=0)
                e_chart = earley_parser.chart_parse(tokens)
                print(e_chart.edges()[-1])
示例#6
0
文件: test.py 项目: tf369/GEP_PAMI
def test_valid():
    paths = config.Paths()
    grammar_file = os.path.join(paths.tmp_root, 'grammar', 'cad', 'stacking_objects.pcfg')

    # sentence = 'null reaching moving placing'
    # grammar = grammarutils.read_grammar(grammar_file, index=False)
    # test_earley(grammar, sentence.split())

    sentence = 'null reaching'
    tokens = sentence.split()
    grammar = grammarutils.read_grammar(grammar_file, index=True, mapping=datasets.cad_metadata.subactivity_index)
    seg_length = 15
    correct_prob = 0.8
    classifier_output = np.ones((seg_length*2, 10)) * 1e-10
    classifier_output[:seg_length, datasets.cad_metadata.subactivity_index[tokens[0]]] = correct_prob
    classifier_output[seg_length:, datasets.cad_metadata.subactivity_index[tokens[1]]] = correct_prob

    classifier_output[:seg_length, datasets.cad_metadata.subactivity_index[tokens[0]]+1] = 1 - correct_prob
    classifier_output[seg_length:, datasets.cad_metadata.subactivity_index[tokens[1]]+1] = 1 - correct_prob
    test_generalized_earley(grammar, classifier_output)
示例#7
0
def validate(data_loader, detection_model, prediction_model, args):
    all_gt_frame_predictions = list()
    all_frame_predictions = list()
    all_nn_frame_predictions = list()

    task_acc_ratio = logutils.AverageMeter()
    task_macro_prec = logutils.AverageMeter()
    task_macro_rec = logutils.AverageMeter()
    task_macro_f1 = logutils.AverageMeter()
    task_acc_ratio_nn = logutils.AverageMeter()

    # switch to evaluate mode
    detection_model.eval()
    prediction_model.eval()

    for batch_idx, data_unit in enumerate(
            tqdm(data_loader, desc='GEP evaluation')):
        features_batch, labels_batch, activities, sequence_ids, total_lengths, obj_nums, ctc_labels, ctc_lengths, probs_batch, additional = data_unit

        padding = features_batch[0, :, :].repeat(args.using_pred_duration - 1,
                                                 1, 1)
        prediction_features = torch.cat((padding, features_batch), dim=0)
        prediction_output = prediction_model(prediction_features)
        detection_output = detection_model(features_batch)

        _, detection_labels = torch.max(detection_output, dim=-1)
        detection_labels = detection_labels.cpu().numpy()

        for batch_i in range(detection_output.size(1)):

            gt_all_pred_labels = labels_batch[1:total_lengths[batch_i],
                                              batch_i].cpu().numpy().tolist()
            _, nn_all_pred_labels = torch.max(
                prediction_output[:total_lengths[batch_i] - 1, batch_i, :],
                dim=-1)
            nn_all_pred_labels = nn_all_pred_labels.cpu().numpy().tolist()

            # Initialization of Earley Parser
            class_num = detection_output.shape[2]
            grammar_file = os.path.join(args.paths.grammar_root,
                                        activities[batch_i] + '.pcfg')
            grammar = grammarutils.read_grammar(grammar_file, index=True)
            gen_earley_parser = GEP.GeneralizedEarley(
                grammar, class_num, mapping=args.metadata.action_index)
            with open(
                    os.path.join(args.paths.prior_root,
                                 'duration_prior.json')) as f:
                duration_prior = json.load(f)

            record = dict()

            start_time = time.time()
            for frame in range(total_lengths[batch_i] -
                               args.using_pred_duration):
                nn_pred_labels = nn_all_pred_labels[frame:frame +
                                                    args.using_pred_duration]
                gt_pred_labels = gt_all_pred_labels[frame:frame +
                                                    args.using_pred_duration]
                update_length = len(nn_pred_labels)

                pred_labels = predict(gen_earley_parser,
                                      detection_output[frame, batch_i, :],
                                      duration_prior, record, frame, args)
                # gt = torch.ones(detection_output.size(2)) * 1e-5
                # gt[labels_batch[frame, batch_i]] = 1
                # gt = torch.log(gt / torch.sum(gt))
                # pred_labels = predict(gen_earley_parser, gt,
                #                       duration_prior, record, frame, args)
                # print(frame)
                # print('detection_labels', detection_labels[max(0, frame - 44) : frame + 1, batch_i].tolist())
                # print('gt_detect labels', labels_batch[max(0, frame - 44) :frame+1, batch_i].cpu().numpy().tolist())
                # print('gt_predic_labels', gt_pred_labels)
                # print('nn_predic_labels', nn_pred_labels)
                # print('xx_predic_labels', pred_labels)

                micro_prec = logutils.compute_accuracy(gt_pred_labels,
                                                       pred_labels)
                nn_micro_prec = logutils.compute_accuracy(
                    gt_pred_labels, nn_pred_labels)
                macro_prec, macro_rec, macro_f1 = logutils.compute_accuracy(
                    gt_pred_labels, nn_pred_labels, metric='macro')
                task_acc_ratio.update(micro_prec, update_length)
                task_acc_ratio_nn.update(nn_micro_prec, update_length)
                task_macro_prec.update(macro_prec, update_length)
                task_macro_rec.update(macro_rec, update_length)
                task_macro_f1.update(macro_f1, update_length)

                all_gt_frame_predictions.extend(gt_pred_labels)
                all_frame_predictions.extend(pred_labels)
                all_nn_frame_predictions.extend(nn_pred_labels)

            print(time.time() - start_time)

        tqdm.write('Task {} {} Batch [{}/{}]\t'
                   'Acc {top1.val:.4f} ({top1.avg:.4f})\t'
                   'NN Acc {nn.val:.4f} ({nn.avg:.4f})\t'
                   'Prec {prec.val:.4f} ({prec.avg:.4f})\t'
                   'Recall {recall.val:.4f} ({recall.avg:.4f})\t'
                   'F1 {f1.val:.4f} ({f1.avg:.4f})'.format(
                       args.task,
                       'test',
                       batch_idx,
                       len(data_loader),
                       top1=task_acc_ratio,
                       nn=task_acc_ratio_nn,
                       prec=task_macro_prec,
                       recall=task_macro_rec,
                       f1=task_macro_f1))

    micro_prec = logutils.compute_accuracy(all_gt_frame_predictions,
                                           all_frame_predictions)
    nn_micro_prec = logutils.compute_accuracy(all_gt_frame_predictions,
                                              all_nn_frame_predictions)
    macro_prec, macro_recall, macro_fscore = logutils.compute_accuracy(
        all_gt_frame_predictions, all_nn_frame_predictions, metric='weighted')
    tqdm.write('[Evaluation] Micro Prec: {}\t'
               'NN Micro Prec: {}\t'
               'Macro Precision: {}\t'
               'Macro Recall: {}\t'
               'Macro F-score: {}'.format(micro_prec, nn_micro_prec,
                                          macro_prec, macro_recall,
                                          macro_fscore))