Ejemplo n.º 1
0
    def test_batch(self, inputs):
        model = self.model

        inputs = cast(inputs, self.dtype)
        preds = model(inputs, training=False)
        preds = cast(preds, tf.float32)
        return preds
Ejemplo n.º 2
0
    def train_batch(self, batch):
        model = self.model
        optimizer_arch, optimizer_model = self.optimizers

        inputs, target = batch
        with tf.GradientTape() as tape:
            inputs = cast(inputs, self.dtype)
            logits = model(inputs, training=True)
            logits = cast(logits, tf.float32)
            per_example_loss = self.criterion(target, logits)
            loss = self.reduce_loss(per_example_loss)
            if self.add_arch_loss:
                arch_loss = model.arch_loss()
                arch_loss = tf.reduce_mean(arch_loss)
                loss = loss + arch_loss

        variables = model.trainable_variables
        grads = tape.gradient(loss, variables)
        model_slice, arch_slice = model.param_splits()
        self.apply_gradients(optimizer_model, grads[model_slice],
                             variables[model_slice], self.grad_clip_norm)
        self.apply_gradients(optimizer_arch, grads[arch_slice],
                             variables[arch_slice])
        self.update_metrics(self.train_metrics, target, logits,
                            per_example_loss)
Ejemplo n.º 3
0
    def simple_eval_batch(self, batch):
        model = self.model

        inputs, target = batch
        inputs = cast(inputs, self.dtype)
        preds = model(inputs, training=False)
        preds = cast(preds, tf.float32)
        return target, preds
Ejemplo n.º 4
0
    def eval_batch(self, batch):
        model = self.model

        inputs, target = batch
        inputs = cast(inputs, self.dtype)
        preds = model(inputs, training=False)
        preds = cast(preds, tf.float32)
        self.update_metrics(self.eval_metrics, target, preds)
Ejemplo n.º 5
0
    def train_batch(self, batch):
        model = self.model
        optimizer = self.optimizers[0]

        inputs, target = batch
        if self.batch_transform is not None:
            inputs, target = self.batch_transform(inputs, target)
        with tf.GradientTape() as tape:
            inputs = cast(inputs, self.dtype)
            preds = model(inputs, training=True)
            preds = cast(preds, tf.float32)
            per_example_loss = self.criterion(target, preds)
            loss = self.reduce_loss(per_example_loss)
            if self.dtype == tf.float16:
                loss = optimizer.get_scaled_loss(loss)
        self.minimize(tape, optimizer, loss, model.trainable_variables)
        self.update_metrics(self.train_metrics, target, preds,
                            per_example_loss)