def make_summary_table(overpasses, tz, twelvehour): """ Make a summary data table to print to console """ table = Table() table.add_column(Align("Date", "center"), justify="right") table.add_column(Align("Duration", "center"), justify="right") table.add_column(Align("Max Elev", "center"), justify="right") table.add_column(Align("Type", "center"), justify='center') def get_min_sec_string(total_seconds: int) -> str: """ Get total number of seconds, return string with min:sec format """ nmin = floor(total_seconds / 60) nsec = total_seconds - nmin * 60 return f"{nmin:.0f}:{nsec:0.0f}" for overpass in overpasses: row = [] date = overpass.aos.dt.astimezone(tz) day = date.strftime("%x").lstrip('0') # round to nearest minute if date.second >= 30: date.replace(minute=date.minute + 1) if twelvehour: time = date.strftime("%I:%M").lstrip("0") ampm = date.strftime('%p').lower() row.append(f"{day:>8} {time:>5} {ampm}") else: time = date.strftime("%H:%M") row.append(f"{day} {time}") # if overpass.brightness is not None: # brightness_str = f"{overpass.brightness:4.1f}" # else: # brightness_str = " "*4 # table_data += " "*2 + brightness_str row.append(get_min_sec_string(overpass.duration)) row.append(f"{int(overpass.tca.elevation):2}\u00B0") if overpass.type: row.append(overpass.type.value) # + '\n' fg = 'green' if overpass.type.value == PassType.visible else None else: fg = None row = tuple(row) table.add_row(*row, style=fg) return table
def get_renderable(self): tasks_panel_rendered = self.tasks_panel.get_renderable() metrics_panel_rendered = self.metrics_panel.get_renderable() if tasks_panel_rendered is None and metrics_panel_rendered is None: return Align( Text(f"Waiting for test to start... ({self.test_name})"), align="center") layout = Layout() layout.split_row( Layout(name="running", visible=False), Layout(name="metrics", ratio=2, visible=False), ) if metrics_panel_rendered is not None: layout["metrics"].visible = True layout["metrics"].update(metrics_panel_rendered) if tasks_panel_rendered is not None: layout["running"].visible = True layout["running"].update(tasks_panel_rendered) return Panel(layout, title=self.test_name)
def test_align_width(): console = Console(file=io.StringIO(), width=40) words = "Deep in the human unconscious is a pervasive need for a logical universe that makes sense. But the real universe is always one step beyond logic" console.print(Align(words, "center", width=30)) result = console.file.getvalue() expected = " Deep in the human unconscious \n is a pervasive need for a \n logical universe that makes \n sense. But the real universe \n is always one step beyond \n logic \n" assert result == expected
def test_align_center_middle(): console = Console(file=io.StringIO(), width=10) console.print(Align("foo\nbar", "center", vertical="middle"), height=5) expected = " \n foo \n bar \n \n \n" result = console.file.getvalue() print(repr(result)) assert result == expected
def test_align_bottom(): console = Console(file=io.StringIO(), width=10) console.print(Align("foo", vertical="bottom"), height=5) expected = " \n \n \n \nfoo \n" result = console.file.getvalue() print(repr(result)) assert result == expected
def test_align_right_style(): console = Console(file=io.StringIO(), width=10, color_system="truecolor", force_terminal=True) console.print(Align("foo", "right", style="on blue")) assert console.file.getvalue( ) == "\x1b[44m \x1b[0m\x1b[44mfoo\x1b[0m\n"
def make_detail_table(overpasses, tz, twelvehour): """ Make a detailed data table to print to console """ table = Table() table.add_column(Align("Date", 'center'), justify="right") for x in ('Start', 'Max', 'End'): table.add_column(Align(f"{x}\nTime", "center"), justify="right") table.add_column(Align(f"{x}\nEl", "center"), justify="right", width=4) table.add_column(Align(f"{x}\nAz", "center"), justify="right", width=4) table.add_column(Align("Type", "center"), justify='center') def point_string(point): time = point.dt.astimezone(tz) point_data = [] if twelvehour: point_data.append( time.strftime("%I:%M:%S").lstrip("0") + ' ' + time.strftime("%p").lower()) else: point_data.append(time.strftime("%H:%M:%S")) point_data.append("{:>2}\u00B0".format(int(point.elevation))) point_data.append("{:3}".format(point.direction)) return point_data for overpass in overpasses: row = [] row.append("{}".format( overpass.aos.dt.astimezone(tz).strftime("%m/%d/%y"))) # if overpass.brightness is not None: # brightness_str = f"{overpass.brightness:4.1f}" # else: # brightness_str = " "*4 # table_data += " "*2 + brightness_str row += point_string(overpass.aos) row += point_string(overpass.tca) row += point_string(overpass.los) if overpass.type: row.append(overpass.type.value) # + '\n' fg = 'green' if overpass.type.value == PassType.visible else None else: fg = None row = tuple(row) table.add_row(*row, style=fg) return table
def main(filename: str) -> None: with open(Path(filename) / "setup.py") as f: tree = ast.parse(f.read()) analyzer = Analyzer() analyzer.visit(tree) if analyzer.setup_function is None: raise RuntimeError("Invalid, no setup function found") print(Panel(Align("New [bold]setup.cfg", align="center"))) print() print(analyzer.metadata) print() print(analyzer.options) print() print(Panel(Align("New setup() in [bold]setup.py", "center"))) print() new_setup_py = black.format_str(ast.unparse(analyzer.setup_function), mode=black.Mode()) print(Syntax(new_setup_py, "python", theme="ansi_light"))
def final_check(pgscatalog_df, plink_variants_df): overlaps = pgscatalog_df[PGSCATALOG_KEY_COLUMN].isin( plink_variants_df[PLINK_KEY_COLUMN]) sample = pgscatalog_df.loc[overlaps, [PGSCATALOG_KEY_COLUMN]].head(10) layout["top"]["leftfile"].update( Align( render_file_table(sample, title="PRS WM file key"), align="center", )) layout["top"]["rightfile"].update( Align( render_file_table(sample, title="plink data key"), align="center", )) console.print(layout) printout( "-----------------\nPlease review the keys the data will be matched on" ) if CONFIG is not None: is_to_run = True # auto mode else: # this is intentionally not done using "ask()" due to a different logic is_to_run = Confirm.ask( "Do you confirm you want to match the files on these keys?" ) # manual mode if is_to_run: layout["top"]["leftfile"].update( render_file_table(pgscatalog_df, title="PRS WM file")) layout["top"]["rightfile"].update( render_file_table(plink_variants_df, title="plink data")) console.print(layout) printout("-----------------") else: printout( "You have decided to halt the execution due to the mismatch between the IDs in two files." ) printout( "Please, review this troubleshooting guide, if you need an inspiration about how fixing this:" ) print_error_files() exit(50)
def generate_display(display: str) -> Panel: """Generate a panel to display using the rich library. Args: display (str): The visualized values to display. Returns: rich.Panel: The generated panel. """ table = Panel(display, expand=False, title="[bold blue]Devilizer[/bold blue]") return Align(table, align="center")
def test_bad_align_legal(): # Legal Align("foo", "left") Align("foo", "center") Align("foo", "right") # illegal with pytest.raises(ValueError): Align("foo", None) with pytest.raises(ValueError): Align("foo", "middle") with pytest.raises(ValueError): Align("foo", "") with pytest.raises(ValueError): Align("foo", "LEFT") with pytest.raises(ValueError): Align("foo", vertical="somewhere")
def make_header_columns(): # Epoch and LR columns = [Text("#", justify="right", style="bold")] if learning_rate is not None: columns += [Text("LR", justify="right")] yield Columns(columns, align="center", width=6) # Training losses text = Align( Text("Total", justify="right", style="bold red"), width=col_width, align="center", ) if multi_target: columns = [text] + [ Align(Text(n, justify="right", style="red"), width=col_width) for n in losses_training.keys() ] yield Columns(columns, align="center", width=col_width) else: yield text # Validation losses if total_loss_validation is not None: text = Align( Text("Total", justify="center", style="bold blue"), width=col_width, align="center", ) if multi_target: columns = [text] + [ Align(Text(n, justify="center", style="blue"), width=col_width) for n in losses_validation.keys() ] yield Columns(columns, align="center", width=col_width) else: yield text # Metrics if metrics is not None: for name, values in metrics.items(): if isinstance(values, dict): columns = [ Align( Text(n, justify="center", style="purple"), width=col_width, ) for n in values.keys() ] yield Columns(columns, align="center", width=col_width) else: yield Align(Text(""), width=col_width)
def dry_run_banner() -> Panel: width = min(get_terminal_size().columns, 75) return Panel( Align( """\ You are in [bold blue]dry-run mode[/]. Issues and cards will not be created. Creation of objects from --create-missing will still occur [dim](e.g. missing labels will be created if you answer 'yes')[/]\ """, "center", width=width, ), title="--dry-run was passed", width=width, style="black on yellow", box=Box(" \n" * 8), )
def run_ensemble_strategy( df, unique_trade_date, rebalance_window, validation_window ) -> None: """Ensemble Strategy that combines PPO, A2C and DDPG""" # for ensemble model, it's necessary to feed the last state # of the previous model to the current model as the initial state last_state_ensemble = [] ppo_sharpe_list = [] ddpg_sharpe_list = [] a2c_sharpe_list = [] model_used = [] # based on the analysis of the in-sample data # turbulence_threshold = 140 insample_turbulence = df[ (df.datadate < config.VALIDATION_START_DATE - 1) & (df.datadate >= config.TRAINING_START_DATE) ] insample_turbulence = insample_turbulence.drop_duplicates(subset=["datadate"]) insample_turbulence_threshold = np.quantile( insample_turbulence.turbulence.values, 0.90 ) start = time.time() for i in range( rebalance_window + validation_window, len(unique_trade_date), rebalance_window ): ## initial state is empty if i - rebalance_window - validation_window == 0: # inital state initial = True else: # previous state initial = False # Tuning turbulence index based on historical data # Turbulence lookback window is one quarter end_date_index = df.index[ df["datadate"] == unique_trade_date[i - rebalance_window - validation_window] ].to_list()[-1] start_date_index = end_date_index - validation_window * 30 + 1 historical_turbulence = df.iloc[start_date_index : (end_date_index + 1), :] # historical_turbulence = df[(df.datadate<unique_trade_date[i - rebalance_window - validation_window]) & (df.datadate>=(unique_trade_date[i - rebalance_window - validation_window - 63]))] historical_turbulence = historical_turbulence.drop_duplicates( subset=["datadate"] ) historical_turbulence_mean = np.mean(historical_turbulence.turbulence.values) if historical_turbulence_mean > insample_turbulence_threshold: # if the mean of the historical data is greater than the 90% quantile of insample turbulence data # then we assume that the current market is volatile, # therefore we set the 90% quantile of insample turbulence data as the turbulence threshold # meaning the current turbulence can't exceed the 90% quantile of insample turbulence data turbulence_threshold = insample_turbulence_threshold else: # if the mean of the historical data is less than the 90% quantile of insample turbulence data # then we tune up the turbulence_threshold, meaning we lower the risk turbulence_threshold = np.quantile(insample_turbulence.turbulence.values, 1) style = "[bold #31DDCF]" rprint( Align( f"{style}Turbulence Threshold:[/] {str(turbulence_threshold)}", "center" ) ) ############## Environment Setup starts ############## ## training env train = data_split( df, start=config.TRAINING_START_DATE, end=unique_trade_date[i - rebalance_window - validation_window], ) env_train = DummyVecEnv([lambda: StockEnvTrain(train)]) ## validation env validation = data_split( df, start=unique_trade_date[i - rebalance_window - validation_window], end=unique_trade_date[i - rebalance_window], ) env_val = DummyVecEnv( [ lambda: StockEnvValidation( validation, turbulence_threshold=turbulence_threshold, iteration=i ) ] ) obs_val = env_val.reset() ############## Environment Setup ends ############## ############## Training and Validation starts ############## table = Table( title=f"Training from 20090000 to {unique_trade_date[i - rebalance_window - validation_window]}", expand=True, ) table.add_column("Mode Name", justify="center") table.add_column("Sharpe Ratio") table.add_column("Training Time") with Live(table, auto_refresh=False) as live: model_a2c, a2c_training_time = train_A2C( env_train, model_name="A2C_30k_dow_{}".format(i), timesteps=30000 ) DRL_validation( model=model_a2c, test_data=validation, test_env=env_val, test_obs=obs_val, ) sharpe_a2c = get_validation_sharpe(i) table.add_row("A2C", str(sharpe_a2c), f"{a2c_training_time} minutes") live.update(table, refresh=True) model_ppo, ppo_training_time = train_PPO( env_train, model_name="PPO_100k_dow_{}".format(i), timesteps=100000 ) DRL_validation( model=model_ppo, test_data=validation, test_env=env_val, test_obs=obs_val, ) sharpe_ppo = get_validation_sharpe(i) table.add_row("PPO", str(sharpe_ppo), f"{ppo_training_time} minutes") live.update(table, refresh=True) model_ddpg, ddpg_training_time = train_DDPG( env_train, model_name="DDPG_10k_dow_{}".format(i), timesteps=10000 ) # model_ddpg = train_TD3(env_train, model_name="DDPG_10k_dow_{}".format(i), timesteps=20000) DRL_validation( model=model_ddpg, test_data=validation, test_env=env_val, test_obs=obs_val, ) sharpe_ddpg = get_validation_sharpe(i) table.add_row("DDPG", str(sharpe_ddpg), f"{ddpg_training_time} minutes") live.update(table, refresh=True) ppo_sharpe_list.append(sharpe_ppo) a2c_sharpe_list.append(sharpe_a2c) ddpg_sharpe_list.append(sharpe_ddpg) # Model Selection based on sharpe ratio if (sharpe_ppo >= sharpe_a2c) & (sharpe_ppo >= sharpe_ddpg): model_ensemble = model_ppo model_used.append("PPO") elif (sharpe_a2c > sharpe_ppo) & (sharpe_a2c > sharpe_ddpg): model_ensemble = model_a2c model_used.append("A2C") else: model_ensemble = model_ddpg model_used.append("DDPG") ############## Training and Validation ends ############## ############## Trading starts ############## # print("Used Model: ", model_ensemble) last_state_ensemble = DRL_prediction( df=df, model=model_ensemble, name="ensemble", last_state=last_state_ensemble, iter_num=i, unique_trade_date=unique_trade_date, rebalance_window=rebalance_window, turbulence_threshold=turbulence_threshold, initial=initial, ) print("\n\n") # print("============Trading Done============") ############## Trading ends ############## end = time.time() print("Ensemble Strategy took: ", (end - start) / 60, " minutes")
def test_align_fit(): console = Console(file=io.StringIO(), width=10) console.print(Align("foobarbaze", "center")) assert console.file.getvalue() == "foobarbaze\n"
def test_render(): console = Console(file=io.StringIO(), width=10) console.print(Align("foo", "left")) assert console.file.getvalue() == "foo\n"
def test_align_right(): console = Console(file=io.StringIO(), width=10) console.print(Align("foo", "right")) assert console.file.getvalue() == " foo\n"
def test_repr(): repr(Align("foo", "left")) repr(Align("foo", "center")) repr(Align("foo", "right"))
def make_columns(): yield Columns( [ Align(Text(f"{epoch:3}", justify="right", style="bold"), width=4), Align(Text(f"{learning_rate:1.4f}", justify="right"), width=6), ], align="center", width=6, ) text = Align( Text(f"{total_loss_training:3.3f}", style="bold red"), align="center", width=col_width, ) if multi_target: columns = [text] + [ Align( Text(f"{l:3.3f}", justify="right", style="red"), width=col_width, align="right", ) for _, l in losses_training.items() ] yield Columns(columns, align="center", width=col_width) else: yield text if total_loss_validation is not None: text = Align( Text( f"{total_loss_validation:.3f}", justify="center", style="bold blue", ), width=col_width, align="center", ) if multi_target: columns = [text] for _, l in losses_validation.items(): columns.append( Align( Text(f"{l:.3f}", justify="center", style="blue"), width=col_width, align="center", )) yield Columns(columns, align="center", width=col_width) else: yield text # Metrics if metrics is not None: for name, values in metrics.items(): if isinstance(values, dict): columns = [ Align( Text(f"{v:.3f}", justify="center", style="purple"), width=col_width, ) for _, v in values.items() ] yield Columns(columns, align="center", width=col_width) else: yield Align(Text(f"{values:.3f}"), width=col_width, style="purple")
def test_align_no_pad(): console = Console(file=io.StringIO(), width=10) console.print(Align("foo", "center", pad=False)) console.print(Align("foo", "left", pad=False)) assert console.file.getvalue() == " foo\nfoo\n"
def test_measure(): console = Console(file=io.StringIO(), width=20) _min, _max = Measurement.get(console, Align("foo bar", "left"), 20) assert _min == 3 assert _max == 7
def check( c, lint_=True, fixmes_=False, test_=True, coverage_=True, mypy_=True, black_=True, isort_=True, docs_=True, clean_=True, ): """Runs all checkers on the code.""" results = {} if lint_: print("-" * 20) print("Running pylint...") print("-" * 20) results["lint"] = lint(c).exited if fixmes_: print("-" * 20) print("Running pylint (fixmes)...") print("-" * 20) results["FIXME's"] = fixmes(c).exited if test_: print("-" * 20) print("Running tests...") print("-" * 20) results["test"] = test(c, verbose=False).exited if coverage_: print("-" * 20) print("Reporting test coverage...") print("-" * 20) results["coverage"] = coverage(c).exited if mypy_: print("-" * 20) print("Running mypy...") print("-" * 20) results["mypy"] = mypy(c).exited if black_: print("-" * 20) print("Running black (formatting, just checking)...") print("-" * 20) results["black"] = black(c, check=True).exited if isort_: print("-" * 20) print("Running isort (formatting, just checking)...") print("-" * 20) results["isort"] = isort(c, check=True).exited if docs_: print("-" * 20) print("Running mkdocs...") print("-" * 20) results["docs"] = docs(c, build=True, verbose=False).exited result = 1 if any(results.values()) else 0 t = Table( title="Report", title_style="bold white", show_header=True, header_style="bold white", show_footer=True, footer_style="bold white", show_lines=True, box=box.ROUNDED, ) t.add_column("Task", "Summary") t.add_column("Result", f"[bold]{ _code_to_stat(result, underline=True) }[/bold]") for k, v in results.items(): t.add_row(k, _code_to_stat(v)) print("\n") con.print(Align(t, "center")) if result == 0: exit_msg = ( "Congratulations :sparkles::fireworks::sparkles: " + "You may commit! :heavy_check_mark:" ) else: exit_msg = ( "Great code dude :+1:, but it could use some final touches. " + "Don't commit just yet! :x:" ) print(Align(f"[underline bold]{exit_msg}[/underline bold]", "center")) print("\n") if clean_: clean(c, silent=True) raise Exit(code=result)
def rich_format_help( *, obj: Union[click.Command, click.Group], ctx: click.Context, markup_mode: MarkupMode, ) -> None: """Print nicely formatted help text using rich. Based on original code from rich-cli, by @willmcgugan. https://github.com/Textualize/rich-cli/blob/8a2767c7a340715fc6fbf4930ace717b9b2fc5e5/src/rich_cli/__main__.py#L162-L236 Replacement for the click function format_help(). Takes a command or group and builds the help text output. """ console = _get_rich_console() # Print usage console.print(Padding(highlighter(obj.get_usage(ctx)), 1), style=STYLE_USAGE_COMMAND) # Print command / group help if we have some if obj.help: # Print with some padding console.print( Padding( Align( _get_help_text( obj=obj, markup_mode=markup_mode, ), pad=False, ), (0, 1, 1, 1), )) panel_to_arguments: DefaultDict[str, List[click.Argument]] = defaultdict(list) panel_to_options: DefaultDict[str, List[click.Option]] = defaultdict(list) for param in obj.get_params(ctx): # Skip if option is hidden if getattr(param, "hidden", False): continue if isinstance(param, click.Argument): panel_name = (getattr(param, _RICH_HELP_PANEL_NAME, None) or ARGUMENTS_PANEL_TITLE) panel_to_arguments[panel_name].append(param) elif isinstance(param, click.Option): panel_name = (getattr(param, _RICH_HELP_PANEL_NAME, None) or OPTIONS_PANEL_TITLE) panel_to_options[panel_name].append(param) default_arguments = panel_to_arguments.get(ARGUMENTS_PANEL_TITLE, []) _print_options_panel( name=ARGUMENTS_PANEL_TITLE, params=default_arguments, ctx=ctx, markup_mode=markup_mode, console=console, ) for panel_name, arguments in panel_to_arguments.items(): if panel_name == ARGUMENTS_PANEL_TITLE: # Already printed above continue _print_options_panel( name=panel_name, params=arguments, ctx=ctx, markup_mode=markup_mode, console=console, ) default_options = panel_to_options.get(OPTIONS_PANEL_TITLE, []) _print_options_panel( name=OPTIONS_PANEL_TITLE, params=default_options, ctx=ctx, markup_mode=markup_mode, console=console, ) for panel_name, options in panel_to_options.items(): if panel_name == OPTIONS_PANEL_TITLE: # Already printed above continue _print_options_panel( name=panel_name, params=options, ctx=ctx, markup_mode=markup_mode, console=console, ) if isinstance(obj, click.MultiCommand): panel_to_commands: DefaultDict[str, List[click.Command]] = defaultdict(list) for command_name in obj.list_commands(ctx): command = obj.get_command(ctx, command_name) if command and not command.hidden: panel_name = (getattr(command, _RICH_HELP_PANEL_NAME, None) or COMMANDS_PANEL_TITLE) panel_to_commands[panel_name].append(command) # Print each command group panel default_commands = panel_to_commands.get(COMMANDS_PANEL_TITLE, []) _print_commands_panel( name=COMMANDS_PANEL_TITLE, commands=default_commands, markup_mode=markup_mode, console=console, ) for panel_name, commands in panel_to_commands.items(): if panel_name == COMMANDS_PANEL_TITLE: # Already printed above continue _print_commands_panel( name=panel_name, commands=commands, markup_mode=markup_mode, console=console, ) # Epilogue if we have it if obj.epilog: # Remove single linebreaks, replace double with single lines = obj.epilog.split("\n\n") epilogue = "\n".join([x.replace("\n", " ").strip() for x in lines]) epilogue_text = _make_rich_rext(text=epilogue, markup_mode=markup_mode) console.print(Padding(Align(epilogue_text, pad=False), 1))
def step(self, actions): # print(self.day) self.terminal = self.day >= len(self.df.index.unique()) - 1 # print(actions) if self.terminal: plt.plot(self.asset_memory, "r") plt.savefig("results/account_value_trade_{}_{}.png".format( self.model_name, self.iteration)) plt.close() df_total_value = pd.DataFrame(self.asset_memory) df_total_value.to_csv( "results/account_value_trade_{}_{}.csv".format( self.model_name, self.iteration)) end_total_asset = self.state[0] + sum( np.array(self.state[1:(STOCK_DIM + 1)]) * np.array(self.state[(STOCK_DIM + 1):(STOCK_DIM * 2 + 1)])) df_total_value.columns = ["account_value"] df_total_value["daily_return"] = df_total_value.pct_change(1) sharpe = ((4**0.5) * df_total_value["daily_return"].mean() / df_total_value["daily_return"].std()) style = "[bold #31DDCF]" rprint( Align( Panel( RenderGroup( f"{style}Previous Total Asset:[/] {'$' + str(self.asset_memory[0])}", f"{style}End Total Asset:[/] {'$' + str(end_total_asset)}", f"""{style}Total Reward:[/] {'$' + str(self.state[0] + sum( np.array(self.state[1 : (STOCK_DIM + 1)]) * np.array(self.state[(STOCK_DIM + 1) : (STOCK_DIM * 2 + 1)]) ) - self.asset_memory[0])}""", f"{style}Total Cost:[/] {'$' + str(self.cost)}", f"{style}Total Trades:[/] {str(self.trades)}", f"{style}Sharpe Ratio:[/] {str(sharpe)}", ), title=self.title, expand=True, ), "center", )) df_rewards = pd.DataFrame(self.rewards_memory) df_rewards.to_csv("results/account_rewards_trade_{}_{}.csv".format( self.model_name, self.iteration)) # print('total asset: {}'.format(self.state[0]+ sum(np.array(self.state[1:29])*np.array(self.state[29:])))) # with open('obs.pkl', 'wb') as f: # pickle.dump(self.state, f) return self.state, self.reward, self.terminal, {} else: # print(np.array(self.state[1:29])) actions = actions * HMAX_NORMALIZE # actions = (actions.astype(int)) if self.turbulence >= self.turbulence_threshold: actions = np.array([-HMAX_NORMALIZE] * STOCK_DIM) begin_total_asset = self.state[0] + sum( np.array(self.state[1:(STOCK_DIM + 1)]) * np.array(self.state[(STOCK_DIM + 1):(STOCK_DIM * 2 + 1)])) # print("begin_total_asset:{}".format(begin_total_asset)) argsort_actions = np.argsort(actions) sell_index = argsort_actions[:np.where(actions < 0)[0].shape[0]] buy_index = argsort_actions[::-1][:np.where( actions > 0)[0].shape[0]] for index in sell_index: # print('take sell action'.format(actions[index])) self._sell_stock(index, actions[index]) for index in buy_index: # print('take buy action: {}'.format(actions[index])) self._buy_stock(index, actions[index]) self.day += 1 self.data = self.df.loc[self.day, :] self.turbulence = self.data["turbulence"].values[0] # print(self.turbulence) # load next state # print("stock_shares:{}".format(self.state[29:])) self.state = ( [self.state[0]] + self.data.adjcp.values.tolist() + list(self.state[(STOCK_DIM + 1):(STOCK_DIM * 2 + 1)]) + self.data.macd.values.tolist() + self.data.rsi.values.tolist() + self.data.cci.values.tolist() + self.data.adx.values.tolist()) end_total_asset = self.state[0] + sum( np.array(self.state[1:(STOCK_DIM + 1)]) * np.array(self.state[(STOCK_DIM + 1):(STOCK_DIM * 2 + 1)])) self.asset_memory.append(end_total_asset) # print("end_total_asset:{}".format(end_total_asset)) self.reward = end_total_asset - begin_total_asset # print("step_reward:{}".format(self.reward)) self.rewards_memory.append(self.reward) self.reward = self.reward * REWARD_SCALING return self.state, self.reward, self.terminal, {}