Пример #1
0
def test_compute():
    rmse = RootMeanSquaredError()

    y_pred = torch.Tensor([[2.0], [-2.0]])
    y = torch.zeros(2)
    rmse.update((y_pred, y))
    assert rmse.compute() == 2.0

    rmse.reset()
    y_pred = torch.Tensor([[3.0], [-3.0]])
    y = torch.zeros(2)
    rmse.update((y_pred, y))
    assert rmse.compute() == 3.0
def test_compute(n_times, test_data):

    rmse = RootMeanSquaredError()

    y_pred, y, batch_size = test_data
    rmse.reset()
    if batch_size > 1:
        n_iters = y.shape[0] // batch_size + 1
        for i in range(n_iters):
            idx = i * batch_size
            rmse.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
    else:
        rmse.update((y_pred, y))

    np_y = y.numpy().ravel()
    np_y_pred = y_pred.numpy().ravel()

    np_res = np.sqrt(np.power((np_y - np_y_pred), 2.0).sum() / np_y.shape[0])
    res = rmse.compute()

    assert isinstance(res, float)
    assert pytest.approx(res) == np_res
Пример #3
0
class LocalMetrics(ignite.metrics.Metric):
    METRICS = (
        'rmse',
        'pearson',
        'per_model_pearson',
    )
    FIGURES = (
        'hist',
    )

    def __init__(self, column, title=None, metrics=None, figures=None, output_transform=lambda x: x):
        self.column = column
        self.title = title if title is not None else ''
        self.metrics = set(metrics if metrics is not None else LocalMetrics.METRICS)
        self.figures = set(figures if figures is not None else LocalMetrics.FIGURES)
        self._rmse = RootMeanSquaredError()
        self._pearson = PearsonR()
        self._per_model_pearson = Mean()
        self._hist = ScoreHistogram(title=title)
        super(LocalMetrics, self).__init__(output_transform=output_transform)

    def reset(self):
        self._rmse.reset()
        self._pearson.reset()
        self._per_model_pearson.reset()
        self._hist.reset()

    def update(self, batch: DecoyBatch):
        # Skip native structures and ignore residues that don't have a ground-truth score
        non_native = np.repeat(np.char.not_equal(batch.decoy_name, 'native'),
                               repeats=batch.num_nodes_by_graph.cpu().numpy())
        has_score = torch.isfinite(batch.lddt).cpu().numpy()
        valid_scores = np.logical_and(non_native, has_score)

        # Used to uniquely identify a (protein, model) pair without using their str names
        target_model_id = batch.node_index_by_graph[valid_scores].cpu().numpy()
        node_preds = batch.node_features[valid_scores, self.column].detach().cpu().numpy()
        node_targets = batch.lddt[valid_scores].detach().cpu().numpy()

        # Streaming metrics on local scores (they expect torch tensors, not numpy arrays)
        self._rmse.update((torch.from_numpy(node_preds), torch.from_numpy(node_targets)))
        self._pearson.update((torch.from_numpy(node_preds), torch.from_numpy(node_targets)))

        # Per model metrics: pandas is the easiest way to get a groupby.
        grouped = pd.DataFrame({
            'target_model': target_model_id,
            'preds': node_preds,
            'true': node_targets
        }).groupby('target_model')

        per_model_pearsons = grouped.apply(lambda df: pearson(df['preds'], df['true']))
        self._per_model_pearson.update(torch.from_numpy(per_model_pearsons.values))

        self._hist.update(node_preds, node_targets)

    def compute(self):
        metrics = {}
        figures = {}

        if 'rmse' in self.metrics:
            metrics['rmse'] = self._rmse.compute()
        if 'pearson' in self.metrics:
            metrics['pearson'] = self._pearson.compute()
        if 'per_model_pearson' in self.metrics:
            metrics['per_model_pearson'] = self._per_model_pearson.compute()

        if 'hist' in self.figures:
            extra_title = []
            if 'pearson' in self.metrics:
                extra_title.append(f'$R$        {metrics["pearson"]:.3f}')
            if 'per_model_pearson' in self.metrics:
                extra_title.append(f'$R_\\mathrm{{model}}$ {metrics["per_model_pearson"]:.3f}')
            figures['hist'] = self._hist.compute('\n'.join(extra_title))

        return {'metrics': metrics, 'figures': figures}

    def completed(self, engine, prefix):
        result = self.compute()
        for name, metric in result['metrics'].items():
            engine.state.metrics[prefix + '/' + name] = metric
        for name, fig in result['figures'].items():
            engine.state.figures[prefix + '/' + name] = fig