def __init__(self, samples: List[Sample], config: DBCASplitterConfig = None): """ Create new DBCASplitter. Args: samples (List[Sample]): Full set of samples to create splits from. config (DBCASplitterConfig, optional): Optional settings for split generation. """ self.sample_store = SampleStore(samples) self.full_sample_set = FullSampleSet(sample_store=self.sample_store) self.sample_splits = { s_id: None for s_id in self.full_sample_set.sample_ids } self.unused_sample_ids = set(self.sample_splits.keys()) self.config = config if config else DBCASplitterConfig() self.logger = logging.getLogger(__name__) self.train_set = SplitSampleSet(split="train") self.test_set = SplitSampleSet(split="test") # set seed for reproduceability np.random.seed(self.config.seed)
def add_sample_to_set(self, sample_id: Sample, sample_set: SplitSampleSet): """ Add new sample to sample set. """ split = sample_set.split_type self.sample_splits[sample_id] = split self.unused_sample_ids.remove(sample_id) sample_set.update(sample_id, self.full_sample_set, inplace=True)
def peek_sample(self, sample_id: str, sample_set_to_update: SplitSampleSet, other_sample_set: SplitSampleSet) -> float: """ Check score for adding sample `sample_id` to `sample_set_to_update` without actually making the update (not in-place). Args: sample_id (str): id of sample to check update for. sample_set_to_update (SplitSampleSet): Sample set to be updated with chosen sample. other_sample_set (SplitSampleSet): The other sample set (not updated). Returns: float: Split score if we had added `sample_id` to `sample_set_to_update` """ a_dist, c_dist = sample_set_to_update.update(sample_id, self.full_sample_set, inplace=False) if sample_set_to_update.is_train: train_a_dist = a_dist train_c_dist = c_dist test_a_dist = other_sample_set.atom_distribution test_c_dist = other_sample_set.compound_distribution else: test_a_dist = a_dist test_c_dist = c_dist train_a_dist = other_sample_set.atom_distribution train_c_dist = other_sample_set.compound_distribution return self.score(train_a_dist, test_a_dist, train_c_dist, test_c_dist)
def __init__(self, samples: List[Sample], config: DBCASplitterConfig = None): """ """ super(DBCASplitterRay, self).__init__(samples, config) self.sample_store = SampleStore(samples) self.full_sample_set = FullSampleSet(sample_store=self.sample_store) self.sample_splits = { s_id: None for s_id in self.full_sample_set.sample_ids } self.unused_sample_ids = set(self.sample_splits.keys()) self.config = config if config else DBCASplitterConfig() self.train_set = SplitSampleSet(split="train") self.test_set = SplitSampleSet(split="test") # set seed for reproduceability np.random.seed(self.config.seed) ray.init(num_cpus=self.config.num_processes)
def peek_ray(sample_id: str, sample_set_to_update: SplitSampleSet, other_sample_set: SplitSampleSet, dbca_config: DBCASplitterConfig, full_sample_set: FullSampleSet) -> float: """ """ a_dist, c_dist = sample_set_to_update.update(sample_id, full_sample_set, inplace=False) if sample_set_to_update.is_train: train_a_dist = a_dist train_c_dist = c_dist test_a_dist = other_sample_set.atom_distribution test_c_dist = other_sample_set.compound_distribution else: test_a_dist = a_dist test_c_dist = c_dist train_a_dist = other_sample_set.atom_distribution train_c_dist = other_sample_set.compound_distribution return score(train_a_dist, test_a_dist, train_c_dist, test_c_dist, dbca_config)
def peek_ray(sample_id: str, sample_set_to_update: SplitSampleSet, other_sample_set: SplitSampleSet, dbca_config: DBCASplitterConfig, full_sample_set: FullSampleSet) -> float: """ """ # print(f"[_outer_peek_mp]: Starting work on {sample_id}... ") a_dist, c_dist = sample_set_to_update.update(sample_id, full_sample_set, inplace=False) if sample_set_to_update.is_train: train_a_dist = a_dist train_c_dist = c_dist test_a_dist = other_sample_set.atom_distribution test_c_dist = other_sample_set.compound_distribution else: test_a_dist = a_dist test_c_dist = c_dist train_a_dist = other_sample_set.atom_distribution train_c_dist = other_sample_set.compound_distribution # print(f"[_outer_peek_mp]: Done work on {sample_id}...! ") return score(train_a_dist, test_a_dist, train_c_dist, test_c_dist, dbca_config)