示例#1
0
    def __init__(self):
        logging.basicConfig(
            filename="train_logs.txt",
            level=logging.DEBUG,
        )
        logger = logging.getLogger(__name__)

        strategy = tf.distribute.MirroredStrategy()
        timing = TimingLogger()
        super().__init__(logger, strategy, timing)
示例#2
0
def main():
    """Main valiation function."""
    timing = TimingLogger()
    timing.start()
    network_settings, train_settings, preprocess_settings = parseConfigsFile(
        ["network", "train", "preprocess"])

    strategy = tf.distribute.MirroredStrategy()
    BATCH_SIZE = train_settings["batch_size"] * strategy.num_replicas_in_sync

    LOGGER.info(" -------- Importing Datasets --------")

    vgg_dataset = VggFace2(mode="concatenated")
    synthetic_num_classes = vgg_dataset.get_number_of_classes()
    validation_dataset = _instantiate_dataset(strategy, BATCH_SIZE)
    # synthetic_num_classes = 8529

    LOGGER.info(" -------- Creating Models and Optimizers --------")

    srfr_model = _instantiate_models(strategy, network_settings,
                                     preprocess_settings,
                                     synthetic_num_classes)

    checkpoint, manager = _create_checkpoint_and_manager(srfr_model)

    test_summary_writer = _create_summary_writer()

    LOGGER.info(" -------- Starting Validation --------")
    with strategy.scope():
        validate_model_use_case = ValidateModelUseCase(strategy,
                                                       test_summary_writer,
                                                       TimingLogger(), LOGGER)

        for model_checkpoint in manager.checkpoints:
            try:
                checkpoint.restore(model_checkpoint)
            except:
                continue
            LOGGER.info(f" Restored from {model_checkpoint}")

            validate_model_use_case.execute(srfr_model, validation_dataset,
                                            BATCH_SIZE, checkpoint)
    def __init__(
            self,
            strategy,
            srfr_model,
            srfr_optimizer,
            discriminator_model,
            discriminator_optimizer,
            train_summary_writer,
            test_summary_writer,
            checkpoint,
            manager,
            ):
        self.strategy = strategy
        self.srfr_model = srfr_model
        self.srfr_optimizer = srfr_optimizer
        self.discriminator_model = discriminator_model
        self.discriminator_optimizer = discriminator_optimizer
        self.train_summary_writer = train_summary_writer
        self.test_summary_writer = test_summary_writer
        self.checkpoint = checkpoint
        self.manager = manager

        self.timing = TimingLogger()
        self.losses: Loss = None
示例#4
0
sys.path.append(os.path.abspath("."))  # isort:skip

import logging
from pathlib import Path

import cv2
import tensorflow as tf

from utils.input_data import InputData, parseConfigsFile
from utils.timing import TimingLogger

logging.basicConfig(filename="vgg_to_tfrecords.txt", level=logging.INFO)
LOGGER = logging.getLogger(__name__)

timing = TimingLogger()
timing.start()

LOGGER.info("--- Setting Functions ---")

SHAPE = tuple(
    parseConfigsFile(["preprocess"])["image_shape_low_resolution"][:2])

BASE_DATA_DIR = Path("/datasets/VGGFace2_LR/Images")
BASE_OUTPUT_PATH = Path("/workspace/datasets/VGGFace2")


def _reduce_resolution(high_resolution_image):
    low_resolution_image = cv2.cvtColor(
        cv2.resize(high_resolution_image, SHAPE,
                   interpolation=cv2.INTER_CUBIC),
def main():
    """Main training function."""
    timing = TimingLogger()
    timing.start()
    network_settings, train_settings, preprocess_settings = parseConfigsFile(
        ['network', 'train', 'preprocess'])

    strategy = tf.distribute.MirroredStrategy()
    BATCH_SIZE = train_settings['batch_size'] * strategy.num_replicas_in_sync
    temp_folder = Path.cwd().joinpath('temp', 'synthetic_ds')

    LOGGER.info(' -------- Importing Datasets --------')

    vgg_dataset = VggFace2(mode='concatenated')
    synthetic_dataset = vgg_dataset.get_dataset()
    synthetic_dataset = vgg_dataset.augment_dataset()
    synthetic_dataset = vgg_dataset.normalize_dataset()
    synthetic_dataset = synthetic_dataset.cache(str(temp_folder))
    #synthetic_dataset_len = vgg_dataset.get_dataset_size()
    synthetic_dataset_len = 100_000
    synthetic_num_classes = vgg_dataset.get_number_of_classes()
    synthetic_dataset = synthetic_dataset.shuffle(
        buffer_size=2_048).repeat().batch(BATCH_SIZE).prefetch(1)

    lfw_path = Path.cwd().joinpath('temp', 'lfw')
    lfw_dataset = LFW()
    (left_pairs, left_aug_pairs, right_pairs, right_aug_pairs,
     is_same_list) = lfw_dataset.get_dataset()
    left_pairs = left_pairs.batch(BATCH_SIZE).cache(
        str(lfw_path.joinpath('left'))).prefetch(AUTOTUNE)
    left_aug_pairs = left_aug_pairs.batch(BATCH_SIZE).cache(
        str(lfw_path.joinpath('left_aug'))).prefetch(AUTOTUNE)
    right_pairs = right_pairs.batch(BATCH_SIZE).cache(
        str(lfw_path.joinpath('right'))).prefetch(AUTOTUNE)
    right_aug_pairs = right_aug_pairs.batch(BATCH_SIZE).cache(
        str(lfw_path.joinpath('right_aug'))).prefetch(AUTOTUNE)

    # Using `distribute_dataset` to distribute the batches across the GPUs
    synthetic_dataset = strategy.experimental_distribute_dataset(
        synthetic_dataset)
    left_pairs = strategy.experimental_distribute_dataset(left_pairs)
    left_aug_pairs = strategy.experimental_distribute_dataset(left_aug_pairs)
    right_pairs = strategy.experimental_distribute_dataset(right_pairs)
    right_aug_pairs = strategy.experimental_distribute_dataset(right_aug_pairs)

    LOGGER.info(' -------- Creating Models and Optimizers --------')

    EPOCHS = generate_num_epochs(
        train_settings['iterations'],
        synthetic_dataset_len,
        BATCH_SIZE,
    )

    with strategy.scope():
        srfr_model = SRFR(
            num_filters=network_settings['num_filters'],
            depth=50,
            categories=network_settings['embedding_size'],
            num_gc=network_settings['gc'],
            num_blocks=network_settings['num_blocks'],
            residual_scailing=network_settings['residual_scailing'],
            training=True,
            input_shape=preprocess_settings['image_shape_low_resolution'],
            num_classes_syn=synthetic_num_classes,
        )
        sr_discriminator_model = DiscriminatorNetwork()

        srfr_optimizer = NovoGrad(
            learning_rate=train_settings['learning_rate'],
            beta_1=train_settings['momentum'],
            beta_2=train_settings['beta_2'],
            weight_decay=train_settings['weight_decay'],
            name='novograd_srfr',
        )
        srfr_optimizer = mixed_precision.LossScaleOptimizer(
            srfr_optimizer,
            loss_scale='dynamic',
        )
        discriminator_optimizer = NovoGrad(
            learning_rate=train_settings['learning_rate'],
            beta_1=train_settings['momentum'],
            beta_2=train_settings['beta_2'],
            weight_decay=train_settings['weight_decay'],
            name='novograd_discriminator',
        )
        discriminator_optimizer = mixed_precision.LossScaleOptimizer(
            discriminator_optimizer, loss_scale='dynamic')

        train_loss = partial(
            strategy.reduce,
            reduce_op=tf.distribute.ReduceOp.MEAN,
            axis=0,
        )

    checkpoint = tf.train.Checkpoint(
        epoch=tf.Variable(1),
        step=tf.Variable(1),
        srfr_model=srfr_model,
        sr_discriminator_model=sr_discriminator_model,
        srfr_optimizer=srfr_optimizer,
        discriminator_optimizer=discriminator_optimizer,
    )
    manager = tf.train.CheckpointManager(checkpoint,
                                         directory='./training_checkpoints',
                                         max_to_keep=5)

    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_summary_writer = tf.summary.create_file_writer(
        str(Path.cwd().joinpath('logs', 'gradient_tape', current_time,
                                'train')), )
    test_summary_writer = tf.summary.create_file_writer(
        str(Path.cwd().joinpath('logs', 'gradient_tape', current_time,
                                'test')), )

    LOGGER.info(' -------- Starting Training --------')
    with strategy.scope():
        checkpoint.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            LOGGER.info(f' Restored from {manager.latest_checkpoint}')
        else:
            LOGGER.info(' Initializing from scratch.')

        for epoch in range(int(checkpoint.epoch), EPOCHS + 1):
            timing.start(Train.__name__)
            LOGGER.info(f' Start of epoch {epoch}')

            train = Train(strategy, srfr_model, srfr_optimizer,
                          sr_discriminator_model, discriminator_optimizer,
                          train_summary_writer, test_summary_writer,
                          checkpoint, manager)
            srfr_loss, discriminator_loss = train.train_srfr_model(
                BATCH_SIZE,
                train_loss,
                synthetic_dataset,
                synthetic_num_classes,
                left_pairs,
                left_aug_pairs,
                right_pairs,
                right_aug_pairs,
                is_same_list,
                sr_weight=train_settings['super_resolution_weight'],
                scale=train_settings['scale'],
                margin=train_settings['angular_margin'],
                # natural_ds,
                # num_classes_natural,
            )
            elapsed_time = timing.end(Train.__name__, True)
            with train_summary_writer.as_default():
                tf.summary.scalar('srfr_loss_per_epoch', srfr_loss, step=epoch)
                tf.summary.scalar(
                    'discriminator_loss_per_epoch',
                    discriminator_loss,
                    step=epoch,
                )
                tf.summary.scalar('training_time_per_epoch',
                                  elapsed_time,
                                  step=epoch)
            LOGGER.info((f' Epoch {epoch}, SRFR Loss: {srfr_loss:.3f},'
                         f' Discriminator Loss: {discriminator_loss:.3f}'))

            train.save_model()

            checkpoint.epoch.assign_add(1)
示例#6
0
    def train(self):
        """Main training function."""
        self.timing.start()

        dimensions = self._create_dimensions()
        hyperparameters = self._create_hyprparameters_domain()
        with tf.summary.create_file_writer(
                str(Path.cwd().joinpath("output", "logs",
                                        "hparam_tuning"))).as_default():
            hp.hparams_config(
                hparams=hyperparameters,
                metrics=[hp.Metric("accuracy", display_name="Accuracy")],
            )

        (
            network_settings,
            train_settings,
            preprocess_settings,
        ) = parseConfigsFile(["network", "train", "preprocess"])

        BATCH_SIZE = train_settings[
            "batch_size"] * self.strategy.num_replicas_in_sync

        (
            synthetic_train,
            synthetic_test,
            synthetic_dataset_len,
            synthetic_num_classes,
        ) = self._get_datasets(BATCH_SIZE)

        srfr_model, discriminator_model = self._instantiate_models(
            synthetic_num_classes, network_settings, preprocess_settings)

        train_model_sr_only_use_case = TrainModelSrOnlyUseCase(
            self.strategy,
            TimingLogger(),
            self.logger,
            BATCH_SIZE,
            synthetic_dataset_len,
        )

        _training = partial(
            self._fitness_function,
            train_model_use_case=train_model_sr_only_use_case,
            srfr_model=srfr_model,
            discriminator_model=discriminator_model,
            batch_size=BATCH_SIZE,
            synthetic_train=synthetic_train,
            synthetic_test=synthetic_test,
            num_classes=synthetic_num_classes,
            train_settings=train_settings,
            hparams=hyperparameters,
        )
        _train = use_named_args(dimensions=dimensions)(_training)

        initial_parameters = [0.0002, 0.9, 1.0, 0.005, 0.01]

        search_result = gp_minimize(
            func=_train,
            dimensions=dimensions,
            acq_func="EI",
            n_calls=20,
            x0=initial_parameters,
        )

        self.logger.info(f"Best hyperparameters: {search_result.x}")
示例#7
0
def main():
    """Main training function."""
    timing = TimingLogger()
    timing.start()
    strategy = tf.distribute.MirroredStrategy()
    # strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
    dimensions = _create_dimensions()
    hyperparameters = _create_hyprparameters_domain()
    with tf.summary.create_file_writer(
            str(Path.cwd().joinpath("output", "logs",
                                    "hparam_tuning"))).as_default():
        hp.hparams_config(
            hparams=hyperparameters,
            metrics=[hp.Metric("accuracy", display_name="Accuracy")],
        )

    (
        network_settings,
        train_settings,
        preprocess_settings,
    ) = parseConfigsFile(["network", "train", "preprocess"])

    BATCH_SIZE = train_settings["batch_size"] * strategy.num_replicas_in_sync

    (
        synthetic_train,
        synthetic_test,
        synthetic_dataset_len,
        synthetic_num_classes,
    ) = _get_datasets(BATCH_SIZE, strategy)

    srfr_model, discriminator_model = _instantiate_models(
        strategy, synthetic_num_classes, network_settings, preprocess_settings)

    train_model_use_case = TrainModelJointLearnUseCase(
        strategy,
        TimingLogger(),
        LOGGER,
        BATCH_SIZE,
        synthetic_dataset_len,
    )

    _training = partial(
        _instantiate_training,
        strategy=strategy,
        train_model_use_case=train_model_use_case,
        srfr_model=srfr_model,
        discriminator_model=discriminator_model,
        batch_size=BATCH_SIZE,
        synthetic_train=synthetic_train,
        synthetic_test=synthetic_test,
        num_classes=synthetic_num_classes,
        train_settings=train_settings,
        hparams=hyperparameters,
    )
    _train = use_named_args(dimensions=dimensions)(_training)

    search_result = gp_minimize(func=_train,
                                dimensions=dimensions,
                                acq_func="EI",
                                n_calls=20)

    LOGGER.info(f"Best hyperparameters: {search_result.x}")
示例#8
0
from pathlib import Path
import logging

import cv2
import tensorflow as tf

from utils.timing import TimingLogger
from utils.input_data import InputData, parseConfigsFile

logging.basicConfig(filename='vgg_to_tfrecords.txt', level=logging.INFO)
LOGGER = logging.getLogger(__name__)

timing = TimingLogger()
timing.start()

LOGGER.info('--- Setting Functions ---')

shape = tuple(
    parseConfigsFile(['preprocess'])['image_shape_low_resolution'][:2])


def _reduce_resolution(high_resolution_image):
    low_resolution_image = cv2.cvtColor(
        cv2.resize(high_resolution_image, shape,
                   interpolation=cv2.INTER_CUBIC), cv2.COLOR_BGR2RGB)
    high_resolution_image = cv2.cvtColor(high_resolution_image,
                                         cv2.COLOR_BGR2RGB)
    return tf.image.encode_png(low_resolution_image), tf.image.encode_png(
        high_resolution_image)

class Train():
    def __init__(
            self,
            strategy,
            srfr_model,
            srfr_optimizer,
            discriminator_model,
            discriminator_optimizer,
            train_summary_writer,
            test_summary_writer,
            checkpoint,
            manager,
            ):
        self.strategy = strategy
        self.srfr_model = srfr_model
        self.srfr_optimizer = srfr_optimizer
        self.discriminator_model = discriminator_model
        self.discriminator_optimizer = discriminator_optimizer
        self.train_summary_writer = train_summary_writer
        self.test_summary_writer = test_summary_writer
        self.checkpoint = checkpoint
        self.manager = manager

        self.timing = TimingLogger()
        self.losses: Loss = None

    def train_srfr_model(
            self,
            batch_size,
            train_loss_function,
            synthetic_dataset,
            num_classes_synthetic: int,
            left_pairs,
            left_aug_pairs,
            right_pairs,
            right_aug_pairs,
            is_same_list,
            sr_weight: float = 0.1,
            scale: float = 64,
            margin: float = 0.5,
            natural_dataset=None,
            num_classes_natural: int = None,
        ) -> float:
        """Train the model using the given dataset, compute the loss_function
        and apply the optimizer.

        Parameters
        ----------
            srfr_model: The Super Resolution Face Recognition model.
            sr_discriminator_model: The Discriminator model.
            batch_size: The Batch size.
            srfr_optimizer: Optimizer for the SRFR network.
            discriminator_optimizer: Optimizer for the Discriminator network.
            train_loss_function:
            sr_weight: Weight for the SR Loss.
            scale:
            margin:
            synthetic_dataset:
            num_classes_synthetic:
            natural_dataset:
            num_classes_natural:

        Returns
        -------
            (srfr_loss, discriminator_loss) The loss value for SRFR and
            Discriminator networks.
        """
        batch_size = tf.constant(batch_size, dtype=tf.float32)
        num_classes_synthetic = tf.constant(num_classes_synthetic,
                                            dtype=tf.int32)
        sr_weight = tf.constant(sr_weight, dtype=tf.float32)
        scale = tf.constant(scale, dtype=tf.float32)
        margin = tf.constant(margin, dtype=tf.float32)
        self.losses = Loss(self.srfr_model, batch_size,
                           self.train_summary_writer, sr_weight,
                           scale, margin)
        #if natural_dataset:
        #    return self._train_with_natural_images(
        #        batch_size,
        #        train_loss_function,
        #        synthetic_dataset,
        #        num_classes_synthetic,
        #        natural_dataset,
        #        num_classes_natural,
        #        sr_weight,
        #        scale,
        #        margin
        #    )

        return self._train_with_synthetic_images_only(
            batch_size,
            train_loss_function,
            synthetic_dataset,
            num_classes_synthetic,
            left_pairs,
            left_aug_pairs,
            right_pairs,
            right_aug_pairs,
            is_same_list,
        )

    def _train_with_synthetic_images_only(
            self,
            batch_size,
            train_loss_function,
            dataset,
            num_classes: int,
            left_pairs,
            left_aug_pairs,
            right_pairs,
            right_aug_pairs,
            is_same_list,
            ) -> float:
        srfr_losses = []
        discriminator_losses = []
        with self.strategy.scope():
            for step, (synthetic_images, groud_truth_images,
                       synthetic_classes) in enumerate(dataset, start=1):
                srfr_loss, discriminator_loss, super_resolution_images = \
                    self._train_step_synthetic_only(synthetic_images,
                                                    groud_truth_images,
                                                    synthetic_classes,
                                                    num_classes)
                srfr_losses.append(srfr_loss)
                discriminator_losses.append(discriminator_loss)

                if step % 1000 == 0:
                    step_batch = 'batch'
                    self.save_model()
                else:
                    step_batch = 'step'
                self._save_metrics(step, srfr_loss, discriminator_loss,
                                   batch_size,
                                   synthetic_images, groud_truth_images,
                                   super_resolution_images, step_batch)
                self.checkpoint.step.assign_add(1)

            if step % 5000 == 0:
                self._validate_on_lfw(left_pairs, left_aug_pairs, right_pairs,
                                      right_aug_pairs, is_same_list,
                                      batch_size)

        return (
            train_loss_function(srfr_losses),
            train_loss_function(discriminator_losses),
        )

    def _save_metrics(self, step, srfr_loss, discriminator_loss, batch_size,
                      synthetic_images, groud_truth_images,
                      super_resolution_images, step_batch='step') -> None:
        LOGGER.info(
            (
                f' SRFR Training loss (for one batch) at step {step}:'
                f' {float(srfr_loss):.3f}'
            )
        )
        LOGGER.info(
            (
                f' Discriminator loss (for one batch) at step {step}:'
                f' {float(discriminator_loss):.3f}'
            )
        )
        LOGGER.info(f' Seen so far: {step * batch_size} samples')
        if step_batch == 'step':
            step = int(self.checkpoint.step)
        else:
            step = int(self.checkpoint.epoch)
        with self.train_summary_writer.as_default():
            tf.summary.scalar(
                f'srfr_loss_per_{step_batch}',
                float(srfr_loss),
                step=step,
            )
            tf.summary.scalar(
                f'discriminator_loss_per_{step_batch}',
                float(discriminator_loss),
                step=step,
            )
            tf.summary.image(
                f'lr_images_per_{step_batch}',
                tf.concat(synthetic_images.values, axis=0),
                max_outputs=10,
                step=step
            )
            tf.summary.image(
                f'hr_images_per_{step_batch}',
                tf.concat(groud_truth_images.values, axis=0),
                max_outputs=10,
                step=step
            )
            tf.summary.image(
                f'sr_images_per_{step_batch}',
                tf.concat(super_resolution_images.values, axis=0),
                max_outputs=10,
                step=step
            )

    def save_model(self):
        save_path = self.manager.save()
        LOGGER.info((f' Saved checkpoint for epoch {int(self.checkpoint.step)}:'
                     f' {save_path}'))

    def _validate_on_lfw(self, left_pairs, left_aug_pairs, right_pairs,
                         right_aug_pairs, is_same_list):
        self.timing.start(validate_model_on_lfw.__name__)
        (accuracy_mean, accuracy_std, validation_rate, validation_std,
         far, auc, eer) = validate_model_on_lfw(
             self.strategy,
             self.srfr_model,
             left_pairs,
             left_aug_pairs,
             right_pairs,
             right_aug_pairs,
             is_same_list,
         )
        elapsed_time = self.timing.end(validate_model_on_lfw.__name__, True)
        with self.test_summary_writer.as_default():
            tf.summary.scalar('accuracy_mean', accuracy_mean,
                              step=int(self.checkpoint.step),)
            tf.summary.scalar('accuracy_std', accuracy_std,
                              step=int(self.checkpoint.step))
            tf.summary.scalar('validation_rate', validation_rate,
                              step=int(self.checkpoint.step))
            tf.summary.scalar('validation_std', validation_std,
                              step=int(self.checkpoint.step))
            tf.summary.scalar('far', far, step=int(self.checkpoint.step))
            tf.summary.scalar('auc', auc, step=int(self.checkpoint.step))
            tf.summary.scalar('eer', eer, step=int(self.checkpoint.step))
            tf.summary.scalar('testing_time', elapsed_time,
                              step=int(self.checkpoint.step))

        LOGGER.info((
            f' Validation on LFW: Step {int(self.checkpoint.step)} -'
            f' Accuracy: {accuracy_mean:.3f} +- {accuracy_std:.3f} -'
            f' Validation Rate: {validation_rate:.3f} +-'
            f' {validation_std:.3f} @ FAR {far:.3f} -'
            f' Area Under Curve (AUC): {auc:.3f} -'
            f' Equal Error Rate (EER): {eer:.3f} -'
        ))

    #@tf.function
    def _step_function(self, low_resolution_batch, groud_truth_batch,
                       ground_truth_classes, num_classes):
        with tf.GradientTape() as srfr_tape, \
                tf.GradientTape() as discriminator_tape:
            (super_resolution_images, embeddings) = self.srfr_model(
                low_resolution_batch)
            discriminator_sr_predictions = self.discriminator_model(
                super_resolution_images)
            discriminator_gt_predictions = self.discriminator_model(
                groud_truth_batch)
            synthetic_face_recognition = (embeddings, ground_truth_classes,
                                          num_classes)
            srfr_loss = self.losses.compute_joint_loss(
                super_resolution_images,
                groud_truth_batch,
                discriminator_sr_predictions,
                discriminator_gt_predictions,
                synthetic_face_recognition,
                self.checkpoint,
            )
            discriminator_loss = self.losses.compute_discriminator_loss(
                discriminator_sr_predictions,
                discriminator_gt_predictions,
            )
            srfr_loss = srfr_loss / self.strategy.num_replicas_in_sync
            discriminator_loss = (discriminator_loss /
                                  self.strategy.num_replicas_in_sync)
            srfr_scaled_loss = self.srfr_optimizer.get_scaled_loss(srfr_loss)
            discriminator_scaled_loss = self.discriminator_optimizer.\
                get_scaled_loss(discriminator_loss)

        srfr_grads = srfr_tape.gradient(srfr_scaled_loss,
                                        self.srfr_model.trainable_weights)
        discriminator_grads = discriminator_tape.gradient(
            discriminator_scaled_loss,
            self.discriminator_model.trainable_weights,
        )
        self.srfr_optimizer.apply_gradients(
            zip(self.srfr_optimizer.get_unscaled_gradients(srfr_grads),
                self.srfr_model.trainable_weights)
        )
        self.discriminator_optimizer.apply_gradients(
            zip(self.discriminator_optimizer.get_unscaled_gradients(
                discriminator_grads),
                self.discriminator_model.trainable_weights)
        )
        return srfr_loss, discriminator_loss, super_resolution_images

    #@tf.function
    def _train_step_synthetic_only(
            self,
            synthetic_images,
            groud_truth_images,
            synthetic_classes,
            num_classes,
        ):
        """Does a training step

        Parameters
        ----------
            model:
            images: Batch of images for training.
            classes: Batch of classes to compute the loss.
            num_classes: Total number of classes in the dataset.
            s:
            margin:

        Returns
        -------
            (srfr_loss, srfr_grads, discriminator_loss, discriminator_grads)
            The loss value and the gradients for SRFR network, as well as the
            loss value and the gradients for the Discriminative network.
        """
        srfr_loss, discriminator_loss, super_resolution_images = \
            self.strategy.experimental_run_v2(
                self._step_function,
                args=(
                    synthetic_images,
                    groud_truth_images,
                    synthetic_classes,
                    num_classes,
                ),
            )
        srfr_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                         srfr_loss, None)
        discriminator_loss = self.strategy.reduce(
            tf.distribute.ReduceOp.MEAN,
            discriminator_loss,
            None,
        )
        return srfr_loss, discriminator_loss, super_resolution_images
示例#10
0
import logging
from pathlib import Path

import cv2
import tensorflow as tf
from utils.timing import TimingLogger
from utils.input_data import InputData, parseConfigsFile

logging.basicConfig(
    filename='lfw_to_tfrecords.txt',
    level=logging.INFO
)
LOGGER = logging.getLogger(__name__)
timing = TimingLogger()
timing.start()

LOGGER.info('--- Setting Functions ---')
shape = tuple(parseConfigsFile(['preprocess'])['image_shape_low_resolution'][:2])


def _reduce_resolution(high_resolution_image):
    low_resolution_image = cv2.cvtColor(
        cv2.resize(high_resolution_image, shape, interpolation=cv2.INTER_CUBIC),
        cv2.COLOR_BGR2RGB)
    return tf.image.encode_png(low_resolution_image)


def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    try: