예제 #1
0
    def __init__(self, schema, lexicon, model_path, fact_check, decoding, timed_session=False, consecutive_entity=True, realizer=None):
        super(NeuralSystem, self).__init__()
        self.schema = schema
        self.lexicon = lexicon
        self.timed_session = timed_session
        self.consecutive_entity = consecutive_entity

        # Load arguments
        args_path = os.path.join(model_path, 'config.json')
        config = read_json(args_path)
        config['batch_size'] = 1
        config['gpu'] = 0  # Don't need GPU for batch_size=1
        config['decoding'] = decoding
        args = argparse.Namespace(**config)

        mappings_path = os.path.join(model_path, 'vocab.pkl')
        mappings = read_pickle(mappings_path)
        vocab = mappings['vocab']

        # TODO: different models have the same key now
        args.dropout = 0
        logstats.add_args('model_args', args)
        model = build_model(schema, mappings, args)

        # Tensorflow config
        if args.gpu == 0:
            print 'GPU is disabled'
            config = tf.ConfigProto(device_count = {'GPU': 0})
        else:
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = 0.5, allow_growth=True)
            config = tf.ConfigProto(device_count = {'GPU': 1}, gpu_options=gpu_options)

        # NOTE: need to close the session when done
        tf_session = tf.Session(config=config)
        tf.initialize_all_variables().run(session=tf_session)

        # Load TF model parameters
        ckpt = tf.train.get_checkpoint_state(model_path+'-best')
        assert ckpt, 'No checkpoint found'
        assert ckpt.model_checkpoint_path, 'No model path found in checkpoint'
        saver = tf.train.Saver()
        saver.restore(tf_session, ckpt.model_checkpoint_path)

        self.model_name = args.model
        if self.model_name == 'attn-copy-encdec':
            args.entity_target_form = 'graph'
            copy = True
        else:
            copy = False
        preprocessor = Preprocessor(schema, lexicon, args.entity_encoding_form, args.entity_decoding_form, args.entity_target_form, args.prepend)
        textint_map = TextIntMap(vocab, mappings['entity'], preprocessor)

        Env = namedtuple('Env', ['model', 'tf_session', 'preprocessor', 'vocab', 'copy', 'textint_map', 'stop_symbol', 'remove_symbols', 'max_len', 'evaluator', 'prepend', 'consecutive_entity', 'realizer'])
        self.env = Env(model, tf_session, preprocessor, mappings['vocab'], copy, textint_map, stop_symbol=vocab.to_ind(markers.EOS), remove_symbols=map(vocab.to_ind, (markers.EOS, markers.PAD)), max_len=20, evaluator=FactEvaluator() if fact_check else None, prepend=args.prepend, consecutive_entity=self.consecutive_entity, realizer=realizer)
예제 #2
0
def compute_statistics(args, lexicon, schema, scenario_db, transcripts):
    if not os.path.exists(os.path.dirname(args.stats_output)) and len(
            os.path.dirname(args.stats_output)) > 0:
        os.makedirs(os.path.dirname(args.stats_output))

    stats = {}
    statsfile = open(args.stats_output, 'w')
    stats["total"] = total_stats = get_total_statistics(
        transcripts, scenario_db)
    print "Aggregated total dataset statistics"
    print_group_stats(total_stats)

    # LM
    if args.lm:
        import kenlm
        lm = kenlm.Model(args.lm)
    else:
        lm = None

    # Speech acts
    preprocessor = Preprocessor(schema, lexicon, 'canonical', 'canonical',
                                'canonical', False)
    strategy_stats = analyze_strategy(transcripts, scenario_db, preprocessor,
                                      args.text_output, lm)
    print_strategy_stats(strategy_stats)
    stats["speech_act"] = {
        k[0]: v
        for k, v in strategy_stats['speech_act'].iteritems() if len(k) == 1
    }
    stats["kb_strategy"] = strategy_stats['kb_strategy']
    stats["dialog_stats"] = strategy_stats['dialog_stats']
    stats["lm_score"] = strategy_stats['lm_score']
    stats["correct"] = strategy_stats['correct']
    stats["entity_mention"] = strategy_stats['entity_mention']
    stats['multi_speech_act'] = strategy_stats['multi_speech_act']

    if args.plot_alpha_stats:
        plot_alpha_stats(strategy_stats["alpha_stats"], args.plot_alpha_stats)

    if args.plot_item_stats:
        plot_num_item_stats(strategy_stats["num_items_stats"],
                            args.plot_item_stats)
    json.dump(stats, statsfile)
    statsfile.close()
예제 #3
0
    raw_chats = read_json(args.dialogue_transcripts)
    uuid_to_chat = {chat['uuid']: chat for chat in raw_chats}
    schema = Schema(args.schema_path)
    scenario_db = ScenarioDB.from_dict(schema, read_json(args.scenarios_path))
    dialogue_ids = filter(raw_eval, uuid_to_chat)

    for eval_ in raw_eval:
        read_eval(eval_, question_scores, mask=dialogue_ids)

    if args.hist:
        hist(question_scores, args.outdir, partner=args.partner)

    if args.summary:
        summary = summarize(question_scores)
        write_json(summary, args.stats)

    if args.analyze:
        schema = Schema(args.schema_path)
        lexicon = Lexicon(schema,
                          False,
                          scenarios_json=args.scenarios_path,
                          stop_words=args.stop_words)
        preprocessor = Preprocessor(schema, lexicon, 'canonical', 'canonical',
                                    'canonical')
        analyze(question_scores, uuid_to_chat, preprocessor)

    # Visualize
    if args.html_output:
        visualize(args.viewer_mode, args.html_output, question_scores,
                  uuid_to_chat)
예제 #4
0
            agent = e["agent"]

            if action == "message":
                raw_tokens = re.findall(re_pattern, msg_data)
                lower_raw_tokens = [r.lower() for r in raw_tokens]
                _, candidate_annotation = lexicon.link_entity(
                    lower_raw_tokens,
                    return_entities=True,
                    agent=agent,
                    uuid=scenario_uuid)

                for c in candidate_annotation:
                    # Entity, Span, Type
                    fout.write(c[1][0] + "\t" + c[0] + "\t" + c[1][1] + "\n")

    preprocessor = Preprocessor(schema, lexicon, 'canonical', 'canonical',
                                'canonical')
    for raw in examples:
        ex = Example.from_dict(None, raw)
        kbs = ex.scenario.kbs
        mentioned_entities = set()
        for i, event in enumerate(ex.events):
            if event.action == 'message':
                utterance = preprocessor.process_event(event, kbs[event.agent],
                                                       mentioned_entities)
                # Skip empty utterances
                if utterance:
                    utterance = utterance[0]
                    for token in utterance:
                        if is_entity(token):
                            span, entity = token
                            entity, type_ = entity
예제 #5
0
파일: main.py 프로젝트: tigerneil/cocoa
    schema = Schema(model_args.schema_path, model_args.domain)
    scenario_db = ScenarioDB.from_dict(schema, read_json(args.scenarios_path))
    dataset = read_dataset(scenario_db, args)
    print 'Building lexicon...'
    start = time.time()
    lexicon = Lexicon(schema, args.learned_lex, stop_words=args.stop_words)
    print '%.2f s' % (time.time() - start)

    # Dataset
    use_kb = False if model_args.model == 'encdec' else True
    copy = True if model_args.model == 'attn-copy-encdec' else False
    if model_args.model == 'attn-copy-encdec':
        model_args.entity_target_form = 'graph'
    preprocessor = Preprocessor(schema, lexicon,
                                model_args.entity_encoding_form,
                                model_args.entity_decoding_form,
                                model_args.entity_target_form)
    if args.test:
        model_args.dropout = 0
        data_generator = DataGenerator(None, None, dataset.test_examples,
                                       preprocessor, schema,
                                       model_args.num_items, mappings, use_kb,
                                       copy)
    else:
        data_generator = DataGenerator(dataset.train_examples,
                                       dataset.test_examples, None,
                                       preprocessor, schema,
                                       model_args.num_items, mappings, use_kb,
                                       copy)
    for d, n in data_generator.num_examples.iteritems():
        logstats.add('data', d, 'num_dialogues', n)
예제 #6
0
        mappings = read_pickle(vocab_path)
        print 'Done [%fs]' % (time.time() - start)
    else:
        # Save config
        if not os.path.isdir(args.checkpoint):
            os.makedirs(args.checkpoint)
        config_path = os.path.join(args.checkpoint, 'config.json')
        write_json(vars(args), config_path)
        model_args = args
        mappings = None
        ckpt = None

    schema = Schema(model_args.schema_path, model_args.domain)
    scenario_db = ScenarioDB.from_dict(schema, read_json(args.scenarios_path))
    dataset = read_dataset(scenario_db, args)
    word_counts = Preprocessor.count_words(
        chain(dataset.train_examples, dataset.test_examples))
    print 'Building lexicon...'
    start = time.time()
    lexicon = Lexicon(schema, args.learned_lex, stop_words=args.stop_words)
    print '%.2f s' % (time.time() - start)

    # Dataset
    use_kb = False if model_args.model == 'encdec' else True
    copy = True if model_args.model == 'attn-copy-encdec' else False
    if model_args.model == 'attn-copy-encdec':
        model_args.entity_target_form = 'graph'
    preprocessor = Preprocessor(schema, lexicon,
                                model_args.entity_encoding_form,
                                model_args.entity_decoding_form,
                                model_args.entity_target_form,
                                model_args.prepend)