def test_ppo_value_baseline(self):
     before_check_actor = get_checksum_network_parameters(self.network.get_actor_parameters())
     before_check_critic = get_checksum_network_parameters(self.network.get_critic_parameters())
     trainer_base_config['phi_key'] = 'value-baseline'
     trainer = ProximatePolicyGradient(config=TrainerConfig().create(config_dict=trainer_base_config),
                                       network=self.network)
     trainer.train()
     # test if network has changed
     self.assertNotEqual(get_checksum_network_parameters(self.network.get_actor_parameters()), before_check_actor)
     self.assertNotEqual(get_checksum_network_parameters(self.network.get_critic_parameters()), before_check_critic)
 def get_checksum(self):
     return get_checksum_network_parameters(self.parameters())
    def test_discriminator(self):
        # create ds
        dataset = generate_dataset_by_length(
            length=5,
            input_size=(1, 200, 200),
            output_size=(200, 200),
        )
        # create architecture
        self.base_config[
            'architecture'] = 'auto_encoder_deeply_supervised_with_discriminator'
        self.base_config['initialisation_type'] = 'xavier'
        network = eval(self.base_config['architecture']).Net(
            config=ArchitectureConfig().create(config_dict=self.base_config))
        # create trainer
        output_path = self.output_dir + '/discriminator'
        os.makedirs(output_path, exist_ok=True)
        trainer_config = {
            'output_path': output_path,
            'optimizer': 'Adam',
            'learning_rate': 0.1,
            'factory_key': 'DeepSupervisionWithDiscriminator',
            'data_loader_config': {
                'batch_size': 2
            },
            'target_data_loader_config': {
                'batch_size': 2
            },
            'criterion': 'WeightedBinaryCrossEntropyLoss',
            "criterion_args_str": 'beta=0.9',
        }
        trainer = TrainerFactory().create(
            config=TrainerConfig().create(config_dict=trainer_config),
            network=network)
        trainer.data_loader.set_dataset(dataset)
        trainer.target_data_loader.set_dataset(dataset)

        # test training main network
        initial_main_checksum = get_checksum_network_parameters(
            network.deeply_supervised_parameters())
        initial_discriminator_checksum = get_checksum_network_parameters(
            network.discriminator_parameters())
        trainer._train_main_network()
        self.assertEqual(
            initial_discriminator_checksum,
            get_checksum_network_parameters(
                network.discriminator_parameters()))
        self.assertNotEqual(
            initial_main_checksum,
            get_checksum_network_parameters(
                network.deeply_supervised_parameters()))
        # test training discriminator network
        initial_discriminator_checksum = get_checksum_network_parameters(
            network.discriminator_parameters())
        initial_main_checksum = get_checksum_network_parameters(
            network.deeply_supervised_parameters())
        trainer._train_discriminator_network()
        self.assertEqual(
            initial_main_checksum,
            get_checksum_network_parameters(
                network.deeply_supervised_parameters()))
        self.assertNotEqual(
            initial_discriminator_checksum,
            get_checksum_network_parameters(
                network.discriminator_parameters()))
        # test load checkpoint
        self.base_config[
            'architecture'] = 'auto_encoder_deeply_supervised_with_discriminator'
        self.base_config['initialisation_type'] = 'xavier'
        self.base_config['random_seed'] = 29078
        second_network = eval(self.base_config['architecture']).Net(
            config=ArchitectureConfig().create(config_dict=self.base_config))
        self.assertNotEqual(second_network.get_checksum(),
                            network.get_checksum())
        second_network.load_checkpoint(network.get_checkpoint())
        for (n1, p1), (n2, p2) in zip(network.named_parameters(),
                                      second_network.named_parameters()):
            print(n1, n2)
            self.assertEqual(p1.sum(), p2.sum())
        network.remove()
        second_network.remove()