Exemplo n.º 1
0
    def calculate_predictions(self):
        for i, (model, scaler_list) in enumerate(
                tqdm(zip(self.models, self.scalers), total=self.num_models)):
            (
                scaler,
                features_scaler,
                atom_descriptor_scaler,
                bond_feature_scaler,
            ) = scaler_list
            if (features_scaler is not None
                    or atom_descriptor_scaler is not None
                    or bond_feature_scaler is not None):
                self.test_data.reset_features_and_targets()
                if features_scaler is not None:
                    self.test_data.normalize_features(features_scaler)
                if atom_descriptor_scaler is not None:
                    self.test_data.normalize_features(
                        atom_descriptor_scaler, scale_atom_descriptors=True)
                if bond_feature_scaler is not None:
                    self.test_data.normalize_features(bond_feature_scaler,
                                                      scale_bond_features=True)
            preds = predict(
                model=model,
                data_loader=self.test_data_loader,
                scaler=scaler,
                return_unc_parameters=False,
            )
            if self.dataset_type == "spectra":
                preds = normalize_spectra(
                    spectra=preds,
                    phase_features=self.test_data.phase_features(),
                    phase_mask=self.spectra_phase_mask,
                    excluded_sub_value=float("nan"),
                )
            if i == 0:
                sum_preds = np.array(preds)
                sum_squared = np.square(preds)
                if self.individual_ensemble_predictions:
                    individual_preds = np.expand_dims(np.array(preds), axis=-1)
                if model.train_class_sizes is not None:
                    self.train_class_sizes = [model.train_class_sizes]
            else:
                sum_preds += np.array(preds)
                sum_squared += np.square(preds)
                if self.individual_ensemble_predictions:
                    individual_preds = np.append(individual_preds,
                                                 np.expand_dims(preds,
                                                                axis=-1),
                                                 axis=-1)
                if model.train_class_sizes is not None:
                    self.train_class_sizes.append(model.train_class_sizes)

        uncal_preds = sum_preds / self.num_models
        uncal_vars = sum_squared / self.num_models \
            - np.square(sum_preds) / self.num_models ** 2
        self.uncal_preds, self.uncal_vars = uncal_preds.tolist(
        ), uncal_vars.tolist()
        if self.individual_ensemble_predictions:
            self.individual_preds = individual_preds.tolist()
Exemplo n.º 2
0
    def test_predict_spectra(self,
                             name: str,
                             model_type: str,
                             expected_score: float,
                             expected_nans: int,
                             train_flags: List[str] = None,
                             predict_flags: List[str] = None):
        with TemporaryDirectory() as save_dir:
            # Train
            dataset_type = 'spectra'
            self.train(dataset_type=dataset_type,
                       metric='sid',
                       save_dir=save_dir,
                       model_type=model_type,
                       flags=train_flags)

            # Predict
            preds_path = os.path.join(save_dir, 'preds.csv')
            self.predict(dataset_type=dataset_type,
                         preds_path=preds_path,
                         save_dir=save_dir,
                         model_type=model_type,
                         flags=predict_flags)

            # Check results
            pred = pd.read_csv(preds_path)
            true = pd.read_csv(os.path.join(TEST_DATA_DIR, 'spectra.csv'))
            self.assertEqual(list(pred.keys()), list(true.keys()))
            self.assertEqual(list(pred['smiles']), list(true['smiles']))

            pred, true = pred.drop(columns=['smiles']), true.drop(
                columns=['smiles'])
            pred, true = pred.to_numpy(), true.to_numpy()
            phase_features = load_features(predict_flags[1])
            if '--spectra_phase_mask_path' in train_flags:
                mask = load_phase_mask(train_flags[5])
            else:
                mask = None
            true = normalize_spectra(true, phase_features, mask)
            sid = evaluate_predictions(preds=pred,
                                       targets=true,
                                       num_tasks=len(true[0]),
                                       metrics=['sid'],
                                       dataset_type='spectra')['sid'][0]
            self.assertAlmostEqual(sid,
                                   expected_score,
                                   delta=DELTA * expected_score)
            self.assertEqual(np.sum(np.isnan(pred)), expected_nans)
Exemplo n.º 3
0
def run_training(args: TrainArgs,
                 data: MoleculeDataset,
                 logger: Logger = None) -> Dict[str, List[float]]:
    """
    Loads data, trains a Chemprop model, and returns test scores for the model checkpoint with the highest validation score.

    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for
                 loading data and training the Chemprop model.
    :param data: A :class:`~chemprop.data.MoleculeDataset` containing the data.
    :param logger: A logger to record output.
    :return: A dictionary mapping each metric in :code:`args.metrics` to a list of values for each task.

    """
    if logger is not None:
        debug, info = logger.debug, logger.info
    else:
        debug = info = print

    # Set pytorch seed for random initial weights
    torch.manual_seed(args.pytorch_seed)

    # Split data
    debug(f'Splitting data with seed {args.seed}')
    if args.separate_test_path:
        test_data = get_data(
            path=args.separate_test_path,
            args=args,
            features_path=args.separate_test_features_path,
            atom_descriptors_path=args.separate_test_atom_descriptors_path,
            bond_features_path=args.separate_test_bond_features_path,
            phase_features_path=args.separate_test_phase_features_path,
            smiles_columns=args.smiles_columns,
            logger=logger)
    if args.separate_val_path:
        val_data = get_data(
            path=args.separate_val_path,
            args=args,
            features_path=args.separate_val_features_path,
            atom_descriptors_path=args.separate_val_atom_descriptors_path,
            bond_features_path=args.separate_val_bond_features_path,
            phase_features_path=args.separate_val_phase_features_path,
            smiles_columns=args.smiles_columns,
            logger=logger)

    if args.separate_val_path and args.separate_test_path:
        train_data = data
    elif args.separate_val_path:
        train_data, _, test_data = split_data(data=data,
                                              split_type=args.split_type,
                                              sizes=(0.8, 0.0, 0.2),
                                              seed=args.seed,
                                              num_folds=args.num_folds,
                                              args=args,
                                              logger=logger)
    elif args.separate_test_path:
        train_data, val_data, _ = split_data(data=data,
                                             split_type=args.split_type,
                                             sizes=(0.8, 0.2, 0.0),
                                             seed=args.seed,
                                             num_folds=args.num_folds,
                                             args=args,
                                             logger=logger)
    else:
        train_data, val_data, test_data = split_data(
            data=data,
            split_type=args.split_type,
            sizes=args.split_sizes,
            seed=args.seed,
            num_folds=args.num_folds,
            args=args,
            logger=logger)

    if args.dataset_type == 'classification':
        class_sizes = get_class_sizes(data)
        debug('Class sizes')
        for i, task_class_sizes in enumerate(class_sizes):
            debug(
                f'{args.task_names[i]} '
                f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}'
            )

    if args.save_smiles_splits:
        save_smiles_splits(
            data_path=args.data_path,
            save_dir=args.save_dir,
            task_names=args.task_names,
            features_path=args.features_path,
            train_data=train_data,
            val_data=val_data,
            test_data=test_data,
            smiles_columns=args.smiles_columns,
            logger=logger,
        )

    if args.features_scaling:
        features_scaler = train_data.normalize_features(replace_nan_token=0)
        val_data.normalize_features(features_scaler)
        test_data.normalize_features(features_scaler)
    else:
        features_scaler = None

    if args.atom_descriptor_scaling and args.atom_descriptors is not None:
        atom_descriptor_scaler = train_data.normalize_features(
            replace_nan_token=0, scale_atom_descriptors=True)
        val_data.normalize_features(atom_descriptor_scaler,
                                    scale_atom_descriptors=True)
        test_data.normalize_features(atom_descriptor_scaler,
                                     scale_atom_descriptors=True)
    else:
        atom_descriptor_scaler = None

    if args.bond_feature_scaling and args.bond_features_size > 0:
        bond_feature_scaler = train_data.normalize_features(
            replace_nan_token=0, scale_bond_features=True)
        val_data.normalize_features(bond_feature_scaler,
                                    scale_bond_features=True)
        test_data.normalize_features(bond_feature_scaler,
                                     scale_bond_features=True)
    else:
        bond_feature_scaler = None

    args.train_data_size = len(train_data)

    debug(
        f'Total size = {len(data):,} | '
        f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}'
    )

    # Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only)
    if args.dataset_type == 'regression':
        debug('Fitting scaler')
        scaler = train_data.normalize_targets()
    elif args.dataset_type == 'spectra':
        debug(
            'Normalizing spectra and excluding spectra regions based on phase')
        args.spectra_phase_mask = load_phase_mask(args.spectra_phase_mask_path)
        for dataset in [train_data, test_data, val_data]:
            data_targets = normalize_spectra(
                spectra=dataset.targets(),
                phase_features=dataset.phase_features(),
                phase_mask=args.spectra_phase_mask,
                excluded_sub_value=None,
                threshold=args.spectra_target_floor,
            )
            dataset.set_targets(data_targets)
        scaler = None
    else:
        scaler = None

    # Get loss function
    loss_func = get_loss_func(args)

    # Set up test set evaluation
    test_smiles, test_targets = test_data.smiles(), test_data.targets()
    if args.dataset_type == 'multiclass':
        sum_test_preds = np.zeros(
            (len(test_smiles), args.num_tasks, args.multiclass_num_classes))
    else:
        sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))

    # Automatically determine whether to cache
    if len(data) <= args.cache_cutoff:
        set_cache_graph(True)
        num_workers = 0
    else:
        set_cache_graph(False)
        num_workers = args.num_workers

    # Create data loaders
    train_data_loader = MoleculeDataLoader(dataset=train_data,
                                           batch_size=args.batch_size,
                                           num_workers=num_workers,
                                           class_balance=args.class_balance,
                                           shuffle=True,
                                           seed=args.seed)
    val_data_loader = MoleculeDataLoader(dataset=val_data,
                                         batch_size=args.batch_size,
                                         num_workers=num_workers)
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=args.batch_size,
                                          num_workers=num_workers)

    if args.class_balance:
        debug(
            f'With class_balance, effective train size = {train_data_loader.iter_size:,}'
        )

    # Train ensemble of models
    for model_idx in range(args.ensemble_size):
        # Tensorboard writer
        save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
        makedirs(save_dir)
        try:
            writer = SummaryWriter(log_dir=save_dir)
        except:
            writer = SummaryWriter(logdir=save_dir)

        # Load/build model
        if args.checkpoint_paths is not None:
            debug(
                f'Loading model {model_idx} from {args.checkpoint_paths[model_idx]}'
            )
            model = load_checkpoint(args.checkpoint_paths[model_idx],
                                    logger=logger)
        else:
            debug(f'Building model {model_idx}')
            model = MoleculeModel(args)

        # Optionally, overwrite weights:
        if args.checkpoint_frzn is not None:
            debug(
                f'Loading and freezing parameters from {args.checkpoint_frzn}.'
            )
            model = load_frzn_model(model=model,
                                    path=args.checkpoint_frzn,
                                    current_args=args,
                                    logger=logger)

        debug(model)

        if args.checkpoint_frzn is not None:
            debug(f'Number of unfrozen parameters = {param_count(model):,}')
            debug(f'Total number of parameters = {param_count_all(model):,}')
        else:
            debug(f'Number of parameters = {param_count_all(model):,}')

        if args.cuda:
            debug('Moving model to cuda')
        model = model.to(args.device)

        # Ensure that model is saved in correct location for evaluation if 0 epochs
        save_checkpoint(os.path.join(save_dir, MODEL_FILE_NAME), model, scaler,
                        features_scaler, atom_descriptor_scaler,
                        bond_feature_scaler, args)

        # Optimizers
        optimizer = build_optimizer(model, args)

        # Learning rate schedulers
        scheduler = build_lr_scheduler(optimizer, args)

        # Run training
        best_score = float('inf') if args.minimize_score else -float('inf')
        best_epoch, n_iter = 0, 0
        for epoch in trange(args.epochs):
            debug(f'Epoch {epoch}')
            n_iter = train(model=model,
                           data_loader=train_data_loader,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           logger=logger,
                           writer=writer)
            if isinstance(scheduler, ExponentialLR):
                scheduler.step()
            val_scores = evaluate(model=model,
                                  data_loader=val_data_loader,
                                  num_tasks=args.num_tasks,
                                  metrics=args.metrics,
                                  dataset_type=args.dataset_type,
                                  scaler=scaler,
                                  logger=logger)

            for metric, scores in val_scores.items():
                # Average validation score
                avg_val_score = np.nanmean(scores)
                debug(f'Validation {metric} = {avg_val_score:.6f}')
                writer.add_scalar(f'validation_{metric}', avg_val_score,
                                  n_iter)

                if args.show_individual_scores:
                    # Individual validation scores
                    for task_name, val_score in zip(args.task_names, scores):
                        debug(
                            f'Validation {task_name} {metric} = {val_score:.6f}'
                        )
                        writer.add_scalar(f'validation_{task_name}_{metric}',
                                          val_score, n_iter)

            # Save model checkpoint if improved validation score
            avg_val_score = np.nanmean(val_scores[args.metric])
            if args.minimize_score and avg_val_score < best_score or \
                    not args.minimize_score and avg_val_score > best_score:
                best_score, best_epoch = avg_val_score, epoch
                save_checkpoint(os.path.join(save_dir,
                                             MODEL_FILE_NAME), model, scaler,
                                features_scaler, atom_descriptor_scaler,
                                bond_feature_scaler, args)

        # Evaluate on test set using model with best validation score
        info(
            f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}'
        )
        model = load_checkpoint(os.path.join(save_dir, MODEL_FILE_NAME),
                                device=args.device,
                                logger=logger)

        test_preds = predict(model=model,
                             data_loader=test_data_loader,
                             scaler=scaler)
        test_scores = evaluate_predictions(preds=test_preds,
                                           targets=test_targets,
                                           num_tasks=args.num_tasks,
                                           metrics=args.metrics,
                                           dataset_type=args.dataset_type,
                                           logger=logger)

        if len(test_preds) != 0:
            sum_test_preds += np.array(test_preds)

        # Average test score
        for metric, scores in test_scores.items():
            avg_test_score = np.nanmean(scores)
            info(f'Model {model_idx} test {metric} = {avg_test_score:.6f}')
            writer.add_scalar(f'test_{metric}', avg_test_score, 0)

            if args.show_individual_scores and args.dataset_type != 'spectra':
                # Individual test scores
                for task_name, test_score in zip(args.task_names, scores):
                    info(
                        f'Model {model_idx} test {task_name} {metric} = {test_score:.6f}'
                    )
                    writer.add_scalar(f'test_{task_name}_{metric}', test_score,
                                      n_iter)
        writer.close()

    # Evaluate ensemble on test set
    avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()

    ensemble_scores = evaluate_predictions(preds=avg_test_preds,
                                           targets=test_targets,
                                           num_tasks=args.num_tasks,
                                           metrics=args.metrics,
                                           dataset_type=args.dataset_type,
                                           logger=logger)

    for metric, scores in ensemble_scores.items():
        # Average ensemble score
        avg_ensemble_test_score = np.nanmean(scores)
        info(f'Ensemble test {metric} = {avg_ensemble_test_score:.6f}')

        # Individual ensemble scores
        if args.show_individual_scores:
            for task_name, ensemble_score in zip(args.task_names, scores):
                info(
                    f'Ensemble test {task_name} {metric} = {ensemble_score:.6f}'
                )

    # Save scores
    with open(os.path.join(args.save_dir, 'test_scores.json'), 'w') as f:
        json.dump(ensemble_scores, f, indent=4, sort_keys=True)

    # Optionally save test preds
    if args.save_preds:
        test_preds_dataframe = pd.DataFrame(
            data={'smiles': test_data.smiles()})

        for i, task_name in enumerate(args.task_names):
            test_preds_dataframe[task_name] = [
                pred[i] for pred in avg_test_preds
            ]

        test_preds_dataframe.to_csv(os.path.join(args.save_dir,
                                                 'test_preds.csv'),
                                    index=False)

    return ensemble_scores
Exemplo n.º 4
0
def predict_and_save(args: PredictArgs, train_args: TrainArgs, test_data: MoleculeDataset,
                     task_names: List[str], num_tasks: int, test_data_loader: MoleculeDataLoader, full_data: MoleculeDataset,
                     full_to_valid_indices: dict, models: List[MoleculeModel], scalers: List[List[StandardScaler]],
                     return_invalid_smiles: bool = False):
    """
    Function to predict with a model and save the predictions to file.

    :param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
                 loading data and a model and making predictions.
    :param train_args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
    :param test_data: A :class:`~chemprop.data.MoleculeDataset` containing valid datapoints.
    :param task_names: A list of task names.
    :param num_tasks: Number of tasks.
    :param test_data_loader: A :class:`~chemprop.data.MoleculeDataLoader` to load the test data.
    :param full_data:  A :class:`~chemprop.data.MoleculeDataset` containing all (valid and invalid) datapoints.
    :param full_to_valid_indices: A dictionary dictionary mapping full to valid indices.
    :param models: A list or generator object of :class:`~chemprop.models.MoleculeModel`\ s.
    :param scalers: A list or generator object of :class:`~chemprop.features.scaler.StandardScaler` objects.
    :param return_invalid_smiles: Whether to return predictions of "Invalid SMILES" for invalid SMILES, otherwise will skip them in returned predictions.
    :return:  A list of lists of target predictions.
    """
    # Predict with each model individually and sum predictions
    if args.dataset_type == 'multiclass':
        sum_preds = np.zeros((len(test_data), num_tasks, args.multiclass_num_classes))
    else:
        sum_preds = np.zeros((len(test_data), num_tasks))
    if args.ensemble_variance or args.individual_ensemble_predictions:
        if args.dataset_type == 'multiclass':
            all_preds = np.zeros((len(test_data), num_tasks, args.multiclass_num_classes, len(args.checkpoint_paths)))
        else:
            all_preds = np.zeros((len(test_data), num_tasks, len(args.checkpoint_paths)))

    # Partial results for variance robust calculation.
    print(f'Predicting with an ensemble of {len(args.checkpoint_paths)} models')
    for index, (model, scaler_list) in enumerate(tqdm(zip(models, scalers), total=len(args.checkpoint_paths))):
        scaler, features_scaler, atom_descriptor_scaler, bond_feature_scaler = scaler_list

        # Normalize features
        if args.features_scaling or train_args.atom_descriptor_scaling or train_args.bond_feature_scaling:
            test_data.reset_features_and_targets()
            if args.features_scaling:
                test_data.normalize_features(features_scaler)
            if train_args.atom_descriptor_scaling and args.atom_descriptors is not None:
                test_data.normalize_features(atom_descriptor_scaler, scale_atom_descriptors=True)
            if train_args.bond_feature_scaling and args.bond_features_size > 0:
                test_data.normalize_features(bond_feature_scaler, scale_bond_features=True)

        # Make predictions
        model_preds = predict(
            model=model,
            data_loader=test_data_loader,
            scaler=scaler
        )
        if args.dataset_type == 'spectra':
            model_preds = normalize_spectra(
                spectra=model_preds,
                phase_features=test_data.phase_features(),
                phase_mask=args.spectra_phase_mask,
                excluded_sub_value=float('nan')
            )
        sum_preds += np.array(model_preds)
        if args.ensemble_variance or args.individual_ensemble_predictions:
            if args.dataset_type == 'multiclass':
                all_preds[:,:,:,index] = model_preds
            else:
                all_preds[:,:,index] = model_preds

    # Ensemble predictions
    avg_preds = sum_preds / len(args.checkpoint_paths)

    if args.ensemble_variance:
        if args.dataset_type == 'spectra':
            all_epi_uncs = roundrobin_sid(all_preds)
        else:
            all_epi_uncs = np.var(all_preds, axis=2)
            all_epi_uncs = all_epi_uncs.tolist()

    # Save predictions
    print(f'Saving predictions to {args.preds_path}')
    assert len(test_data) == len(avg_preds)
    if args.ensemble_variance:
        assert len(test_data) == len(all_epi_uncs)
    makedirs(args.preds_path, isfile=True)

    # Set multiclass column names, update num_tasks definition for multiclass
    if args.dataset_type == 'multiclass':
        task_names = [f'{name}_class_{i}' for name in task_names for i in range(args.multiclass_num_classes)]
        num_tasks = num_tasks * args.multiclass_num_classes

    # Copy predictions over to full_data
    for full_index, datapoint in enumerate(full_data):
        valid_index = full_to_valid_indices.get(full_index, None)
        preds = avg_preds[valid_index] if valid_index is not None else ['Invalid SMILES'] * num_tasks
        if args.ensemble_variance:
            if args.dataset_type == 'spectra':
                epi_uncs = all_epi_uncs[valid_index] if valid_index is not None else ['Invalid SMILES']
            else:
                epi_uncs = all_epi_uncs[valid_index] if valid_index is not None else ['Invalid SMILES'] * num_tasks
        if args.individual_ensemble_predictions:
            ind_preds = all_preds[valid_index] if valid_index is not None else [['Invalid SMILES'] * len(args.checkpoint_paths)] * num_tasks

        # Reshape multiclass to merge task and class dimension, with updated num_tasks
        if args.dataset_type == 'multiclass':
            if isinstance(preds, np.ndarray) and preds.ndim > 1:
                preds = preds.reshape((num_tasks))
                if args.ensemble_variance or args. individual_ensemble_predictions:
                    ind_preds = ind_preds.reshape((num_tasks, len(args.checkpoint_paths)))

        # If extra columns have been dropped, add back in SMILES columns
        if args.drop_extra_columns:
            datapoint.row = OrderedDict()

            smiles_columns = args.smiles_columns

            for column, smiles in zip(smiles_columns, datapoint.smiles):
                datapoint.row[column] = smiles

        # Add predictions columns
        for pred_name, pred in zip(task_names, preds):
            datapoint.row[pred_name] = pred
        if args.individual_ensemble_predictions:
            for pred_name, model_preds in zip(task_names,ind_preds):
                for idx, pred in enumerate(model_preds):
                    datapoint.row[pred_name+f'_model_{idx}'] = pred
        if args.ensemble_variance:
            if args.dataset_type == 'spectra':
                datapoint.row['epi_unc'] = epi_uncs
            else:
                for pred_name, epi_unc in zip(task_names, epi_uncs):
                    datapoint.row[pred_name+'_epi_unc'] = epi_unc

    # Save
    with open(args.preds_path, 'w') as f:
        writer = csv.DictWriter(f, fieldnames=full_data[0].row.keys())
        writer.writeheader()

        for datapoint in full_data:
            writer.writerow(datapoint.row)

    # Return predicted values
    avg_preds = avg_preds.tolist()
    
    if return_invalid_smiles:
        full_preds = []
        for full_index in range(len(full_data)):
            valid_index = full_to_valid_indices.get(full_index, None)
            preds = avg_preds[valid_index] if valid_index is not None else ['Invalid SMILES'] * num_tasks
            full_preds.append(preds)
        return full_preds
    else:
        return avg_preds