Пример #1
0
def main(epochs, enable_function, buffer_size, batch_size, download_path,
         num_examples=70000, embedding_dim=256, enc_units=1024, dec_units=1024):
  file_path = utils.download(download_path)
  train_ds, test_ds, inp_lang, targ_lang = utils.create_dataset(
      file_path, num_examples, buffer_size, batch_size)
  vocab_inp_size = len(inp_lang.word_index) + 1
  vocab_tar_size = len(targ_lang.word_index) + 1

  encoder = nmt.Encoder(vocab_inp_size, embedding_dim, enc_units, batch_size)
  decoder = nmt.Decoder(vocab_tar_size, embedding_dim, dec_units)

  train_obj = Train(epochs, enable_function, encoder, decoder,
                    inp_lang, targ_lang, batch_size, batch_size)
  print ('Training ...')
  return train_obj.training_loop(train_ds, test_ds)
Пример #2
0
def main(epochs,
         enable_function,
         buffer_size,
         batch_size,
         download_path,
         num_examples=70000,
         embedding_dim=256,
         enc_units=1024,
         dec_units=1024,
         num_gpu=1):

    devices = ['/device:GPU:{}'.format(i) for i in range(num_gpu)]
    strategy = tf.distribute.MirroredStrategy(devices)
    num_replicas = strategy.num_replicas_in_sync

    with strategy.scope():
        file_path = utils.download(download_path)
        train_ds, test_ds, inp_lang, targ_lang = utils.create_dataset(
            file_path, num_examples, buffer_size, batch_size)
        vocab_inp_size = len(inp_lang.word_index) + 1
        vocab_tar_size = len(targ_lang.word_index) + 1

        num_train_steps_per_epoch = tf.data.experimental.cardinality(train_ds)
        num_test_steps_per_epoch = tf.data.experimental.cardinality(test_ds)

        train_iterator = strategy.make_dataset_iterator(train_ds)
        test_iterator = strategy.make_dataset_iterator(test_ds)

        local_batch_size, remainder = divmod(batch_size, num_replicas)

        template = ('Batch size ({}) must be divisible by the '
                    'number of replicas ({})')
        if remainder:
            raise ValueError(template.format(batch_size, num_replicas))

        encoder = nmt.Encoder(vocab_inp_size, embedding_dim, enc_units,
                              local_batch_size)
        decoder = nmt.Decoder(vocab_tar_size, embedding_dim, dec_units)

        train_obj = DistributedTrain(epochs, enable_function, encoder, decoder,
                                     inp_lang, targ_lang, batch_size,
                                     local_batch_size)
        print('Training ...')
        return train_obj.training_loop(train_iterator, test_iterator,
                                       num_train_steps_per_epoch,
                                       num_test_steps_per_epoch, strategy)