def __init__(self, observation_space, action_space, config):
        Policy.__init__(self, observation_space=observation_space, action_space=action_space, config=config)

        if config["custom_preprocessor"]:
            self.preprocessor = ModelCatalog.get_preprocessor_for_space(
                observation_space=self.observation_space.original_space,
                options={"custom_preprocessor": config["custom_preprocessor"]})
        else:
            raise ValueError("Custom preprocessor for PokerCFRPolicy needs to be specified on its passed config.")

        env_id = config['env']
        assert env_id == POKER_ENV
        self.policy_dict = None
Beispiel #2
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):

        model_config = with_base_config(base_config=DEFAULT_STRATEGO_MODEL_CONFIG, extra_config=model_config)
        TFModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)

        print(model_config)

        observation_mode = model_config['custom_options']['observation_mode']
        if observation_mode == PARTIALLY_OBSERVABLE:
            self.pi_obs_key = 'partial_observation'
            self.vf_obs_key = 'partial_observation'
        elif observation_mode == FULLY_OBSERVABLE:
            self.pi_obs_key = 'full_observation'
            self.vf_obs_key = 'full_observation'
        elif observation_mode == BOTH_OBSERVATIONS:
            self.pi_obs_key = 'partial_observation'
            self.vf_obs_key = 'full_observation'
            assert not model_config['vf_share_layers']
        else:
            assert False, "policy observation_mode must be in [PARTIALLY_OBSERVABLE, FULLY_OBSERVABLE, BOTH_OBSERVATIONS]"

        if model_config["custom_preprocessor"]:
            print(obs_space)

            self.preprocessor = ModelCatalog.get_preprocessor_for_space(observation_space=self.obs_space.original_space,
                                                                        options=model_config)
        else:
            self.preprocessor = None
            logger.warn("No custom preprocessor for StrategoModel was specified.\n"
                        "Some tree search policies may not initialize their placeholders correctly without this.")

        self.use_lstm = model_config['use_lstm']
        self.lstm_cell_size = model_config['lstm_cell_size']
        self.vf_share_layers = model_config.get("vf_share_layers")
        self.mask_invalid_actions = model_config['custom_options']['mask_invalid_actions']

        conv_activation = get_activation_fn(model_config.get("conv_activation"))
        cnn_filters = model_config.get("conv_filters")
        fc_activation = get_activation_fn(model_config.get("fcnet_activation"))
        hiddens = model_config.get("fcnet_hiddens")

        if self.use_lstm:
            state_in = [tf.keras.layers.Input(shape=(self.lstm_cell_size,), name="pi_lstm_h"),
                        tf.keras.layers.Input(shape=(self.lstm_cell_size,), name="pi_lstm_c"),
                        tf.keras.layers.Input(shape=(self.lstm_cell_size,), name="vf_lstm_h"),
                        tf.keras.layers.Input(shape=(self.lstm_cell_size,), name="vf_lstm_c")]

            seq_lens_in = tf.keras.layers.Input(shape=(), name="lstm_seq_in")
            
            self.pi_obs_inputs = tf.keras.layers.Input(
                shape=(None, *obs_space.original_space[self.pi_obs_key].shape), name="pi_observation")
    
            self.vf_obs_inputs = tf.keras.layers.Input(
                shape=(None, *obs_space.original_space[self.vf_obs_key].shape), name="vf_observation")
        
        else:
            state_in, seq_lens_in = None, None
           
            self.pi_obs_inputs = tf.keras.layers.Input(
                shape=obs_space.original_space[self.pi_obs_key].shape, name="pi_observation")

            self.vf_obs_inputs = tf.keras.layers.Input(
                shape=obs_space.original_space[self.vf_obs_key].shape, name="vf_observation")
           
              
        if cnn_filters is None:
            
            # assuming board size will always remain the same for both pi and vf networks
            if self.use_lstm:
                single_obs_input_shape = self.pi_obs_inputs.shape.as_list()[2:]
            else:
                single_obs_input_shape = self.pi_obs_inputs.shape.as_list()[1:]
            cnn_filters = _get_filter_config(single_obs_input_shape)

        def maybe_td(layer):
            if self.use_lstm:
                return tf.keras.layers.TimeDistributed(layer=layer)
            else:
                return layer

        def build_primary_layers(prefix: str, obs_in: tf.Tensor, state_in: tf.Tensor):
            # encapsulated in a function to either be called once for shared policy/vf or twice for separate policy/vf

            _last_layer = obs_in

            for i, (out_size, kernel, stride) in enumerate(cnn_filters):
                _last_layer = maybe_td(tf.keras.layers.Conv2D(
                    filters=out_size,
                    kernel_size=kernel,
                    strides=stride,
                    activation=conv_activation,
                    padding="same",
                    name="{}_conv_{}".format(prefix, i)))(_last_layer)

            _last_layer = maybe_td(tf.keras.layers.Flatten())(_last_layer)

            for i, size in enumerate(hiddens):
                _last_layer = maybe_td(tf.keras.layers.Dense(
                    size,
                    name="{}_fc_{}".format(prefix, i),
                    activation=fc_activation,
                    kernel_initializer=normc_initializer(1.0)))(_last_layer)

            if self.use_lstm:
                _last_layer, *state_out = tf.keras.layers.LSTM(
                    units=self.lstm_cell_size,
                    return_sequences=True,
                    return_state=True,
                    name="{}_lstm".format(prefix))(
                    inputs=_last_layer,
                    mask=tf.sequence_mask(seq_lens_in),
                    initial_state=state_in)
            else:
                state_out = None

            return _last_layer, state_out


        if self.use_lstm:
            pi_state_in = state_in[:2]
            vf_state_in = state_in[2:]
        else:
            pi_state_in, vf_state_in = None, None

        policy_file_path = None
        if 'policy_keras_model_file_path' in model_config['custom_options']:
            policy_file_path = model_config['custom_options']['policy_keras_model_file_path']
        if policy_file_path is not None:
            if self.use_lstm:
                raise NotImplementedError

            pi_state_out = None
            self._pi_model = load_model(filepath=policy_file_path, compile=False)
            # remove loaded input layer
            # pi_model.layers.pop(0)
            # self.pi_obs_inputs = pi_model.layers[0]

            # rename layers
            for layer in self._pi_model.layers:
                layer._name = "pi_" + layer.name
            self._pi_model.layers[-1]._name = 'pi_unmasked_logits'

            self.unmasked_logits_out = self._pi_model(self.pi_obs_inputs)

        else:
            self._pi_model = None
            pi_last_layer, pi_state_out = build_primary_layers(prefix="pi", obs_in=self.pi_obs_inputs,
                                                               state_in=pi_state_in)

            self.unmasked_logits_out = maybe_td(tf.keras.layers.Dense(
                num_outputs,
                name="pi_unmasked_logits",
                activation=None,
                kernel_initializer=normc_initializer(0.01)))(pi_last_layer)

        vf_last_layer, vf_state_out = build_primary_layers(prefix="vf", obs_in=self.vf_obs_inputs,
                                                           state_in=vf_state_in)

        if self.use_lstm:
            state_out = [*pi_state_out, *vf_state_out]
        else:
            state_out = None

        self._use_q_fn = model_config['custom_options']['q_fn']

        if self._use_q_fn:
            value_out_size = num_outputs
        else:
            value_out_size = 1

        value_out = maybe_td(tf.keras.layers.Dense(
            value_out_size,
            name="vf_out",
            activation=None,
            kernel_initializer=normc_initializer(0.01)))(vf_last_layer)
        
        model_inputs = [self.pi_obs_inputs, self.vf_obs_inputs]
        model_outputs = [self.unmasked_logits_out, value_out]
        if self.use_lstm:
            model_inputs += [seq_lens_in, *state_in]
            model_outputs += state_out

        self.base_model = tf.keras.Model(inputs=model_inputs, outputs=model_outputs)

        print(self.base_model.summary())

        self.register_variables(self.base_model.variables)
Beispiel #3
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name, twin_q):

        model_config = with_base_config(
            base_config=DEFAULT_STRATEGO_MODEL_CONFIG,
            extra_config=model_config)
        TFModelV2.__init__(self, obs_space, action_space, num_outputs,
                           model_config, name)

        print(model_config)

        observation_mode = model_config['custom_options']['observation_mode']
        if observation_mode == PARTIALLY_OBSERVABLE:
            self.pi_obs_key = 'partial_observation'
            self.vf_obs_key = 'partial_observation'
        elif observation_mode == FULLY_OBSERVABLE:
            self.pi_obs_key = 'full_observation'
            self.vf_obs_key = 'full_observation'
        elif observation_mode == BOTH_OBSERVATIONS:
            self.pi_obs_key = 'partial_observation'
            self.vf_obs_key = 'full_observation'
            assert not model_config['vf_share_layers']
        else:
            assert False, "policy observation_mode must be in [PARTIALLY_OBSERVABLE, FULLY_OBSERVABLE, BOTH_OBSERVATIONS]"

        if model_config["custom_preprocessor"]:
            print(obs_space)

            self.preprocessor = ModelCatalog.get_preprocessor_for_space(
                observation_space=self.obs_space.original_space,
                options=model_config)
        else:
            self.preprocessor = None
            logger.warn(
                "No custom preprocessor for StrategoModel was specified.\n"
                "Some tree search policies may not initialize their placeholders correctly without this."
            )

        self.use_lstm = model_config['use_lstm']
        if self.use_lstm:
            raise NotImplementedError

        self.fake_lstm = model_config['custom_options'].get('fake_lstm', False)
        self.vf_share_layers = model_config.get("vf_share_layers")
        self.mask_invalid_actions = model_config['custom_options'][
            'mask_invalid_actions']
        self._use_q_fn = model_config['custom_options']['q_fn']

        self.twin_q = twin_q
        assert not (not self._use_q_fn and self.twin_q)
        if self.twin_q and self.use_lstm:
            raise NotImplementedError
        self._sac_alpha = model_config.get("sac_alpha", False)

        conv_activation = get_activation_fn(
            model_config.get("conv_activation"))

        if self.use_lstm:
            raise NotImplementedError
        else:
            state_in, seq_lens_in = None, None

            self.pi_obs_inputs = tf.keras.layers.Input(
                shape=obs_space.original_space[self.pi_obs_key].shape,
                name="pi_observation")

            self.vf_obs_inputs = tf.keras.layers.Input(
                shape=obs_space.original_space[self.vf_obs_key].shape,
                name="vf_observation")

        def maybe_td(layer):
            if self.use_lstm:
                return tf.keras.layers.TimeDistributed(layer=layer,
                                                       name=f"td_{layer.name}")
            else:
                return layer

        def build_primary_layers(prefix: str, obs_in: tf.Tensor,
                                 state_in: tf.Tensor):
            # encapsulated in a function to either be called once for shared policy/vf or twice for separate policy/vf

            _last_layer = obs_in
            state_out = state_in
            for i, size in enumerate(model_config['fcnet_hiddens']):
                _last_layer = maybe_td(
                    tf.keras.layers.Dense(size,
                                          name="{}_fc_{}".format(prefix, i),
                                          activation=conv_activation,
                                          kernel_initializer=normc_initializer(
                                              1.0)))(_last_layer)

            return _last_layer, state_out

        if self.use_lstm:
            pi_state_in = state_in[:2]
            vf_state_in = state_in[2:]
        else:
            pi_state_in, vf_state_in = None, None

        self.main_vf_prefix = "main_vf" if self.twin_q else "vf"
        pi_last_layer, pi_state_out = build_primary_layers(
            prefix="pi", obs_in=self.pi_obs_inputs, state_in=pi_state_in)

        vf_last_layer, vf_state_out = build_primary_layers(
            prefix=self.main_vf_prefix,
            obs_in=self.vf_obs_inputs,
            state_in=vf_state_in)
        if self.twin_q:
            twin_vf_last_layer, twin_vf_state_out = build_primary_layers(
                prefix="twin_vf", obs_in=self.vf_obs_inputs, state_in=None)
        else:
            twin_vf_last_layer, twin_vf_state_out = None, None

        if self.use_lstm:
            raise NotImplementedError
        else:
            state_out = None

        unmasked_logits_out = maybe_td(
            tf.keras.layers.Dense(
                action_space.n,
                name="{}_fc_{}".format('pi', 'unmasked_logits'),
                activation=None,
                kernel_initializer=normc_initializer(1.0))(pi_last_layer))

        value_out = maybe_td(
            tf.keras.layers.Dense(
                action_space.n,
                name="{}_fc_{}".format(self.main_vf_prefix, 'q_out'),
                activation=None,
                kernel_initializer=normc_initializer(1.0))(vf_last_layer))

        if self.twin_q:
            twin_value_out = maybe_td(
                tf.keras.layers.Dense(action_space.n,
                                      name="{}_fc_{}".format(
                                          'twin_vf', 'q_out'),
                                      activation=None,
                                      kernel_initializer=normc_initializer(
                                          1.0))(twin_vf_last_layer))

        self.pi_model = tf.keras.Model(inputs=[self.pi_obs_inputs],
                                       outputs=[unmasked_logits_out])
        self.main_q_model = tf.keras.Model(inputs=[self.vf_obs_inputs],
                                           outputs=[value_out])

        if self.twin_q:
            self.twin_q_model = tf.keras.Model(inputs=[self.vf_obs_inputs],
                                               outputs=[twin_value_out])
            print(self.twin_q_model.summary())
            self.register_variables(self.twin_q_model.variables)

        print(self.pi_model.summary())
        print(self.main_q_model.summary())

        self.register_variables(self.pi_model.variables)
        self.register_variables(self.main_q_model.variables)

        self.log_alpha = tf.Variable(0.0, dtype=tf.float32, name="log_alpha")
        self.alpha = tf.exp(self.log_alpha)
        self.register_variables([self.log_alpha])
Beispiel #4
0
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):

        model_config = with_base_config(
            base_config=DEFAULT_STRATEGO_MODEL_CONFIG,
            extra_config=model_config)
        TFModelV2.__init__(self, obs_space, action_space, num_outputs,
                           model_config, name)

        print(model_config)

        observation_mode = model_config['custom_options']['observation_mode']
        if observation_mode == PARTIALLY_OBSERVABLE:
            self.pi_obs_key = 'partial_observation'
            self.vf_obs_key = 'partial_observation'
        elif observation_mode == FULLY_OBSERVABLE:
            self.pi_obs_key = 'full_observation'
            self.vf_obs_key = 'full_observation'
        elif observation_mode == BOTH_OBSERVATIONS:
            self.pi_obs_key = 'partial_observation'
            self.vf_obs_key = 'full_observation'
            assert not model_config['vf_share_layers']
        else:
            assert False, "policy observation_mode must be in [PARTIALLY_OBSERVABLE, FULLY_OBSERVABLE, BOTH_OBSERVATIONS]"

        if model_config["custom_preprocessor"]:
            print(obs_space)

            self.preprocessor = ModelCatalog.get_preprocessor_for_space(
                observation_space=self.obs_space.original_space,
                options=model_config)
        else:
            self.preprocessor = None
            logger.warn(
                "No custom preprocessor for StrategoModel was specified.\n"
                "Some tree search policies may not initialize their placeholders correctly without this."
            )

        self.use_lstm = model_config['use_lstm']
        self.fake_lstm = model_config['custom_options'].get('fake_lstm')
        self.vf_share_layers = model_config.get("vf_share_layers")
        self.mask_invalid_actions = model_config['custom_options'][
            'mask_invalid_actions']

        conv_activation = get_activation_fn(
            model_config.get("conv_activation"))
        lstm_filters = model_config["custom_options"]['lstm_filters']
        cnn_filters = model_config.get("conv_filters")
        final_pi_filter_amt = model_config["custom_options"][
            "final_pi_filter_amt"]

        rows = obs_space.original_space[self.pi_obs_key].shape[0]
        colums = obs_space.original_space[self.pi_obs_key].shape[1]

        if self.use_lstm:
            if self.fake_lstm:
                self._lstm_state_shape = (1, )
            else:
                self._lstm_state_shape = (rows, colums, lstm_filters[0][0])

        if self.use_lstm:

            state_in = [
                tf.keras.layers.Input(shape=self._lstm_state_shape,
                                      name="pi_lstm_h"),
                tf.keras.layers.Input(shape=self._lstm_state_shape,
                                      name="pi_lstm_c"),
                tf.keras.layers.Input(shape=self._lstm_state_shape,
                                      name="vf_lstm_h"),
                tf.keras.layers.Input(shape=self._lstm_state_shape,
                                      name="vf_lstm_c")
            ]

            seq_lens_in = tf.keras.layers.Input(shape=(), name="lstm_seq_in")

            self.pi_obs_inputs = tf.keras.layers.Input(
                shape=(None, *obs_space.original_space[self.pi_obs_key].shape),
                name="pi_observation")

            self.vf_obs_inputs = tf.keras.layers.Input(
                shape=(None, *obs_space.original_space[self.vf_obs_key].shape),
                name="vf_observation")

        else:
            state_in, seq_lens_in = None, None

            self.pi_obs_inputs = tf.keras.layers.Input(
                shape=obs_space.original_space[self.pi_obs_key].shape,
                name="pi_observation")

            self.vf_obs_inputs = tf.keras.layers.Input(
                shape=obs_space.original_space[self.vf_obs_key].shape,
                name="vf_observation")

        # if pi_cnn_filters is None:
        #     assert False
        #     # assuming board size will always remain the same for both pi and vf networks
        #     if self.use_lstm:
        #         single_obs_input_shape = self.pi_obs_inputs.shape.as_list()[2:]
        #     else:
        #         single_obs_input_shape = self.pi_obs_inputs.shape.as_list()[1:]
        #     pi_cnn_filters = _get_filter_config(single_obs_input_shape)
        #
        # if v_cnn_filters is None:
        #     assert False
        #     # assuming board size will always remain the same for both pi and vf networks
        #     if self.use_lstm:
        #         single_obs_input_shape = self.pi_obs_inputs.shape.as_list()[2:]
        #     else:
        #         single_obs_input_shape = self.pi_obs_inputs.shape.as_list()[1:]
        #     v_cnn_filters = _get_filter_config(single_obs_input_shape)

        def maybe_td(layer):
            if self.use_lstm:
                return tf.keras.layers.TimeDistributed(layer=layer,
                                                       name=f"td_{layer.name}")
            else:
                return layer

        def build_primary_layers(prefix: str, obs_in: tf.Tensor,
                                 state_in: tf.Tensor):
            # encapsulated in a function to either be called once for shared policy/vf or twice for separate policy/vf

            _last_layer = obs_in

            for i, (out_size, kernel, stride) in enumerate(cnn_filters):
                _last_layer = maybe_td(
                    tf.keras.layers.Conv2D(filters=out_size,
                                           kernel_size=kernel,
                                           strides=stride,
                                           activation=conv_activation,
                                           padding="same",
                                           name="{}_conv_{}".format(
                                               prefix, i)))(_last_layer)

            state_out = state_in
            if self.use_lstm and not self.fake_lstm:
                for i, (out_size, kernel, stride) in enumerate(lstm_filters):
                    if i > 0:
                        raise NotImplementedError(
                            "Only single lstm layers are implemented right now"
                        )

                    _last_layer, *state_out = tf.keras.layers.ConvLSTM2D(
                        filters=out_size,
                        kernel_size=kernel,
                        strides=stride,
                        activation=conv_activation,
                        padding="same",
                        return_sequences=True,
                        return_state=True,
                        name="{}_convlstm".format(prefix))(
                            inputs=_last_layer,
                            mask=tf.sequence_mask(seq_lens_in),
                            initial_state=state_in)

            # state_out = state_in
            # if self.use_lstm:
            #     _last_layer = maybe_td(tf.keras.layers.Flatten())(_last_layer)
            #     _last_layer, *state_out = tf.keras.layers.LSTM(
            #         units=64,
            #         return_sequences=True,
            #         return_state=True,
            #         name="{}_lstm".format(prefix))(
            #         inputs=_last_layer,
            #         mask=tf.sequence_mask(seq_lens_in),
            #         initial_state=state_in)

            return _last_layer, state_out

        if self.use_lstm:
            pi_state_in = state_in[:2]
            vf_state_in = state_in[2:]
        else:
            pi_state_in, vf_state_in = None, None

        pi_last_layer, pi_state_out = build_primary_layers(
            prefix="pi", obs_in=self.pi_obs_inputs, state_in=pi_state_in)

        vf_last_layer, vf_state_out = build_primary_layers(
            prefix="vf", obs_in=self.vf_obs_inputs, state_in=vf_state_in)

        if self.use_lstm:
            state_out = [*pi_state_out, *vf_state_out]
        else:
            state_out = None

        pi_last_layer = maybe_td(
            tf.keras.layers.Conv2D(filters=final_pi_filter_amt,
                                   kernel_size=[3, 3],
                                   strides=1,
                                   activation=conv_activation,
                                   padding="same",
                                   name="{}_conv_{}".format(
                                       'pi', "last")))(pi_last_layer)

        print(
            f"action space n: {action_space.n}, rows: {rows}, columns: {colums}, filters: {int(action_space.n / (rows * colums))}"
        )

        unmasked_logits_out = maybe_td(
            tf.keras.layers.Conv2D(
                filters=int(action_space.n / (rows * colums)),
                kernel_size=[3, 3],
                strides=1,
                activation=None,
                padding="same",
                name="{}_conv_{}".format('pi',
                                         "unmasked_logits")))(pi_last_layer)

        # pi_last_layer = maybe_td(tf.keras.layers.Flatten(name="pi_flatten"))(pi_last_layer)
        # unmasked_logits_out = maybe_td(tf.keras.layers.Dense(
        #     units=9,
        #     name="pi_unmasked_logits_out",
        #     activation=None,
        #     kernel_initializer=normc_initializer(0.01)))(pi_last_layer)
        # unmasked_logits_out = maybe_td(tf.keras.layers.Reshape(target_shape=[3,3,1]))(unmasked_logits_out)

        self._use_q_fn = model_config['custom_options']['q_fn']

        if self._use_q_fn:
            vf_last_layer = maybe_td(
                tf.keras.layers.Conv2D(filters=final_pi_filter_amt,
                                       kernel_size=[3, 3],
                                       strides=1,
                                       activation=conv_activation,
                                       padding="same",
                                       name="{}_conv_{}".format(
                                           'vf', "last")))(vf_last_layer)

            value_out = maybe_td(
                tf.keras.layers.Conv2D(
                    filters=int(action_space.n / (rows * colums)),
                    kernel_size=[3, 3],
                    strides=1,
                    activation=None,
                    padding="same",
                    name="{}_conv_{}".format('vf', "q_out")))(vf_last_layer)
        else:

            vf_last_layer = maybe_td(
                tf.keras.layers.Conv2D(filters=1,
                                       kernel_size=[1, 1],
                                       strides=1,
                                       activation=conv_activation,
                                       padding="same",
                                       name="{}_conv_{}".format(
                                           'vf', "last")))(vf_last_layer)

            vf_last_layer = maybe_td(
                tf.keras.layers.Flatten(name="vf_flatten"))(vf_last_layer)

            value_out = maybe_td(
                tf.keras.layers.Dense(
                    units=1,
                    name="vf_out",
                    activation=None,
                    kernel_initializer=normc_initializer(0.01)))(vf_last_layer)

        model_inputs = [self.pi_obs_inputs, self.vf_obs_inputs]
        model_outputs = [unmasked_logits_out, value_out]

        if self.use_lstm:
            model_inputs += [seq_lens_in, *state_in]
            model_outputs += state_out

        self.base_model = tf.keras.Model(inputs=model_inputs,
                                         outputs=model_outputs)

        print(self.base_model.summary())

        self.register_variables(self.base_model.variables)
Beispiel #5
0
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 q_hiddens=None,
                 dueling=False,
                 num_atoms=1,
                 use_noisy=False,
                 v_min=-10.0,
                 v_max=10.0,
                 sigma0=0.5,
                 parameter_noise=False):

        if q_hiddens or dueling or num_atoms != 1 or use_noisy:
            raise NotImplementedError

        model_config = with_base_config(
            base_config=DEFAULT_STRATEGO_MODEL_CONFIG,
            extra_config=model_config)
        TFModelV2.__init__(self, obs_space, action_space, num_outputs,
                           model_config, name)

        print(model_config)

        observation_mode = model_config['custom_options']['observation_mode']
        if observation_mode == PARTIALLY_OBSERVABLE:
            self.vf_obs_key = 'partial_observation'
        elif observation_mode == FULLY_OBSERVABLE:
            self.vf_obs_key = 'full_observation'
        elif observation_mode == BOTH_OBSERVATIONS:
            raise ValueError(
                f"Using {BOTH_OBSERVATIONS} format doesn't make sense for a Q-network, there's no policy, just a Q-function"
            )

        else:
            assert False, "policy observation_mode must be in [PARTIALLY_OBSERVABLE, FULLY_OBSERVABLE, BOTH_OBSERVATIONS]"

        if model_config["custom_preprocessor"]:
            print(obs_space)

            self.preprocessor = ModelCatalog.get_preprocessor_for_space(
                observation_space=self.obs_space.original_space,
                options=model_config)
        else:
            self.preprocessor = None
            logger.warn(
                "No custom preprocessor for StrategoModel was specified.\n"
                "Some tree search policies may not initialize their placeholders correctly without this."
            )

        self.use_lstm = model_config['use_lstm']
        self.vf_share_layers = model_config.get("vf_share_layers")
        self.mask_invalid_actions = model_config['custom_options'][
            'mask_invalid_actions']

        conv_activation = get_activation_fn(
            model_config.get("conv_activation"))
        lstm_filters = model_config["custom_options"]['lstm_filters']
        cnn_filters = model_config.get("conv_filters")
        final_pi_filter_amt = model_config["custom_options"][
            "final_pi_filter_amt"]

        rows = obs_space.original_space[self.vf_obs_key].shape[0]
        colums = obs_space.original_space[self.vf_obs_key].shape[1]

        if self.use_lstm:
            self._lstm_state_shape = (rows, colums, lstm_filters[0][0])
            # self._lstm_state_shape = (64,)

        if self.use_lstm:
            state_in = [
                tf.keras.layers.Input(shape=self._lstm_state_shape,
                                      name="vf_lstm_h"),
                tf.keras.layers.Input(shape=self._lstm_state_shape,
                                      name="vf_lstm_c")
            ]

            seq_lens_in = tf.keras.layers.Input(shape=(), name="lstm_seq_in")

            self.vf_obs_inputs = tf.keras.layers.Input(
                shape=(None, *obs_space.original_space[self.vf_obs_key].shape),
                name="vf_observation")

        else:
            state_in, seq_lens_in = None, None

            self.vf_obs_inputs = tf.keras.layers.Input(
                shape=obs_space.original_space[self.vf_obs_key].shape,
                name="vf_observation")

        # if pi_cnn_filters is None:
        #     assert False
        #     # assuming board size will always remain the same for both pi and vf networks
        #     if self.use_lstm:
        #         single_obs_input_shape = self.pi_obs_inputs.shape.as_list()[2:]
        #     else:
        #         single_obs_input_shape = self.pi_obs_inputs.shape.as_list()[1:]
        #     pi_cnn_filters = _get_filter_config(single_obs_input_shape)
        #
        # if v_cnn_filters is None:
        #     assert False
        #     # assuming board size will always remain the same for both pi and vf networks
        #     if self.use_lstm:
        #         single_obs_input_shape = self.pi_obs_inputs.shape.as_list()[2:]
        #     else:
        #         single_obs_input_shape = self.pi_obs_inputs.shape.as_list()[1:]
        #     v_cnn_filters = _get_filter_config(single_obs_input_shape)

        def maybe_td(layer):
            if self.use_lstm:
                return tf.keras.layers.TimeDistributed(layer=layer,
                                                       name=f"td_{layer.name}")
            else:
                return layer

        def build_primary_layers(prefix: str, obs_in: tf.Tensor,
                                 state_in: tf.Tensor):
            # encapsulated in a function to either be called once for shared policy/vf or twice for separate policy/vf

            _last_layer = obs_in

            for i, (out_size, kernel, stride) in enumerate(cnn_filters):
                _last_layer = maybe_td(
                    tf.keras.layers.Conv2D(filters=out_size,
                                           kernel_size=kernel,
                                           strides=stride,
                                           activation=conv_activation,
                                           padding="same",
                                           name="{}_conv_{}".format(
                                               prefix, i)))(_last_layer)

                if parameter_noise:
                    # assuming inputs shape (batch_size x w x h x channel)
                    _last_layer = maybe_td(
                        tf.keras.layers.LayerNormalization(
                            axis=(1, 2),
                            name=f"{prefix}_LayerNorm_{i}"))(_last_layer)

            state_out = state_in
            if self.use_lstm:
                for i, (out_size, kernel, stride) in enumerate(lstm_filters):
                    if i > 0:
                        raise NotImplementedError(
                            "Only single lstm layers are implemented right now"
                        )

                    _last_layer, *state_out = tf.keras.layers.ConvLSTM2D(
                        filters=out_size,
                        kernel_size=kernel,
                        strides=stride,
                        activation=conv_activation,
                        padding="same",
                        return_sequences=True,
                        return_state=True,
                        name="{}_convlstm".format(prefix))(
                            inputs=_last_layer,
                            mask=tf.sequence_mask(seq_lens_in),
                            initial_state=state_in)
                    raise NotImplementedError(
                        "havent checked lstms for q model"
                        "")
            return _last_layer, state_out

        if self.use_lstm:
            vf_state_in = state_in[2:]
        else:
            pi_state_in, vf_state_in = None, None

        vf_last_layer, vf_state_out = build_primary_layers(
            prefix="vf", obs_in=self.vf_obs_inputs, state_in=vf_state_in)

        if self.use_lstm:
            state_out = vf_state_out
        else:
            state_out = None

        vf_last_layer = maybe_td(
            tf.keras.layers.Conv2D(filters=final_pi_filter_amt,
                                   kernel_size=[3, 3],
                                   strides=1,
                                   activation=conv_activation,
                                   padding="same",
                                   name="{}_conv_{}".format(
                                       'vf', "last")))(vf_last_layer)

        if parameter_noise:
            # assuming inputs shape (batch_size x w x h x channel)
            vf_last_layer = maybe_td(
                tf.keras.layers.LayerNormalization(
                    axis=(1, 2), name=f"vf_LayerNorm_last"))(vf_last_layer)

        print(
            f"action space n: {action_space.n}, rows: {rows}, columns: {colums}, filters: {int(action_space.n / (rows * colums))}"
        )

        unmasked_logits_out = maybe_td(
            tf.keras.layers.Conv2D(
                filters=int(action_space.n / (rows * colums)),
                kernel_size=[3, 3],
                strides=1,
                activation=None,
                padding="same",
                name="{}_conv_{}".format('vf',
                                         "unmasked_logits")))(vf_last_layer)

        # vf_last_layer = maybe_td(tf.keras.layers.Conv2D(
        #     filters=1,
        #     kernel_size=[1, 1],
        #     strides=1,
        #     activation=conv_activation,
        #     padding="same",
        #     name="{}_conv_{}".format('vf', "last")))(vf_last_layer)
        #
        # vf_last_layer = maybe_td(tf.keras.layers.Flatten(name="vf_flatten"))(vf_last_layer)
        #
        # value_out = maybe_td(tf.keras.layers.Dense(
        #     units=1,
        #     name="vf_out",
        #     activation=None,
        #     kernel_initializer=normc_initializer(0.01)))(vf_last_layer)

        model_inputs = [self.vf_obs_inputs]
        model_outputs = [unmasked_logits_out]

        if self.use_lstm:
            model_inputs += [seq_lens_in, *state_in]
            model_outputs += state_out

        self.base_model = tf.keras.Model(inputs=model_inputs,
                                         outputs=model_outputs)

        print(self.base_model.summary())

        self.register_variables(self.base_model.variables)