Exemplo n.º 1
0
def abstractToMatlab(event_list, output_file):
    UI().objectUI.showMessage("Starting to write mat file", "w")
    file_data = {}
    c = Config()

    if c.matlab:
        struct_type = c.config_data["matlab"]["default"]
    else:
        struct_type = UI().objectUI.chooseWindow(
            "Which struct type do you prefer? ", cte.MATLAB_TYPES_ADMITTED)

    arrays = [[], [], [], []]
    for ev in event_list:
        arrays[0].append(double(ev.x))
        arrays[1].append(double(ev.y))
        arrays[2].append(double(ev.pol))
        arrays[3].append(double(secsToNsecs(ev.ts)))

    # 1 struct
    if struct_type == cte.MATLAB_TYPES_ADMITTED[0]:
        if c.matlab:
            struct_name = c.config_data["matlab"]["1 struct"]["struct_name"]
            data_names = c.config_data["matlab"]["1 struct"]["names"]
        else:
            struct_name = UI().objectUI.simpleInput(
                "What is the name of the struct?: ")
            data_names = UI().objectUI.multiInputsWindow(
                "How are the names of these parameters",
                cte.MATLAB_STRUCT_NAMES)

        file_data[struct_name] = {}
        for arr, name in zip(arrays, data_names):
            file_data[struct_name][name] = arr

    # Matrix nx4
    elif struct_type == cte.MATLAB_TYPES_ADMITTED[1]:
        if c.matlab:
            struct_name = c.config_data["matlab"]["Matrix nx4"]["struct_name"]
        else:
            struct_name = UI().objectUI.simpleInput(
                "What is the name of the struct?: ")

        file_data[struct_name] = column_stack(
            (arrays[0], arrays[1], arrays[2], arrays[3]))

    # 4 structs (one for each event's parameter)
    elif struct_type == cte.MATLAB_TYPES_ADMITTED[2]:
        if c.matlab:
            struct_names = c.config_data["matlab"]["4 structs"]["names"]
        else:
            struct_names = UI().objectUI.multiInputsWindow(
                "How are the names of these structs", cte.MATLAB_STRUCT_NAMES)

        for arr, name in zip(arrays, struct_names):
            file_data[name] = arr

    UI().objectUI.showMessage(
        "Starting to save the data (no progress bar available)", "w")
    sio.savemat(output_file, file_data, oned_as="column")
    UI().objectUI.showMessage("Finishing writing the mat file", "c")
Exemplo n.º 2
0
def abstractToRosbag(event_list, output_file):
    UI().objectUI.showMessage("Starting to write bag file", "w")
    bag = rosbag.Bag(output_file, "w")
    c = Config()

    if c.rosbag:
        topic = c.config_data["rosbag"]["topic"]
    else:
        topic = UI().objectUI.simpleInput(
            "Introduce the name of the topic where the events are going to be write: "
        )

    num_progress = getNumProgress(len(event_list))
    for i, event in enumerate(event_list):
        if i % num_progress == 0:
            UI().objectUI.sumProgress()

        e = _Event()
        e.x = event.x
        e.y = event.y
        e.polarity = event.pol
        e.ts = rospy.Time.from_sec(event.ts)

        bag.write(topic, e)

    UI().objectUI.sumProgress(True)
    bag.close()
    UI().objectUI.showMessage("Finishing writing the bag file", "c")
Exemplo n.º 3
0
 def __init__(self):
     self._config = Config()
     self._db_server = self._config.get_db_server()
     self._db_port = self._config.get_db_port()
     self._db_database = self._config.get_db_database()
     self._client = MongoClient(self._db_server, self._db_port)
     self._date_helper = DateHelper()
Exemplo n.º 4
0
def create_two_layer_dense_model():
    a = Config(config_dict=None,
               standard_key_list=TwoLayerDenseModel.standard_key_list)
    a.load_config(path=CONFIG_PATH +
                  '/model/testTwoLayerDenseAttackConfig.json')
    model = TwoLayerDenseModel(config=a)
    return model
Exemplo n.º 5
0
    def __init__(self):
        logger.info("Initializing Trello crawler")

        factory_controller = FactoryController()
        self._config = Config()
        self._request_helper = RequestHelper()
        self._date_helper = DateHelper()

        self._trello_card_controller = factory_controller.get_trello_card()
Exemplo n.º 6
0
def main():

    UI("terminal")

    Config("src/config/config.json")

    # testFileAndType("data/aedat/aedat4/Cars_sequence.aedat4", "aedat")

    testTypes("aedat", "aedat")
Exemplo n.º 7
0
def create_data():
    conf = Config(standard_key_list=FISData.standard_key_list,
                  config_dict=None)
    config = utils.load_json(file_path=CONFIG_PATH +
                             '/data/testFisAttackDataConfig.json')
    config['FILE_PATH'] = DATASET1_PATH + '/Attack.csv'
    conf.load_config(path=None, json_dict=config)
    data = FISData(config=conf)
    data.load_data()
    return data
Exemplo n.º 8
0
 def __init__(self, config, data=None):
     super(Model, self).__init__(config)
     self.config = config
     if not self.config:
         self.config = Config(standard_key_list=[], config_dict={})
     self.data = data
     self.input = None
     self.delta_state_output = None
     self.snapshot_var = []
     self.save_snapshot_op = []
     self.load_snapshot_op = []
Exemplo n.º 9
0
def abstractToAedat(event_list, output_file):
    c = Config()

    if c.aedat:
        version = c.config_data["aedat"]["version"]
    else:
        version = UI().objectUI.chooseWindow("Choose a version: ",
                                             cte.AEDAT_ACCEPTED_VERSIONS)

    if version == cte.AEDAT_ACCEPTED_VERSIONS[0]:
        abstractToAedat2(event_list, output_file)
    elif version == cte.AEDAT_ACCEPTED_VERSIONS[1]:
        abstractToAedat3(event_list, output_file)
    elif version == cte.AEDAT_ACCEPTED_VERSIONS[2]:
        abstractToAedat4(event_list, output_file)
Exemplo n.º 10
0
    def __init__(self, db_name: str = 'data_main'):
        """Constructor for Mongo()

        :param db_name: default database overwrite, defaults to 'dat_main'
        :type db_name: str, optional
        """

        config = Config()

        client = pymongo.MongoClient(config.MONGO_URI)
        logger.info(f"Connected to MongoDB with {config.MONGO_URI}")
        db = client.get_database(db_name)
        logger.info(f"Connected to database {db_name}")

        self._client = client
        self._db = db
Exemplo n.º 11
0
    def __init__(self, cost_fn=None, config=None):
        super().__init__(config=config)
        if self.config is None:
            self.config = Config(standard_key_list=['REAL_ENVIRONMENT_STATUS',
                                                    'CYBER_ENVIRONMENT_STATUS', 'TEST_ENVIRONMENT_STATUS'])
        self.config.config_dict['REAL_ENVIRONMENT_STATUS'] = 1
        self.config.config_dict['CYBER_ENVIRONMENT_STATUS'] = 0
        self.config.config_dict['TEST_ENVIRONMENT_STATUS'] = 2

        self._test_data = SamplerData()
        self._cyber_data = SamplerData()
        self._real_data = SamplerData()
        self.cost_fn = cost_fn

        self.data = None

        self._env_status = None

        self.env_status = 1
Exemplo n.º 12
0
def main():
    # Args parse.
    input_file, output_file, input_type, output_type, use_config, config_path, ui_type = parseArguments(
    )

    # Init config features.
    if use_config:
        Config(config_path)

    # Create UI.
    if ui_type == "graphic":
        UI("graphic")
    elif ui_type == "terminal":
        UI("terminal")

    # Init UI.
    try:
        UI().objectUI.initialWindow(convert, input_file, output_file,
                                    input_type, output_type, use_config,
                                    config_path)
    except Exception as e:
        UI().objectUI.errorWindow(e)
Exemplo n.º 13
0
def rosbagToAbstract(input_file):
    UI().objectUI.showMessage("Starting to read bag file", "w")
    bag = rosbag.Bag(input_file)
    c = Config()

    if c.rosbag:
        topic = c.config_data["rosbag"]["topic"]
    else:
        topics = bag.get_type_and_topic_info().topics
        topic = UI().objectUI.chooseWindow(
            "Which is the topic that contains the events?: ", topics)

    event_list = []

    num_progress = getNumProgress(bag.get_message_count(topic))
    i = 0
    for topic, msg, t in bag.read_messages(topics=topic):

        if i % num_progress == 0:
            UI().objectUI.sumProgress()
        i += 1

        aux_list = []
        if "EventArray" in str(type(msg)):  # msg._type
            aux_list = msg.events
        else:
            aux_list.append(msg)

        for event in aux_list:
            event_list.append(
                Event(event.x, event.y, event.polarity,
                      combine(event.ts.secs, event.ts.nsecs)))

    bag.close()
    UI().objectUI.sumProgress(True)
    UI().objectUI.showMessage("Finishing reading the bag file", "c")
    return event_list
Exemplo n.º 14
0
    def initialWindow(self,
                      convert,
                      input_file="",
                      output_file=cte.OUTPUT_DEFAULT,
                      input_type="",
                      output_type="",
                      use_config=False,
                      config_path=cte.CONFIG_PATH):

        input_type_aux = input_type if input_type != "" else getExtension(
            input_file)
        output_type_aux = output_type if output_type != "" else getExtension(
            output_file)

        input_layout = [[
            sg.Text("Input path",
                    size=(12, 1),
                    text_color="white",
                    background_color=grey_green),
            sg.In(size=(35, 1),
                  key="INPUT FILE",
                  enable_events=True,
                  default_text=input_file),
            sg.FileBrowse()
        ],
                        [
                            sg.Text("Detected input type",
                                    size=(12, 1),
                                    text_color="white",
                                    background_color=grey_green),
                            sg.Combo(cte.ADMITTED_TYPES,
                                     key="INPUT TYPE",
                                     default_value=input_type_aux,
                                     size=(10, 1))
                        ]]
        output_layout = [[
            sg.Text("Output path",
                    size=(12, 1),
                    text_color="white",
                    background_color=grey_green),
            sg.In(size=(35, 1),
                  key="OUTPUT FILE",
                  enable_events=True,
                  default_text=output_file),
            sg.FileBrowse()
        ],
                         [
                             sg.Text("Output type",
                                     size=(12, 1),
                                     text_color="white",
                                     background_color=grey_green),
                             sg.Combo(cte.ADMITTED_TYPES,
                                      key="OUTPUT TYPE",
                                      default_value=output_type_aux,
                                      size=(10, 1),
                                      enable_events=True)
                         ]]

        config_layout = [[
            sg.Text("Config file path",
                    size=(12, 1),
                    text_color="white",
                    background_color=grey_green),
            sg.In(size=(35, 1),
                  key="CONFIG PATH",
                  enable_events=True,
                  default_text=config_path),
            sg.FileBrowse(),
        ],
                         [
                             sg.Radio('Use config file',
                                      'CONFIG RADIO',
                                      key="USE CONFIG",
                                      default=use_config,
                                      background_color=grey_green),
                             sg.Radio('Do not use config file',
                                      'CONFIG RADIO',
                                      key="NOT USE CONFIG",
                                      default=not use_config,
                                      background_color=grey_green)
                         ]]

        buttons_layout = [[sg.Button("RESET"), sg.Button("CONVERT")]]

        layout = [[
            input_layout, [[sg.HSeparator()]], output_layout,
            [[sg.HSeparator()]], config_layout, [[sg.HSeparator()]],
            buttons_layout
        ]]

        self.window = sg.Window("Event Converter",
                                layout,
                                size=(1000, 500),
                                margins=(30, 35),
                                resizable=True,
                                font=(font_family, font_size),
                                background_color=grey_green,
                                button_color=dark_green)

        while True:
            event, values = self.window.read()
            if event == "Exit" or event == sg.WIN_CLOSED:
                break

            elif event == "RESET":
                self.window.Element('INPUT FILE').Update(value="")
                self.window.Element('INPUT TYPE').Update(value="")
                self.window.Element('OUTPUT FILE').Update(value="")
                self.window.Element('OUTPUT TYPE').Update(value="")
                self.window.Element('CONFIG PATH').Update(value="")
                self.window.Element('NOT USE CONFIG').Update(value=True)

            elif event == "INPUT FILE":
                ext = getExtension(values["INPUT FILE"])
                if ext in cte.ADMITTED_TYPES:
                    self.window.Element('INPUT TYPE').Update(value=ext)

            elif event == "OUTPUT FILE":
                ext = getExtension(values["OUTPUT FILE"])
                if ext in cte.ADMITTED_TYPES:
                    self.window.Element('OUTPUT TYPE').Update(value=ext)

            elif event == "OUTPUT TYPE":
                ext = values["OUTPUT TYPE"]
                file = values["OUTPUT FILE"]
                tam = len(getExtension(file))
                if tam > 0:
                    file = file[:-tam]
                if file[-1] == ".":
                    file += ext
                else:
                    file += "." + ext
                self.window.Element('OUTPUT FILE').Update(value=file)

            elif event == "CONVERT":
                i_f = values["INPUT FILE"]
                i_t = values["INPUT TYPE"]
                o_f = values["OUTPUT FILE"]
                o_t = values["OUTPUT TYPE"]
                u_c = values["USE CONFIG"]
                c_p = values["CONFIG PATH"]

                if not os.path.isfile(i_f):
                    sg.popup_error("The input file does not exists")
                elif u_c and not os.path.isfile(c_p):
                    sg.popup_error("The config file does not exists")
                elif i_t not in cte.ADMITTED_TYPES:
                    sg.popup_error("The input type is not supported")
                elif o_t not in cte.ADMITTED_TYPES:
                    sg.popup_error("The output type is not supported")
                else:
                    if u_c:
                        Config(c_p)
                    convert(i_f, o_f, i_t, o_t)
                    sg.popup('CONVERSION FINISHED!!!')

        self.window.close()
Exemplo n.º 15
0
                                '/fixedOutputModelKey.json')

    def __init__(self, config):
        super(FixedOutputModel, self).__init__(config)

    def predict(self, sess=None, state=None):
        action = [0 for _ in range(self.config.config_dict['ACTION_SPACE'][0])]
        action[0] = self.config.config_dict['F1']
        action[1] = self.config.config_dict['PROB_SAMPLE_ON_REAL']
        action[2] = self.config.config_dict['PROB_TRAIN_ON_REAL']

        return np.array(action)

    def reset(self):
        pass

    def update(self, *args, **kwargs):
        pass

    def print_log_queue(self, status):
        pass


if __name__ == '__main__':
    from conf import CONFIG

    conf = Config(standard_key_list=FixedOutputModel.key_list)
    conf.load_config(path=CONFIG + '/baselineTrainerModelTestConfig.json')
    a = FixedOutputModel(config=conf)
    print(a.predict(state=1))
Exemplo n.º 16
0
        return loss, optimizer

    def update(self, sess, q_label, state, action):
        loss, _, grad = sess.run(
            fetches=[self.loss, self.optimize_loss, self.gradients],
            feed_dict={
                self.q_label: q_label,
                self.state: state,
                self.action: action,
                self.is_training: True
            })
        return loss, grad

    def eval_tensor(self, ):
        raise NotImplementedError
        pass


if __name__ == '__main__':
    from src.config.config import Config
    from configuration import CONFIG_PATH

    a = Config(standard_key_list=DenseCritic.standard_key_list)
    a.load_config(path=CONFIG_PATH + '/testCriticConfig.json')
    critic = DenseCritic(config=a)
    with tf.Session() as sess:
        with sess.as_default():
            tl.layers.initialize_global_variables(sess)
            critic.net.print_params()
    pass
        sess = tf.get_default_session()
        sess.run(self.variables_initializer)
        self.ddpg_model.actor_optimizer.sync()
        self.ddpg_model.critic_optimizer.sync()
        sess.run(self.ddpg_model.target_init_updates)
        self.env_status = self.config.config_dict['REAL_ENVIRONMENT_STATUS']
        super().init()

    def store_one_sample(self, state, next_state, action, reward, done, *arg, **kwargs):
        self.memory.append(obs0=state,
                           obs1=next_state,
                           action=action,
                           reward=reward,
                           terminal1=done)


if __name__ == '__main__':
    from config import CONFIG

    con = Config(standard_key_list=DDPGModel.key_list)
    con.load_config(path=CONFIG + '/targetModelTestConfig.json')

    a = DDPGModel(config=con)
    sess = tf.Session()
    with sess.as_default():
        a.init()
        a.load_snapshot()
        a.save_snapshot()
        a.save_model(path='/home/linsen/.tmp/ddpg-model.ckpt', global_step=1)
        a.load_model(file='/home/linsen/.tmp/ddpg-model.ckpt-1')
Exemplo n.º 18
0
 def __init__(self, *args, **kwargs):
     self.config = Config(os.environ['HOME'] + '/tower.yml')
     super().__init__(*args, **kwargs)
Exemplo n.º 19
0
        if env_status == self.config.config_dict['REAL_ENVIRONMENT_STATUS']:
            memory = self.real_data_memory
        elif env_status == self.config.config_dict['CYBER_ENVIRONMENT_STATUS']:
            memory = self.simulation_data_memory
        else:
            raise ValueError('Wrong Environment status')

        length = memory.nb_entries
        return length >= sample_count


if __name__ == '__main__':

    from src.config.config import Config

    con = Config(standard_key_list=TrpoModel.key_list)
    con.load_config(
        path='/home/dls/CAP/intelligenttrainerframework/conf/modelNetworkConfig/targetModelTestConfig.json')
    a = TrpoModel(config=con, action_bound=([-1], [1]), obs_bound=([-1], [1]))
    import tensorflow as tf
    import os

    os.environ["CUDA_VISIBLE_DEVICES"] = '0'

    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    sess = tf.Session(config=tf_config)
    with sess.as_default():
        a.init()
        for i in range(500):
            a.env_status = a.config.config_dict['CYBER_ENVIRONMENT_STATUS']
Exemplo n.º 20
0
            n_units=self.config.config_dict['LSTM_DENSE_LAYER_2_UNIT'],
            act=tf.nn.tanh,
            name=name_prefix + 'LSTM_DENSE_LAYER_2')

        net = tl.layers.DenseLayer(
            layer=lstm_fc2,
            n_units=self.config.config_dict['MERGED_LAYER_1_UNIT'],
            act=tf.nn.relu,
            name=name_prefix + 'MERGED_DENSE_LAYER_1')
        net = tl.layers.DenseLayer(
            layer=net,
            n_units=self.config.config_dict['MERGED_LAYER_2_UNIT'],
            act=tf.nn.relu,
            name=name_prefix + 'MERGED_DENSE_LAYER_2')
        return net


if __name__ == '__main__':
    from src.config.config import Config
    from configuration import CONFIG_PATH

    a = Config(config_dict=None,
               standard_key_list=LSTMCritic.standard_key_list)
    a.load_config(path=CONFIG_PATH + '/testLSTMCriticconfig.json')
    critic = LSTMCritic(config=a)
    with tf.Session() as sess:
        with sess.as_default():
            tl.layers.initialize_global_variables(sess)
            critic.net.print_params()
    pass
def load_config(key_list, config_path):
    conf = Config(standard_key_list=key_list)
    conf.load_config(path=config_path)
    return conf
Exemplo n.º 22
0
 def __init__(self):
     logger.info("Initialize JWTHelper")
     config = Config()
     self._secret = config.get_jwt_secret()
     self._algorithm = config.get_jwt_algorithm()
     self._date_helper = DateHelper()
Exemplo n.º 23
0
def load_config(key_list, config_path, update_dict):
    conf = Config(standard_key_list=key_list)
    conf.load_config(path=config_path)
    update_config_dict(config=conf, dict=update_dict)
    return conf
Exemplo n.º 24
0
        if attachments:
            if not isinstance(attachments, list):
                attachments = [attachments]
            self.logger.info("Found {} attachment. Processing".format(len(attachments)))
            for attachment in attachments:
                self.logger.info("Attaching \"{}\"".format(attachment))
                self.add_attachment(msg, attachment)
                self.logger.debug("Attached \"{}\"".format(attachment))
        return msg

    def send(self, msg):
        try:
            self.mailserver.send_message(msg)
            self.logger.info("Mail sent to the {} recipients".format(len(self.C["mail.recipients"])))
        except SMTPServerDisconnected:
            self.logger.warning("Mail server disconnected. Reconnecting.")
            self.connect()
            self.send(msg)


if __name__ == '__main__':
    from src.config.config import Config

    c = Config("mail.yaml")
    m = MailClient(config=c)
    mail = m.compose_mail("test mail",
                          "this is a test mail. \n Please ignore the content",
                          attachments=["attachments/1.txt", "attachments/2.txt"])

    m.send(mail)
Exemplo n.º 25
0
        self.data_list = temp_data
        self.state_list = []
        self.output_list = []
        for sample in self.data_list:
            self.state_list.append(sample['STATE'])
            self.output_list.append(sample['OUTPUT'])
        pass
        self.state_list = np.array(self.state_list)
        self.output_list = np.array(self.output_list)

    def return_batch_data(self, index, size):
        return self.state_list[index * size:index * size + size], self.output_list[index * size:index * size + size]
        pass


if __name__ == '__main__':
    from src.config.config import Config
    from src.configuration.standard_key_list import CONFIG_STANDARD_KEY_LIST
    from src.configuration import CONFIG_PATH
    conf = Config(standard_key_list=FISData.standard_key_list, config_dict=None)
    config = utils.load_json(file_path=CONFIG_PATH + '/testFisAttackDataConfig.json')
    config['FILE_PATH'] = DATASET1_PATH + '/Attack.csv'
    conf.load_config(path=None, json_dict=config)
    data = FISData(config=conf)
    data.load_data()
    data.shuffle_data()




Exemplo n.º 26
0
            self.torcs_vision_list.put(ob.img)
        img = self.return_latest_state_from_vision_list()
        return img, reward, done

    def reset(self, relauch=False):
        ob = self.torcs_env.reset(relaunch=relauch)
        self.torcs_vision_list.clear()
        while len(self.torcs_vision_list
                  ) < self.config.config_dict['VISION_DATA_LENGTH']:
            self.torcs_vision_list.put(ob.img)
        img = self.return_latest_state_from_vision_list()
        return img

    def end(self):
        self.torcs_env.end()
        self.torcs_vision_list.clear()

    def return_latest_state_from_vision_list(self):
        right = len(self.torcs_vision_list) - 1
        left = right - self.config.config_dict['VISION_DATA_LENGTH']
        img = slice_queue(q=self.torcs_vision_list, left=left, right=right)
        return img


if __name__ == '__main__':
    from src.config.config import Config
    from configuration import CONFIG_PATH
    a = Config(standard_key_list=TorcsEnvironment.standard_key_list)
    a.load_config(path=CONFIG_PATH + '/testTorcsEnvironmentConfig.json')
    env = TorcsEnvironment(config=a)
Exemplo n.º 27
0
import logging

from flask.json import jsonify

from src.config.config import Config

config = Config()

DEFAULT_DATE_FORMAT = "%Y-%m-%d"
DEFAULT_DATETIME_FORMAT = "%Y-%m-%d %H:%M.%S"


def get_logger(name: str, file_name=None) -> logging.Logger:
    """Get Logger with proper config

    :param name: The name to use, usually __name__
    :type name: str
    :return: The Logger
    :rtype: Logger
    """

    # if config.ENV.lower() == 'local':
    #     # If local, do not write to file
    #     logging.basicConfig(level=config.LOGGING_LEVEL)
    # else:
    #     # If not local, write to file
    #     logging.basicConfig(filename='main.log', level=config.LOGGING_LEVEL)

    if file_name:
        logging.basicConfig(level=config.LOGGING_LEVEL,
                            filename=file_name,
Exemplo n.º 28
0
        # TODO HOW TO ELEGANT CHANGE THIS
        if self.model and hasattr(self.model, 'print_log_queue') and callable(
                self.model.print_log_queue):
            self.model.print_log_queue(status=status)

    def reset(self):
        super().reset()
        self.model.reset()

    def get_trpo_step_count(self):
        from src.model.trpoModel.trpoModel import TrpoModel
        if isinstance(self.model, TrpoModel) is True:
            return self.model.step_count
        else:
            return None


if __name__ == '__main__':
    from conf import CONFIG
    from src.model.ddpgModel.ddpgModel import DDPGModel

    conf = Config(standard_key_list=TargetAgent.key_list)
    conf.load_config(path=CONFIG + '/ddpgAgentTestConfig.json')

    ddog_con = Config(standard_key_list=DDPGModel.key_list)
    ddog_con.load_config(path=CONFIG + '/targetModelTestConfig.json')

    ddpg = DDPGModel(config=ddog_con)

    a = TargetAgent(config=conf, real_env=2, cyber_env=1, model=ddpg)
Exemplo n.º 29
0
            n_units=self.config.config_dict['DENSE_LAYER_2_UNIT'],
            act=tf.nn.leaky_relu,
            name=name_prefix + 'DENSE_LAYER_2')

        net = tl.layers.DenseLayer(
            layer=net,
            n_units=self.config.config_dict['OUTPUT_DIM'],
            act=tf.nn.softmax,
            name=name_prefix + 'OUTPUT_LAYER')

        return net

    def create_training_method(self):
        # weight_decay = tf.add_n([self.config.config_dict['L2'] * tf.nn.l2_loss(var) for var in self.var_list])
        # loss = tf.reduce_mean(tf.square(self.label - self.net.outputs)) + weight_decay
        loss = tf.reduce_mean(tf.square(self.label - self.net.outputs))
        optimizer = tf.train.AdamOptimizer(
            self.config.config_dict['LEARNING_RATE'])
        return loss, optimizer


if __name__ == '__main__':
    from src.config.config import Config
    from src.configuration import CONFIG_PATH

    a = Config(config_dict=None,
               standard_key_list=DenseModel.standard_key_list)
    a.load_config(path=CONFIG_PATH + '/testDenseConfig.json')
    actor = DenseModel(config=a)
    pass
                           self.action_means: self.action_scalar.means,
                           self.output_means: self.delta_scalar.means,
                           self.output_vars: np.sqrt(self.delta_scalar.vars)
                       })

        return utl.squeeze_array(res,
                                 dim=1 +
                                 len(self.config.config_dict['STATE_SPACE']))

    def init(self):
        sess = tf.get_default_session()
        sess.run(self.variables_initializer)
        super().init()


if __name__ == '__main__':
    from conf import CONFIG

    conf = Config(standard_key_list=DynamicsEnvMlpModel.key_list)
    conf.load_config(path=CONFIG + '/dynamicsEnvMlpModelTestConfig.json')
    a = DynamicsEnvMlpModel(config=conf)
    state_input = np.zeros(shape=[10, 20])
    action_input = np.zeros(shape=[10, 6])
    sess = tf.Session()
    with sess.as_default():
        a.init()
        a.load_snapshot()
        a.save_snapshot()
        a.save_model(path='/home/linsen/.tmp/model.ckpt', global_step=1)
        a.load_model(file='/home/linsen/.tmp/model.ckpt-1')