Ejemplo n.º 1
0
def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Translate using a pre-trained model')
    parser.add_argument('model', help='a model previously trained with train.py')
    parser.add_argument('--batch_size', type=int, default=50, help='the batch size (defaults to 50)')
    parser.add_argument('--beam_size', type=int, default=12, help='the beam size (defaults to 12, 0 for greedy search)')
    parser.add_argument('--encoding', default='utf-8', help='the character encoding for input/output (defaults to utf-8)')
    parser.add_argument('-i', '--input', default=sys.stdin.fileno(), help='the input file (defaults to stdin)')
    parser.add_argument('-o', '--output', default=sys.stdout.fileno(), help='the output file (defaults to stdout)')
    parser.add_argument('--noise',type=float,default=0.5)
    parser.add_argument('--pass_att',action='store_true',default=False)
    parser.add_argument('--src_embeddings',default=None,help='common intersection source embeddings')
    parser.add_argument('--cutoff', type=int, default=None, help='cutoff for source embeddings above')
    parser.add_argument('--cat_embedds',help='use torch.load to load src and trg ')
    parser.add_argument('--ncontrol',type=int,default=0,help='control number given while using the decoder')
    args = parser.parse_args()

    try:
        t = torch.load(args.model)
    except Exception:
        t = torch.load(args.model,map_location={'cuda:1':'cuda:0'})

    # Translate sentences
    end = False
    fin = open(args.input, encoding=args.encoding, errors='surrogateescape')
    fout = open(args.output, mode='w', encoding=args.encoding, errors='surrogateescape')
    if args.src_embeddings is not None:
        encoder_embeddings,src_dictionary = data.read_embeddings(open(args.src_embeddings,'r'),threshold=args.cutoff) 
        encoder_embeddings = gpu(encoder_embeddings)
        t.decoder_embeddings=gpu(t.decoder_embeddings)
        t.generator=gpu(t.generator)
        t.encoder=gpu(t.encoder)
        t.decoder=gpu(t.decoder)

        translator_new = Translator(encoder_embeddings,t.decoder_embeddings,t.generator,src_dictionary,\
        t.trg_dictionary,t.encoder,t.decoder,t.denoising,t.device)
    else:
        t.device=gpu
        t.encoder=gpu(t.encoder)
        t.decoder=gpu(t.decoder)
        t.encoder_embeddings=gpu(t.encoder_embeddings)
        t.decoder_embeddings=gpu(t.decoder_embeddings)
        t.generator=gpu(t.generator)
        t.src_dictionary = data.Dictionary(t.src_dictionary.id2word[1:])
        t.trg_dictionary = data.Dictionary(t.trg_dictionary.id2word[1:])
        translator_new = Translator(t.encoder_embeddings,t.decoder_embeddings,t.generator,t.src_dictionary,\
        t.trg_dictionary,t.encoder,t.decoder,t.denoising,t.device)
    # print (translator_new.denoising)
    # exit(0)
    while not end:
        batch = []
        while len(batch) < args.batch_size and not end:
            line = fin.readline()
            if not line:
                end = True
            else:
                batch.append(line)
        if args.beam_size <= 0 and len(batch) > 0:
            for translation in translator_new.greedy(batch, train=False):
                print(translation, file=fout)
        elif len(batch) > 0:
            translations = translator_new.beam_search(batch, train=False, beam_size=12, max_ratio=2,rnk=6,noiseratio=args.noise,pass_att=args.pass_att,ncontrol=args.ncontrol if args.ncontrol!=0 else None)
            print(translations)
            if args.pass_att:
                for translation1,trans2 in translations:
                    print(translation1,trans2, file=fout)
            else:
                for translation in translations:
                    print(translation, file=fout)
        fout.flush()
    fin.close()
    fout.close()
Ejemplo n.º 2
0
def main_train():
    # Build argument parser
    parser = argparse.ArgumentParser(
        description='Train a neural machine translation model')

    # Training corpus
    corpora_group = parser.add_argument_group(
        'training corpora',
        'Corpora related arguments; specify either monolingual or parallel training corpora (or both)'
    )
    corpora_group.add_argument('--src',
                               help='the source language monolingual corpus')
    corpora_group.add_argument('--trg',
                               help='the target language monolingual corpus')
    corpora_group.add_argument('--src2trg',
                               metavar=('SRC', 'TRG'),
                               nargs=2,
                               help='the source-to-target parallel corpus')
    corpora_group.add_argument('--trg2src',
                               metavar=('TRG', 'SRC'),
                               nargs=2,
                               help='the target-to-source parallel corpus')
    corpora_group.add_argument(
        '--max_sentence_length',
        type=int,
        default=50,
        help='the maximum sentence length for training (defaults to 50)')
    corpora_group.add_argument(
        '--cache',
        type=int,
        default=1000000,
        help=
        'the cache size (in sentences) for corpus reading (defaults to 1000000)'
    )
    corpora_group.add_argument(
        '--cache_parallel',
        type=int,
        default=None,
        help='the cache size (in sentences) for parallel corpus reading')

    # Embeddings/vocabulary
    embedding_group = parser.add_argument_group(
        'embeddings',
        'Embedding related arguments; either give pre-trained cross-lingual embeddings, or a vocabulary and embedding dimensionality to randomly initialize them'
    )
    embedding_group.add_argument('--src_embeddings',
                                 help='the source language word embeddings')
    embedding_group.add_argument('--trg_embeddings',
                                 help='the target language word embeddings')
    embedding_group.add_argument('--src_vocabulary',
                                 help='the source language vocabulary')
    embedding_group.add_argument('--trg_vocabulary',
                                 help='the target language vocabulary')
    embedding_group.add_argument('--embedding_size',
                                 type=int,
                                 default=0,
                                 help='the word embedding size')
    embedding_group.add_argument('--cutoff',
                                 type=int,
                                 default=0,
                                 help='cutoff vocabulary to the given size')
    embedding_group.add_argument(
        '--learn_encoder_embeddings',
        action='store_true',
        help=
        'learn the encoder embeddings instead of using the pre-trained ones')
    embedding_group.add_argument(
        '--fixed_decoder_embeddings',
        action='store_true',
        help=
        'use fixed embeddings in the decoder instead of learning them from scratch'
    )
    embedding_group.add_argument(
        '--fixed_generator',
        action='store_true',
        help=
        'use fixed embeddings in the output softmax instead of learning it from scratch'
    )

    # Architecture
    architecture_group = parser.add_argument_group(
        'architecture', 'Architecture related arguments')
    architecture_group.add_argument(
        '--layers',
        type=int,
        default=2,
        help='the number of encoder/decoder layers (defaults to 2)')
    architecture_group.add_argument(
        '--hidden',
        type=int,
        default=600,
        help='the number of dimensions for the hidden layer (defaults to 600)')
    architecture_group.add_argument('--disable_bidirectional',
                                    action='store_true',
                                    help='use a single direction encoder')
    architecture_group.add_argument('--disable_denoising',
                                    action='store_true',
                                    help='disable random swaps')
    architecture_group.add_argument('--disable_backtranslation',
                                    action='store_true',
                                    help='disable backtranslation')

    # Optimization
    optimization_group = parser.add_argument_group(
        'optimization', 'Optimization related arguments')
    optimization_group.add_argument('--batch',
                                    type=int,
                                    default=50,
                                    help='the batch size (defaults to 50)')
    optimization_group.add_argument(
        '--learning_rate',
        type=float,
        default=0.0002,
        help='the global learning rate (defaults to 0.0002)')
    optimization_group.add_argument(
        '--dropout',
        metavar='PROB',
        type=float,
        default=0.3,
        help='dropout probability for the encoder/decoder (defaults to 0.3)')
    optimization_group.add_argument(
        '--param_init',
        metavar='RANGE',
        type=float,
        default=0.1,
        help=
        'uniform initialization in the specified range (defaults to 0.1,  0 for module specific default initialization)'
    )
    optimization_group.add_argument(
        '--iterations',
        type=int,
        default=300000,
        help='the number of training iterations (defaults to 300000)')

    # Model saving
    saving_group = parser.add_argument_group(
        'model saving', 'Arguments for saving the trained model')
    saving_group.add_argument('--save',
                              metavar='PREFIX',
                              help='save models with the given prefix')
    saving_group.add_argument('--save_interval',
                              type=int,
                              default=0,
                              help='save intermediate models at this interval')

    # Logging/validation
    logging_group = parser.add_argument_group(
        'logging', 'Logging and validation arguments')
    logging_group.add_argument('--log_interval',
                               type=int,
                               default=1000,
                               help='log at this interval (defaults to 1000)')
    logging_group.add_argument('--validation',
                               nargs='+',
                               default=(),
                               help='use parallel corpora for validation')
    logging_group.add_argument(
        '--validation_directions',
        nargs='+',
        default=['src2src', 'trg2trg', 'src2trg', 'trg2src'],
        help='validation directions')
    logging_group.add_argument(
        '--validation_output',
        metavar='PREFIX',
        help='output validation translations with the given prefix')
    logging_group.add_argument('--validation_beam_size',
                               type=int,
                               default=0,
                               help='use beam search for validation')

    # Other
    parser.add_argument(
        '--encoding',
        default='utf-8',
        help='the character encoding for input/output (defaults to utf-8)')
    parser.add_argument('--cuda',
                        default=False,
                        action='store_true',
                        help='use cuda')

    # Parse arguments
    args = parser.parse_args()

    # Validate arguments
    if args.src_embeddings is None and args.src_vocabulary is None or args.trg_embeddings is None and args.trg_vocabulary is None:
        print('Either an embedding or a vocabulary file must be provided')
        sys.exit(-1)
    if (args.src_embeddings is None or args.trg_embeddings is None) and (
            not args.learn_encoder_embeddings or args.fixed_decoder_embeddings
            or args.fixed_generator):
        print(
            'Either provide pre-trained word embeddings or set to learn the encoder/decoder embeddings and generator'
        )
        sys.exit(-1)
    if args.src_embeddings is None and args.trg_embeddings is None and args.embedding_size == 0:
        print(
            'Either provide pre-trained word embeddings or the embedding size')
        sys.exit(-1)
    if len(args.validation) % 2 != 0:
        print(
            '--validation should have an even number of arguments (one pair for each validation set)'
        )
        sys.exit(-1)

    # Select device
    device = devices.gpu if args.cuda else devices.cpu

    # Create optimizer lists
    src2src_optimizers = []
    trg2trg_optimizers = []
    src2trg_optimizers = []
    trg2src_optimizers = []

    # Method to create a module optimizer and add it to the given lists
    def add_optimizer(module, directions=()):
        if args.param_init != 0.0:
            for param in module.parameters():
                param.data.uniform_(-args.param_init, args.param_init)
        optimizer = torch.optim.Adam(module.parameters(),
                                     lr=args.learning_rate)
        for direction in directions:
            direction.append(optimizer)
        return optimizer

    # Load word embeddings
    src_words = trg_words = src_embeddings = trg_embeddings = src_dictionary = trg_dictionary = None
    embedding_size = args.embedding_size
    if args.src_vocabulary is not None:
        f = open(args.src_vocabulary,
                 encoding=args.encoding,
                 errors='surrogateescape')
        src_words = [line.strip() for line in f.readlines()]
        if args.cutoff > 0:
            src_words = src_words[:args.cutoff]
        src_dictionary = data.Dictionary(src_words)
    if args.trg_vocabulary is not None:
        f = open(args.trg_vocabulary,
                 encoding=args.encoding,
                 errors='surrogateescape')
        trg_words = [line.strip() for line in f.readlines()]
        if args.cutoff > 0:
            trg_words = trg_words[:args.cutoff]
        trg_dictionary = data.Dictionary(trg_words)
    if args.src_embeddings is not None:
        f = open(args.src_embeddings,
                 encoding=args.encoding,
                 errors='surrogateescape')
        src_embeddings, src_dictionary = data.read_embeddings(
            f, args.cutoff, src_words)
        src_embeddings = device(src_embeddings)
        src_embeddings.requires_grad = False
        if embedding_size == 0:
            embedding_size = src_embeddings.weight.data.size()[1]
        if embedding_size != src_embeddings.weight.data.size()[1]:
            print('Embedding sizes do not match')
            sys.exit(-1)
    if args.trg_embeddings is not None:
        trg_file = open(args.trg_embeddings,
                        encoding=args.encoding,
                        errors='surrogateescape')
        trg_embeddings, trg_dictionary = data.read_embeddings(
            trg_file, args.cutoff, trg_words)
        trg_embeddings = device(trg_embeddings)
        trg_embeddings.requires_grad = False
        if embedding_size == 0:
            embedding_size = trg_embeddings.weight.data.size()[1]
        if embedding_size != trg_embeddings.weight.data.size()[1]:
            print('Embedding sizes do not match')
            sys.exit(-1)
    if args.learn_encoder_embeddings:
        src_encoder_embeddings = device(
            data.random_embeddings(src_dictionary.size(), embedding_size))
        trg_encoder_embeddings = device(
            data.random_embeddings(trg_dictionary.size(), embedding_size))
        add_optimizer(src_encoder_embeddings,
                      (src2src_optimizers, src2trg_optimizers))
        add_optimizer(trg_encoder_embeddings,
                      (trg2trg_optimizers, trg2src_optimizers))
    else:
        src_encoder_embeddings = src_embeddings
        trg_encoder_embeddings = trg_embeddings
    if args.fixed_decoder_embeddings:
        src_decoder_embeddings = src_embeddings
        trg_decoder_embeddings = trg_embeddings
    else:
        src_decoder_embeddings = device(
            data.random_embeddings(src_dictionary.size(), embedding_size))
        trg_decoder_embeddings = device(
            data.random_embeddings(trg_dictionary.size(), embedding_size))
        add_optimizer(src_decoder_embeddings,
                      (src2src_optimizers, trg2src_optimizers))
        add_optimizer(trg_decoder_embeddings,
                      (trg2trg_optimizers, src2trg_optimizers))
    if args.fixed_generator:
        src_embedding_generator = device(
            EmbeddingGenerator(hidden_size=args.hidden,
                               embedding_size=embedding_size))
        trg_embedding_generator = device(
            EmbeddingGenerator(hidden_size=args.hidden,
                               embedding_size=embedding_size))
        add_optimizer(src_embedding_generator,
                      (src2src_optimizers, trg2src_optimizers))
        add_optimizer(trg_embedding_generator,
                      (trg2trg_optimizers, src2trg_optimizers))
        src_generator = device(
            WrappedEmbeddingGenerator(src_embedding_generator, src_embeddings))
        trg_generator = device(
            WrappedEmbeddingGenerator(trg_embedding_generator, trg_embeddings))
    else:
        src_generator = device(
            LinearGenerator(args.hidden, src_dictionary.size()))
        trg_generator = device(
            LinearGenerator(args.hidden, trg_dictionary.size()))
        add_optimizer(src_generator, (src2src_optimizers, trg2src_optimizers))
        add_optimizer(trg_generator, (trg2trg_optimizers, src2trg_optimizers))

    # Build encoder
    encoder = device(
        RNNEncoder(embedding_size=embedding_size,
                   hidden_size=args.hidden,
                   bidirectional=not args.disable_bidirectional,
                   layers=args.layers,
                   dropout=args.dropout))
    add_optimizer(encoder, (src2src_optimizers, trg2trg_optimizers,
                            src2trg_optimizers, trg2src_optimizers))

    # Build decoders
    src_decoder = device(
        RNNAttentionDecoder(embedding_size=embedding_size,
                            hidden_size=args.hidden,
                            layers=args.layers,
                            dropout=args.dropout))
    trg_decoder = device(
        RNNAttentionDecoder(embedding_size=embedding_size,
                            hidden_size=args.hidden,
                            layers=args.layers,
                            dropout=args.dropout))
    add_optimizer(src_decoder, (src2src_optimizers, trg2src_optimizers))
    add_optimizer(trg_decoder, (trg2trg_optimizers, src2trg_optimizers))

    # Build translators
    src2src_translator = Translator(encoder_embeddings=src_encoder_embeddings,
                                    decoder_embeddings=src_decoder_embeddings,
                                    generator=src_generator,
                                    src_dictionary=src_dictionary,
                                    trg_dictionary=src_dictionary,
                                    encoder=encoder,
                                    decoder=src_decoder,
                                    denoising=not args.disable_denoising,
                                    device=device)
    src2trg_translator = Translator(encoder_embeddings=src_encoder_embeddings,
                                    decoder_embeddings=trg_decoder_embeddings,
                                    generator=trg_generator,
                                    src_dictionary=src_dictionary,
                                    trg_dictionary=trg_dictionary,
                                    encoder=encoder,
                                    decoder=trg_decoder,
                                    denoising=not args.disable_denoising,
                                    device=device)
    trg2trg_translator = Translator(encoder_embeddings=trg_encoder_embeddings,
                                    decoder_embeddings=trg_decoder_embeddings,
                                    generator=trg_generator,
                                    src_dictionary=trg_dictionary,
                                    trg_dictionary=trg_dictionary,
                                    encoder=encoder,
                                    decoder=trg_decoder,
                                    denoising=not args.disable_denoising,
                                    device=device)
    trg2src_translator = Translator(encoder_embeddings=trg_encoder_embeddings,
                                    decoder_embeddings=src_decoder_embeddings,
                                    generator=src_generator,
                                    src_dictionary=trg_dictionary,
                                    trg_dictionary=src_dictionary,
                                    encoder=encoder,
                                    decoder=src_decoder,
                                    denoising=not args.disable_denoising,
                                    device=device)

    # Build trainers
    trainers = []
    src2src_trainer = trg2trg_trainer = src2trg_trainer = trg2src_trainer = None
    srcback2trg_trainer = trgback2src_trainer = None
    if args.src is not None:
        f = open(args.src, encoding=args.encoding, errors='surrogateescape')
        corpus = data.CorpusReader(
            f,
            max_sentence_length=args.max_sentence_length,
            cache_size=args.cache)
        src2src_trainer = Trainer(translator=src2src_translator,
                                  optimizers=src2src_optimizers,
                                  corpus=corpus,
                                  batch_size=args.batch)
        trainers.append(src2src_trainer)
        if not args.disable_backtranslation:
            trgback2src_trainer = Trainer(
                translator=trg2src_translator,
                optimizers=trg2src_optimizers,
                corpus=data.BacktranslatorCorpusReader(
                    corpus=corpus, translator=src2trg_translator),
                batch_size=args.batch)
            trainers.append(trgback2src_trainer)
    if args.trg is not None:
        f = open(args.trg, encoding=args.encoding, errors='surrogateescape')
        corpus = data.CorpusReader(
            f,
            max_sentence_length=args.max_sentence_length,
            cache_size=args.cache)
        trg2trg_trainer = Trainer(translator=trg2trg_translator,
                                  optimizers=trg2trg_optimizers,
                                  corpus=corpus,
                                  batch_size=args.batch,
                                  loss_mult=0.01)
        trainers.append(trg2trg_trainer)
        if not args.disable_backtranslation:
            srcback2trg_trainer = Trainer(
                translator=src2trg_translator,
                optimizers=src2trg_optimizers,
                corpus=data.BacktranslatorCorpusReader(
                    corpus=corpus, translator=trg2src_translator),
                batch_size=args.batch,
                loss_mult=0.01)
            trainers.append(srcback2trg_trainer)
    if args.src2trg is not None:
        f1 = open(args.src2trg[0],
                  encoding=args.encoding,
                  errors='surrogateescape')
        f2 = open(args.src2trg[1],
                  encoding=args.encoding,
                  errors='surrogateescape')
        corpus = data.CorpusReader(
            f1,
            f2,
            max_sentence_length=args.max_sentence_length,
            cache_size=args.cache
            if args.cache_parallel is None else args.cache_parallel)
        src2trg_trainer = Trainer(translator=src2trg_translator,
                                  optimizers=src2trg_optimizers,
                                  corpus=corpus,
                                  batch_size=args.batch)
        trainers.append(src2trg_trainer)
    if args.trg2src is not None:
        f1 = open(args.trg2src[0],
                  encoding=args.encoding,
                  errors='surrogateescape')
        f2 = open(args.trg2src[1],
                  encoding=args.encoding,
                  errors='surrogateescape')
        corpus = data.CorpusReader(
            f1,
            f2,
            max_sentence_length=args.max_sentence_length,
            cache_size=args.cache
            if args.cache_parallel is None else args.cache_parallel)
        trg2src_trainer = Trainer(translator=trg2src_translator,
                                  optimizers=trg2src_optimizers,
                                  corpus=corpus,
                                  batch_size=args.batch)
        trainers.append(trg2src_trainer)

    # Build validators
    src2src_validators = []
    trg2trg_validators = []
    src2trg_validators = []
    trg2src_validators = []
    for i in range(0, len(args.validation), 2):
        src_validation = open(args.validation[i],
                              encoding=args.encoding,
                              errors='surrogateescape').readlines()
        trg_validation = open(args.validation[i + 1],
                              encoding=args.encoding,
                              errors='surrogateescape').readlines()
        if len(src_validation) != len(trg_validation):
            print('Validation sizes do not match')
            sys.exit(-1)
        map(lambda x: x.strip(), src_validation)
        map(lambda x: x.strip(), trg_validation)
        if 'src2src' in args.validation_directions:
            src2src_validators.append(
                Validator(src2src_translator, src_validation, src_validation,
                          args.batch, args.validation_beam_size))
        if 'trg2trg' in args.validation_directions:
            trg2trg_validators.append(
                Validator(trg2trg_translator, trg_validation, trg_validation,
                          args.batch, args.validation_beam_size))
        if 'src2trg' in args.validation_directions:
            src2trg_validators.append(
                Validator(src2trg_translator, src_validation, trg_validation,
                          args.batch, args.validation_beam_size))
        if 'trg2src' in args.validation_directions:
            trg2src_validators.append(
                Validator(trg2src_translator, trg_validation, src_validation,
                          args.batch, args.validation_beam_size))

    # Build loggers
    loggers = []
    src2src_output = trg2trg_output = src2trg_output = trg2src_output = None
    if args.validation_output is not None:
        src2src_output = '{0}.src2src'.format(args.validation_output)
        trg2trg_output = '{0}.trg2trg'.format(args.validation_output)
        src2trg_output = '{0}.src2trg'.format(args.validation_output)
        trg2src_output = '{0}.trg2src'.format(args.validation_output)
    loggers.append(
        Logger('Source to target (backtranslation)', srcback2trg_trainer, [],
               None, args.encoding))
    loggers.append(
        Logger('Target to source (backtranslation)', trgback2src_trainer, [],
               None, args.encoding))
    loggers.append(
        Logger('Source to source', src2src_trainer, src2src_validators,
               src2src_output, args.encoding))
    loggers.append(
        Logger('Target to target', trg2trg_trainer, trg2trg_validators,
               trg2trg_output, args.encoding))
    loggers.append(
        Logger('Source to target', src2trg_trainer, src2trg_validators,
               src2trg_output, args.encoding))
    loggers.append(
        Logger('Target to source', trg2src_trainer, trg2src_validators,
               trg2src_output, args.encoding))

    # Method to save models
    def save_models(name):
        torch.save(src2src_translator,
                   '{0}.{1}.src2src.pth'.format(args.save, name))
        torch.save(trg2trg_translator,
                   '{0}.{1}.trg2trg.pth'.format(args.save, name))
        torch.save(src2trg_translator,
                   '{0}.{1}.src2trg.pth'.format(args.save, name))
        torch.save(trg2src_translator,
                   '{0}.{1}.trg2src.pth'.format(args.save, name))

    # Training
    for step in range(1, args.iterations + 1):
        for trainer in trainers:
            trainer.step()

        if args.save is not None and args.save_interval > 0 and step % args.save_interval == 0:
            save_models('it{0}'.format(step))

        if step % args.log_interval == 0:
            print()
            print('STEP {0} x {1}'.format(step, args.batch))
            for logger in loggers:
                logger.log(step)

        step += 1

    save_models('final')
Ejemplo n.º 3
0
def main_train():
    # Build argument parser
    parser = argparse.ArgumentParser(description='Train a neural machine translation model')
    parser.add_argument("--config")

    '''
    # Training corpus
    corpora_group = parser.add_argument_group('training corpora', 'Corpora related arguments; specify either monolingual or parallel training corpora (or both)')
    corpora_group.add_argument('--src', help='the source language monolingual corpus')
    corpora_group.add_argument('--trg', help='the target language monolingual corpus')
    corpora_group.add_argument('--src2trg', metavar=('SRC', 'TRG'), nargs=2, help='the source-to-target parallel corpus')
    corpora_group.add_argument('--trg2src', metavar=('TRG', 'SRC'), nargs=2, help='the target-to-source parallel corpus')
    corpora_group.add_argument('--max_sentence_length', type=int, default=50, help='the maximum sentence length for training (defaults to 50)')
    corpora_group.add_argument('--cache', type=int, default=1000000, help='the cache size (in sentences) for corpus reading (defaults to 1000000)')
    corpora_group.add_argument('--cache_parallel', type=int, default=None, help='the cache size (in sentences) for parallel corpus reading')

    # Embeddings/vocabulary
    embedding_group = parser.add_argument_group('embeddings', 'Embedding related arguments; either give pre-trained cross-lingual embeddings, or a vocabulary and embedding dimensionality to randomly initialize them')
    embedding_group.add_argument('--src_embeddings', help='the source language word embeddings')
    embedding_group.add_argument('--trg_embeddings', help='the target language word embeddings')
    embedding_group.add_argument('--src_vocabulary', help='the source language vocabulary')
    embedding_group.add_argument('--trg_vocabulary', help='the target language vocabulary')
    embedding_group.add_argument('--embedding_size', type=int, default=0, help='the word embedding size')
    embedding_group.add_argument('--cutoff', type=int, default=0, help='cutoff vocabulary to the given size')
    embedding_group.add_argument('--learn_encoder_embeddings', action='store_true', help='learn the encoder embeddings instead of using the pre-trained ones')
    embedding_group.add_argument('--fixed_decoder_embeddings', action='store_true', help='use fixed embeddings in the decoder instead of learning them from scratch')
    embedding_group.add_argument('--fixed_generator', action='store_true', help='use fixed embeddings in the output softmax instead of learning it from scratch')

    # Architecture
    architecture_group = parser.add_argument_group('architecture', 'Architecture related arguments')
    architecture_group.add_argument('--layers', type=int, default=2, help='the number of encoder/decoder layers (defaults to 2)')
    architecture_group.add_argument('--hidden', type=int, default=600, help='the number of dimensions for the hidden layer (defaults to 600)')
    architecture_group.add_argument('--disable_bidirectional', action='store_true', help='use a single direction encoder')
    architecture_group.add_argument('--disable_denoising', action='store_true', help='disable random swaps')
    architecture_group.add_argument('--disable_backtranslation', action='store_true', help='disable backtranslation')

    # Optimization
    optimization_group = parser.add_argument_group('optimization', 'Optimization related arguments')
    optimization_group.add_argument('--batch', type=int, default=50, help='the batch size (defaults to 50)')
    optimization_group.add_argument('--learning_rate', type=float, default=0.0002, help='the global learning rate (defaults to 0.0002)')
    optimization_group.add_argument('--dropout', metavar='PROB', type=float, default=0.3, help='dropout probability for the encoder/decoder (defaults to 0.3)')
    optimization_group.add_argument('--param_init', metavar='RANGE', type=float, default=0.1, help='uniform initialization in the specified range (defaults to 0.1,  0 for module specific default initialization)')
    optimization_group.add_argument('--iterations', type=int, default=300000, help='the number of training iterations (defaults to 300000)')

    # Model saving
    saving_group = parser.add_argument_group('model saving', 'Arguments for saving the trained model')
    saving_group.add_argument('--save', metavar='PREFIX', help='save models with the given prefix')
    saving_group.add_argument('--save_interval', type=int, default=0, help='save intermediate models at this interval')
    saving_group.add_argument('--save_train_interval', type=int, default=0, help='save intermediate trainers at this interval')

    # Model loading
    loading_group = parser.add_argument_group('model loading', 'Arguments for load the trained model')
    loading_group.add_argument('--load', metavar='path', help='load models with the path')
    loading_group.add_argument('--is_load_model', type=bool, default=False, help='whether load the model')
    loading_group.add_argument('--is_load_trainers', type=bool, default=False, help='whether load the trainers')

    # loading_group.add_argument('--load-iter', type=int, default=0, help='iter load the model')
    # loading_group.add_argument('--load-epoch', type=int, default=1, help='epoch save the model')



    # Logging/validation
    logging_group = parser.add_argument_group('logging', 'Logging and validation arguments')
    logging_group.add_argument('--log_interval', type=int, default=1000, help='log at this interval (defaults to 1000)')
    logging_group.add_argument('--validation', nargs='+', default=(), help='use parallel corpora for validation')
    logging_group.add_argument('--validation_directions', nargs='+', default=['src2src', 'trg2trg', 'src2trg', 'trg2src'], help='validation directions')
    logging_group.add_argument('--validation_output', metavar='PREFIX', help='output validation translations with the given prefix')
    logging_group.add_argument('--validation_beam_size', type=int, default=0, help='use beam search for validation')

    # Other
    parser.add_argument('--encoding', default='utf-8', help='the character encoding for input/output (defaults to utf-8)')
    parser.add_argument('--cuda', default=False, action='store_true', help='use cuda')
    
    # Parse arguments
    args = parser.parse_args()
    '''

    args = parser.parse_args()
    args = load_config(args.config)

    #improve the peformance
    torch.backends.cudnn.benchmark = True

    print(args)



    # Validate arguments
    if args.src_embeddings is None and args.src_vocabulary is None or args.trg_embeddings is None and args.trg_vocabulary is None:
        print('Either an embedding or a vocabulary file must be provided')
        sys.exit(-1)
    if (args.src_embeddings is None or args.trg_embeddings is None) and (not args.learn_encoder_embeddings or args.fixed_decoder_embeddings or args.fixed_generator):
        print('Either provide pre-trained word embeddings or set to learn the encoder/decoder embeddings and generator')
        sys.exit(-1)
    if args.src_embeddings is None and args.trg_embeddings is None and args.embedding_size == 0:
        print('Either provide pre-trained word embeddings or the embedding size')
        sys.exit(-1)
    if len(args.validation) % 2 != 0:
        print('--validation should have an even number of arguments (one pair for each validation set)')
        sys.exit(-1)

    # Select device
    device = devices.gpu if args.cuda else devices.cpu

    #load-ckpt

    ##load model
    if args.is_load_model:
        print("load_models from ckpt")

        train = torch.load(os.path.join('trainers', 'trainers.pth'))
        src2src_optimizers = train['src2src_optimizers']
        trg2trg_optimizers = train['trg2trg_optimizers']
        src2trg_optimizers = train['src2trg_optimizers']
        trg2src_optimizers = train['trg2src_optimizers']
        step = train['step']

        src2src_translator = torch.load(os.path.join('models-ckpt', 'src2src.pth'))
        trg2trg_translator = torch.load(os.path.join('models-ckpt', 'trg2trg.pth'))
        src2trg_translator = torch.load(os.path.join('models-ckpt', 'src2trg.pth'))
        trg2src_translator = torch.load(os.path.join('models-ckpt', 'trg2src.pth'))

    else:
        # Create optimizer lists
        src2src_optimizers = []
        trg2trg_optimizers = []
        src2trg_optimizers = []
        trg2src_optimizers = []

        # Method to create a module optimizer and add it to the given lists
        def add_optimizer(module, directions=()):
            if args.param_init != 0.0:
                for param in module.parameters():
                    param.data.uniform_(-args.param_init, args.param_init)
            optimizer = torch.optim.Adam(module.parameters(), lr=args.learning_rate)
            for direction in directions:
                direction.append(optimizer)
            return optimizer

        # Load word embeddings
        src_words = trg_words = src_embeddings = trg_embeddings = src_dictionary = trg_dictionary = None
        embedding_size = args.embedding_size
        if args.src_vocabulary is not None:
            f = open(args.src_vocabulary, encoding=args.encoding, errors='surrogateescape')
            src_words = [line.strip() for line in f.readlines()]
            if args.cutoff > 0:
                src_words = src_words[:args.cutoff]
            src_dictionary = data.Dictionary(src_words)
        if args.trg_vocabulary is not None:
            f = open(args.trg_vocabulary, encoding=args.encoding, errors='surrogateescape')
            trg_words = [line.strip() for line in f.readlines()]
            if args.cutoff > 0:
                trg_words = trg_words[:args.cutoff]
            trg_dictionary = data.Dictionary(trg_words)
        if args.src_embeddings is not None:
            f = open(args.src_embeddings, encoding=args.encoding, errors='surrogateescape')
            src_embeddings, src_dictionary = data.read_embeddings(f, args.cutoff, src_words)
            src_embeddings = device(src_embeddings)
            src_embeddings.requires_grad = False
            if embedding_size == 0:
                embedding_size = src_embeddings.weight.data.size()[1]
            if embedding_size != src_embeddings.weight.data.size()[1]:
                print('Embedding sizes do not match %s', src_embeddings.weight.data.size()[1])
                sys.exit(-1)
        if args.trg_embeddings is not None:
            trg_file = open(args.trg_embeddings, encoding=args.encoding, errors='surrogateescape')
            trg_embeddings, trg_dictionary = data.read_embeddings(trg_file, args.cutoff, trg_words)
            trg_embeddings = device(trg_embeddings)
            trg_embeddings.requires_grad = False
            if embedding_size == 0:
                embedding_size = trg_embeddings.weight.data.size()[1]
            if embedding_size != trg_embeddings.weight.data.size()[1] :
                print('Embedding sizes do not match: %s', trg_embeddings.weight.data.size()[1])
                sys.exit(-1)
        if args.learn_encoder_embeddings:
            src_encoder_embeddings = device(data.random_embeddings(src_dictionary.size(), embedding_size))
            trg_encoder_embeddings = device(data.random_embeddings(trg_dictionary.size(), embedding_size))
            add_optimizer(src_encoder_embeddings, (src2src_optimizers, src2trg_optimizers))
            add_optimizer(trg_encoder_embeddings, (trg2trg_optimizers, trg2src_optimizers))
        else:
            src_encoder_embeddings = src_embeddings
            trg_encoder_embeddings = trg_embeddings
        if args.fixed_decoder_embeddings:
            src_decoder_embeddings = src_embeddings
            trg_decoder_embeddings = trg_embeddings
        else:
            src_decoder_embeddings = device(data.random_embeddings(src_dictionary.size(), embedding_size))
            trg_decoder_embeddings = device(data.random_embeddings(trg_dictionary.size(), embedding_size))
            add_optimizer(src_decoder_embeddings, (src2src_optimizers, trg2src_optimizers))
            add_optimizer(trg_decoder_embeddings, (trg2trg_optimizers, src2trg_optimizers))
        if args.fixed_generator:
            src_embedding_generator = device(EmbeddingGenerator(hidden_size=args.hidden, embedding_size=embedding_size))
            trg_embedding_generator = device(EmbeddingGenerator(hidden_size=args.hidden, embedding_size=embedding_size))
            add_optimizer(src_embedding_generator, (src2src_optimizers, trg2src_optimizers))
            add_optimizer(trg_embedding_generator, (trg2trg_optimizers, src2trg_optimizers))
            src_generator = device(WrappedEmbeddingGenerator(src_embedding_generator, src_embeddings))
            trg_generator = device(WrappedEmbeddingGenerator(trg_embedding_generator, trg_embeddings))
        else:
            src_generator = device(LinearGenerator(args.hidden, src_dictionary.size()))
            trg_generator = device(LinearGenerator(args.hidden, trg_dictionary.size()))
            add_optimizer(src_generator, (src2src_optimizers, trg2src_optimizers))
            add_optimizer(trg_generator, (trg2trg_optimizers, src2trg_optimizers))

        # Build encoder
        encoder = device(RNNEncoder(embedding_size=embedding_size, hidden_size=args.hidden,
                                    bidirectional=not args.disable_bidirectional, layers=args.layers, dropout=args.dropout))
        add_optimizer(encoder, (src2src_optimizers, trg2trg_optimizers, src2trg_optimizers, trg2src_optimizers))

        # Build decoders
        src_decoder = device(RNNAttentionDecoder(embedding_size=embedding_size, hidden_size=args.hidden, layers=args.layers, dropout=args.dropout))
        trg_decoder = device(RNNAttentionDecoder(embedding_size=embedding_size, hidden_size=args.hidden, layers=args.layers, dropout=args.dropout))
		src_decoder._share(trg_decoder.get_attention(), trg_decoder.get_net(), update_attnF = True)
        # trg_decoder._share(src_decoder.get_attention(), src_decoder.get_net(), update_attnF = True)
        add_optimizer(src_decoder, (src2src_optimizers, trg2src_optimizers))
        add_optimizer(trg_decoder, (trg2trg_optimizers, src2trg_optimizers))



        # Build translators
        src2src_translator = Translator(encoder_embeddings=src_encoder_embeddings,
                                        decoder_embeddings=src_decoder_embeddings, generator=src_generator,
                                        src_dictionary=src_dictionary, trg_dictionary=src_dictionary, encoder=encoder,
                                        decoder=src_decoder, denoising=not args.disable_denoising, device=device)
        src2trg_translator = Translator(encoder_embeddings=src_encoder_embeddings,
                                        decoder_embeddings=trg_decoder_embeddings, generator=trg_generator,
                                        src_dictionary=src_dictionary, trg_dictionary=trg_dictionary, encoder=encoder,
                                        decoder=trg_decoder, denoising=not args.disable_denoising, device=device)
        trg2trg_translator = Translator(encoder_embeddings=trg_encoder_embeddings,
                                        decoder_embeddings=trg_decoder_embeddings, generator=trg_generator,
                                        src_dictionary=trg_dictionary, trg_dictionary=trg_dictionary, encoder=encoder,
                                        decoder=trg_decoder, denoising=not args.disable_denoising, device=device)
        trg2src_translator = Translator(encoder_embeddings=trg_encoder_embeddings,
                                        decoder_embeddings=src_decoder_embeddings, generator=src_generator,
                                        src_dictionary=trg_dictionary, trg_dictionary=src_dictionary, encoder=encoder,
                                        decoder=src_decoder, denoising=not args.disable_denoising, device=device)