예제 #1
0
def view_checkpoints(model_dir, sigma, imsz=[28, 28], figid=101):
    """
    checkpoint files should have a name matching the following:
    <model_dir>/checkpoint_<sigma>_<iter>.pdata
    """
    prefix = '%s/checkpoint_%s' % (model_dir, str(sigma))
    checkpoint_numbers = sorted([
        int(fpath.split('.')[0].split('_')[-1])
        for fpath in os.listdir(model_dir)
        if fpath.startswith('checkpoint_%s' % str(sigma))
    ])

    net = gen.StochasticGenerativeNet()

    plt.figure(figid, figsize=(10, 8))
    ax = plt.subplot(111)

    for i in checkpoint_numbers:
        net.load_model_from_file(prefix + '_%d.pdata' % i)
        w = net.layers[-1].params.W.asarray()
        ax.cla()
        vt.bwpatchview(w[:400],
                       imsz,
                       int(np.sqrt(w[:400].shape[0])),
                       rowmajor=True,
                       gridintensity=1,
                       ax=ax)
        plt.draw()
        plt.show()
        print 'Checkpoint %d' % i
        time.sleep(0.04)
예제 #2
0
파일: train.py 프로젝트: ykwon0407/gmmn
def mnist_mmd_input_space(n_hids=[10, 64, 256, 256, 1024],
                          sigma=[2, 5, 10, 20, 40, 80],
                          learn_rate=2,
                          momentum=0.9):
    """
    n_hids: number of hidden units on all layers (top-down) in the generative network.
    sigma: a list of scales used for the kernel
    learn_rate, momentum: parameters for the learning process

    return: KDE log_likelihood on validation set.
    """
    gnp.seed_rand(8)

    x_train, x_val, x_test = mnistio.load_data()

    print ''
    print 'Training data: %d x %d' % x_train.shape

    in_dim = n_hids[0]
    out_dim = x_train.shape[1]

    net = gen.StochasticGenerativeNet(in_dim, out_dim)
    for i in range(1, len(n_hids)):
        net.add_layer(n_hids[i], nonlin_type=ly.NONLIN_NAME_RELU, dropout=0)
    net.add_layer(0, nonlin_type=ly.NONLIN_NAME_SIGMOID, dropout=0)

    # place holder loss
    net.set_loss(ls.LOSS_NAME_MMDGEN,
                 loss_after_nonlin=True,
                 sigma=80,
                 loss_weight=1000)

    print ''
    print '========'
    print 'Training'
    print '========'
    print ''
    print net
    print ''

    mmd_learner = gen.StochasticGenerativeNetLearner(net)
    mmd_learner.load_data(x_train)

    output_base = OUTPUT_BASE_DIR + '/mnist/input_space'

    #sigma = [2,5,10,20,40,80]
    sigma_weights = [1, 1, 1, 1, 1, 1]
    #learn_rate = 1
    #momentum = 0.9

    minibatch_size = 1000
    n_sample_update_iters = 1
    max_iters = 40000
    i_checkpoint = 2000

    output_dir = output_base + '/nhids_%s_sigma_%s_lr_%s_m_%s' % ('_'.join([
        str(nh) for nh in n_hids
    ]), '_'.join([str(s) for s in sigma]), str(learn_rate), str(momentum))

    print ''
    print '>>>> output_dir = %s' % output_dir
    print ''

    mmd_learner.set_output_dir(output_dir)
    #net.set_loss(ls.LOSS_NAME_MMDGEN_MULTISCALE, loss_after_nonlin=True, sigma=sigma, scale_weight=sigma_weights, loss_weight=1000)
    net.set_loss(ls.LOSS_NAME_MMDGEN_SQRT_GAUSSIAN,
                 loss_after_nonlin=True,
                 sigma=sigma,
                 scale_weight=sigma_weights,
                 loss_weight=1000)

    print '**********************************'
    print net.loss
    print '**********************************'
    print ''

    def f_checkpoint(i_iter, w):
        mmd_learner.save_checkpoint('%d' % i_iter)

    mmd_learner.train_sgd(minibatch_size=minibatch_size,
                          n_samples_per_update=minibatch_size,
                          n_sample_update_iters=n_sample_update_iters,
                          learn_rate=learn_rate,
                          momentum=momentum,
                          weight_decay=0,
                          learn_rate_schedule={10000: learn_rate / 10.0},
                          momentum_schedule={10000: 1 - (1 - momentum) / 10.0},
                          learn_rate_drop_iters=0,
                          decrease_type='linear',
                          adagrad_start_iter=0,
                          max_iters=max_iters,
                          iprint=100,
                          i_exe=i_checkpoint,
                          f_exe=f_checkpoint)

    mmd_learner.save_model()

    print ''
    print '===================='
    print 'Evaluating the model'
    print '===================='
    print ''

    log_prob, std, sigma = ev.kde_eval_mnist(net, x_val, verbose=False)
    test_log_prob, test_std, _ = ev.kde_eval_mnist(net,
                                                   x_test,
                                                   sigma_range=[sigma],
                                                   verbose=False)

    print 'Validation: %.2f (%.2f)' % (log_prob, std)
    print 'Test      : %.2f (%.2f)' % (test_log_prob, test_std)
    print ''

    write_config(
        output_dir + '/params_and_results.cfg', {
            'n_hids': n_hids,
            'sigma': sigma,
            'sigma_weights': sigma_weights,
            'learn_rate': learn_rate,
            'momentum': momentum,
            'minibatch_size': minibatch_size,
            'n_sample_update_iters': n_sample_update_iters,
            'max_iters': max_iters,
            'i_checkpoint': i_checkpoint,
            'val_log_prob': log_prob,
            'val_std': std,
            'test_log_prob': test_log_prob,
            'test_std': test_std
        })

    print '>>>> output_dir = %s' % output_dir
    print ''

    return log_prob
예제 #3
0
def get_tfd_input_space_model():
    net = gen.StochasticGenerativeNet()
    net.load_model_from_file(BEST_TFD_INPUT_SPACE_MODEL)
    return net