예제 #1
0
    def gather_data(self, step, t_error, gann, cman: CustomCaseManager):

        if (not step % self.interval) or (step + 1 == gann.options.steps):
            self.append_error(step, t_error,
                              self.t_err)  # Standard training error
            if len(cman.get_validation_cases()) > 0:
                self.append_error(step,
                                  gann.do_testing(cman.get_validation_cases()),
                                  self.v_err)  # Validation error
            self.updated = True
예제 #2
0
    def get_glass_options(session_tracker):
        cases = DataSets.Glass()
        lrate = 0.001
        cman = CustomCaseManager(cases, None, cfrac=1, vfrac=0.1, tfrac=0.1)

        params = {
            'net_dims': calc_net_dims(cases[0], [80, 80]),
            'h_activation_function': tf.nn.relu,
            'o_activation_function': tf.nn.softmax,
            'cost_function': tf.losses.softmax_cross_entropy,
            'learning_rate': lrate,
            'weight_range': [-.5, .5],
            'optimizer': tf.train.AdamOptimizer(lrate, 0.9, 0.999),
            'case_manager': cman,
            'minibatch_size': 75,
            'steps': 1000,
            'vint': 10,
            'session_tracker': session_tracker,

            # MAP options
            'map_case_count': 10,
            'map_case_func': cman.get_testing_cases
        }

        # Options used throughout the program
        return Options(**params)
예제 #3
0
    def get_bit_counter_options(session_tracker):
        cases = DataSets.Bit_Counter(500, 15)
        lrate = 0.1
        cman = CustomCaseManager(cases, None, cfrac=1, vfrac=0.1, tfrac=0.1)

        params = {
            'net_dims': calc_net_dims(cases[0], [80, 80]),
            'h_activation_function': tf.nn.relu,
            'o_activation_function': tf.nn.softmax,
            'cost_function': tf.losses.mean_squared_error,
            'learning_rate': lrate,
            'weight_range': [-.2, .2],
            'optimizer': tf.train.AdagradOptimizer(lrate, 0.001),
            'case_manager': cman,
            'minibatch_size': 100,
            'steps': 10000,
            'vint': 100,
            'session_tracker': session_tracker,

            # MAP options
            'map_case_count': 10,
            'map_case_func': cman.get_testing_cases
        }

        # Options used throughout the program
        return Options(**params)