def train_step(self, inputs: Tuple[Any, Any], model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, metrics: Optional[List[Any]] = None): """Does forward and backward. Args: inputs: a dictionary of input tensors. model: the model, forward pass definition. optimizer: the optimizer for this training step. metrics: a nested structure of metrics objects. Returns: A dictionary of logs. """ features, labels = inputs input_partition_dims = self.task_config.train_input_partition_dims if input_partition_dims: strategy = tf.distribute.get_strategy() features['image'] = strategy.experimental_split_to_logical_devices( features['image'], input_partition_dims) num_replicas = tf.distribute.get_strategy().num_replicas_in_sync with tf.GradientTape() as tape: outputs = model(features, training=True) # Casting output layer as float32 is necessary when mixed_precision is # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) # Computes per-replica loss. if self._is_multilabel(): outputs = tf.nest.map_structure(tf.math.sigmoid, outputs) else: outputs = tf.nest.map_structure(tf.math.softmax, outputs) all_losses = self.build_losses(model_outputs=outputs, labels=labels, aux_losses=model.losses) loss = all_losses[self.loss] # Scales loss as the default gradients allreduce performs sum inside the # optimizer. scaled_loss = loss / num_replicas # For mixed_precision policy, when LossScaleOptimizer is used, loss is # scaled for numerical stability. if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer): scaled_loss = optimizer.get_scaled_loss(scaled_loss) tvars = model.trainable_variables grads = tape.gradient(scaled_loss, tvars) # Scales back gradient before apply_gradients when LossScaleOptimizer is # used. if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer): grads = optimizer.get_unscaled_gradients(grads) optimizer.apply_gradients(list(zip(grads, tvars))) logs = all_losses if metrics: self.process_metrics(metrics, labels, outputs) logs.update({m.name: m.result() for m in metrics}) elif model.compiled_metrics: self.process_compiled_metrics(model.compiled_metrics, labels, outputs) logs.update({m.name: m.result() for m in model.metrics}) return logs
def train_step(self, inputs: Tuple[Any, Any], model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, metrics: Optional[List[Any]] = None) -> Dict[str, Any]: """Does forward and backward. Args: inputs: a dictionary of input tensors. model: the model, forward pass definition. optimizer: the optimizer for this training step. metrics: a nested structure of metrics objects. Returns: A dictionary of logs. """ images, labels = inputs num_replicas = tf.distribute.get_strategy().num_replicas_in_sync with tf.GradientTape() as tape: outputs = model( images, image_info=labels['image_info'], anchor_boxes=labels['anchor_boxes'], gt_boxes=labels['gt_boxes'], gt_classes=labels['gt_classes'], gt_masks=(labels['gt_masks'] if self.task_config.model.include_mask else None), training=True) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) # Computes per-replica loss. losses = self.build_losses(outputs=outputs, labels=labels, aux_losses=model.losses) scaled_loss = losses['total_loss'] / num_replicas # For mixed_precision policy, when LossScaleOptimizer is used, loss is # scaled for numerical stability. if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer): scaled_loss = optimizer.get_scaled_loss(scaled_loss) tvars = model.trainable_variables grads = tape.gradient(scaled_loss, tvars) # Scales back gradient when LossScaleOptimizer is used. if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer): grads = optimizer.get_unscaled_gradients(grads) optimizer.apply_gradients(list(zip(grads, tvars))) logs = {self.loss: losses['total_loss']} if metrics: for m in metrics: m.update_state(losses[m.name]) if self.task_config.segmentation_evaluation.report_train_mean_iou: segmentation_labels = { 'masks': labels['gt_segmentation_mask'], 'valid_masks': labels['gt_segmentation_valid_mask'], 'image_info': labels['image_info'] } self.process_metrics(metrics=[self.segmentation_train_mean_iou], labels=segmentation_labels, model_outputs=outputs['segmentation_outputs']) logs.update({ self.segmentation_train_mean_iou.name: self.segmentation_train_mean_iou.result() }) return logs
def train_step(self, inputs, model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, metrics=None): """Does forward and backward. With distribution strategies, this method runs on devices. Args: inputs: a dictionary of input tensors. model: the model, forward pass definition. optimizer: the optimizer for this training step. metrics: a nested structure of metrics objects. Returns: A dictionary of logs. """ if isinstance(inputs, tuple) and len(inputs) == 2: features, labels = inputs else: features, labels = inputs, inputs with tf.GradientTape() as tape: outputs = model(features, training=True) # Computes per-replica loss. if model.compiled_loss: loss = model.compiled_loss(labels, outputs, regularization_losses=model.losses) loss += self.build_losses(labels=labels, model_outputs=outputs, aux_losses=None) else: loss = self.build_losses(labels=labels, model_outputs=outputs, aux_losses=model.losses) # Scales loss as the default gradients allreduce performs sum inside the # optimizer. scaled_loss = loss / tf.distribute.get_strategy( ).num_replicas_in_sync # For mixed precision, when a LossScaleOptimizer is used, the loss is # scaled to avoid numeric underflow. if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer): scaled_loss = optimizer.get_scaled_loss(scaled_loss) tvars = model.trainable_variables grads = tape.gradient(scaled_loss, tvars) if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer): grads = optimizer.get_unscaled_gradients(grads) optimizer.apply_gradients(list(zip(grads, tvars))) logs = {self.loss: loss} if metrics: self.process_metrics(metrics, labels, outputs) if model.compiled_metrics: self.process_compiled_metrics(model.compiled_metrics, labels, outputs) logs.update({m.name: m.result() for m in metrics or []}) logs.update({m.name: m.result() for m in model.metrics}) return logs
def train_step(self, inputs: Tuple[Any, Any], model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, metrics: Optional[List[Any]] = None): """Does forward and backward. Args: inputs: A tuple of of input tensors of (features, labels). model: A tf.keras.Model instance. optimizer: The optimizer for this training step. metrics: A nested structure of metrics objects. Returns: A dictionary of logs. """ features, labels = inputs is_multilabel = self.task_config.train_data.is_multilabel if self.task_config.losses.one_hot and not is_multilabel: labels = tf.one_hot(labels, self.task_config.model.num_classes) self.image_summary_manager.write_summaries({ 'input_images': features }) num_replicas = tf.distribute.get_strategy().num_replicas_in_sync with tf.GradientTape() as tape: outputs = model(features, training=True) # Casting output layer as float32 is necessary when mixed_precision is # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. outputs = tf.nest.map_structure( lambda x: tf.cast(x, tf.float32), outputs) # Computes per-replica loss. loss = self.build_losses( model_outputs=outputs, labels=labels, aux_losses=model.losses) # Scales loss as the default gradients allreduce performs sum inside the # optimizer. scaled_loss = loss / num_replicas # For mixed_precision policy, when LossScaleOptimizer is used, loss is # scaled for numerical stability. if isinstance( optimizer, tf.keras.mixed_precision.LossScaleOptimizer): scaled_loss = optimizer.get_scaled_loss(scaled_loss) tvars = model.trainable_variables grads = tape.gradient(scaled_loss, tvars) # Scales back gradient before apply_gradients when LossScaleOptimizer is # used. if isinstance( optimizer, tf.keras.mixed_precision.LossScaleOptimizer): grads = optimizer.get_unscaled_gradients(grads) optimizer.apply_gradients(list(zip(grads, tvars))) logs = {self.loss: loss} if metrics: self.process_metrics(metrics, labels, outputs) elif model.compiled_metrics: self.process_compiled_metrics(model.compiled_metrics, labels, outputs) logs.update({m.name: m.result() for m in model.metrics}) return logs