def __init__(self, params): self._use_bfloat16 = params.architecture.use_bfloat16 self._l2_weight_decay = params.train.l2_weight_decay # Optimization. self._optimizer_fn = OptimizerFactory(params.train.optimizer) self._learning_rate_fn = learning_rates.learning_rate_generator( params.train.learning_rate) self._gradient_clip_norm = params.train.gradient_clip_norm self._frozen_variable_prefix = params.train.frozen_variable_prefix # Checkpoint restoration. self._checkpoint = params.train.checkpoint.path self._checkpoint_prefix = params.train.checkpoint.prefix # Summary. self._enable_summary = params.enable_summary self._summaries = {} self._model_dir = params.model_dir self._iterations_per_loop = params.train.iterations_per_loop # Platform device. self._use_tpu = params.use_tpu
def __init__(self, params): self._transpose_input = params.train.transpose_input self._space_to_depth_block_size = ( params.architecture.space_to_depth_block_size) self._use_bfloat16 = params.architecture.use_bfloat16 self._l2_weight_decay = float(params.train.l2_weight_decay) # Optimization. self._optimizer_fn = OptimizerFactory(params.train.optimizer) self._learning_rate_fn = learning_rates.learning_rate_generator( params.train.learning_rate, params.train.total_steps) self._gradient_clip_norm = params.train.gradient_clip_norm self._frozen_var_prefix = params.train.frozen_variable_prefix self._regularization_var_regex = params.train.regularization_variable_regex # Checkpoint restoration. self._checkpoint = params.train.checkpoint.path self._checkpoint_prefix = params.train.checkpoint.prefix self._skip_variables_regex = params.train.checkpoint.skip_variables_regex # Summary. self._enable_summary = params.enable_summary self._summaries = {} self._image_summaries = {} self._model_dir = params.model_dir self._iterations_per_loop = params.train.iterations_per_loop # Platform device. self._use_tpu = params.use_tpu
def __init__(self, params): self._use_bfloat16 = params.architecture.use_bfloat16 if params.architecture.use_bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) # Optimization. self._optimizer_fn = OptimizerFactory(params.train.optimizer) self._learning_rate = learning_rates.learning_rate_generator( params.train.learning_rate) self._frozen_variable_prefix = params.train.frozen_variable_prefix self._regularization_var_regex = params.train.regularization_variable_regex self._l2_weight_decay = params.train.l2_weight_decay # Checkpoint restoration. self._checkpoint = params.train.checkpoint.as_dict() # Summary. self._enable_summary = params.enable_summary self._model_dir = params.model_dir