def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str, save_steps=100, patience=None): """ Completely standard training. Nothing interesting to see here. """ checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=20, makedir=True) start_epoch, last_ckpt = checkpoint.restore(self.vars()) train_iter = iter(train) progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU best_acc = 0 best_acc_epoch = -1 with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard: for epoch in range(start_epoch, num_train_epochs): # Train summary = Summary() loop = range(0, train_size, self.params.batch) for step in loop: progress[:] = (step + (epoch * train_size)) / (num_train_epochs * train_size) self.train_step(summary, next(train_iter), progress) # Eval accuracy, total = 0, 0 if epoch % FLAGS.eval_steps == 0 and test is not None: for data in test: total += data['image'].shape[0] preds = np.argmax(self.predict(data['image'].numpy()), axis=1) accuracy += (preds == data['label'].numpy()).sum() accuracy /= total summary.scalar('eval/accuracy', 100 * accuracy) tensorboard.write(summary, step=(epoch + 1) * train_size) print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](), summary['eval/accuracy']())) if summary['eval/accuracy']() > best_acc: best_acc = summary['eval/accuracy']() best_acc_epoch = epoch elif patience is not None and epoch > best_acc_epoch + patience: print("early stopping!") checkpoint.save(self.vars(), epoch + 1) return else: print('Epoch %04d Loss %.2f Accuracy --' % (epoch + 1, summary['losses/xe']())) if epoch % save_steps == save_steps - 1: checkpoint.save(self.vars(), epoch + 1)
def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str): checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=5, makedir=True) start_epoch, last_ckpt = checkpoint.restore(self.vars()) train_iter = iter(train) progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard: for epoch in range(start_epoch, num_train_epochs): with self.vars().replicate(): # Train summary = Summary() loop = trange(0, train_size, self.params.batch, leave=False, unit='img', unit_scale=self.params.batch, desc='Epoch %d/%d' % (1 + epoch, num_train_epochs)) for step in loop: progress[:] = (step + (epoch * train_size)) / ( num_train_epochs * train_size) self.train_step(summary, next(train_iter), progress) # Eval accuracy, total = 0, 0 for data in tqdm(test, leave=False, desc='Evaluating'): total += data['image'].shape[0] preds = np.argmax(self.predict(data['image'].numpy()), axis=1) accuracy += (preds == data['label'].numpy()).sum() accuracy /= total summary.scalar('eval/accuracy', 100 * accuracy) print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](), summary['eval/accuracy']())) tensorboard.write(summary, step=(epoch + 1) * train_size) checkpoint.save(self.vars(), epoch + 1)
g, v = gv(x, xl) # returns gradients, loss opt(lr, g) model_ema.update_ema() return v train_op = objax.Jit(train_op) # Compile train_op to make it run faster. predict = objax.Jit(model_ema) # Training print(model.vars()) print(f'Visualize results with: tensorboard --logdir "{logdir}"') print( "Disclaimer: This code demonstrates the DNNet class. For SOTA accuracy use a CNN instead." ) with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard: for epoch in range(num_train_epochs): # Train one epoch summary = Summary() loop = trange(0, train_size, batch, leave=False, unit='img', unit_scale=batch, desc='Epoch %d/%d' % (1 + epoch, num_train_epochs)) for it in loop: sel = np.random.randint(size=(batch, ), low=0, high=train_size) x, xl = train.image[sel], train.label[sel] xl = one_hot(xl, nclass) v = train_op(x, xl)