def test_model(model_config_dict, model_test_name):
    import glob
    model_list = glob.glob(samples_dir +'/*.pkl')
    # load parameters
    model_param_dicts = unpickle(model_list[0])

    # load generator
    generator_models = load_generator_model(min_num_gen_filters=model_config_dict['min_num_gen_filters'],
                                            model_params_dict=model_param_dicts)
    generator_function = generator_models[0]

    print 'COMPILING SAMPLING FUNCTION'
    t=time()
    sampling_function = set_sampling_function(generator_function=generator_function)
    print '%.2f SEC '%(time()-t)

    print 'START SAMPLING'
    for s in xrange(model_config_dict['num_sampling']):
        print '{} sampling'.format(s)
        hidden_data  = floatX(np_rng.uniform(low=-model_config_dict['hidden_distribution'],
                                             high=model_config_dict['hidden_distribution'],
                                             size=(model_config_dict['num_display'], model_config_dict['hidden_size'])))
        sample_data = sampling_function(hidden_data)[0]
        sample_data = inverse_transform(np.asarray(sample_data)).transpose([0,2,3,1])
        save_as = samples_dir + '/' + model_test_name + '_SAMPLES(TRAIN){}.png'.format(s+1)
        color_grid_vis(sample_data, (16, 16), save_as)
def continue_train_model(last_batch_idx,
                         data_stream,
                         energy_optimizer,
                         generator_optimizer,
                         model_config_dict,
                         model_test_name):
    model_list = glob.glob(samples_dir +'/*.pkl')
    # load parameters
    model_param_dicts = unpickle(model_list[0])
    generator_models = load_generator_model(min_num_gen_filters=model_config_dict['min_num_gen_filters'],
                                            model_params_dict=model_param_dicts)
    generator_function = generator_models[0]
    generator_params   = generator_models[1]

    energy_models = load_energy_model(num_experts=model_config_dict['expert_size'],
                                      model_params_dict=model_param_dicts)
    feature_function = energy_models[0]
    # norm_function    = energy_models[1]
    expert_function  = energy_models[1]
    # prior_function   = energy_models[3]
    energy_params    = energy_models[2]

    # compile functions
    print 'COMPILING MODEL UPDATER'
    t=time()
    generator_updater, generator_optimizer_params = set_generator_update_function(energy_feature_function=feature_function,
                                                                                  # energy_norm_function=norm_function,
                                                                                  energy_expert_function=expert_function,
                                                                                  # energy_prior_function=prior_function,
                                                                                  generator_function=generator_function,
                                                                                  generator_params=generator_params,
                                                                                  generator_optimizer=generator_optimizer,
                                                                                  init_param_dict=model_param_dicts)
    energy_updater, energy_optimizer_params = set_energy_update_function(energy_feature_function=feature_function,
                                                                         # energy_norm_function=norm_function,
                                                                         energy_expert_function=expert_function,
                                                                         # energy_prior_function=prior_function,
                                                                         generator_function=generator_function,
                                                                         energy_params=energy_params,
                                                                         energy_optimizer=energy_optimizer,
                                                                         init_param_dict=model_param_dicts)
    print '%.2f SEC '%(time()-t)
    print 'COMPILING SAMPLING FUNCTION'
    t=time()
    sampling_function = set_sampling_function(generator_function=generator_function)
    print '%.2f SEC '%(time()-t)

    # set fixed hidden data for sampling
    fixed_hidden_data  = floatX(np_rng.uniform(low=-model_config_dict['hidden_distribution'],
                                               high=model_config_dict['hidden_distribution'],
                                               size=(model_config_dict['num_display'], model_config_dict['hidden_size'])))

    print 'START TRAINING'
    # for each epoch
    input_energy_list = []
    sample_energy_list = []
    batch_count = 0
    for e in xrange(model_config_dict['epochs']):
        # train phase
        batch_iters = data_stream.get_epoch_iterator()
        # for each batch
        for b, batch_data in enumerate(batch_iters):
            # batch count up
            batch_count += 1
            if batch_count<last_batch_idx:
                continue

            # set update function inputs
            input_data   = transform(batch_data[0])
            num_data     = input_data.shape[0]
            hidden_data  = floatX(np_rng.uniform(low=-model_config_dict['hidden_distribution'],
                                                 high=model_config_dict['hidden_distribution'],
                                                 size=(num_data, model_config_dict['hidden_size'])))

            noise_data      = floatX(np_rng.normal(scale=0.01, size=input_data.shape))
            update_input    = [hidden_data, noise_data]
            update_output   = generator_updater(*update_input)
            entropy_weights = update_output[1].mean()
            entropy_cost    = update_output[2].mean()

            noise_data      = floatX(np_rng.normal(scale=0.01, size=input_data.shape))
            update_input    = [input_data, hidden_data, noise_data]
            update_output   = energy_updater(*update_input)
            input_energy    = update_output[0].mean()
            sample_energy   = update_output[1].mean()

            input_energy_list.append(input_energy)
            sample_energy_list.append(sample_energy)

            if batch_count%10==0:
                print '================================================================'
                print 'BATCH ITER #{}'.format(batch_count), model_test_name
                print '================================================================'
                print '   TRAIN RESULTS'
                print '================================================================'
                print '     input energy     : ', input_energy_list[-1]
                print '----------------------------------------------------------------'
                print '     sample energy    : ', sample_energy_list[-1]
                print '----------------------------------------------------------------'
                print '     entropy weight   : ', entropy_weights
                print '----------------------------------------------------------------'
                print '     entropy cost     : ', entropy_cost
                print '================================================================'

            if batch_count%100==0:
                # sample data
                sample_data = sampling_function(fixed_hidden_data)[0]
                sample_data = np.asarray(sample_data)
                save_as = samples_dir + '/' + model_test_name + '_SAMPLES{}.png'.format(batch_count)
                color_grid_vis(inverse_transform(sample_data).transpose([0,2,3,1]), (16, 16), save_as)
                np.save(file=samples_dir + '/' + model_test_name +'_input_energy',
                        arr=np.asarray(input_energy_list))
                np.save(file=samples_dir + '/' + model_test_name +'_sample_energy',
                        arr=np.asarray(sample_energy_list))

                save_as = samples_dir + '/' + model_test_name + '_MODEL.pkl'
                save_model(tensor_params_list=generator_params[0] + generator_params[1] + energy_params + generator_optimizer_params + energy_optimizer_params,
                           save_to=save_as)