Ejemplo n.º 1
0
    def train(self,
              num_train_epochs: int,
              train_size: int,
              train: DataSet,
              test: DataSet,
              logdir: str,
              save_steps=100,
              patience=None):
        """
        Completely standard training. Nothing interesting to see here.
        """
        checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=20, makedir=True)
        start_epoch, last_ckpt = checkpoint.restore(self.vars())
        train_iter = iter(train)
        progress = np.zeros(jax.local_device_count(), 'f')  # for multi-GPU

        best_acc = 0
        best_acc_epoch = -1

        with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
            for epoch in range(start_epoch, num_train_epochs):
                # Train
                summary = Summary()
                loop = range(0, train_size, self.params.batch)
                for step in loop:
                    progress[:] = (step +
                                   (epoch * train_size)) / (num_train_epochs *
                                                            train_size)
                    self.train_step(summary, next(train_iter), progress)

                # Eval
                accuracy, total = 0, 0
                if epoch % FLAGS.eval_steps == 0 and test is not None:
                    for data in test:
                        total += data['image'].shape[0]
                        preds = np.argmax(self.predict(data['image'].numpy()),
                                          axis=1)
                        accuracy += (preds == data['label'].numpy()).sum()
                    accuracy /= total
                    summary.scalar('eval/accuracy', 100 * accuracy)
                    tensorboard.write(summary, step=(epoch + 1) * train_size)
                    print('Epoch %04d  Loss %.2f  Accuracy %.2f' %
                          (epoch + 1, summary['losses/xe'](),
                           summary['eval/accuracy']()))

                    if summary['eval/accuracy']() > best_acc:
                        best_acc = summary['eval/accuracy']()
                        best_acc_epoch = epoch
                    elif patience is not None and epoch > best_acc_epoch + patience:
                        print("early stopping!")
                        checkpoint.save(self.vars(), epoch + 1)
                        return

                else:
                    print('Epoch %04d  Loss %.2f  Accuracy --' %
                          (epoch + 1, summary['losses/xe']()))

                if epoch % save_steps == save_steps - 1:
                    checkpoint.save(self.vars(), epoch + 1)
Ejemplo n.º 2
0
    def train(self, num_train_epochs: int, train_size: int, train: DataSet,
              test: DataSet, logdir: str):
        checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=5, makedir=True)
        start_epoch, last_ckpt = checkpoint.restore(self.vars())
        train_iter = iter(train)
        progress = np.zeros(jax.local_device_count(), 'f')  # for multi-GPU

        with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
            for epoch in range(start_epoch, num_train_epochs):
                with self.vars().replicate():
                    # Train
                    summary = Summary()
                    loop = trange(0,
                                  train_size,
                                  self.params.batch,
                                  leave=False,
                                  unit='img',
                                  unit_scale=self.params.batch,
                                  desc='Epoch %d/%d' %
                                  (1 + epoch, num_train_epochs))
                    for step in loop:
                        progress[:] = (step + (epoch * train_size)) / (
                            num_train_epochs * train_size)
                        self.train_step(summary, next(train_iter), progress)

                    # Eval
                    accuracy, total = 0, 0
                    for data in tqdm(test, leave=False, desc='Evaluating'):
                        total += data['image'].shape[0]
                        preds = np.argmax(self.predict(data['image'].numpy()),
                                          axis=1)
                        accuracy += (preds == data['label'].numpy()).sum()
                    accuracy /= total
                    summary.scalar('eval/accuracy', 100 * accuracy)
                    print('Epoch %04d  Loss %.2f  Accuracy %.2f' %
                          (epoch + 1, summary['losses/xe'](),
                           summary['eval/accuracy']()))
                    tensorboard.write(summary, step=(epoch + 1) * train_size)

                checkpoint.save(self.vars(), epoch + 1)
Ejemplo n.º 3
0
    g, v = gv(x, xl)  # returns gradients, loss
    opt(lr, g)
    model_ema.update_ema()
    return v


train_op = objax.Jit(train_op)  # Compile train_op to make it run faster.
predict = objax.Jit(model_ema)

# Training
print(model.vars())
print(f'Visualize results with: tensorboard --logdir "{logdir}"')
print(
    "Disclaimer: This code demonstrates the DNNet class. For SOTA accuracy use a CNN instead."
)
with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
    for epoch in range(num_train_epochs):
        # Train one epoch
        summary = Summary()
        loop = trange(0,
                      train_size,
                      batch,
                      leave=False,
                      unit='img',
                      unit_scale=batch,
                      desc='Epoch %d/%d' % (1 + epoch, num_train_epochs))
        for it in loop:
            sel = np.random.randint(size=(batch, ), low=0, high=train_size)
            x, xl = train.image[sel], train.label[sel]
            xl = one_hot(xl, nclass)
            v = train_op(x, xl)