def evaluate(model, dev_loader, table_data, beam_size): model.eval() sketch_correct, rule_label_correct, not_all_values_found, total = 0, 0, 0, 0 predictions = [] for batch in tqdm(dev_loader, desc="Evaluating"): for data_row in batch: try: example = build_example(data_row, table_data) except Exception as e: print("Exception while building example (evaluation): {}".format(e)) continue with torch.no_grad(): results_all = model.parse(example, beam_size=beam_size) results = results_all[0] list_preds = [] try: # here we set assemble the predicted actions (including leaf-nodes) as string full_prediction = " ".join([str(x) for x in results[0].actions]) for x in results: list_preds.append(" ".join(str(x.actions))) except Exception as e: # print(e) full_prediction = "" prediction = example.sql_json['pre_sql'] # here we set assemble the predicted sketch actions as string prediction['sketch_result'] = " ".join(str(x) for x in results_all[1]) prediction['model_result'] = full_prediction truth_sketch = " ".join([str(x) for x in example.sketch]) truth_rule_label = " ".join([str(x) for x in example.tgt_actions]) if prediction['all_values_found']: if truth_sketch == prediction['sketch_result']: sketch_correct += 1 if truth_rule_label == prediction['model_result']: rule_label_correct += 1 else: question = prediction['question'] print(f'Not all values found during pre-processing for question {question}. Replace values with dummy to make query fail') prediction['values'] = [1] * len(prediction['values']) not_all_values_found += 1 total += 1 predictions.append(prediction) return float(sketch_correct) / float(total), float(rule_label_correct) / float(total), float(not_all_values_found) / float(total), predictions
def _inference_semql(data_row, schemas, model): example = build_example(data_row, schemas) with torch.no_grad(): results_all = model.parse(example, beam_size=1) results = results_all[0] # here we set assemble the predicted actions (including leaf-nodes) as string full_prediction = " ".join([str(x) for x in results[0].actions]) prediction = example.sql_json['pre_sql'] prediction['model_result'] = full_prediction return prediction, example
def train(global_step, tb_writer, train_dataloader, table_data, model, optimizer, clip_grad, sketch_loss_weight=1, lf_loss_weight=1): tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() model.train() for step, batch in enumerate(tqdm(train_dataloader, desc="Training")): examples = [] for data_row in batch: try: example = build_example(data_row, table_data) examples.append(example) except RuntimeError as e: print("Exception while building example (training): {}".format( e)) examples.sort(key=lambda e: -len(e.src_sent)) sketch_loss, lf_loss = model.forward(examples) mean_sketch_loss = torch.mean(-sketch_loss) mean_lf_loss = torch.mean(-lf_loss) loss = lf_loss_weight * mean_lf_loss + sketch_loss_weight * mean_sketch_loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad) tr_loss += loss.item() optimizer.step() model.zero_grad( ) # after we optimized the weights, we set the gradient back to zero. global_step += 1 tb_writer.add_scalar('loss', (tr_loss - logging_loss), global_step) logging_loss = tr_loss return global_step