Beispiel #1
0
    def testRmspropVectorPiecewiseConstantSchedule(self):
        def loss(x):
            return np.dot(x, x)

        x0 = np.ones(2)
        step_schedule = optimizers.piecewise_constant([25, 75],
                                                      [1.0, 0.5, 0.1])
        self._CheckFuns(optimizers.rmsprop, loss, x0, step_schedule)
Beispiel #2
0
def schedule_maker(schedule_tuple, learn_rate):
    """
    Return a scheduler function given a tuple of the form:
        (sched_name, decay_steps, min_lr)

    This just wraps existing JAX schedulers, but using simplified syntax
    """
    sched_type = schedule_tuple[0]
    assert learn_rate >= 0
    assert sched_type in ['const', 'exp', 'poly', 'piecewise']

    if sched_type == 'const':
        # Constant learning rate
        sched_fun = jopt.constant(learn_rate)
    elif sched_type == 'exp':
        # Exponentially decaying learning rate
        sched_fun = jopt.exponential_decay(learn_rate, schedule_tuple[1], 0.5)
    elif sched_type == 'poly':
        # Harmonically decaying stepped learning rate
        sched_fun = jopt.inverse_time_decay(learn_rate,
                                            schedule_tuple[1],
                                            5,
                                            staircase=True)
    elif sched_type == 'piecewise':
        # Piecewise constant learning rate, drops by factor of 10 each time
        step_len = schedule_tuple[1]
        assert step_len > 0
        bounds = [step_len * i for i in range(1, 10)]
        values = [learn_rate * 10**(-i) for i in range(10)]
        sched_fun = jopt.piecewise_constant(bounds, values)

    def my_sched_fun(epoch):
        lr = sched_fun(epoch)
        if len(schedule_tuple) <= 2:
            return lr
        else:
            return jnp.maximum(lr, schedule_tuple[2])

    return my_sched_fun
Beispiel #3
0
def main(unused_argv):
    from jax.api import grad, jit, vmap, pmap, device_put
    "The following is required to use TPU Driver as JAX's backend."

    if FLAGS.TPU:
        config.FLAGS.jax_xla_backend = "tpu_driver"
        config.FLAGS.jax_backend_target = "grpc://" + os.environ[
            'TPU_ADDR'] + ':8470'
        TPU_ADDR = os.environ['TPU_ADDR']
    ndevices = xla_bridge.device_count()
    if not FLAGS.TPU:
        ndevices = 1

    pmap = partial(pmap, axis_name='i')
    """Setup some experiment parameters."""
    meas_step = FLAGS.meas_step
    training_epochs = int(FLAGS.epochs)

    tmult = 1.0
    if FLAGS.physical:
        tmult = FLAGS.lr
        if FLAGS.physicalL2:
            tmult = FLAGS.L2 * tmult
    if FLAGS.physical:
        training_epochs = 1 + int(FLAGS.epochs / tmult)

    print('Evolving for {:}e'.format(training_epochs))
    losst = FLAGS.losst
    learning_rate = FLAGS.lr
    batch_size_per_device = FLAGS.bs
    N = FLAGS.N
    K = FLAGS.K

    batch_size = batch_size_per_device * ndevices
    steps_per_epoch = 50000 // batch_size
    training_steps = training_epochs * steps_per_epoch

    "Filename from FLAGS"

    filename = 'wrnL2_' + losst + '_n' + str(N) + '_k' + str(K)
    if FLAGS.momentum:
        filename += '_mom'
    if FLAGS.L2_sch:
        filename += '_L2sch' + '_decay' + str(FLAGS.L2dec) + '_del' + str(
            FLAGS.delay)
    if FLAGS.seed != 1:
        filename += 'seed' + str(FLAGS.seed)
    filename += '_L2' + str(FLAGS.L2)
    if FLAGS.std_wrn_sch:
        filename += '_stddec'
        if FLAGS.physical:
            filename += 'phys'
    else:
        filename += '_ctlr'
    if not FLAGS.augment:
        filename += '_noaug'
    if not FLAGS.mix:
        filename += '_nomixup'
    filename += '_bs' + str(batch_size) + '_lr' + str(learning_rate)
    if FLAGS.jobdir is not None:
        filedir = os.path.join('wrnlogs', FLAGS.jobdir)
    else:
        filedir = 'wrnlogs'
    if not os.path.exists(filedir):
        os.makedirs(filedir)
    filedir = os.path.join(filedir, filename + '.csv')

    print('Saving log to ', filename)
    print('Found {} cores.'.format(ndevices))
    """Load CIFAR10 data and create a minimal pipeline."""

    train_images, train_labels, test_images, test_labels = utils.load_data(
        'cifar10')
    train_images = np.reshape(train_images, (-1, 32, 32 * 3))
    train = (train_images, train_labels)
    test = (test_images, test_labels)
    k = train_labels.shape[-1]
    train = utils.shard_data(train, ndevices)
    test = utils.shard_data(test, ndevices)
    """Create a Wide Resnet and replicate its parameters across the devices."""

    initparams, f, _ = utils.WideResnetnt(N, K, k)

    "Loss and optimizer definitions"

    l2_norm = lambda params: tree_map(lambda x: np.sum(x**2), params)
    l2_reg = lambda params: tree_reduce(lambda x, y: x + y, l2_norm(params))
    currL2 = FLAGS.L2
    L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, )))

    def xentr(params, images_and_labels):
        images, labels = images_and_labels
        return -np.mean(stax.logsoftmax(f(params, images)) * labels)

    def mse(params, data_tuple):
        """MSE loss."""
        x, y = data_tuple
        return 0.5 * np.mean((y - f(params, x))**2)

    if losst == 'xentr':
        print('Using xentr')
        lossm = xentr
    else:
        print('Using mse')
        lossm = mse

    loss = lambda params, data, L2: lossm(params, data) + L2 * l2_reg(params)

    def accuracy(params, images_and_labels):
        images, labels = images_and_labels
        return np.mean(
            np.array(np.argmax(f(params, images), axis=1) == np.argmax(labels,
                                                                       axis=1),
                     dtype=np.float32))

    "Define optimizer"

    if FLAGS.std_wrn_sch:
        lr = learning_rate
        first_epoch = int(60 / 200 * training_epochs)
        learning_rate_fn = optimizers.piecewise_constant(
            np.array([1, 2, 3]) * first_epoch * steps_per_epoch,
            np.array([lr, lr * 0.2, lr * 0.2**2, lr * 0.2**3]))
    else:
        learning_rate_fn = optimizers.make_schedule(learning_rate)

    if FLAGS.momentum:
        momentum = 0.9
    else:
        momentum = 0

    @pmap
    def update_step(step, state, batch_state, L2):
        batch, batch_state = batch_fn(batch_state)
        params = get_params(state)
        dparams = grad_loss(params, batch, L2)
        dparams = tree_map(lambda x: lax.psum(x, 'i') / ndevices, dparams)
        return step + 1, apply_fn(step, dparams, state), batch_state

    @pmap
    def evaluate(state, data, L2):
        params = get_params(state)
        lossmm = lossm(params, data)
        l2mm = l2_reg(params)
        return lossmm + L2 * l2mm, accuracy(params, data), lossmm, l2mm

    "Initialization and loading"

    _, params = initparams(random.PRNGKey(0), (-1, 32, 32, 3))
    replicate_array = lambda x: \
        np.broadcast_to(x, (ndevices,) + x.shape)
    replicated_params = tree_map(replicate_array, params)

    grad_loss = jit(grad(loss))
    init_fn, apply_fn, get_params = optimizers.momentum(
        learning_rate_fn, momentum)
    apply_fn = jit(apply_fn)
    key = random.PRNGKey(FLAGS.seed)

    batchinit_fn, batch_fn = utils.sharded_minibatcher(batch_size,
                                                       ndevices,
                                                       transform=FLAGS.augment,
                                                       k=k,
                                                       mix=FLAGS.mix)

    batch_state = pmap(batchinit_fn)(random.split(key, ndevices), train)
    state = pmap(init_fn)(replicated_params)

    if FLAGS.checkpointing:
        ## Loading of checkpoint if available/provided.
        single_state = init_fn(params)
        i0, load_state, load_params, filename0, batch_stateb = utils.load_weights(
            filename,
            single_state,
            params,
            full_file=FLAGS.load_w,
            ndevices=ndevices)
        if i0 is not None:
            filename = filename0
            if batch_stateb is not None:
                batch_state = batch_stateb
            if load_params is not None:
                state = pmap(init_fn)(load_params)
            else:
                state = load_state
        else:
            i0 = 0
    else:
        i0 = 0

    if FLAGS.steps_from_load:
        training_steps = i0 + training_steps

    batch_xs, _ = pmap(batch_fn)(batch_state)

    train_loss = []
    train_accuracy = []
    lrL = []
    test_loss = []
    test_accuracy = []
    test_L2, test_lm, train_lm, train_L2 = [], [], [], []
    L2_t = []
    idel0 = i0
    start = time.time()

    step = pmap(lambda x: x)(i0 * np.ones((ndevices, )))

    "Start training loop"
    if FLAGS.checkpointing:
        print('Evolving for {:}e and saving every {:}s'.format(
            training_epochs, FLAGS.checkpointing))

    print(
        'Epoch\tLearning Rate\tTrain bareLoss\t L2_norm \tTest Loss\tTrain Error\tTest Error\tTime / Epoch'
    )

    for i in range(i0, training_steps):
        if i % meas_step == 0:
            # Make Measurement
            l, a, lm, L2m = evaluate(state, test, L2p)
            test_loss += [np.mean(l)]
            test_accuracy += [np.mean(a)]
            test_lm += [np.mean(lm)]
            test_L2 += [np.mean(L2m)]
            train_batch, _ = pmap(batch_fn)(batch_state)
            l, a, lm, L2m = evaluate(state, train_batch, L2p)

            train_loss += [np.mean(l)]
            train_accuracy += [np.mean(a)]
            train_lm += [np.mean(lm)]
            train_L2 += [np.mean(L2m)]
            L2_t.append(currL2)
            lrL += [learning_rate_fn(i)]

            if FLAGS.L2_sch and i > FLAGS.delay / currL2 + idel0 and len(
                    train_lm) > 2 and ((minloss <= train_lm[-1]
                                        and minloss <= train_lm[-2]) or
                                       (maxacc >= train_accuracy[-1]
                                        and maxacc >= train_accuracy[-2])):
                # If AutoL2 is on and we are beyond the refractory period, decay if the loss or error have increased in the last two measurements.
                print('Decaying L2 to', currL2 / FLAGS.L2dec)
                currL2 = currL2 / FLAGS.L2dec
                L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, )))
                idel0 = i

            elif FLAGS.L2_sch and len(train_lm) >= 2:
                # Update the minimum values.
                try:
                    maxacc = max(train_accuracy[-2], maxacc)
                    minloss = min(train_lm[-2], minloss)
                except:
                    maxacc, minloss = train_accuracy[-2], train_lm[-2]

            if i % (meas_step * 10) == 0 or i == i0:
                # Save measurements to csv
                epoch = batch_size * i / 50000
                dt = (time.time() - start) / (meas_step * 10) * steps_per_epoch
                print(('{}\t' + ('{: .4f}\t' * 7)).format(
                    epoch, learning_rate_fn(i), train_lm[-1], train_L2[-1],
                    test_loss[-1], train_accuracy[-1], test_accuracy[-1], dt))

                start = time.time()
                data = {
                    'train_loss': train_loss,
                    'test_loss': test_loss,
                    'train_acc': train_accuracy,
                    'test_acc': test_accuracy
                }
                data['train_bareloss'] = train_lm
                data['train_L2'] = train_L2
                data['test_bareloss'] = test_lm
                data['test_L2'] = test_L2
                data['L2_t'] = L2_t
                df = pd.DataFrame(data)

                df['learning_rate'] = lrL
                df['width'] = K
                df['batch_size'] = batch_size
                df['step'] = i0 + onp.arange(0, len(train_loss)) * meas_step

                df.to_csv(filedir, index=False)

        if FLAGS.checkpointing:
            ### SAVE MODEL
            if i % FLAGS.checkpointing == 0 and i > i0:

                if not os.path.exists('weights/'):
                    os.makedirs('weights/')
                saveparams = tree_flatten(state[0])[0]
                if ndevices > 1:
                    saveparams = [el[0] for el in saveparams]
                saveparams = np.concatenate(
                    [el.reshape(-1) for el in saveparams])

                step0 = i
                print('Step', i)
                print('saving at', filename, step0, 'size:', saveparams.shape)

                utils.save_weights(filename, step0, saveparams, batch_state)

        ## UPDATE
        step, state, batch_state = update_step(step, state, batch_state, L2p)

    print('Training done')

    if FLAGS.TPU:
        with open('done/' + TPU_ADDR, 'w') as fp:
            fp.write(filedir)
            pass
Beispiel #4
0
def run():
    """
    Run the experiment.
    """

    # init the model first so that jax gets enough GPU memory before TFDS
    forward, model = init_model(43)  # how do you sleep at night
    grad_fn = jax.grad(lambda *args: loss_fn(forward, *args))

    ds_train, ds_test_eval, meta = init_data()
    num_batches = meta["num_batches"]
    num_test_batches = meta["num_test_batches"]

    lr_schedule = optimizers.piecewise_constant(
        boundaries=[9000, 12750],  # 300 epochs, 425 epochs
        values=[1e-3, 1e-4, 1e-5])
    opt_init, opt_update, get_params = optimizers.adam(step_size=lr_schedule)

    unravel_opt = ravel_pytree(opt_init(model["params"]))[1]
    if os.path.exists(parse_args.ckpt_path):
        outfile = open(parse_args.ckpt_path, 'rb')
        state_dict = pickle.load(outfile)
        outfile.close()

        opt_state = unravel_opt(state_dict["opt_state"])

        load_itr = state_dict["itr"]
    else:
        init_params = model["params"]
        opt_state = opt_init(init_params)

        load_itr = 0

    @jax.jit
    def update(_itr, _opt_state, _key, _batch):
        """
        Update the params based on grad for current batch.
        """
        return opt_update(_itr, grad_fn(get_params(_opt_state), _batch, _key),
                          _opt_state)

    @jax.jit
    def sep_losses(_opt_state, _batch, _key):
        """
        Convenience function for calculating losses separately.
        """
        z, delta_logp, r2_regs, fro_regs, kin_regs = model["forward_all"](
            _key, get_params(_opt_state), _batch)
        loss_ = _loss_fn(z, delta_logp)
        r2_reg_ = _reg_loss_fn(r2_regs)
        fro_reg_ = _reg_loss_fn(fro_regs)
        kin_reg_ = _reg_loss_fn(kin_regs)
        total_loss_ = loss_ + lam * r2_reg_ + lam_fro * fro_reg_ + lam_kin * kin_reg_
        return total_loss_, loss_, r2_reg_, fro_reg_, kin_reg_

    def evaluate_loss(opt_state, _key, ds_eval):
        """
        Convenience function for evaluating loss over train set in smaller batches.
        """
        sep_loss_aug_, sep_loss_, sep_loss_r2_reg_, sep_loss_fro_reg_, sep_loss_kin_reg_, nfe, bs = \
            [], [], [], [], [], [], []

        for test_batch_num in range(num_test_batches):
            _key, = jax.random.split(_key, num=1)
            test_batch = next(ds_eval)

            test_batch_loss_aug_, test_batch_loss_, \
            test_batch_loss_r2_reg_, test_batch_loss_fro_reg_, test_batch_loss_kin_reg_ = \
                sep_losses(opt_state, test_batch, _key)

            if count_nfe:
                nfe.append(model["nfe"](_key, get_params(opt_state),
                                        test_batch))
            else:
                nfe.append(0)

            sep_loss_aug_.append(test_batch_loss_aug_)
            sep_loss_.append(test_batch_loss_)
            sep_loss_r2_reg_.append(test_batch_loss_r2_reg_)
            sep_loss_fro_reg_.append(test_batch_loss_fro_reg_)
            sep_loss_kin_reg_.append(test_batch_loss_kin_reg_)
            bs.append(len(test_batch))

        sep_loss_aug_ = jnp.array(sep_loss_aug_)
        sep_loss_ = jnp.array(sep_loss_)
        sep_loss_r2_reg_ = jnp.array(sep_loss_r2_reg_)
        sep_loss_fro_reg_ = jnp.array(sep_loss_fro_reg_)
        sep_loss_kin_reg_ = jnp.array(sep_loss_kin_reg_)
        nfe = jnp.array(nfe)
        bs = jnp.array(bs)

        return jnp.average(sep_loss_aug_, weights=bs), \
               jnp.average(sep_loss_, weights=bs), \
               jnp.average(sep_loss_r2_reg_, weights=bs), \
               jnp.average(sep_loss_fro_reg_, weights=bs), \
               jnp.average(sep_loss_kin_reg_, weights=bs), \
               jnp.average(nfe, weights=bs)

    itr = 0
    info = collections.defaultdict(dict)

    key = rng

    for epoch in range(parse_args.nepochs):
        for i in range(num_batches):
            key, = jax.random.split(key, num=1)
            batch = next(ds_train)

            itr += 1

            if itr <= load_itr:
                continue

            update_start = time.time()
            opt_state = update(itr, opt_state, key, batch)
            tree_flatten(opt_state)[0][0].block_until_ready()
            update_end = time.time()
            time_str = "%d %.18f %d\n" % (itr, update_end - update_start,
                                          load_itr)
            outfile = open(
                "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_time.txt" %
                (dirname, reg, reg_type, lam, lam_fro, lam_kin), "a")
            outfile.write(time_str)
            outfile.close()

            if itr % parse_args.test_freq == 0:

                loss_aug_, loss_, loss_r2_reg_, loss_fro_reg_, loss_kin_reg_, nfe_ = \
                    evaluate_loss(opt_state, key, ds_test_eval)

                print_str = 'Iter {:04d} | Total (Regularized) Loss {:.6f} | Loss {:.6f} | ' \
                            'r {:.6f} | fro {:.6f} | kin {:.6f} | ' \
                            'NFE {:.6f}'.format(itr, loss_aug_, loss_, loss_r2_reg_, loss_fro_reg_, loss_kin_reg_, nfe_)

                print(print_str)

                outfile = open(
                    "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_info.txt"
                    % (dirname, reg, reg_type, lam, lam_fro, lam_kin), "a")
                outfile.write(print_str + "\n")
                outfile.close()

                info[itr]["loss_aug"] = loss_aug_
                info[itr]["loss"] = loss_
                info[itr]["loss_r2_reg"] = loss_r2_reg_
                info[itr]["loss_fro_reg"] = loss_fro_reg_
                info[itr]["loss_kin_reg"] = loss_kin_reg_
                info[itr]["nfe"] = nfe_

            if itr % parse_args.save_freq == 0:
                param_filename = "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_%d_fargs.pickle" \
                                 % (dirname, reg, reg_type, lam, lam_fro, lam_kin, itr)
                fargs = get_params(opt_state)
                outfile = open(param_filename, "wb")
                pickle.dump(fargs, outfile)
                outfile.close()

            if itr % parse_args.ckpt_freq == 0:
                state_dict = {
                    "opt_state": ravel_pytree(opt_state)[0],
                    "itr": itr,
                }
                # only save ckpts if a directory has been made for them (allow easy switching between v1 and v2)
                try:
                    outfile = open(parse_args.ckpt_path, 'wb')
                    pickle.dump(state_dict, outfile)
                    outfile.close()
                except IOError:
                    print("Unable to save ck.pt %d" % itr, file=sys.stderr)
    meta = {"info": info, "args": parse_args}
    outfile = open(
        "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_%d_meta.pickle" %
        (dirname, reg, reg_type, lam, lam_fro, lam_kin, itr), "wb")
    pickle.dump(meta, outfile)
    outfile.close()