def train_step(self, summary: Summary, data: dict, progress: np.ndarray): kv = self.train_op(progress, data['image'].numpy(), data['label'].numpy()) for k, v in kv.items(): if jn.isnan(v): raise ValueError('NaN, try reducing learning rate', k) summary.scalar(k, float(v))
def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str, save_steps=100, patience=None): """ Completely standard training. Nothing interesting to see here. """ checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=20, makedir=True) start_epoch, last_ckpt = checkpoint.restore(self.vars()) train_iter = iter(train) progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU best_acc = 0 best_acc_epoch = -1 with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard: for epoch in range(start_epoch, num_train_epochs): # Train summary = Summary() loop = range(0, train_size, self.params.batch) for step in loop: progress[:] = (step + (epoch * train_size)) / (num_train_epochs * train_size) self.train_step(summary, next(train_iter), progress) # Eval accuracy, total = 0, 0 if epoch % FLAGS.eval_steps == 0 and test is not None: for data in test: total += data['image'].shape[0] preds = np.argmax(self.predict(data['image'].numpy()), axis=1) accuracy += (preds == data['label'].numpy()).sum() accuracy /= total summary.scalar('eval/accuracy', 100 * accuracy) tensorboard.write(summary, step=(epoch + 1) * train_size) print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](), summary['eval/accuracy']())) if summary['eval/accuracy']() > best_acc: best_acc = summary['eval/accuracy']() best_acc_epoch = epoch elif patience is not None and epoch > best_acc_epoch + patience: print("early stopping!") checkpoint.save(self.vars(), epoch + 1) return else: print('Epoch %04d Loss %.2f Accuracy --' % (epoch + 1, summary['losses/xe']())) if epoch % save_steps == save_steps - 1: checkpoint.save(self.vars(), epoch + 1)
def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str): checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=5, makedir=True) start_epoch, last_ckpt = checkpoint.restore(self.vars()) train_iter = iter(train) progress = np.zeros(jax.local_device_count(), 'f') # for multi-GPU with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard: for epoch in range(start_epoch, num_train_epochs): with self.vars().replicate(): # Train summary = Summary() loop = trange(0, train_size, self.params.batch, leave=False, unit='img', unit_scale=self.params.batch, desc='Epoch %d/%d' % (1 + epoch, num_train_epochs)) for step in loop: progress[:] = (step + (epoch * train_size)) / ( num_train_epochs * train_size) self.train_step(summary, next(train_iter), progress) # Eval accuracy, total = 0, 0 for data in tqdm(test, leave=False, desc='Evaluating'): total += data['image'].shape[0] preds = np.argmax(self.predict(data['image'].numpy()), axis=1) accuracy += (preds == data['label'].numpy()).sum() accuracy /= total summary.scalar('eval/accuracy', 100 * accuracy) print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, summary['losses/xe'](), summary['eval/accuracy']())) tensorboard.write(summary, step=(epoch + 1) * train_size) checkpoint.save(self.vars(), epoch + 1)
return v train_op = objax.Jit(train_op) # Compile train_op to make it run faster. predict = objax.Jit(model_ema) # Training print(model.vars()) print(f'Visualize results with: tensorboard --logdir "{logdir}"') print( "Disclaimer: This code demonstrates the DNNet class. For SOTA accuracy use a CNN instead." ) with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard: for epoch in range(num_train_epochs): # Train one epoch summary = Summary() loop = trange(0, train_size, batch, leave=False, unit='img', unit_scale=batch, desc='Epoch %d/%d' % (1 + epoch, num_train_epochs)) for it in loop: sel = np.random.randint(size=(batch, ), low=0, high=train_size) x, xl = train.image[sel], train.label[sel] xl = one_hot(xl, nclass) v = train_op(x, xl) summary.scalar('losses/xe', float(v[0])) # Eval
opt(lr, g) return v # gv.vars() contains model_vars. # Different from GradValues, in the case of PrivateGradValues, gv.vars() has its # own internal variable, the key of the random number generator. # When we jit train_op, we need to have the interval variable passed to Jit. train_op = objax.Jit(train_op, gv.vars() + opt.vars()) # Training with SummaryWriter(os.path.join(log_dir, 'tb')) as tensorboard: steps = 0 # Keep track the number of iterations for privacy accounting. for epoch in range(num_train_epochs): # Train one epoch summary = Summary() loop = trange(0, num_train_images, batch, leave=False, unit='img', unit_scale=batch, desc='Epoch %d/%d' % (1 + epoch, num_train_epochs)) for it in loop: sel = np.random.randint(size=(batch, ), low=0, high=num_train_images) x, xl = train.image[sel], train.label[sel] xl = one_hot(xl, nclass) v = train_op(x, xl) summary.scalar('losses/xe', float(v[0]))