def __init__(self, num_input, num_hidden, num_action_space):
        super(Critic, self).__init__()

        self.s0 = None
        if iterable(num_input):
            len_num_input = len(num_input)
            if len_num_input == 2:
                self.s0 = nn.LSTM(num_input[1], num_hidden)
                # Taking leeway
                #    Note: num_input is used in s1, so doing this...
                num_input = num_hidden
            if len_num_input == 1:
                num_input = num_input[0]

        self.s1 = nn.Sequential(
            nn.Linear(num_input + num_action_space, num_hidden),
            nn.LayerNorm(num_hidden), nn.ReLU(inplace=True))

        self.s2 = nn.Sequential(nn.Linear(num_hidden, num_hidden),
                                nn.LayerNorm(num_hidden),
                                nn.ReLU(inplace=True))

        self.value = nn.Linear(num_hidden, 1)
        self.value.weight.data.mul_(0.1)
        self.value.bias.data.mul_(0.1)
Exemplo n.º 2
0
    def __init__(self,
                 tickers,
                 start_date,
                 num_days_iter,
                 today=None,
                 seed=None,
                 num_action_space=3,
                 render=False,
                 *args,
                 **kwargs):
        if seed: np.random.seed(seed)
        if not iterable(tickers): tickers = [tickers]

        self.tickers = self._get_tickers(tickers, start_date, num_days_iter,
                                         today, num_action_space, *args,
                                         **kwargs)
        self.reset_game()

        if render:
            # Somehow ax_list should be grouped in two always...
            # Or is there another way of getting one axis per row and then add?
            fig_height = 3 * len(self.tickers)
            self.fig, self.ax_list = plt.subplots(len(tickers),
                                                  2,
                                                  figsize=(10, fig_height))
Exemplo n.º 3
0
    def step(self, actions):
        if not iterable(actions): actions = [actions]
        assert len(self.tickers) == len(actions), f'{len(self.tickers)}, {len(actions)}'

        rewards, dones = zip(*(itertools.starmap(lambda ticker, action: ticker.step(action),
                                                 zip(self.tickers, actions))))

        # This is somewhat misleading
        score = functools.reduce(lambda x, y: x + y, rewards, 0.0)
        done = functools.reduce(lambda x, y: x | y, dones, False)

        return score, done
Exemplo n.º 4
0
    def step(self, action_index):
        assert not iterable(action_index), f'{action_index}'
        assert 0 <= action_index < self.moves_available(), \
            f'action_index: {action_index}, moves_avail: {self.moves_available()}'

        positions = self.position_df[action_index]

        actions = [
            np.argwhere(np.isclose(self.action_space, position)).item()
            for position in positions
        ]
        return super(PortfolioDiscrete, self).step(actions)
    def __init__(self, num_input, num_hidden, num_action_space):
        super(Actor, self).__init__()

        self.s0 = None
        if iterable(num_input):
            len_num_input = len(num_input)
            if len_num_input == 2:
                self.s0 = nn.LSTM(num_input[1], num_hidden)
            elif len_num_input == 1:
                num_input = num_input[0]

        if self.s0 is None:
            self.s1 = nn.Sequential(nn.Linear(num_input, num_hidden),
                                    nn.LayerNorm(num_hidden),
                                    nn.ReLU(inplace=True))

        self.s2 = nn.Sequential(nn.Linear(num_hidden, num_hidden),
                                nn.LayerNorm(num_hidden),
                                nn.ReLU(inplace=True))
        self.out = nn.Linear(num_hidden, num_action_space)
        # https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py
        # We are doing tanh afterall!
        self.out.weight.data.mul_(0.1)
        self.out.bias.data.mul_(0.1)