Ejemplo n.º 1
0
    def run(self):
        seed = 13
        random.seed(seed)
        np.random.seed(seed)

        EXP_DIR = os.path.dirname(__file__)
        EXP = "annot"

        model_file = f"{EXP_DIR}/results/{EXP}.mod"
        log_file = f"{EXP_DIR}/results/{EXP}.log"
        xnmt.tee.utils.dy.DynetParams().set_mem(
            1024)  #Doesnt work figure out how to set memory
        xnmt.tee.set_out_file(log_file, exp_name=EXP)

        ParamManager.init_param_col()
        ParamManager.param_col.model_file = model_file

        pre_runner = PreprocRunner(
            tasks=[
                PreprocTokenize(
                    in_files=
                    [  #f'{EXP_DIR}/conala-corpus/conala-trainnodev.snippet',
                        #f'{EXP_DIR}/conala-corpus/conala-trainnodev.intent',
                        #f'{EXP_DIR}/conala-corpus/conala-dev.intent',
                        #f'{EXP_DIR}/conala-corpus/conala-dev.snippet',
                        #f'{EXP_DIR}/conala-corpus/conala-test.intent',
                        #f'{EXP_DIR}/conala-corpus/conala-test.snippet',
                        f'{EXP_DIR}/conala-corpus/attack_code_train.txt',
                        f'{EXP_DIR}/conala-corpus/attack_text_train.txt',
                        f'{EXP_DIR}/conala-corpus/attack_code_test.txt',
                        f'{EXP_DIR}/conala-corpus/attack_text_test.txt'

                        #f'{EXP_DIR}/conala-corpus/all.code',
                        #f'{EXP_DIR}/conala-corpus/all.anno'
                    ],
                    out_files=
                    [  #f'{EXP_DIR}/conala-corpus/conala-trainnodev.tmspm4000.snippet',
                        #f'{EXP_DIR}/conala-corpus/conala-trainnodev.tmspm4000.intent',
                        #f'{EXP_DIR}/conala-corpus/conala-dev.tmspm4000.intent',
                        #f'{EXP_DIR}/conala-corpus/conala-dev.tmspm4000.snippet',
                        #f'{EXP_DIR}/conala-corpus/conala-test.tmspm4000.intent',
                        #f'{EXP_DIR}/conala-corpus/conala-test.tmspm4000.snippet',
                        f'{EXP_DIR}/conala-corpus/attack-train.tmspm4000.snippet',
                        f'{EXP_DIR}/conala-corpus/attack-train.tmspm4000.intent',
                        f'{EXP_DIR}/conala-corpus/attack-test.tmspm4000.snippet',
                        f'{EXP_DIR}/conala-corpus/attack-test.tmspm4000.intent'
                        #f'{EXP_DIR}/conala-corpus/django.tmspm4000.snippet',
                        #f'{EXP_DIR}/conala-corpus/django.tmspm4000.intent'
                    ],
                    specs=[{
                        'filenum':
                        'all',
                        'tokenizers': [
                            SentencepieceTokenizer(
                                hard_vocab_limit=False,
                                train_files=[
                                    f'{EXP_DIR}/conala-corpus/attack_text_train.txt',
                                    f'{EXP_DIR}/conala-corpus/attack_code_train.txt'
                                ],
                                vocab_size=self.vocab_size,
                                model_type=self.model_type,
                                model_prefix=
                                'conala-corpus/attack-train.tmspm4000.spm')
                        ]
                    }]),
                PreprocVocab(
                    in_files=[
                        f'{EXP_DIR}/conala-corpus/attack-train.tmspm4000.intent',
                        f'{EXP_DIR}/conala-corpus/attack-train.tmspm4000.snippet'
                    ],
                    out_files
                    =[
                        f'{EXP_DIR}/conala-corpus/attack-train.tmspm4000.intent.vocab',
                        f'{EXP_DIR}/conala-corpus/attack-train.tmspm4000.snippet.vocab'
                    ],
                    specs=[{
                        'filenum':
                        'all',
                        'filters': [VocabFiltererFreq(min_freq=self.min_freq)]
                    }])
            ],
            overwrite=False)

        src_vocab = Vocab(
            vocab_file=
            f"{EXP_DIR}/conala-corpus/attack-train.tmspm4000.intent.vocab")
        trg_vocab = Vocab(
            vocab_file=
            f"{EXP_DIR}/conala-corpus/attack-train.tmspm4000.snippet.vocab")

        batcher = Batcher(batch_size=64)

        inference = AutoRegressiveInference(search_strategy=BeamSearch(
            len_norm=PolynomialNormalization(apply_during_search=True),
            beam_size=5),
                                            post_process='join-piece')

        layer_dim = self.layer_dim

        model = DefaultTranslator(
            src_reader=PlainTextReader(vocab=src_vocab),
            trg_reader=PlainTextReader(vocab=trg_vocab),
            src_embedder=SimpleWordEmbedder(emb_dim=layer_dim,
                                            vocab=src_vocab),
            encoder=BiLSTMSeqTransducer(input_dim=layer_dim,
                                        hidden_dim=layer_dim,
                                        layers=self.layers),
            attender=MlpAttender(hidden_dim=layer_dim,
                                 state_dim=layer_dim,
                                 input_dim=layer_dim),
            trg_embedder=SimpleWordEmbedder(emb_dim=layer_dim,
                                            vocab=trg_vocab),
            decoder=AutoRegressiveDecoder(
                input_dim=layer_dim,
                rnn=UniLSTMSeqTransducer(
                    input_dim=layer_dim,
                    hidden_dim=layer_dim,
                ),
                transform=AuxNonLinear(input_dim=layer_dim,
                                       output_dim=layer_dim,
                                       aux_input_dim=layer_dim),
                scorer=Softmax(vocab_size=len(trg_vocab), input_dim=layer_dim),
                trg_embed_dim=layer_dim,
                input_feeding=False,
                bridge=CopyBridge(dec_dim=layer_dim)),
            inference=inference)

        #decoder = AutoRegressiveDecoder(bridge=CopyBridge(),inference=inference))

        train = SimpleTrainingRegimen(
            name=f"{EXP}",
            model=model,
            batcher=WordSrcBatcher(avg_batch_size=64),
            trainer=AdamTrainer(alpha=self.alpha),
            patience=3,
            lr_decay=0.5,
            restart_trainer=True,
            run_for_epochs=self.epochs,
            src_file=f"{EXP_DIR}/conala-corpus/attack-train.tmspm4000.intent",
            trg_file=f"{EXP_DIR}/conala-corpus/attack-train.tmspm4000.snippet",
            dev_tasks=[
                LossEvalTask(
                    src_file=
                    f"{EXP_DIR}/conala-corpus/attack-test.tmspm4000.intent",
                    ref_file=
                    f'{EXP_DIR}/conala-corpus/attack-test.tmspm4000.snippet',
                    model=model,
                    batcher=WordSrcBatcher(avg_batch_size=64)),
                AccuracyEvalTask(
                    eval_metrics='bleu',
                    src_file=
                    f'{EXP_DIR}/conala-corpus/attack-test.tmspm4000.intent',
                    ref_file=f'{EXP_DIR}/conala-corpus/attack_text_test.txt',
                    hyp_file=f'results/{EXP}.dev.hyp',
                    model=model)
            ])

        evaluate = [
            AccuracyEvalTask(
                eval_metrics="bleu",
                #src_file=f"{EXP_DIR}/conala-corpus/conala-test.tmspm4000.intent",
                src_file=
                f"{EXP_DIR}/conala-corpus/attack-test.tmspm4000.intent",
                #ref_file=f"{EXP_DIR}/conala-corpus/all.code",
                #ref_file = f"{EXP_DIR}/conala-corpus/conala-test.snippet",
                ref_file=f"{EXP_DIR}/conala-corpus/attack_text_test.txt",
                hyp_file=f"results/{EXP}.test.hyp",
                inference=inference,
                model=model)
        ]

        standard_experiment = Experiment(exp_global=ExpGlobal(
            default_layer_dim=512,
            dropout=0.3,
            log_file=log_file,
            model_file=model_file),
                                         name="annot",
                                         model=model,
                                         train=train,
                                         evaluate=evaluate)

        # run experiment
        standard_experiment(
            save_fct=lambda: save_to_file(model_file, standard_experiment))

        exit()
Ejemplo n.º 2
0
def read_parallel_corpus(src_reader: InputReader,
                         trg_reader: InputReader,
                         src_file: str,
                         trg_file: str,
                         batcher: batchers.Batcher = None,
                         sample_sents=None,
                         max_num_sents=None,
                         max_src_len=None,
                         max_trg_len=None):
    """
  A utility function to read a parallel corpus.

  Args:
    src_reader (InputReader):
    trg_reader (InputReader):
    src_file (str):
    trg_file (str):
    batcher (Batcher):
    sample_sents (int): if not None, denote the number of sents that should be randomly chosen from all available sents.
    max_num_sents (int): if not None, read only the first this many sents
    max_src_len (int): skip pair if src side is too long
    max_trg_len (int): skip pair if trg side is too long

  Returns:
    A tuple of (src_data, trg_data, src_batches, trg_batches) where ``*_batches = *_data`` if ``batcher=None``
  """
    src_data = []
    trg_data = []
    if sample_sents:
        logger.info(
            f"Starting to read {sample_sents} parallel sentences of {src_file} and {trg_file}"
        )
        src_len = src_reader.count_sents(src_file)
        trg_len = trg_reader.count_sents(trg_file)
        if src_len != trg_len:
            raise RuntimeError(
                f"training src sentences don't match trg sentences: {src_len} != {trg_len}!"
            )
        if max_num_sents and max_num_sents < src_len:
            src_len = trg_len = max_num_sents
        filter_ids = np.random.choice(src_len, sample_sents, replace=False)
    else:
        logger.info(f"Starting to read {src_file} and {trg_file}")
        filter_ids = None
        src_len, trg_len = 0, 0
    src_train_iterator = src_reader.read_sents(src_file, filter_ids)
    trg_train_iterator = trg_reader.read_sents(trg_file, filter_ids)
    for src_sent, trg_sent in zip_longest(src_train_iterator,
                                          trg_train_iterator):
        if src_sent is None or trg_sent is None:
            raise RuntimeError(
                f"training src sentences don't match trg sentences: {src_len or src_reader.count_sents(src_file)} != {trg_len or trg_reader.count_sents(trg_file)}!"
            )
        if max_num_sents and (max_num_sents <= len(src_data)):
            break
        src_len_ok = max_src_len is None or src_sent.sent_len() <= max_src_len
        trg_len_ok = max_trg_len is None or trg_sent.sent_len() <= max_trg_len
        if src_len_ok and trg_len_ok:
            src_data.append(src_sent)
            trg_data.append(trg_sent)

    logger.info(
        f"Done reading {src_file} and {trg_file}. Packing into batches.")

    # Pack batches
    if batcher is not None:
        src_batches, trg_batches = batcher.pack(src_data, trg_data)
    else:
        src_batches, trg_batches = src_data, trg_data

    logger.info(f"Done packing batches.")

    return src_data, trg_data, src_batches, trg_batches