def __init__(self, args): super().__init__(args) vocabs = data_loader.load_vocabs(args) self.in_vocab = vocabs['text'] self.out_vocab = vocabs['program'] # Construct NN model if self.model_id == BRIDGE: self.mdl = Bridge(args, self.in_vocab, self.out_vocab) elif self.model_id == SEQ2SEQ_PG: self.mdl = PointerGenerator(args, self.in_vocab, self.out_vocab) elif self.model_id == SEQ2SEQ: self.mdl = Seq2Seq(args, self.in_vocab, self.out_vocab) else: raise NotImplementedError # Specify loss function if self.args.loss == 'cross_entropy': self.loss_fun = MaskedCrossEntropyLoss(self.mdl.out_vocab.pad_id) else: raise NotImplementedError # Optimizer self.define_optimizer() self.define_lr_scheduler() # Post-process _, _, self.output_post_process, _ = tok.get_tokenizers(args) print('{} module created'.format(self.model))
def demo_table(args, sp): """ Run the semantic parser from the standard input. """ sp.load_checkpoint(get_checkpoint_path(args)) sp.eval() vocabs = data_loader.load_vocabs(args) table_array = [ ['name', 'age', 'gender'], ['John', 18, 'male'], ['Kate', 19, 'female'] ] table_name = 'employees' schema_graph = schema_loader.SchemaGraph(table_name) schema_graph.load_data_from_2d_array(table_array) sys.stdout.write('Enter a natural language question: ') sys.stdout.write('> ') sys.stdout.flush() text = sys.stdin.readline() while text: example = data_utils.Text2SQLExample(0, table_name, 0) example.text = text demo_preprocess(args, example, vocabs, schema_graph) output = sp.forward([example]) for i, sql in enumerate(output['pred_decoded'][0]): print('Top {}: {}'.format(i, sql)) sys.stdout.flush() sys.stdout.write('\nEnter a natural language question: ') sys.stdout.write('> ') text = sys.stdin.readline()
def preprocess(args, dataset, process_splits=('train', 'dev', 'test'), print_aggregated_stats=False, verbose=False, save_processed_data=True): """ Data pre-processing for baselines that does only shallow processing on the schema. """ text_tokenize, program_tokenize, post_process, trans_utils = tok.get_tokenizers(args) parsed_programs = load_parsed_sqls(args, augment_with_wikisql=args.augment_with_wikisql) num_parsed_programs = len(parsed_programs) vocabs = load_vocabs(args) schema_graphs = dataset['schema'] schema_graphs.lexicalize_graphs(tokenize=text_tokenize, normalized=(args.model_id in [BRIDGE])) ############################ # data statistics ds = DatasetStatistics() ############################ # parallel data for split in process_splits: if split not in dataset: continue ds_split, sl_split = preprocess_split(dataset, split, args, parsed_programs, text_tokenize, program_tokenize, post_process, trans_utils, schema_graphs, vocabs, verbose=verbose) ds_split.print(split) sl_split.print() ############################ # update data statistics ds.accumulate(ds_split) ############################ if len(parsed_programs) > num_parsed_programs: save_parsed_sqls(args, parsed_programs) if print_aggregated_stats: ds.print() if save_processed_data: out_pkl = get_processed_data_path(args) with open(out_pkl, 'wb') as o_f: pickle.dump(dataset, o_f) print('Processed data dumped to {}'.format(out_pkl))
def __init__(self, args, cs_args, schema, ensemble_model_dirs=None): self.args = args self.text_tokenize, _, _, self.tu = tok.get_tokenizers(args) # Vocabulary self.vocabs = data_loader.load_vocabs(args) # Confusion span detector self.confusion_span_detector = load_confusion_span_detector(cs_args) # Text-to-SQL model self.semantic_parsers = [] self.model_ensemble = None if ensemble_model_dirs is None: sp = load_semantic_parser(args) sp.schema_graphs = SchemaGraphs() self.semantic_parsers.append(sp) else: sps = [EncoderDecoderLFramework(args) for _ in ensemble_model_dirs] for i, model_dir in enumerate(ensemble_model_dirs): checkpoint_path = os.path.join(model_dir, 'model-best.16.tar') sps[i].schema_graphs = SchemaGraphs() sps[i].load_checkpoint(checkpoint_path) sps[i].cuda() sps[i].eval() self.semantic_parsers = sps self.model_ensemble = [sp.mdl for sp in sps] if schema is not None: self.add_schema(schema) self.model_ensemble = None # When generating SQL in execution order, cache reordered SQLs to save time if args.process_sql_in_execution_order: self.pred_restored_cache = self.semantic_parsers[0].load_pred_restored_cache() else: self.pred_restored_cache = None