示例#1
0
 def train_step(self, summary: Summary, data: dict, progress: np.ndarray):
     kv = self.train_op(progress, data['image'].numpy(),
                        data['label'].numpy())
     for k, v in kv.items():
         if jn.isnan(v):
             raise ValueError('NaN, try reducing learning rate', k)
         summary.scalar(k, float(v))
示例#2
0
    def train(self,
              num_train_epochs: int,
              train_size: int,
              train: DataSet,
              test: DataSet,
              logdir: str,
              save_steps=100,
              patience=None):
        """
        Completely standard training. Nothing interesting to see here.
        """
        checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=20, makedir=True)
        start_epoch, last_ckpt = checkpoint.restore(self.vars())
        train_iter = iter(train)
        progress = np.zeros(jax.local_device_count(), 'f')  # for multi-GPU

        best_acc = 0
        best_acc_epoch = -1

        with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
            for epoch in range(start_epoch, num_train_epochs):
                # Train
                summary = Summary()
                loop = range(0, train_size, self.params.batch)
                for step in loop:
                    progress[:] = (step +
                                   (epoch * train_size)) / (num_train_epochs *
                                                            train_size)
                    self.train_step(summary, next(train_iter), progress)

                # Eval
                accuracy, total = 0, 0
                if epoch % FLAGS.eval_steps == 0 and test is not None:
                    for data in test:
                        total += data['image'].shape[0]
                        preds = np.argmax(self.predict(data['image'].numpy()),
                                          axis=1)
                        accuracy += (preds == data['label'].numpy()).sum()
                    accuracy /= total
                    summary.scalar('eval/accuracy', 100 * accuracy)
                    tensorboard.write(summary, step=(epoch + 1) * train_size)
                    print('Epoch %04d  Loss %.2f  Accuracy %.2f' %
                          (epoch + 1, summary['losses/xe'](),
                           summary['eval/accuracy']()))

                    if summary['eval/accuracy']() > best_acc:
                        best_acc = summary['eval/accuracy']()
                        best_acc_epoch = epoch
                    elif patience is not None and epoch > best_acc_epoch + patience:
                        print("early stopping!")
                        checkpoint.save(self.vars(), epoch + 1)
                        return

                else:
                    print('Epoch %04d  Loss %.2f  Accuracy --' %
                          (epoch + 1, summary['losses/xe']()))

                if epoch % save_steps == save_steps - 1:
                    checkpoint.save(self.vars(), epoch + 1)
示例#3
0
    def train(self, num_train_epochs: int, train_size: int, train: DataSet,
              test: DataSet, logdir: str):
        checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=5, makedir=True)
        start_epoch, last_ckpt = checkpoint.restore(self.vars())
        train_iter = iter(train)
        progress = np.zeros(jax.local_device_count(), 'f')  # for multi-GPU

        with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
            for epoch in range(start_epoch, num_train_epochs):
                with self.vars().replicate():
                    # Train
                    summary = Summary()
                    loop = trange(0,
                                  train_size,
                                  self.params.batch,
                                  leave=False,
                                  unit='img',
                                  unit_scale=self.params.batch,
                                  desc='Epoch %d/%d' %
                                  (1 + epoch, num_train_epochs))
                    for step in loop:
                        progress[:] = (step + (epoch * train_size)) / (
                            num_train_epochs * train_size)
                        self.train_step(summary, next(train_iter), progress)

                    # Eval
                    accuracy, total = 0, 0
                    for data in tqdm(test, leave=False, desc='Evaluating'):
                        total += data['image'].shape[0]
                        preds = np.argmax(self.predict(data['image'].numpy()),
                                          axis=1)
                        accuracy += (preds == data['label'].numpy()).sum()
                    accuracy /= total
                    summary.scalar('eval/accuracy', 100 * accuracy)
                    print('Epoch %04d  Loss %.2f  Accuracy %.2f' %
                          (epoch + 1, summary['losses/xe'](),
                           summary['eval/accuracy']()))
                    tensorboard.write(summary, step=(epoch + 1) * train_size)

                checkpoint.save(self.vars(), epoch + 1)
示例#4
0
    return v


train_op = objax.Jit(train_op)  # Compile train_op to make it run faster.
predict = objax.Jit(model_ema)

# Training
print(model.vars())
print(f'Visualize results with: tensorboard --logdir "{logdir}"')
print(
    "Disclaimer: This code demonstrates the DNNet class. For SOTA accuracy use a CNN instead."
)
with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
    for epoch in range(num_train_epochs):
        # Train one epoch
        summary = Summary()
        loop = trange(0,
                      train_size,
                      batch,
                      leave=False,
                      unit='img',
                      unit_scale=batch,
                      desc='Epoch %d/%d' % (1 + epoch, num_train_epochs))
        for it in loop:
            sel = np.random.randint(size=(batch, ), low=0, high=train_size)
            x, xl = train.image[sel], train.label[sel]
            xl = one_hot(xl, nclass)
            v = train_op(x, xl)
            summary.scalar('losses/xe', float(v[0]))

        # Eval
示例#5
0
    opt(lr, g)
    return v


# gv.vars() contains model_vars.
# Different from GradValues, in the case of PrivateGradValues, gv.vars() has its
# own internal variable, the key of the random number generator.
# When we jit train_op, we need to have the interval variable passed to Jit.
train_op = objax.Jit(train_op, gv.vars() + opt.vars())

# Training
with SummaryWriter(os.path.join(log_dir, 'tb')) as tensorboard:
    steps = 0  # Keep track the number of iterations for privacy accounting.
    for epoch in range(num_train_epochs):
        # Train one epoch
        summary = Summary()
        loop = trange(0,
                      num_train_images,
                      batch,
                      leave=False,
                      unit='img',
                      unit_scale=batch,
                      desc='Epoch %d/%d' % (1 + epoch, num_train_epochs))
        for it in loop:
            sel = np.random.randint(size=(batch, ),
                                    low=0,
                                    high=num_train_images)
            x, xl = train.image[sel], train.label[sel]
            xl = one_hot(xl, nclass)
            v = train_op(x, xl)
            summary.scalar('losses/xe', float(v[0]))