예제 #1
0
def train():
    hp = Hp()
    np.random.seed(hp.seed)
    env = OpenAIGym(hp.env_name)

    nb_inputs = env.observation_space.shape[0]
    # if using cnn for inputs
    if hp.conv_input:
        # create dummy image
        test_img = np.ones([env.observation_space.shape[0], env.observation_space.shape[1], env.observation_space.shape[2]]).astype(np.uint8)
        # create dummy Normalizer obj
        test_n = Normalizer(0, hp)
        # pass through cnn
        test_output = test_n.image_cnn(test_img)
        # get output length
        nb_inputs = len(test_output)

    nb_outputs = env.action_space.shape[0]

    policy = Policy(nb_inputs, nb_outputs, hp)
    normalizer = Normalizer(nb_inputs, hp)

    if hp.train_from_previous_weights:
        policy.load()
        normalizer.load()

    instance = Run(env, policy, normalizer, hp)
    instance.train()
예제 #2
0
파일: evaluate.py 프로젝트: tie304/ARS-CNN
def evaluate(n_steps):
    hp = Hp()
    np.random.seed(hp.seed)
    env = OpenAIGym(hp.env_name)
    nb_inputs = env.observation_space.shape[0]
    if hp.conv_input:
        # create dummy image
        test_img = np.ones([
            env.observation_space.shape[0], env.observation_space.shape[1],
            env.observation_space.shape[2]
        ]).astype(np.uint8)
        # create dummy Normalizer obj
        test_n = Normalizer(0, hp)
        # pass through cnn
        test_output = test_n.image_cnn(test_img)
        # get output length
        nb_inputs = len(test_output)
    nb_outputs = env.action_space.shape[0]
    policy = Policy(nb_inputs, nb_outputs, hp)
    normalizer = Normalizer(nb_inputs, hp)

    normalizer.load()  # load normalizer weights
    policy.load()  # load policy weights

    instance = Run(env, policy, normalizer, hp)
    instance.evaluate(n_steps)
예제 #3
0
 def __init__(self):
     """Initialise the main interface.
     """
     # Register a handler for catching Ctrl+c
     signal.signal(signal.SIGINT, self.signal_handler)
     # Create and initialise CLI objects
     cmd.Cmd.__init__(self)
     self.intro = "Welcome to the ACLSwitch command line " \
                  "interface.\nType help or ? to list the " \
                  "available commands.\n"
     self.prompt = "(ACLSwitch) "
     self._policy = Policy(self, self._URL_ACLSW)
     self._acl = ACL(self, self._URL_ACLSW)
예제 #4
0
class ACLSwitchCLI(cmd.Cmd):
    """An interactive Command Line Interface (CLI) for ACLSwitch.
    """

    MSG_ERR_ACLSW_CON = "ERROR: Unable to establish a connection with " \
                        "ACLSwitch."
    MSG_ERR_ACLSW_CON_LOST = "ERROR: Connection with ACLSwitch lost."
    _URL_ACLSW = "http://127.0.0.1:8080/aclswitch"
    _VERSION = "1.0.1"

    def __init__(self):
        """Initialise the main interface.
        """
        # Register a handler for catching Ctrl+c
        signal.signal(signal.SIGINT, self.signal_handler)
        # Create and initialise CLI objects
        cmd.Cmd.__init__(self)
        self.intro = "Welcome to the ACLSwitch command line " \
                     "interface.\nType help or ? to list the " \
                     "available commands.\n"
        self.prompt = "(ACLSwitch) "
        self._policy = Policy(self, self._URL_ACLSW)
        self._acl = ACL(self, self._URL_ACLSW)

    def do_acl(self, arg):
        """Present the user with different options to modify rules.
        """
        self._acl.cmdloop()

    def do_policy(self, arg):
        """Present the user with different options to modify policy domains.
        """
        self._policy.cmdloop()

    def do_status(self, arg):
        """Fetch some basic information from ACLSwitch.
        """
        info = self._fetch_status()
        if info is None:
            return
        print("ACLSwitch CLI version: {0}".format(self._VERSION))
        print("ACLSwitch version: {0}".format(info["version"]))
        print("Number of ACL rules: {0}".format(info["num_rules"]))
        print("Number of policy domains: {0}".format(info[
                                                        "num_policies"]))
        print("Number of registered switches: {0}".format(info[
                                                        "num_switches"]))

    def do_exit(self, arg):
        """Close the program.
        """
        self._close_program()

    def _fetch_status(self):
        """Fetch some basic status information from ACLSwitch.

        :return: Information in a dict, None if error.
        """
        print("Fetching status information...")
        try:
            resp = requests.get(self._URL_ACLSW)
        except requests.ConnectionError as err:
            print(cli_util.MSG_CON_ERR + str(err))
            return None
        except requests.HTTPError as err:
            print(cli_util.MSG_HTTP_ERR + str(err))
            return None
        except requests.Timeout as err:
            print(cli_util.MSG_TIMEOUT + str(err))
            return None
        except requests.TooManyRedirects as err:
            print(cli_util.MSG_REDIRECT_ERR + str(err))
            return None
        if resp.status_code != 200:
            print("Error fetching resource, HTTP {0} "
                  "returned.".format(resp.status_code))
            return None
        return resp.json()

    def signal_handler(self, sig, frame):
        self._close_program()

    def _close_program(self):
        print("\n")
        sys.exit(0)