Exemplo n.º 1
0
class SBVAT(SemiSupervisedModel):
    """
        Implementation of sample-based Batch Virtual Adversarial Training
        Graph Convolutional Networks (SBVAT).
        `Batch Virtual Adversarial Training for Graph Convolutional Networks
        <https://arxiv.org/abs/1902.09192>`
        Tensorflow 1.x implementation: <https://github.com/thudzj/BVAT>


    """
    def __init__(self,
                 *graph,
                 n_samples=50,
                 adj_transform="normalize_adj",
                 attr_transform=None,
                 device='cpu:0',
                 seed=None,
                 name=None,
                 **kwargs):
        """Create a sample-based Batch Virtual Adversarial Training
        Graph Convolutional Networks (SBVAT) model.

         This can be instantiated in several ways:

            model = SBVAT(graph)
                with a `graphgallery.data.Graph` instance representing
                A sparse, attributed, labeled graph.

            model = SBVAT(adj_matrix, attr_matrix, labels)
                where `adj_matrix` is a 2D Scipy sparse matrix denoting the graph,
                 `attr_matrix` is a 2D Numpy array-like matrix denoting the node
                 attributes, `labels` is a 1D Numpy array denoting the node labels.


        Parameters:
        ----------
        graph: An instance of `graphgallery.data.Graph` or a tuple (list) of inputs.
            A sparse, attributed, labeled graph.
        n_samples (Positive integer, optional):
            The number of sampled subset nodes in the graph where the length of the
            shortest path between them is at least `4`. (default :obj: `50`)
        adj_transform: string, `transform`, or None. optional
            How to transform the adjacency matrix. See `graphgallery.transforms`
            (default: :obj:`'normalize_adj'` with normalize rate `-0.5`.
            i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}})
        attr_transform: string, `transform`, or None. optional
            How to transform the node attribute matrix. See `graphgallery.transforms`
            (default :obj: `None`)
        device: string. optional
            The device where the model is running on. You can specified `CPU` or `GPU`
            for the model. (default: :str: `CPU:0`, i.e., running on the 0-th `CPU`)
        seed: interger scalar. optional
            Used in combination with `tf.random.set_seed` & `np.random.seed`
            & `random.seed` to create a reproducible sequence of tensors across
            multiple calls. (default :obj: `None`, i.e., using random seed)
        name: string. optional
            Specified name for the model. (default: :str: `class.__name__`)
        kwargs: other customized keyword Parameters.
        """
        super().__init__(*graph, device=device, seed=seed, name=name, **kwargs)

        self.adj_transform = T.get(adj_transform)
        self.attr_transform = T.get(attr_transform)
        self.n_samples = n_samples
        self.process()

    def process_step(self):
        graph = self.graph
        adj_matrix = self.adj_transform(graph.adj_matrix)
        attr_matrix = self.attr_transform(graph.attr_matrix)
        self.neighbors = find_4o_nbrs(adj_matrix)

        self.feature_inputs, self.structure_inputs = T.astensors(
            attr_matrix, adj_matrix, device=self.device)

    # use decorator to make sure all list arguments have the same length
    @EqualVarLength()
    def build(self,
              hiddens=[16],
              activations=['relu'],
              dropout=0.5,
              lr=0.01,
              l2_norm=5e-4,
              use_bias=False,
              p1=1.,
              p2=1.,
              n_power_iterations=1,
              epsilon=0.03,
              xi=1e-6):

        with tf.device(self.device):

            x = Input(batch_shape=[None, self.graph.n_attrs],
                      dtype=self.floatx,
                      name='attr_matrix')
            adj = Input(batch_shape=[None, None],
                        dtype=self.floatx,
                        sparse=True,
                        name='adj_matrix')
            index = Input(batch_shape=[None],
                          dtype=self.intx,
                          name='node_index')

            GCN_layers = []
            dropout_layers = []
            for hidden, activation in zip(hiddens, activations):
                GCN_layers.append(
                    GraphConvolution(
                        hidden,
                        activation=activation,
                        use_bias=use_bias,
                        kernel_regularizer=regularizers.l2(l2_norm)))
                dropout_layers.append(Dropout(rate=dropout))

            GCN_layers.append(
                GraphConvolution(self.graph.n_classes, use_bias=use_bias))
            self.GCN_layers = GCN_layers
            self.dropout_layers = dropout_layers

            logit = self.forward(x, adj)
            output = Gather()([logit, index])
            model = Model(inputs=[x, adj, index], outputs=output)

            self.model = model
            self.train_metric = SparseCategoricalAccuracy()
            self.test_metric = SparseCategoricalAccuracy()
            self.loss_fn = SparseCategoricalCrossentropy(from_logits=True)
            self.optimizer = Adam(lr=lr)

        self.p1 = p1  # Alpha
        self.p2 = p2  # Beta
        self.xi = xi  # Small constant for finite difference
        # Norm length for (virtual) adversarial training
        self.epsilon = epsilon
        self.n_power_iterations = n_power_iterations  # Number of power iterations

    def forward(self, x, adj, training=True):
        h = x
        for dropout_layer, GCN_layer in zip(self.dropout_layers,
                                            self.GCN_layers[:-1]):
            h = GCN_layer([h, adj])
            h = dropout_layer(h, training=training)
        h = self.GCN_layers[-1]([h, adj])
        return h

    @tf.function
    def train_step(self, sequence):

        with tf.device(self.device):
            self.train_metric.reset_states()

            for inputs, labels in sequence:
                x, adj, index, adv_mask = inputs
                with tf.GradientTape() as tape:
                    logit = self.forward(x, adj)
                    output = tf.gather(logit, index)
                    loss = self.loss_fn(labels, output)
                    entropy_loss = entropy_y_x(logit)
                    vat_loss = self.virtual_adversarial_loss(x,
                                                             adj,
                                                             logit=logit,
                                                             adv_mask=adv_mask)
                    loss += self.p1 * vat_loss + self.p2 * entropy_loss

                    self.train_metric.update_state(labels, output)

                trainable_variables = self.model.trainable_variables
                gradients = tape.gradient(loss, trainable_variables)
                self.optimizer.apply_gradients(
                    zip(gradients, trainable_variables))

            return loss, self.train_metric.result()

    @tf.function
    def test_step(self, sequence):

        with tf.device(self.device):
            self.test_metric.reset_states()

            for inputs, labels in sequence:
                x, adj, index, _ = inputs
                logit = self.forward(x, adj, training=False)
                output = tf.gather(logit, index)
                loss = self.loss_fn(labels, output)
                self.test_metric.update_state(labels, output)

            return loss, self.test_metric.result()

    def virtual_adversarial_loss(self, x, adj, logit, adv_mask):
        d = tf.random.normal(shape=tf.shape(x), dtype=self.floatx)

        for _ in range(self.n_power_iterations):
            d = get_normalized_vector(d) * self.xi
            logit_p = logit
            with tf.GradientTape() as tape:
                tape.watch(d)
                logit_m = self.forward(x + d, adj)
                dist = kl_divergence_with_logit(logit_p, logit_m, adv_mask)
            grad = tape.gradient(dist, d)
            d = tf.stop_gradient(grad)

        r_vadv = get_normalized_vector(d) * self.epsilon
        logit_p = tf.stop_gradient(logit)
        logit_m = self.forward(x + r_vadv, adj)
        loss = kl_divergence_with_logit(logit_p, logit_m, adv_mask)
        return tf.identity(loss)

    def train_sequence(self, index):
        index = T.asintarr(index)
        labels = self.graph.labels[index]

        sequence = SBVATSampleSequence(
            [self.feature_inputs, self.structure_inputs, index],
            labels,
            neighbors=self.neighbors,
            n_samples=self.n_samples,
            device=self.device)

        return sequence

    def test_sequence(self, index):
        index = T.asintarr(index)
        labels = self.graph.labels[index]

        sequence = SBVATSampleSequence(
            [self.feature_inputs, self.structure_inputs, index],
            labels,
            neighbors=self.neighbors,
            n_samples=self.n_samples,
            resample=False,
            device=self.device)

        return sequence

    def predict_step(self, sequence):
        with tf.device(self.device):
            for inputs, _ in sequence:
                x, adj, index, adv_mask = inputs
                output = self.forward(x, adj, training=False)
                logit = tf.gather(output, index)

        if tf.is_tensor(logit):
            logit = logit.numpy()
        return logit
Exemplo n.º 2
0
def train(input_params, train, test, valid, class_cnt):
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

    # tensorboard
    train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
    valid_log_dir = 'logs/gradient_tape/' + current_time + '/valid'
    test_log_dir = 'logs/gradient_tape/' + current_time + '/test'
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    valid_summary_writer = tf.summary.create_file_writer(valid_log_dir)
    # test_summary_writer = tf.summary.create_file_writer(test_log_dir)

    # todo: create model with hyperparams with model_dir = '../data/models/params/current_time/'
    model_dir = '../data/models/model-' + current_time
    # Instantiate an optimizer.
    optimizer = Adam(learning_rate=0.001)
    # Instantiate a loss function.
    loss_fn = SparseCategoricalCrossentropy(from_logits=True)
    train_step = test_step = 0

    # Prepare the metrics.
    #todo use same variable for all the acc_metrics.
    acc_metric = SparseCategoricalAccuracy()

    if utility.dir_empty(model_dir):
        # model definition
        mobilenet = MOBILENET(include_top=False,
                              input_shape=(224, 224, 3),
                              weights='imagenet',
                              pooling='avg',
                              dropout=0.001)
        mobilenet.summary()
        # select till which layer use mobilenet.
        base_model = Model(inputs=mobilenet.input, outputs=mobilenet.output)
        base_model.summary()

        model = Sequential([
            base_model,
            Dropout(0.2),
            Dense(units=class_cnt, activation='softmax'),
        ])
        model.summary()

        epochs = 200
        for epoch in range(epochs):
            print("\nStart of epoch %d" % (epoch,))
            for batch_idx, (x_batch_train, y_batch_train) in enumerate(train):
                with tf.GradientTape() as tape:
                    # forward pass
                    logits = model(x_batch_train, training=True)

                    # compute loss for mini batch
                    loss_value = loss_fn(y_batch_train, logits)

                grads = tape.gradient(loss_value, model.trainable_weights)

                optimizer.apply_gradients(zip(grads, model.trainable_weights))

                # Update training metric.
                acc_metric.update_state(y_batch_train, logits)

                with train_summary_writer.as_default():
                    # import code; code.interact(local=dict(globals(), **locals()))
                    #TODO: add the metrics for test too.
                    #TODO: take the mean of the losses in every batch and then show,
                    #TODO       loss_value is last loss of the batch(only 1).
                            
                    tf.summary.scalar('loss', loss_value, step=train_step)
                    tf.summary.scalar('accuracy', acc_metric.result(), step=train_step)
                    train_step += 1

                if batch_idx % 10 == 0:
                    print("training loss for one batch at step %d: %.4f" % (batch_idx, float(loss_value)))
            # Display metrics at the end of each epoch.
            
            print("Training acc over epoch: %.4f" % (float(acc_metric.result()),))

            # Reset training metrics at the end of each epoch
            acc_metric.reset_states()


            # iterate on validation 
            for batch_idx, (x_batch_val, y_batch_val) in enumerate(valid):
                # val_logits: y_pred of the validation. 
                val_logits = model(x_batch_val, training=False)
                loss = loss_fn(y_batch_val, val_logits)
                # Update val metrics
                acc_metric.update_state(y_batch_val, val_logits)

                with valid_summary_writer.as_default():
                    tf.summary.scalar('loss', loss, step=test_step)
                    tf.summary.scalar('accuracy', acc_metric.result(), step=test_step)
                    test_step += 1
                
            print("Validation acc: %.4f" % (float(acc_metric.result()),))
            # print(classification_report(y_batch_val, val_logits, target_names=labels))
            acc_metric.reset_states()
        
        acc_metric.reset_states()
        model.save(model_dir + 'model')
        
    else:  # if model_dir is not empty
        print("model already exist. loading model...")
        model = load_model(model_dir+'model')
Exemplo n.º 3
0
@tf.function
def testing(images, labels):
    predicts = model(images)
    t_loss = loss_(labels, predicts)

    test_loss(t_loss)
    test_accuracy(labels, predicts)


# TRAINING
for epoch in range(EPOCHS):
    for train_images, train_labels in train:
        training(train_images, train_labels)

    for test_images, test_labels in test:
        testing(test_images, test_labels)

    to_print = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(
        to_print.format(epoch + 1, train_loss.result(),
                        train_accuracy.result() * 100, test_loss.result(),
                        test_accuracy.result() * 100))

    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    model.save_weights('model', save_format='tf')
Exemplo n.º 4
0
    train_reporter()
    # print(colored('Epoch: ', 'red', 'on_white'), epoch + 1)
    # template = 'Train Loss: {:.4f}\t Train Accuracy: {:.2f}%\n' + \
    #            'Validation Loss: {:.4f}\t Validation Accuracy: {:.2f}%\n'
    # print(template.format(train_loss.result(), train_acc.result()*100,
    #                       validation_loss.result(), validation_acc.result()*100))

    # metric_resetter()
    train_losses.append(train_loss.result())
    train_accs.append(train_acc.result())
    validation_losses.append(validation_loss.result() * 100)
    validation_accs.append(validation_acc.result() * 100)

    train_loss.reset_states()
    train_acc.reset_states()
    validation_loss.reset_states()
    validation_acc.reset_states()

for x, y in test_ds:
    predictions = model(x)
    loss = loss_object(y, predictions)

    test_loss(loss)
    test_acc(y, predictions)

print(colored('Final Result: ', 'red', 'on_white'), epoch + 1)
template = 'Test Loss: {:.4f}\t Test Accuracy: {:.2f}%\n'
print(template.format(test_loss.result(), test_acc.result() * 100))

final_result_visualization()
Exemplo n.º 5
0
        testing(val_images, val_labels)

    to_print = 'Epoch {}, Loss: {}, Accuracy: {}, Valid Loss: {}, Valid Accuracy: {}'
    print(
        to_print.format(epoch + 1, train_loss.result(),
                        train_accuracy.result() * 100, val_loss.result(),
                        val_accuracy.result() * 100))
    train_l.append(train_loss.result())
    train_a.append(train_accuracy.result())
    val_l.append(val_loss.result())
    val_a.append(val_accuracy.result())
    epochs.append(epoch)

    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    val_loss.reset_states()
    val_accuracy.reset_states()

    model.save_weights('model', save_format='tf')

plt.figure(figsize=(24, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs, val_a, label="validation_accuracy", c="red")
plt.plot(epochs, train_a, label="training_accuracy", c="green")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs, val_l, label="validation_loss", c="red")
plt.plot(epochs, train_l, label="training_loss", c="green")
plt.legend()
Exemplo n.º 6
0
class PointNetSoftmaxClassification:
    """
    This class implements Softmax classification on the point cloud data
    """
    def __init__(self, model, train_log_dir, test_log_dir, manager):
        self._model = model

        self._loss_fn = tf.nn.sparse_softmax_cross_entropy_with_logits
        self._manager = manager

        self._train_loss = Mean(name='train_loss')
        self._test_loss = Mean(name='test_loss')

        self._train_acc = SparseCategoricalAccuracy(name='train_acc')
        self._test_acc = SparseCategoricalAccuracy(name='test_acc')

        self._train_loss.reset_states()
        self._test_loss.reset_states()

        self._train_acc.reset_states()
        self._test_acc.reset_states()

        os.makedirs(train_log_dir, exist_ok=True)
        os.makedirs(test_log_dir, exist_ok=True)

        self._train_summary_writer = create_file_writer(train_log_dir)
        self._test_summary_writer = create_file_writer(test_log_dir)

    @tf.function
    def _train_step(self, inputs, labels):
        with tf.GradientTape() as tape:
            predictions = self._model(inputs, training=True)
            loss = self._loss_fn(labels, predictions)

        gradients = tape.gradient(loss, self._model.trainable_variables)
        self._model.optimizer.apply_gradients(
            zip(gradients, self._model.trainable_variables))
        self._train_loss(loss)

    @tf.function
    def _test_step(self, inputs, labels):
        predictions = self._model(inputs)
        self._test_loss(self._loss_fn(labels, predictions))
        self._test_acc(labels, predictions)

    @tf.function
    def _train_acc_step(self, inputs, labels):
        predictions = self._model(inputs)
        self._train_acc(labels, predictions)

    def train(self, train_data, test_data, num_epochs, init_epoch=0):
        for ep_idx in range(num_epochs):
            print(f"===== Epoch {ep_idx+init_epoch+1} =====")
            for input_batch, label_batch in train_data.data:
                self._train_step(input_batch, label_batch)

            for input_batch, label_batch in train_data.data:
                self._train_acc_step(input_batch, label_batch)

            for input_batch, label_batch in test_data.data:
                self._test_step(input_batch, label_batch)

            self._update_log(ep_idx + init_epoch + 1)

            self._train_loss.reset_states()
            self._train_acc.reset_states()

            self._test_loss.reset_states()
            self._test_acc.reset_states()

            self._manager.save()

    def test(self, train_data, test_data):
        for input_batch, label_batch in train_data.data:
            self._train_acc_step(input_batch, label_batch)

        for input_batch, label_batch in test_data.data:
            self._test_step(input_batch, label_batch)

        print(f"Train: {self._train_acc.result().numpy()}")
        print(f"Test:  {self._test_acc.result().numpy()}")

    def _update_log(self, epoch_idx):
        with self._train_summary_writer.as_default():
            tf.summary.scalar('loss',
                              self._train_loss.result().numpy(),
                              step=epoch_idx)
            tf.summary.scalar('acc',
                              self._train_acc.result().numpy(),
                              step=epoch_idx)
        with self._test_summary_writer.as_default():
            tf.summary.scalar('loss',
                              self._test_loss.result().numpy(),
                              step=epoch_idx)
            tf.summary.scalar('acc',
                              self._test_acc.result().numpy(),
                              step=epoch_idx)
Exemplo n.º 7
0
Arquivo: model.py Projeto: jh88/fbnet
class Trainer():
    def __init__(self,
                 fbnet,
                 input_shape,
                 initial_temperature=5,
                 temperature_decay_rate=0.956,
                 temperature_decay_steps=1,
                 latency_alpha=0.2,
                 latency_beta=0.6,
                 weight_lr=0.01,
                 weight_momentum=0.9,
                 weight_decay=1e-4,
                 theta_lr=1e-3,
                 theta_beta1=0.9,
                 theta_beta2=0.999,
                 theta_decay=5e-4):
        self._epoch = 0

        self.initial_temperature = initial_temperature
        self.temperature = initial_temperature
        self.latency_alpha = latency_alpha
        self.latency_beta = latency_beta

        self.exponential_decay = lambda step: exponential_decay(
            initial_temperature, temperature_decay_rate,
            temperature_decay_steps, step)

        fbnet.build(input_shape)
        self.fbnet = fbnet

        self.weights = []
        self.thetas = []
        for trainable_weight in fbnet.trainable_weights:
            if 'theta' in trainable_weight.name:
                self.thetas.append(trainable_weight)
            else:
                self.weights.append(trainable_weight)

        self.weight_opt = SGD(learning_rate=weight_lr,
                              momentum=weight_momentum,
                              decay=weight_decay)

        self.theta_opt = Adam(learning_rate=theta_lr,
                              beta_1=theta_beta1,
                              beta_2=theta_beta2,
                              decay=theta_decay)

        self.loss_fn = SparseCategoricalCrossentropy(from_logits=True)
        self.accuracy_metric = SparseCategoricalAccuracy()
        self.loss_metric = Mean()

    @property
    def epoch(self):
        return self._epoch

    @epoch.setter
    def epoch(self, epoch):
        self._epoch = epoch
        self.temperature = self.exponential_decay(epoch)

    def reset_metrics(self):
        self.accuracy_metric.reset_states()
        self.loss_metric.reset_states()

    def _train(self, x, y, weights, opt, training=True):
        with tf.GradientTape() as tape:
            y_hat = self.fbnet(x, self.temperature, training=training)
            loss = self.loss_fn(y, y_hat)
            latency = sum(self.fbnet.losses)
            loss += latency_loss(latency, self.latency_alpha,
                                 self.latency_beta)

        grads = tape.gradient(loss, weights)
        opt.apply_gradients(zip(grads, weights))

        self.accuracy_metric.update_state(y, y_hat)
        self.loss_metric.update_state(loss)

    @tf.function
    def train_weights(self, x, y):
        self._train(x, y, self.weights, self.weight_opt)

    @tf.function
    def train_thetas(self, x, y):
        self._train(x, y, self.thetas, self.theta_opt, training=False)

    @property
    def training_accuracy(self):
        return self.accuracy_metric.result().numpy()

    @property
    def training_loss(self):
        return self.loss_metric.result().numpy()

    @tf.function
    def predict(self, x):
        y_hat = self.fbnet(x, self.temperature, training=False)

        return y_hat

    def evaluate(self, dataset):
        accuracy_metric = SparseCategoricalAccuracy()
        for x, y in dataset:
            y_hat = self.predict(x)

            accuracy_metric.update_state(y, y_hat)

        return accuracy_metric.result().numpy()

    def sample_sequential_config(self):
        ops = [
            op.sample(self.temperature)
            if isinstance(op, MixedOperation) else op for op in self.fbnet.ops
        ]

        sequential_config = {
            'name':
            'sampled_fbnet',
            'layers': [{
                'class_name': type(op).__name__,
                'config': op.get_config()
            } for op in ops if not isinstance(op, Identity)]
        }

        return sequential_config

    def save_weights(self, checkpoint):
        self.fbnet.save_weights(checkpoint, save_format='tf')

    def load_weights(self, checkpoint):
        self.fbnet.load_weights(checkpoint)
Exemplo n.º 8
0
class Train(object):
    def __init__(self, params):
        self.lr = params.lr
        self.epochs = params.epochs
        # Define loss:
        self.loss_object = SparseCategoricalCrossentropy()
        # Define optimizer:
        self.optimizer = Adam()
        # Define metrics for loss:
        self.train_loss = Mean()
        self.train_accuracy = SparseCategoricalAccuracy()
        self.test_loss = Mean()
        self.test_accuracy = SparseCategoricalAccuracy()
        # Define pre processor (params):
        preprocessor = Process(32, 1)
        self.train_ds, self.test_ds, encoder_stats = preprocessor.get_datasets(
        )
        # Define model dims
        d_model = 512
        ff_dim = 2048
        heads = 8
        encoder_dim = 6
        decoder_dim = 6
        dk = d_model / heads
        dv = d_model / heads
        vocab_size = encoder_stats['vocab_size']
        max_pos = 10000
        # Define model:
        self.model = Transformer(d_model, ff_dim, dk, dv, heads, encoder_dim,
                                 decoder_dim, vocab_size, max_pos)
        # Define Checkpoints:
        self.ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                                        optimizer=self.optimizer,
                                        net=self.model)
        # Define Checkpoint manager:
        self.ckpt_manager = tf.train.CheckpointManager(
            self.ckpt, f'checkpoints{params.ckpt_dir}', max_to_keep=3)

    # Feed forward through and update model on train data:
    @tf.function
    def _update(self, inputs, labels):
        dec_input = labels[:, :-1]
        labels = labels[:, 1:]
        inp_mask, latent_mask, dec_mask = masks(inputs, dec_inputs)
        with tf.GradientTape() as tape:
            predictions = self.model(inputs, dec_inputs, inp_mask, latent_mask,
                                     dec_mask, True)
            loss = self.loss_object(labels, predictions)
        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(
            zip(gradients, self.model.trainable_variables))
        self.train_loss(loss)
        self.train_accuracy(labels, predictions)

    # Feed forward through model on test data:
    @tf.function
    def _test(self, inputs, labels):
        predictions = self.model(inputs)
        loss = self.loss_object(labels, predictions)

        self.test_loss(loss)
        self.test_accuracy(labels, predictions)

    # Log status of each epoch:
    def _log(self, epoch):
        template = 'Epoch {}, Loss: {}, Acc: {}, Test Loss: {}, Test Acc: {}'
        print(
            template.format(epoch + 1, self.train_loss.result(),
                            self.train_accuracy.result() * 100,
                            self.test_loss.result(),
                            self.test_accuracy.result() * 100))

    # Save model to checkpoint:
    def _save(self, verbose=False):
        save_path = self.ckpt_manager.save()
        if verbose:
            ckptLog = f"Saved checkpoint for step {int(self.ckpt.step)}: {save_path}"
            print(ckptLog)

    # Restore model from checkpoint:
    def _restore(self):
        self.ckpt.restore(self.ckpt_manager.latest_checkpoint).expect_partial()
        if self.ckpt_manager.latest_checkpoint:
            print(f"Restored from {self.ckpt_manager.latest_checkpoint}")
        else:
            print("Initializing from scratch.")

    # Reset network metrics:
    def _reset(self):
        self.train_loss.reset_states()
        self.train_accuracy.reset_states()
        self.test_loss.reset_states()
        self.test_accuracy.reset_states()

    # Train loop for network:
    def train(self):
        self._restore()
        for epoch in range(self.epochs):
            for batch, (inputs, labels) in enumerate(self.train_ds):
                print(batch)
                self._update(inputs, labels)
            for testInputs, testLabels in self.test_ds:
                self._test(testInputs, testLabels)
            self._log(epoch)
            self._save()
            self._reset()
Exemplo n.º 9
0
class DualStudent(Model):
    """"
    Dual Student for Automatic Speech Recognition (ASR).

    How to train: 1) set the optimizer by means of compile(), 2) use train()
    How to test: use test()

    Remarks:
    - Do not use fit() by Keras, use train()
    - Do not use evaluate() by Keras, use test()
    - Compiled metrics and loss (i.e. set by means of compile()) are not used

    Original proposal for image classification: https://arxiv.org/abs/1909.01804
    """
    def __init__(self,
                 n_classes,
                 n_hidden_layers=3,
                 n_units=96,
                 consistency_loss='mse',
                 consistency_scale=10,
                 stabilization_scale=100,
                 xi=0.6,
                 padding_value=0.,
                 sigma=0.01,
                 schedule='rampup',
                 schedule_length=5,
                 version='mono_directional'):
        """
        Constructs a Dual Student model.

        :param n_classes: number of classes (i.e. number of units in the last layer of each student)
        :param n_hidden_layers: number of hidden layers in each student (i.e. LSTM layers)
        :param n_units: number of units for each hidden layer
        :param consistency_loss: one of 'mse', 'kl'
        :param consistency_scale: maximum value of weight for consistency constraint
        :param stabilization_scale: maximum value of weight for stabilization constraint
        :param xi: threshold for stable sample
        :param padding_value: value used to pad input sequences (used as mask_value for Masking layer)
        :param sigma: standard deviation for noisy augmentation
        :param schedule: type of schedule for lambdas, one of 'rampup', 'triangular_cycling', 'sinusoidal_cycling'
        :param schedule_length:
        :param version: one of:
            - 'mono_directional': both students have mono-directional LSTM layers
            - 'bidirectional: both students have bidirectional LSTM layers
            - 'imbalanced': one student has mono-directional LSTM layers, the other one bidirectional
        """
        super(DualStudent, self).__init__()

        # store parameters
        self.n_classes = n_classes
        self.padding_value = padding_value
        self.n_units = n_units
        self.n_hidden_layers = n_hidden_layers
        self.xi = xi
        self.consistency_scale = consistency_scale
        self.stabilization_scale = stabilization_scale
        self.sigma = sigma
        self.version = version
        self.schedule = schedule
        self.schedule_length = schedule_length
        self._lambda1 = None
        self._lambda2 = None

        # schedule for lambdas
        if schedule == 'rampup':
            self.schedule_fn = sigmoid_rampup
        elif schedule == 'triangular_cycling':
            self.schedule_fn = triangular_cycling
        elif schedule == 'sinusoidal_cycling':
            self.schedule_fn = sinusoidal_cycling
        else:
            raise ValueError('Invalid schedule')

        # loss
        self._loss_cls = SparseCategoricalCrossentropy()  # classification loss
        self._loss_sta = MeanSquaredError()  # stabilization loss
        if consistency_loss == 'mse':
            self._loss_con = MeanSquaredError()  # consistency loss
        elif consistency_loss == 'kl':
            self._loss_con = KLDivergence()
        else:
            raise ValueError('Invalid consistency metric')

        # metrics for training
        self._loss1 = Mean(
            name='loss1')  # we want to average the loss for each batch
        self._loss2 = Mean(name='loss2')
        self._loss1_cls = Mean(name='loss1_cls')
        self._loss2_cls = Mean(name='loss2_cls')
        self._loss1_con = Mean(name='loss1_con')
        self._loss2_con = Mean(name='loss2_con')
        self._loss1_sta = Mean(name='loss1_sta')
        self._loss2_sta = Mean(name='loss2_sta')
        self._acc1 = SparseCategoricalAccuracy(name='acc1')
        self._acc2 = SparseCategoricalAccuracy(name='acc2')

        # metrics for testing
        self._test_loss1 = Mean(name='test_loss1')
        self._test_loss2 = Mean(name='test_loss2')
        self._test_acc1_train_phones = SparseCategoricalAccuracy(
            name='test_acc1_train_phones')
        self._test_acc2_train_phones = SparseCategoricalAccuracy(
            name='test_acc2_train_phones')
        self._test_acc1 = Accuracy(name='test_acc1')
        self._test_acc2 = Accuracy(name='test_acc2')
        self._test_per1 = PhoneErrorRate(name='test_per1')
        self._test_per2 = PhoneErrorRate(name='test_per2')

        # compose students
        if version == 'mono_directional':
            lstm_types = ['mono_directional', 'mono_directional']
        elif version == 'bidirectional':
            lstm_types = ['bidirectional', 'bidirectional']
        elif version == 'imbalanced':
            lstm_types = ['mono_directional', 'bidirectional']
        else:
            raise ValueError('Invalid student version')
        self.student1 = self._get_student('student1', lstm_types[0])
        self.student2 = self._get_student('student2', lstm_types[1])

        # masking layer (just to use compute_mask and remove padding)
        self.mask = Masking(mask_value=self.padding_value)

    def _get_student(self, name, lstm_type):
        student = Sequential(name=name)
        student.add(Masking(mask_value=self.padding_value))
        if lstm_type == 'mono_directional':
            for i in range(self.n_hidden_layers):
                student.add(LSTM(units=self.n_units, return_sequences=True))
        elif lstm_type == 'bidirectional':
            for i in range(self.n_hidden_layers):
                student.add(
                    Bidirectional(
                        LSTM(units=self.n_units, return_sequences=True)))
        else:
            raise ValueError('Invalid LSTM version')
        student.add(Dense(units=self.n_classes, activation="softmax"))
        return student

    def _noisy_augment(self, x):
        return x + tf.random.normal(shape=x.shape, stddev=self.sigma)

    def call(self, inputs, training=False, student='student1', **kwargs):
        """
        Feed-forwards inputs to one of the students.

        This function is called internally by __call__(). Do not use it directly, use the model as callable. You may
        prefer to use pad_and_predict() instead of this, because it pads the sequences and splits in batches. For a big
        dataset, it is strongly suggested that you use pad_and_predict().

        :param inputs: tensor of shape (batch_size, n_frames, n_features)
        :param training: boolean, whether the call is in inference mode or training mode
        :param student: one of 'student1', 'student2'
        :return: tensor of shape (batch_size, n_frames, n_classes), softmax activations (probabilities)
        """
        if student == 'student1':
            return self.student1(inputs, training=training)
        elif student != 'student1':
            return self.student2(inputs, training=training)
        else:
            raise ValueError('Invalid student')

    def build(self, input_shape):
        super(DualStudent, self).build(input_shape)
        self.student1.build(input_shape)
        self.student2.build(input_shape)

    def train(self,
              x_labeled,
              x_unlabeled,
              y_labeled,
              x_val=None,
              y_val=None,
              n_epochs=10,
              batch_size=32,
              shuffle=True,
              evaluation_mapping=None,
              logs_path=None,
              checkpoints_path=None,
              initial_epoch=0,
              seed=None):
        """
        Trains the students with both labeled and unlabeled data (semi-supervised learning).

        :param x_labeled: numpy array of numpy arrays (n_frames, n_features), features corresponding to y_labeled.
            'n_frames' can vary, padding is added to make x_labeled a tensor.
        :param x_unlabeled: numpy array of numpy arrays of shape (n_frames, n_features), features without labels.
            'n_frames' can vary, padding is added to make x_unlabeled a tensor.
        :param y_labeled: numpy array of numpy arrays of shape (n_frames,), labels corresponding to x_labeled.
            'n_frames' can vary, padding is added to make y_labeled a tensor.
        :param x_val: like x_labeled, but for validation set
        :param y_val: like y_labeled, but for validation set
        :param n_epochs: integer, number of training epochs
        :param batch_size: integer, batch size
        :param shuffle: boolean, whether to shuffle at each epoch or not
        :param evaluation_mapping: dictionary {training label -> test label}, the test phones should be a subset of the
            training phones
        :param logs_path: path where to save logs for TensorBoard
        :param checkpoints_path: path to a directory. If the directory contains checkpoints, the latest checkpoint is
            restored.
        :param initial_epoch: int, initial epoch from which to start the training. It can be used together with
            checkpoints_path to resume the training from a previous run.
        :param seed: seed for the random number generator
        """
        # set seed
        if seed is not None:
            np.random.seed(seed)
            tf.random.set_seed(seed)

        # show summary
        self.build(input_shape=(None, ) + x_labeled[0].shape)
        self.student1.summary()
        self.student2.summary()

        # setup for logs
        train_summary_writer = None
        if logs_path is not None:
            train_summary_writer = tf.summary.create_file_writer(logs_path)

        # setup for checkpoints
        checkpoint = None
        if checkpoints_path is not None:
            checkpoint = tf.train.Checkpoint(optimizer=self.optimizer,
                                             model=self)
            checkpoint_path = tf.train.latest_checkpoint(checkpoints_path)
            if checkpoint_path is not None:
                checkpoint.restore(checkpoint_path)
            checkpoint_path = Path(checkpoints_path) / 'ckpt'
            checkpoint_path = str(checkpoint_path)

        # compute batch sizes
        labeled_batch_size = ceil(
            len(x_labeled) / (len(x_unlabeled) + len(x_labeled)) * batch_size)
        unlabeled_batch_size = batch_size - labeled_batch_size
        n_batches = min(ceil(len(x_unlabeled) / unlabeled_batch_size),
                        ceil(len(x_labeled) / labeled_batch_size))

        # training loop
        for epoch in trange(initial_epoch, n_epochs, desc='epochs'):
            # ramp up lambda1 and lambda2
            self._lambda1 = self.consistency_scale * self.schedule_fn(
                epoch, self.schedule_length)
            self._lambda2 = self.stabilization_scale * self.schedule_fn(
                epoch, self.schedule_length)

            # shuffle training set
            if shuffle:
                indices = np.arange(
                    len(x_labeled)
                )  # get indices to shuffle coherently features and labels
                np.random.shuffle(indices)
                x_labeled = x_labeled[indices]
                y_labeled = y_labeled[indices]
                np.random.shuffle(x_unlabeled)

            for i in trange(n_batches, desc='batches'):
                # select batch
                x_labeled_batch = select_batch(x_labeled, i,
                                               labeled_batch_size)
                x_unlabeled_batch = select_batch(x_unlabeled, i,
                                                 unlabeled_batch_size)
                y_labeled_batch = select_batch(y_labeled, i,
                                               labeled_batch_size)

                # pad batch
                x_labeled_batch = pad_sequences(x_labeled_batch,
                                                padding='post',
                                                value=self.padding_value,
                                                dtype='float32')
                x_unlabeled_batch = pad_sequences(x_unlabeled_batch,
                                                  padding='post',
                                                  value=self.padding_value,
                                                  dtype='float32')
                y_labeled_batch = pad_sequences(y_labeled_batch,
                                                padding='post',
                                                value=-1)

                # convert to tensors
                x_labeled_batch = tf.convert_to_tensor(x_labeled_batch)
                x_unlabeled_batch = tf.convert_to_tensor(x_unlabeled_batch)
                y_labeled_batch = tf.convert_to_tensor(y_labeled_batch)

                # train step
                self._train_step(x_labeled_batch, x_unlabeled_batch,
                                 y_labeled_batch)

            # put metrics in dictionary (easy management)
            train_metrics = {
                self._loss1.name: self._loss1.result(),
                self._loss2.name: self._loss2.result(),
                self._loss1_cls.name: self._loss1_cls.result(),
                self._loss2_cls.name: self._loss2_cls.result(),
                self._loss1_con.name: self._loss1_con.result(),
                self._loss2_con.name: self._loss2_con.result(),
                self._loss1_sta.name: self._loss1_sta.result(),
                self._loss2_sta.name: self._loss2_sta.result(),
                self._acc1.name: self._acc1.result(),
                self._acc2.name: self._acc2.result(),
            }
            metrics = {'train': train_metrics}

            # test on validation set
            if x_val is not None and y_val is not None:
                val_metrics = self.test(x_val,
                                        y_val,
                                        evaluation_mapping=evaluation_mapping)
                metrics['val'] = val_metrics

            # print metrics
            for dataset, metrics_ in metrics.items():
                print(f'Epoch {epoch + 1} - ', dataset, ' - ', sep='', end='')
                for k, v in metrics_.items():
                    print(f'{k}: {v}, ', end='')
                print()

            # save logs
            if train_summary_writer is not None:
                with train_summary_writer.as_default():
                    for dataset, metrics_ in metrics.items():
                        for k, v in metrics_.items():
                            tf.summary.scalar(k, v, step=epoch)

            # save checkpoint
            if checkpoint is not None:
                checkpoint.save(file_prefix=checkpoint_path)

            # reset metrics
            self._loss1.reset_states()
            self._loss2.reset_states()
            self._loss1_cls.reset_states()
            self._loss2_cls.reset_states()
            self._loss1_con.reset_states()
            self._loss2_con.reset_states()
            self._loss1_sta.reset_states()
            self._loss2_sta.reset_states()
            self._acc1.reset_states()
            self._acc2.reset_states()

    """
    If you want to use graph execution, pad the whole dataset externally and uncomment the decorator below.
    If you uncomment the decorator without padding the dataset, the graph will be compiled for each batch, 
    because train() pads at batch level and so the batches have different shapes. This would result in worse
    performance compared to eager execution.
    """

    # @tf.function
    def _train_step(self, x_labeled, x_unlabeled, y_labeled):
        # noisy augmented batches (TODO: improvement with data augmentation instead of noise)
        B1_labeled = self._noisy_augment(x_labeled)
        B2_labeled = self._noisy_augment(x_labeled)
        B1_unlabeled = self._noisy_augment(x_unlabeled)
        B2_unlabeled = self._noisy_augment(x_unlabeled)

        # compute masks (to remove padding)
        mask_labeled = self.mask.compute_mask(x_labeled)
        mask_unlabeled = self.mask.compute_mask(x_unlabeled)
        y_labeled = y_labeled[mask_labeled]  # remove padding from labels

        # forward pass
        with tf.GradientTape(persistent=True) as tape:
            # predict augmented labeled samples (for classification and consistency constraint)
            prob1_labeled_B1 = self.student1(B1_labeled, training=True)
            prob1_labeled_B2 = self.student1(B2_labeled, training=True)
            prob2_labeled_B1 = self.student2(B1_labeled, training=True)
            prob2_labeled_B2 = self.student2(B2_labeled, training=True)

            # predict augmented unlabeled samples (for consistency and stabilization constraints)
            prob1_unlabeled_B1 = self.student1(B1_unlabeled, training=True)
            prob1_unlabeled_B2 = self.student1(B2_unlabeled, training=True)
            prob2_unlabeled_B1 = self.student2(B1_unlabeled, training=True)
            prob2_unlabeled_B2 = self.student2(B2_unlabeled, training=True)

            # remove padding
            prob1_labeled_B1 = prob1_labeled_B1[mask_labeled]
            prob1_labeled_B2 = prob1_labeled_B2[mask_labeled]
            prob2_labeled_B1 = prob2_labeled_B1[mask_labeled]
            prob2_labeled_B2 = prob2_labeled_B2[mask_labeled]
            prob1_unlabeled_B1 = prob1_unlabeled_B1[mask_unlabeled]
            prob1_unlabeled_B2 = prob1_unlabeled_B2[mask_unlabeled]
            prob2_unlabeled_B1 = prob2_unlabeled_B1[mask_unlabeled]
            prob2_unlabeled_B2 = prob2_unlabeled_B2[mask_unlabeled]

            # compute classification losses
            L1_cls = self._loss_cls(y_labeled, prob1_labeled_B1)
            L2_cls = self._loss_cls(y_labeled, prob2_labeled_B2)

            # concatenate labeled and unlabeled probability predictions (for consistency loss)
            prob1_labeled_unlabeled_B1 = tf.concat(
                [prob1_labeled_B1, prob1_unlabeled_B1], axis=0)
            prob1_labeled_unlabeled_B2 = tf.concat(
                [prob1_labeled_B2, prob1_unlabeled_B2], axis=0)
            prob2_labeled_unlabeled_B1 = tf.concat(
                [prob2_labeled_B1, prob2_unlabeled_B1], axis=0)
            prob2_labeled_unlabeled_B2 = tf.concat(
                [prob2_labeled_B2, prob2_unlabeled_B2], axis=0)

            # compute consistency losses
            L1_con = self._loss_con(prob1_labeled_unlabeled_B1,
                                    prob1_labeled_unlabeled_B2)
            L2_con = self._loss_con(prob2_labeled_unlabeled_B1,
                                    prob2_labeled_unlabeled_B2)

            # prediction
            P1_unlabeled_B1 = tf.argmax(prob1_unlabeled_B1, axis=-1)
            P1_unlabeled_B2 = tf.argmax(prob1_unlabeled_B2, axis=-1)
            P2_unlabeled_B1 = tf.argmax(prob2_unlabeled_B1, axis=-1)
            P2_unlabeled_B2 = tf.argmax(prob2_unlabeled_B2, axis=-1)

            # confidence (probability of predicted class)
            M1_unlabeled_B1 = tf.reduce_max(prob1_unlabeled_B1, axis=-1)
            M1_unlabeled_B2 = tf.reduce_max(prob1_unlabeled_B2, axis=-1)
            M2_unlabeled_B1 = tf.reduce_max(prob2_unlabeled_B1, axis=-1)
            M2_unlabeled_B2 = tf.reduce_max(prob2_unlabeled_B2, axis=-1)

            # stable samples (masks to index probabilities)
            R1 = tf.logical_and(
                P1_unlabeled_B1 == P1_unlabeled_B2,
                tf.logical_or(M1_unlabeled_B1 > self.xi,
                              M1_unlabeled_B2 > self.xi))
            R2 = tf.logical_and(
                P2_unlabeled_B1 == P2_unlabeled_B2,
                tf.logical_or(M2_unlabeled_B1 > self.xi,
                              M2_unlabeled_B2 > self.xi))
            R12 = tf.logical_and(R1, R2)

            # stabilities
            epsilon1 = MSE(prob1_unlabeled_B1[R12], prob1_unlabeled_B2[R12])
            epsilon2 = MSE(prob2_unlabeled_B1[R12], prob2_unlabeled_B2[R12])

            # compute stabilization losses
            L1_sta = self._loss_sta(
                prob1_unlabeled_B1[R12][epsilon1 > epsilon2],
                prob2_unlabeled_B1[R12][epsilon1 > epsilon2])
            L2_sta = self._loss_sta(
                prob1_unlabeled_B2[R12][epsilon1 < epsilon2],
                prob2_unlabeled_B2[R12][epsilon1 < epsilon2])

            L1_sta += self._loss_sta(
                prob1_unlabeled_B1[tf.logical_and(tf.logical_not(R1), R2)],
                prob2_unlabeled_B1[tf.logical_and(tf.logical_not(R1), R2)])
            L2_sta += self._loss_sta(
                prob1_unlabeled_B2[tf.logical_and(R1, tf.logical_not(R2))],
                prob2_unlabeled_B2[tf.logical_and(R1, tf.logical_not(R2))])

            # compute complete losses
            L1 = L1_cls + self._lambda1 * L1_con + self._lambda2 * L1_sta
            L2 = L2_cls + self._lambda1 * L2_con + self._lambda2 * L2_sta

        # backward pass
        gradients1 = tape.gradient(L1, self.student1.trainable_variables)
        gradients2 = tape.gradient(L2, self.student2.trainable_variables)
        self.optimizer.apply_gradients(
            zip(gradients1, self.student1.trainable_variables))
        self.optimizer.apply_gradients(
            zip(gradients2, self.student2.trainable_variables))
        del tape  # to release memory (persistent tape)

        # update metrics
        self._loss1.update_state(L1)
        self._loss2.update_state(L2)
        self._loss1_cls.update_state(L1_cls)
        self._loss2_cls.update_state(L2_cls)
        self._loss1_con.update_state(L1_con)
        self._loss2_con.update_state(L2_con)
        self._loss1_sta.update_state(L1_sta)
        self._loss2_sta.update_state(L2_sta)
        self._acc1.update_state(y_labeled, prob1_labeled_B1)
        self._acc2.update_state(y_labeled, prob2_labeled_B2)

    def test(self, x, y, batch_size=32, evaluation_mapping=None):
        """
        Tests the model (both students).

        :param x: numpy array of numpy arrays (n_frames, n_features), features corresponding to y_labeled.
            'n_frames' can vary, padding is added to make x a tensor.
        :param y: numpy array of numpy arrays of shape (n_frames,), labels corresponding to x_labeled.
            'n_frames' can vary, padding is added to make y a tensor.
        :param batch_size: integer, batch size
        :param evaluation_mapping: dictionary {training label -> test label}, the test phones should be a subset of the
            training phones
        :return: dictionary {metric_name -> value}
        """
        # test batch by batch
        n_batches = ceil(len(x) / batch_size)
        for i in trange(n_batches, desc='test batches'):
            # select batch
            x_batch = select_batch(x, i, batch_size)
            y_batch = select_batch(y, i, batch_size)

            # pad batch
            x_batch = pad_sequences(x_batch,
                                    padding='post',
                                    value=self.padding_value,
                                    dtype='float32')
            y_batch = pad_sequences(y_batch, padding='post', value=-1)

            # convert to tensors
            x_batch = tf.convert_to_tensor(x_batch)
            y_batch = tf.convert_to_tensor(y_batch)

            # test step
            self._test_step(x_batch, y_batch, evaluation_mapping)

        # put metrics in dictionary (easy management)
        test_metrics = {
            self._test_loss1.name:
            self._test_loss1.result(),
            self._test_loss2.name:
            self._test_loss2.result(),
            self._test_acc1_train_phones.name:
            self._test_acc1_train_phones.result(),
            self._test_acc2_train_phones.name:
            self._test_acc2_train_phones.result(),
            self._test_acc1.name:
            self._test_acc1.result(),
            self._test_acc2.name:
            self._test_acc2.result(),
            self._test_per1.name:
            self._test_per1.result(),
            self._test_per2.name:
            self._test_per2.result(),
        }

        # reset metrics
        self._test_loss1.reset_states()
        self._test_loss2.reset_states()
        self._test_acc1_train_phones.reset_states()
        self._test_acc2_train_phones.reset_states()
        self._test_acc1.reset_states()
        self._test_acc2.reset_states()
        self._test_per1.reset_states()
        self._test_per2.reset_states()

        return test_metrics

    # @tf.function      # see note in _train_step()
    def _test_step(self, x, y, evaluation_mapping):
        # compute mask (to remove padding)
        mask = self.mask.compute_mask(x)

        # forward pass
        y_prob1_train_phones = self.student1(x, training=False)
        y_prob2_train_phones = self.student2(x, training=False)
        y_pred1_train_phones = tf.argmax(y_prob1_train_phones, axis=-1)
        y_pred2_train_phones = tf.argmax(y_prob2_train_phones, axis=-1)
        y_train_phones = tf.identity(y)

        # map labels to set of test phones
        if evaluation_mapping is not None:
            y = tf.numpy_function(map_labels,
                                  [y_train_phones, evaluation_mapping],
                                  [tf.float32])
            y_pred1 = tf.numpy_function(
                map_labels, [y_pred1_train_phones, evaluation_mapping],
                [tf.float32])
            y_pred2 = tf.numpy_function(
                map_labels, [y_pred2_train_phones, evaluation_mapping],
                [tf.float32])
        else:
            y = y_train_phones
            y_pred1 = y_pred1_train_phones
            y_pred2 = y_pred2_train_phones

        # update phone error rate
        self._test_per1.update_state(y, y_pred1, mask)
        self._test_per2.update_state(y, y_pred2, mask)

        # remove padding
        y_pred1 = y_pred1[mask]
        y_pred2 = y_pred2[mask]
        y_prob1_train_phones = y_prob1_train_phones[mask]
        y_prob2_train_phones = y_prob2_train_phones[mask]
        y_train_phones = y_train_phones[mask]
        y = y[mask]

        # compute loss
        loss1 = self._loss_cls(y_train_phones, y_prob1_train_phones)
        loss2 = self._loss_cls(y_train_phones, y_prob2_train_phones)

        # update loss
        self._test_loss1.update_state(loss1)
        self._test_loss2.update_state(loss2)

        # update accuracy using training phones
        self._test_acc1_train_phones.update_state(y_train_phones,
                                                  y_prob1_train_phones)
        self._test_acc2_train_phones.update_state(y_train_phones,
                                                  y_prob2_train_phones)

        # update accuracy using test phones
        self._test_acc1.update_state(y, y_pred1)
        self._test_acc2.update_state(y, y_pred2)
Exemplo n.º 10
0
class SBVAT(SupervisedModel):
    """
        Implementation of sample-based Batch Virtual Adversarial Training  Graph Convolutional Networks (SBVAT). 
        [Batch Virtual Adversarial Training for Graph Convolutional Networks](https://arxiv.org/pdf/1902.09192)
        Tensorflow 1.x implementation: https://github.com/thudzj/BVAT

        Arguments:
        ----------
            adj: `scipy.sparse.csr_matrix` (or `csc_matrix`) with shape (N, N)
                The input `symmetric` adjacency matrix, where `N` is the number of nodes 
                in graph.
            features: `np.array` with shape (N, F)
                The input node feature matrix, where `F` is the dimension of node features.
            labels: `np.array` with shape (N,)
                The ground-truth labels for all nodes in graph.
            n_samples (Positive integer, optional): 
                The number of sampled subset nodes in the graph where the shortest path 
                length between them is at least 4. (default :obj: `50`)
            normalize_rate (Float scalar, optional): 
                The normalize rate for adjacency matrix `adj`. (default: :obj:`-0.5`, 
                i.e., math:: \hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}) 
            normalize_features (Boolean, optional): 
                Whether to use row-normalize for node feature matrix. 
                (default :obj: `True`)
            device (String, optional): 
                The device where the model is running on. You can specified `CPU` or `GPU` 
                for the model. (default: :obj: `CPU:0`, i.e., the model is running on 
                the 0-th device `CPU`)
            seed (Positive integer, optional): 
                Used in combination with `tf.random.set_seed & np.random.seed & random.seed` 
                to create a reproducible sequence of tensors across multiple calls. 
                (default :obj: `None`, i.e., using random seed)
            name (String, optional): 
                Name for the model. (default: name of class)

    """
    
    def __init__(self, adj, features, labels, n_samples=100, 
                 normalize_rate=-0.5, normalize_features=True, device='CPU:0', seed=None, **kwargs):
    
        super().__init__(adj, features, labels, device=device, seed=seed, **kwargs)
        
        self.normalize_rate = normalize_rate
        self.normalize_features = normalize_features            
        self.preprocess(adj, features)
        self.n_samples = n_samples

    def preprocess(self, adj, features):
        
        if self.normalize_rate is not None:
            adj = self._normalize_adj(adj, self.normalize_rate)        
            
        if self.normalize_features:
            features = self._normalize_features(features)
            
        self.neighbors = list(find_4o_nbrs(adj.indices, adj.indptr, np.arange(self.n_nodes)))

        with self.device:
            self.features, self.adj = self._to_tensor([features, adj])


    def build(self, hidden_layers=[16], activations=['relu'], dropout=0.5, 
              learning_rate=0.01, l2_norm=5e-4, p1=1., p2=1., 
              n_power_iterations=1, epsilon=0.03, xi=1e-6):
        
        with self.device:
            
            x = Input(batch_shape=[self.n_nodes, self.n_features], dtype=tf.float32, name='features')
            adj = Input(batch_shape=[self.n_nodes, self.n_nodes], dtype=tf.float32, sparse=True, name='adj_matrix')
            index = Input(batch_shape=[None],  dtype=tf.int32, name='index')

            self.GCN_layers = [GraphConvolution(hidden_layers[0], 
                                                activation=activations[0], 
                                                kernel_regularizer=regularizers.l2(l2_norm)),
                               GraphConvolution(self.n_classes)]
            self.dropout_layer = Dropout(dropout)
            
            logit = self.propagation(x, adj)
            output = tf.gather(logit, index)
            output = Softmax()(output)
            model = Model(inputs=[x, adj, index], outputs=output)
    
            self.model = model
            self.train_metric = SparseCategoricalAccuracy()
            self.test_metric = SparseCategoricalAccuracy()
            self.optimizer = Adam(lr=learning_rate)
            self.built = True
            
        self.p1 = p1 # Alpha
        self.p2 = p2 # Beta
        self.xi = xi # Small constant for finite difference
        self.epsilon = epsilon # Norm length for (virtual) adversarial training
        self.n_power_iterations = n_power_iterations #  Number of power iterations
            
    def propagation(self, x, adj, training=True):
        h = x
        for layer in self.GCN_layers:
            h = self.dropout_layer(h, training=training)
            h = layer([h, adj])
        return h
    
    @tf.function
    def do_train_forward(self, sequence):
        
        with self.device:
            self.train_metric.reset_states()
            
            for inputs, labels in sequence:
                x, adj, index, adv_mask = inputs
                with tf.GradientTape() as tape:
                    logit = self.propagation(x, adj)
                    output = tf.gather(logit, index)
                    output = softmax(output)

                    loss = tf.reduce_mean(sparse_categorical_crossentropy(labels, output))
                    entropy_loss = entropy_y_x(logit)
                    vat_loss = self.virtual_adversarial_loss(x, adj, logit=logit, adv_mask=adv_mask)
                    loss += self.p1 * vat_loss + self.p2 * entropy_loss
            
                    self.train_metric.update_state(labels, output)

                trainable_variables = self.model.trainable_variables
                gradients = tape.gradient(loss, trainable_variables)
                self.optimizer.apply_gradients(zip(gradients, trainable_variables))

        return loss, self.train_metric.result()
            
    @tf.function
    def do_test_forward(self, sequence):
            
        with self.device:
            self.test_metric.reset_states()
            
            for inputs, labels in sequence:
                x, adj, index, _ = inputs
                logit = self.propagation(x, adj, training=False)
                output = tf.gather(logit, index)
                output = softmax(output)
                loss = tf.reduce_mean(sparse_categorical_crossentropy(labels, output))
                self.test_metric.update_state(labels, output)
            
        return loss, self.test_metric.result()
        
        
    def do_forward(self, sequence, training=True):
        if training:
            loss, accuracy = self.do_train_forward(sequence)
        else:
            loss, accuracy = self.do_test_forward(sequence)
            
        return loss.numpy(), accuracy.numpy()

    
    def virtual_adversarial_loss(self, x, adj, logit, adv_mask):
        d = tf.random.normal(shape=tf.shape(x))
        
        for _ in range(self.n_power_iterations):
            d = get_normalized_vector(d) * self.xi
            logit_p = logit
            with tf.GradientTape() as tape:
                tape.watch(d)
                logit_m = self.propagation(x + d, adj)
                dist = kl_divergence_with_logit(logit_p, logit_m, adv_mask)
            grad = tape.gradient(dist, d)
            d = tf.stop_gradient(grad)

        r_vadv = get_normalized_vector(d) * self.epsilon
        logit_p = tf.stop_gradient(logit)
        logit_m = self.propagation(x + r_vadv, adj)
        loss = kl_divergence_with_logit(logit_p, logit_m, adv_mask)
        return tf.identity(loss)    

    
    def train_sequence(self, index):
        index = self._check_and_convert(index)
        labels = self.labels[index]
           
        with self.device:
            sequence = NodeSampleSequence([self.features, self.adj, index], labels,
                                          neighbors=self.neighbors,
                                          n_samples=self.n_samples)
            
        return sequence    
    
    def test_sequence(self, index):
        index = self._check_and_convert(index)
        labels = self.labels[index]
           
        with self.device:
            sequence = NodeSampleSequence([self.features, self.adj, index], labels,
                                          neighbors=self.neighbors,
                                          n_samples=self.n_samples,
                                          resample=False)
            
        return sequence        
    
    def predict(self, index):
        super().predict(index)
        index = self._check_and_convert(index)
        
        with self.device:
            sequence = NodeSampleSequence([self.features, self.adj, index], None,
                                          neighbors=self.neighbors,
                                          n_samples=self.n_samples,
                                          resample=False)
            for inputs, _ in sequence:
                x, adj, index, adv_mask = inputs
                output = self.propagation(x, adj, training=False)
                logit = softmax(tf.gather(output, index))
                
        return logit.numpy()      
Exemplo n.º 11
0
    # Create masks
    encoder_padding_mask = maskHandler.padding_mask(input_language)
    decoder_padding_mask = maskHandler.padding_mask(input_language)

    look_ahead_mask = maskHandler.look_ahead_mask(tf.shape(target_language)[1])
    decoder_target_padding_mask = maskHandler.padding_mask(target_language)
    combined_mask = tf.maximum(decoder_target_padding_mask, look_ahead_mask)

    # Run training step
    with tf.GradientTape() as tape:
        predictions, _ = transformer(input_language, target_input, True,
                                     encoder_padding_mask, combined_mask,
                                     decoder_padding_mask)
        total_loss = padded_loss_function(tartet_output, predictions)

    gradients = tape.gradient(total_loss, transformer.trainable_variables)
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
    training_loss(total_loss)
    training_accuracy(tartet_output, predictions)


for epoch in tqdm(range(20)):
    training_loss.reset_states()
    training_accuracy.reset_states()

    for (batch, (input_language,
                 target_language)) in enumerate(data_container.train_data):
        train_step(input_language, target_language)

    print('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(
        epoch, train_loss.result(), train_accuracy.result()))
Exemplo n.º 12
0
def train_eagerly(model: Model, train_dataset, val_dataset, optimizer,
                  epochs: int, log_dir: str):
    """Trains the model for a fixed number of epochs (iterations on a dataset).

    Args:
        model: A Keras model instance.
        train_dataset: A `tf.data` dataset. Should return a tuple
            of `(inputs, labels)`
        val_dataset: A `tf.data` dataset on which to evaluate
            the loss and any metrics at the end of each epoch.
            Should return a tuple of `(inputs, labels)`
        optimizer: A Keras optimizer instance.
        epochs: Number of epochs to train the model.
        log_dir: Path to the directory where TensorBoard logs will be written.

    Returns:
        model: A Keras trained model.
    """

    criterion = SparseCategoricalCrossentropy()

    train_summary_writer = tf.summary.create_file_writer(
        os.path.join(log_dir, 'train'))
    val_summary_writer = tf.summary.create_file_writer(
        os.path.join(log_dir, 'validation'))

    # Defines metrics for logging to TensorBoard.
    train_loss = Mean('train_loss', dtype=tf.float32)
    train_accuracy = SparseCategoricalAccuracy('train_accuracy')
    val_loss = Mean('val_loss', dtype=tf.float32)
    val_accuracy = SparseCategoricalAccuracy('val_accuracy')

    for epoch in range(epochs):
        # For human-readability purposes,
        # epoch logging starts from 1 rather than 0.
        print(f'Epoch: {epoch + 1}/{epochs + 1}')

        # Training
        for batch, (inputs, labels) in enumerate(train_dataset):
            train_step(model, optimizer, criterion, inputs, labels, train_loss,
                       train_accuracy)
        with train_summary_writer.as_default():
            tf.summary.scalar('loss', train_loss.result(), step=epoch)
            tf.summary.scalar('accuracy', train_accuracy.result(), step=epoch)

        # Validation
        for batch, (inputs, labels) in enumerate(val_dataset):
            val_step(model, criterion, inputs, labels, val_loss, val_accuracy)
        with val_summary_writer.as_default():
            tf.summary.scalar('loss', val_loss.result(), step=epoch)
            tf.summary.scalar('accuracy', val_accuracy.result(), step=epoch)

        template = 'loss: {:.4f} - accuracy: {:.4f} - ' \
                   'val_loss: {:.4f} - val_accuracy: {:.4f}'
        print(
            template.format(train_loss.result(),
                            train_accuracy.result() * 100, val_loss.result(),
                            val_accuracy.result() * 100))

        # Reset metrics every epoch
        train_loss.reset_states()
        val_loss.reset_states()
        train_accuracy.reset_states()
        val_accuracy.reset_states()

    return model
Exemplo n.º 13
0
loss_object = SparseCategoricalCrossentropy()
optimizer = SGD(learning_rate=1)

train_loss = Mean()
train_acc = SparseCategoricalAccuracy()

EPOCHS = 10

for epoch in range(EPOCHS):
    for x, y in train_ds:
        with tf.GradientTape() as tape:
            predictinos = model(x)
            loss = loss_object(y, predictinos)
        
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        train_loss(loss)
        train_acc(y, predictinos)
    print('Epoch: ', epoch + 1)
    template = 'Train Loss: {:.4f}\t Train Accuracy: {:.2f}%\n'
    print(template.format(train_loss.result(), train_acc.result()*100))

    train_loss.reset_states()
    train_acc.reset_states()

model1 = Sequential()
model1.add(Dense(units=2, activation='softmax'))
model1.compile(loss=loss_object, optimizer=optimizer, metrics=['accuracy'])
model1.fit(train_x_noise, train_y, epochs=10)
Exemplo n.º 14
0
        train_acc_metric.update_state(y, student_pred)

        # calculate gradients
        grads = tape.gradient(loss, stud_model.weights)

        # gradient descent
        optimizer.apply_gradients(zip(grads, stud_model.trainable_weights))

        # print some info
        print("Epoch {}, step {}, loss {:5f}".format(ep, step, loss))

    # get result of train accuracy metric
    print("Train accuracy is {:4f}".format(train_acc_metric.result()))

    # reset metric
    train_acc_metric.reset_states()

    # saving model
    stud_model.save(
        str(pathlib.Path(__file__).parent.absolute()) +
        "/saved_model_distillation")

    for x_val, y_val in val_ds:
        # forward pass of student model
        student_val_pred = stud_model(x_val, training=False)
        # assert stud_model.trainable == False, 'Student model should not be trainable in val'

        # update val accuracy metric
        val_acc_metric.update_state(y_val, student_val_pred)

    # get result of train accuracy metric
class Model_Trainer:
    # Training Wrapper For Tensorflow Models. Allows a predifined model to be easily trained
    # while also tracking parameter and gradient information.

    # Please ensure that model_id is unique. It provides the path for all model statistics.
    """
    NO longer in use, model.fit provides significant training speed up. the features utilized below
    will be replaced with tensorflow callbacks.
    """
    def __init__(self,
                 model,
                 model_id,
                 lr=1e-4,
                 optimizer=None,
                 data_augmentation=None):
        """
        Parameters
        ----------

        model: tensorflow.keras.Model
        model_id : string
            An identifying string used in saving model metrics.
        lr : float, tensorflow.keras.optimizers.schedules
            If using the default optimizer, this is the lr used in the Adam optimizer.
            This value is ignored if an optimizer is passed to the trainer.
        optimizer : tensorflow.keras.optimizers
            A pre-defined optimizer used in training the neural network
        data_augmentation : tensorflow.keras.Sequential
            A tensorflow model used to perform data augmentation during training.
            See here: https://www.tensorflow.org/tutorials/images/data_augmentation#use_keras_preprocessing_layers
        """

        self.lr = lr

        self.model = model
        self.init_loss()

        # Can optionally pass a seperate optimizer.
        if optimizer is not None:
            self.optimizer = optimizer
        else:
            self.init_optimizer()

        if data_augmentation is not None:
            self.is_data_augmentation = True
            self.data_augmentation = data_augmentation
        else:
            self.is_data_augmentation = False
            self.data_augmentation = None

        # Used to save the parameters of the model at a given point of time.
        self.checkpoint = tf.train.Checkpoint(model=self.model)
        self.checkpoint_path = (self.model.__class__.__name__ + "/" +
                                model_id + "/training_checkpoints")

        self.summary_path = (self.model.__class__.__name__ + "/" + model_id +
                             "/summaries/")
        self.summary_writer = tf.summary.create_file_writer(self.summary_path)

        self.gradients = None

    # initialize loss function and metrics to track over training
    def init_loss(self):
        self.loss_function = SparseCategoricalCrossentropy()

        self.train_loss = Mean(name="train_loss")
        self.train_accuracy = SparseCategoricalAccuracy(name="train_accuracy")

        self.test_loss = Mean(name="test_loss")
        self.test_accuracy = SparseCategoricalAccuracy(name="test_accuracy")

    # Initialize Model optimizer
    def init_optimizer(self):
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr,
                                                  epsilon=1e-8)

    # Take a single Training step on the given batch of training data.
    def train_step(self, images, labels, track_gradient=False):

        with tf.GradientTape() as gtape:
            predictions = self.model(images, training=True)
            loss = self.loss_function(labels, predictions)

        gradients = gtape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(
            zip(gradients, self.model.trainable_variables))

        # Track model Performance
        self.train_loss(loss)
        self.train_accuracy(labels, predictions)

        return self.train_loss.result(), self.train_accuracy.result() * 100

    # Evaluate Model on Test Data
    def test_step(self, data_set):
        predictions = self.model.predict(images)
        test_loss = self.loss_function(labels, predictions)

        self.test_loss(test_loss)
        self.test_accuracy(labels, predictions)

        return self.test_loss.result(), self.test_accuracy.result() * 100

    # Reset Metrics
    def reset(self):
        self.train_loss.reset_states()
        self.train_accuracy.reset_states()

        self.test_loss.reset_states()
        self.test_accuracy.reset_states()

    # Save a checkpoint instance of the model for later use
    def model_checkpoint(self):
        # Save a checkpoint to /tmp/training_checkpoints-{save_counter}
        save_path = self.checkpoint.save(self.checkpoint_path)
        return save_path

    def log_metrics(self, ):
        # Log metrics using tensorflow summary writer. Can Then visualize using TensorBoard
        step = self.checkpoint.save_counter

        with self.summary_writer.as_default():
            tf.summary.scalar("Train Loss",
                              self.train_loss.result(),
                              step=step)
            tf.summary.scalar("Train Accuracy",
                              self.train_accuracy.result(),
                              step=step)
            tf.summary.scalar("Test Loss", self.test_loss.result(), step=step)
            tf.summary.scalar("Test Accuracy",
                              self.test_accuracy.result(),
                              step=step)
def main():
    # load data from the data files
    jpn_data = get_data(jpn_txt_path)
    en_data = get_data(en_txt_path)
    #train_jpn, val_jpn, train_en, val_en = train_test_split(jpn_data,
    #                                                        en_data,
    #                                                        test_size=TR_TE_RATIO)
    JPN_MAX_LEN = get_max_len(jpn_data)
    EN_MAX_LEN = get_max_len(en_data)
    # include [BOS] and [EOS] to each max len above
    JPN_MAX_LEN += 2
    EN_MAX_LEN += 2

    test_jpn_data = [
        "今日は夜ごはん何にしようかな?", "ここ最近暑い日がずっと続きますね。", "来年は本当にオリンピックが開催されるでしょうか?",
        "将来の夢はエンジニアになることです。", "子供のころはあの公園でたくさん遊んだなー。", "今日は早く帰りたいな。",
        "明日は父の日だ。", "試験勉強はなかなか大変です。", "来年はおいしいお店に行きたいです。",
        "あそこの家にはまだ誰か住んでいますか?"
    ]

    #test_en_data = [[""],
    #                [""],
    #                [""],
    #                [""],
    #                [""]]

    # preprocess for the train dataset
    train_dataset = tf.data.Dataset.from_tensor_slices((jpn_data, en_data))
    train_dataset = train_dataset.map(tf_encode)
    train_dataset = train_dataset.cache()
    train_dataset = train_dataset.shuffle(
        len(jpn_data)).padded_batch(BATCH_SIZE)
    train_dataset = train_dataset.prefetch(AUTOTUNE)
    ## preprocess for the validation dataset
    #val_dataset = tf.data.Dataset.from_tensor_slices((val_jpn, val_en))
    #val_dataset = val_dataset.map(tf_encode)
    #val_dataset = val_dataset.padded_batch(BATCH_SIZE)
    # preprocess for the test data
    #test_dataset = tf.data.Dataset.from_tensor_slices((test_jpn_data, test_en_data))
    #test_dataset = test_dataset.map(tf_encode)
    #test_dataset = test_dataset.cache()
    #test_dataset = test_dataset.padded_batch(len(test_jpn_data))
    #test_dataset = test_dataset.prefetch(AUTOTUNE)

    # instantiate the Transformer model
    transformer = Transformer(num_layers=num_layers,
                              d_model=d_model,
                              num_heads=num_heads,
                              dff=dff,
                              input_vocab_size=jpn_vocab_size,
                              target_vocab_size=en_vocab_size,
                              pe_input=JPN_MAX_LEN,
                              pe_target=EN_MAX_LEN)
    # set learning rate, optimizer, loss and matrics
    learning_rate = CustomSchedule(d_model)
    optimizer = Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

    loss_object = SparseCategoricalCrossentropy(from_logits=True,
                                                reduction="none")

    def loss_function(label, pred):
        mask = tf.math.logical_not(tf.math.equal(label, 0))
        loss_ = loss_object(label, pred)
        mask = tf.cast(mask, dtype=loss_.dtype)
        loss_ *= mask
        return tf.reduce_sum(loss_) / tf.reduce_sum(mask)

    train_loss = Mean(name="train_loss")
    train_accuracy = SparseCategoricalAccuracy(name="train_accuracy")
    """
    The @tf.function trace-compiles train_step into a TF graph for faster
    execution. The function specializes to the precise shape of the argument
    tensors. To avoid re-tracing due to the variable sequence lengths or
    variable batch sizes(usually the last batch is smaller), use input_signature
    to specify more generic shapes.
    """
    train_step_signature = [
        tf.TensorSpec(shape=(None, None), dtype=tf.int64),
        tf.TensorSpec(shape=(None, None), dtype=tf.int64)
    ]

    @tf.function(input_signature=train_step_signature)
    def train_step(inp, tar):
        tar_inp = tar[:, :-1]
        tar_label = tar[:, 1:]
        training = True

        enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
            inp, tar_inp)
        with tf.GradientTape() as tape:
            predictions, _ = transformer(inp, tar_inp, training,
                                         enc_padding_mask, combined_mask,
                                         dec_padding_mask)
            loss = loss_function(tar_label, predictions)

        gradients = tape.gradient(loss, transformer.trainable_variables)
        optimizer.apply_gradients(
            zip(gradients, transformer.trainable_variables))

        train_loss(loss)
        train_accuracy(tar_label, predictions)

    # set the checkpoint and the checkpoint manager
    ckpt = tf.train.Checkpoint(epoch=tf.Variable(0),
                               transformer=transformer,
                               optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_path, max_to_keep=5)
    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print("Latest checkpoint restored.")

    # set up summary writers
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = os.path.join(log_path, current_time, "train")
    #test_log_dir = os.path.join(log_path, current_time, "validation")
    summary_writer = tf.summary.create_file_writer(log_dir)
    #test_summary_writer = tf.summary.create_file_writer(test_log_dir)

    for ckpt.epoch in range(EPOCHS):
        start = time.time()
        ckpt.epoch.assign_add(1)
        train_loss.reset_states()
        train_accuracy.reset_states()

        # inp: Japanese, tar: English
        for (batch, (inp, tar)) in enumerate(train_dataset):
            train_step(inp, tar)

            if batch % 100 == 0:
                print("Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}".format(
                    ckpt.epoch, batch, train_loss.result(),
                    train_accuracy.result()))

        # output the training log for every epoch
        print("Epoch {} Loss {:.4f} Accuracy {:.4f}".format(
            ckpt.epoch, train_loss.result(), train_accuracy.result()))
        print("Time taken for 1 epoch: {:.3f} secs\n".format(time.time() -
                                                             start))

        # check how the model performs for every epoch
        test_summary_log = test_translate(test_jpn_data, EN_MAX_LEN,
                                          transformer)

        with summary_writer.as_default():
            tf.summary.scalar("loss", train_loss.result(), step=ckpt.epoch)
            tf.summary.scalar("accuracy",
                              train_accuracy.result(),
                              step=ckpt.epoch)
            tf.summary.text("test_text", test_summary_log, step=ckpt.epoch)

        if (ckpt.epoch) % 5 == 0:
            ckpt_save_path = ckpt_manager.save()
            print("Saving checkpoint for epoch {} at {}".format(
                ckpt.epoch, ckpt_save_path))