Ejemplo n.º 1
0
    def __init__(self, args, schema, price_tracker, model_path, timed):
        super(PytorchNeuralSystem, self).__init__()
        self.schema = schema
        self.price_tracker = price_tracker
        self.timed_session = timed

        # TODO: do we need the dummy parser?
        dummy_parser = argparse.ArgumentParser(description='duh')
        options.add_model_arguments(dummy_parser)
        options.add_data_generator_arguments(dummy_parser)
        dummy_args = dummy_parser.parse_known_args([])[0]

        # Load the model.
        mappings, model, model_args = model_builder.load_test_model(
                model_path, args, dummy_args.__dict__)
        self.model_name = model_args.model
        vocab = mappings['utterance_vocab']
        self.mappings = mappings

        generator = get_generator(model, vocab, Scorer(args.alpha), args, model_args)
        builder = UtteranceBuilder(vocab, args.n_best, has_tgt=True)

        preprocessor = Preprocessor(schema, price_tracker, model_args.entity_encoding_form,
                model_args.entity_decoding_form, model_args.entity_target_form)
        textint_map = TextIntMap(vocab, preprocessor)
        remove_symbols = map(vocab.to_ind, (markers.EOS, markers.PAD))
        use_cuda = use_gpu(args)

        kb_padding = mappings['kb_vocab'].to_ind(markers.PAD)
        dialogue_batcher = DialogueBatcherFactory.get_dialogue_batcher(model=self.model_name,
            kb_pad=kb_padding,
            mappings=mappings, num_context=model_args.num_context)

        # TODO: class variable is not a good way to do this
        Dialogue.preprocessor = preprocessor
        Dialogue.textint_map = textint_map
        Dialogue.mappings = mappings
        Dialogue.num_context = model_args.num_context

        Env = namedtuple('Env', ['model', 'vocab', 'preprocessor', 'textint_map',
            'stop_symbol', 'remove_symbols', 'gt_prefix',
            'max_len', 'dialogue_batcher', 'cuda',
            'dialogue_generator', 'utterance_builder', 'model_args'])
        self.env = Env(model, vocab, preprocessor, textint_map,
            stop_symbol=vocab.to_ind(markers.EOS), remove_symbols=remove_symbols,
            gt_prefix=1,
            max_len=20, dialogue_batcher=dialogue_batcher, cuda=use_cuda,
            dialogue_generator=generator, utterance_builder=builder, model_args=model_args)
Ejemplo n.º 2
0
    # Know which arguments are for the models thus should not be
    # overwritten during test
    dummy_parser = argparse.ArgumentParser(description='duh')
    add_model_arguments(dummy_parser)
    add_data_generator_arguments(dummy_parser)
    dummy_args = dummy_parser.parse_known_args([])[0]

    if cuda.is_available() and not args.gpuid:
        print("WARNING: You have a CUDA device, should run with --gpuid 0")

    if args.gpuid:
        cuda.set_device(args.gpuid[0])

    # Load the model.
    mappings, model, model_args = \
        model_builder.load_test_model(args.checkpoint, args, dummy_args.__dict__)

    # Figure out src and tgt vocab
    make_model_mappings(model_args.model, mappings)

    schema = Schema(model_args.schema_path, None)
    data_generator = get_data_generator(args, model_args, schema, test=True)

    # Prefix: [GO]
    scorer = Scorer(args.alpha)
    generator = get_generator(model, mappings['tgt_vocab'], scorer, args,
                              model_args)
    builder = UtteranceBuilder(mappings['tgt_vocab'],
                               args.n_best,
                               has_tgt=True)
    evaluator = Evaluator(model, mappings, generator, builder, gt_prefix=1)