def __init__( self, strategy=None, **kwargs ): self.strategy = strategy autoencoder.__init__( self, **kwargs ) self.ONES = tf.ones(shape=[self.batch_size, 1]) self.ZEROS = tf.zeros(shape=[self.batch_size, 1]) self.adversarial_models = { 'generative_discriminator_real': { 'variable': None, 'adversarial_item': 'inference', 'adversarial_value': self.ONES }, 'generative_discriminator_fake': { 'variable': None, 'adversarial_item': 'inference', 'adversarial_value': self.ZEROS }, 'generative_generator_fake': { 'variable': None, 'adversarial_item': 'inference', 'adversarial_value': self.ONES } }
def compile( self, adversarial_losses, adversarial_weights, **kwargs ): self.adversarial_losses=adversarial_losses self.adversarial_weights=adversarial_weights autoencoder.compile( self, **kwargs )
def fit(self, x, validation_data=None, **kwargs): print() print(f'training {autoencoder}') # 1- train the basic basicAE autoencoder.fit(self, x=x, validation_data=validation_data, **kwargs) def create_discriminator(): for model in self.get_variables().values(): layer_stuffing(model) for k, model in self.adversarial_models.items(): model['variable'] = clone_model( old_model=self.get_variables()[model['adversarial_item']], new_name=k, restore=self.filepath) # 2- create a latents discriminator if self.strategy: with self.strategy: create_discriminator() else: create_discriminator() # 3- clone autoencoder variables self.ae_get_variables = copy_fn(self.get_variables) # 4- switch to discriminate if self.strategy: if self.strategy: self.discriminators_compile() else: self.discriminators_compile() verbose = kwargs.pop('verbose') callbacks = kwargs.pop('callbacks') kwargs.pop('input_kw') for k, model in self.adversarial_models.items(): print() print(f'training {k}') # 5- train the latents discriminator model['variable'].fit( x=x.map(self.create_batch_cast({k: model})), validation_data=None if validation_data is None else validation_data.map(self.create_batch_cast({k: model})), callbacks=[EarlyStopping()], verbose=1, **kwargs) kwargs['verbose'] = verbose kwargs['callbacks'] = callbacks # 6- connect all for inference_adversarial training if self.strategy: if self.strategy: self.__models_init__() else: self.__models_init__() print() print('training adversarial models') cbs = [ cb for cb in callbacks or [] if isinstance(cb, tf.keras.callbacks.CSVLogger) ] for cb in cbs: cb.filename = cb.filename.split('.csv')[0] + '_together.csv' mertic_names = [ fn for sublist in [[k + '_' + fn.__name__ for fn in v] for k, v in self.ae_metrics.items()] for fn in sublist ] cb.keys = ['loss' ] + [fn + '_loss' for fn in self._AA.output_names] + mertic_names cb.append_header = cb.keys # 7- training together self._AA.fit( x=x.map(self.create_batch_cast(self.adversarial_models)), validation_data=None if validation_data is None else\ validation_data.map(self.create_batch_cast(self.adversarial_models)), **kwargs )