示例#1
0
    def __init__(self, args):
        self.quantity = args.quantity
        self.max_purchase_quantity = args.max_purchase_quantity
        self.amount_bin_size = args.amount_bin_size
        self.state_bin_size = args.state_bin_size

        self.rate = utils.get_interest_rate()
        self.price = args.price

        # TODO: epsilon을 state마다 따로 둘까? 아니면 action마다 따로 둬야 하나?
        # TODO: Q_Epsilon은 State마다 따로 두고,
        # TODO: P_Epsilon은 (State,Action)마다 따로 줄까?
        self.q_eps = 1.0
        self.p_eps = 1.0
        self.q_eps_decay = args.q_eps_decay
        self.p_eps_decay = args.p_eps_decay

        self.window = args.window
        self.stack_to_state = self.create_stack_to_state()
        self.benefit_tables = {
            state: self.create_benefit_table(stack)
            for stack, state in self.stack_to_state.items()
        }
        self.times = {
            state: utils.MovingAverage(self.window)
            for state in self.stack_to_state.values()
        }

        self.uri = "http://localhost:3000"
        self.headers = {'Content-type': 'application/json'}
        self.id = None

        self.query_minimum = args.query_minimum
        self.query_diff = args.query_diff
        self.query_std = args.query_std
示例#2
0
 def create_benefit_table(self, stack):
     # TODO: 최대 구매 수량 제한을 에이전트별로 다르게 하는 것도 괜찮을까?
     max_n_actions = (self.max_purchase_quantity //
                      self.amount_bin_size) + 1
     n_actions = (stack // self.amount_bin_size) + 1
     n_actions = min(max_n_actions, n_actions)
     benefit_table = [
         utils.MovingAverage(self.window) for _ in range(n_actions)
     ]
     return benefit_table
示例#3
0
def main(argv):
  del argv  # unused arg
  tf.io.gfile.makedirs(FLAGS.output_dir)
  logging.info('Saving checkpoints at %s', FLAGS.output_dir)
  tf.random.set_seed(FLAGS.seed)

  batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
  steps_per_epoch = APPROX_IMAGENET_TRAIN_IMAGES // batch_size
  steps_per_eval = IMAGENET_VALIDATION_IMAGES // batch_size
  logging.info('Saving checkpoints at %s', FLAGS.output_dir)

  if FLAGS.use_gpu:
    logging.info('Use GPU')
    strategy = tf.distribute.MirroredStrategy()
  else:
    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.experimental.TPUStrategy(resolver)

  width_coefficient, depth_coefficient, input_image_size, dropout_rate = (
      efficientnet_model.efficientnet_params(FLAGS.model_name))
  imagenet_train = utils.ImageNetInput(
      is_training=True,
      use_bfloat16=FLAGS.use_bfloat16,
      data_dir=FLAGS.data_dir,
      batch_size=FLAGS.per_core_batch_size,
      image_size=input_image_size,
      normalize_input=True,
      one_hot=True)
  imagenet_eval = utils.ImageNetInput(
      is_training=False,
      use_bfloat16=FLAGS.use_bfloat16,
      data_dir=FLAGS.data_dir,
      batch_size=batch_size,
      image_size=input_image_size,
      normalize_input=True,
      one_hot=True)
  train_dataset = strategy.experimental_distribute_datasets_from_function(
      imagenet_train.input_fn)
  test_datasets = {
      'clean':
          strategy.experimental_distribute_dataset(imagenet_eval.input_fn()),
  }
  train_iterator = iter(train_dataset)
  test_iterator = iter(test_datasets['clean'])

  if FLAGS.use_bfloat16:
    policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
    tf.keras.mixed_precision.experimental.set_policy(policy)

  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.output_dir, 'summaries'))

  with strategy.scope():
    logging.info('Building %s model', FLAGS.model_name)
    model = efficientnet_model.Model(width_coefficient,
                                     depth_coefficient,
                                     dropout_rate)

    scaled_lr = FLAGS.base_learning_rate * (batch_size / 256.0)
    # Decay epoch is 2.4, warmup epoch is 5 according to the Efficientnet paper.
    decay_steps = steps_per_epoch * 2.4
    warmup_step = steps_per_epoch * 5
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        scaled_lr, decay_steps, decay_rate=0.97, staircase=True)
    learning_rate = utils.WarmupDecaySchedule(lr_schedule, warmup_step)
    optimizer = tf.keras.optimizers.RMSprop(
        learning_rate, rho=0.9, momentum=0.9, epsilon=0.001)
    if FLAGS.moving_average_decay > 0:
      optimizer = utils.MovingAverage(
          optimizer,
          average_decay=FLAGS.moving_average_decay)
      optimizer.shadow_copy(model)

    metrics = {
        'train/negative_log_likelihood': tf.keras.metrics.Mean(),
        'train/accuracy': tf.keras.metrics.CategoricalAccuracy(),
        'train/ece': ed.metrics.ExpectedCalibrationError(
            num_bins=FLAGS.num_bins),
        'train/loss': tf.keras.metrics.Mean(),
        'test/negative_log_likelihood': tf.keras.metrics.Mean(),
        'test/accuracy': tf.keras.metrics.CategoricalAccuracy(),
        'test/ece': ed.metrics.ExpectedCalibrationError(
            num_bins=FLAGS.num_bins),
    }
    logging.info('Finished building %s model', FLAGS.model_name)

    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
    initial_epoch = 0
    if latest_checkpoint:
      # checkpoint.restore must be within a strategy.scope() so that optimizer
      # slot variables are mirrored.
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)
      initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

  def train_step(inputs):
    """Build `step_fn` for efficientnet learning."""
    images, labels = inputs

    num_replicas = tf.cast(strategy.num_replicas_in_sync, tf.float32)
    l2_coeff = tf.cast(FLAGS.l2, tf.float32)

    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      logits = tf.cast(logits, tf.float32)
      negative_log_likelihood = tf.reduce_mean(
          tf.keras.losses.categorical_crossentropy(
              labels,
              logits,
              from_logits=True,
              label_smoothing=FLAGS.label_smoothing))

      def _is_batch_norm(v):
        """Decide whether a variable belongs to `batch_norm`."""
        keywords = ['batchnorm', 'batch_norm', 'bn']
        return any([k in v.name.lower() for k in keywords])

      l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in model.trainable_weights
                          if not _is_batch_norm(v)])
      loss = negative_log_likelihood + l2_coeff * l2_loss
      scaled_loss = loss / num_replicas

    gradients = tape.gradient(scaled_loss, model.trainable_weights)
    # MovingAverage optimizer automatically updates avg when applying gradients.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

    sparse_labels = tf.cast(
        tf.math.argmax(labels, axis=-1, output_type=tf.int32), tf.float32)
    probs = tf.nn.softmax(logits)
    metrics['train/loss'].update_state(loss)
    metrics['train/negative_log_likelihood'].update_state(
        negative_log_likelihood)
    metrics['train/accuracy'].update_state(labels, logits)
    metrics['train/ece'].update_state(sparse_labels, probs)

    step_info = {
        'loss/negative_log_likelihood': negative_log_likelihood / num_replicas,
        'loss/total_loss': scaled_loss,
    }
    return step_info

  def eval_step(inputs):
    """A single step."""
    images, labels = inputs
    logits = model(images, training=False)
    logits = tf.cast(logits, tf.float32)
    negative_log_likelihood = tf.reduce_mean(
        tf.keras.losses.categorical_crossentropy(
            labels, logits, from_logits=True))
    sparse_labels = tf.cast(
        tf.math.argmax(labels, axis=-1, output_type=tf.int32), tf.float32)
    probs = tf.nn.softmax(logits)
    metrics['test/negative_log_likelihood'].update_state(
        negative_log_likelihood)
    metrics['test/accuracy'].update_state(labels, logits)
    metrics['test/ece'].update_state(sparse_labels, probs)

  @tf.function
  def epoch_fn(should_eval):
    """Build `epoch_fn` for training and potential eval."""
    for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
      info = strategy.run(train_step, args=(next(train_iterator),))

      optim_step = optimizer.iterations
      if optim_step % tf.cast(100, optim_step.dtype) == 0:
        for k, v in info.items():
          v_reduce = strategy.reduce(tf.distribute.ReduceOp.SUM, v, None)
          tf.summary.scalar(k, v_reduce, optim_step)
        tf.summary.scalar('loss/lr', learning_rate(optim_step), optim_step)
        summary_writer.flush()

    if should_eval:
      if isinstance(optimizer, utils.MovingAverage):
        optimizer.swap_weights(strategy)
      for _ in tf.range(tf.cast(steps_per_eval, tf.int32)):
        strategy.run(eval_step, args=(next(test_iterator),))
      if isinstance(optimizer, utils.MovingAverage):
        optimizer.swap_weights(strategy)

  # Main training loop.
  start_time = time.time()
  with summary_writer.as_default():
    for epoch in range(initial_epoch, FLAGS.train_epochs):
      logging.info('Starting to run epoch: %s', epoch)
      should_eval = (epoch % FLAGS.evaluation_interval == 0)
      epoch_start_time = time.time()
      # Pass tf constant to avoid re-tracing.
      epoch_fn(tf.constant(should_eval))
      epoch_time = time.time() - epoch_start_time
      example_per_secs = (steps_per_epoch * batch_size) / epoch_time
      if not should_eval:
        tf.summary.scalar(
            'examples_per_secs', example_per_secs, optimizer.iterations)
        summary_writer.flush()

      current_step = (epoch + 1) * steps_per_epoch
      max_steps = steps_per_epoch * FLAGS.train_epochs
      time_elapsed = time.time() - start_time
      steps_per_sec = float(current_step) / time_elapsed
      eta_seconds = (max_steps - current_step) / steps_per_sec
      message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                     current_step / max_steps,
                     epoch + 1,
                     FLAGS.train_epochs,
                     steps_per_sec,
                     eta_seconds / 60,
                     time_elapsed / 60))
      logging.info(message)

      logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                   metrics['train/loss'].result(),
                   metrics['train/accuracy'].result() * 100)

      if should_eval:
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)

      total_metrics = metrics.copy()
      total_results = {name: metric.result()
                       for name, metric in total_metrics.items()}
      total_results.update({'lr': learning_rate(optimizer.iterations)})
      with summary_writer.as_default():
        for name, result in total_results.items():
          if should_eval or 'test' not in name:
            tf.summary.scalar(name, result, step=epoch + 1)

      for metric in metrics.values():
        metric.reset_states()

      if (FLAGS.checkpoint_interval > 0 and
          (epoch + 1) % FLAGS.checkpoint_interval == 0):
        checkpoint_name = checkpoint.save(os.path.join(
            FLAGS.output_dir, 'checkpoint'))
        logging.info('Saved checkpoint to %s', checkpoint_name)

  final_checkpoint_name = checkpoint.save(
      os.path.join(FLAGS.output_dir, 'checkpoint'))
  logging.info('Saved last checkpoint to %s', final_checkpoint_name)
示例#4
0
def main(argv):
    #######################################################################
    # Initial Setup. Logging, Flags, Random seeds.
    #######################################################################
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")
    absl_logging.use_python_logging()
    flags_dict = {
        flag.name: flag.value
        for flag in FLAGS.flags_by_module_dict()[argv[0]]
    }

    if FLAGS.use_subset:
        message = (f"{colorama.Back.RED}{colorama.Fore.WHITE}"
                   f"{colorama.Style.BRIGHT}USING A SUBSET OF THE DATASET"
                   f"{colorama.Style.RESET_ALL}")
        LOGGER.warning(message)

    utils.log_module_args(LOGGER, argv[0])
    if not FLAGS.output_dir.startswith("gs://"):
        utils.check_exists(FLAG_OUTPUT_DIR.value)
        if not tf.io.gfile.isdir(FLAG_OUTPUT_DIR.value):
            raise RuntimeError("Output dir needs to be a directory.")

    tf.random.set_seed(FLAG_RANDOM_SEED.value)
    np.random.seed(FLAG_RANDOM_SEED.value)

    # Prepare the instance output directory path and save the config there
    folder_name = time.strftime(
        f"{FLAG_RUN_NAME.value}_{FLAG_APPROACH_TYPE.value}_%Y%m%d-%H%M%S")
    instance_output_dir = os.path.join(FLAG_OUTPUT_DIR.value,
                                       folder_name).strip()
    if not instance_output_dir.endswith("/"):
        instance_output_dir += "/"
    json_target = os.path.join(instance_output_dir, "training_params.json")
    if not json_target.strip().startswith("gs://"):
        subprocess.check_call(["mkdir", "-p", instance_output_dir])
    utils.to_json_file(json_target, instance_output_dir)

    ##############################################################################
    # Initialization and Configuration of the Devices.
    ##############################################################################
    tpu_setup = None
    # current_acelerator_type is always "CPU" in the beginning with TPUs
    if tf_utils.current_accelerator_type() == "CPU":
        tpu_setup = tf_utils.init_tpus()

    LOGGER.debug("Devices we are computing on:\n%s",
                 utils.wrap_iterable(map(str, tf_utils.devices_to_use())))
    LOGGER.debug("All devices:")
    LOGGER.debug(tf_utils.device_mapping())

    if tf_utils.current_accelerator_type() == "GPU":
        tf.config.set_soft_device_placement(True)

    if tf_utils.current_accelerator_type() != "TPU":
        tf.debugging.set_log_device_placement(True)

    if FLAG_DISTRIBUTE_MODE.value in constants.PURE_DATA_PARALLEL_STRATEGIES:
        actual_num_replicas = len(tf_utils.devices_to_use())
    elif FLAG_DISTRIBUTE_MODE.value in constants.DATA_PARALLEL_DMC:
        actual_num_replicas = FLAG_NUM_REPLICAS.value
    else:
        actual_num_replicas = 1

    ##############################################################################
    # We load the retriever model if it is needed.
    ##############################################################################
    # Not currently used.

    retriever = None
    # if (FLAG_APPROACH_TYPE.value ==
    #     constants.ApproachTypeChoices.lm_and_realm):
    #   raise NotImplementedError("This part needs to be tested anew.")
    # config_path = FLAG_RETRIEVER_CONFIG_PATH.value
    # realm_save = tf_utils.REALMSave(**utils.from_json_file(config_path))
    #
    # # Approx 15 min when not in dev mode, on CPU
    # with utils.log_duration(LOGGER, "main",
    #                         "whole of BERTScaNNRetriever.__init__",
    #                         logging.INFO):
    #   scann_config = retrievers.ScannConfig(
    #       **utils.from_json_file(FLAG_SCANN_CONFIG_PATH.value))
    #   retriever = retrievers.BERTScaNNRetriever(
    #       retriever_module_path=realm_save.query_embedder_path,
    #       block_records_path=realm_save.text_records,
    #       num_block_records=realm_save.num_block_records,
    #       mode=tf.estimator.ModeKeys.EVAL,
    #       scann_config=scann_config)

    # elif (FLAG_APPROACH_TYPE.value ==
    #       constants.ApproachTypeChoices.cached_realm):
    #   raise NotImplementedError("This part needs to be tested anew.")
    # config_path = FLAG_RETRIEVER_CONFIG_PATH.value
    # realm_save = tf_utils.REALMSave(**utils.from_json_file(config_path))
    #
    # # Approx 15 min when not in dev mode, on CPU
    # with utils.log_duration(LOGGER, "main",
    #                         "whole of FullyCachedRetriever.__init__",
    #                         logging.INFO):
    #
    #   retriever = retrievers.FullyCachedRetriever(
    #       db_path=FLAG_FULLYCACHED_H5_PATH.value,
    #       block_records_path=realm_save.text_records,
    #       num_block_records=realm_save.num_block_records,
    #       )

    ##############################################################################
    # Distributed training task
    ##############################################################################
    if FLAG_TASK.value == constants.TaskChoices.train:
        with utils.log_duration(LOGGER, "main", "Load model"):
            utils.print_mem("before loading model", LOGGER)
            model_specific = task_specific.load_model(
                FLAG_MODEL_LOAD_PATH.value, FLAG_MODEL_KEY.value,
                FLAG_DISTRIBUTE_MODE.value, tpu_setup, FLAG_NUM_REPLICAS.value)
            utils.print_mem("after loading model", LOGGER)
            model_or_replicas = model_specific.model
            if isinstance(model_or_replicas, list):
                model_or_replicas: List[transformers.TFGPT2LMHeadModel]
            else:
                model_or_replicas: transformers.TFGPT2LMHeadModel

            tokenizer = model_specific.tokenizer

            def make_optimizer():
                return tensor2tensor.utils.adafactor.AdafactorOptimizer(
                    learning_rate=FLAG_LEARNING_RATE.value)

            if model_specific.strategy:
                with model_specific.strategy.scope():
                    optimizer = make_optimizer()
            else:
                optimizer = make_optimizer()

        ############################################################################
        # Prepare the dataset functions
        ############################################################################
        rg = np.random.default_rng(FLAG_RANDOM_SEED.value)

        def call_lm_preproc(repeat, split, random_seed):
            """Using functools.partial prevents the linter from doing its job."""
            if FLAG_DATASET_NAME.value == constants.DatasetNameChoices.kilt_eli5:
                return task_specific.create_lm_ds_kilt_eli5(
                    tokenizer=tokenizer,
                    context_window_size=(
                        model_or_replicas[0].config.n_positions if isinstance(
                            model_or_replicas,
                            list) else model_or_replicas.config.n_positions),
                    dataset_name=FLAG_DATASET_NAME.value,
                    # Batches are split over the replicas:
                    batch_size=FLAG_BATCH_SIZE.value * actual_num_replicas,
                    db_path=FLAG_DB_PATH.value,
                    random_seed=random_seed,
                    use_subset=FLAG_USE_SUBSET.value,
                    subset_size=FLAG_SUBSET_SIZE.value,
                    use_helper_words=FLAG_USE_HELPER_WORDS.value,
                    approach_type=FLAG_APPROACH_TYPE.value,
                    num_retrievals=FLAG_NUM_RETRIEVALS.value,
                    retrieval_temperature=FLAG_RETRIEVAL_TEMPERATURE.value,
                    retriever=retriever,
                    repeat=repeat,
                    split=split,
                    enable_debug_checks=FLAG_DATASET_DEBUG.value,
                    retrieval_bank_size=FLAG_RETRIEVAL_BANK_SIZE.value,
                    dataset_type=FLAG_DATASET_TYPE.value,
                    qty_shuffle=FLAG_QTY_SHUFFLE.value,
                    tfr_prefix=FLAG_TFR_PREFIX.value,
                    max_length_generation=FLAG_MAX_LENGTH_GENERATION.value,
                )
            else:
                raise NotImplementedError(
                    f"FLAG_DATASET_NAME.value unsupported: `{FLAG_DATASET_NAME.value}`"
                )

        make_training_dataset: Callable[Ellipsis,
                                        tf.data.Dataset] = functools.partial(
                                            call_lm_preproc,
                                            split="train",
                                            repeat=False,
                                        )
        make_eval_dataset: Callable[Ellipsis,
                                    tf.data.Dataset] = functools.partial(
                                        call_lm_preproc,
                                        split="eval",
                                        repeat=True,
                                    )

        ############################################################################
        # Prepare the step functions
        ############################################################################
        utils.check_contained(FLAG_DISTRIBUTE_MODE.value,
                              constants.DistributeModeChoices.choices())
        tf_function_flags = dict(
            experimental_compile=FLAG_EXPERIMENTAL_COMPILE.value,
            experimental_relax_shapes=not FLAG_INPUT_FIXED_SIZE.value)

        if (FLAG_DISTRIBUTE_MODE.value ==
                constants.DistributeModeChoices.split_and_data_parallel):
            if not isinstance(model_or_replicas, list):
                raise RuntimeError(type(model_or_replicas))
            training_step = build_manual_data_parallel_training_step(
                model_or_replicas, optimizer, tf_function_flags)

        else:
            training_step = build_regular_training_step(
                model_or_replicas,
                optimizer,
                strategy=model_specific.strategy,
                tf_function_kwargs=tf_function_flags)

        evaluation_step = build_evaluation_step(model_or_replicas,
                                                tf_function_flags)

        secs_since_last_ckpt = time.time()
        # Model checkpoints are saved to the tmp_directory and then rsynced to GCS
        ##########################################################################
        # Prepare the different logging facilities
        ##########################################################################
        train_log_dir = os.path.join(instance_output_dir, "tensorboard",
                                     "train")
        eval_log_dir = os.path.join(instance_output_dir, "tensorboard", "eval")
        flags_log_dir = os.path.join(instance_output_dir, "tensorboard",
                                     "params")
        writers = dict(train=tf.summary.create_file_writer(train_log_dir),
                       eval=tf.summary.create_file_writer(eval_log_dir),
                       flags=tf.summary.create_file_writer(flags_log_dir))
        with writers["flags"].as_default():
            tf.summary.text(
                "Flags",
                # Tensorboard takes Markdown:
                json.dumps(flags_dict, indent=4).replace("\n", "\n\n"),
                step=0)

        ma_loss = dict(train=utils.MovingAverage(0.9),
                       eval=utils.MovingAverage(0.9))
        step_counters = dict(train=0, eval=0)
        batch_counters = dict(train=0, eval=0)
        prev_batch_end = time.time()

        # The eval ds has no real concept of epoch, repeats forever, shuffling
        # each time it reaches its end
        with utils.log_duration(LOGGER, "main", "All of make_eval_dataset"):
            eval_ds_instance = make_eval_dataset(random_seed=rg.integers(
                -2**63, 2**63 - 1), )
        LOGGER.debug("Distributing the eval dataset to the replicas.")
        if FLAG_DATASET_TYPE.value == "tfr":
            eval_ds_instance = (
                model_specific.strategy.experimental_distribute_dataset(
                    eval_ds_instance))

        LOGGER.debug("Done distributing the eval dataset to the replcias.")
        eval_ds_instance = iter(eval_ds_instance)

        ##########################################################################
        # Training Loop
        ##########################################################################
        for epoch in itertools.count():
            ####################################################################
            # Epoch Setup
            ####################################################################
            LOGGER.debug("EPOCH %d START", epoch)
            # Shuffle differently every epoch
            with utils.log_duration(LOGGER, "main",
                                    "All of make_training_dataset"):
                train_ds_instance = make_training_dataset(
                    random_seed=rg.integers(-2**63, 2**63 - 1), )
            LOGGER.debug(
                "Attempting to distribute the training dataset to the replicas."
            )
            if FLAG_DATASET_TYPE.value == "tfr":
                train_ds_instance = (
                    model_specific.strategy.experimental_distribute_dataset(
                        train_ds_instance))

            LOGGER.debug(
                "Done distributing the training dataset to the replicas.")
            train_ds_instance = iter(train_ds_instance)

            # This allows us to see if we reached the end of the training iterator,
            # in which case "did_at_least_one_training_batch == False".
            # We could also test that it did all the batches, to similar results.
            did_at_least_one_training_batch = True
            split = "eval"
            while did_at_least_one_training_batch:
                # Invert split
                if split == "train":
                    split = "eval"
                else:
                    split = "train"

                # Prepare to test if we did at least one training batch
                if split == "train":
                    did_at_least_one_training_batch = False

                if split == "train":
                    dataset_iterator = itertools.islice(
                        train_ds_instance, FLAG_BATCHES_BETWEEN_EVALS.value)
                else:
                    # The evaluation DS is tiny, so we reshuffle and take a random
                    dataset_iterator = itertools.islice(
                        eval_ds_instance, FLAG_NUMBER_EVAL_BATCHES.value)

                LOGGER.debug("Batching")
                for batch in dataset_iterator:
                    # LOGGER.debug("Input sentence:\n\"%s\"",
                    #              tokenizer.decode([x for x in batch["input_ids"][0]
                    #                                if x != tokenizer.eos_token_id]))
                    # LOGGER.debug("Label:\n\"%s\"",
                    #              tokenizer.decode([(x if x != -100 else 0)
                    #                                for x in batch["label_ids"][0]]))

                    if FLAG_DATASET_TYPE.value != "tfr":
                        batch = (model_specific.strategy.
                                 experimental_distribute_values_from_function(
                                     tf_utils.make_dict_distribute_fn(batch)))

                    # We only care about training epochs as, obviously, we don't train
                    # over eval samples; the number of  eval samples seen only
                    # contributes to lowering the variance in the evaluation of when to
                    # do early stopping.
                    if split == "train":
                        did_at_least_one_training_batch = True

                    input_ids = batch["input_ids"]
                    label_ids = batch["label_ids"]

                    ####################################################################
                    # Training Step
                    ####################################################################
                    step_counters[split] += (FLAG_BATCH_SIZE.value *
                                             actual_num_replicas)

                    if split == "train":
                        batch_counters[split] += 1
                        training_kwargs = dict(
                            input_ids=input_ids,
                            label_ids=label_ids,
                        )

                        if model_specific.strategy:
                            utils.print_mem("before running", LOGGER)

                            LOGGER.debug("Training, Calling strategy.run")
                            loss = model_specific.strategy.run(
                                training_step, kwargs=training_kwargs)
                            LOGGER.debug("Training, Done with strategy.run")
                            utils.print_mem("after running", LOGGER)

                        else:
                            loss = training_step(**training_kwargs)  # pytype: disable=wrong-arg-count
                            # If we are in the strategy-free data parallel mode, we need
                            # to change the weights of all replicas to those of the model at
                            # index 0
                            if (FLAG_DISTRIBUTE_MODE.value ==
                                    constants.DistributeModeChoices.
                                    split_and_data_parallel):
                                for replica in model_or_replicas[1:]:
                                    replica.set_weights(
                                        model_or_replicas[0].get_weights())

                    ####################################################################
                    # Evaluation Step
                    ####################################################################
                    elif split == "eval":
                        evaluation_kwargs = dict(
                            input_ids=input_ids,
                            label_ids=label_ids,
                        )

                        if model_specific.strategy:
                            loss = model_specific.strategy.run(
                                evaluation_step, kwargs=evaluation_kwargs)
                        else:
                            loss = evaluation_step(**evaluation_kwargs)
                    else:
                        raise ValueError(
                            f"Unexpected value for split: {split}")

                    ####################################################################
                    # Logging
                    ####################################################################
                    if (FLAG_DISTRIBUTE_MODE.value
                            in constants.PURE_DATA_PARALLEL_STRATEGIES):
                        utils.check_equal(len(loss.values),
                                          actual_num_replicas)
                        LOGGER.debug("Split: %s", split)
                        LOGGER.debug("Real num replicas: %s",
                                     actual_num_replicas)
                        LOGGER.debug("Loss: %s", loss)
                        LOGGER.debug("Loss values: %s", loss.values)

                        average_loss = float(
                            tf.math.reduce_mean(loss.values).numpy())
                    else:
                        average_loss = float(loss.numpy())

                    # tf.debugging.check_numerics(loss)
                    now = time.time()
                    batch_duration = now - prev_batch_end
                    prev_batch_end = now
                    ma_loss[split].update(average_loss)

                    # Actual logging
                    LOGGER.info("Epoch: # %d", epoch)
                    LOGGER.info("Tensorboard_dir: %s", instance_output_dir)
                    LOGGER.info("Batch: %s # %d", split, batch_counters[split])
                    LOGGER.info("Step: %s # %d", split, step_counters[split])
                    if FLAG_USE_SUBSET.value:
                        LOGGER.warning(">> USING A SUBSET OF THE DATASET <<")
                    LOGGER.info("%(split)s Batch loss:           %(metric)f",
                                dict(split=split, metric=average_loss))
                    LOGGER.info(
                        "%(split)s Moving average loss:  %(metric)f",
                        dict(split=split, metric=ma_loss[split].average))
                    LOGGER.info(
                        "%(split)s Moving average ppl:   %(metric)f",
                        dict(split=split,
                             metric=np.exp(ma_loss[split].average)))
                    LOGGER.info(
                        "%(split)s Batch duration:       %(duration)s",
                        dict(split=split,
                             duration=utils.TimeStamp.from_seconds(
                                 batch_duration).format()))
                    if FLAG_DISTRIBUTE_MODE.value in constants.DATA_PARALLEL_DMC:
                        LOGGER.info(
                            "%(split)s Duration per sample:  %(duration)s",
                            dict(split=split,
                                 duration=utils.TimeStamp.from_seconds(
                                     batch_duration / (FLAG_BATCH_SIZE.value *
                                                       actual_num_replicas))))

                    # Write to Tensorboard
                    with writers[split].as_default():
                        tf.summary.scalar(f"Loss/{split}", average_loss,
                                          step_counters[split])
                        tf.summary.scalar(f"PPL/{split}", np.exp(average_loss),
                                          step_counters[split])
                    writers[split].flush()

                    # Save every 5 min
                    if (time.time() - secs_since_last_ckpt) / (60 * 20) >= 1:
                        secs_since_last_ckpt = time.time()
                        save_model(train_steps=step_counters["train"],
                                   model_or_replicas=model_or_replicas,
                                   instance_output_dir=instance_output_dir)

                secs_since_last_ckpt = time.time()
                save_model(train_steps=step_counters["train"],
                           model_or_replicas=model_or_replicas,
                           instance_output_dir=instance_output_dir)
        #############################################################
        # Post Training Cleanup
        #######################################################################
        for writer in writers.values():
            writer.close()
示例#5
0
    def train(self):
        """
        Main actor learner loop for parallerl advantage actor critic learning.
        """
        logging.info('Starting training at step %d' % self.global_step)
        logging.debug('Device: {}'.format(self.device))

        counter = 0
        global_step_start = self.global_step
        average_loss = utils.MovingAverage(
            0.01, ['actor', 'critic', 'entropy', 'grad_norm'])

        total_rewards, training_stats, total_length = [], [], []

        num_emulators = self.batch_env.num_emulators
        total_episode_rewards = np.zeros(num_emulators)

        #stores 0.0 in i-th element if the episode in i-th emulator has just started, otherwise stores 1.0
        #mask is used to cut rnn_state and episode rewards between episodes.
        mask_t = th.zeros(num_emulators).to(self.device)

        #feedforward networks also use rnn_state, it's just empty!
        rnn_state = self.network.init_rnn_state(num_emulators)

        states, infos = self.batch_env.reset_all()
        self.batch_env.set_difficulty(self.starting_length)

        if self.evaluate is not None:
            stats = self.evaluate(self.network)
            training_stats.append((self.global_step, stats))

        start_time = time.time()
        while self.global_step < self.total_steps:

            loop_start_time = time.time()
            values, log_probs, rewards, entropies, masks = [], [], [], [], []
            self.network.detach_rnn_state(rnn_state)

            for t in range(self.rollout_steps):
                outputs = self.choose_action(states, infos,
                                             mask_t.unsqueeze(1), rnn_state)
                a_t, v_t, log_probs_t, entropy_t, rnn_state = outputs
                states, rs, dones, infos = self.batch_env.next(a_t)

                tensor_rs = th.from_numpy(self.reshape_r(rs)).to(self.device)
                rewards.append(tensor_rs)
                entropies.append(entropy_t)
                log_probs.append(log_probs_t)
                values.append(v_t)

                mask_t = 1.0 - th.from_numpy(dones).to(
                    self.device)  #dones.dtype == np.float32
                masks.append(
                    mask_t)  #1.0 if episode is not done, 0.0 otherwise

                done_mask = dones.astype(bool)
                total_episode_rewards += rs

                if any(done_mask):
                    total_rewards.extend(total_episode_rewards[done_mask])
                    total_episode_rewards[done_mask] = 0.

            next_v = self.predict_values(states, infos, mask_t.unsqueeze(1),
                                         rnn_state)

            update_stats = self.update_weights(next_v, rewards, masks, values,
                                               log_probs, entropies)
            average_loss.update(**update_stats)

            self.global_step += num_emulators * self.rollout_steps
            counter += 1

            if counter % (self.print_every //
                          (num_emulators * self.rollout_steps)) == 0:
                curr_time = time.time()
                self._training_info(
                    total_rewards=total_rewards,
                    average_speed=(self.global_step - global_step_start) /
                    (curr_time - start_time),
                    loop_speed=(num_emulators * self.rollout_steps) /
                    (curr_time - loop_start_time),
                    update_stats=average_loss)

            if counter % (self.eval_every //
                          (num_emulators * self.rollout_steps)) == 0:
                if self.evaluate is not None:
                    stats = self.evaluate(self.network)
                    if stats.final_res > 0.95:
                        print(stats.final_res, 'stats.final_res ')
                        if self.curr_learning == True:  #if it is curriculum learning, and final_res > 95 %, then enlarge th length
                            print(self.curr_learning, 'self.curr_learning')
                            self.change_length_labyrinth()
                        else:
                            pass
                    training_stats.append((self.global_step, stats))

            if self.global_step - self.last_saving_step >= self.save_every:
                self._save_progress(self.checkpoint_dir,
                                    summaries=training_stats,
                                    is_best=False)
                training_stats = []
                self.last_saving_step = self.global_step

        self._save_progress(self.checkpoint_dir, is_best=False)
        logging.info('Training ended at step %d' % self.global_step)
示例#6
0
def test_moving_average():
    ma = utils.MovingAverage(0.9)
    assert ma.update(10) == 10
    assert ma.update(10) == 10
    assert ma.update(10) == 10
示例#7
0
    def train(self):
        """
        Main actor learner loop for parallerl advantage actor critic learning.
        """
        logging.info('Starting training at step %d' % self.global_step)
        logging.debug('use_cuda == {}'.format(self.use_cuda))

        counter = 0
        global_step_start = self.global_step
        average_loss = utils.MovingAverage(0.01, ['total', 'actor', 'critic'])
        total_rewards, training_stats = [], []

        if self.eval_func is not None:
            stats = self.evaluate(verbose=True)
            training_stats.append((self.global_step, stats))

        #num_actions = self.args['num_actions']
        num_emulators = self.args['num_envs']
        max_local_steps = self.args['max_local_steps']
        max_global_steps = self.args['max_global_steps']
        clip_norm = self.args['clip_norm']
        rollout_steps = num_emulators * max_local_steps

        states, infos = self.batch_env.reset_all()

        emulator_steps = np.zeros(num_emulators, dtype=int)
        total_episode_rewards = np.zeros(num_emulators)
        not_done_masks = torch.zeros(max_local_steps, num_emulators).type(
            self._tensors.FloatTensor)
        if self.use_rnn:
            hx_init, cx_init = self.network.get_initial_state(num_emulators)
            hx, cx = hx_init, cx_init
        else:  #for feedforward nets just ignore this argument
            hx, cx = None, None

        start_time = time.time()
        while self.global_step < max_global_steps:
            loop_start_time = time.time()
            values, log_probs, rewards, entropies = [], [], [], []
            if self.use_rnn:
                hx, cx = hx.detach(), cx.detach(
                )  #Do I really need to detach here?

            for t in range(max_local_steps):
                outputs = self.choose_action(states, infos, (hx, cx))
                a_t, v_t, log_probs_t, entropy_t, (hx, cx) = outputs
                states, rs, dones, infos = self.batch_env.next(a_t)

                #actions_sum += a_t
                rewards.append(np.clip(rs, -1., 1.))
                entropies.append(entropy_t)
                log_probs.append(log_probs_t)
                values.append(v_t)
                is_done = torch.from_numpy(dones).type(
                    self._tensors.FloatTensor)
                not_done_masks[t] = 1.0 - is_done

                done_mask = dones.astype(bool)
                total_episode_rewards += rs
                emulator_steps += 1

                total_rewards.extend(total_episode_rewards[done_mask])
                total_episode_rewards[done_mask] = 0.
                emulator_steps[done_mask] = 0
                if self.use_rnn and any(
                        done_mask
                ):  # we need to clear all lstm states corresponding to the terminated emulators
                    done_idx = is_done.nonzero().view(-1)
                    hx, cx = hx.clone(), cx.clone(
                    )  #hx_t, cx_t are used for backward op, so we can't modify them in-place
                    hx[done_idx, :] = hx_init[done_idx, :].detach()
                    cx[done_idx, :] = cx_init[done_idx, :].detach()

            self.global_step += rollout_steps
            next_v = self.predict_values(states, infos, (hx, cx))
            R = next_v.detach().view(-1)

            delta_v = []
            for t in reversed(range(max_local_steps)):
                rs = Variable(torch.from_numpy(rewards[t])).type(
                    self._tensors.FloatTensor)
                not_done_t = Variable(not_done_masks[t])
                R = rs + self.gamma * R * not_done_t
                delta_v_t = R - values[t].view(-1)
                delta_v.append(delta_v_t)

            loss, actor_loss, critic_loss = self.compute_loss(
                torch.cat(delta_v, 0),
                torch.cat(log_probs, 0).view(-1),
                torch.cat(entropies, 0).view(-1))

            self.lr_scheduler.adjust_learning_rate(self.global_step)
            self.optimizer.zero_grad()
            loss.backward()
            global_norm = self.clip_gradients(self.network.parameters(),
                                              clip_norm)
            self.optimizer.step()

            average_loss.update(total=loss.data.item(),
                                actor=actor_loss.item(),
                                critic=critic_loss.item())

            counter += 1
            if counter % (self.print_every // rollout_steps) == 0:
                curr_time = time.time()
                self._training_info(
                    total_rewards=total_rewards,
                    average_speed=(self.global_step - global_step_start) /
                    (curr_time - start_time),
                    loop_speed=rollout_steps / (curr_time - loop_start_time),
                    moving_averages=average_loss,
                    grad_norms=global_norm)

            if counter % (self.eval_every // rollout_steps) == 0:
                if (self.eval_func is not None):
                    stats = self.evaluate(verbose=True)
                    training_stats.append((self.global_step, stats))

            if self.global_step - self.last_saving_step >= self.save_every:
                self._save_progress(self.checkpoint_dir,
                                    summaries=training_stats,
                                    is_best=False)
                training_stats = []
                self.last_saving_step = self.global_step

        self._save_progress(self.checkpoint_dir, is_best=False)
        logging.info('Training ended at step %d' % self.global_step)
示例#8
0
def main(argv):
    ##############################################################################
    # Initial Setup. Logging, Flags, Random seeds.
    ##############################################################################
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")
    absl_logging.use_python_logging()
    flags_dict = {
        flag.name: flag.value
        for flag in FLAGS.flags_by_module_dict()[argv[0]]
    }

    if FLAGS.use_subset:
        message = (f"{colorama.Back.RED}{colorama.Fore.WHITE}"
                   f"{colorama.Style.BRIGHT}USING A SUBSET OF THE DATASET"
                   f"{colorama.Style.RESET_ALL}")
        LOGGER.warning(message)

    utils.log_module_args(LOGGER, argv[0])
    if not FLAGS.output_dir.startswith("gs://"):
        utils.check_exists(FLAG_OUTPUT_DIR.value)
        if not tf.io.gfile.isdir(FLAG_OUTPUT_DIR.value):
            raise RuntimeError("Output dir needs to be a directory.")

    tf.random.set_seed(FLAG_RANDOM_SEED.value)
    np.random.seed(FLAG_RANDOM_SEED.value)

    # Prepare the instance output directory path and save the config there
    # Prepare the path
    folder_name = time.strftime(
        f"{FLAG_RUN_NAME.value}_{FLAG_APPROACH_TYPE.value}_%Y%m%d-%H%M%S")
    instance_output_dir = os.path.join(FLAG_OUTPUT_DIR.value,
                                       folder_name).strip()
    if not instance_output_dir.endswith("/"):
        instance_output_dir += "/"
    json_target = os.path.join(instance_output_dir, "training_params.json")

    # Make the folder if we're not on gcloud
    if not json_target.strip().startswith("gs://"):
        subprocess.check_call(["mkdir", "-p", instance_output_dir])

    # Safe the config file
    utils.to_json_file(json_target, flags_dict)

    ##############################################################################
    # Initialization and Configuration of the Devices.
    ##############################################################################
    tpu_setup = None

    accel = tf_utils.current_accelerator_type()
    if FLAG_TPU_IS_LOCAL.value:
        assert accel == "TPU", accel
    if accel == "TPU":
        assert FLAG_TPU_IS_LOCAL.value, FLAG_TPU_IS_LOCAL.value

    if tf_utils.current_accelerator_type() in {"CPU", "TPU"}:
        tpu_setup = tf_utils.init_tpus(tpu_name=FLAG_TPU_NAME.value,
                                       local=FLAG_TPU_IS_LOCAL.value)

    LOGGER.debug("Devices we are computing on:\n%s",
                 utils.wrap_iterable(map(str, tf_utils.devices_to_use())))
    LOGGER.debug("All devices:")
    LOGGER.debug(tf_utils.device_mapping())

    if tf_utils.current_accelerator_type() == "GPU":
        tf.config.set_soft_device_placement(True)

    if tf_utils.current_accelerator_type() != "TPU":
        tf.debugging.set_log_device_placement(True)

    utils.check_operator(operator.ne, tf_utils.current_accelerator_type(),
                         "CPU")

    assert FLAG_TPU_NAME.value == socket.gethostname(), (
        "This is a configuration choice. You can remove this. "
        "There will be no side effects.")

    if FLAG_DISTRIBUTE_MODE.value in constants.PURE_DATA_PARALLEL_STRATEGIES:
        actual_num_replicas = len(tf_utils.devices_to_use())
    elif FLAG_DISTRIBUTE_MODE.value in constants.DATA_PARALLEL_DMC:
        actual_num_replicas = FLAG_NUM_REPLICAS.value
    else:
        actual_num_replicas = 1

    ##############################################################################
    # We load the retriever model if it is needed.
    ##############################################################################
    # Not currently used. See old commits.
    retriever = None

    ##############################################################################
    # Distributed training task
    ##############################################################################
    if FLAG_TASK.value == constants.TaskChoices.train:
        with utils.log_duration(LOGGER, "main", "Load model"):
            utils.print_mem("before loading model", LOGGER)
            model_specific = task_specific.load_model(
                FLAG_MODEL_KEY.value, FLAG_DISTRIBUTE_MODE.value, tpu_setup,
                FLAG_NUM_REPLICAS.value)
            utils.print_mem("after loading model", LOGGER)
            model = model_specific.model
            if isinstance(model, list):
                model: List[transformers.TFGPT2LMHeadModel]
            else:
                model: transformers.TFGPT2LMHeadModel

            tokenizer = model_specific.tokenizer

            def make_optimizer():
                if FLAG_OPTIMIZER_TYPE.value == constants.OptimizerTypes.adafactor:
                    return tensor2tensor.utils.adafactor.AdafactorOptimizer(
                        learning_rate=FLAG_LEARNING_RATE.value)
                elif FLAG_OPTIMIZER_TYPE.value == constants.OptimizerTypes.adam:
                    return tf.keras.optimizers.Adam(
                        learning_rate=FLAG_LEARNING_RATE.value)
                else:
                    raise ValueError(FLAG_OPTIMIZER_TYPE.value)

            if model_specific.strategy:
                with model_specific.strategy.scope():
                    optimizer = make_optimizer()
            else:
                optimizer = make_optimizer()

        ############################################################################
        # Prepare the dataset functions
        ############################################################################
        rg = np.random.default_rng(FLAG_RANDOM_SEED.value)

        def call_lm_preproc(repeat, split, random_seed):
            """Using functools.partial prevents the linter from doing its job."""
            if FLAG_DATASET_NAME.value == constants.DatasetNameChoices.kilt_eli5:
                return task_specific.create_lm_ds_kilt_eli5(
                    tokenizer=tokenizer,
                    context_window_size=model.config.n_positions,
                    dataset_name=FLAG_DATASET_NAME.value,
                    # Batches are split over the replicas:
                    batch_size=FLAG_BATCH_SIZE.value * actual_num_replicas,
                    db_path=FLAG_DB_PATH.value,
                    random_seed=random_seed,
                    use_subset=FLAG_USE_SUBSET.value,
                    subset_size=FLAG_SUBSET_SIZE.value,
                    use_helper_words=FLAG_USE_HELPER_WORDS.value,
                    approach_type=FLAG_APPROACH_TYPE.value,
                    num_retrievals=FLAG_NUM_RETRIEVALS.value,
                    retrieval_temperature=FLAG_RETRIEVAL_TEMPERATURE.value,
                    retriever=retriever,
                    repeat=repeat,
                    split=split,
                    enable_debug_checks=FLAG_DATASET_DEBUG.value,
                    retrieval_bank_size=FLAG_RETRIEVAL_BANK_SIZE.value,
                    dataset_type=FLAG_DATASET_TYPE.value,
                    qty_shuffle=FLAG_QTY_SHUFFLE.value,
                    tfr_prefix=FLAG_TFR_PREFIX.value,
                    max_length_generation=FLAG_MAX_LENGTH_GENERATION.value,
                )
            else:
                raise NotImplementedError(
                    f"FLAG_DATASET_NAME.value unsupported: `{FLAG_DATASET_NAME.value}`"
                )

        make_training_dataset: Callable[...,
                                        tf.data.Dataset] = functools.partial(
                                            call_lm_preproc,
                                            split="train",
                                            repeat=False,
                                        )
        make_eval_dataset: Callable[..., tf.data.Dataset] = functools.partial(
            call_lm_preproc,
            split="eval",
            repeat=True,
        )

        ############################################################################
        # Prepare the step functions
        ############################################################################
        utils.check_contained(FLAG_DISTRIBUTE_MODE.value,
                              constants.DistributeModeChoices.choices())
        tf_function_flags = dict(
            experimental_compile=FLAG_EXPERIMENTAL_COMPILE.value,
            experimental_relax_shapes=not FLAG_INPUT_FIXED_SIZE.value)

        training_step = build_regular_training_step(
            model,
            optimizer,
            strategy=model_specific.strategy,
            tf_function_kwargs=tf_function_flags)

        evaluation_step = build_evaluation_step(model, tf_function_flags)

        timestamp_last_ckpt_secs = time.time()
        # Model checkpoints are saved to the tmp_directory and then rsynced to GCS

        ############################################################################
        # Prepare the statistics and the logging facilities.
        ############################################################################
        # Tensorboard
        with model_specific.strategy.scope():
            checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
        saver = Saver(instance_output_dir, checkpoint)
        train_log_dir = os.path.join(instance_output_dir, "tensorboard",
                                     "train")
        eval_log_dir = os.path.join(instance_output_dir, "tensorboard", "eval")
        flags_log_dir = os.path.join(instance_output_dir, "tensorboard",
                                     "params")
        writers = dict(train=tf.summary.create_file_writer(train_log_dir),
                       eval=tf.summary.create_file_writer(eval_log_dir),
                       flags=tf.summary.create_file_writer(flags_log_dir))
        with writers["flags"].as_default():
            tf.summary.text(
                "Flags",
                # Tensorboard takes Markdown:
                json.dumps(flags_dict, indent=4).replace("\n", "\n\n"),
                step=0)

        # Different information to log.
        ma_loss = dict(train=utils.MovingAverage(0.9),
                       eval=utils.MovingAverage(0.9))
        step_counters = dict(train=0, eval=0)
        batch_counters = dict(train=0, eval=0)
        prev_batch_end = time.time()

        ############################################################################
        # Create the Eval DS object.
        # ==========================================================================
        # The eval ds has no real concept of epoch, repeats forever, shuffling
        # each time it reaches its end.
        ############################################################################
        # Create
        with utils.log_duration(LOGGER, "main", "All of make_eval_dataset"):
            eval_ds_instance = make_eval_dataset(random_seed=rg.integers(
                -2**63, 2**63 - 1), )
        # Maybe distribute
        LOGGER.debug("Distributing the eval dataset to the replicas.")
        if FLAG_DATASET_TYPE.value == "tfr":
            eval_ds_instance = (
                model_specific.strategy.experimental_distribute_dataset(
                    eval_ds_instance))
        # Start the iteration. We step by calling `next(...)`.
        LOGGER.debug("Done distributing the eval dataset to the replicas.")
        eval_ds_instance = iter(eval_ds_instance)
        step_function = dict(train=training_step, eval=evaluation_step)

        ############################################################################
        # Training Loop
        # ==========================================================================
        # Create a new training dataset object that lasts for one epoch.
        # This is different from the eval training dataset object, which loops
        # forever.
        ############################################################################
        for epoch in itertools.count():
            ##########################################################################
            # Epoch Setup
            ##########################################################################
            LOGGER.debug("EPOCH %d START", epoch)
            # Shuffle differently every epoch
            with utils.log_duration(LOGGER, "main",
                                    "All of make_training_dataset"):
                train_ds_instance = make_training_dataset(
                    random_seed=rg.integers(-2**63, 2**63 - 1), )
            LOGGER.debug(
                "Attempting to distribute the training dataset to the replicas."
            )
            if FLAG_DATASET_TYPE.value == "tfr":
                train_ds_instance = (
                    model_specific.strategy.experimental_distribute_dataset(
                        train_ds_instance))

            LOGGER.debug(
                "Done distributing the training dataset to the replicas.")
            train_ds_instance = iter(train_ds_instance)

            # To change splits, we use `itertools.islice` over the dataset generator.
            # When the training dataset generator is done, a new loop of the following
            # while loop occurs, but no training batch is done because we are taking
            # an `islice` of a generator that is done.
            did_at_least_one_training_batch = True
            split = "eval"
            while did_at_least_one_training_batch:
                utils.check_operator(operator.ne,
                                     tf_utils.current_accelerator_type(),
                                     "CPU")

                # Invert split
                if split == "train":
                    split = "eval"
                else:
                    split = "train"

                # Prepare to test if we did at least one training batch
                if split == "train":
                    did_at_least_one_training_batch = False

                ########################################################################
                # Take slices from the dataset iterator
                # ======================================================================
                # We only want to do a certain number of batches before switching splits
                # We do this by using an `itertools.islice` of the dataset iterators.
                ########################################################################
                if split == "train":
                    dataset_iterator = toolz.take(
                        FLAG_BATCHES_BETWEEN_EVALS.value, train_ds_instance)
                else:
                    # The evaluation dataset generator is infinite, reshuffles everytime
                    # it gets to its end.
                    # Still, we take a fixed size slice form that infinite generator.
                    dataset_iterator = toolz.take(
                        FLAG_NUMBER_EVAL_BATCHES.value, eval_ds_instance)

                LOGGER.debug("Batching")
                for batch in dataset_iterator:
                    if FLAG_LOG_SAMPLES.value:
                        ####################################################################
                        # Print elements of the dataset
                        ####################################################################
                        # Make ourselves resistant to values possibly being a PerReplica
                        # object
                        LOGGER.warning(
                            f"%(red)sLOGGING SAMPLES. THIS IS VERY SLOW.%(reset)s",
                            dict(
                                red=colorama.Fore.RED,
                                reset=colorama.Style.RESET_ALL,
                            ))
                        is_distributed = isinstance(batch["input_ids"],
                                                    values.PerReplica)
                        for in_batch_idx in range(FLAG_BATCH_SIZE.value):
                            for replica_idx in (range(actual_num_replicas)
                                                if is_distributed else [0]):
                                if is_distributed:
                                    sample = {
                                        k: batch[k].values[replica_idx]
                                        for k in batch
                                    }
                                else:
                                    sample = batch

                                # input_sentence = tokenizer.decode(
                                #   [x for x in sample["input_ids"][i] if x != tokenizer.eos_token_id]
                                # )

                                # LOGGER.debug(
                                #   "%sInput [%d / %d]%s:\n\"%s\"",
                                #   colorama.Fore.GREEN,
                                #   replica_idx + 1,
                                #   actual_num_replicas,
                                #   colorama.Style.RESET_ALL,
                                #   input_sentence,
                                # )
                                #
                                # answer = tokenizer.decode(
                                #   [(x if x != -100 else 0) for x in sample["label_ids"][i]]
                                # )
                                # LOGGER.debug(
                                #   "%sLabel [%d / %d]%s:\n\"%s\"",
                                #   colorama.Fore.GREEN,
                                #   replica_idx + 1,
                                #   actual_num_replicas,
                                #   colorama.Style.RESET_ALL,
                                #   answer,
                                # )

                                cons = console.Console()
                                sentences = table.Table()
                                sentences.add_column("BPE Index",
                                                     justify="center")
                                sentences.add_column("Inputs",
                                                     justify="center")
                                sentences.add_column("Labels",
                                                     justify="center")
                                for bpe_idx, (x, y) in enumerate(
                                        itertools.zip_longest(
                                            sample["input_ids"]
                                            [in_batch_idx].numpy(),
                                            sample["label_ids"]
                                            [in_batch_idx].numpy(),
                                            fillvalue=None,
                                        )):
                                    x_w = tokenizer.decode(
                                        [x]) if x >= 0 else f"[ {x} ]"
                                    y_w = tokenizer.decode(
                                        [y]) if y >= 0 else f"[ {y} ]"
                                    sentences.add_row(str(bpe_idx), x_w, y_w)

                                cons.print(sentences)

                    # We only care about training epochs as, obviously, we don't train
                    # over eval samples; the number of  eval samples seen only
                    # contributes to lowering the variance in the evaluation of when to
                    # do early stopping.
                    if split == "train":
                        did_at_least_one_training_batch = True

                    input_ids = batch["input_ids"]
                    label_ids = batch["label_ids"]

                    # Per split step counter
                    step_counters[
                        split] += FLAG_BATCH_SIZE.value * actual_num_replicas
                    batch_counters[split] += 1

                    ######################################################################
                    # Model step function.
                    ######################################################################
                    step_function_kwargs = dict(
                        input_ids=input_ids,
                        label_ids=label_ids,
                    )

                    utils.print_mem(f"[{split}] - Mem before `strategy.run`",
                                    LOGGER)
                    LOGGER.debug("[%s] - Calling `strategy.run`", split)
                    loss = model_specific.strategy.run(
                        step_function[split], kwargs=step_function_kwargs)
                    LOGGER.debug("[%s] - Done `strategy.run`", split)
                    utils.print_mem(f"[{split}] - Mem after `strategy.run`",
                                    LOGGER)

                    ####################################################################
                    # End of logging step code / Logging and saving the model.
                    ####################################################################
                    if (FLAG_DISTRIBUTE_MODE.value
                            in constants.PURE_DATA_PARALLEL_STRATEGIES):
                        utils.check_equal(len(loss.values),
                                          actual_num_replicas)
                        LOGGER.debug("[%s] - Real num replicas: %s", split,
                                     actual_num_replicas)
                        average_loss = float(
                            tf.math.reduce_mean(loss.values).numpy())

                        LOGGER.debug("[%s] - Loss: %s", str(split),
                                     str(average_loss))

                    else:
                        average_loss = float(loss.numpy())

                    tf.debugging.check_numerics(
                        loss.values if isinstance(loss, values.PerReplica) else
                        loss, "Numerics failed.")

                    now = time.time()
                    batch_duration = now - prev_batch_end
                    prev_batch_end = now
                    ma_loss[split].update(average_loss)

                    LOGGER.info("[%s] - Epoch: # %d", split, epoch)
                    LOGGER.info("[%s] - Tensorboard_dir: %s", split,
                                instance_output_dir)
                    LOGGER.info("[%s] - Batch: # %d", split,
                                batch_counters[split])
                    LOGGER.info("[%s] - Step:  # %d", split,
                                step_counters[split])
                    if FLAG_USE_SUBSET.value:
                        LOGGER.warning(">> USING A SUBSET OF THE DATASET <<")
                    LOGGER.info(
                        "[%(split)s] - Batch loss:           %(metric)f",
                        dict(split=split, metric=average_loss))
                    LOGGER.info(
                        "[%(split)s] - Moving average loss:  %(metric)f",
                        dict(split=split, metric=ma_loss[split].average))
                    LOGGER.info(
                        "[%(split)s] - Moving average ppl:   %(metric)f",
                        dict(split=split,
                             metric=np.exp(ma_loss[split].average)))
                    LOGGER.info(
                        "[%(split)s] - Batch duration:       %(duration)s",
                        dict(split=split,
                             duration=utils.TimeStamp.from_seconds(
                                 batch_duration).format()))

                    # Write to Tensorboard
                    with writers[split].as_default():
                        tf.summary.scalar(f"Loss/{split}", average_loss,
                                          step_counters[split])
                        tf.summary.scalar(f"PPL/{split}", np.exp(average_loss),
                                          step_counters[split])
                    writers[split].flush()

                    ######################################################################
                    # Save every `FLAG_SAVE_PERIOD_MIN.value` minutes.
                    ######################################################################
                    delta_sec = time.time() - timestamp_last_ckpt_secs
                    utils.check_operator(operator.gt, delta_sec, 0)
                    period_sec = 60 * FLAG_SAVE_PERIOD_MIN.value
                    utils.check_operator(operator.gt, period_sec, 0)
                    ratio = delta_sec / period_sec
                    LOGGER.info(
                        "[%(split)s] - RATIO:                  %(ratio)s",
                        dict(split=split, ratio=str(ratio)))
                    LOGGER.info(
                        "[%(split)s] - Target: %(target)s, Present: %(present)s",
                        dict(
                            split=split,
                            target=str(period_sec),
                            present=str(delta_sec),
                        ))

                    if ratio >= 1:
                        dur = delta_sec / 60
                        timestamp_last_ckpt_secs = time.time()
                        LOGGER.debug(
                            "SAVING MODEL - CAUSE: DURATION - %0.2f min", dur)
                        # checkpoint.save(ckpt_prefix)
                        saver.save_model(
                            train_steps=step_counters["train"],
                            model_or_replicas=model,
                            optimizer=optimizer,
                        )

        ############################################################################
        # Post Training Cleanup
        ############################################################################
        for writer in writers.values():
            writer.close()