Exemplo n.º 1
0
def prepare_data(_seed, property_map, batch_size, num_train, num_val,
                 num_workers):
    """
    Create the dataloaders for training.

    Args:
        _seed (int): seed for controlled randomness
        property_map (dict): mapping between model properties and dataset
            properties
        batch_size (int): batch size
        num_train (int): number of training samles
        num_val (int): number of validation samples
        num_workers (int): number of workers for the dataloaders

    Returns:
        schnetpack.data.Atomsloader objects for training, validation and
        testing and the atomic reference data
    """
    # local seed
    np.random.seed(_seed)

    # load and split
    data = get_dataset(dataset_properties=property_map.values())

    if num_train < 1:
        num_train = int(num_train * len(data))
    if num_val < 1:
        num_val = int(num_val * len(data))

    train, val, test = data.create_splits(num_train, num_val)

    train_loader = AtomsLoader(train,
                               batch_size,
                               True,
                               pin_memory=True,
                               num_workers=num_workers)
    val_loader = AtomsLoader(val,
                             batch_size,
                             False,
                             pin_memory=True,
                             num_workers=num_workers)
    test_loader = AtomsLoader(test,
                              batch_size,
                              False,
                              pin_memory=True,
                              num_workers=num_workers)

    atomrefs = {
        p: data.get_atomref(tgt)
        for p, tgt in property_map.items() if tgt is not None
    }

    return train_loader, val_loader, test_loader, atomrefs
Exemplo n.º 2
0
def train(_log, _config, model_dir, properties, additional_outputs, device):
    """
    Build a trainer from the configuration and start the training.

    Args:
        _config (dict): configuration dictionary
        model_dir (str): path to the training directory
        properties (list): list of model properties
        additional_outputs (list): list of additional model properties that are
            not back-propagated
        device (str): choose device for calculations (CPU/GPU)

    """
    create_dirs(_log=_log, output_dir=model_dir)
    save_config(_config=_config, output_dir=model_dir)
    property_map = get_property_map(properties)

    _log.info("Load data")
    dataset = get_dataset(dataset_properties=property_map.values())
    train_loader, val_loader, test_loader, atomrefs = build_dataloaders(
        property_map=property_map, dataset=dataset)
    np.savez(
        os.path.join(model_dir, "splits.npz"),
        train=train_loader.dataset.subset,
        val=val_loader.dataset.subset,
        test=test_loader.dataset.subset,
        atomrefs=atomrefs,
    )
    mean, stddev = stats(train_loader, atomrefs, property_map)

    _log.info("Build model")
    model_properties = [
        p for p, tgt in property_map.items() if tgt is not None
    ]
    model = build_model(
        mean=mean,
        stddev=stddev,
        atomrefs=atomrefs,
        model_properties=model_properties,
        additional_outputs=additional_outputs,
    ).to(device)
    _log.info("Setup training")
    trainer = setup_trainer(
        model=model,
        train_dir=model_dir,
        train_loader=train_loader,
        val_loader=val_loader,
        property_map=property_map,
    )
    _log.info("Training")
    trainer.train(device)
Exemplo n.º 3
0
def download():
    get_dataset()