Exemple #1
0
    mean, std = train_loader.get_statistics('ip',
                                            divide_by_atoms=args.atomwise)
    model = build_fn(atom_features=args.atom_features,
                     message_steps=args.num_messages,
                     output_layers=args.output_layers,
                     reduce_fn=args.readout_fn,
                     atomwise=args.atomwise,
                     mean=mean['ip'],
                     std=std['ip'])

    # Train the model
    #  Following:
    init_learn_rate = 1e-4
    opt = optim.Adam(model.parameters(), lr=init_learn_rate)

    loss = trn.build_mse_loss(['ip'])
    metrics = [spk.metrics.MeanSquaredError('ip')]
    hooks = [
        trn.CSVHook(log_path=test_dir, metrics=metrics),
        trn.ReduceLROnPlateauHook(opt,
                                  patience=args.num_epochs // 8,
                                  factor=0.8,
                                  min_lr=1e-6,
                                  stop_after_min=True)
    ]

    trainer = trn.Trainer(
        model_path=test_dir,
        model=model,
        hooks=hooks,
        loss_fn=loss,
Exemple #2
0
    mean, std = train_loader.get_statistics('delta',
                                            divide_by_atoms=args.atomwise)
    model = build_fn(atom_features=args.atom_features,
                     message_steps=args.num_messages,
                     output_layers=args.output_layers,
                     reduce_fn=args.readout_fn,
                     atomwise=args.atomwise,
                     mean=mean['delta'],
                     std=std['delta'])

    # Train the model
    #  Following:
    init_learn_rate = 1e-4
    opt = optim.Adam(model.parameters(), lr=init_learn_rate)

    loss = trn.build_mse_loss(['delta'])
    metrics = [spk.metrics.MeanSquaredError('delta')]
    hooks = [
        trn.CSVHook(log_path=test_dir, metrics=metrics),
        trn.ReduceLROnPlateauHook(opt,
                                  patience=args.num_epochs // 8,
                                  factor=0.8,
                                  min_lr=1e-6,
                                  stop_after_min=True)
    ]

    trainer = trn.Trainer(
        model_path=test_dir,
        model=model,
        hooks=hooks,
        loss_fn=loss,
]
model = schnetpack.atomistic.model.AtomisticModel(representation, output_modules)

# build optimizer
optimizer = Adam(params=model.parameters(), lr=1e-4, )

# hooks
logging.info("build trainer")
metrics = [MeanAbsoluteError(p, p) for p in properties]
###hooks = [CSVHook(log_path=model_dir, metrics=metrics), ReduceLROnPlateauHook(optimizer)]
hooks = [CSVHook(log_path=model_dir, metrics=metrics) ]

# trainer
clip_norm=None

loss = build_mse_loss(properties, loss_tradeoff=[0.01, 0.99])
trainer = Trainer(
    model_dir,
    model=model,
    hooks=hooks,
    loss_fn=loss,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
    clip_norm=clip_norm,
)

total_parms = sum(p.numel() for p in model.parameters() if p.requires_grad)
np.savetxt('./parms.txt', [total_parms], fmt='%d')

# run training
Exemple #4
0
    )
]
model = schnetpack.AtomisticModel(representation, output_modules)

# build optimizer
optimizer = Adam(model.parameters(), lr=1e-4)

# hooks
logging.info("build trainer")
metrics = [MeanAbsoluteError(p, p) for p in properties]
hooks = [
    CSVHook(log_path=model_dir, metrics=metrics),
    ReduceLROnPlateauHook(optimizer)
]

# trainer
loss = build_mse_loss(properties)
trainer = Trainer(
    model_dir,
    model=model,
    hooks=hooks,
    loss_fn=loss,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
)

# run training
logging.info("training")
trainer.train(device="cpu")
def train_schnet(
    model: Union[TorchMessage, torch.nn.Module, Path],
    database: Dict[str, float],
    num_epochs: int,
    reset_weights: bool = True,
    property_name: str = 'output',
    test_set: Optional[List[str]] = None,
    device: str = 'cpu',
    batch_size: int = 32,
    validation_split: float = 0.1,
    bootstrap: bool = False,
    random_state: int = 1,
    learning_rate: float = 1e-3,
    patience: int = None,
    timeout: float = None
) -> Union[Tuple[TorchMessage, pd.DataFrame], Tuple[TorchMessage, pd.DataFrame,
                                                    List[float]]]:
    """Train a SchNet model

    Args:
        model: Model to be retrained
        database: Mapping of XYZ format structure to property
        num_epochs: Number of training epochs
        property_name: Name of the property being predicted
        reset_weights: Whether to re-initialize weights before training, or start training from previous
        test_set: Hold-out set. If provided, function will return the performance of the model on those weights
        device: Device (e.g., 'cuda', 'cpu') used for training
        batch_size: Batch size during training
        validation_split: Fraction to training set to use for the validation loss
        bootstrap: Whether to take a bootstrap sample of the training set before training
        random_state: Random seed used for generating validation set and bootstrap sampling
        learning_rate: Initial learning rate for optimizer
        patience: Patience until learning rate is lowered. Default: epochs / 8
        timeout: Maximum training time in seconds
    Returns:
        - model: Retrained model
        - history: Training history
        - test_pred: Predictions on ``test_set``, if provided
    """

    # Make sure the models are converted to Torch models
    if isinstance(model, TorchMessage):
        model = model.get_model(device)
    elif isinstance(model, (Path, str)):
        model = torch.load(model,
                           map_location='cpu')  # Load to main memory first

    # If desired, re-initialize weights
    if reset_weights:
        for module in model.modules():
            if hasattr(module, 'reset_parameters'):
                module.reset_parameters()

    # Separate the database into molecules and properties
    xyz, y = zip(*database.items())
    xyz = np.array(xyz)
    y = np.array(y)

    # Convert the xyz files to ase Atoms
    atoms = np.array([next(read_xyz(StringIO(x), slice(None))) for x in xyz])

    # Make the training and validation splits
    rng = np.random.RandomState(random_state)
    train_split = rng.rand(len(xyz)) > validation_split
    train_X = atoms[train_split]
    train_y = y[train_split]
    valid_X = atoms[~train_split]
    valid_y = y[~train_split]

    # Perform a bootstrap sample of the training data
    if bootstrap:
        sample = rng.choice(len(train_X), size=(len(train_X), ), replace=True)
        train_X = train_X[sample]
        train_y = train_y[sample]

    # Start the training process
    with TemporaryDirectory() as td:
        # Save the data to an ASE Atoms database
        train_file = os.path.join(td, 'train_data.db')
        db = AtomsData(train_file, available_properties=[property_name])
        db.add_systems(train_X, [{property_name: i} for i in train_y])
        train_loader = AtomsLoader(db, batch_size=batch_size, shuffle=True)

        valid_file = os.path.join(td, 'valid_data.db')
        db = AtomsData(valid_file, available_properties=[property_name])
        db.add_systems(valid_X, [{property_name: i} for i in valid_y])
        valid_loader = AtomsLoader(db, batch_size=batch_size)

        # Make the trainer
        opt = optim.Adam(model.parameters(), lr=learning_rate)

        loss = trn.build_mse_loss(['delta'])
        metrics = [spk.metrics.MeanSquaredError('delta')]
        if patience is None:
            patience = num_epochs // 8
        hooks = [
            trn.CSVHook(log_path=td, metrics=metrics),
            trn.ReduceLROnPlateauHook(opt,
                                      patience=patience,
                                      factor=0.8,
                                      min_lr=1e-6,
                                      stop_after_min=True)
        ]

        if timeout is not None:
            hooks.append(TimeoutHook(timeout))

        trainer = trn.Trainer(
            model_path=td,
            model=model,
            hooks=hooks,
            loss_fn=loss,
            optimizer=opt,
            train_loader=train_loader,
            validation_loader=valid_loader,
            checkpoint_interval=num_epochs + 1  # Turns off checkpointing
        )

        trainer.train(device, n_epochs=num_epochs)

        # Load in the best model
        model = torch.load(os.path.join(td, 'best_model'))

        # If desired, report the performance on a test set
        test_pred = None
        if test_set is not None:
            test_pred = evaluate_schnet([model],
                                        test_set,
                                        property_name=property_name,
                                        batch_size=batch_size,
                                        device=device)

        # Move the model off of the GPU to save memory
        if 'cuda' in device:
            model.to('cpu')

        # Load in the training results
        train_results = pd.read_csv(os.path.join(td, 'log.csv'))

        # Return the results
        if test_pred is None:
            return TorchMessage(model), train_results
        else:
            return TorchMessage(model), train_results, test_pred[:, 0].tolist()
Exemple #6
0
 def test_loss_tradeoff(self, batch, result_named, properties,
                        loss_value_traded, loss_tradeoff):
     loss_fn = build_mse_loss(properties, loss_tradeoff)
     loss = loss_fn(batch, result_named)
     assert np.equal(loss, loss_value_traded)