def __init__(self, args): self.quantity = args.quantity self.max_purchase_quantity = args.max_purchase_quantity self.amount_bin_size = args.amount_bin_size self.state_bin_size = args.state_bin_size self.rate = utils.get_interest_rate() self.price = args.price # TODO: epsilon을 state마다 따로 둘까? 아니면 action마다 따로 둬야 하나? # TODO: Q_Epsilon은 State마다 따로 두고, # TODO: P_Epsilon은 (State,Action)마다 따로 줄까? self.q_eps = 1.0 self.p_eps = 1.0 self.q_eps_decay = args.q_eps_decay self.p_eps_decay = args.p_eps_decay self.window = args.window self.stack_to_state = self.create_stack_to_state() self.benefit_tables = { state: self.create_benefit_table(stack) for stack, state in self.stack_to_state.items() } self.times = { state: utils.MovingAverage(self.window) for state in self.stack_to_state.values() } self.uri = "http://localhost:3000" self.headers = {'Content-type': 'application/json'} self.id = None self.query_minimum = args.query_minimum self.query_diff = args.query_diff self.query_std = args.query_std
def create_benefit_table(self, stack): # TODO: 최대 구매 수량 제한을 에이전트별로 다르게 하는 것도 괜찮을까? max_n_actions = (self.max_purchase_quantity // self.amount_bin_size) + 1 n_actions = (stack // self.amount_bin_size) + 1 n_actions = min(max_n_actions, n_actions) benefit_table = [ utils.MovingAverage(self.window) for _ in range(n_actions) ] return benefit_table
def main(argv): del argv # unused arg tf.io.gfile.makedirs(FLAGS.output_dir) logging.info('Saving checkpoints at %s', FLAGS.output_dir) tf.random.set_seed(FLAGS.seed) batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores steps_per_epoch = APPROX_IMAGENET_TRAIN_IMAGES // batch_size steps_per_eval = IMAGENET_VALIDATION_IMAGES // batch_size logging.info('Saving checkpoints at %s', FLAGS.output_dir) if FLAGS.use_gpu: logging.info('Use GPU') strategy = tf.distribute.MirroredStrategy() else: logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local') resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.experimental.TPUStrategy(resolver) width_coefficient, depth_coefficient, input_image_size, dropout_rate = ( efficientnet_model.efficientnet_params(FLAGS.model_name)) imagenet_train = utils.ImageNetInput( is_training=True, use_bfloat16=FLAGS.use_bfloat16, data_dir=FLAGS.data_dir, batch_size=FLAGS.per_core_batch_size, image_size=input_image_size, normalize_input=True, one_hot=True) imagenet_eval = utils.ImageNetInput( is_training=False, use_bfloat16=FLAGS.use_bfloat16, data_dir=FLAGS.data_dir, batch_size=batch_size, image_size=input_image_size, normalize_input=True, one_hot=True) train_dataset = strategy.experimental_distribute_datasets_from_function( imagenet_train.input_fn) test_datasets = { 'clean': strategy.experimental_distribute_dataset(imagenet_eval.input_fn()), } train_iterator = iter(train_dataset) test_iterator = iter(test_datasets['clean']) if FLAGS.use_bfloat16: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') tf.keras.mixed_precision.experimental.set_policy(policy) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.output_dir, 'summaries')) with strategy.scope(): logging.info('Building %s model', FLAGS.model_name) model = efficientnet_model.Model(width_coefficient, depth_coefficient, dropout_rate) scaled_lr = FLAGS.base_learning_rate * (batch_size / 256.0) # Decay epoch is 2.4, warmup epoch is 5 according to the Efficientnet paper. decay_steps = steps_per_epoch * 2.4 warmup_step = steps_per_epoch * 5 lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( scaled_lr, decay_steps, decay_rate=0.97, staircase=True) learning_rate = utils.WarmupDecaySchedule(lr_schedule, warmup_step) optimizer = tf.keras.optimizers.RMSprop( learning_rate, rho=0.9, momentum=0.9, epsilon=0.001) if FLAGS.moving_average_decay > 0: optimizer = utils.MovingAverage( optimizer, average_decay=FLAGS.moving_average_decay) optimizer.shadow_copy(model) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.CategoricalAccuracy(), 'train/ece': ed.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins), 'train/loss': tf.keras.metrics.Mean(), 'test/negative_log_likelihood': tf.keras.metrics.Mean(), 'test/accuracy': tf.keras.metrics.CategoricalAccuracy(), 'test/ece': ed.metrics.ExpectedCalibrationError( num_bins=FLAGS.num_bins), } logging.info('Finished building %s model', FLAGS.model_name) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) initial_epoch = 0 if latest_checkpoint: # checkpoint.restore must be within a strategy.scope() so that optimizer # slot variables are mirrored. checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) initial_epoch = optimizer.iterations.numpy() // steps_per_epoch def train_step(inputs): """Build `step_fn` for efficientnet learning.""" images, labels = inputs num_replicas = tf.cast(strategy.num_replicas_in_sync, tf.float32) l2_coeff = tf.cast(FLAGS.l2, tf.float32) with tf.GradientTape() as tape: logits = model(images, training=True) logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.categorical_crossentropy( labels, logits, from_logits=True, label_smoothing=FLAGS.label_smoothing)) def _is_batch_norm(v): """Decide whether a variable belongs to `batch_norm`.""" keywords = ['batchnorm', 'batch_norm', 'bn'] return any([k in v.name.lower() for k in keywords]) l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in model.trainable_weights if not _is_batch_norm(v)]) loss = negative_log_likelihood + l2_coeff * l2_loss scaled_loss = loss / num_replicas gradients = tape.gradient(scaled_loss, model.trainable_weights) # MovingAverage optimizer automatically updates avg when applying gradients. optimizer.apply_gradients(zip(gradients, model.trainable_weights)) sparse_labels = tf.cast( tf.math.argmax(labels, axis=-1, output_type=tf.int32), tf.float32) probs = tf.nn.softmax(logits) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, logits) metrics['train/ece'].update_state(sparse_labels, probs) step_info = { 'loss/negative_log_likelihood': negative_log_likelihood / num_replicas, 'loss/total_loss': scaled_loss, } return step_info def eval_step(inputs): """A single step.""" images, labels = inputs logits = model(images, training=False) logits = tf.cast(logits, tf.float32) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.categorical_crossentropy( labels, logits, from_logits=True)) sparse_labels = tf.cast( tf.math.argmax(labels, axis=-1, output_type=tf.int32), tf.float32) probs = tf.nn.softmax(logits) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, logits) metrics['test/ece'].update_state(sparse_labels, probs) @tf.function def epoch_fn(should_eval): """Build `epoch_fn` for training and potential eval.""" for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)): info = strategy.run(train_step, args=(next(train_iterator),)) optim_step = optimizer.iterations if optim_step % tf.cast(100, optim_step.dtype) == 0: for k, v in info.items(): v_reduce = strategy.reduce(tf.distribute.ReduceOp.SUM, v, None) tf.summary.scalar(k, v_reduce, optim_step) tf.summary.scalar('loss/lr', learning_rate(optim_step), optim_step) summary_writer.flush() if should_eval: if isinstance(optimizer, utils.MovingAverage): optimizer.swap_weights(strategy) for _ in tf.range(tf.cast(steps_per_eval, tf.int32)): strategy.run(eval_step, args=(next(test_iterator),)) if isinstance(optimizer, utils.MovingAverage): optimizer.swap_weights(strategy) # Main training loop. start_time = time.time() with summary_writer.as_default(): for epoch in range(initial_epoch, FLAGS.train_epochs): logging.info('Starting to run epoch: %s', epoch) should_eval = (epoch % FLAGS.evaluation_interval == 0) epoch_start_time = time.time() # Pass tf constant to avoid re-tracing. epoch_fn(tf.constant(should_eval)) epoch_time = time.time() - epoch_start_time example_per_secs = (steps_per_epoch * batch_size) / epoch_time if not should_eval: tf.summary.scalar( 'examples_per_secs', example_per_secs, optimizer.iterations) summary_writer.flush() current_step = (epoch + 1) * steps_per_epoch max_steps = steps_per_epoch * FLAGS.train_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, FLAGS.train_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) logging.info('Train Loss: %.4f, Accuracy: %.2f%%', metrics['train/loss'].result(), metrics['train/accuracy'].result() * 100) if should_eval: logging.info('Test NLL: %.4f, Accuracy: %.2f%%', metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) total_metrics = metrics.copy() total_results = {name: metric.result() for name, metric in total_metrics.items()} total_results.update({'lr': learning_rate(optimizer.iterations)}) with summary_writer.as_default(): for name, result in total_results.items(): if should_eval or 'test' not in name: tf.summary.scalar(name, result, step=epoch + 1) for metric in metrics.values(): metric.reset_states() if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): checkpoint_name = checkpoint.save(os.path.join( FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name)
def main(argv): ####################################################################### # Initial Setup. Logging, Flags, Random seeds. ####################################################################### if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") absl_logging.use_python_logging() flags_dict = { flag.name: flag.value for flag in FLAGS.flags_by_module_dict()[argv[0]] } if FLAGS.use_subset: message = (f"{colorama.Back.RED}{colorama.Fore.WHITE}" f"{colorama.Style.BRIGHT}USING A SUBSET OF THE DATASET" f"{colorama.Style.RESET_ALL}") LOGGER.warning(message) utils.log_module_args(LOGGER, argv[0]) if not FLAGS.output_dir.startswith("gs://"): utils.check_exists(FLAG_OUTPUT_DIR.value) if not tf.io.gfile.isdir(FLAG_OUTPUT_DIR.value): raise RuntimeError("Output dir needs to be a directory.") tf.random.set_seed(FLAG_RANDOM_SEED.value) np.random.seed(FLAG_RANDOM_SEED.value) # Prepare the instance output directory path and save the config there folder_name = time.strftime( f"{FLAG_RUN_NAME.value}_{FLAG_APPROACH_TYPE.value}_%Y%m%d-%H%M%S") instance_output_dir = os.path.join(FLAG_OUTPUT_DIR.value, folder_name).strip() if not instance_output_dir.endswith("/"): instance_output_dir += "/" json_target = os.path.join(instance_output_dir, "training_params.json") if not json_target.strip().startswith("gs://"): subprocess.check_call(["mkdir", "-p", instance_output_dir]) utils.to_json_file(json_target, instance_output_dir) ############################################################################## # Initialization and Configuration of the Devices. ############################################################################## tpu_setup = None # current_acelerator_type is always "CPU" in the beginning with TPUs if tf_utils.current_accelerator_type() == "CPU": tpu_setup = tf_utils.init_tpus() LOGGER.debug("Devices we are computing on:\n%s", utils.wrap_iterable(map(str, tf_utils.devices_to_use()))) LOGGER.debug("All devices:") LOGGER.debug(tf_utils.device_mapping()) if tf_utils.current_accelerator_type() == "GPU": tf.config.set_soft_device_placement(True) if tf_utils.current_accelerator_type() != "TPU": tf.debugging.set_log_device_placement(True) if FLAG_DISTRIBUTE_MODE.value in constants.PURE_DATA_PARALLEL_STRATEGIES: actual_num_replicas = len(tf_utils.devices_to_use()) elif FLAG_DISTRIBUTE_MODE.value in constants.DATA_PARALLEL_DMC: actual_num_replicas = FLAG_NUM_REPLICAS.value else: actual_num_replicas = 1 ############################################################################## # We load the retriever model if it is needed. ############################################################################## # Not currently used. retriever = None # if (FLAG_APPROACH_TYPE.value == # constants.ApproachTypeChoices.lm_and_realm): # raise NotImplementedError("This part needs to be tested anew.") # config_path = FLAG_RETRIEVER_CONFIG_PATH.value # realm_save = tf_utils.REALMSave(**utils.from_json_file(config_path)) # # # Approx 15 min when not in dev mode, on CPU # with utils.log_duration(LOGGER, "main", # "whole of BERTScaNNRetriever.__init__", # logging.INFO): # scann_config = retrievers.ScannConfig( # **utils.from_json_file(FLAG_SCANN_CONFIG_PATH.value)) # retriever = retrievers.BERTScaNNRetriever( # retriever_module_path=realm_save.query_embedder_path, # block_records_path=realm_save.text_records, # num_block_records=realm_save.num_block_records, # mode=tf.estimator.ModeKeys.EVAL, # scann_config=scann_config) # elif (FLAG_APPROACH_TYPE.value == # constants.ApproachTypeChoices.cached_realm): # raise NotImplementedError("This part needs to be tested anew.") # config_path = FLAG_RETRIEVER_CONFIG_PATH.value # realm_save = tf_utils.REALMSave(**utils.from_json_file(config_path)) # # # Approx 15 min when not in dev mode, on CPU # with utils.log_duration(LOGGER, "main", # "whole of FullyCachedRetriever.__init__", # logging.INFO): # # retriever = retrievers.FullyCachedRetriever( # db_path=FLAG_FULLYCACHED_H5_PATH.value, # block_records_path=realm_save.text_records, # num_block_records=realm_save.num_block_records, # ) ############################################################################## # Distributed training task ############################################################################## if FLAG_TASK.value == constants.TaskChoices.train: with utils.log_duration(LOGGER, "main", "Load model"): utils.print_mem("before loading model", LOGGER) model_specific = task_specific.load_model( FLAG_MODEL_LOAD_PATH.value, FLAG_MODEL_KEY.value, FLAG_DISTRIBUTE_MODE.value, tpu_setup, FLAG_NUM_REPLICAS.value) utils.print_mem("after loading model", LOGGER) model_or_replicas = model_specific.model if isinstance(model_or_replicas, list): model_or_replicas: List[transformers.TFGPT2LMHeadModel] else: model_or_replicas: transformers.TFGPT2LMHeadModel tokenizer = model_specific.tokenizer def make_optimizer(): return tensor2tensor.utils.adafactor.AdafactorOptimizer( learning_rate=FLAG_LEARNING_RATE.value) if model_specific.strategy: with model_specific.strategy.scope(): optimizer = make_optimizer() else: optimizer = make_optimizer() ############################################################################ # Prepare the dataset functions ############################################################################ rg = np.random.default_rng(FLAG_RANDOM_SEED.value) def call_lm_preproc(repeat, split, random_seed): """Using functools.partial prevents the linter from doing its job.""" if FLAG_DATASET_NAME.value == constants.DatasetNameChoices.kilt_eli5: return task_specific.create_lm_ds_kilt_eli5( tokenizer=tokenizer, context_window_size=( model_or_replicas[0].config.n_positions if isinstance( model_or_replicas, list) else model_or_replicas.config.n_positions), dataset_name=FLAG_DATASET_NAME.value, # Batches are split over the replicas: batch_size=FLAG_BATCH_SIZE.value * actual_num_replicas, db_path=FLAG_DB_PATH.value, random_seed=random_seed, use_subset=FLAG_USE_SUBSET.value, subset_size=FLAG_SUBSET_SIZE.value, use_helper_words=FLAG_USE_HELPER_WORDS.value, approach_type=FLAG_APPROACH_TYPE.value, num_retrievals=FLAG_NUM_RETRIEVALS.value, retrieval_temperature=FLAG_RETRIEVAL_TEMPERATURE.value, retriever=retriever, repeat=repeat, split=split, enable_debug_checks=FLAG_DATASET_DEBUG.value, retrieval_bank_size=FLAG_RETRIEVAL_BANK_SIZE.value, dataset_type=FLAG_DATASET_TYPE.value, qty_shuffle=FLAG_QTY_SHUFFLE.value, tfr_prefix=FLAG_TFR_PREFIX.value, max_length_generation=FLAG_MAX_LENGTH_GENERATION.value, ) else: raise NotImplementedError( f"FLAG_DATASET_NAME.value unsupported: `{FLAG_DATASET_NAME.value}`" ) make_training_dataset: Callable[Ellipsis, tf.data.Dataset] = functools.partial( call_lm_preproc, split="train", repeat=False, ) make_eval_dataset: Callable[Ellipsis, tf.data.Dataset] = functools.partial( call_lm_preproc, split="eval", repeat=True, ) ############################################################################ # Prepare the step functions ############################################################################ utils.check_contained(FLAG_DISTRIBUTE_MODE.value, constants.DistributeModeChoices.choices()) tf_function_flags = dict( experimental_compile=FLAG_EXPERIMENTAL_COMPILE.value, experimental_relax_shapes=not FLAG_INPUT_FIXED_SIZE.value) if (FLAG_DISTRIBUTE_MODE.value == constants.DistributeModeChoices.split_and_data_parallel): if not isinstance(model_or_replicas, list): raise RuntimeError(type(model_or_replicas)) training_step = build_manual_data_parallel_training_step( model_or_replicas, optimizer, tf_function_flags) else: training_step = build_regular_training_step( model_or_replicas, optimizer, strategy=model_specific.strategy, tf_function_kwargs=tf_function_flags) evaluation_step = build_evaluation_step(model_or_replicas, tf_function_flags) secs_since_last_ckpt = time.time() # Model checkpoints are saved to the tmp_directory and then rsynced to GCS ########################################################################## # Prepare the different logging facilities ########################################################################## train_log_dir = os.path.join(instance_output_dir, "tensorboard", "train") eval_log_dir = os.path.join(instance_output_dir, "tensorboard", "eval") flags_log_dir = os.path.join(instance_output_dir, "tensorboard", "params") writers = dict(train=tf.summary.create_file_writer(train_log_dir), eval=tf.summary.create_file_writer(eval_log_dir), flags=tf.summary.create_file_writer(flags_log_dir)) with writers["flags"].as_default(): tf.summary.text( "Flags", # Tensorboard takes Markdown: json.dumps(flags_dict, indent=4).replace("\n", "\n\n"), step=0) ma_loss = dict(train=utils.MovingAverage(0.9), eval=utils.MovingAverage(0.9)) step_counters = dict(train=0, eval=0) batch_counters = dict(train=0, eval=0) prev_batch_end = time.time() # The eval ds has no real concept of epoch, repeats forever, shuffling # each time it reaches its end with utils.log_duration(LOGGER, "main", "All of make_eval_dataset"): eval_ds_instance = make_eval_dataset(random_seed=rg.integers( -2**63, 2**63 - 1), ) LOGGER.debug("Distributing the eval dataset to the replicas.") if FLAG_DATASET_TYPE.value == "tfr": eval_ds_instance = ( model_specific.strategy.experimental_distribute_dataset( eval_ds_instance)) LOGGER.debug("Done distributing the eval dataset to the replcias.") eval_ds_instance = iter(eval_ds_instance) ########################################################################## # Training Loop ########################################################################## for epoch in itertools.count(): #################################################################### # Epoch Setup #################################################################### LOGGER.debug("EPOCH %d START", epoch) # Shuffle differently every epoch with utils.log_duration(LOGGER, "main", "All of make_training_dataset"): train_ds_instance = make_training_dataset( random_seed=rg.integers(-2**63, 2**63 - 1), ) LOGGER.debug( "Attempting to distribute the training dataset to the replicas." ) if FLAG_DATASET_TYPE.value == "tfr": train_ds_instance = ( model_specific.strategy.experimental_distribute_dataset( train_ds_instance)) LOGGER.debug( "Done distributing the training dataset to the replicas.") train_ds_instance = iter(train_ds_instance) # This allows us to see if we reached the end of the training iterator, # in which case "did_at_least_one_training_batch == False". # We could also test that it did all the batches, to similar results. did_at_least_one_training_batch = True split = "eval" while did_at_least_one_training_batch: # Invert split if split == "train": split = "eval" else: split = "train" # Prepare to test if we did at least one training batch if split == "train": did_at_least_one_training_batch = False if split == "train": dataset_iterator = itertools.islice( train_ds_instance, FLAG_BATCHES_BETWEEN_EVALS.value) else: # The evaluation DS is tiny, so we reshuffle and take a random dataset_iterator = itertools.islice( eval_ds_instance, FLAG_NUMBER_EVAL_BATCHES.value) LOGGER.debug("Batching") for batch in dataset_iterator: # LOGGER.debug("Input sentence:\n\"%s\"", # tokenizer.decode([x for x in batch["input_ids"][0] # if x != tokenizer.eos_token_id])) # LOGGER.debug("Label:\n\"%s\"", # tokenizer.decode([(x if x != -100 else 0) # for x in batch["label_ids"][0]])) if FLAG_DATASET_TYPE.value != "tfr": batch = (model_specific.strategy. experimental_distribute_values_from_function( tf_utils.make_dict_distribute_fn(batch))) # We only care about training epochs as, obviously, we don't train # over eval samples; the number of eval samples seen only # contributes to lowering the variance in the evaluation of when to # do early stopping. if split == "train": did_at_least_one_training_batch = True input_ids = batch["input_ids"] label_ids = batch["label_ids"] #################################################################### # Training Step #################################################################### step_counters[split] += (FLAG_BATCH_SIZE.value * actual_num_replicas) if split == "train": batch_counters[split] += 1 training_kwargs = dict( input_ids=input_ids, label_ids=label_ids, ) if model_specific.strategy: utils.print_mem("before running", LOGGER) LOGGER.debug("Training, Calling strategy.run") loss = model_specific.strategy.run( training_step, kwargs=training_kwargs) LOGGER.debug("Training, Done with strategy.run") utils.print_mem("after running", LOGGER) else: loss = training_step(**training_kwargs) # pytype: disable=wrong-arg-count # If we are in the strategy-free data parallel mode, we need # to change the weights of all replicas to those of the model at # index 0 if (FLAG_DISTRIBUTE_MODE.value == constants.DistributeModeChoices. split_and_data_parallel): for replica in model_or_replicas[1:]: replica.set_weights( model_or_replicas[0].get_weights()) #################################################################### # Evaluation Step #################################################################### elif split == "eval": evaluation_kwargs = dict( input_ids=input_ids, label_ids=label_ids, ) if model_specific.strategy: loss = model_specific.strategy.run( evaluation_step, kwargs=evaluation_kwargs) else: loss = evaluation_step(**evaluation_kwargs) else: raise ValueError( f"Unexpected value for split: {split}") #################################################################### # Logging #################################################################### if (FLAG_DISTRIBUTE_MODE.value in constants.PURE_DATA_PARALLEL_STRATEGIES): utils.check_equal(len(loss.values), actual_num_replicas) LOGGER.debug("Split: %s", split) LOGGER.debug("Real num replicas: %s", actual_num_replicas) LOGGER.debug("Loss: %s", loss) LOGGER.debug("Loss values: %s", loss.values) average_loss = float( tf.math.reduce_mean(loss.values).numpy()) else: average_loss = float(loss.numpy()) # tf.debugging.check_numerics(loss) now = time.time() batch_duration = now - prev_batch_end prev_batch_end = now ma_loss[split].update(average_loss) # Actual logging LOGGER.info("Epoch: # %d", epoch) LOGGER.info("Tensorboard_dir: %s", instance_output_dir) LOGGER.info("Batch: %s # %d", split, batch_counters[split]) LOGGER.info("Step: %s # %d", split, step_counters[split]) if FLAG_USE_SUBSET.value: LOGGER.warning(">> USING A SUBSET OF THE DATASET <<") LOGGER.info("%(split)s Batch loss: %(metric)f", dict(split=split, metric=average_loss)) LOGGER.info( "%(split)s Moving average loss: %(metric)f", dict(split=split, metric=ma_loss[split].average)) LOGGER.info( "%(split)s Moving average ppl: %(metric)f", dict(split=split, metric=np.exp(ma_loss[split].average))) LOGGER.info( "%(split)s Batch duration: %(duration)s", dict(split=split, duration=utils.TimeStamp.from_seconds( batch_duration).format())) if FLAG_DISTRIBUTE_MODE.value in constants.DATA_PARALLEL_DMC: LOGGER.info( "%(split)s Duration per sample: %(duration)s", dict(split=split, duration=utils.TimeStamp.from_seconds( batch_duration / (FLAG_BATCH_SIZE.value * actual_num_replicas)))) # Write to Tensorboard with writers[split].as_default(): tf.summary.scalar(f"Loss/{split}", average_loss, step_counters[split]) tf.summary.scalar(f"PPL/{split}", np.exp(average_loss), step_counters[split]) writers[split].flush() # Save every 5 min if (time.time() - secs_since_last_ckpt) / (60 * 20) >= 1: secs_since_last_ckpt = time.time() save_model(train_steps=step_counters["train"], model_or_replicas=model_or_replicas, instance_output_dir=instance_output_dir) secs_since_last_ckpt = time.time() save_model(train_steps=step_counters["train"], model_or_replicas=model_or_replicas, instance_output_dir=instance_output_dir) ############################################################# # Post Training Cleanup ####################################################################### for writer in writers.values(): writer.close()
def train(self): """ Main actor learner loop for parallerl advantage actor critic learning. """ logging.info('Starting training at step %d' % self.global_step) logging.debug('Device: {}'.format(self.device)) counter = 0 global_step_start = self.global_step average_loss = utils.MovingAverage( 0.01, ['actor', 'critic', 'entropy', 'grad_norm']) total_rewards, training_stats, total_length = [], [], [] num_emulators = self.batch_env.num_emulators total_episode_rewards = np.zeros(num_emulators) #stores 0.0 in i-th element if the episode in i-th emulator has just started, otherwise stores 1.0 #mask is used to cut rnn_state and episode rewards between episodes. mask_t = th.zeros(num_emulators).to(self.device) #feedforward networks also use rnn_state, it's just empty! rnn_state = self.network.init_rnn_state(num_emulators) states, infos = self.batch_env.reset_all() self.batch_env.set_difficulty(self.starting_length) if self.evaluate is not None: stats = self.evaluate(self.network) training_stats.append((self.global_step, stats)) start_time = time.time() while self.global_step < self.total_steps: loop_start_time = time.time() values, log_probs, rewards, entropies, masks = [], [], [], [], [] self.network.detach_rnn_state(rnn_state) for t in range(self.rollout_steps): outputs = self.choose_action(states, infos, mask_t.unsqueeze(1), rnn_state) a_t, v_t, log_probs_t, entropy_t, rnn_state = outputs states, rs, dones, infos = self.batch_env.next(a_t) tensor_rs = th.from_numpy(self.reshape_r(rs)).to(self.device) rewards.append(tensor_rs) entropies.append(entropy_t) log_probs.append(log_probs_t) values.append(v_t) mask_t = 1.0 - th.from_numpy(dones).to( self.device) #dones.dtype == np.float32 masks.append( mask_t) #1.0 if episode is not done, 0.0 otherwise done_mask = dones.astype(bool) total_episode_rewards += rs if any(done_mask): total_rewards.extend(total_episode_rewards[done_mask]) total_episode_rewards[done_mask] = 0. next_v = self.predict_values(states, infos, mask_t.unsqueeze(1), rnn_state) update_stats = self.update_weights(next_v, rewards, masks, values, log_probs, entropies) average_loss.update(**update_stats) self.global_step += num_emulators * self.rollout_steps counter += 1 if counter % (self.print_every // (num_emulators * self.rollout_steps)) == 0: curr_time = time.time() self._training_info( total_rewards=total_rewards, average_speed=(self.global_step - global_step_start) / (curr_time - start_time), loop_speed=(num_emulators * self.rollout_steps) / (curr_time - loop_start_time), update_stats=average_loss) if counter % (self.eval_every // (num_emulators * self.rollout_steps)) == 0: if self.evaluate is not None: stats = self.evaluate(self.network) if stats.final_res > 0.95: print(stats.final_res, 'stats.final_res ') if self.curr_learning == True: #if it is curriculum learning, and final_res > 95 %, then enlarge th length print(self.curr_learning, 'self.curr_learning') self.change_length_labyrinth() else: pass training_stats.append((self.global_step, stats)) if self.global_step - self.last_saving_step >= self.save_every: self._save_progress(self.checkpoint_dir, summaries=training_stats, is_best=False) training_stats = [] self.last_saving_step = self.global_step self._save_progress(self.checkpoint_dir, is_best=False) logging.info('Training ended at step %d' % self.global_step)
def test_moving_average(): ma = utils.MovingAverage(0.9) assert ma.update(10) == 10 assert ma.update(10) == 10 assert ma.update(10) == 10
def train(self): """ Main actor learner loop for parallerl advantage actor critic learning. """ logging.info('Starting training at step %d' % self.global_step) logging.debug('use_cuda == {}'.format(self.use_cuda)) counter = 0 global_step_start = self.global_step average_loss = utils.MovingAverage(0.01, ['total', 'actor', 'critic']) total_rewards, training_stats = [], [] if self.eval_func is not None: stats = self.evaluate(verbose=True) training_stats.append((self.global_step, stats)) #num_actions = self.args['num_actions'] num_emulators = self.args['num_envs'] max_local_steps = self.args['max_local_steps'] max_global_steps = self.args['max_global_steps'] clip_norm = self.args['clip_norm'] rollout_steps = num_emulators * max_local_steps states, infos = self.batch_env.reset_all() emulator_steps = np.zeros(num_emulators, dtype=int) total_episode_rewards = np.zeros(num_emulators) not_done_masks = torch.zeros(max_local_steps, num_emulators).type( self._tensors.FloatTensor) if self.use_rnn: hx_init, cx_init = self.network.get_initial_state(num_emulators) hx, cx = hx_init, cx_init else: #for feedforward nets just ignore this argument hx, cx = None, None start_time = time.time() while self.global_step < max_global_steps: loop_start_time = time.time() values, log_probs, rewards, entropies = [], [], [], [] if self.use_rnn: hx, cx = hx.detach(), cx.detach( ) #Do I really need to detach here? for t in range(max_local_steps): outputs = self.choose_action(states, infos, (hx, cx)) a_t, v_t, log_probs_t, entropy_t, (hx, cx) = outputs states, rs, dones, infos = self.batch_env.next(a_t) #actions_sum += a_t rewards.append(np.clip(rs, -1., 1.)) entropies.append(entropy_t) log_probs.append(log_probs_t) values.append(v_t) is_done = torch.from_numpy(dones).type( self._tensors.FloatTensor) not_done_masks[t] = 1.0 - is_done done_mask = dones.astype(bool) total_episode_rewards += rs emulator_steps += 1 total_rewards.extend(total_episode_rewards[done_mask]) total_episode_rewards[done_mask] = 0. emulator_steps[done_mask] = 0 if self.use_rnn and any( done_mask ): # we need to clear all lstm states corresponding to the terminated emulators done_idx = is_done.nonzero().view(-1) hx, cx = hx.clone(), cx.clone( ) #hx_t, cx_t are used for backward op, so we can't modify them in-place hx[done_idx, :] = hx_init[done_idx, :].detach() cx[done_idx, :] = cx_init[done_idx, :].detach() self.global_step += rollout_steps next_v = self.predict_values(states, infos, (hx, cx)) R = next_v.detach().view(-1) delta_v = [] for t in reversed(range(max_local_steps)): rs = Variable(torch.from_numpy(rewards[t])).type( self._tensors.FloatTensor) not_done_t = Variable(not_done_masks[t]) R = rs + self.gamma * R * not_done_t delta_v_t = R - values[t].view(-1) delta_v.append(delta_v_t) loss, actor_loss, critic_loss = self.compute_loss( torch.cat(delta_v, 0), torch.cat(log_probs, 0).view(-1), torch.cat(entropies, 0).view(-1)) self.lr_scheduler.adjust_learning_rate(self.global_step) self.optimizer.zero_grad() loss.backward() global_norm = self.clip_gradients(self.network.parameters(), clip_norm) self.optimizer.step() average_loss.update(total=loss.data.item(), actor=actor_loss.item(), critic=critic_loss.item()) counter += 1 if counter % (self.print_every // rollout_steps) == 0: curr_time = time.time() self._training_info( total_rewards=total_rewards, average_speed=(self.global_step - global_step_start) / (curr_time - start_time), loop_speed=rollout_steps / (curr_time - loop_start_time), moving_averages=average_loss, grad_norms=global_norm) if counter % (self.eval_every // rollout_steps) == 0: if (self.eval_func is not None): stats = self.evaluate(verbose=True) training_stats.append((self.global_step, stats)) if self.global_step - self.last_saving_step >= self.save_every: self._save_progress(self.checkpoint_dir, summaries=training_stats, is_best=False) training_stats = [] self.last_saving_step = self.global_step self._save_progress(self.checkpoint_dir, is_best=False) logging.info('Training ended at step %d' % self.global_step)
def main(argv): ############################################################################## # Initial Setup. Logging, Flags, Random seeds. ############################################################################## if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") absl_logging.use_python_logging() flags_dict = { flag.name: flag.value for flag in FLAGS.flags_by_module_dict()[argv[0]] } if FLAGS.use_subset: message = (f"{colorama.Back.RED}{colorama.Fore.WHITE}" f"{colorama.Style.BRIGHT}USING A SUBSET OF THE DATASET" f"{colorama.Style.RESET_ALL}") LOGGER.warning(message) utils.log_module_args(LOGGER, argv[0]) if not FLAGS.output_dir.startswith("gs://"): utils.check_exists(FLAG_OUTPUT_DIR.value) if not tf.io.gfile.isdir(FLAG_OUTPUT_DIR.value): raise RuntimeError("Output dir needs to be a directory.") tf.random.set_seed(FLAG_RANDOM_SEED.value) np.random.seed(FLAG_RANDOM_SEED.value) # Prepare the instance output directory path and save the config there # Prepare the path folder_name = time.strftime( f"{FLAG_RUN_NAME.value}_{FLAG_APPROACH_TYPE.value}_%Y%m%d-%H%M%S") instance_output_dir = os.path.join(FLAG_OUTPUT_DIR.value, folder_name).strip() if not instance_output_dir.endswith("/"): instance_output_dir += "/" json_target = os.path.join(instance_output_dir, "training_params.json") # Make the folder if we're not on gcloud if not json_target.strip().startswith("gs://"): subprocess.check_call(["mkdir", "-p", instance_output_dir]) # Safe the config file utils.to_json_file(json_target, flags_dict) ############################################################################## # Initialization and Configuration of the Devices. ############################################################################## tpu_setup = None accel = tf_utils.current_accelerator_type() if FLAG_TPU_IS_LOCAL.value: assert accel == "TPU", accel if accel == "TPU": assert FLAG_TPU_IS_LOCAL.value, FLAG_TPU_IS_LOCAL.value if tf_utils.current_accelerator_type() in {"CPU", "TPU"}: tpu_setup = tf_utils.init_tpus(tpu_name=FLAG_TPU_NAME.value, local=FLAG_TPU_IS_LOCAL.value) LOGGER.debug("Devices we are computing on:\n%s", utils.wrap_iterable(map(str, tf_utils.devices_to_use()))) LOGGER.debug("All devices:") LOGGER.debug(tf_utils.device_mapping()) if tf_utils.current_accelerator_type() == "GPU": tf.config.set_soft_device_placement(True) if tf_utils.current_accelerator_type() != "TPU": tf.debugging.set_log_device_placement(True) utils.check_operator(operator.ne, tf_utils.current_accelerator_type(), "CPU") assert FLAG_TPU_NAME.value == socket.gethostname(), ( "This is a configuration choice. You can remove this. " "There will be no side effects.") if FLAG_DISTRIBUTE_MODE.value in constants.PURE_DATA_PARALLEL_STRATEGIES: actual_num_replicas = len(tf_utils.devices_to_use()) elif FLAG_DISTRIBUTE_MODE.value in constants.DATA_PARALLEL_DMC: actual_num_replicas = FLAG_NUM_REPLICAS.value else: actual_num_replicas = 1 ############################################################################## # We load the retriever model if it is needed. ############################################################################## # Not currently used. See old commits. retriever = None ############################################################################## # Distributed training task ############################################################################## if FLAG_TASK.value == constants.TaskChoices.train: with utils.log_duration(LOGGER, "main", "Load model"): utils.print_mem("before loading model", LOGGER) model_specific = task_specific.load_model( FLAG_MODEL_KEY.value, FLAG_DISTRIBUTE_MODE.value, tpu_setup, FLAG_NUM_REPLICAS.value) utils.print_mem("after loading model", LOGGER) model = model_specific.model if isinstance(model, list): model: List[transformers.TFGPT2LMHeadModel] else: model: transformers.TFGPT2LMHeadModel tokenizer = model_specific.tokenizer def make_optimizer(): if FLAG_OPTIMIZER_TYPE.value == constants.OptimizerTypes.adafactor: return tensor2tensor.utils.adafactor.AdafactorOptimizer( learning_rate=FLAG_LEARNING_RATE.value) elif FLAG_OPTIMIZER_TYPE.value == constants.OptimizerTypes.adam: return tf.keras.optimizers.Adam( learning_rate=FLAG_LEARNING_RATE.value) else: raise ValueError(FLAG_OPTIMIZER_TYPE.value) if model_specific.strategy: with model_specific.strategy.scope(): optimizer = make_optimizer() else: optimizer = make_optimizer() ############################################################################ # Prepare the dataset functions ############################################################################ rg = np.random.default_rng(FLAG_RANDOM_SEED.value) def call_lm_preproc(repeat, split, random_seed): """Using functools.partial prevents the linter from doing its job.""" if FLAG_DATASET_NAME.value == constants.DatasetNameChoices.kilt_eli5: return task_specific.create_lm_ds_kilt_eli5( tokenizer=tokenizer, context_window_size=model.config.n_positions, dataset_name=FLAG_DATASET_NAME.value, # Batches are split over the replicas: batch_size=FLAG_BATCH_SIZE.value * actual_num_replicas, db_path=FLAG_DB_PATH.value, random_seed=random_seed, use_subset=FLAG_USE_SUBSET.value, subset_size=FLAG_SUBSET_SIZE.value, use_helper_words=FLAG_USE_HELPER_WORDS.value, approach_type=FLAG_APPROACH_TYPE.value, num_retrievals=FLAG_NUM_RETRIEVALS.value, retrieval_temperature=FLAG_RETRIEVAL_TEMPERATURE.value, retriever=retriever, repeat=repeat, split=split, enable_debug_checks=FLAG_DATASET_DEBUG.value, retrieval_bank_size=FLAG_RETRIEVAL_BANK_SIZE.value, dataset_type=FLAG_DATASET_TYPE.value, qty_shuffle=FLAG_QTY_SHUFFLE.value, tfr_prefix=FLAG_TFR_PREFIX.value, max_length_generation=FLAG_MAX_LENGTH_GENERATION.value, ) else: raise NotImplementedError( f"FLAG_DATASET_NAME.value unsupported: `{FLAG_DATASET_NAME.value}`" ) make_training_dataset: Callable[..., tf.data.Dataset] = functools.partial( call_lm_preproc, split="train", repeat=False, ) make_eval_dataset: Callable[..., tf.data.Dataset] = functools.partial( call_lm_preproc, split="eval", repeat=True, ) ############################################################################ # Prepare the step functions ############################################################################ utils.check_contained(FLAG_DISTRIBUTE_MODE.value, constants.DistributeModeChoices.choices()) tf_function_flags = dict( experimental_compile=FLAG_EXPERIMENTAL_COMPILE.value, experimental_relax_shapes=not FLAG_INPUT_FIXED_SIZE.value) training_step = build_regular_training_step( model, optimizer, strategy=model_specific.strategy, tf_function_kwargs=tf_function_flags) evaluation_step = build_evaluation_step(model, tf_function_flags) timestamp_last_ckpt_secs = time.time() # Model checkpoints are saved to the tmp_directory and then rsynced to GCS ############################################################################ # Prepare the statistics and the logging facilities. ############################################################################ # Tensorboard with model_specific.strategy.scope(): checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) saver = Saver(instance_output_dir, checkpoint) train_log_dir = os.path.join(instance_output_dir, "tensorboard", "train") eval_log_dir = os.path.join(instance_output_dir, "tensorboard", "eval") flags_log_dir = os.path.join(instance_output_dir, "tensorboard", "params") writers = dict(train=tf.summary.create_file_writer(train_log_dir), eval=tf.summary.create_file_writer(eval_log_dir), flags=tf.summary.create_file_writer(flags_log_dir)) with writers["flags"].as_default(): tf.summary.text( "Flags", # Tensorboard takes Markdown: json.dumps(flags_dict, indent=4).replace("\n", "\n\n"), step=0) # Different information to log. ma_loss = dict(train=utils.MovingAverage(0.9), eval=utils.MovingAverage(0.9)) step_counters = dict(train=0, eval=0) batch_counters = dict(train=0, eval=0) prev_batch_end = time.time() ############################################################################ # Create the Eval DS object. # ========================================================================== # The eval ds has no real concept of epoch, repeats forever, shuffling # each time it reaches its end. ############################################################################ # Create with utils.log_duration(LOGGER, "main", "All of make_eval_dataset"): eval_ds_instance = make_eval_dataset(random_seed=rg.integers( -2**63, 2**63 - 1), ) # Maybe distribute LOGGER.debug("Distributing the eval dataset to the replicas.") if FLAG_DATASET_TYPE.value == "tfr": eval_ds_instance = ( model_specific.strategy.experimental_distribute_dataset( eval_ds_instance)) # Start the iteration. We step by calling `next(...)`. LOGGER.debug("Done distributing the eval dataset to the replicas.") eval_ds_instance = iter(eval_ds_instance) step_function = dict(train=training_step, eval=evaluation_step) ############################################################################ # Training Loop # ========================================================================== # Create a new training dataset object that lasts for one epoch. # This is different from the eval training dataset object, which loops # forever. ############################################################################ for epoch in itertools.count(): ########################################################################## # Epoch Setup ########################################################################## LOGGER.debug("EPOCH %d START", epoch) # Shuffle differently every epoch with utils.log_duration(LOGGER, "main", "All of make_training_dataset"): train_ds_instance = make_training_dataset( random_seed=rg.integers(-2**63, 2**63 - 1), ) LOGGER.debug( "Attempting to distribute the training dataset to the replicas." ) if FLAG_DATASET_TYPE.value == "tfr": train_ds_instance = ( model_specific.strategy.experimental_distribute_dataset( train_ds_instance)) LOGGER.debug( "Done distributing the training dataset to the replicas.") train_ds_instance = iter(train_ds_instance) # To change splits, we use `itertools.islice` over the dataset generator. # When the training dataset generator is done, a new loop of the following # while loop occurs, but no training batch is done because we are taking # an `islice` of a generator that is done. did_at_least_one_training_batch = True split = "eval" while did_at_least_one_training_batch: utils.check_operator(operator.ne, tf_utils.current_accelerator_type(), "CPU") # Invert split if split == "train": split = "eval" else: split = "train" # Prepare to test if we did at least one training batch if split == "train": did_at_least_one_training_batch = False ######################################################################## # Take slices from the dataset iterator # ====================================================================== # We only want to do a certain number of batches before switching splits # We do this by using an `itertools.islice` of the dataset iterators. ######################################################################## if split == "train": dataset_iterator = toolz.take( FLAG_BATCHES_BETWEEN_EVALS.value, train_ds_instance) else: # The evaluation dataset generator is infinite, reshuffles everytime # it gets to its end. # Still, we take a fixed size slice form that infinite generator. dataset_iterator = toolz.take( FLAG_NUMBER_EVAL_BATCHES.value, eval_ds_instance) LOGGER.debug("Batching") for batch in dataset_iterator: if FLAG_LOG_SAMPLES.value: #################################################################### # Print elements of the dataset #################################################################### # Make ourselves resistant to values possibly being a PerReplica # object LOGGER.warning( f"%(red)sLOGGING SAMPLES. THIS IS VERY SLOW.%(reset)s", dict( red=colorama.Fore.RED, reset=colorama.Style.RESET_ALL, )) is_distributed = isinstance(batch["input_ids"], values.PerReplica) for in_batch_idx in range(FLAG_BATCH_SIZE.value): for replica_idx in (range(actual_num_replicas) if is_distributed else [0]): if is_distributed: sample = { k: batch[k].values[replica_idx] for k in batch } else: sample = batch # input_sentence = tokenizer.decode( # [x for x in sample["input_ids"][i] if x != tokenizer.eos_token_id] # ) # LOGGER.debug( # "%sInput [%d / %d]%s:\n\"%s\"", # colorama.Fore.GREEN, # replica_idx + 1, # actual_num_replicas, # colorama.Style.RESET_ALL, # input_sentence, # ) # # answer = tokenizer.decode( # [(x if x != -100 else 0) for x in sample["label_ids"][i]] # ) # LOGGER.debug( # "%sLabel [%d / %d]%s:\n\"%s\"", # colorama.Fore.GREEN, # replica_idx + 1, # actual_num_replicas, # colorama.Style.RESET_ALL, # answer, # ) cons = console.Console() sentences = table.Table() sentences.add_column("BPE Index", justify="center") sentences.add_column("Inputs", justify="center") sentences.add_column("Labels", justify="center") for bpe_idx, (x, y) in enumerate( itertools.zip_longest( sample["input_ids"] [in_batch_idx].numpy(), sample["label_ids"] [in_batch_idx].numpy(), fillvalue=None, )): x_w = tokenizer.decode( [x]) if x >= 0 else f"[ {x} ]" y_w = tokenizer.decode( [y]) if y >= 0 else f"[ {y} ]" sentences.add_row(str(bpe_idx), x_w, y_w) cons.print(sentences) # We only care about training epochs as, obviously, we don't train # over eval samples; the number of eval samples seen only # contributes to lowering the variance in the evaluation of when to # do early stopping. if split == "train": did_at_least_one_training_batch = True input_ids = batch["input_ids"] label_ids = batch["label_ids"] # Per split step counter step_counters[ split] += FLAG_BATCH_SIZE.value * actual_num_replicas batch_counters[split] += 1 ###################################################################### # Model step function. ###################################################################### step_function_kwargs = dict( input_ids=input_ids, label_ids=label_ids, ) utils.print_mem(f"[{split}] - Mem before `strategy.run`", LOGGER) LOGGER.debug("[%s] - Calling `strategy.run`", split) loss = model_specific.strategy.run( step_function[split], kwargs=step_function_kwargs) LOGGER.debug("[%s] - Done `strategy.run`", split) utils.print_mem(f"[{split}] - Mem after `strategy.run`", LOGGER) #################################################################### # End of logging step code / Logging and saving the model. #################################################################### if (FLAG_DISTRIBUTE_MODE.value in constants.PURE_DATA_PARALLEL_STRATEGIES): utils.check_equal(len(loss.values), actual_num_replicas) LOGGER.debug("[%s] - Real num replicas: %s", split, actual_num_replicas) average_loss = float( tf.math.reduce_mean(loss.values).numpy()) LOGGER.debug("[%s] - Loss: %s", str(split), str(average_loss)) else: average_loss = float(loss.numpy()) tf.debugging.check_numerics( loss.values if isinstance(loss, values.PerReplica) else loss, "Numerics failed.") now = time.time() batch_duration = now - prev_batch_end prev_batch_end = now ma_loss[split].update(average_loss) LOGGER.info("[%s] - Epoch: # %d", split, epoch) LOGGER.info("[%s] - Tensorboard_dir: %s", split, instance_output_dir) LOGGER.info("[%s] - Batch: # %d", split, batch_counters[split]) LOGGER.info("[%s] - Step: # %d", split, step_counters[split]) if FLAG_USE_SUBSET.value: LOGGER.warning(">> USING A SUBSET OF THE DATASET <<") LOGGER.info( "[%(split)s] - Batch loss: %(metric)f", dict(split=split, metric=average_loss)) LOGGER.info( "[%(split)s] - Moving average loss: %(metric)f", dict(split=split, metric=ma_loss[split].average)) LOGGER.info( "[%(split)s] - Moving average ppl: %(metric)f", dict(split=split, metric=np.exp(ma_loss[split].average))) LOGGER.info( "[%(split)s] - Batch duration: %(duration)s", dict(split=split, duration=utils.TimeStamp.from_seconds( batch_duration).format())) # Write to Tensorboard with writers[split].as_default(): tf.summary.scalar(f"Loss/{split}", average_loss, step_counters[split]) tf.summary.scalar(f"PPL/{split}", np.exp(average_loss), step_counters[split]) writers[split].flush() ###################################################################### # Save every `FLAG_SAVE_PERIOD_MIN.value` minutes. ###################################################################### delta_sec = time.time() - timestamp_last_ckpt_secs utils.check_operator(operator.gt, delta_sec, 0) period_sec = 60 * FLAG_SAVE_PERIOD_MIN.value utils.check_operator(operator.gt, period_sec, 0) ratio = delta_sec / period_sec LOGGER.info( "[%(split)s] - RATIO: %(ratio)s", dict(split=split, ratio=str(ratio))) LOGGER.info( "[%(split)s] - Target: %(target)s, Present: %(present)s", dict( split=split, target=str(period_sec), present=str(delta_sec), )) if ratio >= 1: dur = delta_sec / 60 timestamp_last_ckpt_secs = time.time() LOGGER.debug( "SAVING MODEL - CAUSE: DURATION - %0.2f min", dur) # checkpoint.save(ckpt_prefix) saver.save_model( train_steps=step_counters["train"], model_or_replicas=model, optimizer=optimizer, ) ############################################################################ # Post Training Cleanup ############################################################################ for writer in writers.values(): writer.close()