예제 #1
0
def make_gradfun(run_inference, recognize, loglike, pgm_prior, data,
                 batch_size, num_samples, natgrad_scale=1., callback=callback):
    _, unflat = flatten(pgm_prior)
    num_datapoints = get_num_datapoints(data)
    data_batches, num_batches = split_into_batches(data, batch_size)
    get_batch = lambda i: data_batches[i % num_batches]
    saved = lambda: None

    def mc_elbo(pgm_params, loglike_params, recogn_params, i):
        nn_potentials = recognize(recogn_params, get_batch(i))
        samples, saved.stats, global_kl, local_kl = \
            run_inference(pgm_prior, pgm_params, nn_potentials, num_samples)
        return (num_batches * loglike(loglike_params, samples, get_batch(i))
                - global_kl - num_batches * local_kl) / num_datapoints

    def gradfun(params, i):
        pgm_params, loglike_params, recogn_params = params
        objective = lambda (loglike_params, recogn_params): \
            -mc_elbo(pgm_params, loglike_params, recogn_params, i)
        val, (loglike_grad, recogn_grad) = vgrad(objective)((loglike_params, recogn_params))
        # this expression for pgm_natgrad drops a term that can be computed using
        # the function autograd.misc.fixed_points.fixed_point
        pgm_natgrad = -natgrad_scale / num_datapoints * \
            (flat(pgm_prior) + num_batches*flat(saved.stats) - flat(pgm_params))
        grad = unflat(pgm_natgrad), loglike_grad, recogn_grad
        if callback: callback(i, val, params, grad)
        return grad

    return gradfun
예제 #2
0
def adam(data, val_and_grad, callback=None):
    num_datapoints = get_num_datapoints(data)

    def adam(allparams,
             nat_stepsize,
             stepsize,
             num_epochs,
             seq_len,
             num_seqs=None,
             b1=0.9,
             b2=0.999,
             eps=1e-8,
             num_samples=1):
        natparams, params = allparams[:1], allparams[1:]
        m = zeros_like(params)
        v = zeros_like(params)
        i = 0
        accumulate = lambda rho, a, b: add(scale(1 - rho, a), scale(rho, b))

        for epoch in xrange(num_epochs):
            vals = []
            batches, num_batches = split_into_batches(data, seq_len, num_seqs)
            for y in batches:
                val, grad = scale(
                    1. / num_datapoints,
                    val_and_grad(y, num_batches, num_samples, *allparams))
                natgrad, grad = grad[:1], grad[1:]

                m = accumulate(b1, grad, m)  # first moment estimate
                v = accumulate(b2, square(grad), v)  # second moment estimate
                mhat = scale(1. / (1 - b1**(i + 1)), m)  # bias correction
                vhat = scale(1. / (1 - b2**(i + 1)), v)
                update = scale(stepsize, div(mhat, add_scalar(eps,
                                                              sqrt(vhat))))

                natparams = add(natparams, scale(nat_stepsize, natgrad))
                params = add(params, update)
                allparams = concat(natparams, params)
                vals.append(val)
                i += 1

                if callback: callback(epoch, vals, natgrad, allparams)

        return allparams

    return adam
예제 #3
0
def adadelta(data, val_and_grad, callback=None):
    num_datapoints = get_num_datapoints(data)

    def adadelta(allparams,
                 nat_stepsize,
                 num_epochs,
                 seq_len,
                 num_seqs=None,
                 rho=0.95,
                 epsilon=1e-6,
                 num_samples=1,
                 permute=True):
        natparams, params = allparams[:1], allparams[1:]
        sum_gsq = zeros_like(params)  # accumulated sq. grads
        sum_usq = zeros_like(params)  # accumulated sq. updates
        accumulate = lambda a, b: add(scale(rho, a), scale(1 - rho, b))

        for epoch in xrange(num_epochs):
            vals = []
            batches, num_batches = split_into_batches(data, seq_len, num_seqs)
            for y in batches:
                val, grad = scale(
                    1. / num_datapoints,
                    val_and_grad(y, num_batches, num_samples, *allparams))
                natgrad, grad = grad[:1], grad[1:]
                sum_gsq = accumulate(sum_gsq, square(grad))
                diag_scaling = div(sqrt(add_scalar(epsilon, sum_usq)),
                                   sqrt(add_scalar(epsilon, sum_gsq)))
                update = mul(diag_scaling, grad)
                sum_usq = accumulate(sum_usq, square(update))

                natparams = add(natparams, scale(nat_stepsize, natgrad))
                params = add(params, update)
                allparams = concat(natparams, params)
                vals.append(val)

                if callback: callback(epoch, vals, natgrad, allparams)
        return allparams

    return adadelta
예제 #4
0
파일: svae.py 프로젝트: lfywork/svae
def make_gradfun(run_inference,
                 recognize,
                 loglike,
                 pgm_prior,
                 pgm_expectedstats,
                 data,
                 batch_size,
                 num_samples,
                 natgrad_scale=1.,
                 callback=callback):
    _, unflat = flatten(pgm_prior)
    num_datapoints = get_num_datapoints(data)
    data_batches, num_batches = split_into_batches(data, batch_size)
    get_batch = lambda i: data_batches[i % num_batches]
    saved = lambda: None

    def mc_elbo(pgm_params, pgm_stats, loglike_params, recogn_params, i):
        nn_potentials = recognize(recogn_params, get_batch(i))
        samples, saved.stats, global_kl, local_kl = \
            run_inference(pgm_prior, pgm_params, pgm_stats, nn_potentials, num_samples)
        return (num_batches * loglike(loglike_params, samples, get_batch(i)) -
                global_kl - num_batches * local_kl) / num_datapoints

    def gradfun(params, i):
        pgm_params, loglike_params, recogn_params = params
        objective = lambda (pgm_stats, loglike_params, recogn_params): \
            -mc_elbo(pgm_params, pgm_stats, loglike_params, recogn_params, i)
        pgm_stats = pgm_expectedstats(pgm_params)
        val, (pgm_stats_grad, loglike_grad, recogn_grad) = vgrad(objective)(
            (pgm_stats, loglike_params, recogn_params))
        pgm_natgrad = -natgrad_scale / num_datapoints * \
            (flat(pgm_prior) + num_batches*(flat(saved.stats) + flat(pgm_stats_grad)) - flat(pgm_params))
        grad = unflat(pgm_natgrad), loglike_grad, recogn_grad
        if callback: callback(i, val, params, grad)
        return grad

    return gradfun