Esempio n. 1
0
def prepare_model(model_file, batch_size=1):
  """Prepare the model."""
  mode = 'eval' if FLAGS.use_eval_mode else 'predict'
  print('Initializing the model in %s mode.' % mode, flush=True)

  # Read the model name from the gin file
  model_reference = gin.query_parameter(
      'trax.supervised.trainer_lib.train.model')
  model = model_reference.scoped_configurable_fn(mode=mode)

  dec_len = 32 if FLAGS.use_eval_mode else 1
  batch_size_pd = max(1, batch_size // jax.local_device_count())
  shape11 = shapes.ShapeDtype((batch_size_pd, dec_len), dtype=np.int32)
  # shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
  model.init_from_file(
      model_file, weights_only=True, input_signature=(shape11, shape11))
  model = tl.Accelerate(model)

  initial_state = model.state
  vocab = t5_spc_vocab.SentencePieceVocabulary(data.DEFAULT_SPM_PATH)

  return vocab, model, initial_state
Esempio n. 2
0
def sentencepiece_vocab(extra_ids=0):
    return vocabularies.SentencePieceVocabulary(os.path.join(
        TEST_DATA_DIR, "sentencepiece", "sentencepiece.model"),
                                                extra_ids=extra_ids)