def _replicated_step(inputs): """Replicated training step.""" inputs, labels = inputs with tf.GradientTape() as tape: model_outputs = model(inputs, training=True) loss = loss_fn(labels, model_outputs) # Raw loss is used for reporting in metrics/logs. raw_loss = loss if scale_loss: # Scales down the loss for gradients to be invariant from replicas. loss = loss / strategy.num_replicas_in_sync if explicit_allreduce: grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss, training_vars, pre_allreduce_callbacks, post_allreduce_callbacks) else: if isinstance(optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer): with tape: scaled_loss = optimizer.get_scaled_loss(loss) scaled_grads = tape.gradient(scaled_loss, training_vars) grads = optimizer.get_unscaled_gradients(scaled_grads) else: grads = tape.gradient(loss, training_vars) optimizer.apply_gradients(zip(grads, training_vars)) # For reporting, the metric takes the mean of losses. train_loss_metric.update_state(raw_loss) for metric in train_metrics: metric.update_state(labels, model_outputs)
def step_fn(inputs): """Function to run on the device.""" images, labels = inputs with tf.GradientTape() as tape: logits = self.model(images, training=True) prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(prediction_loss) * ( 1.0 / self.flags_obj.batch_size) num_replicas = self.strategy.num_replicas_in_sync if self.flags_obj.single_l2_loss_op: l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([ tf.nn.l2_loss(v) for v in self.model.trainable_variables if 'bn' not in v.name ]) loss += (l2_loss / num_replicas) else: loss += (tf.reduce_sum(self.model.losses) / num_replicas) grad_utils.minimize_using_explicit_allreduce( tape, self.optimizer, loss, self.model.trainable_variables) self.train_loss.update_state(loss) self.train_accuracy.update_state(labels, logits)
def step_fn(inputs): """Function to run on the device.""" images, labels = inputs if self.one_hot: labels = tf.cast(labels, tf.int32) labels = tf.one_hot(labels, 1001) labels = tf.squeeze(labels) with tf.GradientTape() as tape: logits = self.model(images, training=True) prediction_loss = self.get_prediction_loss(labels, logits) loss = tf.reduce_sum(prediction_loss) * (1.0 / self.flags_obj.batch_size) if not self.use_lars_optimizer: num_replicas = self.strategy.num_replicas_in_sync if self.flags_obj.single_l2_loss_op: l2_loss = self.flags_obj.weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in self.model.trainable_variables if ('bn' not in v.name) ]) loss += (l2_loss / num_replicas) else: loss += (tf.reduce_sum(self.model.losses) / num_replicas) if horovod_enabled(): tape = hvd.DistributedGradientTape(tape) grads = tape.gradient(loss, self.model.trainable_variables) grads_and_vars = zip(grads, self.model.trainable_variables) self.optimizer.apply_gradients( grads_and_vars, experimental_aggregate_gradients=False) tf.cond(self.global_step == 1, lambda: hvd.broadcast_variables(self.model.variables + self.optimizer.variables(), root_rank=0), lambda: tf.constant(True)) else: grad_utils.minimize_using_explicit_allreduce( tape, self.optimizer, loss, self.model.trainable_variables) self.train_loss.update_state(loss) self.train_accuracy.update_state(labels, logits)
def train_step(images, labels, step): with tf.GradientTape() as tape: predictions = model(images, training=True) loss = loss_object(labels, predictions) if horovod_enabled(): tape = hvd.DistributedGradientTape(tape) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables), experimental_aggregate_gradients=True) tf.cond( step == 0, lambda: hvd.broadcast_variables( model.variables + optimizer.variables(), root_rank=0), lambda: tf.constant(True)) else: grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss, model.trainable_variables) train_loss(loss) train_accuracy(labels, predictions)
def step_fn(inputs): """Function to run on the device.""" images, labels = inputs if self.one_hot: labels = tf.cast(labels, tf.int32) labels = tf.one_hot(labels, 1001) labels = tf.squeeze(labels) with tf.GradientTape() as tape: logits = self.model(images, training=True) prediction_loss = self.get_prediction_loss(labels, logits) loss = tf.reduce_sum(prediction_loss) * ( 1.0 / self.flags_obj.batch_size) if not self.use_lars_optimizer: num_replicas = self.strategy.num_replicas_in_sync if self.flags_obj.single_l2_loss_op: l2_loss = self.flags_obj.weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in self.model.trainable_variables if ('bn' not in v.name) ]) loss += (l2_loss / num_replicas) else: loss += (tf.reduce_sum(self.model.losses) / num_replicas) if horovod_enabled(): tape = hvd.DistributedGradientTape(tape) grads = tape.gradient(loss, self.model.trainable_variables) grads_and_vars = zip(grads, self.model.trainable_variables) self.optimizer.apply_gradients( grads_and_vars, experimental_aggregate_gradients=False) tf.cond( self.global_step == 1, lambda: hvd.broadcast_variables( self.model.variables + self.optimizer.variables(), root_rank=0), lambda: tf.constant(True)) else: grad_utils.minimize_using_explicit_allreduce( tape, self.optimizer, loss, self.model.trainable_variables) if self.flags_obj.modeling: sess = tf.compat.v1.Session() # pbtxt generation tf.io.write_graph(sess.graph.as_graph_def(add_shapes=True), self.flags_obj.model_dir, 'graph.pbtxt') # meta graph generation tf.compat.v1.train.export_meta_graph( filename='checkpoint_model.meta', meta_info_def=None, graph_def=None, saver_def=None, collection_list=None, as_text=False, graph=None, export_scope=None, clear_devices=False, clear_extraneous_savers=False, strip_default_attrs=False, save_debug_info=False) if self.train_loss: self.train_loss.update_state(loss) if self.train_accuracy: self.train_accuracy.update_state(labels, logits)