Esempio n. 1
0
    def train_step(self, data):
        x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
        x, y, sample_weight = data_adapter.expand_1d((x, y, sample_weight))

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y,
                                      y_pred,
                                      sample_weight,
                                      regularization_losses=self.losses)
        self.compiled_metrics.update_state(y, y_pred, sample_weight)

        if isinstance(self.optimizer, (list, tuple)):
            linear_vars = self.linear_model.trainable_variables
            dnn_vars = self.dnn_model.trainable_variables
            linear_grads, dnn_grads = tape.gradient(loss,
                                                    (linear_vars, dnn_vars))

            linear_optimizer = self.optimizer[0]
            dnn_optimizer = self.optimizer[1]
            linear_optimizer.apply_gradients(zip(linear_grads, linear_vars))
            dnn_optimizer.apply_gradients(zip(dnn_grads, dnn_vars))
        else:
            trainable_variables = self.trainable_variables
            grads = tape.gradient(loss, trainable_variables)
            self.optimizer.apply_gradients(zip(grads, trainable_variables))

        return {m.name: m.result() for m in self.metrics}
Esempio n. 2
0
 def test_expand_1d_sparse_tensors_untouched(self):
     st = tf.SparseTensor(indices=[[0], [10]],
                          values=[1, 2],
                          dense_shape=[10])
     st = data_adapter.expand_1d(st)
     self.assertEqual(st.shape.rank, 1)