コード例 #1
0
def test_sample_valid_molecules_if_not_enough_valid_generated():
    # does not raise an exception if
    molecules = ['invalid' for _ in range(20)]
    molecules[-1] = 'CC'
    molecules[-2] = 'CN'
    generator = MockGenerator(molecules)

    # samples a max of 9*2 molecules and just does not sample the good ones
    # in this case the list of generated molecules is empty
    assert not sample_valid_molecules(generator, 2, max_tries=9)

    # with a max of 10*2 molecules two valid molecules can be sampled
    generator = MockGenerator(molecules)
    mols = sample_valid_molecules(generator, 2)
    assert mols == ['CN', 'CC']
コード例 #2
0
    def assess_model(self, model: DistributionMatchingGenerator) -> DistributionLearningBenchmarkResult:
        chemnet = self._load_chemnet()

        start_time = time.time()
        generated_molecules = sample_valid_molecules(model=model, number_molecules=self.number_samples)
        end_time = time.time()

        if len(generated_molecules) != self.number_samples:
            logger.warning('The model could not generate enough valid molecules.')

        mu_ref, cov_ref = self._calculate_distribution_statistics(chemnet, self.reference_molecules)
        mu, cov = self._calculate_distribution_statistics(chemnet, generated_molecules)

        FCD = fcd.calculate_frechet_distance(mu1=mu_ref, mu2=mu,
                                             sigma1=cov_ref, sigma2=cov)
        score = np.exp(-0.2 * FCD)

        metadata = {
            'number_reference_molecules': len(self.reference_molecules),
            'number_generated_molecules': len(generated_molecules),
            'FCD': FCD
        }

        return DistributionLearningBenchmarkResult(benchmark_name=self.name,
                                                   score=score,
                                                   sampling_time=end_time - start_time,
                                                   metadata=metadata)
コード例 #3
0
def test_sample_valid_molecules_with_invalid_molecules():
    generator = MockGenerator(
        ['invalid', 'invalid', 'invalid', 'CCCC', 'invalid', 'CC'])

    mols = sample_valid_molecules(generator, 2)

    assert mols == ['CCCC', 'CC']
コード例 #4
0
    def assess_model(
        self, model: DistributionMatchingGenerator
    ) -> DistributionLearningBenchmarkResult:
        start_time = time.time()
        molecules = sample_valid_molecules(
            model=model, number_molecules=self.number_samples)
        end_time = time.time()

        if len(molecules) != self.number_samples:
            logger.warning(
                'The model could not generate enough valid molecules. The score will be penalized.'
            )

        # canonicalize_list removes duplicates (and invalid molecules, but there shouldn't be any)
        unique_molecules = canonicalize_list(molecules,
                                             include_stereocenters=False)

        unique_ratio = len(unique_molecules) / self.number_samples
        metadata = {
            'number_samples': self.number_samples,
            'number_unique': len(unique_molecules)
        }

        return DistributionLearningBenchmarkResult(benchmark_name=self.name,
                                                   score=unique_ratio,
                                                   sampling_time=end_time -
                                                   start_time,
                                                   metadata=metadata)
コード例 #5
0
def test_sample_valid_molecules_for_valid_only():
    generator = MockGenerator(['CCCC', 'CC'])

    mols = sample_valid_molecules(generator, 2)

    assert mols == ['CCCC', 'CC']