예제 #1
0
def batch_and_repeat(ds: Dataset, batch_size: int, shuffle: bool,
                     repeat: bool) -> Dataset:
    ds = ds.prefetch(buffer_size=AUTOTUNE)
    if shuffle:
        ds = ds.shuffle(1024, seed=SEED)
    if repeat:
        ds = ds.repeat()
    if batch_size > 0:
        ds = ds.batch(batch_size, drop_remainder=False)
    return ds
예제 #2
0
    def train(self,
              train_dataset: Dataset,
              valid_dataset: Dataset = None,
              batch_size: int = 256,
              epochs: int = 16,
              checkpoints_path: Path = None):
        print("Training model...")

        ckpt = None
        manager = None
        if checkpoints_path is not None:
            checkpoints_path.mkdir(parents=True, exist_ok=True)
            ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                                       optimizer=self.optimizer,
                                       net=self.network)
            manager = tf.train.CheckpointManager(ckpt,
                                                 checkpoints_path,
                                                 max_to_keep=3)
            ckpt.restore(manager.latest_checkpoint)
            if manager.latest_checkpoint:
                print(f"Restored from {manager.latest_checkpoint}")
            else:
                print("Initializing from scratch.")

        # Batch the datasets
        train_dataset = train_dataset.shuffle(1024).batch(batch_size).prefetch(
            buffer_size=tf.data.experimental.AUTOTUNE)
        valid_dataset = valid_dataset.batch(batch_size)

        # Start training the model.
        for epoch in range(1, epochs + 1):
            for images, labels in train_dataset:
                self._train_step(images, labels)

            for valid_images, valid_labels in valid_dataset:
                self._test_step(valid_images, valid_labels)

            if checkpoints_path is not None:
                ckpt.step.assign_add(1)
                if int(ckpt.step) % 10 == 0:
                    save_path = manager.save()
                    print(
                        f"💾 Saved checkpoint for step {int(ckpt.step)}: {save_path}"
                    )

            print(
                f"Epoch {epoch}, "
                f"Loss: {self.train_loss.result()}, Accuracy: {self.train_accuracy.result() * 100}, "
                f"Valid Loss: {self.test_loss.result()}, Valid Accuracy: {self.test_accuracy.result() * 100}"
            )

        # Save the model.
        self.network.trainable = False
        self.network.save(self.save_path)
    def train(self, checkpoints_path: Path, train_dataset: Dataset, valid_dataset: Dataset = None,
              batch_size: int = 256, epochs: int = 16):
        checkpoints_path.mkdir(parents=True, exist_ok=True)
        ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=self.optimizer, net=self.network)
        manager = tf.train.CheckpointManager(ckpt, checkpoints_path, max_to_keep=3)
        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            print(f"Restored from {manager.latest_checkpoint}")
        else:
            print("Initializing from scratch.")

        # Batch the datasets
        train_dataset = train_dataset.shuffle(1024).batch(batch_size).prefetch(
            buffer_size=tf.data.experimental.AUTOTUNE)
        valid_dataset = valid_dataset.batch(batch_size)

        # Start training the model.
        for epoch in range(1, epochs + 1):
            for images, labels in train_dataset:
                self._train_step(images, labels)

            for valid_images, valid_labels in valid_dataset:
                self._test_step(valid_images, valid_labels)

            ckpt.step.assign_add(1)
            if int(ckpt.step) % 10 == 0:
                save_path = manager.save()
                print(f"💾 Saved checkpoint for step {int(ckpt.step)}: {save_path}")

            print(f"Epoch {epoch}, "
                  f"Loss: {self.train_loss.result()}, Accuracy: {self.train_accuracy.result() * 100}, "
                  f"Valid Loss: {self.test_loss.result()}, Valid Accuracy: {self.test_accuracy.result() * 100}")

        # Save the model.
        self.network.trainable = False
        self.network.save(self.save_path)
예제 #4
0
def batch_and_repeat(
    ds: Dataset, batch_size: int, shuffle: bool, repeat: bool
) -> Dataset:
    """Helper method for to apply tensorflow shuffle, repeat and
    batch (in this order)

    Args:
        ds (Dataset): Tensorflow Dataset
        batch_size (int): Will call ds.batch(batch_size, drop_remainder=False)
            if batch_size is greater zero
        shuffle (int): Will call ds.shuffle(1024)
        repeat (bool): Will call ds.repeat()

    Returns:
        Dataset
    """
    ds = ds.prefetch(buffer_size=AUTOTUNE)
    if shuffle:
        ds = ds.shuffle(1024, seed=SEED)
    if repeat:
        ds = ds.repeat()
    if batch_size > 0:
        ds = ds.batch(batch_size, drop_remainder=False)
    return ds