def ensemble(): dataset = data_loader.load_processed_data(args) split = 'test' if args.test else 'dev' dev_examples = dataset[split] print('{} dev examples loaded'.format(len(dev_examples))) if args.dataset_name == 'wikisql': engine_path = os.path.join(args.data_dir, '{}.db'.format(split)) engine = DBEngine(engine_path) else: engine = None sps = [EncoderDecoderLFramework(args) for _ in ensemble_model_dirs] for i, model_dir in enumerate(ensemble_model_dirs): checkpoint_path = os.path.join(model_dir, 'model-best.16.tar') sps[i].schema_graphs = dataset['schema'] sps[i].load_checkpoint(checkpoint_path) sps[i].cuda() sps[i].eval() pred_restored_cache = sps[0].load_pred_restored_cache() pred_restored_cache_size = sum(len(v) for v in pred_restored_cache.values()) out_dict = sps[0].inference(dev_examples, restore_clause_order=args.process_sql_in_execution_order, pred_restored_cache=pred_restored_cache, check_schema_consistency_=args.sql_consistency_check, engine=engine, inline_eval=True, model_ensemble=[sp.mdl for sp in sps], verbose=True) if args.process_sql_in_execution_order: new_pred_restored_cache_size = sum( len(v) for v in out_dict['pred_restored_cache'].values()) newly_cached_size = new_pred_restored_cache_size - pred_restored_cache_size if newly_cached_size > 0: sps[0].save_pred_restored_cache( out_dict['pred_restored_cache'], newly_cached_size) out_txt = os.path.join(sps[0].model_dir, 'predictions.ens.{}.{}.{}.{}.txt'.format( args.beam_size, args.bs_alpha, split, len(ensemble_model_dirs))) with open(out_txt, 'w') as o_f: assert(len(dev_examples) == len(out_dict['pred_decoded'])) for i, pred_sql in enumerate(out_dict['pred_decoded']): if args.dataset_name == 'wikisql': example = dev_examples[i] o_f.write('{}\n'.format(json.dumps( {'sql': pred_sql[0], 'table_id': example.db_name}))) else: o_f.write('{}\n'.format(pred_sql[0])) print('Model predictions saved to {}'.format(out_txt)) print('{} set performance'.format(split.upper())) metrics = eval_tools.get_exact_match_metrics( dev_examples, out_dict['pred_decoded'], engine=engine) print('Top-1 exact match: {:.3f}'.format(metrics['top_1_em'])) print('Top-2 exact match: {:.3f}'.format(metrics['top_2_em'])) print('Top-3 exact match: {:.3f}'.format(metrics['top_3_em'])) print('Top-5 exact match: {:.3f}'.format(metrics['top_5_em'])) print('Top-10 exact match: {:.3f}'.format(metrics['top_10_em']))
def inference(sp): dataset = data_loader.load_processed_data(args) split = 'test' if args.test else 'dev' if args.dataset_name == 'wikisql': engine_path = os.path.join(args.data_dir, '{}.db'.format(split)) engine = DBEngine(engine_path) else: engine = None def evaluate(examples, out_dict): metrics = eval_tools.get_exact_match_metrics( examples, out_dict['pred_decoded'], engine=engine) print('Top-1 exact match: {:.3f}'.format(metrics['top_1_em'])) print('Top-2 exact match: {:.3f}'.format(metrics['top_2_em'])) print('Top-3 exact match: {:.3f}'.format(metrics['top_3_em'])) print('Top-5 exact match: {:.3f}'.format(metrics['top_5_em'])) print('Top-10 exact match: {:.3f}'.format(metrics['top_10_em'])) if args.dataset_name == 'wikisql': print('Top-1 exe match: {:.3f}'.format(metrics['top_1_ex'])) print('Top-2 exe match: {:.3f}'.format(metrics['top_2_ex'])) print('Top-3 exe match: {:.3f}'.format(metrics['top_3_ex'])) print('Top-5 exe match: {:.3f}'.format(metrics['top_5_ex'])) print('Top-10 exet match: {:.3f}'.format(metrics['top_10_ex'])) print('Table error: {:.3f}'.format(metrics['table_err'])) performance = os.path.join(sp.model_dir, f"test_performance_{args.data_dir.split('/')[1]}_{args.beam_size}.txt") metric_keys = ['top_1_em', 'top_2_em', 'top_3_em', 'top_5_em', 'top_10_em', 'top_1_ex', 'top_2_ex', 'top_3_ex', 'top_5_ex', 'top_10_ex', 'table_err'] with open(performance, 'w') as pf: for key in metric_keys: pf.write(f'{key}: {metrics[key]:.3f}\n') examples = dataset[split] # random.shuffle(examples) sp.schema_graphs = dataset['schema'] print('{} {} examples loaded'.format(len(examples), split)) if sp.args.use_pred_tables: in_table = os.path.join(sp.args.model_dir, 'predicted_tables.txt') with open(in_table) as f: content = f.readlines() assert(len(content) == len(examples)) for example, line in zip(examples, content): pred_tables = set([x.strip()[1:-1] for x in line.strip()[1:-1].split(',')]) example.leaf_condition_vals_list = pred_tables sp.load_checkpoint(get_checkpoint_path(args)) sp.eval() if sp.args.augment_with_wikisql: examples_, examples_wikisql = [], [] for example in examples: if example.dataset_id == data_utils.WIKISQL: examples_wikisql.append(example) else: examples_.append(example) examples = examples_ pred_restored_cache = sp.load_pred_restored_cache() pred_restored_cache_size = sum(len(v) for v in pred_restored_cache.values()) # pred_restored_cache = None out_dict = sp.inference(examples, restore_clause_order=args.process_sql_in_execution_order, pred_restored_cache=pred_restored_cache, check_schema_consistency_=args.sql_consistency_check, engine=engine, inline_eval=True, verbose=True) if args.process_sql_in_execution_order: new_pred_restored_cache_size = sum( len(v) for v in out_dict['pred_restored_cache'].values()) newly_cached_size = new_pred_restored_cache_size - pred_restored_cache_size if newly_cached_size > 0: sp.save_pred_restored_cache( out_dict['pred_restored_cache'], newly_cached_size) out_txt = os.path.join(sp.model_dir, 'predictions.{}.{}.{}.txt'.format( args.beam_size, args.bs_alpha, split)) with open(out_txt, 'w') as o_f: assert(len(examples) == len(out_dict['pred_decoded'])) for i, pred_sql in enumerate(out_dict['pred_decoded']): if args.dataset_name == 'wikisql': example = examples[i] o_f.write('{}\n'.format(json.dumps( {'sql': pred_sql[0], 'table_id': example.db_name}))) else: o_f.write('{}\n'.format(pred_sql[0])) print('Model predictions saved to {}'.format(out_txt)) print('{} set performance'.format(split.upper())) evaluate(examples, out_dict) if args.augment_with_wikisql: wikisql_out_dict = sp.forward(examples_wikisql, verbose=False) print('*** WikiSQL ***') evaluate(examples_wikisql, wikisql_out_dict)
def run_train(self, train_data, dev_data): self.print_model_parameters() import wandb wandb.init(project='smore-{}-group-{}-final'.format( self.args.dataset_name, get_no_join_tag(self.args, separator_in_front=True)), group=get_wandb_group(self.args), name=get_wandb_tag(self.args)) os.environ["WANDB_RUN_GROUP"] = get_wandb_group(self.args) wandb.watch(self) if self.args.augment_with_wikisql: train_data_, train_data_augment = [], [] for example in train_data: if example.dataset_id == WIKISQL: train_data_augment.append(example) else: train_data_.append(example) train_data = train_data_ train_batch_size = round(self.train_batch_size * 0.7) train_augment_batch_size = self.train_batch_size - train_batch_size dev_data_, dev_data_augment = [], [] for example in dev_data: if example.dataset_id == WIKISQL: dev_data_augment.append(example) else: dev_data_.append(example) dev_data = dev_data_ print('**************************') print('{} training examples'.format(len(train_data))) print('{} augmented training examples'.format( len(train_data_augment))) print('train batch size = {}'.format(train_batch_size)) print('train augment batch size = {}'.format( train_augment_batch_size)) print('{} dev examples'.format(len(dev_data))) print('{} augmented dev examples'.format(len(dev_data_augment))) print('**************************') else: train_batch_size = self.train_batch_size train_augment_batch_size = 0 # Track training losses dev metrics changes ############################ epoch_losses = [] best_dev_metrics = 0 dev_metrics_history = [] ############################ all_train_data = copy.deepcopy(train_data) # Curriculum learning (start from easy category) if self.args.curriculum_interval > 0: # assert(self.args.curriculum_interval % self.args.num_peek_steps == 0) train_data = [ exp for exp in all_train_data if exp.hardness in ['easy', 'medium'] ] print('Curriculumn: [easy, medium] ({}) ------'.format( len(train_data))) num_steps = self.num_steps * self.num_accumulation_steps num_peek_steps = self.num_peek_steps * self.num_accumulation_steps curriculum_interval = self.args.curriculum_interval * self.num_accumulation_steps random.shuffle(train_data) if self.args.augment_with_wikisql: random.shuffle(train_data_augment) augment_example_id = 0 step_id, example_id = 0, 0 self.optim.zero_grad() self.train() for interval_step_id in range(self.start_step, num_steps, num_peek_steps): # Update model parameters self.train() for s_id in tqdm(range(num_peek_steps)): step_id = interval_step_id + s_id if self.log_in_wandb(step_id / self.num_accumulation_steps): wandb.log({ 'learning_rate/{}'.format(self.dataset): self.optim.param_groups[0]['lr'] }) wandb.log({ 'fine_tuning_rate/{}'.format(self.dataset): self.optim.param_groups[1]['lr'] }) batch_end = example_id + train_batch_size if curriculum_interval > 0 and step_id % curriculum_interval == 0 and \ 0 < step_id / curriculum_interval <= 2: if float(step_id) / curriculum_interval == 1: train_data = [ exp for exp in all_train_data if exp.hardness in ['easy', 'medium', 'hard'] ] print('Curriculumn: [easy, medium, hard] ({}) ------'. format(len(train_data))) elif float(step_id) / curriculum_interval == 2: train_data = all_train_data print( 'Curriculumn: [easy, medium, hard, extra] ({}) ------' .format(len(train_data))) random.shuffle(train_data) example_id, batch_end = 0, train_batch_size if batch_end > len(train_data): random.shuffle(train_data) example_id, batch_end = 0, train_batch_size mini_batch = train_data[example_id:batch_end] example_id = batch_end if self.args.augment_with_wikisql: augment_batch_end = augment_example_id + train_augment_batch_size if augment_batch_end > len(train_data_augment): random.shuffle(train_data_augment) augment_example_id, augment_batch_end = 0, train_augment_batch_size mini_batch += train_data_augment[ augment_example_id:augment_batch_end] augment_example_id = augment_batch_end formatted_batch = self.format_batch(mini_batch) loss = self.loss(formatted_batch) loss.backward() epoch_losses.append(float(loss) * self.num_accumulation_steps) if (step_id + 1) % self.num_accumulation_steps == 0: # Gradient clipping if self.grad_norm > 0: nn.utils.clip_grad_norm_(self.parameters(), self.grad_norm) # Update learning rate scheduler self.lr_scheduler.step() # Update parameters self.optim.step() self.optim.zero_grad() # Check training statistics if step_id > 0 and (step_id + 1) % num_peek_steps == 0: stdout_msg = 'Step {}: average training loss = {}'.format( step_id / self.num_accumulation_steps, np.mean(epoch_losses)) print(stdout_msg) wandb.log({ 'cross_entropy_loss/{}'.format(self.dataset): np.mean(epoch_losses) }) epoch_losses = [] # Check model performance if step_id > 0 and (step_id + 1) % num_peek_steps == 0: self.eval() if self.args.process_sql_in_execution_order: pred_restored_cache = self.load_pred_restored_cache() pred_restored_cache_size = sum( len(v) for v in pred_restored_cache.values()) else: pred_restored_cache = None engine_path = os.path.join( self.args.data_dir, 'dev.db') if self.args.dataset_name == 'wikisql' else None engine = DBEngine(engine_path) if engine_path else None output_dict = self.inference( dev_data, restore_clause_order=self.args. process_sql_in_execution_order, pred_restored_cache=pred_restored_cache, check_schema_consistency_=self.args.sql_consistency_check, engine=engine, inline_eval=True, verbose=False) metrics = eval_tools.get_exact_match_metrics( dev_data, output_dict['pred_decoded'], engine=engine) dev_metrics_history.append(metrics) eval_metrics = metrics[ 'top_1_ex'] if self.args.dataset_name == 'wikisql' else metrics[ 'top_1_em'] wandb.log( {'dev_exact_match/{}'.format(self.dataset): eval_metrics}) print('Dev set performance:') print('Top-1 exact match: {}'.format(metrics['top_1_em'])) print('Top-3 exact match: {}'.format(metrics['top_3_em'])) if self.args.dataset_name == 'wikisql': print('Top-1 exe acc: {}'.format(metrics['top_1_ex'])) print('Top-3 exe acc: {}'.format(metrics['top_3_ex'])) if eval_metrics >= best_dev_metrics: best_dev_metrics = eval_metrics self.save_checkpoint(step_id, step_id / num_peek_steps, output_dict['pred_decoded'], is_best=True) if self.args.augment_with_wikisql and (step_id + 1) % ( num_peek_steps * 3) == 0: wikisql_output_dict = self.inference(dev_data_augment, inline_eval=True, verbose=False) wikisql_metrics = eval_tools.get_exact_match_metrics( dev_data_augment, wikisql_output_dict['pred_decoded']) wandb.log({ 'wikisql_dev_exact_match/{}'.format(self.dataset): wikisql_metrics['top_1_em'] }) print('WikiSQL dev set performance:') print('Top-1 exact match: {}'.format( wikisql_metrics['top_1_em'])) print('Top-3 exact match: {}'.format( wikisql_metrics['top_3_em'])) if self.args.process_sql_in_execution_order: new_pred_restored_cache_size = sum( len(v) for v in output_dict['pred_restored_cache'].values()) newly_cached_size = new_pred_restored_cache_size - pred_restored_cache_size if newly_cached_size > 0: self.save_pred_restored_cache( output_dict['pred_restored_cache'], newly_cached_size)
except Exception as e: pred = repr(e) correct = (pred == gold) if engine else False match = qp == qg return correct, match if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('source_file', help='source file for the prediction') parser.add_argument('db_file', help='source database for the prediction') parser.add_argument('pred_file', help='predictions by the model') parser.add_argument('--ordered', action='store_true', help='whether the exact match should consider the order of conditions') args = parser.parse_args() engine = DBEngine(args.db_file) exact_match = [] with open(args.source_file) as fs, open(args.pred_file) as fp: grades = [] for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)): eg = json.loads(ls) ep = json.loads(lp) qg = Query.from_dict(eg['sql'], ordered=args.ordered) gold = engine.execute_query(eg['table_id'], qg, lower=True) pred = ep.get('error', None) qp = None if not ep.get('error', None): try: qp = Query.from_dict(ep['query'], ordered=args.ordered) pred = engine.execute_query(eg['table_id'], qp, lower=True) except Exception as e: