Esempio n. 1
0
    def train_step(self, sequence):
        (X, A), y, out_weight = next(iter(sequence))
        cfg = self.cfg.fit

        U, V = self.cache.U, self.cache.V
        model = self.model
        loss_fn = getattr(model, LOSS)
        metrics = getattr(model, METRICS)
        optimizer = model.optimizer

        with tf.GradientTape() as tape:
            tape.watch([U, V])
            A0 = (U * V) @ tf.transpose(U)
            out = model([X, A0], training=True)
            out = gather(out, out_weight)
            loss = loss_fn(y, out)

        U_grad, V_grad = tape.gradient(loss, [U, V])
        U_grad = cfg.eps1 * U_grad / tf.norm(U_grad)
        V_grad = cfg.eps2 * V_grad / tf.norm(V_grad)

        U_hat = U + U_grad
        V_hat = V + V_grad

        with tf.GradientTape() as tape:
            A1 = (U_hat * V) @ tf.transpose(U_hat)
            A2 = (U * V_hat) @ tf.transpose(U)

            out0 = model([X, A0], training=True)
            out0 = gather(out0, out_weight)
            out1 = model([X, A1], training=True)
            out1 = gather(out1, out_weight)
            out2 = model([X, A2], training=True)
            out2 = gather(out2, out_weight)

            loss = loss_fn(y, out0) + tf.reduce_sum(model.losses)
            loss += cfg.lamb1 * loss_fn(y, out1) + cfg.lamb2 * loss_fn(y, out2)
            if isinstance(metrics, list):
                for metric in metrics:
                    metric.update_state(y, out)
            else:
                metrics.update_state(y, out)

        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        results = [loss] + [
            metric.result() for metric in getattr(metrics, "metrics", metrics)
        ]
        return dict(zip(model.metrics_names, results))
Esempio n. 2
0
    def train_step_on_batch(self, x, y=None, out_weight=None, device="CPU"):
        # FIXME: self.metrics would return '[]' for tensorflow>=2.2.0
        # See <https://github.com/tensorflow/tensorflow/issues/37990>
        # the loss or metrics must be called to build the compiled_loss
        # or compiled_metrics
        loss_fn = getattr(self, LOSS)
        metrics = getattr(self, METRICS)
        optimizer = self.optimizer

        with tf.device(device):
            with tf.GradientTape() as tape:
                out = self(x, training=True)
                out = gather(out, out_weight)
                loss = loss_fn(y, out) + tf.reduce_sum(self.losses)
                if isinstance(metrics, list):
                    for metric in metrics:
                        metric.update_state(y, out)
                else:
                    metrics.update_state(y, out)

            grad = tape.gradient(loss, self.trainable_variables)
            optimizer.apply_gradients(zip(grad, self.trainable_variables))

            results = [loss] + [
                metric.result()
                for metric in getattr(metrics, "metrics", metrics)
            ]
            return dict(zip(self.metrics_names, results))
Esempio n. 3
0
 def predict_step_on_batch(self,
                           x,
                           out_weight=None,
                           return_logits=True,
                           device="CPU"):
     with tf.device(device):
         out = self(x, training=False)
         out = gather(out, out_weight)
         if not return_logits:
             out = softmax(out)
     return out
Esempio n. 4
0
    def test_step_on_batch(self, x, y=None, out_weight=None, device="CPU"):
        loss_fn = getattr(self, LOSS)
        metrics = getattr(self, METRICS)

        with tf.device(device):
            out = self(x, training=False)
            out = gather(out, out_weight)
            loss = loss_fn(y, out) + tf.reduce_sum(self.losses)
            if isinstance(metrics, list):
                for metric in metrics:
                    metric.update_state(y, out)
            else:
                metrics.update_state(y, out)

            results = [loss] + [
                metric.result()
                for metric in getattr(metrics, "metrics", metrics)
            ]
            return dict(zip(self.metrics_names, results))
Esempio n. 5
0
    def train_step(self, sequence):
        model = self.model
        cfg = self.cfg.fit

        loss_fn = getattr(model, LOSS)
        metrics = getattr(model, METRICS)
        optimizer = model.optimizer

        with tf.device(self.device):

            for inputs, y, out_weight in sequence:
                x, adj, adv_mask = inputs
                with tf.GradientTape() as tape:
                    logit = model([x, adj], training=True)
                    out = gather(logit, out_weight)
                    loss = loss_fn(y, out)
                    entropy_loss = entropy_y_x(logit)
                    vat_loss = self.virtual_adversarial_loss(x,
                                                             adj,
                                                             logit=logit,
                                                             adv_mask=adv_mask)
                    loss += cfg.p1 * vat_loss + cfg.p2 * entropy_loss

                    if isinstance(metrics, list):
                        for metric in metrics:
                            metric.update_state(y, out)
                    else:
                        metrics.update_state(y, out)

                trainable_variables = model.trainable_variables
                gradients = tape.gradient(loss, trainable_variables)
                optimizer.apply_gradients(zip(gradients, trainable_variables))

            results = [loss] + [
                metric.result()
                for metric in getattr(metrics, "metrics", metrics)
            ]
            return dict(zip(model.metrics_names, results))