def _get_input_iterator( self, input_fn: Callable[[Optional[params_dict.ParamsDict]], tf.data.Dataset], strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]: """Returns distributed dataset iterator. Args: input_fn: (params: dict) -> tf.data.Dataset. strategy: an instance of tf.distribute.Strategy. Returns: An iterator that yields input tensors. """ if input_fn is None: return None # When training with multiple TPU workers, datasets needs to be cloned # across workers. Since Dataset instance cannot be cloned in eager mode, # we instead pass callable that returns a dataset. if self._is_multi_host: return iter( strategy.experimental_distribute_datasets_from_function( input_fn)) else: input_data = input_fn(self._params) return iter(strategy.experimental_distribute_dataset(input_data))
def get_datasets(train_batch_size: int, val_batch_size: int, strategy: tf.distribute.Strategy) -> Tuple[Any, Any]: """Create and return train and validation dataset builders.""" ds_train = strategy.experimental_distribute_datasets_from_function( _make_get_dataset_fn('train', train_batch_size, True)) ds_val = strategy.experimental_distribute_datasets_from_function( _make_get_dataset_fn('validation', val_batch_size, False)) return ds_train, ds_val
def train_step(strategy: tf.distribute.Strategy, data_it, disc: Model, gen: Model, model_g: Model, model_d: Model, batch_size: int, z_size: int, num_cat: int, metrics: Dict[str, keras.metrics.Mean]): # Discriminate def discriminate(batch_images: tf.Tensor): train_vars = disc.trainable_variables batch_size = batch_images.shape[0] eps = tf.random.uniform((batch_size, 1, 1, 1), 0, 1) z_input, _, cat_input = CqGAN.generate_z(batch_size, z_size, num_cat) with tf.GradientTape() as tape: disc_gen, disc_real, iwgan_loss, cat_output = model_d( (z_input, batch_images, eps), training=True) full_loss = -disc_gen + disc_real + iwgan_loss + CqGAN.get_loss_cat( cat_input, cat_output) grads = tape.gradient(full_loss, train_vars) disc.optimizer.apply_gradients(zip(grads, train_vars)) loss_real = full_loss metrics["disc_gen"].update_state(disc_gen) metrics["disc_real"].update_state(disc_real) metrics["loss_real"].update_state(loss_real) metrics["iwgan_loss"].update_state(iwgan_loss) def generate(): train_vars = gen.trainable_variables z_input, _, cat_input = CqGAN.generate_z(batch_size, z_size, num_cat) with tf.GradientTape() as tape: disc_gen, cat_output = model_g(z_input, training=True) loss_gen = disc_gen loss_cat = CqGAN.get_loss_cat(cat_input, cat_output) full_loss = loss_gen + loss_cat grads = tape.gradient(full_loss, train_vars) gen.optimizer.apply_gradients(zip(grads, train_vars)) metrics["loss_gen"].update_state(loss_gen) metrics["loss_cat"].update_state(loss_cat) for _ in range(3): batch_images = next(data_it) strategy.run(discriminate, args=(batch_images, )) strategy.run(generate)
def read( self, mode: str, mirrored_strategy: tf.distribute.Strategy = None ) -> tf.data.Dataset: if mirrored_strategy: num_gpus = mirrored_strategy.num_replicas_in_sync with mirrored_strategy.scope(): dataset, num_iters = self._read(mode, self._batch_size * num_gpus) dataset = mirrored_strategy.experimental_distribute_dataset( dataset) return dataset, num_iters else: return self._read(mode, self._batch_size)
def create_mbert_model_v3(model_type: str, strategy: tf.distribute.Strategy, config: AutoConfig, max_len: int) -> tf.keras.Model: # Create Custom Model with strategy.scope(): input_ids = tf.keras.layers.Input(shape=(max_len, ), dtype=tf.int32, name='input_ids') input_masks = tf.keras.layers.Input(shape=(max_len, ), dtype=tf.int32, name='attention_mask') transformers_model = TFBertModel.from_pretrained(model_type, config=config) output_dict = transformers_model({ 'input_ids': input_ids, 'attention_mask': input_masks }) last_hidden_state = output_dict.last_hidden_state outputs = last_hidden_state[:, 0, :] model = tf.keras.Model(inputs=[input_ids, input_masks], outputs=outputs) return model
def create_xlm_roberta_model_v1(use_default_weights: bool, custom_pretrained_model_checkpoint: str, strategy: tf.distribute.Strategy, config: AutoConfig, lr: float) -> tf.keras.Model: # Set Model init specs if use_default_weights: model_type = 'jplu/tf-xlm-roberta-base' from_pt = False else: model_type = custom_pretrained_model_checkpoint from_pt = True # Create 'Standard' Classification Model with strategy.scope(): model = TFXLMRobertaForSequenceClassification.from_pretrained( model_type, config=config, from_pt=from_pt) optimizer = tf.keras.optimizers.Adam(learning_rate=lr) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy') model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) return model
def build(self, strategy: tf.distribute.Strategy = None) -> tf.data.Dataset: """Construct a dataset end-to-end and return it using an optional strategy. Args: strategy: a strategy that, if passed, will distribute the dataset according to that strategy. If passed and `num_devices > 1`, `use_per_replica_batch_size` must be set to `True`. Returns: A TensorFlow dataset outputting batched images and labels. """ if strategy: if strategy.num_replicas_in_sync != self.config.num_devices: logging.warn( 'Passed a strategy with %d devices, but expected' '%d devices.', strategy.num_replicas_in_sync, self.config.num_devices) dataset = strategy.experimental_distribute_datasets_from_function( self._build) else: dataset = self._build() return dataset
def _convert_per_replica_tensor(strategy: tf.distribute.Strategy, *per_replica_tensors) -> tf.float32: """ Concat the tensors distributed over the different GPU replicas. Parameters ---------- strategy: Strategy used to distribute the GPUs. per_replica_tensors: tensor distributed over the GPU replicas. Returns ------- Concatenated tensors """ concatenated_tensors = [] for per_replica_tensor in per_replica_tensors: concatenated_tensors.append( tf.concat( strategy.experimental_local_results(per_replica_tensor), axis=0)) return concatenated_tensors
def get_data_loader(config: DotMap, strategy: tf.distribute.Strategy) -> DataLoader: data_loader_type = config.dataset.data_loader.type with strategy.scope(): if data_loader_type == 'pix2pix': return Pix2PixDataLoader(config, strategy) else: raise ValueError(f"unknown data loader type {data_loader_type}")
def create_dataset( dataset_builder: base.BaseDataset, batch_size: int, process_fn: Any, distributed_strategy: tf.distribute.Strategy, distributed: bool ) -> Union[tf.data.Dataset, tf.distribute.DistributedDataset]: """Creates (optionally distributed) dataset from dataset_builder and process_fn.""" dataset = dataset_builder.load(batch_size=batch_size).map(process_fn) if distributed: dataset = distributed_strategy.experimental_distribute_dataset(dataset) return dataset
def create_xlm_roberta_model_v2(use_default_weights: bool, custom_pretrained_model_checkpoint: str, strategy: tf.distribute.Strategy, config: AutoConfig, max_len: int, lr: float) -> tf.keras.Model: # Set Model init specs if use_default_weights: model_type = 'jplu/tf-xlm-roberta-base' from_pt = False else: model_type = custom_pretrained_model_checkpoint from_pt = True # Create Custom Model with strategy.scope(): input_ids = tf.keras.layers.Input(shape=(max_len, ), dtype=tf.int32, name='input_ids') input_masks = tf.keras.layers.Input(shape=(max_len, ), dtype=tf.int32, name='attention_mask') # Initializers kernel_initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.05, seed=None) bias_initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.05, seed=None) transformers_model = TFRobertaModel.from_pretrained(model_type, config=config, from_pt=from_pt) last_hidden_states = transformers_model({ 'input_ids': input_ids, 'attention_mask': input_masks }) x = last_hidden_states[0][:, 0, :] x = tf.keras.layers.Dropout(0.2)(x) outputs = tf.keras.layers.Dense(2, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer)(x) model = tf.keras.Model(inputs=[input_ids, input_masks], outputs=outputs) optimizer = tf.keras.optimizers.Adam(learning_rate=lr) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy') # Compile model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) return model
def compute_predictions( model: PredictionModel, dataset: tf.data.Dataset, strategy: tf.distribute.Strategy, batch_size: int ) -> Iterator[Tuple[types.ModelPredictions, types.Features]]: """Yield the predictions of the model on the given dataset. Args: model: A function that takes tensor-valued features and returns a vector of predictions. dataset: The dataset that the function consumes to produce the predictions. strategy: The distribution strategy to use when computing. batch_size: The batch size that should be used. Yields: Pairs of model predictions and the corresponding metadata. """ with strategy.scope(): dataset = dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) options = tf.data.Options() options.experimental_distribute.auto_shard_policy = ( tf.data.experimental.AutoShardPolicy.DATA) dataset = dataset.with_options(options) for features in strategy.experimental_distribute_dataset(dataset): time_start = time.time() if isinstance(strategy, tf.distribute.experimental.TPUStrategy): # TODO(josipd): Figure this out better. We can't easily filter, # as they are PerReplica values, not tensors. features_model = {"image": features["image"]} else: features_model = features predictions = materialize(strategy, strategy.run(model, args=(features_model,))) time_end = time.time() time_delta_per_example = (time_end - time_start) / predictions.shape[0] metadatas = materialize(strategy, features["metadata"]) for i in range(predictions.shape[0]): model_predictions = types.ModelPredictions( predictions=[predictions[i]], time_in_s=time_delta_per_example) metadata_i = _slice_dictionary(metadatas, i) yield model_predictions, metadata_i
def compute_predictions( model: PredictionModel, dataset: tf.data.Dataset, strategy: tf.distribute.Strategy) -> Iterator[types.ModelPredictions]: """Yield the predictions of the model on the given dataset. Note that the dataset is expected to yield batches of tensors. Args: model: A function that takes tensor-valued features and returns a vector of predictions. dataset: The dataset that the function consumes to produce the predictions. strategy: The distribution strategy to use when computing. Yields: The predictions of the model on the dataset. """ for features in strategy.experimental_distribute_dataset(dataset): # TODO(josipd): Figure out how to pass only tpu-allowed types. time_start = time.time() predictions = materialize( strategy, strategy.run(model, args=({ "image": features["image"] }, ))) time_end = time.time() time_delta_per_example = (time_end - time_start) / predictions.shape[0] try: element_ids = materialize(strategy, features["element_id"]) except KeyError: element_ids = [None] * predictions.shape[0] metadatas = materialize(strategy, features["metadata"]) for i in range(predictions.shape[0]): yield types.ModelPredictions(element_id=element_ids[i], metadata=_slice_dictionary( metadatas, i), predictions=[predictions[i]], time_in_s=time_delta_per_example)
def create_byt5_model(model_type: str, strategy: tf.distribute.Strategy, config: AutoConfig, lr: float, max_label_len: int, total_steps: int) -> tf.keras.Model: # Create Model with strategy.scope(): radam = tfa.optimizers.RectifiedAdam(learning_rate=lr, total_steps=total_steps, warmup_proportion=0.10, min_lr=lr / 3.) ranger = tfa.optimizers.Lookahead(radam, sync_period=6, slow_step_size=0.5) model = KerasTFByT5ForConditionalGeneration.from_pretrained( model_type, config=config) model.compile(optimizer=ranger, metrics=[T5_Accuracy(label_length=max_label_len)]) return model
def train_classifier( model: BertABSClassifier, optimizer: tf.keras.optimizers.Optimizer, train_dataset: Iterable[ClassifierTrainBatch], epochs: int, test_dataset: Iterable[ClassifierTrainBatch] = None, callbacks: List[Callback] = None, strategy: tf.distribute.Strategy = tf.distribute.OneDeviceStrategy('CPU')): """ This routines tune the classifier along with the language model. """ with strategy.scope(): def train_step(*batch: List[tf.Tensor]): token_ids, attention_mask, token_type_ids, target_labels = batch with tf.GradientTape() as tape: model_outputs = model.call(token_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, training=True) logits, *details = model_outputs loss_value = classifier_loss(target_labels, logits) variables = model.bert.trainable_variables \ + model.classifier.trainable_variables grads = tape.gradient(loss_value, variables) optimizer.apply_gradients(zip(grads, variables)) return [loss_value, *model_outputs] def test_step(*batch: List[tf.Tensor]): token_ids, attention_mask, token_type_ids, target_labels = batch model_outputs = model.call(token_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) logits, *details = model_outputs loss_value = classifier_loss(target_labels, logits) return [loss_value, *model_outputs] routines.train(strategy=strategy, train_step=train_step, train_dataset=train_dataset, test_step=test_step, test_dataset=test_dataset, epochs=epochs, callbacks=callbacks)
def build_model_and_get_trainer(config: DotMap, data_loader: DataLoader, strategy: tf.distribute.Strategy) -> Trainer: model_structure = config.model.structure print('Create the model') if model_structure == 'pix2pix': with strategy.scope(): generator = get_generator_model(config) discriminator = get_discriminator_model(config) trainer = Pix2PixTrainer(generator=generator, discriminator=discriminator, data_loader=data_loader, strategy=strategy, config=config) return trainer else: raise ValueError(f"unknown model structure {model_structure}")
def materialize(strategy: tf.distribute.Strategy, value_or_nested_dict): """Materializes locally (possibly nested dict with) PerReplica values. Args: strategy: The strategy that will be used to evaluate. value_or_nested_dict: Either a single `PerReplica` object, or a nested dict with `PerReplica` values at the deepest level. Returns: Same type and format as the input, with PerReplica values replaced with corresponding `tf.Tensor`s. """ if isinstance(value_or_nested_dict, dict): nested_dict = value_or_nested_dict return { key: materialize(strategy, value) for key, value in nested_dict.items() } else: return tf.concat( strategy.experimental_local_results(value_or_nested_dict), axis=0).numpy()
def get_compiled_model(strategy: tf.distribute.Strategy) -> tf.keras.Model: with strategy.scope(): model = tf.keras.Sequential( [ tf.keras.layers.Conv2D( 32, 3, activation="relu", input_shape=(28, 28, 1) ), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Flatten(), tf.keras.layers.Dense(64, activation="relu"), tf.keras.layers.Dense(10), ] ) model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=tf.keras.optimizers.Adam(), metrics=["accuracy"], ) return model
def __init__(self, generator: tf.keras.Model, discriminator: tf.keras.Model, data_loader: DataLoader, strategy: tf.distribute.Strategy, config: DotMap) -> None: super().__init__(data_loader, strategy, config) self.generator: tf.keras.Model = generator self.discriminator: tf.keras.Model = discriminator with strategy.scope(): self.generator_optimizer = tf.keras.optimizers.Adam(learning_rate=config.model.generator.lr, beta_1=config.model.generator.beta1) self.discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=config.model.discriminator.lr, beta_1=config.model.discriminator.beta1) self.disc_real_accuracy = tf.keras.metrics.BinaryAccuracy(name='real_accuracy') self.disc_fake_accuracy = tf.keras.metrics.BinaryAccuracy(name='fake_accuracy') self.summary_writer = tf.summary.create_file_writer( os.path.join(config.exp.tensorboard_dir, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))) self.checkpoint = tf.train.Checkpoint(generator_optimizer=self.generator_optimizer, discriminator_optimizer=self.discriminator_optimizer, generator=generator, discriminator=discriminator)
def train_eval( root_dir, strategy: tf.distribute.Strategy, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others policy_save_interval=10000, replay_buffer_save_interval=100000, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) _, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) actor_net = create_sequential_actor_network( actor_fc_layers=actor_fc_layers, action_tensor_spec=action_tensor_spec) critic_net = create_sequential_critic_network( obs_fc_layer_units=critic_obs_fc_layers, action_fc_layer_units=critic_action_fc_layers, joint_fc_layer_units=critic_joint_fc_layers) with strategy.scope(): train_step = train_utils.create_train_step() agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.keras.optimizers.Adam( learning_rate=actor_learning_rate), critic_optimizer=tf.keras.optimizers.Adam( learning_rate=critic_learning_rate), alpha_optimizer=tf.keras.optimizers.Adam( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_checkpoint_dir = os.path.join(root_dir, learner.TRAIN_DIR, learner.REPLAY_BUFFER_CHECKPOINT_DIR) reverb_checkpointer = reverb.platform.checkpointers_lib.DefaultCheckpointer( path=reverb_checkpoint_dir) reverb_server = reverb.Server([table], port=reverb_port, checkpointer=reverb_checkpointer) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.ReverbCheckpointTrigger( train_step, interval=replay_buffer_save_interval, reverb_client=reverb_replay.py_client), # TODO(b/165023684): Add SIGTERM handler to checkpoint before preemption. triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers, strategy=strategy) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def run_experiment(distribution_strategy: tf.distribute.Strategy, task: base_task.Task, mode: str, params: config_definitions.ExperimentConfig, model_dir: str, run_post_eval: bool = False, save_summary: bool = True) \ -> Tuple[tf.keras.Model, Mapping[str, Any]]: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. task: A Task instance. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. run_post_eval: Whether to run post eval once after training, metrics logs are returned. save_summary: Whether to save train and validation summary. Returns: A 2-tuple of (model, eval_logs). model: `tf.keras.Model` instance. eval_logs: returns eval metrics logs when run_post_eval is set to True, otherwise, returns {}. """ with distribution_strategy.scope(): trainer = train_utils.create_trainer( params, task, model_dir=model_dir, train='train' in mode, evaluate=('eval' in mode) or run_post_eval, checkpoint_exporter=maybe_create_best_ckpt_exporter( params, model_dir)) if trainer.checkpoint: checkpoint_manager = tf.train.CheckpointManager( trainer.checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=trainer.global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=trainer.initialize) else: checkpoint_manager = None controller = orbit.Controller( distribution_strategy, trainer=trainer if 'train' in mode else None, evaluator=trainer, global_step=trainer.global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None, eval_summary_dir=os.path.join(model_dir, 'validation') if (save_summary) else None, summary_interval=params.trainer.summary_interval if (save_summary) else None) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if trainer.global_step.numpy() >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) if run_post_eval: with distribution_strategy.scope(): return trainer.model, trainer.evaluate( tf.convert_to_tensor(params.trainer.validation_steps)) else: return trainer.model, {}
def run_experiment( distribution_strategy: tf.distribute.Strategy, task: base_task.Task, mode: str, params: config_definitions.ExperimentConfig, model_dir: str, run_post_eval: bool = False, save_summary: bool = True, trainer: Optional[base_trainer.Trainer] = None, controller_cls=orbit.Controller ) -> Tuple[tf.keras.Model, Mapping[str, Any]]: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. task: A Task instance. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. run_post_eval: Whether to run post eval once after training, metrics logs are returned. save_summary: Whether to save train and validation summary. trainer: the base_trainer.Trainer instance. It should be created within the strategy.scope(). controller_cls: The controller class to manage the train and eval process. Must be a orbit.Controller subclass. Returns: A 2-tuple of (model, eval_logs). model: `tf.keras.Model` instance. eval_logs: returns eval metrics logs when run_post_eval is set to True, otherwise, returns {}. """ with distribution_strategy.scope(): if not trainer: trainer = train_utils.create_trainer( params, task, train='train' in mode, evaluate=('eval' in mode) or run_post_eval, checkpoint_exporter=maybe_create_best_ckpt_exporter( params, model_dir)) if trainer.checkpoint: if model_dir is None: raise ValueError('model_dir must be specified, but got None') checkpoint_manager = tf.train.CheckpointManager( trainer.checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=trainer.global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=trainer.initialize) else: checkpoint_manager = None controller = controller_cls( strategy=distribution_strategy, trainer=trainer if 'train' in mode else None, evaluator=trainer, global_step=trainer.global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None, eval_summary_dir=os.path.join( model_dir, params.trainer.validation_summary_subdir) if (save_summary) else None, summary_interval=params.trainer.summary_interval if (save_summary) else None, train_actions=actions.get_train_actions( params, trainer, model_dir, checkpoint_manager=checkpoint_manager), eval_actions=actions.get_eval_actions(params, trainer, model_dir)) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if trainer.global_step.numpy() >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) num_params = train_utils.try_count_params(trainer.model) if num_params is not None: logging.info('Number of trainable params in model: %f Millions.', num_params / 10.**6) flops = train_utils.try_count_flops(trainer.model) if flops is not None: logging.info('FLOPs (multi-adds) in model: %f Billions.', flops / 10.**9 / 2) if run_post_eval: with distribution_strategy.scope(): return trainer.model, trainer.evaluate( tf.convert_to_tensor(params.trainer.validation_steps)) else: return trainer.model, {}
def get_datasets(args, strategy: tf.distribute.Strategy, buffer_size: int = 256): """ Load and return preprocessed and distributed horse2zebra dataset """ dataset, metadata = tfds.load("cycle_gan/horse2zebra", with_info=True, as_supervised=True) train_horses, train_zebras = dataset["trainA"], dataset["trainB"] test_horses, test_zebras = dataset["testA"], dataset["testB"] # calculate the number of train and test steps needed per epoch get_size = lambda name: metadata.splits.get(name).num_examples num_train_samples = min([get_size('trainA'), get_size('trainB')]) num_test_samples = min([get_size('testA'), get_size('testB')]) args.train_steps = ceil(num_train_samples / args.global_batch_size) args.test_steps = ceil(num_test_samples / args.global_batch_size) def normalize_image(image): """ normalize image to [-1, 1] """ image = tf.cast(image, dtype=tf.float32) return (image / 127.5) - 1.0 def preprocess_train(image, _): image = tf.image.random_flip_left_right(image) image = tf.image.resize(image, size=IMAGE_SHAPE) image = tf.image.random_crop(image, size=INPUT_SHAPE) image = normalize_image(image) return image def preprocess_test(image, _): image = tf.image.resize(image, size=INPUT_SHAPE[:2]) image = normalize_image(image) return image train_horses = train_horses.take(num_train_samples) train_horses = train_horses.map(preprocess_train, num_parallel_calls=AUTOTUNE) train_horses = train_horses.cache() train_horses = train_horses.shuffle(buffer_size) train_zebras = train_zebras.take(num_train_samples) train_zebras = train_zebras.map(preprocess_train, num_parallel_calls=AUTOTUNE) train_zebras = train_zebras.cache() train_zebras = train_zebras.shuffle(buffer_size) test_horses = test_horses.take(num_test_samples) test_horses = test_horses.map(preprocess_test, num_parallel_calls=AUTOTUNE) test_horses = test_horses.cache() test_zebras = test_zebras.take(num_test_samples) test_zebras = test_zebras.map(preprocess_test, num_parallel_calls=AUTOTUNE) test_zebras = test_zebras.cache() train_ds = tf.data.Dataset.zip( (train_horses.batch(args.global_batch_size), train_zebras.batch(args.global_batch_size))).prefetch(AUTOTUNE) test_ds = tf.data.Dataset.zip((test_horses.batch(args.global_batch_size), test_zebras.batch(args.global_batch_size))) # take 5 samples from the test set for plotting plot_ds = tf.data.Dataset.zip( (test_horses.take(5).batch(1), test_zebras.take(5).batch(1))) # create distributed datasets train_ds = strategy.experimental_distribute_dataset(train_ds) test_ds = strategy.experimental_distribute_dataset(test_ds) return train_ds, test_ds, plot_ds
def run(train_dataset: tf.data.Dataset, eval_datasets: Dict[str, tf.data.Dataset], steps_per_eval: Dict[str, int], params: utils.ModelParameters, model_dir: str, strategy: tf.distribute.Strategy, summary_writer: tf.summary.SummaryWriter, loss_type: str, graph_augmenter: augmentation_utils.GraphAugment): """Trains and evaluates the model.""" with strategy.scope(): model = ub.models.mpnn( nodes_shape=train_dataset.element_spec[0]['atoms'].shape[1:], edges_shape=train_dataset.element_spec[0]['pairs'].shape[1:], num_heads=params.num_heads, num_layers=params.num_layers, message_layer_size=params.message_layer_size, readout_layer_size=params.readout_layer_size, use_gp_layer=params.use_gp_layer) optimizer = tf.keras.optimizers.RMSprop( learning_rate=params.learning_rate) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.CategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/roc_auc': tf.keras.metrics.AUC(), } for dataset_name in eval_datasets: metrics[ f'{dataset_name}/accuracy'] = tf.keras.metrics.CategoricalAccuracy( ) metrics[f'{dataset_name}/roc_auc'] = tf.keras.metrics.AUC() metrics[ f'{dataset_name}/negative_log_likelihood'] = tf.keras.metrics.Mean( ) if dataset_name == 'test2': ece_num_bins = 5 else: ece_num_bins = 10 metrics[ f'{dataset_name}/ece'] = rm.metrics.ExpectedCalibrationError( num_bins=ece_num_bins) metrics[f'{dataset_name}/brier'] = rm.metrics.Brier() @tf.function def train_step(iterator): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" if len(inputs) == 3: features, labels, sample_weights = inputs else: features, labels = inputs sample_weights = 1 if params.augmentations: # TODO(jihyeonlee): For now, choose 1 augmentation function from all # possible with equal probability. Allow user to specify number of # augmentations to apply per graph. features = graph_augmenter.augment(features) with tf.GradientTape() as tape: probs = model(features, training=True) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.categorical_crossentropy(labels, probs) * sample_weights) l2_loss = sum(model.losses) if loss_type == 'focal': focal_loss_fn = tfa_losses.SigmoidFocalCrossEntropy() focal_loss = tf.reduce_mean( focal_loss_fn(labels, probs) * sample_weights) loss = focal_loss + l2_loss else: loss = negative_log_likelihood + l2_loss # Scale the loss given the tf.distribute.Strategy will reduce sum all # gradients. See details in # https://www.tensorflow.org/tutorials/distribute/custom_training#define_the_loss_function scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) metrics['train/roc_auc'].update_state(labels[:, 1], probs[:, 1]) for _ in tf.range(tf.cast(params.steps_per_epoch, tf.int32)): strategy.run(step_fn, args=(next(iterator), )) @tf.function def eval_step(iterator, dataset_name, num_steps): """Evaluation StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" if len(inputs) == 3: features, labels, _ = inputs else: features, labels = inputs probs = model(features, training=False) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.categorical_crossentropy(labels, probs)) metrics[f'{dataset_name}/negative_log_likelihood'].update_state( negative_log_likelihood) metrics[f'{dataset_name}/accuracy'].update_state(labels, probs) metrics[f'{dataset_name}/roc_auc'].update_state( labels[:, 1], probs[:, 1]) metrics[f'{dataset_name}/ece'].add_batch(probs[:, 1], label=labels[:, 1]) metrics[f'{dataset_name}/brier'].add_batch(probs, label=labels[:, 1]) for _ in tf.range(tf.cast(num_steps, tf.int32)): strategy.run(step_fn, args=(next(iterator), )) # Makes datasets into distributed version. train_dataset = strategy.experimental_distribute_dataset(train_dataset) eval_datasets = { ds_name: strategy.experimental_distribute_dataset(ds) for ds_name, ds in eval_datasets.items() } logging.info('Number of replicas in sync: %s', strategy.num_replicas_in_sync) train_iterator = iter(train_dataset) start_time = time.time() metrics_history = collections.defaultdict(list) for epoch in range(params.num_epochs): logging.info('Starting to run epoch: %s', epoch) train_step(train_iterator) current_step = (epoch + 1) * params.steps_per_epoch max_steps = params.steps_per_epoch * params.num_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, params.num_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) # Start evaluation. logging.info('Starting to run eval at epoch: %s', epoch) for dataset_name, eval_dataset in eval_datasets.items(): eval_iterator = iter(eval_dataset) eval_step(eval_iterator, dataset_name, steps_per_eval[dataset_name]) metrics_history['epoch'].append(epoch + 1) with summary_writer.as_default(): for name, metric in metrics.items(): result = utils.get_metric_result_value(metric) tf.summary.scalar(name, result, step=epoch + 1) metrics_history[name].append(str(result)) for metric in metrics.values(): metric.reset_states() model.save(os.path.join(model_dir, f'model_{epoch + 1}'), overwrite=True) utils.write_params(metrics_history, os.path.join(model_dir, 'metrics_history.json'))
def train(strategy: tf.distribute.Strategy, model_fn: Callable, input_meta_data: Dict, train_input_fn: Callable, total_training_steps: int, steps_per_loop: int, optimizer: tf.keras.optimizers.Optimizer, learning_rate_fn: tf.keras.optimizers.schedules.LearningRateSchedule, eval_fn: Optional[Callable[ [tf.keras.Model, int, tf.summary.SummaryWriter], Any]] = None, metric_fn: Optional[Callable[[], tf.keras.metrics.Metric]] = None, init_checkpoint: Optional[Text] = None, init_from_transformerxl: Optional[bool] = False, model_dir: Optional[Text] = None, save_steps: Optional[int] = None, run_eagerly: Optional[bool] = False): """Runs customized training. Args: strategy: Distribution strategy on which to run low level training loop. model_fn: The function returns a keras.Model. input_meta_data: A dictionary of params: `mem_len`, `lr_layer_decay_rate`, `n_layer`, `batch_size_per_core` and `d_model`. train_input_fn: Function returns a tf.data.Dataset used for training. total_training_steps: Number of steps to train in total. steps_per_loop: Number of steps per graph-mode loop. In order to reduce communication in eager context, training logs are printed every steps_per_loop. optimizer: The optimizer for model. learning_rate_fn: the learning rate schedule. eval_fn: A callback of evaluation function, that takes a keras.Model, current step and evaluation summary writer. metric_fn: A metrics function returns a Keras Metric object to record evaluation result using evaluation dataset or with training dataset after every epoch. init_checkpoint: Optional checkpoint to load to `sub_model` returned by `model_fn`. init_from_transformerxl: Whether to load to `transformerxl_model` of `model_fn`. model_dir: The directory of model (checkpoints, summaries). save_steps: The frequency to save checkpoints. Every save_steps, we save a model checkpoint. Model checkpoint will be saved and evaluation will be conducted if evaluation dataset is provided. run_eagerly: Whether to run training eagerly. Returns: Last training step logits if training happens, otherwise returns None. Raises: TypeError: if model directory is not specified. """ required_arguments = [ train_input_fn, total_training_steps, steps_per_loop, optimizer, learning_rate_fn, save_steps ] if [arg for arg in required_arguments if arg is None]: raise ValueError("`train_input_fn`, `total_training_steps`, " "`steps_per_loop`, `optimizer`, `save_steps` and " "`learning_rate_fn` are required parameters.") if not model_dir: raise TypeError("Model directory must be specified.") train_iterator = data_utils.get_input_iterator(train_input_fn, strategy) if not tf.io.gfile.exists(model_dir): tf.io.gfile.mkdir(model_dir) # Create summary writers summary_dir = os.path.join(model_dir, "summaries") if not tf.io.gfile.exists(summary_dir): tf.io.gfile.mkdir(summary_dir) train_summary_writer = None eval_summary_writer = None if eval_fn: eval_summary_writer = tf.summary.create_file_writer( os.path.join(summary_dir, "eval")) if steps_per_loop >= _MIN_SUMMARY_STEPS: # Only writes summary when the stats are collected sufficiently over # enough steps. train_summary_writer = tf.summary.create_file_writer( os.path.join(summary_dir, "train")) with strategy.scope(): model = model_fn() if init_checkpoint: logging.info("restore from %s", init_checkpoint) if init_from_transformerxl: checkpoint = tf.train.Checkpoint( transformer_xl=model.transformerxl_model) else: checkpoint = tf.train.Checkpoint(model=model) checkpoint.restore(init_checkpoint) model.optimizer = optimizer if not hasattr(model, "optimizer"): raise ValueError("User should set optimizer attribute to model.") train_loss_metric = tf.keras.metrics.Mean("training_loss", dtype=tf.float32) train_metric = None if metric_fn: train_metric = metric_fn() def _replicated_step(inputs, mem=None): """Replicated training step.""" inputs["mems"] = mem with tf.GradientTape() as tape: mem, logits = model(inputs, training=True) loss = model.losses train_loss_metric.update_state(loss) if train_metric: train_metric.update_state(inputs["label_ids"], logits) scaled_loss = loss[0] * 1.0 / float( strategy.num_replicas_in_sync) # Collects training variables. tvars = model.trainable_variables grads = tape.gradient(scaled_loss, tvars) clipped, _ = tf.clip_by_global_norm(grads, clip_norm=1.0) if input_meta_data["lr_layer_decay_rate"] != 1.0: n_layer = 0 for i in range(len(clipped)): m = re.search(r"model/transformer/layer_(\d+?)/", tvars[i].name) if not m: continue n_layer = max(n_layer, int(m.group(1)) + 1) for i in range(len(clipped)): for l in range(n_layer): if "model/transformer/layer_{}/".format( l) in tvars[i].name: abs_rate = input_meta_data[ "lr_layer_decay_rate"]**(n_layer - 1 - l) clipped[i] *= abs_rate logging.info( "Apply mult {:.4f} to layer-{} grad of {}". format(abs_rate, l, tvars[i].name)) break optimizer.apply_gradients(zip(clipped, tvars)) if input_meta_data["mem_len"] > 0: return mem def train_steps(iterator, steps): """Performs distributed training steps in a loop. Args: iterator: the distributed iterator of training datasets. steps: an tf.int32 integer tensor to specify number of steps to run inside host training loop. Raises: ValueError: Any of the arguments or tensor shapes are invalid. Returns: logits: logits computed. """ if not isinstance(steps, tf.Tensor): raise ValueError( "steps should be an Tensor. Python object may cause " "retracing.") def cache_fn(): """Initializes memory tensor used in XLNet pretraining.""" mems = [] if input_meta_data["mem_len"] > 0: for _ in range(input_meta_data["n_layer"]): zeros = tf.zeros([ input_meta_data["batch_size_per_core"], input_meta_data["mem_len"], input_meta_data["d_model"] ], dtype=tf.float32) mems.append(zeros) return mems if input_meta_data["mem_len"] > 0: mem = strategy.run(cache_fn) for _ in tf.range(steps): mem = strategy.run(_replicated_step, args=( next(iterator), mem, )) else: for _ in tf.range(steps): strategy.run(_replicated_step, args=(next(iterator), )) if not run_eagerly: train_steps = tf.function(train_steps) logging.info("Start training...") checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) if latest_checkpoint_file: logging.info( "Checkpoint file %s found and restoring from checkpoint", latest_checkpoint_file) checkpoint.restore(latest_checkpoint_file) logging.info("Loading from checkpoint file completed") current_step = optimizer.iterations.numpy() checkpoint_name = "xlnet_step_{step}.ckpt" while current_step < total_training_steps: train_loss_metric.reset_states() if train_metric: train_metric.reset_states() steps = model_training_utils.steps_to_run(current_step, save_steps, steps_per_loop) train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32)) current_step += steps train_loss = _float_metric_value(train_loss_metric) log_stream = "Train step: %d/%d / lr = %.9f / loss = %.7f" % ( current_step, total_training_steps, learning_rate_fn(current_step), train_loss) if train_metric: log_stream += " / %s = %f" % ( train_metric.name, _float_metric_value(train_metric)) logging.info(log_stream) if train_summary_writer: with train_summary_writer.as_default(): tf.summary.scalar("learning_rate", learning_rate_fn(current_step), step=current_step) tf.summary.scalar(train_loss_metric.name, train_loss, step=current_step) if train_metric: tf.summary.scalar(train_metric.name, _float_metric_value(train_metric), step=current_step) train_summary_writer.flush() if model_dir and current_step % save_steps == 0: _save_checkpoint(checkpoint, model_dir, checkpoint_name.format(step=current_step)) if eval_fn and current_step % save_steps == 0: logging.info("Running evaluation after step: %s.", current_step) eval_fn(model, current_step, eval_summary_writer) if model_dir: _save_checkpoint(checkpoint, model_dir, checkpoint_name.format(step=current_step)) if eval_fn: logging.info( "Running final evaluation after training is complete.") eval_metric = eval_fn(model, current_step, eval_summary_writer) training_summary = { "total_training_steps": total_training_steps, "train_loss": _float_metric_value(train_loss_metric), } if train_metric: training_summary["last_train_metrics"] = _float_metric_value( train_metric) if eval_fn: # eval_metric is supposed to be a float. training_summary["eval_metrics"] = eval_metric model_training_utils.write_txt_summary(training_summary, summary_dir) return model
def run_experiment(*, distribution_strategy: tf.distribute.Strategy, task: multitask.MultiTask, model: base_model.MultiTaskBaseModel, mode: str, params: configs.MultiTaskExperimentConfig, model_dir: str) -> base_model.MultiTaskBaseModel: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. task: A MultiTaskTask instance. model: A MultiTaskBaseModel instance. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. Returns: model: `base_model.MultiTaskBaseModel` instance. """ is_training = 'train' in mode is_eval = 'eval' in mode with distribution_strategy.scope(): optimizer = task.create_optimizer(params.trainer.optimizer_config, params.runtime) kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer) if params.trainer.trainer_type == 'interleaving': sampler = task_sampler.get_task_sampler(params.trainer.task_sampler, task.task_weights) kwargs.update(dict(task_sampler=sampler)) trainer = TRAINERS[params.trainer.trainer_type]( **kwargs) if is_training else None if is_eval: eval_steps = task.task_eval_steps evaluator = evaluator_lib.MultiTaskEvaluator( eval_tasks=task.tasks.values(), model=model, eval_steps=eval_steps, global_step=trainer.global_step if is_training else None, checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( params, model_dir)) else: evaluator = None if trainer: checkpoint = trainer.checkpoint global_step = trainer.global_step else: checkpoint = evaluator.checkpoint global_step = evaluator.global_step # TODO(hongkuny,haozhangthu): Revisit initialization method. checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=model.initialize) controller = orbit.Controller( strategy=distribution_strategy, trainer=trainer, evaluator=evaluator, global_step=global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(model_dir, 'train'), eval_summary_dir=os.path.join(model_dir, 'validation'), summary_interval=params.trainer.summary_interval) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if evaluator.global_step.numpy() >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) return model
def run_experiment_with_multitask_eval( *, distribution_strategy: tf.distribute.Strategy, train_task: base_task.Task, eval_tasks: List[base_task.Task], mode: str, params: configs.MultiEvalExperimentConfig, model_dir: str, run_post_eval: bool = False, save_summary: bool = True, trainer: Optional[core_lib.Trainer] = None) -> tf.keras.Model: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. train_task: A base_task.Task instance. eval_tasks: A list of evaluation tasks. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: MultiEvalExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. run_post_eval: Whether to run post eval once after training, metrics logs are returned. save_summary: Whether to save train and validation summary. trainer: the core_lib.Trainer instance. It should be created within the strategy.scope(). If not provided, an instance will be created by default if `mode` contains 'train'. Returns: model: `tf.keras.Model` instance. """ is_training = 'train' in mode is_eval = 'eval' in mode with distribution_strategy.scope(): if is_training: trainer = trainer or core_lib.Trainer( config=params, task=train_task, model=train_task.build_model(), optimizer=train_task.create_optimizer(params.trainer.optimizer_config, params.runtime), train=True, evaluate=False) else: trainer = None model = trainer.model if trainer else train_task.build_model() if is_eval: eval_steps = dict([(task_routine.task_config.name, task_routine.eval_steps) for task_routine in params.eval_tasks]) evaluator = evaluator_lib.MultiTaskEvaluator( eval_tasks=eval_tasks, model=model, global_step=trainer.global_step if is_training else None, eval_steps=eval_steps, checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( params, model_dir)) else: evaluator = None if trainer: checkpoint = trainer.checkpoint global_step = trainer.global_step else: checkpoint = evaluator.checkpoint global_step = evaluator.global_step checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=trainer.initialize if trainer else None) controller = orbit.Controller( strategy=distribution_strategy, trainer=trainer, evaluator=evaluator, global_step=global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(model_dir, 'train') if save_summary else None, eval_summary_dir=os.path.join(model_dir, 'validation') if (save_summary) else None, summary_interval=params.trainer.summary_interval if (save_summary) else None) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if evaluator.global_step.numpy() >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) if run_post_eval: return model, evaluator.evaluate( tf.convert_to_tensor(params.trainer.validation_steps)) else: return model, {}
def run_experiment( distribution_strategy: tf.distribute.Strategy, task: base_task.Task, mode: str, params: config_definitions.ExperimentConfig, model_dir: str, run_post_eval: bool = False, save_summary: bool = True, trainer: Optional[base_trainer.Trainer] = None ) -> Tuple[tf.keras.Model, Mapping[str, Any]]: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. task: A Task instance. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. run_post_eval: Whether to run post eval once after training, metrics logs are returned. save_summary: Whether to save train and validation summary. trainer: the base_trainer.Trainer instance. It should be created within the strategy.scope(). Returns: A 2-tuple of (model, eval_logs). model: `tf.keras.Model` instance. eval_logs: returns eval metrics logs when run_post_eval is set to True, otherwise, returns {}. """ with distribution_strategy.scope(): if not trainer: trainer = train_utils.create_trainer( params, task, train='train' in mode, evaluate=('eval' in mode) or run_post_eval, checkpoint_exporter=maybe_create_best_ckpt_exporter( params, model_dir)) if trainer.checkpoint: checkpoint_manager = tf.train.CheckpointManager( trainer.checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=trainer.global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=trainer.initialize) # Adds recovery handling. trainer.add_recovery(params.trainer, checkpoint_manager=checkpoint_manager) else: checkpoint_manager = None #Create logs matching tensorboard log parser format #see tensorboard_for_parser.md hparams = { "batch_size": params.task.train_data.global_batch_size, "precision": params.runtime.mixed_precision_dtype } controller = orbit.Controller( strategy=distribution_strategy, trainer=trainer if 'train' in mode else None, evaluator=trainer, global_step=trainer.global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=model_dir if (save_summary) else None, eval_summary_dir=os.path.join( model_dir, params.trainer.validation_summary_subdir) if (save_summary) else None, summary_interval=params.trainer.summary_interval if (save_summary) else None, hparams=hparams if (save_summary) else None, train_actions=None, eval_actions=actions.get_eval_actions(params, trainer, model_dir)) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if (params.runtime.dump_config): from TensorFlow.common.debug import dump_callback with dump_callback( params.runtime.dump_config ) if params.runtime.dump_config else contextlib.ExitStack(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if trainer.global_step.numpy( ) >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) num_params = train_utils.try_count_params(trainer.model) if num_params is not None: logging.info('Number of trainable params in model: %f Millions.', num_params / 10.**6) if run_post_eval: with distribution_strategy.scope(): return trainer.model, trainer.evaluate( tf.convert_to_tensor(params.trainer.validation_steps)) else: return trainer.model, {}
def run( train_dataset: tf.data.Dataset, eval_datasets: Dict[str, tf.data.Dataset], steps_per_eval: Dict[str, int], params: utils.ModelParameters, model_dir: str, gp_layer_kwargs: Dict[str, Any], strategy: tf.distribute.Strategy, summary_writer: tf.summary.SummaryWriter, loss_type: str, use_spec_norm: bool, spec_norm_multiplier: float, use_spec_norm_mp: bool, spec_norm_multiplier_mp: float): """Trains and evaluates the model. Args: train_dataset: tf dataset that provides training data. eval_datasets: A dictionary of tf datasets that provides data for model evaluation. steps_per_eval: A dictionary of steps needed for each evaluation dataset. params: ModelParameters object containing MPNN model parameters. model_dir: Directory for files generated during training and evaluation. gp_layer_kwargs: A dictionary of parameters used for GP layer. strategy: tf Distributed training strategy object. summary_writer: tf summary writer to log training and evaluation metrics. loss_type: str, loss type to use during training. Currently only supports focal loss and cross-entropy loss. use_spec_norm: Whether to use Spectral normalization for the dense layer. spec_norm_multiplier: Multiplier used to control the magnitude of eigenvalue of the dense layer weight matrix. use_spec_norm_mp: Whether to use Spectral normalization for the MP layer. spec_norm_multiplier_mp: Multiplier used to control the magnitude of eigenvalue of the MP layer weight matrix. """ with strategy.scope(): model = ub.models.mpnn( nodes_shape=train_dataset.element_spec[0]['atoms'].shape[1:], edges_shape=train_dataset.element_spec[0]['pairs'].shape[1:], num_heads=params.num_heads, num_layers=params.num_layers, message_layer_size=params.message_layer_size, readout_layer_size=params.readout_layer_size, use_gp_layer=params.use_gp_layer, gp_layer_kwargs=gp_layer_kwargs, use_spec_norm=use_spec_norm, spec_norm_multiplier=spec_norm_multiplier, use_spec_norm_mp=use_spec_norm_mp, spec_norm_multiplier_mp=spec_norm_multiplier_mp) optimizer = tf.keras.optimizers.RMSprop(learning_rate=params.learning_rate) metrics = { 'train/negative_log_likelihood': tf.keras.metrics.Mean(), 'train/accuracy': tf.keras.metrics.CategoricalAccuracy(), 'train/loss': tf.keras.metrics.Mean(), 'train/roc_auc': tf.keras.metrics.AUC(), } for dataset_name in eval_datasets: metrics[ f'{dataset_name}/accuracy'] = tf.keras.metrics.CategoricalAccuracy() metrics[f'{dataset_name}/roc_auc'] = tf.keras.metrics.AUC() metrics[ f'{dataset_name}/negative_log_likelihood'] = tf.keras.metrics.Mean() if dataset_name == 'test2': ece_num_bins = 5 else: ece_num_bins = 10 metrics[f'{dataset_name}/ece'] = rm.metrics.ExpectedCalibrationError( num_bins=ece_num_bins) metrics[f'{dataset_name}/brier'] = rm.metrics.Brier() def per_replica_train_step_fn(inputs): """Per-Replica StepFn.""" if len(inputs) == 3: features, labels, sample_weights = inputs else: features, labels = inputs sample_weights = 1 with tf.GradientTape() as tape: probs = model(features, training=True) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.categorical_crossentropy(labels, probs) * sample_weights) l2_loss = sum(model.losses) if loss_type == 'focal': focal_loss_fn = tfa_losses.SigmoidFocalCrossEntropy() focal_loss = tf.reduce_mean( focal_loss_fn(labels, probs) * sample_weights) loss = focal_loss + l2_loss else: loss = negative_log_likelihood + l2_loss # Scale the loss given the tf.distribute.Strategy will reduce sum all # gradients. See details in # https://www.tensorflow.org/tutorials/distribute/custom_training#define_the_loss_function scaled_loss = loss / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['train/accuracy'].update_state(labels, probs) metrics['train/roc_auc'].update_state(labels[:, 1], probs[:, 1]) def per_replica_eval_step_fn(inputs, dataset_name): """Per-Replica StepFn.""" if len(inputs) == 3: features, labels, _ = inputs else: features, labels = inputs probs = model(features, training=False) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.categorical_crossentropy(labels, probs)) metrics[f'{dataset_name}/negative_log_likelihood'].update_state( negative_log_likelihood) metrics[f'{dataset_name}/accuracy'].update_state(labels, probs) metrics[f'{dataset_name}/roc_auc'].update_state(labels[:, 1], probs[:, 1]) metrics[f'{dataset_name}/ece'].add_batch(probs[:, 1], label=labels[:, 1]) metrics[f'{dataset_name}/brier'].add_batch(probs, label=labels[:, 1]) @tf.function def distributed_train_step(iterator): """Training StepFn.""" for _ in tf.range(tf.cast(params.steps_per_epoch, tf.int32)): strategy.run(per_replica_train_step_fn, args=(next(iterator),)) @tf.function def distributed_eval_step(iterator, dataset_name, num_steps): """Evaluation StepFn.""" for _ in tf.range(tf.cast(num_steps, tf.int32)): strategy.run( per_replica_eval_step_fn, args=(next(iterator), dataset_name)) # Makes datasets into distributed version. train_dataset = strategy.experimental_distribute_dataset(train_dataset) eval_datasets = { ds_name: strategy.experimental_distribute_dataset(ds) for ds_name, ds in eval_datasets.items() } logging.info('Number of replicas in sync: %s', strategy.num_replicas_in_sync) train_iterator = iter(train_dataset) start_time = time.time() metrics_history = collections.defaultdict(list) for epoch in range(params.num_epochs): logging.info('Starting to run epoch: %s', epoch) distributed_train_step(train_iterator) current_step = (epoch + 1) * params.steps_per_epoch max_steps = params.steps_per_epoch * params.num_epochs time_elapsed = time.time() - start_time steps_per_sec = float(current_step) / time_elapsed eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( current_step / max_steps, epoch + 1, params.num_epochs, steps_per_sec, eta_seconds / 60, time_elapsed / 60)) logging.info(message) # Start evaluation. logging.info('Starting to run eval at epoch: %s', epoch) for dataset_name, eval_dataset in eval_datasets.items(): eval_iterator = iter(eval_dataset) distributed_eval_step(eval_iterator, dataset_name, steps_per_eval[dataset_name]) metrics_history['epoch'].append(epoch + 1) with summary_writer.as_default(): for name, metric in metrics.items(): result = utils.get_metric_result_value(metric) tf.summary.scalar(name, result, step=epoch + 1) metrics_history[name].append(str(result)) for metric in metrics.values(): metric.reset_states() model.save(os.path.join(model_dir, f'model_{epoch + 1}'), overwrite=True) utils.write_params(metrics_history, os.path.join(model_dir, 'metrics_history.json'))
def predict_image(strategy: tf.distribute.Strategy, gen: Model, input_z: tf.Tensor): return strategy.run(lambda: gen(input_z))