def __init__(self, frame_shape=None, game_inputs=None):

        if frame_shape is None:
            raise SerpentError("A 'frame_shape' tuple kwarg is required...")

        states_spec = {"type": "float", "shape": frame_shape}

        if game_inputs is None:
            raise SerpentError("A 'game_inputs' dict kwarg is required...")

        self.game_inputs = game_inputs
        self.game_inputs_mapping = self._generate_game_inputs_mapping()

        actions_spec = {"type": "int", "num_actions": len(self.game_inputs)}

        network_spec = [
            {"type": "conv2d", "size": 32, "window": 8, "stride": 4},
            {"type": "conv2d", "size": 64, "window": 4, "stride": 2},
            {"type": "conv2d", "size": 64, "window": 3, "stride": 1},
            {"type": "flatten"},
            {"type": "dense", "size": 512}
        ]

        self.agent = PPOAgent(
            states_spec=states_spec,
            actions_spec=actions_spec,
            batched_observe=128,
            scope="ppo",
            summary_spec=None,
            network_spec=network_spec,
            device=None,
            session_config=None,
            saver_spec=None,
            distributed_spec=None,
            discount=0.99,
            variable_noise=None,
            states_preprocessing_spec=None,
            explorations_spec=None,
            reward_preprocessing_spec=None,
            distributions_spec=None,
            entropy_regularization=1e-2,
            batch_size=128,
            keep_last_timestep=True,
            baseline_mode=None,
            baseline=None,
            baseline_optimizer=None,
            gae_lambda=None,
            likelihood_ratio_clipping=None,
            step_optimizer=None,
            optimization_steps=10
        )
Ejemplo n.º 2
0
    def combine_game_inputs(self, combination):
        """ Combine game input axes in a single flattened collection

        Args:
        combination [list] -- A combination of valid game input axis keys
        """

        # Validation
        if not isinstance(combination, list):
            raise SerpentError("'combination' needs to be a list")

        for entry in combination:
            if isinstance(entry, list):
                for entry_item in entry:
                    if entry_item not in self.game_inputs:
                        raise SerpentError("'combination' entries need to be valid members of self.game_input...")
            else:
                if entry not in self.game_inputs:
                    raise SerpentError("'combination' entries need to be valid members of self.game_input...")

        # Concatenate Grouped Axes (if needed)
        game_input_axes = list()

        for entry in combination:
            if isinstance(entry, str):
                game_input_axes.append(self.game_inputs[entry])
            elif isinstance(entry, list):
                concatenated_game_input_axis = dict()

                for entry_item in entry:
                    concatenated_game_input_axis = {**concatenated_game_input_axis, **self.game_inputs[entry_item]}

                game_input_axes.append(concatenated_game_input_axis)

        # Combine Game Inputs
        game_inputs = dict()

        if not len(game_input_axes):
            return game_inputs

        for keys in itertools.product(*game_input_axes):
            compound_label = list()
            game_input = list()

            for index, key in enumerate(keys):
                compound_label.append(key)
                game_input += game_input_axes[index][key]

            game_inputs[" - ".join(compound_label)] = game_input

        return game_inputs
Ejemplo n.º 3
0
    def after_launch(self):
        self.is_launched = True

        current_attempt = 1

        while current_attempt <= 100:
            self.window_id = self.window_controller.locate_window(self.window_name)

            if self.window_id not in [0, "0"]:
                break

            time.sleep(0.1)

        time.sleep(0.5)

        if self.window_id in [0, "0"]:
            raise SerpentError("Game window not found...")

        self.window_controller.move_window(self.window_id, 0, 0)

        self.dashboard_window_id = self.window_controller.locate_window("Serpent.AI Dashboard")

        # TODO: Test on Linux
        if self.dashboard_window_id is not None and self.dashboard_window_id not in [0, "0"]:
            self.window_controller.bring_window_to_top(self.dashboard_window_id)

        self.window_controller.focus_window(self.window_id)

        self.window_geometry = self.extract_window_geometry()

        print(self.window_geometry)
Ejemplo n.º 4
0
    def __init__(self,
                 name,
                 game_inputs=None,
                 callbacks=None,
                 seed=420133769,
                 logger=Loggers.NOOP,
                 logger_kwargs=None):
        self.name = name

        if not isinstance(game_inputs, list):
            raise SerpentError("'game_inputs' should be list...")

        self.game_inputs = game_inputs
        self.game_inputs_mappings = self._generate_game_inputs_mappings()

        self.callbacks = callbacks or dict()

        self.current_state = None

        self.current_reward = 0
        self.cumulative_reward = 0

        self.analytics_client = AnalyticsClient(
            project_key=config["analytics"]["topic"])

        random.seed(seed)
        np.random.seed(seed)

        self.logger = Agent.logger_mapping[logger](logger_kwargs=logger_kwargs)
    def __init__(self,
                 game_api=None,
                 input_controller=None,
                 bosses=None,
                 items=None):
        super().__init__("Boss Fight Environment",
                         game_api=game_api,
                         input_controller=input_controller)

        if not isinstance(bosses, list):
            raise SerpentError(
                "'bosses' is expected to be a list of Bosses|DoubleBosses|MiniBosses enum items..."
            )

        filtered_bosses = list()

        for boss in bosses:
            if isinstance(boss, Bosses) or isinstance(
                    boss, DoubleBosses) or isinstance(boss, MiniBosses):
                filtered_bosses.append(boss)

        if not len(filtered_bosses):
            raise SerpentError(
                "'bosses' needs to contain at least 1 valid Bosses|DoubleBosses|MiniBosses enum item..."
            )

        self.bosses = filtered_bosses
        self.boss = None

        if items is None:
            items = list()

        if isinstance(items, list):
            filtered_items = list()

            for item in items:
                if isinstance(item, Items):
                    filtered_items.append(item)

        elif isinstance(items, dict):
            filtered_items = items

        self.items = filtered_items

        self.reset()
Ejemplo n.º 6
0
 def _initialize_object_recognizer(self, name, backend, classes, model_path,
                                   **kwargs):
     if backend == ObjectRecognizers.LUMINOTH:
         from serpent.machine_learning.object_recognition.object_recognizers.luminoth_object_recognizer import LuminothObjectRecognizer
         return LuminothObjectRecognizer(name,
                                         classes=classes,
                                         model_path=model_path,
                                         **kwargs)
     else:
         raise SerpentError("The specified backend is invalid!")
    def __init__(self, name, algorithm="ssd", classes=None, model_path=None, **kwargs):
        self.name = name

        self.model_path = model_path
        self.model = None

        if self.model_path is not None:
            self.classes = self._load_classes()

            self.algorithm = None
            self.model = self._load_model()
        else:
            if algorithm not in self.__class__.algorithms:
                raise SerpentError(f"Algorithm '{algorithm}' not implemented in {self.__class__.__name__}")

            self.algorithm = algorithm
            self.classes = classes
Ejemplo n.º 8
0
    def generate_actions(self, state, **kwargs):
        if not isinstance(state, GameFrameBuffer):
            raise SerpentError(
                "RecorderAgent 'generate_actions' state should be a GameFrameBuffer"
            )

        self.game_frame_buffers.append(state)
        self.current_state = state

        actions = list()

        for game_inputs_item in self.game_inputs:
            if game_inputs_item["control_type"] == InputControlTypes.DISCRETE:
                label = random.choice(list(game_inputs_item["inputs"].keys()))
                action = list()

                actions.append((label, action, None))
            elif game_inputs_item[
                    "control_type"] == InputControlTypes.CONTINUOUS:
                label = game_inputs_item["name"]
                action = list()

                size = 1

                if "size" in game_inputs_item["inputs"]:
                    size = game_inputs_item["inputs"]["size"]

                if size == 1:
                    input_value = random.uniform(
                        game_inputs_item["inputs"]["minimum"],
                        game_inputs_item["inputs"]["maximum"])
                else:
                    input_value = list()

                    for i in range(size):
                        input_value.append(
                            random.uniform(
                                game_inputs_item["inputs"]["minimum"],
                                game_inputs_item["inputs"]["maximum"]))

                actions.append((label, action, input_value))

        return actions
Ejemplo n.º 9
0
    def __init__(self,
                 name,
                 game_inputs=None,
                 callbacks=None,
                 seed=None,
                 window_geometry=None,
                 logger=Loggers.NOOP,
                 logger_kwargs=None):
        super().__init__(name,
                         game_inputs=game_inputs,
                         callbacks=callbacks,
                         seed=seed,
                         logger=logger,
                         logger_kwargs=logger_kwargs)

        if window_geometry is None or not isinstance(window_geometry, dict):
            raise SerpentError(
                "RecorderAgent expects a 'window_geometry' dict kwarg.")

        self.window_geometry = window_geometry

        self.game_frame_buffers = list()
        self.rewards = list()

        self.current_step = 0

        self.redis_client = StrictRedis(**config["redis"])

        InputRecorder.pause_input_recording()

        input_recorder_command = "serpent record_inputs"
        self.input_recorder_process = subprocess.Popen(
            shlex.split(input_recorder_command))

        signal.signal(signal.SIGINT, self._handle_signal)
        signal.signal(signal.SIGTERM, self._handle_signal)

        atexit.register(self._handle_signal, 15, None, False)
Ejemplo n.º 10
0
    from kivy.app import App
    from kivy.core.window import Window

    from kivy.uix.widget import Widget
    from kivy.uix.image import Image
    from kivy.uix.label import Label

    from kivy.uix.floatlayout import FloatLayout
    from kivy.uix.gridlayout import GridLayout
    from kivy.uix.boxlayout import BoxLayout

    from kivy.clock import Clock
except ImportError:
    raise SerpentError(
        "Setup has not been been performed for the Dashboard module. Please run 'serpent setup dashboard'"
    )


class DashboardApp(App):
    def __init__(self, width=None, height=None):
        super().__init__()

        self.display_width, self.display_height = self._determine_fullscreen_resolution(
        )

        if width is not None and height is not None:
            self.width = width
            self.height = height
        else:
            self.width = self.display_width
Ejemplo n.º 11
0
    def __init__(self,
                 name,
                 game_inputs=None,
                 callbacks=None,
                 input_shape=None,
                 input_type=None,
                 use_tensorboard=True,
                 tensorforce_kwargs=None):
        super().__init__(name, game_inputs=game_inputs, callbacks=callbacks)

        if input_shape is None or not isinstance(input_shape, tuple):
            raise SerpentError("'input_shape' should be a tuple...")

        if input_type is None or input_type not in ["bool", "int", "float"]:
            raise SerpentError(
                "'input_type' should be one of bool|int|float...")

        states_spec = {"type": input_type, "shape": input_shape}

        # TODO: Support multiple actions
        # TODO: Support continuous action spaces
        actions_spec = {"type": "int", "num_actions": len(self.game_inputs)}

        summary_spec = None

        if use_tensorboard:
            summary_spec = {
                "directory":
                "./tensorboard/",
                "steps":
                50,
                "labels": [
                    "configuration", "gradients_scalar", "regularization",
                    "inputs", "losses", "variables"
                ]
            }

        default_network_spec = [{
            "type": "conv2d",
            "size": 32,
            "window": 8,
            "stride": 4
        }, {
            "type": "conv2d",
            "size": 64,
            "window": 4,
            "stride": 2
        }, {
            "type": "conv2d",
            "size": 64,
            "window": 3,
            "stride": 1
        }, {
            "type": "flatten"
        }, {
            "type": "dense",
            "size": 1024
        }]

        agent_kwargs = dict(batch_size=1024,
                            batched_observe=1024,
                            network_spec=default_network_spec,
                            device=None,
                            session_config=None,
                            saver_spec=None,
                            distributed_spec=None,
                            discount=0.99,
                            variable_noise=None,
                            states_preprocessing_spec=None,
                            explorations_spec=None,
                            reward_preprocessing_spec=None,
                            distributions_spec=None,
                            entropy_regularization=0.01,
                            keep_last_timestep=True,
                            baseline_mode=None,
                            baseline=None,
                            baseline_optimizer=None,
                            gae_lambda=None,
                            likelihood_ratio_clipping=None,
                            step_optimizer=None,
                            optimization_steps=10)

        if isinstance(tensorforce_kwargs, dict):
            for key, value in tensorforce_kwargs.items():
                if key in agent_kwargs:
                    agent_kwargs[key] = value

        self.agent = TFPPOAgent(states_spec=states_spec,
                                actions_spec=actions_spec,
                                summary_spec=summary_spec,
                                scope="ppo",
                                **agent_kwargs)

        try:
            self.restore_model()
        except Exception:
            pass
Ejemplo n.º 12
0
    def __init__(self, frame_shape=None, game_inputs=None):

        if frame_shape is None:
            raise SerpentError("A 'frame_shape' tuple kwarg is required...")

        states_spec = {"type": "float", "shape": frame_shape}

        if game_inputs is None:
            raise SerpentError("A 'game_inputs' dict kwarg is required...")

        self.game_inputs = game_inputs
        self.game_inputs_mapping = self._generate_game_inputs_mapping()

        actions_spec = {"type": "int", "num_actions": len(self.game_inputs)}

        network_spec = [
            {"type": "conv2d", "size": 1, "window": 2, "stride": 1},
            {"type": "flatten"},
            # {"type": "dense", "size": 64},
            {"type": "dense", "size": 6}
        ]

        self.agent = PPOAgent(
            states=states_spec,
            actions=actions_spec,
            network=network_spec,

            batched_observe=256,
            batching_capacity=1000,
            # BatchAgent
            #keep_last_timestep=True,
            # PPOAgent
            step_optimizer=dict(
                type='adam',
                learning_rate=1e-4
            ),
            optimization_steps=10,
            # Model
            scope='ppo'
                #discount=0.97,
            # DistributionModel
                #distributions=None,
                #entropy_regularization=0.01,
            # PGModel
                #baseline_mode=None,
                #baseline=None,
                #baseline_optimizer=None,
                #gae_lambda=None,
            # PGLRModel
                #likelihood_ratio_clipping=None,
            #summary_spec=summary_spec,
            #distributed_spec=None,
            # More info
                #device=None,
            #session_config=None,
                #saver=None,
                #variable_noise=None,
            #states_preprocessing_spec=None,
            #explorations_spec=None,
            #reward_preprocessing_spec=None,
                #execution=None,
                #actions_exploration=None,
                #update_mode=None,
                #memory=None,
                #subsampling_fraction=0.1
        )
Ejemplo n.º 13
0
    def __init__(self, frame_shape=None, game_inputs=None):

        if frame_shape is None:
            raise SerpentError("A 'frame_shape' tuple kwarg is required...")

        states_spec = {"type": "float", "shape": frame_shape}

        if game_inputs is None:
            raise SerpentError("A 'game_inputs' dict kwarg is required...")

        self.game_inputs = game_inputs
        self.game_inputs_mapping = self._generate_game_inputs_mapping()

        print('game inputs mapping:')
        print(self.game_inputs_mapping)
        actions_spec = {"type": "int", "num_values": len(self.game_inputs)}

        summary_spec = {
            "directory":
            "./board/",
            "steps":
            50,
            "labels": [
                "configuration", "gradients_scalar", "regularization",
                "inputs", "losses", "variables"
            ]
        }

        network_spec = [{
            "type": "conv2d",
            "size": 16,
            "window": 8,
            "stride": 4
        }, {
            "type": "conv2d",
            "size": 32,
            "window": 4,
            "stride": 2
        }, {
            "type": "conv2d",
            "size": 32,
            "window": 3,
            "stride": 1
        }, {
            "type": "flatten"
        }, {
            "type": "dense",
            "size": 64
        }]

        baseline_spec = {
            "type": "cnn",
            "conv_sizes": [32, 32],
            "dense_sizes": [32]
        }

        saver_spec = {
            "directory": os.path.join(os.getcwd(), "datasets",
                                      "t4androidmodel"),
            "seconds": 120
        }
        #         memory_spec = {'type':'latest', 'include_next_states':False, 'capacity':1000*1000}

        self.agent = PPOAgent(
            states=states_spec,
            actions=actions_spec,
            network=network_spec,
            #             baseline_mode='states',
            #             baseline=baseline_spec,
            summarizer=summary_spec,
            memory=10,
            update_mode=dict(unit='timesteps', batch_size=2),
            discount=0.97,
            saver=saver_spec)

        self.agent.initialize()
Ejemplo n.º 14
0
    def __init__(self, frame_shape=None, game_inputs=None):

        if frame_shape is None:
            raise SerpentError("A 'frame_shape' tuple kwarg is required...")

        states_spec = {"type": "float", "shape": frame_shape}

        if game_inputs is None:
            raise SerpentError("A 'game_inputs' dict kwarg is required...")

        self.game_inputs = game_inputs
        self.game_inputs_mapping = self._generate_game_inputs_mapping()

        actions_spec = {"type": "int", "num_actions": len(self.game_inputs)}

        summary_spec = {
            "directory":
            "./board/",
            "steps":
            50,
            "labels": [
                "configuration", "gradients_scalar", "regularization",
                "inputs", "losses", "variables"
            ]
        }

        network_spec = [{
            "type": "conv2d",
            "size": 32,
            "window": 8,
            "stride": 4
        }, {
            "type": "conv2d",
            "size": 64,
            "window": 4,
            "stride": 2
        }, {
            "type": "conv2d",
            "size": 64,
            "window": 3,
            "stride": 1
        }, {
            "type": "flatten"
        }, {
            "type": "dense",
            "size": 1024
        }]

        self.agent = PPOAgent(batched_observe=2560,
                              scope="ppo",
                              device=None,
                              saver_spec=None,
                              distributed_spec=None,
                              discount=0.97,
                              variable_noise=None,
                              states_preprocessing_spec=None,
                              explorations_spec=None,
                              reward_preprocessing_spec=None,
                              distributions_spec=None,
                              entropy_regularization=0.01,
                              batch_size=2560,
                              keep_last_timestep=True,
                              baseline_mode=None,
                              baseline=None,
                              baseline_optimizer=None,
                              gae_lambda=None,
                              likelihood_ratio_clipping=None,
                              step_optimizer=None,
                              optimization_steps=10)
Ejemplo n.º 15
0
import json

import numpy as np
import h5py

import skimage.io
import skimage.util

try:
    import torch

    from serpent.machine_learning.reinforcement_learning.rainbow_dqn.rainbow_agent import RainbowAgent
    from serpent.machine_learning.reinforcement_learning.rainbow_dqn.replay_memory import ReplayMemory
except ImportError:
    raise SerpentError(
        "Setup has not been been performed for the ML module. Please run 'serpent setup ml'"
    )


class RainbowDQNAgentModes(enum.Enum):
    OBSERVE = 0
    TRAIN = 1
    EVALUATE = 2


# Adapted for Serpent.AI from https://github.com/Kaixhin/Rainbow


class RainbowDQNAgent(Agent):
    def __init__(self,
                 name,
Ejemplo n.º 16
0
    def __init__(self,
                 name,
                 game_inputs=None,
                 callbacks=None,
                 seed=420133769,
                 rainbow_kwargs=None,
                 logger=Loggers.NOOP,
                 logger_kwargs=None):
        super().__init__(name,
                         game_inputs=game_inputs,
                         callbacks=callbacks,
                         seed=seed,
                         logger=logger,
                         logger_kwargs=logger_kwargs)

        if len(game_inputs) > 1:
            raise SerpentError(
                "RainbowDQNAgent only supports a single axis of game inputs.")

        if game_inputs[0]["control_type"] != InputControlTypes.DISCRETE:
            raise SerpentError(
                "RainbowDQNAgent only supports discrete input spaces")

        if torch.cuda.is_available():
            self.device = torch.device("cuda")

            torch.set_default_tensor_type("torch.cuda.FloatTensor")
            torch.backends.cudnn.enabled = False

            torch.cuda.manual_seed_all(seed)
        else:
            self.device = torch.device("cpu")
            torch.set_num_threads(1)

        torch.manual_seed(seed)

        agent_kwargs = dict(algorithm="Rainbow DQN",
                            replay_memory_capacity=100000,
                            history=4,
                            discount=0.99,
                            multi_step=3,
                            priority_weight=0.4,
                            priority_exponent=0.5,
                            atoms=51,
                            v_min=-10,
                            v_max=10,
                            batch_size=32,
                            hidden_size=1024,
                            noisy_std=0.1,
                            learning_rate=0.0000625,
                            adam_epsilon=1.5e-4,
                            max_grad_norm=10,
                            target_update=10000,
                            save_steps=5000,
                            observe_steps=50000,
                            max_steps=5000000,
                            model=f"datasets/rainbow_dqn_{self.name}.pth",
                            seed=seed)

        if isinstance(rainbow_kwargs, dict):
            for key, value in rainbow_kwargs.items():
                if key in agent_kwargs:
                    agent_kwargs[key] = value

        self.agent = RainbowAgent(len(self.game_inputs[0]["inputs"]),
                                  self.device,
                                  atoms=agent_kwargs["atoms"],
                                  v_min=agent_kwargs["v_min"],
                                  v_max=agent_kwargs["v_max"],
                                  batch_size=agent_kwargs["batch_size"],
                                  multi_step=agent_kwargs["multi_step"],
                                  discount=agent_kwargs["discount"],
                                  history=agent_kwargs["history"],
                                  hidden_size=agent_kwargs["hidden_size"],
                                  noisy_std=agent_kwargs["noisy_std"],
                                  learning_rate=agent_kwargs["learning_rate"],
                                  adam_epsilon=agent_kwargs["adam_epsilon"],
                                  max_grad_norm=agent_kwargs["max_grad_norm"],
                                  model=agent_kwargs["model"])

        self.replay_memory = ReplayMemory(
            agent_kwargs["replay_memory_capacity"],
            self.device,
            history=agent_kwargs["history"],
            discount=agent_kwargs["discount"],
            multi_step=agent_kwargs["multi_step"],
            priority_weight=agent_kwargs["priority_weight"],
            priority_exponent=agent_kwargs["priority_exponent"])

        self.priority_weight_increase = (1 - agent_kwargs["priority_weight"]
                                         ) / (agent_kwargs["max_steps"] -
                                              agent_kwargs["observe_steps"])

        self.target_update = agent_kwargs["target_update"]

        self.save_steps = agent_kwargs["save_steps"]
        self.observe_steps = agent_kwargs["observe_steps"]
        self.max_steps = agent_kwargs["max_steps"]

        self.remaining_observe_steps = self.observe_steps

        self.current_episode = 1
        self.current_step = 0

        self.current_action = -1

        self.observe_mode = "RANDOM"
        self.set_mode(RainbowDQNAgentModes.OBSERVE)

        self.model = agent_kwargs["model"]

        if os.path.isfile(self.model):
            self.observe_mode = "MODEL"
            self.restore_model()

        self.logger.log_hyperparams(agent_kwargs)

        if self._has_human_input_recording() and self.observe_mode == "RANDOM":
            self.add_human_observations_to_replay_memory()
Ejemplo n.º 17
0
    from kivy.core.window import Window
    from kivy.core.image import Image as CoreImage

    from kivy.uix.widget import Widget
    from kivy.uix.image import Image
    from kivy.uix.label import Label

    from kivy.uix.floatlayout import FloatLayout
    from kivy.uix.gridlayout import GridLayout
    from kivy.uix.boxlayout import BoxLayout

    from kivy.clock import Clock
except ImportError:
    raise SerpentError(
        "Setup has not been been performed for the GUI module. Please run 'serpent setup gui'"
    )

from PIL import Image as PILImage

import io


class VisualDebuggerApp(App):
    def __init__(self, buckets=None):
        super().__init__()

        self.visual_debugger = VisualDebugger(buckets=buckets)
        self.canvas = None

    def build(self):
Ejemplo n.º 18
0
import skimage.util
import skimage.transform
import skimage.measure
import skimage.io

from serpent.utilities import is_unix, is_windows
from serpent.utilities import SerpentError

try:
    if is_unix():
        import tesserocr
    elif is_windows():
        import pytesseract
except ImportError:
    raise SerpentError(
        "Setup has not been been performed for the OCR module. Please run 'serpent setup ocr'"
    )

import editdistance

from PIL import Image


def locate_string(query_string,
                  image,
                  fuzziness=0,
                  ocr_preset=None,
                  offset_x=0,
                  offset_y=0):
    images, text_regions = extract_ocr_candidates(
        image,
Ejemplo n.º 19
0
    def __init__(self,
                 name,
                 game_inputs=None,
                 callbacks=None,
                 seed=420133769,
                 input_shape=None,
                 ppo_kwargs=None,
                 logger=Loggers.NOOP,
                 logger_kwargs=None):
        super().__init__(name,
                         game_inputs=game_inputs,
                         callbacks=callbacks,
                         seed=seed,
                         logger=logger,
                         logger_kwargs=logger_kwargs)

        if len(game_inputs) > 1:
            raise SerpentError(
                "PPOAgent only supports a single axis of game inputs.")

        if game_inputs[0]["control_type"] != InputControlTypes.DISCRETE:
            raise SerpentError("PPOAgent only supports discrete input spaces")

        if torch.cuda.is_available():
            self.device = torch.device("cuda")

            torch.set_default_tensor_type("torch.cuda.FloatTensor")
            torch.backends.cudnn.benchmark = True

            torch.cuda.manual_seed_all(seed)
        else:
            self.device = torch.device("cpu")
            torch.set_num_threads(1)

        torch.manual_seed(seed)

        agent_kwargs = dict(algorithm="PPO",
                            is_recurrent=False,
                            surrogate_objective_clip=0.2,
                            epochs=4,
                            batch_size=32,
                            value_loss_coefficient=0.5,
                            entropy_regularization_coefficient=0.01,
                            learning_rate=0.0001,
                            adam_epsilon=0.00001,
                            max_grad_norm=0.3,
                            memory_capacity=1024,
                            discount=0.99,
                            gae=False,
                            gae_tau=0.95,
                            save_steps=10000,
                            model=f"datasets/ppo_{self.name}.pth",
                            seed=seed)

        if isinstance(ppo_kwargs, dict):
            for key, value in ppo_kwargs.items():
                if key in agent_kwargs:
                    agent_kwargs[key] = value

        self.discount = agent_kwargs["discount"]

        self.gae = agent_kwargs["gae"]
        self.gae_tau = agent_kwargs["gae_tau"]

        input_shape = (4, input_shape[0], input_shape[1]
                       )  # 4x Grayscale OR Quantized

        self.actor_critic = Policy(input_shape,
                                   len(self.game_inputs[0]["inputs"]),
                                   agent_kwargs["is_recurrent"])

        if torch.cuda.is_available():
            self.actor_critic.cuda(device=self.device)

        self.agent = PPO(self.actor_critic,
                         agent_kwargs["surrogate_objective_clip"],
                         agent_kwargs["epochs"],
                         agent_kwargs["batch_size"],
                         agent_kwargs["value_loss_coefficient"],
                         agent_kwargs["entropy_regularization_coefficient"],
                         lr=agent_kwargs["learning_rate"],
                         eps=agent_kwargs["adam_epsilon"],
                         max_grad_norm=agent_kwargs["max_grad_norm"])

        self.storage = RolloutStorage(agent_kwargs["memory_capacity"], 1,
                                      input_shape,
                                      len(self.game_inputs[0]["inputs"]),
                                      self.actor_critic.state_size)

        if torch.cuda.is_available():
            self.storage.cuda(device=self.device)

        self.current_episode = 1
        self.current_step = 0

        self.mode = PPOAgentModes.TRAIN

        self.save_steps = agent_kwargs["save_steps"]

        self.model_path = agent_kwargs["model"]

        if os.path.isfile(self.model_path):
            self.restore_model()

        self.logger.log_hyperparams(agent_kwargs)