Пример #1
0
def make_summary_table(overpasses, tz, twelvehour):
    """
    Make a summary data table to print to console
    """
    table = Table()
    table.add_column(Align("Date", "center"), justify="right")
    table.add_column(Align("Duration", "center"), justify="right")
    table.add_column(Align("Max Elev", "center"), justify="right")
    table.add_column(Align("Type", "center"), justify='center')

    def get_min_sec_string(total_seconds: int) -> str:
        """
        Get total number of seconds, return string with min:sec format
        """
        nmin = floor(total_seconds / 60)
        nsec = total_seconds - nmin * 60
        return f"{nmin:.0f}:{nsec:0.0f}"

    for overpass in overpasses:
        row = []
        date = overpass.aos.dt.astimezone(tz)
        day = date.strftime("%x").lstrip('0')
        # round to nearest minute
        if date.second >= 30:
            date.replace(minute=date.minute + 1)
        if twelvehour:
            time = date.strftime("%I:%M").lstrip("0")
            ampm = date.strftime('%p').lower()
            row.append(f"{day:>8}  {time:>5} {ampm}")
        else:
            time = date.strftime("%H:%M")
            row.append(f"{day}  {time}")

        # if overpass.brightness is not None:
        #     brightness_str = f"{overpass.brightness:4.1f}"
        # else:
        #     brightness_str = " "*4
        # table_data += " "*2 + brightness_str

        row.append(get_min_sec_string(overpass.duration))
        row.append(f"{int(overpass.tca.elevation):2}\u00B0")
        if overpass.type:
            row.append(overpass.type.value)  # + '\n'
            fg = 'green' if overpass.type.value == PassType.visible else None
        else:
            fg = None
        row = tuple(row)
        table.add_row(*row, style=fg)
    return table
Пример #2
0
    def get_renderable(self):
        tasks_panel_rendered = self.tasks_panel.get_renderable()
        metrics_panel_rendered = self.metrics_panel.get_renderable()

        if tasks_panel_rendered is None and metrics_panel_rendered is None:
            return Align(
                Text(f"Waiting for test to start... ({self.test_name})"),
                align="center")

        layout = Layout()

        layout.split_row(
            Layout(name="running", visible=False),
            Layout(name="metrics", ratio=2, visible=False),
        )

        if metrics_panel_rendered is not None:
            layout["metrics"].visible = True
            layout["metrics"].update(metrics_panel_rendered)

        if tasks_panel_rendered is not None:
            layout["running"].visible = True
            layout["running"].update(tasks_panel_rendered)

        return Panel(layout, title=self.test_name)
Пример #3
0
def test_align_width():
    console = Console(file=io.StringIO(), width=40)
    words = "Deep in the human unconscious is a pervasive need for a logical universe that makes sense. But the real universe is always one step beyond logic"
    console.print(Align(words, "center", width=30))
    result = console.file.getvalue()
    expected = "     Deep in the human unconscious      \n     is a pervasive need for a          \n     logical universe that makes        \n     sense. But the real universe       \n     is always one step beyond          \n     logic                              \n"
    assert result == expected
Пример #4
0
def test_align_center_middle():
    console = Console(file=io.StringIO(), width=10)
    console.print(Align("foo\nbar", "center", vertical="middle"), height=5)
    expected = "          \n   foo    \n   bar    \n          \n          \n"
    result = console.file.getvalue()
    print(repr(result))
    assert result == expected
Пример #5
0
def test_align_bottom():
    console = Console(file=io.StringIO(), width=10)
    console.print(Align("foo", vertical="bottom"), height=5)
    expected = "          \n          \n          \n          \nfoo       \n"
    result = console.file.getvalue()
    print(repr(result))
    assert result == expected
Пример #6
0
def test_align_right_style():
    console = Console(file=io.StringIO(),
                      width=10,
                      color_system="truecolor",
                      force_terminal=True)
    console.print(Align("foo", "right", style="on blue"))
    assert console.file.getvalue(
    ) == "\x1b[44m       \x1b[0m\x1b[44mfoo\x1b[0m\n"
Пример #7
0
def make_detail_table(overpasses, tz, twelvehour):
    """
    Make a detailed data table to print to console
    """
    table = Table()
    table.add_column(Align("Date", 'center'), justify="right")
    for x in ('Start', 'Max', 'End'):
        table.add_column(Align(f"{x}\nTime", "center"), justify="right")
        table.add_column(Align(f"{x}\nEl", "center"), justify="right", width=4)
        table.add_column(Align(f"{x}\nAz", "center"), justify="right", width=4)
    table.add_column(Align("Type", "center"), justify='center')

    def point_string(point):
        time = point.dt.astimezone(tz)
        point_data = []
        if twelvehour:
            point_data.append(
                time.strftime("%I:%M:%S").lstrip("0") + ' ' +
                time.strftime("%p").lower())
        else:
            point_data.append(time.strftime("%H:%M:%S"))
        point_data.append("{:>2}\u00B0".format(int(point.elevation)))
        point_data.append("{:3}".format(point.direction))
        return point_data

    for overpass in overpasses:
        row = []
        row.append("{}".format(
            overpass.aos.dt.astimezone(tz).strftime("%m/%d/%y")))

        # if overpass.brightness is not None:
        #     brightness_str = f"{overpass.brightness:4.1f}"
        # else:
        #     brightness_str = " "*4
        # table_data += " "*2 + brightness_str
        row += point_string(overpass.aos)
        row += point_string(overpass.tca)
        row += point_string(overpass.los)
        if overpass.type:
            row.append(overpass.type.value)  # + '\n'
            fg = 'green' if overpass.type.value == PassType.visible else None
        else:
            fg = None
        row = tuple(row)
        table.add_row(*row, style=fg)
    return table
Пример #8
0
def main(filename: str) -> None:
    with open(Path(filename) / "setup.py") as f:
        tree = ast.parse(f.read())

    analyzer = Analyzer()
    analyzer.visit(tree)
    if analyzer.setup_function is None:
        raise RuntimeError("Invalid, no setup function found")

    print(Panel(Align("New [bold]setup.cfg", align="center")))
    print()
    print(analyzer.metadata)
    print()
    print(analyzer.options)
    print()
    print(Panel(Align("New setup() in [bold]setup.py", "center")))
    print()
    new_setup_py = black.format_str(ast.unparse(analyzer.setup_function),
                                    mode=black.Mode())
    print(Syntax(new_setup_py, "python", theme="ansi_light"))
def final_check(pgscatalog_df, plink_variants_df):
    overlaps = pgscatalog_df[PGSCATALOG_KEY_COLUMN].isin(
        plink_variants_df[PLINK_KEY_COLUMN])
    sample = pgscatalog_df.loc[overlaps, [PGSCATALOG_KEY_COLUMN]].head(10)
    layout["top"]["leftfile"].update(
        Align(
            render_file_table(sample, title="PRS WM file key"),
            align="center",
        ))
    layout["top"]["rightfile"].update(
        Align(
            render_file_table(sample, title="plink data key"),
            align="center",
        ))
    console.print(layout)
    printout(
        "-----------------\nPlease review the keys the data will be matched on"
    )
    if CONFIG is not None:
        is_to_run = True  # auto mode
    else:
        # this is intentionally not done using "ask()" due to a different logic
        is_to_run = Confirm.ask(
            "Do you confirm you want to match the files on these keys?"
        )  # manual mode
    if is_to_run:
        layout["top"]["leftfile"].update(
            render_file_table(pgscatalog_df, title="PRS WM file"))
        layout["top"]["rightfile"].update(
            render_file_table(plink_variants_df, title="plink data"))
        console.print(layout)
        printout("-----------------")
    else:
        printout(
            "You have decided to halt the execution due to the mismatch between the IDs in two files."
        )
        printout(
            "Please, review this troubleshooting guide, if you need an inspiration about how fixing this:"
        )
        print_error_files()
        exit(50)
Пример #10
0
def generate_display(display: str) -> Panel:
    """Generate a panel to display using the rich library.

    Args:
        display (str): The visualized values to display.

    Returns:
        rich.Panel: The generated panel.
    """
    table = Panel(display,
                  expand=False,
                  title="[bold blue]Devilizer[/bold blue]")
    return Align(table, align="center")
Пример #11
0
def test_bad_align_legal():

    # Legal
    Align("foo", "left")
    Align("foo", "center")
    Align("foo", "right")

    # illegal
    with pytest.raises(ValueError):
        Align("foo", None)
    with pytest.raises(ValueError):
        Align("foo", "middle")
    with pytest.raises(ValueError):
        Align("foo", "")
    with pytest.raises(ValueError):
        Align("foo", "LEFT")
    with pytest.raises(ValueError):
        Align("foo", vertical="somewhere")
Пример #12
0
        def make_header_columns():
            # Epoch and LR
            columns = [Text("#", justify="right", style="bold")]
            if learning_rate is not None:
                columns += [Text("LR", justify="right")]
            yield Columns(columns, align="center", width=6)

            # Training losses
            text = Align(
                Text("Total", justify="right", style="bold red"),
                width=col_width,
                align="center",
            )
            if multi_target:
                columns = [text] + [
                    Align(Text(n, justify="right", style="red"),
                          width=col_width) for n in losses_training.keys()
                ]
                yield Columns(columns, align="center", width=col_width)
            else:
                yield text

            # Validation losses
            if total_loss_validation is not None:
                text = Align(
                    Text("Total", justify="center", style="bold blue"),
                    width=col_width,
                    align="center",
                )
                if multi_target:
                    columns = [text] + [
                        Align(Text(n, justify="center", style="blue"),
                              width=col_width)
                        for n in losses_validation.keys()
                    ]
                    yield Columns(columns, align="center", width=col_width)
                else:
                    yield text

            # Metrics
            if metrics is not None:
                for name, values in metrics.items():
                    if isinstance(values, dict):
                        columns = [
                            Align(
                                Text(n, justify="center", style="purple"),
                                width=col_width,
                            ) for n in values.keys()
                        ]
                        yield Columns(columns, align="center", width=col_width)
                    else:
                        yield Align(Text(""), width=col_width)
Пример #13
0
def dry_run_banner() -> Panel:
    width = min(get_terminal_size().columns, 75)
    return Panel(
        Align(
            """\
You are in [bold blue]dry-run mode[/].
Issues and cards will not be created.

Creation of objects from --create-missing will still occur
[dim](e.g. missing labels will be created if you answer 'yes')[/]\
""",
            "center",
            width=width,
        ),
        title="--dry-run was passed",
        width=width,
        style="black on yellow",
        box=Box("    \n" * 8),
    )
Пример #14
0
def run_ensemble_strategy(
    df, unique_trade_date, rebalance_window, validation_window
) -> None:
    """Ensemble Strategy that combines PPO, A2C and DDPG"""
    # for ensemble model, it's necessary to feed the last state
    # of the previous model to the current model as the initial state
    last_state_ensemble = []

    ppo_sharpe_list = []
    ddpg_sharpe_list = []
    a2c_sharpe_list = []

    model_used = []

    # based on the analysis of the in-sample data
    # turbulence_threshold = 140
    insample_turbulence = df[
        (df.datadate < config.VALIDATION_START_DATE - 1)
        & (df.datadate >= config.TRAINING_START_DATE)
    ]
    insample_turbulence = insample_turbulence.drop_duplicates(subset=["datadate"])
    insample_turbulence_threshold = np.quantile(
        insample_turbulence.turbulence.values, 0.90
    )

    start = time.time()
    for i in range(
        rebalance_window + validation_window, len(unique_trade_date), rebalance_window
    ):
        ## initial state is empty
        if i - rebalance_window - validation_window == 0:
            # inital state
            initial = True
        else:
            # previous state
            initial = False

        # Tuning turbulence index based on historical data
        # Turbulence lookback window is one quarter
        end_date_index = df.index[
            df["datadate"]
            == unique_trade_date[i - rebalance_window - validation_window]
        ].to_list()[-1]
        start_date_index = end_date_index - validation_window * 30 + 1

        historical_turbulence = df.iloc[start_date_index : (end_date_index + 1), :]
        # historical_turbulence = df[(df.datadate<unique_trade_date[i - rebalance_window - validation_window]) & (df.datadate>=(unique_trade_date[i - rebalance_window - validation_window - 63]))]

        historical_turbulence = historical_turbulence.drop_duplicates(
            subset=["datadate"]
        )

        historical_turbulence_mean = np.mean(historical_turbulence.turbulence.values)

        if historical_turbulence_mean > insample_turbulence_threshold:
            # if the mean of the historical data is greater than the 90% quantile of insample turbulence data
            # then we assume that the current market is volatile,
            # therefore we set the 90% quantile of insample turbulence data as the turbulence threshold
            # meaning the current turbulence can't exceed the 90% quantile of insample turbulence data
            turbulence_threshold = insample_turbulence_threshold
        else:
            # if the mean of the historical data is less than the 90% quantile of insample turbulence data
            # then we tune up the turbulence_threshold, meaning we lower the risk
            turbulence_threshold = np.quantile(insample_turbulence.turbulence.values, 1)

        style = "[bold #31DDCF]"
        rprint(
            Align(
                f"{style}Turbulence Threshold:[/] {str(turbulence_threshold)}", "center"
            )
        )

        ############## Environment Setup starts ##############
        ## training env
        train = data_split(
            df,
            start=config.TRAINING_START_DATE,
            end=unique_trade_date[i - rebalance_window - validation_window],
        )
        env_train = DummyVecEnv([lambda: StockEnvTrain(train)])

        ## validation env
        validation = data_split(
            df,
            start=unique_trade_date[i - rebalance_window - validation_window],
            end=unique_trade_date[i - rebalance_window],
        )
        env_val = DummyVecEnv(
            [
                lambda: StockEnvValidation(
                    validation, turbulence_threshold=turbulence_threshold, iteration=i
                )
            ]
        )

        obs_val = env_val.reset()
        ############## Environment Setup ends ##############

        ############## Training and Validation starts ##############
        table = Table(
            title=f"Training from 20090000 to {unique_trade_date[i - rebalance_window - validation_window]}",
            expand=True,
        )
        table.add_column("Mode Name", justify="center")
        table.add_column("Sharpe Ratio")
        table.add_column("Training Time")

        with Live(table, auto_refresh=False) as live:

            model_a2c, a2c_training_time = train_A2C(
                env_train, model_name="A2C_30k_dow_{}".format(i), timesteps=30000
            )
            DRL_validation(
                model=model_a2c,
                test_data=validation,
                test_env=env_val,
                test_obs=obs_val,
            )
            sharpe_a2c = get_validation_sharpe(i)
            table.add_row("A2C", str(sharpe_a2c), f"{a2c_training_time} minutes")
            live.update(table, refresh=True)

            model_ppo, ppo_training_time = train_PPO(
                env_train, model_name="PPO_100k_dow_{}".format(i), timesteps=100000
            )

            DRL_validation(
                model=model_ppo,
                test_data=validation,
                test_env=env_val,
                test_obs=obs_val,
            )
            sharpe_ppo = get_validation_sharpe(i)
            table.add_row("PPO", str(sharpe_ppo), f"{ppo_training_time} minutes")
            live.update(table, refresh=True)

            model_ddpg, ddpg_training_time = train_DDPG(
                env_train, model_name="DDPG_10k_dow_{}".format(i), timesteps=10000
            )
            # model_ddpg = train_TD3(env_train, model_name="DDPG_10k_dow_{}".format(i), timesteps=20000)
            DRL_validation(
                model=model_ddpg,
                test_data=validation,
                test_env=env_val,
                test_obs=obs_val,
            )
            sharpe_ddpg = get_validation_sharpe(i)
            table.add_row("DDPG", str(sharpe_ddpg), f"{ddpg_training_time} minutes")
            live.update(table, refresh=True)

            ppo_sharpe_list.append(sharpe_ppo)
            a2c_sharpe_list.append(sharpe_a2c)
            ddpg_sharpe_list.append(sharpe_ddpg)

        # Model Selection based on sharpe ratio
        if (sharpe_ppo >= sharpe_a2c) & (sharpe_ppo >= sharpe_ddpg):
            model_ensemble = model_ppo
            model_used.append("PPO")
        elif (sharpe_a2c > sharpe_ppo) & (sharpe_a2c > sharpe_ddpg):
            model_ensemble = model_a2c
            model_used.append("A2C")
        else:
            model_ensemble = model_ddpg
            model_used.append("DDPG")
        ############## Training and Validation ends ##############

        ############## Trading starts ##############

        # print("Used Model: ", model_ensemble)
        last_state_ensemble = DRL_prediction(
            df=df,
            model=model_ensemble,
            name="ensemble",
            last_state=last_state_ensemble,
            iter_num=i,
            unique_trade_date=unique_trade_date,
            rebalance_window=rebalance_window,
            turbulence_threshold=turbulence_threshold,
            initial=initial,
        )
        print("\n\n")
        # print("============Trading Done============")
        ############## Trading ends ##############

    end = time.time()
    print("Ensemble Strategy took: ", (end - start) / 60, " minutes")
Пример #15
0
def test_align_fit():
    console = Console(file=io.StringIO(), width=10)
    console.print(Align("foobarbaze", "center"))
    assert console.file.getvalue() == "foobarbaze\n"
Пример #16
0
def test_render():
    console = Console(file=io.StringIO(), width=10)
    console.print(Align("foo", "left"))
    assert console.file.getvalue() == "foo\n"
Пример #17
0
def test_align_right():
    console = Console(file=io.StringIO(), width=10)
    console.print(Align("foo", "right"))
    assert console.file.getvalue() == "       foo\n"
Пример #18
0
def test_repr():
    repr(Align("foo", "left"))
    repr(Align("foo", "center"))
    repr(Align("foo", "right"))
Пример #19
0
        def make_columns():
            yield Columns(
                [
                    Align(Text(f"{epoch:3}", justify="right", style="bold"),
                          width=4),
                    Align(Text(f"{learning_rate:1.4f}", justify="right"),
                          width=6),
                ],
                align="center",
                width=6,
            )

            text = Align(
                Text(f"{total_loss_training:3.3f}", style="bold red"),
                align="center",
                width=col_width,
            )
            if multi_target:
                columns = [text] + [
                    Align(
                        Text(f"{l:3.3f}", justify="right", style="red"),
                        width=col_width,
                        align="right",
                    ) for _, l in losses_training.items()
                ]
                yield Columns(columns, align="center", width=col_width)
            else:
                yield text

            if total_loss_validation is not None:
                text = Align(
                    Text(
                        f"{total_loss_validation:.3f}",
                        justify="center",
                        style="bold blue",
                    ),
                    width=col_width,
                    align="center",
                )
                if multi_target:
                    columns = [text]
                    for _, l in losses_validation.items():
                        columns.append(
                            Align(
                                Text(f"{l:.3f}",
                                     justify="center",
                                     style="blue"),
                                width=col_width,
                                align="center",
                            ))
                    yield Columns(columns, align="center", width=col_width)
                else:
                    yield text

            # Metrics
            if metrics is not None:
                for name, values in metrics.items():
                    if isinstance(values, dict):
                        columns = [
                            Align(
                                Text(f"{v:.3f}",
                                     justify="center",
                                     style="purple"),
                                width=col_width,
                            ) for _, v in values.items()
                        ]
                        yield Columns(columns, align="center", width=col_width)
                    else:
                        yield Align(Text(f"{values:.3f}"),
                                    width=col_width,
                                    style="purple")
Пример #20
0
def test_align_no_pad():
    console = Console(file=io.StringIO(), width=10)
    console.print(Align("foo", "center", pad=False))
    console.print(Align("foo", "left", pad=False))
    assert console.file.getvalue() == "   foo\nfoo\n"
Пример #21
0
def test_measure():
    console = Console(file=io.StringIO(), width=20)
    _min, _max = Measurement.get(console, Align("foo bar", "left"), 20)
    assert _min == 3
    assert _max == 7
Пример #22
0
def check(
    c,
    lint_=True,
    fixmes_=False,
    test_=True,
    coverage_=True,
    mypy_=True,
    black_=True,
    isort_=True,
    docs_=True,
    clean_=True,
):
    """Runs all checkers on the code."""
    results = {}

    if lint_:
        print("-" * 20)
        print("Running pylint...")
        print("-" * 20)
        results["lint"] = lint(c).exited

    if fixmes_:
        print("-" * 20)
        print("Running pylint (fixmes)...")
        print("-" * 20)
        results["FIXME's"] = fixmes(c).exited

    if test_:
        print("-" * 20)
        print("Running tests...")
        print("-" * 20)
        results["test"] = test(c, verbose=False).exited

    if coverage_:
        print("-" * 20)
        print("Reporting test coverage...")
        print("-" * 20)
        results["coverage"] = coverage(c).exited

    if mypy_:
        print("-" * 20)
        print("Running mypy...")
        print("-" * 20)
        results["mypy"] = mypy(c).exited

    if black_:
        print("-" * 20)
        print("Running black (formatting, just checking)...")
        print("-" * 20)
        results["black"] = black(c, check=True).exited

    if isort_:
        print("-" * 20)
        print("Running isort (formatting, just checking)...")
        print("-" * 20)
        results["isort"] = isort(c, check=True).exited

    if docs_:
        print("-" * 20)
        print("Running mkdocs...")
        print("-" * 20)
        results["docs"] = docs(c, build=True, verbose=False).exited

    result = 1 if any(results.values()) else 0

    t = Table(
        title="Report",
        title_style="bold white",
        show_header=True,
        header_style="bold white",
        show_footer=True,
        footer_style="bold white",
        show_lines=True,
        box=box.ROUNDED,
    )
    t.add_column("Task", "Summary")
    t.add_column("Result", f"[bold]{ _code_to_stat(result, underline=True) }[/bold]")

    for k, v in results.items():
        t.add_row(k, _code_to_stat(v))

    print("\n")
    con.print(Align(t, "center"))

    if result == 0:
        exit_msg = (
            "Congratulations :sparkles::fireworks::sparkles: "
            + "You may commit! :heavy_check_mark:"
        )

    else:
        exit_msg = (
            "Great code dude :+1:, but it could use some final touches. "
            + "Don't commit just yet! :x:"
        )

    print(Align(f"[underline bold]{exit_msg}[/underline bold]", "center"))
    print("\n")

    if clean_:
        clean(c, silent=True)

    raise Exit(code=result)
Пример #23
0
def rich_format_help(
    *,
    obj: Union[click.Command, click.Group],
    ctx: click.Context,
    markup_mode: MarkupMode,
) -> None:
    """Print nicely formatted help text using rich.

    Based on original code from rich-cli, by @willmcgugan.
    https://github.com/Textualize/rich-cli/blob/8a2767c7a340715fc6fbf4930ace717b9b2fc5e5/src/rich_cli/__main__.py#L162-L236

    Replacement for the click function format_help().
    Takes a command or group and builds the help text output.
    """
    console = _get_rich_console()

    # Print usage
    console.print(Padding(highlighter(obj.get_usage(ctx)), 1),
                  style=STYLE_USAGE_COMMAND)

    # Print command / group help if we have some
    if obj.help:

        # Print with some padding
        console.print(
            Padding(
                Align(
                    _get_help_text(
                        obj=obj,
                        markup_mode=markup_mode,
                    ),
                    pad=False,
                ),
                (0, 1, 1, 1),
            ))
    panel_to_arguments: DefaultDict[str,
                                    List[click.Argument]] = defaultdict(list)
    panel_to_options: DefaultDict[str, List[click.Option]] = defaultdict(list)
    for param in obj.get_params(ctx):

        # Skip if option is hidden
        if getattr(param, "hidden", False):
            continue
        if isinstance(param, click.Argument):
            panel_name = (getattr(param, _RICH_HELP_PANEL_NAME, None)
                          or ARGUMENTS_PANEL_TITLE)
            panel_to_arguments[panel_name].append(param)
        elif isinstance(param, click.Option):
            panel_name = (getattr(param, _RICH_HELP_PANEL_NAME, None)
                          or OPTIONS_PANEL_TITLE)
            panel_to_options[panel_name].append(param)
    default_arguments = panel_to_arguments.get(ARGUMENTS_PANEL_TITLE, [])
    _print_options_panel(
        name=ARGUMENTS_PANEL_TITLE,
        params=default_arguments,
        ctx=ctx,
        markup_mode=markup_mode,
        console=console,
    )
    for panel_name, arguments in panel_to_arguments.items():
        if panel_name == ARGUMENTS_PANEL_TITLE:
            # Already printed above
            continue
        _print_options_panel(
            name=panel_name,
            params=arguments,
            ctx=ctx,
            markup_mode=markup_mode,
            console=console,
        )
    default_options = panel_to_options.get(OPTIONS_PANEL_TITLE, [])
    _print_options_panel(
        name=OPTIONS_PANEL_TITLE,
        params=default_options,
        ctx=ctx,
        markup_mode=markup_mode,
        console=console,
    )
    for panel_name, options in panel_to_options.items():
        if panel_name == OPTIONS_PANEL_TITLE:
            # Already printed above
            continue
        _print_options_panel(
            name=panel_name,
            params=options,
            ctx=ctx,
            markup_mode=markup_mode,
            console=console,
        )

    if isinstance(obj, click.MultiCommand):
        panel_to_commands: DefaultDict[str,
                                       List[click.Command]] = defaultdict(list)
        for command_name in obj.list_commands(ctx):
            command = obj.get_command(ctx, command_name)
            if command and not command.hidden:
                panel_name = (getattr(command, _RICH_HELP_PANEL_NAME, None)
                              or COMMANDS_PANEL_TITLE)
                panel_to_commands[panel_name].append(command)

        # Print each command group panel
        default_commands = panel_to_commands.get(COMMANDS_PANEL_TITLE, [])
        _print_commands_panel(
            name=COMMANDS_PANEL_TITLE,
            commands=default_commands,
            markup_mode=markup_mode,
            console=console,
        )
        for panel_name, commands in panel_to_commands.items():
            if panel_name == COMMANDS_PANEL_TITLE:
                # Already printed above
                continue
            _print_commands_panel(
                name=panel_name,
                commands=commands,
                markup_mode=markup_mode,
                console=console,
            )

    # Epilogue if we have it
    if obj.epilog:
        # Remove single linebreaks, replace double with single
        lines = obj.epilog.split("\n\n")
        epilogue = "\n".join([x.replace("\n", " ").strip() for x in lines])
        epilogue_text = _make_rich_rext(text=epilogue, markup_mode=markup_mode)
        console.print(Padding(Align(epilogue_text, pad=False), 1))
Пример #24
0
    def step(self, actions):
        # print(self.day)
        self.terminal = self.day >= len(self.df.index.unique()) - 1
        # print(actions)

        if self.terminal:
            plt.plot(self.asset_memory, "r")
            plt.savefig("results/account_value_trade_{}_{}.png".format(
                self.model_name, self.iteration))
            plt.close()
            df_total_value = pd.DataFrame(self.asset_memory)
            df_total_value.to_csv(
                "results/account_value_trade_{}_{}.csv".format(
                    self.model_name, self.iteration))
            end_total_asset = self.state[0] + sum(
                np.array(self.state[1:(STOCK_DIM + 1)]) *
                np.array(self.state[(STOCK_DIM + 1):(STOCK_DIM * 2 + 1)]))

            df_total_value.columns = ["account_value"]
            df_total_value["daily_return"] = df_total_value.pct_change(1)
            sharpe = ((4**0.5) * df_total_value["daily_return"].mean() /
                      df_total_value["daily_return"].std())

            style = "[bold #31DDCF]"
            rprint(
                Align(
                    Panel(
                        RenderGroup(
                            f"{style}Previous Total Asset:[/] {'$' + str(self.asset_memory[0])}",
                            f"{style}End Total Asset:[/] {'$' + str(end_total_asset)}",
                            f"""{style}Total Reward:[/] {'$' + str(self.state[0]
                    + sum(
                        np.array(self.state[1 : (STOCK_DIM + 1)])
                        * np.array(self.state[(STOCK_DIM + 1) : (STOCK_DIM * 2 + 1)])
                    )
                    - self.asset_memory[0])}""",
                            f"{style}Total Cost:[/] {'$' + str(self.cost)}",
                            f"{style}Total Trades:[/] {str(self.trades)}",
                            f"{style}Sharpe Ratio:[/] {str(sharpe)}",
                        ),
                        title=self.title,
                        expand=True,
                    ),
                    "center",
                ))

            df_rewards = pd.DataFrame(self.rewards_memory)
            df_rewards.to_csv("results/account_rewards_trade_{}_{}.csv".format(
                self.model_name, self.iteration))

            # print('total asset: {}'.format(self.state[0]+ sum(np.array(self.state[1:29])*np.array(self.state[29:]))))
            # with open('obs.pkl', 'wb') as f:
            #    pickle.dump(self.state, f)

            return self.state, self.reward, self.terminal, {}

        else:
            # print(np.array(self.state[1:29]))

            actions = actions * HMAX_NORMALIZE
            # actions = (actions.astype(int))
            if self.turbulence >= self.turbulence_threshold:
                actions = np.array([-HMAX_NORMALIZE] * STOCK_DIM)

            begin_total_asset = self.state[0] + sum(
                np.array(self.state[1:(STOCK_DIM + 1)]) *
                np.array(self.state[(STOCK_DIM + 1):(STOCK_DIM * 2 + 1)]))
            # print("begin_total_asset:{}".format(begin_total_asset))

            argsort_actions = np.argsort(actions)

            sell_index = argsort_actions[:np.where(actions < 0)[0].shape[0]]
            buy_index = argsort_actions[::-1][:np.where(
                actions > 0)[0].shape[0]]

            for index in sell_index:
                # print('take sell action'.format(actions[index]))
                self._sell_stock(index, actions[index])

            for index in buy_index:
                # print('take buy action: {}'.format(actions[index]))
                self._buy_stock(index, actions[index])

            self.day += 1
            self.data = self.df.loc[self.day, :]
            self.turbulence = self.data["turbulence"].values[0]
            # print(self.turbulence)
            # load next state
            # print("stock_shares:{}".format(self.state[29:]))
            self.state = (
                [self.state[0]] + self.data.adjcp.values.tolist() +
                list(self.state[(STOCK_DIM + 1):(STOCK_DIM * 2 + 1)]) +
                self.data.macd.values.tolist() +
                self.data.rsi.values.tolist() + self.data.cci.values.tolist() +
                self.data.adx.values.tolist())

            end_total_asset = self.state[0] + sum(
                np.array(self.state[1:(STOCK_DIM + 1)]) *
                np.array(self.state[(STOCK_DIM + 1):(STOCK_DIM * 2 + 1)]))
            self.asset_memory.append(end_total_asset)
            # print("end_total_asset:{}".format(end_total_asset))

            self.reward = end_total_asset - begin_total_asset
            # print("step_reward:{}".format(self.reward))
            self.rewards_memory.append(self.reward)

            self.reward = self.reward * REWARD_SCALING

        return self.state, self.reward, self.terminal, {}