Ejemplo n.º 1
0
def main(batch_size=32,
         nr_filters=8,
         epochs=10,
         step_size=.001,
         decay_rate=.999995,
         model_path=Path('./pixelcnn.params')):
    loss, _ = PixelCNNPP(nr_filters=nr_filters)
    get_train_batches, test_batches = dataset(batch_size)
    key, init_key = random.split(PRNGKey(0))
    opt = Adam(exponential_decay(step_size, 1, decay_rate))
    state = opt.init(loss.init_parameters(next(test_batches), key=init_key))

    for epoch in range(epochs):
        for batch in get_train_batches():
            key, update_key = random.split(key)
            i = opt.get_step(state)

            state, train_loss = opt.update_and_get_loss(loss.apply,
                                                        state,
                                                        batch,
                                                        key=update_key,
                                                        jit=True)

            if i % 100 == 0 or i < 10:
                key, test_key = random.split(key)
                test_loss = loss.apply(opt.get_parameters(state),
                                       next(test_batches),
                                       key=test_key,
                                       jit=True)
                print(f"Epoch {epoch}, iteration {i}, "
                      f"train loss {train_loss:.3f}, "
                      f"test loss {test_loss:.3f} ")

        save(opt.get_parameters(state), model_path)
Ejemplo n.º 2
0
    def testSgdVectorExponentialDecaySchedule(self):
        def loss(x):
            return np.dot(x, x)

        x0 = np.ones(2)
        step_schedule = optimizers.exponential_decay(0.1, 3, 2.)
        self._CheckFuns(optimizers.sgd, loss, x0, step_schedule)
Ejemplo n.º 3
0
def get_optimizer(optim_config):
    """ returns an ADAM optimizer with exponential learning-rate
  decay schedule specified in the config """

    learning_rate = optimizers.exponential_decay(
        optim_config['base_lr'], optim_config['lr_decay_steps'],
        optim_config['lr_decay_rate'])
    opt = optimizers.adam(learning_rate)

    return opt
Ejemplo n.º 4
0
def get_scheduler(lr, train_steps, name='constant'):
    name = name.lower()
    if name == 'constant':
        scheduler = optimizers.constant(lr)
    elif name == 'inverse_time_decay':
        decay_steps = int(train_steps // 5)
        scheduler = optimizers.inverse_time_decay(lr, decay_steps, 2)
    elif name == 'exponential_decay':
        decay_steps = int(train_steps // 3)
        scheduler = optimizers.exponential_decay(lr, decay_steps, 0.3)
    else:
        raise ValueError(f'Not supported scheduler {name}.'
                         f'Supported schedulers={supported_schedulers()}')
    print(f'Loaded a scheduler {name} - {scheduler}')
    return scheduler
Ejemplo n.º 5
0
def schedule_maker(schedule_tuple, learn_rate):
    """
    Return a scheduler function given a tuple of the form:
        (sched_name, decay_steps, min_lr)

    This just wraps existing JAX schedulers, but using simplified syntax
    """
    sched_type = schedule_tuple[0]
    assert learn_rate >= 0
    assert sched_type in ['const', 'exp', 'poly', 'piecewise']

    if sched_type == 'const':
        # Constant learning rate
        sched_fun = jopt.constant(learn_rate)
    elif sched_type == 'exp':
        # Exponentially decaying learning rate
        sched_fun = jopt.exponential_decay(learn_rate, schedule_tuple[1], 0.5)
    elif sched_type == 'poly':
        # Harmonically decaying stepped learning rate
        sched_fun = jopt.inverse_time_decay(learn_rate,
                                            schedule_tuple[1],
                                            5,
                                            staircase=True)
    elif sched_type == 'piecewise':
        # Piecewise constant learning rate, drops by factor of 10 each time
        step_len = schedule_tuple[1]
        assert step_len > 0
        bounds = [step_len * i for i in range(1, 10)]
        values = [learn_rate * 10**(-i) for i in range(10)]
        sched_fun = jopt.piecewise_constant(bounds, values)

    def my_sched_fun(epoch):
        lr = sched_fun(epoch)
        if len(schedule_tuple) <= 2:
            return lr
        else:
            return jnp.maximum(lr, schedule_tuple[2])

    return my_sched_fun
def optimize_lfads(key, init_params, hps, opt_hps, train_data_fun,
                   eval_data_fun):
    """Optimize the LFADS model and print batch based optimization data.

  This loop is at the cpu nonjax-numpy level.

  Arguments:
    init_params: a dict of parameters to be trained
    hps: dict of lfads model HPs
    opt_hps: dict of optimization HPs
    train_data_fun: function that takes a key and returns
      nexamples x time x ndims np array of data for training
    eval_data_fun: function that takes a key and returns
      nexamples x time x ndims np array of data for held out error
  Returns:
    a dictionary of trained parameters"""

    # Begin optimziation loop.
    all_tlosses = []
    all_elosses = []

    # Build some functions used in optimization.
    kl_warmup_fun = get_kl_warmup_fun(opt_hps)
    decay_fun = optimizers.exponential_decay(opt_hps['step_size'],
                                             opt_hps['decay_steps'],
                                             opt_hps['decay_factor'])

    opt_init, opt_update, get_params = optimizers.adam(step_size=decay_fun,
                                                       b1=opt_hps['adam_b1'],
                                                       b2=opt_hps['adam_b2'],
                                                       eps=opt_hps['adam_eps'])
    opt_state = opt_init(init_params)

    def update_w_gc(i, opt_state, hps, opt_hps, key, x_bxt, kl_warmup):
        """Update fun for gradients, includes gradient clipping."""
        params = get_params(opt_state)
        grads = grad(lfads.training_loss_jit)(params, hps, key, x_bxt,
                                              kl_warmup, opt_hps['keep_rate'])
        clipped_grads = optimizers.clip_grads(grads, opt_hps['max_grad_norm'])
        return opt_update(i, clipped_grads, opt_state)

    update_w_gc_jit = jit(update_w_gc, static_argnums=(2, 3))

    # Run the optimization, pausing every so often to collect data and
    # print status.
    batch_size = hps['batch_size']
    num_batches = opt_hps['num_batches']
    print_every = opt_hps['print_every']
    num_opt_loops = int(num_batches / print_every)
    params = get_params(opt_state)
    for oidx in range(num_opt_loops):
        batch_idx_start = oidx * print_every
        start_time = time.time()
        key, tkey, dtkey1, dtkey2, dekey1, dekey2 = \
            random.split(random.fold_in(key, oidx), 6)
        opt_state = optimize_core_jit(tkey, batch_idx_start, print_every,
                                      update_w_gc_jit, kl_warmup_fun,
                                      opt_state, hps, opt_hps, train_data_fun)
        batch_time = time.time() - start_time

        # Losses
        params = get_params(opt_state)
        batch_pidx = batch_idx_start + print_every
        kl_warmup = kl_warmup_fun(batch_idx_start)
        # Training loss
        #didxs = onp.random.randint(0, train_data.shape[0], batch_size)
        #x_bxt = train_data[didxs].astype(onp.float32)
        x_bxt = train_data_fun(dtkey1)
        tlosses = lfads.losses_jit(params, hps, dtkey2, x_bxt, kl_warmup, 1.0)

        # Evaluation loss
        #didxs = onp.random.randint(0, eval_data.shape[0], batch_size)
        #ex_bxt = eval_data[didxs].astype(onp.float32)
        ex_bxt = eval_data_fun(dekey1)
        elosses = lfads.losses_jit(params, hps, dekey2, ex_bxt, kl_warmup, 1.0)
        # Saving, printing.
        resps = softmax(params['prior']['resps'])
        rmin = onp.min(resps)
        rmax = onp.max(resps)
        rmean = onp.mean(resps)
        rstd = onp.std(resps)

        all_tlosses.append(tlosses)
        all_elosses.append(elosses)
        s1 = "Batches {}-{} in {:0.2f} sec, Step size: {:0.5f}"
        s2 = "    Training losses {:0.0f} = NLL {:0.0f} + KL {:0.1f},{:0.1f} + L2 {:0.2f} + II L2 {:0.2f} + <II> {:0.2f} "
        s3 = "        Eval losses {:0.0f} = NLL {:0.0f} + KL {:0.1f},{:0.1f} + L2 {:0.2f} + II L2 {:0.2f} + <II> {:0.2f} "
        s4 = "        Resps: min {:0.4f}, mean {:0.4f}, max {:0.4f}, std {:0.4f}"
        print(
            s1.format(batch_idx_start + 1, batch_pidx, batch_time,
                      decay_fun(batch_pidx)))
        print(
            s2.format(tlosses['total'], tlosses['nlog_p_xgz'],
                      tlosses['kl_prescale'], tlosses['kl'], tlosses['l2'],
                      tlosses['ii_l2'], tlosses['ii_tavg']))
        print(
            s3.format(elosses['total'], elosses['nlog_p_xgz'],
                      elosses['kl_prescale'], elosses['kl'], elosses['l2'],
                      elosses['ii_l2'], elosses['ii_tavg']))
        print(s4.format(rmin, rmean, rmax, rstd))

        tlosses_thru_training = utils.merge_losses_dicts(all_tlosses)
        elosses_thru_training = utils.merge_losses_dicts(all_elosses)
        optimizer_details = {
            'tlosses': tlosses_thru_training,
            'elosses': elosses_thru_training
        }

    return params, optimizer_details
Ejemplo n.º 7
0
    T_train = df.pop("week").values
    E_train = df.pop("arrest").values
    X_train = df.values

    return X_train, T_train, E_train


x_train, t_train, e_train = get_rossi_dataset()

model = Model([Dense(18), Relu])

model.compile(
    optimizer=optimizers.adam,
    optimizer_kwargs={
        "step_size": optimizers.exponential_decay(0.01, 10, 0.999)
    },
    loss=losses.NonParametric(),
)

model.fit(x_train, t_train, e_train, epochs=2, batch_size=32)

print(model.predict_survival_function(x_train[0], np.arange(0, 10)))

dump(model, "testsavefile")
model = load("testsavefile")

print(model.predict_survival_function(x_train[0], np.arange(0, 10)))

model.fit(
    x_train,
Ejemplo n.º 8
0
def train(data_dict, train_dict, seed_dict, results_dict):
    """Train HM-NLICA model using a minibatch implementation of the algorithm
    described in the paper.

    Args:
        data_dict (dict.): dictionary of required data in the form of:
            {'x_data': observed signals (array),
             's_data': true latent component, for evaluation (array),
             'state_seq': true latent state sequece (array)}.
        train_dict (dict.): dictionary of variables related to optimization
            of form:
                {'mix_depth': num. layers in mixing/estimator MLP (int), for
                    example mix_depth=1 is linear ICA,
                 'hidden_size': num. hidden units per MLP layer (int),
                 'learning_rate': step size for optimizer (float),
                 'num_epochs': num. training epochs (int),
                 'subseq_len': length of time sequences in a minibatch (int),
                 'minib_size': num. sub-sequences in a minibatch (int),
                 'decay_rate': multiplier for decaying learning rate (float),
                 'decay_steps': num. epochs per which to decay lr (int)}.
        seed_dict (dict.): dictionary of seeds for reproducible stochasticity
            of form:
                {'est_mlp_seed': seed to initialize MLP parameters (int),
                 'est_distrib_seed': seed to initialize exp fam params (int)}.
        results_dict (dict.): stores data to save (see main.py).

    Returns:
        s_est (array): estimated independent components.
        sort_idx (array): best matching indices of components to true indices.
        results_dict (dict): to save all evaluation and training results.
        est_params (list): list of all estimated parameter arrays.
    """
    # unpack data
    x = data_dict['x_data']
    s_true = data_dict['s_data']
    state_seq = data_dict['state_seq']

    # set data dimensions
    N = x.shape[1]
    T = x.shape[0]
    K = len(np.unique(state_seq))

    # unpack training variables
    mix_depth = train_dict['mix_depth']
    hidden_size = train_dict['hidden_size']
    learning_rate = train_dict['learning_rate']
    num_epochs = train_dict['num_epochs']
    subseq_len = train_dict['subseq_len']
    minib_size = train_dict['minib_size']
    decay_rate = train_dict['decay_rate']
    decay_steps = train_dict['decay_steps']

    print("Training with N={n}, T={t}, K={k}\t"
          "mix_depth={md}".format(n=N, t=T, k=K, md=mix_depth))

    # initialize parameters for mlp function approximator
    key = jrandom.PRNGKey(seed_dict['est_mlp_seed'])
    layer_sizes = [N] + [hidden_size] * (mix_depth - 1) + [N]
    mlp_params = init_mlp_params(key, layer_sizes)

    # initialize parameters for estimating distribution parameters
    np.random.seed(seed_dict['est_distrib_seed'])
    mu_est = np.random.uniform(-5., 5., size=(K, N))
    var_est = np.random.uniform(1., 2., size=(K, N))
    D_est = np.zeros(shape=(K, N, N))
    for k in range(K):
        D_est[k] = np.diag(var_est[k])

    # initialize transition parameter estimates
    A_est = np.eye(K) + 0.05
    A_est = A_est / A_est.sum(1, keepdims=True)
    pi_est = A_est.sum(0) / A_est.sum()

    # set up optimizer
    schedule = optimizers.exponential_decay(learning_rate,
                                            decay_steps=decay_steps,
                                            decay_rate=decay_rate)
    opt_init, opt_update, get_params = optimizers.adam(schedule)

    # set up loss function and training step
    @jit
    def calc_loss(params, input_data, marginal_posteriors, mu_est, D_est,
                  num_subseqs):
        """Calculates the loss for gradient M-step for function estimator.
        """
        lp_x, lp_x_exc_J, lp_J, _ = mbatch_emission_likelihood(
            params, input_data, mu_est, D_est)
        expected_lp_x = jnp.sum(marginal_posteriors * lp_x, -1)
        # note correction for bias below
        return -expected_lp_x.mean() * num_subseqs

    @jit
    def training_step(iter_num, input_data, marginal_posteriors, mu_est, D_est,
                      opt_state, num_subseqs):
        """Performs gradient m-step on the function estimator
               MLP parameters.
        """
        params = get_params(opt_state)
        loss, g = value_and_grad(
            calc_loss, argnums=0)(params, input_data,
                                  lax.stop_gradient(marginal_posteriors),
                                  mu_est, D_est, num_subseqs)
        return loss, opt_update(iter_num, g, opt_state)

    # function to load subsequence data for minibatches
    @jit
    def get_subseq_data(orig_data, subseq_array_to_fill):
        """Collects all sub-sequences into an array.
        """
        subseq_data = subseq_array_to_fill
        num_subseqs = subseq_data.shape[0]
        subseq_len = subseq_data.shape[1]

        def body_fun(i, subseq_data):
            """Function to loop over.
            """
            subseq_i = lax.dynamic_slice_in_dim(orig_data, i, subseq_len)
            subseq_data = ops.index_update(subseq_data, ops.index[i, :, :],
                                           subseq_i)
            return subseq_data

        return lax.fori_loop(0, num_subseqs, body_fun, subseq_data)

    # set up minibatch training
    num_subseqs = T - subseq_len + 1
    assert num_subseqs >= minib_size
    num_full_minibs, remainder = divmod(num_subseqs, minib_size)
    num_minibs = num_full_minibs + bool(remainder)
    sub_data_holder = jnp.zeros((num_subseqs, subseq_len, N))
    sub_data = get_subseq_data(x, sub_data_holder)
    print("T: {t}\t"
          "subseq_len: {slen}\t"
          "minibatch size: {mbs}\t"
          "num minibatches: {nbs}".format(t=T,
                                          slen=subseq_len,
                                          mbs=minib_size,
                                          nbs=num_minibs))

    # initialize and train
    best_logl = -np.inf
    itercount = itertools.count()
    opt_state = opt_init(mlp_params)
    all_subseqs_idx = np.arange(num_subseqs)
    for epoch in range(num_epochs):
        tic = time.time()
        # shuffle subseqs for added stochasticity
        np.random.shuffle(all_subseqs_idx)
        sub_data = sub_data.copy()[all_subseqs_idx]
        # train over minibatches
        for batch in range(num_minibs):
            # select sub-sequence for current minibatch
            batch_data = sub_data[batch * minib_size:(batch + 1) * minib_size]

            # calculate emission likelihood using most recent parameters
            params = get_params(opt_state)
            logp_x, logp_x_exc_J, lpj, s_est = mbatch_emission_likelihood(
                params, batch_data, mu_est, D_est)

            # forward-backward algorithm
            marg_posteriors, pw_posteriors, scalers = mbatch_fwd_bwd_algo(
                logp_x, A_est, pi_est)

            # exact M-step for mean and variance
            mu_est, D_est, A_est, pi_est = mbatch_m_step(
                s_est, marg_posteriors, pw_posteriors)

            # SGD for mlp parameters
            loss, opt_state = training_step(next(itercount), batch_data,
                                            marg_posteriors, mu_est, D_est,
                                            opt_state, num_subseqs)

        # gather full data after each epoch for evaluation
        params_latest = get_params(opt_state)
        logp_x_all, _, _, s_est_all = emission_likelihood(
            params_latest, x, mu_est, D_est)
        _, _, scalers = forward_backward_algo(logp_x_all, A_est, pi_est)
        logl_all = np.log(scalers).sum()

        # viterbi to estimate state prediction
        est_seq = viterbi_algo(logp_x_all, A_est, pi_est)
        cluster_acc = clustering_acc(np.array(est_seq), np.array(state_seq))

        # evaluate correlation of estimated and true independent components
        mean_abs_corr, s_est_sorted, sort_idx = matching_sources_corr(
            np.array(s_est_all), np.array(s_true))

        # save results
        if logl_all > best_logl:
            best_logl = logl_all
            best_logl_corr = mean_abs_corr
            best_logl_acc = cluster_acc
            results_dict['results'].append({
                'best_logl': best_logl,
                'best_logl_corr': mean_abs_corr,
                'best_logl_acc': cluster_acc
            })

        results_dict['results'].append({
            'epoch': epoch,
            'logl': logl_all,
            'corr': mean_abs_corr,
            'acc': cluster_acc
        })
        # print them
        print("Epoch: [{0}/{1}]\t"
              "LogL: {logl:.2f}\t"
              "mean corr between s and s_est {corr:.2f}\t"
              "acc {acc:.2f}\t"
              "elapsed {time:.2f}".format(epoch,
                                          num_epochs,
                                          logl=logl_all,
                                          corr=mean_abs_corr,
                                          acc=cluster_acc,
                                          time=time.time() - tic))

    # pack data into tuples
    results_dict['results'].append({
        'best_logl': best_logl,
        'best_logl_corr': best_logl_corr,
        'best_logl_acc': best_logl_acc
    })
    est_params = (mu_est, D_est, A_est, est_seq)
    return s_est, sort_idx, results_dict, est_params
Ejemplo n.º 9
0
 def testSgdVectorExponentialDecaySchedule(self):
   def loss(x, _): return np.dot(x, x)
   x0 = np.ones(2)
   num_iters = 100
   step_schedule = optimizers.exponential_decay(0.1, 3, 2.)
   self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_schedule)
Ejemplo n.º 10
0
    latent_shape, init_encoder_params = encoder_init(next(keys),
                                                     (n_timesteps, n_neurons))
    decoded_shape, init_decoder_params = decoder_init(next(keys), latent_shape)
    init_params = init_encoder_params, init_decoder_params

    # Optimizer #


    def kl_warmup_fun(batch_idx):
        progress_frac = ((batch_idx - kl_warmup_start) /
                         (kl_warmup_end - kl_warmup_start))
        _warmup = np.where(batch_idx < kl_warmup_start, kl_min,
                           (kl_max - kl_min) * progress_frac + kl_min)
        return np.where(batch_idx > kl_warmup_end, kl_max, _warmup)

    decay_fun = optimizers.exponential_decay(STEP_SIZE, DECAY_STEPS,
                                             DECAY_FACTOR)
    # TODO: Check exponential_decay when using epochs / batches.

    opt_init, opt_update, get_params = optimizers.adam(step_size=decay_fun,
                                                       b1=0.9,
                                                       b2=0.999,
                                                       eps=1e-1)  # Seems big
    opt_state = opt_init(init_params)

    @jit
    def run_epoch(rng, _opt_state, epoch_idx):
        _rng, dat_keys = utils.keygen(rng, 1)
        _rng, batch_keys = utils.keygen(_rng, num_batches)

        # Randomize epoch data.
        epoch_data = random.shuffle(next(dat_keys), X_train, axis=0)
Ejemplo n.º 11
0
def optimize_lfads(key,
                   init_params,
                   hps,
                   opt_hps,
                   train_data_fun,
                   eval_data_fun,
                   ncompleted_batches=0,
                   opt_state=None,
                   callback_fun=None,
                   do_print=True):
    """Optimize the LFADS model and print batch based optimization data.

  This loop is at the cpu nonjax-numpy level.

  Arguments:
    key: random.PRNGKey for randomness
    init_params: a dict of parameters to be trained
    hps: dict of lfads model HPs
    opt_hps: dict of optimization HPs
    train_data_fun: function that takes a key and returns
      nexamples x time x ndims np array of data for training
    eval_data_fun: function that takes a key and returns
      nexamples x time x ndims np array of data for held out error
    ncompleted_batches: (default 0), use this to restart training in the middle
      of the batch count. Used in tandem with opt_state (below).
    opt_state: (default None) 3-tuple (params, m - 1st moment, v - 2nd moment) 
      from jax.experimental.optimizers.adam (None value starts optimizer anew).
      The params in opt_state[0] will *override* the init_params argument.
    callback_fun: (default None) function that the optimzie routine will call 
      every print_every loops, in order to do whatever the user wants, typically
      saving, or reporting to a hyperparameter tuner, etc.
      callback_fun parameters are
        (current_batch_idx:int, hps:dict, opt_hps:dict, 
         params:dict, opt_state:tuple,
         tlosses:dict, elosses:dict) 
    do_print: (default True), print loss information
  Returns:
    A 3-tuple of 
      (trained_params, 
       opt_details - dictionary of optimization losses through training, 
       (opt_state - a 3-tuple of trained params in odd pytree form, 
         m 1st moment, v 2nd moment)).
  """

    # Begin optimziation loop.
    all_tlosses = []
    all_elosses = []

    # Build some functions used in optimization.
    kl_warmup_fun = get_kl_warmup_fun(opt_hps)
    decay_fun = optimizers.exponential_decay(opt_hps['step_size'],
                                             opt_hps['decay_steps'],
                                             opt_hps['decay_factor'])

    opt_init, opt_update, get_params = optimizers.adam(step_size=decay_fun,
                                                       b1=opt_hps['adam_b1'],
                                                       b2=opt_hps['adam_b2'],
                                                       eps=opt_hps['adam_eps'])
    print_every = opt_hps['print_every']
    if ncompleted_batches > 0:
        print('Starting batch count at %d.' % (ncompleted_batches))
        assert ncompleted_batches % print_every == 0
        opt_loop_start_idx = int(ncompleted_batches / print_every)
    else:
        opt_loop_start_idx = 0
    if opt_state is not None:
        print('Received opt_state, ignoring init_params argument.')
    else:
        opt_state = opt_init(init_params)

    def update_w_gc(i, opt_state, hps, opt_hps, key, x_bxt, kl_warmup):
        """Update fun for gradients, includes gradient clipping."""
        params = get_params(opt_state)
        grads = grad(lfads.training_loss_jit)(params, hps, key, x_bxt,
                                              kl_warmup, opt_hps['keep_rate'])
        clipped_grads = optimizers.clip_grads(grads, opt_hps['max_grad_norm'])
        return opt_update(i, clipped_grads, opt_state)

    update_w_gc_jit = jit(update_w_gc, static_argnums=(2, 3))

    # Run the optimization, pausing every so often to collect data and
    # print status.
    batch_size = hps['batch_size']
    num_batches = opt_hps['num_batches']
    assert num_batches % print_every == 0
    num_opt_loops = int(num_batches / print_every)
    params = get_params(opt_state)
    for oidx in range(opt_loop_start_idx, num_opt_loops):
        batch_idx_start = oidx * print_every
        start_time = time.time()
        key, tkey, dtkey1, dtkey2, dekey1, dekey2 = \
            random.split(random.fold_in(key, oidx), 6)
        opt_state = optimize_core_jit(tkey, batch_idx_start, print_every,
                                      update_w_gc_jit, kl_warmup_fun,
                                      opt_state, hps, opt_hps, train_data_fun)
        batch_time = time.time() - start_time

        # Losses
        params = get_params(opt_state)
        batch_pidx = batch_idx_start + print_every
        kl_warmup = kl_warmup_fun(batch_idx_start)
        # Training loss
        x_bxt = train_data_fun(dtkey1)
        tlosses = lfads.losses_jit(params, hps, dtkey2, x_bxt, kl_warmup, 1.0)

        # Evaluation loss
        ex_bxt = eval_data_fun(dekey1)
        elosses = lfads.losses_jit(params, hps, dekey2, ex_bxt, kl_warmup, 1.0)
        # Saving, printing.
        resps = softmax(params['prior']['resps'])
        rmin = onp.min(resps)
        rmax = onp.max(resps)
        rmean = onp.mean(resps)
        rstd = onp.std(resps)

        all_tlosses.append(tlosses)
        all_elosses.append(elosses)
        if do_print:
            s1 = "Batches {}-{} in {:0.2f} sec, Step size: {:0.5f}"
            s2 = "    Training losses {:0.0f} = NLL {:0.0f} + KL {:0.1f},{:0.1f} + L2 {:0.2f} + II L2 {:0.2f} + <II> {:0.2f} "
            s3 = "        Eval losses {:0.0f} = NLL {:0.0f} + KL {:0.1f},{:0.1f} + L2 {:0.2f} + II L2 {:0.2f} + <II> {:0.2f} "
            s4 = "        Resps: min {:0.4f}, mean {:0.4f}, max {:0.4f}, std {:0.4f}"
            print(
                s1.format(batch_idx_start + 1, batch_pidx, batch_time,
                          decay_fun(batch_pidx)))
            print(
                s2.format(tlosses['total'], tlosses['nlog_p_xgz'],
                          tlosses['kl_prescale'], tlosses['kl'], tlosses['l2'],
                          tlosses['ii_l2'], tlosses['ii_tavg']))
            print(
                s3.format(elosses['total'], elosses['nlog_p_xgz'],
                          elosses['kl_prescale'], elosses['kl'], elosses['l2'],
                          elosses['ii_l2'], elosses['ii_tavg']))
            print(s4.format(rmin, rmean, rmax, rstd))

        if callback_fun is not None:
            callback_fun(batch_pidx, hps, opt_hps, params, opt_state, tlosses,
                         elosses)

    tlosses_thru_training = utils.merge_losses_dicts(all_tlosses)
    elosses_thru_training = utils.merge_losses_dicts(all_elosses)
    optimizer_details = {
        'tlosses': tlosses_thru_training,
        'elosses': elosses_thru_training
    }

    return params, optimizer_details, opt_state
Ejemplo n.º 12
0
from jax.experimental.stax import Dense, Dropout, Tanh, Relu, randn
from jax.experimental import optimizers
import pandas as pd
import lifelike.losses as losses
from lifelike import Model
from lifelike.callbacks import *
from datasets.loaders import *


x_train, t_train, e_train = get_generated_churn_dataset()

model = Model([Dense(8), Relu, Dense(12), Relu, Dense(16), Relu])

model.compile(
    optimizer=optimizers.adam,
    optimizer_kwargs={"step_size": optimizers.exponential_decay(0.001, 1, 0.9995)},
    weight_l2=0.00,
    smoothing_l2=100.,
    loss=losses.NonParametric()
)

print(model)

model.fit(
    x_train,
    t_train,
    e_train,
    epochs=10000,
    batch_size=10000,
    validation_split=0.1,
    callbacks=[
Ejemplo n.º 13
0
    data_loss = multiclass_xent(logits, batch['labels'])
    reg_loss = l2_pen * renn.norm(params)
    return data_loss + reg_loss


f_df = jax.value_and_grad(xent)


@jax.jit
def accuracy(params, batch):
    logits = apply_fun(params, batch['inputs'])
    predictions = jnp.argmax(logits, axis=1)
    return jnp.mean(predictions == batch['labels'])


learning_rate = optimizers.exponential_decay(2e-3, 1000, 0.8)
init_opt, update_opt, get_params = optimizers.adam(learning_rate)

state = init_opt(initial_params)
losses = []


@jax.jit
def step(k, opt_state, batch):
    params = get_params(opt_state)
    loss, gradients = f_df(params, batch)
    new_state = update_opt(k, gradients, opt_state)
    return new_state, loss


def test_acc(params):
Ejemplo n.º 14
0
#     return memo[n]

# def reverse_fib(n):
#     """ Return the index of the greatest number from the Fibonacci sequence,
#     that is smaller than or equal to n. """
#     i = 0
#     while fib(i+1) <= n:
#         i += 1
#     return i

#-------------------- optimizer and LR schedule --------------------#
step_size = 1e-2
decay_rate = 0.65  # 0.65 ** 10 = 0.01 ---> decaying the step size 10 times ammounts to dividing by 100
decay_steps = 10
step_fn = optimizers.exponential_decay(step_size=step_size,
                                       decay_rate=decay_rate,
                                       decay_steps=decay_steps)
opt_init, opt_update, get_params = optimizers.nesterov(step_size=step_fn,
                                                       mass=0.9)

#-------------------- params training utilities --------------------#
reg = 3e-5
clip_max_grad = 10.0

init_fun, apply_fun = model_fn()
apply_fun = jax.jit(apply_fun)


@jax.jit
def l2_regularizer(params, reg=reg):
    """ Return the L2 regularization loss. """
Ejemplo n.º 15
0
def optimize_fps(rnn_fun, fp_candidates, hps, do_print=True):
    """Find fixed points of the rnn via optimization.

  This loop is at the cpu non-JAX level.

  Arguments:
    rnn_fun : RNN one step update function for a single hidden state vector
      h_t -> h_t+1, for which the fixed point candidates are trained to be 
      fixed points
    fp_candidates: np array with shape (batch size, state dim) of hidden states 
      of RNN to start training for fixed points
    hps: fixed point hyperparameters
    do_print: Print useful information? 

  Returns:
    np array of numerically optimized fixed points"""

    total_fp_loss_fun = get_total_fp_loss_fun(rnn_fun)

    def get_update_fun(opt_update):
        """Update the parameters using gradient descent.

    Arguments:
      opt_update: a function that updates the parameters (from jax.optimizers)

    Returns:
      a 2-tuple (function which updates the parameters according to the 
        optimizer, a dictionary of details of the optimization)
    """
        def update(i, opt_state):
            params = optimizers.get_params(opt_state)
            grads = grad(total_fp_loss_fun)(params)
            return opt_update(i, grads, opt_state)

        return update

    # Build some functions used in optimization.
    decay_fun = optimizers.exponential_decay(hps['step_size'],
                                             hps['decay_steps'],
                                             hps['decay_factor'])
    opt_init, opt_update = optimizers.adam(step_size=decay_fun,
                                           b1=hps['adam_b1'],
                                           b2=hps['adam_b2'],
                                           eps=hps['adam_eps'])
    opt_state = opt_init(fp_candidates)
    update_fun = get_update_fun(opt_update)

    # Run the optimization, pausing every so often to collect data and
    # print status.
    batch_size = fp_candidates.shape[0]
    num_batches = hps['num_batches']
    print_every = hps['opt_print_every']
    num_opt_loops = int(num_batches / print_every)
    fps = optimizers.get_params(opt_state)
    fp_losses = []
    do_stop = False
    for oidx in range(num_opt_loops):
        if do_stop:
            break
        batch_idx_start = oidx * print_every
        start_time = time.time()
        opt_state = optimize_fp_core_jit(batch_idx_start, print_every,
                                         update_fun, opt_state)
        batch_time = time.time() - start_time

        # Training loss
        fps = optimizers.get_params(opt_state)
        batch_pidx = batch_idx_start + print_every
        total_fp_loss = total_fp_loss_fun(fps)
        fp_losses.append(total_fp_loss)

        # Saving, printing.
        if do_print:
            s = "    Batches {}-{} in {:0.2f} sec, Step size: {:0.5f}, Training loss {:0.5f}"
            print(
                s.format(batch_idx_start + 1, batch_pidx, batch_time,
                         decay_fun(batch_pidx), total_fp_loss))

        if total_fp_loss < hps['fp_opt_stop_tol']:
            do_stop = True
            if do_print:
                print(
                    'Stopping as mean training loss {:0.5f} is below tolerance {:0.5f}.'
                    .format(total_fp_loss, hps['fp_opt_stop_tol']))
        optimizer_details = {'fp_losses': fp_losses}
    return fps, optimizer_details
Ejemplo n.º 16
0
def optimize_lfads(init_params, lfads_hps, lfads_opt_hps, train_data,
                   eval_data):
    """Optimize the LFADS model and print batch based optimization data.

  Arguments:
    init_params: a dict of parameters to be trained
    lfads_hps: dict of lfads model HPs
    lfads_opt_hps: dict of optimization HPs
    train_data: nexamples x time x ndims np array of data for training
    eval_data: nexamples x time x ndims np array of data for evaluation

  Returns:
    a dictionary of trained parameters"""

    batch_size = lfads_hps['batch_size']
    num_batches = lfads_opt_hps['num_batches']
    print_every = lfads_opt_hps['print_every']

    # Build some functions used in optimization.
    kl_warmup_fun = get_kl_warmup_fun(lfads_opt_hps)
    decay_fun = optimizers.exponential_decay(lfads_opt_hps['step_size'],
                                             lfads_opt_hps['decay_steps'],
                                             lfads_opt_hps['decay_factor'])
    opt_init, opt_update = optimizers.adam(step_size=decay_fun,
                                           b1=lfads_opt_hps['adam_b1'],
                                           b2=lfads_opt_hps['adam_b2'],
                                           eps=lfads_opt_hps['adam_eps'])
    update_w_gc = get_update_w_gc_fun(init_params, opt_update)
    update_w_gc_jit = jit(update_w_gc, static_argnums=(2, 3))

    # Begin optimziation loop.
    all_tlosses = []
    all_elosses = []
    start_time = time.time()
    opt_state = opt_init(init_params)
    for bidx in range(num_batches):
        kl_warmup = kl_warmup_fun(bidx)
        didxs = onp.random.randint(0, train_data.shape[0], batch_size)
        x_bxt = train_data[didxs].astype(onp.float32)
        key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT))
        opt_state = update_w_gc_jit(bidx, opt_state, lfads_hps, lfads_opt_hps,
                                    key, x_bxt, kl_warmup)

        if bidx % print_every == 0:
            params = optimizers.get_params(opt_state)

            # Training loss
            didxs = onp.random.randint(0, train_data.shape[0], batch_size)
            x_bxt = train_data[didxs].astype(onp.float32)
            key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT))
            tlosses = lfads.lfads_losses_jit(params, lfads_hps, key, x_bxt,
                                             kl_warmup, 1.0)

            # Evaluation loss
            key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT))
            didxs = onp.random.randint(0, eval_data.shape[0], batch_size)
            ex_bxt = eval_data[didxs].astype(onp.float32)
            # Commented out lfads_eval_losses_jit cuz freezing.
            elosses = lfads.lfads_losses_jit(params, lfads_hps, key, ex_bxt,
                                             kl_warmup, 1.0)
            # Saving, printing.
            all_tlosses.append(tlosses)
            all_elosses.append(elosses)
            batch_time = time.time() - start_time
            s = "Batch {} in {:0.2f} sec, Step size: {:0.5f}, \
              Training loss {:0.0f}, Eval loss {:0.0f}"

            print(
                s.format(bidx, batch_time, decay_fun(bidx), tlosses['total'],
                         elosses['total']))
            start_time = time.time()

            tlosses_thru_training = utils.merge_losses_dicts(all_tlosses)
            elosses_thru_training = utils.merge_losses_dicts(all_elosses)
            optimizer_details = {
                'tlosses': tlosses_thru_training,
                'elosses': elosses_thru_training
            }
    return optimizers.get_params(opt_state), optimizer_details
Ejemplo n.º 17
0
# Plot a few input/target examples to make sure things look sane.
do_plot = False
if do_plot:
    ntoplot = 10
    key, subkey = random.split(key, 2)
    skeys = random.split(subkey, ntoplot)
    inputs, targets = integrator.build_inputs_and_targets_jit(
        input_params, skeys)
    plot_batch(ntimesteps, inputs, targets)

### TRAINING

# Init some parameters for training.
key, subkey = random.split(key, 2)
init_params = rnn.random_vrnn_params(subkey, u, n, o, g=param_scale)
decay_fun = optimizers.exponential_decay(step_size, decay_steps, decay_factor)
opt_init, opt_update = optimizers.adam(decay_fun, adam_b1, adam_b2, adam_eps)
opt_state = opt_init(init_params)
# Run the optimization loop, first jit'd calls will take a minute.
start_time = time.time()
for batch in range(num_batchs):
    key, subkey = random.split(key, 2)
    skeys = random.split(subkey, batch_size)
    inputs, targets = integrator.build_inputs_and_targets_jit(
        input_params, skeys)
    opt_state = rnn.update_w_gc_jit(batch, opt_state, opt_update, inputs,
                                    targets, max_grad_norm, l2reg)
    if batch % print_every == 0:
        params = optimizers.get_params(opt_state)
        train_loss = rnn.loss_jit(params, inputs, targets, l2reg)
        batch_time = time.time() - start_time
Ejemplo n.º 18
0
def optimize_lfads(key, init_params, lfads_hps, lfads_opt_hps, train_data,
                   eval_data):
    """Optimize the LFADS model and print batch based optimization data.

  This loop is at the cpu nonjax-numpy level.

  Arguments:
    init_params: a dict of parameters to be trained
    lfads_hps: dict of lfads model HPs
    lfads_opt_hps: dict of optimization HPs
    train_data: nexamples x time x ndims np array of data for training

  Returns:
    a dictionary of trained parameters"""

    # Begin optimziation loop.
    all_tlosses = []
    all_elosses = []

    # Build some functions used in optimization.
    kl_warmup_fun = get_kl_warmup_fun(lfads_opt_hps)
    decay_fun = optimizers.exponential_decay(lfads_opt_hps['step_size'],
                                             lfads_opt_hps['decay_steps'],
                                             lfads_opt_hps['decay_factor'])

    opt_init, opt_update, get_params = optimizers.adam(
        step_size=decay_fun,
        b1=lfads_opt_hps['adam_b1'],
        b2=lfads_opt_hps['adam_b2'],
        eps=lfads_opt_hps['adam_eps'])
    opt_state = opt_init(init_params)

    def update_w_gc(i, opt_state, lfads_hps, lfads_opt_hps, key, x_bxt,
                    kl_warmup):
        """Update fun for gradients, includes gradient clipping."""
        params = get_params(opt_state)
        grads = grad(lfads.lfads_training_loss)(params, lfads_hps, key, x_bxt,
                                                kl_warmup,
                                                lfads_opt_hps['keep_rate'])
        clipped_grads = optimizers.clip_grads(grads,
                                              lfads_opt_hps['max_grad_norm'])
        return opt_update(i, clipped_grads, opt_state)

    # Run the optimization, pausing every so often to collect data and
    # print status.
    batch_size = lfads_hps['batch_size']
    num_batches = lfads_opt_hps['num_batches']
    print_every = lfads_opt_hps['print_every']
    num_opt_loops = int(num_batches / print_every)
    params = get_params(opt_state)
    for oidx in range(num_opt_loops):
        batch_idx_start = oidx * print_every
        start_time = time.time()
        key, tkey, dtkey, dekey = random.split(random.fold_in(key, oidx), 4)
        opt_state = optimize_lfads_core_jit(tkey, batch_idx_start, print_every,
                                            update_w_gc, kl_warmup_fun,
                                            opt_state, lfads_hps,
                                            lfads_opt_hps, train_data)
        batch_time = time.time() - start_time

        # Losses
        params = get_params(opt_state)
        batch_pidx = batch_idx_start + print_every
        kl_warmup = kl_warmup_fun(batch_idx_start)
        # Training loss
        didxs = onp.random.randint(0, train_data.shape[0], batch_size)
        x_bxt = train_data[didxs].astype(onp.float32)
        tlosses = lfads.lfads_losses_jit(params, lfads_hps, dtkey, x_bxt,
                                         kl_warmup, 1.0)

        # Evaluation loss
        didxs = onp.random.randint(0, eval_data.shape[0], batch_size)
        ex_bxt = eval_data[didxs].astype(onp.float32)
        elosses = lfads.lfads_losses_jit(params, lfads_hps, dekey, ex_bxt,
                                         kl_warmup, 1.0)
        # Saving, printing.
        all_tlosses.append(tlosses)
        all_elosses.append(elosses)
        s = "Batches {}-{} in {:0.2f} sec, Step size: {:0.5f}, Training loss {:0.0f}, Eval loss {:0.0f}"
        print(
            s.format(batch_idx_start + 1, batch_pidx, batch_time,
                     decay_fun(batch_pidx), tlosses['total'],
                     elosses['total']))

        tlosses_thru_training = utils.merge_losses_dicts(all_tlosses)
        elosses_thru_training = utils.merge_losses_dicts(all_elosses)
        optimizer_details = {
            'tlosses': tlosses_thru_training,
            'elosses': elosses_thru_training
        }

    return params, optimizer_details