Exemplo n.º 1
0
def classify(model_dir,
             n_inference_steps=20, n_inference_samples=20,
             dim_hs=[100], h_act='T.nnet.softplus',
             learning_rate=0.0001, learning_rate_schedule=None,
             dropout=0.1, batch_size=100, l2_decay=0.002,
             epochs=100,
             optimizer='rmsprop', optimizer_args=dict(),
             center_input=True, name='classifier'):
    out_path = model_dir

    inference_args = dict(
        inference_method='adaptive',
        inference_rate=0.1,
    )

    # ========================================================================
    print 'Loading model'

    model_file = glob(path.join(model_dir, '*best*npz'))[0]

    models, model_args = load_model(model_file, unpack_sbn, **inference_args)

    model = models['sbn']
    model.set_tparams()

    dataset = model_args['dataset']
    dataset_args = model_args['dataset_args']
    if dataset == 'mnist':
        dataset_args['binarize'] = True
        dataset_args['source'] = '/export/mialab/users/dhjelm/data/mnist.pkl.gz'

    train, valid, test = load_data(dataset, batch_size, batch_size, batch_size,
                                   **dataset_args)

    mlp_args = dict(
        dim_hs=dim_hs,
        h_act=h_act,
        dropout=dropout,
        out_act=train.acts['label']
    )

    X = T.matrix('x', dtype=floatX)
    Y = T.matrix('y', dtype=floatX)
    trng = RandomStreams(random.randint(0, 1000000))

    if center_input:
        print 'Centering input with train dataset mean image'
        X_mean = theano.shared(train.mean_image.astype(floatX), name='X_mean')
        X_i = X - X_mean
    else:
        X_i = X

    # ========================================================================
    print 'Loading MLP and forming graph'

    (qs, i_costs), _, updates = model.infer_q(
            X_i, X, n_inference_steps, n_inference_samples=n_inference_samples)

    q0 = qs[0]
    qk = qs[-1]

    constants = [q0, qk]
    dim_in = model.dim_h
    dim_out = train.dims['label']

    mlp0_args = deepcopy(mlp_args)
    mlp0 = MLP(dim_in, dim_out, name='classifier_0', **mlp0_args)
    mlpk_args = deepcopy(mlp_args)
    mlpk = MLP(dim_in, dim_out, name='classifier_k', **mlpk_args)
    mlpx_args = deepcopy(mlp_args)
    mlpx = MLP(train.dims[str(dataset)], dim_out, name='classifier_x', **mlpx_args)
    tparams = mlp0.set_tparams()
    tparams.update(**mlpk.set_tparams())
    tparams.update(**mlpx.set_tparams())

    print_profile(tparams)

    p0 = mlp0(q0)
    pk = mlpk(qk)
    px = mlpx(X_i)

    # ========================================================================
    print 'Getting cost'

    cost0 = mlp0.neg_log_prob(Y, p0).sum(axis=0)
    costk = mlpk.neg_log_prob(Y, pk).sum(axis=0)
    costx = mlpx.neg_log_prob(Y, px).sum(axis=0)

    cost = cost0 + costk + costx
    extra_outs = []
    extra_outs_names = ['cost']

    if l2_decay > 0.:
        print 'Adding %.5f L2 weight decay' % l2_decay
        mlp0_l2_cost = mlp0.get_L2_weight_cost(l2_decay)
        mlpk_l2_cost = mlpk.get_L2_weight_cost(l2_decay)
        mlpx_l2_cost = mlpx.get_L2_weight_cost(l2_decay)
        cost += mlp0_l2_cost + mlpk_l2_cost + mlpx_l2_cost
        extra_outs += [mlp0_l2_cost, mlpk_l2_cost, mlpx_l2_cost]
        extra_outs_names += ['MLP0 L2 cost', 'MLPk L2 cost', 'MLPx L2 cost']

    # ========================================================================
    print 'Extra functions'
    error0 = (Y * (1 - p0)).sum(1).mean()
    errork = (Y * (1 - pk)).sum(1).mean()
    errorx = (Y * (1 - px)).sum(1).mean()
    
    f_test_keys = ['Error 0', 'Error k', 'Error x', 'Cost 0', 'Cost k', 'Cost x']
    f_test = theano.function([X, Y], [error0, errork, errorx, cost0, costk, costx])
    
    # ========================================================================
    print 'Setting final tparams and save function'

    all_params = OrderedDict((k, v) for k, v in tparams.iteritems())

    tparams = OrderedDict((k, v)
        for k, v in tparams.iteritems()
        if (v not in updates.keys() or v not in excludes))

    print 'Learned model params: %s' % tparams.keys()
    print 'Saved params: %s' % all_params.keys()

    def save(tparams, outfile):
        d = dict((k, v.get_value()) for k, v in all_params.items())

        d.update(
            dim_in=dim_in,
            dim_out=dim_out,
            dataset=dataset, dataset_args=dataset_args,
            **mlp_args
        )
        np.savez(outfile, **d)

     # ========================================================================
    print 'Getting gradients.'
    grads = T.grad(cost, wrt=itemlist(tparams),
                   consider_constant=constants)

    # ========================================================================
    print 'Building optimizer'
    lr = T.scalar(name='lr')
    f_grad_shared, f_grad_updates = eval('op.' + optimizer)(
        lr, tparams, grads, [X, Y], cost, extra_ups=updates,
        extra_outs=extra_outs, **optimizer_args)

    monitor = SimpleMonitor()

    try:
        epoch_t0 = time.time()
        s = 0
        e = 0

        widgets = ['Epoch {epoch} ({name}, '.format(epoch=e, name=name),
                   Timer(), '): ', Bar()]
        epoch_pbar = ProgressBar(widgets=widgets, maxval=train.n).start()
        training_time = 0

        while True:
            try:
                x, y = train.next()
                
                if train.pos == -1:
                    epoch_pbar.update(train.n)
                else:
                    epoch_pbar.update(train.pos)

            except StopIteration:
                print
                epoch_t1 = time.time()
                training_time += (epoch_t1 - epoch_t0)
                valid.reset()

                widgets = ['Validating: ',
                          Percentage(), ' (', Timer(), ')']
                pbar    = ProgressBar(widgets=widgets, maxval=valid.n).start()
                results_train = OrderedDict()
                results_valid = OrderedDict()
                while True:
                    try:
                        x_valid, y_valid = valid.next()
                        x_train, y_train = train.next()

                        r_train = f_test(x_train, y_train)
                        r_valid = f_test(x_valid, y_valid)
                        results_i_train = dict((k, v) for k, v in zip(f_test_keys, r_train))
                        results_i_valid = dict((k, v) for k, v in zip(f_test_keys, r_valid))
                        update_dict_of_lists(results_train, **results_i_train)
                        update_dict_of_lists(results_valid, **results_i_valid)

                        if valid.pos == -1:
                            pbar.update(valid.n)
                        else:
                            pbar.update(valid.pos)

                    except StopIteration:
                        print
                        break

                def summarize(d):
                    for k, v in d.iteritems():
                        d[k] = np.mean(v)

                summarize(results_train)
                summarize(results_valid)

                monitor.update(**results_train)
                monitor.update(dt_epoch=(epoch_t1-epoch_t0),
                               training_time=training_time)
                monitor.update_valid(**results_valid)
                monitor.display()

                monitor.save(path.join(
                    out_path, '{name}_monitor.png').format(name=name))
                monitor.save_stats(path.join(
                    out_path, '{name}_monitor.npz').format(name=name))
                monitor.save_stats_valid(path.join(
                    out_path, '{name}_monitor_valid.npz').format(name=name))

                e += 1
                epoch_t0 = time.time()

                valid.reset()
                train.reset()

                if learning_rate_schedule is not None:
                    if e in learning_rate_schedule.keys():
                        lr = learning_rate_schedule[e]
                        print 'Changing learning rate to %.5f' % lr
                        learning_rate = lr

                widgets = ['Epoch {epoch} ({name}, '.format(epoch=e, name=name),
                           Timer(), '): ', Bar()]
                epoch_pbar = ProgressBar(widgets=widgets, maxval=train.n).start()

                continue

            if e > epochs:
                break

            rval = f_grad_shared(x, y)

            if check_bad_nums(rval, extra_outs_names):
                print rval
                print np.any(np.isnan(mlpk.W0.get_value()))
                print np.any(np.isnan(mlpk.b0.get_value()))
                print np.any(np.isnan(mlpk.W1.get_value()))
                print np.any(np.isnan(mlpk.b1.get_value()))
                raise ValueError('Bad number!')

            f_grad_updates(learning_rate)
            s += 1

    except KeyboardInterrupt:
        print 'Training interrupted'

    test.reset()

    widgets = ['Testing: ',
               Percentage(), ' (', Timer(), ')']
    pbar    = ProgressBar(widgets=widgets, maxval=test.n).start()
    results_test = OrderedDict()
    while True:
        try:
            x_test, y_test = test.next()
            r_test = f_test(x_test, y_test)
            results_i_test = dict((k, v) for k, v in zip(f_test_keys, r_test))
            update_dict_of_lists(results_test, **results_i_test)
            if test.pos == -1:
                pbar.update(test.n)
            else:
                pbar.update(test.pos)

        except StopIteration:
            print
            break

    def summarize(d):
        for k, v in d.iteritems():
            d[k] = np.mean(v)

    summarize(results_test)
    print 'Test results:'
    monitor.simple_display(results_test)

    if out_path is not None:
        outfile = path.join(out_path, '{name}_{t}.npz'.format(name=name, t=int(time.time())))
        last_outfile = path.join(out_path, '{name}_last.npz'.format(name=name))

        print 'Saving'
        save(tparams, outfile)
        save(tparams, last_outfile)
        print 'Done saving.'

    print 'Bye bye!'
Exemplo n.º 2
0
def test_make_mlp(dim_in=13, dim_h=17, dim_out=19, n_layers=2, h_act='T.nnet.softplus', out_act='T.nnet.sigmoid'):
    mlp = MLP(dim_in, dim_h, dim_out, n_layers, h_act=h_act, out_act=out_act)
    mlp.set_tparams()
    return mlp