Example #1
0
def train_reranker_and_test(args):
    print('load dataset [test %s], [dev %s]' % (args.test_file, args.dev_file), file=sys.stderr)
    test_set = Dataset.from_bin_file(args.test_file)
    dev_set = Dataset.from_bin_file(args.dev_file)

    features = []
    i = 0
    while i < len(args.features):
        feat_name = args.features[i]
        feat_cls = Registrable.by_name(feat_name)
        print('Add feature %s' % feat_name, file=sys.stderr)
        if issubclass(feat_cls, nn.Module):
            feat_path = os.path.join('saved_models/conala/', args.features[i] + '.bin')
            feat_inst = feat_cls.load(feat_path)
            print('Load feature %s from %s' % (feat_name, feat_path), file=sys.stderr)
        else:
            feat_inst = feat_cls()

        features.append(feat_inst)
        i += 1

    transition_system = next(feat.transition_system for feat in features if hasattr(feat, 'transition_system'))
    evaluator = Registrable.by_name(args.evaluator)(transition_system)


    print('load dev decode results [%s]' % args.dev_decode_file, file=sys.stderr)
    dev_decode_results = pickle.load(open(args.dev_decode_file, 'rb'))
    dev_eval_results = evaluator.evaluate_dataset(dev_set, dev_decode_results, fast_mode=False)

    print('load test decode results [%s]' % args.test_decode_file, file=sys.stderr)
    test_decode_results = pickle.load(open(args.test_decode_file, 'rb'))
    test_eval_results = evaluator.evaluate_dataset(test_set, test_decode_results, fast_mode=False)

    print('Dev Eval Results', file=sys.stderr)
    print(dev_eval_results, file=sys.stderr)
    print('Test Eval Results', file=sys.stderr)
    print(test_eval_results, file=sys.stderr)

    if args.load_reranker:
        reranker = GridSearchReranker.load(args.load_reranker)
    else:
        reranker = GridSearchReranker(features, transition_system=transition_system)

        if args.num_workers == 1:
            reranker.train(dev_set.examples, dev_decode_results, evaluator=evaluator)
        else:
            reranker.train_multiprocess(dev_set.examples, dev_decode_results, evaluator=evaluator, num_workers=args.num_workers)

        if args.save_to:
            print('Save Reranker to %s' % args.save_to, file=sys.stderr)
            reranker.save(args.save_to)

    test_score_with_rerank = reranker.compute_rerank_performance(test_set.examples, test_decode_results, verbose=True,
                                                                 evaluator=evaluator, args=args)

    print('Test Eval Results After Reranking', file=sys.stderr)
    print(test_score_with_rerank, file=sys.stderr)
Example #2
0
def eval_streg():
    test_set = Dataset.from_bin_file(sys.argv[1])
    # test_set.examples = test_set.examples[:200]
    decodes = pickle.load(open(sys.argv[2], "rb"))
    decodes = [x.decodes for x in decodes]
    results = []
    acc = 0
    for idx, (pred_hyps, gt_exs) in enumerate(zip(decodes, test_set)):
        # top_pred = pred_hyps[0]
        pred_codes = [x.code.replace(" ", "") for x in pred_hyps]
        gt_code = gt_exs.tgt_code.replace(" ", "")

        match_result = batch_filtering_test(gt_code,
                                            pred_codes,
                                            gt_exs.meta,
                                            flag_force=True)
        results.append(match_result)
        if match_result[0]:
            acc += 1
        print("{}----{}/{}".format(acc, idx, len(decodes)))

    with open("testi-first50-report.txt", "w") as f:
        for i, res in enumerate(results):
            line_fields = [str(i), str(res[0]), str(res[1])]
            line_fields.extend(
                ["{},{}".format(x[0], x[1]) for x in enumerate(res[2])])
            f.write(" ".join(line_fields) + "\n")
Example #3
0
def test(args):
    test_set = Dataset.from_bin_file(args.test_file)
    assert args.load_model

    print('load model from [%s]' % args.load_model, file=sys.stderr)
    params = torch.load(args.load_model,
                        map_location=lambda storage, loc: storage)
    vocab = params['vocab']
    transition_system = params['transition_system']
    saved_args = params['args']
    saved_state = params['state_dict']
    saved_args.cuda = args.cuda
    # set the correct domain from saved arg
    args.lang = saved_args.lang

    update_args(saved_args)

    parser_cls = get_parser_class(saved_args.lang)
    parser = parser_cls(saved_args, vocab, transition_system)

    parser.load_state_dict(saved_state)

    if args.cuda: parser = parser.cuda()
    parser.eval()

    eval_results, decode_results = evaluation.evaluate(
        test_set.examples,
        parser,
        args,
        verbose=args.verbose,
        return_decode_result=True)
    print(eval_results, file=sys.stderr)
    if args.save_decode_to:
        pickle.dump(decode_results, open(args.save_decode_to, 'wb'))
Example #4
0
    def MAF_comparison_boxplot(self):
        long_format_mafs = self._generate_maf_long_df()

        populations_to_plot = {
            "superpopulation": ['AFR', 'EUR', 'AMR'],
            "population": Dataset.used_populations(),
        }
        for population_level, long_df in long_format_mafs.items():
            population_list = populations_to_plot[population_level]
            mask = long_df["population"].isin(population_list)
            long_df = long_df[mask]
            fig_width = 13 if population_level == "population" else 7
            fig = plt.figure(figsize=(fig_width, 4))
            ax = fig.add_subplot(1, 1, 1)

            panel_labels = long_df["panel"].unique()
            colors = [v for k, v in panel_colors().items() if k in panel_labels]

            sns.boxplot(data=long_df, x="population", y="MAF", hue="panel",
                        ax=ax, linewidth=0.3, showcaps=False, showfliers=False,
                        palette=sns.color_palette(colors), width=0.70)

            self._boxplot_aesthetics(ax)

            filename = "MAF_comparison__{}".format(population_level)
            plt.savefig(join(self.PLOTS_DIR, filename), bbox_inches="tight")
            plt.show()
    def plot(self, filename, panel_list, ancestries_df):
        dataset_label, _ = self._unique_dataset_and_K_check(ancestries_df)
        dataset = Dataset(dataset_label)
        population_order = Dataset.used_populations()

        rows, cols = 1, len(panel_list)
        width, height = self.PLOT_SIZE
        fig = plt.figure(figsize=(cols * width, rows * height), dpi=30)
        fig.set_size_inches((cols*width), (rows*height))
        ax_ids = (np.arange(rows * cols) + 1).tolist()[::-1]

        # One subplot per panel
        for panel in panel_list:
            df_lite = ancestries_df.xs(panel.label, level="panel")
            df_lite = df_lite.reset_index(drop=True).set_index("population")
            plot_title = "Dataset: {}\n{}".format(dataset.name, panel.name)

            ax = fig.add_subplot(rows, cols, ax_ids.pop())
            fig, tax = ternary.figure(scale=1, ax=ax)

            df_lite = df_lite.loc[population_order]
            df_lite = df_lite[["EUR", "AFR", "AMR"]].dropna()
            df_grouped = df_lite.groupby(level="population", sort=False)

            for population, df_pop_group in df_grouped:
                tax.scatter(
                    df_pop_group.values, label=population, s=45,
                    alpha=0.75, color=population_colors(population),
                    marker=population_markers(population)
                )

            self._ternary_plot_aesthetics(tax, plot_title, df_lite)

        makedirs(self.PLOTS_DIR, exist_ok=True)
        plt.savefig(join(self.PLOTS_DIR, filename), bbox_inches="tight")
Example #6
0
def test(args):
    test_set = Dataset.from_bin_file(args.test_file)
    assert args.load_model

    print('load model from [%s]' % args.load_model, file=sys.stderr)
    params = torch.load(args.load_model,
                        map_location=lambda storage, loc: storage)
    transition_system = params['transition_system']
    saved_args = params['args']
    saved_args.cuda = args.cuda
    # set the correct domain from saved arg
    args.lang = saved_args.lang

    parser_cls = Registrable.by_name(args.parser)
    parser = parser_cls.load(model_path=args.load_model, cuda=args.cuda)
    parser.eval()
    evaluator = Registrable.by_name(args.evaluator)(transition_system,
                                                    args=args)
    eval_results, decode_results = evaluation.evaluate(
        test_set.examples,
        parser,
        evaluator,
        args,
        verbose=args.verbose,
        return_decode_result=True)
    print(eval_results, file=sys.stderr)
    if args.save_decode_to:
        pickle.dump(decode_results, open(args.save_decode_to, 'wb'))
Example #7
0
def test(args):
    test_set = Dataset.from_bin_file(args.test_file)
    assert args.load_model

    print('load model from [%s]' % args.load_model, file=sys.stderr)
    params = torch.load(args.load_model, map_location=lambda storage, loc: storage)
    vocab = params['vocab']
    transition_system = params['transition_system']
    saved_args = params['args']
    saved_state = params['state_dict']
    saved_args.cuda = args.cuda
    # set the correct domain from saved arg
    args.lang = saved_args.lang

    update_args(saved_args)

    parser_cls = get_parser_class(saved_args.lang)
    parser = parser_cls(saved_args, vocab, transition_system)

    parser.load_state_dict(saved_state)

    if args.cuda: parser = parser.cuda()
    parser.eval()

    eval_results, decode_results = evaluation.evaluate(test_set.examples, parser, args,
                                                       verbose=args.verbose, return_decode_result=True)
    print(eval_results, file=sys.stderr)
    if args.save_decode_to:
        pickle.dump(decode_results, open(args.save_decode_to, 'wb'))
Example #8
0
def test(args):
    tmpvocab={'null':0}
    tmpvocab1={'null':0,'unk':1}
    tmpvocab2={'null':0,'unk':1}
    tmpvocab3={'null':0,'unk':1}
    train_set = Dataset.from_bin_file(args.train_file)
    from dependency import sentencetoadj,sentencetoextra_message
    valid=0
    for example in tqdm.tqdm(train_set.examples):
#        print(example.src_sent)
        example.mainnode,example.adj,example.edge,_,isv=sentencetoadj(example.src_sent,tmpvocab)
        example.contains,example.pos,example.ner,example.types,example.tins,example.F1=sentencetoextra_message(example.src_sent,[item.tokens for item in example.table.header],[item.type for item in example.table.header],tmpvocab1,tmpvocab2,tmpvocab3,True)
        valid+=isv
#        print(example.adj)
#        a=input('gh')
    print('bukey',valid)
#    if args.dev_file:
#        dev_set = Dataset.from_bin_file(args.dev_file)
#    else: dev_set = Dataset(examples=[])
    test_set = Dataset.from_bin_file(args.test_file)
    for example in tqdm.tqdm(test_set.examples):
#        print(example.src_sent)
        example.mainnode,example.adj,example.edge,_,_=sentencetoadj(example.src_sent,tmpvocab)
        example.contains,example.pos,example.ner,example.types,example.tins,example.F1=sentencetoextra_message(example.src_sent,[item.tokens for item in example.table.header],[item.type for item in example.table.header],tmpvocab1,tmpvocab2,tmpvocab3,False)
    bertmodels=BertModel.from_pretrained('./pretrained_models/base-uncased/')
    tokenizer=BertTokenizer.from_pretrained('./pretrained_models/bert-base-uncased-vocab.txt')
    assert args.load_model

    print('load model from [%s]' % args.load_model, file=sys.stderr)
    params = torch.load(args.load_model, map_location=lambda storage, loc: storage)
    transition_system = params['transition_system']
    saved_args = params['args']
    saved_args.cuda = args.cuda
    # set the correct domain from saved arg
    args.lang = saved_args.lang

    parser_cls = Registrable.by_name(args.parser)
    parser = parser_cls.load(model_path=args.load_model, cuda=args.cuda,bert=bertmodels,tmpv=tmpvocab,v1=tmpvocab1,v2=tmpvocab2,v3=tmpvocab3)
    parser.tokenizer=tokenizer
    parser.eval()
    evaluator = Registrable.by_name(args.evaluator)(transition_system, args=args)
    eval_results, decode_results = evaluation.evaluate(test_set.examples, parser, evaluator, args,
                                                       verbose=args.verbose, return_decode_result=True)
    print(eval_results, file=sys.stderr)
    if args.save_decode_to:
        pickle.dump(decode_results, open(args.save_decode_to, 'wb'))
Example #9
0
def pl_debug(args):
    test_set = Dataset.from_bin_file(args.test_file)
    test_set.examples = [x for i,x in enumerate(test_set.examples) if i in debug_idx]
    assert args.load_model
    print('load model from [%s]' % args.load_model, file=sys.stderr)
    params = torch.load(args.load_model, map_location=lambda storage, loc: storage)
    transition_system = params['transition_system']
    saved_args = params['args']
    saved_args.cuda = args.cuda
    # set the correct domain from saved arg
    args.lang = saved_args.lang

    parser_cls = Registrable.by_name(args.parser)
    parser = parser_cls.load(model_path=args.load_model, cuda=args.cuda)
    parser.eval()
    evaluator = Registrable.by_name(args.evaluator)(transition_system, args=args)
    
    # decode_results, turning_point = [before_decodes, after_decodes], 
    eval_results, decode_results, debug_info = evaluation.pl_evaluate(test_set.examples, parser, evaluator, args,
                                                       verbose=args.verbose, return_decode_result=True, debug=True)
    print(eval_results, file=sys.stderr)
    # if args.save_decode_to:
    #     pickle.dump(decode_results, open(args.save_decode_to, 'wb'))

    # dump_debug_info

    with open("debug_info.txt", "w") as f:
        for idx, ex, pred_hyps, info in zip(debug_idx, test_set.examples, decode_results, debug_info):
            predictions = [x.code.replace(" ", "") for x in pred_hyps]
            f.write("----------------{}------------------\n".format(idx))
            f.write("Src: {}\n".format(" ".join(ex.src_sent)))
            f.write("Tgt: {}\n".format(ex.tgt_code.replace(" ", "")))
            f.write("Predictions:\n")
            pred_results = eval_streg_predictions(predictions, ex)
            for p, r in zip(predictions, pred_results):
                f.write("\t{} {}\n".format(r, p.replace(" ", "")))

            if info is None:
                f.write("\n")
                continue
            prev_beam, latter_beam = info
            prev_beam.sort(key=lambda hyp: -hyp.score)
            latter_beam.sort(key=lambda hyp: -hyp.score)
            f.write("Beam {}:\n".format(prev_beam[0].t))
            for p_hyp in prev_beam:
                _, partial_ast = partial_asdl_ast_to_streg_ast(p_hyp.tree)
                f.write("\t{:.2f} {}\n".format(p_hyp.score, partial_ast.debug_form()))
            
            f.write("Beam {}:\n".format(latter_beam[0].t))
            for p_hyp in latter_beam:
                _, partial_ast = partial_asdl_ast_to_streg_ast(p_hyp.tree)
                f.write("\t{:.2f} {}\n".format(p_hyp.score, partial_ast.debug_form()))
            f.write("\n")
Example #10
0
    def read_ancestry_files(self, only_optimal_Ks=False):
        dataframes = []

        datasets = Dataset.all_datasets()
        Ks = self.available_Ks()
        panels = Panel.all_panels() + Panel.all_control_panels()

        for dataset, K, panel in product(datasets, Ks, panels):

            if only_optimal_Ks and self.optimal_Ks()[dataset.label] != K:
                continue

            # Results are sorted in directories named like DATASET_PANEL
            tag = "{}_{}".format(dataset.label, panel.label)
            basedir = join(ADMIXTURE_DIR, tag)

            if not isdir(basedir):
                continue

            # Read the .Q file for ratios of ancestry per sample
            fname = "{}.{}.Q".format(tag, K)
            ancestries_df = pd.read_csv(join(basedir, fname), sep="\s+",
                                        names=list(range(K)))

            # Read the .fam file for the sample IDs (they're in the same order)
            fname = "{}.fam".format(tag)
            samples = pd.read_csv(join(basedir, fname), sep="\s+", index_col=0,
                                  usecols=[0], names=["sample"])
            ancestries_df.index = samples.index

            # Add population data to the sample IDs
            samples_df = ThousandGenomes().all_samples()
            ancestries_df = samples_df.join(ancestries_df).dropna()

            continents_present = len(ancestries_df["superpopulation"].unique())
            if continents_present >= 3:
                self.infer_ancestral_components_from_samples_origin(ancestries_df)

            self.infer_ancestral_components_from_reference_pop(ancestries_df)

            # Arrange the hierarchical index
            ancestries_df.reset_index(inplace=True)
            ancestries_df["dataset"] = dataset.label
            ancestries_df["K"] = K
            ancestries_df["panel"] = panel.label
            ancestries_df.set_index(["dataset", "K", "panel"], inplace=True)

            dataframes.append(ancestries_df)

        return pd.concat(dataframes)
Example #11
0
def eval():
    test_set = Dataset.from_bin_file(sys.argv[1])
    # test_set.examples = test_set.examples[:200]
    decodes = pickle.load(open(sys.argv[2], "rb"))
    results = []
    acc = 0
    for i, (pred_hyps, gt_exs) in enumerate(zip(decodes, test_set)):
        top_pred = pred_hyps[0]
        gt_code = post_process(gt_exs.tgt_code)
        pred_code = post_process(top_pred.code)

        eq_res = check_equiv(pred_code, gt_code)
        results.append(eq_res)
        acc += eq_res
        print(acc, i)
    print(sum(results))
Example #12
0
def test(cfg: argparse.Namespace):
    cli_logger.info("=== Testing ===")
    experiment_base_dir = prologue(cfg)

    test_set = Dataset.from_bin_file(cfg.test_file)
    cli_logger.info(f"Loaded test file from [{cfg.test_file}]")

    params = torch.load(cfg.load_model, map_location=lambda storage, loc: storage)
    cli_logger.info(f"Loaded model from [{cfg.load_model}]")

    transition_system = params['transition_system']
    cli_logger.info(f"Loaded transition system [{type(transition_system)}]")

    saved_args: argparse.Namespace = params['args']
    saved_args.cuda = cfg.cuda
    # FIXME ?? set the correct domain from saved arg
    cfg.lang = saved_args.lang

    dump_cfg(experiment_base_dir + "/was_trained_with.txt", cfg=saved_args.__dict__)

    parser_cls = Registrable.by_name(cfg.parser)
    parser = parser_cls.load(model_path=cfg.load_model, cuda=cfg.cuda)
    parser.eval()
    cli_logger.info(f"Loaded parser model [{cfg.parser}]")

    evaluator = Registrable.by_name(cfg.evaluator)(transition_system, args=cfg)
    cli_logger.info(f"Loaded evaluator [{cfg.evaluator}]")

    # Do the evaluation
    eval_results, decoded_results = evaluation.evaluate(
        examples=test_set.examples, model=parser, evaluator=evaluator, args=cfg, verbose=cfg.verbose, return_decoded_result=True
    )

    cli_logger.info(eval_results)

    if cfg.save_decode_to:
        pickle.dump(decoded_results, open(cfg.save_decode_to, 'wb'))
        cli_logger.info(f"Saved decoded results to [{cfg.save_decode_to}]")

    epilogue(cfg)
    cli_logger.info("=== Done ===")
Example #13
0
def train(args):
    """Maximum Likelihood Estimation"""

    # load in train/dev set
    train_set = Dataset.from_bin_file(args.train_file)

    vocab = pickle.load(open(args.vocab, 'rb'))

    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = Registrable.by_name(args.transition_system)(grammar)

    parser_cls = Registrable.by_name(args.parser)  # TODO: add arg
    model = parser_cls(args, vocab, transition_system)
    model.train()

    if args.cuda: model.cuda()

    optimizer_cls = eval('torch.optim.%s' %
                         args.optimizer)  # FIXME: this is evil!
    optimizer = optimizer_cls(model.parameters(), lr=args.lr)

    nn_utils.glorot_init(model.parameters())

    print('begin training, %d training examples' % len(train_set),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = 0.
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [e for e in batch_examples if len(e.tgt_actions)]
            train_iter += 1
            optimizer.zero_grad()

            ret_val = model.score(batch_examples)
            loss = -ret_val

            # print(loss.data)
            loss_val = torch.sum(loss).data
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                log_str = '[Iter %d] loss=%.5f' % (train_iter, report_loss /
                                                   report_examples)

                print(log_str, file=sys.stderr)
                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)

        model_file = args.save_to + '.iter%d.bin' % train_iter
        print('save model to [%s]' % model_file, file=sys.stderr)
        model.save(model_file)

        if epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)
Example #14
0
def train_semi_jae(args):

    bi_direction = args.bi_direction

    encoder_params = torch.load(args.load_model,
                                map_location=lambda storage, loc: storage)
    decoder_params = torch.load(args.load_decoder,
                                map_location=lambda storage, loc: storage)

    print('loaded encoder at %s' % args.load_model, file=sys.stderr)
    print('loaded decoder at %s' % args.load_decoder, file=sys.stderr)

    transition_system = encoder_params['transition_system']
    encoder_params['args'].cuda = decoder_params['args'].cuda = args.cuda

    encoder = Parser(encoder_params['args'], encoder_params['vocab'],
                     transition_system)
    encoder.load_state_dict(encoder_params['state_dict'])
    decoder = Reconstructor(decoder_params['args'], decoder_params['vocab'],
                            transition_system)
    decoder.load_state_dict(decoder_params['state_dict'])

    zprior = LSTMPrior.load(args.load_prior,
                            transition_system=transition_system,
                            cuda=args.cuda)
    print('loaded p(z) prior at %s' % args.load_prior, file=sys.stderr)
    # freeze prior parameters
    for p in zprior.parameters():
        p.requires_grad = False
    zprior.eval()
    xprior = LSTMLanguageModel.load(args.load_src_lm)
    print('loaded p(x) prior at %s' % args.load_src_lm, file=sys.stderr)
    xprior.eval()

    if args.cache:
        jae = JAE_cache(encoder, decoder, zprior, xprior, args)
    else:
        jae = JAE(encoder, decoder, zprior, xprior, args)

    jae.train()
    encoder.train()
    decoder.train()
    if args.cuda: jae.cuda()

    labeled_data = Dataset.from_bin_file(args.train_file)
    # labeled_data.examples = labeled_data.examples[:10]
    unlabeled_data = Dataset.from_bin_file(
        args.unlabeled_file)  # pretend they are un-labeled!
    dev_set = Dataset.from_bin_file(args.dev_file)
    # dev_set.examples = dev_set.examples[:10]

    optimizer = torch.optim.Adam(
        [p for p in jae.parameters() if p.requires_grad], lr=args.lr)

    print(
        '*** begin semi-supervised training %d labeled examples, %d unlabeled examples ***'
        % (len(labeled_data), len(unlabeled_data)),
        file=sys.stderr)
    report_encoder_loss = report_decoder_loss = report_examples = 0.
    report_unsup_examples = report_unsup_encoder_loss = report_unsup_decoder_loss = report_unsup_baseline_loss = 0.
    patience = 0
    num_trial = 1
    epoch = train_iter = 0
    history_dev_scores = []
    while True:
        epoch += 1
        epoch_begin = time.time()
        unlabeled_examples_iter = unlabeled_data.batch_iter(
            batch_size=args.unsup_batch_size, shuffle=True)
        for labeled_examples in labeled_data.batch_iter(
                batch_size=args.batch_size, shuffle=True):
            labeled_examples = [
                e for e in labeled_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]

            train_iter += 1

            optimizer.zero_grad()

            report_examples += len(labeled_examples)

            sup_encoder_loss = -encoder.score(labeled_examples)
            sup_decoder_loss = -decoder.score(labeled_examples)

            report_encoder_loss += sup_encoder_loss.sum().data[0]
            report_decoder_loss += sup_decoder_loss.sum().data[0]

            sup_encoder_loss = torch.mean(sup_encoder_loss)
            sup_decoder_loss = torch.mean(sup_decoder_loss)

            sup_loss = sup_encoder_loss + sup_decoder_loss

            # compute unsupervised loss

            try:
                unlabeled_examples = next(unlabeled_examples_iter)
            except StopIteration:
                # if finished unlabeled data stream, restart it
                unlabeled_examples_iter = unlabeled_data.batch_iter(
                    batch_size=args.batch_size, shuffle=True)
                unlabeled_examples = next(unlabeled_examples_iter)
                unlabeled_examples = [
                    e for e in unlabeled_examples
                    if len(e.tgt_actions) <= args.decode_max_time_step
                ]

            unsup_encoder_loss, unsup_decoder_loss, meta_data = jae.get_unsupervised_loss(
                unlabeled_examples, args.moves)
            if bi_direction:
                unsup_encoder_loss_back, unsup_decoder_loss_back, meta_data_back = jae.get_unsupervised_loss_backward(
                    unlabeled_examples, args.moves)

            nan = False
            if nn_utils.isnan(sup_loss.data):
                print('Nan in sup_loss')
                nan = True
            if nn_utils.isnan(unsup_encoder_loss.data):
                print('Nan in unsup_encoder_loss!', file=sys.stderr)
                nan = True
            if nn_utils.isnan(unsup_decoder_loss.data):
                print('Nan in unsup_decoder_loss!', file=sys.stderr)
                nan = True
            if bi_direction:
                if nn_utils.isnan(unsup_encoder_loss_back.data):
                    print('Nan in unsup_encoder_loss_back!', file=sys.stderr)
                    nan = True
                if nn_utils.isnan(unsup_decoder_loss_back.data):
                    print('Nan in unsup_decoder_loss_back!', file=sys.stderr)
                    nan = True

            if nan:
                continue
            if bi_direction:
                report_unsup_encoder_loss += (
                    unsup_encoder_loss.sum().data[0] +
                    unsup_encoder_loss_back.sum().data[0])
                report_unsup_decoder_loss += (
                    unsup_decoder_loss.sum().data[0] +
                    unsup_decoder_loss_back.sum().data[0])
            else:
                report_unsup_encoder_loss += unsup_encoder_loss.sum().data[0]
                report_unsup_decoder_loss += unsup_decoder_loss.sum().data[0]
            report_unsup_examples += unsup_encoder_loss.size(0)

            if bi_direction:
                unsup_loss = torch.mean(unsup_encoder_loss) + torch.mean(
                    unsup_decoder_loss) + torch.mean(
                        unsup_encoder_loss_back) + torch.mean(
                            unsup_decoder_loss_back)
            else:
                unsup_loss = torch.mean(unsup_encoder_loss) + torch.mean(
                    unsup_decoder_loss)
            loss = sup_loss + args.unsup_loss_weight * unsup_loss

            loss.backward()
            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(jae.parameters(),
                                                      args.clip_grad)
            optimizer.step()
            if train_iter % args.log_every == 0:
                print(
                    '[Iter %d] supervised: encoder loss=%.5f, decoder loss=%.5f'
                    % (train_iter, report_encoder_loss / report_examples,
                       report_decoder_loss / report_examples),
                    file=sys.stderr)

                print(
                    '[Iter %d] unsupervised: encoder loss=%.5f, decoder loss=%.5f, baseline loss=%.5f'
                    % (train_iter,
                       report_unsup_encoder_loss / report_unsup_examples,
                       report_unsup_decoder_loss / report_unsup_examples,
                       report_unsup_baseline_loss / report_unsup_examples),
                    file=sys.stderr)

                samples = meta_data['samples']
                for v in meta_data.values():
                    if isinstance(v, Variable): v.cpu()
                for i, sample in enumerate(samples[:1]):
                    print('\t[%s] Source: %s' %
                          (sample.idx, ' '.join(sample.src_sent)),
                          file=sys.stderr)
                    print('\t[%s] Code: \n%s' % (sample.idx, sample.tgt_code),
                          file=sys.stderr)
                    ref_example = [
                        e for e in unlabeled_examples
                        if e.idx == int(sample.idx[:sample.idx.index('-')])
                    ][0]
                    print('\t[%s] Gold Code: \n%s' %
                          (sample.idx, ref_example.tgt_code),
                          file=sys.stderr)
                    print(
                        '\t[%s] Log p(z|x): %f' %
                        (sample.idx, meta_data['encoding_scores'][i].data[0]),
                        file=sys.stderr)
                    print('\t[%s] Log p(x|z): %f' %
                          (sample.idx,
                           meta_data['reconstruction_scores'][i].data[0]),
                          file=sys.stderr)
                    print('\t[%s] Encoder Loss: %f' %
                          (sample.idx, unsup_encoder_loss[i].data[0]),
                          file=sys.stderr)
                    print('\t**************************', file=sys.stderr)

                report_encoder_loss = report_decoder_loss = report_examples = 0.
                report_unsup_encoder_loss = report_unsup_decoder_loss = report_unsup_baseline_loss = report_unsup_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)
        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)

        eval_start = time.time()
        eval_results = evaluation.evaluate(dev_set.examples,
                                           encoder,
                                           args,
                                           verbose=True)
        encoder.train()
        dev_acc = eval_results['accuracy']
        print('[Epoch %d] code generation accuracy=%.5f took %ds' %
              (epoch, dev_acc, time.time() - eval_start),
              file=sys.stderr)
        is_better = history_dev_scores == [] or dev_acc > max(
            history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            jae.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load best model's parameters
            jae.load_parameters(args.save_to + '.bin')
            if args.cuda: jae = jae.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset to a new infer_optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(
                    [p for p in jae.parameters() if p.requires_grad], lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Example #15
0
def train_semi(args):
    encoder_params = torch.load(args.load_model, map_location=lambda storage, loc: storage)
    decoder_params = torch.load(args.load_decoder, map_location=lambda storage, loc: storage)

    print('loaded encoder at %s' % args.load_model, file=sys.stderr)
    print('loaded decoder at %s' % args.load_decoder, file=sys.stderr)

    transition_system = encoder_params['transition_system']
    encoder_params['args'].cuda = decoder_params['args'].cuda = args.cuda

    encoder = Parser(encoder_params['args'], encoder_params['vocab'], transition_system)
    encoder.load_state_dict(encoder_params['state_dict'])
    decoder = Reconstructor(decoder_params['args'], decoder_params['vocab'], transition_system)
    decoder.load_state_dict(decoder_params['state_dict'])

    if args.prior == 'lstm':
        prior = LSTMPrior.load(args.load_prior, transition_system=transition_system, cuda=args.cuda)
        print('loaded prior at %s' % args.load_prior, file=sys.stderr)
        # freeze prior parameters
        for p in prior.parameters():
            p.requires_grad = False
        prior.eval()
    else:
        prior = UniformPrior()

    if args.baseline == 'mlp':
        structVAE = StructVAE(encoder, decoder, prior, args)
    elif args.baseline == 'src_lm' or args.baseline == 'src_lm_and_linear':
        src_lm = LSTMLanguageModel.load(args.load_src_lm)
        print('loaded source LM at %s' % args.load_src_lm, file=sys.stderr)
        vae_cls = StructVAE_LMBaseline if args.baseline == 'src_lm' else StructVAE_SrcLmAndLinearBaseline
        structVAE = vae_cls(encoder, decoder, prior, src_lm, args)
    else:
        raise ValueError('unknown baseline')

    structVAE.train()
    if args.cuda: structVAE.cuda()

    labeled_data = Dataset.from_bin_file(args.train_file)
    # labeled_data.examples = labeled_data.examples[:10]
    unlabeled_data = Dataset.from_bin_file(args.unlabeled_file)   # pretend they are un-labeled!
    dev_set = Dataset.from_bin_file(args.dev_file)
    # dev_set.examples = dev_set.examples[:10]

    optimizer = torch.optim.Adam(ifilter(lambda p: p.requires_grad, structVAE.parameters()), lr=args.lr)

    print('*** begin semi-supervised training %d labeled examples, %d unlabeled examples ***' %
          (len(labeled_data), len(unlabeled_data)), file=sys.stderr)
    report_encoder_loss = report_decoder_loss = report_src_sent_words_num = report_tgt_query_words_num = report_examples = 0.
    report_unsup_examples = report_unsup_encoder_loss = report_unsup_decoder_loss = report_unsup_baseline_loss = 0.
    patience = 0
    num_trial = 1
    epoch = train_iter = 0
    history_dev_scores = []
    while True:
        epoch += 1
        epoch_begin = time.time()
        unlabeled_examples_iter = unlabeled_data.batch_iter(batch_size=args.unsup_batch_size, shuffle=True)

        for labeled_examples in labeled_data.batch_iter(batch_size=args.batch_size, shuffle=True):
            labeled_examples = [e for e in labeled_examples if len(e.tgt_actions) <= args.decode_max_time_step]

            train_iter += 1
            optimizer.zero_grad()
            report_examples += len(labeled_examples)

            sup_encoder_loss = -encoder.score(labeled_examples)
            sup_decoder_loss = -decoder.score(labeled_examples)

            report_encoder_loss += sup_encoder_loss.sum().data[0]
            report_decoder_loss += sup_decoder_loss.sum().data[0]

            sup_encoder_loss = torch.mean(sup_encoder_loss)
            sup_decoder_loss = torch.mean(sup_decoder_loss)

            sup_loss = sup_encoder_loss + sup_decoder_loss

            # compute unsupervised loss
            try:
                unlabeled_examples = next(unlabeled_examples_iter)
            except StopIteration:
                # if finished unlabeled data stream, restart it
                unlabeled_examples_iter = unlabeled_data.batch_iter(batch_size=args.batch_size, shuffle=True)
                unlabeled_examples = next(unlabeled_examples_iter)
                unlabeled_examples = [e for e in unlabeled_examples if len(e.tgt_actions) <= args.decode_max_time_step]

            try:
                unsup_encoder_loss, unsup_decoder_loss, unsup_baseline_loss, meta_data = structVAE.get_unsupervised_loss(
                    unlabeled_examples)

                nan = False
                if nn_utils.isnan(sup_loss.data):
                    print('Nan in sup_loss')
                    nan = True
                if nn_utils.isnan(unsup_encoder_loss.data):
                    print('Nan in unsup_encoder_loss!', file=sys.stderr)
                    nan = True
                if nn_utils.isnan(unsup_decoder_loss.data):
                    print('Nan in unsup_decoder_loss!', file=sys.stderr)
                    nan = True
                if nn_utils.isnan(unsup_baseline_loss.data):
                    print('Nan in unsup_baseline_loss!', file=sys.stderr)
                    nan = True

                if nan:
                    # torch.save((unsup_encoder_loss, unsup_decoder_loss, unsup_baseline_loss, meta_data), 'nan_data.bin')
                    continue

                report_unsup_encoder_loss += unsup_encoder_loss.sum().data[0]
                report_unsup_decoder_loss += unsup_decoder_loss.sum().data[0]
                report_unsup_baseline_loss += unsup_baseline_loss.sum().data[0]
                report_unsup_examples += unsup_encoder_loss.size(0)
            except ValueError as e:
                print(e.message, file=sys.stderr)
                continue
            # except Exception as e:
            #     print('********** Error **********', file=sys.stderr)
            #     print('batch labeled examples: ', file=sys.stderr)
            #     for example in labeled_examples:
            #         print('%s %s' % (example.idx, ' '.join(example.src_sent)), file=sys.stderr)
            #     print('batch unlabeled examples: ', file=sys.stderr)
            #     for example in unlabeled_examples:
            #         print('%s %s' % (example.idx, ' '.join(example.src_sent)), file=sys.stderr)
            #     print(e.message, file=sys.stderr)
            #     traceback.print_exc(file=sys.stderr)
            #     for k, v in meta_data.iteritems():
            #         print('%s: %s' % (k, v), file=sys.stderr)
            #     print('********** Error **********', file=sys.stderr)
            #     continue

            unsup_loss = torch.mean(unsup_encoder_loss) + torch.mean(unsup_decoder_loss) + torch.mean(unsup_baseline_loss)

            loss = sup_loss + args.unsup_loss_weight * unsup_loss

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(structVAE.parameters(), args.clip_grad)
            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] supervised: encoder loss=%.5f, decoder loss=%.5f' %
                      (train_iter,
                       report_encoder_loss / report_examples,
                       report_decoder_loss / report_examples),
                      file=sys.stderr)

                print('[Iter %d] unsupervised: encoder loss=%.5f, decoder loss=%.5f, baseline loss=%.5f' %
                      (train_iter,
                       report_unsup_encoder_loss / report_unsup_examples,
                       report_unsup_decoder_loss / report_unsup_examples,
                       report_unsup_baseline_loss / report_unsup_examples),
                      file=sys.stderr)

                # print('[Iter %d] unsupervised: baseline=%.5f, raw learning signal=%.5f, learning signal=%.5f' % (train_iter,
                #                                                                        meta_data['baseline'].mean().data[0],
                #                                                                        meta_data['raw_learning_signal'].mean().data[0],
                #                                                                        meta_data['learning_signal'].mean().data[0]), file=sys.stderr)

                if isinstance(structVAE, StructVAE_LMBaseline):
                    print('[Iter %d] baseline: source LM b_lm_weight: %.3f, b: %.3f' % (train_iter,
                                                                                        structVAE.b_lm_weight.data[0],
                                                                                        structVAE.b.data[0]),
                          file=sys.stderr)

                samples = meta_data['samples']
                for v in meta_data.itervalues():
                    if isinstance(v, Variable): v.cpu()
                for i, sample in enumerate(samples[:15]):
                    print('\t[%s] Source: %s' % (sample.idx, ' '.join(sample.src_sent)), file=sys.stderr)
                    print('\t[%s] Code: \n%s' % (sample.idx, sample.tgt_code), file=sys.stderr)
                    ref_example = [e for e in unlabeled_examples if e.idx == int(sample.idx[:sample.idx.index('-')])][0]
                    print('\t[%s] Gold Code: \n%s' % (sample.idx, ref_example.tgt_code), file=sys.stderr)
                    print('\t[%s] Log p(z|x): %f' % (sample.idx, meta_data['encoding_scores'][i].data[0]), file=sys.stderr)
                    print('\t[%s] Log p(x|z): %f' % (sample.idx, meta_data['reconstruction_scores'][i].data[0]), file=sys.stderr)
                    print('\t[%s] KL term: %f' % (sample.idx, meta_data['kl_term'][i].data[0]), file=sys.stderr)
                    print('\t[%s] Prior: %f' % (sample.idx, meta_data['prior'][i].data[0]), file=sys.stderr)
                    print('\t[%s] baseline: %f' % (sample.idx, meta_data['baseline'][i].data[0]), file=sys.stderr)
                    print('\t[%s] Raw Learning Signal: %f' % (sample.idx, meta_data['raw_learning_signal'][i].data[0]), file=sys.stderr)
                    print('\t[%s] Learning Signal - baseline: %f' % (sample.idx, meta_data['learning_signal'][i].data[0]), file=sys.stderr)
                    print('\t[%s] Encoder Loss: %f' % (sample.idx, unsup_encoder_loss[i].data[0]), file=sys.stderr)
                    print('\t**************************', file=sys.stderr)

                report_encoder_loss = report_decoder_loss = report_examples = 0.
                report_unsup_encoder_loss = report_unsup_decoder_loss = report_unsup_baseline_loss = report_unsup_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr)
        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)

        eval_start = time.time()
        eval_results = evaluation.evaluate(dev_set.examples, encoder, args, verbose=True)
        dev_acc = eval_results['accuracy']
        print('[Epoch %d] code generation accuracy=%.5f took %ds' % (epoch, dev_acc, time.time() - eval_start),
              file=sys.stderr)
        is_better = history_dev_scores == [] or dev_acc > max(history_dev_scores)
        history_dev_scores.append(dev_acc)

        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # structVAE.save(model_file)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            structVAE.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)

            # load best model's parameters
            structVAE.load_parameters(args.save_to + '.bin')
            if args.cuda: structVAE = structVAE.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset to a new infer_optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(ifilter(lambda p: p.requires_grad, structVAE.parameters()), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Example #16
0
def log_semi(args):
    print('loading VAE at %s' % args.load_model, file=sys.stderr)
    fname, ext = os.path.splitext(args.load_model)
    encoder_path = fname + '.encoder' + ext
    decoder_path = fname + '.decoder' + ext

    vae_params = torch.load(args.load_model, map_location=lambda storage, loc: storage)
    encoder_params = torch.load(encoder_path, map_location=lambda storage, loc: storage)
    decoder_params = torch.load(decoder_path, map_location=lambda storage, loc: storage)

    transition_system = encoder_params['transition_system']
    vae_params['args'].cuda = encoder_params['args'].cuda = decoder_params['args'].cuda = args.cuda

    encoder = Parser(encoder_params['args'], encoder_params['vocab'], transition_system)
    decoder = Reconstructor(decoder_params['args'], decoder_params['vocab'], transition_system)

    if vae_params['args'].prior == 'lstm':
        prior = LSTMPrior.load(vae_params['args'].load_prior, transition_system=decoder_params['transition_system'], cuda=args.cuda)
        print('loaded prior at %s' % vae_params['args'].load_prior, file=sys.stderr)
        # freeze prior parameters
        for p in prior.parameters():
            p.requires_grad = False
        prior.eval()
    else:
        prior = UniformPrior()

    if vae_params['args'].baseline == 'mlp':
        structVAE = StructVAE(encoder, decoder, prior, vae_params['args'])
    elif vae_params['args'].baseline == 'src_lm' or vae_params['args'].baseline == 'src_lm_and_linear':
        src_lm = LSTMLanguageModel.load(vae_params['args'].load_src_lm)
        print('loaded source LM at %s' % vae_params['args'].load_src_lm, file=sys.stderr)
        Baseline = StructVAE_LMBaseline if args.baseline == 'src_lm' else StructVAE_SrcLmAndLinearBaseline
        structVAE = Baseline(encoder, decoder, prior, src_lm, vae_params['args'])
    else:
        raise ValueError('unknown baseline')

    structVAE.load_parameters(args.load_model)
    structVAE.train()
    if args.cuda: structVAE.cuda()

    unlabeled_data = Dataset.from_bin_file(args.unlabeled_file)  # pretend they are un-labeled!

    print('*** begin sampling ***', file=sys.stderr)
    start_time = time.time()
    train_iter = 0
    log_entries = []
    for unlabeled_examples in unlabeled_data.batch_iter(batch_size=args.batch_size, shuffle=False):
        unlabeled_examples = [e for e in unlabeled_examples if len(e.tgt_actions) <= args.decode_max_time_step]

        train_iter += 1
        try:
            unsup_encoder_loss, unsup_decoder_loss, unsup_baseline_loss, meta_data = structVAE.get_unsupervised_loss(
                unlabeled_examples)

        except ValueError as e:
            print(e.message, file=sys.stderr)
            continue

        samples = meta_data['samples']
        for v in meta_data.itervalues():
            if isinstance(v, Variable): v.cpu()

        for i, sample in enumerate(samples):
            ref_example = [e for e in unlabeled_examples if e.idx == int(sample.idx[:sample.idx.index('-')])][0]
            log_entry = {
                'sample': sample,
                'ref_example': ref_example,
                'log_p_z_x': meta_data['encoding_scores'][i].data[0],
                'log_p_x_z': meta_data['reconstruction_scores'][i].data[0],
                'kl': meta_data['kl_term'][i].data[0],
                'prior': meta_data['prior'][i].data[0],
                'baseline': meta_data['baseline'][i].data[0],
                'learning_signal': meta_data['raw_learning_signal'][i].data[0],
                'learning_signal - baseline': meta_data['learning_signal'][i].data[0],
                'encoder_loss': unsup_encoder_loss[i].data[0],
                'decoder_loss': unsup_decoder_loss[i].data[0]
            }

            log_entries.append(log_entry)

    print('done! took %d s' % (time.time() - start_time), file=sys.stderr)
    pkl.dump(log_entries, open(args.save_to, 'wb'))
Example #17
0
def train(args):
    """Maximum Likelihood Estimation"""

    # load in train/dev set
    train_set = Dataset.from_bin_file(args.train_file)

    if args.dev_file:
        dev_set = Dataset.from_bin_file(args.dev_file)
    else:
        dev_set = Dataset(examples=[])

    vocab = pickle.load(open(args.vocab, 'rb'))

    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = Registrable.by_name(args.transition_system)(grammar)

    parser_cls = Registrable.by_name(args.parser)  # TODO: add arg
    if args.pretrain:
        print('Finetune with: ', args.pretrain, file=sys.stderr)
        model = parser_cls.load(model_path=args.pretrain, cuda=args.cuda)
    else:
        model = parser_cls(args, vocab, transition_system)

    model.train()
    evaluator = Registrable.by_name(args.evaluator)(transition_system,
                                                    args=args)
    if args.cuda: model.cuda()

    trainable_parameters = [
        p for n, p in model.named_parameters()
        if 'automodel' not in n and p.requires_grad
    ]
    bert_parameters = [
        p for n, p in model.named_parameters()
        if 'automodel' in n and p.requires_grad
    ]

    optimizer_cls = eval('torch.optim.%s' %
                         args.optimizer)  # FIXME: this is evil!
    if args.finetune_bert:
        optimizer = optimizer_cls(trainable_parameters, lr=args.lr)
    else:
        optimizer = optimizer_cls(trainable_parameters + bert_parameters,
                                  lr=args.lr)

    if not args.pretrain:
        if args.uniform_init:
            print('uniformly initialize parameters [-%f, +%f]' %
                  (args.uniform_init, args.uniform_init),
                  file=sys.stderr)
            nn_utils.uniform_init(-args.uniform_init, args.uniform_init,
                                  trainable_parameters)
        elif args.glorot_init:
            print('use glorot initialization', file=sys.stderr)
            nn_utils.glorot_init(trainable_parameters)

        # load pre-trained word embedding (optional)
        if args.glove_embed_path:
            print('load glove embedding from: %s' % args.glove_embed_path,
                  file=sys.stderr)
            glove_embedding = GloveHelper(args.glove_embed_path)
            glove_embedding.load_to(model.src_embed, vocab.source)

    print('begin training, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = report_sup_att_loss = 0.
    history_dev_scores = []
    num_trial = patience = 0

    if args.warmup_step > 0 and args.annealing_step > args.warmup_step:
        lr_scheduler = get_linear_schedule_with_warmup(optimizer,
                                                       args.warmup_step,
                                                       args.annealing_step)

    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]
            train_iter += 1
            optimizer.zero_grad()

            ret_val = model.score(batch_examples)
            loss = -ret_val[0]

            # print(loss.data)
            loss_val = torch.sum(loss).data.item()
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            if args.sup_attention:
                att_probs = ret_val[1]
                if att_probs:
                    sup_att_loss = -torch.log(torch.cat(att_probs)).mean()
                    sup_att_loss_val = sup_att_loss.data[0]
                    report_sup_att_loss += sup_att_loss_val

                    loss += sup_att_loss

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                if args.finetune_bert:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        trainable_parameters + bert_parameters, args.clip_grad)
                else:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        trainable_parameters, args.clip_grad)

            optimizer.step()

            # warmup and annealing
            if args.warmup_step > 0 and args.annealing_step > args.warmup_step:
                lr_scheduler.step()

            if train_iter % args.log_every == 0:
                lr = optimizer.param_groups[0]['lr']
                log_str = '[Iter %d] encoder loss=%.5f, lr=%.6f' % (
                    train_iter, report_loss / report_examples, lr)
                if args.sup_attention:
                    log_str += ' supervised attention loss=%.5f' % (
                        report_sup_att_loss / report_examples)
                    report_sup_att_loss = 0.

                print(log_str, file=sys.stderr)
                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)

        if args.save_all_models:
            model_file = args.save_to + '.iter%d.bin' % train_iter
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)

        # perform validation
        is_better = False
        if args.dev_file:
            if epoch % args.valid_every_epoch == 0:
                print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
                eval_start = time.time()
                eval_results = evaluation.evaluate(
                    dev_set.examples,
                    model,
                    evaluator,
                    args,
                    verbose=False,
                    eval_top_pred_only=args.eval_top_pred_only)
                dev_score = eval_results[evaluator.default_metric]

                print(
                    '[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)'
                    % (epoch, eval_results, evaluator.default_metric,
                       dev_score, time.time() - eval_start),
                    file=sys.stderr)

                is_better = history_dev_scores == [] or dev_score > max(
                    history_dev_scores)
                history_dev_scores.append(dev_score)
        else:
            is_better = True

        if args.decay_lr_every_epoch and epoch > args.lr_decay_after_epoch:
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('decay learning rate to %f' % lr, file=sys.stderr)

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save the current model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience and epoch >= args.lr_decay_after_epoch:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)

        if patience >= args.patience and epoch >= args.lr_decay_after_epoch:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial or (
                    args.warmup_step > 0
                    and args.annealing_step > args.warmup_step):
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                if args.finetune_bert:
                    optimizer = torch.optim.Adam(trainable_parameters +
                                                 bert_parameters,
                                                 lr=lr)
                else:
                    optimizer = torch.optim.Adam(trainable_parameters, lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Example #18
0
def self_training(args):
    """Perform self-training

    First load decoding results on disjoint data
    also load pre-trained model and perform supervised
    training on both existing training data and the
    decoded results
    """

    print('load pre-trained model from [%s]' % args.load_model, file=sys.stderr)
    params = torch.load(args.load_model, map_location=lambda storage, loc: storage)
    vocab = params['vocab']
    transition_system = params['transition_system']
    saved_args = params['args']
    saved_state = params['state_dict']

    # transfer arguments
    saved_args.cuda = args.cuda
    saved_args.save_to = args.save_to
    saved_args.train_file = args.train_file
    saved_args.unlabeled_file = args.unlabeled_file
    saved_args.dev_file = args.dev_file
    saved_args.load_decode_results = args.load_decode_results
    args = saved_args

    update_args(args)

    model = Parser(saved_args, vocab, transition_system)
    model.load_state_dict(saved_state)

    if args.cuda: model = model.cuda()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    print('load unlabeled data [%s]' % args.unlabeled_file, file=sys.stderr)
    unlabeled_data = Dataset.from_bin_file(args.unlabeled_file)

    print('load decoding results of unlabeled data [%s]' % args.load_decode_results, file=sys.stderr)
    decode_results = pickle.load(open(args.load_decode_results))

    labeled_data = Dataset.from_bin_file(args.train_file)
    dev_set = Dataset.from_bin_file(args.dev_file)

    print('Num. examples in unlabeled data: %d' % len(unlabeled_data), file=sys.stderr)
    assert len(unlabeled_data) == len(decode_results)
    self_train_examples = []
    for example, hyps in zip(unlabeled_data, decode_results):
        if hyps:
            hyp = hyps[0]
            sampled_example = Example(idx='self_train-%s' % example.idx,
                                      src_sent=example.src_sent,
                                      tgt_code=hyp.code,
                                      tgt_actions=hyp.action_infos,
                                      tgt_ast=hyp.tree)
            self_train_examples.append(sampled_example)
    print('Num. self training examples: %d, Num. labeled examples: %d' % (len(self_train_examples), len(labeled_data)),
          file=sys.stderr)

    train_set = Dataset(examples=labeled_data.examples + self_train_examples)

    print('begin training, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True):
            batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step]

            train_iter += 1
            optimizer.zero_grad()

            loss = -model.score(batch_examples)
            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter,
                       report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr)
        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # model.save(model_file)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        eval_results = evaluation.evaluate(dev_set.examples, model, args, verbose=True)
        dev_acc = eval_results['accuracy']
        print('[Epoch %d] code generation accuracy=%.5f took %ds' % (epoch, dev_acc, time.time() - eval_start), file=sys.stderr)
        is_better = history_dev_scores == [] or dev_acc > max(history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(model.inference_model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Example #19
0
def self_training(args):
    """Perform self-training

    First load decoding results on disjoint data
    also load pre-trained model and perform supervised
    training on both existing training data and the
    decoded results
    """

    print('load pre-trained model from [%s]' % args.load_model,
          file=sys.stderr)
    params = torch.load(args.load_model,
                        map_location=lambda storage, loc: storage)
    vocab = params['vocab']
    transition_system = params['transition_system']
    saved_args = params['args']
    saved_state = params['state_dict']

    # transfer arguments
    saved_args.cuda = args.cuda
    saved_args.save_to = args.save_to
    saved_args.train_file = args.train_file
    saved_args.unlabeled_file = args.unlabeled_file
    saved_args.dev_file = args.dev_file
    saved_args.load_decode_results = args.load_decode_results
    args = saved_args

    update_args(args)

    model = Parser(saved_args, vocab, transition_system)
    model.load_state_dict(saved_state)

    if args.cuda: model = model.cuda()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    print('load unlabeled data [%s]' % args.unlabeled_file, file=sys.stderr)
    unlabeled_data = Dataset.from_bin_file(args.unlabeled_file)

    print('load decoding results of unlabeled data [%s]' %
          args.load_decode_results,
          file=sys.stderr)
    decode_results = pickle.load(open(args.load_decode_results))

    labeled_data = Dataset.from_bin_file(args.train_file)
    dev_set = Dataset.from_bin_file(args.dev_file)

    print('Num. examples in unlabeled data: %d' % len(unlabeled_data),
          file=sys.stderr)
    assert len(unlabeled_data) == len(decode_results)
    self_train_examples = []
    for example, hyps in zip(unlabeled_data, decode_results):
        if hyps:
            hyp = hyps[0]
            sampled_example = Example(idx='self_train-%s' % example.idx,
                                      src_sent=example.src_sent,
                                      tgt_code=hyp.code,
                                      tgt_actions=hyp.action_infos,
                                      tgt_ast=hyp.tree)
            self_train_examples.append(sampled_example)
    print('Num. self training examples: %d, Num. labeled examples: %d' %
          (len(self_train_examples), len(labeled_data)),
          file=sys.stderr)

    train_set = Dataset(examples=labeled_data.examples + self_train_examples)

    print('begin training, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]

            train_iter += 1
            optimizer.zero_grad()

            loss = -model.score(batch_examples)
            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter, report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)
        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # model.save(model_file)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        eval_results = evaluation.evaluate(dev_set.examples,
                                           model,
                                           args,
                                           verbose=True)
        dev_acc = eval_results['accuracy']
        print('[Epoch %d] code generation accuracy=%.5f took %ds' %
              (epoch, dev_acc, time.time() - eval_start),
              file=sys.stderr)
        is_better = history_dev_scores == [] or dev_acc > max(
            history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(
                    model.inference_model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Example #20
0
def train_rerank_feature(args):
    train_set = Dataset.from_bin_file(args.train_file)
    dev_set = Dataset.from_bin_file(args.dev_file)
    vocab = pickle.load(open(args.vocab, 'rb'))

    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = TransitionSystem.get_class_by_lang(args.lang)(grammar)

    train_paraphrase_model = args.mode == 'train_paraphrase_identifier'

    def _get_feat_class():
        if args.mode == 'train_reconstructor':
            return Reconstructor
        elif args.mode == 'train_paraphrase_identifier':
            return ParaphraseIdentificationModel

    def _filter_hyps(_decode_results):
        for i in range(len(_decode_results)):
            valid_hyps = []
            for hyp in _decode_results[i]:
                try:
                    transition_system.tokenize_code(hyp.code)
                    valid_hyps.append(hyp)
                except: pass

            _decode_results[i] = valid_hyps

    model = _get_feat_class()(args, vocab, transition_system)

    if args.glorot_init:
        print('use glorot initialization', file=sys.stderr)
        nn_utils.glorot_init(model.parameters())

    model.train()
    if args.cuda: model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # if training the paraphrase model, also load in decoding results
    if train_paraphrase_model:
        print('load training decode results [%s]' % args.train_decode_file, file=sys.stderr)
        train_decode_results = pickle.load(open(args.train_decode_file, 'rb'))
        _filter_hyps(train_decode_results)
        train_decode_results = {e.idx: hyps for e, hyps in zip(train_set, train_decode_results)}

        print('load dev decode results [%s]' % args.dev_decode_file, file=sys.stderr)
        dev_decode_results = pickle.load(open(args.dev_decode_file, 'rb'))
        _filter_hyps(dev_decode_results)
        dev_decode_results = {e.idx: hyps for e, hyps in zip(dev_set, dev_decode_results)}

    def evaluate_ppl():
        model.eval()
        cum_loss = 0.
        cum_tgt_words = 0.
        for batch in dev_set.batch_iter(args.batch_size):
            loss = -model.score(batch).sum()
            cum_loss += loss.data.item()
            cum_tgt_words += sum(len(e.src_sent) + 1 for e in batch)  # add ending </s>

        ppl = np.exp(cum_loss / cum_tgt_words)
        model.train()
        return ppl

    def evaluate_paraphrase_acc():
        model.eval()
        labels = []
        for batch in dev_set.batch_iter(args.batch_size):
            probs = model.score(batch).exp().data.cpu().numpy()
            for p in probs:
                labels.append(p >= 0.5)

            # get negative examples
            batch_decoding_results = [dev_decode_results[e.idx] for e in batch]
            batch_negative_examples = [get_negative_example(e, _hyps, type='best')
                                       for e, _hyps in zip(batch, batch_decoding_results)]
            batch_negative_examples = list(filter(None, batch_negative_examples))
            probs = model.score(batch_negative_examples).exp().data.cpu().numpy()
            for p in probs:
                labels.append(p < 0.5)

        acc = np.average(labels)
        model.train()
        return acc

    def get_negative_example(_example, _hyps, type='sample'):
        incorrect_hyps = [hyp for hyp in _hyps if not hyp.is_correct]
        if incorrect_hyps:
            incorrect_hyp_scores = [hyp.score for hyp in incorrect_hyps]
            if type in ('best', 'sample'):
                if type == 'best':
                    sample_idx = np.argmax(incorrect_hyp_scores)
                    sampled_hyp = incorrect_hyps[sample_idx]
                else:
                    incorrect_hyp_probs = [np.exp(score) for score in incorrect_hyp_scores]
                    incorrect_hyp_probs = np.array(incorrect_hyp_probs) / sum(incorrect_hyp_probs)
                    sampled_hyp = np.random.choice(incorrect_hyps, size=1, p=incorrect_hyp_probs)
                    sampled_hyp = sampled_hyp[0]

                sample = Example(idx='negative-%s' % _example.idx,
                                 src_sent=_example.src_sent,
                                 tgt_code=sampled_hyp.code,
                                 tgt_actions=None,
                                 tgt_ast=None)
                return sample
            elif type == 'all':
                samples = []
                for i, hyp in enumerate(incorrect_hyps):
                    sample = Example(idx='negative-%s-%d' % (_example.idx, i),
                                     src_sent=_example.src_sent,
                                     tgt_code=hyp.code,
                                     tgt_actions=None,
                                     tgt_ast=None)
                    samples.append(sample)

                return samples
        else:
            return None

    print('begin training decoder, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True):
            batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step]

            if train_paraphrase_model:
                positive_examples_num = len(batch_examples)
                labels = [0] * len(batch_examples)
                negative_samples = []
                batch_decoding_results = [train_decode_results[e.idx] for e in batch_examples]
                # sample negative examples
                for example, hyps in zip(batch_examples, batch_decoding_results):
                    if hyps:
                        negative_sample = get_negative_example(example, hyps, type=args.negative_sample_type)
                        if negative_sample:
                            if isinstance(negative_sample, Example):
                                negative_samples.append(negative_sample)
                                labels.append(1)
                            else:
                                negative_samples.extend(negative_sample)
                                labels.extend([1] * len(negative_sample))

                batch_examples += negative_samples

            train_iter += 1
            optimizer.zero_grad()

            nll = -model(batch_examples)
            if train_paraphrase_model:
                idx_tensor = Variable(torch.LongTensor(labels).unsqueeze(-1), requires_grad=False)
                if args.cuda: idx_tensor = idx_tensor.cuda()
                loss = torch.gather(nll, 1, idx_tensor)
            else:
                loss = nll

            # print(loss.data)
            loss_val = torch.sum(loss).data.item()
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter,
                       report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        # evaluate dev_score
        dev_acc = evaluate_paraphrase_acc() if train_paraphrase_model else -evaluate_ppl()
        print('[Epoch %d] dev_score=%.5f took %ds' % (epoch, dev_acc, time.time() - eval_start), file=sys.stderr)
        is_better = history_dev_scores == [] or dev_acc > max(history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(model.inference_model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Example #21
0
def train(args):
    train_set = Dataset.from_bin_file(args.train_file)
    dev_set = Dataset.from_bin_file(args.dev_file)
    vocab = pickle.load(open(args.vocab, 'rb'))

    model = LSTMLanguageModel(vocab.source,
                              args.embed_size,
                              args.hidden_size,
                              dropout=args.dropout)
    model.train()
    if args.cuda: model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    def evaluate_ppl():
        model.eval()
        cum_loss = 0.
        cum_tgt_words = 0.
        for batch in dev_set.batch_iter(args.batch_size):
            src_sents_var = nn_utils.to_input_variable(
                [e.src_sent for e in batch],
                vocab.source,
                cuda=args.cuda,
                append_boundary_sym=True)
            loss = model(src_sents_var).sum()
            cum_loss += loss.data[0]
            cum_tgt_words += sum(len(e.src_sent) + 1
                                 for e in batch)  # add ending </s>

        ppl = np.exp(cum_loss / cum_tgt_words)
        model.train()
        return ppl

    print('begin training decoder, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)))
    print('vocab size: %d' % len(vocab.source))

    epoch = train_iter = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples if len(e.tgt_actions) <= 100
            ]
            src_sents = [e.src_sent for e in batch_examples]
            src_sents_var = nn_utils.to_input_variable(
                src_sents,
                vocab.source,
                cuda=args.cuda,
                append_boundary_sym=True)

            train_iter += 1
            optimizer.zero_grad()

            loss = model(src_sents_var)
            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter, report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)
        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # model.save(model_file)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        # evaluate ppl
        ppl = evaluate_ppl()
        print('[Epoch %d] ppl=%.5f took %ds' %
              (epoch, ppl, time.time() - eval_start),
              file=sys.stderr)
        dev_acc = -ppl
        is_better = history_dev_scores == [] or dev_acc > max(
            history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

            if patience == args.patience:
                num_trial += 1
                print('hit #%d trial' % num_trial, file=sys.stderr)
                if num_trial == args.max_num_trial:
                    print('early stop!', file=sys.stderr)
                    exit(0)

                # decay lr, and restore from previously best checkpoint
                lr = optimizer.param_groups[0]['lr'] * args.lr_decay
                print(
                    'load previously best model and decay learning rate to %f'
                    % lr,
                    file=sys.stderr)

                # load model
                params = torch.load(args.save_to + '.bin',
                                    map_location=lambda storage, loc: storage)
                model.load_state_dict(params['state_dict'])
                if args.cuda: model = model.cuda()

                # load optimizers
                if args.reset_optimizer:
                    print('reset optimizer', file=sys.stderr)
                    optimizer = torch.optim.Adam(
                        model.inference_model.parameters(), lr=lr)
                else:
                    print('restore parameters of the optimizers',
                          file=sys.stderr)
                    optimizer.load_state_dict(
                        torch.load(args.save_to + '.optim.bin'))

                # set new lr
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

                # reset patience
                patience = 0
Example #22
0
def train(args):
    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = TransitionSystem.get_class_by_lang(args.lang)(grammar)
    train_set = Dataset.from_bin_file(args.train_file)
    dev_set = Dataset.from_bin_file(args.dev_file)
    vocab = pkl.load(open(args.vocab, 'rb'))

    model = Parser(args, vocab, transition_system)
    model.train()
    if args.cuda: model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    print('begin training, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]

            train_iter += 1
            optimizer.zero_grad()

            loss = -model.score(batch_examples)
            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter, report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)
        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # model.save(model_file)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        eval_results = evaluation.evaluate(dev_set.examples,
                                           model,
                                           args,
                                           verbose=True)
        dev_acc = eval_results['accuracy']
        print('[Epoch %d] code generation accuracy=%.5f took %ds' %
              (epoch, dev_acc, time.time() - eval_start),
              file=sys.stderr)
        is_better = history_dev_scores == [] or dev_acc > max(
            history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(
                    model.inference_model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Example #23
0
def train_rl(args):
    test_set = Dataset.from_bin_file(args.test_file)
    assert args.load_model
    train_set = Dataset.from_bin_file(args.train_file)
    print('load model from [%s]' % args.load_model, file=sys.stderr)
    params = torch.load(args.load_model, map_location=lambda storage, loc: storage)
    transition_system = params['transition_system']
    saved_args = params['args']
    saved_args.cuda = args.cuda
    # set the correct domain from saved arg
    args.lang = saved_args.lang
    def getnew(model,e,shows):
        example=copy.deepcopy(e)
        hyps,extra = model.sample(example.src_sent, context=example.table,show=shows)
        if(len(hyps)==0):
            return e
        actions=hyps[0].actions
        example.tgt_actions = get_action_infos(example.src_sent, actions, force_copy=True)
        example.actions=actions
        example.extra=extra
        if(shows):print('haha')
        return example
    parser_cls = Registrable.by_name(args.parser)
    parser = parser_cls.load(model_path=args.load_model, cuda=args.cuda)
    
    parser.train()
    model=parser
    evaluator = Registrable.by_name(args.evaluator)(transition_system, args=args)
    if args.cuda: model.cuda()
    optimizer_cls = eval('torch.optim.%s' % args.optimizer)  # FIXME: this is evil!
    optimizer = optimizer_cls(model.parameters(), lr=args.lr)
    for e in train_set.examples:
                e.actions=[item.action for item in e.tgt_actions]
    train_iter=1 
    if(True):
                parser.eval()
#                print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
                eval_start = time.time()
                eval_results = evaluation.evaluate(test_set.examples, model, evaluator, args,
                                                   verbose=True, eval_top_pred_only=args.eval_top_pred_only)
                dev_score = eval_results[evaluator.default_metric]

                print('[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)' % (
                                    train_iter, eval_results,
                                    evaluator.default_metric,
                                    dev_score,
                                    time.time() - eval_start), file=sys.stderr)
    a=input('df')
    print(len(train_set.examples)/64) 
    bestscore=0            
    for _ in range(100): 
     model.train()   
     show=False    
     for batch_examples in tqdm.tqdm(list(train_set.batch_iter(batch_size=args.batch_size, shuffle=False))[0:]):
            model.eval()
#            if(train_iter>=23):
#                show=True
            batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step]
            newbatch_examples=[getnew(parser,e,show) for e in batch_examples]
#            print(newbatch_examples[0].extra)
#            model.train() 
            means=[]
            model.eval()
            for example in batch_examples:               
             hyps = model.parse(example.src_sent, context=example.table, beam_size=args.beam_size)
             if(len(hyps)==0):
                 means.append(None)
             else:means.append(hyps[0])
            import asdl
            weight=asdl.transition_system.get_scores(newbatch_examples,batch_examples,means)
#            if(train_iter>1000):
#            for i in range(len(batch_examples)):
#                print('true'+str(batch_examples[i].actions))
#                print('chou'+str(newbatch_examples[i].actions))
#                print('max'+str(means[i].actions))
#                print(weight[i])
#                a=input('hj')
#            if(train_iter>=22):
#             print(weight)
##            a=input('jj')
            train_iter += 1
            model.train()
            optimizer.zero_grad()

            ret_val,extra = model.score(batch_examples,weights=weight)
            ret_val1,extra1 = model.score(batch_examples)
#            print([torch.exp(item) for item in extra[0]])
#            a=input('gh')
            loss = -ret_val[0]
            loss1= -ret_val1[0]

            # print(loss.data)
            loss_val = torch.sum(loss).data.item()
#            report_loss += loss_val
#            report_examples += len(batch_examples)
            loss = torch.mean(loss)+torch.mean(loss1)*0.0

#            if args.sup_attention:
#                att_probs = ret_val[1]
#                if att_probs:
#                    sup_att_loss = -torch.log(torch.cat(att_probs)).mean()
#                    sup_att_loss_val = sup_att_loss.data.item()
##                    report_sup_att_loss += sup_att_loss_val
#
#                    loss += sup_att_loss

            loss.backward()
#
#            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad)
#
            optimizer.step()
#            print(train_iter)
            if train_iter % 200 == 0:
                parser.eval()
#                print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
                eval_start = time.time()
                eval_results = evaluation.evaluate(test_set.examples[:2000], model, evaluator, args,
                                                   verbose=True, eval_top_pred_only=args.eval_top_pred_only)
                dev_score = eval_results[evaluator.default_metric]
                if bestscore < dev_score:
                    bestscore =dev_score
                print('[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)' % (
                                    train_iter, eval_results,
                                    evaluator.default_metric,
                                    dev_score,
                                    time.time() - eval_start), file=sys.stderr)
                print('hhah'+str(bestscore))
            
    evaluator = Registrable.by_name(args.evaluator)(transition_system, args=args)
    eval_results, decode_results = evaluation.evaluate(test_set.examples, parser, evaluator, args,
                                                       verbose=args.verbose, return_decode_result=True)
    print(eval_results, file=sys.stderr)
    if args.save_decode_to:
        pickle.dump(decode_results, open(args.save_decode_to, 'wb'))
Example #24
0
def train(args):
    """Maximum Likelihood Estimation"""

    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = TransitionSystem.get_class_by_lang(args.lang)(grammar)
    train_set = Dataset.from_bin_file(args.train_file)

    if args.dev_file:
        dev_set = Dataset.from_bin_file(args.dev_file)
    else: dev_set = Dataset(examples=[])

    vocab = pickle.load(open(args.vocab, 'rb'))
    
    if args.lang == 'wikisql':
        # import additional packages for wikisql dataset
        from model.wikisql.dataset import WikiSqlExample, WikiSqlTable, TableColumn

    parser_cls = get_parser_class(args.lang)
    model = parser_cls(args, vocab, transition_system)
    model.train()
    if args.cuda: model.cuda()

    optimizer_cls = eval('torch.optim.%s' % args.optimizer)  # FIXME: this is evil!
    optimizer = optimizer_cls(model.parameters(), lr=args.lr)

    if args.uniform_init:
        print('uniformly initialize parameters [-%f, +%f]' % (args.uniform_init, args.uniform_init), file=sys.stderr)
        nn_utils.uniform_init(-args.uniform_init, args.uniform_init, model.parameters())
    elif args.glorot_init:
        print('use glorot initialization', file=sys.stderr)
        nn_utils.glorot_init(model.parameters())

    # load pre-trained word embedding (optional)
    if args.glove_embed_path:
        print('load glove embedding from: %s' % args.glove_embed_path, file=sys.stderr)
        glove_embedding = GloveHelper(args.glove_embed_path)
        glove_embedding.load_to(model.src_embed, vocab.source)

    print('begin training, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = report_sup_att_loss = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True):
            batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step]

            train_iter += 1
            optimizer.zero_grad()

            ret_val = model.score(batch_examples)
            loss = -ret_val[0]

            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            if args.sup_attention:
                att_probs = ret_val[1]
                if att_probs:
                    sup_att_loss = -torch.log(torch.cat(att_probs)).mean()
                    sup_att_loss_val = sup_att_loss.data[0]
                    report_sup_att_loss += sup_att_loss_val

                    loss += sup_att_loss

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                log_str = '[Iter %d] encoder loss=%.5f' % (train_iter, report_loss / report_examples)
                if args.sup_attention:
                    log_str += ' supervised attention loss=%.5f' % (report_sup_att_loss / report_examples)
                    report_sup_att_loss = 0.

                print(log_str, file=sys.stderr)
                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr)

        if args.save_all_models:
            model_file = args.save_to + '.iter%d.bin' % train_iter
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)

        # perform validation
        if args.dev_file:
            if epoch % args.valid_every_epoch == 0:
                print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
                eval_start = time.time()
                eval_results = evaluation.evaluate(dev_set.examples, model, args,
                                                   verbose=True, eval_top_pred_only=args.eval_top_pred_only)
                dev_acc = eval_results['accuracy']
                print('[Epoch %d] code generation accuracy=%.5f took %ds' % (epoch, dev_acc, time.time() - eval_start), file=sys.stderr)
                is_better = history_dev_scores == [] or dev_acc > max(history_dev_scores)
                history_dev_scores.append(dev_acc)
        else:
            is_better = True

            if epoch > args.lr_decay_after_epoch:
                lr = optimizer.param_groups[0]['lr'] * args.lr_decay
                print('decay learning rate to %f' % lr, file=sys.stderr)

                # set new lr
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save the current model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience and epoch >= args.lr_decay_after_epoch:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)

        if patience >= args.patience and epoch >= args.lr_decay_after_epoch:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Example #25
0
def train(args):
    """Maximum Likelihood Estimation"""

    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = TransitionSystem.get_class_by_lang(args.lang)(grammar)
    train_set = Dataset.from_bin_file(args.train_file)

    if args.dev_file:
        dev_set = Dataset.from_bin_file(args.dev_file)
    else:
        dev_set = Dataset(examples=[])

    vocab = pickle.load(open(args.vocab, 'rb'))

    if args.lang == 'wikisql':
        # import additional packages for wikisql dataset
        from model.wikisql.dataset import WikiSqlExample, WikiSqlTable, TableColumn

    parser_cls = get_parser_class(args.lang)
    model = parser_cls(args, vocab, transition_system)
    model.train()
    if args.cuda: model.cuda()

    optimizer_cls = eval('torch.optim.%s' %
                         args.optimizer)  # FIXME: this is evil!
    optimizer = optimizer_cls(model.parameters(), lr=args.lr)

    if args.uniform_init:
        print('uniformly initialize parameters [-%f, +%f]' %
              (args.uniform_init, args.uniform_init),
              file=sys.stderr)
        nn_utils.uniform_init(-args.uniform_init, args.uniform_init,
                              model.parameters())
    elif args.glorot_init:
        print('use glorot initialization', file=sys.stderr)
        nn_utils.glorot_init(model.parameters())

    # load pre-trained word embedding (optional)
    if args.glove_embed_path:
        print('load glove embedding from: %s' % args.glove_embed_path,
              file=sys.stderr)
        glove_embedding = GloveHelper(args.glove_embed_path)
        glove_embedding.load_to(model.src_embed, vocab.source)

    print('begin training, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = report_sup_att_loss = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]

            train_iter += 1
            optimizer.zero_grad()

            ret_val = model.score(batch_examples)
            loss = -ret_val[0]

            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            if args.sup_attention:
                att_probs = ret_val[1]
                if att_probs:
                    sup_att_loss = -torch.log(torch.cat(att_probs)).mean()
                    sup_att_loss_val = sup_att_loss.data[0]
                    report_sup_att_loss += sup_att_loss_val

                    loss += sup_att_loss

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                log_str = '[Iter %d] encoder loss=%.5f' % (
                    train_iter, report_loss / report_examples)
                if args.sup_attention:
                    log_str += ' supervised attention loss=%.5f' % (
                        report_sup_att_loss / report_examples)
                    report_sup_att_loss = 0.

                print(log_str, file=sys.stderr)
                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)

        if args.save_all_models:
            model_file = args.save_to + '.iter%d.bin' % train_iter
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)

        # perform validation
        if args.dev_file:
            if epoch % args.valid_every_epoch == 0:
                print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
                eval_start = time.time()
                eval_results = evaluation.evaluate(
                    dev_set.examples,
                    model,
                    args,
                    verbose=True,
                    eval_top_pred_only=args.eval_top_pred_only)
                dev_acc = eval_results['accuracy']
                print('[Epoch %d] code generation accuracy=%.5f took %ds' %
                      (epoch, dev_acc, time.time() - eval_start),
                      file=sys.stderr)
                is_better = history_dev_scores == [] or dev_acc > max(
                    history_dev_scores)
                history_dev_scores.append(dev_acc)
        else:
            is_better = True

            if epoch > args.lr_decay_after_epoch:
                lr = optimizer.param_groups[0]['lr'] * args.lr_decay
                print('decay learning rate to %f' % lr, file=sys.stderr)

                # set new lr
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save the current model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience and epoch >= args.lr_decay_after_epoch:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)

        if patience >= args.patience and epoch >= args.lr_decay_after_epoch:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Example #26
0
def train(args):
    """Maximum Likelihood Estimation"""
    tokenizer=BertTokenizer.from_pretrained('./pretrained_models/bert-base-uncased-vocab.txt')
    bertmodels=BertModel.from_pretrained('./pretrained_models/base-uncased/')
    print(len(tokenizer.vocab))
    
    # load in train/dev set
    tmpvocab={'null':0}
    tmpvocab1={'null':0,'unk':1}
    tmpvocab2={'null':0,'unk':1}
    tmpvocab3={'null':0,'unk':1}
    train_set = Dataset.from_bin_file(args.train_file)
    from dependency import sentencetoadj,sentencetoextra_message
    valid=0
#    for example in tqdm.tqdm(train_set.examples):
##        print(example.src_sent)
##        example.mainnode,example.adj,example.edge,_,isv=sentencetoadj(example.src_sent,tmpvocab)
#        example.contains,example.pos,example.ner,example.types,example.tins,example.F1=sentencetoextra_message(example.src_sent,[item.tokens for item in example.table.header],[item.type for item in example.table.header],tmpvocab1,tmpvocab2,tmpvocab3,True)
##        valid+=isv
##        print(example.src_sent)
###        print( example.contains,example.pos,example.ner,example.types)
##        a=input('gh')
#    print('bukey',valid)
    if args.dev_file:
        dev_set = Dataset.from_bin_file(args.dev_file)
    else: dev_set = Dataset(examples=[])
#    for example in tqdm.tqdm(dev_set.examples):
##        print(example.src_sent)
##        example.mainnode,example.adj,example.edge,_,_=sentencetoadj(example.src_sent,tmpvocab)
#        example.contains,example.pos,example.ner,example.types,example.tins,example.F1=sentencetoextra_message(example.src_sent,[item.tokens for item in example.table.header],[item.type for item in example.table.header],tmpvocab1,tmpvocab2,tmpvocab3,False)
    vocab = pickle.load(open(args.vocab, 'rb'))
    print(len(vocab.source))
    vocab.source.copyandmerge(tokenizer)
    print(len(vocab.source))
#    tokenizer.update(vocab.source.word2id)
#    print(len(tokenizer.vocab))
#    print(tokenizer.vocab['metodiev'])
#    bertmodels.resize_token_embeddings(len(vocab.source))
    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = Registrable.by_name(args.transition_system)(grammar)

    parser_cls = Registrable.by_name(args.parser)  # TODO: add arg
    model = parser_cls(args, vocab, transition_system,tmpvocab,tmpvocab1,tmpvocab2,tmpvocab3)

    model.train()
    model.tokenizer=tokenizer
    evaluator = Registrable.by_name(args.evaluator)(transition_system, args=args)
    if args.cuda: model.cuda()



    if args.uniform_init:
        print('uniformly initialize parameters [-%f, +%f]' % (args.uniform_init, args.uniform_init), file=sys.stderr)
        nn_utils.uniform_init(-args.uniform_init, args.uniform_init, model.parameters())
    elif args.glorot_init:
        print('use glorot initialization', file=sys.stderr)
        nn_utils.glorot_init(model.parameters())

    # load pre-trained word embedding (optional)
    if args.glove_embed_path:
        print('load glove embedding from: %s' % args.glove_embed_path, file=sys.stderr)
        glove_embedding = GloveHelper(args.glove_embed_path)
        glove_embedding.load_to(model.src_embed, vocab.source)
    print([name for name,_ in model.named_parameters()])
    model.bert_model=bertmodels
#    print([name for name,_ in model.named_parameters()])
    model.train()
    if args.cuda: model.cuda()
#    return 0
#    a=input('haha')
    optimizer_cls = eval('torch.optim.%s' % args.optimizer)  # FIXME: this is evil!
#    parameters=[p for name,p in model.named_parameters() if 'bert_model' not in name or 'embeddings' in name]
    parameters=[p for name,p in model.named_parameters() if 'bert_model' not in name]
    parameters1=[p for name,p in model.named_parameters() if 'bert_model' in name]
    optimizer = optimizer_cls(parameters, lr=args.lr)
    optimizer1 = optimizer_cls(parameters1, lr=0.00001)
    print('begin training, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)
    is_better = False
    epoch = train_iter = 0
    report_loss = report_examples = report_sup_att_loss = 0.
    report_loss1=0
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        if epoch>40:break
        epoch += 1
        epoch_begin = time.time()
        model.train()
        for batch_examples in tqdm.tqdm(train_set.batch_iter(batch_size=args.batch_size, shuffle=True)):
            def process(header,src_sent,tokenizer):
                length1=len(header)
                flat_src=[]
                for item in src_sent:
                    flat_src.extend(tokenizer._tokenize(item))
                flat=[token for item in header for token in item.tokens]
                flat_head=[]
                for item in flat:
                    flat_head.extend(tokenizer._tokenize(item))
#                length2=len(flat)+length1+len(src_sent)
                length2=len(flat_head)+length1+len(flat_src)
                print(src_sent)
                print([item.tokens for item in header])
                print(flat_src)
                print(flat)
                a=input('hahaha')
                return length2<130
            batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step and process(e.table.header,e.src_sent,tokenizer)]
            train_iter += 1
            optimizer.zero_grad()
            optimizer1.zero_grad()
#            params1=model.named_parameters()
#            print([param for param,_ in params1])
#            params=model.rnns.named_parameters()
#            print([param for param,_ in params])
#            print([type(param.grad) for _,param in model.rnns.named_parameters()])
#            a=input('ghh;')
            ret_val,_ = model.score(batch_examples)
            loss = -ret_val[0]
            loss1=ret_val[1]
            # print(loss.data)
            loss_val = torch.sum(loss).data.item()
            report_loss += loss_val
            report_loss1 += 1.0*torch.sum(ret_val[2])
            report_examples += len(batch_examples)
            loss = torch.mean(loss)+0*loss1+0*torch.mean(ret_val[2])

            if args.sup_attention:
                att_probs = ret_val[1]
                if att_probs:
                    sup_att_loss = -torch.log(torch.cat(att_probs)).mean()
                    sup_att_loss_val = sup_att_loss.data.item()
                    report_sup_att_loss += sup_att_loss_val

                    loss += sup_att_loss

            loss.backward()
#            print([type(param.grad) for _,param in model.rnns.named_parameters()])
#            
#            print([type(param.grad) for param in model.parameters()])
#            a=input('ghh;')
            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad)

            optimizer.step()
            optimizer1.step()
            loss=None
            if train_iter % args.log_every == 0:
                log_str = '[Iter %d] encoder loss=%.5f,coverage loss=%.5f' % (train_iter, report_loss / report_examples,report_loss1 / report_examples)
                if args.sup_attention:
                    log_str += ' supervised attention loss=%.5f' % (report_sup_att_loss / report_examples)
                    report_sup_att_loss = 0.

                print(log_str, file=sys.stderr)
                report_loss = report_examples = 0.
                report_loss1=0

        print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr)

        if args.save_all_models:
            model_file = args.save_to + '.iter%d.bin' % train_iter
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)

        # perform validation
        if args.dev_file and epoch>=6:
#            a=input('gh')
            if epoch % args.valid_every_epoch == 0:
                print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
                eval_start = time.time()
                eval_results = evaluation.evaluate(dev_set.examples, model, evaluator, args,
                                                   verbose=True, eval_top_pred_only=args.eval_top_pred_only)
                dev_score = eval_results[evaluator.default_metric]

                print('[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)' % (
                                    epoch, eval_results,
                                    evaluator.default_metric,
                                    dev_score,
                                    time.time() - eval_start), file=sys.stderr)
                is_better = history_dev_scores == [] or dev_score > max(history_dev_scores)
                history_dev_scores.append(dev_score)
                print('[Epoch %d] begin validation2' % epoch, file=sys.stderr)
#                eval_start = time.time()
#                eval_results = evaluation.evaluate(dev_set.examples[:2000], model, evaluator, args,
#                                                   verbose=True, eval_top_pred_only=args.eval_top_pred_only)
#                dev_score = eval_results[evaluator.default_metric]
#
#                print('[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)' % (
#                                    epoch, eval_results,
#                                    evaluator.default_metric,
#                                    dev_score,
#                                    time.time() - eval_start), file=sys.stderr)
#                is_better = history_dev_scores == [] or dev_score > max(history_dev_scores)
#                history_dev_scores.append(dev_score)

        else:
            is_better = True

        if args.decay_lr_every_epoch and epoch > args.lr_decay_after_epoch:
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('decay learning rate to %f' % lr, file=sys.stderr)

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save the current model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience and epoch >= args.lr_decay_after_epoch:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)

        if patience >= args.patience and epoch >= args.lr_decay_after_epoch:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                a=input('hj')
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Example #27
0
def train_decoder(args):
    train_set = Dataset.from_bin_file(args.train_file)
    dev_set = Dataset.from_bin_file(args.dev_file)
    vocab = pickle.load(open(args.vocab))

    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = TransitionSystem.get_class_by_lang(args.lang)(grammar)

    model = Reconstructor(args, vocab, transition_system)
    model.train()
    if args.cuda: model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    def evaluate_ppl():
        model.eval()
        cum_loss = 0.
        cum_tgt_words = 0.
        for batch in dev_set.batch_iter(args.batch_size):
            loss = -model.score(batch).sum()
            cum_loss += loss.data[0]
            cum_tgt_words += sum(len(e.src_sent) + 1
                                 for e in batch)  # add ending </s>

        ppl = np.exp(cum_loss / cum_tgt_words)
        model.train()
        return ppl

    print('begin training decoder, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]
            # batch_examples = [e for e in train_set.examples if e.idx in [10192, 10894, 9706, 4659, 5609, 1442, 5849, 10644, 4592, 1875]]

            train_iter += 1
            optimizer.zero_grad()

            loss = -model.score(batch_examples)
            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter, report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)
        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # model.save(model_file)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        # evaluate ppl
        ppl = evaluate_ppl()
        print('[Epoch %d] ppl=%.5f took %ds' %
              (epoch, ppl, time.time() - eval_start),
              file=sys.stderr)
        dev_acc = -ppl
        is_better = history_dev_scores == [] or dev_acc > max(
            history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(
                    model.inference_model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Example #28
0
def train(cfg: argparse.Namespace):
    cli_logger.info("=== Training ===")

    # initial setup
    summary_writer = prologue(cfg)

    # load train/dev set
    train_set = Dataset.from_bin_file(cfg.train_file)

    if cfg.dev_file:
        dev_set = Dataset.from_bin_file(cfg.dev_file)
    else:
        dev_set = Dataset(examples=[])

    vocab = pickle.load(open(cfg.vocab, 'rb'))

    grammar = ASDLGrammar.from_text(open(cfg.asdl_file).read())
    transition_system = Registrable.by_name(cfg.transition_system)(grammar)

    parser_cls = Registrable.by_name(cfg.parser)
    model = parser_cls(cfg, vocab, transition_system)
    model.train()

    evaluator = Registrable.by_name(cfg.evaluator)(transition_system, args=cfg)
    if cfg.cuda:
        model.cuda()

    optimizer_cls = eval(f'torch.optim.{cfg.optimizer}')
    optimizer = optimizer_cls(model.parameters(), lr=cfg.lr)

    if cfg.uniform_init:
        print('uniformly initialize parameters [-%f, +%f]' %
              (cfg.uniform_init, cfg.uniform_init),
              file=sys.stderr)
        nn_utils.uniform_init(-cfg.uniform_init, cfg.uniform_init,
                              model.parameters())
    elif cfg.xavier_init:
        print('use xavier initialization', file=sys.stderr)
        nn_utils.xavier_init(model.parameters())

    # load pre-trained word embedding (optional)
    if cfg.glove_embed_path:
        print('load glove embedding from: %s' % cfg.glove_embed_path,
              file=sys.stderr)
        glove_embedding = GloveHelper(cfg.glove_embed_path)
        glove_embedding.load_to(model.src_embed, vocab.source)

    print('begin training, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = report_sup_att_loss = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=cfg.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples
                if len(e.tgt_actions) <= cfg.decode_max_time_step
            ]
            train_iter += 1
            optimizer.zero_grad()

            ret_val = model.score(batch_examples)
            loss = -ret_val[0]

            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            if cfg.sup_attention:
                att_probs = ret_val[1]
                if att_probs:
                    sup_att_loss = -torch.log(torch.cat(att_probs)).mean()
                    sup_att_loss_val = sup_att_loss.data[0]
                    report_sup_att_loss += sup_att_loss_val

                    loss += sup_att_loss

            loss.backward()

            # clip gradient
            if cfg.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), cfg.clip_grad)

            optimizer.step()

            if train_iter % cfg.log_every == 0:
                log_str = '[Iter %d] encoder loss=%.5f' % (
                    train_iter, report_loss / report_examples)
                if cfg.sup_attention:
                    log_str += ' supervised attention loss=%.5f' % (
                        report_sup_att_loss / report_examples)
                    report_sup_att_loss = 0.

                print(log_str, file=sys.stderr)
                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)

        if cfg.save_all_models:
            model_file = cfg.save_to + '.iter%d.bin' % train_iter
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)

        # perform validation
        if cfg.dev_file:
            if epoch % cfg.valid_every_epoch == 0:
                print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
                eval_start = time.time()
                eval_results = evaluation.evaluate(
                    dev_set.examples,
                    model,
                    evaluator,
                    cfg,
                    verbose=True,
                    eval_top_pred_only=cfg.eval_top_pred_only)
                dev_score = eval_results[evaluator.default_metric]

                print(
                    '[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)'
                    % (epoch, eval_results, evaluator.default_metric,
                       dev_score, time.time() - eval_start),
                    file=sys.stderr)

                is_better = history_dev_scores == [] or dev_score > max(
                    history_dev_scores)
                history_dev_scores.append(dev_score)
        else:
            is_better = True

        if cfg.decay_lr_every_epoch and epoch > cfg.lr_decay_after_epoch:
            lr = optimizer.param_groups[0]['lr'] * cfg.lr_decay
            print('decay learning rate to %f' % lr, file=sys.stderr)

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        if is_better:
            patience = 0
            model_file = cfg.save_to + '.bin'
            print('save the current model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), cfg.save_to + '.optim.bin')
        elif patience < cfg.patience and epoch >= cfg.lr_decay_after_epoch:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if epoch == cfg.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)

        if patience >= cfg.patience and epoch >= cfg.lr_decay_after_epoch:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == cfg.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * cfg.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load model
            params = torch.load(cfg.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if cfg.cuda:
                model = model.cuda()

            # load optimizers
            if cfg.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(cfg.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0

    # final setup
    epilogue(cfg, summary_writer)
Example #29
0
def train(args):
    """Maximum Likelihood Estimation"""

    # load in train/dev set
    print("loading files")
    print(f"Loading Train at {args.train_file}")
    train_set = Dataset.from_bin_file(args.train_file)
    train_ids = [e.idx.split('-')[-1] for e in train_set.examples]
    print(f"{len(train_set.examples)} total examples")
    print('Checking ids:')
    for idx in ['4170655', '13704860', '4170655', '13704860', '3862010']:
        print(f"\t{idx} is {'not in' if idx not in train_ids else 'in'} train")
    print("")
    print(f"First 5 Examples in Train:")
    for i in range(10):
        print(f'\tExample {i + 1}(idx:{train_set.examples[i].idx}):')
        print(f"\t\tSource:{repr(' '.join(train_set.all_source[i])[:100])}")
        print(f"\t\tTarget:{repr(train_set.all_targets[i][:100])}")
    if args.dev_file:
        print(f"Loading dev at {args.dev_file}")
        dev_set = Dataset.from_bin_file(args.dev_file)
    else:
        dev_set = Dataset(examples=[])
    print("Loading vocab")

    vocab = pickle.load(open(args.vocab, 'rb'))

    print(f"Loading grammar {args.asdl_file}")
    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = Registrable.by_name(args.transition_system)(grammar)

    parser_cls = Registrable.by_name(args.parser)  # TODO: add arg
    if args.pretrain:
        print('Finetune with: ', args.pretrain)
        model = parser_cls.load(model_path=args.pretrain, cuda=args.cuda)
    else:
        model = parser_cls(args, vocab, transition_system)

    model.train()
    evaluator = Registrable.by_name(args.evaluator)(transition_system,
                                                    args=args)
    if args.cuda: model.cuda()

    optimizer_cls = eval('torch.optim.%s' %
                         args.optimizer)  # FIXME: this is evil!
    optimizer = optimizer_cls(model.parameters(), lr=args.lr)

    if not args.pretrain:
        if args.uniform_init:
            print('uniformly initialize parameters [-%f, +%f]' %
                  (args.uniform_init, args.uniform_init))
            nn_utils.uniform_init(-args.uniform_init, args.uniform_init,
                                  model.parameters())
        elif args.glorot_init:
            print('use glorot initialization')
            nn_utils.glorot_init(model.parameters())

        # load pre-trained word embedding (optional)
        if args.glove_embed_path:
            print('load glove embedding from: %s' % args.glove_embed_path)
            glove_embedding = GloveHelper(args.glove_embed_path)
            glove_embedding.load_to(model.src_embed, vocab.source)

    print('begin training, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab))

    epoch = train_iter = 0
    report_loss = report_examples = report_sup_att_loss = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]
            train_iter += 1
            optimizer.zero_grad()

            ret_val = model.score(batch_examples)
            loss = -ret_val[0]

            # print(loss.data)
            loss_val = torch.sum(loss).data.item()
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            if args.sup_attention:
                att_probs = ret_val[1]
                if att_probs:
                    sup_att_loss = -torch.log(torch.cat(att_probs)).mean()
                    sup_att_loss_val = sup_att_loss.data[0]
                    report_sup_att_loss += sup_att_loss_val

                    loss += sup_att_loss

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                log_str = '[Iter %d] encoder loss=%.5f' % (
                    train_iter, report_loss / report_examples)
                if args.sup_attention:
                    log_str += ' supervised attention loss=%.5f' % (
                        report_sup_att_loss / report_examples)
                    report_sup_att_loss = 0.

                print(log_str)
                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin))

        if args.save_all_models:
            model_file = args.save_to + '.iter%d.bin' % train_iter
            print('save model to [%s]' % model_file)
            model.save(model_file)

        # perform validation
        is_better = False
        if args.dev_file:
            if epoch % args.valid_every_epoch == 0:
                print('[Epoch %d] begin validation' % epoch)
                eval_start = time.time()
                eval_results = evaluation.evaluate(
                    dev_set.examples,
                    model,
                    evaluator,
                    args,
                    verbose=False,
                    eval_top_pred_only=args.eval_top_pred_only)
                dev_score = eval_results[evaluator.default_metric]

                print(
                    '[Epoch %d] evaluate details: %s, dev %s: %.5f (took %ds)'
                    % (epoch, eval_results, evaluator.default_metric,
                       dev_score, time.time() - eval_start))

                is_better = history_dev_scores == [] or dev_score > max(
                    history_dev_scores)
                history_dev_scores.append(dev_score)
        else:
            is_better = True

        if args.decay_lr_every_epoch and epoch > args.lr_decay_after_epoch:
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('decay learning rate to %f' % lr)

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save the current model ..')
            print('save model to [%s]' % model_file)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience and epoch >= args.lr_decay_after_epoch:
            patience += 1
            print('hit patience %d' % patience)

        if epoch == args.max_epoch:
            print('reached max epoch, stop!')
            exit(0)

        if patience >= args.patience and epoch >= args.lr_decay_after_epoch:
            num_trial += 1
            print('hit #%d trial' % num_trial)
            if num_trial == args.max_num_trial:
                print('early stop!')
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr)

            # load model
            params = torch.load(args.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer')
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers')
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Example #30
0
def train_decoder(args):
    train_set = Dataset.from_bin_file(args.train_file)
    dev_set = Dataset.from_bin_file(args.dev_file)
    vocab = pickle.load(open(args.vocab))

    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = TransitionSystem.get_class_by_lang(args.lang)(grammar)

    model = Reconstructor(args, vocab, transition_system)
    model.train()
    if args.cuda: model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    def evaluate_ppl():
        model.eval()
        cum_loss = 0.
        cum_tgt_words = 0.
        for batch in dev_set.batch_iter(args.batch_size):
            loss = -model.score(batch).sum()
            cum_loss += loss.data[0]
            cum_tgt_words += sum(len(e.src_sent) + 1 for e in batch)  # add ending </s>

        ppl = np.exp(cum_loss / cum_tgt_words)
        model.train()
        return ppl

    print('begin training decoder, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True):
            batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step]
            # batch_examples = [e for e in train_set.examples if e.idx in [10192, 10894, 9706, 4659, 5609, 1442, 5849, 10644, 4592, 1875]]

            train_iter += 1
            optimizer.zero_grad()

            loss = -model.score(batch_examples)
            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter,
                       report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr)
        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # model.save(model_file)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        # evaluate ppl
        ppl = evaluate_ppl()
        print('[Epoch %d] ppl=%.5f took %ds' % (epoch, ppl, time.time() - eval_start), file=sys.stderr)
        dev_acc = -ppl
        is_better = history_dev_scores == [] or dev_acc > max(history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(model.inference_model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0