示例#1
0
    def __init__(self, args):
        super().__init__(args)
        vocabs = data_loader.load_vocabs(args)
        self.in_vocab = vocabs['text']
        self.out_vocab = vocabs['program']

        # Construct NN model
        if self.model_id == BRIDGE:
            self.mdl = Bridge(args, self.in_vocab, self.out_vocab)
        elif self.model_id == SEQ2SEQ_PG:
            self.mdl = PointerGenerator(args, self.in_vocab, self.out_vocab)
        elif self.model_id == SEQ2SEQ:
            self.mdl = Seq2Seq(args, self.in_vocab, self.out_vocab)
        else:
            raise NotImplementedError

        # Specify loss function
        if self.args.loss == 'cross_entropy':
            self.loss_fun = MaskedCrossEntropyLoss(self.mdl.out_vocab.pad_id)
        else:
            raise NotImplementedError

        # Optimizer
        self.define_optimizer()
        self.define_lr_scheduler()

        # Post-process
        _, _, self.output_post_process, _ = tok.get_tokenizers(args)

        print('{} module created'.format(self.model))
示例#2
0
def demo_table(args, sp):
    """
    Run the semantic parser from the standard input.
    """
    sp.load_checkpoint(get_checkpoint_path(args))
    sp.eval()

    vocabs = data_loader.load_vocabs(args)

    table_array = [
        ['name', 'age', 'gender'],
        ['John', 18, 'male'],
        ['Kate', 19, 'female']
    ]
    table_name = 'employees'
    schema_graph = schema_loader.SchemaGraph(table_name)
    schema_graph.load_data_from_2d_array(table_array)

    sys.stdout.write('Enter a natural language question: ')
    sys.stdout.write('> ')
    sys.stdout.flush()
    text = sys.stdin.readline()

    while text:
        example = data_utils.Text2SQLExample(0, table_name, 0)
        example.text = text
        demo_preprocess(args, example, vocabs, schema_graph)
        output = sp.forward([example])
        for i, sql in enumerate(output['pred_decoded'][0]):
            print('Top {}: {}'.format(i, sql))
        sys.stdout.flush()
        sys.stdout.write('\nEnter a natural language question: ')
        sys.stdout.write('> ')
        text = sys.stdin.readline()
示例#3
0
def preprocess(args, dataset, process_splits=('train', 'dev', 'test'), print_aggregated_stats=False, verbose=False,
               save_processed_data=True):
    """
    Data pre-processing for baselines that does only shallow processing on the schema.
    """
    text_tokenize, program_tokenize, post_process, trans_utils = tok.get_tokenizers(args)
    parsed_programs = load_parsed_sqls(args, augment_with_wikisql=args.augment_with_wikisql)
    num_parsed_programs = len(parsed_programs)

    vocabs = load_vocabs(args)

    schema_graphs = dataset['schema']
    schema_graphs.lexicalize_graphs(tokenize=text_tokenize, normalized=(args.model_id in [BRIDGE]))

    ############################
    # data statistics
    ds = DatasetStatistics()
    ############################

    # parallel data
    for split in process_splits:
        if split not in dataset:
            continue
        ds_split, sl_split = preprocess_split(dataset, split, args, parsed_programs,
                                              text_tokenize, program_tokenize, post_process, trans_utils,
                                              schema_graphs, vocabs, verbose=verbose)
        ds_split.print(split)
        sl_split.print()
        ############################
        # update data statistics
        ds.accumulate(ds_split)
        ############################

    if len(parsed_programs) > num_parsed_programs:
        save_parsed_sqls(args, parsed_programs)

    if print_aggregated_stats:
        ds.print()

    if save_processed_data:
        out_pkl = get_processed_data_path(args)
        with open(out_pkl, 'wb') as o_f:
            pickle.dump(dataset, o_f)
            print('Processed data dumped to {}'.format(out_pkl))
示例#4
0
    def __init__(self, args, cs_args, schema, ensemble_model_dirs=None):
        self.args = args
        self.text_tokenize, _, _, self.tu = tok.get_tokenizers(args)

        # Vocabulary
        self.vocabs = data_loader.load_vocabs(args)

        # Confusion span detector
        self.confusion_span_detector = load_confusion_span_detector(cs_args)

        # Text-to-SQL model
        self.semantic_parsers = []
        self.model_ensemble = None
        if ensemble_model_dirs is None:
            sp = load_semantic_parser(args)
            sp.schema_graphs = SchemaGraphs()
            self.semantic_parsers.append(sp)
        else:
            sps = [EncoderDecoderLFramework(args) for _ in ensemble_model_dirs]
            for i, model_dir in enumerate(ensemble_model_dirs):
                checkpoint_path = os.path.join(model_dir, 'model-best.16.tar')
                sps[i].schema_graphs = SchemaGraphs()
                sps[i].load_checkpoint(checkpoint_path)
                sps[i].cuda()
                sps[i].eval()
            self.semantic_parsers = sps
            self.model_ensemble = [sp.mdl for sp in sps]

        if schema is not None:
            self.add_schema(schema)

        self.model_ensemble = None

        # When generating SQL in execution order, cache reordered SQLs to save time
        if args.process_sql_in_execution_order:
            self.pred_restored_cache = self.semantic_parsers[0].load_pred_restored_cache()
        else:
            self.pred_restored_cache = None