np.max(data_test_lengths, axis=0)],
        axis=0)

if args.src_max_len > 0:
    src_max_len = args.src_max_len
else:
    src_max_len = max_len[0]
if args.tgt_max_len > 0:
    tgt_max_len = args.tgt_max_len
else:
    tgt_max_len = max_len[1]

encoder, decoder = get_transformer_encoder_decoder(units=args.num_units,
                                                   hidden_size=args.hidden_size,
                                                   dropout=args.dropout,
                                                   num_layers=args.num_layers,
                                                   num_heads=args.num_heads,
                                                   max_src_length=max(src_max_len, 500),
                                                   max_tgt_length=max(tgt_max_len, 500),
                                                   scaled=args.scaled)
model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder,
                 share_embed=args.dataset != 'TOY', embed_size=args.num_units,
                 tie_weights=args.dataset != 'TOY', embed_initializer=None, prefix='transformer_')

param_name = args.model_parameter
if (not os.path.exists(param_name)):
    archive_param_url = 'http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/{}'
    archive_file_hash = ('transformer_en_de_512_WMT2014-97ffd554a.zip',
                         'c182aae397ead66cc91f1bf241ce07a91884c869')
    param_file_hash = ('transformer_en_de_512_WMT2014-97ffd554a.params',
                       '97ffd554aac1f4ba2c5a99483543f47440bd9738')
    archive_file, archive_hash = archive_file_hash
예제 #2
0
    max_len = np.max(
        [np.max(data_train_lengths, axis=0), np.max(data_val_lengths, axis=0),
         np.max(data_test_lengths, axis=0)],
        axis=0)
if args.src_max_len > 0:
    src_max_len = args.src_max_len
else:
    src_max_len = max_len[0]
if args.tgt_max_len > 0:
    tgt_max_len = args.tgt_max_len
else:
    tgt_max_len = max_len[1]
encoder, decoder = get_transformer_encoder_decoder(units=args.num_units,
                                                   hidden_size=args.hidden_size,
                                                   dropout=args.dropout,
                                                   num_layers=args.num_layers,
                                                   num_heads=args.num_heads,
                                                   max_src_length=max(src_max_len, 500),
                                                   max_tgt_length=max(tgt_max_len, 500),
                                                   scaled=args.scaled)
model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder,
                 share_embed=args.dataset != 'TOY', embed_size=args.num_units,
                 tie_weights=args.dataset != 'TOY', embed_initializer=None, prefix='transformer_')
model.initialize(init=mx.init.Xavier(magnitude=args.magnitude), ctx=ctx)
static_alloc = True
model.hybridize(static_alloc=static_alloc)
logging.info(model)

translator = BeamSearchTranslator(model=model, beam_size=args.beam_size,
                                  scorer=nlp.model.BeamSearchScorer(alpha=args.lp_alpha,
                                                                    K=args.lp_k),
                                  max_length=200)