Beispiel #1
0
def linear_alkanes(max_carbon=10):
    """A toy dataset with linear alkanes from 1 to `max_carbon` carbons.

    Parameters
    ----------
    max_carbon : int
        Maximum number of carbons in the molecules generated.

    Returns
    -------
    Dataset
        A dataset containing alanes.

    Examples
    --------
    dataset = count_carbons(10)


    """

    dataset = Dataset([Point(idx * "C") for idx in range(1, max_carbon + 1)])

    def annotate(point):
        point.y = len(point.smiles)
        return point

    dataset = dataset.apply(annotate)
    return dataset
Beispiel #2
0
def test_dataset_subtraction():
    from malt.data.dataset import Dataset
    from malt.point import Point

    p0 = Point("C")
    p1 = Point("CC")

    dataset0 = Dataset([p0])
    dataset1 = Dataset([p0, p1])
    assert len(dataset1 - dataset0) == 1
Beispiel #3
0
def test_split_dataset():
    from malt.data.dataset import Dataset
    from malt.point import Point

    p1 = Point("C")
    p2 = Point("CC")
    ds = Dataset([p1, p2])
    ds0, ds1 = ds.split([0.5, 0.5])
    assert len(ds0) == 1
    assert len(ds1) == 1
Beispiel #4
0
def test_dataset_view_batch_of_g():
    import dgl
    from malt.data.dataset import Dataset
    from malt.point import Point

    p1 = Point("C", y=0.0)
    p2 = Point("CC", y=0.0)
    ds = Dataset([p1, p2])
    _ds = ds.view(batch_size=2, collate_fn="batch_of_g")
    g = next(iter(_ds))
    assert isinstance(g, dgl.DGLGraph)
    assert (
        g.number_of_nodes() == p1.g.number_of_nodes() + p2.g.number_of_nodes()
    )
Beispiel #5
0
def test_dataset_view():
    import torch
    import dgl
    from malt.data.dataset import Dataset
    from malt.point import Point

    p1 = Point("C", y=0.0)
    p2 = Point("CC", y=0.0)
    ds = Dataset([p1, p2])
    _ds = ds.view(batch_size=2)
    assert isinstance(_ds, torch.utils.data.DataLoader)
    g, y = next(iter(_ds))
    assert isinstance(g, dgl.DGLGraph)
    assert isinstance(y, torch.Tensor)
    assert y.shape[0] == 2
    assert y.shape[1] == 1
Beispiel #6
0
def test_build_dataset():
    from malt.data.dataset import Dataset
    from malt.point import Point

    p1 = Point("C")
    p2 = Point("CC")
    ds = Dataset([p1, p2])
    assert len(ds) == 2
    assert ds[0] == p1
Beispiel #7
0
 def step(self):
     best = self.prioritize()
     if best is None:
         return None
     best = Dataset([best])
     best = self.merchandize(best)
     best = self.assay(best)
     self.train()
     return best
Beispiel #8
0
def _dataset_from_dgllife(dgllife_dataset):
    idx = 0
    ds = []
    for smiles, g, y in dgllife_dataset:
        point = Point(smiles, g, y.item(), extra={"idx": idx})
        idx += 1
        ds.append(point)

    ds = Dataset(ds)

    return ds
Beispiel #9
0
 def __init__(
     self,
     model: SupervisedModel,
     policy: Callable,
     trainer: Callable,
     merchant: Merchant,
     assayer: Assayer,
     portfolio: Union[Dataset, None] = None,
 ):
     super(ModelBasedPlayer, self).__init__()
     self.model = model
     self.policy = policy
     self.trainer = trainer
     self.merchant = merchant
     self.assayer = assayer
     if portfolio is None:
         portfolio = Dataset([])
     self.portfolio = portfolio
Beispiel #10
0
 def __init__(
     self,
     dataset: Dataset,
 ):
     super(DatasetMerchant, self).__init__()
     self.dataset = dataset.clone().erase_annotation()