示例#1
0
 def __init__(
         self,
         episode_len,
         **kwargs
 ):
     self.episode_len = episode_len
     self.input_scale = 1
     baseVAE.__init__(self, **kwargs)
示例#2
0
    def __init__(self, strategy=None, **kwargs):
        self.strategy = strategy
        autoencoder.__init__(self, **kwargs)
        self.ONES = tf.ones(shape=[self.batch_size, 1])
        self.ZEROS = tf.zeros(shape=[self.batch_size, 1])

        self.adversarial_models = {
            'inference_discriminator_real': {
                'variable': None,
                'adversarial_item': 'generative',
                'adversarial_value': self.ONES
            },
            'inference_discriminator_fake': {
                'variable': None,
                'adversarial_item': 'generative',
                'adversarial_value': self.ZEROS
            },
            'inference_generator_fake': {
                'variable': None,
                'adversarial_item': 'generative',
                'adversarial_value': self.ONES
            },
            'generative_discriminator_real': {
                'variable': None,
                'adversarial_item': 'inference_mean',
                'adversarial_value': self.ONES
            },
            'generative_discriminator_fake': {
                'variable': None,
                'adversarial_item': 'inference_mean',
                'adversarial_value': self.ZEROS
            },
            'generative_generator_fake': {
                'variable': None,
                'adversarial_item': 'inference_mean',
                'adversarial_value': self.ONES
            }
        }
示例#3
0
    def fit(self, x, validation_data=None, **kwargs):
        print()
        print(f'training {autoencoder}')
        # 1- train the basic basicAE
        autoencoder.fit(self, x=x, validation_data=validation_data, **kwargs)

        def create_discriminator():
            for model in self.get_variables().values():
                layer_stuffing(model)

            for k, model in self.adversarial_models.items():
                model['variable'] = clone_model(
                    old_model=self.get_variables()[model['adversarial_item']],
                    new_name=k,
                    restore=self.filepath)

        # 2- create a latents discriminator
        if self.strategy:
            with self.strategy:
                create_discriminator()
        else:
            create_discriminator()

        # 3- clone autoencoder variables
        self.ae_get_variables = copy_fn(self.get_variables)

        # 4- switch to discriminate
        if self.strategy:
            if self.strategy:
                self.discriminators_compile()
        else:
            self.discriminators_compile()

        verbose = kwargs.pop('verbose')
        callbacks = kwargs.pop('callbacks')
        kwargs.pop('input_kw')

        for k, model in self.adversarial_models.items():
            print()
            print(f'training {k}')
            # 5- train the latents discriminator
            model['variable'].fit(
                x=x.map(self.create_batch_cast({k: model})),
                validation_data=None if validation_data is None else
                validation_data.map(self.create_batch_cast({k: model})),
                callbacks=[EarlyStopping()],
                verbose=1,
                **kwargs)

        kwargs['verbose'] = verbose
        kwargs['callbacks'] = callbacks

        # 6- connect all for inference_adversarial training
        if self.strategy:
            if self.strategy:
                self.__models_init__()
        else:
            self.__models_init__()

        print()
        print('training adversarial models')
        cbs = [
            cb for cb in callbacks or []
            if isinstance(cb, tf.keras.callbacks.CSVLogger)
        ]
        for cb in cbs:
            cb.filename = cb.filename.split('.csv')[0] + '_together.csv'
            mertic_names = [
                fn for sublist in [[k + '_' + fn.__name__ for fn in v]
                                   for k, v in self.ae_metrics.items()]
                for fn in sublist
            ]
            cb.keys = ['loss'
                       ] + [fn + '_loss'
                            for fn in self._AA.output_names] + mertic_names
            cb.append_header = cb.keys

        # 7- training together
        self._AA.fit(
            x=x.map(self.create_batch_cast(self.adversarial_models)),
            validation_data=None if validation_data is None else \
                validation_data.map(self.create_batch_cast(self.adversarial_models)),
            **kwargs
        )
示例#4
0
 def compile(self, adversarial_losses, adversarial_weights, **kwargs):
     self.adversarial_losses = adversarial_losses
     self.adversarial_weights = adversarial_weights
     autoencoder.compile(self, **kwargs)