Beispiel #1
0
def cd_stats(rbm, v0_vmap, visible_units, hidden_units, context_units=[], k=1, mean_field_for_stats=[], mean_field_for_gibbs=[], persistent_vmap=None):
    # mean_field_for_gibbs is a list of units for which 'mean_field' should be used during gibbs sampling, rather than 'sample'.
    # mean_field_for_stats is a list of units for which 'mean_field' should be used to compute statistics, rather than 'sample'.

    # complete units lists
    visible_units = rbm.complete_units_list(visible_units)
    hidden_units = rbm.complete_units_list(hidden_units)
    context_units = rbm.complete_units_list(context_units)
    
    # complete the supplied vmap
    v0_vmap = rbm.complete_vmap(v0_vmap)
    
    # extract the context vmap, because we will need to merge it into all other vmaps
    context_vmap = dict((u, v0_vmap[u]) for u in context_units)

    h0_activation_vmap = dict((h, h.activation(v0_vmap)) for h in hidden_units)
    h0_stats_vmap, h0_gibbs_vmap = gibbs_step(rbm, v0_vmap, hidden_units, mean_field_for_stats, mean_field_for_gibbs)
            
    # add context
    h0_activation_vmap.update(context_vmap)
    h0_gibbs_vmap.update(context_vmap)
    h0_stats_vmap.update(context_vmap)
    
    exp_input = [v0_vmap[u] for u in visible_units]
    exp_context = [v0_vmap[u] for u in context_units]
    exp_latent = [h0_gibbs_vmap[u] for u in hidden_units]
    
    # scan requires a function that returns theano expressions, so we cannot pass vmaps in or out. annoying.
    def gibbs_hvh(*args):
        h0_gibbs_vmap = dict(zip(hidden_units, args))
        
        v1_in_vmap = h0_gibbs_vmap.copy()
        v1_in_vmap.update(context_vmap) # add context
        
        v1_activation_vmap = dict((v, v.activation(v1_in_vmap)) for v in visible_units)
        v1_stats_vmap, v1_gibbs_vmap = gibbs_step(rbm, v1_in_vmap, visible_units, mean_field_for_stats, mean_field_for_gibbs)

        h1_in_vmap = v1_gibbs_vmap.copy()
        h1_in_vmap.update(context_vmap) # add context

        h1_activation_vmap = dict((h, h.activation(h1_in_vmap)) for h in hidden_units)
        h1_stats_vmap, h1_gibbs_vmap = gibbs_step(rbm, h1_in_vmap, hidden_units, mean_field_for_stats, mean_field_for_gibbs)
            
        # get the v1 values in a fixed order
        v1_activation_values = [v1_activation_vmap[u] for u in visible_units]
        v1_gibbs_values = [v1_gibbs_vmap[u] for u in visible_units]
        v1_stats_values = [v1_stats_vmap[u] for u in visible_units]
        
        # same for the h1 values
        h1_activation_values = [h1_activation_vmap[u] for u in hidden_units]
        h1_gibbs_values = [h1_gibbs_vmap[u] for u in hidden_units]
        h1_stats_values = [h1_stats_vmap[u] for u in hidden_units]
        
        return v1_activation_values + v1_stats_values + v1_gibbs_values + \
               h1_activation_values + h1_stats_values + h1_gibbs_values
    
    
    # support for persistent CD
    if persistent_vmap is None:
        chain_start = exp_latent
    else:
        chain_start = [persistent_vmap[u] for u in hidden_units]
    
    
    # The 'outputs_info' keyword argument of scan configures how the function outputs are mapped to the inputs.
    # in this case, we want the h1_gibbs_vmap values to map onto the function arguments, so they become
    # h0_gibbs_vmap values in the next iteration. To this end, we construct outputs_info as follows:
    outputs_info = [None] * (len(exp_input)*3) + [None] * (len(exp_latent)*2) + list(chain_start)
    # 'None' indicates that this output is not used in the next iteration.
    
    exp_output_all_list, theano_updates = theano.scan(gibbs_hvh, outputs_info = outputs_info, n_steps = k)
    # we only need the final outcomes, not intermediary values
    exp_output_list = [out[-1] for out in exp_output_all_list]
            
    # reconstruct vmaps from the exp_output_list.
    n_input, n_latent = len(visible_units), len(hidden_units)
    vk_activation_vmap = dict(zip(visible_units, exp_output_list[0:1*n_input]))
    vk_stats_vmap = dict(zip(visible_units, exp_output_list[1*n_input:2*n_input]))
    vk_gibbs_vmap = dict(zip(visible_units, exp_output_list[2*n_input:3*n_input]))
    hk_activation_vmap = dict(zip(hidden_units, exp_output_list[3*n_input:3*n_input+1*n_latent]))
    hk_stats_vmap = dict(zip(hidden_units, exp_output_list[3*n_input+1*n_latent:3*n_input+2*n_latent]))
    hk_gibbs_vmap = dict(zip(hidden_units, exp_output_list[3*n_input+2*n_latent:3*n_input+3*n_latent]))
    
    # add the Theano updates for the persistent CD states:
    if persistent_vmap is not None:
        for u, v in persistent_vmap.items():
            theano_updates[v] = hk_gibbs_vmap[u] # this should be the gibbs vmap, and not the stats vmap!
    
    activation_data_vmap = v0_vmap.copy() # TODO: this doesn't really make sense to have in an activation vmap!
    activation_data_vmap.update(h0_activation_vmap)
    activation_model_vmap = vk_activation_vmap.copy()
    activation_model_vmap.update(context_vmap)
    activation_model_vmap.update(hk_activation_vmap)
    
    stats = Stats(theano_updates) # create a new stats object
    
    # store the computed stats in a dictionary of vmaps.
    stats_data_vmap = v0_vmap.copy()
    stats_data_vmap.update(h0_stats_vmap)
    stats_model_vmap = vk_stats_vmap.copy()
    stats_model_vmap.update(context_vmap)
    stats_model_vmap.update(hk_stats_vmap)
    stats.update({
      'data': stats_data_vmap,
      'model': stats_model_vmap,
    })
            
    stats['data_activation'] = activation_data_vmap
    stats['model_activation'] = activation_model_vmap
        
    return stats
Beispiel #2
0
def pt_stats(rbm, v0_vmap, visible_units, hidden_units, persistent_vmap, beta, k=1, m=1):
    """Returns stats for parallel tempering given a list of inverse temperatures beta.

    
    
    """
    # TODO: use an additional shared variable in scan's output info to count
    # the switch acceptances.

    # v0_vmap is the batch of train data
    # k is the number of sampling steps
    # persistent_vmap and v0_vmap should determine the data batch size and the number of chains
    # m is the number of chains that is used for model statistics
    N_chains = persistent_vmap[rbm.h].shape[0]

    # complete units lists
    visible_units = rbm.complete_units_list(visible_units)
    hidden_units = rbm.complete_units_list(hidden_units)
    
    # complete the supplied vmap
    v0_vmap = rbm.complete_vmap(v0_vmap)
    mb_size = theano.tensor.cast(v0_vmap[rbm.v].shape[0], dtype=theano.config.floatX)
    n_chains = theano.tensor.cast(m, dtype=theano.config.floatX)
    
    h0_activation_vmap = dict((h, h.activation(v0_vmap)) for h in hidden_units)

    # compute data dependent gradient component
    h0_stats_vmap = rbm.mean_field_from_activation(h0_activation_vmap)

            
    # scan requires a function that returns theano expressions, so we cannot pass vmaps in or out. annoying.
    # for this reason, the exp lists are used and the identity of the variables is coded in the order.
    exp_input = [v0_vmap[u] for u in visible_units]
    exp_latent = [h0_stats_vmap[u] for u in hidden_units]

    
    def gibbs_hvh(*args):
        # generates a fixed order list to be processed by scan
        h0_gibbs_vmap = dict(zip(hidden_units, args))
        
        # what goes 'in' the rbm to compute the visibles
        v1_in_vmap = h0_gibbs_vmap.copy()
        
        v1_activation_vmap = dict((v, v.activation(v1_in_vmap)) for v in visible_units)
        v1_scaled_activation_vmap = rescale_activations(v1_activation_vmap, beta)
        
        v1_gibbs_vmap_preswap = rbm.sample_from_activation(v1_scaled_activation_vmap)

        h1_in_vmap = v1_gibbs_vmap_preswap.copy()

        h1_activation_vmap = dict((h, h.activation(h1_in_vmap)) for h in hidden_units)

        h1_scaled_activation_vmap = rescale_activations(h1_activation_vmap, beta)
        h1_gibbs_vmap_preswap = rbm.sample_from_activation(h1_scaled_activation_vmap)
        
        # merge activations
        samples = h1_gibbs_vmap_preswap
        samples[rbm.v] = v1_gibbs_vmap_preswap[rbm.v]

        # Select a candidate (from N_chains-1 possible pairs) pair for replica exchange
        rv_i = rng.random_integers(low=0, high=N_chains-2) #the range is inclusive
        rv_u = rng.uniform()
        
        
        chain_pair = manipulate_vmap(samples, lambda x: x[rv_i:rv_i+2, :])

        # compute relevant energy scores
        e_pair = rbm.energy(chain_pair) # vector with two elements
        
        b1, b2 = beta[rv_i][0], beta[rv_i+1][0]
        
        r = T.exp((b1 - b2) * (e_pair[0] - e_pair[1]))
        comparison = T.le(rv_u, r)
        accepted = T.cast(T.switch(comparison, rv_i, -1), dtype=theano.config.floatX)
        def swap(x):
            return T.concatenate((x[:rv_i, :], x[rv_i+1:rv_i+2, :],
                               x[rv_i:rv_i+1, :], x[rv_i+2:, :]))
        v1_gibbs_vmap_swap = manipulate_vmap(v1_gibbs_vmap_preswap, swap)
        h1_gibbs_vmap_swap = manipulate_vmap(h1_gibbs_vmap_preswap, swap)
        v1_gibbs_vmap = switch_vmap(comparison, v1_gibbs_vmap_swap, v1_gibbs_vmap_preswap)
        h1_gibbs_vmap = switch_vmap(comparison, h1_gibbs_vmap_swap, h1_gibbs_vmap_preswap)
        
        # only use the first m chains for stats
        h1_stats_vmap = manipulate_vmap(h1_gibbs_vmap, lambda x: x[:m, :])
        v1_stats_vmap = manipulate_vmap(v1_gibbs_vmap, lambda x: x[:m, :])
        
        # get the v1 values in a fixed order
        v1_activation_values = [v1_activation_vmap[u] for u in visible_units]
        v1_gibbs_values = [v1_gibbs_vmap[u] for u in visible_units]
        v1_stats_values = [v1_stats_vmap[u] for u in visible_units]
        
        # same for the h1 values
        h1_activation_values = [h1_activation_vmap[u] for u in hidden_units]
        h1_gibbs_values = [h1_gibbs_vmap[u] for u in hidden_units]
        h1_stats_values = [h1_stats_vmap[u] for u in hidden_units]
        
        return v1_activation_values + v1_stats_values + v1_gibbs_values + \
               h1_activation_values + h1_stats_values + h1_gibbs_values + [accepted]
    
    
    chain_start = [persistent_vmap[u] for u in hidden_units]
    
    
    # The 'outputs_info' keyword argument of scan configures how the function outputs are mapped to the inputs.
    # in this case, we want the h1_gibbs_vmap values to map onto the function arguments, so they become
    # h0_gibbs_vmap values in the next iteration. To this end, we construct outputs_info as follows:
    outputs_info = [None] * (len(exp_input)*3) + [None] * (len(exp_latent)*2) + list(chain_start) + [None]
 
    # 'None' indicates that this output is not used in the next iteration.
    
    exp_output_all_list, theano_updates = theano.scan(gibbs_hvh, outputs_info = outputs_info, n_steps = k)
    # we only need the final outcomes, not intermediary values
    exp_output_list = [out[-1] for out in exp_output_all_list]
    accepted = exp_output_all_list[-1]
    #accepted = exp_output_list[-1]
            
    # TODO: this messes things up when I use larger batches of data
    # reconstruct vmaps from the exp_output_list.
    n_input, n_latent = len(visible_units), len(hidden_units)
    vk_activation_vmap = dict(zip(visible_units, exp_output_list[0:1*n_input]))
    vk_stats_vmap = dict(zip(visible_units, exp_output_list[1*n_input:2*n_input]))
    vk_gibbs_vmap = dict(zip(visible_units, exp_output_list[2*n_input:3*n_input]))
    hk_activation_vmap = dict(zip(hidden_units, exp_output_list[3*n_input:3*n_input+1*n_latent]))
    hk_stats_vmap = dict(zip(hidden_units, exp_output_list[3*n_input+1*n_latent:3*n_input+2*n_latent]))
    hk_gibbs_vmap = dict(zip(hidden_units, exp_output_list[3*n_input+2*n_latent:3*n_input+3*n_latent]))
    # add the Theano updates for the persistent CD states:
    if persistent_vmap is not None:
        for u, v in persistent_vmap.items():
            theano_updates[v] = hk_gibbs_vmap[u] # this should be the gibbs vmap, and not the stats vmap!
    
    activation_data_vmap = v0_vmap.copy() # TODO: this doesn't really make sense to have in an activation vmap!
    activation_data_vmap.update(h0_activation_vmap)
    activation_model_vmap = vk_activation_vmap.copy()
    activation_model_vmap.update(hk_activation_vmap)
    
    stats = Stats(theano_updates) # create a new stats object
    
    # store the computed stats in a dictionary of vmaps.
    stats_data_vmap = v0_vmap.copy()
    stats_data_vmap.update(h0_stats_vmap)
    stats_model_vmap = vk_stats_vmap.copy()
    stats_model_vmap.update(hk_stats_vmap)
    stats.update({
      'data': stats_data_vmap,
      'model': stats_model_vmap,
    })
            
    stats['data_activation'] = activation_data_vmap
    stats['model_activation'] = activation_model_vmap
    stats['accepted'] = accepted
        
    return stats