Exemple #1
0
    with strategy.scope():
        file_path = utils.download(download_path)
        train_ds, test_ds, inp_lang, targ_lang = utils.create_dataset(
            file_path, num_examples, buffer_size, batch_size)
        vocab_inp_size = len(inp_lang.word_index) + 1
        vocab_tar_size = len(targ_lang.word_index) + 1

        num_train_steps_per_epoch = tf.data.experimental.cardinality(train_ds)
        num_test_steps_per_epoch = tf.data.experimental.cardinality(test_ds)

        train_iterator = strategy.make_dataset_iterator(train_ds)
        test_iterator = strategy.make_dataset_iterator(test_ds)

        encoder = nmt.Encoder(vocab_inp_size, embedding_dim, enc_units,
                              batch_size)
        decoder = nmt.Decoder(vocab_tar_size, embedding_dim, dec_units,
                              batch_size)

        train_obj = DistributedTrain(epochs, enable_function, encoder, decoder,
                                     inp_lang, targ_lang, batch_size)
        print('Training ...')
        return train_obj.training_loop(train_iterator, test_iterator,
                                       num_train_steps_per_epoch,
                                       num_test_steps_per_epoch, strategy)


if __name__ == '__main__':
    utils.nmt_flags()
    app.run(run_main)
Exemple #2
0
    return (self.train_loss_metric.result().numpy(),
            self.test_loss_metric.result().numpy())


def run_main(argv):
  del argv
  kwargs = utils.flags_dict()
  main(**kwargs)


def main(epochs, enable_function, buffer_size, batch_size, download_path,
         num_examples=70000, embedding_dim=256, enc_units=1024, dec_units=1024):
  file_path = utils.download(download_path)
  train_ds, test_ds, inp_lang, targ_lang = utils.create_dataset(
      file_path, num_examples, buffer_size, batch_size)
  vocab_inp_size = len(inp_lang.word_index) + 1
  vocab_tar_size = len(targ_lang.word_index) + 1

  encoder = nmt.Encoder(vocab_inp_size, embedding_dim, enc_units, batch_size)
  decoder = nmt.Decoder(vocab_tar_size, embedding_dim, dec_units)

  train_obj = Train(epochs, enable_function, encoder, decoder,
                    inp_lang, targ_lang, batch_size, batch_size)
  print ('Training ...')
  return train_obj.training_loop(train_ds, test_ds)

if __name__ == '__main__':
  utils.nmt_flags()
  app.run(run_main)