def test_batch(self, inputs): model = self.model inputs = cast(inputs, self.dtype) preds = model(inputs, training=False) preds = cast(preds, tf.float32) return preds
def train_batch(self, batch): model = self.model optimizer_arch, optimizer_model = self.optimizers inputs, target = batch with tf.GradientTape() as tape: inputs = cast(inputs, self.dtype) logits = model(inputs, training=True) logits = cast(logits, tf.float32) per_example_loss = self.criterion(target, logits) loss = self.reduce_loss(per_example_loss) if self.add_arch_loss: arch_loss = model.arch_loss() arch_loss = tf.reduce_mean(arch_loss) loss = loss + arch_loss variables = model.trainable_variables grads = tape.gradient(loss, variables) model_slice, arch_slice = model.param_splits() self.apply_gradients(optimizer_model, grads[model_slice], variables[model_slice], self.grad_clip_norm) self.apply_gradients(optimizer_arch, grads[arch_slice], variables[arch_slice]) self.update_metrics(self.train_metrics, target, logits, per_example_loss)
def simple_eval_batch(self, batch): model = self.model inputs, target = batch inputs = cast(inputs, self.dtype) preds = model(inputs, training=False) preds = cast(preds, tf.float32) return target, preds
def eval_batch(self, batch): model = self.model inputs, target = batch inputs = cast(inputs, self.dtype) preds = model(inputs, training=False) preds = cast(preds, tf.float32) self.update_metrics(self.eval_metrics, target, preds)
def train_batch(self, batch): model = self.model optimizer = self.optimizers[0] inputs, target = batch if self.batch_transform is not None: inputs, target = self.batch_transform(inputs, target) with tf.GradientTape() as tape: inputs = cast(inputs, self.dtype) preds = model(inputs, training=True) preds = cast(preds, tf.float32) per_example_loss = self.criterion(target, preds) loss = self.reduce_loss(per_example_loss) if self.dtype == tf.float16: loss = optimizer.get_scaled_loss(loss) self.minimize(tape, optimizer, loss, model.trainable_variables) self.update_metrics(self.train_metrics, target, preds, per_example_loss)