Exemplo n.º 1
0
    def __str__(self):
        t = Terminal()

        pos = self.pos
        text = self.token.value

        # TODO: lack of bounds check...
        next_token = self.tokens[pos + 1]
        msg = ("try inserting '{t.bold}{text}{t.normal}' "
                "".format_map(locals()))

        line_tokens = get_token_line(self.pos, self.tokens)

        if self.insert_after:
            line = format_line(line_tokens)
            # Add an extra space BEFORE the insertion point:
            padding = ' ' * (2 + 1 + self.column)
        else:
            # Add an extra space AFTER insertion point;
            line = format_line(line_tokens,
                               insert_space_before=self.tokens[self.pos])
            padding = ' ' * (1 + self.column)

        arrow = padding + t.bold_green('^')
        suggestion = padding + t.green(text)

        return '\n'.join((msg, line, arrow, suggestion))
Exemplo n.º 2
0
def print_bugs_and_errors(bugs: Iterable[Bug],
                          errs: Iterable[ValidationError]) -> None:
    t = Terminal()

    for bug in bugs:
        print(t.bold_green("✔ ") + t.green(repr(bug)))

    for err in errs:
        model = err.model
        assert isinstance(model, ModelMetaclass)
        title = model.schema()["title"]
        for sub_err in err.errors():
            (loc, ) = sub_err["loc"]
            msg = sub_err["msg"]
            print(t.bold_red("✘ ") + t.red(f"{title}.{loc}: {msg}"))
Exemplo n.º 3
0
    cycle = cycle + 1
    sleep(1)

  r_mppt_v = avg_volt / cycle;
  r_mppt_i = avg_current / cycle;

  if float(r_mppt_i) > float(current_max):
    result = t.bold_red('FAILED')
  elif float(r_mppt_i) < float(current_min):
    result = t.bold_red('FAILED')
  elif float(r_mppt_v) > float(voltage_max):
    result = t.bold_red('FAILED')
  elif float(r_mppt_v) < float(voltage_min):
    result = t.bold_red('FAILED')
  else:
    result = t.bold_green('PASSED')

  print 'Franz CH%s @ %sV, %sA....[%s]' %(ch, r_mppt_v, r_mppt_i, result)
  print ''

def config_acs(pfc_path):
  sleep(5)
  tom = shell.Shell(pfc_path)
  sleep(1)
  sb = shell.Scoreboard(tom,'acs')
  sleep(1)
  tom.sendline('power on acs')
  sleep(3)
  print sb.query('power_acs_enabled')
  sleep(1)
  tom.sendline('acs esc on')
Exemplo n.º 4
0
class Cli:
    def __init__(self):
        self.data = OrgData.load()
        self.total_points = 0
        self.personal_points = []
        self.terminal = Terminal()

    def run(self):
        self.list_names()
        self.print_prompt(
            "Please enter a person's name and their effort ratings, separated by spaces."
        )
        self.print_prompt(
            "Add a comma to the end of each line to continue to the next person."
        )
        ratings = self.get_ratings()
        self.calculate_points(ratings)
        self.calculate_shares()

    def list_names(self):
        text = 'People: '
        for (i, person) in enumerate(self.data.people):
            text += self.terminal.bold_yellow(person['name'])
            if i < len(self.data.people) - 1:
                text += ' | '
        print(text)

    def get_ratings(self):
        ratings = []
        while True:
            input_line = input('<- ')
            ratings.append(input_line.replace(',', ''))
            if not input_line.endswith(','):
                break
        return ratings

    def calculate_points(self, ratings):
        for rating in ratings:
            frags = rating.split(' ')
            name = frags[0]
            standing = list(
                filter(lambda p: p['name'] == name,
                       self.data.people))[0]['standing']
            multiplier = float(
                list(
                    filter(lambda m: m['name'] == standing,
                           self.data.standings))[0]['rate'])
            points = sum(list(map(int, frags[1:]))) * multiplier
            self.personal_points.append((name, points))
            self.total_points += points

    def calculate_shares(self):
        remainder = 0.0
        for person in self.personal_points:
            share = person[1] / self.total_points * 100
            truncated_share = math.trunc(share)
            remainder += share - truncated_share
            self.print_result(person[0], truncated_share)
        self.print_result('Remainder', remainder)

    def print_prompt(self, prompt):
        print(self.terminal.italic(prompt))

    def print_result(self, label, percentage):
        left = self.terminal.bold_yellow(f"{label}")
        right = self.terminal.bold_green(f"{percentage}%")
        print(left, '->', right)
Exemplo n.º 5
0
class ProgressiveResult(TextTestResult):
    """Test result which updates a progress bar instead of printing dots

    Nose's ResultProxy will wrap it, and other plugins can still print
    stuff---but without smashing into my progress bar, care of my Plugin's
    stderr/out wrapping.

    """
    def __init__(self, cwd, total_tests, stream, config=None):
        super(ProgressiveResult, self).__init__(stream, None, 0, config=config)
        self._cwd = cwd
        self._options = config.options
        self._term = Terminal(stream=stream,
                              force_styling=config.options.with_styling)

        if self._term.is_a_tty or self._options.with_bar:
            # 1 in case test counting failed and returned 0
            self.bar = ProgressBar(total_tests or 1,
                                   self._term,
                                   config.options.bar_filled_color,
                                   config.options.bar_empty_color)
        else:
            self.bar = NullProgressBar()

        # Declare errorclass-savviness so ErrorClassPlugins don't monkeypatch
        # half my methods away:
        self.errorClasses = {}

    def startTest(self, test):
        """Update the progress bar."""
        super(ProgressiveResult, self).startTest(test)
        self.bar.update(nose_selector(test), self.testsRun)

    def _printTraceback(self, test, err):
        """Print a nicely formatted traceback.

        :arg err: exc_info()-style traceback triple
        :arg test: the test that precipitated this call

        """
        # Don't bind third item to a local var; that can create
        # circular refs which are expensive to collect. See the
        # sys.exc_info() docs.
        exception_type, exception_value = err[:2]
        # TODO: In Python 3, the traceback is attached to the exception
        # instance through the __traceback__ attribute. If the instance
        # is saved in a local variable that persists outside the except
        # block, the traceback will create a reference cycle with the
        # current frame and its dictionary of local variables. This will
        # delay reclaiming dead resources until the next cyclic garbage
        # collection pass.

        extracted_tb = extract_relevant_tb(
            err[2],
            exception_type,
            exception_type is test.failureException)
        test_frame_index = index_of_test_frame(
            extracted_tb,
            exception_type,
            exception_value,
            test)
        if test_frame_index:
            # We have a good guess at which frame is the test, so
            # trim everything until that. We don't care to see test
            # framework frames.
            extracted_tb = extracted_tb[test_frame_index:]

        with self.bar.dodging():
            self.stream.write(''.join(
                format_traceback(
                    extracted_tb,
                    exception_type,
                    exception_value,
                    self._cwd,
                    self._term,
                    self._options.function_color,
                    self._options.dim_color,
                    self._options.editor,
                    self._options.editor_shortcut_template)))

    def _printHeadline(self, kind, test, is_failure=True):
        """Output a 1-line error summary to the stream if appropriate.

        The line contains the kind of error and the pathname of the test.

        :arg kind: The (string) type of incident the precipitated this call
        :arg test: The test that precipitated this call

        """
        if is_failure or self._options.show_advisories:
            with self.bar.dodging():
                self.stream.writeln(
                        '\n' +
                        (self._term.bold if is_failure else '') +
                        '%s: %s' % (kind, nose_selector(test)) +
                        (self._term.normal if is_failure else ''))  # end bold

    def _recordAndPrintHeadline(self, test, error_class, artifact):
        """Record that an error-like thing occurred, and print a summary.

        Store ``artifact`` with the record.

        Return whether the test result is any sort of failure.

        """
        # We duplicate the errorclass handling from super rather than calling
        # it and monkeying around with showAll flags to keep it from printing
        # anything.
        is_error_class = False
        for cls, (storage, label, is_failure) in self.errorClasses.items():
            if isclass(error_class) and issubclass(error_class, cls):
                if is_failure:
                    test.passed = False
                storage.append((test, artifact))
                is_error_class = True
        if not is_error_class:
            self.errors.append((test, artifact))
            test.passed = False

        is_any_failure = not is_error_class or is_failure
        self._printHeadline(label if is_error_class else 'ERROR',
                            test,
                            is_failure=is_any_failure)
        return is_any_failure

    def addSkip(self, test, reason):
        """Catch skipped tests in Python 2.7 and above.

        Though ``addSkip()`` is deprecated in the nose plugin API, it is very
        much not deprecated as a Python 2.7 ``TestResult`` method. In Python
        2.7, this will get called instead of ``addError()`` for skips.

        :arg reason: Text describing why the test was skipped

        """
        self._recordAndPrintHeadline(test, SkipTest, reason)
        # Python 2.7 users get a little bonus: the reason the test was skipped.
        if isinstance(reason, Exception):
            reason = reason.message
        if reason and self._options.show_advisories:
            with self.bar.dodging():
                self.stream.writeln(reason)

    def addError(self, test, err):
        # We don't read this, but some other plugin might conceivably expect it
        # to be there:
        excInfo = self._exc_info_to_string(err, test)
        is_failure = self._recordAndPrintHeadline(test, err[0], excInfo)
        if is_failure:
            self._printTraceback(test, err)

    def addFailure(self, test, err):
        super(ProgressiveResult, self).addFailure(test, err)
        self._printHeadline('FAIL', test)
        self._printTraceback(test, err)

    def printSummary(self, start, stop):
        """As a final summary, print number of tests, broken down by result."""
        def renderResultType(type, number, is_failure):
            """Return a rendering like '2 failures'.

            :arg type: A singular label, like "failure"
            :arg number: The number of tests with a result of that type
            :arg is_failure: Whether that type counts as a failure

            """
            # I'd rather hope for the best with plurals than totally punt on
            # being Englishlike:
            ret = '%s %s%s' % (number, type, 's' if number != 1 else '')
            if is_failure and number:
                ret = self._term.bold(ret)
            return ret

        # Summarize the special cases:
        counts = [('test', self.testsRun, False),
                  ('failure', len(self.failures), True),
                  ('error', len(self.errors), True)]
        # Support custom errorclasses as well as normal failures and errors.
        # Lowercase any all-caps labels, but leave the rest alone in case there
        # are hard-to-read camelCaseWordBreaks.
        counts.extend([(label.lower() if label.isupper() else label,
                        len(storage),
                        is_failure)
                        for (storage, label, is_failure) in
                            self.errorClasses.values() if len(storage)])
        summary = (', '.join(renderResultType(*a) for a in counts) +
                   ' in %.1fs' % (stop - start))

        # Erase progress bar. Bash doesn't clear the whole line when printing
        # the prompt, leaving a piece of the bar. Also, the prompt may not be
        # at the bottom of the terminal.
        self.bar.erase()
        self.stream.writeln()
        if self.wasSuccessful():
            self.stream.write(self._term.bold_green('OK!  '))
        self.stream.writeln(summary)
Exemplo n.º 6
0
class BasePlugin(Cmd):
    """ BasePlugin - the base class which all of our plugins should inherit from.
        It is meant to define all the necessary base functions for plugins. """

    prompt = '>> '
    ruler = '-'
    intro = banner()
    terminators = []

    CATEGORY_SHELL = to_bold_cyan('Shell Based Operations')
    CATEGORY_GENERAL = to_bold_cyan('General Commands')

    def __init__(self):
        Cmd.__init__(self,
                     startup_script=read_config().get('STARTUP_SCRIPT', ''),
                     use_ipython=True)

        self.aliases.update({'exit': 'quit', 'help': 'help -v'})
        self.hidden_commands.extend([
            'load', 'pyscript', 'set', 'shortcuts', 'alias', 'unalias',
            'shell', 'macro'
        ])

        self.t = Terminal()
        self.selected_client = None

        self.prompt = self.get_prompt()
        self.allow_cli_args = False

        # Alerts Thread
        self._stop_thread = False
        self._seen_clients = set(Client.unique_client_ids())
        self._alert_thread = Thread()

        # Register the hook functions
        self.register_preloop_hook(self._alert_thread_preloop_hook)
        self.register_postloop_hook(self._alert_thread_postloop_hook)

        # Set the window title
        self.set_window_title('<< JSShell 2.0 >>')

        categorize([
            BasePlugin.do_help, BasePlugin.do_quit, BasePlugin.do_py,
            BasePlugin.do_ipy, BasePlugin.do_history, BasePlugin.do_edit
        ], BasePlugin.CATEGORY_GENERAL)

        self.register_postparsing_hook(
            self._refresh_client_data_post_parse_hook)

    def _alert_thread_preloop_hook(self) -> None:
        """ Start the alerter thread """

        self._stop_thread = False
        self._alert_thread = Thread(name='alerter',
                                    target=self._alert_function)
        self._alert_thread.start()

    def _alert_thread_postloop_hook(self) -> None:
        """ Stops the alerter thread """

        self._stop_thread = True

        if self._alert_thread.is_alive():
            self._alert_thread.join()

    def _alert_function(self) -> None:
        """ When the client list is larger than the one we know of
            alert the user that a new client has registered """

        while not self._stop_thread:
            if self.terminal_lock.acquire(blocking=False):
                current_clients = set(Client.unique_client_ids())
                delta = current_clients - self._seen_clients

                if len(delta) > 0:
                    self.async_alert(
                        self.t.bold_blue(' << new client registered >>'),
                        self.prompt)

                self._seen_clients = current_clients
                self.terminal_lock.release()

            sleep(0.5)

    def print_error(self, text: str, end: str = '\n', start: str = '') -> None:
        """ Prints a formatted error message """

        self.poutput(start + self.t.bold_red('[-]') + ' ' + self.t.red(text),
                     end=end)

    def print_info(self, text: str, end: str = '\n', start: str = '') -> None:
        """ Prints a formatted informational message """

        self.poutput(start + self.t.bold_yellow('[!]') + ' ' +
                     self.t.yellow(text),
                     end=end)

    def print_ok(self, text: str, end: str = '\n', start: str = '') -> None:
        """ Prints a formatted success message """

        self.poutput(start + self.t.bold_green('[+]') + ' ' +
                     self.t.green(text),
                     end=end)

    def print_pairs(self,
                    title: str,
                    body: Dict[str, str],
                    just_return: bool = False,
                    colors: bool = True) -> Union[str, None]:
        """ Prints pairs of values with a certain title """

        if colors:
            data = [self.t.bold_white_underline(title)]
        else:
            data = [title]

        for key, value in body.items():
            k = key + ':'
            if colors:
                data.append(f' - {self.t.bold(k)} {value}')
            else:
                data.append(f' - {k} {value}')

        if just_return:
            return '\n'.join(data)

        self.ppaged('\n'.join(data))

    def select_client(self, client: Client) -> None:
        """ Handles the operation of selecting a new client """

        self.selected_client = client
        self.update_prompt()

    def _refresh_client_data_post_parse_hook(
            self, params: PostparsingData) -> PostparsingData:
        """ Refreshes the selected client data from the database. We do that because
            of `mongoengine`s behaviour, where if we set the current client, we do not track
            for modifications. This way, before every command is parsed we re-select the client """

        if self.selected_client:
            cid = self.selected_client.cid
            self.select_client(Client.objects(cid=cid).first())

        return params

    def get_prompt(self) -> str:
        """ Handles the operations of getting the prompt string """

        prompt = self.t.bold_cyan('>> ')

        if self.selected_client:
            client_id = self.t.bold_red(self.selected_client.cid)
            prompt = self.t.cyan(f"[Client #{client_id}]") + ' ' + prompt

        return prompt

    def update_prompt(self) -> None:
        """ Handles what is needed when updating the prompt """

        self.prompt = get_prompt(self)
Exemplo n.º 7
0
class Interface:
    def __init__(self, config_filename):
        self.config_filename = config_filename
        self.config = collections.OrderedDict()
        self.term = Terminal()

        self.hline = u'\u2500'
        self.left_up = u'\u250c'
        self.right_up = u'\u2510'
        self.left_down = u'\u2514'
        self.right_down = u'\u2518'
        self.left_middle = u'\u251c'
        self.right_middle = u'\u2524'
        self.vline = u'\u2502'

        self.TITLE = [
            'SDNProbe: A Lightweight Tool for Securing SDN Data Plane with Active Probing'
            .center(self.term.width - 2)
        ]
        self.MENU = [
            'Show the configure file',
            'Generate test packets from topology file',
            'Generate topology graph for the controller as input',
            'Start probing', 'Exit (or press q)'
        ]
        self.menu_pos = 0

    def show_menu(self):
        while True:
            os.system('clear')

            print self.left_up + self.hline * (self.term.width -
                                               2) + self.right_up
            print self.vline + self.term.bold_cyan(self.TITLE[0]) + self.vline
            print self.left_middle + self.hline * (self.term.width -
                                                   2) + self.right_middle

            for i, menu in enumerate(self.MENU):
                if i == self.menu_pos:
                    print self.vline + ' ' + self.term.yellow_reverse(
                        '[' + str(i + 1) + '] ' + menu) + ' ' * (
                            self.term.width - len(menu) - 7) + self.vline
                else:
                    print self.vline + ' ' + '[' + str(
                        i + 1) + '] ' + menu + ' ' * (
                            self.term.width - len(menu) - 7) + self.vline
            print self.left_down + self.hline * (self.term.width -
                                                 2) + self.right_down
            print

            press = self.getch()
            if press == 'up':
                self.menu_pos = (self.menu_pos - 1) % (len(self.MENU))
            elif press == 'down':
                self.menu_pos = (self.menu_pos + 1) % (len(self.MENU))
            elif press == 'exit':
                return 5
            elif press == 'enter':
                return self.menu_pos + 1

    def read_config(self):
        try:
            with open(self.config_filename, 'r') as f:
                for line in f:
                    var, val = line.strip().split('=')
                    self.config[var] = val
        except IOError:
            self.print_IOError(self.config_filename)
            exit()
        except:
            print self.term.bold_red('[-] Error')
            exit()

    def show_config(self):
        max_var_len = len(max(self.config.keys(), key=len))
        max_val_len = len(max(self.config.values(), key=len))
        width = max_var_len + 3 + max_val_len + 1

        print self.term.bold_green(' [+] Config file: ' + self.config_filename)
        print ' ' + self.left_up + self.hline * width + self.right_up
        for var, val in self.config.iteritems():
            print ' ' + self.vline + ' ' + var.ljust(
                max_var_len) + ' = ' + val.ljust(max_val_len) + self.vline
        print ' ' + self.left_down + self.hline * width + self.right_down

        self.check_config()

        self.show_coninue()

    def generate_test_packets(self):
        print self.term.bold_green(' [+] Generating test packets from "' +
                                   self.config['TOPOLOGY_FILE'] + '"')
        if not self.check_config():
            self.show_coninue()
            return False
        print

        p = Popen(['make', '-C', 'cmodule/'], stdout=PIPE)
        print self.term.bold_green(' [+] Make file')
        for line in iter(p.stdout.readline, b''):
            print self.term.bold_magenta(' ' * 5 + line.strip())
        print

        p = Popen(
            ['cmodule/./generate_test_packets', self.config['TOPOLOGY_FILE']],
            stdout=PIPE)
        print self.term.bold_green(' [+] Start to generate test packets')
        for line in iter(p.stdout.readline, b''):
            if line.startswith('  '):
                print self.term.bold_magenta(' ' * 5 + line.strip())
            elif line.strip() == '':
                pass
            else:
                print self.term.bold_green(' [+] ' + line.strip())

        self.show_coninue()

    def generate_topology_graph(self):
        print self.term.bold_green(' [+] Generating topology graph from "' +
                                   self.config['TOPOLOGY_FILE'] + '" and "' +
                                   self.config['TOPOLOGY_FILE'] +
                                   '.testpackets"')
        if not self.check_config():
            self.show_coninue()
            return False
        p = Popen([
            'pymodule/generate_topology_graph.py', '-i',
            self.config['TOPOLOGY_FILE']
        ],
                  stdout=PIPE)
        for line in iter(p.stdout.readline, b''):
            print self.term.bold_green(' [+] ' + line.rstrip())

        self.show_coninue()

    def start_probing(self):
        p = Popen(['ryu-manager', 'pymodule/controller.py'],
                  stdout=PIPE,
                  stderr=PIPE)
        for line in iter(p.stderr.readline, b''):
            try:
                if line.startswith('---'):
                    print self.term.bold_magenta(' [+] ' + line.strip())
                elif line.startswith('  '):
                    print self.term.bold_magenta(' ' * 5 + line.strip())
                else:
                    print self.term.bold_green(' [+] ' + line.strip())
            except:
                pass
        print

        self.show_coninue()

    def show_exit(self):
        print self.term.bold_red(' [+] Exit...\n')

    def show_coninue(self):
        print
        print self.term.bold_green(' [+] Press any key to continue...')
        self.getch()

    def check_config(self):
        if not os.path.exists(self.config['TOPOLOGY_FILE']):
            self.print_IOError(self.config['TOPOLOGY_FILE'])
            return False
        return True

    def print_IOError(self, filename):
        print self.term.bold_red(' [-] IOError: No such file: ' + filename)

    def getch(self):
        def input_key():
            fd = sys.stdin.fileno()
            old_settings = termios.tcgetattr(fd)
            key = ''
            try:
                tty.setraw(sys.stdin.fileno())
                ch = sys.stdin.read(1)
                if ord(ch) == 13: return 'enter'
                elif ord(ch) == 113 or ord(ch) == 81: return 'exit'
                elif ord(ch) != 27: return False

                ch = sys.stdin.read(1)
                if ord(ch) != 91: return False

                ch = sys.stdin.read(1)
                if ord(ch) == 65: return 'up'
                elif ord(ch) == 66:
                    return 'down'
                    #elif ord(ch) == 67: return 'right'
                    #elif ord(ch) == 68: return 'left'
                else:
                    return False
            finally:
                termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)

        while True:
            press_key = input_key()
            if press_key != '': break

        return press_key
Exemplo n.º 8
0
        action = agent.get_action(state)
        next_state, reward, done, info = env.step(action)
        agent.train((state, action, next_state, reward, done))

        state = next_state
        total_reward += reward

        print("reward:", reward, "state:", state, "action:", action,
              "episode:", episode, "steps:", steps, "total_reward:",
              total_reward, "epsilon:", agent.epsilon)
        # env.render()

        with tf.variable_scope('q_table', reuse=True):
            weights = agent.sess.run(tf.get_variable('kernel'))
            # print(weights)

        # time.sleep(.01)
        clear_output(wait=True)
        steps += 1

    print(t.bold_red("Fell into a hole.")) if reward == 0.0 else print(
        t.bold_green("Success!"))
    # time.sleep(.7)

    # how many times it fell into a hole vs success per episode
    agent.episode_plot.append(episode)
    agent.total_reward_plot.append(total_reward)
    agent.steps_plot.append(steps)

plot_seaborn(agent.episode_plot, agent.total_reward_plot, agent.steps_plot)
Exemplo n.º 9
0
class BaseMixin(Cmd):
    """The Mqtt-Pwn Base Command Line Interface Mixin"""

    prompt = '>> '
    ruler = '-'
    intro = banner()

    CMD_CAT_BROKER_OP = 'Broker Related Operations'
    CMD_CAT_VICTIM_OP = 'Victim Related Operations'
    CMD_CAT_GENERAL = 'General Commands'

    variables_choices = ['victim', 'scan']

    def __init__(self):
        """The class initializer"""

        Cmd.__init__(self, startup_script=config.STARTUP_SCRIPT)

        self.aliases.update({'exit': 'quit'})
        self.hidden_commands.extend(
            ['load', 'pyscript', 'set', 'shortcuts', 'alias', 'unalias', 'py'])

        self.current_victim = None
        self.mqtt_client = None
        self.current_scan = None

        self.t = Terminal()

        self.base_prompt = get_prompt(self)
        self.prompt = self.base_prompt

        categorize((
            BaseMixin.do_edit,
            BaseMixin.do_help,
            BaseMixin.do_history,
            BaseMixin.do_quit,
            BaseMixin.do_shell,
        ), BaseMixin.CMD_CAT_GENERAL)

    def print_error(self, text, end='\n', start=''):
        """Prints an error message with colors"""

        self.poutput(start + self.t.bold_red('[-]') + ' ' + self.t.red(text),
                     end=end)

    def print_info(self, text, end='\n', start=''):
        """Prints an information message with colors"""

        self.poutput(start + self.t.bold_yellow('[!]') + ' ' +
                     self.t.yellow(text),
                     end=end)

    def print_ok(self, text, end='\n', start=''):
        """Prints a successful message with colors"""

        self.poutput(start + self.t.bold_green('[+]') + ' ' +
                     self.t.green(text),
                     end=end)

    def print_pairs(self, title, body):
        """Prints a message that contains pairs for data"""

        self.poutput(self.t.bold_white_underline(title))

        for key, value in body.items():
            k = key + ':'
            self.poutput(f' - {self.t.bold(k)} {value}')

    def update_prompt(self):
        """Updates the command prompt"""

        self.prompt = get_prompt(self)
Exemplo n.º 10
0
for episode in range(300): # 1500 10000
    state = env.reset()
    done = False
    steps = 0
    while not done:
        action = agent.get_action(state)
        next_state, reward, done, info = env.step(action)
        agent.train((state, action, next_state, reward, done))

        state = next_state
        total_reward += reward

        print("reward:", reward, "state:", state, "action:", action, "episode:", episode, "steps:", steps, "total_reward:", total_reward, "epsilon:", agent.epsilon)
        env.render()

        # print(agent.q_table)

        time.sleep(.01)
        clear_output(wait=True)
        steps += 1

    print(t.bold_red("Fell into a hole.")) if reward == 0.0 else print(t.bold_green("Success!"))
    time.sleep(.7)

    # how many times it fell into a hole vs success per episode
    agent.episode_plot.append(episode)
    agent.total_reward_plot.append(total_reward)
    agent.steps_plot.append(steps)

plot_seaborn(agent.episode_plot, agent.total_reward_plot, agent.steps_plot)