def resume_from_checkpoint(model: tf.keras.Model, model_dir: str, params, flags, train_steps: int) -> int: """Resumes from the latest checkpoint, if possible. Loads the model weights and optimizer settings from a checkpoint. This function should be used in case of preemption recovery. Args: model: The model whose weights should be restored. model_dir: The directory where model weights were saved. train_steps: The number of steps to train. Returns: The epoch of the latest checkpoint, or 0 if not restoring. """ logging.info('Load from checkpoint is enabled.') latest_checkpoint = tf.train.latest_checkpoint(model_dir) logging.info('latest_checkpoint: %s', latest_checkpoint) if latest_checkpoint: logging.info( 'Checkpoint file %s found and restoring from ' 'checkpoint', latest_checkpoint) model.load_weights(latest_checkpoint) elif flags.init_chkpt: logging.info('Load init checkpoint from: %s', flags.init_chkpt) model.load_weights(flags.init_chkpt) # model.optimizer.iterations.assign(int(model.optimizer.iterations * flags.SWITCH_FROM / flags.SWITCH_TO)) model.optimizer.iterations.assign( int(train_steps * int(flags.init_chkpt[-4:]))) else: logging.info('No checkpoint detected.') return 0 if flags.freeze_lr: learning_rate = optimizer_factory.build_learning_rate( params=params.model.learning_rate, batch_size=int(flags.SWITCH_FROM), train_steps=train_steps) model.optimizer.lr = learning_rate initial_epoch = model.optimizer.iterations // train_steps # Check the epoch count if not latest_checkpoint and flags.init_chkpt: assert initial_epoch == int( flags.init_chkpt[-4:]), (model.optimizer.iterations, train_steps, initial_epoch, flags.init_chkpt[-4:]) logging.info('Completed loading from checkpoint.') logging.info('Resuming from epoch %d', initial_epoch) return int(initial_epoch)
def test_learning_rate_without_decay_or_warmups(self): params = base_configs.LearningRateConfig(name='exponential', initial_lr=0.01, decay_rate=0.01, decay_epochs=None, warmup_epochs=None, scale_by_batch_size=0.01, examples_per_epoch=1, boundaries=[0], multipliers=[0, 1]) batch_size = 1 train_steps = 1 lr = optimizer_factory.build_learning_rate(params=params, batch_size=batch_size, train_steps=train_steps) self.assertTrue( issubclass(type(lr), tf.keras.optimizers.schedules.LearningRateSchedule))
def test_learning_rate_with_decay_and_warmup(self, lr_decay_type): """Basic smoke test for syntax.""" params = base_configs.LearningRateConfig(name=lr_decay_type, initial_lr=0.01, decay_rate=0.01, decay_epochs=1, warmup_epochs=1, scale_by_batch_size=0.01, examples_per_epoch=1, boundaries=[0], multipliers=[0, 1]) batch_size = 1 train_steps = 1 lr = optimizer_factory.build_learning_rate(params=params, batch_size=batch_size, train_steps=train_steps) self.assertTrue( issubclass(type(lr), tf.keras.optimizers.schedules.LearningRateSchedule))
def train_and_eval( params: base_configs.ExperimentConfig, strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]: """Runs the train and eval path using compile/fit.""" logging.info('Running train and eval.') # Note: for TPUs, strategy and scope should be created before the dataset strategy = strategy_override or distribution_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) strategy_scope = distribution_utils.get_strategy_scope(strategy) logging.info('Detected %d devices.', strategy.num_replicas_in_sync if strategy else 1) label_smoothing = params.model.loss.label_smoothing one_hot = label_smoothing and label_smoothing > 0 builders = _get_dataset_builders(params, strategy, one_hot) datasets = [ builder.build(strategy) if builder else None for builder in builders ] # Unpack datasets and builders based on train/val/test splits train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking train_dataset, validation_dataset = datasets train_epochs = params.train.epochs train_steps = params.train.steps or train_builder.num_steps validation_steps = params.evaluation.steps or validation_builder.num_steps initialize(params, train_builder) logging.info('Global batch size: %d', train_builder.global_batch_size) with strategy_scope: model_params = params.model.model_params.as_dict() model = get_models()[params.model.name](**model_params) learning_rate = optimizer_factory.build_learning_rate( params=params.model.learning_rate, batch_size=train_builder.global_batch_size, train_steps=train_steps) optimizer = optimizer_factory.build_optimizer( optimizer_name=params.model.optimizer.name, base_learning_rate=learning_rate, params=params.model.optimizer.as_dict()) metrics_map = _get_metrics(one_hot) metrics = [metrics_map[metric] for metric in params.train.metrics] if one_hot: loss_obj = tf.keras.losses.CategoricalCrossentropy( label_smoothing=params.model.loss.label_smoothing) else: loss_obj = tf.keras.losses.SparseCategoricalCrossentropy() model.compile( optimizer=optimizer, loss=loss_obj, metrics=metrics, experimental_steps_per_execution=params.train.steps_per_loop) initial_epoch = 0 if params.train.resume_checkpoint: initial_epoch = resume_from_checkpoint(model=model, model_dir=params.model_dir, train_steps=train_steps) callbacks = custom_callbacks.get_callbacks( model_checkpoint=params.train.callbacks. enable_checkpoint_and_export, include_tensorboard=params.train.callbacks.enable_tensorboard, time_history=params.train.callbacks.enable_time_history, track_lr=params.train.tensorboard.track_lr, write_model_weights=params.train.tensorboard.write_model_weights, initial_step=initial_epoch * train_steps, batch_size=train_builder.global_batch_size, log_steps=params.train.time_history.log_steps, model_dir=params.model_dir) serialize_config(params=params, model_dir=params.model_dir) if params.evaluation.skip_eval: validation_kwargs = {} else: validation_kwargs = { 'validation_data': validation_dataset, 'validation_steps': validation_steps, 'validation_freq': params.evaluation.epochs_between_evals, } history = model.fit(train_dataset, epochs=train_epochs, steps_per_epoch=train_steps, initial_epoch=initial_epoch, callbacks=callbacks, verbose=2, **validation_kwargs) validation_output = None if not params.evaluation.skip_eval: validation_output = model.evaluate(validation_dataset, steps=validation_steps, verbose=2) # TODO(dankondratyuk): eval and save final test accuracy stats = common.build_stats(history, validation_output, callbacks) return stats