def train_step(self, summary: Summary, data: dict, progress: np.ndarray): kv = self.train_op(progress, data['image'].numpy(), data['label'].numpy()) for k, v in kv.items(): if jn.isnan(v): raise ValueError('NaN, try reducing learning rate', k) summary.scalar(k, float(v))
def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str, save_steps=100, patience=None): """ Completely standard training. Nothing interesting to see here. """ checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=20, makedir=True) start_epoch, last_ckpt = checkpoint.restore(self.vars()) train_iter = iter(train) progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU best_acc = 0 best_acc_epoch = -1 with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard: for epoch in range(start_epoch, num_train_epochs): # Train summary = Summary() loop = range(0, train_size, self.params.batch) for step in loop: progress[:] = (step + (epoch * train_size)) / (num_train_epochs * train_size) self.train_step(summary, next(train_iter), progress) # Eval accuracy, total = 0, 0 if epoch % FLAGS.eval_steps == 0 and test is not None: for data in test: total += data['image'].shape[0] preds = np.argmax(self.predict(data['image'].numpy()), axis=1) accuracy += (preds == data['label'].numpy()).sum() accuracy /= total summary.scalar('eval/accuracy', 100 * accuracy) tensorboard.write(summary, step=(epoch + 1) * train_size) print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](), summary['eval/accuracy']())) if summary['eval/accuracy']() > best_acc: best_acc = summary['eval/accuracy']() best_acc_epoch = epoch elif patience is not None and epoch > best_acc_epoch + patience: print("early stopping!") checkpoint.save(self.vars(), epoch + 1) return else: print('Epoch %04d Loss %.2f Accuracy --' % (epoch + 1, summary['losses/xe']())) if epoch % save_steps == save_steps - 1: checkpoint.save(self.vars(), epoch + 1)
def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str): checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=5, makedir=True) start_epoch, last_ckpt = checkpoint.restore(self.vars()) train_iter = iter(train) progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard: for epoch in range(start_epoch, num_train_epochs): with self.vars().replicate(): # Train summary = Summary() loop = trange(0, train_size, self.params.batch, leave=False, unit='img', unit_scale=self.params.batch, desc='Epoch %d/%d' % (1 + epoch, num_train_epochs)) for step in loop: progress[:] = (step + (epoch * train_size)) / ( num_train_epochs * train_size) self.train_step(summary, next(train_iter), progress) # Eval accuracy, total = 0, 0 for data in tqdm(test, leave=False, desc='Evaluating'): total += data['image'].shape[0] preds = np.argmax(self.predict(data['image'].numpy()), axis=1) accuracy += (preds == data['label'].numpy()).sum() accuracy /= total summary.scalar('eval/accuracy', 100 * accuracy) print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](), summary['eval/accuracy']())) tensorboard.write(summary, step=(epoch + 1) * train_size) checkpoint.save(self.vars(), epoch + 1)
for epoch in range(num_train_epochs): # Train one epoch summary = Summary() loop = trange(0, train_size, batch, leave=False, unit='img', unit_scale=batch, desc='Epoch %d/%d' % (1 + epoch, num_train_epochs)) for it in loop: sel = np.random.randint(size=(batch, ), low=0, high=train_size) x, xl = train.image[sel], train.label[sel] xl = one_hot(xl, nclass) v = train_op(x, xl) summary.scalar('losses/xe', float(v[0])) # Eval accuracy = 0 for it in trange(0, test.image.shape[0], batch, leave=False, desc='Evaluating'): x = test.image[it:it + batch] xl = test.label[it:it + batch] accuracy += (np.argmax(predict(x), axis=1) == xl).sum() accuracy /= test.image.shape[0] summary.scalar('eval/accuracy', 100 * accuracy) print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](), summary['eval/accuracy']()))
summary = Summary() loop = trange(0, num_train_images, batch, leave=False, unit='img', unit_scale=batch, desc='Epoch %d/%d' % (1 + epoch, num_train_epochs)) for it in loop: sel = np.random.randint(size=(batch, ), low=0, high=num_train_images) x, xl = train.image[sel], train.label[sel] xl = one_hot(xl, nclass) v = train_op(x, xl) summary.scalar('losses/xe', float(v[0])) steps += 1 # Eval accuracy = 0 for it in trange(0, test.image.shape[0], batch, leave=False, desc='Evaluating'): x = test.image[it:it + batch] xl = test.label[it:it + batch] accuracy += (np.argmax(predict(x), axis=1) == xl).sum() accuracy /= test.image.shape[0] summary.scalar('eval/accuracy', 100 * accuracy)