def __init__(self, train_X=None, valid_X=None, test_X=None, args=None, logger=None):
     # Output logger
     self.logger = logger
     self.outdir = args.get("output_path", defaults["output_path"])
     if self.outdir[-1] != '/':
         self.outdir = self.outdir+'/'
     # Input data - make sure it is a list of shared datasets
     self.train_X = raise_data_to_list(train_X)
     self.valid_X = raise_data_to_list(valid_X)
     self.test_X  = raise_data_to_list(test_X)
     
     # variables from the dataset that are used for initialization and image reconstruction
     if train_X is None:
         self.N_input = args.get("input_size")
         if args.get("input_size") is None:
             raise AssertionError("Please either specify input_size in the arguments or provide an example train_X for input dimensionality.")
     else:
         self.N_input = train_X[0].eval().shape[1]
     self.root_N_input = numpy.sqrt(self.N_input)
     
     self.is_image = args.get('is_image', defaults['is_image'])
     if self.is_image:
         self.image_width  = args.get('width', self.root_N_input)
         self.image_height = args.get('height', self.root_N_input)
     
     #######################################
     # Network and training specifications #
     #######################################
     self.layers          = args.get('layers', defaults['layers']) # number hidden layers
     self.walkbacks       = args.get('walkbacks', defaults['walkbacks']) # number of walkbacks
     self.learning_rate   = theano.shared(cast32(args.get('learning_rate', defaults['learning_rate'])))  # learning rate
     self.init_learn_rate = cast32(args.get('learning_rate', defaults['learning_rate']))
     self.momentum        = theano.shared(cast32(args.get('momentum', defaults['momentum']))) # momentum term
     self.annealing       = cast32(args.get('annealing', defaults['annealing'])) # exponential annealing coefficient
     self.noise_annealing = cast32(args.get('noise_annealing', defaults['noise_annealing'])) # exponential noise annealing coefficient
     self.batch_size      = args.get('batch_size', defaults['batch_size'])
     self.n_epoch         = args.get('n_epoch', defaults['n_epoch'])
     self.early_stop_threshold = args.get('early_stop_threshold', defaults['early_stop_threshold'])
     self.early_stop_length = args.get('early_stop_length', defaults['early_stop_length'])
     self.save_frequency  = args.get('save_frequency', defaults['save_frequency'])
     
     self.noiseless_h1           = args.get('noiseless_h1', defaults["noiseless_h1"])
     self.hidden_add_noise_sigma = theano.shared(cast32(args.get('hidden_add_noise_sigma', defaults["hidden_add_noise_sigma"])))
     self.input_salt_and_pepper  = theano.shared(cast32(args.get('input_salt_and_pepper', defaults["input_salt_and_pepper"])))
     self.input_sampling         = args.get('input_sampling', defaults["input_sampling"])
     self.vis_init               = args.get('vis_init', defaults['vis_init'])
     
     self.layer_sizes = [self.N_input] + [args.get('hidden_size', defaults['hidden_size'])] * self.layers # layer sizes, from h0 to hK (h0 is the visible layer)
     
     self.f_recon = None
     self.f_noise = None
     
     # Activation functions!            
     if args.get('hidden_activation') is not None:
         log.maybeLog(self.logger, 'Using specified activation for hiddens')
         self.hidden_activation = args.get('hidden_activation')
     elif args.get('hidden_act') is not None:
         self.hidden_activation = get_activation_function(args.get('hidden_act'))
         log.maybeLog(self.logger, 'Using {0!s} activation for hiddens'.format(args.get('hidden_act')))
     else:
         log.maybeLog(self.logger, "Using default activation for hiddens")
         self.hidden_activation = defaults['hidden_activation']
         
     # Visible layer activation
     if args.get('visible_activation') is not None:
         log.maybeLog(self.logger, 'Using specified activation for visible layer')
         self.visible_activation = args.get('visible_activation')
     elif args.get('visible_act') is not None:
         self.visible_activation = get_activation_function(args.get('visible_act'))
         log.maybeLog(self.logger, 'Using {0!s} activation for visible layer'.format(args.get('visible_act')))
     else:
         log.maybeLog(self.logger, 'Using default activation for visible layer')
         self.visible_activation = defaults['visible_activation']
         
     # Cost function!
     if args.get('cost_function') is not None:
         log.maybeLog(self.logger, '\nUsing specified cost function for training\n')
         self.cost_function = args.get('cost_function')
     elif args.get('cost_funct') is not None:
         self.cost_function = get_cost_function(args.get('cost_funct'))
         log.maybeLog(self.logger, 'Using {0!s} for cost function'.format(args.get('cost_funct')))
     else:
         log.maybeLog(self.logger, '\nUsing default cost function for training\n')
         self.cost_function = defaults['cost_function']
     
     ############################
     # Theano variables and RNG #
     ############################
     self.X   = T.fmatrix('X') # for use in sampling
     self.MRG = RNG_MRG.MRG_RandomStreams(1)
     rng.seed(1)
     
     ###############
     # Parameters! #
     ###############
     # initialize a list of weights and biases based on layer_sizes for the GSN
     if args.get('weights_list') is None:
         self.weights_list = [get_shared_weights(self.layer_sizes[layer], self.layer_sizes[layer+1], name="W_{0!s}_{1!s}".format(layer,layer+1)) for layer in range(self.layers)] # initialize each layer to uniform sample from sqrt(6. / (n_in + n_out))
     else:
         self.weights_list = args.get('weights_list')
     if args.get('bias_list') is None:
         self.bias_list    = [get_shared_bias(self.layer_sizes[layer], name='b_'+str(layer)) for layer in range(self.layers + 1)] # initialize each layer to 0's.
     else:
         self.bias_list    = args.get('bias_list')
     self.params = self.weights_list + self.bias_list
     
     #################
     # Build the GSN #
     #################
     log.maybeLog(self.logger, "\nBuilding GSN graphs for training and testing")
     # GSN for training - with noise
     add_noise = True
     p_X_chain, _ = build_gsn(self.X,
                              self.weights_list,
                              self.bias_list,
                              add_noise,
                              self.noiseless_h1,
                              self.hidden_add_noise_sigma,
                              self.input_salt_and_pepper,
                              self.input_sampling,
                              self.MRG,
                              self.visible_activation,
                              self.hidden_activation,
                              self.walkbacks,
                              self.logger)
     
     # GSN for reconstruction checks along the way - no noise
     add_noise = False
     p_X_chain_recon, _ = build_gsn(self.X,
                                    self.weights_list,
                                    self.bias_list,
                                    add_noise,
                                    self.noiseless_h1,
                                    self.hidden_add_noise_sigma,
                                    self.input_salt_and_pepper,
                                    self.input_sampling,
                                    self.MRG,
                                    self.visible_activation,
                                    self.hidden_activation,
                                    self.walkbacks,
                                    self.logger)
     
     #######################
     # Costs and gradients #
     #######################
     log.maybeLog(self.logger, 'Cost w.r.t p(X|...) at every step in the graph for the GSN')
     gsn_costs     = [self.cost_function(rX, self.X) for rX in p_X_chain]
     show_gsn_cost = gsn_costs[-1] # for logging to show progress
     gsn_cost      = numpy.sum(gsn_costs)
     
     gsn_costs_recon     = [self.cost_function(rX, self.X) for rX in p_X_chain_recon]
     show_gsn_cost_recon = gsn_costs_recon[-1]
     
     log.maybeLog(self.logger, ["gsn params:", self.params])
     
     # Stochastic gradient descent!
     gradient        =   T.grad(gsn_cost, self.params)              
     gradient_buffer =   [theano.shared(numpy.zeros(param.get_value().shape, dtype='float32')) for param in self.params] 
     m_gradient      =   [self.momentum * gb + (cast32(1) - self.momentum) * g for (gb, g) in zip(gradient_buffer, gradient)]
     param_updates   =   [(param, param - self.learning_rate * mg) for (param, mg) in zip(self.params, m_gradient)]
     gradient_buffer_updates = zip(gradient_buffer, m_gradient)
     updates         =   OrderedDict(param_updates + gradient_buffer_updates)
     
     ############
     # Sampling #
     ############
     # the input to the sampling function
     X_sample = T.fmatrix("X_sampling")
     self.network_state_input = [X_sample] + [T.fmatrix("H_sampling_"+str(i+1)) for i in range(self.layers)]
    
     # "Output" state of the network (noisy)
     # initialized with input, then we apply updates
     self.network_state_output = [X_sample] + self.network_state_input[1:]
     visible_pX_chain = []
 
     # ONE update
     log.maybeLog(self.logger, "Performing one walkback in network state sampling.")
     update_layers(self.network_state_output,
                   self.weights_list,
                   self.bias_list,
                   visible_pX_chain, 
                   True,
                   self.noiseless_h1,
                   self.hidden_add_noise_sigma,
                   self.input_salt_and_pepper,
                   self.input_sampling,
                   self.MRG,
                   self.visible_activation,
                   self.hidden_activation,
                   self.logger)
     
     #################################
     #     Create the functions      #
     #################################
     log.maybeLog(self.logger, "Compiling functions...")
     t = time.time()
     
     self.f_learn = theano.function(inputs  = [self.X],
                               updates = updates,
                               outputs = show_gsn_cost,
                               name='gsn_f_learn')
 
     self.f_cost  = theano.function(inputs  = [self.X],
                               outputs = show_gsn_cost,
                               name='gsn_f_cost')
     
     # used for checkpoints and testing - no noise in network
     self.f_recon = theano.function(inputs  = [self.X],
                                    outputs = [show_gsn_cost_recon, p_X_chain_recon[-1]],
                                    name='gsn_f_recon')
     
     self.f_noise = theano.function(inputs = [self.X],
                                    outputs = salt_and_pepper(self.X, self.input_salt_and_pepper),
                                    name='gsn_f_noise')
 
     if self.layers == 1: 
         self.f_sample = theano.function(inputs = [X_sample], 
                                         outputs = visible_pX_chain[-1], 
                                         name='gsn_f_sample_single_layer')
     else:
         # WHY IS THERE A WARNING????
         # because the first odd layers are not used -> directly computed FROM THE EVEN layers
         # unused input = warn
         self.f_sample = theano.function(inputs = self.network_state_input,
                                         outputs = self.network_state_output + visible_pX_chain,
                                         on_unused_input='warn',
                                         name='gsn_f_sample')
     
     log.maybeLog(self.logger, "Compiling done. Took "+make_time_units_string(time.time() - t)+".\n")
Exemple #2
0
    def __init__(self,
                 train_X=None,
                 valid_X=None,
                 test_X=None,
                 args=None,
                 logger=None):
        # Output logger
        self.logger = logger
        self.outdir = args.get("output_path", defaults["output_path"])
        if self.outdir[-1] != '/':
            self.outdir = self.outdir + '/'
        # Input data - make sure it is a list of shared datasets
        self.train_X = raise_data_to_list(train_X)
        self.valid_X = raise_data_to_list(valid_X)
        self.test_X = raise_data_to_list(test_X)

        # variables from the dataset that are used for initialization and image reconstruction
        if train_X is None:
            self.N_input = args.get("input_size")
            if args.get("input_size") is None:
                raise AssertionError(
                    "Please either specify input_size in the arguments or provide an example train_X for input dimensionality."
                )
        else:
            self.N_input = train_X[0].eval().shape[1]
        self.root_N_input = numpy.sqrt(self.N_input)

        self.is_image = args.get('is_image', defaults['is_image'])
        if self.is_image:
            self.image_width = args.get('width', self.root_N_input)
            self.image_height = args.get('height', self.root_N_input)

        #######################################
        # Network and training specifications #
        #######################################
        self.layers = args.get('layers',
                               defaults['layers'])  # number hidden layers
        self.walkbacks = args.get('walkbacks',
                                  defaults['walkbacks'])  # number of walkbacks
        self.learning_rate = theano.shared(
            cast32(args.get('learning_rate',
                            defaults['learning_rate'])))  # learning rate
        self.init_learn_rate = cast32(
            args.get('learning_rate', defaults['learning_rate']))
        self.momentum = theano.shared(
            cast32(args.get('momentum',
                            defaults['momentum'])))  # momentum term
        self.annealing = cast32(args.get(
            'annealing',
            defaults['annealing']))  # exponential annealing coefficient
        self.noise_annealing = cast32(
            args.get('noise_annealing', defaults['noise_annealing'])
        )  # exponential noise annealing coefficient
        self.batch_size = args.get('batch_size', defaults['batch_size'])
        self.n_epoch = args.get('n_epoch', defaults['n_epoch'])
        self.early_stop_threshold = args.get('early_stop_threshold',
                                             defaults['early_stop_threshold'])
        self.early_stop_length = args.get('early_stop_length',
                                          defaults['early_stop_length'])
        self.save_frequency = args.get('save_frequency',
                                       defaults['save_frequency'])

        self.noiseless_h1 = args.get('noiseless_h1', defaults["noiseless_h1"])
        self.hidden_add_noise_sigma = theano.shared(
            cast32(
                args.get('hidden_add_noise_sigma',
                         defaults["hidden_add_noise_sigma"])))
        self.input_salt_and_pepper = theano.shared(
            cast32(
                args.get('input_salt_and_pepper',
                         defaults["input_salt_and_pepper"])))
        self.input_sampling = args.get('input_sampling',
                                       defaults["input_sampling"])
        self.vis_init = args.get('vis_init', defaults['vis_init'])

        self.layer_sizes = [self.N_input] + [
            args.get('hidden_size', defaults['hidden_size'])
        ] * self.layers  # layer sizes, from h0 to hK (h0 is the visible layer)

        self.f_recon = None
        self.f_noise = None

        # Activation functions!
        if args.get('hidden_activation') is not None:
            log.maybeLog(self.logger, 'Using specified activation for hiddens')
            self.hidden_activation = args.get('hidden_activation')
        elif args.get('hidden_act') is not None:
            self.hidden_activation = get_activation_function(
                args.get('hidden_act'))
            log.maybeLog(
                self.logger, 'Using {0!s} activation for hiddens'.format(
                    args.get('hidden_act')))
        else:
            log.maybeLog(self.logger, "Using default activation for hiddens")
            self.hidden_activation = defaults['hidden_activation']

        # Visible layer activation
        if args.get('visible_activation') is not None:
            log.maybeLog(self.logger,
                         'Using specified activation for visible layer')
            self.visible_activation = args.get('visible_activation')
        elif args.get('visible_act') is not None:
            self.visible_activation = get_activation_function(
                args.get('visible_act'))
            log.maybeLog(
                self.logger, 'Using {0!s} activation for visible layer'.format(
                    args.get('visible_act')))
        else:
            log.maybeLog(self.logger,
                         'Using default activation for visible layer')
            self.visible_activation = defaults['visible_activation']

        # Cost function!
        if args.get('cost_function') is not None:
            log.maybeLog(self.logger,
                         '\nUsing specified cost function for training\n')
            self.cost_function = args.get('cost_function')
        elif args.get('cost_funct') is not None:
            self.cost_function = get_cost_function(args.get('cost_funct'))
            log.maybeLog(
                self.logger,
                'Using {0!s} for cost function'.format(args.get('cost_funct')))
        else:
            log.maybeLog(self.logger,
                         '\nUsing default cost function for training\n')
            self.cost_function = defaults['cost_function']

        ############################
        # Theano variables and RNG #
        ############################
        self.X = T.fmatrix('X')  # for use in sampling
        self.MRG = RNG_MRG.MRG_RandomStreams(1)
        rng.seed(1)

        ###############
        # Parameters! #
        ###############
        # initialize a list of weights and biases based on layer_sizes for the GSN
        if args.get('weights_list') is None:
            self.weights_list = [
                get_shared_weights(self.layer_sizes[layer],
                                   self.layer_sizes[layer + 1],
                                   name="W_{0!s}_{1!s}".format(
                                       layer, layer + 1))
                for layer in range(self.layers)
            ]  # initialize each layer to uniform sample from sqrt(6. / (n_in + n_out))
        else:
            self.weights_list = args.get('weights_list')
        if args.get('bias_list') is None:
            self.bias_list = [
                get_shared_bias(self.layer_sizes[layer],
                                name='b_' + str(layer))
                for layer in range(self.layers + 1)
            ]  # initialize each layer to 0's.
        else:
            self.bias_list = args.get('bias_list')
        self.params = self.weights_list + self.bias_list

        #################
        # Build the GSN #
        #################
        log.maybeLog(self.logger,
                     "\nBuilding GSN graphs for training and testing")
        # GSN for training - with noise
        add_noise = True
        p_X_chain, _ = build_gsn(
            self.X, self.weights_list, self.bias_list, add_noise,
            self.noiseless_h1, self.hidden_add_noise_sigma,
            self.input_salt_and_pepper, self.input_sampling, self.MRG,
            self.visible_activation, self.hidden_activation, self.walkbacks,
            self.logger)

        # GSN for reconstruction checks along the way - no noise
        add_noise = False
        p_X_chain_recon, _ = build_gsn(
            self.X, self.weights_list, self.bias_list, add_noise,
            self.noiseless_h1, self.hidden_add_noise_sigma,
            self.input_salt_and_pepper, self.input_sampling, self.MRG,
            self.visible_activation, self.hidden_activation, self.walkbacks,
            self.logger)

        #######################
        # Costs and gradients #
        #######################
        log.maybeLog(
            self.logger,
            'Cost w.r.t p(X|...) at every step in the graph for the GSN')
        gsn_costs = [self.cost_function(rX, self.X) for rX in p_X_chain]
        show_gsn_cost = gsn_costs[-1]  # for logging to show progress
        gsn_cost = numpy.sum(gsn_costs)

        gsn_costs_recon = [
            self.cost_function(rX, self.X) for rX in p_X_chain_recon
        ]
        show_gsn_cost_recon = gsn_costs_recon[-1]

        log.maybeLog(self.logger, ["gsn params:", self.params])

        # Stochastic gradient descent!
        gradient = T.grad(gsn_cost, self.params)
        gradient_buffer = [
            theano.shared(numpy.zeros(param.get_value().shape,
                                      dtype='float32'))
            for param in self.params
        ]
        m_gradient = [
            self.momentum * gb + (cast32(1) - self.momentum) * g
            for (gb, g) in zip(gradient_buffer, gradient)
        ]
        param_updates = [(param, param - self.learning_rate * mg)
                         for (param, mg) in zip(self.params, m_gradient)]
        gradient_buffer_updates = zip(gradient_buffer, m_gradient)
        updates = OrderedDict(param_updates + gradient_buffer_updates)

        ############
        # Sampling #
        ############
        # the input to the sampling function
        X_sample = T.fmatrix("X_sampling")
        self.network_state_input = [X_sample] + [
            T.fmatrix("H_sampling_" + str(i + 1)) for i in range(self.layers)
        ]

        # "Output" state of the network (noisy)
        # initialized with input, then we apply updates
        self.network_state_output = [X_sample] + self.network_state_input[1:]
        visible_pX_chain = []

        # ONE update
        log.maybeLog(self.logger,
                     "Performing one walkback in network state sampling.")
        update_layers(self.network_state_output, self.weights_list,
                      self.bias_list, visible_pX_chain, True,
                      self.noiseless_h1, self.hidden_add_noise_sigma,
                      self.input_salt_and_pepper, self.input_sampling,
                      self.MRG, self.visible_activation,
                      self.hidden_activation, self.logger)

        #################################
        #     Create the functions      #
        #################################
        log.maybeLog(self.logger, "Compiling functions...")
        t = time.time()

        self.f_learn = theano.function(inputs=[self.X],
                                       updates=updates,
                                       outputs=show_gsn_cost,
                                       name='gsn_f_learn')

        self.f_cost = theano.function(inputs=[self.X],
                                      outputs=show_gsn_cost,
                                      name='gsn_f_cost')

        # used for checkpoints and testing - no noise in network
        self.f_recon = theano.function(
            inputs=[self.X],
            outputs=[show_gsn_cost_recon, p_X_chain_recon[-1]],
            name='gsn_f_recon')

        self.f_noise = theano.function(inputs=[self.X],
                                       outputs=salt_and_pepper(
                                           self.X, self.input_salt_and_pepper),
                                       name='gsn_f_noise')

        if self.layers == 1:
            self.f_sample = theano.function(inputs=[X_sample],
                                            outputs=visible_pX_chain[-1],
                                            name='gsn_f_sample_single_layer')
        else:
            # WHY IS THERE A WARNING????
            # because the first odd layers are not used -> directly computed FROM THE EVEN layers
            # unused input = warn
            self.f_sample = theano.function(inputs=self.network_state_input,
                                            outputs=self.network_state_output +
                                            visible_pX_chain,
                                            on_unused_input='warn',
                                            name='gsn_f_sample')

        log.maybeLog(
            self.logger, "Compiling done. Took " +
            make_time_units_string(time.time() - t) + ".\n")
 def train(self, train_X=None, valid_X=None, test_X=None, continue_training=False):
     log.maybeLog(self.logger, "\nTraining---------\n")
     if train_X is None:
         log.maybeLog(self.logger, "Training using data given during initialization of GSN.\n")
         train_X = self.train_X
         if train_X is None:
             log.maybeLog(self.logger, "\nPlease provide a training dataset!\n")
             raise AssertionError("Please provide a training dataset!")
     else:
         log.maybeLog(self.logger, "Training using data provided to training function.\n")
     if valid_X is None:
         valid_X = self.valid_X
     if test_X is None:
         test_X  = self.test_X
         
     train_X = raise_data_to_list(train_X)
     valid_X = raise_data_to_list(valid_X)
     test_X  = raise_data_to_list(test_X)
         
     
     ############
     # TRAINING #
     ############
     log.maybeLog(self.logger, "-----------TRAINING GSN FOR {0!s} EPOCHS-----------".format(self.n_epoch))
     STOP        = False
     counter     = 0
     if not continue_training:
         self.learning_rate.set_value(self.init_learn_rate)  # learning rate
     times       = []
     best_cost   = float('inf')
     best_params = None
     patience    = 0
                 
     log.maybeLog(self.logger, ['train X size:',str(train_X[0].shape.eval())])
     if valid_X is not None:
         log.maybeLog(self.logger, ['valid X size:',str(valid_X[0].shape.eval())])
     if test_X is not None:
         log.maybeLog(self.logger, ['test X size:',str(test_X[0].shape.eval())])
     
     if self.vis_init:
         self.bias_list[0].set_value(logit(numpy.clip(0.9,0.001,train_X[0].get_value().mean(axis=0))))
 
     while not STOP:
         counter += 1
         t = time.time()
         log.maybeAppend(self.logger, [counter,'\t'])
         
         #train
         train_costs = data.apply_cost_function_to_dataset(self.f_learn, train_X, self.batch_size)
         log.maybeAppend(self.logger, ['Train:',trunc(numpy.mean(train_costs)), '\t'])
 
         #valid
         if valid_X is not None:
             valid_costs = data.apply_cost_function_to_dataset(self.f_cost, valid_X, self.batch_size)
             log.maybeAppend(self.logger, ['Valid:',trunc(numpy.mean(valid_costs)), '\t'])
 
         #test
         if test_X is not None:
             test_costs = data.apply_cost_function_to_dataset(self.f_cost, test_X, self.batch_size)
             log.maybeAppend(self.logger, ['Test:',trunc(numpy.mean(test_costs)), '\t'])
             
         #check for early stopping
         if valid_X is not None:
             cost = numpy.sum(valid_costs)
         else:
             cost = numpy.sum(train_costs)
         if cost < best_cost*self.early_stop_threshold:
             patience = 0
             best_cost = cost
             # save the parameters that made it the best
             best_params = save_params(self.params)
         else:
             patience += 1
 
         if counter >= self.n_epoch or patience >= self.early_stop_length:
             STOP = True
             if best_params is not None:
                 restore_params(self.params, best_params)
             save_params_to_file(counter, self.params, self.outdir, self.logger)
 
         timing = time.time() - t
         times.append(timing)
 
         log.maybeAppend(self.logger, 'time: '+make_time_units_string(timing)+'\t')
         
         log.maybeLog(self.logger, 'remaining: '+make_time_units_string((self.n_epoch - counter) * numpy.mean(times)))
     
         if (counter % self.save_frequency) == 0 or STOP is True:
             if self.is_image:
                 n_examples = 100
                 tests = test_X.get_value()[0:n_examples]
                 noisy_tests = self.f_noise(test_X.get_value()[0:n_examples])
                 _, reconstructed = self.f_recon(noisy_tests) 
                 # Concatenate stuff if it is an image
                 stacked = numpy.vstack([numpy.vstack([tests[i*10 : (i+1)*10], noisy_tests[i*10 : (i+1)*10], reconstructed[i*10 : (i+1)*10]]) for i in range(10)])
                 number_reconstruction = PIL.Image.fromarray(tile_raster_images(stacked, (self.image_height,self.image_width), (10,30)))
                 
                 number_reconstruction.save(self.outdir+'gsn_image_reconstruction_epoch_'+str(counter)+'.png')
     
             #save gsn_params
             save_params_to_file(counter, self.params, self.outdir, self.logger)
      
         # ANNEAL!
         new_lr = self.learning_rate.get_value() * self.annealing
         self.learning_rate.set_value(new_lr)
         
         new_hidden_sigma = self.hidden_add_noise_sigma.get_value() * self.noise_annealing
         self.hidden_add_noise_sigma.set_value(new_hidden_sigma)
         
         new_salt_pepper = self.input_salt_and_pepper.get_value() * self.noise_annealing
         self.input_salt_and_pepper.set_value(new_salt_pepper)
Exemple #4
0
    def train(self,
              train_X=None,
              valid_X=None,
              test_X=None,
              continue_training=False):
        log.maybeLog(self.logger, "\nTraining---------\n")
        if train_X is None:
            log.maybeLog(
                self.logger,
                "Training using data given during initialization of GSN.\n")
            train_X = self.train_X
            if train_X is None:
                log.maybeLog(self.logger,
                             "\nPlease provide a training dataset!\n")
                raise AssertionError("Please provide a training dataset!")
        else:
            log.maybeLog(
                self.logger,
                "Training using data provided to training function.\n")
        if valid_X is None:
            valid_X = self.valid_X
        if test_X is None:
            test_X = self.test_X

        train_X = raise_data_to_list(train_X)
        valid_X = raise_data_to_list(valid_X)
        test_X = raise_data_to_list(test_X)

        ############
        # TRAINING #
        ############
        log.maybeLog(
            self.logger,
            "-----------TRAINING GSN FOR {0!s} EPOCHS-----------".format(
                self.n_epoch))
        STOP = False
        counter = 0
        if not continue_training:
            self.learning_rate.set_value(self.init_learn_rate)  # learning rate
        times = []
        best_cost = float('inf')
        best_params = None
        patience = 0

        log.maybeLog(
            self.logger,
            ['train X size:', str(train_X[0].shape.eval())])
        if valid_X is not None:
            log.maybeLog(self.logger,
                         ['valid X size:',
                          str(valid_X[0].shape.eval())])
        if test_X is not None:
            log.maybeLog(
                self.logger,
                ['test X size:', str(test_X[0].shape.eval())])

        if self.vis_init:
            self.bias_list[0].set_value(
                logit(
                    numpy.clip(0.9, 0.001,
                               train_X[0].get_value().mean(axis=0))))

        while not STOP:
            counter += 1
            t = time.time()
            log.maybeAppend(self.logger, [counter, '\t'])

            #train
            train_costs = data.apply_cost_function_to_dataset(
                self.f_learn, train_X, self.batch_size)
            log.maybeAppend(
                self.logger,
                ['Train:', trunc(numpy.mean(train_costs)), '\t'])

            #valid
            if valid_X is not None:
                valid_costs = data.apply_cost_function_to_dataset(
                    self.f_cost, valid_X, self.batch_size)
                log.maybeAppend(
                    self.logger,
                    ['Valid:', trunc(numpy.mean(valid_costs)), '\t'])

            #test
            if test_X is not None:
                test_costs = data.apply_cost_function_to_dataset(
                    self.f_cost, test_X, self.batch_size)
                log.maybeAppend(
                    self.logger,
                    ['Test:', trunc(numpy.mean(test_costs)), '\t'])

            #check for early stopping
            if valid_X is not None:
                cost = numpy.sum(valid_costs)
            else:
                cost = numpy.sum(train_costs)
            if cost < best_cost * self.early_stop_threshold:
                patience = 0
                best_cost = cost
                # save the parameters that made it the best
                best_params = save_params(self.params)
            else:
                patience += 1

            if counter >= self.n_epoch or patience >= self.early_stop_length:
                STOP = True
                if best_params is not None:
                    restore_params(self.params, best_params)
                save_params_to_file(counter, self.params, self.outdir,
                                    self.logger)

            timing = time.time() - t
            times.append(timing)

            log.maybeAppend(self.logger,
                            'time: ' + make_time_units_string(timing) + '\t')

            log.maybeLog(
                self.logger, 'remaining: ' + make_time_units_string(
                    (self.n_epoch - counter) * numpy.mean(times)))

            if (counter % self.save_frequency) == 0 or STOP is True:
                if self.is_image:
                    n_examples = 100
                    tests = test_X.get_value()[0:n_examples]
                    noisy_tests = self.f_noise(
                        test_X.get_value()[0:n_examples])
                    _, reconstructed = self.f_recon(noisy_tests)
                    # Concatenate stuff if it is an image
                    stacked = numpy.vstack([
                        numpy.vstack([
                            tests[i * 10:(i + 1) * 10],
                            noisy_tests[i * 10:(i + 1) * 10],
                            reconstructed[i * 10:(i + 1) * 10]
                        ]) for i in range(10)
                    ])
                    number_reconstruction = PIL.Image.fromarray(
                        tile_raster_images(
                            stacked, (self.image_height, self.image_width),
                            (10, 30)))

                    number_reconstruction.save(
                        self.outdir + 'gsn_image_reconstruction_epoch_' +
                        str(counter) + '.png')

                #save gsn_params
                save_params_to_file(counter, self.params, self.outdir,
                                    self.logger)

            # ANNEAL!
            new_lr = self.learning_rate.get_value() * self.annealing
            self.learning_rate.set_value(new_lr)

            new_hidden_sigma = self.hidden_add_noise_sigma.get_value(
            ) * self.noise_annealing
            self.hidden_add_noise_sigma.set_value(new_hidden_sigma)

            new_salt_pepper = self.input_salt_and_pepper.get_value(
            ) * self.noise_annealing
            self.input_salt_and_pepper.set_value(new_salt_pepper)