def train_step(self, sequence): (X, A), y, out_weight = next(iter(sequence)) cfg = self.cfg.fit U, V = self.cache.U, self.cache.V model = self.model loss_fn = getattr(model, LOSS) metrics = getattr(model, METRICS) optimizer = model.optimizer with tf.GradientTape() as tape: tape.watch([U, V]) A0 = (U * V) @ tf.transpose(U) out = model([X, A0], training=True) out = gather(out, out_weight) loss = loss_fn(y, out) U_grad, V_grad = tape.gradient(loss, [U, V]) U_grad = cfg.eps1 * U_grad / tf.norm(U_grad) V_grad = cfg.eps2 * V_grad / tf.norm(V_grad) U_hat = U + U_grad V_hat = V + V_grad with tf.GradientTape() as tape: A1 = (U_hat * V) @ tf.transpose(U_hat) A2 = (U * V_hat) @ tf.transpose(U) out0 = model([X, A0], training=True) out0 = gather(out0, out_weight) out1 = model([X, A1], training=True) out1 = gather(out1, out_weight) out2 = model([X, A2], training=True) out2 = gather(out2, out_weight) loss = loss_fn(y, out0) + tf.reduce_sum(model.losses) loss += cfg.lamb1 * loss_fn(y, out1) + cfg.lamb2 * loss_fn(y, out2) if isinstance(metrics, list): for metric in metrics: metric.update_state(y, out) else: metrics.update_state(y, out) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) results = [loss] + [ metric.result() for metric in getattr(metrics, "metrics", metrics) ] return dict(zip(model.metrics_names, results))
def train_step_on_batch(self, x, y=None, out_weight=None, device="CPU"): # FIXME: self.metrics would return '[]' for tensorflow>=2.2.0 # See <https://github.com/tensorflow/tensorflow/issues/37990> # the loss or metrics must be called to build the compiled_loss # or compiled_metrics loss_fn = getattr(self, LOSS) metrics = getattr(self, METRICS) optimizer = self.optimizer with tf.device(device): with tf.GradientTape() as tape: out = self(x, training=True) out = gather(out, out_weight) loss = loss_fn(y, out) + tf.reduce_sum(self.losses) if isinstance(metrics, list): for metric in metrics: metric.update_state(y, out) else: metrics.update_state(y, out) grad = tape.gradient(loss, self.trainable_variables) optimizer.apply_gradients(zip(grad, self.trainable_variables)) results = [loss] + [ metric.result() for metric in getattr(metrics, "metrics", metrics) ] return dict(zip(self.metrics_names, results))
def predict_step_on_batch(self, x, out_weight=None, return_logits=True, device="CPU"): with tf.device(device): out = self(x, training=False) out = gather(out, out_weight) if not return_logits: out = softmax(out) return out
def test_step_on_batch(self, x, y=None, out_weight=None, device="CPU"): loss_fn = getattr(self, LOSS) metrics = getattr(self, METRICS) with tf.device(device): out = self(x, training=False) out = gather(out, out_weight) loss = loss_fn(y, out) + tf.reduce_sum(self.losses) if isinstance(metrics, list): for metric in metrics: metric.update_state(y, out) else: metrics.update_state(y, out) results = [loss] + [ metric.result() for metric in getattr(metrics, "metrics", metrics) ] return dict(zip(self.metrics_names, results))
def train_step(self, sequence): model = self.model cfg = self.cfg.fit loss_fn = getattr(model, LOSS) metrics = getattr(model, METRICS) optimizer = model.optimizer with tf.device(self.device): for inputs, y, out_weight in sequence: x, adj, adv_mask = inputs with tf.GradientTape() as tape: logit = model([x, adj], training=True) out = gather(logit, out_weight) loss = loss_fn(y, out) entropy_loss = entropy_y_x(logit) vat_loss = self.virtual_adversarial_loss(x, adj, logit=logit, adv_mask=adv_mask) loss += cfg.p1 * vat_loss + cfg.p2 * entropy_loss if isinstance(metrics, list): for metric in metrics: metric.update_state(y, out) else: metrics.update_state(y, out) trainable_variables = model.trainable_variables gradients = tape.gradient(loss, trainable_variables) optimizer.apply_gradients(zip(gradients, trainable_variables)) results = [loss] + [ metric.result() for metric in getattr(metrics, "metrics", metrics) ] return dict(zip(model.metrics_names, results))