def experiment(state, outdir_base='./'):
    rng.seed(1) #seed the numpy random generator  
    R.seed(1) #seed the other random generator (for reconstruction function indices)
    # Initialize output directory and files
    data.mkdir_p(outdir_base)
    outdir = outdir_base + "/" + state.dataset + "/"
    data.mkdir_p(outdir)
    logger = Logger(outdir)
    logger.log("----------MODEL 2, {0!s}-----------\n".format(state.dataset))
    gsn_train_convergence = outdir+"gsn_train_convergence.csv"
    gsn_valid_convergence = outdir+"gsn_valid_convergence.csv"
    gsn_test_convergence  = outdir+"gsn_test_convergence.csv"
    train_convergence = outdir+"train_convergence.csv"
    valid_convergence = outdir+"valid_convergence.csv"
    test_convergence  = outdir+"test_convergence.csv"
    init_empty_file(gsn_train_convergence)
    init_empty_file(gsn_valid_convergence)
    init_empty_file(gsn_test_convergence)
    init_empty_file(train_convergence)
    init_empty_file(valid_convergence)
    init_empty_file(test_convergence)
    
    #load parameters from config file if this is a test
    config_filename = outdir+'config'
    if state.test_model and 'config' in os.listdir(outdir):
        config_vals = load_from_config(config_filename)
        for CV in config_vals:
            logger.log(CV)
            if CV.startswith('test'):
                logger.log('Do not override testing switch')
                continue        
            try:
                exec('state.'+CV) in globals(), locals()
            except:
                exec('state.'+CV.split('=')[0]+"='"+CV.split('=')[1]+"'") in globals(), locals()
    else:
        # Save the current configuration
        # Useful for logs/experiments
        logger.log('Saving config')
        with open(config_filename, 'w') as f:
            f.write(str(state))

    logger.log(state)
    
    ####################################################
    # Load the data, train = train+valid, and sequence #
    ####################################################
    artificial = False
    if state.dataset == 'MNIST_1' or state.dataset == 'MNIST_2' or state.dataset == 'MNIST_3':
        (train_X, train_Y), (valid_X, valid_Y), (test_X, test_Y) = data.load_mnist(state.data_path)
        train_X = numpy.concatenate((train_X, valid_X))
        train_Y = numpy.concatenate((train_Y, valid_Y))
        artificial = True
        try:
            dataset = int(state.dataset.split('_')[1])
        except:
            logger.log("ERROR: artificial dataset number not recognized. Input was "+str(state.dataset))
            raise AssertionError("artificial dataset number not recognized. Input was "+str(state.dataset))
    else:
        logger.log("ERROR: dataset not recognized.")
        raise AssertionError("dataset not recognized.")
    
    train_X = theano.shared(train_X)
    train_Y = theano.shared(train_Y)
    valid_X = theano.shared(valid_X)
    valid_Y = theano.shared(valid_Y) 
    test_X = theano.shared(test_X)
    test_Y = theano.shared(test_Y) 
   
    if artificial:
        logger.log('Sequencing MNIST data...')
        logger.log(['train set size:',len(train_Y.eval())])
        logger.log(['train set size:',len(valid_Y.eval())])
        logger.log(['train set size:',len(test_Y.eval())])
        data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
        logger.log(['train set size:',len(train_Y.eval())])
        logger.log(['train set size:',len(valid_Y.eval())])
        logger.log(['train set size:',len(test_Y.eval())])
        logger.log('Sequencing done.\n')
    
    
    N_input =   train_X.eval().shape[1]
    root_N_input = numpy.sqrt(N_input)  
    
    # Network and training specifications
    layers      = state.layers # number hidden layers
    walkbacks   = state.walkbacks # number of walkbacks 
    layer_sizes = [N_input] + [state.hidden_size] * layers # layer sizes, from h0 to hK (h0 is the visible layer)
    
    learning_rate = theano.shared(cast32(state.learning_rate))  # learning rate
    annealing     = cast32(state.annealing) # exponential annealing coefficient
    momentum      = theano.shared(cast32(state.momentum)) # momentum term 

    ##############
    # PARAMETERS #
    ##############
    #gsn
    weights_list = [get_shared_weights(layer_sizes[i], layer_sizes[i+1], name="W_{0!s}_{1!s}".format(i,i+1)) for i in range(layers)] # initialize each layer to uniform sample from sqrt(6. / (n_in + n_out))
    bias_list    = [get_shared_bias(layer_sizes[i], name='b_'+str(i)) for i in range(layers + 1)] # initialize each layer to 0's.
    
    #recurrent
    recurrent_to_gsn_weights_list = [get_shared_weights(state.recurrent_hidden_size, layer_sizes[layer], name="W_u_h{0!s}".format(layer)) for layer in range(layers+1) if (layer%2) != 0]
    W_u_u = get_shared_weights(state.recurrent_hidden_size, state.recurrent_hidden_size, name="W_u_u")
    W_x_u = get_shared_weights(N_input, state.recurrent_hidden_size, name="W_x_u")
    recurrent_bias = get_shared_bias(state.recurrent_hidden_size, name='b_u')
    
    #lists for use with gradients
    gsn_params = weights_list + bias_list
    u_params   = [W_u_u, W_x_u, recurrent_bias]
    params     = gsn_params + recurrent_to_gsn_weights_list + u_params
    
    ###########################################################
    # load initial parameters of gsn to speed up my debugging #
    ###########################################################
    params_to_load = 'gsn_params.pkl'
    initialized_gsn = False
    if os.path.isfile(params_to_load):
        logger.log("\nLoading existing GSN parameters\n")
        loaded_params = cPickle.load(open(params_to_load,'r'))
        [p.set_value(lp.get_value(borrow=False)) for lp, p in zip(loaded_params[:len(weights_list)], weights_list)]
        [p.set_value(lp.get_value(borrow=False)) for lp, p in zip(loaded_params[len(weights_list):], bias_list)]
        initialized_gsn = True
    
    
    ############################
    # Theano variables and RNG #
    ############################
    MRG = RNG_MRG.MRG_RandomStreams(1)
    X = T.fmatrix('X') #single (batch) for training gsn
    Xs = T.fmatrix(name="Xs") #sequence for training rnn-gsn
    
 
    ########################
    # ACTIVATION FUNCTIONS #
    ########################
    # hidden activation
    if state.hidden_act == 'sigmoid':
        logger.log('Using sigmoid activation for hiddens')
        hidden_activation = T.nnet.sigmoid
    elif state.hidden_act == 'rectifier':
        logger.log('Using rectifier activation for hiddens')
        hidden_activation = lambda x : T.maximum(cast32(0), x)
    elif state.hidden_act == 'tanh':
        logger.log('Using hyperbolic tangent activation for hiddens')
        hidden_activation = lambda x : T.tanh(x)
    else:
        logger.log("ERROR: Did not recognize hidden activation {0!s}, please use tanh, rectifier, or sigmoid".format(state.hidden_act))
        raise AssertionError("Did not recognize hidden activation {0!s}, please use tanh, rectifier, or sigmoid".format(state.hidden_act))
    
    # visible activation
    if state.visible_act == 'sigmoid':
        logger.log('Using sigmoid activation for visible layer')
        visible_activation = T.nnet.sigmoid
    elif state.visible_act == 'softmax':
        logger.log('Using softmax activation for visible layer')
        visible_activation = T.nnet.softmax
    else:
        logger.log("ERROR: Did not recognize visible activation {0!s}, please use sigmoid or softmax".format(state.visible_act))
        raise AssertionError("Did not recognize visible activation {0!s}, please use sigmoid or softmax".format(state.visible_act))
    
    # recurrent activation
    if state.recurrent_hidden_act == 'sigmoid':
        logger.log('Using sigmoid activation for recurrent hiddens')
        recurrent_hidden_activation = T.nnet.sigmoid
    elif state.recurrent_hidden_act == 'rectifier':
        logger.log('Using rectifier activation for recurrent hiddens')
        recurrent_hidden_activation = lambda x : T.maximum(cast32(0), x)
    elif state.recurrent_hidden_act == 'tanh':
        logger.log('Using hyperbolic tangent activation for recurrent hiddens')
        recurrent_hidden_activation = lambda x : T.tanh(x)
    else:
        logger.log("ERROR: Did not recognize recurrent hidden activation {0!s}, please use tanh, rectifier, or sigmoid".format(state.recurrent_hidden_act))
        raise AssertionError("Did not recognize recurrent hidden activation {0!s}, please use tanh, rectifier, or sigmoid".format(state.recurrent_hidden_act))
    
    logger.log("\n")
    
    ####################
    #  COST FUNCTIONS  #
    ####################
    if state.cost_funct == 'binary_crossentropy':
        logger.log('Using binary cross-entropy cost!')
        cost_function = lambda x,y: T.mean(T.nnet.binary_crossentropy(x,y))
    elif state.cost_funct == 'square':
        logger.log("Using square error cost!")
        #cost_function = lambda x,y: T.log(T.mean(T.sqr(x-y)))
        cost_function = lambda x,y: T.log(T.sum(T.pow((x-y),2)))
    else:
        logger.log("ERROR: Did not recognize cost function {0!s}, please use binary_crossentropy or square".format(state.cost_funct))
        raise AssertionError("Did not recognize cost function {0!s}, please use binary_crossentropy or square".format(state.cost_funct))
    
    logger.log("\n")  
        
    ################################################
    #  COMPUTATIONAL GRAPH HELPER METHODS FOR GSN  #
    ################################################
    def update_layers(hiddens, p_X_chain, noisy = True):
        logger.log('odd layer updates')
        update_odd_layers(hiddens, noisy)
        logger.log('even layer updates')
        update_even_layers(hiddens, p_X_chain, noisy)
        logger.log('done full update.\n')
        
    def update_layers_reverse(hiddens, p_X_chain, noisy = True):
        logger.log('even layer updates')
        update_even_layers(hiddens, p_X_chain, noisy)
        logger.log('odd layer updates')
        update_odd_layers(hiddens, noisy)
        logger.log('done full update.\n')
        
    # Odd layer update function
    # just a loop over the odd layers
    def update_odd_layers(hiddens, noisy):
        for i in range(1, len(hiddens), 2):
            logger.log(['updating layer',i])
            simple_update_layer(hiddens, None, i, add_noise = noisy)
    
    # Even layer update
    # p_X_chain is given to append the p(X|...) at each full update (one update = odd update + even update)
    def update_even_layers(hiddens, p_X_chain, noisy):
        for i in range(0, len(hiddens), 2):
            logger.log(['updating layer',i])
            simple_update_layer(hiddens, p_X_chain, i, add_noise = noisy)
    
    # The layer update function
    # hiddens   :   list containing the symbolic theano variables [visible, hidden1, hidden2, ...]
    #               layer_update will modify this list inplace
    # p_X_chain :   list containing the successive p(X|...) at each update
    #               update_layer will append to this list
    # add_noise     : pre and post activation gaussian noise
    
    def simple_update_layer(hiddens, p_X_chain, i, add_noise=True):   
        # Compute the dot product, whatever layer
        # If the visible layer X
        if i == 0:
            logger.log('using '+str(weights_list[i])+'.T')
            hiddens[i]  =   T.dot(hiddens[i+1], weights_list[i].T) + bias_list[i]           
        # If the top layer
        elif i == len(hiddens)-1:
            logger.log(['using',weights_list[i-1]])
            hiddens[i]  =   T.dot(hiddens[i-1], weights_list[i-1]) + bias_list[i]
        # Otherwise in-between layers
        else:
            logger.log("using {0!s} and {1!s}.T".format(weights_list[i-1], weights_list[i]))
            # next layer        :   hiddens[i+1], assigned weights : W_i
            # previous layer    :   hiddens[i-1], assigned weights : W_(i-1)
            hiddens[i]  =   T.dot(hiddens[i+1], weights_list[i].T) + T.dot(hiddens[i-1], weights_list[i-1]) + bias_list[i]
    
        # Add pre-activation noise if NOT input layer
        if i==1 and state.noiseless_h1:
            logger.log('>>NO noise in first hidden layer')
            add_noise   =   False
    
        # pre activation noise            
        if i != 0 and add_noise:
            logger.log(['Adding pre-activation gaussian noise for layer', i])
            hiddens[i] = add_gaussian_noise(hiddens[i], state.hidden_add_noise_sigma)
       
        # ACTIVATION!
        if i == 0:
            logger.log('{} activation for visible layer'.format(state.visible_act))
            hiddens[i] = visible_activation(hiddens[i])
        else:
            logger.log(['Hidden units {} activation for layer'.format(state.hidden_act), i])
            hiddens[i] = hidden_activation(hiddens[i])
    
        # post activation noise
        # why is there post activation noise? Because there is already pre-activation noise, this just doubles the amount of noise between each activation of the hiddens.  
        if i != 0 and add_noise:
            logger.log(['Adding post-activation gaussian noise for layer', i])
            hiddens[i]  =   add_gaussian_noise(hiddens[i], state.hidden_add_noise_sigma)
    
        # build the reconstruction chain if updating the visible layer X
        if i == 0:
            # if input layer -> append p(X|...)
            p_X_chain.append(hiddens[i])
            
            # sample from p(X|...) - SAMPLING NEEDS TO BE CORRECT FOR INPUT TYPES I.E. FOR BINARY MNIST SAMPLING IS BINOMIAL. real-valued inputs should be gaussian
            if state.input_sampling:
                logger.log('Sampling from input')
                sampled = MRG.binomial(p = hiddens[i], size=hiddens[i].shape, dtype='float32')
            else:
                logger.log('>>NO input sampling')
                sampled = hiddens[i]
            # add noise
            sampled = salt_and_pepper(sampled, state.input_salt_and_pepper)
            
            # set input layer
            hiddens[i] = sampled
                
    ##############################################
    #    Build the training graph for the GSN    #
    ##############################################
    # the loop step for the rnn-gsn, return the sample and the costs
    def create_gsn_reverse(x_t, u_tm1, noiseflag=True):
        chain = []
        # init hiddens from the u
        hiddens_t = [T.zeros_like(x_t)]
        for layer, w in enumerate(weights_list):
            layer = layer+1
            # if this is an even layer, just append zeros
            if layer%2 == 0:
                hiddens_t.append(T.zeros_like(T.dot(hiddens_t[-1], w)))
            # if it is an odd layer, use the rnn to determine the layer
            else:
                hiddens_t.append(hidden_activation(T.dot(u_tm1, recurrent_to_gsn_weights_list[layer/2]) + bias_list[layer]))
                
        for i in range(walkbacks):
            logger.log("Reverse Walkback {!s}/{!s} for RNN-GSN".format(i+1,walkbacks))
            update_layers_reverse(hiddens_t, chain, noiseflag)
        
        x_sample  = chain[-1]
        costs     = [cost_function(rX, x_t) for rX in chain]
        show_cost = costs[-1]
        cost      = T.sum(costs)
        
        return x_sample, cost, show_cost
        
    # the GSN graph for the rnn-gsn
    def build_gsn_given_u(xs, u, noiseflag=True):
        logger.log("Creating recurrent gsn step scan.\n")
        u0 = T.zeros((1,state.recurrent_hidden_size))
        if u is None:
            u = u0
        else:
            u = T.concatenate([u0,u]) #add the initial u condition to the list of u's created from the recurrent scan op.
        (samples, costs, show_costs), updates = theano.scan(lambda x_t, u_tm1: create_gsn_reverse(x_t, u_tm1, noiseflag),
                                                            sequences = [xs, u])
        cost = T.sum(costs)
        show_cost = T.mean(show_costs)
        last_sample = samples[-1]
        
        return last_sample, cost, show_cost, updates
    
    def build_gsn_given_u0(x, u0, noiseflag=True):
        x_sample, _, _ = create_gsn_reverse(x, u0, noiseflag)
        return x_sample
    
    # the GSN graph for initial GSN training
    def build_gsn_graph(x, noiseflag):
        p_X_chain = []
        if noiseflag:
            X_init = salt_and_pepper(x, state.input_salt_and_pepper)
        else:
            X_init = x
        # init hiddens with zeros
        hiddens = [X_init]
        for w in weights_list:
            hiddens.append(T.zeros_like(T.dot(hiddens[-1], w)))
        # The layer update scheme
        logger.log(["Building the gsn graph :", walkbacks,"updates"])
        for i in range(walkbacks):
            logger.log("GSN Walkback {!s}/{!s}".format(i+1,walkbacks))
            update_layers(hiddens, p_X_chain, noisy=noiseflag)
            
        return p_X_chain
    
    '''Build the actual gsn training graph'''
    p_X_chain_gsn = build_gsn_graph(X, noiseflag=True)
    
    
    ##############################################
    #  Build the training graph for the RNN-GSN  #
    ##############################################
    # If `x_t` is given, deterministic recurrence to compute the u_t. Otherwise, first generate
    def recurrent_step(x_t, u_tm1):
        ua_t = T.dot(x_t, W_x_u) + T.dot(u_tm1, W_u_u) + recurrent_bias
        u_t = recurrent_hidden_activation(ua_t)
        return ua_t, u_t
    
    logger.log("\nCreating recurrent step scan.")
    # For training, the deterministic recurrence is used to compute all the
    # {h_t, 1 <= t <= T} given Xs. Conditional GSNs can then be trained
    # in batches using those parameters.
    u0 = T.zeros((state.recurrent_hidden_size,))  # initial value for the RNN hidden units
    (_, u), updates_recurrent = theano.scan(lambda x_t, u_tm1: recurrent_step(x_t, u_tm1),
                                   sequences=Xs,
                                   outputs_info=[None, u0])
    
    _, cost, show_cost, updates_gsn = build_gsn_given_u(Xs, u, noiseflag=True)
    
    updates_recurrent.update(updates_gsn)
    
    updates_train = updates_recurrent
    updates_cost = updates_recurrent
    
    ################################################
    #  Build the checkpoint graph for the RNN-GSN  #
    ################################################
    # Used to generate the next predicted output given all previous inputs - starting with nothing
    # When there is no X history
    x_sample_R_init = build_gsn_given_u0(X, u0, noiseflag=False)
    # When there is some number of Xs history
    x_sample_R, _, _, updates_gsn_R = build_gsn_given_u(Xs, u, noiseflag=False)
        

    #############
    #   COSTS   #
    #############
    logger.log("")    
    logger.log('Cost w.r.t p(X|...) at every step in the graph')
    
    gsn_costs     = [cost_function(rX, X) for rX in p_X_chain_gsn]
    gsn_show_cost = gsn_costs[-1]
    gsn_cost      = numpy.sum(gsn_costs)
            

    ###################################
    # GRADIENTS AND FUNCTIONS FOR GSN #
    ###################################
    logger.log(["params:",params])
    
    logger.log("creating functions...")
    start_functions_time = time.time()
    
    gradient_gsn        = T.grad(gsn_cost, gsn_params)      
    gradient_buffer_gsn = [theano.shared(numpy.zeros(param.get_value().shape, dtype='float32')) for param in gsn_params]
    
    m_gradient_gsn    = [momentum * gb + (cast32(1) - momentum) * g for (gb, g) in zip(gradient_buffer_gsn, gradient_gsn)]
    param_updates_gsn = [(param, param - learning_rate * mg) for (param, mg) in zip(gsn_params, m_gradient_gsn)]
    gradient_buffer_updates_gsn = zip(gradient_buffer_gsn, m_gradient_gsn)
        
    grad_updates_gsn = OrderedDict(param_updates_gsn + gradient_buffer_updates_gsn)
    
    f_cost_gsn = theano.function(inputs  = [X], 
                                 outputs = gsn_show_cost, 
                                 on_unused_input='warn')

    f_learn_gsn = theano.function(inputs  = [X],
                                  updates = grad_updates_gsn,
                                  outputs = gsn_show_cost,
                                  on_unused_input='warn')
    
    #######################################
    # GRADIENTS AND FUNCTIONS FOR RNN-GSN #
    #######################################
    # if we are not using Hessian-free training create the normal sgd functions
    if state.hf == 0:
        gradient      = T.grad(cost, params)      
        gradient_buffer = [theano.shared(numpy.zeros(param.get_value().shape, dtype='float32')) for param in params]
        
        m_gradient    = [momentum * gb + (cast32(1) - momentum) * g for (gb, g) in zip(gradient_buffer, gradient)]
        param_updates = [(param, param - learning_rate * mg) for (param, mg) in zip(params, m_gradient)]
        gradient_buffer_updates = zip(gradient_buffer, m_gradient)
            
        updates = OrderedDict(param_updates + gradient_buffer_updates)
        updates_train.update(updates)
    
        f_learn = theano.function(inputs  = [Xs],
                                  updates = updates_train,
                                  outputs = show_cost,
                                  on_unused_input='warn')
        
                
        f_cost  = theano.function(inputs  = [Xs], 
                                  updates = updates_cost,
                                  outputs = show_cost, 
                                  on_unused_input='warn')
    
    logger.log("Training/cost functions done.")
    compilation_time = time.time() - start_functions_time
    # Show the compile time with appropriate easy-to-read units.
    if compilation_time < 60:
        logger.log(["Compilation took",compilation_time,"seconds.\n\n"])
    elif compilation_time < 3600:
        logger.log(["Compilation took",compilation_time/60,"minutes.\n\n"])
    else:
        logger.log(["Compilation took",compilation_time/3600,"hours.\n\n"])
    
    ############################################################################################
    # Denoise some numbers : show number, noisy number, predicted number, reconstructed number #
    ############################################################################################   
    # Recompile the graph without noise for reconstruction function
    # The layer update scheme
    logger.log("Creating graph for noisy reconstruction function at checkpoints during training.")
    f_recon_init = theano.function(inputs=[X], outputs=x_sample_R_init, on_unused_input='warn')
    f_recon = theano.function(inputs=[Xs], outputs=x_sample_R, updates=updates_gsn_R)
    
    # Now do the same but for the GSN in the initial run
    p_X_chain_R = build_gsn_graph(X, noiseflag=False)
    f_recon_gsn = theano.function(inputs=[X], outputs = p_X_chain_R[-1])

    logger.log("Done compiling all functions.")
    compilation_time = time.time() - start_functions_time
    # Show the compile time with appropriate easy-to-read units.
    if compilation_time < 60:
        logger.log(["Total time took",compilation_time,"seconds.\n\n"])
    elif compilation_time < 3600:
        logger.log(["Total time took",compilation_time/60,"minutes.\n\n"])
    else:
        logger.log(["Total time took",compilation_time/3600,"hours.\n\n"])

    ############
    # Sampling #
    ############
    # a function to add salt and pepper noise
    f_noise = theano.function(inputs = [X], outputs = salt_and_pepper(X, state.input_salt_and_pepper))
    # the input to the sampling function
    X_sample = T.fmatrix("X_sampling")
    network_state_input     =   [X_sample] + [T.fmatrix("H_sampling_"+str(i+1)) for i in range(layers)]
   
    # "Output" state of the network (noisy)
    # initialized with input, then we apply updates
    
    network_state_output    =   [X_sample] + network_state_input[1:]

    visible_pX_chain        =   []

    # ONE update
    logger.log("Performing one walkback in network state sampling.")
    update_layers(network_state_output, visible_pX_chain, noisy=True)

    if layers == 1: 
        f_sample_simple = theano.function(inputs = [X_sample], outputs = visible_pX_chain[-1])
    
    
    # WHY IS THERE A WARNING????
    # because the first odd layers are not used -> directly computed FROM THE EVEN layers
    # unused input = warn
    f_sample2   =   theano.function(inputs = network_state_input, outputs = network_state_output + visible_pX_chain, on_unused_input='warn')

    def sample_some_numbers_single_layer():
        x0    =   test_X.get_value()[:1]
        samples = [x0]
        x  =   f_noise(x0)
        for i in range(399):
            x = f_sample_simple(x)
            samples.append(x)
            x = numpy.random.binomial(n=1, p=x, size=x.shape).astype('float32')
            x = f_noise(x)
        return numpy.vstack(samples)
            
    def sampling_wrapper(NSI):
        # * is the "splat" operator: It takes a list as input, and expands it into actual positional arguments in the function call.
        out             =   f_sample2(*NSI)
        NSO             =   out[:len(network_state_output)]
        vis_pX_chain    =   out[len(network_state_output):]
        return NSO, vis_pX_chain

    def sample_some_numbers(N=400):
        # The network's initial state
        init_vis        =   test_X.get_value()[:1]

        noisy_init_vis  =   f_noise(init_vis)

        network_state   =   [[noisy_init_vis] + [numpy.zeros((1,len(b.get_value())), dtype='float32') for b in bias_list[1:]]]

        visible_chain   =   [init_vis]

        noisy_h0_chain  =   [noisy_init_vis]

        for i in range(N-1):
           
            # feed the last state into the network, compute new state, and obtain visible units expectation chain 
            net_state_out, vis_pX_chain =   sampling_wrapper(network_state[-1])

            # append to the visible chain
            visible_chain   +=  vis_pX_chain

            # append state output to the network state chain
            network_state.append(net_state_out)
            
            noisy_h0_chain.append(net_state_out[0])

        return numpy.vstack(visible_chain), numpy.vstack(noisy_h0_chain)
    
    def plot_samples(epoch_number, leading_text):
        to_sample = time.time()
        if layers == 1:
            # one layer model
            V = sample_some_numbers_single_layer()
        else:
            V, H0 = sample_some_numbers()
        img_samples =   PIL.Image.fromarray(tile_raster_images(V, (root_N_input,root_N_input), (20,20)))
        
        fname       =   outdir+leading_text+'samples_epoch_'+str(epoch_number)+'.png'
        img_samples.save(fname) 
        logger.log('Took ' + str(time.time() - to_sample) + ' to sample 400 numbers')
   
    #############################
    # Save the model parameters #
    #############################
    def save_params_to_file(name, n, gsn_params):
        pass
        print 'saving parameters...'
        save_path = outdir+name+'_params_epoch_'+str(n)+'.pkl'
        f = open(save_path, 'wb')
        try:
            cPickle.dump(gsn_params, f, protocol=cPickle.HIGHEST_PROTOCOL)
        finally:
            f.close()
            
    def save_params(params):
        values = [param.get_value(borrow=True) for param in params]
        return values
    
    def restore_params(params, values):
        for i in range(len(params)):
            params[i].set_value(values[i])

    ################
    # GSN TRAINING #
    ################
    def train_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        logger.log("\n-----------TRAINING GSN------------\n")
        
        # TRAINING
        n_epoch     =   state.n_epoch
        batch_size  =   state.gsn_batch_size
        STOP        =   False
        counter     =   0
        learning_rate.set_value(cast32(state.learning_rate))  # learning rate
        times = []
        best_cost = float('inf')
        best_params = None
        patience = 0
                    
        logger.log(['train X size:',str(train_X.shape.eval())])
        logger.log(['valid X size:',str(valid_X.shape.eval())])
        logger.log(['test X size:',str(test_X.shape.eval())])
        
        if state.vis_init:
            bias_list[0].set_value(logit(numpy.clip(0.9,0.001,train_X.get_value().mean(axis=0))))
    
        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            logger.log('Testing : skip training')
            STOP    =   True
    
        while not STOP:
            counter += 1
            t = time.time()
            logger.append([counter,'\t'])
                
            #shuffle the data
            data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
                
            #train
            train_costs = []
            for i in xrange(len(train_X.get_value(borrow=True)) / batch_size):
                x = train_X.get_value()[i * batch_size : (i+1) * batch_size]
                cost = f_learn_gsn(x)
                train_costs.append([cost])
            train_costs = numpy.mean(train_costs)
            # record it
            logger.append(['Train:',trunc(train_costs),'\t'])
            with open(gsn_train_convergence,'a') as f:
                f.write("{0!s},".format(train_costs))
                f.write("\n")
    
    
            #valid
            valid_costs = []
            for i in xrange(len(valid_X.get_value(borrow=True)) / batch_size):
                x = valid_X.get_value()[i * batch_size : (i+1) * batch_size]
                cost = f_cost_gsn(x)
                valid_costs.append([cost])                    
            valid_costs = numpy.mean(valid_costs)
            # record it
            logger.append(['Valid:',trunc(valid_costs), '\t'])
            with open(gsn_valid_convergence,'a') as f:
                f.write("{0!s},".format(valid_costs))
                f.write("\n")
    
    
            #test
            test_costs = []
            for i in xrange(len(test_X.get_value(borrow=True)) / batch_size):
                x = test_X.get_value()[i * batch_size : (i+1) * batch_size]
                cost = f_cost_gsn(x)
                test_costs.append([cost])                
            test_costs = numpy.mean(test_costs)
            # record it 
            logger.append(['Test:',trunc(test_costs), '\t'])
            with open(gsn_test_convergence,'a') as f:
                f.write("{0!s},".format(test_costs))
                f.write("\n")
            
            
            #check for early stopping
            cost = numpy.sum(valid_costs)
            if cost < best_cost*state.early_stop_threshold:
                patience = 0
                best_cost = cost
                # save the parameters that made it the best
                best_params = save_params(gsn_params)
            else:
                patience += 1
    
            if counter >= n_epoch or patience >= state.early_stop_length:
                STOP = True
                if best_params is not None:
                    restore_params(gsn_params, best_params)
                save_params_to_file('gsn', counter, gsn_params)
    
            timing = time.time() - t
            times.append(timing)
    
            logger.append(['time:', trunc(timing)])
            
            logger.log(['remaining:', trunc((n_epoch - counter) * numpy.mean(times) / 60 / 60), 'hrs'])
    
            if (counter % state.save_frequency) == 0 or STOP is True:
                n_examples = 100
                random_idx = numpy.array(R.sample(range(len(test_X.get_value(borrow=True))), n_examples))
                numbers = test_X.get_value(borrow=True)[random_idx]
                noisy_numbers = f_noise(test_X.get_value(borrow=True)[random_idx])
                reconstructed = f_recon_gsn(noisy_numbers) 
                # Concatenate stuff
                stacked = numpy.vstack([numpy.vstack([numbers[i*10 : (i+1)*10], noisy_numbers[i*10 : (i+1)*10], reconstructed[i*10 : (i+1)*10]]) for i in range(10)])
                number_reconstruction = PIL.Image.fromarray(tile_raster_images(stacked, (root_N_input,root_N_input), (10,30)))
                    
                number_reconstruction.save(outdir+'gsn_number_reconstruction_epoch_'+str(counter)+'.png')
        
                #sample_numbers(counter, 'seven')
                plot_samples(counter, 'gsn')
        
                #save gsn_params
                save_params_to_file('gsn', counter, gsn_params)
         
            # ANNEAL!
            new_lr = learning_rate.get_value() * annealing
            learning_rate.set_value(new_lr)

        
        # 10k samples
        print 'Generating 10,000 samples'
        samples, _  =   sample_some_numbers(N=10000)
        f_samples   =   outdir+'samples.npy'
        numpy.save(f_samples, samples)
        print 'saved digits'
        
    def train_RNN_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        # If we are using Hessian-free training
        if state.hf == 1:
            pass
#         gradient_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=5000)
#         cg_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=1000)
#         valid_dataset = hf_sequence_dataset([valid_X.get_value()], batch_size=None, number_batches=1000)
#         
#         s = x_samples
#         costs = [cost, show_cost]
#         hf_optimizer(params, [Xs], s, costs, u, ua).train(gradient_dataset, cg_dataset, initial_lambda=1.0, preconditioner=True, validation=valid_dataset)
        
        # If we are using SGD training
        else:
            logger.log("\n-----------TRAINING RNN-GSN------------\n")
            # TRAINING
            n_epoch     =   state.n_epoch
            batch_size  =   state.batch_size
            STOP        =   False
            counter     =   0
            learning_rate.set_value(cast32(state.learning_rate))  # learning rate
            times = []
            best_cost = float('inf')
            best_params = None
            patience = 0
                        
            logger.log(['train X size:',str(train_X.shape.eval())])
            logger.log(['valid X size:',str(valid_X.shape.eval())])
            logger.log(['test X size:',str(test_X.shape.eval())])
            
            if state.vis_init:
                bias_list[0].set_value(logit(numpy.clip(0.9,0.001,train_X.get_value().mean(axis=0))))
        
            if state.test_model:
                # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
                logger.log('Testing : skip training')
                STOP    =   True
        
            while not STOP:
                counter += 1
                t = time.time()
                logger.append([counter,'\t'])
                    
                #shuffle the data
                data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
                     
                #train
                train_costs = []
                for i in xrange(len(train_X.get_value(borrow=True)) / batch_size):
                    xs = train_X.get_value(borrow=True)[i * batch_size : (i+1) * batch_size]
                    cost = f_learn(xs)
                    train_costs.append([cost])
                train_costs = numpy.mean(train_costs)
                # record it
                logger.append(['Train:',trunc(train_costs),'\t'])
                with open(train_convergence,'a') as f:
                    f.write("{0!s},".format(train_costs))
                    f.write("\n")
         
         
                #valid
                valid_costs = []
                for i in xrange(len(valid_X.get_value(borrow=True)) / batch_size):
                    xs = valid_X.get_value(borrow=True)[i * batch_size : (i+1) * batch_size]
                    cost = f_cost(xs)
                    valid_costs.append([cost])                    
                valid_costs = numpy.mean(valid_costs)
                # record it
                logger.append(['Valid:',trunc(valid_costs), '\t'])
                with open(valid_convergence,'a') as f:
                    f.write("{0!s},".format(valid_costs))
                    f.write("\n")
         
         
                #test
                test_costs = []
                for i in xrange(len(test_X.get_value(borrow=True)) / batch_size):
                    xs = test_X.get_value(borrow=True)[i * batch_size : (i+1) * batch_size]
                    cost = f_cost(xs)
                    test_costs.append([cost])                
                test_costs = numpy.mean(test_costs)
                # record it 
                logger.append(['Test:',trunc(test_costs), '\t'])
                with open(test_convergence,'a') as f:
                    f.write("{0!s},".format(test_costs))
                    f.write("\n")
                 
                 
                #check for early stopping
                cost = numpy.sum(valid_costs)
                if cost < best_cost*state.early_stop_threshold:
                    patience = 0
                    best_cost = cost
                    # save the parameters that made it the best
                    best_params = save_params(params)
                else:
                    patience += 1
         
                if counter >= n_epoch or patience >= state.early_stop_length:
                    STOP = True
                    if best_params is not None:
                        restore_params(params, best_params)
                    save_params_to_file('all', counter, params)
         
                timing = time.time() - t
                times.append(timing)
         
                logger.append(['time:', trunc(timing)])
                 
                logger.log(['remaining:', trunc((n_epoch - counter) * numpy.mean(times) / 60 / 60), 'hrs'])
        
                if (counter % state.save_frequency) == 0 or STOP is True:
                    n_examples = 100
                    nums = test_X.get_value(borrow=True)[range(n_examples)]
                    noisy_nums = f_noise(test_X.get_value(borrow=True)[range(n_examples)])
                    reconstructions = []
                    for i in xrange(0, len(noisy_nums)):
                        if i is 0:
                            recon = f_recon_init(noisy_nums[:i+1])
                        else:
                            recon = f_recon(noisy_nums[max(0,(i+1)-batch_size):i+1])
                        reconstructions.append(recon)
                    reconstructed = numpy.array(reconstructions)

                    # Concatenate stuff
                    stacked = numpy.vstack([numpy.vstack([nums[i*10 : (i+1)*10], noisy_nums[i*10 : (i+1)*10], reconstructed[i*10 : (i+1)*10]]) for i in range(10)])
                    number_reconstruction = PIL.Image.fromarray(tile_raster_images(stacked, (root_N_input,root_N_input), (10,30)))
                        
                    number_reconstruction.save(outdir+'rnngsn_number_reconstruction_epoch_'+str(counter)+'.png')
            
                    #sample_numbers(counter, 'seven')
                    plot_samples(counter, 'rnngsn')
            
                    #save params
                    save_params_to_file('all', counter, params)
             
                # ANNEAL!
                new_lr = learning_rate.get_value() * annealing
                learning_rate.set_value(new_lr)
    
            
            # 10k samples
            print 'Generating 10,000 samples'
            samples, _  =   sample_some_numbers(N=10000)
            f_samples   =   outdir+'samples.npy'
            numpy.save(f_samples, samples)
            print 'saved digits'
            
            
            
    #####################
    # STORY 2 ALGORITHM #
    #####################
    # train the GSN parameters first to get a good baseline (if not loaded from parameter .pkl file)
    if initialized_gsn is False:
        train_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
    # train the entire RNN-GSN
    train_RNN_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
    def train(self, batch_size=100, num_epochs=300):
        '''Train the RNN-RBM via stochastic gradient descent (SGD) using MIDI
files converted to piano-rolls.

files : list of strings
  List of MIDI files that will be loaded as piano-rolls for training.
batch_size : integer
  Training sequences will be split into subsequences of at most this size
  before applying the SGD updates.
num_epochs : integer
  Number of epochs (pass over the training set) performed. The user can
  safely interrupt training with Ctrl+C at any time.'''

        (train_X, train_Y), (valid_X, valid_Y), (test_X, test_Y) = data.load_mnist("../datasets/")
        train_X = numpy.concatenate((train_X, valid_X))
        train_Y = numpy.concatenate((train_Y, valid_Y))

        print
        'Sequencing MNIST data...'
        print
        'train set size:', train_X.shape
        print
        'valid set size:', valid_X.shape
        print
        'test set size:', test_X.shape

        train_X = theano.shared(train_X)
        train_Y = theano.shared(train_Y)
        valid_X = theano.shared(valid_X)
        valid_Y = theano.shared(valid_Y)
        test_X = theano.shared(test_X)
        test_Y = theano.shared(test_Y)

        data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset=4)

        print
        'train set size:', train_X.shape.eval()
        print
        'valid set size:', valid_X.shape.eval()
        print
        'test set size:', test_X.shape.eval()
        print
        'Sequencing done.'
        print

        N_input = train_X.eval().shape[1]
        self.root_N_input = numpy.sqrt(N_input)

        times = []

        try:
            for epoch in xrange(num_epochs):
                t = time.time()
                print
                'Epoch %i/%i : ' % (epoch + 1, num_epochs)
                # sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
                accuracy = []
                costs = []
                crossentropy = []
                tests = []
                test_acc = []

                for i in range(len(train_X.get_value(borrow=True)) / batch_size):
                    t0=time.time()
                    xs = train_X.get_value(borrow=True)[(i * batch_size): ((i + 1) * batch_size)]
                    acc, cost, cross = self.train_function(xs)
                    accuracy.append(acc)
                    costs.append(cost)
                    crossentropy.append(cross)
                    print time.time()-t0
                print 'Train',numpy.mean(accuracy), 'cost', numpy.mean(costs), 'cross', numpy.mean(crossentropy),
                    
                for i in range(len(test_X.get_value(borrow=True)) / batch_size):
                    xs = train_X.get_value(borrow=True)[(i * batch_size): ((i + 1) * batch_size)]
                    acc, cost = self.test_function(xs)
                    test_acc.append(acc)
                    tests.append(cost)

                print
                '\t Test_acc', numpy.mean(test_acc), "cross", numpy.mean(tests)

                timing = time.time() - t
                times.append(timing)
                print
                'time : ', trunc(timing),
                print
                'remaining: ', (num_epochs - (epoch + 1)) * numpy.mean(times) / 60 / 60, 'hrs'
                sys.stdout.flush()

                # new learning rate
                new_lr = self.lr.get_value() * self.annealing
                self.lr.set_value(new_lr)

        except KeyboardInterrupt:
            print
            'Interrupted by user.'
def experiment(state, outdir_base='./'):
    rng.seed(1)  # seed the numpy random generator
    # Initialize output directory and files
    data.mkdir_p(outdir_base)
    outdir = outdir_base + "/" + state.dataset + "/"
    data.mkdir_p(outdir)
    logfile = outdir + "log.txt"
    with open(logfile, 'w') as f:
        f.write("MODEL 2, {0!s}\n\n".format(state.dataset))
    train_convergence_pre = outdir + "train_convergence_pre.csv"
    train_convergence_post = outdir + "train_convergence_post.csv"
    valid_convergence_pre = outdir + "valid_convergence_pre.csv"
    valid_convergence_post = outdir + "valid_convergence_post.csv"
    test_convergence_pre = outdir + "test_convergence_pre.csv"
    test_convergence_post = outdir + "test_convergence_post.csv"

    print
    print
    "----------MODEL 2, {0!s}--------------".format(state.dataset)
    print

    # load parameters from config file if this is a test
    config_filename = outdir + 'config'
    if state.test_model and 'config' in os.listdir(outdir):
        config_vals = load_from_config(config_filename)
        for CV in config_vals:
            print
            CV
            if CV.startswith('test'):
                print
                'Do not override testing switch'
                continue
            try:
                exec('state.' + CV) in globals(), locals()
            except:
                exec('state.' + CV.split('=')[0] + "='" + CV.split('=')[1] + "'") in globals(), locals()
    else:
        # Save the current configuration
        # Useful for logs/experiments
        print
        'Saving config'
        with open(config_filename, 'w') as f:
            f.write(str(state))

    print
    state
    # Load the data, train = train+valid, and sequence
    artificial = False
    if state.dataset == 'MNIST_1' or state.dataset == 'MNIST_2' or state.dataset == 'MNIST_3':
        (train_X, train_Y), (valid_X, valid_Y), (test_X, test_Y) = data.load_mnist(state.data_path)
        train_X = numpy.concatenate((train_X, valid_X))
        train_Y = numpy.concatenate((train_Y, valid_Y))
        artificial = True
        try:
            dataset = int(state.dataset.split('_')[1])
        except:
            raise AssertionError("artificial dataset number not recognized. Input was " + state.dataset)
    else:
        raise AssertionError("dataset not recognized.")

    train_X = theano.shared(train_X)
    train_Y = theano.shared(train_Y)
    valid_X = theano.shared(valid_X)
    valid_Y = theano.shared(valid_Y)
    test_X = theano.shared(test_X)
    test_Y = theano.shared(test_Y)

    if artificial:
        print
        'Sequencing MNIST data...'
        print
        'train set size:', len(train_Y.eval())
        print
        'valid set size:', len(valid_Y.eval())
        print
        'test set size:', len(test_Y.eval())
        data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
        print
        'train set size:', len(train_Y.eval())
        print
        'valid set size:', len(valid_Y.eval())
        print
        'test set size:', len(test_Y.eval())
        print
        'Sequencing done.'
        print

    N_input = train_X.eval().shape[1]
    root_N_input = numpy.sqrt(N_input)

    # Network and training specifications
    layers = state.layers  # number hidden layers
    walkbacks = state.walkbacks  # number of walkbacks
    layer_sizes = [N_input] + [state.hidden_size] * layers  # layer sizes, from h0 to hK (h0 is the visible layer)
    learning_rate = theano.shared(cast32(state.learning_rate))  # learning rate
    annealing = cast32(state.annealing)  # exponential annealing coefficient
    momentum = theano.shared(cast32(state.momentum))  # momentum term

    # PARAMETERS : weights list and bias list.
    # initialize a list of weights and biases based on layer_sizes
    weights_list = [get_shared_weights(layer_sizes[i], layer_sizes[i + 1], name="W_{0!s}_{1!s}".format(i, i + 1)) for i
                    in range(layers)]  # initialize each layer to uniform sample from sqrt(6. / (n_in + n_out))
    recurrent_weights_list = [
        get_shared_weights(layer_sizes[i + 1], layer_sizes[i], name="V_{0!s}_{1!s}".format(i + 1, i)) for i in
        range(layers)]  # initialize each layer to uniform sample from sqrt(6. / (n_in + n_out))
    bias_list = [get_shared_bias(layer_sizes[i], name='b_' + str(i)) for i in
                 range(layers + 1)]  # initialize each layer to 0's.

    # Theano variables and RNG
    MRG = RNG_MRG.MRG_RandomStreams(1)
    X = T.fmatrix('X')
    Xs = [T.fmatrix(name="X_initial") if i == 0 else T.fmatrix(name="X_" + str(i + 1)) for i in range(walkbacks + 1)]
    hiddens_input = [X] + [T.fmatrix(name="h_" + str(i + 1)) for i in range(layers)]
    hiddens_output = hiddens_input[:1] + hiddens_input[1:]

    # Check variables for bad inputs and stuff
    if state.batch_size > len(Xs):
        warnings.warn(
            "Batch size should not be bigger than walkbacks+1 (len(Xs)) unless you know what you're doing. You need to know the sequence length beforehand.")
    if state.batch_size <= 0:
        raise AssertionError("batch size cannot be <= 0")

    ''' F PROP '''
    if state.hidden_act == 'sigmoid':
        print
        'Using sigmoid activation for hiddens'
        hidden_activation = T.nnet.sigmoid
    elif state.hidden_act == 'rectifier':
        print
        'Using rectifier activation for hiddens'
        hidden_activation = lambda x: T.maximum(cast32(0), x)
    elif state.hidden_act == 'tanh':
        print
        'Using hyperbolic tangent activation for hiddens'
        hidden_activation = lambda x: T.tanh(x)
    else:
        raise AssertionError("Did not recognize hidden activation {0!s}, please use tanh, rectifier, or sigmoid".format(
            state.hidden_act))

    if state.visible_act == 'sigmoid':
        print
        'Using sigmoid activation for visible layer'
        visible_activation = T.nnet.sigmoid
    elif state.visible_act == 'softmax':
        print
        'Using softmax activation for visible layer'
        visible_activation = T.nnet.softmax
    else:
        raise AssertionError(
            "Did not recognize visible activation {0!s}, please use sigmoid or softmax".format(state.visible_act))

    def update_layers(hiddens, p_X_chain, Xs, sequence_idx, noisy=True, sampling=True):
        print
        'odd layer updates'
        update_odd_layers(hiddens, noisy)
        print
        'even layer updates'
        update_even_layers(hiddens, p_X_chain, Xs, sequence_idx, noisy, sampling)
        # choose the correct output for hidden_outputs based on batch_size and walkbacks (this is due to an issue with batches, see note in run_story2.py)
        if state.batch_size <= len(Xs) and sequence_idx == state.batch_size - 1:
            return hiddens
        else:
            return None
        print
        'done full update.'
        print

    # Odd layer update function
    # just a loop over the odd layers
    def update_odd_layers(hiddens, noisy):
        for i in range(1, len(hiddens), 2):
            print
            'updating layer', i
            simple_update_layer(hiddens, None, None, None, i, add_noise=noisy)

    # Even layer update
    # p_X_chain is given to append the p(X|...) at each full update (one update = odd update + even update)
    def update_even_layers(hiddens, p_X_chain, Xs, sequence_idx, noisy, sampling):
        for i in range(0, len(hiddens), 2):
            print
            'updating layer', i
            simple_update_layer(hiddens, p_X_chain, Xs, sequence_idx, i, add_noise=noisy, input_sampling=sampling)

    # The layer update function
    # hiddens   :   list containing the symbolic theano variables [visible, hidden1, hidden2, ...]
    #               layer_update will modify this list inplace
    # p_X_chain :   list containing the successive p(X|...) at each update
    #               update_layer will append to this list
    # add_noise     : pre and post activation gaussian noise

    def simple_update_layer(hiddens, p_X_chain, Xs, sequence_idx, i, add_noise=True, input_sampling=True):
        # Compute the dot product, whatever layer
        # If the visible layer X
        if i == 0:
            print
            'using', recurrent_weights_list[i]
            hiddens[i] = (T.dot(hiddens[i + 1], recurrent_weights_list[i]) + bias_list[i])
        # If the top layer
        elif i == len(hiddens) - 1:
            print
            'using', weights_list[i - 1]
            hiddens[i] = T.dot(hiddens[i - 1], weights_list[i - 1]) + bias_list[i]
        # Otherwise in-between layers
        else:
            # next layer        :   hiddens[i+1], assigned weights : W_i
            # previous layer    :   hiddens[i-1], assigned weights : W_(i-1)
            print
            "using {0!s} and {1!s}".format(weights_list[i - 1], recurrent_weights_list[i])
            hiddens[i] = T.dot(hiddens[i + 1], recurrent_weights_list[i]) + T.dot(hiddens[i - 1], weights_list[i - 1]) + \
                         bias_list[i]

        # Add pre-activation noise if NOT input layer
        if i == 1 and state.noiseless_h1:
            print
            '>>NO noise in first hidden layer'
            add_noise = False

        # pre activation noise            
        if i != 0 and add_noise:
            print
            'Adding pre-activation gaussian noise for layer', i
            hiddens[i] = add_gaussian_noise(hiddens[i], state.hidden_add_noise_sigma)

        # ACTIVATION!
        if i == 0:
            print
            'Sigmoid units activation for visible layer X'
            hiddens[i] = visible_activation(hiddens[i])
        else:
            print
            'Hidden units {} activation for layer'.format(state.act), i
            hiddens[i] = hidden_activation(hiddens[i])

            # post activation noise
            # why is there post activation noise? Because there is already pre-activation noise, this just doubles the amount of noise between each activation of the hiddens.
        #         if i != 0 and add_noise:
        #             print 'Adding post-activation gaussian noise for layer', i
        #             hiddens[i]  =   add_gaussian(hiddens[i], state.hidden_add_noise_sigma)

        # build the reconstruction chain if updating the visible layer X
        if i == 0:
            # if input layer -> append p(X|...)
            p_X_chain.append(hiddens[i])  # what the predicted next input should be

            if sequence_idx + 1 < len(Xs):
                next_input = Xs[sequence_idx + 1]
                # sample from p(X|...) - SAMPLING NEEDS TO BE CORRECT FOR INPUT TYPES I.E. FOR BINARY MNIST SAMPLING IS BINOMIAL. real-valued inputs should be gaussian
                if input_sampling:
                    print
                    'Sampling from input'
                    sampled = MRG.binomial(p=next_input, size=next_input.shape, dtype='float32')
                else:
                    print
                    '>>NO input sampling'
                    sampled = next_input
                # add noise
                sampled = salt_and_pepper(sampled, state.input_salt_and_pepper)

                # DOES INPUT SAMPLING MAKE SENSE FOR SEQUENTIAL? - not really since it was used in walkbacks which was gibbs.
                # set input layer
                hiddens[i] = sampled

    def build_graph(hiddens, Xs, noisy=True, sampling=True):
        predicted_X_chain = []  # the visible layer that gets generated at each update_layers run
        H_chain = []  # either None or hiddens that gets generated at each update_layers run, this is used to determine what the correct hiddens_output should be
        print
        "Building the graph :", walkbacks, "updates"
        for i in range(walkbacks):
            print
            "Forward Prediction {!s}/{!s}".format(i + 1, walkbacks)
            H_chain.append(update_layers(hiddens, predicted_X_chain, Xs, i, noisy, sampling))
        return predicted_X_chain, H_chain

    '''Build the main training graph'''
    # corrupt x
    hiddens_output[0] = salt_and_pepper(hiddens_output[0], state.input_salt_and_pepper)
    # build the computation graph and the generated visible layers and appropriate hidden_output
    predicted_X_chain, H_chain = build_graph(hiddens_output, Xs, noisy=True, sampling=state.input_sampling)
    #     predicted_X_chain, H_chain = build_graph(hiddens_output, Xs, noisy=False, sampling=state.input_sampling) #testing one-hot without noise


    # choose the correct output for hiddens_output (this is due to the issue with batches - see note in run_story2.py)
    # this finds the not-None element of H_chain and uses that for hiddens_output
    h_empty = [True if h is None else False for h in H_chain]
    if False in h_empty:  # if there was a not-None element
        hiddens_output = H_chain[h_empty.index(False)]  # set hiddens_output to the appropriate element from H_chain

    ######################
    # COST AND GRADIENTS #
    ######################
    print
    if state.cost_funct == 'binary_crossentropy':
        print
        'Using binary cross-entropy cost!'
        cost_function = lambda x, y: T.mean(T.nnet.binary_crossentropy(x, y))
    elif state.cost_funct == 'square':
        print
        "Using square error cost!"
        cost_function = lambda x, y: T.mean(T.sqr(x - y))
    else:
        raise AssertionError(
            "Did not recognize cost function {0!s}, please use binary_crossentropy or square".format(state.cost_funct))
    print
    'Cost w.r.t p(X|...) at every step in the graph'

    costs = [cost_function(predicted_X_chain[i], Xs[i + 1]) for i in range(len(predicted_X_chain))]
    # outputs for the functions
    show_COSTs = [costs[0]] + [costs[-1]]

    # cost for the gradient
    # care more about the immediate next predictions rather than the future - use exponential decay
    #     COST = T.sum(costs)
    COST = T.sum([T.exp(-i / T.ceil(walkbacks / 3)) * costs[i] for i in range(len(costs))])

    params = weights_list + recurrent_weights_list + bias_list
    print
    "params:", params

    print
    "creating functions..."
    gradient = T.grad(COST, params)

    gradient_buffer = [theano.shared(numpy.zeros(param.get_value().shape, dtype='float32')) for param in params]

    m_gradient = [momentum * gb + (cast32(1) - momentum) * g for (gb, g) in zip(gradient_buffer, gradient)]
    param_updates = [(param, param - learning_rate * mg) for (param, mg) in zip(params, m_gradient)]
    gradient_buffer_updates = zip(gradient_buffer, m_gradient)

    updates = OrderedDict(param_updates + gradient_buffer_updates)

    # odd layer h's not used from input -> calculated directly from even layers (starting with h_0) since the odd layers are updated first.
    f_cost = theano.function(inputs=hiddens_input + Xs,
                             outputs=hiddens_output + show_COSTs,
                             on_unused_input='warn')

    f_learn = theano.function(inputs=hiddens_input + Xs,
                              updates=updates,
                              outputs=hiddens_output + show_COSTs,
                              on_unused_input='warn')

    print
    "functions done."
    print

    #############
    # Denoise some numbers  :   show number, noisy number, reconstructed number
    #############
    import random as R
    R.seed(1)
    # a function to add salt and pepper noise
    f_noise = theano.function(inputs=[X], outputs=salt_and_pepper(X, state.input_salt_and_pepper))

    # Recompile the graph without noise for reconstruction function - the input x_recon is already going to be noisy, and this is to test on a simulated 'real' input.
    X_recon = T.fvector("X_recon")
    Xs_recon = [T.fvector("Xs_recon")]
    hiddens_R_input = [X_recon] + [T.fvector(name="h_recon_" + str(i + 1)) for i in range(layers)]
    hiddens_R_output = hiddens_R_input[:1] + hiddens_R_input[1:]

    # The layer update scheme
    print
    "Creating graph for noisy reconstruction function at checkpoints during training."
    p_X_chain_R, H_chain_R = build_graph(hiddens_R_output, Xs_recon, noisy=False)

    # choose the correct output from H_chain for hidden_outputs based on batch_size and walkbacks
    # choose the correct output for hiddens_output
    h_empty = [True if h is None else False for h in H_chain_R]
    if False in h_empty:  # if there was a set of hiddens output from the batch_size-1 element of the chain
        hiddens_R_output = H_chain_R[
            h_empty.index(False)]  # extract out the not-None element from the list if it exists
    #     if state.batch_size <= len(Xs_recon):
    #         for i in range(len(hiddens_R_output)):
    #             hiddens_R_output[i] = H_chain_R[state.batch_size - 1][i]

    f_recon = theano.function(inputs=hiddens_R_input + Xs_recon,
                              outputs=hiddens_R_output + [p_X_chain_R[0], p_X_chain_R[-1]],
                              on_unused_input="warn")

    ############
    # Sampling #
    ############

    # the input to the sampling function
    X_sample = T.fmatrix("X_sampling")
    network_state_input = [X_sample] + [T.fmatrix("H_sampling_" + str(i + 1)) for i in range(layers)]

    # "Output" state of the network (noisy)
    # initialized with input, then we apply updates

    network_state_output = [X_sample] + network_state_input[1:]

    visible_pX_chain = []

    # ONE update
    print
    "Performing one walkback in network state sampling."
    _ = update_layers(network_state_output, visible_pX_chain, [X_sample], 0, noisy=True)

    if layers == 1:
        f_sample_simple = theano.function(inputs=[X_sample], outputs=visible_pX_chain[-1])

    # WHY IS THERE A WARNING????
    # because the first odd layers are not used -> directly computed FROM THE EVEN layers
    # unused input = warn
    f_sample2 = theano.function(inputs=network_state_input, outputs=network_state_output + visible_pX_chain,
                                on_unused_input='warn')

    def sample_some_numbers_single_layer():
        x0 = test_X.get_value()[:1]
        samples = [x0]
        x = f_noise(x0)
        for i in range(399):
            x = f_sample_simple(x)
            samples.append(x)
            x = numpy.random.binomial(n=1, p=x, size=x.shape).astype('float32')
            x = f_noise(x)
        return numpy.vstack(samples)

    def sampling_wrapper(NSI):
        # * is the "splat" operator: It takes a list as input, and expands it into actual positional arguments in the function call.
        out = f_sample2(*NSI)
        NSO = out[:len(network_state_output)]
        vis_pX_chain = out[len(network_state_output):]
        return NSO, vis_pX_chain

    def sample_some_numbers(N=400):
        # The network's initial state
        init_vis = test_X.get_value()[:1]

        noisy_init_vis = f_noise(init_vis)

        network_state = [
            [noisy_init_vis] + [numpy.zeros((1, len(b.get_value())), dtype='float32') for b in bias_list[1:]]]

        visible_chain = [init_vis]

        noisy_h0_chain = [noisy_init_vis]

        for i in range(N - 1):
            # feed the last state into the network, compute new state, and obtain visible units expectation chain
            net_state_out, vis_pX_chain = sampling_wrapper(network_state[-1])

            # append to the visible chain
            visible_chain += vis_pX_chain

            # append state output to the network state chain
            network_state.append(net_state_out)

            noisy_h0_chain.append(net_state_out[0])

        return numpy.vstack(visible_chain), numpy.vstack(noisy_h0_chain)

    def plot_samples(epoch_number, iteration):
        to_sample = time.time()
        if layers == 1:
            # one layer model
            V = sample_some_numbers_single_layer()
        else:
            V, H0 = sample_some_numbers()
        img_samples = PIL.Image.fromarray(tile_raster_images(V, (root_N_input, root_N_input), (20, 20)))

        fname = outdir + 'samples_iteration_' + str(iteration) + '_epoch_' + str(epoch_number) + '.png'
        img_samples.save(fname)
        print
        'Took ' + str(time.time() - to_sample) + ' to sample 400 numbers'

    ##############
    # Inpainting #
    ##############
    def inpainting(digit):
        # The network's initial state

        # NOISE INIT
        init_vis = cast32(numpy.random.uniform(size=digit.shape))

        # noisy_init_vis  =   f_noise(init_vis)
        # noisy_init_vis  =   cast32(numpy.random.uniform(size=init_vis.shape))

        # INDEXES FOR VISIBLE AND NOISY PART
        noise_idx = (numpy.arange(N_input) % root_N_input < (root_N_input / 2))
        fixed_idx = (numpy.arange(N_input) % root_N_input > (root_N_input / 2))

        # function to re-init the visible to the same noise

        # FUNCTION TO RESET HALF VISIBLE TO DIGIT
        def reset_vis(V):
            V[0][fixed_idx] = digit[0][fixed_idx]
            return V

        # INIT DIGIT : NOISE and RESET HALF TO DIGIT
        init_vis = reset_vis(init_vis)

        network_state = [[init_vis] + [numpy.zeros((1, len(b.get_value())), dtype='float32') for b in bias_list[1:]]]

        visible_chain = [init_vis]

        noisy_h0_chain = [init_vis]

        for i in range(49):
            # feed the last state into the network, compute new state, and obtain visible units expectation chain
            net_state_out, vis_pX_chain = sampling_wrapper(network_state[-1])

            # reset half the digit
            net_state_out[0] = reset_vis(net_state_out[0])
            vis_pX_chain[0] = reset_vis(vis_pX_chain[0])

            # append to the visible chain
            visible_chain += vis_pX_chain

            # append state output to the network state chain
            network_state.append(net_state_out)

            noisy_h0_chain.append(net_state_out[0])

        return numpy.vstack(visible_chain), numpy.vstack(noisy_h0_chain)

    def save_params_to_file(name, n, params, iteration):
        print
        'saving parameters...'
        save_path = outdir + name + '_params_iteration_' + str(iteration) + '_epoch_' + str(n) + '.pkl'
        f = open(save_path, 'wb')
        try:
            cPickle.dump(params, f, protocol=cPickle.HIGHEST_PROTOCOL)
        finally:
            f.close()

            ################

    # GSN TRAINING #
    ################
    def train_recurrent_GSN(iteration, train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        print
        '----------------------------------------'
        print
        'TRAINING GSN FOR ITERATION', iteration
        with open(logfile, 'a') as f:
            f.write("--------------------------\nTRAINING GSN FOR ITERATION {0!s}\n".format(iteration))

        # TRAINING
        n_epoch = state.n_epoch
        batch_size = state.batch_size
        STOP = False
        counter = 0
        if iteration == 0:
            learning_rate.set_value(cast32(state.learning_rate))  # learning rate
        times = []
        best_cost = float('inf')
        patience = 0

        print
        'learning rate:', learning_rate.get_value()

        print
        'train X size:', str(train_X.shape.eval())
        print
        'valid X size:', str(valid_X.shape.eval())
        print
        'test X size:', str(test_X.shape.eval())

        train_costs = []
        valid_costs = []
        test_costs = []
        train_costs_post = []
        valid_costs_post = []
        test_costs_post = []

        if state.vis_init:
            bias_list[0].set_value(logit(numpy.clip(0.9, 0.001, train_X.get_value().mean(axis=0))))

        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            print
            'Testing : skip training'
            STOP = True

        while not STOP:
            counter += 1
            t = time.time()
            print
            counter, '\t',
            with open(logfile, 'a') as f:
                f.write("{0!s}\t".format(counter))
            # shuffle the data
            data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)

            # train
            # init hiddens
            #             hiddens = [(T.zeros_like(train_X[:batch_size]).eval())]
            #             for i in range(len(weights_list)):
            #                 # init with zeros
            #                 hiddens.append(T.zeros_like(T.dot(hiddens[i], weights_list[i])).eval())
            hiddens = [T.zeros((batch_size, layer_size)).eval() for layer_size in layer_sizes]
            train_cost = []
            train_cost_post = []
            for i in range(len(train_X.get_value(borrow=True)) / batch_size):
                xs = [train_X.get_value(borrow=True)[
                      (i * batch_size) + sequence_idx: ((i + 1) * batch_size) + sequence_idx] for sequence_idx in
                      range(len(Xs))]
                xs, hiddens = fix_input_size(xs, hiddens)
                hiddens[0] = xs[0]
                _ins = hiddens + xs
                _outs = f_learn(*_ins)
                hiddens = _outs[:len(hiddens)]
                cost = _outs[-2]
                cost_post = _outs[-1]
                train_cost.append(cost)
                train_cost_post.append(cost_post)

            train_cost = numpy.mean(train_cost)
            train_costs.append(train_cost)
            train_cost_post = numpy.mean(train_cost_post)
            train_costs_post.append(train_cost_post)
            print
            'Train : ', trunc(train_cost), trunc(train_cost_post), '\t',
            with open(logfile, 'a') as f:
                f.write("Train : {0!s} {1!s}\t".format(trunc(train_cost), trunc(train_cost_post)))
            with open(train_convergence_pre, 'a') as f:
                f.write("{0!s},".format(train_cost))
            with open(train_convergence_post, 'a') as f:
                f.write("{0!s},".format(train_cost_post))

            # valid
            # init hiddens
            hiddens = [T.zeros((batch_size, layer_size)).eval() for layer_size in layer_sizes]
            valid_cost = []
            valid_cost_post = []
            for i in range(len(valid_X.get_value(borrow=True)) / batch_size):
                xs = [valid_X.get_value(borrow=True)[
                      (i * batch_size) + sequence_idx: ((i + 1) * batch_size) + sequence_idx] for sequence_idx in
                      range(len(Xs))]
                xs, hiddens = fix_input_size(xs, hiddens)
                hiddens[0] = xs[0]
                _ins = hiddens + xs
                _outs = f_cost(*_ins)
                hiddens = _outs[:-2]
                cost = _outs[-2]
                cost_post = _outs[-1]
                valid_cost.append(cost)
                valid_cost_post.append(cost_post)

            valid_cost = numpy.mean(valid_cost)
            valid_costs.append(valid_cost)
            valid_cost_post = numpy.mean(valid_cost_post)
            valid_costs_post.append(valid_cost_post)
            print
            'Valid : ', trunc(valid_cost), trunc(valid_cost_post), '\t',
            with open(logfile, 'a') as f:
                f.write("Valid : {0!s} {1!s}\t".format(trunc(valid_cost), trunc(valid_cost_post)))
            with open(valid_convergence_pre, 'a') as f:
                f.write("{0!s},".format(valid_cost))
            with open(valid_convergence_post, 'a') as f:
                f.write("{0!s},".format(valid_cost_post))

            # test
            # init hiddens
            hiddens = [T.zeros((batch_size, layer_size)).eval() for layer_size in layer_sizes]
            test_cost = []
            test_cost_post = []
            for i in range(len(test_X.get_value(borrow=True)) / batch_size):
                xs = [test_X.get_value(borrow=True)[
                      (i * batch_size) + sequence_idx: ((i + 1) * batch_size) + sequence_idx] for sequence_idx in
                      range(len(Xs))]
                xs, hiddens = fix_input_size(xs, hiddens)
                hiddens[0] = xs[0]
                _ins = hiddens + xs
                _outs = f_cost(*_ins)
                hiddens = _outs[:-2]
                cost = _outs[-2]
                cost_post = _outs[-1]
                test_cost.append(cost)
                test_cost_post.append(cost_post)

            test_cost = numpy.mean(test_cost)
            test_costs.append(test_cost)
            test_cost_post = numpy.mean(test_cost_post)
            test_costs_post.append(test_cost_post)
            print
            'Test  : ', trunc(test_cost), trunc(test_cost_post), '\t',
            with open(logfile, 'a') as f:
                f.write("Test : {0!s} {1!s}\t".format(trunc(test_cost), trunc(test_cost_post)))
            with open(test_convergence_pre, 'a') as f:
                f.write("{0!s},".format(test_cost))
            with open(test_convergence_post, 'a') as f:
                f.write("{0!s},".format(test_cost_post))

            # check for early stopping
            cost = train_cost
            if cost < best_cost * state.early_stop_threshold:
                patience = 0
                best_cost = cost
            else:
                patience += 1

            if counter >= n_epoch or patience >= state.early_stop_length:
                STOP = True
                save_params_to_file('gsn', counter, params, iteration)

            timing = time.time() - t
            times.append(timing)

            print
            'time : ', trunc(timing),

            print
            'remaining: ', trunc((n_epoch - counter) * numpy.mean(times) / 60 / 60), 'hrs',

            print
            'B : ', [trunc(abs(b.get_value(borrow=True)).mean()) for b in bias_list],

            print
            'W : ', [trunc(abs(w.get_value(borrow=True)).mean()) for w in weights_list],

            print
            'V : ', [trunc(abs(v.get_value(borrow=True)).mean()) for v in recurrent_weights_list]

            with open(logfile, 'a') as f:
                f.write("MeanVisB : {0!s}\t".format(trunc(bias_list[0].get_value().mean())))

            with open(logfile, 'a') as f:
                f.write("W : {0!s}\t".format(str([trunc(abs(w.get_value(borrow=True)).mean()) for w in weights_list])))

            with open(logfile, 'a') as f:
                f.write("Time : {0!s} seconds\n".format(trunc(timing)))

            if (counter % state.save_frequency) == 0:
                # Checking reconstruction
                nums = test_X.get_value()[range(100)]
                noisy_nums = f_noise(test_X.get_value()[range(100)])
                reconstructed_prediction = []
                reconstructed_prediction_end = []
                # init reconstruction hiddens
                hiddens = [T.zeros(layer_size).eval() for layer_size in layer_sizes]
                for num in noisy_nums:
                    hiddens[0] = num
                    for i in range(len(hiddens)):
                        if len(hiddens[i].shape) == 2 and hiddens[i].shape[0] == 1:
                            hiddens[i] = hiddens[i][0]
                    _ins = hiddens + [num]
                    _outs = f_recon(*_ins)
                    hiddens = _outs[:len(hiddens)]
                    [reconstructed_1, reconstructed_n] = _outs[len(hiddens):]
                    reconstructed_prediction.append(reconstructed_1)
                    reconstructed_prediction_end.append(reconstructed_n)

                with open(logfile, 'a') as f:
                    f.write("\n")
                for i in range(len(nums)):
                    if len(reconstructed_prediction[i].shape) == 2 and reconstructed_prediction[i].shape[0] == 1:
                        reconstructed_prediction[i] = reconstructed_prediction[i][0]
                    print
                    nums[i].tolist(), "->", reconstructed_prediction[i].tolist()
                    with open(logfile, 'a') as f:
                        f.write("{0!s} -> {1!s}\n".format(nums[i].tolist(),
                                                          [trunc(n) if n > 0.0001 else trunc(0.00000000000000000) for n
                                                           in reconstructed_prediction[i].tolist()]))
                with open(logfile, 'a') as f:
                    f.write("\n")

                #                 # Concatenate stuff
                #                 stacked = numpy.vstack([numpy.vstack([nums[i*10 : (i+1)*10], noisy_nums[i*10 : (i+1)*10], reconstructed_prediction[i*10 : (i+1)*10], reconstructed_prediction_end[i*10 : (i+1)*10]]) for i in range(10)])
                #                 numbers_reconstruction = PIL.Image.fromarray(tile_raster_images(stacked, (root_N_input,root_N_input), (10,40)))
                #                 numbers_reconstruction.save(outdir+'gsn_number_reconstruction_iteration_'+str(iteration)+'_epoch_'+str(counter)+'.png')
                #
                #                 #sample_numbers(counter, 'seven')
                #                 plot_samples(counter, iteration)
                #
                #                 #save params
                #                 save_params_to_file('gsn', counter, params, iteration)

            # ANNEAL!
            new_lr = learning_rate.get_value() * annealing
            learning_rate.set_value(new_lr)

        # 10k samples
        print
        'Generating 10,000 samples'
        samples, _ = sample_some_numbers(N=10000)
        f_samples = outdir + 'samples.npy'
        numpy.save(f_samples, samples)
        print
        'saved digits'

    #####################
    # STORY 2 ALGORITHM #
    #####################
    for iter in range(state.max_iterations):
        train_recurrent_GSN(iter, train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
    def train_recurrent_GSN(iteration, train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        print
        '----------------------------------------'
        print
        'TRAINING GSN FOR ITERATION', iteration
        with open(logfile, 'a') as f:
            f.write("--------------------------\nTRAINING GSN FOR ITERATION {0!s}\n".format(iteration))

        # TRAINING
        n_epoch = state.n_epoch
        batch_size = state.batch_size
        STOP = False
        counter = 0
        if iteration == 0:
            learning_rate.set_value(cast32(state.learning_rate))  # learning rate
        times = []
        best_cost = float('inf')
        patience = 0

        print
        'learning rate:', learning_rate.get_value()

        print
        'train X size:', str(train_X.shape.eval())
        print
        'valid X size:', str(valid_X.shape.eval())
        print
        'test X size:', str(test_X.shape.eval())

        train_costs = []
        valid_costs = []
        test_costs = []
        train_costs_post = []
        valid_costs_post = []
        test_costs_post = []

        if state.vis_init:
            bias_list[0].set_value(logit(numpy.clip(0.9, 0.001, train_X.get_value().mean(axis=0))))

        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            print
            'Testing : skip training'
            STOP = True

        while not STOP:
            counter += 1
            t = time.time()
            print
            counter, '\t',
            with open(logfile, 'a') as f:
                f.write("{0!s}\t".format(counter))
            # shuffle the data
            data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)

            # train
            # init hiddens
            #             hiddens = [(T.zeros_like(train_X[:batch_size]).eval())]
            #             for i in range(len(weights_list)):
            #                 # init with zeros
            #                 hiddens.append(T.zeros_like(T.dot(hiddens[i], weights_list[i])).eval())
            hiddens = [T.zeros((batch_size, layer_size)).eval() for layer_size in layer_sizes]
            train_cost = []
            train_cost_post = []
            for i in range(len(train_X.get_value(borrow=True)) / batch_size):
                xs = [train_X.get_value(borrow=True)[
                      (i * batch_size) + sequence_idx: ((i + 1) * batch_size) + sequence_idx] for sequence_idx in
                      range(len(Xs))]
                xs, hiddens = fix_input_size(xs, hiddens)
                hiddens[0] = xs[0]
                _ins = hiddens + xs
                _outs = f_learn(*_ins)
                hiddens = _outs[:len(hiddens)]
                cost = _outs[-2]
                cost_post = _outs[-1]
                train_cost.append(cost)
                train_cost_post.append(cost_post)

            train_cost = numpy.mean(train_cost)
            train_costs.append(train_cost)
            train_cost_post = numpy.mean(train_cost_post)
            train_costs_post.append(train_cost_post)
            print
            'Train : ', trunc(train_cost), trunc(train_cost_post), '\t',
            with open(logfile, 'a') as f:
                f.write("Train : {0!s} {1!s}\t".format(trunc(train_cost), trunc(train_cost_post)))
            with open(train_convergence_pre, 'a') as f:
                f.write("{0!s},".format(train_cost))
            with open(train_convergence_post, 'a') as f:
                f.write("{0!s},".format(train_cost_post))

            # valid
            # init hiddens
            hiddens = [T.zeros((batch_size, layer_size)).eval() for layer_size in layer_sizes]
            valid_cost = []
            valid_cost_post = []
            for i in range(len(valid_X.get_value(borrow=True)) / batch_size):
                xs = [valid_X.get_value(borrow=True)[
                      (i * batch_size) + sequence_idx: ((i + 1) * batch_size) + sequence_idx] for sequence_idx in
                      range(len(Xs))]
                xs, hiddens = fix_input_size(xs, hiddens)
                hiddens[0] = xs[0]
                _ins = hiddens + xs
                _outs = f_cost(*_ins)
                hiddens = _outs[:-2]
                cost = _outs[-2]
                cost_post = _outs[-1]
                valid_cost.append(cost)
                valid_cost_post.append(cost_post)

            valid_cost = numpy.mean(valid_cost)
            valid_costs.append(valid_cost)
            valid_cost_post = numpy.mean(valid_cost_post)
            valid_costs_post.append(valid_cost_post)
            print
            'Valid : ', trunc(valid_cost), trunc(valid_cost_post), '\t',
            with open(logfile, 'a') as f:
                f.write("Valid : {0!s} {1!s}\t".format(trunc(valid_cost), trunc(valid_cost_post)))
            with open(valid_convergence_pre, 'a') as f:
                f.write("{0!s},".format(valid_cost))
            with open(valid_convergence_post, 'a') as f:
                f.write("{0!s},".format(valid_cost_post))

            # test
            # init hiddens
            hiddens = [T.zeros((batch_size, layer_size)).eval() for layer_size in layer_sizes]
            test_cost = []
            test_cost_post = []
            for i in range(len(test_X.get_value(borrow=True)) / batch_size):
                xs = [test_X.get_value(borrow=True)[
                      (i * batch_size) + sequence_idx: ((i + 1) * batch_size) + sequence_idx] for sequence_idx in
                      range(len(Xs))]
                xs, hiddens = fix_input_size(xs, hiddens)
                hiddens[0] = xs[0]
                _ins = hiddens + xs
                _outs = f_cost(*_ins)
                hiddens = _outs[:-2]
                cost = _outs[-2]
                cost_post = _outs[-1]
                test_cost.append(cost)
                test_cost_post.append(cost_post)

            test_cost = numpy.mean(test_cost)
            test_costs.append(test_cost)
            test_cost_post = numpy.mean(test_cost_post)
            test_costs_post.append(test_cost_post)
            print
            'Test  : ', trunc(test_cost), trunc(test_cost_post), '\t',
            with open(logfile, 'a') as f:
                f.write("Test : {0!s} {1!s}\t".format(trunc(test_cost), trunc(test_cost_post)))
            with open(test_convergence_pre, 'a') as f:
                f.write("{0!s},".format(test_cost))
            with open(test_convergence_post, 'a') as f:
                f.write("{0!s},".format(test_cost_post))

            # check for early stopping
            cost = train_cost
            if cost < best_cost * state.early_stop_threshold:
                patience = 0
                best_cost = cost
            else:
                patience += 1

            if counter >= n_epoch or patience >= state.early_stop_length:
                STOP = True
                save_params_to_file('gsn', counter, params, iteration)

            timing = time.time() - t
            times.append(timing)

            print
            'time : ', trunc(timing),

            print
            'remaining: ', trunc((n_epoch - counter) * numpy.mean(times) / 60 / 60), 'hrs',

            print
            'B : ', [trunc(abs(b.get_value(borrow=True)).mean()) for b in bias_list],

            print
            'W : ', [trunc(abs(w.get_value(borrow=True)).mean()) for w in weights_list],

            print
            'V : ', [trunc(abs(v.get_value(borrow=True)).mean()) for v in recurrent_weights_list]

            with open(logfile, 'a') as f:
                f.write("MeanVisB : {0!s}\t".format(trunc(bias_list[0].get_value().mean())))

            with open(logfile, 'a') as f:
                f.write("W : {0!s}\t".format(str([trunc(abs(w.get_value(borrow=True)).mean()) for w in weights_list])))

            with open(logfile, 'a') as f:
                f.write("Time : {0!s} seconds\n".format(trunc(timing)))

            if (counter % state.save_frequency) == 0:
                # Checking reconstruction
                nums = test_X.get_value()[range(100)]
                noisy_nums = f_noise(test_X.get_value()[range(100)])
                reconstructed_prediction = []
                reconstructed_prediction_end = []
                # init reconstruction hiddens
                hiddens = [T.zeros(layer_size).eval() for layer_size in layer_sizes]
                for num in noisy_nums:
                    hiddens[0] = num
                    for i in range(len(hiddens)):
                        if len(hiddens[i].shape) == 2 and hiddens[i].shape[0] == 1:
                            hiddens[i] = hiddens[i][0]
                    _ins = hiddens + [num]
                    _outs = f_recon(*_ins)
                    hiddens = _outs[:len(hiddens)]
                    [reconstructed_1, reconstructed_n] = _outs[len(hiddens):]
                    reconstructed_prediction.append(reconstructed_1)
                    reconstructed_prediction_end.append(reconstructed_n)

                with open(logfile, 'a') as f:
                    f.write("\n")
                for i in range(len(nums)):
                    if len(reconstructed_prediction[i].shape) == 2 and reconstructed_prediction[i].shape[0] == 1:
                        reconstructed_prediction[i] = reconstructed_prediction[i][0]
                    print
                    nums[i].tolist(), "->", reconstructed_prediction[i].tolist()
                    with open(logfile, 'a') as f:
                        f.write("{0!s} -> {1!s}\n".format(nums[i].tolist(),
                                                          [trunc(n) if n > 0.0001 else trunc(0.00000000000000000) for n
                                                           in reconstructed_prediction[i].tolist()]))
                with open(logfile, 'a') as f:
                    f.write("\n")

                #                 # Concatenate stuff
                #                 stacked = numpy.vstack([numpy.vstack([nums[i*10 : (i+1)*10], noisy_nums[i*10 : (i+1)*10], reconstructed_prediction[i*10 : (i+1)*10], reconstructed_prediction_end[i*10 : (i+1)*10]]) for i in range(10)])
                #                 numbers_reconstruction = PIL.Image.fromarray(tile_raster_images(stacked, (root_N_input,root_N_input), (10,40)))
                #                 numbers_reconstruction.save(outdir+'gsn_number_reconstruction_iteration_'+str(iteration)+'_epoch_'+str(counter)+'.png')
                #
                #                 #sample_numbers(counter, 'seven')
                #                 plot_samples(counter, iteration)
                #
                #                 #save params
                #                 save_params_to_file('gsn', counter, params, iteration)

            # ANNEAL!
            new_lr = learning_rate.get_value() * annealing
            learning_rate.set_value(new_lr)

        # 10k samples
        print
        'Generating 10,000 samples'
        samples, _ = sample_some_numbers(N=10000)
        f_samples = outdir + 'samples.npy'
        numpy.save(f_samples, samples)
        print
        'saved digits'
Example #5
0
def experiment(state, outdir_base='./'):
    rng.seed(1)  # seed the numpy random generator
    R.seed(1)  # seed the other random generator (for reconstruction function indices)
    # Initialize the output directories and files
    data.mkdir_p(outdir_base)
    outdir = outdir_base + "/" + state.dataset + "/"
    data.mkdir_p(outdir)
    logger = Logger(outdir)
    train_convergence = outdir + "train_convergence.csv"
    valid_convergence = outdir + "valid_convergence.csv"
    test_convergence = outdir + "test_convergence.csv"
    regression_train_convergence = outdir + "regression_train_convergence.csv"
    regression_valid_convergence = outdir + "regression_valid_convergence.csv"
    regression_test_convergence = outdir + "regression_test_convergence.csv"
    init_empty_file(train_convergence)
    init_empty_file(valid_convergence)
    init_empty_file(test_convergence)
    init_empty_file(regression_train_convergence)
    init_empty_file(regression_valid_convergence)
    init_empty_file(regression_test_convergence)

    logger.log("----------MODEL 1, {0!s}--------------\n\n".format(state.dataset))

    # load parameters from config file if this is a test
    config_filename = outdir + 'config'
    if state.test_model and 'config' in os.listdir(outdir):
        config_vals = load_from_config(config_filename)
        for CV in config_vals:
            logger.log(CV)
            if CV.startswith('test'):
                logger.log('Do not override testing switch')
                continue
            try:
                exec('state.' + CV) in globals(), locals()
            except:
                exec('state.' + CV.split('=')[0] + "='" + CV.split('=')[1] + "'") in globals(), locals()
    else:
        # Save the current configuration
        # Useful for logs/experiments
        logger.log('Saving config')
        with open(config_filename, 'w') as f:
            f.write(str(state))

    logger.log(state)

    ####################################################
    # Load the data, train = train+valid, and sequence #
    ####################################################
    artificial = False  # internal flag to see if the dataset is one of my artificially-sequenced MNIST varieties.
    if state.dataset == 'MNIST_1' or state.dataset == 'MNIST_2' or state.dataset == 'MNIST_3' or state.dataset == 'MNIST_4':
        (train_X, train_Y), (valid_X, valid_Y), (test_X, test_Y) = data.load_mnist(state.data_path)
        train_X = numpy.concatenate((train_X, valid_X))
        train_Y = numpy.concatenate((train_Y, valid_Y))
        artificial = True
        try:
            dataset = int(state.dataset.split('_')[1])
        except:
            raise AssertionError("artificial dataset number not recognized. Input was " + state.dataset)
    else:
        raise AssertionError("dataset not recognized.")

    # transfer the datasets into theano shared variables
    train_X = theano.shared(train_X)
    train_Y = theano.shared(train_Y)
    valid_X = theano.shared(valid_X)
    valid_Y = theano.shared(valid_Y)
    test_X = theano.shared(test_X)
    test_Y = theano.shared(test_Y)

    if artificial:  # if it my MNIST sequence, appropriately sequence it.
        logger.log('Sequencing MNIST data...')
        logger.log(['train set size:', len(train_Y.eval())])
        logger.log(['valid set size:', len(valid_Y.eval())])
        logger.log(['test set size:', len(test_Y.eval())])
        data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
        logger.log(['train set size:', len(train_Y.eval())])
        logger.log(['valid set size:', len(valid_Y.eval())])
        logger.log(['test set size:', len(test_Y.eval())])
        logger.log('Sequencing done.\n')

    # variables from the dataset that are used for initialization and image reconstruction
    N_input = train_X.eval().shape[1]
    root_N_input = numpy.sqrt(N_input)

    # Network and training specifications
    layers = state.layers  # number hidden layers
    walkbacks = state.walkbacks  # number of walkbacks
    sequence_window_size = state.sequence_window_size  # number of previous hidden states to consider for the regression
    layer_sizes = [N_input] + [state.hidden_size] * layers  # layer sizes, from h0 to hK (h0 is the visible layer)
    learning_rate = theano.shared(cast32(state.learning_rate))  # learning rate
    regression_learning_rate = theano.shared(cast32(state.learning_rate))  # learning rate
    annealing = cast32(state.annealing)  # exponential annealing coefficient
    momentum = theano.shared(cast32(state.momentum))  # momentum term

    # Theano variables and RNG
    X = T.fmatrix('X')  # for use in sampling
    Xs = [T.fmatrix(name="X_t") if i == 0 else T.fmatrix(name="X_{t-" + str(i) + "}") for i in range(
        sequence_window_size + 1)]  # for use in training - need one X variable for each input in the sequence history window, and what the current one should be
    Xs_recon = [T.fvector(name="Xrecon_t") if i == 0 else T.fvector(name="Xrecon_{t-" + str(i) + "}") for i in range(
        sequence_window_size + 1)]  # for use in training - need one X variable for each input in the sequence history window, and what the current one should be
    # sequence_graph_output_index = T.lscalar("i")
    MRG = RNG_MRG.MRG_RandomStreams(1)

    ##############
    # PARAMETERS #
    ##############
    # initialize a list of weights and biases based on layer_sizes for the GSN
    weights_list = [
        get_shared_weights(layer_sizes[layer], layer_sizes[layer + 1], name="W_{0!s}_{1!s}".format(layer, layer + 1))
        for layer in range(layers)]  # initialize each layer to uniform sample from sqrt(6. / (n_in + n_out))
    bias_list = [get_shared_bias(layer_sizes[layer], name='b_' + str(layer)) for layer in
                 range(layers + 1)]  # initialize each layer to 0's.
    # parameters for the regression - only need them for the odd layers in the network!
    regression_weights_list = [
        [get_shared_regression_weights(state.hidden_size, name="V_{t-" + str(window + 1) + "}_layer" + str(layer)) for
         layer in range(layers + 1) if (layer % 2) != 0] for window in
        range(sequence_window_size)]  # initialize to identity matrix the size of hidden layer.
    regression_bias_list = [get_shared_bias(state.hidden_size, name='vb_' + str(layer)) for layer in range(layers + 1)
                            if (layer % 2) != 0]  # initialize to 0's.
    # need initial biases (tau) as well for when there aren't sequence_window_size hiddens in the history.
    tau_list = [
        [get_shared_bias(state.hidden_size, name='tau_{t-' + str(window + 1) + "}_layer" + str(layer)) for layer in
         range(layers + 1) if (layer % 2) != 0] for window in range(sequence_window_size)]

    ###########################################################
    # load initial parameters of gsn to speed up my debugging #
    ###########################################################
    params_to_load = 'gsn_params.pkl'
    initialized_gsn = False
    if os.path.isfile(params_to_load):
        logger.log("\nLoading existing GSN parameters")
        loaded_params = cPickle.load(open(params_to_load, 'r'))
        [p.set_value(lp.get_value(borrow=False)) for lp, p in zip(loaded_params[:len(weights_list)], weights_list)]
        [p.set_value(lp.get_value(borrow=False)) for lp, p in zip(loaded_params[len(weights_list):], bias_list)]
        initialized_gsn = True

    ########################
    # ACTIVATION FUNCTIONS #
    ########################
    if state.hidden_act == 'sigmoid':
        logger.log('Using sigmoid activation for hiddens')
        hidden_activation = T.nnet.sigmoid
    elif state.hidden_act == 'rectifier':
        logger.log('Using rectifier activation for hiddens')
        hidden_activation = lambda x: T.maximum(cast32(0), x)
    elif state.hidden_act == 'tanh':
        logger.log('Using hyperbolic tangent activation for hiddens')
        hidden_activation = lambda x: T.tanh(x)
    else:
        logger.log("Did not recognize hidden activation {0!s}, please use tanh, rectifier, or sigmoid".format(
            state.hidden_act))
        raise AssertionError("Did not recognize hidden activation {0!s}, please use tanh, rectifier, or sigmoid".format(
            state.hidden_act))

    if state.visible_act == 'sigmoid':
        logger.log('Using sigmoid activation for visible layer')
        visible_activation = T.nnet.sigmoid
    elif state.visible_act == 'softmax':
        logger.log('Using softmax activation for visible layer')
        visible_activation = T.nnet.softmax
    else:
        logger.log(
            "Did not recognize visible activation {0!s}, please use sigmoid or softmax".format(state.visible_act))
        raise AssertionError(
            "Did not recognize visible activation {0!s}, please use sigmoid or softmax".format(state.visible_act))

    ###############################################
    # COMPUTATIONAL GRAPH HELPER METHODS FOR TGSN #
    ###############################################
    def update_layers(hiddens, p_X_chain, noisy=True):
        logger.log('odd layer updates')
        update_odd_layers(hiddens, noisy)
        logger.log('even layer updates')
        update_even_layers(hiddens, p_X_chain, noisy)
        logger.log('done full update.\n')

    def update_layers_reverse(hiddens, p_X_chain, noisy=True):
        logger.log('even layer updates')
        update_even_layers(hiddens, p_X_chain, noisy)
        logger.log('odd layer updates')
        update_odd_layers(hiddens, noisy)
        logger.log('done full update.\n')

    # Odd layer update function
    # just a loop over the odd layers
    def update_odd_layers(hiddens, noisy):
        for i in range(1, len(hiddens), 2):
            logger.log(['updating layer', i])
            simple_update_layer(hiddens, None, i, add_noise=noisy)

    # Even layer update
    # p_X_chain is given to append the p(X|...) at each full update (one update = odd update + even update)
    def update_even_layers(hiddens, p_X_chain, noisy):
        for i in range(0, len(hiddens), 2):
            logger.log(['updating layer', i])
            simple_update_layer(hiddens, p_X_chain, i, add_noise=noisy)

    # The layer update function
    # hiddens   :   list containing the symbolic theano variables [visible, hidden1, hidden2, ...]
    #               layer_update will modify this list inplace
    # p_X_chain :   list containing the successive p(X|...) at each update
    #               update_layer will append to this list
    # i         :   the current layer being updated
    # add_noise :   pre (and post) activation gaussian noise flag
    def simple_update_layer(hiddens, p_X_chain, i, add_noise=True):
        # Compute the dot product, whatever layer
        # If the visible layer X
        if i == 0:
            logger.log('using ' + str(weights_list[i]) + '.T')
            hiddens[i] = T.dot(hiddens[i + 1], weights_list[i].T) + bias_list[i]
            # If the top layer
        elif i == len(hiddens) - 1:
            logger.log(['using', weights_list[i - 1]])
            hiddens[i] = T.dot(hiddens[i - 1], weights_list[i - 1]) + bias_list[i]
        # Otherwise in-between layers
        else:
            logger.log(["using {0!s} and {1!s}.T".format(weights_list[i - 1], weights_list[i])])
            # next layer        :   hiddens[i+1], assigned weights : W_i
            # previous layer    :   hiddens[i-1], assigned weights : W_(i-1)
            hiddens[i] = T.dot(hiddens[i + 1], weights_list[i].T) + T.dot(hiddens[i - 1], weights_list[i - 1]) + \
                         bias_list[i]

        # Add pre-activation noise if NOT input layer
        if i == 1 and state.noiseless_h1:
            logger.log('>>NO noise in first hidden layer')
            add_noise = False

        # pre activation noise            
        if i != 0 and add_noise:
            logger.log(['Adding pre-activation gaussian noise for layer', i])
            hiddens[i] = add_gaussian_noise(hiddens[i], state.hidden_add_noise_sigma)

        # ACTIVATION!
        if i == 0:
            logger.log('{} activation for visible layer'.format(state.visible_act))
            hiddens[i] = visible_activation(hiddens[i])
        else:
            logger.log(['Hidden units {} activation for layer'.format(state.hidden_act), i])
            hiddens[i] = hidden_activation(hiddens[i])

            # post activation noise
            # why is there post activation noise? Because there is already pre-activation noise, this just doubles the amount of noise between each activation of the hiddens.
        #         if i != 0 and add_noise:
        #             logger.log(['Adding post-activation gaussian noise for layer', i])
        #             hiddens[i]  =   add_gaussian(hiddens[i], state.hidden_add_noise_sigma)

        # build the reconstruction chain if updating the visible layer X
        if i == 0:
            # if input layer -> append p(X|...)
            p_X_chain.append(hiddens[i])

            # sample from p(X|...) - SAMPLING NEEDS TO BE CORRECT FOR INPUT TYPES I.E. FOR BINARY MNIST SAMPLING IS BINOMIAL. real-valued inputs should be gaussian
            if state.input_sampling:
                logger.log('Sampling from input')
                sampled = MRG.binomial(p=hiddens[i], size=hiddens[i].shape, dtype='float32')
            else:
                logger.log('>>NO input sampling')
                sampled = hiddens[i]
            # add noise
            sampled = salt_and_pepper(sampled, state.input_salt_and_pepper)

            # set input layer
            hiddens[i] = sampled

    def perform_regression_step(hiddens, sequence_history):
        logger.log(["Sequence history length:", len(sequence_history)])
        # only need to work over the odd layers of the hiddens
        odd_layers = [i for i in range(len(hiddens)) if (i % 2) != 0]
        # depending on the size of the sequence history, it could be 0, 1, 2, 3, ... sequence_window_size
        for (hidden_index, regression_index) in zip(odd_layers, range(len(odd_layers))):
            terms_used = []
            sequence_terms = []
            for history_index in range(sequence_window_size):
                if history_index < len(sequence_history):
                    # dot product with history term
                    sequence_terms.append(T.dot(sequence_history[history_index][regression_index],
                                                regression_weights_list[history_index][regression_index]))
                    terms_used.append(regression_weights_list[history_index][regression_index])
                else:
                    # otherwise, no history for necessary spot, so use the tau
                    sequence_terms.append(tau_list[history_index][regression_index])
                    terms_used.append(tau_list[history_index][regression_index])

            if len(sequence_terms) > 0:
                sequence_terms.append(regression_bias_list[regression_index])
                terms_used.append(regression_bias_list[regression_index])
                logger.log(["REGRESSION for hidden layer {0!s} using:".format(hidden_index), terms_used])
                hiddens[hidden_index] = numpy.sum(sequence_terms)

    def build_gsn_graph(x, noiseflag):
        p_X_chain = []
        if noiseflag:
            X_init = salt_and_pepper(x, state.input_salt_and_pepper)
        else:
            X_init = x
        # init hiddens with zeros
        hiddens = [X_init]
        for w in weights_list:
            hiddens.append(T.zeros_like(T.dot(hiddens[-1], w)))
        # The layer update scheme
        logger.log(["Building the gsn graph :", walkbacks, "updates"])
        for i in range(walkbacks):
            logger.log("GSN Walkback {!s}/{!s}".format(i + 1, walkbacks))
            update_layers(hiddens, p_X_chain, noisy=noiseflag)

        return p_X_chain

    def build_sequence_graph(xs, noiseflag):
        predicted_X_chains = []
        p_X_chains = []
        sequence_history = []
        # The layer update scheme
        logger.log(["Building the regression graph :", len(Xs), "updates"])
        for x_index in range(len(xs)):
            x = xs[x_index]
            # Predict what the current X should be
            ''' hidden layer init '''
            pred_hiddens = [T.zeros_like(x)]
            for w in weights_list:
                # init with zeros
                pred_hiddens.append(T.zeros_like(T.dot(pred_hiddens[-1], w)))
            logger.log("Performing regression step!")
            perform_regression_step(pred_hiddens, sequence_history)  # do the regression!
            logger.log("\n")

            predicted_X_chain = []
            for i in range(walkbacks):
                logger.log("Prediction Walkback {!s}/{!s}".format(i + 1, walkbacks))
                update_layers_reverse(pred_hiddens, predicted_X_chain,
                                      noisy=False)  # no noise in the prediction because x_prediction can't be recovered from x anyway
            predicted_X_chains.append(predicted_X_chain)

            # Now do the actual GSN step and add it to the sequence history
            # corrupt x if noisy
            if noiseflag:
                X_init = salt_and_pepper(x, state.input_salt_and_pepper)
            else:
                X_init = x
            ''' hidden layer init '''
            hiddens = [T.zeros_like(x)]
            for w in weights_list:
                # init with zeros
                hiddens.append(T.zeros_like(T.dot(hiddens[-1], w)))
            #             # substitute some of the zero layers for what was predicted - need to advance the prediction by 1 layer so it is the evens
            #             update_even_layers(pred_hiddens,[],noisy=False)
            #             for i in [layer for layer in range(len(hiddens)) if (layer%2 == 0)]:
            #                 hiddens[i] = pred_hiddens[i]
            hiddens[0] = X_init

            chain = []
            for i in range(walkbacks):
                logger.log("GSN walkback {!s}/{!s}".format(i + 1, walkbacks))
                update_layers(hiddens, chain, noisy=noiseflag)
            # Append the p_X_chain
            p_X_chains.append(chain)
            # Append the odd layers of the hiddens to the sequence history
            sequence_history.append([hiddens[layer] for layer in range(len(hiddens)) if (layer % 2) != 0])


            # select the prediction and reconstruction from the lists
        #         prediction_chain = T.stacklists(predicted_X_chains)[sequence_graph_output_index]
        #         reconstruction_chain = T.stacklists(p_X_chains)[sequence_graph_output_index]
        return predicted_X_chains, p_X_chains

    ##############################################
    #    Build the training graph for the GSN    #
    ##############################################
    logger.log("\nBuilding GSN graphs")
    p_X_chain_init = build_gsn_graph(X, noiseflag=True)
    predicted_X_chain_gsns, p_X_chains = build_sequence_graph(Xs, noiseflag=True)
    predicted_X_chain_gsn = predicted_X_chain_gsns[-1]
    p_X_chain = p_X_chains[-1]

    ###############################################
    # Build the training graph for the regression #
    ###############################################
    logger.log("\nBuilding regression graph")
    # no noise! noise is only used as regularization for GSN stage
    predicted_X_chains_regression, _ = build_sequence_graph(Xs, noiseflag=False)
    predicted_X_chain = predicted_X_chains_regression[-1]

    ######################
    # COST AND GRADIENTS #
    ######################
    if state.cost_funct == 'binary_crossentropy':
        logger.log('\nUsing binary cross-entropy cost!')
        cost_function = lambda x, y: T.mean(T.nnet.binary_crossentropy(x, y))
    elif state.cost_funct == 'square':
        logger.log("\nUsing square error cost!")
        # cost_function = lambda x,y: T.log(T.mean(T.sqr(x-y)))
        cost_function = lambda x, y: T.log(T.sum(T.pow((x - y), 2)))
    else:
        raise AssertionError(
            "Did not recognize cost function {0!s}, please use binary_crossentropy or square".format(state.cost_funct))

    logger.log('Cost w.r.t p(X|...) at every step in the graph for the TGSN')
    gsn_costs_init = [cost_function(rX, X) for rX in p_X_chain_init]
    show_gsn_cost_init = gsn_costs_init[-1]
    gsn_cost_init = numpy.sum(gsn_costs_init)
    gsn_init_mse = T.mean(T.sqr(p_X_chain_init[-1] - X), axis=0)
    gsn_init_error = T.mean(gsn_init_mse)

    # gsn_costs     = T.mean(T.mean(T.nnet.binary_crossentropy(p_X_chain, T.stacklists(Xs)[sequence_graph_output_index]),2),1)
    gsn_costs = [cost_function(rX, Xs[-1]) for rX in predicted_X_chain_gsn]
    show_gsn_cost = gsn_costs[-1]
    gsn_cost = T.sum(gsn_costs)
    gsn_mse = T.mean(T.sqr(predicted_X_chain_gsn[-1] - Xs[-1]), axis=0)
    gsn_error = T.mean(gsn_mse)

    gsn_params = weights_list + bias_list
    logger.log(["gsn params:", gsn_params])

    # l2 regularization
    # regression_regularization_cost = T.sum([T.sum(recurrent_weights ** 2) for recurrent_weights in regression_weights_list])
    regression_regularization_cost = 0
    regression_costs = [cost_function(rX, Xs[-1]) for rX in predicted_X_chain]
    show_regression_cost = regression_costs[-1]
    regression_cost = T.sum(regression_costs) + state.regularize_weight * regression_regularization_cost
    regression_mse = T.mean(T.sqr(predicted_X_chain[-1] - Xs[-1]), axis=0)
    regression_error = T.mean(regression_mse)

    # only using the odd layers update -> even-indexed parameters in the list because it starts at v1
    # need to flatten the regression list -> couldn't immediately find the python method so here is the implementation
    regression_weights_flattened = []
    for weights in regression_weights_list:
        regression_weights_flattened.extend(weights)
    tau_flattened = []
    for tau in tau_list:
        tau_flattened.extend(tau)

    regression_params = regression_weights_flattened + regression_bias_list  # + tau_flattened

    logger.log(["regression params:", regression_params])

    logger.log("creating functions...")
    t = time.time()

    gradient_init = T.grad(gsn_cost_init, gsn_params)
    gradient_buffer_init = [theano.shared(numpy.zeros(param.get_value().shape, dtype='float32')) for param in
                            gsn_params]
    m_gradient_init = [momentum * gb + (cast32(1) - momentum) * g for (gb, g) in
                       zip(gradient_buffer_init, gradient_init)]
    param_updates_init = [(param, param - learning_rate * mg) for (param, mg) in zip(gsn_params, m_gradient_init)]
    gradient_buffer_updates_init = zip(gradient_buffer_init, m_gradient_init)
    updates_init = OrderedDict(param_updates_init + gradient_buffer_updates_init)

    gsn_f_learn_init = theano.function(inputs=[X],
                                       updates=updates_init,
                                       outputs=[show_gsn_cost_init, gsn_init_error])

    gsn_f_cost_init = theano.function(inputs=[X],
                                      outputs=[show_gsn_cost_init, gsn_init_error])

    gradient = T.grad(gsn_cost, gsn_params)
    gradient_buffer = [theano.shared(numpy.zeros(param.get_value().shape, dtype='float32')) for param in gsn_params]
    m_gradient = [momentum * gb + (cast32(1) - momentum) * g for (gb, g) in zip(gradient_buffer, gradient)]
    param_updates = [(param, param - learning_rate * mg) for (param, mg) in zip(gsn_params, m_gradient)]
    gradient_buffer_updates = zip(gradient_buffer, m_gradient)

    updates = OrderedDict(param_updates + gradient_buffer_updates)

    gsn_f_cost = theano.function(inputs=Xs,
                                 outputs=[show_gsn_cost, gsn_error])

    gsn_f_learn = theano.function(inputs=Xs,
                                  updates=updates,
                                  outputs=[show_gsn_cost, gsn_error])

    regression_gradient = T.grad(regression_cost, regression_params)
    regression_gradient_buffer = [theano.shared(numpy.zeros(rparam.get_value().shape, dtype='float32')) for rparam in
                                  regression_params]
    regression_m_gradient = [momentum * rgb + (cast32(1) - momentum) * rg for (rgb, rg) in
                             zip(regression_gradient_buffer, regression_gradient)]
    regression_param_updates = [(rparam, rparam - regression_learning_rate * rmg) for (rparam, rmg) in
                                zip(regression_params, regression_m_gradient)]
    regression_gradient_buffer_updates = zip(regression_gradient_buffer, regression_m_gradient)

    regression_updates = OrderedDict(regression_param_updates + regression_gradient_buffer_updates)

    regression_f_cost = theano.function(inputs=Xs,
                                        outputs=[show_regression_cost, regression_error])

    regression_f_learn = theano.function(inputs=Xs,
                                         updates=regression_updates,
                                         outputs=[show_regression_cost, regression_error])

    logger.log("functions done. took " + make_time_units_string(time.time() - t) + ".\n")

    ############################################################################################
    # Denoise some numbers : show number, noisy number, predicted number, reconstructed number #
    ############################################################################################   
    # Recompile the graph without noise for reconstruction function
    # The layer update scheme
    logger.log("Creating graph for noisy reconstruction function at checkpoints during training.")
    predicted_X_chains_R, p_X_chains_R = build_sequence_graph(Xs_recon, noiseflag=False)
    predicted_X_chain_R = predicted_X_chains_R[-1]
    p_X_chain_R = p_X_chains_R[-1]
    f_recon = theano.function(inputs=Xs_recon, outputs=[predicted_X_chain_R[-1], p_X_chain_R[-1]])

    # Now do the same but for the GSN in the initial run
    p_X_chain_R = build_gsn_graph(X, noiseflag=False)
    f_recon_init = theano.function(inputs=[X], outputs=p_X_chain_R[-1])

    ############
    # Sampling #
    ############
    f_noise = theano.function(inputs=[X], outputs=salt_and_pepper(X, state.input_salt_and_pepper))
    # the input to the sampling function
    network_state_input = [X] + [T.fmatrix() for i in range(layers)]

    # "Output" state of the network (noisy)
    # initialized with input, then we apply updates
    # network_state_output    =   network_state_input

    network_state_output = [X] + network_state_input[1:]

    visible_pX_chain = []

    # ONE update
    logger.log("Performing one walkback in network state sampling.")
    update_layers(network_state_output, visible_pX_chain, noisy=True)

    if layers == 1:
        f_sample_simple = theano.function(inputs=[X], outputs=visible_pX_chain[-1])

    # WHY IS THERE A WARNING????
    # because the first odd layers are not used -> directly computed FROM THE EVEN layers
    # unused input = warn
    f_sample2 = theano.function(inputs=network_state_input, outputs=network_state_output + visible_pX_chain,
                                on_unused_input='warn')

    def sample_some_numbers_single_layer():
        x0 = test_X.get_value()[7:8]
        samples = [x0]
        x = f_noise(x0)
        for i in range(399):
            x = f_sample_simple(x)
            samples.append(x)
            x = numpy.random.binomial(n=1, p=x, size=x.shape).astype('float32')
            x = f_noise(x)
        return numpy.vstack(samples)

    def sampling_wrapper(NSI):
        # * is the "splat" operator: It takes a list as input, and expands it into actual positional arguments in the function call.
        out = f_sample2(*NSI)
        NSO = out[:len(network_state_output)]
        vis_pX_chain = out[len(network_state_output):]
        return NSO, vis_pX_chain

    def sample_some_numbers(N=400):
        # The network's initial state
        init_vis = test_X.get_value()[7:8]

        noisy_init_vis = f_noise(init_vis)

        network_state = [
            [noisy_init_vis] + [numpy.zeros((1, len(b.get_value())), dtype='float32') for b in bias_list[1:]]]

        visible_chain = [init_vis]

        noisy_h0_chain = [noisy_init_vis]

        for i in range(N - 1):
            # feed the last state into the network, compute new state, and obtain visible units expectation chain
            net_state_out, vis_pX_chain = sampling_wrapper(network_state[-1])

            # append to the visible chain
            visible_chain += vis_pX_chain

            # append state output to the network state chain
            network_state.append(net_state_out)

            noisy_h0_chain.append(net_state_out[0])

        return numpy.vstack(visible_chain), numpy.vstack(noisy_h0_chain)

    def plot_samples(epoch_number, iteration):
        to_sample = time.time()
        if layers == 1:
            # one layer model
            V = sample_some_numbers_single_layer()
        else:
            V, _ = sample_some_numbers()
        img_samples = PIL.Image.fromarray(tile_raster_images(V, (root_N_input, root_N_input), (20, 20)))

        fname = outdir + 'samples_iteration_' + str(iteration) + '_epoch_' + str(epoch_number) + '.png'
        img_samples.save(fname)
        logger.log('Took ' + str(time.time() - to_sample) + ' to sample 400 numbers')

    #############################
    # Save the model parameters #
    #############################
    def save_params_to_file(name, n, gsn_params, iteration):
        pass
        logger.log('saving parameters...')
        save_path = outdir + name + '_params_iteration_' + str(iteration) + '_epoch_' + str(n) + '.pkl'
        f = open(save_path, 'wb')
        try:
            cPickle.dump(gsn_params, f, protocol=cPickle.HIGHEST_PROTOCOL)
        finally:
            f.close()

    def save_params(params):
        values = [param.get_value(borrow=True) for param in params]
        return values

    def restore_params(params, values):
        for i in range(len(params)):
            params[i].set_value(values[i])

    ################
    # GSN TRAINING #
    ################
    def train_GSN(iteration, train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        logger.log('----------------TRAINING GSN FOR ITERATION ' + str(iteration) + "--------------\n")

        # TRAINING
        n_epoch = state.n_epoch
        batch_size = state.batch_size
        STOP = False
        counter = 0
        if iteration == 0:
            learning_rate.set_value(cast32(state.learning_rate))  # learning rate
        times = []
        best_cost = float('inf')
        best_params = None
        patience = 0

        logger.log(['learning rate:', learning_rate.get_value()])

        logger.log(['train X size:', str(train_X.shape.eval())])
        logger.log(['valid X size:', str(valid_X.shape.eval())])
        logger.log(['test X size:', str(test_X.shape.eval())])

        if state.vis_init:
            bias_list[0].set_value(logit(numpy.clip(0.9, 0.001, train_X.get_value().mean(axis=0))))

        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            logger.log('Testing : skip training')
            STOP = True

        while not STOP:
            counter += 1
            t = time.time()
            logger.append([counter, '\t'])

            # shuffle the data
            # data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)

            # train
            train_costs = []
            train_errors = []
            if iteration == 0:
                for i in range(len(train_X.get_value(borrow=True)) / batch_size):
                    x = train_X.get_value(borrow=True)[i * batch_size: (i + 1) * batch_size]
                    cost, error = gsn_f_learn_init(x)
                    train_costs.append([cost])
                    train_errors.append([error])
            else:
                for i in range(len(train_X.get_value(borrow=True)) / batch_size):
                    xs = [train_X.get_value(borrow=True)[
                          (i * batch_size) + sequence_idx: ((i + 1) * batch_size) + sequence_idx] for sequence_idx in
                          range(len(Xs))]
                    xs, _ = fix_input_size(xs)
                    _ins = xs  # + [sequence_window_size]
                    cost, error = gsn_f_learn(*_ins)
                    train_costs.append(cost)
                    train_errors.append(error)

            train_costs = numpy.mean(train_costs)
            train_errors = numpy.mean(train_errors)
            logger.append(['Train: ', trunc(train_costs), trunc(train_errors), '\t'])
            with open(train_convergence, 'a') as f:
                f.write("{0!s},".format(train_costs))
                f.write("\n")

            # valid
            valid_costs = []
            if iteration == 0:
                for i in range(len(valid_X.get_value(borrow=True)) / batch_size):
                    x = valid_X.get_value(borrow=True)[i * batch_size: (i + 1) * batch_size]
                    cost, _ = gsn_f_cost_init(x)
                    valid_costs.append([cost])
            else:
                for i in range(len(valid_X.get_value(borrow=True)) / batch_size):
                    xs = [valid_X.get_value(borrow=True)[
                          (i * batch_size) + sequence_idx: ((i + 1) * batch_size) + sequence_idx] for sequence_idx in
                          range(len(Xs))]
                    xs, _ = fix_input_size(xs)
                    _ins = xs  # + [sequence_window_size]
                    costs, _ = gsn_f_cost(*_ins)
                    valid_costs.append(costs)

            valid_costs = numpy.mean(valid_costs)
            logger.append(['Valid: ', trunc(valid_costs), '\t'])
            with open(valid_convergence, 'a') as f:
                f.write("{0!s},".format(valid_costs))
                f.write("\n")

            # test
            test_costs = []
            test_errors = []
            if iteration == 0:
                for i in range(len(test_X.get_value(borrow=True)) / batch_size):
                    x = test_X.get_value(borrow=True)[i * batch_size: (i + 1) * batch_size]
                    cost, error = gsn_f_cost_init(x)
                    test_costs.append([cost])
                    test_errors.append([error])
            else:
                for i in range(len(test_X.get_value(borrow=True)) / batch_size):
                    xs = [test_X.get_value(borrow=True)[
                          (i * batch_size) + sequence_idx: ((i + 1) * batch_size) + sequence_idx] for sequence_idx in
                          range(len(Xs))]
                    xs, _ = fix_input_size(xs)
                    _ins = xs  # + [sequence_window_size]
                    costs, errors = gsn_f_cost(*_ins)
                    test_costs.append(costs)
                    test_errors.append(errors)

            test_costs = numpy.mean(test_costs)
            test_errors = numpy.mean(test_errors)
            logger.append(['Test: ', trunc(test_costs), trunc(test_errors), '\t'])
            with open(test_convergence, 'a') as f:
                f.write("{0!s},".format(test_costs))
                f.write("\n")

            # check for early stopping
            cost = numpy.sum(valid_costs)
            if cost < best_cost * state.early_stop_threshold:
                patience = 0
                best_cost = cost
                # save the parameters that made it the best
                best_params = save_params(gsn_params)
            else:
                patience += 1

            if counter >= n_epoch or patience >= state.early_stop_length:
                STOP = True
                if best_params is not None:
                    restore_params(gsn_params, best_params)
                save_params_to_file('gsn', counter, gsn_params, iteration)
                logger.log(["next learning rate should be", learning_rate.get_value() * annealing])

            timing = time.time() - t
            times.append(timing)

            logger.append('time: ' + make_time_units_string(timing))

            logger.log('remaining: ' + make_time_units_string((n_epoch - counter) * numpy.mean(times)))

            if (counter % state.save_frequency) == 0 or STOP is True:
                n_examples = 100
                if iteration == 0:
                    random_idx = numpy.array(R.sample(range(len(test_X.get_value())), n_examples))
                    numbers = test_X.get_value()[random_idx]
                    noisy_numbers = f_noise(test_X.get_value()[random_idx])
                    reconstructed = f_recon_init(noisy_numbers)
                    # Concatenate stuff
                    stacked = numpy.vstack([numpy.vstack(
                        [numbers[i * 10: (i + 1) * 10], noisy_numbers[i * 10: (i + 1) * 10],
                         reconstructed[i * 10: (i + 1) * 10]]) for i in range(10)])
                    number_reconstruction = PIL.Image.fromarray(
                        tile_raster_images(stacked, (root_N_input, root_N_input), (10, 30)))
                else:
                    n_examples = n_examples + sequence_window_size
                    # Checking reconstruction
                    # grab 100 numbers in the sequence from the test set
                    nums = test_X.get_value()[range(n_examples)]
                    noisy_nums = f_noise(test_X.get_value()[range(n_examples)])

                    reconstructed_prediction = []
                    reconstructed = []
                    for i in range(n_examples):
                        if i >= sequence_window_size:
                            xs = [noisy_nums[i - x] for x in range(len(Xs))]
                            xs.reverse()
                            _ins = xs  # + [sequence_window_size]
                            _outs = f_recon(*_ins)
                            prediction = _outs[0]
                            reconstruction = _outs[1]
                            reconstructed_prediction.append(prediction)
                            reconstructed.append(reconstruction)
                    nums = nums[sequence_window_size:]
                    noisy_nums = noisy_nums[sequence_window_size:]
                    reconstructed_prediction = numpy.array(reconstructed_prediction)
                    reconstructed = numpy.array(reconstructed)

                    # Concatenate stuff
                    stacked = numpy.vstack([numpy.vstack([nums[i * 10: (i + 1) * 10], noisy_nums[i * 10: (i + 1) * 10],
                                                          reconstructed_prediction[i * 10: (i + 1) * 10],
                                                          reconstructed[i * 10: (i + 1) * 10]]) for i in range(10)])
                    number_reconstruction = PIL.Image.fromarray(
                        tile_raster_images(stacked, (root_N_input, root_N_input), (10, 40)))

                # epoch_number    =   reduce(lambda x,y : x + y, ['_'] * (4-len(str(counter)))) + str(counter)
                number_reconstruction.save(
                    outdir + 'gsn_number_reconstruction_iteration_' + str(iteration) + '_epoch_' + str(
                        counter) + '.png')

                # sample_numbers(counter, 'seven')
                plot_samples(counter, iteration)

                # save gsn_params
                save_params_to_file('gsn', counter, gsn_params, iteration)

            # ANNEAL!
            new_lr = learning_rate.get_value() * annealing
            learning_rate.set_value(new_lr)

        # 10k samples
        logger.log('Generating 10,000 samples')
        samples, _ = sample_some_numbers(N=10000)
        f_samples = outdir + 'samples.npy'
        numpy.save(f_samples, samples)
        logger.log('saved digits')

    #######################
    # REGRESSION TRAINING #
    #######################        
    def train_regression(iteration, train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        logger.log('-------------TRAINING REGRESSION FOR ITERATION {0!s}-------------'.format(iteration))

        # TRAINING
        n_epoch = state.n_epoch
        batch_size = state.batch_size
        STOP = False
        counter = 0
        best_cost = float('inf')
        best_params = None
        patience = 0
        if iteration == 0:
            regression_learning_rate.set_value(cast32(state.learning_rate))  # learning rate
        times = []

        logger.log(['learning rate:', regression_learning_rate.get_value()])

        logger.log(['train X size:', str(train_X.shape.eval())])
        logger.log(['valid X size:', str(valid_X.shape.eval())])
        logger.log(['test X size:', str(test_X.shape.eval())])

        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            logger.log('Testing : skip training')
            STOP = True

        while not STOP:
            counter += 1
            t = time.time()
            logger.append([counter, '\t'])

            # shuffle the data
            # data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)

            # train
            train_costs = []
            train_errors = []
            for i in range(len(train_X.get_value(borrow=True)) / batch_size):
                xs = [train_X.get_value(borrow=True)[
                      (i * batch_size) + sequence_idx: ((i + 1) * batch_size) + sequence_idx] for sequence_idx in
                      range(len(Xs))]
                xs, _ = fix_input_size(xs)
                _ins = xs  # + [sequence_window_size]
                cost, error = regression_f_learn(*_ins)
                # print trunc(cost)
                # print [numpy.asarray(a) for a in f_check(*_ins)]
                train_costs.append(cost)
                train_errors.append(error)

            train_costs = numpy.mean(train_costs)
            train_errors = numpy.mean(train_errors)
            logger.append(['rTrain: ', trunc(train_costs), trunc(train_errors), '\t'])
            with open(regression_train_convergence, 'a') as f:
                f.write("{0!s},".format(train_costs))
                f.write("\n")

            # valid
            valid_costs = []
            for i in range(len(valid_X.get_value(borrow=True)) / batch_size):
                xs = [valid_X.get_value(borrow=True)[
                      (i * batch_size) + sequence_idx: ((i + 1) * batch_size) + sequence_idx] for sequence_idx in
                      range(len(Xs))]
                xs, _ = fix_input_size(xs)
                _ins = xs  # + [sequence_window_size]
                cost, _ = regression_f_cost(*_ins)
                valid_costs.append(cost)

            valid_costs = numpy.mean(valid_costs)
            logger.append(['rValid: ', trunc(valid_costs), '\t'])
            with open(regression_valid_convergence, 'a') as f:
                f.write("{0!s},".format(valid_costs))
                f.write("\n")

            # test
            test_costs = []
            test_errors = []
            for i in range(len(test_X.get_value(borrow=True)) / batch_size):
                xs = [test_X.get_value(borrow=True)[
                      (i * batch_size) + sequence_idx: ((i + 1) * batch_size) + sequence_idx] for sequence_idx in
                      range(len(Xs))]
                xs, _ = fix_input_size(xs)
                _ins = xs  # + [sequence_window_size]
                cost, error = regression_f_cost(*_ins)
                test_costs.append(cost)
                test_errors.append(error)

            test_costs = numpy.mean(test_costs)
            test_errors = numpy.mean(test_errors)
            logger.append(['rTest: ', trunc(test_costs), trunc(test_errors), '\t'])
            with open(regression_test_convergence, 'a') as f:
                f.write("{0!s},".format(test_costs))
                f.write("\n")

            # check for early stopping
            cost = numpy.sum(valid_costs)
            if cost < best_cost * state.early_stop_threshold:
                patience = 0
                best_cost = cost
                # keep the best params so far
                best_params = save_params(regression_params)
            else:
                patience += 1

            if counter >= n_epoch or patience >= state.early_stop_length:
                STOP = True
                if best_params is not None:
                    restore_params(regression_params, best_params)
                save_params_to_file('regression', counter, regression_params, iteration)
                logger.log(["next learning rate should be", regression_learning_rate.get_value() * annealing])

            timing = time.time() - t
            times.append(timing)

            logger.append('time: ' + make_time_units_string(timing))

            logger.log('remaining: ' + make_time_units_string((n_epoch - counter) * numpy.mean(times)))

            if (counter % state.save_frequency) == 0 or STOP is True:
                n_examples = 100 + sequence_window_size
                # Checking reconstruction
                # grab 100 numbers in the sequence from the test set
                nums = test_X.get_value()[range(n_examples)]
                noisy_nums = f_noise(test_X.get_value()[range(n_examples)])

                reconstructed_prediction = []
                reconstructed = []
                for i in range(n_examples):
                    if i >= sequence_window_size:
                        xs = [noisy_nums[i - x] for x in range(len(Xs))]
                        xs.reverse()
                        _ins = xs  # + [sequence_window_size]
                        _outs = f_recon(*_ins)
                        prediction = _outs[0]
                        reconstruction = _outs[1]
                        reconstructed_prediction.append(prediction)
                        reconstructed.append(reconstruction)
                nums = nums[sequence_window_size:]
                noisy_nums = noisy_nums[sequence_window_size:]
                reconstructed_prediction = numpy.array(reconstructed_prediction)
                reconstructed = numpy.array(reconstructed)

                # Concatenate stuff
                stacked = numpy.vstack([numpy.vstack([nums[i * 10: (i + 1) * 10], noisy_nums[i * 10: (i + 1) * 10],
                                                      reconstructed_prediction[i * 10: (i + 1) * 10],
                                                      reconstructed[i * 10: (i + 1) * 10]]) for i in range(10)])

                number_reconstruction = PIL.Image.fromarray(
                    tile_raster_images(stacked, (root_N_input, root_N_input), (10, 40)))
                # epoch_number    =   reduce(lambda x,y : x + y, ['_'] * (4-len(str(counter)))) + str(counter)
                number_reconstruction.save(
                    outdir + 'regression_number_reconstruction_iteration_' + str(iteration) + '_epoch_' + str(
                        counter) + '.png')

                # save gsn_params
                save_params_to_file('regression', counter, regression_params, iteration)

            # ANNEAL!
            new_r_lr = regression_learning_rate.get_value() * annealing
            regression_learning_rate.set_value(new_r_lr)

    #####################
    # STORY 1 ALGORITHM #
    #####################
    # alternate training the gsn and training the regression
    for iteration in range(state.max_iterations):
        # if iteration is 0 and initialized_gsn is False:
        #     train_regression(iteration, train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
        # else:
        #     train_GSN(iteration, train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
        #     train_regression(iteration, train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
        train_GSN(iteration, train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
        train_regression(iteration, train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
    def train_recurrent_GSN(iteration, train_X, train_Y, valid_X, valid_Y,
                            test_X, test_Y):
        print '----------------------------------------'
        print 'TRAINING GSN FOR ITERATION', iteration
        with open(logfile, 'a') as f:
            f.write(
                "--------------------------\nTRAINING GSN FOR ITERATION {0!s}\n"
                .format(iteration))

        # TRAINING
        n_epoch = state.n_epoch
        batch_size = state.batch_size
        STOP = False
        counter = 0
        if iteration == 0:
            learning_rate.set_value(cast32(
                state.learning_rate))  # learning rate
        times = []
        best_cost = float('inf')
        patience = 0

        print 'learning rate:', learning_rate.get_value()

        print 'train X size:', str(train_X.shape.eval())
        print 'valid X size:', str(valid_X.shape.eval())
        print 'test X size:', str(test_X.shape.eval())

        train_costs = []
        valid_costs = []
        test_costs = []
        train_costs_post = []
        valid_costs_post = []
        test_costs_post = []

        if state.vis_init:
            bias_list[0].set_value(
                logit(numpy.clip(0.9, 0.001,
                                 train_X.get_value().mean(axis=0))))

        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            print 'Testing : skip training'
            STOP = True

        while not STOP:
            counter += 1
            t = time.time()
            print counter, '\t',
            with open(logfile, 'a') as f:
                f.write("{0!s}\t".format(counter))
            #shuffle the data
            data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y,
                                     test_X, test_Y, dataset, rng)

            #train
            #init hiddens
            #             hiddens = [(T.zeros_like(train_X[:batch_size]).eval())]
            #             for i in range(len(weights_list)):
            #                 # init with zeros
            #                 hiddens.append(T.zeros_like(T.dot(hiddens[i], weights_list[i])).eval())
            hiddens = [
                T.zeros((batch_size, layer_size)).eval()
                for layer_size in layer_sizes
            ]
            train_cost = []
            train_cost_post = []
            for i in range(len(train_X.get_value(borrow=True)) / batch_size):
                xs = [
                    train_X.get_value(
                        borrow=True)[(i * batch_size) +
                                     sequence_idx:((i + 1) * batch_size) +
                                     sequence_idx]
                    for sequence_idx in range(len(Xs))
                ]
                xs, hiddens = fix_input_size(xs, hiddens)
                hiddens[0] = xs[0]
                _ins = hiddens + xs
                _outs = f_learn(*_ins)
                hiddens = _outs[:len(hiddens)]
                cost = _outs[-2]
                cost_post = _outs[-1]
                train_cost.append(cost)
                train_cost_post.append(cost_post)

            train_cost = numpy.mean(train_cost)
            train_costs.append(train_cost)
            train_cost_post = numpy.mean(train_cost_post)
            train_costs_post.append(train_cost_post)
            print 'Train : ', trunc(train_cost), trunc(train_cost_post), '\t',
            with open(logfile, 'a') as f:
                f.write("Train : {0!s} {1!s}\t".format(trunc(train_cost),
                                                       trunc(train_cost_post)))
            with open(train_convergence_pre, 'a') as f:
                f.write("{0!s},".format(train_cost))
            with open(train_convergence_post, 'a') as f:
                f.write("{0!s},".format(train_cost_post))

            #valid
            #init hiddens
            hiddens = [
                T.zeros((batch_size, layer_size)).eval()
                for layer_size in layer_sizes
            ]
            valid_cost = []
            valid_cost_post = []
            for i in range(len(valid_X.get_value(borrow=True)) / batch_size):
                xs = [
                    valid_X.get_value(
                        borrow=True)[(i * batch_size) +
                                     sequence_idx:((i + 1) * batch_size) +
                                     sequence_idx]
                    for sequence_idx in range(len(Xs))
                ]
                xs, hiddens = fix_input_size(xs, hiddens)
                hiddens[0] = xs[0]
                _ins = hiddens + xs
                _outs = f_cost(*_ins)
                hiddens = _outs[:-2]
                cost = _outs[-2]
                cost_post = _outs[-1]
                valid_cost.append(cost)
                valid_cost_post.append(cost_post)

            valid_cost = numpy.mean(valid_cost)
            valid_costs.append(valid_cost)
            valid_cost_post = numpy.mean(valid_cost_post)
            valid_costs_post.append(valid_cost_post)
            print 'Valid : ', trunc(valid_cost), trunc(valid_cost_post), '\t',
            with open(logfile, 'a') as f:
                f.write("Valid : {0!s} {1!s}\t".format(trunc(valid_cost),
                                                       trunc(valid_cost_post)))
            with open(valid_convergence_pre, 'a') as f:
                f.write("{0!s},".format(valid_cost))
            with open(valid_convergence_post, 'a') as f:
                f.write("{0!s},".format(valid_cost_post))

            #test
            #init hiddens
            hiddens = [
                T.zeros((batch_size, layer_size)).eval()
                for layer_size in layer_sizes
            ]
            test_cost = []
            test_cost_post = []
            for i in range(len(test_X.get_value(borrow=True)) / batch_size):
                xs = [
                    test_X.get_value(
                        borrow=True)[(i * batch_size) +
                                     sequence_idx:((i + 1) * batch_size) +
                                     sequence_idx]
                    for sequence_idx in range(len(Xs))
                ]
                xs, hiddens = fix_input_size(xs, hiddens)
                hiddens[0] = xs[0]
                _ins = hiddens + xs
                _outs = f_cost(*_ins)
                hiddens = _outs[:-2]
                cost = _outs[-2]
                cost_post = _outs[-1]
                test_cost.append(cost)
                test_cost_post.append(cost_post)

            test_cost = numpy.mean(test_cost)
            test_costs.append(test_cost)
            test_cost_post = numpy.mean(test_cost_post)
            test_costs_post.append(test_cost_post)
            print 'Test  : ', trunc(test_cost), trunc(test_cost_post), '\t',
            with open(logfile, 'a') as f:
                f.write("Test : {0!s} {1!s}\t".format(trunc(test_cost),
                                                      trunc(test_cost_post)))
            with open(test_convergence_pre, 'a') as f:
                f.write("{0!s},".format(test_cost))
            with open(test_convergence_post, 'a') as f:
                f.write("{0!s},".format(test_cost_post))

            #check for early stopping
            cost = train_cost
            if cost < best_cost * state.early_stop_threshold:
                patience = 0
                best_cost = cost
            else:
                patience += 1

            if counter >= n_epoch or patience >= state.early_stop_length:
                STOP = True
                save_params_to_file('gsn', counter, params, iteration)

            timing = time.time() - t
            times.append(timing)

            print 'time : ', trunc(timing),

            print 'remaining: ', trunc(
                (n_epoch - counter) * numpy.mean(times) / 60 / 60), 'hrs',

            print 'B : ', [
                trunc(abs(b.get_value(borrow=True)).mean()) for b in bias_list
            ],

            print 'W : ', [
                trunc(abs(w.get_value(borrow=True)).mean())
                for w in weights_list
            ],

            print 'V : ', [
                trunc(abs(v.get_value(borrow=True)).mean())
                for v in recurrent_weights_list
            ]

            with open(logfile, 'a') as f:
                f.write("MeanVisB : {0!s}\t".format(
                    trunc(bias_list[0].get_value().mean())))

            with open(logfile, 'a') as f:
                f.write("W : {0!s}\t".format(
                    str([
                        trunc(abs(w.get_value(borrow=True)).mean())
                        for w in weights_list
                    ])))

            with open(logfile, 'a') as f:
                f.write("Time : {0!s} seconds\n".format(trunc(timing)))

            if (counter % state.save_frequency) == 0:
                # Checking reconstruction
                nums = test_X.get_value()[range(100)]
                noisy_nums = f_noise(test_X.get_value()[range(100)])
                reconstructed_prediction = []
                reconstructed_prediction_end = []
                #init reconstruction hiddens
                hiddens = [
                    T.zeros(layer_size).eval() for layer_size in layer_sizes
                ]
                for num in noisy_nums:
                    hiddens[0] = num
                    for i in range(len(hiddens)):
                        if len(hiddens[i].shape
                               ) == 2 and hiddens[i].shape[0] == 1:
                            hiddens[i] = hiddens[i][0]
                    _ins = hiddens + [num]
                    _outs = f_recon(*_ins)
                    hiddens = _outs[:len(hiddens)]
                    [reconstructed_1, reconstructed_n] = _outs[len(hiddens):]
                    reconstructed_prediction.append(reconstructed_1)
                    reconstructed_prediction_end.append(reconstructed_n)

                with open(logfile, 'a') as f:
                    f.write("\n")
                for i in range(len(nums)):
                    if len(
                            reconstructed_prediction[i].shape
                    ) == 2 and reconstructed_prediction[i].shape[0] == 1:
                        reconstructed_prediction[i] = reconstructed_prediction[
                            i][0]
                    print nums[i].tolist(
                    ), "->", reconstructed_prediction[i].tolist()
                    with open(logfile, 'a') as f:
                        f.write("{0!s} -> {1!s}\n".format(
                            nums[i].tolist(), [
                                trunc(n)
                                if n > 0.0001 else trunc(0.00000000000000000)
                                for n in reconstructed_prediction[i].tolist()
                            ]))
                with open(logfile, 'a') as f:
                    f.write("\n")

#                 # Concatenate stuff
#                 stacked = numpy.vstack([numpy.vstack([nums[i*10 : (i+1)*10], noisy_nums[i*10 : (i+1)*10], reconstructed_prediction[i*10 : (i+1)*10], reconstructed_prediction_end[i*10 : (i+1)*10]]) for i in range(10)])
#                 numbers_reconstruction = PIL.Image.fromarray(tile_raster_images(stacked, (root_N_input,root_N_input), (10,40)))
#                 numbers_reconstruction.save(outdir+'gsn_number_reconstruction_iteration_'+str(iteration)+'_epoch_'+str(counter)+'.png')
#
#                 #sample_numbers(counter, 'seven')
#                 plot_samples(counter, iteration)
#
#                 #save params
#                 save_params_to_file('gsn', counter, params, iteration)

# ANNEAL!
            new_lr = learning_rate.get_value() * annealing
            learning_rate.set_value(new_lr)

        # 10k samples
        print 'Generating 10,000 samples'
        samples, _ = sample_some_numbers(N=10000)
        f_samples = outdir + 'samples.npy'
        numpy.save(f_samples, samples)
        print 'saved digits'
    def train_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        logger.log("\n-----------TRAINING GSN------------\n")

        # TRAINING
        n_epoch = state.n_epoch
        batch_size = state.gsn_batch_size
        STOP = False
        counter = 0
        learning_rate.set_value(cast32(state.learning_rate))  # learning rate
        times = []
        best_cost = float('inf')
        best_params = None
        patience = 0

        logger.log(['train X size:', str(train_X.shape.eval())])
        logger.log(['valid X size:', str(valid_X.shape.eval())])
        logger.log(['test X size:', str(test_X.shape.eval())])

        if state.vis_init:
            bias_list[0].set_value(
                logit(numpy.clip(0.9, 0.001,
                                 train_X.get_value().mean(axis=0))))

        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            logger.log('Testing : skip training')
            STOP = True

        while not STOP:
            counter += 1
            t = time.time()
            logger.append([counter, '\t'])

            #shuffle the data
            data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y,
                                     test_X, test_Y, dataset, rng)

            #train
            train_costs = []
            for i in xrange(len(train_X.get_value(borrow=True)) / batch_size):
                x = train_X.get_value()[i * batch_size:(i + 1) * batch_size]
                cost = f_learn_gsn(x)
                train_costs.append([cost])
            train_costs = numpy.mean(train_costs)
            # record it
            logger.append(['Train:', trunc(train_costs), '\t'])
            with open(gsn_train_convergence, 'a') as f:
                f.write("{0!s},".format(train_costs))
                f.write("\n")

            #valid
            valid_costs = []
            for i in xrange(len(valid_X.get_value(borrow=True)) / batch_size):
                x = valid_X.get_value()[i * batch_size:(i + 1) * batch_size]
                cost = f_cost_gsn(x)
                valid_costs.append([cost])
            valid_costs = numpy.mean(valid_costs)
            # record it
            logger.append(['Valid:', trunc(valid_costs), '\t'])
            with open(gsn_valid_convergence, 'a') as f:
                f.write("{0!s},".format(valid_costs))
                f.write("\n")

            #test
            test_costs = []
            for i in xrange(len(test_X.get_value(borrow=True)) / batch_size):
                x = test_X.get_value()[i * batch_size:(i + 1) * batch_size]
                cost = f_cost_gsn(x)
                test_costs.append([cost])
            test_costs = numpy.mean(test_costs)
            # record it
            logger.append(['Test:', trunc(test_costs), '\t'])
            with open(gsn_test_convergence, 'a') as f:
                f.write("{0!s},".format(test_costs))
                f.write("\n")

            #check for early stopping
            cost = numpy.sum(valid_costs)
            if cost < best_cost * state.early_stop_threshold:
                patience = 0
                best_cost = cost
                # save the parameters that made it the best
                best_params = save_params(gsn_params)
            else:
                patience += 1

            if counter >= n_epoch or patience >= state.early_stop_length:
                STOP = True
                if best_params is not None:
                    restore_params(gsn_params, best_params)
                save_params_to_file('gsn', counter, gsn_params)

            timing = time.time() - t
            times.append(timing)

            logger.append('time: ' + make_time_units_string(timing) + '\t')

            logger.log('remaining: ' +
                       make_time_units_string((n_epoch - counter) *
                                              numpy.mean(times)))

            if (counter % state.save_frequency) == 0 or STOP is True:
                n_examples = 100
                random_idx = numpy.array(
                    R.sample(range(len(test_X.get_value(borrow=True))),
                             n_examples))
                numbers = test_X.get_value(borrow=True)[random_idx]
                noisy_numbers = f_noise(
                    test_X.get_value(borrow=True)[random_idx])
                reconstructed = f_recon_gsn(noisy_numbers)
                # Concatenate stuff
                stacked = numpy.vstack([
                    numpy.vstack([
                        numbers[i * 10:(i + 1) * 10],
                        noisy_numbers[i * 10:(i + 1) * 10],
                        reconstructed[i * 10:(i + 1) * 10]
                    ]) for i in range(10)
                ])
                number_reconstruction = PIL.Image.fromarray(
                    tile_raster_images(stacked, (root_N_input, root_N_input),
                                       (10, 30)))

                number_reconstruction.save(outdir +
                                           'gsn_number_reconstruction_epoch_' +
                                           str(counter) + '.png')

                #sample_numbers(counter, 'seven')
                plot_samples(counter, 'gsn')

                #save gsn_params
                save_params_to_file('gsn', counter, gsn_params)

            # ANNEAL!
            new_lr = learning_rate.get_value() * annealing
            learning_rate.set_value(new_lr)

        # 10k samples
        print 'Generating 10,000 samples'
        samples, _ = sample_some_numbers(N=10000)
        f_samples = outdir + 'samples.npy'
        numpy.save(f_samples, samples)
        print 'saved digits'
    def train_RNN_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        # If we are using Hessian-free training
        if state.hessian_free == 1:
            pass
#         gradient_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=5000)
#         cg_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=1000)
#         valid_dataset = hf_sequence_dataset([valid_X.get_value()], batch_size=None, number_batches=1000)
#         
#         s = x_samples
#         costs = [cost, show_cost]
#         hf_optimizer(params, [Xs], s, costs, u, ua).train(gradient_dataset, cg_dataset, initial_lambda=1.0, preconditioner=True, validation=valid_dataset)
        
        # If we are using SGD training
        else:
            # Define the re-used loops for f_learn and f_cost
            def apply_cost_function_to_dataset(function, dataset):
                costs = []
                for i in xrange(len(dataset.get_value(borrow=True)) / batch_size):
                    xs = dataset.get_value(borrow=True)[i * batch_size : (i+1) * batch_size]
                    cost = function(xs)
                    costs.append([cost])
                return numpy.mean(costs)
            
            logger.log("\n-----------TRAINING RNN-GSN------------\n")
            # TRAINING
            n_epoch     =   state.n_epoch
            batch_size  =   state.batch_size
            STOP        =   False
            counter     =   0
            learning_rate.set_value(cast32(state.learning_rate))  # learning rate
            times = []
            best_cost = float('inf')
            best_params = None
            patience = 0
                        
            logger.log(['train X size:',str(train_X.shape.eval())])
            logger.log(['valid X size:',str(valid_X.shape.eval())])
            logger.log(['test X size:',str(test_X.shape.eval())])
            
            if state.vis_init:
                bias_list[0].set_value(logit(numpy.clip(0.9,0.001,train_X.get_value().mean(axis=0))))
        
            if state.test_model:
                # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
                logger.log('Testing : skip training')
                STOP    =   True
        
            while not STOP:
                counter += 1
                t = time.time()
                logger.append([counter,'\t'])
                    
                #shuffle the data
                data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
                     
                #train
                train_costs = apply_cost_function_to_dataset(f_learn, train_X)
                # record it
                logger.append(['Train:',trunc(train_costs),'\t'])
                with open(train_convergence,'a') as f:
                    f.write("{0!s},".format(train_costs))
                    f.write("\n")
         
         
                #valid
                valid_costs = apply_cost_function_to_dataset(f_cost, valid_X)
                # record it
                logger.append(['Valid:',trunc(valid_costs), '\t'])
                with open(valid_convergence,'a') as f:
                    f.write("{0!s},".format(valid_costs))
                    f.write("\n")
         
         
                #test
                test_costs = apply_cost_function_to_dataset(f_cost, test_X)
                # record it 
                logger.append(['Test:',trunc(test_costs), '\t'])
                with open(test_convergence,'a') as f:
                    f.write("{0!s},".format(test_costs))
                    f.write("\n")
                 
                 
                #check for early stopping
                cost = numpy.sum(valid_costs)
                if cost < best_cost*state.early_stop_threshold:
                    patience = 0
                    best_cost = cost
                    # save the parameters that made it the best
                    best_params = save_params(params)
                else:
                    patience += 1
         
                if counter >= n_epoch or patience >= state.early_stop_length:
                    STOP = True
                    if best_params is not None:
                        restore_params(params, best_params)
                    save_params_to_file('all', counter, params)
         
                timing = time.time() - t
                times.append(timing)
         
                logger.append('time: '+make_time_units_string(timing)+'\t')
            
                logger.log('remaining: '+make_time_units_string((n_epoch - counter) * numpy.mean(times)))
        
                if (counter % state.save_frequency) == 0 or STOP is True:
                    n_examples = 100
                    nums = test_X.get_value(borrow=True)[range(n_examples)]
                    noisy_nums = f_noise(test_X.get_value(borrow=True)[range(n_examples)])
                    reconstructions = []
                    for i in xrange(0, len(noisy_nums)):
                        recon = f_recon(noisy_nums[max(0,(i+1)-batch_size):i+1])
                        reconstructions.append(recon)
                    reconstructed = numpy.array(reconstructions)

                    # Concatenate stuff
                    stacked = numpy.vstack([numpy.vstack([nums[i*10 : (i+1)*10], noisy_nums[i*10 : (i+1)*10], reconstructed[i*10 : (i+1)*10]]) for i in range(10)])
                    number_reconstruction = PIL.Image.fromarray(tile_raster_images(stacked, (root_N_input,root_N_input), (10,30)))
                        
                    number_reconstruction.save(outdir+'rnngsn_number_reconstruction_epoch_'+str(counter)+'.png')
            
                    #sample_numbers(counter, 'seven')
                    plot_samples(counter, 'rnngsn')
            
                    #save params
                    save_params_to_file('all', counter, params)
             
                # ANNEAL!
                new_lr = learning_rate.get_value() * annealing
                learning_rate.set_value(new_lr)
    
            
            # 10k samples
            print 'Generating 10,000 samples'
            samples, _  =   sample_some_numbers(N=10000)
            f_samples   =   outdir+'samples.npy'
            numpy.save(f_samples, samples)
            print 'saved digits'
def experiment(state, outdir_base='./'):
    rng.seed(1)  #seed the numpy random generator
    R.seed(
        1
    )  #seed the other random generator (for reconstruction function indices)
    # Initialize output directory and files
    data.mkdir_p(outdir_base)
    outdir = outdir_base + "/" + state.dataset + "/"
    data.mkdir_p(outdir)
    logger = Logger(outdir)
    logger.log("----------MODEL 2, {0!s}-----------\n".format(state.dataset))
    gsn_train_convergence = outdir + "gsn_train_convergence.csv"
    gsn_valid_convergence = outdir + "gsn_valid_convergence.csv"
    gsn_test_convergence = outdir + "gsn_test_convergence.csv"
    train_convergence = outdir + "train_convergence.csv"
    valid_convergence = outdir + "valid_convergence.csv"
    test_convergence = outdir + "test_convergence.csv"
    init_empty_file(gsn_train_convergence)
    init_empty_file(gsn_valid_convergence)
    init_empty_file(gsn_test_convergence)
    init_empty_file(train_convergence)
    init_empty_file(valid_convergence)
    init_empty_file(test_convergence)

    #load parameters from config file if this is a test
    config_filename = outdir + 'config'
    if state.test_model and 'config' in os.listdir(outdir):
        config_vals = load_from_config(config_filename)
        for CV in config_vals:
            logger.log(CV)
            if CV.startswith('test'):
                logger.log('Do not override testing switch')
                continue
            try:
                exec('state.' + CV) in globals(), locals()
            except:
                exec('state.' + CV.split('=')[0] + "='" + CV.split('=')[1] +
                     "'") in globals(), locals()
    else:
        # Save the current configuration
        # Useful for logs/experiments
        logger.log('Saving config')
        with open(config_filename, 'w') as f:
            f.write(str(state))

    logger.log(state)

    ####################################################
    # Load the data, train = train+valid, and sequence #
    ####################################################
    artificial = False
    if state.dataset == 'MNIST_1' or state.dataset == 'MNIST_2' or state.dataset == 'MNIST_3':
        (train_X,
         train_Y), (valid_X,
                    valid_Y), (test_X,
                               test_Y) = data.load_mnist(state.data_path)
        train_X = numpy.concatenate((train_X, valid_X))
        train_Y = numpy.concatenate((train_Y, valid_Y))
        artificial = True
        try:
            dataset = int(state.dataset.split('_')[1])
        except:
            logger.log(
                "ERROR: artificial dataset number not recognized. Input was " +
                str(state.dataset))
            raise AssertionError(
                "artificial dataset number not recognized. Input was " +
                str(state.dataset))
    else:
        logger.log("ERROR: dataset not recognized.")
        raise AssertionError("dataset not recognized.")

    train_X = theano.shared(train_X)
    train_Y = theano.shared(train_Y)
    valid_X = theano.shared(valid_X)
    valid_Y = theano.shared(valid_Y)
    test_X = theano.shared(test_X)
    test_Y = theano.shared(test_Y)

    if artificial:
        logger.log('Sequencing MNIST data...')
        logger.log(['train set size:', len(train_Y.eval())])
        logger.log(['train set size:', len(valid_Y.eval())])
        logger.log(['train set size:', len(test_Y.eval())])
        data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X,
                                 test_Y, dataset, rng)
        logger.log(['train set size:', len(train_Y.eval())])
        logger.log(['train set size:', len(valid_Y.eval())])
        logger.log(['train set size:', len(test_Y.eval())])
        logger.log('Sequencing done.\n')

    N_input = train_X.eval().shape[1]
    root_N_input = numpy.sqrt(N_input)

    # Network and training specifications
    layers = state.layers  # number hidden layers
    walkbacks = state.walkbacks  # number of walkbacks
    layer_sizes = [
        N_input
    ] + [state.hidden_size
         ] * layers  # layer sizes, from h0 to hK (h0 is the visible layer)

    learning_rate = theano.shared(cast32(state.learning_rate))  # learning rate
    annealing = cast32(state.annealing)  # exponential annealing coefficient
    momentum = theano.shared(cast32(state.momentum))  # momentum term

    ##############
    # PARAMETERS #
    ##############
    #gsn
    weights_list = [
        get_shared_weights(layer_sizes[i],
                           layer_sizes[i + 1],
                           name="W_{0!s}_{1!s}".format(i, i + 1))
        for i in range(layers)
    ]  # initialize each layer to uniform sample from sqrt(6. / (n_in + n_out))
    bias_list = [
        get_shared_bias(layer_sizes[i], name='b_' + str(i))
        for i in range(layers + 1)
    ]  # initialize each layer to 0's.

    #recurrent
    recurrent_to_gsn_weights_list = [
        get_shared_weights(state.recurrent_hidden_size,
                           layer_sizes[layer],
                           name="W_u_h{0!s}".format(layer))
        for layer in range(layers + 1) if (layer % 2) != 0
    ]
    W_u_u = get_shared_weights(state.recurrent_hidden_size,
                               state.recurrent_hidden_size,
                               name="W_u_u")
    W_x_u = get_shared_weights(N_input,
                               state.recurrent_hidden_size,
                               name="W_x_u")
    recurrent_bias = get_shared_bias(state.recurrent_hidden_size, name='b_u')

    #lists for use with gradients
    gsn_params = weights_list + bias_list
    u_params = [W_u_u, W_x_u, recurrent_bias]
    params = gsn_params + recurrent_to_gsn_weights_list + u_params

    ###########################################################
    # load initial parameters of gsn to speed up my debugging #
    ###########################################################
    params_to_load = 'gsn_params.pkl'
    initialized_gsn = False
    if os.path.isfile(params_to_load):
        logger.log("\nLoading existing GSN parameters\n")
        loaded_params = cPickle.load(open(params_to_load, 'r'))
        [
            p.set_value(lp.get_value(borrow=False))
            for lp, p in zip(loaded_params[:len(weights_list)], weights_list)
        ]
        [
            p.set_value(lp.get_value(borrow=False))
            for lp, p in zip(loaded_params[len(weights_list):], bias_list)
        ]
        initialized_gsn = True

    ############################
    # Theano variables and RNG #
    ############################
    MRG = RNG_MRG.MRG_RandomStreams(1)
    X = T.fmatrix('X')  #single (batch) for training gsn
    Xs = T.fmatrix(name="Xs")  #sequence for training rnn-gsn

    ########################
    # ACTIVATION FUNCTIONS #
    ########################
    # hidden activation
    if state.hidden_act == 'sigmoid':
        logger.log('Using sigmoid activation for hiddens')
        hidden_activation = T.nnet.sigmoid
    elif state.hidden_act == 'rectifier':
        logger.log('Using rectifier activation for hiddens')
        hidden_activation = lambda x: T.maximum(cast32(0), x)
    elif state.hidden_act == 'tanh':
        logger.log('Using hyperbolic tangent activation for hiddens')
        hidden_activation = lambda x: T.tanh(x)
    else:
        logger.log(
            "ERROR: Did not recognize hidden activation {0!s}, please use tanh, rectifier, or sigmoid"
            .format(state.hidden_act))
        raise AssertionError(
            "Did not recognize hidden activation {0!s}, please use tanh, rectifier, or sigmoid"
            .format(state.hidden_act))

    # visible activation
    if state.visible_act == 'sigmoid':
        logger.log('Using sigmoid activation for visible layer')
        visible_activation = T.nnet.sigmoid
    elif state.visible_act == 'softmax':
        logger.log('Using softmax activation for visible layer')
        visible_activation = T.nnet.softmax
    else:
        logger.log(
            "ERROR: Did not recognize visible activation {0!s}, please use sigmoid or softmax"
            .format(state.visible_act))
        raise AssertionError(
            "Did not recognize visible activation {0!s}, please use sigmoid or softmax"
            .format(state.visible_act))

    # recurrent activation
    if state.recurrent_hidden_act == 'sigmoid':
        logger.log('Using sigmoid activation for recurrent hiddens')
        recurrent_hidden_activation = T.nnet.sigmoid
    elif state.recurrent_hidden_act == 'rectifier':
        logger.log('Using rectifier activation for recurrent hiddens')
        recurrent_hidden_activation = lambda x: T.maximum(cast32(0), x)
    elif state.recurrent_hidden_act == 'tanh':
        logger.log('Using hyperbolic tangent activation for recurrent hiddens')
        recurrent_hidden_activation = lambda x: T.tanh(x)
    else:
        logger.log(
            "ERROR: Did not recognize recurrent hidden activation {0!s}, please use tanh, rectifier, or sigmoid"
            .format(state.recurrent_hidden_act))
        raise AssertionError(
            "Did not recognize recurrent hidden activation {0!s}, please use tanh, rectifier, or sigmoid"
            .format(state.recurrent_hidden_act))

    logger.log("\n")

    ####################
    #  COST FUNCTIONS  #
    ####################
    if state.cost_funct == 'binary_crossentropy':
        logger.log('Using binary cross-entropy cost!')
        cost_function = lambda x, y: T.mean(T.nnet.binary_crossentropy(x, y))
    elif state.cost_funct == 'square':
        logger.log("Using square error cost!")
        #cost_function = lambda x,y: T.log(T.mean(T.sqr(x-y)))
        cost_function = lambda x, y: T.log(T.sum(T.pow((x - y), 2)))
    else:
        logger.log(
            "ERROR: Did not recognize cost function {0!s}, please use binary_crossentropy or square"
            .format(state.cost_funct))
        raise AssertionError(
            "Did not recognize cost function {0!s}, please use binary_crossentropy or square"
            .format(state.cost_funct))

    logger.log("\n")

    ################################################
    #  COMPUTATIONAL GRAPH HELPER METHODS FOR GSN  #
    ################################################
    def update_layers(hiddens, p_X_chain, noisy=True):
        logger.log('odd layer updates')
        update_odd_layers(hiddens, noisy)
        logger.log('even layer updates')
        update_even_layers(hiddens, p_X_chain, noisy)
        logger.log('done full update.\n')

    def update_layers_reverse(hiddens, p_X_chain, noisy=True):
        logger.log('even layer updates')
        update_even_layers(hiddens, p_X_chain, noisy)
        logger.log('odd layer updates')
        update_odd_layers(hiddens, noisy)
        logger.log('done full update.\n')

    # Odd layer update function
    # just a loop over the odd layers
    def update_odd_layers(hiddens, noisy):
        for i in range(1, len(hiddens), 2):
            logger.log(['updating layer', i])
            simple_update_layer(hiddens, None, i, add_noise=noisy)

    # Even layer update
    # p_X_chain is given to append the p(X|...) at each full update (one update = odd update + even update)
    def update_even_layers(hiddens, p_X_chain, noisy):
        for i in range(0, len(hiddens), 2):
            logger.log(['updating layer', i])
            simple_update_layer(hiddens, p_X_chain, i, add_noise=noisy)

    # The layer update function
    # hiddens   :   list containing the symbolic theano variables [visible, hidden1, hidden2, ...]
    #               layer_update will modify this list inplace
    # p_X_chain :   list containing the successive p(X|...) at each update
    #               update_layer will append to this list
    # add_noise     : pre and post activation gaussian noise

    def simple_update_layer(hiddens, p_X_chain, i, add_noise=True):
        # Compute the dot product, whatever layer
        # If the visible layer X
        if i == 0:
            logger.log('using ' + str(weights_list[i]) + '.T')
            hiddens[i] = T.dot(hiddens[i + 1],
                               weights_list[i].T) + bias_list[i]
        # If the top layer
        elif i == len(hiddens) - 1:
            logger.log(['using', weights_list[i - 1]])
            hiddens[i] = T.dot(hiddens[i - 1],
                               weights_list[i - 1]) + bias_list[i]
        # Otherwise in-between layers
        else:
            logger.log("using {0!s} and {1!s}.T".format(
                weights_list[i - 1], weights_list[i]))
            # next layer        :   hiddens[i+1], assigned weights : W_i
            # previous layer    :   hiddens[i-1], assigned weights : W_(i-1)
            hiddens[i] = T.dot(hiddens[i + 1], weights_list[i].T) + T.dot(
                hiddens[i - 1], weights_list[i - 1]) + bias_list[i]

        # Add pre-activation noise if NOT input layer
        if i == 1 and state.noiseless_h1:
            logger.log('>>NO noise in first hidden layer')
            add_noise = False

        # pre activation noise
        if i != 0 and add_noise:
            logger.log(['Adding pre-activation gaussian noise for layer', i])
            hiddens[i] = add_gaussian_noise(hiddens[i],
                                            state.hidden_add_noise_sigma)

        # ACTIVATION!
        if i == 0:
            logger.log('{} activation for visible layer'.format(
                state.visible_act))
            hiddens[i] = visible_activation(hiddens[i])
        else:
            logger.log([
                'Hidden units {} activation for layer'.format(
                    state.hidden_act), i
            ])
            hiddens[i] = hidden_activation(hiddens[i])

        # post activation noise
        # why is there post activation noise? Because there is already pre-activation noise, this just doubles the amount of noise between each activation of the hiddens.
        if i != 0 and add_noise:
            logger.log(['Adding post-activation gaussian noise for layer', i])
            hiddens[i] = add_gaussian_noise(hiddens[i],
                                            state.hidden_add_noise_sigma)

        # build the reconstruction chain if updating the visible layer X
        if i == 0:
            # if input layer -> append p(X|...)
            p_X_chain.append(hiddens[i])

            # sample from p(X|...) - SAMPLING NEEDS TO BE CORRECT FOR INPUT TYPES I.E. FOR BINARY MNIST SAMPLING IS BINOMIAL. real-valued inputs should be gaussian
            if state.input_sampling:
                logger.log('Sampling from input')
                sampled = MRG.binomial(p=hiddens[i],
                                       size=hiddens[i].shape,
                                       dtype='float32')
            else:
                logger.log('>>NO input sampling')
                sampled = hiddens[i]
            # add noise
            sampled = salt_and_pepper(sampled, state.input_salt_and_pepper)

            # set input layer
            hiddens[i] = sampled

    ##############################################
    #    Build the training graph for the GSN    #
    ##############################################
    # the loop step for the rnn-gsn, return the sample and the costs
    def create_gsn_reverse(x_t, u_tm1, noiseflag=True):
        chain = []
        # init hiddens from the u
        hiddens_t = [T.zeros_like(x_t)]
        for layer, w in enumerate(weights_list):
            layer = layer + 1
            # if this is an even layer, just append zeros
            if layer % 2 == 0:
                hiddens_t.append(T.zeros_like(T.dot(hiddens_t[-1], w)))
            # if it is an odd layer, use the rnn to determine the layer
            else:
                hiddens_t.append(
                    hidden_activation(
                        T.dot(u_tm1, recurrent_to_gsn_weights_list[layer /
                                                                   2]) +
                        bias_list[layer]))

        for i in range(walkbacks):
            logger.log("Reverse Walkback {!s}/{!s} for RNN-GSN".format(
                i + 1, walkbacks))
            update_layers_reverse(hiddens_t, chain, noiseflag)

        x_sample = chain[-1]
        costs = [cost_function(rX, x_t) for rX in chain]
        show_cost = costs[-1]
        cost = T.sum(costs)

        return x_sample, cost, show_cost

    # the GSN graph for the rnn-gsn
    def build_gsn_given_u(xs, u, noiseflag=True):
        logger.log("Creating recurrent gsn step scan.\n")
        u0 = T.zeros((1, state.recurrent_hidden_size))
        if u is None:
            u = u0
        else:
            u = T.concatenate(
                [u0, u]
            )  #add the initial u condition to the list of u's created from the recurrent scan op.
        (samples, costs, show_costs), updates = theano.scan(
            lambda x_t, u_tm1: create_gsn_reverse(x_t, u_tm1, noiseflag),
            sequences=[xs, u])
        cost = T.sum(costs)
        show_cost = T.mean(show_costs)
        last_sample = samples[-1]

        return last_sample, cost, show_cost, updates

    def build_gsn_given_u0(x, u0, noiseflag=True):
        x_sample, _, _ = create_gsn_reverse(x, u0, noiseflag)
        return x_sample

    # the GSN graph for initial GSN training
    def build_gsn_graph(x, noiseflag):
        p_X_chain = []
        if noiseflag:
            X_init = salt_and_pepper(x, state.input_salt_and_pepper)
        else:
            X_init = x
        # init hiddens with zeros
        hiddens = [X_init]
        for w in weights_list:
            hiddens.append(T.zeros_like(T.dot(hiddens[-1], w)))
        # The layer update scheme
        logger.log(["Building the gsn graph :", walkbacks, "updates"])
        for i in range(walkbacks):
            logger.log("GSN Walkback {!s}/{!s}".format(i + 1, walkbacks))
            update_layers(hiddens, p_X_chain, noisy=noiseflag)

        return p_X_chain

    '''Build the actual gsn training graph'''
    p_X_chain_gsn = build_gsn_graph(X, noiseflag=True)

    ##############################################
    #  Build the training graph for the RNN-GSN  #
    ##############################################
    # If `x_t` is given, deterministic recurrence to compute the u_t. Otherwise, first generate
    def recurrent_step(x_t, u_tm1):
        ua_t = T.dot(x_t, W_x_u) + T.dot(u_tm1, W_u_u) + recurrent_bias
        u_t = recurrent_hidden_activation(ua_t)
        return ua_t, u_t

    logger.log("\nCreating recurrent step scan.")
    # For training, the deterministic recurrence is used to compute all the
    # {h_t, 1 <= t <= T} given Xs. Conditional GSNs can then be trained
    # in batches using those parameters.
    u0 = T.zeros((state.recurrent_hidden_size,
                  ))  # initial value for the RNN hidden units
    (_, u), updates_recurrent = theano.scan(
        lambda x_t, u_tm1: recurrent_step(x_t, u_tm1),
        sequences=Xs,
        outputs_info=[None, u0])

    _, cost, show_cost, updates_gsn = build_gsn_given_u(Xs, u, noiseflag=True)

    updates_recurrent.update(updates_gsn)

    updates_train = updates_recurrent
    updates_cost = updates_recurrent

    ################################################
    #  Build the checkpoint graph for the RNN-GSN  #
    ################################################
    # Used to generate the next predicted output given all previous inputs - starting with nothing
    # When there is no X history
    x_sample_R_init = build_gsn_given_u0(X, u0, noiseflag=False)
    # When there is some number of Xs history
    x_sample_R, _, _, updates_gsn_R = build_gsn_given_u(Xs, u, noiseflag=False)

    #############
    #   COSTS   #
    #############
    logger.log("")
    logger.log('Cost w.r.t p(X|...) at every step in the graph')

    gsn_costs = [cost_function(rX, X) for rX in p_X_chain_gsn]
    gsn_show_cost = gsn_costs[-1]
    gsn_cost = numpy.sum(gsn_costs)

    ###################################
    # GRADIENTS AND FUNCTIONS FOR GSN #
    ###################################
    logger.log(["params:", params])

    logger.log("creating functions...")
    start_functions_time = time.time()

    gradient_gsn = T.grad(gsn_cost, gsn_params)
    gradient_buffer_gsn = [
        theano.shared(numpy.zeros(param.get_value().shape, dtype='float32'))
        for param in gsn_params
    ]

    m_gradient_gsn = [
        momentum * gb + (cast32(1) - momentum) * g
        for (gb, g) in zip(gradient_buffer_gsn, gradient_gsn)
    ]
    param_updates_gsn = [(param, param - learning_rate * mg)
                         for (param, mg) in zip(gsn_params, m_gradient_gsn)]
    gradient_buffer_updates_gsn = zip(gradient_buffer_gsn, m_gradient_gsn)

    grad_updates_gsn = OrderedDict(param_updates_gsn +
                                   gradient_buffer_updates_gsn)

    f_cost_gsn = theano.function(inputs=[X],
                                 outputs=gsn_show_cost,
                                 on_unused_input='warn')

    f_learn_gsn = theano.function(inputs=[X],
                                  updates=grad_updates_gsn,
                                  outputs=gsn_show_cost,
                                  on_unused_input='warn')

    #######################################
    # GRADIENTS AND FUNCTIONS FOR RNN-GSN #
    #######################################
    # if we are not using Hessian-free training create the normal sgd functions
    if state.hf == 0:
        gradient = T.grad(cost, params)
        gradient_buffer = [
            theano.shared(numpy.zeros(param.get_value().shape,
                                      dtype='float32')) for param in params
        ]

        m_gradient = [
            momentum * gb + (cast32(1) - momentum) * g
            for (gb, g) in zip(gradient_buffer, gradient)
        ]
        param_updates = [(param, param - learning_rate * mg)
                         for (param, mg) in zip(params, m_gradient)]
        gradient_buffer_updates = zip(gradient_buffer, m_gradient)

        updates = OrderedDict(param_updates + gradient_buffer_updates)
        updates_train.update(updates)

        f_learn = theano.function(inputs=[Xs],
                                  updates=updates_train,
                                  outputs=show_cost,
                                  on_unused_input='warn')

        f_cost = theano.function(inputs=[Xs],
                                 updates=updates_cost,
                                 outputs=show_cost,
                                 on_unused_input='warn')

    logger.log("Training/cost functions done.")
    compilation_time = time.time() - start_functions_time
    # Show the compile time with appropriate easy-to-read units.
    if compilation_time < 60:
        logger.log(["Compilation took", compilation_time, "seconds.\n\n"])
    elif compilation_time < 3600:
        logger.log(["Compilation took", compilation_time / 60, "minutes.\n\n"])
    else:
        logger.log(["Compilation took", compilation_time / 3600, "hours.\n\n"])

    ############################################################################################
    # Denoise some numbers : show number, noisy number, predicted number, reconstructed number #
    ############################################################################################
    # Recompile the graph without noise for reconstruction function
    # The layer update scheme
    logger.log(
        "Creating graph for noisy reconstruction function at checkpoints during training."
    )
    f_recon_init = theano.function(inputs=[X],
                                   outputs=x_sample_R_init,
                                   on_unused_input='warn')
    f_recon = theano.function(inputs=[Xs],
                              outputs=x_sample_R,
                              updates=updates_gsn_R)

    # Now do the same but for the GSN in the initial run
    p_X_chain_R = build_gsn_graph(X, noiseflag=False)
    f_recon_gsn = theano.function(inputs=[X], outputs=p_X_chain_R[-1])

    logger.log("Done compiling all functions.")
    compilation_time = time.time() - start_functions_time
    # Show the compile time with appropriate easy-to-read units.
    if compilation_time < 60:
        logger.log(["Total time took", compilation_time, "seconds.\n\n"])
    elif compilation_time < 3600:
        logger.log(["Total time took", compilation_time / 60, "minutes.\n\n"])
    else:
        logger.log(["Total time took", compilation_time / 3600, "hours.\n\n"])

    ############
    # Sampling #
    ############
    # a function to add salt and pepper noise
    f_noise = theano.function(inputs=[X],
                              outputs=salt_and_pepper(
                                  X, state.input_salt_and_pepper))
    # the input to the sampling function
    X_sample = T.fmatrix("X_sampling")
    network_state_input = [X_sample] + [
        T.fmatrix("H_sampling_" + str(i + 1)) for i in range(layers)
    ]

    # "Output" state of the network (noisy)
    # initialized with input, then we apply updates

    network_state_output = [X_sample] + network_state_input[1:]

    visible_pX_chain = []

    # ONE update
    logger.log("Performing one walkback in network state sampling.")
    update_layers(network_state_output, visible_pX_chain, noisy=True)

    if layers == 1:
        f_sample_simple = theano.function(inputs=[X_sample],
                                          outputs=visible_pX_chain[-1])

    # WHY IS THERE A WARNING????
    # because the first odd layers are not used -> directly computed FROM THE EVEN layers
    # unused input = warn
    f_sample2 = theano.function(inputs=network_state_input,
                                outputs=network_state_output +
                                visible_pX_chain,
                                on_unused_input='warn')

    def sample_some_numbers_single_layer():
        x0 = test_X.get_value()[:1]
        samples = [x0]
        x = f_noise(x0)
        for i in range(399):
            x = f_sample_simple(x)
            samples.append(x)
            x = numpy.random.binomial(n=1, p=x, size=x.shape).astype('float32')
            x = f_noise(x)
        return numpy.vstack(samples)

    def sampling_wrapper(NSI):
        # * is the "splat" operator: It takes a list as input, and expands it into actual positional arguments in the function call.
        out = f_sample2(*NSI)
        NSO = out[:len(network_state_output)]
        vis_pX_chain = out[len(network_state_output):]
        return NSO, vis_pX_chain

    def sample_some_numbers(N=400):
        # The network's initial state
        init_vis = test_X.get_value()[:1]

        noisy_init_vis = f_noise(init_vis)

        network_state = [[noisy_init_vis] + [
            numpy.zeros((1, len(b.get_value())), dtype='float32')
            for b in bias_list[1:]
        ]]

        visible_chain = [init_vis]

        noisy_h0_chain = [noisy_init_vis]

        for i in range(N - 1):

            # feed the last state into the network, compute new state, and obtain visible units expectation chain
            net_state_out, vis_pX_chain = sampling_wrapper(network_state[-1])

            # append to the visible chain
            visible_chain += vis_pX_chain

            # append state output to the network state chain
            network_state.append(net_state_out)

            noisy_h0_chain.append(net_state_out[0])

        return numpy.vstack(visible_chain), numpy.vstack(noisy_h0_chain)

    def plot_samples(epoch_number, leading_text):
        to_sample = time.time()
        if layers == 1:
            # one layer model
            V = sample_some_numbers_single_layer()
        else:
            V, H0 = sample_some_numbers()
        img_samples = PIL.Image.fromarray(
            tile_raster_images(V, (root_N_input, root_N_input), (20, 20)))

        fname = outdir + leading_text + 'samples_epoch_' + str(
            epoch_number) + '.png'
        img_samples.save(fname)
        logger.log('Took ' + str(time.time() - to_sample) +
                   ' to sample 400 numbers')

    #############################
    # Save the model parameters #
    #############################
    def save_params_to_file(name, n, gsn_params):
        pass
        print 'saving parameters...'
        save_path = outdir + name + '_params_epoch_' + str(n) + '.pkl'
        f = open(save_path, 'wb')
        try:
            cPickle.dump(gsn_params, f, protocol=cPickle.HIGHEST_PROTOCOL)
        finally:
            f.close()

    def save_params(params):
        values = [param.get_value(borrow=True) for param in params]
        return values

    def restore_params(params, values):
        for i in range(len(params)):
            params[i].set_value(values[i])

    ################
    # GSN TRAINING #
    ################
    def train_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        logger.log("\n-----------TRAINING GSN------------\n")

        # TRAINING
        n_epoch = state.n_epoch
        batch_size = state.gsn_batch_size
        STOP = False
        counter = 0
        learning_rate.set_value(cast32(state.learning_rate))  # learning rate
        times = []
        best_cost = float('inf')
        best_params = None
        patience = 0

        logger.log(['train X size:', str(train_X.shape.eval())])
        logger.log(['valid X size:', str(valid_X.shape.eval())])
        logger.log(['test X size:', str(test_X.shape.eval())])

        if state.vis_init:
            bias_list[0].set_value(
                logit(numpy.clip(0.9, 0.001,
                                 train_X.get_value().mean(axis=0))))

        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            logger.log('Testing : skip training')
            STOP = True

        while not STOP:
            counter += 1
            t = time.time()
            logger.append([counter, '\t'])

            #shuffle the data
            data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y,
                                     test_X, test_Y, dataset, rng)

            #train
            train_costs = []
            for i in xrange(len(train_X.get_value(borrow=True)) / batch_size):
                x = train_X.get_value()[i * batch_size:(i + 1) * batch_size]
                cost = f_learn_gsn(x)
                train_costs.append([cost])
            train_costs = numpy.mean(train_costs)
            # record it
            logger.append(['Train:', trunc(train_costs), '\t'])
            with open(gsn_train_convergence, 'a') as f:
                f.write("{0!s},".format(train_costs))
                f.write("\n")

            #valid
            valid_costs = []
            for i in xrange(len(valid_X.get_value(borrow=True)) / batch_size):
                x = valid_X.get_value()[i * batch_size:(i + 1) * batch_size]
                cost = f_cost_gsn(x)
                valid_costs.append([cost])
            valid_costs = numpy.mean(valid_costs)
            # record it
            logger.append(['Valid:', trunc(valid_costs), '\t'])
            with open(gsn_valid_convergence, 'a') as f:
                f.write("{0!s},".format(valid_costs))
                f.write("\n")

            #test
            test_costs = []
            for i in xrange(len(test_X.get_value(borrow=True)) / batch_size):
                x = test_X.get_value()[i * batch_size:(i + 1) * batch_size]
                cost = f_cost_gsn(x)
                test_costs.append([cost])
            test_costs = numpy.mean(test_costs)
            # record it
            logger.append(['Test:', trunc(test_costs), '\t'])
            with open(gsn_test_convergence, 'a') as f:
                f.write("{0!s},".format(test_costs))
                f.write("\n")

            #check for early stopping
            cost = numpy.sum(valid_costs)
            if cost < best_cost * state.early_stop_threshold:
                patience = 0
                best_cost = cost
                # save the parameters that made it the best
                best_params = save_params(gsn_params)
            else:
                patience += 1

            if counter >= n_epoch or patience >= state.early_stop_length:
                STOP = True
                if best_params is not None:
                    restore_params(gsn_params, best_params)
                save_params_to_file('gsn', counter, gsn_params)

            timing = time.time() - t
            times.append(timing)

            logger.append(['time:', trunc(timing)])

            logger.log([
                'remaining:',
                trunc((n_epoch - counter) * numpy.mean(times) / 60 / 60), 'hrs'
            ])

            if (counter % state.save_frequency) == 0 or STOP is True:
                n_examples = 100
                random_idx = numpy.array(
                    R.sample(range(len(test_X.get_value(borrow=True))),
                             n_examples))
                numbers = test_X.get_value(borrow=True)[random_idx]
                noisy_numbers = f_noise(
                    test_X.get_value(borrow=True)[random_idx])
                reconstructed = f_recon_gsn(noisy_numbers)
                # Concatenate stuff
                stacked = numpy.vstack([
                    numpy.vstack([
                        numbers[i * 10:(i + 1) * 10],
                        noisy_numbers[i * 10:(i + 1) * 10],
                        reconstructed[i * 10:(i + 1) * 10]
                    ]) for i in range(10)
                ])
                number_reconstruction = PIL.Image.fromarray(
                    tile_raster_images(stacked, (root_N_input, root_N_input),
                                       (10, 30)))

                number_reconstruction.save(outdir +
                                           'gsn_number_reconstruction_epoch_' +
                                           str(counter) + '.png')

                #sample_numbers(counter, 'seven')
                plot_samples(counter, 'gsn')

                #save gsn_params
                save_params_to_file('gsn', counter, gsn_params)

            # ANNEAL!
            new_lr = learning_rate.get_value() * annealing
            learning_rate.set_value(new_lr)

        # 10k samples
        print 'Generating 10,000 samples'
        samples, _ = sample_some_numbers(N=10000)
        f_samples = outdir + 'samples.npy'
        numpy.save(f_samples, samples)
        print 'saved digits'

    def train_RNN_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        # If we are using Hessian-free training
        if state.hf == 1:
            pass
#         gradient_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=5000)
#         cg_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=1000)
#         valid_dataset = hf_sequence_dataset([valid_X.get_value()], batch_size=None, number_batches=1000)
#
#         s = x_samples
#         costs = [cost, show_cost]
#         hf_optimizer(params, [Xs], s, costs, u, ua).train(gradient_dataset, cg_dataset, initial_lambda=1.0, preconditioner=True, validation=valid_dataset)

# If we are using SGD training
        else:
            logger.log("\n-----------TRAINING RNN-GSN------------\n")
            # TRAINING
            n_epoch = state.n_epoch
            batch_size = state.batch_size
            STOP = False
            counter = 0
            learning_rate.set_value(cast32(
                state.learning_rate))  # learning rate
            times = []
            best_cost = float('inf')
            best_params = None
            patience = 0

            logger.log(['train X size:', str(train_X.shape.eval())])
            logger.log(['valid X size:', str(valid_X.shape.eval())])
            logger.log(['test X size:', str(test_X.shape.eval())])

            if state.vis_init:
                bias_list[0].set_value(
                    logit(
                        numpy.clip(0.9, 0.001,
                                   train_X.get_value().mean(axis=0))))

            if state.test_model:
                # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
                logger.log('Testing : skip training')
                STOP = True

            while not STOP:
                counter += 1
                t = time.time()
                logger.append([counter, '\t'])

                #shuffle the data
                data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y,
                                         test_X, test_Y, dataset, rng)

                #train
                train_costs = []
                for i in xrange(
                        len(train_X.get_value(borrow=True)) / batch_size):
                    xs = train_X.get_value(
                        borrow=True)[i * batch_size:(i + 1) * batch_size]
                    cost = f_learn(xs)
                    train_costs.append([cost])
                train_costs = numpy.mean(train_costs)
                # record it
                logger.append(['Train:', trunc(train_costs), '\t'])
                with open(train_convergence, 'a') as f:
                    f.write("{0!s},".format(train_costs))
                    f.write("\n")

                #valid
                valid_costs = []
                for i in xrange(
                        len(valid_X.get_value(borrow=True)) / batch_size):
                    xs = valid_X.get_value(
                        borrow=True)[i * batch_size:(i + 1) * batch_size]
                    cost = f_cost(xs)
                    valid_costs.append([cost])
                valid_costs = numpy.mean(valid_costs)
                # record it
                logger.append(['Valid:', trunc(valid_costs), '\t'])
                with open(valid_convergence, 'a') as f:
                    f.write("{0!s},".format(valid_costs))
                    f.write("\n")

                #test
                test_costs = []
                for i in xrange(
                        len(test_X.get_value(borrow=True)) / batch_size):
                    xs = test_X.get_value(borrow=True)[i * batch_size:(i + 1) *
                                                       batch_size]
                    cost = f_cost(xs)
                    test_costs.append([cost])
                test_costs = numpy.mean(test_costs)
                # record it
                logger.append(['Test:', trunc(test_costs), '\t'])
                with open(test_convergence, 'a') as f:
                    f.write("{0!s},".format(test_costs))
                    f.write("\n")

                #check for early stopping
                cost = numpy.sum(valid_costs)
                if cost < best_cost * state.early_stop_threshold:
                    patience = 0
                    best_cost = cost
                    # save the parameters that made it the best
                    best_params = save_params(params)
                else:
                    patience += 1

                if counter >= n_epoch or patience >= state.early_stop_length:
                    STOP = True
                    if best_params is not None:
                        restore_params(params, best_params)
                    save_params_to_file('all', counter, params)

                timing = time.time() - t
                times.append(timing)

                logger.append(['time:', trunc(timing)])

                logger.log([
                    'remaining:',
                    trunc((n_epoch - counter) * numpy.mean(times) / 60 / 60),
                    'hrs'
                ])

                if (counter % state.save_frequency) == 0 or STOP is True:
                    n_examples = 100
                    nums = test_X.get_value(borrow=True)[range(n_examples)]
                    noisy_nums = f_noise(
                        test_X.get_value(borrow=True)[range(n_examples)])
                    reconstructions = []
                    for i in xrange(0, len(noisy_nums)):
                        if i is 0:
                            recon = f_recon_init(noisy_nums[:i + 1])
                        else:
                            recon = f_recon(
                                noisy_nums[max(0, (i + 1) - batch_size):i + 1])
                        reconstructions.append(recon)
                    reconstructed = numpy.array(reconstructions)

                    # Concatenate stuff
                    stacked = numpy.vstack([
                        numpy.vstack([
                            nums[i * 10:(i + 1) * 10],
                            noisy_nums[i * 10:(i + 1) * 10],
                            reconstructed[i * 10:(i + 1) * 10]
                        ]) for i in range(10)
                    ])
                    number_reconstruction = PIL.Image.fromarray(
                        tile_raster_images(stacked,
                                           (root_N_input, root_N_input),
                                           (10, 30)))

                    number_reconstruction.save(
                        outdir + 'rnngsn_number_reconstruction_epoch_' +
                        str(counter) + '.png')

                    #sample_numbers(counter, 'seven')
                    plot_samples(counter, 'rnngsn')

                    #save params
                    save_params_to_file('all', counter, params)

                # ANNEAL!
                new_lr = learning_rate.get_value() * annealing
                learning_rate.set_value(new_lr)

            # 10k samples
            print 'Generating 10,000 samples'
            samples, _ = sample_some_numbers(N=10000)
            f_samples = outdir + 'samples.npy'
            numpy.save(f_samples, samples)
            print 'saved digits'

    #####################
    # STORY 2 ALGORITHM #
    #####################
    # train the GSN parameters first to get a good baseline (if not loaded from parameter .pkl file)
    if initialized_gsn is False:
        train_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
    # train the entire RNN-GSN
    train_RNN_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
def experiment(state, outdir_base='./'):
    rng.seed(1)  #seed the numpy random generator
    R.seed(
        1
    )  #seed the other random generator (for reconstruction function indices)
    # Initialize output directory and files
    data.mkdir_p(outdir_base)
    outdir = outdir_base + "/" + state.dataset + "/"
    data.mkdir_p(outdir)
    logger = Logger(outdir)
    logger.log("----------MODEL 2, {0!s}-----------\n".format(state.dataset))
    if state.initialize_gsn:
        gsn_train_convergence = outdir + "gsn_train_convergence.csv"
        gsn_valid_convergence = outdir + "gsn_valid_convergence.csv"
        gsn_test_convergence = outdir + "gsn_test_convergence.csv"
    train_convergence = outdir + "train_convergence.csv"
    valid_convergence = outdir + "valid_convergence.csv"
    test_convergence = outdir + "test_convergence.csv"
    if state.initialize_gsn:
        init_empty_file(gsn_train_convergence)
        init_empty_file(gsn_valid_convergence)
        init_empty_file(gsn_test_convergence)
    init_empty_file(train_convergence)
    init_empty_file(valid_convergence)
    init_empty_file(test_convergence)

    #load parameters from config file if this is a test
    config_filename = outdir + 'config'
    if state.test_model and 'config' in os.listdir(outdir):
        config_vals = load_from_config(config_filename)
        for CV in config_vals:
            logger.log(CV)
            if CV.startswith('test'):
                logger.log('Do not override testing switch')
                continue
            try:
                exec('state.' + CV) in globals(), locals()
            except:
                exec('state.' + CV.split('=')[0] + "='" + CV.split('=')[1] +
                     "'") in globals(), locals()
    else:
        # Save the current configuration
        # Useful for logs/experiments
        logger.log('Saving config')
        with open(config_filename, 'w') as f:
            f.write(str(state))

    logger.log(state)

    ####################################################
    # Load the data, train = train+valid, and sequence #
    ####################################################
    artificial = False
    if state.dataset == 'MNIST_1' or state.dataset == 'MNIST_2' or state.dataset == 'MNIST_3':
        (train_X,
         train_Y), (valid_X,
                    valid_Y), (test_X,
                               test_Y) = data.load_mnist(state.data_path)
        train_X = numpy.concatenate((train_X, valid_X))
        train_Y = numpy.concatenate((train_Y, valid_Y))
        artificial = True
        try:
            dataset = int(state.dataset.split('_')[1])
        except:
            logger.log(
                "ERROR: artificial dataset number not recognized. Input was " +
                str(state.dataset))
            raise AssertionError(
                "artificial dataset number not recognized. Input was " +
                str(state.dataset))
    else:
        logger.log("ERROR: dataset not recognized.")
        raise AssertionError("dataset not recognized.")

    # transfer the datasets into theano shared variables
    train_X, train_Y = data.shared_dataset((train_X, train_Y), borrow=True)
    valid_X, valid_Y = data.shared_dataset((valid_X, valid_Y), borrow=True)
    test_X, test_Y = data.shared_dataset((test_X, test_Y), borrow=True)

    if artificial:
        logger.log('Sequencing MNIST data...')
        logger.log(['train set size:', len(train_Y.eval())])
        logger.log(['train set size:', len(valid_Y.eval())])
        logger.log(['train set size:', len(test_Y.eval())])
        data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X,
                                 test_Y, dataset, rng)
        logger.log(['train set size:', len(train_Y.eval())])
        logger.log(['train set size:', len(valid_Y.eval())])
        logger.log(['train set size:', len(test_Y.eval())])
        logger.log('Sequencing done.\n')

    N_input = train_X.eval().shape[1]
    root_N_input = numpy.sqrt(N_input)

    # Network and training specifications
    layers = state.layers  # number hidden layers
    walkbacks = state.walkbacks  # number of walkbacks
    layer_sizes = [
        N_input
    ] + [state.hidden_size
         ] * layers  # layer sizes, from h0 to hK (h0 is the visible layer)

    learning_rate = theano.shared(cast32(state.learning_rate))  # learning rate
    annealing = cast32(state.annealing)  # exponential annealing coefficient
    momentum = theano.shared(cast32(state.momentum))  # momentum term

    ##############
    # PARAMETERS #
    ##############
    #gsn
    weights_list = [
        get_shared_weights(layer_sizes[i],
                           layer_sizes[i + 1],
                           name="W_{0!s}_{1!s}".format(i, i + 1))
        for i in range(layers)
    ]  # initialize each layer to uniform sample from sqrt(6. / (n_in + n_out))
    bias_list = [
        get_shared_bias(layer_sizes[i], name='b_' + str(i))
        for i in range(layers + 1)
    ]  # initialize each layer to 0's.

    #recurrent
    recurrent_to_gsn_bias_weights_list = [
        get_shared_weights(state.recurrent_hidden_size,
                           layer_sizes[layer],
                           name="W_u_b{0!s}".format(layer))
        for layer in range(layers + 1)
    ]
    W_u_u = get_shared_weights(state.recurrent_hidden_size,
                               state.recurrent_hidden_size,
                               name="W_u_u")
    W_x_u = get_shared_weights(N_input,
                               state.recurrent_hidden_size,
                               name="W_x_u")
    recurrent_bias = get_shared_bias(state.recurrent_hidden_size, name='b_u')

    #lists for use with gradients
    gsn_params = weights_list + bias_list
    u_params = [W_u_u, W_x_u, recurrent_bias]
    params = gsn_params + recurrent_to_gsn_bias_weights_list + u_params

    ###########################################################
    #           load initial parameters of gsn                #
    ###########################################################
    train_gsn_first = False
    if state.initialize_gsn:
        params_to_load = 'gsn_params.pkl'
        if not os.path.isfile(params_to_load):
            train_gsn_first = True
        else:
            logger.log("\nLoading existing GSN parameters\n")
            loaded_params = cPickle.load(open(params_to_load, 'r'))
            [
                p.set_value(lp.get_value(borrow=False)) for lp, p in zip(
                    loaded_params[:len(weights_list)], weights_list)
            ]
            [
                p.set_value(lp.get_value(borrow=False))
                for lp, p in zip(loaded_params[len(weights_list):], bias_list)
            ]

    ############################
    # Theano variables and RNG #
    ############################
    MRG = RNG_MRG.MRG_RandomStreams(1)
    X = T.fmatrix('X')  #single (batch) for training gsn
    Xs = T.fmatrix(name="Xs")  #sequence for training rnn-gsn

    ########################
    # ACTIVATION FUNCTIONS #
    ########################
    # hidden activation
    if state.hidden_act == 'sigmoid':
        logger.log('Using sigmoid activation for hiddens')
        hidden_activation = T.nnet.sigmoid
    elif state.hidden_act == 'rectifier':
        logger.log('Using rectifier activation for hiddens')
        hidden_activation = lambda x: T.maximum(cast32(0), x)
    elif state.hidden_act == 'tanh':
        logger.log('Using hyperbolic tangent activation for hiddens')
        hidden_activation = lambda x: T.tanh(x)
    else:
        logger.log(
            "ERROR: Did not recognize hidden activation {0!s}, please use tanh, rectifier, or sigmoid"
            .format(state.hidden_act))
        raise NotImplementedError(
            "Did not recognize hidden activation {0!s}, please use tanh, rectifier, or sigmoid"
            .format(state.hidden_act))

    # visible activation
    if state.visible_act == 'sigmoid':
        logger.log('Using sigmoid activation for visible layer')
        visible_activation = T.nnet.sigmoid
    elif state.visible_act == 'softmax':
        logger.log('Using softmax activation for visible layer')
        visible_activation = T.nnet.softmax
    else:
        logger.log(
            "ERROR: Did not recognize visible activation {0!s}, please use sigmoid or softmax"
            .format(state.visible_act))
        raise NotImplementedError(
            "Did not recognize visible activation {0!s}, please use sigmoid or softmax"
            .format(state.visible_act))

    # recurrent activation
    if state.recurrent_hidden_act == 'sigmoid':
        logger.log('Using sigmoid activation for recurrent hiddens')
        recurrent_hidden_activation = T.nnet.sigmoid
    elif state.recurrent_hidden_act == 'rectifier':
        logger.log('Using rectifier activation for recurrent hiddens')
        recurrent_hidden_activation = lambda x: T.maximum(cast32(0), x)
    elif state.recurrent_hidden_act == 'tanh':
        logger.log('Using hyperbolic tangent activation for recurrent hiddens')
        recurrent_hidden_activation = lambda x: T.tanh(x)
    else:
        logger.log(
            "ERROR: Did not recognize recurrent hidden activation {0!s}, please use tanh, rectifier, or sigmoid"
            .format(state.recurrent_hidden_act))
        raise NotImplementedError(
            "Did not recognize recurrent hidden activation {0!s}, please use tanh, rectifier, or sigmoid"
            .format(state.recurrent_hidden_act))

    logger.log("\n")

    ####################
    #  COST FUNCTIONS  #
    ####################
    if state.cost_funct == 'binary_crossentropy':
        logger.log('Using binary cross-entropy cost!')
        cost_function = lambda x, y: T.mean(T.nnet.binary_crossentropy(x, y))
    elif state.cost_funct == 'square':
        logger.log("Using square error cost!")
        #cost_function = lambda x,y: T.log(T.mean(T.sqr(x-y)))
        cost_function = lambda x, y: T.log(T.sum(T.pow((x - y), 2)))
    else:
        logger.log(
            "ERROR: Did not recognize cost function {0!s}, please use binary_crossentropy or square"
            .format(state.cost_funct))
        raise NotImplementedError(
            "Did not recognize cost function {0!s}, please use binary_crossentropy or square"
            .format(state.cost_funct))

    logger.log("\n")

    ##############################################
    #    Build the training graph for the GSN    #
    ##############################################
    if train_gsn_first:
        '''Build the actual gsn training graph'''
        p_X_chain_gsn, _ = generative_stochastic_network.build_gsn(
            X, weights_list, bias_list, True, state.noiseless_h1,
            state.hidden_add_noise_sigma, state.input_salt_and_pepper,
            state.input_sampling, MRG, visible_activation, hidden_activation,
            walkbacks, logger)
        # now without noise
        p_X_chain_gsn_recon, _ = generative_stochastic_network.build_gsn(
            X, weights_list, bias_list, False, state.noiseless_h1,
            state.hidden_add_noise_sigma, state.input_salt_and_pepper,
            state.input_sampling, MRG, visible_activation, hidden_activation,
            walkbacks, logger)

    ##############################################
    #  Build the training graph for the RNN-GSN  #
    ##############################################
    # If `x_t` is given, deterministic recurrence to compute the u_t. Otherwise, first generate
    def recurrent_step(x_t, u_tm1):
        bv_t = bias_list[0] + T.dot(u_tm1,
                                    recurrent_to_gsn_bias_weights_list[0])
        bh_t = T.concatenate([
            bias_list[i + 1] +
            T.dot(u_tm1, recurrent_to_gsn_bias_weights_list[i + 1])
            for i in range(layers)
        ],
                             axis=0)
        generate = x_t is None
        if generate:
            pass
        ua_t = T.dot(x_t, W_x_u) + T.dot(u_tm1, W_u_u) + recurrent_bias
        u_t = recurrent_hidden_activation(ua_t)
        return None if generate else [ua_t, u_t, bv_t, bh_t]

    logger.log("\nCreating recurrent step scan.")
    # For training, the deterministic recurrence is used to compute all the
    # {h_t, 1 <= t <= T} given Xs. Conditional GSNs can then be trained
    # in batches using those parameters.
    u0 = T.zeros((state.recurrent_hidden_size,
                  ))  # initial value for the RNN hidden units
    (ua, u, bv_t, bh_t), updates_recurrent = theano.scan(
        fn=lambda x_t, u_tm1, *_: recurrent_step(x_t, u_tm1),
        sequences=Xs,
        outputs_info=[None, u0, None, None],
        non_sequences=params)
    # put the bias_list together from hiddens and visible biases
    #b_list = [bv_t.flatten(2)] + [bh_t.dimshuffle((1,0,2))[i] for i in range(len(weights_list))]
    b_list = [bv_t] + [
        (bh_t.T[i * state.hidden_size:(i + 1) * state.hidden_size]).T
        for i in range(layers)
    ]

    _, cost, show_cost = generative_stochastic_network.build_gsn_scan(
        Xs, weights_list, b_list, True, state.noiseless_h1,
        state.hidden_add_noise_sigma, state.input_salt_and_pepper,
        state.input_sampling, MRG, visible_activation, hidden_activation,
        walkbacks, cost_function, logger)
    x_sample_recon, _, _ = generative_stochastic_network.build_gsn_scan(
        Xs, weights_list, b_list, False, state.noiseless_h1,
        state.hidden_add_noise_sigma, state.input_salt_and_pepper,
        state.input_sampling, MRG, visible_activation, hidden_activation,
        walkbacks, cost_function, logger)

    updates_train = updates_recurrent
    #updates_train.update(updates_gsn)
    updates_cost = updates_recurrent

    #updates_recon = updates_recurrent
    #updates_recon.update(updates_gsn_recon)

    #############
    #   COSTS   #
    #############
    logger.log("")
    logger.log('Cost w.r.t p(X|...) at every step in the graph')

    if train_gsn_first:
        gsn_costs = [cost_function(rX, X) for rX in p_X_chain_gsn]
        gsn_show_cost = gsn_costs[-1]
        gsn_cost = numpy.sum(gsn_costs)

    ###################################
    # GRADIENTS AND FUNCTIONS FOR GSN #
    ###################################
    logger.log(["params:", params])

    logger.log("creating functions...")
    start_functions_time = time.time()

    if train_gsn_first:
        gradient_gsn = T.grad(gsn_cost, gsn_params)
        gradient_buffer_gsn = [
            theano.shared(numpy.zeros(param.get_value().shape,
                                      dtype='float32')) for param in gsn_params
        ]

        m_gradient_gsn = [
            momentum * gb + (cast32(1) - momentum) * g
            for (gb, g) in zip(gradient_buffer_gsn, gradient_gsn)
        ]
        param_updates_gsn = [(param, param - learning_rate * mg)
                             for (param, mg) in zip(gsn_params, m_gradient_gsn)
                             ]
        gradient_buffer_updates_gsn = zip(gradient_buffer_gsn, m_gradient_gsn)

        grad_updates_gsn = OrderedDict(param_updates_gsn +
                                       gradient_buffer_updates_gsn)

        logger.log("gsn cost...")
        f_cost_gsn = theano.function(inputs=[X],
                                     outputs=gsn_show_cost,
                                     on_unused_input='warn')

        logger.log("gsn learn...")
        f_learn_gsn = theano.function(inputs=[X],
                                      updates=grad_updates_gsn,
                                      outputs=gsn_show_cost,
                                      on_unused_input='warn')

    #######################################
    # GRADIENTS AND FUNCTIONS FOR RNN-GSN #
    #######################################
    # if we are not using Hessian-free training create the normal sgd functions
    if state.hessian_free == 0:
        gradient = T.grad(cost, params)
        gradient_buffer = [
            theano.shared(numpy.zeros(param.get_value().shape,
                                      dtype='float32')) for param in params
        ]

        m_gradient = [
            momentum * gb + (cast32(1) - momentum) * g
            for (gb, g) in zip(gradient_buffer, gradient)
        ]
        param_updates = [(param, param - learning_rate * mg)
                         for (param, mg) in zip(params, m_gradient)]
        gradient_buffer_updates = zip(gradient_buffer, m_gradient)

        updates = OrderedDict(param_updates + gradient_buffer_updates)
        updates_train.update(updates)

        logger.log("rnn-gsn learn...")
        f_learn = theano.function(inputs=[Xs],
                                  updates=updates_train,
                                  outputs=show_cost,
                                  on_unused_input='warn')

        logger.log("rnn-gsn cost...")
        f_cost = theano.function(inputs=[Xs],
                                 updates=updates_cost,
                                 outputs=show_cost,
                                 on_unused_input='warn')

    logger.log("Training/cost functions done.")
    compilation_time = time.time() - start_functions_time
    # Show the compile time with appropriate easy-to-read units.
    logger.log("Compilation took " + make_time_units_string(compilation_time) +
               ".\n\n")

    ############################################################################################
    # Denoise some numbers : show number, noisy number, predicted number, reconstructed number #
    ############################################################################################
    # Recompile the graph without noise for reconstruction function
    # The layer update scheme
    logger.log(
        "Creating graph for noisy reconstruction function at checkpoints during training."
    )
    f_recon = theano.function(inputs=[Xs], outputs=x_sample_recon[-1])

    # Now do the same but for the GSN in the initial run
    if train_gsn_first:
        f_recon_gsn = theano.function(inputs=[X],
                                      outputs=p_X_chain_gsn_recon[-1])

    logger.log("Done compiling all functions.")
    compilation_time = time.time() - start_functions_time
    # Show the compile time with appropriate easy-to-read units.
    logger.log("Total time took " + make_time_units_string(compilation_time) +
               ".\n\n")

    ############
    # Sampling #
    ############
    # a function to add salt and pepper noise
    f_noise = theano.function(inputs=[X],
                              outputs=salt_and_pepper(
                                  X, state.input_salt_and_pepper))
    # the input to the sampling function
    X_sample = T.fmatrix("X_sampling")
    network_state_input = [X_sample] + [
        T.fmatrix("H_sampling_" + str(i + 1)) for i in range(layers)
    ]

    # "Output" state of the network (noisy)
    # initialized with input, then we apply updates

    network_state_output = [X_sample] + network_state_input[1:]

    visible_pX_chain = []

    # ONE update
    logger.log("Performing one walkback in network state sampling.")
    generative_stochastic_network.update_layers(
        network_state_output, weights_list, bias_list, visible_pX_chain, True,
        state.noiseless_h1, state.hidden_add_noise_sigma,
        state.input_salt_and_pepper, state.input_sampling, MRG,
        visible_activation, hidden_activation, logger)

    if layers == 1:
        f_sample_simple = theano.function(inputs=[X_sample],
                                          outputs=visible_pX_chain[-1])

    # WHY IS THERE A WARNING????
    # because the first odd layers are not used -> directly computed FROM THE EVEN layers
    # unused input = warn
    f_sample2 = theano.function(inputs=network_state_input,
                                outputs=network_state_output +
                                visible_pX_chain,
                                on_unused_input='warn')

    def sample_some_numbers_single_layer():
        x0 = test_X.get_value()[:1]
        samples = [x0]
        x = f_noise(x0)
        for i in range(399):
            x = f_sample_simple(x)
            samples.append(x)
            x = numpy.random.binomial(n=1, p=x, size=x.shape).astype('float32')
            x = f_noise(x)
        return numpy.vstack(samples)

    def sampling_wrapper(NSI):
        # * is the "splat" operator: It takes a list as input, and expands it into actual positional arguments in the function call.
        out = f_sample2(*NSI)
        NSO = out[:len(network_state_output)]
        vis_pX_chain = out[len(network_state_output):]
        return NSO, vis_pX_chain

    def sample_some_numbers(N=400):
        # The network's initial state
        init_vis = test_X.get_value()[:1]

        noisy_init_vis = f_noise(init_vis)

        network_state = [[noisy_init_vis] + [
            numpy.zeros((1, len(b.get_value())), dtype='float32')
            for b in bias_list[1:]
        ]]

        visible_chain = [init_vis]

        noisy_h0_chain = [noisy_init_vis]

        for i in range(N - 1):

            # feed the last state into the network, compute new state, and obtain visible units expectation chain
            net_state_out, vis_pX_chain = sampling_wrapper(network_state[-1])

            # append to the visible chain
            visible_chain += vis_pX_chain

            # append state output to the network state chain
            network_state.append(net_state_out)

            noisy_h0_chain.append(net_state_out[0])

        return numpy.vstack(visible_chain), numpy.vstack(noisy_h0_chain)

    def plot_samples(epoch_number, leading_text):
        to_sample = time.time()
        if layers == 1:
            # one layer model
            V = sample_some_numbers_single_layer()
        else:
            V, H0 = sample_some_numbers()
        img_samples = PIL.Image.fromarray(
            tile_raster_images(V, (root_N_input, root_N_input), (20, 20)))

        fname = outdir + leading_text + 'samples_epoch_' + str(
            epoch_number) + '.png'
        img_samples.save(fname)
        logger.log('Took ' + str(time.time() - to_sample) +
                   ' to sample 400 numbers')

    #############################
    # Save the model parameters #
    #############################
    def save_params_to_file(name, n, gsn_params):
        pass
#         print 'saving parameters...'
#         save_path = outdir+name+'_params_epoch_'+str(n)+'.pkl'
#         f = open(save_path, 'wb')
#         try:
#             cPickle.dump(gsn_params, f, protocol=cPickle.HIGHEST_PROTOCOL)
#         finally:
#             f.close()

    def save_params(params):
        values = [param.get_value(borrow=True) for param in params]
        return values

    def restore_params(params, values):
        for i in range(len(params)):
            params[i].set_value(values[i])

    ################
    # GSN TRAINING #
    ################
    def train_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        logger.log("\n-----------TRAINING GSN------------\n")

        # TRAINING
        n_epoch = state.n_epoch
        batch_size = state.gsn_batch_size
        STOP = False
        counter = 0
        learning_rate.set_value(cast32(state.learning_rate))  # learning rate
        times = []
        best_cost = float('inf')
        best_params = None
        patience = 0

        logger.log(['train X size:', str(train_X.shape.eval())])
        logger.log(['valid X size:', str(valid_X.shape.eval())])
        logger.log(['test X size:', str(test_X.shape.eval())])

        if state.vis_init:
            bias_list[0].set_value(
                logit(numpy.clip(0.9, 0.001,
                                 train_X.get_value().mean(axis=0))))

        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            logger.log('Testing : skip training')
            STOP = True

        while not STOP:
            counter += 1
            t = time.time()
            logger.append([counter, '\t'])

            #shuffle the data
            data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y,
                                     test_X, test_Y, dataset, rng)

            #train
            train_costs = []
            for i in xrange(len(train_X.get_value(borrow=True)) / batch_size):
                x = train_X.get_value()[i * batch_size:(i + 1) * batch_size]
                cost = f_learn_gsn(x)
                train_costs.append([cost])
            train_costs = numpy.mean(train_costs)
            # record it
            logger.append(['Train:', trunc(train_costs), '\t'])
            with open(gsn_train_convergence, 'a') as f:
                f.write("{0!s},".format(train_costs))
                f.write("\n")

            #valid
            valid_costs = []
            for i in xrange(len(valid_X.get_value(borrow=True)) / batch_size):
                x = valid_X.get_value()[i * batch_size:(i + 1) * batch_size]
                cost = f_cost_gsn(x)
                valid_costs.append([cost])
            valid_costs = numpy.mean(valid_costs)
            # record it
            logger.append(['Valid:', trunc(valid_costs), '\t'])
            with open(gsn_valid_convergence, 'a') as f:
                f.write("{0!s},".format(valid_costs))
                f.write("\n")

            #test
            test_costs = []
            for i in xrange(len(test_X.get_value(borrow=True)) / batch_size):
                x = test_X.get_value()[i * batch_size:(i + 1) * batch_size]
                cost = f_cost_gsn(x)
                test_costs.append([cost])
            test_costs = numpy.mean(test_costs)
            # record it
            logger.append(['Test:', trunc(test_costs), '\t'])
            with open(gsn_test_convergence, 'a') as f:
                f.write("{0!s},".format(test_costs))
                f.write("\n")

            #check for early stopping
            cost = numpy.sum(valid_costs)
            if cost < best_cost * state.early_stop_threshold:
                patience = 0
                best_cost = cost
                # save the parameters that made it the best
                best_params = save_params(gsn_params)
            else:
                patience += 1

            if counter >= n_epoch or patience >= state.early_stop_length:
                STOP = True
                if best_params is not None:
                    restore_params(gsn_params, best_params)
                save_params_to_file('gsn', counter, gsn_params)

            timing = time.time() - t
            times.append(timing)

            logger.append('time: ' + make_time_units_string(timing) + '\t')

            logger.log('remaining: ' +
                       make_time_units_string((n_epoch - counter) *
                                              numpy.mean(times)))

            if (counter % state.save_frequency) == 0 or STOP is True:
                n_examples = 100
                random_idx = numpy.array(
                    R.sample(range(len(test_X.get_value(borrow=True))),
                             n_examples))
                numbers = test_X.get_value(borrow=True)[random_idx]
                noisy_numbers = f_noise(
                    test_X.get_value(borrow=True)[random_idx])
                reconstructed = f_recon_gsn(noisy_numbers)
                # Concatenate stuff
                stacked = numpy.vstack([
                    numpy.vstack([
                        numbers[i * 10:(i + 1) * 10],
                        noisy_numbers[i * 10:(i + 1) * 10],
                        reconstructed[i * 10:(i + 1) * 10]
                    ]) for i in range(10)
                ])
                number_reconstruction = PIL.Image.fromarray(
                    tile_raster_images(stacked, (root_N_input, root_N_input),
                                       (10, 30)))

                number_reconstruction.save(outdir +
                                           'gsn_number_reconstruction_epoch_' +
                                           str(counter) + '.png')

                #sample_numbers(counter, 'seven')
                plot_samples(counter, 'gsn')

                #save gsn_params
                save_params_to_file('gsn', counter, gsn_params)

            # ANNEAL!
            new_lr = learning_rate.get_value() * annealing
            learning_rate.set_value(new_lr)

        # 10k samples
        print 'Generating 10,000 samples'
        samples, _ = sample_some_numbers(N=10000)
        f_samples = outdir + 'samples.npy'
        numpy.save(f_samples, samples)
        print 'saved digits'

    def train_RNN_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        # If we are using Hessian-free training
        if state.hessian_free == 1:
            pass
#         gradient_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=5000)
#         cg_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=1000)
#         valid_dataset = hf_sequence_dataset([valid_X.get_value()], batch_size=None, number_batches=1000)
#
#         s = x_samples
#         costs = [cost, show_cost]
#         hf_optimizer(params, [Xs], s, costs, u, ua).train(gradient_dataset, cg_dataset, initial_lambda=1.0, preconditioner=True, validation=valid_dataset)

# If we are using SGD training
        else:
            # Define the re-used loops for f_learn and f_cost
            def apply_cost_function_to_dataset(function, dataset):
                costs = []
                for i in xrange(
                        len(dataset.get_value(borrow=True)) / batch_size):
                    xs = dataset.get_value(
                        borrow=True)[i * batch_size:(i + 1) * batch_size]
                    cost = function(xs)
                    costs.append([cost])
                return numpy.mean(costs)

            logger.log("\n-----------TRAINING RNN-GSN------------\n")
            # TRAINING
            n_epoch = state.n_epoch
            batch_size = state.batch_size
            STOP = False
            counter = 0
            learning_rate.set_value(cast32(
                state.learning_rate))  # learning rate
            times = []
            best_cost = float('inf')
            best_params = None
            patience = 0

            logger.log(['train X size:', str(train_X.shape.eval())])
            logger.log(['valid X size:', str(valid_X.shape.eval())])
            logger.log(['test X size:', str(test_X.shape.eval())])

            if state.vis_init:
                bias_list[0].set_value(
                    logit(
                        numpy.clip(0.9, 0.001,
                                   train_X.get_value().mean(axis=0))))

            if state.test_model:
                # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
                logger.log('Testing : skip training')
                STOP = True

            while not STOP:
                counter += 1
                t = time.time()
                logger.append([counter, '\t'])

                #shuffle the data
                data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y,
                                         test_X, test_Y, dataset, rng)

                #train
                train_costs = apply_cost_function_to_dataset(f_learn, train_X)
                # record it
                logger.append(['Train:', trunc(train_costs), '\t'])
                with open(train_convergence, 'a') as f:
                    f.write("{0!s},".format(train_costs))
                    f.write("\n")

                #valid
                valid_costs = apply_cost_function_to_dataset(f_cost, valid_X)
                # record it
                logger.append(['Valid:', trunc(valid_costs), '\t'])
                with open(valid_convergence, 'a') as f:
                    f.write("{0!s},".format(valid_costs))
                    f.write("\n")

                #test
                test_costs = apply_cost_function_to_dataset(f_cost, test_X)
                # record it
                logger.append(['Test:', trunc(test_costs), '\t'])
                with open(test_convergence, 'a') as f:
                    f.write("{0!s},".format(test_costs))
                    f.write("\n")

                #check for early stopping
                cost = numpy.sum(valid_costs)
                if cost < best_cost * state.early_stop_threshold:
                    patience = 0
                    best_cost = cost
                    # save the parameters that made it the best
                    best_params = save_params(params)
                else:
                    patience += 1

                if counter >= n_epoch or patience >= state.early_stop_length:
                    STOP = True
                    if best_params is not None:
                        restore_params(params, best_params)
                    save_params_to_file('all', counter, params)

                timing = time.time() - t
                times.append(timing)

                logger.append('time: ' + make_time_units_string(timing) + '\t')

                logger.log('remaining: ' +
                           make_time_units_string((n_epoch - counter) *
                                                  numpy.mean(times)))

                if (counter % state.save_frequency) == 0 or STOP is True:
                    n_examples = 100
                    nums = test_X.get_value(borrow=True)[range(n_examples)]
                    noisy_nums = f_noise(
                        test_X.get_value(borrow=True)[range(n_examples)])
                    reconstructions = []
                    for i in xrange(0, len(noisy_nums)):
                        recon = f_recon(noisy_nums[max(0, (i + 1) -
                                                       batch_size):i + 1])
                        reconstructions.append(recon)
                    reconstructed = numpy.array(reconstructions)

                    # Concatenate stuff
                    stacked = numpy.vstack([
                        numpy.vstack([
                            nums[i * 10:(i + 1) * 10],
                            noisy_nums[i * 10:(i + 1) * 10],
                            reconstructed[i * 10:(i + 1) * 10]
                        ]) for i in range(10)
                    ])
                    number_reconstruction = PIL.Image.fromarray(
                        tile_raster_images(stacked,
                                           (root_N_input, root_N_input),
                                           (10, 30)))

                    number_reconstruction.save(
                        outdir + 'rnngsn_number_reconstruction_epoch_' +
                        str(counter) + '.png')

                    #sample_numbers(counter, 'seven')
                    plot_samples(counter, 'rnngsn')

                    #save params
                    save_params_to_file('all', counter, params)

                # ANNEAL!
                new_lr = learning_rate.get_value() * annealing
                learning_rate.set_value(new_lr)

            # 10k samples
            print 'Generating 10,000 samples'
            samples, _ = sample_some_numbers(N=10000)
            f_samples = outdir + 'samples.npy'
            numpy.save(f_samples, samples)
            print 'saved digits'

    #####################
    # STORY 2 ALGORITHM #
    #####################
    # train the GSN parameters first to get a good baseline (if not loaded from parameter .pkl file)
    if train_gsn_first:
        train_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
    # train the entire RNN-GSN
    train_RNN_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
 def train_regression(iteration, train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
     print '-------------------------------------------'
     print 'TRAINING RECURRENT REGRESSION FOR ITERATION',iteration
     with open(logfile,'a') as f:
         f.write("--------------------------\nTRAINING RECURRENT REGRESSION FOR ITERATION {0!s}\n".format(iteration))
     
     # TRAINING
     # TRAINING
     n_epoch     =   state.n_epoch
     batch_size  =   state.batch_size
     STOP        =   False
     counter     =   0
     if iteration == 0:
         recurrent_learning_rate.set_value(cast32(state.learning_rate))  # learning rate
     times = []
     best_cost = float('inf')
     patience = 0
         
     print 'learning rate:',recurrent_learning_rate.get_value()
     
     print 'train X size:',str(train_X.shape.eval())
     print 'valid X size:',str(valid_X.shape.eval())
     print 'test X size:',str(test_X.shape.eval())
 
     train_costs =   []
     valid_costs =   []
     test_costs  =   []
     
     if state.vis_init:
         bias_list[0].set_value(logit(numpy.clip(0.9,0.001,train_X.get_value().mean(axis=0))))
 
     if state.test_model:
         # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
         print 'Testing : skip training'
         STOP    =   True
 
 
     while not STOP:
         counter += 1
         t = time.time()
         print counter,'\t',
         
         #shuffle the data
         data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
         
         #train
         #init recurrent hiddens as zero
         recurrent_hiddens = [T.zeros((batch_size,recurrent_layer_size)).eval() for recurrent_layer_size in recurrent_layer_sizes]
         train_cost = []
         for i in range(len(train_X.get_value(borrow=True)) / batch_size):
             x = train_X.get_value()[i * batch_size : (i+1) * batch_size]
             x1 = train_X.get_value()[(i * batch_size) + 1 : ((i+1) * batch_size) + 1]
             [x,x1], recurrent_hiddens = fix_input_size([x,x1], recurrent_hiddens)
             _ins = recurrent_hiddens + [x,x1]
             _outs = recurrent_f_learn(*_ins)
             recurrent_hiddens = _outs[:len(recurrent_hiddens)]
             cost = _outs[-1]
             train_cost.append(cost)
             
         train_cost = numpy.mean(train_cost) 
         train_costs.append(train_cost)
         print 'rTrain : ',trunc(train_cost), '\t',
         with open(logfile,'a') as f:
             f.write("rTrain : {0!s}\t".format(trunc(train_cost)))
         with open(recurrent_train_convergence,'a') as f:
             f.write("{0!s},".format(train_cost))
 
         #valid
         #init recurrent hiddens as zero
         recurrent_hiddens = [T.zeros((batch_size,recurrent_layer_size)).eval() for recurrent_layer_size in recurrent_layer_sizes]
         valid_cost  =  []
         for i in range(len(valid_X.get_value(borrow=True)) / batch_size):
             x = valid_X.get_value()[i * batch_size : (i+1) * batch_size]
             x1 = valid_X.get_value()[(i * batch_size) + 1 : ((i+1) * batch_size) + 1]
             [x,x1], recurrent_hiddens = fix_input_size([x,x1], recurrent_hiddens)
             _ins = recurrent_hiddens + [x,x1]
             _outs = f_cost(*_ins)
             recurrent_hiddens = _outs[:len(recurrent_hiddens)]
             cost = _outs[-1]
             valid_cost.append(cost)
                 
         valid_cost = numpy.mean(valid_cost) 
         valid_costs.append(valid_cost)
         print 'rValid : ', trunc(valid_cost), '\t',
         with open(logfile,'a') as f:
             f.write("rValid : {0!s}\t".format(trunc(valid_cost)))
         with open(recurrent_valid_convergence,'a') as f:
             f.write("{0!s},".format(valid_cost))
 
         #test
         recurrent_hiddens = [T.zeros((batch_size,recurrent_layer_size)).eval() for recurrent_layer_size in recurrent_layer_sizes]
         test_cost  =   []
         for i in range(len(test_X.get_value(borrow=True)) / batch_size):
             x = test_X.get_value()[i * batch_size : (i+1) * batch_size]
             x1 = test_X.get_value()[(i * batch_size) + 1 : ((i+1) * batch_size) + 1]
             [x,x1], recurrent_hiddens = fix_input_size([x,x1], recurrent_hiddens)
             _ins = recurrent_hiddens + [x,x1]
             _outs = f_cost(*_ins)
             recurrent_hiddens = _outs[:len(recurrent_hiddens)]
             cost = _outs[-1]
             test_cost.append(cost)
             
         test_cost = numpy.mean(test_cost) 
         test_costs.append(test_cost)
         print 'rTest  : ', trunc(test_cost), '\t',
         with open(logfile,'a') as f:
             f.write("rTest : {0!s}\t".format(trunc(test_cost)))
         with open(recurrent_test_convergence,'a') as f:
             f.write("{0!s},".format(test_cost))
 
         #check for early stopping
         cost = train_cost
         if iteration != 0:
             cost = cost + train_cost
         if cost < best_cost*state.early_stop_threshold:
             patience = 0
             best_cost = cost
         else:
             patience += 1
             
         timing = time.time() - t
         times.append(timing)
 
         print 'time : ', trunc(timing),
         
         print 'remaining: ', trunc((n_epoch - counter) * numpy.mean(times) / 60 / 60), 'hrs'
         
         with open(logfile,'a') as f:
             f.write("B : {0!s}\t".format(str([trunc(vb.get_value().mean()) for vb in recurrent_bias_list])))
             
         with open(logfile,'a') as f:
             f.write("W : {0!s}\t".format(str([trunc(abs(v.get_value(borrow=True)).mean()) for v in recurrent_weights_list_encode])))
         
         with open(logfile,'a') as f:
             f.write("V : {0!s}\t".format(str([trunc(abs(v.get_value(borrow=True)).mean()) for v in recurrent_weights_list_decode])))
             
         with open(logfile,'a') as f:
             f.write("Time : {0!s} seconds\n".format(trunc(timing)))
                 
         if (counter % state.save_frequency) == 0:
             # Checking reconstruction
             nums = test_X.get_value()[range(100)]
             noisy_nums = f_noise(test_X.get_value()[range(100)])
             reconstructed = []
             reconstructed_prediction = []
             #init recurrent hiddens as zero
             recurrent_hiddens = [T.zeros((batch_size,recurrent_layer_size)).eval() for recurrent_layer_size in recurrent_layer_sizes]
             for num in noisy_nums:
                 _ins = recurrent_hiddens + [num]
                 _outs = f_recon(*_ins)
                 recurrent_hiddens = _outs[:len(recurrent_hiddens)]
                 [recon,recon_pred] = _outs[len(recurrent_hiddens):]
                 reconstructed.append(recon)
                 reconstructed_prediction.append(recon_pred)
             # Concatenate stuff
             stacked = numpy.vstack([numpy.vstack([nums[i*10 : (i+1)*10], noisy_nums[i*10 : (i+1)*10], reconstructed[i*10 : (i+1)*10], reconstructed_prediction[i*10 : (i+1)*10]]) for i in range(10)])
             
             number_reconstruction   =   PIL.Image.fromarray(tile_raster_images(stacked, (root_N_input,root_N_input), (10,40)))
             #epoch_number    =   reduce(lambda x,y : x + y, ['_'] * (4-len(str(counter)))) + str(counter)
             number_reconstruction.save(outdir+'recurrent_number_reconstruction_iteration_'+str(iteration)+'_epoch_'+str(counter)+'.png')
     
             #sample_numbers(counter, 'seven')
             plot_samples(counter, iteration)
     
             #save params
             save_params_to_file('recurrent', counter, params, iteration)
      
         # ANNEAL!
         new_r_lr = recurrent_learning_rate.get_value() * annealing
         recurrent_learning_rate.set_value(new_r_lr)
 
     # if test
 
     # 10k samples
     print 'Generating 10,000 samples'
     samples, _  =   sample_some_numbers(N=10000)
     f_samples   =   outdir+'samples.npy'
     numpy.save(f_samples, samples)
     print 'saved digits'
def experiment(state, outdir_base='./'):
    rng.seed(1) #seed the numpy random generator  
    # create the output directories and log/result files
    data.mkdir_p(outdir_base)
    outdir = outdir_base + "/" + state.dataset + "/"
    data.mkdir_p(outdir)
    logfile = outdir+"log.txt"
    with open(logfile,'w') as f:
        f.write("MODEL 3, {0!s}\n\n".format(state.dataset))
    train_convergence_pre = outdir+"train_convergence_pre.csv"
    train_convergence_post = outdir+"train_convergence_post.csv"
    valid_convergence_pre = outdir+"valid_convergence_pre.csv"
    valid_convergence_post = outdir+"valid_convergence_post.csv"
    test_convergence_pre = outdir+"test_convergence_pre.csv"
    test_convergence_post = outdir+"test_convergence_post.csv"
    recurrent_train_convergence = outdir+"recurrent_train_convergence.csv"
    recurrent_valid_convergence = outdir+"recurrent_valid_convergence.csv"
    recurrent_test_convergence = outdir+"recurrent_test_convergence.csv"
    
    print
    print "----------MODEL 3--------------"
    print
    #load parameters from config file if this is a test
    config_filename = outdir+'config'
    if state.test_model and 'config' in os.listdir(outdir):
        config_vals = load_from_config(config_filename)
        for CV in config_vals:
            print CV
            if CV.startswith('test'):
                print 'Do not override testing switch'
                continue        
            try:
                exec('state.'+CV) in globals(), locals()
            except:
                exec('state.'+CV.split('=')[0]+"='"+CV.split('=')[1]+"'") in globals(), locals()
    else:
        # Save the current configuration
        # Useful for logs/experiments
        print 'Saving config'
        with open(config_filename, 'w') as f:
            f.write(str(state))


    print state
    # Load the data
    artificial = False #flag for using my artificially-sequenced mnist datasets
    if state.dataset == 'MNIST_1' or state.dataset == 'MNIST_2' or state.dataset == 'MNIST_3':
        (train_X, train_Y), (valid_X, valid_Y), (test_X, test_Y) = data.load_mnist(state.data_path)
        train_X = numpy.concatenate((train_X, valid_X))
        train_Y = numpy.concatenate((train_Y, valid_Y))
        artificial = True
        try:
            dataset = int(state.dataset.split('_')[1])
        except:
            raise AssertionError("artificial dataset number not recognized. Input was "+state.dataset)
    else:
        raise AssertionError("dataset not recognized.")
    
    #make shared variables for better use of the gpu
    train_X = theano.shared(train_X)
    train_Y = theano.shared(train_Y)
    valid_X = theano.shared(valid_X)
    valid_Y = theano.shared(valid_Y) 
    test_X = theano.shared(test_X)
    test_Y = theano.shared(test_Y) 
   
    if artificial: #run the appropriate artificial sequencing of mnist data
        print 'Sequencing MNIST data...'
        print 'train set size:',len(train_Y.eval())
        print 'valid set size:',len(valid_Y.eval())
        print 'test set size:',len(test_Y.eval())
        data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
        print 'train set size:',len(train_Y.eval())
        print 'valid set size:',len(valid_Y.eval())
        print 'test set size:',len(test_Y.eval())
        print 'Sequencing done.'
        print
    
    N_input =   train_X.eval().shape[1]
    root_N_input = numpy.sqrt(N_input)
    
    # Theano variables and RNG
    X       = T.fmatrix("X")
    X1      = T.fmatrix("X1")
    H       = T.fmatrix("Hrecurr_visible")
    MRG = RNG_MRG.MRG_RandomStreams(1)
    
    
    # Network and training specifications
    layers          =   state.layers # number hidden layers
    walkbacks       =   state.walkbacks # number of walkbacks 
    recurrent_layers =   state.recurrent_layers # number recurrent hidden layers
    recurrent_walkbacks =   state.recurrent_walkbacks # number of recurrent walkbacks 
    layer_sizes     =   [N_input] + [state.hidden_size] * layers # layer sizes, from h0 to hK (h0 is the visible layer)
    print 'layer_sizes:', layer_sizes
    recurrent_layer_sizes = [state.hidden_size*numpy.ceil(layers/2.0)] + [state.recurrent_hidden_size] * recurrent_layers
    print 'recurrent_sizes',recurrent_layer_sizes
    learning_rate   =   theano.shared(cast32(state.learning_rate))  # learning rate
    recurrent_learning_rate   =   theano.shared(cast32(state.learning_rate))  # learning rate
    annealing       =   cast32(state.annealing) # exponential annealing coefficient
    momentum        =   theano.shared(cast32(state.momentum)) # momentum term 
    
    recurrent_hiddens_input = [H] + [T.fmatrix(name="hrecurr_"+str(i+1)) for i in range(recurrent_layers)]
    recurrent_hiddens_output = recurrent_hiddens_input[:1] + recurrent_hiddens_input[1:]

    # PARAMETERS : weights list and bias list.
    # initialize a list of weights and biases based on layer_sizes: these are theta_gsn parameters
    weights_list    =   [get_shared_weights(layer_sizes[i], layer_sizes[i+1], name="W_{0!s}_{1!s}".format(i,i+1)) for i in range(layers)] # initialize each layer to uniform sample from sqrt(6. / (n_in + n_out))
    bias_list       =   [get_shared_bias(layer_sizes[i], name='b_'+str(i)) for i in range(layers + 1)] # initialize each layer to 0's.
    # parameters for recurrent part
    #recurrent weights initial visible layer is the even layers of the network below it: these are theta_transition parameters
    recurrent_weights_list_encode = [get_shared_weights(recurrent_layer_sizes[i], recurrent_layer_sizes[i+1], name="U_{0!s}_{1!s}".format(i,i+1)) for i in range(recurrent_layers)] #untied weights in the recurrent layers
    recurrent_weights_list_decode = [get_shared_weights(recurrent_layer_sizes[i+1], recurrent_layer_sizes[i], name="V_{0!s}_{1!s}".format(i+1,i)) for i in range(recurrent_layers)]
    recurrent_bias_list = [get_shared_bias(recurrent_layer_sizes[i], name='vb_'+str(i)) for i in range(recurrent_layers+1)] # initialize to 0's.

 
    ''' F PROP '''
    if state.act == 'sigmoid':
        print 'Using sigmoid activation'
        hidden_activation = T.nnet.sigmoid
    elif state.act == 'rectifier':
        print 'Using rectifier activation'
        hidden_activation = lambda x : T.maximum(cast32(0), x)
    elif state.act == 'tanh':
        print 'Using tanh activation'
        hidden_activation = lambda x : T.tanh(x)
        
    print 'Using sigmoid activation for visible layer'
    visible_activation = T.nnet.sigmoid 
  
        
    def update_layers(hiddens, p_X_chain, noisy = True):
        print 'odd layer updates'
        update_odd_layers(hiddens, noisy)
        print 'even layer updates'
        update_even_layers(hiddens, p_X_chain, noisy)
        print 'done full update.'
        print
        
    def update_layers_reverse_order(hiddens, p_X_chain, noisy = True):
        print 'even layer updates'
        update_even_layers(hiddens, p_X_chain, noisy)
        print 'odd layer updates'
        update_odd_layers(hiddens, noisy)
        print 'done full update.'
        print
        
    # Odd layer update function
    # just a loop over the odd layers
    def update_odd_layers(hiddens, noisy):
        for i in range(1, len(hiddens), 2):
            print 'updating layer',i
            simple_update_layer(hiddens, None, i, add_noise = noisy)
    
    # Even layer update
    # p_X_chain is given to append the p(X|...) at each full update (one update = odd update + even update)
    def update_even_layers(hiddens, p_X_chain, noisy):
        for i in range(0, len(hiddens), 2):
            print 'updating layer',i
            simple_update_layer(hiddens, p_X_chain, i, add_noise = noisy)
    
    # The layer update function
    # hiddens   :   list containing the symbolic theano variables [visible, hidden1, hidden2, ...]
    #               layer_update will modify this list inplace
    # p_X_chain :   list containing the successive p(X|...) at each update
    #               update_layer will append to this list
    # add_noise     : pre and post activation gaussian noise
    
    def simple_update_layer(hiddens, p_X_chain, i, add_noise=True):                               
        # Compute the dot product, whatever layer        
        # If the visible layer X
        if i == 0:
            print 'using '+str(weights_list[i])+'.T'
            hiddens[i]  =   T.dot(hiddens[i+1], weights_list[i].T) + bias_list[i]           
        # If the top layer
        elif i == len(hiddens)-1:
            print 'using',weights_list[i-1]
            hiddens[i]  =   T.dot(hiddens[i-1], weights_list[i-1]) + bias_list[i]
        # Otherwise in-between layers
        else:
            print "using {0!s} and {1!s}.T".format(weights_list[i-1], weights_list[i])
            # next layer        :   hiddens[i+1], assigned weights : W_i
            # previous layer    :   hiddens[i-1], assigned weights : W_(i-1)
            hiddens[i]  =   T.dot(hiddens[i+1], weights_list[i].T) + T.dot(hiddens[i-1], weights_list[i-1]) + bias_list[i]
    
        # Add pre-activation noise if NOT input layer
        if i==1 and state.noiseless_h1:
            print '>>NO noise in first hidden layer'
            add_noise   =   False
    
        # pre activation noise            
        if i != 0 and add_noise:
            print 'Adding pre-activation gaussian noise for layer', i
            hiddens[i]  =   add_gaussian_noise(hiddens[i], state.hidden_add_noise_sigma)
       
        # ACTIVATION!
        if i == 0:
            print 'Sigmoid units activation for visible layer X'
            hiddens[i]  =   visible_activation(hiddens[i])
        else:
            print 'Hidden units {} activation for layer'.format(state.act), i
            hiddens[i]  =   hidden_activation(hiddens[i])
    
        # post activation noise
        # why is there post activation noise? Because there is already pre-activation noise, this just doubles the amount of noise between each activation of the hiddens.           
#         if i != 0 and add_noise:
#             print 'Adding post-activation gaussian noise for layer', i
#             hiddens[i]  =   add_gaussian(hiddens[i], state.hidden_add_noise_sigma)
    
        # build the reconstruction chain if updating the visible layer X
        if i == 0:
            # if input layer -> append p(X|...)
            p_X_chain.append(hiddens[i])
            
            # sample from p(X|...) - SAMPLING NEEDS TO BE CORRECT FOR INPUT TYPES I.E. FOR BINARY MNIST SAMPLING IS BINOMIAL. real-valued inputs should be gaussian
            if state.input_sampling:
                print 'Sampling from input'
                sampled     =   sample_visibles(hiddens[i])
            else:
                print '>>NO input sampling'
                sampled     =   hiddens[i]
            # add noise
            sampled     =   salt_and_pepper(sampled, state.input_salt_and_pepper)
            
            # set input layer
            hiddens[i]  =   sampled
            
    def update_recurrent_layers(hiddens, p_X_chain, noisy = True):
        print 'odd layer updates'
        update_odd_recurrent_layers(hiddens, noisy)
        print 'even layer updates'
        update_even_recurrent_layers(hiddens, p_X_chain, noisy)
        print 'done full update.'
        print
        
    # Odd layer update function
    # just a loop over the odd layers
    def update_odd_recurrent_layers(hiddens, noisy):
        for i in range(1, len(hiddens), 2):
            print 'updating layer',i
            simple_update_recurrent_layer(hiddens, None, i, add_noise = noisy)
    
    # Even layer update
    # p_X_chain is given to append the p(X|...) at each full update (one update = odd update + even update)
    def update_even_recurrent_layers(hiddens, p_X_chain, noisy):
        for i in range(0, len(hiddens), 2):
            print 'updating layer',i
            simple_update_recurrent_layer(hiddens, p_X_chain, i, add_noise = noisy)
    
    # The layer update function
    # hiddens   :   list containing the symbolic theano variables [visible, hidden1, hidden2, ...]
    #               layer_update will modify this list inplace
    # p_X_chain :   list containing the successive p(X|...) at each update
    #               update_layer will append to this list
    # add_noise     : pre and post activation gaussian noise
    
    def simple_update_recurrent_layer(hiddens, p_X_chain, i, add_noise=True):                               
        # Compute the dot product, whatever layer        
        # If the visible layer X
        if i == 0:
            print 'using '+str(recurrent_weights_list_decode[i])
            hiddens[i]  =   T.dot(hiddens[i+1], recurrent_weights_list_decode[i]) + recurrent_bias_list[i]           
        # If the top layer
        elif i == len(hiddens)-1:
            print 'using',recurrent_weights_list_encode[i-1]
            hiddens[i]  =   T.dot(hiddens[i-1], recurrent_weights_list_encode[i-1]) + recurrent_bias_list[i]
        # Otherwise in-between layers
        else:
            print "using {0!s} and {1!s}".format(recurrent_weights_list_encode[i-1], recurrent_weights_list_decode[i])
            # next layer        :   hiddens[i+1], assigned weights : W_i
            # previous layer    :   hiddens[i-1], assigned weights : W_(i-1)
            hiddens[i]  =   T.dot(hiddens[i+1], recurrent_weights_list_decode[i]) + T.dot(hiddens[i-1], recurrent_weights_list_encode[i-1]) + recurrent_bias_list[i]
    
        # Add pre-activation noise if NOT input layer
        if i==1 and state.noiseless_h1:
            print '>>NO noise in first hidden layer'
            add_noise   =   False
    
        # pre activation noise
        if i != 0 and add_noise:
            print 'Adding pre-activation gaussian noise for layer', i
            hiddens[i]  =   add_gaussian_noise(hiddens[i], state.hidden_add_noise_sigma)
       
        # ACTIVATION!
        print 'Recurrent hidden units {} activation for layer'.format(state.act), i
        hiddens[i]  =   hidden_activation(hiddens[i])
    
        # post activation noise
        # why is there post activation noise? Because there is already pre-activation noise, this just doubles the amount of noise between each activation of the hiddens.    
#         if i != 0 and add_noise:
#             print 'Adding post-activation gaussian noise for layer', i
#             hiddens[i]  =   add_gaussian(hiddens[i], state.hidden_add_noise_sigma)
    
        # build the reconstruction chain if updating the visible layer X
        if i == 0:
            # if input layer -> append p(X|...)
            p_X_chain.append(hiddens[i])
            
            # sample from p(X|...) - SAMPLING NEEDS TO BE CORRECT FOR INPUT TYPES I.E. FOR BINARY MNIST SAMPLING IS BINOMIAL. real-valued inputs should be gaussian
            if state.input_sampling:
                print 'Sampling from input'
                sampled     =   sample_hiddens(hiddens[i])
            else:
                print '>>NO input sampling'
                sampled     =   hiddens[i]
            # add noise
            sampled     =   salt_and_pepper(sampled, state.input_salt_and_pepper)
            
            # set input layer
            hiddens[i]  =   sampled
            
            
    def sample_hiddens(hiddens):
        return MRG.multinomial(pvals = hiddens, dtype='float32')
    
    def sample_visibles(visibles):
        return MRG.binomial(p = visibles, size=visibles.shape, dtype='float32')
    
    def build_gsn(hiddens, p_X_chain, noiseflag):
        print "Building the gsn graph :", walkbacks,"updates"
        for i in range(walkbacks):
            print "Walkback {!s}/{!s}".format(i+1,walkbacks)
            update_layers(hiddens, p_X_chain, noisy=noiseflag)
            
    def build_gsn_reverse(hiddens, p_X_chain, noiseflag):
        print "Building the gsn graph reverse layer update order:", walkbacks,"updates"
        for i in range(walkbacks):
            print "Walkback {!s}/{!s}".format(i+1,walkbacks)
            update_layers_reverse_order(hiddens, p_X_chain, noisy=noiseflag)
    
    def build_recurrent_gsn(recurrent_hiddens, p_H_chain, noiseflag):
        #recurrent_hiddens is a list that will be appended to for each of the walkbacks. Used because I need the immediate next set of hidden states to carry through when using the functions - trying not to break the sequences.
        print "Building the recurrent gsn graph :", recurrent_walkbacks,"updates"
        for i in range(recurrent_walkbacks):
            print "Recurrent walkback {!s}/{!s}".format(i+1,recurrent_walkbacks)
            update_recurrent_layers(recurrent_hiddens, p_H_chain, noisy=noiseflag)
        
        
        
    def build_graph(hiddens, recurrent_hiddens, noiseflag, prediction_index=0):
        p_X_chain = []
        recurrent_hiddens = []
        p_H_chain = []
        p_X1_chain = []
        # The layer update scheme
        print "Building the model graph :", walkbacks*2 + recurrent_walkbacks,"updates"

        # First, build the GSN for the given input.
        build_gsn(hiddens, p_X_chain, noiseflag)
        
        # Next, use the recurrent GSN to predict future hidden states
        # the recurrent hiddens base layer only consists of the odd layers from the gsn - this is because the gsn is constructed by half its layers every time
        recurrent_hiddens[0] = T.concatenate([hiddens[i] for i in range(1,len(hiddens),2)], axis=1)
        if noiseflag:
            recurrent_hiddens[0] = salt_and_pepper(recurrent_hiddens[0], state.input_salt_and_pepper)
        # Build the recurrent gsn predicting the next hidden states of future input gsn's
        build_recurrent_gsn(recurrent_hiddens, p_H_chain, noiseflag)
        
        #for every next predicted hidden states H, restore the odd layers of the hiddens from what they were predicted to be by the recurrent gsn
        for predicted_H in p_H_chain:
            index_accumulator = 0
            for i in range(1,len(hiddens),2):
                hiddens[i] = p_H_chain[prediction_index][:, index_accumulator:index_accumulator + layer_sizes[i]]
                index_accumulator += layer_sizes[i]
            build_gsn_reverse(hiddens, p_X1_chain, noiseflag)
        
        return hiddens, recurrent_hiddens, p_X_chain, p_H_chain, p_X1_chain
   
    
    ''' Corrupt X '''
    X_corrupt   = salt_and_pepper(X, state.input_salt_and_pepper)

    ''' hidden layer init '''
    hiddens     = [X_corrupt]
    
    print "Hidden units initialization"
    for w in weights_list:
        # init with zeros
        print "Init hidden units at zero before creating the graph"
        print
        hiddens.append(T.zeros_like(T.dot(hiddens[-1], w)))

    hiddens, recurrent_hiddens_output, p_X_chain, p_H_chain, p_X1_chain = build_graph(hiddens, recurrent_hiddens_output, noiseflag=True)
    

    # COST AND GRADIENTS    
    print
    print 'Cost w.r.t p(X|...) at every step in the graph'
    COSTS_pre        =   [T.mean(T.nnet.binary_crossentropy(rX, X)) for rX in p_X_chain]
    show_COST_pre    =   COSTS_pre[-1]
    COST_pre         =   numpy.sum(COSTS_pre)
    COSTS_post       =   [T.mean(T.nnet.binary_crossentropy(rX1, X1)) for rX1 in p_X1_chain]
    show_COST_post   =   COSTS_post[-1]
    COST_post        =   numpy.sum(COSTS_post)
    COSTS            =   COSTS_pre + COSTS_post
    COST             =   numpy.sum(COSTS)
        
    params           =   weights_list + bias_list
    print "params:",params

    recurrent_params = recurrent_weights_list_encode + recurrent_weights_list_decode + recurrent_bias_list
    print "recurrent params:", recurrent_params   
    
     
    
    print "creating functions..."
    gradient_init        =   T.grad(COST_pre, params)
                 
    gradient_buffer_init =   [theano.shared(numpy.zeros(param.get_value().shape, dtype='float32')) for param in params]
     
    m_gradient_init      =   [momentum * gb + (cast32(1) - momentum) * g for (gb, g) in zip(gradient_buffer_init, gradient_init)]
    param_updates_init   =   [(param, param - learning_rate * mg) for (param, mg) in zip(params, m_gradient_init)]
    gradient_buffer_updates_init = zip(gradient_buffer_init, m_gradient_init)
         
    updates_init         =   OrderedDict(param_updates_init + gradient_buffer_updates_init)
    
    
    gradient        =   T.grad(COST, params)
                
    gradient_buffer =   [theano.shared(numpy.zeros(param.get_value().shape, dtype='float32')) for param in params]
    
    m_gradient      =   [momentum * gb + (cast32(1) - momentum) * g for (gb, g) in zip(gradient_buffer, gradient)]
    param_updates   =   [(param, param - learning_rate * mg) for (param, mg) in zip(params, m_gradient)]
    gradient_buffer_updates = zip(gradient_buffer, m_gradient)
        
    updates         =   OrderedDict(param_updates + gradient_buffer_updates)
    
    
    f_cost          =   theano.function(inputs  = recurrent_hiddens_input + [X, X1], 
                                        outputs = recurrent_hiddens_output + [show_COST_pre, show_COST_post], 
                                        on_unused_input='warn')

    f_learn         =   theano.function(inputs  = recurrent_hiddens_input + [X, X1], 
                                        updates = updates, 
                                        outputs = recurrent_hiddens_output + [show_COST_pre, show_COST_post],
                                        on_unused_input='warn')
    
    f_learn_init    =   theano.function(inputs  = [X], 
                                        updates = updates_init, 
                                        outputs = [show_COST_pre],
                                        on_unused_input='warn')
       
    
    
    recurrent_gradient        =   T.grad(COST_post, recurrent_params)
    recurrent_gradient_buffer =   [theano.shared(numpy.zeros(param.get_value().shape, dtype='float32')) for param in recurrent_params]
    recurrent_m_gradient      =   [momentum * gb + (cast32(1) - momentum) * g for (gb, g) in zip(recurrent_gradient_buffer, recurrent_gradient)]
    recurrent_param_updates   =   [(param, param - recurrent_learning_rate * mg) for (param, mg) in zip(recurrent_params, recurrent_m_gradient)]
    recurrent_gradient_buffer_updates = zip(recurrent_gradient_buffer, recurrent_m_gradient)
        
    recurrent_updates         =   OrderedDict(recurrent_param_updates + recurrent_gradient_buffer_updates)
    

    recurrent_f_learn         =   theano.function(inputs  = recurrent_hiddens_input + [X,X1],
                                                  updates = recurrent_updates,
                                                  outputs = recurrent_hiddens_output + [show_COST_post],
                                                  on_unused_input='warn')

    print "functions done."
    print
    
    #############
    # Denoise some numbers  :   show number, noisy number, reconstructed number
    #############
    import random as R
    R.seed(1)
    # Grab 100 random indices from test_X
    random_idx      =   numpy.array(R.sample(range(len(test_X.get_value())), 100))
    numbers         =   test_X.get_value()[random_idx]
    
    f_noise         =   theano.function(inputs = [X], outputs = salt_and_pepper(X, state.input_salt_and_pepper))
    noisy_numbers   =   f_noise(test_X.get_value()[random_idx])
    #noisy_numbers   =   salt_and_pepper(numbers, state.input_salt_and_pepper)

    # Recompile the graph without noise for reconstruction function
    X_recon          = T.fvector("X_recon")
    hiddens_R        = [X_recon]
    hiddens_R_input  = [T.fvector(name="h_recon_visible")] + [T.fvector(name="h_recon_"+str(i+1)) for i in range(layers)]
    hiddens_R_output = hiddens_R_input[:1] + hiddens_R_input[1:]

    for w in weights_list:
        hiddens_R.append(T.zeros_like(T.dot(hiddens_R[-1], w)))

    # The layer update scheme
    print "Creating graph for noisy reconstruction function at checkpoints during training."
    hiddens_R, recurrent_hiddens_output, p_X_chain_R, p_H_chain_R, p_X1_chain_R = build_graph(hiddens_R, hiddens_R_output, noiseflag=False)

    f_recon = theano.function(inputs = hiddens_R_input+[X_recon], 
                              outputs = hiddens_R_output+[p_X_chain_R[-1] ,p_X1_chain_R[-1]], 
                              on_unused_input="warn")


    ############
    # Sampling #
    ############
    
    # the input to the sampling function
    network_state_input     =   [X] + [T.fmatrix() for i in range(layers)]
   
    # "Output" state of the network (noisy)
    # initialized with input, then we apply updates
    #network_state_output    =   network_state_input
    
    network_state_output    =   [X] + network_state_input[1:]

    visible_pX_chain        =   []

    # ONE update
    print "Performing one walkback in network state sampling."
    update_layers(network_state_output, visible_pX_chain, noisy=True)

    if layers == 1: 
        f_sample_simple = theano.function(inputs = [X], outputs = visible_pX_chain[-1])
    
    
    # WHY IS THERE A WARNING????
    # because the first odd layers are not used -> directly computed FROM THE EVEN layers
    # unused input = warn
    f_sample2   =   theano.function(inputs = network_state_input, outputs = network_state_output + visible_pX_chain, on_unused_input='warn')

    def sample_some_numbers_single_layer():
        x0    =   test_X.get_value()[:1]
        samples = [x0]
        x  =   f_noise(x0)
        for i in range(399):
            x = f_sample_simple(x)
            samples.append(x)
            x = numpy.random.binomial(n=1, p=x, size=x.shape).astype('float32')
            x = f_noise(x)
        return numpy.vstack(samples)
            
    def sampling_wrapper(NSI):
        # * is the "splat" operator: It takes a list as input, and expands it into actual positional arguments in the function call.
        out             =   f_sample2(*NSI)
        NSO             =   out[:len(network_state_output)]
        vis_pX_chain    =   out[len(network_state_output):]
        return NSO, vis_pX_chain

    def sample_some_numbers(N=400):
        # The network's initial state
        init_vis        =   test_X.get_value()[:1]

        noisy_init_vis  =   f_noise(init_vis)

        network_state   =   [[noisy_init_vis] + [numpy.zeros((1,len(b.get_value())), dtype='float32') for b in bias_list[1:]]]

        visible_chain   =   [init_vis]

        noisy_h0_chain  =   [noisy_init_vis]

        for i in range(N-1):
           
            # feed the last state into the network, compute new state, and obtain visible units expectation chain 
            net_state_out, vis_pX_chain =   sampling_wrapper(network_state[-1])

            # append to the visible chain
            visible_chain   +=  vis_pX_chain

            # append state output to the network state chain
            network_state.append(net_state_out)
            
            noisy_h0_chain.append(net_state_out[0])

        return numpy.vstack(visible_chain), numpy.vstack(noisy_h0_chain)
    
    def plot_samples(epoch_number, iteration):
        to_sample = time.time()
        if layers == 1:
            # one layer model
            V = sample_some_numbers_single_layer()
        else:
            V, H0 = sample_some_numbers()
        img_samples =   PIL.Image.fromarray(tile_raster_images(V, (root_N_input,root_N_input), (20,20)))
        
        fname       =   outdir+'samples_iteration_'+str(iteration)+'_epoch_'+str(epoch_number)+'.png'
        img_samples.save(fname) 
        print 'Took ' + str(time.time() - to_sample) + ' to sample 400 numbers'
   
    ##############
    # Inpainting #
    ##############
    def inpainting(digit):
        # The network's initial state

        # NOISE INIT
        init_vis    =   cast32(numpy.random.uniform(size=digit.shape))

        #noisy_init_vis  =   f_noise(init_vis)
        #noisy_init_vis  =   cast32(numpy.random.uniform(size=init_vis.shape))

        # INDEXES FOR VISIBLE AND NOISY PART
        noise_idx = (numpy.arange(N_input) % root_N_input < (root_N_input/2))
        fixed_idx = (numpy.arange(N_input) % root_N_input > (root_N_input/2))
        # function to re-init the visible to the same noise

        # FUNCTION TO RESET HALF VISIBLE TO DIGIT
        def reset_vis(V):
            V[0][fixed_idx] =   digit[0][fixed_idx]
            return V
        
        # INIT DIGIT : NOISE and RESET HALF TO DIGIT
        init_vis = reset_vis(init_vis)

        network_state   =   [[init_vis] + [numpy.zeros((1,len(b.get_value())), dtype='float32') for b in bias_list[1:]]]

        visible_chain   =   [init_vis]

        noisy_h0_chain  =   [init_vis]

        for i in range(49):
           
            # feed the last state into the network, compute new state, and obtain visible units expectation chain 
            net_state_out, vis_pX_chain =   sampling_wrapper(network_state[-1])


            # reset half the digit
            net_state_out[0] = reset_vis(net_state_out[0])
            vis_pX_chain[0]  = reset_vis(vis_pX_chain[0])

            # append to the visible chain
            visible_chain   +=  vis_pX_chain

            # append state output to the network state chain
            network_state.append(net_state_out)
            
            noisy_h0_chain.append(net_state_out[0])

        return numpy.vstack(visible_chain), numpy.vstack(noisy_h0_chain)

    def save_params_to_file(name, n, params, iteration):
        print 'saving parameters...'
        save_path = outdir+name+'_params_iteration_'+str(iteration)+'_epoch_'+str(n)+'.pkl'
        f = open(save_path, 'wb')
        try:
            cPickle.dump(params, f, protocol=cPickle.HIGHEST_PROTOCOL)
        finally:
            f.close() 


    ################
    # GSN TRAINING #
    ################
    def train_GSN(iteration, train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        print '----------------------------------------'
        print 'TRAINING GSN FOR ITERATION',iteration
        with open(logfile,'a') as f:
            f.write("--------------------------\nTRAINING GSN FOR ITERATION {0!s}\n".format(iteration))
        
        # TRAINING
        n_epoch     =   state.n_epoch
        batch_size  =   state.batch_size
        STOP        =   False
        counter     =   0
        if iteration == 0:
            learning_rate.set_value(cast32(state.learning_rate))  # learning rate
        times = []
        best_cost = float('inf')
        patience = 0
            
        print 'learning rate:',learning_rate.get_value()
        
        print 'train X size:',str(train_X.shape.eval())
        print 'valid X size:',str(valid_X.shape.eval())
        print 'test X size:',str(test_X.shape.eval())
    
        pre_train_costs =   []
        pre_valid_costs =   []
        pre_test_costs  =   []
        post_train_costs =   []
        post_valid_costs =   []
        post_test_costs  =   []
        
        if state.vis_init:
            bias_list[0].set_value(logit(numpy.clip(0.9,0.001,train_X.get_value().mean(axis=0))))
    
        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            print 'Testing : skip training'
            STOP    =   True
    
    
        while not STOP:
            counter += 1
            t = time.time()
            print counter,'\t',
            with open(logfile,'a') as f:
                f.write("{0!s}\t".format(counter))
            
            #shuffle the data
            data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
            
            #train
            #init recurrent hiddens as zero
            recurrent_hiddens = [T.zeros((batch_size,recurrent_layer_size)).eval() for recurrent_layer_size in recurrent_layer_sizes]
            pre_train_cost = []
            post_train_cost = []
            if iteration == 0:
                for i in range(len(train_X.get_value(borrow=True)) / batch_size):
                    x = train_X.get_value()[i * batch_size : (i+1) * batch_size]
                    pre = f_learn_init(x)
                    pre_train_cost.append(pre)
                    post_train_cost.append(-1)
            else:
                for i in range(len(train_X.get_value(borrow=True)) / batch_size):
                    x = train_X.get_value()[i * batch_size : (i+1) * batch_size]
                    x1 = train_X.get_value()[(i * batch_size) + 1 : ((i+1) * batch_size) + 1]
                    [x,x1], recurrent_hiddens = fix_input_size([x,x1], recurrent_hiddens)
                    _ins = recurrent_hiddens + [x,x1]
                    _outs = f_learn(*_ins)
                    recurrent_hiddens = _outs[:len(recurrent_hiddens)]
                    pre = _outs[-2]
                    post = _outs[-1]
                    pre_train_cost.append(pre)
                    post_train_cost.append(post)
                    
                                    
            pre_train_cost = numpy.mean(pre_train_cost) 
            pre_train_costs.append(pre_train_cost)
            post_train_cost = numpy.mean(post_train_cost) 
            post_train_costs.append(post_train_cost)
            print 'Train : ',trunc(pre_train_cost),trunc(post_train_cost), '\t',
            with open(logfile,'a') as f:
                f.write("Train : {0!s} {1!s}\t".format(trunc(pre_train_cost),trunc(post_train_cost)))
            with open(train_convergence_pre,'a') as f:
                f.write("{0!s},".format(pre_train_cost))
            with open(train_convergence_post,'a') as f:
                f.write("{0!s},".format(post_train_cost))
    
            #valid
            #init recurrent hiddens as zero
            recurrent_hiddens = [T.zeros((batch_size,recurrent_layer_size)).eval() for recurrent_layer_size in recurrent_layer_sizes]
            pre_valid_cost  =   []    
            post_valid_cost  =  []
            for i in range(len(valid_X.get_value(borrow=True)) / batch_size):
                x = valid_X.get_value()[i * batch_size : (i+1) * batch_size]
                x1 = valid_X.get_value()[(i * batch_size) + 1 : ((i+1) * batch_size) + 1]
                [x,x1], recurrent_hiddens = fix_input_size([x,x1], recurrent_hiddens)
                _ins = recurrent_hiddens + [x,x1]
                _outs = f_cost(*_ins)
                recurrent_hiddens = _outs[:len(recurrent_hiddens)]
                pre = _outs[-2]
                post = _outs[-1]
                pre_valid_cost.append(pre)
                post_valid_cost.append(post)
                    
            pre_valid_cost = numpy.mean(pre_valid_cost) 
            pre_valid_costs.append(pre_valid_cost)
            post_valid_cost = numpy.mean(post_valid_cost) 
            post_valid_costs.append(post_valid_cost)
            print 'Valid : ', trunc(pre_valid_cost),trunc(post_valid_cost), '\t',
            with open(logfile,'a') as f:
                f.write("Valid : {0!s} {1!s}\t".format(trunc(pre_valid_cost),trunc(post_valid_cost)))
            with open(valid_convergence_pre,'a') as f:
                f.write("{0!s},".format(pre_valid_cost))
            with open(valid_convergence_post,'a') as f:
                f.write("{0!s},".format(post_valid_cost))
    
            #test
            #init recurrent hiddens as zero
            recurrent_hiddens = [T.zeros((batch_size,recurrent_layer_size)).eval() for recurrent_layer_size in recurrent_layer_sizes]
            pre_test_cost  =   []
            post_test_cost  =   []
            for i in range(len(test_X.get_value(borrow=True)) / batch_size):
                x = test_X.get_value()[i * batch_size : (i+1) * batch_size]
                x1 = test_X.get_value()[(i * batch_size) + 1 : ((i+1) * batch_size) + 1]
                [x,x1], recurrent_hiddens = fix_input_size([x,x1], recurrent_hiddens)
                _ins = recurrent_hiddens + [x,x1]
                _outs = f_cost(*_ins)
                recurrent_hiddens = _outs[:len(recurrent_hiddens)]
                pre = _outs[-2]
                post = _outs[-1]
                pre_test_cost.append(pre)
                post_test_cost.append(post)
                
            pre_test_cost = numpy.mean(pre_test_cost) 
            pre_test_costs.append(pre_test_cost)
            post_test_cost = numpy.mean(post_test_cost) 
            post_test_costs.append(post_test_cost)
            print 'Test  : ', trunc(pre_test_cost),trunc(post_test_cost), '\t',
            with open(logfile,'a') as f:
                f.write("Test : {0!s} {1!s}\t".format(trunc(pre_test_cost),trunc(post_test_cost)))
            with open(test_convergence_pre,'a') as f:
                f.write("{0!s},".format(pre_test_cost))
            with open(test_convergence_post,'a') as f:
                f.write("{0!s},".format(post_test_cost))
    
            #check for early stopping
            cost = pre_train_cost
            if iteration != 0:
                cost = cost + post_train_cost
            if cost < best_cost*state.early_stop_threshold:
                patience = 0
                best_cost = cost
            else:
                patience += 1
    
            if counter >= n_epoch or patience >= state.early_stop_length:
                STOP = True
                save_params_to_file('gsn', counter, params, iteration)
                print "next learning rate should be", learning_rate.get_value() * annealing
                
            timing = time.time() - t
            times.append(timing)
    
            print 'time : ', trunc(timing),
            
            print 'remaining: ', (n_epoch - counter) * numpy.mean(times) / 60 / 60, 'hrs'
                    
            if (counter % state.save_frequency) == 0:
                # Checking reconstruction
                nums = test_X.get_value()[range(100)]
                noisy_nums = f_noise(test_X.get_value()[range(100)])
                reconstructed = []
                reconstructed_prediction = []
                #init recurrent hiddens as zero
                recurrent_hiddens = [T.zeros((batch_size,recurrent_layer_size)).eval() for recurrent_layer_size in recurrent_layer_sizes]
                for num in noisy_nums:
                    _ins = recurrent_hiddens + [num]
                    _outs = f_recon(*_ins)
                    recurrent_hiddens = _outs[:len(recurrent_hiddens)]
                    [recon,recon_pred] = _outs[len(recurrent_hiddens):]
                    reconstructed.append(recon)
                    reconstructed_prediction.append(recon_pred)
                # Concatenate stuff
                stacked = numpy.vstack([numpy.vstack([nums[i*10 : (i+1)*10], noisy_nums[i*10 : (i+1)*10], reconstructed[i*10 : (i+1)*10], reconstructed_prediction[i*10 : (i+1)*10]]) for i in range(10)])
            
                number_reconstruction   =   PIL.Image.fromarray(tile_raster_images(stacked, (root_N_input,root_N_input), (10,40)))
                #epoch_number    =   reduce(lambda x,y : x + y, ['_'] * (4-len(str(counter)))) + str(counter)
                number_reconstruction.save(outdir+'gsn_number_reconstruction_iteration_'+str(iteration)+'_epoch_'+str(counter)+'.png')
        
                #sample_numbers(counter, 'seven')
                plot_samples(counter, iteration)
        
                #save params
                save_params_to_file('gsn', counter, params, iteration)
         
            # ANNEAL!
            new_lr = learning_rate.get_value() * annealing
            learning_rate.set_value(new_lr)
    
        # if test
    
        # 10k samples
        print 'Generating 10,000 samples'
        samples, _  =   sample_some_numbers(N=10000)
        f_samples   =   outdir+'samples.npy'
        numpy.save(f_samples, samples)
        print 'saved digits'
    
    
        # parzen
#         print 'Evaluating parzen window'
#         import likelihood_estimation_parzen
#         likelihood_estimation_parzen.main(0.20,'mnist') 
    
        # Inpainting
        '''
        print 'Inpainting'
        test_X  =   test_X.get_value()
    
        numpy.random.seed(2)
        test_idx    =   numpy.arange(len(test_Y.get_value(borrow=True)))
    
        for Iter in range(10):
    
            numpy.random.shuffle(test_idx)
            test_X = test_X[test_idx]
            test_Y = test_Y[test_idx]
    
            digit_idx = [(test_Y==i).argmax() for i in range(10)]
            inpaint_list = []
    
            for idx in digit_idx:
                DIGIT = test_X[idx:idx+1]
                V_inpaint, H_inpaint = inpainting(DIGIT)
                inpaint_list.append(V_inpaint)
    
            INPAINTING  =   numpy.vstack(inpaint_list)
    
            plot_inpainting =   PIL.Image.fromarray(tile_raster_images(INPAINTING, (root_N_input,root_N_input), (10,50)))
    
            fname   =   'inpainting_'+str(Iter)+'_iteration_'+str(iteration)+'.png'
            #fname   =   os.path.join(state.model_path, fname)
    
            plot_inpainting.save(fname)
    '''        
            
            
            
            
    def train_regression(iteration, train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        print '-------------------------------------------'
        print 'TRAINING RECURRENT REGRESSION FOR ITERATION',iteration
        with open(logfile,'a') as f:
            f.write("--------------------------\nTRAINING RECURRENT REGRESSION FOR ITERATION {0!s}\n".format(iteration))
        
        # TRAINING
        # TRAINING
        n_epoch     =   state.n_epoch
        batch_size  =   state.batch_size
        STOP        =   False
        counter     =   0
        if iteration == 0:
            recurrent_learning_rate.set_value(cast32(state.learning_rate))  # learning rate
        times = []
        best_cost = float('inf')
        patience = 0
            
        print 'learning rate:',recurrent_learning_rate.get_value()
        
        print 'train X size:',str(train_X.shape.eval())
        print 'valid X size:',str(valid_X.shape.eval())
        print 'test X size:',str(test_X.shape.eval())
    
        train_costs =   []
        valid_costs =   []
        test_costs  =   []
        
        if state.vis_init:
            bias_list[0].set_value(logit(numpy.clip(0.9,0.001,train_X.get_value().mean(axis=0))))
    
        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            print 'Testing : skip training'
            STOP    =   True
    
    
        while not STOP:
            counter += 1
            t = time.time()
            print counter,'\t',
            
            #shuffle the data
            data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
            
            #train
            #init recurrent hiddens as zero
            recurrent_hiddens = [T.zeros((batch_size,recurrent_layer_size)).eval() for recurrent_layer_size in recurrent_layer_sizes]
            train_cost = []
            for i in range(len(train_X.get_value(borrow=True)) / batch_size):
                x = train_X.get_value()[i * batch_size : (i+1) * batch_size]
                x1 = train_X.get_value()[(i * batch_size) + 1 : ((i+1) * batch_size) + 1]
                [x,x1], recurrent_hiddens = fix_input_size([x,x1], recurrent_hiddens)
                _ins = recurrent_hiddens + [x,x1]
                _outs = recurrent_f_learn(*_ins)
                recurrent_hiddens = _outs[:len(recurrent_hiddens)]
                cost = _outs[-1]
                train_cost.append(cost)
                
            train_cost = numpy.mean(train_cost) 
            train_costs.append(train_cost)
            print 'rTrain : ',trunc(train_cost), '\t',
            with open(logfile,'a') as f:
                f.write("rTrain : {0!s}\t".format(trunc(train_cost)))
            with open(recurrent_train_convergence,'a') as f:
                f.write("{0!s},".format(train_cost))
    
            #valid
            #init recurrent hiddens as zero
            recurrent_hiddens = [T.zeros((batch_size,recurrent_layer_size)).eval() for recurrent_layer_size in recurrent_layer_sizes]
            valid_cost  =  []
            for i in range(len(valid_X.get_value(borrow=True)) / batch_size):
                x = valid_X.get_value()[i * batch_size : (i+1) * batch_size]
                x1 = valid_X.get_value()[(i * batch_size) + 1 : ((i+1) * batch_size) + 1]
                [x,x1], recurrent_hiddens = fix_input_size([x,x1], recurrent_hiddens)
                _ins = recurrent_hiddens + [x,x1]
                _outs = f_cost(*_ins)
                recurrent_hiddens = _outs[:len(recurrent_hiddens)]
                cost = _outs[-1]
                valid_cost.append(cost)
                    
            valid_cost = numpy.mean(valid_cost) 
            valid_costs.append(valid_cost)
            print 'rValid : ', trunc(valid_cost), '\t',
            with open(logfile,'a') as f:
                f.write("rValid : {0!s}\t".format(trunc(valid_cost)))
            with open(recurrent_valid_convergence,'a') as f:
                f.write("{0!s},".format(valid_cost))
    
            #test
            recurrent_hiddens = [T.zeros((batch_size,recurrent_layer_size)).eval() for recurrent_layer_size in recurrent_layer_sizes]
            test_cost  =   []
            for i in range(len(test_X.get_value(borrow=True)) / batch_size):
                x = test_X.get_value()[i * batch_size : (i+1) * batch_size]
                x1 = test_X.get_value()[(i * batch_size) + 1 : ((i+1) * batch_size) + 1]
                [x,x1], recurrent_hiddens = fix_input_size([x,x1], recurrent_hiddens)
                _ins = recurrent_hiddens + [x,x1]
                _outs = f_cost(*_ins)
                recurrent_hiddens = _outs[:len(recurrent_hiddens)]
                cost = _outs[-1]
                test_cost.append(cost)
                
            test_cost = numpy.mean(test_cost) 
            test_costs.append(test_cost)
            print 'rTest  : ', trunc(test_cost), '\t',
            with open(logfile,'a') as f:
                f.write("rTest : {0!s}\t".format(trunc(test_cost)))
            with open(recurrent_test_convergence,'a') as f:
                f.write("{0!s},".format(test_cost))
    
            #check for early stopping
            cost = train_cost
            if iteration != 0:
                cost = cost + train_cost
            if cost < best_cost*state.early_stop_threshold:
                patience = 0
                best_cost = cost
            else:
                patience += 1
                
            timing = time.time() - t
            times.append(timing)
    
            print 'time : ', trunc(timing),
            
            print 'remaining: ', trunc((n_epoch - counter) * numpy.mean(times) / 60 / 60), 'hrs'
            
            with open(logfile,'a') as f:
                f.write("B : {0!s}\t".format(str([trunc(vb.get_value().mean()) for vb in recurrent_bias_list])))
                
            with open(logfile,'a') as f:
                f.write("W : {0!s}\t".format(str([trunc(abs(v.get_value(borrow=True)).mean()) for v in recurrent_weights_list_encode])))
            
            with open(logfile,'a') as f:
                f.write("V : {0!s}\t".format(str([trunc(abs(v.get_value(borrow=True)).mean()) for v in recurrent_weights_list_decode])))
                
            with open(logfile,'a') as f:
                f.write("Time : {0!s} seconds\n".format(trunc(timing)))
                    
            if (counter % state.save_frequency) == 0:
                # Checking reconstruction
                nums = test_X.get_value()[range(100)]
                noisy_nums = f_noise(test_X.get_value()[range(100)])
                reconstructed = []
                reconstructed_prediction = []
                #init recurrent hiddens as zero
                recurrent_hiddens = [T.zeros((batch_size,recurrent_layer_size)).eval() for recurrent_layer_size in recurrent_layer_sizes]
                for num in noisy_nums:
                    _ins = recurrent_hiddens + [num]
                    _outs = f_recon(*_ins)
                    recurrent_hiddens = _outs[:len(recurrent_hiddens)]
                    [recon,recon_pred] = _outs[len(recurrent_hiddens):]
                    reconstructed.append(recon)
                    reconstructed_prediction.append(recon_pred)
                # Concatenate stuff
                stacked = numpy.vstack([numpy.vstack([nums[i*10 : (i+1)*10], noisy_nums[i*10 : (i+1)*10], reconstructed[i*10 : (i+1)*10], reconstructed_prediction[i*10 : (i+1)*10]]) for i in range(10)])
                
                number_reconstruction   =   PIL.Image.fromarray(tile_raster_images(stacked, (root_N_input,root_N_input), (10,40)))
                #epoch_number    =   reduce(lambda x,y : x + y, ['_'] * (4-len(str(counter)))) + str(counter)
                number_reconstruction.save(outdir+'recurrent_number_reconstruction_iteration_'+str(iteration)+'_epoch_'+str(counter)+'.png')
        
                #sample_numbers(counter, 'seven')
                plot_samples(counter, iteration)
        
                #save params
                save_params_to_file('recurrent', counter, params, iteration)
         
            # ANNEAL!
            new_r_lr = recurrent_learning_rate.get_value() * annealing
            recurrent_learning_rate.set_value(new_r_lr)
    
        # if test
    
        # 10k samples
        print 'Generating 10,000 samples'
        samples, _  =   sample_some_numbers(N=10000)
        f_samples   =   outdir+'samples.npy'
        numpy.save(f_samples, samples)
        print 'saved digits'
            
            
            
            
    #####################
    # STORY 3 ALGORITHM #
    #####################
    for iter in range(state.max_iterations):
        train_GSN(iter, train_X, train_Y, valid_X, valid_Y, test_X, test_Y)        
        train_regression(iter, train_X, train_Y, valid_X, valid_Y, test_X, test_Y) 
    def train(self, batch_size=100, num_epochs=300):
        '''Train the RNN-RBM via stochastic gradient descent (SGD) using MIDI
files converted to piano-rolls.

files : list of strings
  List of MIDI files that will be loaded as piano-rolls for training.
batch_size : integer
  Training sequences will be split into subsequences of at most this size
  before applying the SGD updates.
num_epochs : integer
  Number of epochs (pass over the training set) performed. The user can
  safely interrupt training with Ctrl+C at any time.'''

        (train_X,
         train_Y), (valid_X,
                    valid_Y), (test_X,
                               test_Y) = data.load_mnist("../datasets/")
        train_X = numpy.concatenate((train_X, valid_X))
        train_Y = numpy.concatenate((train_Y, valid_Y))

        print 'Sequencing MNIST data...'
        print 'train set size:', train_X.shape
        print 'valid set size:', valid_X.shape
        print 'test set size:', test_X.shape

        train_X = theano.shared(train_X)
        train_Y = theano.shared(train_Y)
        valid_X = theano.shared(valid_X)
        valid_Y = theano.shared(valid_Y)
        test_X = theano.shared(test_X)
        test_Y = theano.shared(test_Y)

        data.sequence_mnist_data(train_X,
                                 train_Y,
                                 valid_X,
                                 valid_Y,
                                 test_X,
                                 test_Y,
                                 dataset=4)

        print 'train set size:', train_X.shape.eval()
        print 'valid set size:', valid_X.shape.eval()
        print 'test set size:', test_X.shape.eval()
        print 'Sequencing done.'
        print

        N_input = train_X.eval().shape[1]
        self.root_N_input = numpy.sqrt(N_input)

        times = []

        try:
            for epoch in xrange(num_epochs):
                t = time.time()
                print 'Epoch %i/%i : ' % (epoch + 1, num_epochs)
                # sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
                accuracy = []
                costs = []
                crossentropy = []
                tests = []
                test_acc = []

                for i in range(
                        len(train_X.get_value(borrow=True)) / batch_size):
                    t0 = time.time()
                    xs = train_X.get_value(
                        borrow=True)[(i * batch_size):((i + 1) * batch_size)]
                    acc, cost, cross = self.train_function(xs)
                    accuracy.append(acc)
                    costs.append(cost)
                    crossentropy.append(cross)
                    print time.time() - t0

                print 'Train', numpy.mean(accuracy), 'cost', numpy.mean(
                    costs), 'cross', numpy.mean(crossentropy),

                for i in range(
                        len(test_X.get_value(borrow=True)) / batch_size):
                    xs = train_X.get_value(
                        borrow=True)[(i * batch_size):((i + 1) * batch_size)]
                    acc, cost = self.test_function(xs)
                    test_acc.append(acc)
                    tests.append(cost)

                print '\t Test_acc', numpy.mean(test_acc), "cross", numpy.mean(
                    tests)

                timing = time.time() - t
                times.append(timing)
                print 'time : ', trunc(timing),
                print 'remaining: ', (
                    num_epochs -
                    (epoch + 1)) * numpy.mean(times) / 60 / 60, 'hrs'
                sys.stdout.flush()

                #new learning rate
                new_lr = self.lr.get_value() * self.annealing
                self.lr.set_value(new_lr)

        except KeyboardInterrupt:
            print 'Interrupted by user.'
def experiment(state, outdir_base='./'):
    rng.seed(1) #seed the numpy random generator  
    R.seed(1) #seed the other random generator (for reconstruction function indices)
    # Initialize output directory and files
    data.mkdir_p(outdir_base)
    outdir = outdir_base + "/" + state.dataset + "/"
    data.mkdir_p(outdir)
    logger = Logger(outdir)
    logger.log("----------MODEL 2, {0!s}-----------\n".format(state.dataset))
    if state.initialize_gsn:
        gsn_train_convergence = outdir+"gsn_train_convergence.csv"
        gsn_valid_convergence = outdir+"gsn_valid_convergence.csv"
        gsn_test_convergence  = outdir+"gsn_test_convergence.csv"
    train_convergence = outdir+"train_convergence.csv"
    valid_convergence = outdir+"valid_convergence.csv"
    test_convergence  = outdir+"test_convergence.csv"
    if state.initialize_gsn:
        init_empty_file(gsn_train_convergence)
        init_empty_file(gsn_valid_convergence)
        init_empty_file(gsn_test_convergence)
    init_empty_file(train_convergence)
    init_empty_file(valid_convergence)
    init_empty_file(test_convergence)
    
    #load parameters from config file if this is a test
    config_filename = outdir+'config'
    if state.test_model and 'config' in os.listdir(outdir):
        config_vals = load_from_config(config_filename)
        for CV in config_vals:
            logger.log(CV)
            if CV.startswith('test'):
                logger.log('Do not override testing switch')
                continue        
            try:
                exec('state.'+CV) in globals(), locals()
            except:
                exec('state.'+CV.split('=')[0]+"='"+CV.split('=')[1]+"'") in globals(), locals()
    else:
        # Save the current configuration
        # Useful for logs/experiments
        logger.log('Saving config')
        with open(config_filename, 'w') as f:
            f.write(str(state))

    logger.log(state)
    
    ####################################################
    # Load the data, train = train+valid, and sequence #
    ####################################################
    artificial = False
    if state.dataset == 'MNIST_1' or state.dataset == 'MNIST_2' or state.dataset == 'MNIST_3':
        (train_X, train_Y), (valid_X, valid_Y), (test_X, test_Y) = data.load_mnist(state.data_path)
        train_X = numpy.concatenate((train_X, valid_X))
        train_Y = numpy.concatenate((train_Y, valid_Y))
        artificial = True
        try:
            dataset = int(state.dataset.split('_')[1])
        except:
            logger.log("ERROR: artificial dataset number not recognized. Input was "+str(state.dataset))
            raise AssertionError("artificial dataset number not recognized. Input was "+str(state.dataset))
    else:
        logger.log("ERROR: dataset not recognized.")
        raise AssertionError("dataset not recognized.")
    
    # transfer the datasets into theano shared variables
    train_X, train_Y = data.shared_dataset((train_X, train_Y), borrow=True)
    valid_X, valid_Y = data.shared_dataset((valid_X, valid_Y), borrow=True)
    test_X, test_Y   = data.shared_dataset((test_X, test_Y), borrow=True)
   
    if artificial:
        logger.log('Sequencing MNIST data...')
        logger.log(['train set size:',len(train_Y.eval())])
        logger.log(['train set size:',len(valid_Y.eval())])
        logger.log(['train set size:',len(test_Y.eval())])
        data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
        logger.log(['train set size:',len(train_Y.eval())])
        logger.log(['train set size:',len(valid_Y.eval())])
        logger.log(['train set size:',len(test_Y.eval())])
        logger.log('Sequencing done.\n')
    
    
    N_input =   train_X.eval().shape[1]
    root_N_input = numpy.sqrt(N_input)  
    
    # Network and training specifications
    layers      = state.layers # number hidden layers
    walkbacks   = state.walkbacks # number of walkbacks 
    layer_sizes = [N_input] + [state.hidden_size] * layers # layer sizes, from h0 to hK (h0 is the visible layer)
    
    learning_rate = theano.shared(cast32(state.learning_rate))  # learning rate
    annealing     = cast32(state.annealing) # exponential annealing coefficient
    momentum      = theano.shared(cast32(state.momentum)) # momentum term 

    ##############
    # PARAMETERS #
    ##############
    #gsn
    weights_list = [get_shared_weights(layer_sizes[i], layer_sizes[i+1], name="W_{0!s}_{1!s}".format(i,i+1)) for i in range(layers)] # initialize each layer to uniform sample from sqrt(6. / (n_in + n_out))
    bias_list    = [get_shared_bias(layer_sizes[i], name='b_'+str(i)) for i in range(layers + 1)] # initialize each layer to 0's.
    
    #recurrent
    recurrent_to_gsn_bias_weights_list = [get_shared_weights(state.recurrent_hidden_size, layer_sizes[layer], name="W_u_b{0!s}".format(layer)) for layer in range(layers+1)]
    W_u_u = get_shared_weights(state.recurrent_hidden_size, state.recurrent_hidden_size, name="W_u_u")
    W_x_u = get_shared_weights(N_input, state.recurrent_hidden_size, name="W_x_u")
    recurrent_bias = get_shared_bias(state.recurrent_hidden_size, name='b_u')
    
    #lists for use with gradients
    gsn_params = weights_list + bias_list
    u_params   = [W_u_u, W_x_u, recurrent_bias]
    params     = gsn_params + recurrent_to_gsn_bias_weights_list + u_params
    
    ###########################################################
    #           load initial parameters of gsn                #
    ###########################################################
    train_gsn_first = False
    if state.initialize_gsn:
        params_to_load = 'gsn_params.pkl'
        if not os.path.isfile(params_to_load):
            train_gsn_first = True 
        else:
            logger.log("\nLoading existing GSN parameters\n")
            loaded_params = cPickle.load(open(params_to_load,'r'))
            [p.set_value(lp.get_value(borrow=False)) for lp, p in zip(loaded_params[:len(weights_list)], weights_list)]
            [p.set_value(lp.get_value(borrow=False)) for lp, p in zip(loaded_params[len(weights_list):], bias_list)]
    
    
    ############################
    # Theano variables and RNG #
    ############################
    MRG = RNG_MRG.MRG_RandomStreams(1)
    X = T.fmatrix('X') #single (batch) for training gsn
    Xs = T.fmatrix(name="Xs") #sequence for training rnn-gsn
    
 
    ########################
    # ACTIVATION FUNCTIONS #
    ########################
    # hidden activation
    if state.hidden_act == 'sigmoid':
        logger.log('Using sigmoid activation for hiddens')
        hidden_activation = T.nnet.sigmoid
    elif state.hidden_act == 'rectifier':
        logger.log('Using rectifier activation for hiddens')
        hidden_activation = lambda x : T.maximum(cast32(0), x)
    elif state.hidden_act == 'tanh':
        logger.log('Using hyperbolic tangent activation for hiddens')
        hidden_activation = lambda x : T.tanh(x)
    else:
        logger.log("ERROR: Did not recognize hidden activation {0!s}, please use tanh, rectifier, or sigmoid".format(state.hidden_act))
        raise NotImplementedError("Did not recognize hidden activation {0!s}, please use tanh, rectifier, or sigmoid".format(state.hidden_act))
    
    # visible activation
    if state.visible_act == 'sigmoid':
        logger.log('Using sigmoid activation for visible layer')
        visible_activation = T.nnet.sigmoid
    elif state.visible_act == 'softmax':
        logger.log('Using softmax activation for visible layer')
        visible_activation = T.nnet.softmax
    else:
        logger.log("ERROR: Did not recognize visible activation {0!s}, please use sigmoid or softmax".format(state.visible_act))
        raise NotImplementedError("Did not recognize visible activation {0!s}, please use sigmoid or softmax".format(state.visible_act))
    
    # recurrent activation
    if state.recurrent_hidden_act == 'sigmoid':
        logger.log('Using sigmoid activation for recurrent hiddens')
        recurrent_hidden_activation = T.nnet.sigmoid
    elif state.recurrent_hidden_act == 'rectifier':
        logger.log('Using rectifier activation for recurrent hiddens')
        recurrent_hidden_activation = lambda x : T.maximum(cast32(0), x)
    elif state.recurrent_hidden_act == 'tanh':
        logger.log('Using hyperbolic tangent activation for recurrent hiddens')
        recurrent_hidden_activation = lambda x : T.tanh(x)
    else:
        logger.log("ERROR: Did not recognize recurrent hidden activation {0!s}, please use tanh, rectifier, or sigmoid".format(state.recurrent_hidden_act))
        raise NotImplementedError("Did not recognize recurrent hidden activation {0!s}, please use tanh, rectifier, or sigmoid".format(state.recurrent_hidden_act))
    
    logger.log("\n")
    
    ####################
    #  COST FUNCTIONS  #
    ####################
    if state.cost_funct == 'binary_crossentropy':
        logger.log('Using binary cross-entropy cost!')
        cost_function = lambda x,y: T.mean(T.nnet.binary_crossentropy(x,y))
    elif state.cost_funct == 'square':
        logger.log("Using square error cost!")
        #cost_function = lambda x,y: T.log(T.mean(T.sqr(x-y)))
        cost_function = lambda x,y: T.log(T.sum(T.pow((x-y),2)))
    else:
        logger.log("ERROR: Did not recognize cost function {0!s}, please use binary_crossentropy or square".format(state.cost_funct))
        raise NotImplementedError("Did not recognize cost function {0!s}, please use binary_crossentropy or square".format(state.cost_funct))
    
    logger.log("\n")  
        
    
    ##############################################
    #    Build the training graph for the GSN    #
    ##############################################       
    if train_gsn_first:
        '''Build the actual gsn training graph'''
        p_X_chain_gsn, _ = generative_stochastic_network.build_gsn(X,
                                      weights_list,
                                      bias_list,
                                      True,
                                      state.noiseless_h1,
                                      state.hidden_add_noise_sigma,
                                      state.input_salt_and_pepper,
                                      state.input_sampling,
                                      MRG,
                                      visible_activation,
                                      hidden_activation,
                                      walkbacks,
                                      logger)
        # now without noise
        p_X_chain_gsn_recon, _ = generative_stochastic_network.build_gsn(X,
                                      weights_list,
                                      bias_list,
                                      False,
                                      state.noiseless_h1,
                                      state.hidden_add_noise_sigma,
                                      state.input_salt_and_pepper,
                                      state.input_sampling,
                                      MRG,
                                      visible_activation,
                                      hidden_activation,
                                      walkbacks,
                                      logger)
    
    ##############################################
    #  Build the training graph for the RNN-GSN  #
    ##############################################
    # If `x_t` is given, deterministic recurrence to compute the u_t. Otherwise, first generate
    def recurrent_step(x_t, u_tm1):
        bv_t = bias_list[0] + T.dot(u_tm1, recurrent_to_gsn_bias_weights_list[0])
        bh_t = T.concatenate([bias_list[i+1] + T.dot(u_tm1, recurrent_to_gsn_bias_weights_list[i+1]) for i in range(layers)],axis=0)
        generate = x_t is None
        if generate:
            pass
        ua_t = T.dot(x_t, W_x_u) + T.dot(u_tm1, W_u_u) + recurrent_bias
        u_t = recurrent_hidden_activation(ua_t)
        return None if generate else [ua_t, u_t, bv_t, bh_t]
    
    logger.log("\nCreating recurrent step scan.")
    # For training, the deterministic recurrence is used to compute all the
    # {h_t, 1 <= t <= T} given Xs. Conditional GSNs can then be trained
    # in batches using those parameters.
    u0 = T.zeros((state.recurrent_hidden_size,))  # initial value for the RNN hidden units
    (ua, u, bv_t, bh_t), updates_recurrent = theano.scan(fn=lambda x_t, u_tm1, *_: recurrent_step(x_t, u_tm1),
                                                       sequences=Xs,
                                                       outputs_info=[None, u0, None, None],
                                                       non_sequences=params)
    # put the bias_list together from hiddens and visible biases
    #b_list = [bv_t.flatten(2)] + [bh_t.dimshuffle((1,0,2))[i] for i in range(len(weights_list))]
    b_list = [bv_t] + [(bh_t.T[i*state.hidden_size:(i+1)*state.hidden_size]).T for i in range(layers)]
    
    _, cost, show_cost = generative_stochastic_network.build_gsn_scan(Xs, weights_list, b_list, True, state.noiseless_h1, state.hidden_add_noise_sigma, state.input_salt_and_pepper, state.input_sampling, MRG, visible_activation, hidden_activation, walkbacks, cost_function, logger)
    x_sample_recon, _, _ = generative_stochastic_network.build_gsn_scan(Xs, weights_list, b_list, False, state.noiseless_h1, state.hidden_add_noise_sigma, state.input_salt_and_pepper, state.input_sampling, MRG, visible_activation, hidden_activation, walkbacks, cost_function, logger)
    
    updates_train = updates_recurrent
    #updates_train.update(updates_gsn)
    updates_cost = updates_recurrent
    
    #updates_recon = updates_recurrent
    #updates_recon.update(updates_gsn_recon)
        

    #############
    #   COSTS   #
    #############
    logger.log("")    
    logger.log('Cost w.r.t p(X|...) at every step in the graph')
    
    if train_gsn_first:
        gsn_costs     = [cost_function(rX, X) for rX in p_X_chain_gsn]
        gsn_show_cost = gsn_costs[-1]
        gsn_cost      = numpy.sum(gsn_costs)
            

    ###################################
    # GRADIENTS AND FUNCTIONS FOR GSN #
    ###################################
    logger.log(["params:",params])
    
    logger.log("creating functions...")
    start_functions_time = time.time()
    
    if train_gsn_first:
        gradient_gsn        = T.grad(gsn_cost, gsn_params)      
        gradient_buffer_gsn = [theano.shared(numpy.zeros(param.get_value().shape, dtype='float32')) for param in gsn_params]
        
        m_gradient_gsn    = [momentum * gb + (cast32(1) - momentum) * g for (gb, g) in zip(gradient_buffer_gsn, gradient_gsn)]
        param_updates_gsn = [(param, param - learning_rate * mg) for (param, mg) in zip(gsn_params, m_gradient_gsn)]
        gradient_buffer_updates_gsn = zip(gradient_buffer_gsn, m_gradient_gsn)
            
        grad_updates_gsn = OrderedDict(param_updates_gsn + gradient_buffer_updates_gsn)
        
        logger.log("gsn cost...")
        f_cost_gsn = theano.function(inputs  = [X], 
                                     outputs = gsn_show_cost, 
                                     on_unused_input='warn')
        
        logger.log("gsn learn...")
        f_learn_gsn = theano.function(inputs  = [X],
                                      updates = grad_updates_gsn,
                                      outputs = gsn_show_cost,
                                      on_unused_input='warn')
    
    #######################################
    # GRADIENTS AND FUNCTIONS FOR RNN-GSN #
    #######################################
    # if we are not using Hessian-free training create the normal sgd functions
    if state.hessian_free == 0:
        gradient      = T.grad(cost, params)      
        gradient_buffer = [theano.shared(numpy.zeros(param.get_value().shape, dtype='float32')) for param in params]
        
        m_gradient    = [momentum * gb + (cast32(1) - momentum) * g for (gb, g) in zip(gradient_buffer, gradient)]
        param_updates = [(param, param - learning_rate * mg) for (param, mg) in zip(params, m_gradient)]
        gradient_buffer_updates = zip(gradient_buffer, m_gradient)
            
        updates = OrderedDict(param_updates + gradient_buffer_updates)
        updates_train.update(updates)
    
        logger.log("rnn-gsn learn...")
        f_learn = theano.function(inputs  = [Xs],
                                  updates = updates_train,
                                  outputs = show_cost,
                                  on_unused_input='warn')
        
        logger.log("rnn-gsn cost...")
        f_cost  = theano.function(inputs  = [Xs],
                                  updates = updates_cost,
                                  outputs = show_cost, 
                                  on_unused_input='warn')
    
    logger.log("Training/cost functions done.")
    compilation_time = time.time() - start_functions_time
    # Show the compile time with appropriate easy-to-read units.
    logger.log("Compilation took "+make_time_units_string(compilation_time)+".\n\n")
    
    ############################################################################################
    # Denoise some numbers : show number, noisy number, predicted number, reconstructed number #
    ############################################################################################   
    # Recompile the graph without noise for reconstruction function
    # The layer update scheme
    logger.log("Creating graph for noisy reconstruction function at checkpoints during training.")
    f_recon = theano.function(inputs=[Xs], outputs=x_sample_recon[-1])
    
    # Now do the same but for the GSN in the initial run
    if train_gsn_first:
        f_recon_gsn = theano.function(inputs=[X], outputs = p_X_chain_gsn_recon[-1])

    logger.log("Done compiling all functions.")
    compilation_time = time.time() - start_functions_time
    # Show the compile time with appropriate easy-to-read units.
    logger.log("Total time took "+make_time_units_string(compilation_time)+".\n\n")

    ############
    # Sampling #
    ############
    # a function to add salt and pepper noise
    f_noise = theano.function(inputs = [X], outputs = salt_and_pepper(X, state.input_salt_and_pepper))
    # the input to the sampling function
    X_sample = T.fmatrix("X_sampling")
    network_state_input     =   [X_sample] + [T.fmatrix("H_sampling_"+str(i+1)) for i in range(layers)]
   
    # "Output" state of the network (noisy)
    # initialized with input, then we apply updates
    
    network_state_output    =   [X_sample] + network_state_input[1:]

    visible_pX_chain        =   []

    # ONE update
    logger.log("Performing one walkback in network state sampling.")
    generative_stochastic_network.update_layers(network_state_output,
                      weights_list,
                      bias_list,
                      visible_pX_chain, 
                      True,
                      state.noiseless_h1,
                      state.hidden_add_noise_sigma,
                      state.input_salt_and_pepper,
                      state.input_sampling,
                      MRG,
                      visible_activation,
                      hidden_activation,
                      logger)

    if layers == 1: 
        f_sample_simple = theano.function(inputs = [X_sample], outputs = visible_pX_chain[-1])
    
    
    # WHY IS THERE A WARNING????
    # because the first odd layers are not used -> directly computed FROM THE EVEN layers
    # unused input = warn
    f_sample2   =   theano.function(inputs = network_state_input, outputs = network_state_output + visible_pX_chain, on_unused_input='warn')

    def sample_some_numbers_single_layer():
        x0    =   test_X.get_value()[:1]
        samples = [x0]
        x  =   f_noise(x0)
        for i in range(399):
            x = f_sample_simple(x)
            samples.append(x)
            x = numpy.random.binomial(n=1, p=x, size=x.shape).astype('float32')
            x = f_noise(x)
        return numpy.vstack(samples)
            
    def sampling_wrapper(NSI):
        # * is the "splat" operator: It takes a list as input, and expands it into actual positional arguments in the function call.
        out             =   f_sample2(*NSI)
        NSO             =   out[:len(network_state_output)]
        vis_pX_chain    =   out[len(network_state_output):]
        return NSO, vis_pX_chain

    def sample_some_numbers(N=400):
        # The network's initial state
        init_vis        =   test_X.get_value()[:1]

        noisy_init_vis  =   f_noise(init_vis)

        network_state   =   [[noisy_init_vis] + [numpy.zeros((1,len(b.get_value())), dtype='float32') for b in bias_list[1:]]]

        visible_chain   =   [init_vis]

        noisy_h0_chain  =   [noisy_init_vis]

        for i in range(N-1):
           
            # feed the last state into the network, compute new state, and obtain visible units expectation chain 
            net_state_out, vis_pX_chain =   sampling_wrapper(network_state[-1])

            # append to the visible chain
            visible_chain   +=  vis_pX_chain

            # append state output to the network state chain
            network_state.append(net_state_out)
            
            noisy_h0_chain.append(net_state_out[0])

        return numpy.vstack(visible_chain), numpy.vstack(noisy_h0_chain)
    
    def plot_samples(epoch_number, leading_text):
        to_sample = time.time()
        if layers == 1:
            # one layer model
            V = sample_some_numbers_single_layer()
        else:
            V, H0 = sample_some_numbers()
        img_samples =   PIL.Image.fromarray(tile_raster_images(V, (root_N_input,root_N_input), (20,20)))
        
        fname       =   outdir+leading_text+'samples_epoch_'+str(epoch_number)+'.png'
        img_samples.save(fname) 
        logger.log('Took ' + str(time.time() - to_sample) + ' to sample 400 numbers')
   
    #############################
    # Save the model parameters #
    #############################
    def save_params_to_file(name, n, gsn_params):
        pass
#         print 'saving parameters...'
#         save_path = outdir+name+'_params_epoch_'+str(n)+'.pkl'
#         f = open(save_path, 'wb')
#         try:
#             cPickle.dump(gsn_params, f, protocol=cPickle.HIGHEST_PROTOCOL)
#         finally:
#             f.close()
            
    def save_params(params):
        values = [param.get_value(borrow=True) for param in params]
        return values
    
    def restore_params(params, values):
        for i in range(len(params)):
            params[i].set_value(values[i])

    ################
    # GSN TRAINING #
    ################
    def train_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        logger.log("\n-----------TRAINING GSN------------\n")
        
        # TRAINING
        n_epoch     =   state.n_epoch
        batch_size  =   state.gsn_batch_size
        STOP        =   False
        counter     =   0
        learning_rate.set_value(cast32(state.learning_rate))  # learning rate
        times = []
        best_cost = float('inf')
        best_params = None
        patience = 0
                    
        logger.log(['train X size:',str(train_X.shape.eval())])
        logger.log(['valid X size:',str(valid_X.shape.eval())])
        logger.log(['test X size:',str(test_X.shape.eval())])
        
        if state.vis_init:
            bias_list[0].set_value(logit(numpy.clip(0.9,0.001,train_X.get_value().mean(axis=0))))
    
        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            logger.log('Testing : skip training')
            STOP    =   True
    
        while not STOP:
            counter += 1
            t = time.time()
            logger.append([counter,'\t'])
                
            #shuffle the data
            data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
                
            #train
            train_costs = []
            for i in xrange(len(train_X.get_value(borrow=True)) / batch_size):
                x = train_X.get_value()[i * batch_size : (i+1) * batch_size]
                cost = f_learn_gsn(x)
                train_costs.append([cost])
            train_costs = numpy.mean(train_costs)
            # record it
            logger.append(['Train:',trunc(train_costs),'\t'])
            with open(gsn_train_convergence,'a') as f:
                f.write("{0!s},".format(train_costs))
                f.write("\n")
    
    
            #valid
            valid_costs = []
            for i in xrange(len(valid_X.get_value(borrow=True)) / batch_size):
                x = valid_X.get_value()[i * batch_size : (i+1) * batch_size]
                cost = f_cost_gsn(x)
                valid_costs.append([cost])                    
            valid_costs = numpy.mean(valid_costs)
            # record it
            logger.append(['Valid:',trunc(valid_costs), '\t'])
            with open(gsn_valid_convergence,'a') as f:
                f.write("{0!s},".format(valid_costs))
                f.write("\n")
    
    
            #test
            test_costs = []
            for i in xrange(len(test_X.get_value(borrow=True)) / batch_size):
                x = test_X.get_value()[i * batch_size : (i+1) * batch_size]
                cost = f_cost_gsn(x)
                test_costs.append([cost])                
            test_costs = numpy.mean(test_costs)
            # record it 
            logger.append(['Test:',trunc(test_costs), '\t'])
            with open(gsn_test_convergence,'a') as f:
                f.write("{0!s},".format(test_costs))
                f.write("\n")
            
            
            #check for early stopping
            cost = numpy.sum(valid_costs)
            if cost < best_cost*state.early_stop_threshold:
                patience = 0
                best_cost = cost
                # save the parameters that made it the best
                best_params = save_params(gsn_params)
            else:
                patience += 1
    
            if counter >= n_epoch or patience >= state.early_stop_length:
                STOP = True
                if best_params is not None:
                    restore_params(gsn_params, best_params)
                save_params_to_file('gsn', counter, gsn_params)
    
            timing = time.time() - t
            times.append(timing)
    
            logger.append('time: '+make_time_units_string(timing)+'\t')
            
            logger.log('remaining: '+make_time_units_string((n_epoch - counter) * numpy.mean(times)))
    
            if (counter % state.save_frequency) == 0 or STOP is True:
                n_examples = 100
                random_idx = numpy.array(R.sample(range(len(test_X.get_value(borrow=True))), n_examples))
                numbers = test_X.get_value(borrow=True)[random_idx]
                noisy_numbers = f_noise(test_X.get_value(borrow=True)[random_idx])
                reconstructed = f_recon_gsn(noisy_numbers) 
                # Concatenate stuff
                stacked = numpy.vstack([numpy.vstack([numbers[i*10 : (i+1)*10], noisy_numbers[i*10 : (i+1)*10], reconstructed[i*10 : (i+1)*10]]) for i in range(10)])
                number_reconstruction = PIL.Image.fromarray(tile_raster_images(stacked, (root_N_input,root_N_input), (10,30)))
                    
                number_reconstruction.save(outdir+'gsn_number_reconstruction_epoch_'+str(counter)+'.png')
        
                #sample_numbers(counter, 'seven')
                plot_samples(counter, 'gsn')
        
                #save gsn_params
                save_params_to_file('gsn', counter, gsn_params)
         
            # ANNEAL!
            new_lr = learning_rate.get_value() * annealing
            learning_rate.set_value(new_lr)

        
        # 10k samples
        print 'Generating 10,000 samples'
        samples, _  =   sample_some_numbers(N=10000)
        f_samples   =   outdir+'samples.npy'
        numpy.save(f_samples, samples)
        print 'saved digits'
        
    def train_RNN_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        # If we are using Hessian-free training
        if state.hessian_free == 1:
            pass
#         gradient_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=5000)
#         cg_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=1000)
#         valid_dataset = hf_sequence_dataset([valid_X.get_value()], batch_size=None, number_batches=1000)
#         
#         s = x_samples
#         costs = [cost, show_cost]
#         hf_optimizer(params, [Xs], s, costs, u, ua).train(gradient_dataset, cg_dataset, initial_lambda=1.0, preconditioner=True, validation=valid_dataset)
        
        # If we are using SGD training
        else:
            # Define the re-used loops for f_learn and f_cost
            def apply_cost_function_to_dataset(function, dataset):
                costs = []
                for i in xrange(len(dataset.get_value(borrow=True)) / batch_size):
                    xs = dataset.get_value(borrow=True)[i * batch_size : (i+1) * batch_size]
                    cost = function(xs)
                    costs.append([cost])
                return numpy.mean(costs)
            
            logger.log("\n-----------TRAINING RNN-GSN------------\n")
            # TRAINING
            n_epoch     =   state.n_epoch
            batch_size  =   state.batch_size
            STOP        =   False
            counter     =   0
            learning_rate.set_value(cast32(state.learning_rate))  # learning rate
            times = []
            best_cost = float('inf')
            best_params = None
            patience = 0
                        
            logger.log(['train X size:',str(train_X.shape.eval())])
            logger.log(['valid X size:',str(valid_X.shape.eval())])
            logger.log(['test X size:',str(test_X.shape.eval())])
            
            if state.vis_init:
                bias_list[0].set_value(logit(numpy.clip(0.9,0.001,train_X.get_value().mean(axis=0))))
        
            if state.test_model:
                # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
                logger.log('Testing : skip training')
                STOP    =   True
        
            while not STOP:
                counter += 1
                t = time.time()
                logger.append([counter,'\t'])
                    
                #shuffle the data
                data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
                     
                #train
                train_costs = apply_cost_function_to_dataset(f_learn, train_X)
                # record it
                logger.append(['Train:',trunc(train_costs),'\t'])
                with open(train_convergence,'a') as f:
                    f.write("{0!s},".format(train_costs))
                    f.write("\n")
         
         
                #valid
                valid_costs = apply_cost_function_to_dataset(f_cost, valid_X)
                # record it
                logger.append(['Valid:',trunc(valid_costs), '\t'])
                with open(valid_convergence,'a') as f:
                    f.write("{0!s},".format(valid_costs))
                    f.write("\n")
         
         
                #test
                test_costs = apply_cost_function_to_dataset(f_cost, test_X)
                # record it 
                logger.append(['Test:',trunc(test_costs), '\t'])
                with open(test_convergence,'a') as f:
                    f.write("{0!s},".format(test_costs))
                    f.write("\n")
                 
                 
                #check for early stopping
                cost = numpy.sum(valid_costs)
                if cost < best_cost*state.early_stop_threshold:
                    patience = 0
                    best_cost = cost
                    # save the parameters that made it the best
                    best_params = save_params(params)
                else:
                    patience += 1
         
                if counter >= n_epoch or patience >= state.early_stop_length:
                    STOP = True
                    if best_params is not None:
                        restore_params(params, best_params)
                    save_params_to_file('all', counter, params)
         
                timing = time.time() - t
                times.append(timing)
         
                logger.append('time: '+make_time_units_string(timing)+'\t')
            
                logger.log('remaining: '+make_time_units_string((n_epoch - counter) * numpy.mean(times)))
        
                if (counter % state.save_frequency) == 0 or STOP is True:
                    n_examples = 100
                    nums = test_X.get_value(borrow=True)[range(n_examples)]
                    noisy_nums = f_noise(test_X.get_value(borrow=True)[range(n_examples)])
                    reconstructions = []
                    for i in xrange(0, len(noisy_nums)):
                        recon = f_recon(noisy_nums[max(0,(i+1)-batch_size):i+1])
                        reconstructions.append(recon)
                    reconstructed = numpy.array(reconstructions)

                    # Concatenate stuff
                    stacked = numpy.vstack([numpy.vstack([nums[i*10 : (i+1)*10], noisy_nums[i*10 : (i+1)*10], reconstructed[i*10 : (i+1)*10]]) for i in range(10)])
                    number_reconstruction = PIL.Image.fromarray(tile_raster_images(stacked, (root_N_input,root_N_input), (10,30)))
                        
                    number_reconstruction.save(outdir+'rnngsn_number_reconstruction_epoch_'+str(counter)+'.png')
            
                    #sample_numbers(counter, 'seven')
                    plot_samples(counter, 'rnngsn')
            
                    #save params
                    save_params_to_file('all', counter, params)
             
                # ANNEAL!
                new_lr = learning_rate.get_value() * annealing
                learning_rate.set_value(new_lr)
    
            
            # 10k samples
            print 'Generating 10,000 samples'
            samples, _  =   sample_some_numbers(N=10000)
            f_samples   =   outdir+'samples.npy'
            numpy.save(f_samples, samples)
            print 'saved digits'
            
    
    #####################
    # STORY 2 ALGORITHM #
    #####################
    # train the GSN parameters first to get a good baseline (if not loaded from parameter .pkl file)
    if train_gsn_first:
        train_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
    # train the entire RNN-GSN
    train_RNN_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y)
    def train_RNN_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        # If we are using Hessian-free training
        if state.hessian_free == 1:
            pass
#         gradient_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=5000)
#         cg_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=1000)
#         valid_dataset = hf_sequence_dataset([valid_X.get_value()], batch_size=None, number_batches=1000)
#
#         s = x_samples
#         costs = [cost, show_cost]
#         hf_optimizer(params, [Xs], s, costs, u, ua).train(gradient_dataset, cg_dataset, initial_lambda=1.0, preconditioner=True, validation=valid_dataset)

# If we are using SGD training
        else:
            # Define the re-used loops for f_learn and f_cost
            def apply_cost_function_to_dataset(function, dataset):
                costs = []
                for i in xrange(
                        len(dataset.get_value(borrow=True)) / batch_size):
                    xs = dataset.get_value(
                        borrow=True)[i * batch_size:(i + 1) * batch_size]
                    cost = function(xs)
                    costs.append([cost])
                return numpy.mean(costs)

            logger.log("\n-----------TRAINING RNN-GSN------------\n")
            # TRAINING
            n_epoch = state.n_epoch
            batch_size = state.batch_size
            STOP = False
            counter = 0
            learning_rate.set_value(cast32(
                state.learning_rate))  # learning rate
            times = []
            best_cost = float('inf')
            best_params = None
            patience = 0

            logger.log(['train X size:', str(train_X.shape.eval())])
            logger.log(['valid X size:', str(valid_X.shape.eval())])
            logger.log(['test X size:', str(test_X.shape.eval())])

            if state.vis_init:
                bias_list[0].set_value(
                    logit(
                        numpy.clip(0.9, 0.001,
                                   train_X.get_value().mean(axis=0))))

            if state.test_model:
                # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
                logger.log('Testing : skip training')
                STOP = True

            while not STOP:
                counter += 1
                t = time.time()
                logger.append([counter, '\t'])

                #shuffle the data
                data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y,
                                         test_X, test_Y, dataset, rng)

                #train
                train_costs = apply_cost_function_to_dataset(f_learn, train_X)
                # record it
                logger.append(['Train:', trunc(train_costs), '\t'])
                with open(train_convergence, 'a') as f:
                    f.write("{0!s},".format(train_costs))
                    f.write("\n")

                #valid
                valid_costs = apply_cost_function_to_dataset(f_cost, valid_X)
                # record it
                logger.append(['Valid:', trunc(valid_costs), '\t'])
                with open(valid_convergence, 'a') as f:
                    f.write("{0!s},".format(valid_costs))
                    f.write("\n")

                #test
                test_costs = apply_cost_function_to_dataset(f_cost, test_X)
                # record it
                logger.append(['Test:', trunc(test_costs), '\t'])
                with open(test_convergence, 'a') as f:
                    f.write("{0!s},".format(test_costs))
                    f.write("\n")

                #check for early stopping
                cost = numpy.sum(valid_costs)
                if cost < best_cost * state.early_stop_threshold:
                    patience = 0
                    best_cost = cost
                    # save the parameters that made it the best
                    best_params = save_params(params)
                else:
                    patience += 1

                if counter >= n_epoch or patience >= state.early_stop_length:
                    STOP = True
                    if best_params is not None:
                        restore_params(params, best_params)
                    save_params_to_file('all', counter, params)

                timing = time.time() - t
                times.append(timing)

                logger.append('time: ' + make_time_units_string(timing) + '\t')

                logger.log('remaining: ' +
                           make_time_units_string((n_epoch - counter) *
                                                  numpy.mean(times)))

                if (counter % state.save_frequency) == 0 or STOP is True:
                    n_examples = 100
                    nums = test_X.get_value(borrow=True)[range(n_examples)]
                    noisy_nums = f_noise(
                        test_X.get_value(borrow=True)[range(n_examples)])
                    reconstructions = []
                    for i in xrange(0, len(noisy_nums)):
                        recon = f_recon(noisy_nums[max(0, (i + 1) -
                                                       batch_size):i + 1])
                        reconstructions.append(recon)
                    reconstructed = numpy.array(reconstructions)

                    # Concatenate stuff
                    stacked = numpy.vstack([
                        numpy.vstack([
                            nums[i * 10:(i + 1) * 10],
                            noisy_nums[i * 10:(i + 1) * 10],
                            reconstructed[i * 10:(i + 1) * 10]
                        ]) for i in range(10)
                    ])
                    number_reconstruction = PIL.Image.fromarray(
                        tile_raster_images(stacked,
                                           (root_N_input, root_N_input),
                                           (10, 30)))

                    number_reconstruction.save(
                        outdir + 'rnngsn_number_reconstruction_epoch_' +
                        str(counter) + '.png')

                    #sample_numbers(counter, 'seven')
                    plot_samples(counter, 'rnngsn')

                    #save params
                    save_params_to_file('all', counter, params)

                # ANNEAL!
                new_lr = learning_rate.get_value() * annealing
                learning_rate.set_value(new_lr)

            # 10k samples
            print 'Generating 10,000 samples'
            samples, _ = sample_some_numbers(N=10000)
            f_samples = outdir + 'samples.npy'
            numpy.save(f_samples, samples)
            print 'saved digits'
    def train_GSN(train_X, train_Y, valid_X, valid_Y, test_X, test_Y):
        logger.log("\n-----------TRAINING GSN------------\n")
        
        # TRAINING
        n_epoch     =   state.n_epoch
        batch_size  =   state.gsn_batch_size
        STOP        =   False
        counter     =   0
        learning_rate.set_value(cast32(state.learning_rate))  # learning rate
        times = []
        best_cost = float('inf')
        best_params = None
        patience = 0
                    
        logger.log(['train X size:',str(train_X.shape.eval())])
        logger.log(['valid X size:',str(valid_X.shape.eval())])
        logger.log(['test X size:',str(test_X.shape.eval())])
        
        if state.vis_init:
            bias_list[0].set_value(logit(numpy.clip(0.9,0.001,train_X.get_value().mean(axis=0))))
    
        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            logger.log('Testing : skip training')
            STOP    =   True
    
        while not STOP:
            counter += 1
            t = time.time()
            logger.append([counter,'\t'])
                
            #shuffle the data
            data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, dataset, rng)
                
            #train
            train_costs = []
            for i in xrange(len(train_X.get_value(borrow=True)) / batch_size):
                x = train_X.get_value()[i * batch_size : (i+1) * batch_size]
                cost = f_learn_gsn(x)
                train_costs.append([cost])
            train_costs = numpy.mean(train_costs)
            # record it
            logger.append(['Train:',trunc(train_costs),'\t'])
            with open(gsn_train_convergence,'a') as f:
                f.write("{0!s},".format(train_costs))
                f.write("\n")
    
    
            #valid
            valid_costs = []
            for i in xrange(len(valid_X.get_value(borrow=True)) / batch_size):
                x = valid_X.get_value()[i * batch_size : (i+1) * batch_size]
                cost = f_cost_gsn(x)
                valid_costs.append([cost])                    
            valid_costs = numpy.mean(valid_costs)
            # record it
            logger.append(['Valid:',trunc(valid_costs), '\t'])
            with open(gsn_valid_convergence,'a') as f:
                f.write("{0!s},".format(valid_costs))
                f.write("\n")
    
    
            #test
            test_costs = []
            for i in xrange(len(test_X.get_value(borrow=True)) / batch_size):
                x = test_X.get_value()[i * batch_size : (i+1) * batch_size]
                cost = f_cost_gsn(x)
                test_costs.append([cost])                
            test_costs = numpy.mean(test_costs)
            # record it 
            logger.append(['Test:',trunc(test_costs), '\t'])
            with open(gsn_test_convergence,'a') as f:
                f.write("{0!s},".format(test_costs))
                f.write("\n")
            
            
            #check for early stopping
            cost = numpy.sum(valid_costs)
            if cost < best_cost*state.early_stop_threshold:
                patience = 0
                best_cost = cost
                # save the parameters that made it the best
                best_params = save_params(gsn_params)
            else:
                patience += 1
    
            if counter >= n_epoch or patience >= state.early_stop_length:
                STOP = True
                if best_params is not None:
                    restore_params(gsn_params, best_params)
                save_params_to_file('gsn', counter, gsn_params)
    
            timing = time.time() - t
            times.append(timing)
    
            logger.append('time: '+make_time_units_string(timing)+'\t')
            
            logger.log('remaining: '+make_time_units_string((n_epoch - counter) * numpy.mean(times)))
    
            if (counter % state.save_frequency) == 0 or STOP is True:
                n_examples = 100
                random_idx = numpy.array(R.sample(range(len(test_X.get_value(borrow=True))), n_examples))
                numbers = test_X.get_value(borrow=True)[random_idx]
                noisy_numbers = f_noise(test_X.get_value(borrow=True)[random_idx])
                reconstructed = f_recon_gsn(noisy_numbers) 
                # Concatenate stuff
                stacked = numpy.vstack([numpy.vstack([numbers[i*10 : (i+1)*10], noisy_numbers[i*10 : (i+1)*10], reconstructed[i*10 : (i+1)*10]]) for i in range(10)])
                number_reconstruction = PIL.Image.fromarray(tile_raster_images(stacked, (root_N_input,root_N_input), (10,30)))
                    
                number_reconstruction.save(outdir+'gsn_number_reconstruction_epoch_'+str(counter)+'.png')
        
                #sample_numbers(counter, 'seven')
                plot_samples(counter, 'gsn')
        
                #save gsn_params
                save_params_to_file('gsn', counter, gsn_params)
         
            # ANNEAL!
            new_lr = learning_rate.get_value() * annealing
            learning_rate.set_value(new_lr)

        
        # 10k samples
        print 'Generating 10,000 samples'
        samples, _  =   sample_some_numbers(N=10000)
        f_samples   =   outdir+'samples.npy'
        numpy.save(f_samples, samples)
        print 'saved digits'
def experiment(state, outdir_base='./'):
    rng.seed(1)  #seed the numpy random generator
    # Initialize output directory and files
    data.mkdir_p(outdir_base)
    outdir = outdir_base + "/" + state.dataset + "/"
    data.mkdir_p(outdir)
    logfile = outdir + "log.txt"
    with open(logfile, 'w') as f:
        f.write("MODEL 2, {0!s}\n\n".format(state.dataset))
    train_convergence_pre = outdir + "train_convergence_pre.csv"
    train_convergence_post = outdir + "train_convergence_post.csv"
    valid_convergence_pre = outdir + "valid_convergence_pre.csv"
    valid_convergence_post = outdir + "valid_convergence_post.csv"
    test_convergence_pre = outdir + "test_convergence_pre.csv"
    test_convergence_post = outdir + "test_convergence_post.csv"

    print
    print "----------MODEL 2, {0!s}--------------".format(state.dataset)
    print

    #load parameters from config file if this is a test
    config_filename = outdir + 'config'
    if state.test_model and 'config' in os.listdir(outdir):
        config_vals = load_from_config(config_filename)
        for CV in config_vals:
            print CV
            if CV.startswith('test'):
                print 'Do not override testing switch'
                continue
            try:
                exec('state.' + CV) in globals(), locals()
            except:
                exec('state.' + CV.split('=')[0] + "='" + CV.split('=')[1] +
                     "'") in globals(), locals()
    else:
        # Save the current configuration
        # Useful for logs/experiments
        print 'Saving config'
        with open(config_filename, 'w') as f:
            f.write(str(state))

    print state
    # Load the data, train = train+valid, and sequence
    artificial = False
    if state.dataset == 'MNIST_1' or state.dataset == 'MNIST_2' or state.dataset == 'MNIST_3':
        (train_X,
         train_Y), (valid_X,
                    valid_Y), (test_X,
                               test_Y) = data.load_mnist(state.data_path)
        train_X = numpy.concatenate((train_X, valid_X))
        train_Y = numpy.concatenate((train_Y, valid_Y))
        artificial = True
        try:
            dataset = int(state.dataset.split('_')[1])
        except:
            raise AssertionError(
                "artificial dataset number not recognized. Input was " +
                state.dataset)
    else:
        raise AssertionError("dataset not recognized.")

    train_X = theano.shared(train_X)
    train_Y = theano.shared(train_Y)
    valid_X = theano.shared(valid_X)
    valid_Y = theano.shared(valid_Y)
    test_X = theano.shared(test_X)
    test_Y = theano.shared(test_Y)

    if artificial:
        print 'Sequencing MNIST data...'
        print 'train set size:', len(train_Y.eval())
        print 'valid set size:', len(valid_Y.eval())
        print 'test set size:', len(test_Y.eval())
        data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X,
                                 test_Y, dataset, rng)
        print 'train set size:', len(train_Y.eval())
        print 'valid set size:', len(valid_Y.eval())
        print 'test set size:', len(test_Y.eval())
        print 'Sequencing done.'
        print

    N_input = train_X.eval().shape[1]
    root_N_input = numpy.sqrt(N_input)

    # Network and training specifications
    layers = state.layers  # number hidden layers
    walkbacks = state.walkbacks  # number of walkbacks
    layer_sizes = [
        N_input
    ] + [state.hidden_size
         ] * layers  # layer sizes, from h0 to hK (h0 is the visible layer)
    learning_rate = theano.shared(cast32(state.learning_rate))  # learning rate
    annealing = cast32(state.annealing)  # exponential annealing coefficient
    momentum = theano.shared(cast32(state.momentum))  # momentum term

    # PARAMETERS : weights list and bias list.
    # initialize a list of weights and biases based on layer_sizes
    weights_list = [
        get_shared_weights(layer_sizes[i],
                           layer_sizes[i + 1],
                           name="W_{0!s}_{1!s}".format(i, i + 1))
        for i in range(layers)
    ]  # initialize each layer to uniform sample from sqrt(6. / (n_in + n_out))
    recurrent_weights_list = [
        get_shared_weights(layer_sizes[i + 1],
                           layer_sizes[i],
                           name="V_{0!s}_{1!s}".format(i + 1, i))
        for i in range(layers)
    ]  # initialize each layer to uniform sample from sqrt(6. / (n_in + n_out))
    bias_list = [
        get_shared_bias(layer_sizes[i], name='b_' + str(i))
        for i in range(layers + 1)
    ]  # initialize each layer to 0's.

    # Theano variables and RNG
    MRG = RNG_MRG.MRG_RandomStreams(1)
    X = T.fmatrix('X')
    Xs = [
        T.fmatrix(name="X_initial") if i == 0 else T.fmatrix(name="X_" +
                                                             str(i + 1))
        for i in range(walkbacks + 1)
    ]
    hiddens_input = [X] + [
        T.fmatrix(name="h_" + str(i + 1)) for i in range(layers)
    ]
    hiddens_output = hiddens_input[:1] + hiddens_input[1:]

    # Check variables for bad inputs and stuff
    if state.batch_size > len(Xs):
        warnings.warn(
            "Batch size should not be bigger than walkbacks+1 (len(Xs)) unless you know what you're doing. You need to know the sequence length beforehand."
        )
    if state.batch_size <= 0:
        raise AssertionError("batch size cannot be <= 0")
    ''' F PROP '''
    if state.hidden_act == 'sigmoid':
        print 'Using sigmoid activation for hiddens'
        hidden_activation = T.nnet.sigmoid
    elif state.hidden_act == 'rectifier':
        print 'Using rectifier activation for hiddens'
        hidden_activation = lambda x: T.maximum(cast32(0), x)
    elif state.hidden_act == 'tanh':
        print 'Using hyperbolic tangent activation for hiddens'
        hidden_activation = lambda x: T.tanh(x)
    else:
        raise AssertionError(
            "Did not recognize hidden activation {0!s}, please use tanh, rectifier, or sigmoid"
            .format(state.hidden_act))

    if state.visible_act == 'sigmoid':
        print 'Using sigmoid activation for visible layer'
        visible_activation = T.nnet.sigmoid
    elif state.visible_act == 'softmax':
        print 'Using softmax activation for visible layer'
        visible_activation = T.nnet.softmax
    else:
        raise AssertionError(
            "Did not recognize visible activation {0!s}, please use sigmoid or softmax"
            .format(state.visible_act))

    def update_layers(hiddens,
                      p_X_chain,
                      Xs,
                      sequence_idx,
                      noisy=True,
                      sampling=True):
        print 'odd layer updates'
        update_odd_layers(hiddens, noisy)
        print 'even layer updates'
        update_even_layers(hiddens, p_X_chain, Xs, sequence_idx, noisy,
                           sampling)
        # choose the correct output for hidden_outputs based on batch_size and walkbacks (this is due to an issue with batches, see note in run_story2.py)
        if state.batch_size <= len(
                Xs) and sequence_idx == state.batch_size - 1:
            return hiddens
        else:
            return None
        print 'done full update.'
        print

    # Odd layer update function
    # just a loop over the odd layers
    def update_odd_layers(hiddens, noisy):
        for i in range(1, len(hiddens), 2):
            print 'updating layer', i
            simple_update_layer(hiddens, None, None, None, i, add_noise=noisy)

    # Even layer update
    # p_X_chain is given to append the p(X|...) at each full update (one update = odd update + even update)
    def update_even_layers(hiddens, p_X_chain, Xs, sequence_idx, noisy,
                           sampling):
        for i in range(0, len(hiddens), 2):
            print 'updating layer', i
            simple_update_layer(hiddens,
                                p_X_chain,
                                Xs,
                                sequence_idx,
                                i,
                                add_noise=noisy,
                                input_sampling=sampling)

    # The layer update function
    # hiddens   :   list containing the symbolic theano variables [visible, hidden1, hidden2, ...]
    #               layer_update will modify this list inplace
    # p_X_chain :   list containing the successive p(X|...) at each update
    #               update_layer will append to this list
    # add_noise     : pre and post activation gaussian noise

    def simple_update_layer(hiddens,
                            p_X_chain,
                            Xs,
                            sequence_idx,
                            i,
                            add_noise=True,
                            input_sampling=True):
        # Compute the dot product, whatever layer
        # If the visible layer X
        if i == 0:
            print 'using', recurrent_weights_list[i]
            hiddens[i] = (T.dot(hiddens[i + 1], recurrent_weights_list[i]) +
                          bias_list[i])
        # If the top layer
        elif i == len(hiddens) - 1:
            print 'using', weights_list[i - 1]
            hiddens[i] = T.dot(hiddens[i - 1],
                               weights_list[i - 1]) + bias_list[i]
        # Otherwise in-between layers
        else:
            # next layer        :   hiddens[i+1], assigned weights : W_i
            # previous layer    :   hiddens[i-1], assigned weights : W_(i-1)
            print "using {0!s} and {1!s}".format(weights_list[i - 1],
                                                 recurrent_weights_list[i])
            hiddens[i] = T.dot(
                hiddens[i + 1], recurrent_weights_list[i]) + T.dot(
                    hiddens[i - 1], weights_list[i - 1]) + bias_list[i]

        # Add pre-activation noise if NOT input layer
        if i == 1 and state.noiseless_h1:
            print '>>NO noise in first hidden layer'
            add_noise = False

        # pre activation noise
        if i != 0 and add_noise:
            print 'Adding pre-activation gaussian noise for layer', i
            hiddens[i] = add_gaussian_noise(hiddens[i],
                                            state.hidden_add_noise_sigma)

        # ACTIVATION!
        if i == 0:
            print 'Sigmoid units activation for visible layer X'
            hiddens[i] = visible_activation(hiddens[i])
        else:
            print 'Hidden units {} activation for layer'.format(state.act), i
            hiddens[i] = hidden_activation(hiddens[i])

        # post activation noise
        # why is there post activation noise? Because there is already pre-activation noise, this just doubles the amount of noise between each activation of the hiddens.
#         if i != 0 and add_noise:
#             print 'Adding post-activation gaussian noise for layer', i
#             hiddens[i]  =   add_gaussian(hiddens[i], state.hidden_add_noise_sigma)

# build the reconstruction chain if updating the visible layer X
        if i == 0:
            # if input layer -> append p(X|...)
            p_X_chain.append(
                hiddens[i])  #what the predicted next input should be

            if sequence_idx + 1 < len(Xs):
                next_input = Xs[sequence_idx + 1]
                # sample from p(X|...) - SAMPLING NEEDS TO BE CORRECT FOR INPUT TYPES I.E. FOR BINARY MNIST SAMPLING IS BINOMIAL. real-valued inputs should be gaussian
                if input_sampling:
                    print 'Sampling from input'
                    sampled = MRG.binomial(p=next_input,
                                           size=next_input.shape,
                                           dtype='float32')
                else:
                    print '>>NO input sampling'
                    sampled = next_input
                # add noise
                sampled = salt_and_pepper(sampled, state.input_salt_and_pepper)

                # DOES INPUT SAMPLING MAKE SENSE FOR SEQUENTIAL? - not really since it was used in walkbacks which was gibbs.
                # set input layer
                hiddens[i] = sampled

    def build_graph(hiddens, Xs, noisy=True, sampling=True):
        predicted_X_chain = [
        ]  # the visible layer that gets generated at each update_layers run
        H_chain = [
        ]  # either None or hiddens that gets generated at each update_layers run, this is used to determine what the correct hiddens_output should be
        print "Building the graph :", walkbacks, "updates"
        for i in range(walkbacks):
            print "Forward Prediction {!s}/{!s}".format(i + 1, walkbacks)
            H_chain.append(
                update_layers(hiddens, predicted_X_chain, Xs, i, noisy,
                              sampling))
        return predicted_X_chain, H_chain

    '''Build the main training graph'''
    # corrupt x
    hiddens_output[0] = salt_and_pepper(hiddens_output[0],
                                        state.input_salt_and_pepper)
    # build the computation graph and the generated visible layers and appropriate hidden_output
    predicted_X_chain, H_chain = build_graph(hiddens_output,
                                             Xs,
                                             noisy=True,
                                             sampling=state.input_sampling)
    #     predicted_X_chain, H_chain = build_graph(hiddens_output, Xs, noisy=False, sampling=state.input_sampling) #testing one-hot without noise

    # choose the correct output for hiddens_output (this is due to the issue with batches - see note in run_story2.py)
    # this finds the not-None element of H_chain and uses that for hiddens_output
    h_empty = [True if h is None else False for h in H_chain]
    if False in h_empty:  # if there was a not-None element
        hiddens_output = H_chain[h_empty.index(
            False
        )]  # set hiddens_output to the appropriate element from H_chain

    ######################
    # COST AND GRADIENTS #
    ######################
    print
    if state.cost_funct == 'binary_crossentropy':
        print 'Using binary cross-entropy cost!'
        cost_function = lambda x, y: T.mean(T.nnet.binary_crossentropy(x, y))
    elif state.cost_funct == 'square':
        print "Using square error cost!"
        cost_function = lambda x, y: T.mean(T.sqr(x - y))
    else:
        raise AssertionError(
            "Did not recognize cost function {0!s}, please use binary_crossentropy or square"
            .format(state.cost_funct))
    print 'Cost w.r.t p(X|...) at every step in the graph'

    costs = [
        cost_function(predicted_X_chain[i], Xs[i + 1])
        for i in range(len(predicted_X_chain))
    ]
    # outputs for the functions
    show_COSTs = [costs[0]] + [costs[-1]]

    # cost for the gradient
    # care more about the immediate next predictions rather than the future - use exponential decay
    #     COST = T.sum(costs)
    COST = T.sum([
        T.exp(-i / T.ceil(walkbacks / 3)) * costs[i] for i in range(len(costs))
    ])

    params = weights_list + recurrent_weights_list + bias_list
    print "params:", params

    print "creating functions..."
    gradient = T.grad(COST, params)

    gradient_buffer = [
        theano.shared(numpy.zeros(param.get_value().shape, dtype='float32'))
        for param in params
    ]

    m_gradient = [
        momentum * gb + (cast32(1) - momentum) * g
        for (gb, g) in zip(gradient_buffer, gradient)
    ]
    param_updates = [(param, param - learning_rate * mg)
                     for (param, mg) in zip(params, m_gradient)]
    gradient_buffer_updates = zip(gradient_buffer, m_gradient)

    updates = OrderedDict(param_updates + gradient_buffer_updates)

    #odd layer h's not used from input -> calculated directly from even layers (starting with h_0) since the odd layers are updated first.
    f_cost = theano.function(inputs=hiddens_input + Xs,
                             outputs=hiddens_output + show_COSTs,
                             on_unused_input='warn')

    f_learn = theano.function(inputs=hiddens_input + Xs,
                              updates=updates,
                              outputs=hiddens_output + show_COSTs,
                              on_unused_input='warn')

    print "functions done."
    print

    #############
    # Denoise some numbers  :   show number, noisy number, reconstructed number
    #############
    import random as R
    R.seed(1)
    # a function to add salt and pepper noise
    f_noise = theano.function(inputs=[X],
                              outputs=salt_and_pepper(
                                  X, state.input_salt_and_pepper))

    # Recompile the graph without noise for reconstruction function - the input x_recon is already going to be noisy, and this is to test on a simulated 'real' input.
    X_recon = T.fvector("X_recon")
    Xs_recon = [T.fvector("Xs_recon")]
    hiddens_R_input = [X_recon] + [
        T.fvector(name="h_recon_" + str(i + 1)) for i in range(layers)
    ]
    hiddens_R_output = hiddens_R_input[:1] + hiddens_R_input[1:]

    # The layer update scheme
    print "Creating graph for noisy reconstruction function at checkpoints during training."
    p_X_chain_R, H_chain_R = build_graph(hiddens_R_output,
                                         Xs_recon,
                                         noisy=False)

    # choose the correct output from H_chain for hidden_outputs based on batch_size and walkbacks
    # choose the correct output for hiddens_output
    h_empty = [True if h is None else False for h in H_chain_R]
    if False in h_empty:  # if there was a set of hiddens output from the batch_size-1 element of the chain
        hiddens_R_output = H_chain_R[h_empty.index(
            False
        )]  # extract out the not-None element from the list if it exists
#     if state.batch_size <= len(Xs_recon):
#         for i in range(len(hiddens_R_output)):
#             hiddens_R_output[i] = H_chain_R[state.batch_size - 1][i]

    f_recon = theano.function(inputs=hiddens_R_input + Xs_recon,
                              outputs=hiddens_R_output +
                              [p_X_chain_R[0], p_X_chain_R[-1]],
                              on_unused_input="warn")

    ############
    # Sampling #
    ############

    # the input to the sampling function
    X_sample = T.fmatrix("X_sampling")
    network_state_input = [X_sample] + [
        T.fmatrix("H_sampling_" + str(i + 1)) for i in range(layers)
    ]

    # "Output" state of the network (noisy)
    # initialized with input, then we apply updates

    network_state_output = [X_sample] + network_state_input[1:]

    visible_pX_chain = []

    # ONE update
    print "Performing one walkback in network state sampling."
    _ = update_layers(network_state_output,
                      visible_pX_chain, [X_sample],
                      0,
                      noisy=True)

    if layers == 1:
        f_sample_simple = theano.function(inputs=[X_sample],
                                          outputs=visible_pX_chain[-1])

    # WHY IS THERE A WARNING????
    # because the first odd layers are not used -> directly computed FROM THE EVEN layers
    # unused input = warn
    f_sample2 = theano.function(inputs=network_state_input,
                                outputs=network_state_output +
                                visible_pX_chain,
                                on_unused_input='warn')

    def sample_some_numbers_single_layer():
        x0 = test_X.get_value()[:1]
        samples = [x0]
        x = f_noise(x0)
        for i in range(399):
            x = f_sample_simple(x)
            samples.append(x)
            x = numpy.random.binomial(n=1, p=x, size=x.shape).astype('float32')
            x = f_noise(x)
        return numpy.vstack(samples)

    def sampling_wrapper(NSI):
        # * is the "splat" operator: It takes a list as input, and expands it into actual positional arguments in the function call.
        out = f_sample2(*NSI)
        NSO = out[:len(network_state_output)]
        vis_pX_chain = out[len(network_state_output):]
        return NSO, vis_pX_chain

    def sample_some_numbers(N=400):
        # The network's initial state
        init_vis = test_X.get_value()[:1]

        noisy_init_vis = f_noise(init_vis)

        network_state = [[noisy_init_vis] + [
            numpy.zeros((1, len(b.get_value())), dtype='float32')
            for b in bias_list[1:]
        ]]

        visible_chain = [init_vis]

        noisy_h0_chain = [noisy_init_vis]

        for i in range(N - 1):

            # feed the last state into the network, compute new state, and obtain visible units expectation chain
            net_state_out, vis_pX_chain = sampling_wrapper(network_state[-1])

            # append to the visible chain
            visible_chain += vis_pX_chain

            # append state output to the network state chain
            network_state.append(net_state_out)

            noisy_h0_chain.append(net_state_out[0])

        return numpy.vstack(visible_chain), numpy.vstack(noisy_h0_chain)

    def plot_samples(epoch_number, iteration):
        to_sample = time.time()
        if layers == 1:
            # one layer model
            V = sample_some_numbers_single_layer()
        else:
            V, H0 = sample_some_numbers()
        img_samples = PIL.Image.fromarray(
            tile_raster_images(V, (root_N_input, root_N_input), (20, 20)))

        fname = outdir + 'samples_iteration_' + str(
            iteration) + '_epoch_' + str(epoch_number) + '.png'
        img_samples.save(fname)
        print 'Took ' + str(time.time() - to_sample) + ' to sample 400 numbers'

    ##############
    # Inpainting #
    ##############
    def inpainting(digit):
        # The network's initial state

        # NOISE INIT
        init_vis = cast32(numpy.random.uniform(size=digit.shape))

        #noisy_init_vis  =   f_noise(init_vis)
        #noisy_init_vis  =   cast32(numpy.random.uniform(size=init_vis.shape))

        # INDEXES FOR VISIBLE AND NOISY PART
        noise_idx = (numpy.arange(N_input) % root_N_input < (root_N_input / 2))
        fixed_idx = (numpy.arange(N_input) % root_N_input > (root_N_input / 2))

        # function to re-init the visible to the same noise

        # FUNCTION TO RESET HALF VISIBLE TO DIGIT
        def reset_vis(V):
            V[0][fixed_idx] = digit[0][fixed_idx]
            return V

        # INIT DIGIT : NOISE and RESET HALF TO DIGIT
        init_vis = reset_vis(init_vis)

        network_state = [[init_vis] + [
            numpy.zeros((1, len(b.get_value())), dtype='float32')
            for b in bias_list[1:]
        ]]

        visible_chain = [init_vis]

        noisy_h0_chain = [init_vis]

        for i in range(49):

            # feed the last state into the network, compute new state, and obtain visible units expectation chain
            net_state_out, vis_pX_chain = sampling_wrapper(network_state[-1])

            # reset half the digit
            net_state_out[0] = reset_vis(net_state_out[0])
            vis_pX_chain[0] = reset_vis(vis_pX_chain[0])

            # append to the visible chain
            visible_chain += vis_pX_chain

            # append state output to the network state chain
            network_state.append(net_state_out)

            noisy_h0_chain.append(net_state_out[0])

        return numpy.vstack(visible_chain), numpy.vstack(noisy_h0_chain)

    def save_params_to_file(name, n, params, iteration):
        print 'saving parameters...'
        save_path = outdir + name + '_params_iteration_' + str(
            iteration) + '_epoch_' + str(n) + '.pkl'
        f = open(save_path, 'wb')
        try:
            cPickle.dump(params, f, protocol=cPickle.HIGHEST_PROTOCOL)
        finally:
            f.close()

    ################
    # GSN TRAINING #
    ################
    def train_recurrent_GSN(iteration, train_X, train_Y, valid_X, valid_Y,
                            test_X, test_Y):
        print '----------------------------------------'
        print 'TRAINING GSN FOR ITERATION', iteration
        with open(logfile, 'a') as f:
            f.write(
                "--------------------------\nTRAINING GSN FOR ITERATION {0!s}\n"
                .format(iteration))

        # TRAINING
        n_epoch = state.n_epoch
        batch_size = state.batch_size
        STOP = False
        counter = 0
        if iteration == 0:
            learning_rate.set_value(cast32(
                state.learning_rate))  # learning rate
        times = []
        best_cost = float('inf')
        patience = 0

        print 'learning rate:', learning_rate.get_value()

        print 'train X size:', str(train_X.shape.eval())
        print 'valid X size:', str(valid_X.shape.eval())
        print 'test X size:', str(test_X.shape.eval())

        train_costs = []
        valid_costs = []
        test_costs = []
        train_costs_post = []
        valid_costs_post = []
        test_costs_post = []

        if state.vis_init:
            bias_list[0].set_value(
                logit(numpy.clip(0.9, 0.001,
                                 train_X.get_value().mean(axis=0))))

        if state.test_model:
            # If testing, do not train and go directly to generating samples, parzen window estimation, and inpainting
            print 'Testing : skip training'
            STOP = True

        while not STOP:
            counter += 1
            t = time.time()
            print counter, '\t',
            with open(logfile, 'a') as f:
                f.write("{0!s}\t".format(counter))
            #shuffle the data
            data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y,
                                     test_X, test_Y, dataset, rng)

            #train
            #init hiddens
            #             hiddens = [(T.zeros_like(train_X[:batch_size]).eval())]
            #             for i in range(len(weights_list)):
            #                 # init with zeros
            #                 hiddens.append(T.zeros_like(T.dot(hiddens[i], weights_list[i])).eval())
            hiddens = [
                T.zeros((batch_size, layer_size)).eval()
                for layer_size in layer_sizes
            ]
            train_cost = []
            train_cost_post = []
            for i in range(len(train_X.get_value(borrow=True)) / batch_size):
                xs = [
                    train_X.get_value(
                        borrow=True)[(i * batch_size) +
                                     sequence_idx:((i + 1) * batch_size) +
                                     sequence_idx]
                    for sequence_idx in range(len(Xs))
                ]
                xs, hiddens = fix_input_size(xs, hiddens)
                hiddens[0] = xs[0]
                _ins = hiddens + xs
                _outs = f_learn(*_ins)
                hiddens = _outs[:len(hiddens)]
                cost = _outs[-2]
                cost_post = _outs[-1]
                train_cost.append(cost)
                train_cost_post.append(cost_post)

            train_cost = numpy.mean(train_cost)
            train_costs.append(train_cost)
            train_cost_post = numpy.mean(train_cost_post)
            train_costs_post.append(train_cost_post)
            print 'Train : ', trunc(train_cost), trunc(train_cost_post), '\t',
            with open(logfile, 'a') as f:
                f.write("Train : {0!s} {1!s}\t".format(trunc(train_cost),
                                                       trunc(train_cost_post)))
            with open(train_convergence_pre, 'a') as f:
                f.write("{0!s},".format(train_cost))
            with open(train_convergence_post, 'a') as f:
                f.write("{0!s},".format(train_cost_post))

            #valid
            #init hiddens
            hiddens = [
                T.zeros((batch_size, layer_size)).eval()
                for layer_size in layer_sizes
            ]
            valid_cost = []
            valid_cost_post = []
            for i in range(len(valid_X.get_value(borrow=True)) / batch_size):
                xs = [
                    valid_X.get_value(
                        borrow=True)[(i * batch_size) +
                                     sequence_idx:((i + 1) * batch_size) +
                                     sequence_idx]
                    for sequence_idx in range(len(Xs))
                ]
                xs, hiddens = fix_input_size(xs, hiddens)
                hiddens[0] = xs[0]
                _ins = hiddens + xs
                _outs = f_cost(*_ins)
                hiddens = _outs[:-2]
                cost = _outs[-2]
                cost_post = _outs[-1]
                valid_cost.append(cost)
                valid_cost_post.append(cost_post)

            valid_cost = numpy.mean(valid_cost)
            valid_costs.append(valid_cost)
            valid_cost_post = numpy.mean(valid_cost_post)
            valid_costs_post.append(valid_cost_post)
            print 'Valid : ', trunc(valid_cost), trunc(valid_cost_post), '\t',
            with open(logfile, 'a') as f:
                f.write("Valid : {0!s} {1!s}\t".format(trunc(valid_cost),
                                                       trunc(valid_cost_post)))
            with open(valid_convergence_pre, 'a') as f:
                f.write("{0!s},".format(valid_cost))
            with open(valid_convergence_post, 'a') as f:
                f.write("{0!s},".format(valid_cost_post))

            #test
            #init hiddens
            hiddens = [
                T.zeros((batch_size, layer_size)).eval()
                for layer_size in layer_sizes
            ]
            test_cost = []
            test_cost_post = []
            for i in range(len(test_X.get_value(borrow=True)) / batch_size):
                xs = [
                    test_X.get_value(
                        borrow=True)[(i * batch_size) +
                                     sequence_idx:((i + 1) * batch_size) +
                                     sequence_idx]
                    for sequence_idx in range(len(Xs))
                ]
                xs, hiddens = fix_input_size(xs, hiddens)
                hiddens[0] = xs[0]
                _ins = hiddens + xs
                _outs = f_cost(*_ins)
                hiddens = _outs[:-2]
                cost = _outs[-2]
                cost_post = _outs[-1]
                test_cost.append(cost)
                test_cost_post.append(cost_post)

            test_cost = numpy.mean(test_cost)
            test_costs.append(test_cost)
            test_cost_post = numpy.mean(test_cost_post)
            test_costs_post.append(test_cost_post)
            print 'Test  : ', trunc(test_cost), trunc(test_cost_post), '\t',
            with open(logfile, 'a') as f:
                f.write("Test : {0!s} {1!s}\t".format(trunc(test_cost),
                                                      trunc(test_cost_post)))
            with open(test_convergence_pre, 'a') as f:
                f.write("{0!s},".format(test_cost))
            with open(test_convergence_post, 'a') as f:
                f.write("{0!s},".format(test_cost_post))

            #check for early stopping
            cost = train_cost
            if cost < best_cost * state.early_stop_threshold:
                patience = 0
                best_cost = cost
            else:
                patience += 1

            if counter >= n_epoch or patience >= state.early_stop_length:
                STOP = True
                save_params_to_file('gsn', counter, params, iteration)

            timing = time.time() - t
            times.append(timing)

            print 'time : ', trunc(timing),

            print 'remaining: ', trunc(
                (n_epoch - counter) * numpy.mean(times) / 60 / 60), 'hrs',

            print 'B : ', [
                trunc(abs(b.get_value(borrow=True)).mean()) for b in bias_list
            ],

            print 'W : ', [
                trunc(abs(w.get_value(borrow=True)).mean())
                for w in weights_list
            ],

            print 'V : ', [
                trunc(abs(v.get_value(borrow=True)).mean())
                for v in recurrent_weights_list
            ]

            with open(logfile, 'a') as f:
                f.write("MeanVisB : {0!s}\t".format(
                    trunc(bias_list[0].get_value().mean())))

            with open(logfile, 'a') as f:
                f.write("W : {0!s}\t".format(
                    str([
                        trunc(abs(w.get_value(borrow=True)).mean())
                        for w in weights_list
                    ])))

            with open(logfile, 'a') as f:
                f.write("Time : {0!s} seconds\n".format(trunc(timing)))

            if (counter % state.save_frequency) == 0:
                # Checking reconstruction
                nums = test_X.get_value()[range(100)]
                noisy_nums = f_noise(test_X.get_value()[range(100)])
                reconstructed_prediction = []
                reconstructed_prediction_end = []
                #init reconstruction hiddens
                hiddens = [
                    T.zeros(layer_size).eval() for layer_size in layer_sizes
                ]
                for num in noisy_nums:
                    hiddens[0] = num
                    for i in range(len(hiddens)):
                        if len(hiddens[i].shape
                               ) == 2 and hiddens[i].shape[0] == 1:
                            hiddens[i] = hiddens[i][0]
                    _ins = hiddens + [num]
                    _outs = f_recon(*_ins)
                    hiddens = _outs[:len(hiddens)]
                    [reconstructed_1, reconstructed_n] = _outs[len(hiddens):]
                    reconstructed_prediction.append(reconstructed_1)
                    reconstructed_prediction_end.append(reconstructed_n)

                with open(logfile, 'a') as f:
                    f.write("\n")
                for i in range(len(nums)):
                    if len(
                            reconstructed_prediction[i].shape
                    ) == 2 and reconstructed_prediction[i].shape[0] == 1:
                        reconstructed_prediction[i] = reconstructed_prediction[
                            i][0]
                    print nums[i].tolist(
                    ), "->", reconstructed_prediction[i].tolist()
                    with open(logfile, 'a') as f:
                        f.write("{0!s} -> {1!s}\n".format(
                            nums[i].tolist(), [
                                trunc(n)
                                if n > 0.0001 else trunc(0.00000000000000000)
                                for n in reconstructed_prediction[i].tolist()
                            ]))
                with open(logfile, 'a') as f:
                    f.write("\n")

#                 # Concatenate stuff
#                 stacked = numpy.vstack([numpy.vstack([nums[i*10 : (i+1)*10], noisy_nums[i*10 : (i+1)*10], reconstructed_prediction[i*10 : (i+1)*10], reconstructed_prediction_end[i*10 : (i+1)*10]]) for i in range(10)])
#                 numbers_reconstruction = PIL.Image.fromarray(tile_raster_images(stacked, (root_N_input,root_N_input), (10,40)))
#                 numbers_reconstruction.save(outdir+'gsn_number_reconstruction_iteration_'+str(iteration)+'_epoch_'+str(counter)+'.png')
#
#                 #sample_numbers(counter, 'seven')
#                 plot_samples(counter, iteration)
#
#                 #save params
#                 save_params_to_file('gsn', counter, params, iteration)

# ANNEAL!
            new_lr = learning_rate.get_value() * annealing
            learning_rate.set_value(new_lr)

        # 10k samples
        print 'Generating 10,000 samples'
        samples, _ = sample_some_numbers(N=10000)
        f_samples = outdir + 'samples.npy'
        numpy.save(f_samples, samples)
        print 'saved digits'

    #####################
    # STORY 2 ALGORITHM #
    #####################
    for iter in range(state.max_iterations):
        train_recurrent_GSN(iter, train_X, train_Y, valid_X, valid_Y, test_X,
                            test_Y)
Example #18
0
    def train(self, train_X=None, train_Y=None, valid_X=None, valid_Y=None, test_X=None, test_Y=None, is_artificial=False, artificial_sequence=1, continue_training=False):
        log.maybeLog(self.logger, "\nTraining---------\n")
        if train_X is None:
            log.maybeLog(self.logger, "Training using data given during initialization of RNN-GSN.\n")
            train_X = self.train_X
            train_Y = self.train_Y
            if train_X is None:
                log.maybeLog(self.logger, "\nPlease provide a training dataset!\n")
                raise AssertionError("Please provide a training dataset!")
        else:
            log.maybeLog(self.logger, "Training using data provided to training function.\n")
        if valid_X is None:
            valid_X = self.valid_X
            valid_Y = self.valid_Y
        if test_X is None:
            test_X  = self.test_X
            test_Y  = self.test_Y
            
        ##########################################################
        # Train the GSN first to get good weights initialization #
        ##########################################################
        if self.train_gsn_first:
            log.maybeLog(self.logger, "\n\n----------Initially training the GSN---------\n\n")
            init_gsn = generative_stochastic_network.GSN(train_X=train_X, valid_X=valid_X, test_X=test_X, args=self.gsn_args, logger=self.logger)
            init_gsn.train()
    
        #############################
        # Save the model parameters #
        #############################
        def save_params_to_file(name, n, gsn_params):
            pass
            print 'saving parameters...'
            save_path = self.outdir+name+'_params_epoch_'+str(n)+'.pkl'
            f = open(save_path, 'wb')
            try:
                cPickle.dump(gsn_params, f, protocol=cPickle.HIGHEST_PROTOCOL)
            finally:
                f.close()
                
        def save_params(params):
            values = [param.get_value(borrow=True) for param in params]
            return values
        
        def restore_params(params, values):
            for i in range(len(params)):
                params[i].set_value(values[i])
    
        
        #########################################
        # If we are using Hessian-free training #
        #########################################
        if self.hessian_free:
            pass
#         gradient_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=5000)
#         cg_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=1000)
#         valid_dataset = hf_sequence_dataset([valid_X.get_value()], batch_size=None, number_batches=1000)
#         
#         s = x_samples
#         costs = [cost, show_cost]
#         hf_optimizer(params, [Xs], s, costs, u, ua).train(gradient_dataset, cg_dataset, initial_lambda=1.0, preconditioner=True, validation=valid_dataset)
        
        ################################
        # If we are using SGD training #
        ################################
        else:
            log.maybeLog(self.logger, "\n-----------TRAINING RNN-GSN------------\n")
            # TRAINING
            STOP        =   False
            counter     =   0
            if not continue_training:
                self.learning_rate.set_value(self.init_learn_rate)  # learning rate
            times = []
            best_cost = float('inf')
            best_params = None
            patience = 0
                        
            log.maybeLog(self.logger, ['train X size:',str(train_X.shape.eval())])
            if valid_X is not None:
                log.maybeLog(self.logger, ['valid X size:',str(valid_X.shape.eval())])
            if test_X is not None:
                log.maybeLog(self.logger, ['test X size:',str(test_X.shape.eval())])
            
            if self.vis_init:
                self.bias_list[0].set_value(logit(numpy.clip(0.9,0.001,train_X.get_value().mean(axis=0))))
        
            while not STOP:
                counter += 1
                t = time.time()
                log.maybeAppend(self.logger, [counter,'\t'])
                    
                if is_artificial:
                    data.sequence_mnist_data(train_X, train_Y, valid_X, valid_Y, test_X, test_Y, artificial_sequence, rng)
                     
                #train
                train_costs = data.apply_cost_function_to_dataset(self.f_learn, train_X, self.batch_size)
                # record it
                log.maybeAppend(self.logger, ['Train:',trunc(train_costs),'\t'])
         
         
                #valid
                valid_costs = data.apply_cost_function_to_dataset(self.f_cost, valid_X, self.batch_size)
                # record it
                log.maybeAppend(self.logger, ['Valid:',trunc(valid_costs), '\t'])
         
         
                #test
                test_costs = data.apply_cost_function_to_dataset(self.f_cost, test_X, self.batch_size)
                # record it 
                log.maybeAppend(self.logger, ['Test:',trunc(test_costs), '\t'])
                 
                 
                #check for early stopping
                cost = numpy.sum(valid_costs)
                if cost < best_cost*self.early_stop_threshold:
                    patience = 0
                    best_cost = cost
                    # save the parameters that made it the best
                    best_params = save_params(self.params)
                else:
                    patience += 1
         
                if counter >= self.n_epoch or patience >= self.early_stop_length:
                    STOP = True
                    if best_params is not None:
                        restore_params(self.params, best_params)
                    save_params_to_file('all', counter, self.params)
         
                timing = time.time() - t
                times.append(timing)
         
                log.maybeAppend(self.logger, 'time: '+make_time_units_string(timing)+'\t')
            
                log.maybeLog(self.logger, 'remaining: '+make_time_units_string((self.n_epoch - counter) * numpy.mean(times)))
        
                if (counter % self.save_frequency) == 0 or STOP is True:
                    n_examples = 100
                    nums = test_X.get_value(borrow=True)[range(n_examples)]
                    noisy_nums = self.f_noise(test_X.get_value(borrow=True)[range(n_examples)])
                    reconstructions = []
                    for i in xrange(0, len(noisy_nums)):
                        recon = self.f_recon(noisy_nums[max(0,(i+1)-self.batch_size):i+1])
                        reconstructions.append(recon)
                    reconstructed = numpy.array(reconstructions)

                    # Concatenate stuff
                    stacked = numpy.vstack([numpy.vstack([nums[i*10 : (i+1)*10], noisy_nums[i*10 : (i+1)*10], reconstructed[i*10 : (i+1)*10]]) for i in range(10)])
                    number_reconstruction = PIL.Image.fromarray(tile_raster_images(stacked, (self.root_N_input,self.root_N_input), (10,30)))
                        
                    number_reconstruction.save(self.outdir+'rnngsn_number_reconstruction_epoch_'+str(counter)+'.png')
                    
                    #save params
                    save_params_to_file('all', counter, self.params)
             
                # ANNEAL!
                new_lr = self.learning_rate.get_value() * self.annealing
                self.learning_rate.set_value(new_lr)
Example #19
0
    def train(self,
              train_X=None,
              train_Y=None,
              valid_X=None,
              valid_Y=None,
              test_X=None,
              test_Y=None,
              is_artificial=False,
              artificial_sequence=1,
              continue_training=False):
        log.maybeLog(self.logger, "\nTraining---------\n")
        if train_X is None:
            log.maybeLog(
                self.logger,
                "Training using data given during initialization of RNN-GSN.\n"
            )
            train_X = self.train_X
            train_Y = self.train_Y
            if train_X is None:
                log.maybeLog(self.logger,
                             "\nPlease provide a training dataset!\n")
                raise AssertionError("Please provide a training dataset!")
        else:
            log.maybeLog(
                self.logger,
                "Training using data provided to training function.\n")
        if valid_X is None:
            valid_X = self.valid_X
            valid_Y = self.valid_Y
        if test_X is None:
            test_X = self.test_X
            test_Y = self.test_Y

        ##########################################################
        # Train the GSN first to get good weights initialization #
        ##########################################################
        if self.train_gsn_first:
            log.maybeLog(
                self.logger,
                "\n\n----------Initially training the GSN---------\n\n")
            init_gsn = generative_stochastic_network.GSN(train_X=train_X,
                                                         valid_X=valid_X,
                                                         test_X=test_X,
                                                         args=self.gsn_args,
                                                         logger=self.logger)
            init_gsn.train()

        #############################
        # Save the model parameters #
        #############################
        def save_params_to_file(name, n, gsn_params):
            pass
            print 'saving parameters...'
            save_path = self.outdir + name + '_params_epoch_' + str(n) + '.pkl'
            f = open(save_path, 'wb')
            try:
                cPickle.dump(gsn_params, f, protocol=cPickle.HIGHEST_PROTOCOL)
            finally:
                f.close()

        def save_params(params):
            values = [param.get_value(borrow=True) for param in params]
            return values

        def restore_params(params, values):
            for i in range(len(params)):
                params[i].set_value(values[i])

        #########################################
        # If we are using Hessian-free training #
        #########################################
        if self.hessian_free:
            pass
#         gradient_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=5000)
#         cg_dataset = hf_sequence_dataset([train_X.get_value()], batch_size=None, number_batches=1000)
#         valid_dataset = hf_sequence_dataset([valid_X.get_value()], batch_size=None, number_batches=1000)
#
#         s = x_samples
#         costs = [cost, show_cost]
#         hf_optimizer(params, [Xs], s, costs, u, ua).train(gradient_dataset, cg_dataset, initial_lambda=1.0, preconditioner=True, validation=valid_dataset)

################################
# If we are using SGD training #
################################
        else:
            log.maybeLog(self.logger,
                         "\n-----------TRAINING RNN-GSN------------\n")
            # TRAINING
            STOP = False
            counter = 0
            if not continue_training:
                self.learning_rate.set_value(
                    self.init_learn_rate)  # learning rate
            times = []
            best_cost = float('inf')
            best_params = None
            patience = 0

            log.maybeLog(
                self.logger,
                ['train X size:', str(train_X.shape.eval())])
            if valid_X is not None:
                log.maybeLog(self.logger,
                             ['valid X size:',
                              str(valid_X.shape.eval())])
            if test_X is not None:
                log.maybeLog(
                    self.logger,
                    ['test X size:', str(test_X.shape.eval())])

            if self.vis_init:
                self.bias_list[0].set_value(
                    logit(
                        numpy.clip(0.9, 0.001,
                                   train_X.get_value().mean(axis=0))))

            while not STOP:
                counter += 1
                t = time.time()
                log.maybeAppend(self.logger, [counter, '\t'])

                if is_artificial:
                    data.sequence_mnist_data(train_X, train_Y, valid_X,
                                             valid_Y, test_X, test_Y,
                                             artificial_sequence, rng)

                #train
                train_costs = data.apply_cost_function_to_dataset(
                    self.f_learn, train_X, self.batch_size)
                # record it
                log.maybeAppend(self.logger,
                                ['Train:', trunc(train_costs), '\t'])

                #valid
                valid_costs = data.apply_cost_function_to_dataset(
                    self.f_cost, valid_X, self.batch_size)
                # record it
                log.maybeAppend(self.logger,
                                ['Valid:', trunc(valid_costs), '\t'])

                #test
                test_costs = data.apply_cost_function_to_dataset(
                    self.f_cost, test_X, self.batch_size)
                # record it
                log.maybeAppend(self.logger,
                                ['Test:', trunc(test_costs), '\t'])

                #check for early stopping
                cost = numpy.sum(valid_costs)
                if cost < best_cost * self.early_stop_threshold:
                    patience = 0
                    best_cost = cost
                    # save the parameters that made it the best
                    best_params = save_params(self.params)
                else:
                    patience += 1

                if counter >= self.n_epoch or patience >= self.early_stop_length:
                    STOP = True
                    if best_params is not None:
                        restore_params(self.params, best_params)
                    save_params_to_file('all', counter, self.params)

                timing = time.time() - t
                times.append(timing)

                log.maybeAppend(
                    self.logger,
                    'time: ' + make_time_units_string(timing) + '\t')

                log.maybeLog(
                    self.logger, 'remaining: ' + make_time_units_string(
                        (self.n_epoch - counter) * numpy.mean(times)))

                if (counter % self.save_frequency) == 0 or STOP is True:
                    n_examples = 100
                    nums = test_X.get_value(borrow=True)[range(n_examples)]
                    noisy_nums = self.f_noise(
                        test_X.get_value(borrow=True)[range(n_examples)])
                    reconstructions = []
                    for i in xrange(0, len(noisy_nums)):
                        recon = self.f_recon(
                            noisy_nums[max(0, (i + 1) - self.batch_size):i +
                                       1])
                        reconstructions.append(recon)
                    reconstructed = numpy.array(reconstructions)

                    # Concatenate stuff
                    stacked = numpy.vstack([
                        numpy.vstack([
                            nums[i * 10:(i + 1) * 10],
                            noisy_nums[i * 10:(i + 1) * 10],
                            reconstructed[i * 10:(i + 1) * 10]
                        ]) for i in range(10)
                    ])
                    number_reconstruction = PIL.Image.fromarray(
                        tile_raster_images(
                            stacked, (self.root_N_input, self.root_N_input),
                            (10, 30)))

                    number_reconstruction.save(
                        self.outdir + 'rnngsn_number_reconstruction_epoch_' +
                        str(counter) + '.png')

                    #save params
                    save_params_to_file('all', counter, self.params)

                # ANNEAL!
                new_lr = self.learning_rate.get_value() * self.annealing
                self.learning_rate.set_value(new_lr)