Пример #1
0
 def test_show_individual_scores(self):
     try:
         self.args.show_individual_scores = True
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('show_individual_scores')
Пример #2
0
 def test_scaffold(self):
     try:
         self.args.split_type = 'scaffold_balanced'
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('scaffold')
Пример #3
0
 def test_no_cache(self):
     try:
         self.args.no_cache = True
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('no_cache')
Пример #4
0
def sklearn_train() -> None:
    """Parses scikit-learn training arguments and trains a scikit-learn model.

    This is the entry point for the command line command :code:`sklearn_train`.
    """
    cross_validate(args=SklearnTrainArgs().parse_args(),
                   train_func=run_sklearn)
Пример #5
0
 def test_bias(self):
     try:
         self.args.bias = True
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('bias')
Пример #6
0
 def test_undirected_messages(self):
     try:
         self.args.undirected = True
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('undirected_messages')
Пример #7
0
 def test_save_smiles_splits(self):
     try:
         self.args.save_smiles_splits = True
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('save smiles splits')
Пример #8
0
 def test_config(self):
     try:
         self.args.config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json')
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('config')
Пример #9
0
 def test_atom_messages(self):
     try:
         self.args.atom_messages = True
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('atom_messages')
Пример #10
0
    def setUp(self):
        parser = ArgumentParser()
        add_train_args(parser)
        args = parser.parse_args([])
        args.data_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), 'delaney_toy.csv')
        args.dataset_type = 'regression'
        args.batch_size = 2
        args.hidden_size = 5
        args.epochs = 1
        args.quiet = True
        self.temp_dir = TemporaryDirectory()
        args.save_dir = self.temp_dir.name
        logger = create_logger(name='train',
                               save_dir=args.save_dir,
                               quiet=args.quiet)
        modify_train_args(args)
        cross_validate(args, logger)
        clear_cache()

        parser = ArgumentParser()
        add_predict_args(parser)
        args = parser.parse_args([])
        args.batch_size = 2
        args.checkpoint_dir = self.temp_dir.name
        args.preds_path = NamedTemporaryFile().name
        args.test_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)),
            'delaney_toy_smiles.csv')
        self.args = args
Пример #11
0
 def test_rdkit_2d_features_unnormalized(self):
     try:
         self.args.features_generator = ['rdkit_2d']
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('rdkit_2d_features_unnormalized')
Пример #12
0
 def test_activation_prelu(self):
     try:
         self.args.activation = 'PReLU'
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('activation_prelu')
Пример #13
0
 def test_num_folds_ensemble(self):
     try:
         self.args.num_folds = 2
         self.args.ensemble_size = 2
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('num_folds_ensemble')
Пример #14
0
 def test_features_path(self):
     try:
         self.args.features_path = [os.path.join(os.path.dirname(os.path.abspath(__file__)), 'delaney_toy_features.npz')]
         self.args.no_features_scaling = True
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('features_path')
Пример #15
0
 def test_classification_multiclass_default(self):
     try:
         self.args.data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tox21_toy.csv')
         self.args.dataset_type = 'classification'
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('classification_default')
Пример #16
0
 def test_features_only(self):
     try:
         self.args.features_generator = ['morgan']
         self.features_only = True
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('features_only')
Пример #17
0
 def test_rdkit_2d_features(self):
     try:
         self.args.features_generator = ['rdkit_2d_normalized']
         self.args.no_features_scaling = True
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('rdkit_2d_features')
Пример #18
0
def main():
    '''main method.'''
    args = parse_train_args()

    logger = create_logger(name='train',
                           save_dir=args.save_dir,
                           quiet=args.quiet)

    cross_validate(args, logger)
Пример #19
0
 def test_predetermined_split(self):
     try:
         self.args.split_type = 'predetermined'
         self.args.folds_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'delaney_toy_folds.pkl')
         self.args.val_fold_index = 1
         self.args.test_fold_index = 2
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('predetermined_split')
Пример #20
0
def train_outside(args_dict):
    """
    Used for calling this script from another python script.
    :dict args_dict: dict of args to use
    """

    sys.argv = create_args(args_dict, 'train.py')

    args = TrainArgs().parse_args()
    logger = create_logger(name='train', save_dir=args.save_dir, quiet=args.quiet)
    cross_validate(args, logger)
Пример #21
0
 def test_checkpoint(self):
     try:
         args_copy = deepcopy(self.args)
         temp_dir = TemporaryDirectory()
         self.args.save_dir = temp_dir.name
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
         args_copy.checkpoint_dir = temp_dir.name
         args_copy.test = True
         modify_train_args(args_copy)
         cross_validate(args_copy, self.logger)
     except:
         self.fail('checkpoint')
Пример #22
0
def run_comparison(experiment_args: Namespace,
                   logger: logging.Logger,
                   features_dir: str = None):
    for dataset_name in experiment_args.datasets:
        dataset_type, dataset_path, num_folds, metric = DATASETS[dataset_name]
        logger.info(dataset_name)

        # Set up args
        args = deepcopy(experiment_args)
        args.data_path = dataset_path
        args.dataset_type = dataset_type
        args.save_dir = os.path.join(args.save_dir, dataset_name)
        args.num_folds = num_folds
        args.metric = metric
        if features_dir is not None:
            args.features_path = [
                os.path.join(features_dir, dataset_name + '.pckl')
            ]
        modify_train_args(args)

        # Set up logging for training
        os.makedirs(args.save_dir, exist_ok=True)
        fh = logging.FileHandler(os.path.join(args.save_dir, args.log_name))
        fh.setLevel(logging.DEBUG)

        # Cross validate
        TRAIN_LOGGER.addHandler(fh)
        mean_score, std_score = cross_validate(args, TRAIN_LOGGER)
        TRAIN_LOGGER.removeHandler(fh)

        # Record results
        logger.info(f'{mean_score} +/- {std_score} {metric}')
        temp_model = build_model(args)
        logger.info(f'num params: {param_count(temp_model):,}')
    def objective(hyperparams: Dict[str, Union[int, float]],
                  seed: int) -> Dict:
        # Convert hyperparams from float to int when necessary
        for key in INT_KEYS:
            hyperparams[key] = int(hyperparams[key])

        # Copy args
        hyper_args = deepcopy(args)

        # Update args with hyperparams
        if args.save_dir is not None:
            folder_name = '_'.join(f'{key}_{value}'
                                   for key, value in hyperparams.items())
            hyper_args.save_dir = os.path.join(hyper_args.save_dir,
                                               folder_name)

        for key, value in hyperparams.items():
            setattr(hyper_args, key, value)

        hyper_args.ffn_hidden_size = hyper_args.hidden_size

        # Cross validate
        mean_score, std_score = cross_validate(args=hyper_args,
                                               train_func=run_training)

        # Record results
        temp_model = MoleculeModel(hyper_args)
        num_params = param_count(temp_model)
        logger.info(f'Trial results with seed {seed}')
        logger.info(hyperparams)
        logger.info(f'num params: {num_params:,}')
        logger.info(f'{mean_score} +/- {std_score} {hyper_args.metric}')

        # Deal with nan
        if np.isnan(mean_score):
            if hyper_args.dataset_type == 'classification':
                mean_score = 0
            else:
                raise ValueError(
                    'Can\'t handle nan score for non-classification dataset.')

        loss = (1 if hyper_args.minimize_score else -1) * mean_score

        return {
            'loss': loss,
            'status': 'ok',
            'mean_score': mean_score,
            'std_score': std_score,
            'hyperparams': hyperparams,
            'num_params': num_params,
            'seed': seed,
        }
    def objective(hyperparams: Dict[str, Union[int, float]]) -> float:
        # Convert hyperparams from float to int when necessary
        for key in INT_KEYS:
            hyperparams[key] = int(hyperparams[key])

        # Copy args
        hyper_args = deepcopy(args)

        # Update args with hyperparams
        if args.save_dir is not None:
            folder_name = "_".join(f"{key}_{value}"
                                   for key, value in hyperparams.items())
            hyper_args.save_dir = os.path.join(hyper_args.save_dir,
                                               folder_name)

        for key, value in hyperparams.items():
            setattr(hyper_args, key, value)

        hyper_args.ffn_hidden_size = hyper_args.hidden_size

        # Record hyperparameters
        logger.info(hyperparams)

        # Cross validate
        mean_score, std_score = cross_validate(args=hyper_args,
                                               train_func=run_training)

        # Record results
        temp_model = MoleculeModel(hyper_args)
        num_params = param_count(temp_model)
        logger.info(f"num params: {num_params:,}")
        logger.info(f"{mean_score} +/- {std_score} {hyper_args.metric}")

        results.append({
            "mean_score": mean_score,
            "std_score": std_score,
            "hyperparams": hyperparams,
            "num_params": num_params,
        })

        # Deal with nan
        if np.isnan(mean_score):
            if hyper_args.dataset_type == "classification":
                mean_score = 0
            else:
                raise ValueError(
                    "Can't handle nan score for non-classification dataset.")

        return (1 if hyper_args.minimize_score else -1) * mean_score
Пример #25
0
    def objective(hyperparams: Dict[str, Union[int, float]]) -> float:
        # Convert hyperparams from float to int when necessary
        for key in INT_KEYS:
            hyperparams[key] = int(hyperparams[key])

        # Update args with hyperparams
        hyper_args = deepcopy(args)
        if args.save_dir is not None:
            folder_name = '_'.join([
                f'{key}_{value}' if key in INT_KEYS else f'{key}_{value}'
                for key, value in hyperparams.items()
            ])
            hyper_args.save_dir = os.path.join(hyper_args.save_dir,
                                               folder_name)
        for key, value in hyperparams.items():
            setattr(hyper_args, key, value)

        # Record hyperparameters
        logger.info(hyperparams)

        # Cross validate
        mean_score, std_score = cross_validate(hyper_args, TRAIN_LOGGER)

        # Record results
        temp_model = build_model(hyper_args)
        num_params = param_count(temp_model)
        logger.info(f'num params: {num_params:,}')
        logger.info(f'{mean_score} +/- {std_score} {hyper_args.metric}')

        results.append({
            'mean_score': mean_score,
            'std_score': std_score,
            'hyperparams': hyperparams,
            'num_params': num_params
        })

        # Deal with nan
        if np.isnan(mean_score):
            if hyper_args.dataset_type == 'classification':
                mean_score = 0
            else:
                raise ValueError(
                    'Can\'t handle nan score for non-classification dataset.')

        return (1 if hyper_args.minimize_score else -1) * mean_score
Пример #26
0
        def objective(hyperparams: Dict[str, Union[int, float]]) -> float:
            # Convert hyperparms from float to int when necessary
            for key in INT_KEYS:
                hyperparams[key] = int(hyperparams[key])

            # Copy args
            gs_args = deepcopy(dataset_args)

            for key, value in hyperparams.items():
                setattr(gs_args, key, value)

            # Record hyperparameters
            logger.info(hyperparams)

            # Cross validate
            mean_score, std_score = cross_validate(gs_args, TRAIN_LOGGER)

            # Record results
            temp_model = build_model(gs_args)
            num_params = param_count(temp_model)
            logger.info('num params: {:,}'.format(num_params))
            logger.info('{} +/- {} {}'.format(mean_score, std_score, metric))

            results.append({
                'mean_score': mean_score,
                'std_score': std_score,
                'hyperparams': hyperparams,
                'num_params': num_params
            })

            # Deal with nan
            if np.isnan(mean_score):
                if gs_args.dataset_type == 'classification':
                    mean_score = 0
                else:
                    raise ValueError(
                        'Can\'t handle nan score for non-classification dataset.'
                    )

            return (1 if gs_args.minimize_score else -1) * mean_score
Пример #27
0
"""Trains a model on a dataset."""
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from chemprop.args import TrainArgs
from chemprop.train import cross_validate, cross_validate_mechine
from chemprop.utils import create_logger

if __name__ == '__main__':
    args = TrainArgs().parse_args()
    logger = create_logger(name='train',
                           save_dir=args.save_dir,
                           quiet=args.quiet)
    model = cross_validate(args, logger)
    # cross_validate_mechine(args, logger)
Пример #28
0
import pandas as pd
import glob
import os

from chemprop.args import TrainArgs
from chemprop.train import cross_validate
from chemprop.utils import create_logger

csvs = glob.glob(os.path.join('../data/tmprss2_meyer_et_al/', '*.csv'))
raw_data = pd.concat((pd.read_csv(f) for f in csvs))

chemprop_data = raw_data[['SMILES', 'Activity']]
chemprop_data.to_csv('chemprop_in.csv', index=False)

# argument passing pretty janky but it's set up to use command line
args = TrainArgs().parse_args(['--data_path', 'chemprop_in.csv', '--dataset_type', 'regression',
                               '--save_dir', 'models'])
logger = create_logger(name='train', save_dir=args.save_dir, quiet=args.quiet)
cross_validate(args, logger)
Пример #29
0
 def test_regression_default(self):
     try:
         modify_train_args(self.args)
         cross_validate(self.args, self.logger)
     except:
         self.fail('regression_default')