Beispiel #1
0
    def __call__(self, x, y, stride=10, **model_args):

        model = self.model

        inference_outs, _, updates = self.inference(x, y)
        i_costs = inference_outs['i_costs']

        qss = inference_outs['qss']

        if self.n_inference_steps > stride and stride != 0:
            steps = [0, 1] + range(stride, self.n_inference_steps, stride)
            steps = steps[:-1] + [self.n_inference_steps - 1]
        elif self.n_inference_steps > 0:
            steps = [0, self.n_inference_steps - 1]
        else:
            steps = [0]

        full_results = OrderedDict()
        full_results['i_cost'] = []
        samples = OrderedDict()
        for i in steps:
            qks = [qs[i] for qs in qss]
            results_k, samples_k, _ = model(x, y, qks, **model_args)
            samples_k['qs'] = qks
            update_dict_of_lists(full_results, **results_k)
            full_results['i_cost'].append(i_costs[i])
            update_dict_of_lists(samples, **samples_k)

        results = OrderedDict()
        for k, v in full_results.iteritems():
            results[k] = v[-1]
            results[k + '0'] = v[0]
            results['d_' + k] = v[0] - v[-1]

        return results, samples, full_results, updates
Beispiel #2
0
def train(
    out_path='', name='', model_to_load=None, save_images=True,
    dim_h=None, dim_hs=None, center_input=True, prior='binomial',
    recognition_net=None, generation_net=None,

    learning_args=dict(),
    inference_args=dict(),
    inference_args_test=dict(),
    dataset_args=None):

    if dim_h is None:
        assert dim_hs is not None
        deep = True
    else:
        assert dim_hs is None
        deep = False

    # ========================================================================
    learning_args = init_learning_args(**learning_args)
    inference_args = init_inference_args(**inference_args)
    inference_args_test = init_inference_args(**inference_args_test)

    print 'Dataset args: %s' % pprint.pformat(dataset_args)
    print 'Learning args: %s' % pprint.pformat(learning_args)
    print 'Inference args: %s' % pprint.pformat(inference_args)
    print 'Inference args (test): %s' % pprint.pformat(inference_args_test)

    # ========================================================================
    print_section('Setting up data')
    train, valid, test = load_data(
        train_batch_size=learning_args['batch_size'],
        valid_batch_size=learning_args['valid_batch_size'],
        **dataset_args)

    # ========================================================================
    print_section('Setting model and variables')
    dim_in = train.dims[train.name]
    batch_size = learning_args['batch_size']

    X = T.matrix('x', dtype=floatX)
    X.tag.test_value = np.zeros((batch_size, dim_in), dtype=X.dtype)
    trng = get_trng()

    if center_input:
        print 'Centering input with train dataset mean image'
        X_mean = theano.shared(train.mean_image.astype(floatX), name='X_mean')
        X_i = X - X_mean
    else:
        X_i = X

    # ========================================================================
    print_section('Loading model and forming graph')

    if prior == 'gaussian':
        if deep:
            raise NotImplementedError()
        C = GBN
        PC = Gaussian
        unpack = unpack_gbn
        model_name = 'gbn'
    elif prior == 'binomial':
        if deep:
            C = DeepSBN
            unpack = unpack_deepsbn
        else:
            C = SBN
            unpack = unpack_sbn
        PC = Binomial
        model_name = 'sbn'
    elif prior == 'darn':
        if deep:
            raise NotImplementedError()
        C = SBN
        PC = AutoRegressor
        unpack = unpack_sbn
        model_name = 'sbn'
    else:
        raise ValueError(prior)

    if model_to_load is not None:
        models, _ = load_model(model_to_load, unpack,
                               distributions=train.distributions, dims=train.dims)
    else:
        if deep:
            model = C(dim_in, dim_hs, trng=trng)
        else:
            prior_model = PC(dim_h)
            mlps = C.mlp_factory(
                dim_h, train.dims, train.distributions,
                recognition_net=recognition_net,
                generation_net=generation_net)

            model = C(dim_in, dim_h, trng=trng, prior=prior_model, **mlps)

        models = OrderedDict()
        models[model.name] = model

    model = models[model_name]
    tparams = model.set_tparams(excludes=[])
    print_profile(tparams)

    # ==========================================================================
    print_section('Getting cost')

    inference_method = inference_args['inference_method']

    if inference_method is not None:
        inference = resolve_inference(model, deep=deep, **inference_args)
    else:
        inference = None

    if inference_method == 'momentum':
        if prior == 'binomial':
            raise NotImplementedError()
        i_results, constants, updates = inference.inference(X_i, X)
        qk = i_results['qk']
        results, samples, constants_m = model(
            X_i, X, qk, pass_gradients=inference_args['pass_gradients'],
            n_posterior_samples=learning_args['n_posterior_samples'])
        constants += constants_m
    elif inference_method == 'rws':
        results, _, constants = inference(
            X_i, X, n_posterior_samples=learning_args['n_posterior_samples'])
        updates = theano.OrderedUpdates()
    elif inference_method == 'air':
        if prior == 'gaussian':
            raise NotImplementedError()
        i_results, constants, updates = inference.inference(X_i, X)
        qk = i_results['qk']
        results, _, _ = model(
            X_i, X, qk, n_posterior_samples=learning_args['n_posterior_samples'])
    elif inference_method is None:
        if prior != 'gaussian':
            raise NotImplementedError()
        qk = None
        constants = []
        updates = theano.OrderedUpdates()
        results, samples, constants_m = model(
            X_i, X, qk, pass_gradients=inference_args['pass_gradients'],
            n_posterior_samples=learning_args['n_posterior_samples'])
        constants += constants_m
    else:
        raise ValueError(inference_method)

    cost = results.pop('cost')
    extra_outs = []
    extra_outs_keys = ['cost']

    l2_decay = learning_args['l2_decay']
    if l2_decay > 0.:
        print 'Adding %.5f L2 weight decay' % l2_decay
        l2_rval = model.l2_decay(l2_decay)
        cost += l2_rval.pop('cost')
        extra_outs += l2_rval.values()
        extra_outs_keys += l2_rval.keys()

    # ==========================================================================
    print_section('Test functions')
    # Test function with sampling
    inference_method_test = inference_args_test['inference_method']
    if inference_method_test is not None:
        inference = resolve_inference(model, deep=deep, **inference_args_test)
    else:
        inference = None

    if inference_method_test == 'momentum':
        if prior == 'binomial':
            raise NotImplementedError()
        results, samples, full_results, updates_s = inference(
            X_i, X,
            n_posterior_samples=learning_args['n_posterior_samples_test'])
        py = samples['py'][-1]
    elif inference_method_test == 'rws':
        results, samples,  = inference(
            X_i, X, n_posterior_samples=learning_args['n_posterior_samples_test'])
        updates_s = theano.OrderedUpdates()
        py = samples['py']
    elif inference_method_test == 'air':
        results, samples, full_results, updates_s = inference(
            X_i, X, n_posterior_samples=learning_args['n_posterior_samples_test'])
        py = samples['py'][-1]
    elif inference_method_test is None:
        updates_s = theano.OrderedUpdates()
        py = samples['py']
    else:
        raise ValueError(inference_method_test)

    f_test_keys = results.keys()
    f_test = theano.function([X], results.values(), updates=updates_s)
    f_icost = theano.function([X], full_results['i_cost'], updates=updates_s)

    # ========================================================================
    print_section('Setting final tparams and save function')

    all_params = OrderedDict((k, v) for k, v in tparams.iteritems())

    tparams = OrderedDict((k, v)
        for k, v in tparams.iteritems()
        if (v not in updates.keys() or v not in excludes))

    print 'Learned model params: %s' % tparams.keys()
    print 'Saved params: %s' % all_params.keys()

    def save(tparams, outfile):
        d = dict((k, v.get_value()) for k, v in all_params.items())
        d.update(
            dim_in=dim_in,
            dim_h=dim_h,
            dim_hs=dim_hs,
            prior=prior,
            center_input=center_input,
            generation_net=generation_net,
            recognition_net=recognition_net,
            dataset_args=dataset_args
        )
        np.savez(outfile, **d)

    # ========================================================================
    print_section('Getting gradients.')
    grads = T.grad(cost, wrt=itemlist(tparams),
                   consider_constant=constants)

    # ========================================================================
    print_section('Building optimizer')
    lr = T.scalar(name='lr')
    optimizer = learning_args['optimizer']
    optimizer_args = learning_args['optimizer_args']
    f_grad_shared, f_grad_updates = eval('op.' + optimizer)(
        lr, tparams, grads, [X], cost, extra_ups=updates,
        extra_outs=extra_outs, **optimizer_args)

    monitor = SimpleMonitor()

    # ========================================================================
    print_section('Actually running (main loop)')

    best_cost = float('inf')
    best_epoch = 0

    if out_path is not None:
        bestfile = path.join(out_path, '{name}_best.npz'.format(name=name))

    epochs = learning_args['epochs']
    learning_rate = learning_args['learning_rate']
    learning_rate_schedule = learning_args['learning_rate_schedule']
    valid_key = learning_args['valid_key']
    valid_sign = learning_args['valid_sign']
    try:
        epoch_t0 = time.time()
        s = 0
        e = 0

        widgets = ['Epoch {epoch} (training {name}, '.format(epoch=e, name=name),
                   Timer(), '): ', Bar()]
        epoch_pbar = ProgressBar(widgets=widgets, maxval=train.n).start()
        training_time = 0
        while True:
            try:
                x = train.next()[train.name]
                if train.pos == -1:
                    epoch_pbar.update(train.n)
                else:
                    epoch_pbar.update(train.pos)

            except StopIteration:
                print
                epoch_t1 = time.time()
                training_time += (epoch_t1 - epoch_t0)
                valid.reset()
                maxvalid = valid.n

                widgets = ['Validating: (%d posterior samples) '
                           % learning_args['n_posterior_samples_test'],
                           Percentage(), ' (', Timer(), ')']
                pbar    = ProgressBar(widgets=widgets, maxval=maxvalid).start()
                results_train = OrderedDict()
                results_valid = OrderedDict()
                while True:
                    try:
                        x_valid = valid.next()[train.name]
                        if valid.pos > maxvalid:
                            raise StopIteration
                        x_train = train.next()[train.name]
                        #print f_icost(x_valid)
                        r_train = f_test(x_train)
                        r_valid = f_test(x_valid)
                        results_i_train = dict((k, v) for k, v in zip(f_test_keys, r_train))
                        results_i_valid = dict((k, v) for k, v in zip(f_test_keys, r_valid))
                        update_dict_of_lists(results_train, **results_i_train)
                        update_dict_of_lists(results_valid, **results_i_valid)

                        if valid.pos == -1:
                            pbar.update(maxvalid)
                        else:
                            pbar.update(valid.pos)

                    except StopIteration:
                        print
                        break

                def summarize(d):
                    for k, v in d.iteritems():
                        d[k] = np.mean(v)

                summarize(results_train)
                summarize(results_valid)
                valid_value = results_valid[valid_key]
                if valid_sign == '-':
                    valid_value *= -1

                if valid_value < best_cost:
                    print 'Found best %s: %.2f' % (valid_key, valid_value)
                    best_cost = valid_value
                    best_epoch = e
                    if out_path is not None:
                        print 'Saving best to %s' % bestfile
                        save(tparams, bestfile)
                else:
                    print 'Best (%.2f) at epoch %d' % (best_cost, best_epoch)

                monitor.update(**results_train)
                monitor.update(dt_epoch=(epoch_t1-epoch_t0),
                               training_time=training_time)
                monitor.update_valid(**results_valid)
                monitor.display()

                monitor.save(path.join(
                    out_path, '{name}_monitor.png').format(name=name))
                monitor.save_stats(path.join(
                    out_path, '{name}_monitor.npz').format(name=name))
                monitor.save_stats_valid(path.join(
                    out_path, '{name}_monitor_valid.npz').format(name=name))

                e += 1
                epoch_t0 = time.time()

                valid.reset()
                train.reset()

                if learning_rate_schedule is not None:
                    if 'decay' in learning_rate_schedule.keys():
                        learning_rate /= learning_rate_schedule['decay']
                        print 'Changing learning rate to %.5f' % learning_rate
                    elif e in learning_rate_schedule.keys():
                        lr = learning_rate_schedule[e]
                        print 'Changing learning rate to %.5f' % lr
                        learning_rate = lr

                widgets = ['Epoch {epoch} ({name}, '.format(epoch=e, name=name),
                           Timer(), '): ', Bar()]
                epoch_pbar = ProgressBar(widgets=widgets, maxval=train.n).start()

                continue

            if e > epochs:
                break

            rval = f_grad_shared(x)
            check_bad_nums(rval, extra_outs_keys)
            if check_bad_nums(rval[:1], extra_outs_keys[:1]):
                print zip(extra_outs_keys, rval)
                print 'Dying, found bad cost... Sorry (bleh)'
                exit()
            f_grad_updates(learning_rate)
            s += 1

    except KeyboardInterrupt:
        print 'Training interrupted'

    if out_path is not None:
        outfile = path.join(out_path, '{name}_{t}.npz'.format(name=name, t=int(time.time())))
        last_outfile = path.join(out_path, '{name}_last.npz'.format(name=name))

        print 'Saving'
        save(tparams, outfile)
        save(tparams, last_outfile)
        print 'Done saving.'

    print 'Bye bye!'
Beispiel #3
0
def compare(model_dirs,
            out_path,
            name=None,
            by_training_time=False,
            omit_deltas=True,
            **test_args):

    model_results = OrderedDict()
    valid_results = OrderedDict()

    for model_dir in model_dirs:
        n = model_dir.split('/')[-1]
        if n == '':
            n = model_dir.split('/')[-2]
        result_file = path.join(model_dir, '{name}_monitor.npz'.format(name=n))
        params = np.load(result_file)
        d = dict(params)
        update_dict_of_lists(model_results, **d)

        valid_file = path.join(model_dir,
                               '{name}_monitor_valid.npz'.format(name=n))
        params_valid = np.load(valid_file)
        d_valid = dict(params_valid)
        update_dict_of_lists(valid_results, **d_valid)
        update_dict_of_lists(model_results, name=n)

    if omit_deltas:
        model_results = OrderedDict((k, v)
                                    for k, v in model_results.iteritems()
                                    if not k.startswith('d_'))

    model_results.pop('dt_epoch')
    names = model_results.pop('name')
    training_times = model_results.pop('training_time')

    if name is None:
        name = '.'.join(names)

    out_dir = path.join(out_path, 'compare.' + name)
    if path.isfile(out_dir):
        raise ValueError()
    elif not path.isdir(out_dir):
        os.mkdir(path.abspath(out_dir))

    plt.clf()
    x = 3
    y = ((len(model_results) - 1) // x) + 1

    fig, axes = plt.subplots(y, x)
    fig.set_size_inches(15, 10)

    if by_training_time:
        xlabel = 'seconds'
        us = [tt - tt[0] for tt in training_times]
    else:
        us = [range(tt.shape[0]) for tt in training_times]
        xlabel = 'epochs'

    for j, (k, vs) in enumerate(model_results.iteritems()):
        ax = axes[j // x, j % x]
        for n, u, v in zip(names, us, vs):
            ax.plot(u[10:], v[10:], label=n)

        if k in valid_results.keys():
            for n, u, v in zip(names, us, valid_results[k]):
                ax.plot(u[10:], v[10:], label=n + '(v)')

        ax.set_ylabel(k)
        ax.set_xlabel(xlabel)
        ax.legend()
        ax.patch.set_alpha(0.5)

    plt.tight_layout()
    plt.savefig(path.join(out_dir, 'results.png'))
    plt.close()

    print 'Sampling from priors'

    results = OrderedDict()
    hps = OrderedDict()
    for model_dir in model_dirs:
        models, data_iter, name, exp_dict, mean_image, deep, inference_method = unpack_model_and_data(
            model_dir)
        sample_from_prior(models, data_iter, name, out_dir)
        rs = test(models,
                  data_iter,
                  name,
                  mean_image,
                  deep=deep,
                  inference_method=inference_method,
                  **test_args)
        update_dict_of_lists(results, **rs)
        update_dict_of_lists(hps, **exp_dict)

    columns = ['Stat'] + names
    data = [[k] + v for k, v in hps.iteritems()]
    data += [[k] + v for k, v in results.iteritems() if not k.startswith('d_')]

    with open(path.join(out_dir, 'summary.txt'), 'w+') as f:
        print >> f, tabulate(data, headers=columns)

    print tabulate(data, headers=columns)
Beispiel #4
0
def test(models,
         data_iter,
         name,
         mean_image,
         deep=False,
         data_samples=10000,
         n_posterior_samples=1000,
         inference_args=None,
         inference_method=None,
         dx=100,
         calculate_true_likelihood=False,
         center_input=True,
         **extra_kwargs):

    model = models['main']
    tparams = model.set_tparams()
    data_iter.reset()

    X = T.matrix('x', dtype=floatX)

    if center_input:
        print 'Centering input with train dataset mean image'
        X_mean = theano.shared(mean_image, name='X_mean')
        X_i = X - X_mean
    else:
        X_i = X.copy()

    inference = resolve_inference(model,
                                  deep=deep,
                                  inference_method=inference_method,
                                  **inference_args)

    if inference_method == 'momentum':
        if prior == 'binomial':
            raise NotImplementedError()
        results, samples, full_results, updates = inference(
            X_i, X, n_posterior_samples=n_posterior_samples)
    elif inference_method == 'air':
        results, samples, full_results, updates = inference(
            X_i, X, n_posterior_samples=n_posterior_samples)
    else:
        raise ValueError(inference_method)

    f_test_keys = results.keys()
    f_test = theano.function([X], results.values(), updates=updates)
    widgets = ['Testing %s:' % name, Timer(), Bar()]
    pbar = ProgressBar(maxval=data_iter.n).start()
    rs = OrderedDict()
    while True:
        try:
            y = data_iter.next(batch_size=dx)[data_iter.name]
        except StopIteration:
            break
        r = f_test(y)
        rs_i = dict((k, v) for k, v in zip(f_test_keys, r))
        update_dict_of_lists(rs, **rs_i)

        if data_iter.pos == -1:
            pbar.update(data_iter.n)
        else:
            pbar.update(data_iter.pos)
    print

    def summarize(d):
        for k, v in d.iteritems():
            d[k] = np.mean(v)

    summarize(rs)

    return rs
Beispiel #5
0
def train_model(
    out_path='', name='', load_last=False, model_to_load=None, save_images=True,

    learning_rate=0.0001, optimizer='rmsprop', optimizer_args=dict(),
    learning_rate_schedule=None,
    batch_size=100, valid_batch_size=100, test_batch_size=1000,
    max_valid=10000,
    epochs=100,

    dim_h=300, prior='logistic', pass_gradients=False,
    l2_decay=0.,

    input_mode=None,
    generation_net=None, recognition_net=None,
    excludes=['gaussian.log_sigma'],
    center_input=True,

    z_init=None,
    inference_method='momentum',
    inference_rate=.01,
    n_inference_steps=20,
    n_inference_steps_test=20,
    n_inference_samples=20,
    n_inference_samples_test=100,
    extra_inference_args=dict(),

    n_mcmc_samples=20,
    n_mcmc_samples_test=20,

    dataset=None, dataset_args=None,
    model_save_freq=1000, show_freq=100
    ):

    kwargs = dict(
        z_init=z_init,
        inference_method=inference_method,
        inference_rate=inference_rate,
        extra_inference_args=extra_inference_args
    )

    # ========================================================================
    print 'Dataset args: %s' % pprint.pformat(dataset_args)
    print 'Model args: %s' % pprint.pformat(kwargs)

    # ========================================================================
    print 'Setting up data'
    train, valid, test = load_data(dataset,
                                   batch_size,
                                   valid_batch_size,
                                   test_batch_size,
                                   **dataset_args)

    # ========================================================================
    print 'Setting model and variables'
    dim_in = train.dims[dataset]
    X = T.matrix('x', dtype=floatX)
    X.tag.test_value = np.zeros((batch_size, dim_in), dtype=X.dtype)
    trng = RandomStreams(random.randint(0, 1000000))

    if input_mode == 'sample':
        print 'Sampling datapoints'
        X = trng.binomial(p=X, size=X.shape, n=1, dtype=X.dtype)
    elif input_mode == 'noise':
        print 'Adding noise to data points'
        X = X * trng.binomial(p=0.1, size=X.shape, n=1, dtype=X.dtype)

    if center_input:
        print 'Centering input with train dataset mean image'
        X_mean = theano.shared(train.mean_image.astype(floatX), name='X_mean')
        X_i = X - X_mean
    else:
        X_i = X

    # ========================================================================
    print 'Loading model and forming graph'

    if prior == 'logistic':
        out_act = 'T.nnet.sigmoid'
    elif prior == 'darn':
        out_act = 'T.nnet.sigmoid'
    elif prior == 'gaussian':
        out_act = 'lambda x: x'
    else:
        raise ValueError('%s prior not known' % prior)

    if recognition_net is not None:
        input_layer = recognition_net.pop('input_layer')
        recognition_net['dim_in'] = train.dims[input_layer]
        recognition_net['dim_out'] = dim_h
        recognition_net['out_act'] = out_act
    if generation_net is not None:
        generation_net['dim_in'] = dim_h
        t = generation_net.get('type', None)
        if t is None or t == 'darn':
            generation_net['dim_out'] = train.dims[generation_net['output']]
            generation_net['out_act'] = train.acts[generation_net['output']]
        elif t == 'MMMLP':
            generation_net['graph']['outs'] = dict()
            for out in generation_net['graph']['outputs']:
                generation_net['graph']['outs'][out] = dict(
                    dim=train.dims[out],
                    act=train.acts[out]
                )
        else:
            raise ValueError()

    if model_to_load is not None:
        models, _ = load_model(model_to_load, unpack, **kwargs)
    elif load_last:
        model_file = glob(path.join(out_path, '*last.npz'))[0]
        models, _ = load_model(model_file, unpack, **kwargs)
    else:
        if prior == 'logistic':
            prior_model = Bernoulli(dim_h)
        elif prior == 'darn':
            prior_model = AutoRegressor(dim_h)
        elif prior == 'gaussian':
            prior_model = Gaussian(dim_h)
        else:
            raise ValueError('%s prior not known' % prior)

        mlps = SBN.mlp_factory(recognition_net=recognition_net,
                               generation_net=generation_net)

        if prior == 'logistic' or prior == 'darn':
            C = SBN
        elif prior == 'gaussian':
            C = GBN
        else:
            raise ValueError()

        kwargs.update(**mlps)
        model = C(recognition_net['dim_in'], dim_h, trng=trng, prior=prior_model, **kwargs)

        models = OrderedDict()
        models[model.name] = model

    if prior == 'logistic' or prior == 'darn':
        model = models['sbn']
    elif prior == 'gaussian':
        model = models['gbn']

    tparams = model.set_tparams(excludes=[])
    print_profile(tparams)

    # ========================================================================
    print 'Getting cost'
    results, updates, constants = model.inference(
        X_i, X, n_inference_steps=n_inference_steps, n_samples=n_mcmc_samples,
        n_inference_samples=n_inference_samples, pass_gradients=pass_gradients)

    cost = results.pop('cost')
    extra_outs = []
    extra_outs_names = ['cost']

    if l2_decay > 0.:
        print 'Adding %.5f L2 weight decay' % l2_decay
        rec_l2_cost = model.posterior.get_L2_weight_cost(l2_decay)
        gen_l2_cost = model.conditional.get_L2_weight_cost(l2_decay)
        cost += rec_l2_cost + gen_l2_cost
        extra_outs += [rec_l2_cost, gen_l2_cost]
        extra_outs_names += ['Rec net L2 cost', 'Gen net L2 cost']
        if prior == 'darn':
            print 'Adding autoregressor weight decay'
            ar_l2_cost = model.prior.get_L2_weight_cost(l2_decay)
            cost += ar_l2_cost
            extra_outs += [ar_l2_cost]
            extra_outs_names += ['AR L2 cost']

    # ========================================================================
    print 'Extra functions'
    # Test function with sampling
    results_s, samples, full_results, updates_s = model(
        X_i, X, n_samples=n_mcmc_samples_test,
        n_inference_steps=n_inference_steps_test,
        n_inference_samples=n_inference_samples_test)

    f_test_keys = results_s.keys()
    f_test = theano.function([X], results_s.values(), updates=updates_s)

    py_s = samples['py'][-1]
    (pd_s, d_hat_s), updates_c = concatenate_inputs(model, X, py_s)
    updates_s.update(updates_c)

    f_sample = theano.function([X], [pd_s, d_hat_s], updates=updates_s)

    # Sample from prior
    py_p, updates_p = model.sample_from_prior()
    f_prior = theano.function([], py_p, updates=updates_p)

    # ========================================================================
    print 'Setting final tparams and save function'

    all_params = OrderedDict((k, v) for k, v in tparams.iteritems())

    tparams = OrderedDict((k, v)
        for k, v in tparams.iteritems()
        if (v not in updates.keys() or v not in excludes))

    print 'Learned model params: %s' % tparams.keys()
    print 'Saved params: %s' % all_params.keys()

    def save(tparams, outfile):
        d = dict((k, v.get_value()) for k, v in all_params.items())

        d.update(
            dim_in=dim_in,
            dim_h=dim_h,
            input_mode=input_mode,
            prior=prior,
            generation_net=generation_net, recognition_net=recognition_net,
            dataset=dataset, dataset_args=dataset_args
        )
        np.savez(outfile, **d)

    # ========================================================================
    print 'Getting gradients.'
    grads = T.grad(cost, wrt=itemlist(tparams),
                   consider_constant=constants)

    # ========================================================================
    print 'Building optimizer'
    lr = T.scalar(name='lr')
    f_grad_shared, f_grad_updates = eval('op.' + optimizer)(
        lr, tparams, grads, [X], cost, extra_ups=updates,
        extra_outs=extra_outs, **optimizer_args)

    monitor = SimpleMonitor()

    # ========================================================================
    print 'Actually running (main loop)'

    best_cost = float('inf')
    best_epoch = 0

    if out_path is not None:
        bestfile = path.join(out_path, '{name}_best.npz'.format(name=name))

    try:
        epoch_t0 = time.time()
        s = 0
        e = 0

        widgets = ['Epoch {epoch} ({name}, '.format(epoch=e, name=name),
                   Timer(), '): ', Bar()]
        epoch_pbar = ProgressBar(widgets=widgets, maxval=train.n).start()
        training_time = 0
        while True:
            try:
                x, _ = train.next()
                if train.pos == -1:
                    epoch_pbar.update(train.n)
                else:
                    epoch_pbar.update(train.pos)

            except StopIteration:
                print
                epoch_t1 = time.time()
                training_time += (epoch_t1 - epoch_t0)
                valid.reset()
                maxvalid = min(max_valid, valid.n)

                widgets =['Validating: (%d posterior samples) ' % n_mcmc_samples_test,
                          Percentage(), ' (', Timer(), ')']
                pbar    = ProgressBar(widgets=widgets, maxval=maxvalid).start()
                results_train = OrderedDict()
                results_valid = OrderedDict()
                while True:
                    try:
                        x_valid, _ = valid.next()
                        if valid.pos > max_valid:
                            raise StopIteration
                        x_train, _ = train.next()

                        r_train = f_test(x_train)
                        r_valid = f_test(x_valid)
                        results_i_train = dict((k, v) for k, v in zip(f_test_keys, r_train))
                        results_i_valid = dict((k, v) for k, v in zip(f_test_keys, r_valid))
                        update_dict_of_lists(results_train, **results_i_train)
                        update_dict_of_lists(results_valid, **results_i_valid)

                        if valid.pos == -1:
                            pbar.update(maxvalid)
                        else:
                            pbar.update(valid.pos)

                    except StopIteration:
                        print
                        break

                def summarize(d):
                    for k, v in d.iteritems():
                        d[k] = np.mean(v)

                summarize(results_train)
                summarize(results_valid)
                lower_bound = results_valid['lower_bound']

                if lower_bound < best_cost:
                    print 'Found best: %.2f' % lower_bound
                    best_cost = lower_bound
                    best_epoch = e
                    if out_path is not None:
                        print 'Saving best to %s' % bestfile
                        save(tparams, bestfile)
                else:
                    print 'Best (%.2f) at epoch %d' % (best_cost, best_epoch)

                monitor.update(**results_train)
                monitor.update(dt_epoch=(epoch_t1-epoch_t0),
                               training_time=training_time)
                monitor.update_valid(**results_valid)
                monitor.display()

                monitor.save(path.join(
                    out_path, '{name}_monitor.png').format(name=name))
                monitor.save_stats(path.join(
                    out_path, '{name}_monitor.npz').format(name=name))
                monitor.save_stats_valid(path.join(
                    out_path, '{name}_monitor_valid.npz').format(name=name))

                prior_file = path.join(out_path, 'samples_from_prior.png')
                print 'Saving posterior samples'
                samples = f_prior()
                train.save_images(samples[:, None], prior_file, x_limit=10)

                e += 1
                epoch_t0 = time.time()

                valid.reset()
                train.reset()

                if learning_rate_schedule is not None:
                    if 'decay' in learning_rate_schedule.keys():
                        learning_rate /= learning_rate_schedule['decay']
                        print 'Changing learning rate to %.5f' % learning_rate
                    elif e in learning_rate_schedule.keys():
                        lr = learning_rate_schedule[e]
                        print 'Changing learning rate to %.5f' % lr
                        learning_rate = lr

                widgets = ['Epoch {epoch} ({name}, '.format(epoch=e, name=name),
                           Timer(), '): ', Bar()]
                epoch_pbar = ProgressBar(widgets=widgets, maxval=train.n).start()

                continue

            if e > epochs:
                break

            rval = f_grad_shared(x)

            if check_bad_nums(rval, extra_outs_names):
                raise ValueError('Bad number!')

            if save_images and s % model_save_freq == 0:
                try:
                    x_v, _ = valid.next()
                except StopIteration:
                    x_v, _ = valid.next()

                pd_v, d_hat_v = f_sample(x_v)
                d_hat_s = np.concatenate([pd_v[:10],
                                          d_hat_v[1][None, :, :]], axis=0)
                d_hat_s = d_hat_s[:, :min(10, d_hat_s.shape[1] - 1)]
                train.save_images(d_hat_s, path.join(
                    out_path, '{name}_samples.png'.format(name=name)))

                pd_p = f_prior()
                train.save_images(
                    pd_p[:, None], path.join(
                        out_path,
                        '{name}_samples_from_prior.png'.format(name=name)),
                    x_limit=10
                )

            f_grad_updates(learning_rate)
            s += 1

    except KeyboardInterrupt:
        print 'Training interrupted'

    if out_path is not None:
        outfile = path.join(out_path, '{name}_{t}.npz'.format(name=name, t=int(time.time())))
        last_outfile = path.join(out_path, '{name}_last.npz'.format(name=name))

        print 'Saving'
        save(tparams, outfile)
        save(tparams, last_outfile)
        print 'Done saving.'

    print 'Bye bye!'
Beispiel #6
0
def classify(model_dir,
             n_inference_steps=20, n_inference_samples=20,
             dim_hs=[100], h_act='T.nnet.softplus',
             learning_rate=0.0001, learning_rate_schedule=None,
             dropout=0.1, batch_size=100, l2_decay=0.002,
             epochs=100,
             optimizer='rmsprop', optimizer_args=dict(),
             center_input=True, name='classifier'):
    out_path = model_dir

    inference_args = dict(
        inference_method='adaptive',
        inference_rate=0.1,
    )

    # ========================================================================
    print 'Loading model'

    model_file = glob(path.join(model_dir, '*best*npz'))[0]

    models, model_args = load_model(model_file, unpack_sbn, **inference_args)

    model = models['sbn']
    model.set_tparams()

    dataset = model_args['dataset']
    dataset_args = model_args['dataset_args']
    if dataset == 'mnist':
        dataset_args['binarize'] = True
        dataset_args['source'] = '/export/mialab/users/dhjelm/data/mnist.pkl.gz'

    train, valid, test = load_data(dataset, batch_size, batch_size, batch_size,
                                   **dataset_args)

    mlp_args = dict(
        dim_hs=dim_hs,
        h_act=h_act,
        dropout=dropout,
        out_act=train.acts['label']
    )

    X = T.matrix('x', dtype=floatX)
    Y = T.matrix('y', dtype=floatX)
    trng = RandomStreams(random.randint(0, 1000000))

    if center_input:
        print 'Centering input with train dataset mean image'
        X_mean = theano.shared(train.mean_image.astype(floatX), name='X_mean')
        X_i = X - X_mean
    else:
        X_i = X

    # ========================================================================
    print 'Loading MLP and forming graph'

    (qs, i_costs), _, updates = model.infer_q(
            X_i, X, n_inference_steps, n_inference_samples=n_inference_samples)

    q0 = qs[0]
    qk = qs[-1]

    constants = [q0, qk]
    dim_in = model.dim_h
    dim_out = train.dims['label']

    mlp0_args = deepcopy(mlp_args)
    mlp0 = MLP(dim_in, dim_out, name='classifier_0', **mlp0_args)
    mlpk_args = deepcopy(mlp_args)
    mlpk = MLP(dim_in, dim_out, name='classifier_k', **mlpk_args)
    mlpx_args = deepcopy(mlp_args)
    mlpx = MLP(train.dims[str(dataset)], dim_out, name='classifier_x', **mlpx_args)
    tparams = mlp0.set_tparams()
    tparams.update(**mlpk.set_tparams())
    tparams.update(**mlpx.set_tparams())

    print_profile(tparams)

    p0 = mlp0(q0)
    pk = mlpk(qk)
    px = mlpx(X_i)

    # ========================================================================
    print 'Getting cost'

    cost0 = mlp0.neg_log_prob(Y, p0).sum(axis=0)
    costk = mlpk.neg_log_prob(Y, pk).sum(axis=0)
    costx = mlpx.neg_log_prob(Y, px).sum(axis=0)

    cost = cost0 + costk + costx
    extra_outs = []
    extra_outs_names = ['cost']

    if l2_decay > 0.:
        print 'Adding %.5f L2 weight decay' % l2_decay
        mlp0_l2_cost = mlp0.get_L2_weight_cost(l2_decay)
        mlpk_l2_cost = mlpk.get_L2_weight_cost(l2_decay)
        mlpx_l2_cost = mlpx.get_L2_weight_cost(l2_decay)
        cost += mlp0_l2_cost + mlpk_l2_cost + mlpx_l2_cost
        extra_outs += [mlp0_l2_cost, mlpk_l2_cost, mlpx_l2_cost]
        extra_outs_names += ['MLP0 L2 cost', 'MLPk L2 cost', 'MLPx L2 cost']

    # ========================================================================
    print 'Extra functions'
    error0 = (Y * (1 - p0)).sum(1).mean()
    errork = (Y * (1 - pk)).sum(1).mean()
    errorx = (Y * (1 - px)).sum(1).mean()
    
    f_test_keys = ['Error 0', 'Error k', 'Error x', 'Cost 0', 'Cost k', 'Cost x']
    f_test = theano.function([X, Y], [error0, errork, errorx, cost0, costk, costx])
    
    # ========================================================================
    print 'Setting final tparams and save function'

    all_params = OrderedDict((k, v) for k, v in tparams.iteritems())

    tparams = OrderedDict((k, v)
        for k, v in tparams.iteritems()
        if (v not in updates.keys() or v not in excludes))

    print 'Learned model params: %s' % tparams.keys()
    print 'Saved params: %s' % all_params.keys()

    def save(tparams, outfile):
        d = dict((k, v.get_value()) for k, v in all_params.items())

        d.update(
            dim_in=dim_in,
            dim_out=dim_out,
            dataset=dataset, dataset_args=dataset_args,
            **mlp_args
        )
        np.savez(outfile, **d)

     # ========================================================================
    print 'Getting gradients.'
    grads = T.grad(cost, wrt=itemlist(tparams),
                   consider_constant=constants)

    # ========================================================================
    print 'Building optimizer'
    lr = T.scalar(name='lr')
    f_grad_shared, f_grad_updates = eval('op.' + optimizer)(
        lr, tparams, grads, [X, Y], cost, extra_ups=updates,
        extra_outs=extra_outs, **optimizer_args)

    monitor = SimpleMonitor()

    try:
        epoch_t0 = time.time()
        s = 0
        e = 0

        widgets = ['Epoch {epoch} ({name}, '.format(epoch=e, name=name),
                   Timer(), '): ', Bar()]
        epoch_pbar = ProgressBar(widgets=widgets, maxval=train.n).start()
        training_time = 0

        while True:
            try:
                x, y = train.next()
                
                if train.pos == -1:
                    epoch_pbar.update(train.n)
                else:
                    epoch_pbar.update(train.pos)

            except StopIteration:
                print
                epoch_t1 = time.time()
                training_time += (epoch_t1 - epoch_t0)
                valid.reset()

                widgets = ['Validating: ',
                          Percentage(), ' (', Timer(), ')']
                pbar    = ProgressBar(widgets=widgets, maxval=valid.n).start()
                results_train = OrderedDict()
                results_valid = OrderedDict()
                while True:
                    try:
                        x_valid, y_valid = valid.next()
                        x_train, y_train = train.next()

                        r_train = f_test(x_train, y_train)
                        r_valid = f_test(x_valid, y_valid)
                        results_i_train = dict((k, v) for k, v in zip(f_test_keys, r_train))
                        results_i_valid = dict((k, v) for k, v in zip(f_test_keys, r_valid))
                        update_dict_of_lists(results_train, **results_i_train)
                        update_dict_of_lists(results_valid, **results_i_valid)

                        if valid.pos == -1:
                            pbar.update(valid.n)
                        else:
                            pbar.update(valid.pos)

                    except StopIteration:
                        print
                        break

                def summarize(d):
                    for k, v in d.iteritems():
                        d[k] = np.mean(v)

                summarize(results_train)
                summarize(results_valid)

                monitor.update(**results_train)
                monitor.update(dt_epoch=(epoch_t1-epoch_t0),
                               training_time=training_time)
                monitor.update_valid(**results_valid)
                monitor.display()

                monitor.save(path.join(
                    out_path, '{name}_monitor.png').format(name=name))
                monitor.save_stats(path.join(
                    out_path, '{name}_monitor.npz').format(name=name))
                monitor.save_stats_valid(path.join(
                    out_path, '{name}_monitor_valid.npz').format(name=name))

                e += 1
                epoch_t0 = time.time()

                valid.reset()
                train.reset()

                if learning_rate_schedule is not None:
                    if e in learning_rate_schedule.keys():
                        lr = learning_rate_schedule[e]
                        print 'Changing learning rate to %.5f' % lr
                        learning_rate = lr

                widgets = ['Epoch {epoch} ({name}, '.format(epoch=e, name=name),
                           Timer(), '): ', Bar()]
                epoch_pbar = ProgressBar(widgets=widgets, maxval=train.n).start()

                continue

            if e > epochs:
                break

            rval = f_grad_shared(x, y)

            if check_bad_nums(rval, extra_outs_names):
                print rval
                print np.any(np.isnan(mlpk.W0.get_value()))
                print np.any(np.isnan(mlpk.b0.get_value()))
                print np.any(np.isnan(mlpk.W1.get_value()))
                print np.any(np.isnan(mlpk.b1.get_value()))
                raise ValueError('Bad number!')

            f_grad_updates(learning_rate)
            s += 1

    except KeyboardInterrupt:
        print 'Training interrupted'

    test.reset()

    widgets = ['Testing: ',
               Percentage(), ' (', Timer(), ')']
    pbar    = ProgressBar(widgets=widgets, maxval=test.n).start()
    results_test = OrderedDict()
    while True:
        try:
            x_test, y_test = test.next()
            r_test = f_test(x_test, y_test)
            results_i_test = dict((k, v) for k, v in zip(f_test_keys, r_test))
            update_dict_of_lists(results_test, **results_i_test)
            if test.pos == -1:
                pbar.update(test.n)
            else:
                pbar.update(test.pos)

        except StopIteration:
            print
            break

    def summarize(d):
        for k, v in d.iteritems():
            d[k] = np.mean(v)

    summarize(results_test)
    print 'Test results:'
    monitor.simple_display(results_test)

    if out_path is not None:
        outfile = path.join(out_path, '{name}_{t}.npz'.format(name=name, t=int(time.time())))
        last_outfile = path.join(out_path, '{name}_last.npz'.format(name=name))

        print 'Saving'
        save(tparams, outfile)
        save(tparams, last_outfile)
        print 'Done saving.'

    print 'Bye bye!'