def test_invalid_strategy(self): with self.assertRaisesRegexp( ValueError, 'distribution_strategy must be a string but got: False. If'): distribute_utils.get_distribution_strategy(False) with self.assertRaisesRegexp( ValueError, 'distribution_strategy must be a string but got: 1'): distribute_utils.get_distribution_strategy(1)
def test_tpu_strategy(self): if not TPU_TEST: self.skipTest('Only Cloud TPU VM instances can have local TPUs.') with self.assertRaises(ValueError): _ = distribute_utils.get_distribution_strategy('tpu') ds = distribute_utils.get_distribution_strategy('tpu', tpu_address='local') self.assertIsInstance(ds, tf.distribute.TPUStrategy)
def test_mwms(self): distribute_utils.configure_cluster(worker_hosts=None, task_index=-1) ds = distribute_utils.get_distribution_strategy( 'multi_worker_mirrored', all_reduce_alg='nccl') self.assertIsInstance( ds, tf.distribute.experimental.MultiWorkerMirroredStrategy) with self.assertRaisesRegex( ValueError, 'When used with `multi_worker_mirrored`, valid values.*'): _ = distribute_utils.get_distribution_strategy( 'multi_worker_mirrored', all_reduce_alg='dummy')
def test_get_strategy_scope(self): ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0) with distribute_utils.get_strategy_scope(ds): self.assertIs(tf.distribute.get_strategy(), ds) with distribute_utils.get_strategy_scope(None): self.assertIsNot(tf.distribute.get_strategy(), ds)
def main(_): logging.info('Parsing config files...') gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) params = get_exp_config() # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16' # can have significant impact on model speeds by utilizing float16 in case of # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # dtype is float16 if params.runtime.mixed_precision_dtype: performance.set_mixed_precision_policy( params.runtime.mixed_precision_dtype, params.runtime.loss_scale, use_experimental_api=True) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) with distribution_strategy.scope(): task = distillation.BertDistillationTask( strategy=distribution_strategy, progressive=params.trainer.progressive, optimizer_config=params.trainer.optimizer_config, task_config=params.task) train_lib.run_experiment(distribution_strategy=distribution_strategy, task=task, mode=FLAGS.mode, params=params, model_dir=FLAGS.model_dir)
def create_distribution_strategy(distribution_strategy, tpu_address, input_partition_dims=None, num_gpus=None): """Creates distribution strategy to use for computation.""" if input_partition_dims is not None: if distribution_strategy != 'tpu': raise ValueError('Spatial partitioning is only supported ' 'for TPUStrategy.') # When `input_partition_dims` is specified create custom TPUStrategy # instance with computation shape for model parallelism. resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=tpu_address) if tpu_address not in ('', 'local'): tf.config.experimental_connect_to_cluster(resolver) topology = tf.tpu.experimental.initialize_tpu_system(resolver) num_replicas = resolver.get_tpu_system_metadata().num_cores // np.prod( input_partition_dims) device_assignment = tf.tpu.experimental.DeviceAssignment.build( topology, num_replicas=num_replicas, computation_shape=input_partition_dims) return tf.distribute.TPUStrategy( resolver, experimental_device_assignment=device_assignment) return distribute_utils.get_distribution_strategy( distribution_strategy=distribution_strategy, tpu_address=tpu_address, num_gpus=num_gpus)
def main(_): with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) if FLAGS.mode == 'export_only': export_squad(FLAGS.model_export_path, input_meta_data) return # Configures cluster spec for multi-worker distribution strategy. if FLAGS.num_gpus > 0: _ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index) strategy = distribute_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, all_reduce_alg=FLAGS.all_reduce_alg, tpu_address=FLAGS.tpu) if 'train' in FLAGS.mode: train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly) if 'predict' in FLAGS.mode: predict_squad(strategy, input_meta_data) if 'eval' in FLAGS.mode: eval_metrics = eval_squad(strategy, input_meta_data) f1_score = eval_metrics['final_f1'] logging.info('SQuAD eval F1-score: %f', f1_score) summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval') summary_writer = tf.summary.create_file_writer(summary_dir) with summary_writer.as_default(): # TODO(lehou): write to the correct step number. tf.summary.scalar('F1-score', f1_score, step=0) summary_writer.flush() # Also write eval_metrics to json file. squad_lib_sp.write_to_json_files( eval_metrics, os.path.join(summary_dir, 'eval_metrics.json')) time.sleep(60)
def __init__(self, flags_obj): """Init function of TransformerMain. Args: flags_obj: Object containing parsed flag values, i.e., FLAGS. Raises: ValueError: if not using static batch for input data on TPU. """ self.flags_obj = flags_obj self.predict_model = None # Add flag-defined parameters to params object num_gpus = flags_core.get_num_gpus(flags_obj) self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus) params["num_gpus"] = num_gpus params["use_ctl"] = flags_obj.use_ctl params["data_dir"] = flags_obj.data_dir params["model_dir"] = flags_obj.model_dir params["static_batch"] = flags_obj.static_batch params["max_length"] = flags_obj.max_length params["decode_batch_size"] = flags_obj.decode_batch_size params["decode_max_length"] = flags_obj.decode_max_length params["padded_decode"] = flags_obj.padded_decode params["max_io_parallelism"] = (flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE) params["use_synthetic_data"] = flags_obj.use_synthetic_data params["batch_size"] = flags_obj.batch_size or params[ "default_batch_size"] params["repeat_dataset"] = None params["dtype"] = flags_core.get_tf_dtype(flags_obj) params["enable_tensorboard"] = flags_obj.enable_tensorboard params[ "enable_metrics_in_training"] = flags_obj.enable_metrics_in_training params["steps_between_evals"] = flags_obj.steps_between_evals params["enable_checkpointing"] = flags_obj.enable_checkpointing params["save_weights_only"] = flags_obj.save_weights_only self.distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu or "") if self.use_tpu: params[ "num_replicas"] = self.distribution_strategy.num_replicas_in_sync else: logging.info("Running transformer with num_gpus = %d", num_gpus) if self.distribution_strategy: logging.info("For training, using distribution strategy: %s", self.distribution_strategy) else: logging.info("Not using any distribution strategy.") performance.set_mixed_precision_policy(params["dtype"])
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) params = train_utils.parse_configuration(FLAGS) model_dir = FLAGS.model_dir if 'train' in FLAGS.mode: # Pure eval modes do not output yaml files. Otherwise continuous eval job # may race against the train job for writing the same file. train_utils.serialize_config(params, model_dir) # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16' # can have significant impact on model speeds by utilizing float16 in case of # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # dtype is float16 if params.runtime.mixed_precision_dtype: performance.set_mixed_precision_policy( params.runtime.mixed_precision_dtype) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu, **params.runtime.model_parallelism()) with distribution_strategy.scope(): task = classification_example.ClassificationExampleTask(params.task) train_lib.run_experiment(distribution_strategy=distribution_strategy, task=task, mode=FLAGS.mode, params=params, model_dir=model_dir) train_utils.save_gin_config(FLAGS.mode, model_dir)
def main(_): with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) if not FLAGS.model_dir: FLAGS.model_dir = '/tmp/bert20/' strategy = distribute_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, tpu_address=FLAGS.tpu) max_seq_length = input_meta_data['max_seq_length'] train_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) eval_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.eval_data_path, max_seq_length, FLAGS.eval_batch_size, is_training=False) albert_config = albert_configs.AlbertConfig.from_json_file( FLAGS.bert_config_file) if FLAGS.mode == 'train_and_eval': run_classifier_bert.run_bert(strategy, input_meta_data, albert_config, train_input_fn, eval_input_fn) elif FLAGS.mode == 'predict': predict(strategy, albert_config, input_meta_data, eval_input_fn) else: raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode) return
def test_invalid_args(self): with self.assertRaisesRegex(ValueError, '`num_gpus` can not be negative.'): _ = distribute_utils.get_distribution_strategy(num_gpus=-1) with self.assertRaisesRegex(ValueError, '.*If you meant to pass the string .*'): _ = distribute_utils.get_distribution_strategy( distribution_strategy=False, num_gpus=0) with self.assertRaisesRegex(ValueError, 'When 2 GPUs are specified.*'): _ = distribute_utils.get_distribution_strategy( distribution_strategy='off', num_gpus=2) with self.assertRaisesRegex(ValueError, '`OneDeviceStrategy` can not be used.*'): _ = distribute_utils.get_distribution_strategy( distribution_strategy='one_device', num_gpus=2)
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) params = train_utils.parse_configuration(FLAGS) model_dir = FLAGS.model_dir if "train" in FLAGS.mode: train_utils.serialize_config(params, model_dir) if params.runtime.mixed_precision_dtype: performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu, **params.runtime.model_parallelism()) with distribution_strategy.scope(): if params.task.use_crf: task = ap_parsing_task.APParsingTaskCRF(params.task) else: task = ap_parsing_task.APParsingTaskBase(params.task) ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter( params, model_dir) trainer = train_utils.create_trainer( params, task, train="train" in FLAGS.mode, evaluate=("eval" in FLAGS.mode), checkpoint_exporter=ckpt_exporter) model, _ = train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode=FLAGS.mode, params=params, trainer=trainer, model_dir=model_dir) train_utils.save_gin_config(FLAGS.mode, model_dir) # Export saved model. if "train" in FLAGS.mode: saved_model_path = os.path.join(model_dir, "saved_models/latest") logging.info("Exporting SavedModel to %s", saved_model_path) tf.saved_model.save(model, saved_model_path) if ckpt_exporter: logging.info("Loading best checkpoint for export") trainer.checkpoint.restore(ckpt_exporter.best_ckpt_path) saved_model_path = os.path.join(model_dir, "saved_models/best") # Make sure restored and not re-initialized. if trainer.global_step > 0: logging.info( "Exporting best saved model by %s (from global step: %d) to %s", params.trainer.best_checkpoint_eval_metric, trainer.global_step.numpy(), saved_model_path) tf.saved_model.save(trainer.model, saved_model_path)
def main(unused_argv): del unused_argv strategy = distribute_utils.get_distribution_strategy( distribution_strategy=FLAGS.strategy_type, tpu_address=FLAGS.tpu) if strategy: logging.info("***** Number of cores used : %d", strategy.num_replicas_in_sync) train_input_fn = functools.partial(data_utils.get_classification_input_data, FLAGS.train_batch_size, FLAGS.seq_len, strategy, True, FLAGS.train_tfrecord_path) test_input_fn = functools.partial(data_utils.get_classification_input_data, FLAGS.test_batch_size, FLAGS.seq_len, strategy, False, FLAGS.test_tfrecord_path) total_training_steps = FLAGS.train_steps steps_per_loop = FLAGS.iterations eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size) eval_fn = functools.partial(run_evaluation, strategy, test_input_fn, eval_steps) optimizer, learning_rate_fn = optimization.create_optimizer( FLAGS.learning_rate, total_training_steps, FLAGS.warmup_steps, adam_epsilon=FLAGS.adam_epsilon) model_config = xlnet_config.XLNetConfig(FLAGS) run_config = xlnet_config.create_run_config(True, False, FLAGS) model_fn = functools.partial(get_classificationxlnet_model, model_config, run_config, FLAGS.n_class, FLAGS.summary_type) input_meta_data = {} input_meta_data["d_model"] = FLAGS.d_model input_meta_data["mem_len"] = FLAGS.mem_len input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / strategy.num_replicas_in_sync) input_meta_data["n_layer"] = FLAGS.n_layer input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate input_meta_data["n_class"] = FLAGS.n_class training_utils.train( strategy=strategy, model_fn=model_fn, input_meta_data=input_meta_data, eval_fn=eval_fn, metric_fn=get_metric_fn, train_input_fn=train_input_fn, init_checkpoint=FLAGS.init_checkpoint, init_from_transformerxl=FLAGS.init_from_transformerxl, total_training_steps=total_training_steps, steps_per_loop=steps_per_loop, optimizer=optimizer, learning_rate_fn=learning_rate_fn, model_dir=FLAGS.model_dir, save_steps=FLAGS.save_steps)
def test_mirrored_strategy(self): ds = distribute_utils.get_distribution_strategy(num_gpus=5) self.assertEquals(ds.num_replicas_in_sync, 5) self.assertEquals(len(ds.extended.worker_devices), 5) for device in ds.extended.worker_devices: self.assertIn('GPU', device) _ = distribute_utils.get_distribution_strategy( distribution_strategy='mirrored', num_gpus=2, all_reduce_alg='nccl', num_packs=2) with self.assertRaisesRegex( ValueError, 'When used with `mirrored`, valid values for all_reduce_alg are.*' ): _ = distribute_utils.get_distribution_strategy( distribution_strategy='mirrored', num_gpus=2, all_reduce_alg='dummy', num_packs=2)
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) params = train_utils.parse_configuration(FLAGS) model_dir = FLAGS.model_dir if 'train' in FLAGS.mode: # Pure eval modes do not output yaml files. Otherwise continuous eval job # may race against the train job for writing the same file. train_utils.serialize_config(params, model_dir) if 'train_and_eval' in FLAGS.mode: assert ( params.task.train_data.feature_shape == params.task.validation_data.feature_shape), ( f'train {params.task.train_data.feature_shape} != validate ' f'{params.task.validation_data.feature_shape}') if 'assemblenet' in FLAGS.experiment: if 'eval' in FLAGS.mode: # Use the feature shape in validation_data for all jobs. The number of # frames in train_data will be used to construct the Assemblenet model. params.task.model.backbone.assemblenet.num_frames = params.task.validation_data.feature_shape[ 0] shape = params.task.validation_data.feature_shape else: params.task.model.backbone.assemblenet.num_frames = params.task.train_data.feature_shape[ 0] shape = params.task.train_data.feature_shape logging.info('mode %r num_frames %r feature shape %r', FLAGS.mode, params.task.model.backbone.assemblenet.num_frames, shape) # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16' # can have significant impact on model speeds by utilizing float16 in case of # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # dtype is float16 if params.runtime.mixed_precision_dtype: performance.set_mixed_precision_policy( params.runtime.mixed_precision_dtype) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) with distribution_strategy.scope(): task = task_factory.get_task(params.task, logging_dir=model_dir) train_lib.run_experiment(distribution_strategy=distribution_strategy, task=task, mode=FLAGS.mode, params=params, model_dir=model_dir) train_utils.save_gin_config(FLAGS.mode, model_dir)
def run(): """Runs NHNet using Keras APIs.""" if FLAGS.enable_mlir_bridge: tf.config.experimental.enable_mlir_bridge() strategy = distribute_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, tpu_address=FLAGS.tpu) if strategy: logging.info("***** Number of cores used : %d", strategy.num_replicas_in_sync) params = models.get_model_params(FLAGS.model_type) params = params_dict.override_params_dict(params, FLAGS.params_override, is_strict=True) params.override( { "len_title": FLAGS.len_title, "len_passage": FLAGS.len_passage, "num_hidden_layers": FLAGS.num_encoder_layers, "num_decoder_layers": FLAGS.num_decoder_layers, "passage_list": [chr(ord("b") + i) for i in range(FLAGS.num_nhnet_articles)], }, is_strict=False) stats = {} if "train" in FLAGS.mode: stats = train(params, strategy) if "eval" in FLAGS.mode: timeout = 0 if FLAGS.mode == "train_and_eval" else FLAGS.eval_timeout # Uses padded decoding for TPU. Always uses cache. padded_decode = isinstance(strategy, tf.distribute.experimental.TPUStrategy) params.override({ "padded_decode": padded_decode, }, is_strict=False) stats = evaluation.continuous_eval( strategy, params, model_type=FLAGS.model_type, eval_file_pattern=FLAGS.eval_file_pattern, batch_size=FLAGS.eval_batch_size, eval_steps=FLAGS.eval_steps, model_dir=FLAGS.model_dir, timeout=timeout) return stats
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) params = train_utils.parse_configuration(FLAGS) model_dir = FLAGS.model_dir if 'train' in FLAGS.mode: # Pure eval modes do not output yaml files. Otherwise continuous eval job # may race against the train job for writing the same file. train_utils.serialize_config(params, model_dir) # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16' # can have significant impact on model speeds by utilizing float16 in case of # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # dtype is float16 if params.runtime.mixed_precision_dtype: performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) if isinstance(params, cfg.ExperimentConfig): with distribution_strategy.scope(): task = task_factory.get_task(params.task, logging_dir=model_dir) train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode=FLAGS.mode, params=params, model_dir=model_dir) elif isinstance(params, multi_cfg.MultiTaskExperimentConfig): with distribution_strategy.scope(): task = multitask.MultiTask.from_config(params.task, model_dir) model = multihead_model.build_model(params.task) train_lib_multitask.run_experiment( distribution_strategy=distribution_strategy, task=task, model=model, mode=FLAGS.mode, params=params, model_dir=model_dir) else: raise ValueError("Expected config to be either type cfg.ExperimentConfig" + \ "or multi_cfg.MultiTaskExperimentConfig, got %s" %type(params)) train_utils.save_gin_config(FLAGS.mode, model_dir)
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) if not FLAGS.model_dir: FLAGS.model_dir = '/tmp/bert20/' # Configures cluster spec for multi-worker distribution strategy. if FLAGS.num_gpus > 0: _ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index) strategy = distribute_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, all_reduce_alg=FLAGS.all_reduce_alg, tpu_address=FLAGS.tpu) if strategy: print('***** Number of cores used : ', strategy.num_replicas_in_sync) run_bert_pretrain(strategy)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') superglue_flags.validate_flags(FLAGS, file_exists_fn=tf.io.gfile.exists) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, tpu_address=FLAGS.tpu) with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) with distribution_strategy.scope(): task = None if 'train_eval' in FLAGS.mode: logging.info('Starting training and eval...') logging.info('Model dir: %s', FLAGS.model_dir) exp_config = _get_exp_config(input_meta_data=input_meta_data, exp_config_files=FLAGS.config_file) train_utils.serialize_config(exp_config, FLAGS.model_dir) task = task_factory.get_task(exp_config.task, logging_dir=FLAGS.model_dir) train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode='train_and_eval', params=exp_config, model_dir=FLAGS.model_dir) if 'predict' in FLAGS.mode: logging.info('Starting predict...') # When mode is `predict`, `task` will be None. if task is None: exp_config = _get_exp_config(input_meta_data=input_meta_data, exp_config_files=[ os.path.join( FLAGS.model_dir, 'params.yaml') ]) task = task_factory.get_task(exp_config.task, logging_dir=FLAGS.model_dir) _write_submission_file(task, input_meta_data['max_seq_length'])
def __init__(self, strategy_type=None, strategy_config=None): _ = distribute_utils.configure_cluster(strategy_config.worker_hosts, strategy_config.task_index) """Constructor. Args: strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'. If None, the user is responsible to set the strategy before calling build_executor(...). strategy_config: necessary config for constructing the proper Strategy. Check strategy_flags_dict() for examples of the structure. """ self._strategy = distribute_utils.get_distribution_strategy( distribution_strategy=strategy_type, num_gpus=strategy_config.num_gpus, all_reduce_alg=strategy_config.all_reduce_alg, num_packs=strategy_config.num_packs, tpu_address=strategy_config.tpu)
def get_v1_distribution_strategy(params): """Returns the distribution strategy to use.""" if params["use_tpu"]: # Some of the networking libraries are quite chatty. for name in [ "googleapiclient.discovery", "googleapiclient.discovery_cache", "oauth2client.transport" ]: logging.getLogger(name).setLevel(logging.ERROR) tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=params["tpu"], zone=params["tpu_zone"], project=params["tpu_gcp_project"], coordinator_name="coordinator") logging.info("Issuing reset command to TPU to ensure a clean state.") tf.Session.reset(tpu_cluster_resolver.get_master()) # Estimator looks at the master it connects to for MonitoredTrainingSession # by reading the `TF_CONFIG` environment variable, and the coordinator # is used by StreamingFilesDataset. tf_config_env = { "session_master": tpu_cluster_resolver.get_master(), "eval_session_master": tpu_cluster_resolver.get_master(), "coordinator": tpu_cluster_resolver.cluster_spec().as_dict()["coordinator"] } os.environ["TF_CONFIG"] = json.dumps(tf_config_env) distribution = tf.distribute.experimental.TPUStrategy( tpu_cluster_resolver, steps_per_run=100) else: distribution = distribute_utils.get_distribution_strategy( num_gpus=params["num_gpus"]) return distribution
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) print(FLAGS.experiment) params = train_utils.parse_configuration(FLAGS) model_dir = FLAGS.model_dir if 'train' in FLAGS.mode: # Pure eval modes do not output yaml files. Otherwise continuous eval job # may race against the train job for writing the same file. train_utils.serialize_config(params, model_dir) # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16' # can have significant impact on model speeds by utilizing float16 in case of # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # dtype is float16 if params.runtime.mixed_precision_dtype: performance.set_mixed_precision_policy( params.runtime.mixed_precision_dtype, params.runtime.loss_scale) if params.runtime.worker_hosts != '' and params.runtime.worker_hosts is not None: num_workers = distribute_utils.configure_cluster( worker_hosts=params.runtime.worker_hosts, task_index=params.runtime.task_index) print(num_workers) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) with distribution_strategy.scope(): task = task_factory.get_task(params.task, logging_dir=model_dir) train_lib.run_experiment(distribution_strategy=distribution_strategy, task=task, mode=FLAGS.mode, params=params, model_dir=model_dir)
def train_and_eval( params: base_configs.ExperimentConfig, strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]: """Runs the train and eval path using compile/fit.""" logging.info('Running train and eval.') distribute_utils.configure_cluster(params.runtime.worker_hosts, params.runtime.task_index) # Note: for TPUs, strategy and scope should be created before the dataset strategy = strategy_override or distribute_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) strategy_scope = distribute_utils.get_strategy_scope(strategy) logging.info('Detected %d devices.', strategy.num_replicas_in_sync if strategy else 1) label_smoothing = params.model.loss.label_smoothing one_hot = label_smoothing and label_smoothing > 0 builders = _get_dataset_builders(params, strategy, one_hot) datasets = [ builder.build(strategy) if builder else None for builder in builders ] # Unpack datasets and builders based on train/val/test splits train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking train_dataset, validation_dataset = datasets train_epochs = params.train.epochs train_steps = params.train.steps or train_builder.num_steps validation_steps = params.evaluation.steps or validation_builder.num_steps initialize(params, train_builder) logging.info('Global batch size: %d', train_builder.global_batch_size) with strategy_scope: model_params = params.model.model_params.as_dict() model = get_models()[params.model.name](**model_params) learning_rate = optimizer_factory.build_learning_rate( params=params.model.learning_rate, batch_size=train_builder.global_batch_size, train_epochs=train_epochs, train_steps=train_steps) optimizer = optimizer_factory.build_optimizer( optimizer_name=params.model.optimizer.name, base_learning_rate=learning_rate, params=params.model.optimizer.as_dict(), model=model) optimizer = performance.configure_optimizer( optimizer, use_float16=train_builder.dtype == 'float16', loss_scale=get_loss_scale(params)) metrics_map = _get_metrics(one_hot) metrics = [metrics_map[metric] for metric in params.train.metrics] steps_per_loop = train_steps if params.train.set_epoch_loop else 1 if one_hot: loss_obj = tf.keras.losses.CategoricalCrossentropy( label_smoothing=params.model.loss.label_smoothing) else: loss_obj = tf.keras.losses.SparseCategoricalCrossentropy() model.compile( optimizer=optimizer, loss=loss_obj, metrics=metrics, steps_per_execution=steps_per_loop) initial_epoch = 0 if params.train.resume_checkpoint: initial_epoch = resume_from_checkpoint( model=model, model_dir=params.model_dir, train_steps=train_steps) callbacks = custom_callbacks.get_callbacks( model_checkpoint=params.train.callbacks.enable_checkpoint_and_export, include_tensorboard=params.train.callbacks.enable_tensorboard, time_history=params.train.callbacks.enable_time_history, track_lr=params.train.tensorboard.track_lr, write_model_weights=params.train.tensorboard.write_model_weights, initial_step=initial_epoch * train_steps, batch_size=train_builder.global_batch_size, log_steps=params.train.time_history.log_steps, model_dir=params.model_dir, backup_and_restore=params.train.callbacks.enable_backup_and_restore) serialize_config(params=params, model_dir=params.model_dir) if params.evaluation.skip_eval: validation_kwargs = {} else: validation_kwargs = { 'validation_data': validation_dataset, 'validation_steps': validation_steps, 'validation_freq': params.evaluation.epochs_between_evals, } history = model.fit( train_dataset, epochs=train_epochs, steps_per_epoch=train_steps, initial_epoch=initial_epoch, callbacks=callbacks, verbose=2, **validation_kwargs) validation_output = None if not params.evaluation.skip_eval: validation_output = model.evaluate( validation_dataset, steps=validation_steps, verbose=2) # TODO(dankondratyuk): eval and save final test accuracy stats = common.build_stats(history, validation_output, callbacks) return stats
def test_one_device_strategy_gpu(self): ds = distribute_utils.get_distribution_strategy(num_gpus=1) self.assertEquals(ds.num_replicas_in_sync, 1) self.assertEquals(len(ds.extended.worker_devices), 1) self.assertIn('GPU', ds.extended.worker_devices[0])
def test_mirrored_strategy(self): ds = distribute_utils.get_distribution_strategy(num_gpus=5) self.assertEquals(ds.num_replicas_in_sync, 5) self.assertEquals(len(ds.extended.worker_devices), 5) for device in ds.extended.worker_devices: self.assertIn('GPU', device)
def run_executor(params, mode, checkpoint_path=None, train_input_fn=None, eval_input_fn=None, callbacks=None, prebuilt_strategy=None): """Runs the object detection model on distribution strategy defined by the user.""" if params.architecture.use_bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) model_builder = model_factory.model_generator(params) if prebuilt_strategy is not None: strategy = prebuilt_strategy else: strategy_config = params.strategy_config distribute_utils.configure_cluster(strategy_config.worker_hosts, strategy_config.task_index) strategy = distribute_utils.get_distribution_strategy( distribution_strategy=params.strategy_type, num_gpus=strategy_config.num_gpus, all_reduce_alg=strategy_config.all_reduce_alg, num_packs=strategy_config.num_packs, tpu_address=strategy_config.tpu) num_workers = int(strategy.num_replicas_in_sync + 7) // 8 is_multi_host = (int(num_workers) >= 2) if mode == 'train': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.TRAIN) logging.info( 'Train num_replicas_in_sync %d num_workers %d is_multi_host %s', strategy.num_replicas_in_sync, num_workers, is_multi_host) dist_executor = DetectionDistributedExecutor( strategy=strategy, params=params, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, is_multi_host=is_multi_host, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder. make_filter_trainable_variables_fn()) if is_multi_host: train_input_fn = functools.partial( train_input_fn, batch_size=params.train.batch_size // strategy.num_replicas_in_sync) return dist_executor.train( train_input_fn=train_input_fn, model_dir=params.model_dir, iterations_per_loop=params.train.iterations_per_loop, total_steps=params.train.total_steps, init_checkpoint=model_builder.make_restore_checkpoint_fn(), custom_callbacks=callbacks, save_config=True) elif mode == 'eval' or mode == 'eval_once': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT) logging.info( 'Eval num_replicas_in_sync %d num_workers %d is_multi_host %s', strategy.num_replicas_in_sync, num_workers, is_multi_host) if is_multi_host: eval_input_fn = functools.partial( eval_input_fn, batch_size=params.eval.batch_size // strategy.num_replicas_in_sync) dist_executor = DetectionDistributedExecutor( strategy=strategy, params=params, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, is_multi_host=is_multi_host, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder. make_filter_trainable_variables_fn()) if mode == 'eval': results = dist_executor.evaluate_from_model_dir( model_dir=params.model_dir, eval_input_fn=eval_input_fn, eval_metric_fn=model_builder.eval_metrics, eval_timeout=params.eval.eval_timeout, min_eval_interval=params.eval.min_eval_interval, total_steps=params.train.total_steps) else: # Run evaluation once for a single checkpoint. if not checkpoint_path: raise ValueError('checkpoint_path cannot be empty.') if tf.io.gfile.isdir(checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) summary_writer = executor.SummaryWriter(params.model_dir, 'eval') results, _ = dist_executor.evaluate_checkpoint( checkpoint_path=checkpoint_path, eval_input_fn=eval_input_fn, eval_metric_fn=model_builder.eval_metrics, summary_writer=summary_writer) for k, v in results.items(): logging.info('Final eval metric %s: %f', k, v) return results else: raise ValueError('Mode not found: %s.' % mode)
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config() performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) if tf.config.list_physical_devices('GPU'): if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj. datasets_num_private_threads) common.set_cudnn_batchnorm_mode() data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.config.list_physical_devices('GPU') else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribute_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) if flags_obj.steps_per_loop is None: steps_per_loop = per_epoch_steps elif flags_obj.steps_per_loop > per_epoch_steps: steps_per_loop = per_epoch_steps logging.warn('Setting steps_per_loop to %d to respect epoch boundary.', steps_per_loop) else: steps_per_loop = flags_obj.steps_per_loop logging.info( 'Training %d epochs, each epoch has %d steps, ' 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory( flags_obj.batch_size, flags_obj.log_steps, logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None) with distribute_utils.get_strategy_scope(strategy): runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, per_epoch_steps) eval_interval = flags_obj.epochs_between_evals * per_epoch_steps checkpoint_interval = (steps_per_loop * 5 if flags_obj.enable_checkpoint_and_export else None) summary_interval = steps_per_loop if flags_obj.enable_tensorboard else None checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=flags_obj.model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval) resnet_controller = orbit.Controller( strategy=strategy, trainer=runnable, evaluator=runnable if not flags_obj.skip_eval else None, global_step=runnable.global_step, steps_per_loop=steps_per_loop, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, summary_dir=flags_obj.model_dir, eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval')) time_callback.on_train_begin() if not flags_obj.skip_eval: resnet_controller.train_and_evaluate(train_steps=per_epoch_steps * train_epochs, eval_steps=eval_steps, eval_interval=eval_interval) else: resnet_controller.train(steps=per_epoch_steps * train_epochs) time_callback.on_train_end() stats = build_stats(runnable, time_callback) return stats
def main(unused_argv): del unused_argv num_hosts = 1 strategy = distribute_utils.get_distribution_strategy( distribution_strategy=FLAGS.strategy_type, tpu_address=FLAGS.tpu) if FLAGS.strategy_type == "tpu": num_hosts = strategy.extended.num_hosts if strategy: logging.info("***** Number of cores used : %d", strategy.num_replicas_in_sync) logging.info("***** Number of hosts used : %d", num_hosts) online_masking_config = data_utils.OnlineMaskingConfig( sample_strategy=FLAGS.sample_strategy, max_num_tokens=FLAGS.max_num_tokens, min_num_tokens=FLAGS.min_num_tokens, max_num_words=FLAGS.max_num_words, min_num_words=FLAGS.min_num_words) train_input_fn = functools.partial( data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len, strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size, FLAGS.leak_ratio, FLAGS.num_predict, FLAGS.uncased, online_masking_config, num_hosts) total_training_steps = FLAGS.train_steps steps_per_loop = FLAGS.iterations optimizer, learning_rate_fn = optimization.create_optimizer( init_lr=FLAGS.learning_rate, num_train_steps=total_training_steps, num_warmup_steps=FLAGS.warmup_steps, min_lr_ratio=FLAGS.min_lr_ratio, adam_epsilon=FLAGS.adam_epsilon, weight_decay_rate=FLAGS.weight_decay_rate) model_config = xlnet_config.XLNetConfig(FLAGS) run_config = xlnet_config.create_run_config(True, False, FLAGS) input_meta_data = {} input_meta_data["d_model"] = FLAGS.d_model input_meta_data["mem_len"] = FLAGS.mem_len input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / strategy.num_replicas_in_sync) input_meta_data["n_layer"] = FLAGS.n_layer input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate model_fn = functools.partial(get_pretrainxlnet_model, model_config, run_config) model = training_utils.train( strategy=strategy, model_fn=model_fn, input_meta_data=input_meta_data, eval_fn=None, metric_fn=None, train_input_fn=train_input_fn, init_checkpoint=FLAGS.init_checkpoint, init_from_transformerxl=FLAGS.init_from_transformerxl, total_training_steps=total_training_steps, steps_per_loop=steps_per_loop, optimizer=optimizer, learning_rate_fn=learning_rate_fn, model_dir=FLAGS.model_dir, save_steps=FLAGS.save_steps) # Export transformer-xl model checkpoint to be used in finetuning. checkpoint = tf.train.Checkpoint(transformer_xl=model.transformerxl_model) saved_path = checkpoint.save( os.path.join(FLAGS.model_dir, "pretrained/transformer_xl.ckpt")) logging.info( "Exporting the transformer-xl model as a new TF checkpoint: %s", saved_path)
def run_continuous_finetune( mode: str, params: config_definitions.ExperimentConfig, model_dir: str, run_post_eval: bool = False, pretrain_steps: Optional[int] = None, ) -> Mapping[str, Any]: """Run modes with continuous training. Currently only supports continuous_train_and_eval. Args: mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a checkpoint directory. Once a new checkpoint is discovered, loads the checkpoint, finetune the model by training it (probably on another dataset or with another task), then evaluate the finetuned model. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. run_post_eval: Whether to run post eval once after training, metrics logs are returned. pretrain_steps: Optional, the number of total training steps for the pretraining job. Returns: eval logs: returns eval metrics logs when run_post_eval is set to True, othewise, returns {}. """ assert mode == 'continuous_train_and_eval', ( 'Only continuous_train_and_eval is supported by continuous_finetune. ' 'Got mode: {}'.format(mode)) # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16' # can have significant impact on model speeds by utilizing float16 in case of # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # dtype is float16 if params.runtime.mixed_precision_dtype: performance.set_mixed_precision_policy( params.runtime.mixed_precision_dtype, params.runtime.loss_scale) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) retry_times = 0 while not tf.io.gfile.isdir(params.task.init_checkpoint): # Wait for the init_checkpoint directory to be created. if retry_times >= 60: raise ValueError( 'ExperimentConfig.task.init_checkpoint must be a directory for ' 'continuous_train_and_eval mode.') retry_times += 1 time.sleep(60) summary_writer = tf.summary.create_file_writer( os.path.join(model_dir, 'eval')) global_step = 0 def timeout_fn(): if pretrain_steps and global_step < pretrain_steps: # Keeps waiting for another timeout period. logging.info( 'Continue waiting for new checkpoint as current pretrain ' 'global_step=%d and target is %d.', global_step, pretrain_steps) return False # Quits the loop. return True for pretrain_ckpt in tf.train.checkpoints_iterator( checkpoint_dir=params.task.init_checkpoint, min_interval_secs=10, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn): with distribution_strategy.scope(): global_step = train_utils.read_global_step_from_checkpoint( pretrain_ckpt) # Replaces params.task.init_checkpoint to make sure that we load # exactly this pretrain checkpoint. if params.trainer.best_checkpoint_export_subdir: best_ckpt_subdir = '{}_{}'.format( params.trainer.best_checkpoint_export_subdir, global_step) params_replaced = params.replace( task={'init_checkpoint': pretrain_ckpt}, trainer={'best_checkpoint_export_subdir': best_ckpt_subdir}) else: params_replaced = params.replace( task={'init_checkpoint': pretrain_ckpt}) params_replaced.lock() logging.info('Running finetuning with params: %s', params_replaced) with distribution_strategy.scope(): if isinstance(params, configs.MultiEvalExperimentConfig): task = task_factory.get_task(params_replaced.task) eval_tasks = multitask.MultiTask.from_config( params_replaced.eval_tasks) (_, eval_metrics ) = multitask_train_lib.run_experiment_wtih_multitask_eval( distribution_strategy=distribution_strategy, train_task=task, eval_tasks=eval_tasks, mode='train_and_eval', params=params_replaced, model_dir=model_dir, run_post_eval=True, save_summary=False) else: task = task_factory.get_task(params_replaced.task, logging_dir=model_dir) _, eval_metrics = train_lib.run_experiment( distribution_strategy=distribution_strategy, task=task, mode='train_and_eval', params=params_replaced, model_dir=model_dir, run_post_eval=True, save_summary=False) logging.info('Evaluation finished. Pretrain global_step: %d', global_step) train_utils.write_json_summary(model_dir, global_step, eval_metrics) if not os.path.basename(model_dir): # if model_dir.endswith('/') summary_grp = os.path.dirname(model_dir) + '_' + task.name else: summary_grp = os.path.basename(model_dir) + '_' + task.name summaries = {} for name, value in _flatten_dict(eval_metrics).items(): summaries[summary_grp + '/' + '-'.join(name)] = value train_utils.write_summary(summary_writer, global_step, summaries) train_utils.remove_ckpts(model_dir) # In TF2, the resource life cycle is bound with the python object life # cycle. Force trigger python garbage collection here so those resources # can be deallocated in time, so it doesn't cause OOM when allocating new # objects. # TODO(b/169178664): Fix cycle reference in Keras model and revisit to see # if we need gc here. gc.collect() if run_post_eval: return eval_metrics return {}
def run_ncf(_): """Run NCF training and eval with Keras.""" keras_utils.set_session_config(enable_xla=FLAGS.enable_xla) if FLAGS.seed is not None: print("Setting tf seed") tf.random.set_seed(FLAGS.seed) model_helpers.apply_clean(FLAGS) if FLAGS.dtype == "fp16" and FLAGS.fp16_implementation == "keras": tf.keras.mixed_precision.set_global_policy("mixed_float16") strategy = distribute_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, tpu_address=FLAGS.tpu) params = ncf_common.parse_flags(FLAGS) params["distribute_strategy"] = strategy params["use_tpu"] = (FLAGS.distribution_strategy == "tpu") if params["use_tpu"] and not params["keras_use_ctl"]: logging.error( "Custom training loop must be used when using TPUStrategy.") return batch_size = params["batch_size"] time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps) callbacks = [time_callback] producer, input_meta_data = None, None generate_input_online = params["train_dataset_path"] is None if generate_input_online: # Start data producing thread. num_users, num_items, _, _, producer = ncf_common.get_inputs(params) producer.start() per_epoch_callback = IncrementEpochCallback(producer) callbacks.append(per_epoch_callback) else: assert params["eval_dataset_path"] and params["input_meta_data_path"] with tf.io.gfile.GFile(params["input_meta_data_path"], "rb") as reader: input_meta_data = json.loads(reader.read().decode("utf-8")) num_users = input_meta_data["num_users"] num_items = input_meta_data["num_items"] params["num_users"], params["num_items"] = num_users, num_items if FLAGS.early_stopping: early_stopping_callback = CustomEarlyStopping( "val_HR_METRIC", desired_value=FLAGS.hr_threshold) callbacks.append(early_stopping_callback) (train_input_dataset, eval_input_dataset, num_train_steps, num_eval_steps) = \ (ncf_input_pipeline.create_ncf_input_data( params, producer, input_meta_data, strategy)) steps_per_epoch = None if generate_input_online else num_train_steps with distribute_utils.get_strategy_scope(strategy): keras_model = _get_keras_model(params) optimizer = tf.keras.optimizers.Adam( learning_rate=params["learning_rate"], beta_1=params["beta1"], beta_2=params["beta2"], epsilon=params["epsilon"]) if FLAGS.fp16_implementation == "graph_rewrite": optimizer = \ tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic")) elif FLAGS.dtype == "fp16": loss_scale = flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic") # Note Model.compile automatically wraps the optimizer with a # LossScaleOptimizer using dynamic loss scaling. We explicitly wrap it # here for the case where a custom training loop or fixed loss scale is # used. if loss_scale == "dynamic": optimizer = tf.keras.mixed_precision.LossScaleOptimizer( optimizer) else: optimizer = tf.keras.mixed_precision.LossScaleOptimizer( optimizer, dynamic=False, initial_scale=loss_scale) if params["keras_use_ctl"]: train_loss, eval_results = run_ncf_custom_training( params, strategy, keras_model, optimizer, callbacks, train_input_dataset, eval_input_dataset, num_train_steps, num_eval_steps, generate_input_online=generate_input_online) else: keras_model.compile(optimizer=optimizer, run_eagerly=FLAGS.run_eagerly) if not FLAGS.ml_perf: # Create Tensorboard summary and checkpoint callbacks. summary_dir = os.path.join(FLAGS.model_dir, "summaries") summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) checkpoint_path = os.path.join(FLAGS.model_dir, "checkpoint") checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( checkpoint_path, save_weights_only=True) callbacks += [summary_callback, checkpoint_callback] history = keras_model.fit(train_input_dataset, epochs=FLAGS.train_epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, validation_data=eval_input_dataset, validation_steps=num_eval_steps, verbose=2) logging.info("Training done. Start evaluating") eval_loss_and_metrics = keras_model.evaluate(eval_input_dataset, steps=num_eval_steps, verbose=2) logging.info("Keras evaluation is done.") # Keras evaluate() API returns scalar loss and metric values from # evaluation as a list. Here, the returned list would contain # [evaluation loss, hr sum, hr count]. eval_hit_rate = eval_loss_and_metrics[1] / eval_loss_and_metrics[2] # Format evaluation result into [eval loss, eval hit accuracy]. eval_results = [eval_loss_and_metrics[0], eval_hit_rate] if history and history.history: train_history = history.history train_loss = train_history["loss"][-1] stats = build_stats(train_loss, eval_results, time_callback) return stats