def view_checkpoints(model_dir, sigma, imsz=[28, 28], figid=101): """ checkpoint files should have a name matching the following: <model_dir>/checkpoint_<sigma>_<iter>.pdata """ prefix = '%s/checkpoint_%s' % (model_dir, str(sigma)) checkpoint_numbers = sorted([ int(fpath.split('.')[0].split('_')[-1]) for fpath in os.listdir(model_dir) if fpath.startswith('checkpoint_%s' % str(sigma)) ]) net = gen.StochasticGenerativeNet() plt.figure(figid, figsize=(10, 8)) ax = plt.subplot(111) for i in checkpoint_numbers: net.load_model_from_file(prefix + '_%d.pdata' % i) w = net.layers[-1].params.W.asarray() ax.cla() vt.bwpatchview(w[:400], imsz, int(np.sqrt(w[:400].shape[0])), rowmajor=True, gridintensity=1, ax=ax) plt.draw() plt.show() print 'Checkpoint %d' % i time.sleep(0.04)
def mnist_mmd_input_space(n_hids=[10, 64, 256, 256, 1024], sigma=[2, 5, 10, 20, 40, 80], learn_rate=2, momentum=0.9): """ n_hids: number of hidden units on all layers (top-down) in the generative network. sigma: a list of scales used for the kernel learn_rate, momentum: parameters for the learning process return: KDE log_likelihood on validation set. """ gnp.seed_rand(8) x_train, x_val, x_test = mnistio.load_data() print '' print 'Training data: %d x %d' % x_train.shape in_dim = n_hids[0] out_dim = x_train.shape[1] net = gen.StochasticGenerativeNet(in_dim, out_dim) for i in range(1, len(n_hids)): net.add_layer(n_hids[i], nonlin_type=ly.NONLIN_NAME_RELU, dropout=0) net.add_layer(0, nonlin_type=ly.NONLIN_NAME_SIGMOID, dropout=0) # place holder loss net.set_loss(ls.LOSS_NAME_MMDGEN, loss_after_nonlin=True, sigma=80, loss_weight=1000) print '' print '========' print 'Training' print '========' print '' print net print '' mmd_learner = gen.StochasticGenerativeNetLearner(net) mmd_learner.load_data(x_train) output_base = OUTPUT_BASE_DIR + '/mnist/input_space' #sigma = [2,5,10,20,40,80] sigma_weights = [1, 1, 1, 1, 1, 1] #learn_rate = 1 #momentum = 0.9 minibatch_size = 1000 n_sample_update_iters = 1 max_iters = 40000 i_checkpoint = 2000 output_dir = output_base + '/nhids_%s_sigma_%s_lr_%s_m_%s' % ('_'.join([ str(nh) for nh in n_hids ]), '_'.join([str(s) for s in sigma]), str(learn_rate), str(momentum)) print '' print '>>>> output_dir = %s' % output_dir print '' mmd_learner.set_output_dir(output_dir) #net.set_loss(ls.LOSS_NAME_MMDGEN_MULTISCALE, loss_after_nonlin=True, sigma=sigma, scale_weight=sigma_weights, loss_weight=1000) net.set_loss(ls.LOSS_NAME_MMDGEN_SQRT_GAUSSIAN, loss_after_nonlin=True, sigma=sigma, scale_weight=sigma_weights, loss_weight=1000) print '**********************************' print net.loss print '**********************************' print '' def f_checkpoint(i_iter, w): mmd_learner.save_checkpoint('%d' % i_iter) mmd_learner.train_sgd(minibatch_size=minibatch_size, n_samples_per_update=minibatch_size, n_sample_update_iters=n_sample_update_iters, learn_rate=learn_rate, momentum=momentum, weight_decay=0, learn_rate_schedule={10000: learn_rate / 10.0}, momentum_schedule={10000: 1 - (1 - momentum) / 10.0}, learn_rate_drop_iters=0, decrease_type='linear', adagrad_start_iter=0, max_iters=max_iters, iprint=100, i_exe=i_checkpoint, f_exe=f_checkpoint) mmd_learner.save_model() print '' print '====================' print 'Evaluating the model' print '====================' print '' log_prob, std, sigma = ev.kde_eval_mnist(net, x_val, verbose=False) test_log_prob, test_std, _ = ev.kde_eval_mnist(net, x_test, sigma_range=[sigma], verbose=False) print 'Validation: %.2f (%.2f)' % (log_prob, std) print 'Test : %.2f (%.2f)' % (test_log_prob, test_std) print '' write_config( output_dir + '/params_and_results.cfg', { 'n_hids': n_hids, 'sigma': sigma, 'sigma_weights': sigma_weights, 'learn_rate': learn_rate, 'momentum': momentum, 'minibatch_size': minibatch_size, 'n_sample_update_iters': n_sample_update_iters, 'max_iters': max_iters, 'i_checkpoint': i_checkpoint, 'val_log_prob': log_prob, 'val_std': std, 'test_log_prob': test_log_prob, 'test_std': test_std }) print '>>>> output_dir = %s' % output_dir print '' return log_prob
def get_tfd_input_space_model(): net = gen.StochasticGenerativeNet() net.load_model_from_file(BEST_TFD_INPUT_SPACE_MODEL) return net