示例#1
0
def run_training(head_host, head_port, debug_out=None):
    """Main worker training routine (creates the Seq2SeqTrainingService and connects it to the
    head.

    @param head_host: hostname of the head
    @param head_port: head port number
    @param debug_out: path to the debugging output file (debug output discarded if None)
    """
    # setup debugging output, if applicable
    if debug_out is not None:
        set_debug_stream(file_stream(debug_out, mode='w'))
    # start the server (in the background)
    log_info('Creating training server...')
    server = ThreadPoolServer(service=Seq2SeqTrainingService, nbThreads=1)
    server_thread = Thread(target=server.start)
    server_thread.start()
    my_host = socket.getfqdn()
    log_info('Worker server created at %s:%d. Connecting to head at %s:%d...' %
             (my_host, server.port, head_host, head_port))
    # notify main about this server
    conn = connect(head_host, head_port, config={'allow_pickle': True})
    conn.root.register_worker(my_host, server.port)
    conn.close()
    log_info('Worker is registered with the head.')
    # now serve until we're killed (the server thread will continue to run)
    server_thread.join()
示例#2
0
    def save_to_file(self, model_fname):
        """Save the whole ensemble into a file (get all settings and parameters, dump them in a
        pickle)."""
        # TODO support for lexicalizer

        log_info("Saving generator to %s..." % model_fname)
        with file_stream(model_fname, 'wb', encoding=None) as fh:
            pickle.dump(self.__class__, fh, protocol=pickle.HIGHEST_PROTOCOL)
            pickle.dump(self.cfg, fh, protocol=pickle.HIGHEST_PROTOCOL)

            gens_dump = []
            for gen in self.gens:
                setting = gen.get_all_settings()
                parset = gen.get_model_params()
                setting['classif_filter'] = self.classif_filter is not None
                gens_dump.append((setting, parset))

            pickle.dump(gens_dump, fh, protocol=pickle.HIGHEST_PROTOCOL)

            if self.classif_filter:
                pickle.dump(self.classif_filter.get_all_settings(),
                            fh,
                            protocol=pickle.HIGHEST_PROTOCOL)
                pickle.dump(self.classif_filter.get_model_params(),
                            fh,
                            protocol=pickle.HIGHEST_PROTOCOL)
示例#3
0
 def save_to_file(self, model_fname):
     log_info("Saving classifier to %s..." % model_fname)
     with file_stream(model_fname, 'wb', encoding=None) as fh:
         pickle.dump(self.__class__, fh, protocol=pickle.HIGHEST_PROTOCOL)
         pickle.dump(self.get_all_settings(),
                     fh,
                     protocol=pickle.HIGHEST_PROTOCOL)
示例#4
0
 def load_surface_forms(self, surface_forms_fname):
     """Load all proper name surface forms from a file."""
     log_info('Loading surface forms from %s...' % surface_forms_fname)
     with file_stream(surface_forms_fname) as fh:
         data = json.load(fh)
     for slot, values in data.iteritems():
         sf_all = {}
         sf_formeme = {}
         sf_tag = {}
         for value in values.keys():
             for surface_form in values[value]:
                 form, tag = surface_form.split("\t")
                 if slot == 'street':  # add street number placeholders to addresses
                     value += ' _'
                     slot = 'address'
                 # store the value globally + for all possible tag subsets/formemes
                 sf_all[value] = sf_all.get(value, []) + [form]
                 sf_tag[value] = sf_tag.get(value, {})
                 sf_formeme[value] = sf_formeme.get(value, {})
                 for tag_subset in self._get_tag_subsets(tag):
                     sf_tag[value][tag_subset] = sf_tag[value].get(
                         tag_subset, []) + [form]
                 for formeme in self._get_compatible_formemes(tag):
                     sf_formeme[value][formeme] = sf_formeme[value].get(
                         formeme, []) + [form]
         self._sf_all[slot] = sf_all
         self._sf_by_formeme[slot] = sf_formeme
         self._sf_by_tag[slot] = sf_tag
示例#5
0
 def save_model(self, model_fname_pattern):
     """Save the RNNLM model to a file."""
     model_fname = re.sub(r'(.pickle)?(.gz)?$', '.rnnlm',
                          model_fname_pattern)
     with file_stream(model_fname, 'wb', encoding=None) as fh:
         pickle.dump(self.get_all_settings(), fh, pickle.HIGHEST_PROTOCOL)
         pickle.dump(self.get_model_params(), fh, pickle.HIGHEST_PROTOCOL)
示例#6
0
 def save_model(self, model_fname_pattern):
     if not self._word_freq:
         log_warn('No lexicalizer model trained, skipping saving!')
     model_fname = re.sub(r'(.pickle)?(.gz)?$', '.wfreq',
                          model_fname_pattern)
     with file_stream(model_fname, 'wb', encoding=None) as fh:
         pickle.dump(self._word_freq, fh, pickle.HIGHEST_PROTOCOL)
示例#7
0
def percrank_train(args):
    opts, files = getopt(args, 'c:d:s:j:w:e:r:')
    candgen_model = None
    train_size = 1.0
    parallel = False
    jobs_number = 0
    work_dir = None
    experiment_id = None

    for opt, arg in opts:
        if opt == '-d':
            set_debug_stream(file_stream(arg, mode='w'))
        elif opt == '-s':
            train_size = float(arg)
        elif opt == '-c':
            candgen_model = arg
        elif opt == '-j':
            parallel = True
            jobs_number = int(arg)
        elif opt == '-w':
            work_dir = arg
        elif opt == '-e':
            experiment_id = arg
        elif opt == '-r' and arg:
            rnd.seed(arg)

    if len(files) != 4:
        sys.exit(__doc__)

    fname_rank_config, fname_train_das, fname_train_ttrees, fname_rank_model = files
    log_info('Training perceptron ranker...')

    rank_config = Config(fname_rank_config)
    if candgen_model:
        rank_config['candgen_model'] = candgen_model
    if rank_config.get('nn'):
        from tgen.rank_nn import SimpleNNRanker, EmbNNRanker
        if rank_config['nn'] in ['emb', 'emb_trees', 'emb_prev']:
            ranker_class = EmbNNRanker
        else:
            ranker_class = SimpleNNRanker
    else:
        ranker_class = PerceptronRanker

    log_info('Using %s for ranking' % ranker_class.__name__)

    if not parallel:
        ranker = ranker_class(rank_config)
    else:
        rank_config['jobs_number'] = jobs_number
        if work_dir is None:
            work_dir, _ = os.path.split(fname_rank_config)
        ranker = ParallelRanker(rank_config, work_dir, experiment_id,
                                ranker_class)

    ranker.train(fname_train_das, fname_train_ttrees, data_portion=train_size)

    # avoid the "maximum recursion depth exceeded" error
    sys.setrecursionlimit(100000)
    ranker.save_to_file(fname_rank_model)
示例#8
0
 def load_model(self, model_fname_pattern):
     """Load the RNNLM model from a file."""
     model_fname = re.sub(r'(.pickle)?(.gz)?$', '.rnnlm', model_fname_pattern)
     with file_stream(model_fname, 'rb', encoding=None) as fh:
         self.load_all_settings(pickle.load(fh))
         self._init_neural_network()
         self.set_model_params(pickle.load(fh))
示例#9
0
文件: run_tgen.py 项目: UFAL-DSG/tgen
def percrank_train(args):
    opts, files = getopt(args, 'c:d:s:j:w:e:r:')
    candgen_model = None
    train_size = 1.0
    parallel = False
    jobs_number = 0
    work_dir = None
    experiment_id = None

    for opt, arg in opts:
        if opt == '-d':
            set_debug_stream(file_stream(arg, mode='w'))
        elif opt == '-s':
            train_size = float(arg)
        elif opt == '-c':
            candgen_model = arg
        elif opt == '-j':
            parallel = True
            jobs_number = int(arg)
        elif opt == '-w':
            work_dir = arg
        elif opt == '-e':
            experiment_id = arg
        elif opt == '-r' and arg:
            rnd.seed(arg)

    if len(files) != 4:
        sys.exit(__doc__)

    fname_rank_config, fname_train_das, fname_train_ttrees, fname_rank_model = files
    log_info('Training perceptron ranker...')

    rank_config = Config(fname_rank_config)
    if candgen_model:
        rank_config['candgen_model'] = candgen_model
    if rank_config.get('nn'):
        from tgen.rank_nn import SimpleNNRanker, EmbNNRanker
        if rank_config['nn'] in ['emb', 'emb_trees', 'emb_prev']:
            ranker_class = EmbNNRanker
        else:
            ranker_class = SimpleNNRanker
    else:
        ranker_class = PerceptronRanker

    log_info('Using %s for ranking' % ranker_class.__name__)

    if not parallel:
        ranker = ranker_class(rank_config)
    else:
        rank_config['jobs_number'] = jobs_number
        if work_dir is None:
            work_dir, _ = os.path.split(fname_rank_config)
        ranker = ParallelRanker(rank_config, work_dir, experiment_id, ranker_class)

    ranker.train(fname_train_das, fname_train_ttrees, data_portion=train_size)

    # avoid the "maximum recursion depth exceeded" error
    sys.setrecursionlimit(100000)
    ranker.save_to_file(fname_rank_model)
示例#10
0
    def save_to_file(self, lexicalizer_fname):
        """Save the lexicalizer model to a file (and a second file with the LM, if needed)."""
        log_info("Saving lexicalizer to %s..." % lexicalizer_fname)
        with file_stream(lexicalizer_fname, 'wb', encoding=None) as fh:
            pickle.dump(self.get_all_settings(), fh, protocol=pickle.HIGHEST_PROTOCOL)

        if not isinstance(self._form_select, RandomFormSelect):
            self._form_select.save_model(lexicalizer_fname)
示例#11
0
 def load_from_file(reranker_fname):
     """Detect correct model type and start loading."""
     model_type = RerankingClassifier  # default to classifier
     with file_stream(reranker_fname, 'rb', encoding=None) as fh:
         data = pickle.load(fh)
         if isinstance(data, type):
             from tgen.e2e.slot_error import E2EPatternClassifier
             model_type = data
     return model_type.load_from_file(reranker_fname)
示例#12
0
文件: run_tgen.py 项目: UFAL-DSG/tgen
def seq2seq_train(args):

    ap = ArgumentParser(prog=' '.join(sys.argv[0:2]))

    ap.add_argument('-s', '--train-size', type=float,
                    help='Portion of the training data to use (default: 1.0)', default=1.0)
    ap.add_argument('-d', '--debug-logfile', type=str, help='Debug output file name')
    ap.add_argument('-j', '--jobs', type=int, help='Number of parallel jobs to use')
    ap.add_argument('-w', '--work-dir', type=str, help='Main working directory for parallel jobs')
    ap.add_argument('-e', '--experiment-id', type=str,
                    help='Experiment ID for parallel jobs (used as job name prefix)')
    ap.add_argument('-r', '--random-seed', type=str,
                    help='Initial random seed (used as string).')
    ap.add_argument('-c', '--context-file', type=str,
                    help='Input ttree/text file with context utterances')
    ap.add_argument('-v', '--valid-data', type=str,
                    help='Validation data paths (2-3 comma-separated files: DAs, trees/sentences, contexts)')
    ap.add_argument('-l', '--lexic-data', type=str,
                    help='Lexicalization data paths (1-2 comma-separated files: surface forms,' +
                    'training lexic. instructions)')
    ap.add_argument('-t', '--tb-summary-dir', '--tensorboard-summary-dir', '--tensorboard', type=str,
                    help='Directory where Tensorboard summaries are saved during training')

    ap.add_argument('seq2seq_config_file', type=str, help='Seq2Seq generator configuration file')
    ap.add_argument('da_train_file', type=str, help='Input training DAs')
    ap.add_argument('tree_train_file', type=str, help='Input training trees/sentences')
    ap.add_argument('seq2seq_model_file', type=str,
                    help='File name where to save the trained Seq2Seq generator model')

    args = ap.parse_args(args)

    if args.debug_logfile:
        set_debug_stream(file_stream(args.debug_logfile, mode='w'))
    if args.random_seed:
        rnd.seed(args.random_seed)

    log_info('Training sequence-to-sequence generator...')

    config = Config(args.seq2seq_config_file)

    if args.tb_summary_dir:  # override Tensorboard setting
        config['tb_summary_dir'] = args.tb_summary_dir
    if args.jobs:  # parallelize when training
        config['jobs_number'] = args.jobs
        if not args.work_dir:
            work_dir, _ = os.path.split(args.seq2seq_config_file)
        generator = ParallelSeq2SeqTraining(config, args.work_dir or work_dir, args.experiment_id)
    else:  # just a single training instance
        generator = Seq2SeqGen(config)

    generator.train(args.da_train_file, args.tree_train_file,
                    data_portion=args.train_size, context_file=args.context_file,
                    validation_files=args.valid_data, lexic_files=args.lexic_data)

    sys.setrecursionlimit(100000)
    generator.save_to_file(args.seq2seq_model_file)
示例#13
0
def seq2seq_train(args):

    ap = ArgumentParser(prog=' '.join(sys.argv[0:2]))

    ap.add_argument('-s', '--train-size', type=float,
                    help='Portion of the training data to use (default: 1.0)', default=1.0)
    ap.add_argument('-d', '--debug-logfile', type=str, help='Debug output file name')
    ap.add_argument('-j', '--jobs', type=int, help='Number of parallel jobs to use')
    ap.add_argument('-w', '--work-dir', type=str, help='Main working directory for parallel jobs')
    ap.add_argument('-e', '--experiment-id', type=str,
                    help='Experiment ID for parallel jobs (used as job name prefix)')
    ap.add_argument('-r', '--random-seed', type=str,
                    help='Initial random seed (used as string).')
    ap.add_argument('-c', '--context-file', type=str,
                    help='Input ttree/text file with context utterances')
    ap.add_argument('-v', '--valid-data', type=str,
                    help='Validation data paths (2-3 comma-separated files: DAs, trees/sentences, contexts)')
    ap.add_argument('-l', '--lexic-data', type=str,
                    help='Lexicalization data paths (1-2 comma-separated files: surface forms,' +
                    'training lexic. instructions)')
    ap.add_argument('-t', '--tb-summary-dir', '--tensorboard-summary-dir', '--tensorboard', type=str,
                    help='Directory where Tensorboard summaries are saved during training')

    ap.add_argument('seq2seq_config_file', type=str, help='Seq2Seq generator configuration file')
    ap.add_argument('da_train_file', type=str, help='Input training DAs')
    ap.add_argument('tree_train_file', type=str, help='Input training trees/sentences')
    ap.add_argument('seq2seq_model_file', type=str,
                    help='File name where to save the trained Seq2Seq generator model')

    args = ap.parse_args(args)

    if args.debug_logfile:
        set_debug_stream(file_stream(args.debug_logfile, mode='w'))
    if args.random_seed:
        rnd.seed(args.random_seed)

    log_info('Training sequence-to-sequence generator...')

    config = Config(args.seq2seq_config_file)

    if args.tb_summary_dir:  # override Tensorboard setting
        config['tb_summary_dir'] = args.tb_summary_dir
    if args.jobs:  # parallelize when training
        config['jobs_number'] = args.jobs
        if not args.work_dir:
            work_dir, _ = os.path.split(args.seq2seq_config_file)
        generator = ParallelSeq2SeqTraining(config, args.work_dir or work_dir, args.experiment_id)
    else:  # just a single training instance
        generator = Seq2SeqGen(config)

    generator.train(args.da_train_file, args.tree_train_file,
                    data_portion=args.train_size, context_file=args.context_file,
                    validation_files=args.valid_data, lexic_files=args.lexic_data)

    sys.setrecursionlimit(100000)
    generator.save_to_file(args.seq2seq_model_file)
示例#14
0
def process_file(args):

    # find out generator mode
    mode = 'trees'
    if args.model:
        with file_stream(args.model, 'rb', encoding=None) as fh:
            data = pickle.load(fh)
            if 'mode' in data['cfg']:
                mode = data['cfg']['mode']
            if 'use_tokens' in data['cfg']:
                mode = 'tokens' if data['cfg']['use_tokens'] else 'trees'

    # compose scenario

    scen = ['T2A::CS::VocalizePrepos',  # this surface stuff is needed for both tokens and tagged lemmas
            'T2A::CS::CapitalizeSentStart',
            'A2W::ConcatenateTokens',
            'A2W::CS::DetokenizeUsingRules',
            'A2W::CS::RemoveRepeatedTokens', ]
    if mode == 'tokens':
        scen = ['T2A::CopyTtree',
                'Util::Eval anode="$.set_form($.lemma);"', ] + scen
    elif mode == 'tagged_lemmas':
        scen = ['T2A::CopyTtree',
                'Util::Eval atree=\'my @as=$.get_descendants({ordered=>1}); ' +
                'while (my ($l, $t) = splice @as, 0, 2){ next if (!defined($t)); $l->set_tag($t->lemma); $t->remove(); }\'',
                'Misc::TagToMorphcat',
                'T2A::CS::GenerateWordforms', ] + scen
    else:
        # get the canonical CS generation scenario
        scen_dump_ps = subprocess.Popen('treex -d Scen::Synthesis::CS', shell=True, stdout=subprocess.PIPE)
        scen, _ = scen_dump_ps.communicate()
        scen = [block for block in scen.split("\n") if block and not block.startswith('#')]

        # insert our custom morphological processing block into it
        pos = next(i for i, block in enumerate(scen) if re.search(r'generateword', block, re.IGNORECASE))
        scen.insert(pos, 'Misc::GenerateWordformsFromJSON surface_forms="%s"' % args.surface_forms)

        # add grammatemes and clause number processing
        scen = ['Util::Eval tnode="$.set_functor(\\"???\\"); ' +
                '$.set_t_lemma(\\"\\") if (!defined($.t_lemma)); ' +
                '$.set_formeme(\\"x\\") if (!defined($.formeme));"',
                'T2T::AssignDefaultGrammatemes grammateme_file="%s" da_file="%s"' %
                (args.grammatemes, args.input_das),
                'Misc::RestoreCoordNodes',
                'T2T::CS2CS::MarkClauseHeads',
                'T2T::SetClauseNumber'] + scen

    scen = ['Read::YAML from="%s"' % args.input_file] + scen
    scen += ['Write::Treex',
             'Util::Eval document="$.set_path(\\"\\"); $.set_file_stem(\\"test\\");"',
             'Write::SgmMTEval to="%s" set_id=CsRest sys_id=TGEN add_header=tstset' % args.output_file]

    subprocess.call(('treex -Lcs -S%s ' % args.selector) + " ".join(scen), shell=True)
示例#15
0
    def load_from_file(lexicalizer_fname):
        """Load the lexicalizer model from a file (and a second file with the LM, if needed)."""
        log_info("Loading lexicalizer from %s..." % lexicalizer_fname)
        with file_stream(lexicalizer_fname, 'rb', encoding=None) as fh:
            data = pickle.load(fh)
            ret = Lexicalizer(cfg=data['cfg'])
            ret.__dict__.update(data)
            ret._form_select = ret._form_select(data['cfg'])

        if not isinstance(ret._form_select, RandomFormSelect):
            ret._form_select.load_model(lexicalizer_fname)
        return ret
示例#16
0
    def load_from_file(model_fname):
        log_info("Loading classifier from %s..." % model_fname)

        with file_stream(model_fname, 'rb', encoding=None) as fh:
            typeid = pickle.load(fh)
            if typeid != E2EPatternClassifier:
                raise ValueError('Wrong type identifier in file %s' %
                                 model_fname)
            cfg = pickle.load(fh)

        ret = E2EPatternClassifier(cfg)
        ret.__dict__.update(cfg)  # load the trained settings
        return ret
示例#17
0
def process_file(args):

    # find out generator mode
    mode = 'trees'
    if args.model:
        with file_stream(args.model, 'rb', encoding=None) as fh:
            data = pickle.load(fh)
            if 'mode' in data['cfg']:
                mode = data['cfg']['mode']
            if 'use_tokens' in data['cfg']:
                mode = 'tokens' if data['cfg']['use_tokens'] else 'trees'

    # compose scenario

    scen = [
        'T2A::CS::VocalizePrepos',  # this surface stuff is needed for both tokens and tagged lemmas
        'T2A::CS::CapitalizeSentStart',
        'A2W::ConcatenateTokens',
        'A2W::CS::DetokenizeUsingRules',
        'A2W::CS::RemoveRepeatedTokens',
    ]
    if mode == 'tokens':
        scen = [
            'T2A::CopyTtree',
            'Util::Eval anode="$.set_form($.lemma);"',
        ] + scen
    elif mode == 'tagged_lemmas':
        scen = [
            'T2A::CopyTtree',
            'Util::Eval atree="@as=$.get_descendants({ordered=>1}); ' +
            'while (($l, $t) = splice @as, 0, 2){ $l->set_tag($t.lemma); $t->remove(); }"',
            'T2A::CS::GenerateWordforms',
        ] + scen
    else:
        scen = [
            'T2T::AssignDefaultGrammatemes grammateme_file="%s" da_file="%s"' %
            (args.grammatemes, args.input_das),
            'Scen::Synthesis::CS',
        ]

    scen = ['Read::YAML from="%s"' % args.input_file] + scen
    scen += [
        'Write::Treex',
        'Util::Eval document="$.set_path(\\"\\"); $.set_file_stem(\\"test\\");"',
        'Write::SgmMTEval to="%s" set_id=CsRest sys_id=TGEN add_header=tstset'
        % args.output_file
    ]

    subprocess.call(('treex -Lcs -S%s ' % args.selector) + " ".join(scen),
                    shell=True)
示例#18
0
def seq2seq_train(args):

    ap = ArgumentParser()

    ap.add_argument('-s', '--train-size', type=float,
                    help='Portion of the training data to use (default: 1.0)', default=1.0)
    ap.add_argument('-d', '--debug-logfile', type=str, help='Debug output file name')
    ap.add_argument('-j', '--jobs', type=int, help='Number of parallel jobs to use')
    ap.add_argument('-w', '--work-dir', type=str, help='Main working for parallel jobs')
    ap.add_argument('-e', '--experiment-id', type=str,
                    help='Experiment ID for parallel jobs (used as job name prefix)')
    ap.add_argument('-r', '--random-seed', type=str,
                    help='Initial random seed (used as string).')
    ap.add_argument('-c', '--context-file', type=str,
                    help='Input ttree/text file with context utterances')
    ap.add_argument('-v', '--valid-data', type=str,
                    help='Validation data paths (2-3 comma-separated files: DAs, trees/sentences, contexts)')

    ap.add_argument('seq2seq_config_file', type=str, help='Seq2Seq generator configuration file')
    ap.add_argument('da_train_file', type=str, help='Input training DAs')
    ap.add_argument('tree_train_file', type=str, help='Input training trees/sentences')
    ap.add_argument('seq2seq_model_file', type=str,
                    help='File name where to save the trained Seq2Seq generator model')

    args = ap.parse_args(args)

    if args.debug_logfile:
        set_debug_stream(file_stream(args.debug_logfile, mode='w'))
    if args.random_seed:
        rnd.seed(rnd.seed(args.random_seed))

    log_info('Training sequence-to-sequence generator...')

    config = Config(args.seq2seq_config_file)
    if args.jobs:
        config['jobs_number'] = args.jobs
        if not args.work_dir:
            work_dir, _ = os.path.split(args.seq2seq_config_file)
        generator = ParallelSeq2SeqTraining(config, args.work_dir or work_dir, args.experiment_id)
    else:
        generator = Seq2SeqGen(config)

    generator.train(args.da_train_file, args.tree_train_file,
                    data_portion=args.train_size, context_file=args.context_file,
                    validation_files=args.valid_data)

    sys.setrecursionlimit(100000)
    generator.save_to_file(args.seq2seq_model_file)
示例#19
0
    def save_to_file(self, model_fname):
        """Save the classifier  to a file (actually two files, one for configuration and one
        for the TensorFlow graph, which must be stored separately).

        @param model_fname: file name (for the configuration file); TF graph will be stored with a \
            different extension
        """
        model_fname = self.tf_check_filename(model_fname)
        log_info("Saving classifier to %s..." % model_fname)
        with file_stream(model_fname, 'wb', encoding=None) as fh:
            pickle.dump(self.get_all_settings(), fh, protocol=pickle.HIGHEST_PROTOCOL)
        tf_session_fname = re.sub(r'(.pickle)?(.gz)?$', '.tfsess', model_fname)
        if hasattr(self, 'checkpoint_path') and self.checkpoint_path:
            self.restore_checkpoint()
            shutil.rmtree(os.path.dirname(self.checkpoint_path))
        self.saver.save(self.session, tf_session_fname)
示例#20
0
def candgen_train(args):
    opts, files = getopt(args, 'p:lnc:sd:t:')

    prune_threshold = 1
    parent_lemmas = False
    node_limits = False
    comp_type = None
    comp_limit = None
    comp_slots = False
    tree_classif = False

    for opt, arg in opts:
        if opt == '-p':
            prune_threshold = int(arg)
        elif opt == '-d':
            set_debug_stream(file_stream(arg, mode='w'))
        elif opt == '-l':
            parent_lemmas = True
        elif opt == '-n':
            node_limits = True
        elif opt == '-c':
            comp_type = arg
            if ':' in comp_type:
                comp_type, comp_limit = comp_type.split(':', 1)
                comp_limit = int(comp_limit)
        elif opt == '-t':
            tree_classif = Config(arg)
        elif opt == '-s':
            comp_slots = True

    if len(files) != 3:
        sys.exit("Invalid arguments.\n" + __doc__)
    fname_da_train, fname_ttrees_train, fname_cand_model = files

    log_info('Training candidate generator...')
    candgen = RandomCandidateGenerator({
        'prune_threshold': prune_threshold,
        'parent_lemmas': parent_lemmas,
        'node_limits': node_limits,
        'compatible_dais_type': comp_type,
        'compatible_dais_limit': comp_limit,
        'compatible_slots': comp_slots,
        'tree_classif': tree_classif
    })
    candgen.train(fname_da_train, fname_ttrees_train)
    candgen.save_to_file(fname_cand_model)
示例#21
0
文件: run_tgen.py 项目: UFAL-DSG/tgen
def candgen_train(args):
    opts, files = getopt(args, 'p:lnc:sd:t:')

    prune_threshold = 1
    parent_lemmas = False
    node_limits = False
    comp_type = None
    comp_limit = None
    comp_slots = False
    tree_classif = False

    for opt, arg in opts:
        if opt == '-p':
            prune_threshold = int(arg)
        elif opt == '-d':
            set_debug_stream(file_stream(arg, mode='w'))
        elif opt == '-l':
            parent_lemmas = True
        elif opt == '-n':
            node_limits = True
        elif opt == '-c':
            comp_type = arg
            if ':' in comp_type:
                comp_type, comp_limit = comp_type.split(':', 1)
                comp_limit = int(comp_limit)
        elif opt == '-t':
            tree_classif = Config(arg)
        elif opt == '-s':
            comp_slots = True

    if len(files) != 3:
        sys.exit("Invalid arguments.\n" + __doc__)
    fname_da_train, fname_ttrees_train, fname_cand_model = files

    log_info('Training candidate generator...')
    candgen = RandomCandidateGenerator({'prune_threshold': prune_threshold,
                                        'parent_lemmas': parent_lemmas,
                                        'node_limits': node_limits,
                                        'compatible_dais_type': comp_type,
                                        'compatible_dais_limit': comp_limit,
                                        'compatible_slots': comp_slots,
                                        'tree_classif': tree_classif})
    candgen.train(fname_da_train, fname_ttrees_train)
    candgen.save_to_file(fname_cand_model)
示例#22
0
def run_worker(head_host, head_port, debug_out=None):
    # setup debugging output, if applicable
    if debug_out is not None:
        set_debug_stream(file_stream(debug_out, mode='w'))
    # start the server (in the background)
    log_info('Creating worker server...')
    server = ThreadPoolServer(service=RankerTrainingService, nbThreads=1)
    server_thread = Thread(target=server.start)
    server_thread.start()
    my_host = socket.getfqdn()
    log_info('Worker server created at %s:%d. Connecting to head at %s:%d...' %
             (my_host, server.port, head_host, head_port))
    # notify main about this server
    conn = connect(head_host, head_port, config={'allow_pickle': True})
    conn.root.register_worker(my_host, server.port)
    conn.close()
    log_info('Worker is registered with the head.')
    # now serve until we're killed (the server thread will continue to run)
    server_thread.join()
示例#23
0
    def load_from_file(model_fname):
        """Load the reranker from a file (actually two files, one for configuration and one
        for the TensorFlow graph, which must be stored separately).

        @param model_fname: file name (for the configuration file); TF graph must be stored with a \
            different extension
        """
        log_info("Loading reranker from %s..." % model_fname)
        with file_stream(model_fname, 'rb', encoding=None) as fh:
            data = pickle.load(fh)
            if 'version' not in data:
                data['version'] = 1
            ret = RerankingClassifier(cfg=data['cfg'])
            ret.load_all_settings(data)

        # re-build TF graph and restore the TF session
        tf_session_fname = os.path.abspath(re.sub(r'(.pickle)?(.gz)?$', '.tfsess', model_fname))
        ret._init_neural_network()
        ret.saver.restore(ret.session, tf_session_fname)
        return ret
示例#24
0
文件: lexicalize.py 项目: ml-lab/tgen
 def load_surface_forms(self, surface_forms_fname):
     """Load all proper name surface forms from a file."""
     log_info('Loading surface forms from %s...' % surface_forms_fname)
     with file_stream(surface_forms_fname) as fh:
         data = json.load(fh)
     for slot, values in data.iteritems():
         sf_all = {}
         sf_formeme = {}
         sf_tag = {}
         lemma_for_sf = {}
         if slot == 'street':  # this is domain-specific: street names -> street name + number
             slot = 'address'  # TODO change this in the surface form file
         for value in values.keys():
             orig_value = value  # TODO get rid of this
             if slot == 'address':  # add street number placeholders to addresses
                 value += ' _'  # TODO change this in the surface form file
             for surface_form in values[orig_value]:
                 lemma, form, tag = surface_form.split("\t")
                 if slot == 'address':  # add street number placeholders to addresses
                     lemma += ' _'  # TODO change this in the surface form file
                     form += ' _'
                 # store the value globally + for all possible tag subsets/formemes
                 # store lemmas for formemes, forms for tags/global
                 sf_all[value] = sf_all.get(value, []) + [form]
                 sf_tag[value] = sf_tag.get(value, {})
                 sf_formeme[value] = sf_formeme.get(value, {})
                 for tag_subset in self._get_tag_subsets(tag):
                     sf_tag[value][tag_subset] = sf_tag[value].get(
                         tag_subset, []) + [form]
                 for formeme in self._get_compatible_formemes(tag):
                     sf_formeme[value][formeme] = sf_formeme[value].get(
                         formeme, []) + [lemma]
                 # store lemma for form (for lexicalizing training sentences with trees)
                 lemma_for_sf[form] = lemma
         self._sf_all[slot] = sf_all
         self._sf_by_formeme[slot] = sf_formeme
         self._sf_by_tag[slot] = sf_tag
         self._lemma_for_sf[slot] = lemma_for_sf
示例#25
0
    def load_from_file(model_fname):
        """Load the whole ensemble from a file (load settings and model parameters, then build the
        ensemble network)."""
        # TODO support for lexicalizer

        log_info("Loading ensemble generator from %s..." % model_fname)

        with file_stream(model_fname, 'rb', encoding=None) as fh:
            typeid = pickle.load(fh)
            if typeid != Seq2SeqEnsemble:
                raise ValueError('Wrong type identifier in file %s' %
                                 model_fname)
            cfg = pickle.load(fh)
            ret = Seq2SeqEnsemble(cfg)
            gens_dump = pickle.load(fh)
            if 'classif_filter' in cfg:
                rerank_settings = pickle.load(fh)
                rerank_params = pickle.load(fh)
            else:
                rerank_settings = None
                rerank_params = None

        ret.build_ensemble(gens_dump, rerank_settings, rerank_params)
        return ret
示例#26
0
文件: run_tgen.py 项目: UFAL-DSG/tgen
def seq2seq_gen(args):
    """Sequence-to-sequence generation"""

    ap = ArgumentParser(prog=' '.join(sys.argv[0:2]))

    ap.add_argument('-e', '--eval-file', type=str, help='A ttree/text file for evaluation')
    ap.add_argument('-a', '--abstr-file', type=str,
                    help='Lexicalization file (a.k.a. abstraction instructions, for postprocessing)')
    ap.add_argument('-r', '--ref-selector', type=str, default='',
                    help='Selector for reference trees in the evaluation file')
    ap.add_argument('-t', '--target-selector', type=str, default='',
                    help='Target selector for generated trees in the output file')
    ap.add_argument('-d', '--debug-logfile', type=str, help='Debug output file name')
    ap.add_argument('-w', '--output-file', type=str, help='Output tree/text file')
    ap.add_argument('-b', '--beam-size', type=int,
                    help='Override beam size for beam search decoding')
    ap.add_argument('-c', '--context-file', type=str,
                    help='Input ttree/text file with context utterances')

    ap.add_argument('seq2seq_model_file', type=str, help='Trained Seq2Seq generator model')
    ap.add_argument('da_test_file', type=str, help='Input DAs for generation')

    args = ap.parse_args(args)

    if args.debug_logfile:
        set_debug_stream(file_stream(args.debug_logfile, mode='w'))

    # load the generator
    tgen = Seq2SeqBase.load_from_file(args.seq2seq_model_file)
    if args.beam_size is not None:
        tgen.beam_size = args.beam_size

    # read input files (DAs, contexts)
    das = read_das(args.da_test_file)
    if args.context_file:
        if not tgen.use_context and not tgen.context_bleu_weight:
            log_warn('Generator is not trained to use context, ignoring context input file.')
        else:
            if args.context_file.endswith('.txt'):
                contexts = read_tokens(args.context_file)
            else:
                contexts = tokens_from_doc(read_ttrees(args.context_file),
                                           tgen.language, tgen.selector)
            das = [(context, da) for context, da in zip(contexts, das)]
    elif tgen.use_context or tgen.context_bleu_weight:
        log_warn('Generator is trained to use context. ' +
                 'Using empty contexts, expect lower performance.')
        das = [([], da) for da in das]

    # generate
    log_info('Generating...')
    gen_trees = []
    for num, da in enumerate(das, start=1):
        log_debug("\n\nTREE No. %03d" % num)
        gen_trees.append(tgen.generate_tree(da))
        if num % 100 == 0:
            log_info("Generated tree %d" % num)
    log_info(tgen.get_slot_err_stats())

    # evaluate the generated trees against golden trees (delexicalized)
    eval_doc = None
    if args.eval_file and not args.eval_file.endswith('.txt'):
        eval_doc = read_ttrees(args.eval_file)
        evaler = Evaluator()
        evaler.process_eval_doc(eval_doc, gen_trees, tgen.language, args.ref_selector,
                                args.target_selector or tgen.selector)

    # lexicalize, if required
    if args.abstr_file and tgen.lexicalizer:
        log_info('Lexicalizing...')
        tgen.lexicalize(gen_trees, args.abstr_file)

    # we won't need contexts anymore, but we do need DAs
    if tgen.use_context or tgen.context_bleu_weight:
        das = [da for _, da in das]

    # evaluate the generated & lexicalized tokens (F1 and BLEU scores)
    if args.eval_file and args.eval_file.endswith('.txt'):
        eval_tokens(das, read_tokens(args.eval_file, ref_mode=True),
                    [t.to_tok_list() for t in gen_trees])

    # write output .yaml.gz or .txt
    if args.output_file is not None:
        log_info('Writing output...')
        if args.output_file.endswith('.txt'):
            gen_toks = [t.to_tok_list() for t in gen_trees]
            postprocess_tokens(gen_toks, das)
            write_tokens(gen_toks, args.output_file)
        else:
            write_ttrees(create_ttree_doc(gen_trees, eval_doc, tgen.language,
                                          args.target_selector or tgen.selector),
                         args.output_file)
示例#27
0
 def load_from_file(fname):
     log_info('Loading model from ' + fname)
     with file_stream(fname, mode='rb', encoding=None) as fh:
         classif = pickle.load(fh)
     return classif
示例#28
0
def asearch_gen(args):
    """A*search generation"""
    from pytreex.core.document import Document

    opts, files = getopt(args, 'e:d:w:c:s:')
    eval_file = None
    fname_ttrees_out = None
    cfg_file = None
    eval_selector = ''

    for opt, arg in opts:
        if opt == '-e':
            eval_file = arg
        elif opt == '-s':
            eval_selector = arg
        elif opt == '-d':
            set_debug_stream(file_stream(arg, mode='w'))
        elif opt == '-w':
            fname_ttrees_out = arg
        elif opt == '-c':
            cfg_file = arg

    if len(files) != 3:
        sys.exit('Invalid arguments.\n' + __doc__)
    fname_cand_model, fname_rank_model, fname_da_test = files

    log_info('Initializing...')
    candgen = RandomCandidateGenerator.load_from_file(fname_cand_model)
    ranker = PerceptronRanker.load_from_file(fname_rank_model)
    cfg = Config(cfg_file) if cfg_file else {}
    cfg.update({'candgen': candgen, 'ranker': ranker})
    tgen = ASearchPlanner(cfg)

    log_info('Generating...')
    das = read_das(fname_da_test)

    if eval_file is None:
        gen_doc = Document()
    else:
        eval_doc = read_ttrees(eval_file)
        if eval_selector == tgen.selector:
            gen_doc = Document()
        else:
            gen_doc = eval_doc

    # generate and evaluate
    if eval_file is not None:
        # generate + analyze open&close lists
        lists_analyzer = ASearchListsAnalyzer()
        for num, (da, gold_tree) in enumerate(zip(
                das, trees_from_doc(eval_doc, tgen.language, eval_selector)),
                                              start=1):
            log_debug("\n\nTREE No. %03d" % num)
            gen_tree = tgen.generate_tree(da, gen_doc)
            lists_analyzer.append(gold_tree, tgen.open_list, tgen.close_list)
            if gen_tree != gold_tree:
                log_debug("\nDIFFING TREES:\n" +
                          tgen.ranker.diffing_trees_with_scores(
                              da, gold_tree, gen_tree) + "\n")

        log_info('Gold tree BEST: %.4f, on CLOSE: %.4f, on ANY list: %4f' %
                 lists_analyzer.stats())

        # evaluate the generated trees against golden trees
        eval_ttrees = ttrees_from_doc(eval_doc, tgen.language, eval_selector)
        gen_ttrees = ttrees_from_doc(gen_doc, tgen.language, tgen.selector)

        log_info('Evaluating...')
        evaler = Evaluator()
        for eval_bundle, eval_ttree, gen_ttree, da in zip(
                eval_doc.bundles, eval_ttrees, gen_ttrees, das):
            # add some stats about the tree directly into the output file
            add_bundle_text(
                eval_bundle, tgen.language, tgen.selector + 'Xscore',
                "P: %.4f R: %.4f F1: %.4f" %
                p_r_f1_from_counts(*corr_pred_gold(eval_ttree, gen_ttree)))

            # collect overall stats
            evaler.append(eval_ttree, gen_ttree,
                          ranker.score(TreeData.from_ttree(eval_ttree), da),
                          ranker.score(TreeData.from_ttree(gen_ttree), da))
        # print overall stats
        log_info("NODE precision: %.4f, Recall: %.4f, F1: %.4f" %
                 evaler.p_r_f1())
        log_info("DEP  precision: %.4f, Recall: %.4f, F1: %.4f" %
                 evaler.p_r_f1(EvalTypes.DEP))
        log_info("Tree size stats:\n * GOLD %s\n * PRED %s\n * DIFF %s" %
                 evaler.size_stats())
        log_info("Score stats:\n * GOLD %s\n * PRED %s\n * DIFF %s" %
                 evaler.score_stats())
        log_info(
            "Common subtree stats:\n -- SIZE: %s\n -- ΔGLD: %s\n -- ΔPRD: %s" %
            evaler.common_substruct_stats())
    # just generate
    else:
        for da in das:
            tgen.generate_tree(da, gen_doc)

    # write output
    if fname_ttrees_out is not None:
        log_info('Writing output...')
        write_ttrees(gen_doc, fname_ttrees_out)
示例#29
0
def seq2seq_gen(args):
    """Sequence-to-sequence generation"""
    def write_trees_or_tokens(output_file, das, gen_trees, base_doc, language,
                              selector):
        """Decide to write t-trees or tokens based on the output file name."""
        if output_file.endswith('.txt'):
            gen_toks = [t.to_tok_list() for t in gen_trees]
            postprocess_tokens(gen_toks, das)
            write_tokens(gen_toks, output_file)
        else:
            write_ttrees(
                create_ttree_doc(gen_trees, base_doc, language, selector),
                output_file)

    ap = ArgumentParser(prog=' '.join(sys.argv[0:2]))

    ap.add_argument('-e',
                    '--eval-file',
                    type=str,
                    help='A ttree/text file for evaluation')
    ap.add_argument(
        '-a',
        '--abstr-file',
        type=str,
        help=
        'Lexicalization file (a.k.a. abstraction instructions, for postprocessing)'
    )
    ap.add_argument('-r',
                    '--ref-selector',
                    type=str,
                    default='',
                    help='Selector for reference trees in the evaluation file')
    ap.add_argument(
        '-t',
        '--target-selector',
        type=str,
        default='',
        help='Target selector for generated trees in the output file')
    ap.add_argument('-d',
                    '--debug-logfile',
                    type=str,
                    help='Debug output file name')
    ap.add_argument('-w',
                    '--output-file',
                    type=str,
                    help='Output tree/text file')
    ap.add_argument('-D',
                    '--delex-output-file',
                    type=str,
                    help='Output file for trees/text before lexicalization')
    ap.add_argument('-b',
                    '--beam-size',
                    type=int,
                    help='Override beam size for beam search decoding')
    ap.add_argument('-c',
                    '--context-file',
                    type=str,
                    help='Input ttree/text file with context utterances')

    ap.add_argument('seq2seq_model_file',
                    type=str,
                    help='Trained Seq2Seq generator model')
    ap.add_argument('da_test_file', type=str, help='Input DAs for generation')

    args = ap.parse_args(args)

    if args.debug_logfile:
        set_debug_stream(file_stream(args.debug_logfile, mode='w'))

    # load the generator
    tgen = Seq2SeqBase.load_from_file(args.seq2seq_model_file)
    if args.beam_size is not None:
        tgen.beam_size = args.beam_size

    # read input files (DAs, contexts)
    das = read_das(args.da_test_file)
    if args.context_file:
        if not tgen.use_context and not tgen.context_bleu_weight:
            log_warn(
                'Generator is not trained to use context, ignoring context input file.'
            )
        else:
            if args.context_file.endswith('.txt'):
                contexts = read_tokens(args.context_file)
            else:
                contexts = tokens_from_doc(read_ttrees(args.context_file),
                                           tgen.language, tgen.selector)
            das = [(context, da) for context, da in zip(contexts, das)]
    elif tgen.use_context or tgen.context_bleu_weight:
        log_warn('Generator is trained to use context. ' +
                 'Using empty contexts, expect lower performance.')
        das = [([], da) for da in das]

    # generate
    log_info('Generating...')
    gen_trees = []
    for num, da in enumerate(das, start=1):
        log_debug("\n\nTREE No. %03d" % num)
        gen_trees.append(tgen.generate_tree(da))
        if num % 100 == 0:
            log_info("Generated tree %d" % num)
    log_info(tgen.get_slot_err_stats())

    if args.delex_output_file is not None:
        log_info('Writing delex output...')
        write_trees_or_tokens(args.delex_output_file, das, gen_trees, None,
                              tgen.language, args.target_selector
                              or tgen.selector)

    # evaluate the generated trees against golden trees (delexicalized)
    eval_doc = None
    if args.eval_file and not args.eval_file.endswith('.txt'):
        eval_doc = read_ttrees(args.eval_file)
        evaler = Evaluator()
        evaler.process_eval_doc(eval_doc, gen_trees, tgen.language,
                                args.ref_selector, args.target_selector
                                or tgen.selector)

    # lexicalize, if required
    if args.abstr_file and tgen.lexicalizer:
        log_info('Lexicalizing...')
        tgen.lexicalize(gen_trees, args.abstr_file)

    # we won't need contexts anymore, but we do need DAs
    if tgen.use_context or tgen.context_bleu_weight:
        das = [da for _, da in das]

    # evaluate the generated & lexicalized tokens (F1 and BLEU scores)
    if args.eval_file and args.eval_file.endswith('.txt'):
        eval_tokens(das, read_tokens(args.eval_file, ref_mode=True),
                    [t.to_tok_list() for t in gen_trees])

    # write output .yaml.gz or .txt
    if args.output_file is not None:
        log_info('Writing output...')
        write_trees_or_tokens(args.output_file, das, gen_trees, eval_doc,
                              tgen.language, args.target_selector
                              or tgen.selector)
def process_file(args):

    # find out generator mode
    mode = 'trees'
    if args.model:
        with file_stream(args.model, 'rb', encoding=None) as fh:
            data = pickle.load(fh)
            if 'mode' in data['cfg']:
                mode = data['cfg']['mode']
            if 'use_tokens' in data['cfg']:
                mode = 'tokens' if data['cfg']['use_tokens'] else 'trees'

    # compose scenario

    scen = [
        'T2A::CS::VocalizePrepos',  # this surface stuff is needed for both tokens and tagged lemmas
        'T2A::CS::CapitalizeSentStart',
        'A2W::ConcatenateTokens',
        'A2W::CS::DetokenizeUsingRules',
        'A2W::CS::RemoveRepeatedTokens',
    ]
    if mode == 'tokens':
        scen = [
            'T2A::CopyTtree',
            'Util::Eval anode="$.set_form($.lemma);"',
        ] + scen
    elif mode == 'tagged_lemmas':
        scen = [
            'T2A::CopyTtree',
            'Util::Eval atree=\'my @as=$.get_descendants({ordered=>1}); ' +
            'while (my ($l, $t) = splice @as, 0, 2){ next if (!defined($t)); $l->set_tag($t->lemma); $t->remove(); }\'',
            'Misc::TagToMorphcat',
            'T2A::CS::GenerateWordforms',
        ] + scen
    else:
        # get the canonical CS generation scenario
        scen_dump_ps = subprocess.Popen('treex -d Scen::Synthesis::CS',
                                        shell=True,
                                        stdout=subprocess.PIPE)
        scen, _ = scen_dump_ps.communicate()
        scen = [
            block for block in scen.split("\n")
            if block and not block.startswith('#')
        ]

        # insert our custom morphological processing block into it
        pos = next(i for i, block in enumerate(scen)
                   if re.search(r'generateword', block, re.IGNORECASE))
        scen.insert(
            pos, 'Misc::GenerateWordformsFromJSON surface_forms="%s"' %
            args.surface_forms)

        # add grammatemes and clause number processing
        scen = [
            'Util::Eval tnode="$.set_functor(\\"???\\"); ' +
            '$.set_t_lemma(\\"\\") if (!defined($.t_lemma)); ' +
            '$.set_formeme(\\"x\\") if (!defined($.formeme));"',
            'T2T::AssignDefaultGrammatemes grammateme_file="%s" da_file="%s"' %
            (args.grammatemes, args.input_das), 'Misc::RestoreCoordNodes',
            'T2T::CS2CS::MarkClauseHeads', 'T2T::SetClauseNumber'
        ] + scen

    scen = ['Read::YAML from="%s"' % args.input_file] + scen
    scen += [
        'Write::Treex',
        'Util::Eval document="$.set_path(\\"\\"); $.set_file_stem(\\"test\\");"',
        'Write::SgmMTEval to="%s" set_id=CsRest sys_id=TGEN add_header=tstset'
        % args.output_file
    ]

    subprocess.call(('treex -Lcs -S%s ' % args.selector) + " ".join(scen),
                    shell=True)
示例#31
0
 def save_to_file(self, fname):
     log_info('Saving model to ' + fname)
     with file_stream(fname, mode='wb', encoding=None) as fh:
         pickle.dump(self, fh, pickle.HIGHEST_PROTOCOL)
示例#32
0
 def load_model(self, model_fname_pattern):
     model_fname = re.sub(r'(.pickle)?(.gz)?$', '.wfreq',
                          model_fname_pattern)
     with file_stream(model_fname, 'rb', encoding=None) as fh:
         self._word_freq = pickle.load(fh)
示例#33
0
 def load_from_file(model_fname):
     """Load a pre-trained model from a file."""
     log_info("Loading ranker from %s..." % model_fname)
     with file_stream(model_fname, 'rb', encoding=None) as fh:
         return pickle.load(fh)
示例#34
0
 def save_to_file(self, model_fname):
     """Save the model to a file."""
     log_info("Saving ranker to %s..." % model_fname)
     with file_stream(model_fname, 'wb', encoding=None) as fh:
         pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL)
示例#35
0
文件: run_tgen.py 项目: qjay612/tgen
def seq2seq_gen(args):
    """Sequence-to-sequence generation"""

    ap = ArgumentParser()

    ap.add_argument('-e', '--eval-file', type=str, help='A ttree/text file for evaluation')
    ap.add_argument('-a', '--abstr-file', type=str,
                    help='Lexicalization file (a.k.a. abstraction instructions, for postprocessing)')
    ap.add_argument('-r', '--ref-selector', type=str, default='',
                    help='Selector for reference trees in the evaluation file')
    ap.add_argument('-t', '--target-selector', type=str, default='',
                    help='Target selector for generated trees in the output file')
    ap.add_argument('-d', '--debug-logfile', type=str, help='Debug output file name')
    ap.add_argument('-w', '--output-file', type=str, help='Output tree/text file')
    ap.add_argument('-b', '--beam-size', type=int,
                    help='Override beam size for beam search decoding')
    ap.add_argument('-c', '--context-file', type=str,
                    help='Input ttree/text file with context utterances')

    ap.add_argument('seq2seq_model_file', type=str, help='Trained Seq2Seq generator model')
    ap.add_argument('da_test_file', type=str, help='Input DAs for generation')

    args = ap.parse_args(args)

    if args.debug_logfile:
        set_debug_stream(file_stream(args.debug_logfile, mode='w'))

    # load the generator
    tgen = Seq2SeqBase.load_from_file(args.seq2seq_model_file)
    if args.beam_size is not None:
        tgen.beam_size = args.beam_size

    # read input files
    das = read_das(args.da_test_file)
    if args.context_file:
        if not tgen.use_context and not tgen.context_bleu_weight:
            log_warn('Generator is not trained to use context, ignoring context input file.')
        else:
            if args.context_file.endswith('.txt'):
                contexts = read_tokens(args.context_file)
            else:
                contexts = tokens_from_doc(read_ttrees(args.context_file),
                                           tgen.language, tgen.selector)
            das = [(context, da) for context, da in zip(contexts, das)]

    # generate
    log_info('Generating...')
    gen_trees = []
    for num, da in enumerate(das, start=1):
        log_debug("\n\nTREE No. %03d" % num)
        gen_trees.append(tgen.generate_tree(da))
    log_info(tgen.get_slot_err_stats())

    # evaluate the generated trees against golden trees (delexicalized)
    eval_doc = None
    if args.eval_file and not args.eval_file.endswith('.txt'):
        eval_doc = read_ttrees(args.eval_file)
        evaler = Evaluator()
        evaler.process_eval_doc(eval_doc, gen_trees, tgen.language, args.ref_selector,
                                args.target_selector or tgen.selector)

    # lexicalize, if required
    if args.abstr_file and tgen.lexicalizer:
        log_info('Lexicalizing...')
        tgen.lexicalize(gen_trees, args.abstr_file)

    # evaluate the generated & lexicalized tokens (F1 and BLEU scores)
    if args.eval_file and args.eval_file.endswith('.txt'):
        eval_tokens(das, read_tokens(args.eval_file, ref_mode=True), gen_trees)

    # write output .yaml.gz or .txt
    if args.output_file is not None:
        log_info('Writing output...')
        if args.output_file.endswith('.txt'):
            write_tokens(gen_trees, args.output_file)
        else:
            write_ttrees(create_ttree_doc(gen_trees, eval_doc, tgen.language,
                                          args.target_selector or tgen.selector),
                         args.output_file)
示例#36
0
文件: run_tgen.py 项目: UFAL-DSG/tgen
def asearch_gen(args):
    """A*search generation"""
    from pytreex.core.document import Document

    opts, files = getopt(args, 'e:d:w:c:s:')
    eval_file = None
    fname_ttrees_out = None
    cfg_file = None
    eval_selector = ''

    for opt, arg in opts:
        if opt == '-e':
            eval_file = arg
        elif opt == '-s':
            eval_selector = arg
        elif opt == '-d':
            set_debug_stream(file_stream(arg, mode='w'))
        elif opt == '-w':
            fname_ttrees_out = arg
        elif opt == '-c':
            cfg_file = arg

    if len(files) != 3:
        sys.exit('Invalid arguments.\n' + __doc__)
    fname_cand_model, fname_rank_model, fname_da_test = files

    log_info('Initializing...')
    candgen = RandomCandidateGenerator.load_from_file(fname_cand_model)
    ranker = PerceptronRanker.load_from_file(fname_rank_model)
    cfg = Config(cfg_file) if cfg_file else {}
    cfg.update({'candgen': candgen, 'ranker': ranker})
    tgen = ASearchPlanner(cfg)

    log_info('Generating...')
    das = read_das(fname_da_test)

    if eval_file is None:
        gen_doc = Document()
    else:
        eval_doc = read_ttrees(eval_file)
        if eval_selector == tgen.selector:
            gen_doc = Document()
        else:
            gen_doc = eval_doc

    # generate and evaluate
    if eval_file is not None:
        # generate + analyze open&close lists
        lists_analyzer = ASearchListsAnalyzer()
        for num, (da, gold_tree) in enumerate(zip(das,
                                                  trees_from_doc(eval_doc, tgen.language, eval_selector)),
                                              start=1):
            log_debug("\n\nTREE No. %03d" % num)
            gen_tree = tgen.generate_tree(da, gen_doc)
            lists_analyzer.append(gold_tree, tgen.open_list, tgen.close_list)
            if gen_tree != gold_tree:
                log_debug("\nDIFFING TREES:\n" + tgen.ranker.diffing_trees_with_scores(da, gold_tree, gen_tree) + "\n")

        log_info('Gold tree BEST: %.4f, on CLOSE: %.4f, on ANY list: %4f' % lists_analyzer.stats())

        # evaluate the generated trees against golden trees
        eval_ttrees = ttrees_from_doc(eval_doc, tgen.language, eval_selector)
        gen_ttrees = ttrees_from_doc(gen_doc, tgen.language, tgen.selector)

        log_info('Evaluating...')
        evaler = Evaluator()
        for eval_bundle, eval_ttree, gen_ttree, da in zip(eval_doc.bundles, eval_ttrees, gen_ttrees, das):
            # add some stats about the tree directly into the output file
            add_bundle_text(eval_bundle, tgen.language, tgen.selector + 'Xscore',
                            "P: %.4f R: %.4f F1: %.4f" % p_r_f1_from_counts(*corr_pred_gold(eval_ttree, gen_ttree)))

            # collect overall stats
            evaler.append(eval_ttree,
                          gen_ttree,
                          ranker.score(TreeData.from_ttree(eval_ttree), da),
                          ranker.score(TreeData.from_ttree(gen_ttree), da))
        # print overall stats
        log_info("NODE precision: %.4f, Recall: %.4f, F1: %.4f" % evaler.p_r_f1())
        log_info("DEP  precision: %.4f, Recall: %.4f, F1: %.4f" % evaler.p_r_f1(EvalTypes.DEP))
        log_info("Tree size stats:\n * GOLD %s\n * PRED %s\n * DIFF %s" % evaler.size_stats())
        log_info("Score stats:\n * GOLD %s\n * PRED %s\n * DIFF %s" % evaler.score_stats())
        log_info("Common subtree stats:\n -- SIZE: %s\n -- ΔGLD: %s\n -- ΔPRD: %s" %
                 evaler.common_substruct_stats())
    # just generate
    else:
        for da in das:
            tgen.generate_tree(da, gen_doc)

    # write output
    if fname_ttrees_out is not None:
        log_info('Writing output...')
        write_ttrees(gen_doc, fname_ttrees_out)