示例#1
0
    def last_env_response(self) -> Union[List[EnvResponse], EnvResponse]:
        """
        Get the last environment response

        :return: a dictionary that contains the state, reward, etc.
        """
        return squeeze_list(self._last_env_response)
示例#2
0
    def predict(self, inputs, outputs=None, squeeze_output=True, initial_feed_dict=None):
        """
        Run a forward pass of the network using the given input
        :param inputs: The input for the network
        :param outputs: The output for the network, defaults to self.outputs
        :param squeeze_output: call squeeze_list on output
        :param initial_feed_dict: a dictionary to use as the initial feed_dict. other inputs will be added to this dict
        :return: The network output

        WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
        """
        feed_dict = self.create_feed_dict(inputs)
        if initial_feed_dict:
            feed_dict.update(initial_feed_dict)
        if outputs is None:
            outputs = self.outputs

        if self.middleware.__class__.__name__ == 'LSTMMiddleware':
            feed_dict[self.middleware.c_in] = self.curr_rnn_c_in
            feed_dict[self.middleware.h_in] = self.curr_rnn_h_in

            output, (self.curr_rnn_c_in, self.curr_rnn_h_in) = self.sess.run([outputs, self.middleware.state_out],
                                                                             feed_dict=feed_dict)
        else:
            output = self.sess.run(outputs, feed_dict)

        if squeeze_output:
            output = squeeze_list(output)
        return output
    def predict(
        self,
        inputs: Dict[str, np.ndarray],
        outputs: List[str] = None,
        squeeze_output: bool = True,
        initial_feed_dict: Dict[str,
                                np.ndarray] = None) -> Tuple[np.ndarray, ...]:
        """
        Run a forward pass of the network using the given input
        :param inputs: The input dictionary for the network. Key is name of the embedder.
        :param outputs: list of outputs to return. Return all outputs if unspecified (currently not supported)
        :param squeeze_output: call squeeze_list on output if True
        :param initial_feed_dict: a dictionary of extra inputs for forward pass (currently not supported)
        :return: The network output

        WARNING: must only call once per state since each call is assumed by LSTM to be a new time step.
        """
        assert initial_feed_dict is None, "initial_feed_dict must be None"
        assert outputs is None, "outputs must be None"

        output = self._predict(inputs)
        output = list(o.asnumpy() for o in output)
        if squeeze_output:
            output = squeeze_list(output)
        return output
示例#4
0
 def __init__(self, name, parent, plot):
     self.name = name
     self.full_name = "{}/{}".format(parent.filename, self.name)
     self.plot = plot
     self.selected = False
     self.color = random.choice(Dark2[8])
     self.line = None
     self.scatter = None
     self.bands = None
     self.bokeh_source = parent.bokeh_source
     self.min_val = 0
     self.max_val = 0
     self.axis = 'default'
     self.sub_signals = []
     for name in self.bokeh_source.data.keys():
         if (len(name.split('/')) == 1 and name == self.name) or '/'.join(
                 name.split('/')[:-1]) == self.name:
             self.sub_signals.append(name)
     if len(self.sub_signals) > 1:
         self.mean_signal = squeeze_list([
             name for name in self.sub_signals
             if 'Mean' in name.split('/')[-1]
         ])
         self.stdev_signal = squeeze_list([
             name for name in self.sub_signals
             if 'Stdev' in name.split('/')[-1]
         ])
         self.min_signal = squeeze_list([
             name for name in self.sub_signals
             if 'Min' in name.split('/')[-1]
         ])
         self.max_signal = squeeze_list([
             name for name in self.sub_signals
             if 'Max' in name.split('/')[-1]
         ])
     else:
         self.mean_signal = squeeze_list(self.name)
         self.stdev_signal = None
         self.min_signal = None
         self.max_signal = None
     self.has_bollinger_bands = False
     if self.mean_signal and self.stdev_signal and self.min_signal and self.max_signal:
         self.has_bollinger_bands = True
     self.show_bollinger_bands = False
     self.bollinger_bands_source = None
     self.update_range()