コード例 #1
0
    def __init__(self,
                 samples: List[Sample],
                 config: DBCASplitterConfig = None):
        """
        Create new DBCASplitter.

        Args:
            samples (List[Sample]): Full set of samples to create splits from.
            config (DBCASplitterConfig, optional): Optional settings for split generation.
        """
        self.sample_store = SampleStore(samples)
        self.full_sample_set = FullSampleSet(sample_store=self.sample_store)
        self.sample_splits = {
            s_id: None
            for s_id in self.full_sample_set.sample_ids
        }
        self.unused_sample_ids = set(self.sample_splits.keys())

        self.config = config if config else DBCASplitterConfig()
        self.logger = logging.getLogger(__name__)
        self.train_set = SplitSampleSet(split="train")
        self.test_set = SplitSampleSet(split="test")

        # set seed for reproduceability
        np.random.seed(self.config.seed)
コード例 #2
0
    def add_sample_to_set(self, sample_id: Sample, sample_set: SplitSampleSet):
        """ 
        Add new sample to sample set.
        """
        split = sample_set.split_type
        self.sample_splits[sample_id] = split
        self.unused_sample_ids.remove(sample_id)

        sample_set.update(sample_id, self.full_sample_set, inplace=True)
コード例 #3
0
    def peek_sample(self, sample_id: str, sample_set_to_update: SplitSampleSet,
                         other_sample_set: SplitSampleSet) -> float:
        """
        Check score for adding sample `sample_id` to `sample_set_to_update` without actually
        making the update (not in-place).

        Args:
            sample_id (str): id of sample to check update for.
            sample_set_to_update (SplitSampleSet): Sample set to be updated with chosen sample.
            other_sample_set (SplitSampleSet): The other sample set (not updated).

        Returns:
            float: Split score if we had added `sample_id` to `sample_set_to_update`
        """
        a_dist, c_dist = sample_set_to_update.update(sample_id, self.full_sample_set, 
                                                     inplace=False)
        if sample_set_to_update.is_train:
            train_a_dist = a_dist
            train_c_dist = c_dist
            test_a_dist = other_sample_set.atom_distribution
            test_c_dist = other_sample_set.compound_distribution
        else:
            test_a_dist = a_dist
            test_c_dist = c_dist
            train_a_dist = other_sample_set.atom_distribution
            train_c_dist = other_sample_set.compound_distribution
            
        return self.score(train_a_dist, test_a_dist, train_c_dist, test_c_dist)
コード例 #4
0
    def __init__(self,
                 samples: List[Sample],
                 config: DBCASplitterConfig = None):
        """
        """
        super(DBCASplitterRay, self).__init__(samples, config)
        self.sample_store = SampleStore(samples)
        self.full_sample_set = FullSampleSet(sample_store=self.sample_store)
        self.sample_splits = {
            s_id: None
            for s_id in self.full_sample_set.sample_ids
        }
        self.unused_sample_ids = set(self.sample_splits.keys())

        self.config = config if config else DBCASplitterConfig()
        self.train_set = SplitSampleSet(split="train")
        self.test_set = SplitSampleSet(split="test")

        # set seed for reproduceability
        np.random.seed(self.config.seed)

        ray.init(num_cpus=self.config.num_processes)
コード例 #5
0
def peek_ray(sample_id: str, sample_set_to_update: SplitSampleSet,
                         other_sample_set: SplitSampleSet, 
                         dbca_config: DBCASplitterConfig, full_sample_set: FullSampleSet) -> float:
    """ 
    
    """
    a_dist, c_dist = sample_set_to_update.update(sample_id,
                                                    full_sample_set,
                                                    inplace=False)
    if sample_set_to_update.is_train:
        train_a_dist = a_dist
        train_c_dist = c_dist
        test_a_dist = other_sample_set.atom_distribution
        test_c_dist = other_sample_set.compound_distribution
    else:
        test_a_dist = a_dist
        test_c_dist = c_dist
        train_a_dist = other_sample_set.atom_distribution
        train_c_dist = other_sample_set.compound_distribution
    return score(train_a_dist, test_a_dist, train_c_dist, test_c_dist, 
                 dbca_config)
コード例 #6
0
def peek_ray(sample_id: str, sample_set_to_update: SplitSampleSet,
             other_sample_set: SplitSampleSet, dbca_config: DBCASplitterConfig,
             full_sample_set: FullSampleSet) -> float:
    """ 
    
    """
    # print(f"[_outer_peek_mp]: Starting work on {sample_id}... ")
    a_dist, c_dist = sample_set_to_update.update(sample_id,
                                                 full_sample_set,
                                                 inplace=False)
    if sample_set_to_update.is_train:
        train_a_dist = a_dist
        train_c_dist = c_dist
        test_a_dist = other_sample_set.atom_distribution
        test_c_dist = other_sample_set.compound_distribution
    else:
        test_a_dist = a_dist
        test_c_dist = c_dist
        train_a_dist = other_sample_set.atom_distribution
        train_c_dist = other_sample_set.compound_distribution
    # print(f"[_outer_peek_mp]: Done work on {sample_id}...! ")
    return score(train_a_dist, test_a_dist, train_c_dist, test_c_dist,
                 dbca_config)