Пример #1
0
 def prepare_summary_writers(self):
     """Prepares the summary writers for TensorBoard."""
     self.dnn_summary_writer = SummaryWriter(
         os.path.join(self.trial_directory, 'DNN'))
     self.gan_summary_writer = SummaryWriter(
         os.path.join(self.trial_directory, 'GAN'))
     self.dnn_summary_writer.summary_period = self.settings.summary_step_period
     self.gan_summary_writer.summary_period = self.settings.summary_step_period
Пример #2
0
class DnnExperiment(Experiment, ABC):
    """A class to manage an experimental trial with only a DNN."""
    def prepare_summary_writers(self):
        """Prepares the summary writers for TensorBoard."""
        self.dnn_summary_writer = SummaryWriter(
            os.path.join(self.trial_directory, 'DNN'))
        self.dnn_summary_writer.summary_period = self.settings.summary_step_period

    def gpu_mode(self):
        """
        Moves the networks to the GPU (if available).
        """
        self.DNN.to(gpu)

    def eval_mode(self):
        """
        Changes the network to evaluation mode.
        """
        self.DNN.eval()

    def train_mode(self):
        """
        Converts the networks to train mode.
        """
        self.DNN.train()

    def prepare_optimizers(self):
        """Prepares the optimizers of the network."""
        d_lr = self.settings.learning_rate
        weight_decay = self.settings.weight_decay
        self.dnn_optimizer = Adam(self.DNN.parameters(),
                                  lr=d_lr,
                                  weight_decay=weight_decay,
                                  betas=(0.9, 0.999))

    def save_models(self, step):
        """Saves the network models."""
        model = {
            'DNN': self.DNN.state_dict(),
            'dnn_optimizer': self.dnn_optimizer.state_dict(),
            'step': step
        }
        torch.save(model,
                   os.path.join(self.trial_directory, f'model_{step}.pth'))

    def load_models(self, with_optimizers=True):
        """Loads existing models if they exist at `self.settings.load_model_path`."""
        if self.settings.load_model_path:
            latest_model = None
            model_path_file_names = os.listdir(self.settings.load_model_path)
            for file_name in model_path_file_names:
                match = re.search(r'model_?(\d+)?\.pth', file_name)
                if match:
                    latest_model = self.compare_model_path_for_latest(
                        latest_model, match)
            latest_model = None if latest_model is None else latest_model.group(
                0)
            if not torch.cuda.is_available():
                map_location = 'cpu'
            else:
                map_location = None
            if latest_model:
                model_path = os.path.join(self.settings.load_model_path,
                                          latest_model)
                loaded_model = torch.load(model_path, map_location)
                self.DNN.load_state_dict(loaded_model['DNN'])
                if with_optimizers:
                    self.dnn_optimizer.load_state_dict(
                        loaded_model['dnn_optimizer'])
                    self.optimizer_to_gpu(self.dnn_optimizer)
                print('Model loaded from `{}`.'.format(model_path))
                if self.settings.continue_existing_experiments:
                    self.starting_step = loaded_model['step'] + 1
                    print(f'Continuing from step {self.starting_step}')

    def training_loop(self):
        """Runs the main training loop."""
        train_dataset_generator = self.infinite_iter(self.train_dataset_loader)
        step_time_start = datetime.datetime.now()
        for step in range(self.starting_step, self.settings.steps_to_run):
            self.adjust_learning_rate(step)
            samples = next(train_dataset_generator)
            if len(samples) == 2:
                labeled_examples, labels = samples
                labeled_examples, labels = labeled_examples.to(gpu), labels.to(
                    gpu)
            else:
                labeled_examples, primary_labels, secondary_labels = samples
                labeled_examples, labels = labeled_examples.to(gpu), (
                    primary_labels.to(gpu), secondary_labels.to(gpu))
            self.dnn_training_step(labeled_examples, labels, step)
            if self.dnn_summary_writer.is_summary_step(
            ) or step == self.settings.steps_to_run - 1:
                print('\rStep {}, {}...'.format(
                    step,
                    datetime.datetime.now() - step_time_start),
                      end='')
                step_time_start = datetime.datetime.now()
                self.eval_mode()
                with torch.no_grad():
                    self.validation_summaries(step)
                self.train_mode()
            self.handle_user_input(step)
Пример #3
0
class DnnExperiment(Experiment, ABC):
    """A class to manage an experimental trial with only a DNN."""
    def prepare_summary_writers(self):
        """Prepares the summary writers for TensorBoard."""
        self.dnn_summary_writer = SummaryWriter(os.path.join(self.trial_directory, 'DNN'))
        self.dnn_summary_writer.summary_period = self.settings.summary_step_period

    def load_models(self):
        """Loads existing models if they exist at `self.settings.load_model_path`."""
        if self.settings.load_model_path:
            latest_dnn_model = None
            model_path_file_names = os.listdir(self.settings.load_model_path)
            for file_name in model_path_file_names:
                match = re.search(r'(DNN|D|G)_model_?(\d+)?\.pth', file_name)
                if match:
                    if match.group(1) == 'DNN':
                        latest_dnn_model = self.compare_model_path_for_latest(latest_dnn_model, match)
            latest_dnn_model = None if latest_dnn_model is None else latest_dnn_model.group(0)
            if not torch.cuda.is_available():
                map_location = 'cpu'
            else:
                map_location = None
            if latest_dnn_model:
                dnn_model_path = os.path.join(self.settings.load_model_path, latest_dnn_model)
                print('DNN model loaded from `{}`.'.format(dnn_model_path))
                self.DNN.load_state_dict(torch.load(dnn_model_path, map_location), strict=False)

    def gpu_mode(self):
        """
        Moves the networks to the GPU (if available).
        """
        self.DNN.to(gpu)

    def eval_mode(self):
        """
        Changes the network to evaluation mode.
        """
        self.DNN.eval()

    def train_mode(self):
        """
        Converts the networks to train mode.
        """
        self.DNN.train()

    def prepare_optimizers(self):
        """Prepares the optimizers of the network."""
        d_lr = self.settings.learning_rate
        weight_decay = self.settings.weight_decay
        self.dnn_optimizer = Adam(self.DNN.parameters(), lr=d_lr, weight_decay=weight_decay, betas=(0.9, 0.999))

    def save_models(self, step=None):
        """Saves the network models."""
        if step is not None:
            suffix = '_{}'.format(step)
        else:
            suffix = ''
        torch.save(self.DNN.state_dict(), os.path.join(self.trial_directory, 'DNN_model{}.pth'.format(suffix)))

    def training_loop(self):
        """Runs the main training loop."""
        train_dataset_generator = self.infinite_iter(self.train_dataset_loader)
        step_time_start = datetime.datetime.now()
        for step in range(self.settings.steps_to_run):
            self.adjust_learning_rate(step)
            labeled_examples, labels, maps = next(train_dataset_generator)
            labeled_examples, labels, maps = labeled_examples.to(gpu), labels.to(gpu), maps.to(gpu)
            self.dnn_training_step(labeled_examples, labels, step)
            if self.dnn_summary_writer.is_summary_step() or step == self.settings.steps_to_run - 1:
                print('\rStep {}, {}...'.format(step, datetime.datetime.now() - step_time_start), end='')
                step_time_start = datetime.datetime.now()
                self.eval_mode()
                self.validation_summaries(step)
                self.train_mode()
            self.handle_user_input(step)
Пример #4
0
class Experiment(ABC):
    """A class to manage an experimental trial."""
    def __init__(self, settings: Settings):
        self.settings = settings
        self.trial_directory: str = None
        self.dnn_summary_writer: SummaryWriter = None
        self.gan_summary_writer: SummaryWriter = None
        self.dataset_class = None
        self.train_dataset: Dataset = None
        self.train_dataset_loader: DataLoader = None
        self.unlabeled_dataset: Dataset = None
        self.unlabeled_dataset_loader: DataLoader = None
        self.validation_dataset: Dataset = None
        self.DNN: Module = None
        self.dnn_optimizer: Optimizer = None
        self.D: Module = None
        self.d_optimizer: Optimizer = None
        self.G: Module = None
        self.g_optimizer: Optimizer = None
        self.signal_quit = False

        self.labeled_features = None
        self.unlabeled_features = None
        self.fake_features = None
        self.interpolates_features = None

    def train(self):
        """
        Run the SRGAN training for the experiment.
        """
        self.trial_directory = os.path.join(self.settings.logs_directory,
                                            self.settings.trial_name)
        if (self.settings.skip_completed_experiment
                and os.path.exists(self.trial_directory)
                and '/check' not in self.trial_directory):
            print('`{}` experiment already exists. Skipping...'.format(
                self.trial_directory))
            return
        self.trial_directory = make_directory_name_unique(self.trial_directory)
        print(self.trial_directory)
        os.makedirs(
            os.path.join(self.trial_directory,
                         self.settings.temporary_directory))
        self.prepare_summary_writers()
        seed_all(0)

        self.dataset_setup()
        self.model_setup()
        self.load_models()
        self.gpu_mode()
        self.train_mode()
        self.prepare_optimizers()

        self.training_loop()

        print('Completed {}'.format(self.trial_directory))
        if self.settings.should_save_models:
            self.save_models()

    def save_models(self, step=None):
        """Saves the network models."""
        if step is not None:
            suffix = '_{}'.format(step)
        else:
            suffix = ''
        torch.save(
            self.DNN.state_dict(),
            os.path.join(self.trial_directory,
                         'DNN_model{}.pth'.format(suffix)))
        torch.save(
            self.D.state_dict(),
            os.path.join(self.trial_directory, 'D_model{}.pth'.format(suffix)))
        torch.save(
            self.G.state_dict(),
            os.path.join(self.trial_directory, 'G_model{}.pth'.format(suffix)))

    def training_loop(self):
        """Runs the main training loop."""
        train_dataset_generator = self.infinite_iter(self.train_dataset_loader)
        unlabeled_dataset_generator = self.infinite_iter(
            self.unlabeled_dataset_loader)
        step_time_start = datetime.datetime.now()
        for step in range(self.settings.steps_to_run):
            self.adjust_learning_rate(step)
            # DNN.
            samples = next(train_dataset_generator)
            if len(samples) == 2:
                labeled_examples, labels = samples
                labeled_examples, labels = labeled_examples.to(gpu), labels.to(
                    gpu)
            else:
                labeled_examples, primary_labels, secondary_labels = samples
                labeled_examples, labels = labeled_examples.to(gpu), (
                    primary_labels.to(gpu), secondary_labels.to(gpu))
            self.dnn_training_step(labeled_examples, labels, step)
            # GAN.
            unlabeled_examples = next(unlabeled_dataset_generator)[0]
            unlabeled_examples = unlabeled_examples.to(gpu)
            self.gan_training_step(labeled_examples, labels,
                                   unlabeled_examples, step)

            if self.gan_summary_writer.is_summary_step(
            ) or step == self.settings.steps_to_run - 1:
                print('\rStep {}, {}...'.format(
                    step,
                    datetime.datetime.now() - step_time_start),
                      end='')
                step_time_start = datetime.datetime.now()
                self.eval_mode()
                self.validation_summaries(step)
                self.train_mode()
            self.handle_user_input(step)

    def prepare_optimizers(self):
        """Prepares the optimizers of the network."""
        d_lr = self.settings.learning_rate
        g_lr = d_lr
        # betas = (0.9, 0.999)
        weight_decay = self.settings.weight_decay
        self.d_optimizer = Adam(self.D.parameters(),
                                lr=d_lr,
                                weight_decay=weight_decay,
                                betas=(0.99, 0.9999))
        self.g_optimizer = Adam(self.G.parameters(), lr=g_lr)
        self.dnn_optimizer = Adam(self.DNN.parameters(),
                                  lr=d_lr,
                                  weight_decay=weight_decay,
                                  betas=(0.99, 0.9999))

    def prepare_summary_writers(self):
        """Prepares the summary writers for TensorBoard."""
        self.dnn_summary_writer = SummaryWriter(
            os.path.join(self.trial_directory, 'DNN'))
        self.gan_summary_writer = SummaryWriter(
            os.path.join(self.trial_directory, 'GAN'))
        self.dnn_summary_writer.summary_period = self.settings.summary_step_period
        self.gan_summary_writer.summary_period = self.settings.summary_step_period

    def handle_user_input(self, step):
        """
        Handle input from the user.

        :param step: The current step of the program.
        :type step: int
        """
        while sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
            line = sys.stdin.readline()
            if 'save' in line:
                self.save_models(step)
                print('\rSaved model for step {}...'.format(step))
            if 'quit' in line:
                self.signal_quit = True
                print('\rQuit requested after current experiment...')

    def train_mode(self):
        """
        Converts the networks to train mode.
        """
        self.D.train()
        self.DNN.train()
        self.G.train()

    def gpu_mode(self):
        """
        Moves the networks to the GPU (if available).
        """
        self.D.to(gpu)
        self.DNN.to(gpu)
        self.G.to(gpu)

    def eval_mode(self):
        """
        Changes the network to evaluation mode.
        """
        self.D.eval()
        self.DNN.eval()
        self.G.eval()

    def cpu_mode(self):
        """
        Moves the networks to the CPU.
        """
        self.D.to('cpu')
        self.DNN.to('cpu')
        self.G.to('cpu')

    @staticmethod
    def compare_model_path_for_latest(model_path1, model_path2):
        """
        Compares two version of the model path to see which one has trained longer. A model without any step number
        is considered to have trained the longest.

        :param model_path1: The first model path.
        :type model_path1: re.Match
        :param model_path2: The second model path.
        :type model_path2: re.Match
        :return: The model path which was newer.
        :rtype: re.Match
        """
        if model_path1 is None:
            return model_path2
        elif model_path1.group(2) is None:
            return model_path1
        elif model_path2.group(2) is None:
            return model_path2
        elif int(model_path1.group(2)) > int(model_path2.group(2)):
            return model_path1
        else:
            return model_path2

    def load_models(self):
        """Loads existing models if they exist at `self.settings.load_model_path`."""
        if self.settings.load_model_path:
            latest_dnn_model = None
            latest_d_model = None
            latest_g_model = None
            model_path_file_names = os.listdir(self.settings.load_model_path)
            for file_name in model_path_file_names:
                match = re.search(r'(DNN|D|G)_model_?(\d+)?\.pth', file_name)
                if match:
                    if match.group(1) == 'DNN':
                        latest_dnn_model = self.compare_model_path_for_latest(
                            latest_dnn_model, match)
                    elif match.group(1) == 'D':
                        latest_d_model = self.compare_model_path_for_latest(
                            latest_d_model, match)
                    elif match.group(1) == 'G':
                        latest_g_model = self.compare_model_path_for_latest(
                            latest_g_model, match)
            latest_dnn_model = None if latest_dnn_model is None else latest_dnn_model.group(
                0)
            latest_d_model = None if latest_d_model is None else latest_d_model.group(
                0)
            latest_g_model = None if latest_g_model is None else latest_g_model.group(
                0)
            if not torch.cuda.is_available():
                map_location = 'cpu'
            else:
                map_location = None
            if latest_dnn_model:
                dnn_model_path = os.path.join(self.settings.load_model_path,
                                              latest_dnn_model)
                print('DNN model loaded from `{}`.'.format(dnn_model_path))
                self.DNN.load_state_dict(
                    torch.load(dnn_model_path, map_location))
            if latest_d_model:
                d_model_path = os.path.join(self.settings.load_model_path,
                                            latest_d_model)
                print('D model loaded from `{}`.'.format(d_model_path))
                self.D.load_state_dict(torch.load(d_model_path, map_location))
            if latest_g_model:
                g_model_path = os.path.join(self.settings.load_model_path,
                                            latest_g_model)
                print('G model loaded from `{}`.'.format(g_model_path))
                self.G.load_state_dict(torch.load(g_model_path, map_location))

    def dnn_training_step(self, examples, labels, step):
        """Runs an individual round of DNN training."""
        self.dnn_summary_writer.step = step
        self.dnn_optimizer.zero_grad()
        dnn_loss = self.dnn_loss_calculation(examples, labels)
        dnn_loss.backward()
        self.dnn_optimizer.step()
        # Summaries.
        if self.dnn_summary_writer.is_summary_step():
            self.dnn_summary_writer.add_scalar('Discriminator/Labeled Loss',
                                               dnn_loss.item())
            if hasattr(self.DNN, 'features') and self.DNN.features is not None:
                self.dnn_summary_writer.add_scalar(
                    'Feature Norm/Labeled',
                    self.DNN.features.norm(dim=1).mean().item())

    def gan_training_step(self, labeled_examples, labels, unlabeled_examples,
                          step):
        """Runs an individual round of GAN training."""
        # Labeled.
        self.gan_summary_writer.step = step
        self.d_optimizer.zero_grad()
        loss = torch.tensor(0, dtype=torch.float)
        labeled_loss = self.labeled_loss_calculation(labeled_examples, labels)
        loss += labeled_loss
        # Unlabeled.
        self.D.apply(
            disable_batch_norm_updates
        )  # Make sure only labeled data is used for batch norm statistics
        unlabeled_loss = self.unlabeled_loss_calculation(unlabeled_examples)
        loss += unlabeled_loss
        # Feature regularization loss.
        if self.settings.regularize_feature_norm:
            feature_regularization_loss = torch.abs(
                self.unlabeled_features.mean(0).norm() - 1)
            loss += feature_regularization_loss
        # Fake.
        z = torch.tensor(
            MixtureModel([
                norm(-self.settings.mean_offset, 1),
                norm(self.settings.mean_offset, 1)
            ]).rvs(
                size=[unlabeled_examples.size(0), self.G.input_size]).astype(
                    np.float32)).to(gpu)
        fake_examples = self.G(z)
        fake_loss = self.fake_loss_calculation(fake_examples)
        loss += fake_loss
        # Gradient penalty.
        alpha = torch.rand(2, device=gpu)
        alpha = alpha / alpha.sum(0)
        interpolates = (
            alpha[0] * unlabeled_examples.detach().requires_grad_() +
            alpha[1] * fake_examples.detach().requires_grad_())
        interpolates_loss = self.interpolate_loss_calculation(interpolates)
        gradients = torch.autograd.grad(outputs=interpolates_loss,
                                        inputs=interpolates,
                                        grad_outputs=torch.ones_like(
                                            interpolates_loss, device=gpu),
                                        create_graph=True,
                                        only_inputs=True)[0]
        gradient_penalty = (
            (gradients.view(unlabeled_examples.size(0), -1).norm(dim=1) - 1)**
            2).mean() * self.settings.gradient_penalty_multiplier
        # Discriminator update.
        loss += gradient_penalty
        loss.backward()
        self.d_optimizer.step()
        # Generator.
        if step % self.settings.generator_training_step_period == 0:
            self.g_optimizer.zero_grad()
            z = torch.randn(unlabeled_examples.size(0),
                            self.G.input_size).to(gpu)
            fake_examples = self.G(z)
            generator_loss = self.generator_loss_calculation(
                fake_examples, unlabeled_examples)
            generator_loss.backward()
            self.g_optimizer.step()
            if self.gan_summary_writer.is_summary_step():
                self.gan_summary_writer.add_scalar('Generator/Loss',
                                                   generator_loss.item())
        # Summaries.
        if self.gan_summary_writer.is_summary_step():
            self.gan_summary_writer.add_scalar('Discriminator/Labeled Loss',
                                               labeled_loss.item())
            self.gan_summary_writer.add_scalar('Discriminator/Unlabeled Loss',
                                               unlabeled_loss.item())
            self.gan_summary_writer.add_scalar('Discriminator/Fake Loss',
                                               fake_loss.item())
            if self.labeled_features is not None:
                self.gan_summary_writer.add_scalar(
                    'Feature Norm/Labeled',
                    self.labeled_features.mean(0).norm().item())
                self.gan_summary_writer.add_scalar(
                    'Feature Norm/Unlabeled',
                    self.unlabeled_features.mean(0).norm().item())
        self.D.apply(
            enable_batch_norm_updates
        )  # Make sure only labeled data is used for batch norm running statistics

    def dnn_loss_calculation(self, labeled_examples, labels):
        """Calculates the DNN loss."""
        predicted_labels = self.DNN(labeled_examples)
        labeled_loss = self.labeled_loss_function(
            predicted_labels, labels, order=self.settings.labeled_loss_order)
        labeled_loss *= self.settings.labeled_loss_multiplier
        return labeled_loss

    def labeled_loss_calculation(self, labeled_examples, labels):
        """Calculates the labeled loss."""
        predicted_labels = self.D(labeled_examples)
        self.labeled_features = self.D.features
        labeled_loss = self.labeled_loss_function(
            predicted_labels, labels, order=self.settings.labeled_loss_order)
        labeled_loss *= self.settings.labeled_loss_multiplier
        return labeled_loss

    def unlabeled_loss_calculation(self, unlabeled_examples):
        """Calculates the unlabeled loss."""
        _ = self.D(unlabeled_examples)
        self.unlabeled_features = self.D.features
        unlabeled_loss = feature_distance_loss(
            self.unlabeled_features,
            self.labeled_features) * self.settings.unlabeled_loss_multiplier
        return unlabeled_loss

    def fake_loss_calculation(self, fake_examples):
        """Calculates the fake loss."""
        _ = self.D(fake_examples.detach())
        self.fake_features = self.D.features
        fake_loss = feature_distance_loss(
            self.unlabeled_features,
            self.fake_features,
            distance_function=self.settings.fake_loss_distance).neg(
            ) * self.settings.fake_loss_multiplier
        return fake_loss

    def interpolate_loss_calculation(self, interpolates):
        """Calculates the interpolate loss for use in the gradient penalty."""
        _ = self.D(interpolates)
        self.interpolates_features = self.D.features
        interpolates_loss = feature_distance_loss(
            self.unlabeled_features,
            self.interpolates_features,
            distance_function=self.settings.fake_loss_distance).neg(
            ) * self.settings.fake_loss_multiplier
        return interpolates_loss

    def generator_loss_calculation(self, fake_examples, unlabeled_examples):
        """Calculates the generator's loss."""
        _ = self.D(fake_examples)
        self.fake_features = self.D.features
        _ = self.D(unlabeled_examples)
        detached_unlabeled_features = self.D.features.detach()
        generator_loss = feature_distance_loss(detached_unlabeled_features,
                                               self.fake_features)
        return generator_loss

    @abstractmethod
    def dataset_setup(self):
        """Prepares all the datasets and loaders required for the application."""
        self.train_dataset = Dataset()
        self.unlabeled_dataset = Dataset()
        self.validation_dataset = Dataset()
        self.train_dataset_loader = DataLoader(self.train_dataset)
        self.unlabeled_dataset_loader = DataLoader(self.validation_dataset)

    @abstractmethod
    def model_setup(self):
        """Prepares all the model architectures required for the application."""
        self.DNN = Module()
        self.D = Module()
        self.G = Module()

    @abstractmethod
    def validation_summaries(self, step: int):
        """Prepares the summaries that should be run for the given application."""
        pass

    @staticmethod
    def labeled_loss_function(predicted_labels, labels, order=2):
        """Calculate the loss from the label difference prediction."""
        return (predicted_labels - labels).abs().pow(order).mean()

    def evaluate(self):
        """Evaluates the model on the test dataset (needs to be overridden by subclass."""
        self.model_setup()
        self.load_models()
        self.eval_mode()

    @staticmethod
    def infinite_iter(dataset):
        """Create an infinite generator from a dataset. Forces full batch sizes."""
        while True:
            for examples in dataset:
                yield examples

    def adjust_learning_rate(self, step):
        """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
        lr = self.settings.learning_rate * (0.1**(step // 100000))
        for param_group in self.dnn_optimizer.param_groups:
            param_group['lr'] = lr
Пример #5
0
class Experiment(ABC):
    """A class to manage an experimental trial."""
    def __init__(self, settings: Settings):
        self.settings = settings
        self.trial_directory: str = None
        self.dnn_summary_writer: SummaryWriter = None
        self.gan_summary_writer: SummaryWriter = None
        self.dataset_class = None
        self.train_dataset: Dataset = None
        self.train_dataset_loader: DataLoader = None
        self.unlabeled_dataset: Dataset = None
        self.unlabeled_dataset_loader: DataLoader = None
        self.validation_dataset: Dataset = None
        self.DNN: Module = None
        self.dnn_optimizer: Optimizer = None
        self.D: Module = None
        self.d_optimizer: Optimizer = None
        self.G: Module = None
        self.g_optimizer: Optimizer = None
        self.signal_quit = False
        self.starting_step = 0

        self.labeled_features = None
        self.unlabeled_features = None
        self.fake_features = None
        self.interpolates_features = None
        self.gradient_norm = None

    def train(self):
        """
        Run the SRGAN training for the experiment.
        """
        self.trial_directory = os.path.join(self.settings.logs_directory,
                                            self.settings.trial_name)
        if (self.settings.skip_completed_experiment
                and os.path.exists(self.trial_directory)
                and '/check' not in self.trial_directory
                and not self.settings.continue_existing_experiments):
            print('`{}` experiment already exists. Skipping...'.format(
                self.trial_directory))
            return
        if not self.settings.continue_existing_experiments:
            self.trial_directory = make_directory_name_unique(
                self.trial_directory)
        else:
            if os.path.exists(self.trial_directory
                              ) and self.settings.load_model_path is not None:
                raise ValueError(
                    'Cannot load from path and continue existing at the same time.'
                )
            elif self.settings.load_model_path is None:
                self.settings.load_model_path = self.trial_directory
            elif not os.path.exists(self.trial_directory):
                self.settings.continue_existing_experiments = False
        print(self.trial_directory)
        os.makedirs(os.path.join(self.trial_directory,
                                 self.settings.temporary_directory),
                    exist_ok=True)
        self.prepare_summary_writers()
        seed_all(0)

        self.dataset_setup()
        self.model_setup()
        self.prepare_optimizers()
        self.load_models()
        self.gpu_mode()
        self.train_mode()

        self.training_loop()

        print('Completed {}'.format(self.trial_directory))
        if self.settings.should_save_models:
            self.save_models(step=self.settings.steps_to_run)

    def save_models(self, step):
        """Saves the network models."""
        model = {
            'DNN': self.DNN.state_dict(),
            'dnn_optimizer': self.dnn_optimizer.state_dict(),
            'D': self.D.state_dict(),
            'd_optimizer': self.d_optimizer.state_dict(),
            'G': self.G.state_dict(),
            'g_optimizer': self.g_optimizer.state_dict(),
            'step': step
        }
        torch.save(model,
                   os.path.join(self.trial_directory, f'model_{step}.pth'))

    def training_loop(self):
        """Runs the main training loop."""
        train_dataset_generator = self.infinite_iter(self.train_dataset_loader)
        unlabeled_dataset_generator = self.infinite_iter(
            self.unlabeled_dataset_loader)
        step_time_start = datetime.datetime.now()
        for step in range(self.starting_step, self.settings.steps_to_run):
            self.adjust_learning_rate(step)
            # DNN.
            samples = next(train_dataset_generator)
            if len(samples) == 2:
                labeled_examples, labels = samples
                labeled_examples, labels = labeled_examples.to(gpu), labels.to(
                    gpu)
            else:
                labeled_examples, primary_labels, secondary_labels = samples
                labeled_examples, labels = labeled_examples.to(gpu), (
                    primary_labels.to(gpu), secondary_labels.to(gpu))
            self.dnn_training_step(labeled_examples, labels, step)
            # GAN.
            unlabeled_examples = next(unlabeled_dataset_generator)[0]
            unlabeled_examples = unlabeled_examples.to(gpu)
            self.gan_training_step(labeled_examples, labels,
                                   unlabeled_examples, step)

            if self.gan_summary_writer.is_summary_step(
            ) or step == self.settings.steps_to_run - 1:
                print('\rStep {}, {}...'.format(
                    step,
                    datetime.datetime.now() - step_time_start),
                      end='')
                step_time_start = datetime.datetime.now()
                self.eval_mode()
                with torch.no_grad():
                    self.validation_summaries(step)
                self.train_mode()
            self.handle_user_input(step)
            if self.settings.save_step_period and step % self.settings.save_step_period == 0 and step != 0:
                self.save_models(step=step)

    def prepare_optimizers(self):
        """Prepares the optimizers of the network."""
        d_lr = self.settings.learning_rate
        g_lr = d_lr
        weight_decay = self.settings.weight_decay
        self.d_optimizer = Adam(self.D.parameters(),
                                lr=d_lr,
                                weight_decay=weight_decay)
        self.g_optimizer = Adam(self.G.parameters(), lr=g_lr)
        self.dnn_optimizer = Adam(self.DNN.parameters(),
                                  lr=d_lr,
                                  weight_decay=weight_decay)

    def prepare_summary_writers(self):
        """Prepares the summary writers for TensorBoard."""
        self.dnn_summary_writer = SummaryWriter(
            os.path.join(self.trial_directory, 'DNN'))
        self.gan_summary_writer = SummaryWriter(
            os.path.join(self.trial_directory, 'GAN'))
        self.dnn_summary_writer.summary_period = self.settings.summary_step_period
        self.gan_summary_writer.summary_period = self.settings.summary_step_period
        self.dnn_summary_writer.steps_to_run = self.settings.steps_to_run
        self.gan_summary_writer.steps_to_run = self.settings.steps_to_run

    def handle_user_input(self, step):
        """
        Handle input from the user.

        :param step: The current step of the program.
        :type step: int
        """
        while sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
            line = sys.stdin.readline()
            if 'save' in line:
                self.save_models(step)
                print('\rSaved model for step {}...'.format(step))
            if 'quit' in line:
                self.signal_quit = True
                print('\rQuit requested after current experiment...')

    def train_mode(self):
        """
        Converts the networks to train mode.
        """
        self.D.train()
        self.DNN.train()
        self.G.train()

    def gpu_mode(self):
        """
        Moves the networks to the GPU (if available).
        """
        self.D.to(gpu)
        self.DNN.to(gpu)
        self.G.to(gpu)

    def eval_mode(self):
        """
        Changes the network to evaluation mode.
        """
        self.D.eval()
        self.DNN.eval()
        self.G.eval()

    def cpu_mode(self):
        """
        Moves the networks to the CPU.
        """
        self.D.to('cpu')
        self.DNN.to('cpu')
        self.G.to('cpu')

    @staticmethod
    def compare_model_path_for_latest(model_path1, model_path2):
        """
        Compares two version of the model path to see which one has trained longer. A model without any step number
        is considered to have trained the longest.

        :param model_path1: The first model path.
        :type model_path1: re.Match
        :param model_path2: The second model path.
        :type model_path2: re.Match
        :return: The model path which was newer.
        :rtype: re.Match
        """
        if model_path1 is None:
            return model_path2
        elif model_path1.group(1) is None:
            return model_path1
        elif model_path2.group(1) is None:
            return model_path2
        elif int(model_path1.group(1)) > int(model_path2.group(1)):
            return model_path1
        else:
            return model_path2

    def load_models(self, with_optimizers=True):
        """Loads existing models if they exist at `self.settings.load_model_path`."""
        if self.settings.load_model_path:
            latest_model = None
            model_path_file_names = os.listdir(self.settings.load_model_path)
            for file_name in model_path_file_names:
                match = re.search(r'model_?(\d+)?\.pth', file_name)
                if match:
                    latest_model = self.compare_model_path_for_latest(
                        latest_model, match)
            latest_model = None if latest_model is None else latest_model.group(
                0)
            if not torch.cuda.is_available():
                map_location = 'cpu'
            else:
                map_location = None
            if latest_model:
                model_path = os.path.join(self.settings.load_model_path,
                                          latest_model)
                loaded_model = torch.load(model_path, map_location)
                self.DNN.load_state_dict(loaded_model['DNN'])
                self.D.load_state_dict(loaded_model['D'])
                self.G.load_state_dict(loaded_model['G'])
                if with_optimizers:
                    self.dnn_optimizer.load_state_dict(
                        loaded_model['dnn_optimizer'])
                    self.optimizer_to_gpu(self.dnn_optimizer)
                    self.d_optimizer.load_state_dict(
                        loaded_model['d_optimizer'])
                    self.optimizer_to_gpu(self.d_optimizer)
                    self.g_optimizer.load_state_dict(
                        loaded_model['g_optimizer'])
                    self.optimizer_to_gpu(self.g_optimizer)
                print('Model loaded from `{}`.'.format(model_path))
                if self.settings.continue_existing_experiments:
                    self.starting_step = loaded_model['step'] + 1
                    print(f'Continuing from step {self.starting_step}')

    def optimizer_to_gpu(self, optimizer):
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()

    def dnn_training_step(self, examples, labels, step):
        """Runs an individual round of DNN training."""
        self.DNN.apply(disable_batch_norm_updates)  # No batch norm
        self.dnn_summary_writer.step = step
        self.dnn_optimizer.zero_grad()
        dnn_loss = self.dnn_loss_calculation(examples, labels)
        dnn_loss.backward()
        self.dnn_optimizer.step()
        # Summaries.
        if self.dnn_summary_writer.is_summary_step():
            self.dnn_summary_writer.add_scalar('Discriminator/Labeled Loss',
                                               dnn_loss.item())
            if hasattr(self.DNN, 'features') and self.DNN.features is not None:
                self.dnn_summary_writer.add_scalar(
                    'Feature Norm/Labeled',
                    self.DNN.features.norm(dim=1).mean().item())

    def gan_training_step(self, labeled_examples, labels, unlabeled_examples,
                          step):
        """Runs an individual round of GAN training."""
        # Labeled.
        self.D.apply(disable_batch_norm_updates)  # No batch norm
        self.gan_summary_writer.step = step
        self.d_optimizer.zero_grad()
        labeled_loss = self.labeled_loss_calculation(labeled_examples, labels)
        labeled_loss.backward()
        # Unlabeled.
        # self.D.apply(disable_batch_norm_updates)  # Make sure only labeled data is used for batch norm statistics
        unlabeled_loss = self.unlabeled_loss_calculation(
            labeled_examples, unlabeled_examples)
        unlabeled_loss.backward()
        # Fake.
        z = torch.tensor(
            MixtureModel([
                norm(-self.settings.mean_offset, 1),
                norm(self.settings.mean_offset, 1)
            ]).rvs(
                size=[unlabeled_examples.size(0), self.G.input_size]).astype(
                    np.float32)).to(gpu)
        fake_examples = self.G(z)
        fake_loss = self.fake_loss_calculation(unlabeled_examples,
                                               fake_examples)
        fake_loss.backward()
        # Gradient penalty.
        gradient_penalty = self.gradient_penalty_calculation(
            fake_examples, unlabeled_examples)
        gradient_penalty.backward()
        # Discriminator update.
        self.d_optimizer.step()
        # Generator.
        if step % self.settings.generator_training_step_period == 0:
            self.g_optimizer.zero_grad()
            z = torch.randn(unlabeled_examples.size(0),
                            self.G.input_size).to(gpu)
            fake_examples = self.G(z)
            generator_loss = self.generator_loss_calculation(
                fake_examples, unlabeled_examples)
            generator_loss.backward()
            self.g_optimizer.step()
            if self.gan_summary_writer.is_summary_step():
                self.gan_summary_writer.add_scalar('Generator/Loss',
                                                   generator_loss.item())
        # Summaries.
        if self.gan_summary_writer.is_summary_step():
            self.gan_summary_writer.add_scalar('Discriminator/Labeled Loss',
                                               labeled_loss.item())
            self.gan_summary_writer.add_scalar('Discriminator/Unlabeled Loss',
                                               unlabeled_loss.item())
            self.gan_summary_writer.add_scalar('Discriminator/Fake Loss',
                                               fake_loss.item())
            self.gan_summary_writer.add_scalar(
                'Discriminator/Gradient Penalty', gradient_penalty.item())
            self.gan_summary_writer.add_scalar(
                'Discriminator/Gradient Norm',
                self.gradient_norm.mean().item())
            if self.labeled_features is not None and self.unlabeled_features is not None:
                self.gan_summary_writer.add_scalar(
                    'Feature Norm/Labeled',
                    self.labeled_features.mean(0).norm().item())
                self.gan_summary_writer.add_scalar(
                    'Feature Norm/Unlabeled',
                    self.unlabeled_features.mean(0).norm().item())
        # self.D.apply(enable_batch_norm_updates)  # Only labeled data used for batch norm running statistics

    def dnn_loss_calculation(self, labeled_examples, labels):
        """Calculates the DNN loss."""
        predicted_labels = self.DNN(labeled_examples)
        labeled_loss = self.labeled_loss_function(
            predicted_labels, labels, order=self.settings.labeled_loss_order)
        labeled_loss *= self.settings.labeled_loss_multiplier
        return labeled_loss

    def labeled_loss_calculation(self, labeled_examples, labels):
        """Calculates the labeled loss."""
        predicted_labels = self.D(labeled_examples)
        self.labeled_features = self.D.features
        labeled_loss = self.labeled_loss_function(
            predicted_labels, labels, order=self.settings.labeled_loss_order)
        labeled_loss *= self.settings.labeled_loss_multiplier
        return labeled_loss

    def unlabeled_loss_calculation(self, labeled_examples: Tensor,
                                   unlabeled_examples: Tensor):
        """Calculates the unlabeled loss."""
        _ = self.D(labeled_examples)
        self.labeled_features = self.D.features
        _ = self.D(unlabeled_examples)
        self.unlabeled_features = self.D.features
        unlabeled_loss = self.feature_distance_loss(self.unlabeled_features,
                                                    self.labeled_features)
        unlabeled_loss *= self.settings.matching_loss_multiplier
        unlabeled_loss *= self.settings.srgan_loss_multiplier
        return unlabeled_loss

    def fake_loss_calculation(self, unlabeled_examples: Tensor,
                              fake_examples: Tensor):
        """Calculates the fake loss."""
        _ = self.D(unlabeled_examples)
        self.unlabeled_features = self.D.features
        _ = self.D(fake_examples.detach())
        self.fake_features = self.D.features
        fake_loss = self.feature_distance_loss(
            self.unlabeled_features,
            self.fake_features,
            distance_function=self.settings.contrasting_distance_function)
        fake_loss *= self.settings.contrasting_loss_multiplier
        fake_loss *= self.settings.srgan_loss_multiplier
        return fake_loss

    def gradient_penalty_calculation(self, fake_examples: Tensor,
                                     unlabeled_examples: Tensor) -> Tensor:
        """Calculates the gradient penalty from the given fake and real examples."""
        alpha_shape = [1] * len(unlabeled_examples.size())
        alpha_shape[0] = self.settings.batch_size
        alpha = torch.rand(alpha_shape, device=gpu)
        interpolates = (alpha * unlabeled_examples.detach().requires_grad_() +
                        (1 - alpha) * fake_examples.detach().requires_grad_())
        interpolates_loss = self.interpolate_loss_calculation(interpolates)
        gradients = torch.autograd.grad(outputs=interpolates_loss,
                                        inputs=interpolates,
                                        grad_outputs=torch.ones_like(
                                            interpolates_loss, device=gpu),
                                        create_graph=True)[0]
        gradient_norm = gradients.view(unlabeled_examples.size(0),
                                       -1).norm(dim=1)
        self.gradient_norm = gradient_norm
        norm_excesses = torch.max(gradient_norm - 1,
                                  torch.zeros_like(gradient_norm))
        gradient_penalty = (
            norm_excesses**
            2).mean() * self.settings.gradient_penalty_multiplier
        return gradient_penalty

    def interpolate_loss_calculation(self, interpolates):
        """Calculates the interpolate loss for use in the gradient penalty."""
        _ = self.D(interpolates)
        self.interpolates_features = self.D.features
        return self.interpolates_features.norm(dim=1)

    def generator_loss_calculation(self, fake_examples, unlabeled_examples):
        """Calculates the generator's loss."""
        _ = self.D(fake_examples)
        self.fake_features = self.D.features
        _ = self.D(unlabeled_examples)
        detached_unlabeled_features = self.D.features.detach()
        generator_loss = self.feature_distance_loss(
            detached_unlabeled_features, self.fake_features)
        generator_loss *= self.settings.matching_loss_multiplier
        return generator_loss

    @abstractmethod
    def dataset_setup(self):
        """Prepares all the datasets and loaders required for the application."""
        self.train_dataset = Dataset()
        self.unlabeled_dataset = Dataset()
        self.validation_dataset = Dataset()
        self.train_dataset_loader = DataLoader(self.train_dataset)
        self.unlabeled_dataset_loader = DataLoader(self.validation_dataset)

    @abstractmethod
    def model_setup(self):
        """Prepares all the model architectures required for the application."""
        self.DNN = Module()
        self.D = Module()
        self.G = Module()

    @abstractmethod
    def validation_summaries(self, step: int):
        """Prepares the summaries that should be run for the given application."""
        pass

    @staticmethod
    def labeled_loss_function(predicted_labels, labels, order=2):
        """Calculate the loss from the label difference prediction."""
        return (predicted_labels - labels).abs().pow(order).mean()

    def evaluate(self):
        """Evaluates the model on the test dataset (needs to be overridden by subclass)."""
        self.model_setup()
        self.load_models()
        self.eval_mode()

    @staticmethod
    def infinite_iter(dataset):
        """Create an infinite generator from a dataset"""
        while True:
            for examples in dataset:
                yield examples

    def adjust_learning_rate(self, step):
        """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
        lr = self.settings.learning_rate * (0.1**(step // 100000))
        for param_group in self.dnn_optimizer.param_groups:
            param_group['lr'] = lr

    def feature_distance_loss(self,
                              base_features,
                              other_features,
                              distance_function=None):
        """Calculate the loss based on the distance between feature vectors."""
        if distance_function is None:
            distance_function = self.settings.matching_distance_function
        base_mean_features = base_features.mean(0)
        other_mean_features = other_features.mean(0)
        if self.settings.normalize_feature_norm:
            epsilon = 1e-5
            base_mean_features = base_mean_features / (
                base_mean_features.norm() + epsilon)
            other_mean_features = other_features / (
                other_mean_features.norm() + epsilon)
        distance_vector = distance_function(base_mean_features -
                                            other_mean_features)
        return distance_vector

    @property
    def inference_network(self):
        """The network to be used for inference."""
        return self.D

    def inference_setup(self):
        """
        Sets up the network for inference.
        """
        self.model_setup()
        self.load_models(with_optimizers=False)
        self.gpu_mode()
        self.eval_mode()

    def inference(self, input_):
        """
        Run the inference for the experiment.
        """
        raise NotImplementedError
Пример #6
0
    def train(self):
        """
        Run the SRGAN training for the experiment.
        """
        self.trial_directory = os.path.join(self.settings.logs_directory,
                                            self.settings.trial_name)
        if (self.settings.skip_completed_experiment
                and os.path.exists(self.trial_directory)
                and '/check' not in self.trial_directory):
            print('`{}` experiment already exists. Skipping...'.format(
                self.trial_directory))
            return
        self.trial_directory = make_directory_name_unique(self.trial_directory)
        print(self.trial_directory)
        os.makedirs(
            os.path.join(self.trial_directory,
                         self.settings.temporary_directory))
        self.dnn_summary_writer = SummaryWriter(
            os.path.join(self.trial_directory, 'DNN'))
        self.gan_summary_writer = SummaryWriter(
            os.path.join(self.trial_directory, 'GAN'))
        self.dnn_summary_writer.summary_period = self.settings.summary_step_period
        self.gan_summary_writer.summary_period = self.settings.summary_step_period
        seed_all(0)

        self.dataset_setup()
        self.model_setup()
        self.load_models()

        d_lr = self.settings.learning_rate
        g_lr = d_lr
        # betas = (0.9, 0.999)
        weight_decay = 1e-2
        self.d_optimizer = Adam(self.D.parameters(),
                                lr=d_lr,
                                weight_decay=weight_decay)
        self.g_optimizer = Adam(self.G.parameters(), lr=g_lr)
        self.dnn_optimizer = Adam(self.DNN.parameters(),
                                  lr=d_lr,
                                  weight_decay=weight_decay)

        step_time_start = datetime.datetime.now()
        train_dataset_generator = infinite_iter(self.train_dataset_loader)
        unlabeled_dataset_generator = infinite_iter(
            self.unlabeled_dataset_loader)

        for step in range(self.settings.steps_to_run):
            # DNN.
            labeled_examples, labels = next(train_dataset_generator)
            labeled_examples, labels = labeled_examples.to(gpu), labels.to(gpu)
            self.dnn_training_step(labeled_examples, labels, step)
            # GAN.
            unlabeled_examples, _ = next(unlabeled_dataset_generator)
            unlabeled_examples = unlabeled_examples.to(gpu)
            self.gan_training_step(labeled_examples, labels,
                                   unlabeled_examples, step)

            if self.gan_summary_writer.is_summary_step():
                print('\rStep {}, {}...'.format(
                    step,
                    datetime.datetime.now() - step_time_start),
                      end='')
                step_time_start = datetime.datetime.now()
                self.eval_mode()
                self.validation_summaries(step)
                self.train_mode()
                while sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
                    line = sys.stdin.readline()
                    if 'save' in line:
                        torch.save(
                            self.DNN.state_dict(),
                            os.path.join(self.trial_directory,
                                         'DNN_model_{}.pth'.format(step)))
                        torch.save(
                            self.D.state_dict(),
                            os.path.join(self.trial_directory,
                                         'D_model_{}.pth'.format(step)))
                        torch.save(
                            self.G.state_dict(),
                            os.path.join(self.trial_directory,
                                         'G_model_{}.pth'.format(step)))
                        print('\rSaved model for step {}...'.format(step))
                    if 'quit' in line:
                        self.signal_quit = True
                        print('\rQuit requested after current experiment...')

        print('Completed {}'.format(self.trial_directory))
        if self.settings.should_save_models:
            torch.save(self.DNN.state_dict(),
                       os.path.join(self.trial_directory, 'DNN_model.pth'))
            torch.save(self.D.state_dict(),
                       os.path.join(self.trial_directory, 'D_model.pth'))
            torch.save(self.G.state_dict(),
                       os.path.join(self.trial_directory, 'G_model.pth'))
Пример #7
0
class Experiment(ABC):
    """A class to manage an experimental trial."""
    def __init__(self, settings: Settings):
        self.settings = settings
        self.trial_directory: str = None
        self.dnn_summary_writer: SummaryWriter = None
        self.gan_summary_writer: SummaryWriter = None
        self.train_dataset: Dataset = None
        self.train_dataset_loader: DataLoader = None
        self.unlabeled_dataset: Dataset = None
        self.unlabeled_dataset_loader: DataLoader = None
        self.validation_dataset: Dataset = None
        self.DNN: Module = None
        self.dnn_optimizer: Optimizer = None
        self.D: Module = None
        self.d_optimizer: Optimizer = None
        self.G: Module = None
        self.g_optimizer: Optimizer = None
        self.signal_quit = False

        self.labeled_features = None
        self.unlabeled_features = None
        self.fake_features = None
        self.interpolates_features = None

    def train(self):
        """
        Run the SRGAN training for the experiment.
        """
        self.trial_directory = os.path.join(self.settings.logs_directory,
                                            self.settings.trial_name)
        if (self.settings.skip_completed_experiment
                and os.path.exists(self.trial_directory)
                and '/check' not in self.trial_directory):
            print('`{}` experiment already exists. Skipping...'.format(
                self.trial_directory))
            return
        self.trial_directory = make_directory_name_unique(self.trial_directory)
        print(self.trial_directory)
        os.makedirs(
            os.path.join(self.trial_directory,
                         self.settings.temporary_directory))
        self.dnn_summary_writer = SummaryWriter(
            os.path.join(self.trial_directory, 'DNN'))
        self.gan_summary_writer = SummaryWriter(
            os.path.join(self.trial_directory, 'GAN'))
        self.dnn_summary_writer.summary_period = self.settings.summary_step_period
        self.gan_summary_writer.summary_period = self.settings.summary_step_period
        seed_all(0)

        self.dataset_setup()
        self.model_setup()
        self.load_models()

        d_lr = self.settings.learning_rate
        g_lr = d_lr
        # betas = (0.9, 0.999)
        weight_decay = 1e-2
        self.d_optimizer = Adam(self.D.parameters(),
                                lr=d_lr,
                                weight_decay=weight_decay)
        self.g_optimizer = Adam(self.G.parameters(), lr=g_lr)
        self.dnn_optimizer = Adam(self.DNN.parameters(),
                                  lr=d_lr,
                                  weight_decay=weight_decay)

        step_time_start = datetime.datetime.now()
        train_dataset_generator = infinite_iter(self.train_dataset_loader)
        unlabeled_dataset_generator = infinite_iter(
            self.unlabeled_dataset_loader)

        for step in range(self.settings.steps_to_run):
            # DNN.
            labeled_examples, labels = next(train_dataset_generator)
            labeled_examples, labels = labeled_examples.to(gpu), labels.to(gpu)
            self.dnn_training_step(labeled_examples, labels, step)
            # GAN.
            unlabeled_examples, _ = next(unlabeled_dataset_generator)
            unlabeled_examples = unlabeled_examples.to(gpu)
            self.gan_training_step(labeled_examples, labels,
                                   unlabeled_examples, step)

            if self.gan_summary_writer.is_summary_step():
                print('\rStep {}, {}...'.format(
                    step,
                    datetime.datetime.now() - step_time_start),
                      end='')
                step_time_start = datetime.datetime.now()
                self.eval_mode()
                self.validation_summaries(step)
                self.train_mode()
                while sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
                    line = sys.stdin.readline()
                    if 'save' in line:
                        torch.save(
                            self.DNN.state_dict(),
                            os.path.join(self.trial_directory,
                                         'DNN_model_{}.pth'.format(step)))
                        torch.save(
                            self.D.state_dict(),
                            os.path.join(self.trial_directory,
                                         'D_model_{}.pth'.format(step)))
                        torch.save(
                            self.G.state_dict(),
                            os.path.join(self.trial_directory,
                                         'G_model_{}.pth'.format(step)))
                        print('\rSaved model for step {}...'.format(step))
                    if 'quit' in line:
                        self.signal_quit = True
                        print('\rQuit requested after current experiment...')

        print('Completed {}'.format(self.trial_directory))
        if self.settings.should_save_models:
            torch.save(self.DNN.state_dict(),
                       os.path.join(self.trial_directory, 'DNN_model.pth'))
            torch.save(self.D.state_dict(),
                       os.path.join(self.trial_directory, 'D_model.pth'))
            torch.save(self.G.state_dict(),
                       os.path.join(self.trial_directory, 'G_model.pth'))

    def train_mode(self):
        """
        Converts the networks to train mode, including moving them to the GPU.
        """
        self.D.train()
        self.DNN.train()
        self.G.train()
        self.D.to(gpu)
        self.DNN.to(gpu)
        self.G.to(gpu)

    def eval_mode(self):
        self.D.eval()
        self.DNN.eval()
        self.G.eval()
        self.D.to('cpu')
        self.DNN.to('cpu')
        self.G.to('cpu')

    def load_models(self):
        if self.settings.load_model_path:
            if not torch.cuda.is_available():
                map_location = 'cpu'
            else:
                map_location = None
            self.DNN.load_state_dict(
                torch.load(
                    os.path.join(self.settings.load_model_path,
                                 'DNN_model.pth'), map_location))
            self.D.load_state_dict(
                torch.load(
                    os.path.join(self.settings.load_model_path, 'D_model.pth'),
                    map_location))
            self.G.load_state_dict(
                torch.load(
                    os.path.join(self.settings.load_model_path, 'G_model.pth'),
                    map_location))
        self.G = self.G.to(gpu)
        self.D = self.D.to(gpu)
        self.DNN = self.DNN.to(gpu)

    def dnn_training_step(self, examples, labels, step):
        """Runs an individual round of DNN training."""
        self.dnn_summary_writer.step = step
        self.dnn_optimizer.zero_grad()
        dnn_loss = self.dnn_loss_calculation(examples, labels)
        dnn_loss.backward()
        self.dnn_optimizer.step()
        # Summaries.
        if self.dnn_summary_writer.is_summary_step():
            self.dnn_summary_writer.add_scalar('Discriminator/Labeled Loss',
                                               dnn_loss.item())
            if self.DNN.features is not None:
                self.dnn_summary_writer.add_scalar(
                    'Feature Norm/Labeled',
                    self.DNN.features.norm(dim=1).mean().item())

    def gan_training_step(self, labeled_examples, labels, unlabeled_examples,
                          step):
        """Runs an individual round of GAN training."""
        # Labeled.
        self.gan_summary_writer.step = step
        self.d_optimizer.zero_grad()
        labeled_loss = self.labeled_loss_calculation(labeled_examples, labels)
        # Unlabeled.
        unlabeled_loss = self.unlabeled_loss_calculation(unlabeled_examples)
        # Fake.
        z = torch.tensor(
            MixtureModel([
                norm(-self.settings.mean_offset, 1),
                norm(self.settings.mean_offset, 1)
            ]).rvs(
                size=[unlabeled_examples.size(0), self.G.input_size]).astype(
                    np.float32)).to(gpu)
        fake_examples = self.G(z)
        fake_loss = self.fake_loss_calculation(fake_examples)
        # Gradient penalty.
        alpha = torch.rand(2, device=gpu)
        alpha = alpha / alpha.sum(0)
        interpolates = (
            alpha[0] * unlabeled_examples.detach().requires_grad_() +
            alpha[1] * fake_examples.detach().requires_grad_())
        interpolates_loss = self.interpolate_loss_calculation(interpolates)
        gradients = torch.autograd.grad(outputs=interpolates_loss,
                                        inputs=interpolates,
                                        grad_outputs=torch.ones_like(
                                            interpolates_loss, device=gpu),
                                        create_graph=True,
                                        only_inputs=True)[0]
        gradient_penalty = (
            (gradients.view(unlabeled_examples.size(0), -1).norm(dim=1) - 1)**
            2).mean() * self.settings.gradient_penalty_multiplier
        # Discriminator update.
        loss = labeled_loss + unlabeled_loss + fake_loss + gradient_penalty
        loss.backward()
        self.d_optimizer.step()
        # Generator.
        if step % self.settings.generator_training_step_period == 0:
            self.g_optimizer.zero_grad()
            z = torch.randn(unlabeled_examples.size(0),
                            self.G.input_size).to(gpu)
            fake_examples = self.G(z)
            generator_loss = self.generator_loss_calculation(
                fake_examples, unlabeled_examples)
            generator_loss.backward()
            self.g_optimizer.step()
            if self.gan_summary_writer.is_summary_step():
                self.gan_summary_writer.add_scalar('Generator/Loss',
                                                   generator_loss.item())
        # Summaries.
        if self.gan_summary_writer.is_summary_step():
            self.gan_summary_writer.add_scalar('Discriminator/Labeled Loss',
                                               labeled_loss.item())
            self.gan_summary_writer.add_scalar('Discriminator/Unlabeled Loss',
                                               unlabeled_loss.item())
            self.gan_summary_writer.add_scalar('Discriminator/Fake Loss',
                                               fake_loss.item())
            if self.labeled_features is not None:
                self.gan_summary_writer.add_scalar(
                    'Feature Norm/Labeled',
                    self.labeled_features.mean(0).norm().item())
                self.gan_summary_writer.add_scalar(
                    'Feature Norm/Unlabeled',
                    self.unlabeled_features.mean(0).norm().item())

    def dnn_loss_calculation(self, labeled_examples, labels):
        """Calculates the DNN loss."""
        predicted_labels = self.DNN(labeled_examples)
        labeled_loss = self.labeled_loss_function(
            predicted_labels, labels, order=self.settings.labeled_loss_order
        ) * self.settings.labeled_loss_multiplier
        return labeled_loss

    def labeled_loss_calculation(self, labeled_examples, labels):
        """Calculates the labeled loss."""
        predicted_labels = self.D(labeled_examples)
        self.labeled_features = self.D.features
        labeled_loss = self.labeled_loss_function(
            predicted_labels, labels, order=self.settings.labeled_loss_order
        ) * self.settings.labeled_loss_multiplier
        return labeled_loss

    def unlabeled_loss_calculation(self, unlabeled_examples):
        """Calculates the unlabeled loss."""
        _ = self.D(unlabeled_examples)
        self.unlabeled_features = self.D.features
        unlabeled_loss = feature_distance_loss(
            self.unlabeled_features,
            self.labeled_features,
            order=self.settings.unlabeled_loss_order
        ) * self.settings.unlabeled_loss_multiplier
        return unlabeled_loss

    def fake_loss_calculation(self, fake_examples):
        """Calculates the fake loss."""
        _ = self.D(fake_examples.detach())
        self.fake_features = self.D.features
        fake_loss = feature_distance_loss(
            self.unlabeled_features,
            self.fake_features,
            scale=self.settings.normalize_fake_loss,
            order=self.settings.fake_loss_order).neg(
            ) * self.settings.fake_loss_multiplier
        return fake_loss

    def interpolate_loss_calculation(self, interpolates):
        """Calculates the interpolate loss for use in the gradient penalty."""
        _ = self.D(interpolates)
        self.interpolates_features = self.D.features
        interpolates_loss = feature_distance_loss(
            self.unlabeled_features,
            self.interpolates_features,
            scale=self.settings.normalize_fake_loss,
            order=self.settings.fake_loss_order).neg(
            ) * self.settings.fake_loss_multiplier
        return interpolates_loss

    def generator_loss_calculation(self, fake_examples, unlabeled_examples):
        """Calculates the generator's loss."""
        _ = self.D(fake_examples)
        self.fake_features = self.D.features
        _ = self.D(unlabeled_examples)
        detached_unlabeled_features = self.D.features.detach()
        generator_loss = feature_distance_loss(
            detached_unlabeled_features,
            self.fake_features,
            order=self.settings.generator_loss_order)
        return generator_loss

    @abstractmethod
    def dataset_setup(self):
        """Prepares all the datasets and loaders required for the application."""
        self.train_dataset = Dataset()
        self.unlabeled_dataset = Dataset()
        self.validation_dataset = Dataset()
        self.train_dataset_loader = DataLoader(self.train_dataset)
        self.unlabeled_dataset_loader = DataLoader(self.validation_dataset)

    @abstractmethod
    def model_setup(self):
        """Prepares all the model architectures required for the application."""
        self.DNN = Module()
        self.D = Module()
        self.G = Module()

    @abstractmethod
    def validation_summaries(self, step: int):
        """Prepares the summaries that should be run for the given application."""
        pass

    @staticmethod
    def labeled_loss_function(predicted_labels, labels, order=2):
        """Calculate the loss from the label difference prediction."""
        return (predicted_labels - labels).abs().pow(order).mean()

    def evaluate(self):
        self.model_setup()
        self.load_models()
        self.eval_mode()