Exemple #1
0
def ensemble():
    dataset = data_loader.load_processed_data(args)
    split = 'test' if args.test else 'dev'
    dev_examples = dataset[split]
    print('{} dev examples loaded'.format(len(dev_examples)))
    if args.dataset_name == 'wikisql':
        engine_path = os.path.join(args.data_dir, '{}.db'.format(split))
        engine = DBEngine(engine_path)
    else:
        engine = None

    sps = [EncoderDecoderLFramework(args) for _ in ensemble_model_dirs]
    for i, model_dir in enumerate(ensemble_model_dirs):
        checkpoint_path = os.path.join(model_dir, 'model-best.16.tar')
        sps[i].schema_graphs = dataset['schema']
        sps[i].load_checkpoint(checkpoint_path)
        sps[i].cuda()
        sps[i].eval()

    pred_restored_cache = sps[0].load_pred_restored_cache()
    pred_restored_cache_size = sum(len(v)
                                   for v in pred_restored_cache.values())

    out_dict = sps[0].inference(dev_examples, restore_clause_order=args.process_sql_in_execution_order,
                                pred_restored_cache=pred_restored_cache,
                                check_schema_consistency_=args.sql_consistency_check, engine=engine,
                                inline_eval=True, model_ensemble=[sp.mdl for sp in sps], verbose=True)

    if args.process_sql_in_execution_order:
        new_pred_restored_cache_size = sum(
            len(v) for v in out_dict['pred_restored_cache'].values())
        newly_cached_size = new_pred_restored_cache_size - pred_restored_cache_size
        if newly_cached_size > 0:
            sps[0].save_pred_restored_cache(
                out_dict['pred_restored_cache'], newly_cached_size)

    out_txt = os.path.join(sps[0].model_dir, 'predictions.ens.{}.{}.{}.{}.txt'.format(
        args.beam_size, args.bs_alpha, split, len(ensemble_model_dirs)))
    with open(out_txt, 'w') as o_f:
        assert(len(dev_examples) == len(out_dict['pred_decoded']))
        for i, pred_sql in enumerate(out_dict['pred_decoded']):
            if args.dataset_name == 'wikisql':
                example = dev_examples[i]
                o_f.write('{}\n'.format(json.dumps(
                    {'sql': pred_sql[0], 'table_id': example.db_name})))
            else:
                o_f.write('{}\n'.format(pred_sql[0]))
        print('Model predictions saved to {}'.format(out_txt))

    print('{} set performance'.format(split.upper()))
    metrics = eval_tools.get_exact_match_metrics(
        dev_examples, out_dict['pred_decoded'], engine=engine)
    print('Top-1 exact match: {:.3f}'.format(metrics['top_1_em']))
    print('Top-2 exact match: {:.3f}'.format(metrics['top_2_em']))
    print('Top-3 exact match: {:.3f}'.format(metrics['top_3_em']))
    print('Top-5 exact match: {:.3f}'.format(metrics['top_5_em']))
    print('Top-10 exact match: {:.3f}'.format(metrics['top_10_em']))
 def evaluate(examples, out_dict):
     metrics = eval_tools.get_exact_match_metrics(examples,
                                                  out_dict['pred_decoded'],
                                                  engine=engine)
     print('Top-1 exact match: {:.3f}'.format(metrics['top_1_em']))
     print('Top-2 exact match: {:.3f}'.format(metrics['top_2_em']))
     print('Top-3 exact match: {:.3f}'.format(metrics['top_3_em']))
     print('Top-5 exact match: {:.3f}'.format(metrics['top_5_em']))
     print('Top-10 exact match: {:.3f}'.format(metrics['top_10_em']))
     if args.dataset_name == 'wikisql':
         print('Top-1 exe match: {:.3f}'.format(metrics['top_1_ex']))
         print('Top-2 exe match: {:.3f}'.format(metrics['top_2_ex']))
         print('Top-3 exe match: {:.3f}'.format(metrics['top_3_ex']))
         print('Top-5 exe match: {:.3f}'.format(metrics['top_5_ex']))
         print('Top-10 exet match: {:.3f}'.format(metrics['top_10_ex']))
     print('Table error: {:.3f}'.format(metrics['table_err']))
Exemple #3
0
 def evaluate(examples, out_dict):
     metrics = eval_tools.get_exact_match_metrics(
         examples, out_dict['pred_decoded'], engine=engine)
     print('Top-1 exact match: {:.3f}'.format(metrics['top_1_em']))
     print('Top-2 exact match: {:.3f}'.format(metrics['top_2_em']))
     print('Top-3 exact match: {:.3f}'.format(metrics['top_3_em']))
     print('Top-5 exact match: {:.3f}'.format(metrics['top_5_em']))
     print('Top-10 exact match: {:.3f}'.format(metrics['top_10_em']))
     if args.dataset_name == 'wikisql':
         print('Top-1 exe match: {:.3f}'.format(metrics['top_1_ex']))
         print('Top-2 exe match: {:.3f}'.format(metrics['top_2_ex']))
         print('Top-3 exe match: {:.3f}'.format(metrics['top_3_ex']))
         print('Top-5 exe match: {:.3f}'.format(metrics['top_5_ex']))
         print('Top-10 exet match: {:.3f}'.format(metrics['top_10_ex']))
     print('Table error: {:.3f}'.format(metrics['table_err']))
     performance = os.path.join(sp.model_dir, f"test_performance_{args.data_dir.split('/')[1]}_{args.beam_size}.txt")
     metric_keys = ['top_1_em', 'top_2_em', 'top_3_em', 'top_5_em', 'top_10_em', 'top_1_ex', 'top_2_ex', 
     'top_3_ex', 'top_5_ex', 'top_10_ex', 'table_err']
     with open(performance, 'w') as pf:
         for key in metric_keys:
             pf.write(f'{key}: {metrics[key]:.3f}\n')
Exemple #4
0
    def run_train(self, train_data, dev_data):
        self.print_model_parameters()

        import wandb
        wandb.init(project='smore-{}-group-{}-final'.format(
            self.args.dataset_name,
            get_no_join_tag(self.args, separator_in_front=True)),
                   group=get_wandb_group(self.args),
                   name=get_wandb_tag(self.args))
        os.environ["WANDB_RUN_GROUP"] = get_wandb_group(self.args)
        wandb.watch(self)

        if self.args.augment_with_wikisql:
            train_data_, train_data_augment = [], []
            for example in train_data:
                if example.dataset_id == WIKISQL:
                    train_data_augment.append(example)
                else:
                    train_data_.append(example)
            train_data = train_data_
            train_batch_size = round(self.train_batch_size * 0.7)
            train_augment_batch_size = self.train_batch_size - train_batch_size

            dev_data_, dev_data_augment = [], []
            for example in dev_data:
                if example.dataset_id == WIKISQL:
                    dev_data_augment.append(example)
                else:
                    dev_data_.append(example)
                dev_data = dev_data_
            print('**************************')
            print('{} training examples'.format(len(train_data)))
            print('{} augmented training examples'.format(
                len(train_data_augment)))
            print('train batch size = {}'.format(train_batch_size))
            print('train augment batch size = {}'.format(
                train_augment_batch_size))
            print('{} dev examples'.format(len(dev_data)))
            print('{} augmented dev examples'.format(len(dev_data_augment)))
            print('**************************')
        else:
            train_batch_size = self.train_batch_size
            train_augment_batch_size = 0

        # Track training losses dev metrics changes
        ############################
        epoch_losses = []
        best_dev_metrics = 0
        dev_metrics_history = []
        ############################

        all_train_data = copy.deepcopy(train_data)
        # Curriculum learning (start from easy category)
        if self.args.curriculum_interval > 0:
            # assert(self.args.curriculum_interval % self.args.num_peek_steps == 0)
            train_data = [
                exp for exp in all_train_data
                if exp.hardness in ['easy', 'medium']
            ]
            print('Curriculumn: [easy, medium] ({}) ------'.format(
                len(train_data)))

        num_steps = self.num_steps * self.num_accumulation_steps
        num_peek_steps = self.num_peek_steps * self.num_accumulation_steps
        curriculum_interval = self.args.curriculum_interval * self.num_accumulation_steps

        random.shuffle(train_data)
        if self.args.augment_with_wikisql:
            random.shuffle(train_data_augment)
            augment_example_id = 0
        step_id, example_id = 0, 0

        self.optim.zero_grad()
        self.train()

        for interval_step_id in range(self.start_step, num_steps,
                                      num_peek_steps):
            # Update model parameters
            self.train()

            for s_id in tqdm(range(num_peek_steps)):
                step_id = interval_step_id + s_id
                if self.log_in_wandb(step_id / self.num_accumulation_steps):
                    wandb.log({
                        'learning_rate/{}'.format(self.dataset):
                        self.optim.param_groups[0]['lr']
                    })
                    wandb.log({
                        'fine_tuning_rate/{}'.format(self.dataset):
                        self.optim.param_groups[1]['lr']
                    })

                batch_end = example_id + train_batch_size
                if curriculum_interval > 0 and step_id % curriculum_interval == 0 and \
                        0 < step_id / curriculum_interval <= 2:
                    if float(step_id) / curriculum_interval == 1:
                        train_data = [
                            exp for exp in all_train_data
                            if exp.hardness in ['easy', 'medium', 'hard']
                        ]
                        print('Curriculumn: [easy, medium, hard] ({}) ------'.
                              format(len(train_data)))
                    elif float(step_id) / curriculum_interval == 2:
                        train_data = all_train_data
                        print(
                            'Curriculumn: [easy, medium, hard, extra] ({}) ------'
                            .format(len(train_data)))
                    random.shuffle(train_data)
                    example_id, batch_end = 0, train_batch_size
                if batch_end > len(train_data):
                    random.shuffle(train_data)
                    example_id, batch_end = 0, train_batch_size
                mini_batch = train_data[example_id:batch_end]
                example_id = batch_end
                if self.args.augment_with_wikisql:
                    augment_batch_end = augment_example_id + train_augment_batch_size
                    if augment_batch_end > len(train_data_augment):
                        random.shuffle(train_data_augment)
                        augment_example_id, augment_batch_end = 0, train_augment_batch_size
                    mini_batch += train_data_augment[
                        augment_example_id:augment_batch_end]
                    augment_example_id = augment_batch_end

                formatted_batch = self.format_batch(mini_batch)
                loss = self.loss(formatted_batch)
                loss.backward()
                epoch_losses.append(float(loss) * self.num_accumulation_steps)

                if (step_id + 1) % self.num_accumulation_steps == 0:
                    # Gradient clipping
                    if self.grad_norm > 0:
                        nn.utils.clip_grad_norm_(self.parameters(),
                                                 self.grad_norm)
                    # Update learning rate scheduler
                    self.lr_scheduler.step()
                    # Update parameters
                    self.optim.step()
                    self.optim.zero_grad()

            # Check training statistics
            if step_id > 0 and (step_id + 1) % num_peek_steps == 0:
                stdout_msg = 'Step {}: average training loss = {}'.format(
                    step_id / self.num_accumulation_steps,
                    np.mean(epoch_losses))
                print(stdout_msg)
                wandb.log({
                    'cross_entropy_loss/{}'.format(self.dataset):
                    np.mean(epoch_losses)
                })
                epoch_losses = []

            # Check model performance
            if step_id > 0 and (step_id + 1) % num_peek_steps == 0:
                self.eval()
                if self.args.process_sql_in_execution_order:
                    pred_restored_cache = self.load_pred_restored_cache()
                    pred_restored_cache_size = sum(
                        len(v) for v in pred_restored_cache.values())
                else:
                    pred_restored_cache = None
                engine_path = os.path.join(
                    self.args.data_dir,
                    'dev.db') if self.args.dataset_name == 'wikisql' else None
                engine = DBEngine(engine_path) if engine_path else None

                output_dict = self.inference(
                    dev_data,
                    restore_clause_order=self.args.
                    process_sql_in_execution_order,
                    pred_restored_cache=pred_restored_cache,
                    check_schema_consistency_=self.args.sql_consistency_check,
                    engine=engine,
                    inline_eval=True,
                    verbose=False)
                metrics = eval_tools.get_exact_match_metrics(
                    dev_data, output_dict['pred_decoded'], engine=engine)
                dev_metrics_history.append(metrics)

                eval_metrics = metrics[
                    'top_1_ex'] if self.args.dataset_name == 'wikisql' else metrics[
                        'top_1_em']
                wandb.log(
                    {'dev_exact_match/{}'.format(self.dataset): eval_metrics})

                print('Dev set performance:')
                print('Top-1 exact match: {}'.format(metrics['top_1_em']))
                print('Top-3 exact match: {}'.format(metrics['top_3_em']))
                if self.args.dataset_name == 'wikisql':
                    print('Top-1 exe acc: {}'.format(metrics['top_1_ex']))
                    print('Top-3 exe acc: {}'.format(metrics['top_3_ex']))

                if eval_metrics >= best_dev_metrics:
                    best_dev_metrics = eval_metrics
                    self.save_checkpoint(step_id,
                                         step_id / num_peek_steps,
                                         output_dict['pred_decoded'],
                                         is_best=True)
                if self.args.augment_with_wikisql and (step_id + 1) % (
                        num_peek_steps * 3) == 0:
                    wikisql_output_dict = self.inference(dev_data_augment,
                                                         inline_eval=True,
                                                         verbose=False)
                    wikisql_metrics = eval_tools.get_exact_match_metrics(
                        dev_data_augment, wikisql_output_dict['pred_decoded'])
                    wandb.log({
                        'wikisql_dev_exact_match/{}'.format(self.dataset):
                        wikisql_metrics['top_1_em']
                    })
                    print('WikiSQL dev set performance:')
                    print('Top-1 exact match: {}'.format(
                        wikisql_metrics['top_1_em']))
                    print('Top-3 exact match: {}'.format(
                        wikisql_metrics['top_3_em']))
                if self.args.process_sql_in_execution_order:
                    new_pred_restored_cache_size = sum(
                        len(v)
                        for v in output_dict['pred_restored_cache'].values())
                    newly_cached_size = new_pred_restored_cache_size - pred_restored_cache_size
                    if newly_cached_size > 0:
                        self.save_pred_restored_cache(
                            output_dict['pred_restored_cache'],
                            newly_cached_size)