def eval(self, summary: objax.jaxboard.Summary, epoch: int, test: Dict[str, Iterable], valid: Optional[Iterable] = None): def get_accuracy(dataset: DataSet): accuracy, total, batch = 0, 0, None for data in tqdm(dataset, leave=False, desc='Evaluating'): x, y = data['image'].numpy(), data['label'].numpy() total += x.shape[0] batch = batch or x.shape[0] if x.shape[0] != batch: # Pad the last batch if it's smaller than expected (must divide properly on GPUs). x = np.concatenate([x] + [x[-1:]] * (batch - x.shape[0])) p = self.eval_op(x)[:y.shape[0]] accuracy += (np.argmax(p, axis=1) == data['label'].numpy()).sum() return accuracy / total if total else 0 valid_accuracy = 0 if valid is None else get_accuracy(valid) summary.scalar('accuracy/valid', 100 * valid_accuracy) test_accuracy = { key: get_accuracy(value) for key, value in test.items() } to_print = [] for key, value in sorted(test_accuracy.items()): summary.scalar('accuracy/%s' % key, 100 * value) to_print.append('Acccuracy/%s %.2f' % (key, summary['accuracy/%s' % key]())) print('Epoch %-4d Loss %.2f %s (Valid %.2f)' % (epoch + 1, summary['losses/xe'](), ' '.join(to_print), summary['accuracy/valid']()))
def train_step(self, summary: objax.jaxboard.Summary, data: dict, step: np.ndarray): kv = self.train_op(step, data['image'], data['label']) for k, v in kv.items(): if jn.isnan(v): raise ValueError('NaN', k) summary.scalar(k, float(v))