Beispiel #1
0
register_env("dm-" + env_name, env_creator)


# Placeholder to enable use of a custom pre-processor
class ImagePreproc(Preprocessor):
    def _init_shape(self, obs_space, options):
        self.shape = (84, 84, 3)  # Adjust third dim if stacking frames
        return self.shape

    def transform(self, observation):
        observation = cv2.resize(observation, (self.shape[0], self.shape[1]))
        return observation


ModelCatalog.register_custom_preprocessor("sq_im_84", ImagePreproc)

config = {
    # Model and preprocessor options.
    "model": {
        "custom_model": model_name,
        "custom_options": {
            # Custom notes for the experiment
            "notes": {
                "args": vars(args)
            },
        },
        # NOTE:Wrappers are applied by RLlib if custom_preproc is NOT specified
        "custom_preprocessor": "sq_im_84",
        "dim": 84,
        "free_log_std": False,  # if args.discrete_actions else True,
Beispiel #2
0
        self.regressor.fit(X_scaled, y_pct)
        #print('Reg fit time:{:.2f}'.format(time.time()-t))

        #t = time.time()
        self.classifier.fit(X_scaled, y_sign)
        #print('Class fit time:{:.2f}'.format(time.time()-t))

        self.is_fitted = True

    def predict(self, obs):
        obs = np.array(obs).reshape(1, -1)
        return self.regressor.predict(obs)[0], self.classifier.predict(obs)[0]


if __name__ == '__main__':
    ModelCatalog.register_custom_preprocessor('mv_pred',
                                              PredictiveMarketVariables)
    env = MarketOrderEnv(order_paths='../../data/feather/',
                         snapshot_paths='../../data/snap_json/',
                         max_sequence_skip=10000,
                         max_episode_time='20hours',
                         random_start=False)

    options = {
        'custom_preprocessor': 'mv_pred',
        'custom_options': {
            'fast_macd_l': 1200,
            'slow_macd_l': 2400,
            'ofi_l': 1000,
            'mid_l': 1000
        }
    }