Beispiel #1
0
class TrainLoopFSL(objax.Module):
    model: objax.Module
    eval_op: Callable
    train_op: Callable

    def __init__(self, nclass: int, **kwargs):
        self.params = EasyDict(kwargs)
        self.nclass = nclass

    def serialize_model(
            self
    ):  # Overload it in your model if you need something different.
        return pickle.dumps(self.model)

    def print(self):
        print(self.model.vars())
        print('Byte size %d\n' % len(self.serialize_model()))
        print('Parameters'.center(79, '-'))
        for kv in sorted(self.params.items()):
            print('%-32s %s' % kv)

    def train_step(self, summary: objax.jaxboard.Summary, data: dict,
                   step: np.ndarray):
        kv = self.train_op(step, data['image'], data['label'])
        for k, v in kv.items():
            if jn.isnan(v):
                raise ValueError('NaN', k)
            summary.scalar(k, float(v))

    def eval(self,
             summary: objax.jaxboard.Summary,
             epoch: int,
             test: Dict[str, Iterable],
             valid: Optional[Iterable] = None):
        def get_accuracy(dataset: DataSet):
            accuracy, total, batch = 0, 0, None
            for data in tqdm(dataset, leave=False, desc='Evaluating'):
                x, y = data['image'].numpy(), data['label'].numpy()
                total += x.shape[0]
                batch = batch or x.shape[0]
                if x.shape[0] != batch:
                    # Pad the last batch if it's smaller than expected (must divide properly on GPUs).
                    x = np.concatenate([x] + [x[-1:]] * (batch - x.shape[0]))
                p = self.eval_op(x)[:y.shape[0]]
                accuracy += (np.argmax(p,
                                       axis=1) == data['label'].numpy()).sum()
            return accuracy / total if total else 0

        valid_accuracy = 0 if valid is None else get_accuracy(valid)
        summary.scalar('accuracy/valid', 100 * valid_accuracy)
        test_accuracy = {
            key: get_accuracy(value)
            for key, value in test.items()
        }
        to_print = []
        for key, value in sorted(test_accuracy.items()):
            summary.scalar('accuracy/%s' % key, 100 * value)
            to_print.append('Acccuracy/%s %.2f' %
                            (key, summary['accuracy/%s' % key]()))
        print('Epoch %-4d  Loss %.2f  %s (Valid %.2f)' %
              (epoch + 1, summary['losses/xe'](), ' '.join(to_print),
               summary['accuracy/valid']()))

    def train(self,
              train_kimg: int,
              report_kimg: int,
              train: Iterable,
              valid: Iterable,
              test: Dict[str, Iterable],
              logdir: str,
              keep_ckpts: int,
              verbose: bool = True):
        if verbose:
            self.print()
            print()
            print('Training config'.center(79, '-'))
            print('%-20s %s' % ('Test sets:', sorted(test.keys())))
            print('%-20s %s' % ('Work directory:', logdir))
            print()
        model_path = os.path.join(logdir, 'model/latest.pickle')
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        ckpt = objax.io.Checkpoint(logdir=logdir, keep_ckpts=keep_ckpts)
        start_epoch = ckpt.restore(self.vars())[0]

        train_iter = iter(train)
        step_array = np.zeros(jax.local_device_count(),
                              'uint32')  # for multi-GPU
        with objax.jaxboard.SummaryWriter(os.path.join(logdir,
                                                       'tb')) as tensorboard:
            for epoch in range(start_epoch, train_kimg // report_kimg):
                summary = objax.jaxboard.Summary()
                loop = trange(0,
                              report_kimg << 10,
                              self.params.batch,
                              leave=False,
                              unit='img',
                              unit_scale=self.params.batch,
                              desc='Epoch %d/%d' %
                              (1 + epoch, train_kimg // report_kimg))
                with self.vars().replicate():
                    for step in loop:
                        step_array[:] = step + (epoch * (report_kimg << 10))
                        self.train_step(summary,
                                        next(train_iter),
                                        step=step_array)

                    self.eval(summary, epoch, test, valid)

                tensorboard.write(summary,
                                  step=(epoch + 1) * report_kimg * 1024)
                ckpt.save(self.vars(), epoch + 1)
                with open(model_path, 'wb') as f:
                    f.write(self.serialize_model())
Beispiel #2
0
def main(argv):
    del argv
    tf.config.experimental.set_visible_devices([], "GPU")

    seed = FLAGS.seed
    if seed is None:
        import time
        seed = np.random.randint(0, 1000000000)
        seed ^= int(time.time())

    args = EasyDict(arch=FLAGS.arch,
                    lr=FLAGS.lr,
                    batch=FLAGS.batch,
                    weight_decay=FLAGS.weight_decay,
                    augment=FLAGS.augment,
                    seed=seed)

    if FLAGS.tunename:
        logdir = '_'.join(sorted('%s=%s' % k for k in args.items()))
    elif FLAGS.expid is not None:
        logdir = "experiment-%d_%d" % (FLAGS.expid, FLAGS.num_experiments)
    else:
        logdir = "experiment-" + str(seed)
    logdir = os.path.join(FLAGS.logdir, logdir)

    if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz" % 10)):
        print(f"run {FLAGS.expid} already completed.")
        return
    else:
        if os.path.exists(logdir):
            print(f"deleting run {FLAGS.expid} that did not complete.")
            shutil.rmtree(logdir)

    print(f"starting run {FLAGS.expid}.")
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    train, test, xs, ys, keep, nclass = get_data(seed)

    # Define the network and train_it
    tm = MemModule(network(FLAGS.arch),
                   nclass=nclass,
                   mnist=FLAGS.dataset == 'mnist',
                   epochs=FLAGS.epochs,
                   expid=FLAGS.expid,
                   num_experiments=FLAGS.num_experiments,
                   pkeep=FLAGS.pkeep,
                   save_steps=FLAGS.save_steps,
                   only_subset=FLAGS.only_subset,
                   **args)

    r = {}
    r.update(tm.params)

    open(os.path.join(logdir, 'hparams.json'),
         "w").write(json.dumps(tm.params))
    np.save(os.path.join(logdir, 'keep.npy'), keep)

    tm.train(FLAGS.epochs,
             len(xs),
             train,
             test,
             logdir,
             save_steps=FLAGS.save_steps,
             patience=FLAGS.patience)