コード例 #1
0
def main(
        env_name, n_epoch, learning_rate, gamma, n_hidden,
        seed_val=0, max_steps=1000
):
    '''train an a2c network some gym env'''
    # define env
    env = gym.make(env_name)
    env.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    # define agent
    state_dim, n_actions, action_space_type = get_env_info(env)
    agent = Agent(state_dim, n_hidden, n_actions)
    optimizer = torch.optim.Adam(agent.parameters(), lr=learning_rate)
    # train
    log_step = np.zeros((n_epoch,))
    log_return = np.zeros((n_epoch,))
    log_loss_v = np.zeros((n_epoch,))
    log_loss_p = np.zeros((n_epoch,))
    for i in range(n_epoch):
        cumulative_reward, step, probs, rewards, values = run(
            agent, env, gamma=gamma, max_steps=max_steps
        )
        # update weights
        returns = compute_returns(rewards, gamma=gamma, normalize=True)
        loss_policy, loss_value = compute_a2c_loss(probs, values, returns)
        loss = loss_policy + loss_value
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # log message
        log_step[i] = step
        log_return[i] = cumulative_reward
        log_loss_v[i] = loss_value.item()
        log_loss_p[i] = loss_policy.item()
        if np.mod(i, 10) == 0:
            print(
                'Epoch : %.3d | R: %.2f, steps: %4d | L: pi: %.2f, V: %.2f' %
                (i, log_return[i], log_step[i], log_loss_p[i], log_loss_v[i])
            )

    # save weights
    # ckpt_fname = f'../log/agent-{env_name}-{n_epoch}.pth'
    ckpt_fname = f'../log/agent-{env_name}.pth'
    torch.save(agent.state_dict(), ckpt_fname)

    '''show learning curve: return, steps'''
    f, axes = plt.subplots(2, 1, figsize=(7, 7), sharex=True)
    axes[0].plot(log_return)
    axes[1].plot(log_step)
    axes[0].set_title(f'Learning curve: {env_name}')
    axes[0].set_ylabel('Return')
    axes[1].set_ylabel('#steps')
    axes[1].set_xlabel('Epoch')
    sns.despine()
    f.tight_layout()
    f.savefig(f'../figs/lc-{env_name}.png', dpi=120)
コード例 #2
0
def main(env_name, n_hidden):
    '''render the performance of a saved ckpt'''
    # define env and agent
    env = gym.make(env_name).env
    state_dim, n_actions, action_space_type = get_env_info(env)
    agent = Agent(state_dim, n_hidden, n_actions)
    agent.load_state_dict(torch.load(f'../log/agent-{env_name}.pth'))
    agent.eval()
    cumulative_reward, step, probs, rewards, values = run(agent,
                                                          env,
                                                          render=True)
コード例 #3
0
def spot_get_market_summary():
    url = "/api/{0}/market_summary".format(get_spot_api_version())
    env = get_env_info()
    ret = {}
    try:
        resp = requests.get(get_spot_full_url(env["API_HOST"], url))
        resp.raise_for_status()
    except HTTPError as http_err:
        print("HTTP error occurred: {0}".format(http_err))
    except Exception as err:
        print("Other error occurred: {0}".format(err))
    else:
        ret = resp.json()
    return ret
コード例 #4
0
def futures_get_trades(params):
    url = "/api/{0}/trades".format(get_futures_api_version())
    env = get_env_info()
    ret = {}
    try:
        resp = requests.get(get_futures_full_url(env["API_HOST"], url),
                            params=params)
        resp.raise_for_status()
    except HTTPError as http_err:
        print("HTTP error occurred: {0}".format(http_err))
    except Exception as err:
        print("Other error occurred: {0}".format(err))
    else:
        ret = resp.json()
    return ret
コード例 #5
0
def earn_get_orders():
    url = "/api/{0}/invest/orders".format(get_spot_api_version())
    env = get_env_info()
    headers = gen_headers(env["API_KEY"], env["API_SECRET_KEY"], url)
    ret = {}
    try:
        resp = requests.get(get_spot_full_url(env["API_HOST"], url),
                            headers=headers)
        resp.raise_for_status()
    except HTTPError as http_err:
        print("HTTP error occurred: {0}".format(http_err))
    except Exception as err:
        print("Other error occurred: {0}".format(err))
    else:
        ret = resp.json()
    return ret
コード例 #6
0
def otc_accept_quote(quote_id):
    url = "/api/{0}/accept/{1}".format(get_otc_api_version(), quote_id)
    env = get_env_info()
    headers = gen_headers(env["API_KEY"], env["API_SECRET_KEY"], url)
    ret = {}
    try:
        resp = requests.post(get_otc_full_url(env["API_HOST"], url),
                             headers=headers)
        resp.raise_for_status()
    except HTTPError as http_err:
        print("HTTP error occurred: {0}".format(http_err))
    except Exception as err:
        print("Other error occurred: {0}".format(err))
    else:
        ret = resp.json()
    return ret
コード例 #7
0
def spot_place_limit_order(data):
    url = "/api/{0}/order".format(get_spot_api_version())
    env = get_env_info()
    headers = gen_headers(env["API_KEY"], env["API_SECRET_KEY"], url,
                          json.dumps(data))
    ret = {}
    try:
        resp = requests.post(get_spot_full_url(env["API_HOST"], url),
                             json=data,
                             headers=headers)
        resp.raise_for_status()
    except HTTPError as http_err:
        print("HTTP error occurred: {0}".format(http_err))
    except Exception as err:
        print("Other error occurred: {0}".format(err))
    else:
        ret = resp.json()
    return ret
コード例 #8
0
def futures_change_settlement_currency(data):
    url = "/api/{0}/settle_in".format(get_futures_api_version())
    env = get_env_info()
    headers = gen_headers(env["API_KEY"], env["API_SECRET_KEY"], url,
                          json.dumps(data))
    ret = {}
    try:
        resp = requests.post(
            get_futures_full_url(env["API_HOST"], url),
            json=data,
            headers=headers,
        )
        resp.raise_for_status()
    except HTTPError as http_err:
        print("HTTP error occurred: {0}".format(http_err))
    except Exception as err:
        print("Other error occurred: {0}".format(err))
    else:
        ret = resp
    return ret
コード例 #9
0
def spot_get_wallet_address(currency):
    url = "/api/{0}/user/wallet/address".format(get_spot_api_version())
    env = get_env_info()
    headers = gen_headers(env["API_KEY"], env["API_SECRET_KEY"], url)
    ret = {}
    params = {"currency": currency}
    try:
        resp = requests.get(
            get_spot_full_url(env["API_HOST"], url),
            params=params,
            headers=headers,
        )
        resp.raise_for_status()
    except HTTPError as http_err:
        print("HTTP error occurred: {0}".format(http_err))
    except Exception as err:
        print("Other error occurred: {0}".format(err))
    else:
        ret = resp.json()
    return ret
コード例 #10
0
def earn_get_history(page_number=1, page_size=10):
    url = "/api/{0}/invest/history".format(get_spot_api_version())
    env = get_env_info()
    headers = gen_headers(env["API_KEY"], env["API_SECRET_KEY"], url)
    params = {
        "pageNumber": page_number,
        "pageSize": page_size,
    }
    ret = {}
    try:
        # page_number is 1-based
        resp = requests.get(
            get_spot_full_url(env["API_HOST"], url),
            params=params,
            headers=headers,
        )
        resp.raise_for_status()
    except HTTPError as http_err:
        print("HTTP error occurred: {0}".format(http_err))
    except Exception as err:
        print("Other error occurred: {0}".format(err))
    else:
        ret = resp.json()
    return ret
コード例 #11
0
def main():
    bot.run(utils.get_env_info('DISCORD_TOKEN'))
コード例 #12
0
    print(message)


def on_error(ws, error):
    print(error)


def on_close(ws, close_status_code, close_msg):
    print("### socket closed ###")


def on_open(ws):
    payload = {
        "op": "subscribe",
        "args": ["orderBookApi:BTCPFC_0"],
    }
    ws.send(json.dumps(payload))


if __name__ == "__main__":
    # websocket.enableTrace(True)
    env = get_env_info()
    ws = websocket.WebSocketApp(
        get_futures_ws_url(env["WS_HOST"]),
        on_open=on_open,
        on_message=on_message,
        on_error=on_error,
        on_close=on_close,
    )
    ws.run_forever()