Exemplo n.º 1
0
    def eval(self,
             summary: objax.jaxboard.Summary,
             epoch: int,
             test: Dict[str, Iterable],
             valid: Optional[Iterable] = None):
        def get_accuracy(dataset: DataSet):
            accuracy, total, batch = 0, 0, None
            for data in tqdm(dataset, leave=False, desc='Evaluating'):
                x, y = data['image'].numpy(), data['label'].numpy()
                total += x.shape[0]
                batch = batch or x.shape[0]
                if x.shape[0] != batch:
                    # Pad the last batch if it's smaller than expected (must divide properly on GPUs).
                    x = np.concatenate([x] + [x[-1:]] * (batch - x.shape[0]))
                p = self.eval_op(x)[:y.shape[0]]
                accuracy += (np.argmax(p,
                                       axis=1) == data['label'].numpy()).sum()
            return accuracy / total if total else 0

        valid_accuracy = 0 if valid is None else get_accuracy(valid)
        summary.scalar('accuracy/valid', 100 * valid_accuracy)
        test_accuracy = {
            key: get_accuracy(value)
            for key, value in test.items()
        }
        to_print = []
        for key, value in sorted(test_accuracy.items()):
            summary.scalar('accuracy/%s' % key, 100 * value)
            to_print.append('Acccuracy/%s %.2f' %
                            (key, summary['accuracy/%s' % key]()))
        print('Epoch %-4d  Loss %.2f  %s (Valid %.2f)' %
              (epoch + 1, summary['losses/xe'](), ' '.join(to_print),
               summary['accuracy/valid']()))
Exemplo n.º 2
0
 def train_step(self, summary: objax.jaxboard.Summary, data: dict,
                step: np.ndarray):
     kv = self.train_op(step, data['image'], data['label'])
     for k, v in kv.items():
         if jn.isnan(v):
             raise ValueError('NaN', k)
         summary.scalar(k, float(v))