def test_two_stage_model2():
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 128
    batch_reps = 1

    ###############################################
    # Setup some parameters for the TwoStageModel #
    ###############################################
    x_dim = Xtr.shape[1]
    z_dim = 50
    h_dim = 100
    x_type = 'bernoulli'

    # some InfNet instances to build the TwoStageModel from
    xin_sym = T.matrix('xin_sym')
    xout_sym = T.matrix('xout_sym')

    ###############
    # p_h_given_z #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'fc',
       'in_chans': z_dim,
       'out_chans': 100,
       'activation': tanh_actfun,
       'apply_bn': True}, \
      {'layer_type': 'fc',
       'in_chans': 100,
       'out_chans': 100,
       'activation': tanh_actfun,
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'fc',
       'in_chans': 100,
       'out_chans': h_dim,
       'activation': tanh_actfun,
       'apply_bn': False}, \
      {'layer_type': 'fc',
       'in_chans': 100,
       'out_chans': h_dim,
       'activation': tanh_actfun,
       'apply_bn': False} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    p_h_given_z = HydraNet(rng=rng, Xd=xin_sym,
            params=params, shared_param_dicts=None)
    p_h_given_z.init_biases(0.0)
    ###############
    # p_x_given_h #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'fc',
       'in_chans': h_dim,
       'out_chans': 200,
       'activation': tanh_actfun,
       'apply_bn': True}, \
      {'layer_type': 'fc',
       'in_chans': 200,
       'out_chans': 200,
       'activation': tanh_actfun,
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'fc',
       'in_chans': 200,
       'out_chans': x_dim,
       'activation': tanh_actfun,
       'apply_bn': False}, \
      {'layer_type': 'fc',
       'in_chans': 200,
       'out_chans': x_dim,
       'activation': tanh_actfun,
       'apply_bn': False} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    p_x_given_h = HydraNet(rng=rng, Xd=xin_sym,
            params=params, shared_param_dicts=None)
    p_x_given_h.init_biases(0.0)
    ###############
    # q_h_given_x #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'fc',
       'in_chans': x_dim,
       'out_chans': 200,
       'activation': tanh_actfun,
       'apply_bn': True}, \
      {'layer_type': 'fc',
       'in_chans': 200,
       'out_chans': 200,
       'activation': tanh_actfun,
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'fc',
       'in_chans': 200,
       'out_chans': h_dim,
       'activation': tanh_actfun,
       'apply_bn': False}, \
      {'layer_type': 'fc',
       'in_chans': 200,
       'out_chans': h_dim,
       'activation': tanh_actfun,
       'apply_bn': False} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    q_h_given_x = HydraNet(rng=rng, Xd=xin_sym,
            params=params, shared_param_dicts=None)
    q_h_given_x.init_biases(0.0)
    ###############
    # q_z_given_h #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'fc',
       'in_chans': h_dim,
       'out_chans': 100,
       'activation': tanh_actfun,
       'apply_bn': True}, \
      {'layer_type': 'fc',
       'in_chans': 100,
       'out_chans': 100,
       'activation': tanh_actfun,
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'fc',
       'in_chans': 100,
       'out_chans': z_dim,
       'activation': tanh_actfun,
       'apply_bn': False}, \
      {'layer_type': 'fc',
       'in_chans': 100,
       'out_chans': z_dim,
       'activation': tanh_actfun,
       'apply_bn': False} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    q_z_given_h = HydraNet(rng=rng, Xd=xin_sym,
            params=params, shared_param_dicts=None)
    q_z_given_h.init_biases(0.0)

    ##############################################################
    # Define parameters for the TwoStageModel, and initialize it #
    ##############################################################
    print("Building the TwoStageModel...")
    tsm_params = {}
    tsm_params['x_type'] = x_type
    tsm_params['obs_transform'] = 'sigmoid'
    TSM = TwoStageModel2(rng=rng, x_in=xin_sym, x_out=xout_sym,
            x_dim=x_dim, z_dim=z_dim, h_dim=h_dim,
            q_h_given_x=q_h_given_x,
            q_z_given_h=q_z_given_h,
            p_h_given_z=p_h_given_z,
            p_x_given_h=p_x_given_h,
            params=tsm_params)

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    log_name = "{}_RESULTS.txt".format("TSM2A_TEST")
    out_file = open(log_name, 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.001
    momentum = 0.9
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(500000):
        scale = min(1.0, ((i+1) / 5000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        Xb = to_fX( Xtr.take(batch_idx, axis=0) )
        #Xb = binarize_data(Xtr.take(batch_idx, axis=0))
        # set sgd and objective function hyperparams for this update
        TSM.set_sgd_params(lr=scale*learn_rate,
                           mom_1=(scale*momentum), mom_2=0.98)
        TSM.set_train_switch(1.0)
        TSM.set_lam_nll(lam_nll=1.0)
        TSM.set_lam_kld(lam_kld_q2p=1.0, lam_kld_p2q=0.0)
        TSM.set_lam_l2w(1e-5)
        # perform a minibatch update and record the cost for this batch
        result = TSM.train_joint(Xb, Xb, batch_reps)
        costs = [(costs[j] + result[j]) for j in range(len(result))]
        if ((i % 500) == 0):
            costs = [(v / 500.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    joint_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_cost  : {0:.4f}".format(costs[1])
            str4 = "    kld_cost  : {0:.4f}".format(costs[2])
            str5 = "    reg_cost  : {0:.4f}".format(costs[3])
            str6 = "    nll       : {0:.4f}".format(np.mean(costs[4]))
            str7 = "    kld_z     : {0:.4f}".format(np.mean(costs[5]))
            str8 = "    kld_h     : {0:.4f}".format(np.mean(costs[6]))
            joint_str = "\n".join([str1, str2, str3, str4, str5, str6, str7, str8])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if (((i % 5000) == 0) or ((i < 10000) and ((i % 1000) == 0))):
            # draw some independent random samples from the model
            samp_count = 300
            model_samps = TSM.sample_from_prior(samp_count)
            file_name = "TSM2A_SAMPLES_b{0:d}.png".format(i)
            utils.visualize_samples(model_samps, file_name, num_rows=15)
            # compute free energy estimate for validation samples
            Xva = row_shuffle(Xva)
            fe_terms = TSM.compute_fe_terms(Xva[0:5000], Xva[0:5000], 20)
            fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1])
            out_str = "    nll_bound : {0:.4f}".format(fe_mean)
            print(out_str)
            out_file.write(out_str+"\n")
            out_file.flush()
    return
示例#2
0
def test_with_model_init():
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 200
    batch_reps = 1

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    obs_dim = Xtr.shape[1]
    z_dim = 20
    h_dim = 200
    ir_steps = 6
    init_scale = 1.0
    
    x_type = 'bernoulli'

    # some InfNet instances to build the TwoStageModel from
    x_in_sym = T.matrix('x_in_sym')
    x_out_sym = T.matrix('x_out_sym')

    #################
    # p_hi_given_si #
    #################
    params = {}
    shared_config = [obs_dim, 300, 300]
    top_config = [shared_config[-1], h_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = relu_actfun
    params['init_scale'] = init_scale
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_hi_given_si = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    p_hi_given_si.init_biases(0.2)
    ######################
    # p_sip1_given_si_hi #
    ######################
    params = {}
    shared_config = [h_dim, 300, 300]
    output_config = [obs_dim, obs_dim, obs_dim]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['activation'] = relu_actfun
    params['init_scale'] = init_scale
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_sip1_given_si_hi = HydraNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    p_sip1_given_si_hi.init_biases(0.2)
    ################
    # p_s0_given_z #
    ################
    params = {}
    shared_config = [z_dim, 250, 250]
    top_config = [shared_config[-1], obs_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = relu_actfun
    params['init_scale'] = init_scale
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_s0_given_z = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    p_s0_given_z.init_biases(0.2)
    ###############
    # q_z_given_x #
    ###############
    params = {}
    shared_config = [obs_dim, 250, 250]
    top_config = [shared_config[-1], z_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = relu_actfun
    params['init_scale'] = init_scale
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    q_z_given_x = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    q_z_given_x.init_biases(0.2)
    ###################
    # q_hi_given_x_si #
    ###################
    params = {}
    shared_config = [(obs_dim + obs_dim), 500, 500]
    top_config = [shared_config[-1], h_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = relu_actfun
    params['init_scale'] = init_scale
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    q_hi_given_x_si = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    q_hi_given_x_si.init_biases(0.2)


    ################################################################
    # Define parameters for the MultiStageModel, and initialize it #
    ################################################################
    print("Building the MultiStageModel...")
    msm_params = {}
    msm_params['x_type'] = x_type
    msm_params['obs_transform'] = 'sigmoid'
    MSM = MultiStageModel(rng=rng, x_in=x_in_sym, x_out=x_out_sym, \
            p_s0_given_z=p_s0_given_z, \
            p_hi_given_si=p_hi_given_si, \
            p_sip1_given_si_hi=p_sip1_given_si_hi, \
            q_z_given_x=q_z_given_x, \
            q_hi_given_x_si=q_hi_given_x_si, \
            obs_dim=obs_dim, z_dim=z_dim, h_dim=h_dim, \
            ir_steps=ir_steps, params=msm_params)

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    out_file = open("MSM_A_RESULTS.txt", 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0003
    momentum = 0.9
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(250000):
        scale = min(1.0, ((i+1) / 3000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        # set sgd and objective function hyperparams for this update
        MSM.set_sgd_params(lr_1=scale*learn_rate, lr_2=scale*learn_rate, \
                mom_1=scale*momentum, mom_2=0.99)
        MSM.set_train_switch(1.0)
        MSM.set_lam_nll(lam_nll=1.0)
        MSM.set_lam_kld(lam_kld_z=1.0, lam_kld_q2p=0.8, lam_kld_p2q=0.2)
        MSM.set_lam_kld_l1l2(lam_kld_l1l2=1.0)
        MSM.set_lam_l2w(1e-4)
        MSM.set_drop_rate(0.0)
        MSM.q_hi_given_x_si.set_bias_noise(0.0)
        MSM.p_hi_given_si.set_bias_noise(0.0)
        MSM.p_sip1_given_si_hi.set_bias_noise(0.0)
        # perform a minibatch update and record the cost for this batch
        Xb_tr = to_fX( Xtr.take(batch_idx, axis=0) )
        result = MSM.train_joint(Xb_tr, Xb_tr, batch_reps)
        costs = [(costs[j] + result[j]) for j in range(len(result)-1)]
        if ((i % 500) == 0):
            costs = [(v / 500.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    joint_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_cost  : {0:.4f}".format(costs[1])
            str4 = "    kld_cost  : {0:.4f}".format(costs[2])
            str5 = "    reg_cost  : {0:.4f}".format(costs[3])
            joint_str = "\n".join([str1, str2, str3, str4, str5])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if (((i % 2000) == 0) or ((i < 10000) and ((i % 1000) == 0))):
            MSM.set_drop_rate(0.0)
            MSM.q_hi_given_x_si.set_bias_noise(0.0)
            MSM.p_hi_given_si.set_bias_noise(0.0)
            MSM.p_sip1_given_si_hi.set_bias_noise(0.0)
            # Get some validation samples for computing diagnostics
            Xva = row_shuffle(Xva)
            Xb_va = to_fX( Xva[0:2000] )
            # draw some independent random samples from the model
            samp_count = 200
            model_samps = MSM.sample_from_prior(samp_count)
            seq_len = len(model_samps)
            seq_samps = np.zeros((seq_len*samp_count, model_samps[0].shape[1]))
            idx = 0
            for s1 in range(samp_count):
                for s2 in range(seq_len):
                    seq_samps[idx] = model_samps[s2][s1]
                    idx += 1
            file_name = "MSM_A_SAMPLES_IND_b{0:d}.png".format(i)
            utils.visualize_samples(seq_samps, file_name, num_rows=20)
            # draw some conditional random samples from the model
            samp_count = 200
            Xs = np.vstack((Xb_tr[0:(samp_count/4)], Xb_va[0:(samp_count/4)]))
            Xs = np.repeat(Xs, 2, axis=0)
            # draw some conditional random samples from the model
            model_samps = MSM.sample_from_input(Xs, guided_decoding=False)
            model_samps.append(Xs)
            seq_len = len(model_samps)
            seq_samps = np.zeros((seq_len*samp_count, model_samps[0].shape[1]))
            idx = 0
            for s1 in range(samp_count): 
                for s2 in range(seq_len):
                    seq_samps[idx] = model_samps[s2][s1]
                    idx += 1
            file_name = "MSM_A_SAMPLES_CND_b{0:d}.png".format(i)
            utils.visualize_samples(seq_samps, file_name, num_rows=20)
            # compute information about posterior KLds on validation set
            raw_klds = MSM.compute_raw_klds(Xb_va, Xb_va)
            init_kld, q2p_kld, p2q_kld = raw_klds
            file_name = "MSM_A_H0_KLDS_b{0:d}.png".format(i)
            utils.plot_stem(np.arange(init_kld.shape[1]), \
                    np.mean(init_kld, axis=0), file_name)
            file_name = "MSM_A_HI_Q2P_KLDS_b{0:d}.png".format(i)
            utils.plot_stem(np.arange(q2p_kld.shape[1]), \
                    np.mean(q2p_kld, axis=0), file_name)
            file_name = "MSM_A_HI_P2Q_KLDS_b{0:d}.png".format(i)
            utils.plot_stem(np.arange(p2q_kld.shape[1]), \
                    np.mean(p2q_kld, axis=0), file_name)
            Xb_tr = to_fX( Xtr[0:2000] )
            fe_terms = MSM.compute_fe_terms(Xb_tr, Xb_tr, 30)
            fe_nll = np.mean(fe_terms[0])
            fe_kld = np.mean(fe_terms[1])
            fe_joint = fe_nll + fe_kld
            joint_str = "    vfe-tr: {0:.4f}, nll: ({1:.4f}, {2:.4f}, {3:.4f}), kld: ({4:.4f}, {5:.4f}, {6:.4f})".format( \
                    fe_joint, fe_nll, np.min(fe_terms[0]), np.max(fe_terms[0]), fe_kld, np.min(fe_terms[1]), np.max(fe_terms[1]))
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            fe_terms = MSM.compute_fe_terms(Xb_va, Xb_va, 30)
            fe_nll = np.mean(fe_terms[0])
            fe_kld = np.mean(fe_terms[1])
            fe_joint = fe_nll + fe_kld
            joint_str = "    vfe-va: {0:.4f}, nll: ({1:.4f}, {2:.4f}, {3:.4f}), kld: ({4:.4f}, {5:.4f}, {6:.4f})".format( \
                    fe_joint, fe_nll, np.min(fe_terms[0]), np.max(fe_terms[0]), fe_kld, np.min(fe_terms[1]), np.max(fe_terms[1]))
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
def test_sgm_mnist(step_type='add', occ_dim=14, drop_prob=0.0, attention=False):
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 200

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    x_dim = Xtr.shape[1]
    writer_dim = 250
    reader_dim = 250
    dyn_dim = 250
    primary_dim = 500
    guide_dim = 500
    z_dim = 100
    n_iter = 20
    dp_int = int(100.0 * drop_prob)
    
    rnninits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }
    inits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }

    att_tag = "NA" # attention not implemented yet

    # reader MLP provides input to the dynamics LSTM update
    reader_mlp = MLP([Rectifier(), Rectifier(), None], \
                     [(x_dim + z_dim), reader_dim, reader_dim, 4*dyn_dim], \
                     name="reader_mlp", **inits)
    # writer MLP applies changes to the generation workspace
    writer_mlp = MLP([Rectifier(), Rectifier(), None], \
                     [(dyn_dim + z_dim), writer_dim, writer_dim, x_dim], \
                     name="writer_mlp", **inits)

    # MLPs for computing conditionals over z
    primary_policy = CondNet([Rectifier(), Rectifier()], \
                             [(dyn_dim + x_dim), primary_dim, primary_dim, z_dim], \
                             name="primary_policy", **inits)
    guide_policy = CondNet([Rectifier(), Rectifier()], \
                           [(dyn_dim + 2*x_dim), guide_dim, guide_dim, z_dim], \
                           name="guide_policy", **inits)
    # LSTMs for the actual LSTMs (obviously, perhaps)
    shared_dynamics = BiasedLSTM(dim=dyn_dim, ig_bias=2.0, fg_bias=2.0, \
                                 name="shared_dynamics", **rnninits)

    model = SeqGenModel(
                n_iter,
                step_type=step_type, # step_type can be 'add' or 'jump'
                reader_mlp=reader_mlp,
                writer_mlp=writer_mlp,
                primary_policy=primary_policy,
                guide_policy=guide_policy,
                shared_dynamics=shared_dynamics)
    model.initialize()

    # build the cost gradients, training function, samplers, etc.
    model.build_model_funcs()

    #model.load_model_params(f_name="TBSGM_IMP_MNIST_PARAMS_OD{}_DP{}_{}_{}.pkl".format(occ_dim, dp_int, step_type, att_tag))

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    print("Beginning to train the model...")
    out_file = open("TBSGM_IMP_MNIST_RESULTS_OD{}_DP{}_{}_{}.txt".format(occ_dim, dp_int, step_type, att_tag), 'wb')
    out_file.flush()
    costs = [0. for i in range(10)]
    learn_rate = 0.0002
    momentum = 0.5
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(250000):
        scale = min(1.0, ((i+1) / 1000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        if (i > 10000):
            momentum = 0.90
        else:
            momentum = 0.50
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        # set sgd and objective function hyperparams for this update
        zero_ary = np.zeros((1,))
        model.lr.set_value(to_fX(zero_ary + learn_rate))
        model.mom_1.set_value(to_fX(zero_ary + momentum))
        model.mom_2.set_value(to_fX(zero_ary + 0.99))

        # perform a minibatch update and record the cost for this batch
        Xb = to_fX(Xtr.take(batch_idx, axis=0))
        _, Xb, Mb = construct_masked_data(Xb, drop_prob=drop_prob, \
                                    occ_dim=occ_dim, data_mean=None)
        result = model.train_joint(Xb, Mb)

        costs = [(costs[j] + result[j]) for j in range(len(result))]
        if ((i % 200) == 0):
            costs = [(v / 200.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    total_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_bound : {0:.4f}".format(costs[1])
            str4 = "    nll_term  : {0:.4f}".format(costs[2])
            str5 = "    kld_q2p   : {0:.4f}".format(costs[3])
            str6 = "    kld_p2q   : {0:.4f}".format(costs[4])
            str7 = "    reg_term  : {0:.4f}".format(costs[5])
            joint_str = "\n".join([str1, str2, str3, str4, str5, str6, str7])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if ((i % 1000) == 0):
            model.save_model_params("TBSGM_IMP_MNIST_PARAMS_OD{}_DP{}_{}_{}.pkl".format(occ_dim, dp_int, step_type, att_tag))
            # compute a small-sample estimate of NLL bound on validation set
            Xva = row_shuffle(Xva)
            Xb = to_fX(Xva[:5000])
            _, Xb, Mb = construct_masked_data(Xb, drop_prob=drop_prob, \
                                    occ_dim=occ_dim, data_mean=None)
            va_costs = model.compute_nll_bound(Xb, Mb)
            str1 = "    va_nll_bound : {}".format(va_costs[1])
            str2 = "    va_nll_term  : {}".format(va_costs[2])
            str3 = "    va_kld_q2p   : {}".format(va_costs[3])
            joint_str = "\n".join([str1, str2, str3])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            # draw some independent samples from the model
            Xb = to_fX(Xva[:100])
            _, Xb, Mb = construct_masked_data(Xb, drop_prob=drop_prob, \
                                    occ_dim=occ_dim, data_mean=None)
            samples, _ = model.do_sample(Xb, Mb)
            n_iter, N, D = samples.shape
            samples = samples.reshape( (n_iter, N, 28, 28) )
            for j in xrange(n_iter):
                img = img_grid(samples[j,:,:,:])
                img.save("TBSGM-IMP-MNIST-OD{0:d}-DP{1:d}-{2:s}-samples-{3:03d}.png".format(occ_dim, dp_int, step_type, j))
def test_with_model_init():
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 200
    batch_reps = 1

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    obs_dim = Xtr.shape[1]
    z_dim = 20
    h_dim = 100
    x_type = 'bernoulli'

    # some InfNet instances to build the TwoStageModel from
    X_sym = T.matrix('X_sym')

    ########################
    # p_s0_obs_given_z_obs #
    ########################
    params = {}
    shared_config = [z_dim, 250, 250]
    top_config = [shared_config[-1], obs_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = relu_actfun
    params['init_scale'] = 1.0
    params['lam_l2a'] = 1e-3
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_s0_obs_given_z_obs = InfNet(rng=rng, Xd=X_sym, \
            params=params, shared_param_dicts=None)
    p_s0_obs_given_z_obs.init_biases(0.2)
    #################
    # p_hi_given_si #
    #################
    params = {}
    shared_config = [obs_dim, 250, 250]
    top_config = [shared_config[-1], h_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = relu_actfun
    params['init_scale'] = 1.0
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_hi_given_si = InfNet(rng=rng, Xd=X_sym, \
            params=params, shared_param_dicts=None)
    p_hi_given_si.init_biases(0.2)
    ######################
    # p_sip1_given_si_hi #
    ######################
    params = {}
    shared_config = [h_dim, 250, 250]
    top_config = [shared_config[-1], obs_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = relu_actfun
    params['init_scale'] = 1.0
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_sip1_given_si_hi = InfNet(rng=rng, Xd=X_sym, \
            params=params, shared_param_dicts=None)
    p_sip1_given_si_hi.init_biases(0.2)
    ###############
    # q_z_given_x #
    ###############
    params = {}
    shared_config = [obs_dim, 250, 250]
    top_config = [shared_config[-1], z_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = relu_actfun
    params['init_scale'] = 1.0
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    q_z_given_x = InfNet(rng=rng, Xd=X_sym, \
            params=params, shared_param_dicts=None)
    q_z_given_x.init_biases(0.2)
    ###################
    # q_hi_given_x_si #
    ###################
    params = {}
    shared_config = [(obs_dim + obs_dim), 250, 250]
    top_config = [shared_config[-1], h_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = relu_actfun
    params['init_scale'] = 1.0
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    q_hi_given_x_si = InfNet(rng=rng, Xd=X_sym, \
            params=params, shared_param_dicts=None)
    q_hi_given_x_si.init_biases(0.2)


    ################################################################
    # Define parameters for the MultiStageModel, and initialize it #
    ################################################################
    print("Building the MultiStageModel...")
    msm_params = {}
    msm_params['x_type'] = x_type
    msm_params['obs_transform'] = 'sigmoid'
    MSM = MultiStageModel(rng=rng, x_in=X_sym, \
            p_s0_obs_given_z_obs=p_s0_obs_given_z_obs, \
            p_hi_given_si=p_hi_given_si, \
            p_sip1_given_si_hi=p_sip1_given_si_hi, \
            q_z_given_x=q_z_given_x, \
            q_hi_given_x_si=q_hi_given_x_si, \
            obs_dim=obs_dim, z_dim=z_dim, h_dim=h_dim, \
            model_init_obs=True, ir_steps=5, \
            params=msm_params)
    obs_mean = (0.9 * np.mean(Xtr, axis=0)) + 0.05
    obs_mean_logit = np.log(obs_mean / (1.0 - obs_mean))
    MSM.set_input_bias(-obs_mean)
    MSM.set_obs_bias(0.1*obs_mean_logit)

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    log_name = "{}_RESULTS.txt".format("MSM_TEST")
    out_file = open(log_name, 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0002
    momentum = 0.9
    for i in range(300000):
        scale = min(1.0, ((i+1) / 15000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        # randomly sample a minibatch
        tr_idx = npr.randint(low=0,high=tr_samples,size=(batch_size,))
        Xb = Xtr.take(tr_idx, axis=0)
        #Xb = binarize_data(Xtr.take(tr_idx, axis=0))
        # set sgd and objective function hyperparams for this update
        MSM.set_sgd_params(lr_1=scale*learn_rate, lr_2=scale*learn_rate, \
                           mom_1=(scale*momentum), mom_2=0.98)
        MSM.set_train_switch(1.0)
        MSM.set_l1l2_weight(1.0)
        MSM.set_drop_rate(drop_rate=0.0)
        MSM.set_lam_nll(lam_nll=1.0)
        MSM.set_lam_kld(lam_kld_1=1.0, lam_kld_2=1.0)
        MSM.set_lam_l2w(1e-6)
        MSM.set_kzg_weight(0.05)
        # perform a minibatch update and record the cost for this batch
        result = MSM.train_joint(Xb, batch_reps)
        costs = [(costs[j] + result[j]) for j in range(len(result))]
        if ((i % 500) == 0):
            costs = [(v / 500.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    joint_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_cost  : {0:.4f}".format(costs[1])
            str4 = "    kld_cost  : {0:.4f}".format(costs[2])
            str5 = "    reg_cost  : {0:.4f}".format(costs[3])
            joint_str = "\n".join([str1, str2, str3, str4, str5])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if (((i % 2000) == 0) or ((i < 10000) and ((i % 1000) == 0))):
            Xva = row_shuffle(Xva)
            # draw some independent random samples from the model
            samp_count = 200
            model_samps = MSM.sample_from_prior(samp_count)
            seq_len = len(model_samps)
            seq_samps = np.zeros((seq_len*samp_count, model_samps[0].shape[1]))
            idx = 0
            for s1 in range(samp_count): 
                for s2 in range(seq_len):
                    seq_samps[idx] = model_samps[s2][s1]
                    idx += 1
            file_name = "MX_SAMPLES_b{0:d}.png".format(i)
            utils.visualize_samples(seq_samps, file_name, num_rows=20)
            # visualize some important weights in the model
            # file_name = "MX_INF_1_WEIGHTS_b{0:d}.png".format(i)
            # W = MSM.inf_1_weights.get_value(borrow=False).T
            # utils.visualize_samples(W[:,:obs_dim], file_name, num_rows=20)
            # file_name = "MX_INF_2_WEIGHTS_b{0:d}.png".format(i)
            # W = MSM.inf_2_weights.get_value(borrow=False).T
            # utils.visualize_samples(W[:,:obs_dim], file_name, num_rows=20)
            # file_name = "MX_GEN_1_WEIGHTS_b{0:d}.png".format(i)
            # W = MSM.gen_1_weights.get_value(borrow=False)
            # utils.visualize_samples(W[:,:obs_dim], file_name, num_rows=20)
            # file_name = "MX_GEN_2_WEIGHTS_b{0:d}.png".format(i)
            # W = MSM.gen_2_weights.get_value(borrow=False)
            # utils.visualize_samples(W[:,:obs_dim], file_name, num_rows=20)
            # file_name = "MX_GEN_INF_WEIGHTS_b{0:d}.png".format(i)
            # W = MSM.gen_inf_weights.get_value(borrow=False).T
            # utils.visualize_samples(W[:,:obs_dim], file_name, num_rows=20)
            # compute information about posterior KLds on validation set
            #post_klds = MSM.compute_post_klds(Xva[0:5000])
            #file_name = "MX_H0_KLDS_b{0:d}.png".format(i)
            #utils.plot_stem(np.arange(post_klds[0].shape[1]), \
            #        np.mean(post_klds[0], axis=0), file_name)
            #file_name = "MX_HI_COND_KLDS_b{0:d}.png".format(i)
            #utils.plot_stem(np.arange(post_klds[1].shape[1]), \
            #        np.mean(post_klds[1], axis=0), file_name)
            #file_name = "MX_HI_GLOB_KLDS_b{0:d}.png".format(i)
            #utils.plot_stem(np.arange(post_klds[2].shape[1]), \
            #        np.mean(post_klds[2], axis=0), file_name)
            # compute information about free-energy on validation set
            fe_terms = MSM.compute_fe_terms(binarize_data(Xva[0:5000]), 20)
            #file_name = "MX_FREE_ENERGY_b{0:d}.png".format(i)
            #utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, \
            #        x_label='Posterior KLd', y_label='Negative Log-likelihood')
            fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1])
            out_str = "    nll_bound : {0:.4f}".format(fe_mean)
            print(out_str)
            out_file.write(out_str+"\n")
            out_file.flush()

    return
def test_imoold_generation_ft(step_type="add", attention=False):
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path="./data/")
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    # del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 250

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    x_dim = Xtr.shape[1]
    write_dim = 200
    enc_dim = 250
    dec_dim = 250
    mix_dim = 20
    z_dim = 100
    if attention:
        n_iter = 50
    else:
        n_iter = 16

    rnninits = {"weights_init": IsotropicGaussian(0.01), "biases_init": Constant(0.0)}
    inits = {"weights_init": IsotropicGaussian(0.01), "biases_init": Constant(0.0)}

    # setup the reader and writer
    if attention:
        read_N, write_N = (2, 5)  # resolution of reader and writer
        read_dim = 2 * read_N ** 2  # total number of "pixels" read by reader
        reader_mlp = AttentionReader2d(x_dim=x_dim, dec_dim=dec_dim, width=28, height=28, N=read_N, **inits)
        writer_mlp = AttentionWriter(input_dim=dec_dim, output_dim=x_dim, width=28, height=28, N=write_N, **inits)
        att_tag = "YA"
    else:
        read_dim = 2 * x_dim
        reader_mlp = Reader(x_dim=x_dim, dec_dim=dec_dim, **inits)
        writer_mlp = MLP([None, None], [dec_dim, write_dim, x_dim], name="writer_mlp", **inits)
        att_tag = "NA"

    # setup the infinite mixture initialization model
    mix_enc_mlp = CondNet([Tanh()], [x_dim, 250, mix_dim], name="mix_enc_mlp", **inits)
    mix_dec_mlp = MLP([Tanh(), Tanh()], [mix_dim, 250, (2 * enc_dim + 2 * dec_dim)], name="mix_dec_mlp", **inits)
    # setup the components of the sequential generative model
    enc_mlp_in = MLP([Identity()], [(read_dim + dec_dim), 4 * enc_dim], name="enc_mlp_in", **inits)
    dec_mlp_in = MLP([Identity()], [z_dim, 4 * dec_dim], name="dec_mlp_in", **inits)
    enc_mlp_out = CondNet([], [enc_dim, z_dim], name="enc_mlp_out", **inits)
    dec_mlp_out = CondNet([], [dec_dim, z_dim], name="dec_mlp_out", **inits)
    enc_rnn = BiasedLSTM(dim=enc_dim, ig_bias=2.0, fg_bias=2.0, name="enc_rnn", **rnninits)
    dec_rnn = BiasedLSTM(dim=dec_dim, ig_bias=2.0, fg_bias=2.0, name="dec_rnn", **rnninits)

    draw = IMoOLDrawModels(
        n_iter,
        step_type=step_type,  # step_type can be 'add' or 'jump'
        mix_enc_mlp=mix_enc_mlp,
        mix_dec_mlp=mix_dec_mlp,
        reader_mlp=reader_mlp,
        enc_mlp_in=enc_mlp_in,
        enc_mlp_out=enc_mlp_out,
        enc_rnn=enc_rnn,
        dec_mlp_in=dec_mlp_in,
        dec_mlp_out=dec_mlp_out,
        dec_rnn=dec_rnn,
        writer_mlp=writer_mlp,
    )
    draw.initialize()

    # build the cost gradients, training function, samplers, etc.
    draw.build_model_funcs()
    draw.build_extra_funcs()

    # load parameters from a pre-trained model into the compiled model
    draw.load_model_params(f_name="TBOLM_GEN_PARAMS_{}_{}.pkl".format(step_type, att_tag))

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    print("Beginning to fine-tune the model...")
    out_file = open("TBOLM_GEN_RESULTS_{}_{}_FT.txt".format(step_type, att_tag), "wb")
    costs = [0.0 for i in range(10)]
    learn_rate = 0.0001
    momentum = 0.9
    batch_idx = np.arange(batch_size) + va_samples
    for i in range(50000):
        if ((i + 1) % 2000) == 0:
            learn_rate = learn_rate * 0.95
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if np.max(batch_idx) >= va_samples:
            # we finished an "epoch", so we rejumble the training set
            Xva = row_shuffle(Xva)
            batch_idx = np.arange(batch_size)

        # set sgd and objective function hyperparams for this update
        zero_ary = np.zeros((1,))
        draw.lr.set_value(to_fX(zero_ary + learn_rate))
        draw.mom_1.set_value(to_fX(zero_ary + momentum))
        draw.mom_2.set_value(to_fX(zero_ary + 0.99))

        # perform a minibatch update and record the cost for this batch
        Xb = to_fX(Xva.take(batch_idx, axis=0))
        result = draw.train_var(Xb, Xb)  # only train variational parameters
        costs = [(costs[j] + result[j]) for j in range(len(result))]

        # diagnostics
        if (i % 200) == 0:
            costs = [(v / 200.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    total_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_bound : {0:.4f}".format(costs[1])
            str4 = "    nll_term  : {0:.4f}".format(costs[2])
            str5 = "    kld_q2p   : {0:.4f}".format(costs[3])
            str6 = "    kld_p2q   : {0:.4f}".format(costs[4])
            str7 = "    reg_term  : {0:.4f}".format(costs[5])
            joint_str = "\n".join([str1, str2, str3, str4, str5, str6, str7])
            print(joint_str)
            out_file.write(joint_str + "\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if (i % 1000) == 0:
            # compute a small-sample estimate of NLL bound on validation set
            Xb = to_fX(Xva[:5000])
            va_costs = draw.compute_nll_bound(Xb, Xb)
            str1 = "    va_nll_bound : {}".format(va_costs[1])
            str2 = "    va_nll_term  : {}".format(va_costs[2])
            str3 = "    va_kld_q2p   : {}".format(va_costs[3])
            joint_str = "\n".join([str1, str2, str3])
            print(joint_str)
            out_file.write(joint_str + "\n")
            out_file.flush()
def test_one_stage_model():
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 128
    batch_reps = 1

    ###############################################
    # Setup some parameters for the OneStageModel #
    ###############################################
    x_dim = Xtr.shape[1]
    z_dim = 64
    x_type = 'bernoulli'
    xin_sym = T.matrix('xin_sym')

    ###############
    # p_x_given_z #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'fc',
       'in_chans': z_dim,
       'out_chans': 256,
       'activation': relu_actfun,
       'apply_bn': True}, \
      {'layer_type': 'fc',
       'in_chans': 256,
       'out_chans': 7*7*128,
       'activation': relu_actfun,
       'apply_bn': True,
       'shape_func_out': lambda x: T.reshape(x, (-1, 128, 7, 7))}, \
      {'layer_type': 'conv',
       'in_chans': 128, # in shape:  (batch, 128, 7, 7)
       'out_chans': 64, # out shape: (batch, 64, 14, 14)
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'half',
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'conv',
       'in_chans': 64, # in shape:  (batch, 64, 14, 14)
       'out_chans': 1, # out shape: (batch, 1, 28, 28)
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'half',
       'apply_bn': False,
       'shape_func_out': lambda x: T.flatten(x, 2)}, \
      {'layer_type': 'conv',
       'in_chans': 64,
       'out_chans': 1,
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'half',
       'apply_bn': False,
       'shape_func_out': lambda x: T.flatten(x, 2)} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    p_x_given_z = HydraNet(rng=rng, Xd=xin_sym, \
            params=params, shared_param_dicts=None)
    p_x_given_z.init_biases(0.0)
    ###############
    # q_z_given_x #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'conv',
       'in_chans': 1,   # in shape:  (batch, 784)
       'out_chans': 64, # out shape: (batch, 64, 14, 14)
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'double',
       'apply_bn': True,
       'shape_func_in': lambda x: T.reshape(x, (-1, 1, 28, 28))}, \
      {'layer_type': 'conv',
       'in_chans': 64,   # in shape:  (batch, 64, 14, 14)
       'out_chans': 128, # out shape: (batch, 128, 7, 7)
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'double',
       'apply_bn': True,
       'shape_func_out': lambda x: T.flatten(x, 2)}, \
      {'layer_type': 'fc',
       'in_chans': 128*7*7,
       'out_chans': 256,
       'activation': relu_actfun,
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'fc',
       'in_chans': 256,
       'out_chans': z_dim,
       'activation': relu_actfun,
       'apply_bn': False}, \
      {'layer_type': 'fc',
       'in_chans': 256,
       'out_chans': z_dim,
       'activation': relu_actfun,
       'apply_bn': False} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    q_z_given_x = HydraNet(rng=rng, Xd=xin_sym, \
            params=params, shared_param_dicts=None)
    q_z_given_x.init_biases(0.0)

    ##############################################################
    # Define parameters for the TwoStageModel, and initialize it #
    ##############################################################
    print("Building the OneStageModel...")
    osm_params = {}
    osm_params['x_type'] = x_type
    osm_params['obs_transform'] = 'sigmoid'
    OSM = OneStageModel(rng=rng,
                        x_in=xin_sym,
                        x_dim=x_dim,
                        z_dim=z_dim,
                        p_x_given_z=p_x_given_z,
                        q_z_given_x=q_z_given_x,
                        params=osm_params)

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    log_name = "{}_RESULTS.txt".format("OSM_TEST")
    out_file = open(log_name, 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0005
    momentum = 0.9
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(500000):
        scale = min(0.5, ((i + 1) / 5000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        Xb = to_fX(Xtr.take(batch_idx, axis=0))
        #Xb = binarize_data(Xtr.take(batch_idx, axis=0))
        # set sgd and objective function hyperparams for this update
        OSM.set_sgd_params(lr=scale*learn_rate, \
                           mom_1=(scale*momentum), mom_2=0.98)
        OSM.set_lam_nll(lam_nll=1.0)
        OSM.set_lam_kld(lam_kld=1.0)
        OSM.set_lam_l2w(1e-5)
        # perform a minibatch update and record the cost for this batch
        result = OSM.train_joint(Xb, batch_reps)
        costs = [(costs[j] + result[j]) for j in range(len(result))]
        if ((i % 500) == 0):
            costs = [(v / 500.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    joint_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_cost  : {0:.4f}".format(costs[1])
            str4 = "    kld_cost  : {0:.4f}".format(costs[2])
            str5 = "    reg_cost  : {0:.4f}".format(costs[3])
            joint_str = "\n".join([str1, str2, str3, str4, str5])
            print(joint_str)
            out_file.write(joint_str + "\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if (((i % 5000) == 0) or ((i < 10000) and ((i % 1000) == 0))):
            # draw some independent random samples from the model
            samp_count = 300
            model_samps = OSM.sample_from_prior(samp_count)
            file_name = "OSM_SAMPLES_b{0:d}.png".format(i)
            utils.visualize_samples(model_samps, file_name, num_rows=15)
            # compute free energy estimate for validation samples
            Xva = row_shuffle(Xva)
            fe_terms = OSM.compute_fe_terms(Xva[0:5000], 20)
            fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1])
            out_str = "    nll_bound : {0:.4f}".format(fe_mean)
            print(out_str)
            out_file.write(out_str + "\n")
            out_file.flush()
    return
def test_mnist_results(step_type='add',
                       imp_steps=6,
                       occ_dim=15,
                       drop_prob=0.0):
    #########################################
    # Format the result tag more thoroughly #
    #########################################
    dp_int = int(100.0 * drop_prob)
    result_tag = "{}GPSI_OD{}_DP{}_IS{}_{}_NA".format(RESULT_PATH, occ_dim, dp_int, imp_steps, step_type)

    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]

    ##########################
    # Get some training data #
    ##########################
    # rng = np.random.RandomState(1234)
    # dataset = 'data/mnist.pkl.gz'
    # datasets = load_udm(dataset, as_shared=False, zero_mean=False)
    # Xtr = datasets[0][0]
    # Xva = datasets[1][0]
    # Xte = datasets[2][0]
    # # Merge validation set and training set, and test on test set.
    # #Xtr = np.concatenate((Xtr, Xva), axis=0)
    # #Xva = Xte
    # Xtr = to_fX(shift_and_scale_into_01(Xtr))
    # Xva = to_fX(shift_and_scale_into_01(Xva))
    # tr_samples = Xtr.shape[0]
    # va_samples = Xva.shape[0]
    batch_size = 250
    batch_reps = 1
    all_pix_mean = np.mean(np.mean(Xtr, axis=1))
    data_mean = to_fX( all_pix_mean * np.ones((Xtr.shape[1],)) )

    # Load parameters from a previously trained model
    print("Testing model load from file...")
    GPSI = load_gpsimputer_from_file(f_name="{}_PARAMS.pkl".format(result_tag), \
                                     rng=rng)

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    log_name = "{}_FINAL_RESULTS_NEW.txt".format(result_tag)
    out_file = open(log_name, 'wb')

    Xva = row_shuffle(Xva)
    # record an estimate of performance on the test set
    str0 = "GUIDED SAMPLE BOUND:"
    print(str0)
    xi, xo, xm = construct_masked_data(Xva[:5000], drop_prob=drop_prob, \
                                       occ_dim=occ_dim, data_mean=data_mean)
    nll_0, kld_0 = GPSI.compute_fe_terms(xi, xo, xm, sample_count=10, \
                                         use_guide_policy=True)
    xi, xo, xm = construct_masked_data(Xva[5000:], drop_prob=drop_prob, \
                                       occ_dim=occ_dim, data_mean=data_mean)
    nll_1, kld_1 = GPSI.compute_fe_terms(xi, xo, xm, sample_count=10, \
                                         use_guide_policy=True)
    nll = np.concatenate((nll_0, nll_1))
    kld = np.concatenate((kld_0, kld_1))
    vfe = np.mean(nll) + np.mean(kld)
    str1 = "    va_nll_bound : {}".format(vfe)
    str2 = "    va_nll_term  : {}".format(np.mean(nll))
    str3 = "    va_kld_q2p   : {}".format(np.mean(kld))
    joint_str = "\n".join([str0, str1, str2, str3])
    print(joint_str)
    out_file.write(joint_str+"\n")
    out_file.flush()
    # record an estimate of performance on the test set
    str0 = "UNGUIDED SAMPLE BOUND:"
    print(str0)
    xi, xo, xm = construct_masked_data(Xva[:5000], drop_prob=drop_prob, \
                                       occ_dim=occ_dim, data_mean=data_mean)
    nll_0, kld_0 = GPSI.compute_fe_terms(xi, xo, xm, sample_count=10, \
                                         use_guide_policy=False)
    xi, xo, xm = construct_masked_data(Xva[5000:], drop_prob=drop_prob, \
                                       occ_dim=occ_dim, data_mean=data_mean)
    nll_1, kld_1 = GPSI.compute_fe_terms(xi, xo, xm, sample_count=10, \
                                         use_guide_policy=False)
    nll = np.concatenate((nll_0, nll_1))
    kld = np.concatenate((kld_0, kld_1))
    str1 = "    va_nll_bound : {}".format(np.mean(nll))
    str2 = "    va_nll_term  : {}".format(np.mean(nll))
    str3 = "    va_kld_q2p   : {}".format(np.mean(kld))
    joint_str = "\n".join([str0, str1, str2, str3])
    print(joint_str)
    out_file.write(joint_str+"\n")
    out_file.flush()
示例#8
0
def test_imoold_generation(step_type='add', attention=False):
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 200

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    x_dim = Xtr.shape[1]
    write_dim = 250
    enc_dim = 250
    dec_dim = 250
    mix_dim = 25
    z_dim = 100
    if attention:
        n_iter = 64
    else:
        n_iter = 32

    rnninits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }
    inits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }

    # setup the reader and writer
    if attention:
        read_N, write_N = (2, 5) # resolution of reader and writer
        read_dim = 2*read_N**2   # total number of "pixels" read by reader
        reader_mlp = AttentionReader2d(x_dim=x_dim, dec_dim=dec_dim,
                                 width=28, height=28,
                                 N=read_N, **inits)
        writer_mlp = AttentionWriter(input_dim=dec_dim, output_dim=x_dim,
                                 width=28, height=28,
                                 N=write_N, **inits)
        att_tag = "YA"
    else:
        read_dim = 2*x_dim
        reader_mlp = Reader(x_dim=x_dim, dec_dim=dec_dim, **inits)
        writer_mlp = MLP([None, None], [dec_dim, write_dim, x_dim], \
                         name="writer_mlp", **inits)
        att_tag = "NA"

    # setup the infinite mixture initialization model
    mix_enc_mlp = CondNet([Tanh()], [x_dim, 250, mix_dim], \
                          name="mix_enc_mlp", **inits)
    mix_dec_mlp = MLP([Tanh(), Tanh()], \
                      [mix_dim, 250, (2*enc_dim + 2*dec_dim)], \
                      name="mix_dec_mlp", **inits)
    # setup the components of the sequential generative model
    enc_mlp_in = MLP([Identity()], [(read_dim + dec_dim), 4*enc_dim], \
                     name="enc_mlp_in", **inits)
    dec_mlp_in = MLP([Identity()], [               z_dim, 4*dec_dim], \
                     name="dec_mlp_in", **inits)
    enc_mlp_out = CondNet([], [enc_dim, z_dim], name="enc_mlp_out", **inits)
    dec_mlp_out = CondNet([], [dec_dim, z_dim], name="dec_mlp_out", **inits)
    enc_rnn = BiasedLSTM(dim=enc_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="enc_rnn", **rnninits)
    dec_rnn = BiasedLSTM(dim=dec_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="dec_rnn", **rnninits)

    draw = IMoOLDrawModels(
                n_iter,
                step_type=step_type, # step_type can be 'add' or 'jump'
                mix_enc_mlp=mix_enc_mlp,
                mix_dec_mlp=mix_dec_mlp,
                reader_mlp=reader_mlp,
                enc_mlp_in=enc_mlp_in,
                enc_mlp_out=enc_mlp_out,
                enc_rnn=enc_rnn,
                dec_mlp_in=dec_mlp_in,
                dec_mlp_out=dec_mlp_out,
                dec_rnn=dec_rnn,
                writer_mlp=writer_mlp)
    draw.initialize()

    compile_start_time = time.time()

    # build the cost gradients, training function, samplers, etc.
    draw.build_model_funcs()

    compile_end_time = time.time()
    compile_minutes = (compile_end_time - compile_start_time) / 60.0
    print("THEANO COMPILE TIME (MIN): {}".format(compile_minutes))

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    print("Beginning to train the model...")
    out_file = open("TBOLM_GEN_RESULTS_{}_{}.txt".format(step_type, att_tag), 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.00015
    momentum = 0.9
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(250000):
        scale = min(1.0, ((i+1) / 5000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)

        # set sgd and objective function hyperparams for this update
        zero_ary = np.zeros((1,))
        draw.lr.set_value(to_fX(zero_ary + scale*learn_rate))
        draw.mom_1.set_value(to_fX(zero_ary + scale*momentum))
        draw.mom_2.set_value(to_fX(zero_ary + 0.98))

        # perform a minibatch update and record the cost for this batch
        Xb = to_fX(Xtr.take(batch_idx, axis=0))
        draw.set_rnn_noise(rnn_noise=0.02)
        result = draw.train_joint(Xb, Xb)
        costs = [(costs[j] + result[j]) for j in range(len(result))]

        # diagnostics
        if ((i % 200) == 0):
            costs = [(v / 200.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    total_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_bound : {0:.4f}".format(costs[1])
            str4 = "    nll_term  : {0:.4f}".format(costs[2])
            str5 = "    kld_q2p   : {0:.4f}".format(costs[3])
            str6 = "    kld_p2q   : {0:.4f}".format(costs[4])
            str7 = "    reg_term  : {0:.4f}".format(costs[5])
            str8 = "    step_klds : {0:s}".format(np.array_str(costs[6], precision=2))
            joint_str = "\n".join([str1, str2, str3, str4, str5, str6, str7, str8])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if ((i % 1000) == 0):
            draw.save_model_params("TBOLM_GEN_PARAMS_{}_{}.pkl".format(step_type, att_tag))
            # compute a small-sample estimate of NLL bound on validation set
            Xva = row_shuffle(Xva)
            Xb = to_fX(Xva[:5000])
            draw.set_rnn_noise(rnn_noise=0.0)
            va_costs = draw.compute_nll_bound(Xb, Xb)
            str1 = "    va_nll_bound : {}".format(va_costs[1])
            str2 = "    va_nll_term  : {}".format(va_costs[2])
            str3 = "    va_kld_q2p   : {}".format(va_costs[3])
            joint_str = "\n".join([str1, str2, str3])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            # draw some independent samples from the model
            samples, x_logodds = draw.do_sample(16*16)
            utils.plot_kde_histogram(x_logodds[-1,:,:], "TBOLM-log_odds_hist.png", bins=30)
            n_iter, N, D = samples.shape
            samples = samples.reshape( (n_iter, N, 28, 28) )
            for j in xrange(n_iter):
                img = img_grid(samples[j,:,:,:])
                img.save("TBOLM-gen-samples-%03d.png" % (j,))
def test_lstm_structpred(step_type='add', use_pol=True, use_binary=False):
    ###########################################
    # Make a tag for identifying result files #
    ###########################################
    pol_tag = "P1" if use_pol else "P0"
    bin_tag = "B1" if use_binary else "B0"
    res_tag = "STRUCT_PRED_RESULTS/SP_LSTM_{}_{}_{}".format(step_type, pol_tag, bin_tag)

    if use_binary:
        ############################
        # Get binary training data #
        ############################
        rng = np.random.RandomState(1234)
        Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
        #Xtr = np.vstack((Xtr, Xva))
        #Xva = Xte
    else:
        ################################
        # Get continuous training data #
        ################################
        rng = np.random.RandomState(1234)
        dataset = 'data/mnist.pkl.gz'
        datasets = load_udm(dataset, as_shared=False, zero_mean=False)
        Xtr = datasets[0][0]
        Xva = datasets[1][0]
        Xte = datasets[2][0]
        #Xtr = np.concatenate((Xtr, Xva), axis=0)
        #Xva = Xte
        Xtr = to_fX(shift_and_scale_into_01(Xtr))
        Xva = to_fX(shift_and_scale_into_01(Xva))
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 200


    ########################################################
    # Split data into "observation" and "prediction" parts #
    ########################################################
    obs_cols = 14             # number of columns to observe
    pred_cols = 28 - obs_cols # number of columns to predict
    x_dim = obs_cols * 28     # dimensionality of observations
    y_dim = pred_cols * 28    # dimensionality of predictions
    Xtr, Ytr = img_split(Xtr, im_dim=(28, 28), split_col=obs_cols, transposed=True)
    Xva, Yva = img_split(Xva, im_dim=(28, 28), split_col=obs_cols, transposed=True)

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    read_dim = 128
    write_dim = 128
    mlp_dim = 128
    rnn_dim = 128
    z_dim = 64
    n_iter = 15

    rnninits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }
    inits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }

    # setup reader/writer models
    reader_mlp = MLP([Rectifier(), Tanh()], [x_dim, mlp_dim, read_dim],
                     name="reader_mlp", **inits)
    writer_mlp = MLP([Rectifier(), None], [rnn_dim, mlp_dim, y_dim],
                     name="writer_mlp", **inits)

    # setup submodels for processing LSTM inputs
    pol_inp_dim = y_dim + read_dim + rnn_dim
    var_inp_dim = y_dim + y_dim + read_dim + rnn_dim
    pol_mlp_in = MLP([Identity()], [pol_inp_dim, 4*rnn_dim],
                     name="pol_mlp_in", **inits)
    var_mlp_in = MLP([Identity()], [var_inp_dim, 4*rnn_dim],
                     name="var_mlp_in", **inits)
    dec_mlp_in = MLP([Identity()], [z_dim, 4*rnn_dim],
                     name="dec_mlp_in", **inits)

    # setup submodels for turning LSTM states into conditionals over z
    pol_mlp_out = CondNet([], [rnn_dim, z_dim], name="pol_mlp_out", **inits)
    var_mlp_out = CondNet([], [rnn_dim, z_dim], name="var_mlp_out", **inits)
    dec_mlp_out = CondNet([], [rnn_dim, z_dim], name="dec_mlp_out", **inits)

    # setup the LSTMs for primary policy, guide policy, and shared dynamics
    pol_rnn = BiasedLSTM(dim=rnn_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="pol_rnn", **rnninits)
    var_rnn = BiasedLSTM(dim=rnn_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="var_rnn", **rnninits)
    dec_rnn = BiasedLSTM(dim=rnn_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="dec_rnn", **rnninits)

    model = IRStructPredModel(
                n_iter,
                step_type=step_type,
                use_pol=use_pol,
                reader_mlp=reader_mlp,
                writer_mlp=writer_mlp,
                pol_mlp_in=pol_mlp_in,
                pol_mlp_out=pol_mlp_out,
                pol_rnn=pol_rnn,
                var_mlp_in=var_mlp_in,
                var_mlp_out=var_mlp_out,
                var_rnn=var_rnn,
                dec_mlp_in=dec_mlp_in,
                dec_mlp_out=dec_mlp_out,
                dec_rnn=dec_rnn)
    model.initialize()

    compile_start_time = time.time()

    # build the cost gradients, training function, samplers, etc.
    model.build_sampling_funcs()
    print("Testing model sampler...")
    # draw some independent samples from the model
    samp_count = 10
    samp_reps = 3
    x_in = Xtr[:10,:].repeat(samp_reps, axis=0)
    y_in = Ytr[:10,:].repeat(samp_reps, axis=0)
    x_samps, y_samps = model.sample_model(x_in, y_in, sample_source='p')
    # TODO: visualize sample prediction trajectories
    img_seq = seq_img_join(x_samps, y_samps, im_dim=(28,28), transposed=True)
    seq_len = len(img_seq)
    samp_count = img_seq[0].shape[0]
    seq_samps = np.zeros((seq_len*samp_count, img_seq[0].shape[1]))
    idx = 0
    for s1 in range(samp_count):
        for s2 in range(seq_len):
            seq_samps[idx] = img_seq[s2][s1]
            idx += 1
    file_name = "{0:s}_samples_b{1:d}.png".format(res_tag, 0)
    utils.visualize_samples(seq_samps, file_name, num_rows=samp_count)

    model.build_model_funcs()

    compile_end_time = time.time()
    compile_minutes = (compile_end_time - compile_start_time) / 60.0
    print("THEANO COMPILE TIME (MIN): {}".format(compile_minutes))

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    print("Beginning to train the model...")
    out_file = open("{}_results.txt".format(res_tag), 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0002
    momentum = 0.9
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(300000):
        scale = min(1.0, ((i+1) / 5000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr, Ytr = row_shuffle(Xtr, Ytr)
            batch_idx = np.arange(batch_size)
        # set sgd and objective function hyperparams for this update
        model.set_sgd_params(lr=scale*learn_rate, mom_1=scale*momentum, mom_2=0.98)
        model.set_lam_kld(lam_kld_q2p=1.0, lam_kld_p2q=0.1)
        model.set_grad_noise(grad_noise=0.02)
        # perform a minibatch update and record the cost for this batch
        Xb = to_fX(Xtr.take(batch_idx, axis=0))
        Yb = to_fX(Ytr.take(batch_idx, axis=0))
        result = model.train_joint(Xb, Yb)
        costs = [(costs[j] + result[j]) for j in range(len(result))]

        # diagnostics
        if ((i % 250) == 0):
            costs = [(v / 250.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    total_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_bound : {0:.4f}".format(costs[1])
            str4 = "    nll_term  : {0:.4f}".format(costs[2])
            str5 = "    kld_q2p   : {0:.4f}".format(costs[3])
            str6 = "    kld_p2q   : {0:.4f}".format(costs[4])
            str7 = "    reg_term  : {0:.4f}".format(costs[5])
            joint_str = "\n".join([str1, str2, str3, str4, str5, str6, str7])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if ((i % 1000) == 0):
            model.save_model_params("{}_params.pkl".format(res_tag))
            # compute a small-sample estimate of NLL bound on validation set
            Xva, Yva = row_shuffle(Xva, Yva)
            Xb = to_fX(Xva[:5000])
            Yb = to_fX(Yva[:5000])
            va_costs = model.compute_nll_bound(Xb, Yb)
            str1 = "    va_nll_bound : {}".format(va_costs[1])
            str2 = "    va_nll_term  : {}".format(va_costs[2])
            str3 = "    va_kld_q2p   : {}".format(va_costs[3])
            joint_str = "\n".join([str1, str2, str3])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            # draw some independent samples from the model
            samp_count = 10
            samp_reps = 3
            x_in = Xva[:samp_count,:].repeat(samp_reps, axis=0)
            y_in = Yva[:samp_count,:].repeat(samp_reps, axis=0)
            x_samps, y_samps = model.sample_model(x_in, y_in, sample_source='p')
            # visualize sample prediction trajectories
            img_seq = seq_img_join(x_samps, y_samps, im_dim=(28,28), transposed=True)
            seq_len = len(img_seq)
            samp_count = img_seq[0].shape[0]
            seq_samps = np.zeros((seq_len*samp_count, img_seq[0].shape[1]))
            idx = 0
            for s1 in range(samp_count):
                for s2 in range(seq_len):
                    if use_binary:
                        seq_samps[idx] = binarize_data(img_seq[s2][s1])
                    else:
                        seq_samps[idx] = img_seq[s2][s1]
                    idx += 1
            file_name = "{0:s}_samples_b{1:d}.png".format(res_tag, i)
            utils.visualize_samples(seq_samps, file_name, num_rows=samp_count)
def test_mnist(step_type='add',
               imp_steps=6,
               occ_dim=15,
               drop_prob=0.0):
    #########################################
    # Format the result tag more thoroughly #
    #########################################
    dp_int = int(100.0 * drop_prob)
    result_tag = "{}GPSI_OD{}_DP{}_IS{}_{}_NA".format(RESULT_PATH, occ_dim, dp_int, imp_steps, step_type)

    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]

    ##########################
    # Get some training data #
    ##########################
    # rng = np.random.RandomState(1234)
    # dataset = 'data/mnist.pkl.gz'
    # datasets = load_udm(dataset, as_shared=False, zero_mean=False)
    # Xtr = datasets[0][0]
    # Xva = datasets[1][0]
    # Xte = datasets[2][0]
    # # Merge validation set and training set, and test on test set.
    # #Xtr = np.concatenate((Xtr, Xva), axis=0)
    # #Xva = Xte
    # Xtr = to_fX(shift_and_scale_into_01(Xtr))
    # Xva = to_fX(shift_and_scale_into_01(Xva))
    # tr_samples = Xtr.shape[0]
    # va_samples = Xva.shape[0]
    batch_size = 200
    batch_reps = 1
    all_pix_mean = np.mean(np.mean(Xtr, axis=1))
    data_mean = to_fX( all_pix_mean * np.ones((Xtr.shape[1],)) )

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    x_dim = Xtr.shape[1]
    s_dim = x_dim
    h_dim = 50
    z_dim = 100
    init_scale = 0.6

    x_in_sym = T.matrix('x_in_sym')
    x_out_sym = T.matrix('x_out_sym')
    x_mask_sym = T.matrix('x_mask_sym')

    ###############
    # p_h_given_x #
    ###############
    params = {}
    shared_config = [x_dim, 250]
    top_config = [shared_config[-1], h_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = tanh_actfun #relu_actfun
    params['init_scale'] = 'xg' #init_scale
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_h_given_x = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    p_h_given_x.init_biases(0.0)
    ################
    # p_s0_given_h #
    ################
    params = {}
    shared_config = [h_dim, 250]
    output_config = [s_dim, s_dim, s_dim]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['activation'] = tanh_actfun #relu_actfun
    params['init_scale'] = 'xg' #init_scale
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_s0_given_h = HydraNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    p_s0_given_h.init_biases(0.0)
    #################
    # p_zi_given_xi #
    #################
    params = {}
    shared_config = [(x_dim + x_dim), 500, 500]
    top_config = [shared_config[-1], z_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = tanh_actfun #relu_actfun
    params['init_scale'] = init_scale
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_zi_given_xi = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    p_zi_given_xi.init_biases(0.0)
    ###################
    # p_sip1_given_zi #
    ###################
    params = {}
    shared_config = [z_dim, 500, 500]
    output_config = [s_dim, s_dim, s_dim]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['activation'] = tanh_actfun #relu_actfun
    params['init_scale'] = init_scale
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_sip1_given_zi = HydraNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    p_sip1_given_zi.init_biases(0.0)
    ################
    # p_x_given_si #
    ################
    params = {}
    shared_config = [s_dim]
    output_config = [x_dim, x_dim]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['activation'] = tanh_actfun #relu_actfun
    params['init_scale'] = init_scale
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_x_given_si = HydraNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    p_x_given_si.init_biases(0.0)
    ###############
    # q_h_given_x #
    ###############
    params = {}
    shared_config = [x_dim, 250]
    top_config = [shared_config[-1], h_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = tanh_actfun #relu_actfun
    params['init_scale'] = 'xg' #init_scale
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    q_h_given_x = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    q_h_given_x.init_biases(0.0)
    #################
    # q_zi_given_xi #
    #################
    params = {}
    shared_config = [(x_dim + x_dim), 500, 500]
    top_config = [shared_config[-1], z_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = tanh_actfun #relu_actfun
    params['init_scale'] = init_scale
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    q_zi_given_xi = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    q_zi_given_xi.init_biases(0.0)

    ###########################################################
    # Define parameters for the GPSImputer, and initialize it #
    ###########################################################
    print("Building the GPSImputer...")
    gpsi_params = {}
    gpsi_params['x_dim'] = x_dim
    gpsi_params['h_dim'] = h_dim
    gpsi_params['z_dim'] = z_dim
    gpsi_params['s_dim'] = s_dim
    # switch between direct construction and construction via p_x_given_si
    gpsi_params['use_p_x_given_si'] = False
    gpsi_params['imp_steps'] = imp_steps
    gpsi_params['step_type'] = step_type
    gpsi_params['x_type'] = 'bernoulli'
    gpsi_params['obs_transform'] = 'sigmoid'
    GPSI = GPSImputerWI(rng=rng,
            x_in=x_in_sym, x_out=x_out_sym, x_mask=x_mask_sym, \
            p_h_given_x=p_h_given_x, \
            p_s0_given_h=p_s0_given_h, \
            p_zi_given_xi=p_zi_given_xi, \
            p_sip1_given_zi=p_sip1_given_zi, \
            p_x_given_si=p_x_given_si, \
            q_h_given_x=q_h_given_x, \
            q_zi_given_xi=q_zi_given_xi, \
            params=gpsi_params, \
            shared_param_dicts=None)

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    log_name = "{}_RESULTS.txt".format(result_tag)
    out_file = open(log_name, 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0002
    momentum = 0.5
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(250000):
        scale = min(1.0, ((i+1) / 5000.0))
        lam_scale = 1.0 - min(1.0, ((i+1) / 100000.0)) # decays from 1.0->0.0
        if (((i + 1) % 15000) == 0):
            learn_rate = learn_rate * 0.93
        if (i > 10000):
            momentum = 0.90
        else:
            momentum = 0.75
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        # set sgd and objective function hyperparams for this update
        GPSI.set_sgd_params(lr=scale*learn_rate, \
                            mom_1=scale*momentum, mom_2=0.98)
        GPSI.set_train_switch(1.0)
        GPSI.set_lam_nll(lam_nll=1.0)
        GPSI.set_lam_kld(lam_kld_p=0.05, lam_kld_q=0.95, \
                         lam_kld_g=(0.1 * lam_scale), lam_kld_s=(0.1 * lam_scale))
        GPSI.set_lam_l2w(1e-5)
        # perform a minibatch update and record the cost for this batch
        xb = to_fX( Xtr.take(batch_idx, axis=0) )
        xi, xo, xm = construct_masked_data(xb, drop_prob=drop_prob, \
                                        occ_dim=occ_dim, data_mean=data_mean)
        result = GPSI.train_joint(xi, xo, xm, batch_reps)
        # do diagnostics and general training tracking
        costs = [(costs[j] + result[j]) for j in range(len(result)-1)]
        if ((i % 250) == 0):
            costs = [(v / 250.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    joint_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_bound : {0:.4f}".format(costs[1])
            str4 = "    nll_cost  : {0:.4f}".format(costs[2])
            str5 = "    kld_cost  : {0:.4f}".format(costs[3])
            str6 = "    reg_cost  : {0:.4f}".format(costs[4])
            joint_str = "\n".join([str1, str2, str3, str4, str5, str6])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if ((i % 1000) == 0):
            Xva = row_shuffle(Xva)
            # record an estimate of performance on the test set
            xi, xo, xm = construct_masked_data(Xva[0:5000], drop_prob=drop_prob, \
                                               occ_dim=occ_dim, data_mean=data_mean)
            nll, kld = GPSI.compute_fe_terms(xi, xo, xm, sample_count=10)
            vfe = np.mean(nll) + np.mean(kld)
            str1 = "    va_nll_bound : {}".format(vfe)
            str2 = "    va_nll_term  : {}".format(np.mean(nll))
            str3 = "    va_kld_q2p   : {}".format(np.mean(kld))
            joint_str = "\n".join([str1, str2, str3])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
        if ((i % 2000) == 0):
            GPSI.save_to_file("{}_PARAMS.pkl".format(result_tag))
            # Get some validation samples for evaluating model performance
            xb = to_fX( Xva[0:100] )
            xi, xo, xm = construct_masked_data(xb, drop_prob=drop_prob, \
                                    occ_dim=occ_dim, data_mean=data_mean)
            xi = np.repeat(xi, 2, axis=0)
            xo = np.repeat(xo, 2, axis=0)
            xm = np.repeat(xm, 2, axis=0)
            # draw some sample imputations from the model
            samp_count = xi.shape[0]
            _, model_samps = GPSI.sample_imputer(xi, xo, xm, use_guide_policy=False)
            seq_len = len(model_samps)
            seq_samps = np.zeros((seq_len*samp_count, model_samps[0].shape[1]))
            idx = 0
            for s1 in range(samp_count):
                for s2 in range(seq_len):
                    seq_samps[idx] = model_samps[s2][s1]
                    idx += 1
            file_name = "{0:s}_samples_ng_b{1:d}.png".format(result_tag, i)
            utils.visualize_samples(seq_samps, file_name, num_rows=20)
            # show KLds and NLLs on a step-by-step basis
            xb = to_fX( Xva[0:1000] )
            xi, xo, xm = construct_masked_data(xb, drop_prob=drop_prob, \
                                    occ_dim=occ_dim, data_mean=data_mean)
            step_costs = GPSI.compute_per_step_cost(xi, xo, xm)
            step_nlls = step_costs[0]
            step_klds = step_costs[1]
            step_nums = np.arange(step_nlls.shape[0])
            file_name = "{0:s}_NLL_b{1:d}.png".format(result_tag, i)
            utils.plot_stem(step_nums, step_nlls, file_name)
            file_name = "{0:s}_KLD_b{1:d}.png".format(result_tag, i)
            utils.plot_stem(step_nums, step_klds, file_name)
示例#11
0
def test_two_stage_model2():
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 128
    batch_reps = 1

    ###############################################
    # Setup some parameters for the TwoStageModel #
    ###############################################
    x_dim = Xtr.shape[1]
    z_dim = 50
    h_dim = 100
    x_type = 'bernoulli'

    # some InfNet instances to build the TwoStageModel from
    xin_sym = T.matrix('xin_sym')
    xout_sym = T.matrix('xout_sym')

    ###############
    # p_h_given_z #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'fc',
       'in_chans': z_dim,
       'out_chans': 100,
       'activation': tanh_actfun,
       'apply_bn': True}, \
      {'layer_type': 'fc',
       'in_chans': 100,
       'out_chans': 100,
       'activation': tanh_actfun,
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'fc',
       'in_chans': 100,
       'out_chans': h_dim,
       'activation': tanh_actfun,
       'apply_bn': False}, \
      {'layer_type': 'fc',
       'in_chans': 100,
       'out_chans': h_dim,
       'activation': tanh_actfun,
       'apply_bn': False} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    p_h_given_z = HydraNet(rng=rng,
                           Xd=xin_sym,
                           params=params,
                           shared_param_dicts=None)
    p_h_given_z.init_biases(0.0)
    ###############
    # p_x_given_h #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'fc',
       'in_chans': h_dim,
       'out_chans': 200,
       'activation': tanh_actfun,
       'apply_bn': True}, \
      {'layer_type': 'fc',
       'in_chans': 200,
       'out_chans': 200,
       'activation': tanh_actfun,
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'fc',
       'in_chans': 200,
       'out_chans': x_dim,
       'activation': tanh_actfun,
       'apply_bn': False}, \
      {'layer_type': 'fc',
       'in_chans': 200,
       'out_chans': x_dim,
       'activation': tanh_actfun,
       'apply_bn': False} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    p_x_given_h = HydraNet(rng=rng,
                           Xd=xin_sym,
                           params=params,
                           shared_param_dicts=None)
    p_x_given_h.init_biases(0.0)
    ###############
    # q_h_given_x #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'fc',
       'in_chans': x_dim,
       'out_chans': 200,
       'activation': tanh_actfun,
       'apply_bn': True}, \
      {'layer_type': 'fc',
       'in_chans': 200,
       'out_chans': 200,
       'activation': tanh_actfun,
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'fc',
       'in_chans': 200,
       'out_chans': h_dim,
       'activation': tanh_actfun,
       'apply_bn': False}, \
      {'layer_type': 'fc',
       'in_chans': 200,
       'out_chans': h_dim,
       'activation': tanh_actfun,
       'apply_bn': False} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    q_h_given_x = HydraNet(rng=rng,
                           Xd=xin_sym,
                           params=params,
                           shared_param_dicts=None)
    q_h_given_x.init_biases(0.0)
    ###############
    # q_z_given_h #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'fc',
       'in_chans': h_dim,
       'out_chans': 100,
       'activation': tanh_actfun,
       'apply_bn': True}, \
      {'layer_type': 'fc',
       'in_chans': 100,
       'out_chans': 100,
       'activation': tanh_actfun,
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'fc',
       'in_chans': 100,
       'out_chans': z_dim,
       'activation': tanh_actfun,
       'apply_bn': False}, \
      {'layer_type': 'fc',
       'in_chans': 100,
       'out_chans': z_dim,
       'activation': tanh_actfun,
       'apply_bn': False} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    q_z_given_h = HydraNet(rng=rng,
                           Xd=xin_sym,
                           params=params,
                           shared_param_dicts=None)
    q_z_given_h.init_biases(0.0)

    ##############################################################
    # Define parameters for the TwoStageModel, and initialize it #
    ##############################################################
    print("Building the TwoStageModel...")
    tsm_params = {}
    tsm_params['x_type'] = x_type
    tsm_params['obs_transform'] = 'sigmoid'
    TSM = TwoStageModel2(rng=rng,
                         x_in=xin_sym,
                         x_out=xout_sym,
                         x_dim=x_dim,
                         z_dim=z_dim,
                         h_dim=h_dim,
                         q_h_given_x=q_h_given_x,
                         q_z_given_h=q_z_given_h,
                         p_h_given_z=p_h_given_z,
                         p_x_given_h=p_x_given_h,
                         params=tsm_params)

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    log_name = "{}_RESULTS.txt".format("TSM2A_TEST")
    out_file = open(log_name, 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.001
    momentum = 0.9
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(500000):
        scale = min(1.0, ((i + 1) / 5000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        Xb = to_fX(Xtr.take(batch_idx, axis=0))
        #Xb = binarize_data(Xtr.take(batch_idx, axis=0))
        # set sgd and objective function hyperparams for this update
        TSM.set_sgd_params(lr=scale * learn_rate,
                           mom_1=(scale * momentum),
                           mom_2=0.98)
        TSM.set_train_switch(1.0)
        TSM.set_lam_nll(lam_nll=1.0)
        TSM.set_lam_kld(lam_kld_q2p=1.0, lam_kld_p2q=0.0)
        TSM.set_lam_l2w(1e-5)
        # perform a minibatch update and record the cost for this batch
        result = TSM.train_joint(Xb, Xb, batch_reps)
        costs = [(costs[j] + result[j]) for j in range(len(result))]
        if ((i % 500) == 0):
            costs = [(v / 500.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    joint_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_cost  : {0:.4f}".format(costs[1])
            str4 = "    kld_cost  : {0:.4f}".format(costs[2])
            str5 = "    reg_cost  : {0:.4f}".format(costs[3])
            str6 = "    nll       : {0:.4f}".format(np.mean(costs[4]))
            str7 = "    kld_z     : {0:.4f}".format(np.mean(costs[5]))
            str8 = "    kld_h     : {0:.4f}".format(np.mean(costs[6]))
            joint_str = "\n".join(
                [str1, str2, str3, str4, str5, str6, str7, str8])
            print(joint_str)
            out_file.write(joint_str + "\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if (((i % 5000) == 0) or ((i < 10000) and ((i % 1000) == 0))):
            # draw some independent random samples from the model
            samp_count = 300
            model_samps = TSM.sample_from_prior(samp_count)
            file_name = "TSM2A_SAMPLES_b{0:d}.png".format(i)
            utils.visualize_samples(model_samps, file_name, num_rows=15)
            # compute free energy estimate for validation samples
            Xva = row_shuffle(Xva)
            fe_terms = TSM.compute_fe_terms(Xva[0:5000], Xva[0:5000], 20)
            fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1])
            out_str = "    nll_bound : {0:.4f}".format(fe_mean)
            print(out_str)
            out_file.write(out_str + "\n")
            out_file.flush()
    return
def test_with_model_init():
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 250

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    x_dim = Xtr.shape[1]
    write_dim = 220
    enc_dim = 260
    dec_dim = 260
    mix_dim = 20
    z_dim = 100
    n_iter = 18
    
    rnninits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }
    inits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }

    # setup the reader and writer
    read_dim = 2*x_dim
    reader_mlp = Reader(x_dim=x_dim, dec_dim=dec_dim, **inits)
    writer_mlp = MLP([None, None], [dec_dim, write_dim, x_dim], \
                     name="writer_mlp", **inits)
    
    # setup the mixture weight sampler
    mix_enc_mlp = CondNet([Tanh()], [x_dim, 250, mix_dim], \
                          name="mix_enc_mlp", **inits)
    mix_dec_mlp = MLP([Tanh(), Tanh()], \
                      [mix_dim, 250, (2*enc_dim + 2*dec_dim)], \
                      name="mix_dec_mlp", **inits)
    # setup the components of the generative DRAW model
    enc_mlp_in = MLP([Identity()], [(read_dim + dec_dim), 4*enc_dim], \
                        name="enc_mlp_in", **inits)
    dec_mlp_in = MLP([Identity()], [               z_dim, 4*dec_dim], \
                        name="dec_mlp_in", **inits)
    enc_mlp_out = CondNet([], [enc_dim, z_dim], name="enc_mlp_out", **inits)
    dec_mlp_out = CondNet([], [dec_dim, z_dim], name="dec_mlp_out", **inits)
    enc_rnn = BiasedLSTM(dim=enc_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="enc_rnn", **rnninits)
    dec_rnn = BiasedLSTM(dim=dec_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="dec_rnn", **rnninits)
    enc_mlp_stop = MLP([Tanh(), None], [(x_dim + dec_dim), 500, 1], \
                       name="enc_mlp_stop", **inits)
    dec_mlp_stop = MLP([Tanh(), None], [dec_dim, 500, 1], \
                       name="dec_mlp_stop", **inits)

    draw = IMoESDrawModels(
                n_iter,
                step_type='add', # step_type can be 'add' or 'jump'
                mix_enc_mlp=mix_enc_mlp,
                mix_dec_mlp=mix_dec_mlp,
                reader_mlp=reader_mlp,
                writer_mlp=writer_mlp,
                enc_mlp_in=enc_mlp_in,
                enc_mlp_out=enc_mlp_out,
                enc_rnn=enc_rnn,
                enc_mlp_stop=enc_mlp_stop,
                dec_mlp_in=dec_mlp_in,
                dec_mlp_out=dec_mlp_out,
                dec_rnn=dec_rnn,
                dec_mlp_stop=dec_mlp_stop)
    draw.initialize()

    # some symbolic vars to represent various inputs/outputs
    x_in_sym = T.matrix('x_in_sym')
    x_out_sym = T.matrix('x_out_sym')

    # collect reconstructions of x produced by the IMoDRAW model
    vfe_cost, cost_all = draw.reconstruct(x_in_sym, x_out_sym)

    # grab handles for all the optimizable parameters in our cost
    cg = ComputationGraph([vfe_cost])
    joint_params = VariableFilter(roles=[PARAMETER])(cg.variables)

    # apply some l2 regularization to the model parameters
    reg_term = (1e-5 * sum([T.sum(p**2.0) for p in joint_params]))
    reg_term.name = "reg_term"

    # compute the full cost w.r.t. which we will optimize
    total_cost = vfe_cost + reg_term
    total_cost.name = "total_cost"

    # Get the gradient of the joint cost for all optimizable parameters
    print("Computing gradients of total_cost...")
    joint_grads = OrderedDict()
    grad_list = T.grad(total_cost, joint_params)
    for i, p in enumerate(joint_params):
        joint_grads[p] = grad_list[i]
    
    # shared var learning rate for generator and inferencer
    zero_ary = to_fX( np.zeros((1,)) )
    lr_shared = theano.shared(value=zero_ary, name='tbm_lr')
    # shared var momentum parameters for generator and inferencer
    mom_1_shared = theano.shared(value=zero_ary, name='tbm_mom_1')
    mom_2_shared = theano.shared(value=zero_ary, name='tbm_mom_2')
    # construct the updates for the generator and inferencer networks
    joint_updates = get_adam_updates(params=joint_params, \
            grads=joint_grads, alpha=lr_shared, \
            beta1=mom_1_shared, beta2=mom_2_shared, \
            mom2_init=1e-4, smoothing=1e-6, max_grad_norm=10.0)

    # collect the outputs to return from this function
    outputs = [total_cost, vfe_cost, reg_term]
    # compile the theano function
    print("Compiling model training/update function...")
    train_joint = theano.function(inputs=[ x_in_sym, x_out_sym ], \
                                  outputs=outputs, updates=joint_updates)
    print("Compiling NLL bound estimator function...")
    compute_nll_bound = theano.function(inputs=[ x_in_sym, x_out_sym], \
                                        outputs=outputs)
    print("Compiling model sampler...")
    n_samples = T.iscalar("n_samples")
    samples = draw.sample(n_samples)
    do_sample = theano.function([n_samples], outputs=samples, allow_input_downcast=True)

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    print("Beginning to train the model...")
    out_file = open("TBM_ES_RESULTS.txt", 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0002
    momentum = 0.9
    fresh_idx = np.arange(batch_size) + tr_samples
    for i in range(250000):
        scale = min(1.0, ((i+1) / 2500.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        # get the indices of training samples for this batch update
        fresh_idx += batch_size
        if (np.max(fresh_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            fresh_idx = np.arange(batch_size)
        batch_idx = fresh_idx
        # set sgd and objective function hyperparams for this update
        zero_ary = np.zeros((1,))
        lr_shared.set_value(to_fX(zero_ary + scale*learn_rate))
        mom_1_shared.set_value(to_fX(zero_ary + scale*momentum))
        mom_2_shared.set_value(to_fX(zero_ary + 0.99))

        # perform a minibatch update and record the cost for this batch
        Xb = to_fX( Xtr.take(batch_idx, axis=0) )
        result = train_joint(Xb, Xb)
        # aggregate costs over multiple minibatches
        costs = [(costs[j] + result[j]) for j in range(len(result))]
        if ((i % 200) == 0):
            # occasionally dump information about the costs
            costs = [(v / 200.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    total_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_bound : {0:.4f}".format(costs[1])
            str4 = "    reg_term  : {0:.4f}".format(costs[2])
            joint_str = "\n".join([str1, str2, str3, str4])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if ((i % 1000) == 0):
            # compute a small-sample estimate of NLL bound on validation set
            Xva = row_shuffle(Xva)
            Xb = to_fX(Xva[:5000])
            va_costs = compute_nll_bound(Xb, Xb)
            str1 = "    va_nll_bound : {}".format(va_costs[1])
            joint_str = "\n".join([str1])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            # draw some independent samples from the model
            samples = do_sample(16*16)
            n_iter, N, D = samples.shape
            samples = samples.reshape( (n_iter, N, 28, 28) )
            for j in xrange(n_iter):
                img = img_grid(samples[j,:,:,:])
                img.save("TBM-ES-samples-b%06d-%03d.png" % (i, j))
def test_mnist(step_type='add', \
               rev_sched=None):
    #########################################
    # Format the result tag more thoroughly #
    #########################################
    result_tag = "{}AAA_SRRM_ST{}".format(RESULT_PATH, step_type)

    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 200

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    x_dim = Xtr.shape[1]
    s_dim = x_dim
    #s_dim = 300
    z_dim = 100
    init_scale = 0.66

    x_out_sym = T.matrix('x_out_sym')

    #################
    # p_zi_given_xi #
    #################
    params = {}
    shared_config = [(x_dim + x_dim), 500, 500]
    top_config = [shared_config[-1], z_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = tanh_actfun
    params['init_scale'] = init_scale
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_zi_given_xi = InfNet(rng=rng, Xd=x_out_sym, \
            params=params, shared_param_dicts=None)
    p_zi_given_xi.init_biases(0.0)
    ###################
    # p_sip1_given_zi #
    ###################
    params = {}
    shared_config = [z_dim, 500, 500]
    output_config = [s_dim, s_dim, s_dim]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['activation'] = tanh_actfun
    params['init_scale'] = init_scale
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_sip1_given_zi = HydraNet(rng=rng, Xd=x_out_sym, \
            params=params, shared_param_dicts=None)
    p_sip1_given_zi.init_biases(0.0)
    ################
    # p_x_given_si #
    ################
    params = {}
    shared_config = [s_dim, 500]
    output_config = [x_dim, x_dim]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['activation'] = tanh_actfun
    params['init_scale'] = init_scale
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_x_given_si = HydraNet(rng=rng, Xd=x_out_sym, \
            params=params, shared_param_dicts=None)
    p_x_given_si.init_biases(0.0)
    ###################
    # q_zi_given_xi #
    ###################
    params = {}
    shared_config = [(x_dim + x_dim), 500, 500]
    top_config = [shared_config[-1], z_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = tanh_actfun
    params['init_scale'] = init_scale
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    q_zi_given_xi = InfNet(rng=rng, Xd=x_out_sym, \
            params=params, shared_param_dicts=None)
    q_zi_given_xi.init_biases(0.0)

    #################################################
    # Setup a revelation schedule if none was given #
    #################################################
    # if rev_sched is None:
    #    rev_sched = [(10, 1.0)]
    # rev_masks = None
    p_masks = np.zeros((16,x_dim))
    p_masks[7] = npr.uniform(size=(1,x_dim)) < 0.25
    p_masks[-1] = np.ones((1,x_dim))
    p_masks = p_masks.astype(theano.config.floatX)
    q_masks = np.ones(p_masks.shape).astype(theano.config.floatX)
    rev_masks = [p_masks, q_masks]

    #########################################################
    # Define parameters for the SRRModel, and initialize it #
    #########################################################
    print("Building the SRRModel...")
    srrm_params = {}
    srrm_params['x_dim'] = x_dim
    srrm_params['z_dim'] = z_dim
    srrm_params['s_dim'] = s_dim
    srrm_params['use_p_x_given_si'] = False
    srrm_params['rev_sched'] = rev_sched
    srrm_params['rev_masks'] = rev_masks
    srrm_params['step_type'] = step_type
    srrm_params['x_type'] = 'bernoulli'
    srrm_params['obs_transform'] = 'sigmoid'
    SRRM = SRRModel(rng=rng,
            x_out=x_out_sym, \
            p_zi_given_xi=p_zi_given_xi, \
            p_sip1_given_zi=p_sip1_given_zi, \
            p_x_given_si=p_x_given_si, \
            q_zi_given_xi=q_zi_given_xi, \
            params=srrm_params, \
            shared_param_dicts=None)

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    log_name = "{}_RESULTS.txt".format(result_tag)
    out_file = open(log_name, 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.00015
    momentum = 0.5
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(250000):
        scale = min(1.0, ((i+1) / 5000.0))
        lam_scale = 1.0 - min(1.0, ((i+1) / 50000.0)) # decays from 1.0->0.0
        if (((i + 1) % 15000) == 0):
            learn_rate = learn_rate * 0.93
        if (i > 10000):
            momentum = 0.95
        else:
            momentum = 0.80
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        # set sgd and objective function hyperparams for this update
        SRRM.set_sgd_params(lr=scale*learn_rate, \
                            mom_1=scale*momentum, mom_2=0.98)
        SRRM.set_train_switch(1.0)
        SRRM.set_lam_kld(lam_kld_p=0.0, lam_kld_q=1.0, \
                         lam_kld_g=0.0, lam_kld_s=0.0)
        SRRM.set_lam_l2w(1e-5)
        # perform a minibatch update and record the cost for this batch
        xb = to_fX( Xtr.take(batch_idx, axis=0) )
        result = SRRM.train_joint(xb)
        # do diagnostics and general training tracking
        costs = [(costs[j] + result[j]) for j in range(len(result)-1)]
        if ((i % 250) == 0):
            costs = [(v / 250.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    joint_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_bound : {0:.4f}".format(costs[1])
            str4 = "    nll_cost  : {0:.4f}".format(costs[2])
            str5 = "    kld_cost  : {0:.4f}".format(costs[3])
            str6 = "    reg_cost  : {0:.4f}".format(costs[4])
            joint_str = "\n".join([str1, str2, str3, str4, str5, str6])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if ((i % 1000) == 0):
            Xva = row_shuffle(Xva)
            # record an estimate of performance on the test set
            xb = Xva[0:5000]
            nll, kld = SRRM.compute_fe_terms(xb, sample_count=10)
            vfe = np.mean(nll) + np.mean(kld)
            str1 = "    va_nll_bound : {}".format(vfe)
            str2 = "    va_nll_term  : {}".format(np.mean(nll))
            str3 = "    va_kld_q2p   : {}".format(np.mean(kld))
            joint_str = "\n".join([str1, str2, str3])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            # draw some sample imputations from the model
            xo = Xva[0:100]
            samp_count = xo.shape[0]
            xm_seq, xi_seq, mi_seq = SRRM.sequence_sampler(xo, use_guide_policy=True)
            seq_len = len(xm_seq)
            seq_samps = np.zeros((seq_len*samp_count, xm_seq[0].shape[1]))
            ######
            # xm #
            ######
            idx = 0
            for s1 in range(samp_count):
                for s2 in range(seq_len):
                    seq_samps[idx] = xm_seq[s2,s1,:]
                    idx += 1
            file_name = "{0:s}_xm_samples_b{1:d}.png".format(result_tag, i)
            utils.visualize_samples(seq_samps, file_name, num_rows=20)
            ######
            # xi #
            ######
            idx = 0
            for s1 in range(samp_count):
                for s2 in range(seq_len):
                    seq_samps[idx] = xi_seq[s2,s1,:]
                    idx += 1
            file_name = "{0:s}_xi_samples_b{1:d}.png".format(result_tag, i)
            utils.visualize_samples(seq_samps, file_name, num_rows=20)
            ######
            # mi #
            ######
            idx = 0
            for s1 in range(samp_count):
                for s2 in range(seq_len):
                    seq_samps[idx] = mi_seq[s2,s1,:]
                    idx += 1
            file_name = "{0:s}_mi_samples_b{1:d}.png".format(result_tag, i)
            utils.visualize_samples(seq_samps, file_name, num_rows=20)
示例#14
0
def test_with_model_init():
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    #dataset = 'data/mnist.pkl.gz'
    #datasets = load_udm(dataset, as_shared=False, zero_mean=False)
    #Xtr = datasets[0][0]
    #Xva = datasets[1][0]
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 200
    batch_reps = 1

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    x_dim = Xtr.shape[1]
    z_dim = 20
    h_dim = 50
    s_dim = 50
    init_scale = 1.0
    
    x_type = 'bernoulli'

    # some InfNet instances to build the TwoStageModel from
    x_in_sym = T.matrix('x_in_sym')
    x_out_sym = T.matrix('x_out_sym')

    ###############
    # p_h_given_s #
    ###############
    params = {}
    shared_config = [s_dim, 250, 250]
    top_config = [shared_config[-1], h_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = softplus_actfun
    params['init_scale'] = init_scale
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_h_given_s = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    p_h_given_s.init_biases(0.2)
    #################
    # p_x_given_s_h #
    #################
    params = {}
    shared_config = [(s_dim + h_dim), 250, 250]
    top_config = [shared_config[-1], x_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = softplus_actfun
    params['init_scale'] = init_scale
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_x_given_s_h = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    p_x_given_s_h.init_biases(0.2)
    ###############
    # p_s_given_z #
    ###############
    params = {}
    shared_config = [z_dim, 250]
    top_config = [shared_config[-1], s_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = softplus_actfun
    params['init_scale'] = init_scale
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_s_given_z = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    p_s_given_z.init_biases(0.2)
    ###############
    # q_z_given_x #
    ###############
    params = {}
    shared_config = [x_dim, 250, 250]
    top_config = [shared_config[-1], z_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = softplus_actfun
    params['init_scale'] = init_scale
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    q_z_given_x = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    q_z_given_x.init_biases(0.2)
    #################
    # q_h_given_x_s #
    #################
    params = {}
    shared_config = [(x_dim + s_dim), 500, 500]
    top_config = [shared_config[-1], h_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = softplus_actfun
    params['init_scale'] = init_scale
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    q_h_given_x_s = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    q_h_given_x_s.init_biases(0.2)


    ##############################################################
    # Define parameters for the TwoStageModel, and initialize it #
    ##############################################################
    print("Building the TwoStageModel...")
    msm_params = {}
    msm_params['x_type'] = x_type
    msm_params['obs_transform'] = 'sigmoid'
    TSM = TwoStageModel(rng=rng, \
            x_in=x_in_sym, x_out=x_out_sym, \
            p_s_given_z=p_s_given_z, \
            p_h_given_s=p_h_given_s, \
            p_x_given_s_h=p_x_given_s_h, \
            q_z_given_x=q_z_given_x, \
            q_h_given_x_s=q_h_given_x_s, \
            x_dim=x_dim, \
            z_dim=z_dim, s_dim=s_dim, h_dim=h_dim, \
            params=msm_params)

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    out_file = open("TSM_A_RESULTS.txt", 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0003
    momentum = 0.5
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(250000):
        scale = min(1.0, ((i+1) / 3000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        if (i > 50000):
            momentum = 0.90
        else:
            momentum = 0.50
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        # train on the training set
        lam_kld = 1.0
        # set sgd and objective function hyperparams for this update
        TSM.set_sgd_params(lr_1=scale*learn_rate, lr_2=scale*learn_rate, \
                mom_1=scale*momentum, mom_2=0.99)
        TSM.set_train_switch(1.0)
        TSM.set_lam_nll(lam_nll=1.0)
        TSM.set_lam_kld(lam_kld_z=1.0, lam_kld_q2p=0.8, lam_kld_p2q=0.2)
        TSM.set_lam_kld_l1l2(lam_kld_l1l2=scale)
        TSM.set_lam_l2w(1e-4)
        TSM.set_drop_rate(0.0)
        TSM.q_h_given_x_s.set_bias_noise(0.0)
        TSM.p_h_given_s.set_bias_noise(0.0)
        TSM.p_x_given_s_h.set_bias_noise(0.0)
        # perform a minibatch update and record the cost for this batch
        Xb_tr = to_fX( Xtr.take(batch_idx, axis=0) )
        result = TSM.train_joint(Xb_tr, Xb_tr, batch_reps)
        costs = [(costs[j] + result[j]) for j in range(len(result)-1)]
        if ((i % 500) == 0):
            costs = [(v / 500.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    joint_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_cost  : {0:.4f}".format(costs[1])
            str4 = "    kld_cost  : {0:.4f}".format(costs[2])
            str5 = "    reg_cost  : {0:.4f}".format(costs[3])
            joint_str = "\n".join([str1, str2, str3, str4, str5])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if (((i % 2000) == 0) or ((i < 10000) and ((i % 1000) == 0))):
            TSM.set_drop_rate(0.0)
            TSM.q_h_given_x_s.set_bias_noise(0.0)
            TSM.p_h_given_s.set_bias_noise(0.0)
            TSM.p_x_given_s_h.set_bias_noise(0.0)
            # Get some validation samples for computing diagnostics
            Xva = row_shuffle(Xva)
            Xb_va = to_fX( Xva[0:2000] )
            # draw some independent random samples from the model
            samp_count = 500
            model_samps = TSM.sample_from_prior(samp_count)
            file_name = "TSM_A_SAMPLES_IND_b{0:d}.png".format(i)
            utils.visualize_samples(model_samps, file_name, num_rows=20)
            Xb_tr = to_fX( Xtr[0:2000] )
            fe_terms = TSM.compute_fe_terms(Xb_tr, Xb_tr, 30)
            fe_nll = np.mean(fe_terms[0])
            fe_kld = np.mean(fe_terms[1])
            fe_joint = fe_nll + fe_kld
            joint_str = "    vfe-tr: {0:.4f}, nll: ({1:.4f}, {2:.4f}, {3:.4f}), kld: ({4:.4f}, {5:.4f}, {6:.4f})".format( \
                    fe_joint, fe_nll, np.min(fe_terms[0]), np.max(fe_terms[0]), fe_kld, np.min(fe_terms[1]), np.max(fe_terms[1]))
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            fe_terms = TSM.compute_fe_terms(Xb_va, Xb_va, 30)
            fe_nll = np.mean(fe_terms[0])
            fe_kld = np.mean(fe_terms[1])
            fe_joint = fe_nll + fe_kld
            joint_str = "    vfe-va: {0:.4f}, nll: ({1:.4f}, {2:.4f}, {3:.4f}), kld: ({4:.4f}, {5:.4f}, {6:.4f})".format( \
                    fe_joint, fe_nll, np.min(fe_terms[0]), np.max(fe_terms[0]), fe_kld, np.min(fe_terms[1]), np.max(fe_terms[1]))
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
def test_one_stage_model():
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 128
    batch_reps = 1

    ###############################################
    # Setup some parameters for the OneStageModel #
    ###############################################
    x_dim = Xtr.shape[1]
    z_dim = 64
    x_type = 'bernoulli'
    xin_sym = T.matrix('xin_sym')

    ###############
    # p_x_given_z #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'fc',
       'in_chans': z_dim,
       'out_chans': 256,
       'activation': relu_actfun,
       'apply_bn': True}, \
      {'layer_type': 'fc',
       'in_chans': 256,
       'out_chans': 7*7*128,
       'activation': relu_actfun,
       'apply_bn': True,
       'shape_func_out': lambda x: T.reshape(x, (-1, 128, 7, 7))}, \
      {'layer_type': 'conv',
       'in_chans': 128, # in shape:  (batch, 128, 7, 7)
       'out_chans': 64, # out shape: (batch, 64, 14, 14)
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'half',
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'conv',
       'in_chans': 64, # in shape:  (batch, 64, 14, 14)
       'out_chans': 1, # out shape: (batch, 1, 28, 28)
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'half',
       'apply_bn': False,
       'shape_func_out': lambda x: T.flatten(x, 2)}, \
      {'layer_type': 'conv',
       'in_chans': 64,
       'out_chans': 1,
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'half',
       'apply_bn': False,
       'shape_func_out': lambda x: T.flatten(x, 2)} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    p_x_given_z = HydraNet(rng=rng, Xd=xin_sym, \
            params=params, shared_param_dicts=None)
    p_x_given_z.init_biases(0.0)
    ###############
    # q_z_given_x #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'conv',
       'in_chans': 1,   # in shape:  (batch, 784)
       'out_chans': 64, # out shape: (batch, 64, 14, 14)
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'double',
       'apply_bn': True,
       'shape_func_in': lambda x: T.reshape(x, (-1, 1, 28, 28))}, \
      {'layer_type': 'conv',
       'in_chans': 64,   # in shape:  (batch, 64, 14, 14)
       'out_chans': 128, # out shape: (batch, 128, 7, 7)
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'double',
       'apply_bn': True,
       'shape_func_out': lambda x: T.flatten(x, 2)}, \
      {'layer_type': 'fc',
       'in_chans': 128*7*7,
       'out_chans': 256,
       'activation': relu_actfun,
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'fc',
       'in_chans': 256,
       'out_chans': z_dim,
       'activation': relu_actfun,
       'apply_bn': False}, \
      {'layer_type': 'fc',
       'in_chans': 256,
       'out_chans': z_dim,
       'activation': relu_actfun,
       'apply_bn': False} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    q_z_given_x = HydraNet(rng=rng, Xd=xin_sym, \
            params=params, shared_param_dicts=None)
    q_z_given_x.init_biases(0.0)


    ##############################################################
    # Define parameters for the TwoStageModel, and initialize it #
    ##############################################################
    print("Building the OneStageModel...")
    osm_params = {}
    osm_params['x_type'] = x_type
    osm_params['obs_transform'] = 'sigmoid'
    OSM = OneStageModel(rng=rng, x_in=xin_sym,
            x_dim=x_dim, z_dim=z_dim,
            p_x_given_z=p_x_given_z,
            q_z_given_x=q_z_given_x,
            params=osm_params)

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    log_name = "{}_RESULTS.txt".format("OSM_TEST")
    out_file = open(log_name, 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0005
    momentum = 0.9
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(500000):
        scale = min(0.5, ((i+1) / 5000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        Xb = to_fX( Xtr.take(batch_idx, axis=0) )
        #Xb = binarize_data(Xtr.take(batch_idx, axis=0))
        # set sgd and objective function hyperparams for this update
        OSM.set_sgd_params(lr=scale*learn_rate, \
                           mom_1=(scale*momentum), mom_2=0.98)
        OSM.set_lam_nll(lam_nll=1.0)
        OSM.set_lam_kld(lam_kld=1.0)
        OSM.set_lam_l2w(1e-5)
        # perform a minibatch update and record the cost for this batch
        result = OSM.train_joint(Xb, batch_reps)
        costs = [(costs[j] + result[j]) for j in range(len(result))]
        if ((i % 500) == 0):
            costs = [(v / 500.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    joint_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_cost  : {0:.4f}".format(costs[1])
            str4 = "    kld_cost  : {0:.4f}".format(costs[2])
            str5 = "    reg_cost  : {0:.4f}".format(costs[3])
            joint_str = "\n".join([str1, str2, str3, str4, str5])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if (((i % 5000) == 0) or ((i < 10000) and ((i % 1000) == 0))):
            # draw some independent random samples from the model
            samp_count = 300
            model_samps = OSM.sample_from_prior(samp_count)
            file_name = "OSM_SAMPLES_b{0:d}.png".format(i)
            utils.visualize_samples(model_samps, file_name, num_rows=15)
            # compute free energy estimate for validation samples
            Xva = row_shuffle(Xva)
            fe_terms = OSM.compute_fe_terms(Xva[0:5000], 20)
            fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1])
            out_str = "    nll_bound : {0:.4f}".format(fe_mean)
            print(out_str)
            out_file.write(out_str+"\n")
            out_file.flush()
    return
示例#16
0
def test_imocld_generation(step_type='add', attention=False):
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 250

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    x_dim = Xtr.shape[1]
    write_dim = 200
    enc_dim = 250
    dec_dim = 250
    mix_dim = 20
    z_dim = 100
    n_iter = 16

    rnninits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }
    inits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }

    att_tag = "NA"  # attention not tested yet

    # setup the reader and writer (shared by primary and guide policies)
    read_dim = 2 * x_dim  # dimension of output from reader_mlp
    reader_mlp = Reader(x_dim=x_dim, dec_dim=dec_dim, **inits)
    writer_mlp = MLP([None, None], [dec_dim, write_dim, x_dim], \
                     name="writer_mlp", **inits)

    # mlps for setting conditionals over z_mix
    mix_var_mlp = CondNet([Tanh()], [x_dim, 250, mix_dim], \
                          name="mix_var_mlp", **inits)
    mix_enc_mlp = CondNet([Tanh()], [x_dim, 250, mix_dim], \
                          name="mix_enc_mlp", **inits)
    # mlp for decoding z_mix into a distribution over initial LSTM states
    mix_dec_mlp = MLP([Tanh(), Tanh()], \
                      [mix_dim, 250, (2*enc_dim + 2*dec_dim + 2*enc_dim)], \
                      name="mix_dec_mlp", **inits)
    # mlps for processing inputs to LSTMs
    var_mlp_in = MLP([Identity()], [(read_dim + dec_dim), 4*enc_dim], \
                     name="var_mlp_in", **inits)
    enc_mlp_in = MLP([Identity()], [(read_dim + dec_dim), 4*enc_dim], \
                     name="enc_mlp_in", **inits)
    dec_mlp_in = MLP([Identity()], [               z_dim, 4*dec_dim], \
                     name="dec_mlp_in", **inits)
    # mlps for turning LSTM outputs into conditionals over z_gen
    var_mlp_out = CondNet([], [enc_dim, z_dim], name="var_mlp_out", **inits)
    enc_mlp_out = CondNet([], [enc_dim, z_dim], name="enc_mlp_out", **inits)
    # LSTMs for the actual LSTMs (obviously, perhaps)
    var_rnn = BiasedLSTM(dim=enc_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="var_rnn", **rnninits)
    enc_rnn = BiasedLSTM(dim=enc_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="enc_rnn", **rnninits)
    dec_rnn = BiasedLSTM(dim=dec_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="dec_rnn", **rnninits)

    draw = IMoCLDrawModels(
        n_iter,
        step_type='add',  # step_type can be 'add' or 'jump'
        reader_mlp=reader_mlp,
        writer_mlp=writer_mlp,
        mix_enc_mlp=mix_enc_mlp,
        mix_dec_mlp=mix_dec_mlp,
        mix_var_mlp=mix_var_mlp,
        enc_mlp_in=enc_mlp_in,
        enc_mlp_out=enc_mlp_out,
        enc_rnn=enc_rnn,
        dec_mlp_in=dec_mlp_in,
        dec_rnn=dec_rnn,
        var_mlp_in=var_mlp_in,
        var_mlp_out=var_mlp_out,
        var_rnn=var_rnn)
    draw.initialize()

    # build the cost gradients, training function, samplers, etc.
    draw.build_model_funcs()

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    print("Beginning to train the model...")
    out_file = open("TBCLM_GEN_RESULTS_{}_{}.txt".format(step_type, att_tag),
                    'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0002
    momentum = 0.5
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(250000):
        scale = min(1.0, ((i + 1) / 1000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        if (i > 10000):
            momentum = 0.90
        else:
            momentum = 0.50
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        # set sgd and objective function hyperparams for this update
        zero_ary = np.zeros((1, ))
        draw.lr.set_value(to_fX(zero_ary + learn_rate))
        draw.mom_1.set_value(to_fX(zero_ary + momentum))
        draw.mom_2.set_value(to_fX(zero_ary + 0.99))

        # perform a minibatch update and record the cost for this batch
        Xb = to_fX(Xtr.take(batch_idx, axis=0))
        Mb = 0.0 * Xb
        result = draw.train_joint(Xb, Mb)

        costs = [(costs[j] + result[j]) for j in range(len(result))]
        if ((i % 200) == 0):
            costs = [(v / 200.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    total_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_bound : {0:.4f}".format(costs[1])
            str4 = "    nll_term  : {0:.4f}".format(costs[2])
            str5 = "    kld_q2p   : {0:.4f}".format(costs[3])
            str6 = "    kld_p2q   : {0:.4f}".format(costs[4])
            str7 = "    reg_term  : {0:.4f}".format(costs[5])
            joint_str = "\n".join([str1, str2, str3, str4, str5, str6, str7])
            print(joint_str)
            out_file.write(joint_str + "\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if ((i % 1000) == 0):
            draw.save_model_params("TBCLM_GEN_PARAMS_{}_{}.pkl".format(
                step_type, att_tag))
            # compute a small-sample estimate of NLL bound on validation set
            Xva = row_shuffle(Xva)
            Xb = to_fX(Xva[:5000])
            Mb = 0.0 * Xb
            va_costs = draw.compute_nll_bound(Xb, Mb)
            str1 = "    va_nll_bound : {}".format(va_costs[1])
            str2 = "    va_nll_term  : {}".format(va_costs[2])
            str3 = "    va_kld_q2p   : {}".format(va_costs[3])
            joint_str = "\n".join([str1, str2, str3])
            print(joint_str)
            out_file.write(joint_str + "\n")
            out_file.flush()
            # draw some independent samples from the model
            Xb = to_fX(Xva[:256])
            Mb = 0.0 * Xb
            samples, _ = draw.do_sample(Xb, Mb)
            n_iter, N, D = samples.shape
            samples = samples.reshape((n_iter, N, 28, 28))
            for j in xrange(n_iter):
                img = img_grid(samples[j, :, :, :])
                img.save("TBCLM-gen-samples-%03d.png" % (j, ))
示例#17
0
def test_rldraw_classic(step_type='add', use_pol=True):
    ###########################################
    # Make a tag for identifying result files #
    ###########################################
    pol_tag = "yp" if use_pol else "np"
    res_tag = "TRLD_SPLIT_E002_{}_{}".format(step_type, pol_tag)

    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 200

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    x_dim = Xtr.shape[1]
    write_dim = 500
    rnn_dim = 500
    z_dim = 100
    n_iter = 20

    rnninits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }
    inits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }

    # setup reader/writer models
    read_dim = 2*x_dim
    reader_mlp = Reader(x_dim=x_dim, dec_dim=rnn_dim, **inits)
    writer_mlp = MLP([None, None], [rnn_dim, write_dim, x_dim],
                     name="writer_mlp", **inits)

    # setup submodels for processing LSTM inputs
    pol_mlp_in = MLP([Identity()], [rnn_dim, 4*rnn_dim],
                     name="pol_mlp_in", **inits)
    var_mlp_in = MLP([Identity()], [(x_dim + rnn_dim), 4*rnn_dim],
                     name="var_mlp_in", **inits)
    ent_mlp_in = MLP([Identity()], [(x_dim + rnn_dim), 4*rnn_dim],
                     name="ent_mlp_in", **inits)
    dec_mlp_in = MLP([Identity()], [z_dim, 4*rnn_dim],
                     name="dec_mlp_in", **inits)
    # setup submodels for turning LSTM states into conditionals over z
    pol_mlp_out = CondNet([], [rnn_dim, z_dim], name="pol_mlp_out", **inits)
    var_mlp_out = CondNet([], [rnn_dim, z_dim], name="var_mlp_out", **inits)
    ent_mlp_out = CondNet([], [rnn_dim, z_dim], name="ent_mlp_out", **inits)
    dec_mlp_out = CondNet([], [rnn_dim, z_dim], name="dec_mlp_out", **inits)
    # setup the LSTMs for primary policy, guide policy, and shared dynamics
    pol_rnn = BiasedLSTM(dim=rnn_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="pol_rnn", **rnninits)
    var_rnn = BiasedLSTM(dim=rnn_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="var_rnn", **rnninits)
    ent_rnn = BiasedLSTM(dim=rnn_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="ent_rnn", **rnninits)
    dec_rnn = BiasedLSTM(dim=rnn_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="dec_rnn", **rnninits)

    draw = RLDrawModel(
                n_iter,
                step_type=step_type, # step_type can be 'add' or 'jump'
                use_pol=use_pol,
                reader_mlp=reader_mlp,
                writer_mlp=writer_mlp,
                pol_mlp_in=pol_mlp_in,
                pol_mlp_out=pol_mlp_out,
                pol_rnn=pol_rnn,
                var_mlp_in=var_mlp_in,
                var_mlp_out=var_mlp_out,
                var_rnn=var_rnn,
                dec_mlp_in=dec_mlp_in,
                dec_mlp_out=dec_mlp_out,
                dec_rnn=dec_rnn,
                ent_mlp_in=ent_mlp_in,
                ent_mlp_out=ent_mlp_out,
                ent_rnn=ent_rnn)
    draw.initialize()

    compile_start_time = time.time()

    # build the cost gradients, training function, samplers, etc.
    draw.build_sampling_funcs()
    print("Testing model sampler...")
    # draw some independent samples from the model
    samples = draw.sample_model(Xtr[:65,:], sample_source='p')
    n_iter, N, D = samples.shape
    samples = samples.reshape( (n_iter, N, 28, 28) )
    for j in xrange(n_iter):
        img = img_grid(samples[j,:,:,:])
        img.save("%s_samples_%03d.png" % (res_tag, j))

    draw.build_model_funcs()

    compile_end_time = time.time()
    compile_minutes = (compile_end_time - compile_start_time) / 60.0
    print("THEANO COMPILE TIME (MIN): {}".format(compile_minutes))

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    print("Beginning to train the model...")
    out_file = open("{}_results.txt".format(res_tag), 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.00015
    momentum = 0.9
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(300000):
        scale = min(1.0, ((i+1) / 5000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        # set sgd and objective function hyperparams for this update
        draw.set_sgd_params(lr=scale*learn_rate, mom_1=scale*momentum, mom_2=0.98)
        draw.set_lam_kld(lam_kld_q2p=1.0, lam_kld_p2q=0.0, lam_neg_ent=0.02)
        draw.set_grad_noise(grad_noise=0.02)
        # perform a minibatch update and record the cost for this batch
        Xb = to_fX(Xtr.take(batch_idx, axis=0))
        result = draw.train_joint(Xb)
        costs = [(costs[j] + result[j]) for j in range(len(result))]

        # diagnostics
        if ((i % 250) == 0):
            costs = [(v / 250.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    total_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_bound : {0:.4f}".format(costs[1])
            str4 = "    nll_term  : {0:.4f}".format(costs[2])
            str5 = "    kld_q2p   : {0:.4f}".format(costs[3])
            str6 = "    kld_p2q   : {0:.4f}".format(costs[4])
            str7 = "    neg_ent   : {0:.4f}".format(costs[5])
            str8 = "    reg_term  : {0:.4f}".format(costs[6])
            joint_str = "\n".join([str1, str2, str3, str4, str5, str6, str7, str8])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if ((i % 1000) == 0):
            draw.save_model_params("{}_params.pkl".format(res_tag))
            # compute a small-sample estimate of NLL bound on validation set
            Xva = row_shuffle(Xva)
            Xb = to_fX(Xva[:5000])
            va_costs = draw.compute_nll_bound(Xb)
            str1 = "    va_nll_bound : {}".format(va_costs[1])
            str2 = "    va_nll_term  : {}".format(va_costs[2])
            str3 = "    va_kld_q2p   : {}".format(va_costs[3])
            str4 = "    va_neg_ent   : {}".format(va_costs[5])
            joint_str = "\n".join([str1, str2, str3, str4])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            # draw some independent samples from the model
            samples = draw.sample_model(Xb[:256,:], sample_source='p')
            n_iter, N, D = samples.shape
            samples = samples.reshape( (n_iter, N, 28, 28) )
            for j in xrange(n_iter):
                img = img_grid(samples[j,:,:,:])
                img.save("%s_samples_%03d.png" % (res_tag, j))
def test_imocld_generation(step_type='add', attention=False):
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 250

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    x_dim = Xtr.shape[1]
    write_dim = 200
    enc_dim = 250
    dec_dim = 250
    mix_dim = 20
    z_dim = 100
    n_iter = 16
    
    rnninits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }
    inits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }

    att_tag = "NA" # attention not tested yet

    # setup the reader and writer (shared by primary and guide policies)
    read_dim = 2*x_dim # dimension of output from reader_mlp
    reader_mlp = Reader(x_dim=x_dim, dec_dim=dec_dim, **inits)
    writer_mlp = MLP([None, None], [dec_dim, write_dim, x_dim], \
                     name="writer_mlp", **inits)
    
    # mlps for setting conditionals over z_mix
    mix_var_mlp = CondNet([Tanh()], [x_dim, 250, mix_dim], \
                          name="mix_var_mlp", **inits)
    mix_enc_mlp = CondNet([Tanh()], [x_dim, 250, mix_dim], \
                          name="mix_enc_mlp", **inits)
    # mlp for decoding z_mix into a distribution over initial LSTM states
    mix_dec_mlp = MLP([Tanh(), Tanh()], \
                      [mix_dim, 250, (2*enc_dim + 2*dec_dim + 2*enc_dim)], \
                      name="mix_dec_mlp", **inits)
    # mlps for processing inputs to LSTMs
    var_mlp_in = MLP([Identity()], [(read_dim + dec_dim), 4*enc_dim], \
                     name="var_mlp_in", **inits)
    enc_mlp_in = MLP([Identity()], [(read_dim + dec_dim), 4*enc_dim], \
                     name="enc_mlp_in", **inits)
    dec_mlp_in = MLP([Identity()], [               z_dim, 4*dec_dim], \
                     name="dec_mlp_in", **inits)
    # mlps for turning LSTM outputs into conditionals over z_gen
    var_mlp_out = CondNet([], [enc_dim, z_dim], name="var_mlp_out", **inits)
    enc_mlp_out = CondNet([], [enc_dim, z_dim], name="enc_mlp_out", **inits)
    # LSTMs for the actual LSTMs (obviously, perhaps)
    var_rnn = BiasedLSTM(dim=enc_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="var_rnn", **rnninits)
    enc_rnn = BiasedLSTM(dim=enc_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="enc_rnn", **rnninits)
    dec_rnn = BiasedLSTM(dim=dec_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="dec_rnn", **rnninits)

    draw = IMoCLDrawModels(
                n_iter,
                step_type='add', # step_type can be 'add' or 'jump'
                reader_mlp=reader_mlp,
                writer_mlp=writer_mlp,
                mix_enc_mlp=mix_enc_mlp,
                mix_dec_mlp=mix_dec_mlp,
                mix_var_mlp=mix_var_mlp,
                enc_mlp_in=enc_mlp_in,
                enc_mlp_out=enc_mlp_out,
                enc_rnn=enc_rnn,
                dec_mlp_in=dec_mlp_in,
                dec_rnn=dec_rnn,
                var_mlp_in=var_mlp_in,
                var_mlp_out=var_mlp_out,
                var_rnn=var_rnn)
    draw.initialize()

    # build the cost gradients, training function, samplers, etc.
    draw.build_model_funcs()

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    print("Beginning to train the model...")
    out_file = open("TBCLM_GEN_RESULTS_{}_{}.txt".format(step_type, att_tag), 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0002
    momentum = 0.5
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(250000):
        scale = min(1.0, ((i+1) / 1000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        if (i > 10000):
            momentum = 0.90
        else:
            momentum = 0.50
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        # set sgd and objective function hyperparams for this update
        zero_ary = np.zeros((1,))
        draw.lr.set_value(to_fX(zero_ary + learn_rate))
        draw.mom_1.set_value(to_fX(zero_ary + momentum))
        draw.mom_2.set_value(to_fX(zero_ary + 0.99))

        # perform a minibatch update and record the cost for this batch
        Xb = to_fX(Xtr.take(batch_idx, axis=0))
        Mb = 0.0 * Xb
        result = draw.train_joint(Xb, Mb)

        costs = [(costs[j] + result[j]) for j in range(len(result))]
        if ((i % 200) == 0):
            costs = [(v / 200.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    total_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_bound : {0:.4f}".format(costs[1])
            str4 = "    nll_term  : {0:.4f}".format(costs[2])
            str5 = "    kld_q2p   : {0:.4f}".format(costs[3])
            str6 = "    kld_p2q   : {0:.4f}".format(costs[4])
            str7 = "    reg_term  : {0:.4f}".format(costs[5])
            joint_str = "\n".join([str1, str2, str3, str4, str5, str6, str7])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if ((i % 1000) == 0):
            draw.save_model_params("TBCLM_GEN_PARAMS_{}_{}.pkl".format(step_type, att_tag))
            # compute a small-sample estimate of NLL bound on validation set
            Xva = row_shuffle(Xva)
            Xb = to_fX(Xva[:5000])
            Mb = 0.0 * Xb
            va_costs = draw.compute_nll_bound(Xb, Mb)
            str1 = "    va_nll_bound : {}".format(va_costs[1])
            str2 = "    va_nll_term  : {}".format(va_costs[2])
            str3 = "    va_kld_q2p   : {}".format(va_costs[3])
            joint_str = "\n".join([str1, str2, str3])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            # draw some independent samples from the model
            Xb = to_fX(Xva[:256])
            Mb = 0.0 * Xb
            samples = draw.do_sample(Xb, Mb)
            n_iter, N, D = samples.shape
            samples = samples.reshape( (n_iter, N, 28, 28) )
            for j in xrange(n_iter):
                img = img_grid(samples[j,:,:,:])
                img.save("TBCLM-gen-samples-%03d.png" % (j,))
def test_lstm_structpred(step_type='add', use_pol=True, use_binary=False):
    ###########################################
    # Make a tag for identifying result files #
    ###########################################
    pol_tag = "P1" if use_pol else "P0"
    bin_tag = "B1" if use_binary else "B0"
    res_tag = "STRUCT_PRED_RESULTS/SP_LSTM_{}_{}_{}".format(
        step_type, pol_tag, bin_tag)

    if use_binary:
        ############################
        # Get binary training data #
        ############################
        rng = np.random.RandomState(1234)
        Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
        #Xtr = np.vstack((Xtr, Xva))
        #Xva = Xte
    else:
        ################################
        # Get continuous training data #
        ################################
        rng = np.random.RandomState(1234)
        dataset = 'data/mnist.pkl.gz'
        datasets = load_udm(dataset, as_shared=False, zero_mean=False)
        Xtr = datasets[0][0]
        Xva = datasets[1][0]
        Xte = datasets[2][0]
        #Xtr = np.concatenate((Xtr, Xva), axis=0)
        #Xva = Xte
        Xtr = to_fX(shift_and_scale_into_01(Xtr))
        Xva = to_fX(shift_and_scale_into_01(Xva))
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 200

    ########################################################
    # Split data into "observation" and "prediction" parts #
    ########################################################
    obs_cols = 14  # number of columns to observe
    pred_cols = 28 - obs_cols  # number of columns to predict
    x_dim = obs_cols * 28  # dimensionality of observations
    y_dim = pred_cols * 28  # dimensionality of predictions
    Xtr, Ytr = img_split(Xtr,
                         im_dim=(28, 28),
                         split_col=obs_cols,
                         transposed=True)
    Xva, Yva = img_split(Xva,
                         im_dim=(28, 28),
                         split_col=obs_cols,
                         transposed=True)

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    read_dim = 128
    write_dim = 128
    mlp_dim = 128
    rnn_dim = 128
    z_dim = 64
    n_iter = 15

    rnninits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }
    inits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }

    # setup reader/writer models
    reader_mlp = MLP([Rectifier(), Tanh()], [x_dim, mlp_dim, read_dim],
                     name="reader_mlp",
                     **inits)
    writer_mlp = MLP([Rectifier(), None], [rnn_dim, mlp_dim, y_dim],
                     name="writer_mlp",
                     **inits)

    # setup submodels for processing LSTM inputs
    pol_inp_dim = y_dim + read_dim + rnn_dim
    var_inp_dim = y_dim + y_dim + read_dim + rnn_dim
    pol_mlp_in = MLP([Identity()], [pol_inp_dim, 4 * rnn_dim],
                     name="pol_mlp_in",
                     **inits)
    var_mlp_in = MLP([Identity()], [var_inp_dim, 4 * rnn_dim],
                     name="var_mlp_in",
                     **inits)
    dec_mlp_in = MLP([Identity()], [z_dim, 4 * rnn_dim],
                     name="dec_mlp_in",
                     **inits)

    # setup submodels for turning LSTM states into conditionals over z
    pol_mlp_out = CondNet([], [rnn_dim, z_dim], name="pol_mlp_out", **inits)
    var_mlp_out = CondNet([], [rnn_dim, z_dim], name="var_mlp_out", **inits)
    dec_mlp_out = CondNet([], [rnn_dim, z_dim], name="dec_mlp_out", **inits)

    # setup the LSTMs for primary policy, guide policy, and shared dynamics
    pol_rnn = BiasedLSTM(dim=rnn_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="pol_rnn", **rnninits)
    var_rnn = BiasedLSTM(dim=rnn_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="var_rnn", **rnninits)
    dec_rnn = BiasedLSTM(dim=rnn_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="dec_rnn", **rnninits)

    model = IRStructPredModel(n_iter,
                              step_type=step_type,
                              use_pol=use_pol,
                              reader_mlp=reader_mlp,
                              writer_mlp=writer_mlp,
                              pol_mlp_in=pol_mlp_in,
                              pol_mlp_out=pol_mlp_out,
                              pol_rnn=pol_rnn,
                              var_mlp_in=var_mlp_in,
                              var_mlp_out=var_mlp_out,
                              var_rnn=var_rnn,
                              dec_mlp_in=dec_mlp_in,
                              dec_mlp_out=dec_mlp_out,
                              dec_rnn=dec_rnn)
    model.initialize()

    compile_start_time = time.time()

    # build the cost gradients, training function, samplers, etc.
    model.build_sampling_funcs()
    print("Testing model sampler...")
    # draw some independent samples from the model
    samp_count = 10
    samp_reps = 3
    x_in = Xtr[:10, :].repeat(samp_reps, axis=0)
    y_in = Ytr[:10, :].repeat(samp_reps, axis=0)
    x_samps, y_samps = model.sample_model(x_in, y_in, sample_source='p')
    # TODO: visualize sample prediction trajectories
    img_seq = seq_img_join(x_samps, y_samps, im_dim=(28, 28), transposed=True)
    seq_len = len(img_seq)
    samp_count = img_seq[0].shape[0]
    seq_samps = np.zeros((seq_len * samp_count, img_seq[0].shape[1]))
    idx = 0
    for s1 in range(samp_count):
        for s2 in range(seq_len):
            seq_samps[idx] = img_seq[s2][s1]
            idx += 1
    file_name = "{0:s}_samples_b{1:d}.png".format(res_tag, 0)
    utils.visualize_samples(seq_samps, file_name, num_rows=samp_count)

    model.build_model_funcs()

    compile_end_time = time.time()
    compile_minutes = (compile_end_time - compile_start_time) / 60.0
    print("THEANO COMPILE TIME (MIN): {}".format(compile_minutes))

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    print("Beginning to train the model...")
    out_file = open("{}_results.txt".format(res_tag), 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0002
    momentum = 0.9
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(300000):
        scale = min(1.0, ((i + 1) / 5000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr, Ytr = row_shuffle(Xtr, Ytr)
            batch_idx = np.arange(batch_size)
        # set sgd and objective function hyperparams for this update
        model.set_sgd_params(lr=scale * learn_rate,
                             mom_1=scale * momentum,
                             mom_2=0.98)
        model.set_lam_kld(lam_kld_q2p=1.0, lam_kld_p2q=0.1)
        model.set_grad_noise(grad_noise=0.02)
        # perform a minibatch update and record the cost for this batch
        Xb = to_fX(Xtr.take(batch_idx, axis=0))
        Yb = to_fX(Ytr.take(batch_idx, axis=0))
        result = model.train_joint(Xb, Yb)
        costs = [(costs[j] + result[j]) for j in range(len(result))]

        # diagnostics
        if ((i % 250) == 0):
            costs = [(v / 250.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    total_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_bound : {0:.4f}".format(costs[1])
            str4 = "    nll_term  : {0:.4f}".format(costs[2])
            str5 = "    kld_q2p   : {0:.4f}".format(costs[3])
            str6 = "    kld_p2q   : {0:.4f}".format(costs[4])
            str7 = "    reg_term  : {0:.4f}".format(costs[5])
            joint_str = "\n".join([str1, str2, str3, str4, str5, str6, str7])
            print(joint_str)
            out_file.write(joint_str + "\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if ((i % 1000) == 0):
            model.save_model_params("{}_params.pkl".format(res_tag))
            # compute a small-sample estimate of NLL bound on validation set
            Xva, Yva = row_shuffle(Xva, Yva)
            Xb = to_fX(Xva[:5000])
            Yb = to_fX(Yva[:5000])
            va_costs = model.compute_nll_bound(Xb, Yb)
            str1 = "    va_nll_bound : {}".format(va_costs[1])
            str2 = "    va_nll_term  : {}".format(va_costs[2])
            str3 = "    va_kld_q2p   : {}".format(va_costs[3])
            joint_str = "\n".join([str1, str2, str3])
            print(joint_str)
            out_file.write(joint_str + "\n")
            out_file.flush()
            # draw some independent samples from the model
            samp_count = 10
            samp_reps = 3
            x_in = Xva[:samp_count, :].repeat(samp_reps, axis=0)
            y_in = Yva[:samp_count, :].repeat(samp_reps, axis=0)
            x_samps, y_samps = model.sample_model(x_in,
                                                  y_in,
                                                  sample_source='p')
            # visualize sample prediction trajectories
            img_seq = seq_img_join(x_samps,
                                   y_samps,
                                   im_dim=(28, 28),
                                   transposed=True)
            seq_len = len(img_seq)
            samp_count = img_seq[0].shape[0]
            seq_samps = np.zeros((seq_len * samp_count, img_seq[0].shape[1]))
            idx = 0
            for s1 in range(samp_count):
                for s2 in range(seq_len):
                    if use_binary:
                        seq_samps[idx] = binarize_data(img_seq[s2][s1])
                    else:
                        seq_samps[idx] = img_seq[s2][s1]
                    idx += 1
            file_name = "{0:s}_samples_b{1:d}.png".format(res_tag, i)
            utils.visualize_samples(seq_samps, file_name, num_rows=samp_count)
def test_imoold_generation(step_type="add", attention=False):
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path="./data/")
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    # del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 250

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    x_dim = Xtr.shape[1]
    write_dim = 200
    enc_dim = 250
    dec_dim = 250
    mix_dim = 25
    z_dim = 100
    if attention:
        n_iter = 64
    else:
        n_iter = 16

    rnninits = {"weights_init": IsotropicGaussian(0.01), "biases_init": Constant(0.0)}
    inits = {"weights_init": IsotropicGaussian(0.01), "biases_init": Constant(0.0)}

    # setup the reader and writer
    if attention:
        read_N, write_N = (2, 5)  # resolution of reader and writer
        read_dim = 2 * read_N ** 2  # total number of "pixels" read by reader
        reader_mlp = AttentionReader2d(x_dim=x_dim, dec_dim=dec_dim, width=28, height=28, N=read_N, **inits)
        writer_mlp = AttentionWriter(input_dim=dec_dim, output_dim=x_dim, width=28, height=28, N=write_N, **inits)
        att_tag = "YA"
    else:
        read_dim = 2 * x_dim
        reader_mlp = Reader(x_dim=x_dim, dec_dim=dec_dim, **inits)
        writer_mlp = MLP([None, None], [dec_dim, write_dim, x_dim], name="writer_mlp", **inits)
        att_tag = "NA"

    # setup the infinite mixture initialization model
    mix_enc_mlp = CondNet([Tanh()], [x_dim, 250, mix_dim], name="mix_enc_mlp", **inits)
    mix_dec_mlp = MLP([Tanh(), Tanh()], [mix_dim, 250, (2 * enc_dim + 2 * dec_dim)], name="mix_dec_mlp", **inits)
    # setup the components of the sequential generative model
    enc_mlp_in = MLP([Identity()], [(read_dim + dec_dim), 4 * enc_dim], name="enc_mlp_in", **inits)
    dec_mlp_in = MLP([Identity()], [z_dim, 4 * dec_dim], name="dec_mlp_in", **inits)
    enc_mlp_out = CondNet([], [enc_dim, z_dim], name="enc_mlp_out", **inits)
    dec_mlp_out = CondNet([], [dec_dim, z_dim], name="dec_mlp_out", **inits)
    enc_rnn = BiasedLSTM(dim=enc_dim, ig_bias=2.0, fg_bias=2.0, name="enc_rnn", **rnninits)
    dec_rnn = BiasedLSTM(dim=dec_dim, ig_bias=2.0, fg_bias=2.0, name="dec_rnn", **rnninits)

    draw = IMoOLDrawModels(
        n_iter,
        step_type=step_type,  # step_type can be 'add' or 'jump'
        mix_enc_mlp=mix_enc_mlp,
        mix_dec_mlp=mix_dec_mlp,
        reader_mlp=reader_mlp,
        enc_mlp_in=enc_mlp_in,
        enc_mlp_out=enc_mlp_out,
        enc_rnn=enc_rnn,
        dec_mlp_in=dec_mlp_in,
        dec_mlp_out=dec_mlp_out,
        dec_rnn=dec_rnn,
        writer_mlp=writer_mlp,
    )
    draw.initialize()

    compile_start_time = time.time()

    # build the cost gradients, training function, samplers, etc.
    draw.build_model_funcs()

    compile_end_time = time.time()
    compile_minutes = (compile_end_time - compile_start_time) / 60.0
    print("THEANO COMPILE TIME (MIN): {}".format(compile_minutes))

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    print("Beginning to train the model...")
    out_file = open("TBOLM_GEN_RESULTS_{}_{}.txt".format(step_type, att_tag), "wb")
    costs = [0.0 for i in range(10)]
    learn_rate = 0.0002
    momentum = 0.5
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(250000):
        scale = min(1.0, ((i + 1) / 1000.0))
        if ((i + 1) % 10000) == 0:
            learn_rate = learn_rate * 0.95
        if i > 10000:
            momentum = 0.90
        else:
            momentum = 0.50
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if np.max(batch_idx) >= tr_samples:
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)

        # set sgd and objective function hyperparams for this update
        zero_ary = np.zeros((1,))
        draw.lr.set_value(to_fX(zero_ary + learn_rate))
        draw.mom_1.set_value(to_fX(zero_ary + momentum))
        draw.mom_2.set_value(to_fX(zero_ary + 0.99))

        # perform a minibatch update and record the cost for this batch
        Xb = to_fX(Xtr.take(batch_idx, axis=0))
        result = draw.train_joint(Xb, Xb)
        costs = [(costs[j] + result[j]) for j in range(len(result))]

        # diagnostics
        if (i % 200) == 0:
            costs = [(v / 200.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    total_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_bound : {0:.4f}".format(costs[1])
            str4 = "    nll_term  : {0:.4f}".format(costs[2])
            str5 = "    kld_q2p   : {0:.4f}".format(costs[3])
            str6 = "    kld_p2q   : {0:.4f}".format(costs[4])
            str7 = "    reg_term  : {0:.4f}".format(costs[5])
            str8 = "    step_klds : {0:s}".format(np.array_str(costs[6], precision=2))
            joint_str = "\n".join([str1, str2, str3, str4, str5, str6, str7, str8])
            print(joint_str)
            out_file.write(joint_str + "\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if (i % 1000) == 0:
            draw.save_model_params("TBOLM_GEN_PARAMS_{}_{}.pkl".format(step_type, att_tag))
            # compute a small-sample estimate of NLL bound on validation set
            Xva = row_shuffle(Xva)
            Xb = to_fX(Xva[:5000])
            va_costs = draw.compute_nll_bound(Xb, Xb)
            str1 = "    va_nll_bound : {}".format(va_costs[1])
            str2 = "    va_nll_term  : {}".format(va_costs[2])
            str3 = "    va_kld_q2p   : {}".format(va_costs[3])
            joint_str = "\n".join([str1, str2, str3])
            print(joint_str)
            out_file.write(joint_str + "\n")
            out_file.flush()
            # draw some independent samples from the model
            samples, x_logodds = draw.do_sample(16 * 16)
            utils.plot_kde_histogram(x_logodds[-1, :, :], "TBOLM-log_odds_hist.png", bins=30)
            n_iter, N, D = samples.shape
            samples = samples.reshape((n_iter, N, 28, 28))
            for j in xrange(n_iter):
                img = img_grid(samples[j, :, :, :])
                img.save("TBOLM-gen-samples-%03d.png" % (j,))
def test_ddm_generation():
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 250

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    x_dim = Xtr.shape[1]
    enc_dim = 250
    dec_dim = 250
    mix_dim = 20
    z_dim = 100
    n_iter = 8
    
    rnninits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }
    inits = {
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }

    # setup the infinite mixture initialization model
    mix_enc_mlp = CondNet([Tanh()], [x_dim, 250, mix_dim], \
                          name="mix_enc_mlp", **inits)
    mix_dec_mlp = MLP([Tanh(), Tanh()], \
                      [mix_dim, 250, (2*enc_dim + 2*dec_dim)], \
                      name="mix_dec_mlp", **inits)
    # setup the components of the sequential generative model
    enc_mlp_in = MLP([Identity()], [(x_dim + dec_dim + dec_dim), 4*enc_dim], \
                     name="enc_mlp_in", **inits)
    dec_mlp_in = MLP([Identity()], [z_dim, 4*dec_dim], \
                     name="dec_mlp_in", **inits)
    enc_mlp_out = CondNet([], [enc_dim, z_dim], name="enc_mlp_out", **inits)
    dec_mlp_out = CondNet([], [dec_dim, z_dim], name="dec_mlp_out", **inits)
    enc_rnn = BiasedLSTM(dim=enc_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="enc_rnn", **rnninits)
    dec_rnn = BiasedLSTM(dim=dec_dim, ig_bias=2.0, fg_bias=2.0, \
                         name="dec_rnn", **rnninits)
    # set up the transform from latent space to observation space
    s2x_mlp = TanhMLPwFFBP(dec_dim, [500], x_dim, name="s2x_mlp", **inits)

    draw = DriftDiffModel(
                n_iter,
                mix_enc_mlp=mix_enc_mlp,
                mix_dec_mlp=mix_dec_mlp,
                enc_mlp_in=enc_mlp_in,
                enc_mlp_out=enc_mlp_out,
                enc_rnn=enc_rnn,
                dec_mlp_in=dec_mlp_in,
                dec_mlp_out=dec_mlp_out,
                dec_rnn=dec_rnn,
                s2x_mlp=s2x_mlp)
    draw.initialize()

    # build the cost gradients, training function, samplers, etc.
    draw.build_model_funcs()

    #draw.load_model_params(f_name="TBDDM_GEN_PARAMS.pkl")

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    print("Beginning to train the model...")
    out_file = open("TBDDM_GEN_RESULTS.txt", 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0002
    momentum = 0.5
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(250000):
        scale = min(1.0, ((i+1) / 1000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        if (i > 10000):
            momentum = 0.90
        else:
            momentum = 0.50
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)

        # set sgd and objective function hyperparams for this update
        zero_ary = np.zeros((1,))
        draw.lr.set_value(to_fX(zero_ary + learn_rate))
        draw.mom_1.set_value(to_fX(zero_ary + momentum))
        draw.mom_2.set_value(to_fX(zero_ary + 0.99))

        # perform a minibatch update and record the cost for this batch
        Xb = to_fX(Xtr.take(batch_idx, axis=0))
        result = draw.train_joint(Xb, Xb)
        costs = [(costs[j] + result[j]) for j in range(len(result))]

        # diagnostics
        if ((i % 250) == 0):
            costs = [(v / 250.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    total_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_bound : {0:.4f}".format(costs[1])
            str4 = "    nll_term  : {0:.4f}".format(costs[2])
            str5 = "    kld_q2p   : {0:.4f}".format(costs[3])
            str6 = "    kld_p2q   : {0:.4f}".format(costs[4])
            str7 = "    reg_term  : {0:.4f}".format(costs[5])
            joint_str = "\n".join([str1, str2, str3, str4, str5, str6, str7])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if ((i % 500) == 0):
            draw.save_model_params("TBDDM_GEN_PARAMS.pkl")
            # compute a small-sample estimate of NLL bound on validation set
            Xva = row_shuffle(Xva)
            Xb = to_fX(Xva[:5000])
            va_costs = draw.compute_nll_bound(Xb, Xb)
            str1 = "    va_nll_bound : {}".format(va_costs[1])
            str2 = "    va_nll_term  : {}".format(va_costs[2])
            str3 = "    va_kld_q2p   : {}".format(va_costs[3])
            joint_str = "\n".join([str1, str2, str3])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            # draw some independent samples from the model
            samples = draw.do_sample(16*16)
            n_iter, N, D = samples.shape
            samples = samples.reshape( (n_iter, N, 28, 28) )
            for j in xrange(n_iter):
                img = img_grid(samples[j,:,:,:])
                img.save("TBDDM-gen-samples-%03d.png" % (j,))