Exemplo n.º 1
0
def _load_stl10():
    def unflatten(images):
        return np.transpose(images.reshape((-1, 3, 96, 96)), [0, 3, 2, 1])

    with tempfile.NamedTemporaryFile() as f:
        if tf.io.gfile.exists('stl10/stl10_binary.tar.gz'):
            f = tf.io.gfile.GFile('stl10/stl10_binary.tar.gz', 'rb')
        else:
            request.urlretrieve(URLS['stl10'], f.name)
        tar = tarfile.open(fileobj=f)
        train_x = tar.extractfile('stl10_binary/train_X.bin')
        train_y = tar.extractfile('stl10_binary/train_y.bin')
        test_x = tar.extractfile('stl10_binary/test_X.bin')
        test_y = tar.extractfile('stl10_binary/test_y.bin')
        unlabeled_x = tar.extractfile('stl10_binary/unlabeled_X.bin')
        train_set = {
            'images': np.frombuffer(train_x.read(), dtype=np.uint8),
            'labels': np.frombuffer(train_y.read(), dtype=np.uint8) - 1
        }
        test_set = {
            'images': np.frombuffer(test_x.read(), dtype=np.uint8),
            'labels': np.frombuffer(test_y.read(), dtype=np.uint8) - 1
        }
        _imgs = np.frombuffer(unlabeled_x.read(), dtype=np.uint8)
        unlabeled_set = {
            'images': _imgs,
            'labels': np.zeros(100000, dtype=np.uint8)
        }
        fold_indices = tar.extractfile('stl10_binary/fold_indices.txt').read()

    train_set['images'] = _encode_png(unflatten(train_set['images']))
    test_set['images'] = _encode_png(unflatten(test_set['images']))
    unlabeled_set['images'] = _encode_png(unflatten(unlabeled_set['images']))
    return dict(
        train=train_set,
        test=test_set,
        unlabeled=unlabeled_set,
        files=[EasyDict(filename="stl10_fold_indices.txt", data=fold_indices)])
Exemplo n.º 2
0
from objax.util import EasyDict

# Data: train has 1027 images - test has 256 images
# Each image is 300 x 300 x 3 bytes
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
data = tfds.as_numpy(tfds.load(name='horses_or_humans', batch_size=-1, data_dir=DATA_DIR))


def prepare(x, downscale=3):
    """Normalize images to [-1, 1] and downscale them to 100x100x3 (for faster training) and flatten them."""
    s = x.shape
    x = x.astype('f').reshape((s[0], s[1] // downscale, downscale, s[2] // downscale, downscale, s[3]))
    return x.mean((2, 4)).reshape((s[0], -1)) * (1 / 127.5) - 1


train = EasyDict(image=prepare(data['train']['image']), label=data['train']['label'])
test = EasyDict(image=prepare(data['test']['image']), label=data['test']['label'])
ndim = train.image.shape[-1]
del data

# Settings
lr = 0.0001  # learning rate
batch = 256
epochs = 20

# Model
model = objax.nn.Linear(ndim, 1)
opt = objax.optimizer.SGD(model.vars())
print(model.vars())

Exemplo n.º 3
0
 def __init__(self, nclass: int, **kwargs):
     self.nclass = nclass
     self.params = EasyDict(kwargs)
Exemplo n.º 4
0
class TrainLoopFSL(objax.Module):
    model: objax.Module
    eval_op: Callable
    train_op: Callable

    def __init__(self, nclass: int, **kwargs):
        self.params = EasyDict(kwargs)
        self.nclass = nclass

    def serialize_model(
            self
    ):  # Overload it in your model if you need something different.
        return pickle.dumps(self.model)

    def print(self):
        print(self.model.vars())
        print('Byte size %d\n' % len(self.serialize_model()))
        print('Parameters'.center(79, '-'))
        for kv in sorted(self.params.items()):
            print('%-32s %s' % kv)

    def train_step(self, summary: objax.jaxboard.Summary, data: dict,
                   step: np.ndarray):
        kv = self.train_op(step, data['image'], data['label'])
        for k, v in kv.items():
            if jn.isnan(v):
                raise ValueError('NaN', k)
            summary.scalar(k, float(v))

    def eval(self,
             summary: objax.jaxboard.Summary,
             epoch: int,
             test: Dict[str, Iterable],
             valid: Optional[Iterable] = None):
        def get_accuracy(dataset: DataSet):
            accuracy, total, batch = 0, 0, None
            for data in tqdm(dataset, leave=False, desc='Evaluating'):
                x, y = data['image'].numpy(), data['label'].numpy()
                total += x.shape[0]
                batch = batch or x.shape[0]
                if x.shape[0] != batch:
                    # Pad the last batch if it's smaller than expected (must divide properly on GPUs).
                    x = np.concatenate([x] + [x[-1:]] * (batch - x.shape[0]))
                p = self.eval_op(x)[:y.shape[0]]
                accuracy += (np.argmax(p,
                                       axis=1) == data['label'].numpy()).sum()
            return accuracy / total if total else 0

        valid_accuracy = 0 if valid is None else get_accuracy(valid)
        summary.scalar('accuracy/valid', 100 * valid_accuracy)
        test_accuracy = {
            key: get_accuracy(value)
            for key, value in test.items()
        }
        to_print = []
        for key, value in sorted(test_accuracy.items()):
            summary.scalar('accuracy/%s' % key, 100 * value)
            to_print.append('Acccuracy/%s %.2f' %
                            (key, summary['accuracy/%s' % key]()))
        print('Epoch %-4d  Loss %.2f  %s (Valid %.2f)' %
              (epoch + 1, summary['losses/xe'](), ' '.join(to_print),
               summary['accuracy/valid']()))

    def train(self,
              train_kimg: int,
              report_kimg: int,
              train: Iterable,
              valid: Iterable,
              test: Dict[str, Iterable],
              logdir: str,
              keep_ckpts: int,
              verbose: bool = True):
        if verbose:
            self.print()
            print()
            print('Training config'.center(79, '-'))
            print('%-20s %s' % ('Test sets:', sorted(test.keys())))
            print('%-20s %s' % ('Work directory:', logdir))
            print()
        model_path = os.path.join(logdir, 'model/latest.pickle')
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        ckpt = objax.io.Checkpoint(logdir=logdir, keep_ckpts=keep_ckpts)
        start_epoch = ckpt.restore(self.vars())[0]

        train_iter = iter(train)
        step_array = np.zeros(jax.local_device_count(),
                              'uint32')  # for multi-GPU
        with objax.jaxboard.SummaryWriter(os.path.join(logdir,
                                                       'tb')) as tensorboard:
            for epoch in range(start_epoch, train_kimg // report_kimg):
                summary = objax.jaxboard.Summary()
                loop = trange(0,
                              report_kimg << 10,
                              self.params.batch,
                              leave=False,
                              unit='img',
                              unit_scale=self.params.batch,
                              desc='Epoch %d/%d' %
                              (1 + epoch, train_kimg // report_kimg))
                with self.vars().replicate():
                    for step in loop:
                        step_array[:] = step + (epoch * (report_kimg << 10))
                        self.train_step(summary,
                                        next(train_iter),
                                        step=step_array)

                    self.eval(summary, epoch, test, valid)

                tensorboard.write(summary,
                                  step=(epoch + 1) * report_kimg * 1024)
                ckpt.save(self.vars(), epoch + 1)
                with open(model_path, 'wb') as f:
                    f.write(self.serialize_model())
Exemplo n.º 5
0
# Data
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR))
train_size = len(data['train']['image'])
test_size = len(data['test']['image'])
train_shape = data['train']['image'].shape
image_size = train_shape[1] * train_shape[2] * train_shape[3]
nclass = len(np.unique(data['train']['label']))
flat_train_images = np.reshape(
    data['train']['image'].transpose(0, 3, 1, 2) / 127.5 - 1,
    (train_size, image_size))
flat_test_images = np.reshape(
    data['test']['image'].transpose(0, 3, 1, 2) / 127.5 - 1,
    (test_size, image_size))
test = EasyDict(image=flat_test_images, label=data['test']['label'])
train = EasyDict(image=flat_train_images, label=data['train']['label'])
del data

# Settings
lr = 0.0002
batch = 64
num_train_epochs = 40
dnn_layer_sizes = image_size, 128, 10
logdir = f'experiments/classify/img/mnist/filters{dnn_layer_sizes}'

# Model
model = DNNet(dnn_layer_sizes, leaky_relu)
model_ema = objax.optimizer.ExponentialMovingAverageModule(model,
                                                           momentum=0.999)
opt = objax.optimizer.Adam(model.vars())
Exemplo n.º 6
0
def main(argv):
    del argv
    tf.config.experimental.set_visible_devices([], "GPU")

    seed = FLAGS.seed
    if seed is None:
        import time
        seed = np.random.randint(0, 1000000000)
        seed ^= int(time.time())

    args = EasyDict(arch=FLAGS.arch,
                    lr=FLAGS.lr,
                    batch=FLAGS.batch,
                    weight_decay=FLAGS.weight_decay,
                    augment=FLAGS.augment,
                    seed=seed)

    if FLAGS.tunename:
        logdir = '_'.join(sorted('%s=%s' % k for k in args.items()))
    elif FLAGS.expid is not None:
        logdir = "experiment-%d_%d" % (FLAGS.expid, FLAGS.num_experiments)
    else:
        logdir = "experiment-" + str(seed)
    logdir = os.path.join(FLAGS.logdir, logdir)

    if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz" % 10)):
        print(f"run {FLAGS.expid} already completed.")
        return
    else:
        if os.path.exists(logdir):
            print(f"deleting run {FLAGS.expid} that did not complete.")
            shutil.rmtree(logdir)

    print(f"starting run {FLAGS.expid}.")
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    train, test, xs, ys, keep, nclass = get_data(seed)

    # Define the network and train_it
    tm = MemModule(network(FLAGS.arch),
                   nclass=nclass,
                   mnist=FLAGS.dataset == 'mnist',
                   epochs=FLAGS.epochs,
                   expid=FLAGS.expid,
                   num_experiments=FLAGS.num_experiments,
                   pkeep=FLAGS.pkeep,
                   save_steps=FLAGS.save_steps,
                   only_subset=FLAGS.only_subset,
                   **args)

    r = {}
    r.update(tm.params)

    open(os.path.join(logdir, 'hparams.json'),
         "w").write(json.dumps(tm.params))
    np.save(os.path.join(logdir, 'keep.npy'), keep)

    tm.train(FLAGS.epochs,
             len(xs),
             train,
             test,
             logdir,
             save_steps=FLAGS.save_steps,
             patience=FLAGS.patience)
Exemplo n.º 7
0
    def __call__(self,
                 x,
                 training=False):  # x = (batch, colors, height, width)
        y = self.pre_conv(x)
        y = self.block1(y)
        y = self.block2(y)
        logits = self.post_conv(y).mean((2, 3))  # logits = (batch, nclass)
        if training:
            return logits
        return objax.functional.softmax(logits)


# Data
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR))
train = EasyDict(image=data['train']['image'].transpose(0, 3, 1, 2) / 255,
                 label=data['train']['label'])
test = EasyDict(image=data['test']['image'].transpose(0, 3, 1, 2) / 255,
                label=data['test']['label'])
del data


def augment(
    x,
    shift=4
):  # Shift all images in the batch by up to "shift" pixels in any direction.
    x_pad = np.pad(x, [[0, 0], [0, 0], [shift, shift], [shift, shift]])
    rx, ry = np.random.randint(0, shift, size=2)
    return x_pad[:, :, rx:rx + 28, ry:ry + 28]


# Settings
Exemplo n.º 8
0
        for i in range(scales):
            ops.extend([
                Conv2D(nf(i), nf(i), 3), nl,
                Conv2D(nf(i), nf(i + 1), 3), nl,
                partial(average_pool_2d, size=2, strides=2)
            ])
        ops.extend([Conv2D(nf(scales), nclass, 3), lambda x: x.mean((2, 3))])
        super().__init__(ops)


# Data
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR))
inputs = data['train']['image']
labels = data['train']['label']
train = EasyDict(image=inputs.transpose(0, 3, 1, 2) / 127.5 - 1, label=labels)
test = EasyDict(image=data['test']['image'].transpose(0, 3, 1, 2) / 127.5 - 1,
                label=data['test']['label'])
num_train_images = train.image.shape[0]
nclass = len(np.unique(data['train']['label']))
del data, inputs, labels

# Settings
log_dir = args.log_dir
filters = 32
filters_max = 64

num_train_epochs = args.epochs
lr = args.lr
batch = args.batchsize
l2_norm_clip = args.l2_norm_clip