예제 #1
0
    def __init__(self, cfg, word_vocab_size, pos_vocab_size, dep_vocab_size):
        super().__init__()
        self.transformer = models.Transformer(cfg)

        #logits_pos
        self.fc1 = nn.Linear(cfg.dim, cfg.dim)
        self.activ1 = models.gelu
        self.norm1 = models.LayerNorm(cfg)
        self.decoder1 = nn.Linear(cfg.dim, pos_vocab_size)

        #logits_vocab
        self.fc2 = nn.Linear(cfg.dim, cfg.dim)
        self.activ2 = models.gelu
        self.norm2 = models.LayerNorm(cfg)
        self.decoder2 = nn.Linear(cfg.dim, dep_vocab_size)

        #logits_word_vocab_size
        self.fc3 = nn.Linear(cfg.dim, cfg.dim)
        self.activ3 = models.gelu
        self.norm3 = models.LayerNorm(cfg)
        embed_weight = self.transformer.embed.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder3 = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder3.weight = embed_weight
        self.decoder3_bias = nn.Parameter(torch.zeros(n_vocab))
예제 #2
0
    def __init__(self, cfg):
        super().__init__()
        self.transformer = models.Transformer(cfg)
        self.fc = nn.Linear(cfg.hidden, cfg.hidden)
        self.activ1 = nn.Tanh()
        self.linear = nn.Linear(cfg.hidden, cfg.hidden)
        self.activ2 = models.gelu
        self.norm = models.LayerNorm(cfg)
        self.classifier = nn.Linear(cfg.hidden, 2)

        # decoder is shared with embedding layer
        ## project hidden layer to embedding layer
        embed_weight2 = self.transformer.embed.tok_embed2.weight
        n_hidden, n_embedding = embed_weight2.size()
        self.decoder1 = nn.Linear(n_hidden, n_embedding, bias=False)
        self.decoder1.weight.data = embed_weight2.data.t()

        ## project embedding layer to vocabulary layer
        embed_weight1 = self.transformer.embed.tok_embed1.weight
        n_vocab, n_embedding = embed_weight1.size()
        self.decoder2 = nn.Linear(n_embedding, n_vocab, bias=False)
        self.decoder2.weight = embed_weight1

        # self.tok_embed1 = nn.Embedding(cfg.vocab_size, cfg.embedding)
        # self.tok_embed2 = nn.Linear(cfg.embedding, cfg.hidden)

        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))  # 역할이 뭐지..?
예제 #3
0
def predict_step(inputs,
                 params,
                 cache,
                 eos_id,
                 max_decode_len,
                 config,
                 beam_size=4):
    """Predict translation with fast decoding beam search on a batch."""
    batch_size = inputs.shape[0]

    # Prepare transformer fast-decoder call for beam search: for beam search, we
    # need to set up our decoder model to handle a batch size equal to
    # batch_size * beam_size, where each batch item's data is expanded in-place
    # rather than tiled.
    # i.e. if we denote each batch element subtensor as el[n]:
    # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2]
    src_padding_mask = decode.flat_batch_beam_expand((inputs > 0)[..., None],
                                                     beam_size)
    tgt_padding_mask = decode.flat_batch_beam_expand(
        jnp.ones((batch_size, 1, 1)), beam_size)
    encoded_inputs = decode.flat_batch_beam_expand(
        models.Transformer(config).apply({'param': params},
                                         inputs,
                                         method=models.Transformer.encode),
        beam_size)

    def tokens_ids_to_logits(flat_ids, flat_cache):
        """Token slice to logits from decoder model."""
        # --> [batch * beam, 1, vocab]
        flat_logits, new_vars = models.Transformer(config).apply(
            {
                'param': params,
                'cache': flat_cache
            },
            encoded_inputs,
            src_padding_mask,
            flat_ids,
            tgt_padding_mask=tgt_padding_mask,
            mutable=['cache'],
            method=models.Transformer.decode)
        new_flat_cache = new_vars['cache']
        # Remove singleton sequence-length dimension:
        # [batch * beam, 1, vocab] --> [batch * beam, vocab]
        flat_logits = flat_logits.squeeze(axis=1)
        return flat_logits, new_flat_cache

    # Using the above-defined single-step decoder function, run a
    # beam search over possible sequences given input encoding.
    beam_seqs, _ = decode.beam_search(inputs,
                                      cache,
                                      tokens_ids_to_logits,
                                      beam_size=beam_size,
                                      alpha=0.6,
                                      eos_id=eos_id,
                                      max_decode_len=max_decode_len)

    # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension
    # sorted in increasing order of log-probability.
    # Return the highest scoring beam sequence, drop first dummy 0 token.
    return beam_seqs[:, -1, 1:]
예제 #4
0
def eval_step(params, batch, config, label_smoothing=0.0):
  """Calculate evaluation metrics on a batch."""
  inputs, targets = batch["inputs"], batch["targets"]
  weights = jnp.where(targets > 0, 1.0, 0.0)
  logits = models.Transformer(config).apply({"params": params}, inputs, targets)

  return compute_metrics(logits, targets, weights, label_smoothing)
예제 #5
0
파일: train.py 프로젝트: wrzadkow/flax
def initialize_cache(inputs, max_decode_len, config):
    """Initialize a cache for a given input shape and max decode length."""
    target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:]
    initial_variables = models.Transformer(config).init(
        jax.random.PRNGKey(0), jnp.ones(inputs.shape, config.dtype),
        jnp.ones(target_shape, config.dtype))
    return initial_variables['cache']
예제 #6
0
 def __init__(self, cfg, n_labels):
     super().__init__()
     self.transformer = models.Transformer(cfg)
     self.fc = nn.Linear(cfg.dim, cfg.dim)
     self.activ = nn.Tanh()
     self.drop = nn.Dropout(cfg.p_drop_hidden)
     self.classifier = nn.Linear(cfg.dim, n_labels)
예제 #7
0
    def __init__(self, cfg):
        super().__init__()
        self.transformer = models.Transformer(cfg)

        #logits_sentence_clsf
        self.fc = nn.Linear(cfg.dim, cfg.dim)
        self.activ1 = nn.Tanh()
        self.classifier = nn.Linear(cfg.dim, 2)

        #logits_paragraph_clsf
        '''
        self.fc = nn.Linear(cfg.dim, 2)
        self.activ1 = nn.Tanh()
        self.norm1 = models.LayerNorm(cfg)
        self.drop = nn.Dropout(cfg.p_drop_hidden)
        self.classifier = nn.Linear(cfg.max_len * 2, 2)
        '''

        #logits_lm
        self.linear = nn.Linear(cfg.dim, cfg.dim)
        self.activ2 = models.gelu
        self.norm2 = models.LayerNorm(cfg)
        # decoder is shared with embedding layer
        embed_weight = self.transformer.embed.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))

        #logits_same
        self.linear2 = nn.Linear(cfg.dim, cfg.vocab_size)
예제 #8
0
 def __init__(self, cfg, n_labels):
     super().__init__()
     self.transformer = models.Transformer(cfg)
     self.fc = nn.Linear(cfg.hidden, cfg.hidden)
     self.activ = nn.ReLU()
     self.drop = nn.Dropout(cfg.p_drop_hidden)
     self.pool = nn.AdaptiveMaxPool1d(1)
     self.classifier = nn.Linear(cfg.hidden, n_labels)
예제 #9
0
def single_translate(english_text):
    """ 只输入一句话,即batch_size==1 """
    conf = configuration.Config()
    tokenizer = tokenization.FullTokenizer(en_vocab_file=os.path.join(conf.file_config.data_path,
                                                                      conf.file_config.en_vocab),
                                           zh_vocab_file=os.path.join(conf.file_config.data_path,
                                                                      conf.file_config.zh_vocab))
    conf.model_config.src_vocab_size = tokenizer.get_en_vocab_size() + 2
    conf.model_config.trg_vocab_size = tokenizer.get_zh_vocab_size() + 2
    model = models.Transformer(conf)
    model.to(device)

    translate_dataset = datasets.Translate(mode='single_translate',
                                           config=conf,
                                           tokenizer=tokenizer,
                                           texts=english_text)
    # encoder输入即为原始英文id
    en_ids, _ = translate_dataset[0]
    en_ids = torch.tensor(en_ids).unsqueeze(dim=0)
    # decoder的初始输出为<BOS>
    decoder_input = torch.tensor([tokenizer.get_zh_vocab_size()]).view(1, 1)   # [1, 1]

    model.load_state_dict(torch.load(os.path.join(conf.train_config.model_dir,
                                                  conf.train_config.model_name + '.pth'),
                                     map_location=device))
    # checkpoint = torch.load(os.path.join(conf.train_config.model_dir,
    #                                      conf.train_config.model_name + '_epoch_{}.tar'.format(50)),
    #                         map_location=device)
    # model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    for i in range(51):
        if torch.cuda.is_available():
            prediction_logits = model(en_ids.cuda(),
                                      decoder_input.cuda())        # [1, i+1, vocab_size]
        else:
            prediction_logits = model(en_ids,
                                      decoder_input)
        # 取出最后一个distribution,做argmax得到预测的新字
        predictions = prediction_logits[:, -1, :]              # [batch_size, vocab_size]
        predictions = F.softmax(predictions, dim=-1)
        predict_zh_ids = torch.argmax(predictions, dim=-1)      # [batch_size]

        # 若预测出的结果是<EOS>,则结束
        if predict_zh_ids.data == tokenizer.get_zh_vocab_size() + 1:
            break
        # 否则,预测出的结果与先前的结果拼接,重新循环
        else:
            if torch.cuda.is_available():
                decoder_input = torch.cat([decoder_input.cuda(), predict_zh_ids.view(1, 1)], dim=1)
            else:
                decoder_input = torch.cat([decoder_input, predict_zh_ids.view(1, 1)], dim=1)

    # 将生成的中文id转回中文文字
    translated_text = tokenizer.convert_zh_ids_to_text(list(decoder_input.cpu().detach().numpy()[0])[1: 51])
    print('原文:', english_text)
    print('翻译:', translated_text)
    return translated_text
예제 #10
0
def main():
    conf = configuration.Config()
    tokenizer = tokenization.FullTokenizer(en_vocab_file=os.path.join(conf.file_config.data_path,
                                                                      conf.file_config.en_vocab),
                                           zh_vocab_file=os.path.join(conf.file_config.data_path,
                                                                      conf.file_config.zh_vocab))
    logging.info('Using Device: {}'.format(device))
    conf.model_config.src_vocab_size = tokenizer.get_en_vocab_size() + 2
    conf.model_config.trg_vocab_size = tokenizer.get_zh_vocab_size() + 2
    model = models.Transformer(conf)
    model = model.to(device)

    if args.train:
        train_dataset = datasets.Translate(mode='train',
                                           config=conf,
                                           tokenizer=tokenizer,
                                           auto_padding=conf.train_config.auto_padding,
                                           do_filter=False)

        logging.info("***** Running training *****")
        logging.info("  Num examples = %d", len(train_dataset))
        logging.info("  Total training steps: {}".format(train_dataset.num_steps))

        train_dataloader = DataLoader(train_dataset,
                                      batch_size=conf.train_config.train_batch_size,
                                      shuffle=True,
                                      collate_fn=collate_fn)

        run(config=conf,
            dataloader=train_dataloader,
            model=model,
            mode='train',
            start_epoch=0,
            total_steps=train_dataset.num_steps)

    if args.eval:
        eval_dataset = datasets.Translate(mode='eval',
                                          config=conf,
                                          tokenizer=tokenizer,
                                          auto_padding=conf.train_config.auto_padding,
                                          do_filter=False)

        logging.info("***** Running validating *****")
        logging.info("  Num examples = %d", len(eval_dataset))
        logging.info("  Total validating steps: {}".format(eval_dataset.num_steps))

        eval_dataloader = DataLoader(eval_dataset,
                                     batch_size=conf.train_config.eval_batch_size,
                                     collate_fn=collate_fn)
        run(config=conf,
            dataloader=eval_dataloader,
            model=model,
            mode='eval',
            start_epoch=0,
            total_steps=eval_dataset.num_steps)
예제 #11
0
    def __init__(self, cfg):
        super().__init__()
        self.transformer = models.Transformer(cfg)
        self.fc = nn.Linear(cfg.hidden, cfg.hidden)
        self.activ1 = nn.Tanh()
        self.linear = nn.Linear(cfg.hidden, cfg.hidden)
        self.activ2 = models.gelu
        self.norm = models.LayerNorm(cfg)
        self.classifier = nn.Linear(cfg.hidden, 2)

        # decoder is shared with embedding layer
        ## project hidden layer to embedding layer
        self.discriminator = nn.Linear(cfg.hidden, 1, bias=False)
예제 #12
0
    def __init__(self, cfg, n_labels, local_pretrained=False):
        super().__init__()

        if (local_pretrained):
            self.transformer = models.Transformer(cfg)
        else:
            self.transformer = BertModel.from_pretrained('bert-base-uncased')

        self.fc = nn.Linear(cfg.dim, cfg.dim)
        self.activ = nn.Tanh()
        self.drop = nn.Dropout(cfg.p_drop_hidden)
        #self.classifier = nn.Linear(cfg.dim, n_labels)
        self.local_pretrained = local_pretrained
예제 #13
0
    def __init__(self, cfg):
        super().__init__()
        self.transformer = models.Transformer(cfg)

        #logits_word_vocab_size
        self.fc3 = nn.Linear(cfg.dim, cfg.dim)
        self.activ3 = models.gelu
        self.norm3 = models.LayerNorm(cfg)
        embed_weight3 = self.transformer.embed.tok_embed.weight
        n_vocab3, n_dim3 = embed_weight3.size()
        self.decoder3 = nn.Linear(n_dim3, n_vocab3, bias=False)
        self.decoder3.weight = embed_weight3
        self.decoder3_bias = nn.Parameter(torch.zeros(n_vocab3))
예제 #14
0
 def __init__(self, cfg):
     super().__init__()
     self.transformer = models.Transformer(cfg)
     self.fc = nn.Linear(cfg.dim, cfg.dim)
     self.activ1 = nn.Tanh()
     self.linear = nn.Linear(cfg.dim, cfg.dim)
     self.activ2 = models.gelu
     self.norm = models.LayerNorm(cfg)
     self.classifier = nn.Linear(cfg.dim, 2)
     # decoder is shared with embedding layer
     embed_weight = self.transformer.embed.tok_embed.weight
     n_vocab, n_dim = embed_weight.size()
     self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
     self.decoder.weight = embed_weight
     self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
예제 #15
0
파일: train.py 프로젝트: wrzadkow/flax
    def loss_fn(params):
        """loss function used for training."""
        logits = models.Transformer(config).apply(
            {'params': params},
            inputs,
            targets,
            inputs_positions=inputs_positions,
            targets_positions=targets_positions,
            inputs_segmentation=inputs_segmentation,
            targets_segmentation=targets_segmentation,
            rngs={'dropout': dropout_rng})

        loss, weight_sum = compute_weighted_cross_entropy(
            logits, targets, weights, label_smoothing)
        mean_loss = loss / weight_sum
        return mean_loss, logits
예제 #16
0
파일: train.py 프로젝트: wrzadkow/flax
 def tokens_ids_to_logits(flat_ids, flat_cache):
     """Token slice to logits from decoder model."""
     # --> [batch * beam, 1, vocab]
     flat_logits, new_vars = models.Transformer(config).apply(
         {
             'params': params,
             'cache': flat_cache
         },
         encoded_inputs,
         raw_inputs,  # only needed for input padding mask
         flat_ids,
         mutable=['cache'],
         method=models.Transformer.decode)
     new_flat_cache = new_vars['cache']
     # Remove singleton sequence-length dimension:
     # [batch * beam, 1, vocab] --> [batch * beam, vocab]
     flat_logits = flat_logits.squeeze(axis=1)
     return flat_logits, new_flat_cache
예제 #17
0
    def __init__(self, cfg, n_labels):
        super().__init__()

        self.transformer = models.Transformer(cfg)
        self.fc = nn.Linear(cfg.dim, cfg.dim)
        self.activ = nn.Tanh()
        self.drop = nn.Dropout(cfg.p_drop_hidden)
        self.classifier = nn.Linear(cfg.dim, n_labels)
        '''
        self.n_labels = n_labels
        self.transformer = models.Transformer(cfg)

        self.fc = nn.Linear(cfg.dim, n_labels)
        self.norm = nn.BatchNorm1d(cfg.dim)
        self.activ = nn.Sigmoid()
        self.drop = nn.Dropout(cfg.p_drop_hidden)
        self.classifier = nn.Linear(cfg.max_len * n_labels, n_labels)
        '''
        '''
예제 #18
0
def build_model(model_type, model_part, learned, seq_len, feature_num, kmer, clayer_num, filters, layer_sizes, embedding_dim, activation, output_activation, transfer, transfer_dim, dropout=0.1):
    
    if model_type == "ConvModel":
        convlayers = [{"kernel_size": kmer, "filters": filters, "activation": "ReLU"} for _ in range(clayer_num)]
        model = models.ConvModel(model_part, seq_len, feature_num, convlayers, layer_sizes, learned, embedding_dim, activation, output_activation, transfer, transfer_dim, dropout, posembed=False)
    elif model_type == "SpannyConvModel":
        convlayers = [{"kernel_size": kmer, "filters": filters, "activation": "ReLU"} for _ in range(clayer_num)]
        global_kernel = {"kernel_size": kmer, "filters": filters, "activation": "ReLU"}
        
        model = models.SpannyConvModel(model_part, seq_len, feature_num, global_kernel, convlayers, layer_sizes, learned, embedding_dim, activation, output_activation, transfer, transfer_dim, dropout, posembed=False)
    elif model_type == "MHCflurry":
        locally_connected_layers = [{"kernel_size": kmer, "filters": filters, "activation": "Tanh"} for _ in range(clayer_num)]
        model = models.MHCflurry(model_part, seq_len, feature_num, locally_connected_layers, layer_sizes, learned, embedding_dim, activation, output_activation, transfer, transfer_dim, dropout)
    elif model_type == "Transformer":
        # d_model : 
        model = models.Transformer(model_part, seq_len, feature_num, feature_num*2, filters, int(feature_num/filters), int(feature_num/filters), layer_sizes, learned, embedding_dim, activation, output_activation, transfer, transfer_dim)
    else:
        raise ValueError("Unsupported model type : "+model_type)
   
    if torch.cuda.is_available():
        model.cuda()
    return model
예제 #19
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  tf.io.gfile.makedirs(workdir)

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    tf.io.gfile.makedirs(workdir)
    train_summary_writer = tensorboard.SummaryWriter(
        os.path.join(workdir, "train"))
    eval_summary_writer = tensorboard.SummaryWriter(
        os.path.join(workdir, "eval"))

  if config.batch_size % n_devices:
    raise ValueError("Batch size must be divisible by the number of devices")

  vocab_path = config.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(workdir, "sentencepiece_model")
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info("Initializing dataset.")
  train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
      n_devices=n_devices,
      dataset_name=config.dataset_name,
      eval_dataset_name=config.eval_dataset_name,
      shard_idx=jax.host_id(),
      shard_count=jax.host_count(),
      vocab_path=vocab_path,
      target_vocab_size=config.vocab_size,
      batch_size=config.batch_size,
      max_corpus_chars=config.max_corpus_chars,
      max_length=config.max_target_length,
      max_eval_length=config.max_eval_target_length)
  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.
  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode("utf-8")

  if config.num_predict_steps > 0:
    predict_ds = predict_ds.take(config.num_predict_steps)

  logging.info("Initializing model, optimizer, and step functions.")

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      share_embeddings=config.share_embeddings,
      logits_via_embedding=config.logits_via_embedding,
      dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
      emb_dim=config.emb_dim,
      num_heads=config.num_heads,
      num_layers=config.num_layers,
      qkv_dim=config.qkv_dim,
      mlp_dim=config.mlp_dim,
      max_len=max(config.max_target_length, config.max_eval_target_length),
      dropout_rate=config.dropout_rate,
      attention_dropout_rate=config.attention_dropout_rate,
      deterministic=False,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(deterministic=True, decode=True)

  start_step = 0
  rng = random.PRNGKey(config.seed)
  rng, init_rng = random.split(rng)
  input_shape = (config.batch_size, config.max_target_length)
  target_shape = (config.batch_size, config.max_target_length)

  m = models.Transformer(eval_config)
  initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32),
                  jnp.ones(target_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      config.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=config.weight_decay)
  optimizer = optimizer_def.create(initial_variables["params"])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # compile multidevice versions of train/eval/predict step and cache init fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=train_config,
          learning_rate_fn=learning_rate_fn,
          label_smoothing=config.label_smoothing),
      axis_name="batch",
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, config=eval_config,
          label_smoothing=config.label_smoothing),
      axis_name="batch")
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=config.max_predict_length,
          config=predict_config),
      axis_name="batch")
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config, beam_size=config.beam_size),
      axis_name="batch",
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap"d training update for performance.
  dropout_rngs = random.split(rng, n_devices)

  logging.info("Starting training loop.")
  metrics_all = []
  t_loop_start = time.time()
  for step, batch in zip(range(start_step, config.num_train_steps), train_iter):
    # Shard data to devices and do a training step.
    batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
    optimizer, metrics, dropout_rngs = p_train_step(
        optimizer, batch, dropout_rng=dropout_rngs)
    metrics_all.append(metrics)

    # Quick indication that training is happening.
    logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)

    # Save a checkpoint on one host after every checkpoint_freq steps.
    if (config.save_checkpoints and step % config.checkpoint_freq == 0 and
        step > 0 and jax.host_id() == 0):
      checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer),
                                  step)

    # Periodic metric handling.
    if step % config.eval_frequency != 0 and step > 0:
      continue

    # Training Metrics
    logging.info("Gathering training metrics.")
    metrics_all = common_utils.get_metrics(metrics_all)
    lr = metrics_all.pop("learning_rate").mean()
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop("denominator")
    summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
    summary["learning_rate"] = lr
    steps_per_eval = config.eval_frequency if step != 0 else 1
    steps_per_sec = steps_per_eval / (time.time() - t_loop_start)
    t_loop_start = time.time()
    if jax.host_id() == 0:
      train_summary_writer.scalar("steps per second", steps_per_sec, step)
      for key, val in summary.items():
        train_summary_writer.scalar(key, val, step)
      train_summary_writer.flush()
    metrics_all = []
    logging.info("train in step: %d, loss: %.4f", step, summary["loss"])

    # Eval Metrics
    logging.info("Gathering evaluation metrics.")
    t_eval_start = time.time()
    eval_metrics = []
    eval_iter = iter(eval_ds)
    for _, eval_batch in zip(range(config.num_eval_steps), eval_iter):
      eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
      eval_batch = common_utils.shard(eval_batch)
      metrics = p_eval_step(optimizer.target, eval_batch)
      eval_metrics.append(metrics)
    eval_metrics = common_utils.get_metrics(eval_metrics)
    eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
    eval_denominator = eval_metrics_sums.pop("denominator")
    eval_summary = jax.tree_map(
        lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
        eval_metrics_sums)
    if jax.host_id() == 0:
      for key, val in eval_summary.items():
        eval_summary_writer.scalar(key, val, step)
      eval_summary_writer.flush()
    logging.info("eval in step: %d, loss: %.4f", step, eval_summary["loss"])
    logging.info("eval time: %.4f s step %d", time.time() - t_eval_start, step)

    # Translation and BLEU Score.
    logging.info("Translating evaluation dataset.")
    t_inference_start = time.time()
    sources, references, predictions = [], [], []
    for pred_batch in predict_ds:
      pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
      # Handle final odd-sized batch by padding instead of dropping it.
      cur_pred_batch_size = pred_batch["inputs"].shape[0]
      if cur_pred_batch_size % n_devices:
        padded_size = int(
            np.ceil(cur_pred_batch_size / n_devices) * n_devices)
        pred_batch = jax.tree_map(
            lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
      pred_batch = common_utils.shard(pred_batch)
      cache = p_init_cache(pred_batch["inputs"])
      predicted = p_pred_step(pred_batch["inputs"], optimizer.target, cache,
                              eos_id, config.max_predict_length)
      predicted = tohost(predicted)
      inputs = tohost(pred_batch["inputs"])
      targets = tohost(pred_batch["targets"])
      # Iterate through non-padding examples of batch.
      for i, s in enumerate(predicted[:cur_pred_batch_size]):
        sources.append(decode_tokens(inputs[i]))
        references.append(decode_tokens(targets[i]))
        predictions.append(decode_tokens(s))
    logging.info("Translation: %d predictions %d references %d sources.",
                 len(predictions), len(references), len(sources))
    logging.info("Translation time: %.4f s step %d.",
                 time.time() - t_inference_start, step)

    # Calculate BLEU score for translated eval corpus against reference.
    bleu_matches = bleu.bleu_partial(references, predictions)
    all_bleu_matches = per_host_sum_pmap(bleu_matches)
    bleu_score = bleu.complete_bleu(*all_bleu_matches)
    # Save translation samples for tensorboard.
    exemplars = ""
    for n in np.random.choice(np.arange(len(predictions)), 8):
      exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n"
    if jax.host_id() == 0:
      eval_summary_writer.scalar("bleu", bleu_score, step)
      eval_summary_writer.text("samples", exemplars, step)
      eval_summary_writer.flush()
    logging.info("Translation BLEU Score %.4f", bleu_score)
예제 #20
0
파일: train.py 프로젝트: wrzadkow/flax
 def initialize_variables(rng):
     return models.Transformer(eval_config).init(
         rng, jnp.ones(input_shape, jnp.float32),
         jnp.ones(target_shape, jnp.float32))
예제 #21
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  tf.io.gfile.makedirs(workdir)

  vocab_path = config.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(workdir, "sentencepiece_model")
    config.vocab_path = vocab_path
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info("Initializing dataset.")
  train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
      n_devices=jax.local_device_count(),
      config=config,
      reverse_translation=config.reverse_translation,
      vocab_path=vocab_path)

  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode("utf-8")

  if config.num_predict_steps > 0:
    predict_ds = predict_ds.take(config.num_predict_steps)

  logging.info("Initializing model, optimizer, and step functions.")

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      share_embeddings=config.share_embeddings,
      logits_via_embedding=config.logits_via_embedding,
      dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
      emb_dim=config.emb_dim,
      num_heads=config.num_heads,
      num_layers=config.num_layers,
      qkv_dim=config.qkv_dim,
      mlp_dim=config.mlp_dim,
      max_len=max(config.max_target_length, config.max_eval_target_length),
      dropout_rate=config.dropout_rate,
      attention_dropout_rate=config.attention_dropout_rate,
      deterministic=False,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(deterministic=True, decode=True)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)
  rng, init_rng = jax.random.split(rng)
  input_shape = (config.per_device_batch_size, config.max_target_length)
  target_shape = (config.per_device_batch_size, config.max_target_length)

  m = models.Transformer(eval_config)
  initial_variables = jax.jit(m.init)(init_rng,
                                      jnp.ones(input_shape, jnp.float32),
                                      jnp.ones(target_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      config.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=config.weight_decay)
  optimizer = optimizer_def.create(initial_variables["params"])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0)
  if start_step == 0:
    writer.write_hparams(dict(config))

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # compile multidevice versions of train/eval/predict step and cache init fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=train_config,
          learning_rate_fn=learning_rate_fn,
          label_smoothing=config.label_smoothing),
      axis_name="batch",
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, config=eval_config),
      axis_name="batch")
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=config.max_predict_length,
          config=predict_config),
      axis_name="batch")
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config, beam_size=config.beam_size),
      axis_name="batch",
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap"d training update for performance.
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  logging.info("Starting training loop.")
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  if jax.host_id() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
    ]
  train_metrics = []
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation("train", step_num=step):
        batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
        optimizer, metrics = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed("training_metrics"):
          logging.info("Gathering training metrics.")
          train_metrics = common_utils.get_metrics(train_metrics)
          lr = train_metrics.pop("learning_rate").mean()
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          denominator = metrics_sums.pop("denominator")
          summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
          summary["learning_rate"] = lr
          summary = {"train_" + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed("eval"):
          eval_results = evaluate(
              p_eval_step=p_eval_step,
              target=optimizer.target,
              eval_ds=eval_ds,
              num_eval_steps=config.num_eval_steps)
          writer.write_scalars(
              step, {"eval_" + k: v for k, v in eval_results.items()})

        with report_progress.timed("translate_and_bleu"):
          exemplars, bleu_score = translate_and_calculate_bleu(
              p_pred_step=p_pred_step,
              p_init_cache=p_init_cache,
              target=optimizer.target,
              predict_ds=predict_ds,
              decode_tokens=decode_tokens,
              max_predict_length=config.max_predict_length)
          writer.write_scalars(step, {"bleu": bleu_score})
          writer.write_texts(step, {"samples": exemplars})

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (step % config.checkpoint_every_steps == 0 or
                         is_last_step)
      if config.save_checkpoints and save_checkpoint and jax.host_id() == 0:
        with report_progress.timed("checkpoint"):
          checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer),
                                      step)
예제 #22
0
 
 if args.dataset == 'speechcoco':
     audio_root_path_train = os.path.join(args.data_dir, 'train2014/wav/')
     image_root_path_train = os.path.join(args.data_dir, 'train2014/imgs/')
     segment_file_train = os.path.join(args.data_dir, 'train2014/mscoco_train_word_segments.txt')
     bbox_file_train = os.path.join(args.data_dir, 'train2014/mscoco_train_rcnn_feature.npz')
     audio_root_path_test = os.path.join(args.data_dir, 'val2014/wav/') 
     image_root_path_test = os.path.join(args.data_dir, 'val2014/imgs/')
     segment_file_test = os.path.join(args.data_dir, 'val2014/mscoco_val_word_segments.txt')
     bbox_file_test = os.path.join(args.data_dir, 'val2014/mscoco_val_rcnn_feature.npz')
     split_file = os.path.join(args.data_dir, 'val2014/mscoco_val_split.txt')
 
     if args.audio_model == 'davenet':
         audio_model = nn.DataParallel(models.Davenet(embedding_dim=1024))
     elif args.audio_model == 'transformer':
         audio_model = nn.DataParallel(models.Transformer(embedding_dim=1024))
 
     train_set = OnlineImageAudioCaptionDataset(audio_root_path_train,
                                      image_root_path_train,
                                      segment_file_train,
                                      bbox_file_train,
                                      configs={'return_boundary': True})
     test_set = OnlineImageAudioCaptionDataset(audio_root_path_test,
                                     image_root_path_test,
                                     segment_file_test,
                                     bbox_file_test,
                                     keep_index_file=split_file, 
                                     configs={'return_boundary': True})
 elif args.dataset == 'mscoco':
     segment_file_train = os.path.join(args.data_dir, 'train2014/mscoco_train_word_phone_segments.txt')
     segment_file_test = os.path.join(args.data_dir, 'val2014/mscoco_val_word_phone_segments.txt') # TODO
예제 #23
0
파일: run.py 프로젝트: lwang114/magnet
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=1,
                                          pin_memory=True)

# Initialize the image and audio models
if args.audio_model == 'tdnn':
    if args.supervision_level == 'text':
        audio_model = models.SmallDavenetEncoder(audio_encoder_configs)
    else:  # TODO
        audio_model = models.DavenetEncoder(audio_encoder_configs)
elif args.audio_model == 'lstm':
    audio_model = models.BLSTMEncoder(audio_encoder_configs)
elif args.audio_model == 'transformer':  # TODO
    pretrained_model_file = '/ws/ifp-53_1/hasegawa/tools/espnet/egs/discophone/ifp_lwang114/dump/mscoco/eval/deltafalse/split1utt/data_encoder.pth'
    audio_model = models.Transformer(
        n_class=49, pretrained_model_file=pretrained_model_file)

if args.image_model == 'res34':
    image_model = models.ResnetEncoder(image_encoder_configs)
elif args.image_model == 'linear' or args.image_model == 'rcnn':
    image_model = models.LinearEncoder(image_encoder_configs)

if args.supervision_level == 'audio':
    audio_segment_model = models.NoopSegmenter(audio_segmenter_configs)
else:
    audio_segment_model = models.FixedTextSegmenter(audio_segmenter_configs)

image_segment_model = models.NoopSegmenter(image_segmenter_configs)

if args.alignment_model == 'mixture_aligner':
    if args.translate_direction == 'sp2im':
예제 #24
0
if not os.path.exists(args.exp_dir):
    os.makedirs("%s/models" % args.exp_dir)

if args.precompute_acoustic_feature:
    audio_model = models.NoOpEncoder(embedding_dim=1000)
    image_model = models.LinearTrans(input_dim=2048, embedding_dim=1000)
    attention_model = models.DotProductAttention(in_size=1000)
    if not args.only_eval:
        train_attention(audio_model, image_model, attention_model,
                        train_loader, val_loader, args)
    else:
        evaluation_attention(audio_model, image_model, attention_model,
                             val_loader, args)
else:
    audio_model = models.Transformer(embedding_dim=1024)
    image_model = models.LinearTrans(input_dim=2048, embedding_dim=1024)
    attention_model = models.DotProductAttention(in_size=1024)
    if not args.only_eval:
        train_attention(audio_model, image_model, attention_model,
                        train_loader, val_loader, args)
    else:
        loader_for_alignment = torch.utils.data.DataLoader(
            dataloaders.OnlineImageAudioCaptionDataset(
                audio_root_path_test,
                image_root_path_test,
                segment_file_test,
                bbox_file_test,
                keep_index_file=split_file,
                configs={'return_boundary': True}),
            batch_size=args.batch_size,
예제 #25
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Make sure tf does not allocate gpu memory.
  tf.config.experimental.set_visible_devices([], 'GPU')

  batch_size = FLAGS.batch_size
  learning_rate = FLAGS.learning_rate
  num_train_steps = FLAGS.num_train_steps
  eval_freq = FLAGS.eval_frequency
  random_seed = FLAGS.random_seed

  if not FLAGS.dev:
    raise app.UsageError('Please provide path to dev set.')
  if not FLAGS.train:
    raise app.UsageError('Please provide path to training set.')
  if batch_size % jax.device_count() > 0:
    raise ValueError('Batch size must be divisible by the number of devices')
  device_batch_size = batch_size // jax.device_count()

  if jax.process_index() == 0:
    train_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train'))
    eval_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval'))

  # create the training and development dataset
  vocabs = input_pipeline.create_vocabs(FLAGS.train)
  config = models.TransformerConfig(
      vocab_size=len(vocabs['forms']),
      output_vocab_size=len(vocabs['xpos']),
      max_len=FLAGS.max_length)

  attributes_input = [input_pipeline.CoNLLAttributes.FORM]
  attributes_target = [input_pipeline.CoNLLAttributes.XPOS]
  train_ds = input_pipeline.sentence_dataset_dict(
      FLAGS.train,
      vocabs,
      attributes_input,
      attributes_target,
      batch_size=batch_size,
      bucket_size=config.max_len)
  train_iter = iter(train_ds)

  eval_ds = input_pipeline.sentence_dataset_dict(
      FLAGS.dev,
      vocabs,
      attributes_input,
      attributes_target,
      batch_size=batch_size,
      bucket_size=config.max_len,
      repeat=1)

  model = models.Transformer(config)

  rng = random.PRNGKey(random_seed)
  rng, init_rng = random.split(rng)

  # call a jitted initialization function to get the initial parameter tree
  @jax.jit
  def initialize_variables(init_rng):
    init_batch = jnp.ones((config.max_len, 1), jnp.float32)
    init_variables = model.init(init_rng, inputs=init_batch, train=False)
    return init_variables
  init_variables = initialize_variables(init_rng)

  optimizer_def = optim.Adam(learning_rate, beta1=0.9, beta2=0.98,
      eps=1e-9, weight_decay=1e-1)
  optimizer = optimizer_def.create(init_variables['params'])
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=learning_rate)

  p_train_step = jax.pmap(
      functools.partial(train_step, model=model, learning_rate_fn=learning_rate_fn),
      axis_name='batch')

  def eval_step(params, batch):
    """Calculate evaluation metrics on a batch."""
    inputs, targets = batch['inputs'], batch['targets']
    weights = jnp.where(targets > 0, 1.0, 0.0)
    logits = model.apply({'params': params}, inputs=inputs, train=False)
    return compute_metrics(logits, targets, weights)

  p_eval_step = jax.pmap(eval_step, axis_name='batch')

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  dropout_rngs = random.split(rng, jax.local_device_count())
  metrics_all = []
  tick = time.time()
  best_dev_score = 0
  for step, batch in zip(range(num_train_steps), train_iter):
    batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access

    optimizer, metrics, dropout_rngs = p_train_step(optimizer, batch, dropout_rng=dropout_rngs)
    metrics_all.append(metrics)

    if (step + 1) % eval_freq == 0:
      metrics_all = common_utils.get_metrics(metrics_all)
      lr = metrics_all.pop('learning_rate').mean()
      metrics_sums = jax.tree_map(jnp.sum, metrics_all)
      denominator = metrics_sums.pop('denominator')
      summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
      summary['learning_rate'] = lr
      logging.info('train in step: %d, loss: %.4f', step, summary['loss'])
      if jax.process_index() == 0:
        tock = time.time()
        steps_per_sec = eval_freq / (tock - tick)
        tick = tock
        train_summary_writer.scalar('steps per second', steps_per_sec, step)
        for key, val in summary.items():
          train_summary_writer.scalar(key, val, step)
        train_summary_writer.flush()

      metrics_all = []  # reset metric accumulation for next evaluation cycle.

      eval_metrics = []
      eval_iter = iter(eval_ds)

      for eval_batch in eval_iter:
        eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
        # Handle final odd-sized batch by padding instead of dropping it.
        cur_pred_batch_size = eval_batch['inputs'].shape[0]
        if cur_pred_batch_size != batch_size:
          # pad up to batch size
          eval_batch = jax.tree_map(
              lambda x: pad_examples(x, batch_size), eval_batch)
        eval_batch = common_utils.shard(eval_batch)

        metrics = p_eval_step(optimizer.target, eval_batch)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
      eval_denominator = eval_metrics_sums.pop('denominator')
      eval_summary = jax.tree_map(
          lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
          eval_metrics_sums)

      logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step,
                   eval_summary['loss'], eval_summary['accuracy'])

      if best_dev_score < eval_summary['accuracy']:
        best_dev_score = eval_summary['accuracy']
        # TODO: save model.
      eval_summary['best_dev_score'] = best_dev_score
      logging.info('best development model score %.4f', best_dev_score)
      if jax.process_index() == 0:
        for key, val in eval_summary.items():
          eval_summary_writer.scalar(key, val, step)
        eval_summary_writer.flush()
예제 #26
0
    from torch.utils.data import DataLoader
    import models

    conf = configuration.Config()
    tokenizer = tokenization.FullTokenizer(
        en_vocab_file=os.path.join(conf.file_config.data_path,
                                   conf.file_config.en_vocab),
        zh_vocab_file=os.path.join(conf.file_config.data_path,
                                   conf.file_config.zh_vocab))
    conf.model_config.src_vocab_size = tokenizer.get_en_vocab_size() + 2
    conf.model_config.trg_vocab_size = tokenizer.get_zh_vocab_size() + 2
    dataset = Translate(mode='demo',
                        config=conf,
                        tokenizer=tokenizer,
                        do_filter=False)
    transformer = models.Transformer(conf)

    def collate_fn(batches):
        """ 每个batch做padding,而非所有样本做padding """
        batch_en_ids = [torch.tensor(batch[0]) for batch in batches]
        batch_en_ids = rnn_utils.pad_sequence(batch_en_ids,
                                              batch_first=True,
                                              padding_value=0)

        if batches[0][1] is not None:
            batch_zh_ids = [torch.tensor(batch[1]) for batch in batches]
            batch_zh_ids = rnn_utils.pad_sequence(batch_zh_ids,
                                                  batch_first=True,
                                                  padding_value=0)
        else:
            batch_zh_ids = None