def test_gaussian_distribution():
    with tf.Graph().as_default():
        logits = tf.Variable(initial_value=[[1, 1]],
                             trainable=True,
                             dtype=tf.float32)
        distribution = GaussianDistribution(
            logits,
            act_size=VECTOR_ACTION_SPACE,
            reparameterize=False,
            tanh_squash=False,
        )
        sess = tf.Session()
        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)
            output = sess.run(distribution.sample)
            for _ in range(10):
                output = sess.run(
                    [distribution.sample, distribution.log_probs])
                for out in output:
                    assert out.shape[1] == VECTOR_ACTION_SPACE[0]
                output = sess.run([distribution.total_log_probs])
                assert output[0].shape[0] == 1
            # Test entropy is correct
            log_std_tensor = tf.get_default_graph().get_tensor_by_name(
                "log_std/BiasAdd:0")
            feed_dict = {log_std_tensor: [[1.0, 1.0]]}
            entropy = sess.run([distribution.entropy], feed_dict=feed_dict)
            # Entropy with log_std of 1.0 should be 2.42
            assert pytest.approx(entropy[0], 0.01) == 2.42
예제 #2
0
def test_tanh_distribution():
    with tf.Graph().as_default():
        logits = tf.Variable(initial_value=[[0, 0]],
                             trainable=True,
                             dtype=tf.float32)
        distribution = GaussianDistribution(logits,
                                            act_size=VECTOR_ACTION_SPACE,
                                            reparameterize=False,
                                            tanh_squash=True)
        sess = tf.Session()
        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)
            output = sess.run(distribution.sample)
            for _ in range(10):
                output = sess.run(
                    [distribution.sample, distribution.log_probs])
                for out in output:
                    assert out.shape[1] == VECTOR_ACTION_SPACE[0]
                # Assert action never exceeds [-1,1]
                action = output[0][0]
                for act in action:
                    assert act >= -1 and act <= 1
                output = sess.run([distribution.total_log_probs])
                assert output[0].shape[0] == 1
예제 #3
0
def convert_frozen_to_onnx(
    settings: SerializationSettings, frozen_graph_def: tf.GraphDef
) -> Any:
    # This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py

    inputs = _get_input_node_names(frozen_graph_def)
    outputs = _get_output_node_names(frozen_graph_def)
    logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}")

    frozen_graph_def = tf_optimize(
        inputs, outputs, frozen_graph_def, fold_constant=True
    )

    with tf.Graph().as_default() as tf_graph:
        tf.import_graph_def(frozen_graph_def, name="")
    with tf.Session(graph=tf_graph):
        g = process_tf_graph(
            tf_graph,
            input_names=inputs,
            output_names=outputs,
            opset=settings.onnx_opset,
        )

    onnx_graph = optimizer.optimize_graph(g)
    model_proto = onnx_graph.make_model(settings.brain_name)

    return model_proto
예제 #4
0
def test_ppo_model_cc_vector_rnn():
    tf.reset_default_graph()
    with tf.Session() as sess:
        with tf.variable_scope("FakeGraphScope"):
            memory_size = 128
            model = PPOModel(
                make_brain_parameters(discrete_action=False, visual_inputs=0),
                use_recurrent=True,
                m_size=memory_size,
            )
            init = tf.global_variables_initializer()
            sess.run(init)

            run_list = [
                model.output,
                model.all_log_probs,
                model.value,
                model.entropy,
                model.learning_rate,
                model.memory_out,
            ]
            feed_dict = {
                model.batch_size: 1,
                model.sequence_length: 2,
                model.memory_in: np.zeros((1, memory_size), dtype=np.float32),
                model.vector_in: np.array([[1, 2, 3, 1, 2, 3],
                                           [3, 4, 5, 3, 4, 5]]),
                model.epsilon: np.array([[0, 1]]),
            }
            sess.run(run_list, feed_dict=feed_dict)
예제 #5
0
def test_average_gradients(mock_get_devices, dummy_config):
    tf.reset_default_graph()
    mock_get_devices.return_value = [
        "/device:GPU:0",
        "/device:GPU:1",
        "/device:GPU:2",
        "/device:GPU:3",
    ]

    trainer_parameters = dummy_config
    trainer_parameters["model_path"] = ""
    trainer_parameters["keep_checkpoints"] = 3
    brain = create_mock_brainparams()
    with tf.Session() as sess:
        policy = MultiGpuPPOPolicy(0, brain, trainer_parameters, False, False)
        var = tf.Variable(0)
        tower_grads = [
            [(tf.constant(0.1), var)],
            [(tf.constant(0.2), var)],
            [(tf.constant(0.3), var)],
            [(tf.constant(0.4), var)],
        ]
        avg_grads = policy.average_gradients(tower_grads)

        init = tf.global_variables_initializer()
        sess.run(init)
        run_out = sess.run(avg_grads)
    assert run_out == [(0.25, 0)]
예제 #6
0
def test_ppo_model_dc_visual():
    tf.reset_default_graph()
    with tf.Session() as sess:
        with tf.variable_scope("FakeGraphScope"):
            model = PPOModel(
                make_brain_parameters(discrete_action=True, visual_inputs=2))
            init = tf.global_variables_initializer()
            sess.run(init)

            run_list = [
                model.output,
                model.all_log_probs,
                model.value,
                model.entropy,
                model.learning_rate,
            ]
            feed_dict = {
                model.batch_size: 2,
                model.sequence_length: 1,
                model.vector_in: np.array([[1, 2, 3, 1, 2, 3],
                                           [3, 4, 5, 3, 4, 5]]),
                model.visual_in[0]: np.ones([2, 40, 30, 3], dtype=np.float32),
                model.visual_in[1]: np.ones([2, 40, 30, 3], dtype=np.float32),
                model.action_masks: np.ones([2, 2], dtype=np.float32),
            }
            sess.run(run_list, feed_dict=feed_dict)
def test_multicategorical_distribution():
    with tf.Graph().as_default():
        logits = tf.Variable(initial_value=[[0, 0]],
                             trainable=True,
                             dtype=tf.float32)
        action_masks = tf.Variable(
            initial_value=[[1 for _ in range(sum(DISCRETE_ACTION_SPACE))]],
            trainable=True,
            dtype=tf.float32,
        )
        distribution = MultiCategoricalDistribution(
            logits, act_size=DISCRETE_ACTION_SPACE, action_masks=action_masks)
        sess = tf.Session()
        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)
            output = sess.run(distribution.sample)
            for _ in range(10):
                sample, log_probs, entropy = sess.run([
                    distribution.sample, distribution.log_probs,
                    distribution.entropy
                ])
                assert len(log_probs[0]) == sum(DISCRETE_ACTION_SPACE)
                # Assert action never exceeds [-1,1]
                assert len(sample[0]) == len(DISCRETE_ACTION_SPACE)
                for i, act in enumerate(sample[0]):
                    assert act >= 0 and act <= DISCRETE_ACTION_SPACE[i]
                output = sess.run([distribution.total_log_probs])
                assert output[0].shape[0] == 1
                # Make sure entropy is correct
                assert entropy[0] > 3.8

            # Test masks
            mask = []
            for space in DISCRETE_ACTION_SPACE:
                mask.append(1)
                for _action_space in range(1, space):
                    mask.append(0)
            for _ in range(10):
                sample, log_probs = sess.run(
                    [distribution.sample, distribution.log_probs],
                    feed_dict={action_masks: [mask]},
                )
                for act in sample[0]:
                    assert act >= 0 and act <= 1
                output = sess.run([distribution.total_log_probs])
예제 #8
0
    def __init__(
        self,
        seed: int,
        behavior_spec: BehaviorSpec,
        trainer_settings: TrainerSettings,
        model_path: str,
        load: bool = False,
    ):
        """
        Initialized the policy.
        :param seed: Random seed to use for TensorFlow.
        :param brain: The corresponding Brain for this policy.
        :param trainer_settings: The trainer parameters.
        :param model_path: Where to load/save the model.
        :param load: If True, load model from model_path. Otherwise, create new model.
        """

        self.m_size = 0
        self.trainer_settings = trainer_settings
        self.network_settings: NetworkSettings = trainer_settings.network_settings
        # for ghost trainer save/load snapshots
        self.assign_phs: List[tf.Tensor] = []
        self.assign_ops: List[tf.Operation] = []

        self.inference_dict: Dict[str, tf.Tensor] = {}
        self.update_dict: Dict[str, tf.Tensor] = {}
        self.sequence_length = 1
        self.seed = seed
        self.behavior_spec = behavior_spec

        self.act_size = (list(behavior_spec.discrete_action_branches)
                         if behavior_spec.is_action_discrete() else
                         [behavior_spec.action_size])
        self.vec_obs_size = sum(shape[0]
                                for shape in behavior_spec.observation_shapes
                                if len(shape) == 1)
        self.vis_obs_size = sum(1 for shape in behavior_spec.observation_shapes
                                if len(shape) == 3)

        self.use_recurrent = self.network_settings.memory is not None
        self.memory_dict: Dict[str, np.ndarray] = {}
        self.num_branches = self.behavior_spec.action_size
        self.previous_action_dict: Dict[str, np.array] = {}
        self.normalize = self.network_settings.normalize
        self.use_continuous_act = behavior_spec.is_action_continuous()
        self.model_path = model_path
        self.initialize_path = self.trainer_settings.init_path
        self.keep_checkpoints = self.trainer_settings.keep_checkpoints
        self.graph = tf.Graph()
        self.sess = tf.Session(config=tf_utils.generate_session_config(),
                               graph=self.graph)
        self.saver: Optional[tf.Operation] = None
        self.seed = seed
        if self.network_settings.memory is not None:
            self.m_size = self.network_settings.memory.memory_size
            self.sequence_length = self.network_settings.memory.sequence_length
        self._initialize_tensorflow_references()
        self.load = load
예제 #9
0
    def __init__(self, seed, brain, trainer_parameters, load=False):
        """
        Initialized the policy.
        :param seed: Random seed to use for TensorFlow.
        :param brain: The corresponding Brain for this policy.
        :param trainer_parameters: The trainer parameters.
        """
        self._version_number_ = 2
        self.m_size = 0

        # for ghost trainer save/load snapshots
        self.assign_phs = []
        self.assign_ops = []

        self.inference_dict = {}
        self.update_dict = {}
        self.sequence_length = 1
        self.seed = seed
        self.brain = brain

        self.act_size = brain.vector_action_space_size
        self.vec_obs_size = brain.vector_observation_space_size
        self.vis_obs_size = brain.number_visual_observations

        self.use_recurrent = trainer_parameters["use_recurrent"]
        self.memory_dict: Dict[str, np.ndarray] = {}
        self.num_branches = len(self.brain.vector_action_space_size)
        self.previous_action_dict: Dict[str, np.array] = {}
        self.normalize = trainer_parameters.get("normalize", False)
        self.use_continuous_act = brain.vector_action_space_type == "continuous"
        if self.use_continuous_act:
            self.num_branches = self.brain.vector_action_space_size[0]
        self.model_path = trainer_parameters["output_path"]
        self.initialize_path = trainer_parameters.get("init_path", None)
        self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5)
        self.graph = tf.Graph()
        self.sess = tf.Session(
            config=tf_utils.generate_session_config(), graph=self.graph
        )
        self.saver = None
        self.seed = seed
        if self.use_recurrent:
            self.m_size = trainer_parameters["memory_size"]
            self.sequence_length = trainer_parameters["sequence_length"]
            if self.m_size == 0:
                raise UnityPolicyException(
                    "The memory size for brain {0} is 0 even "
                    "though the trainer uses recurrent.".format(brain.brain_name)
                )
            elif self.m_size % 2 != 0:
                raise UnityPolicyException(
                    "The memory size for brain {0} is {1} "
                    "but it must be divisible by 2.".format(
                        brain.brain_name, self.m_size
                    )
                )
        self._initialize_tensorflow_references()
        self.load = load
예제 #10
0
 def __init__(self, seed, brain, trainer_parameters):
     """
     Initialized the policy.
     :param seed: Random seed to use for TensorFlow.
     :param brain: The corresponding Brain for this policy.
     :param trainer_parameters: The trainer parameters.
     """
     self.m_size = None
     self.model = None
     self.inference_dict = {}
     self.update_dict = {}
     self.sequence_length = 1
     self.seed = seed
     self.brain = brain
     self.use_recurrent = trainer_parameters["use_recurrent"]
     self.memory_dict: Dict[str, np.ndarray] = {}
     self.reward_signals: Dict[str, "RewardSignal"] = {}
     self.num_branches = len(self.brain.vector_action_space_size)
     self.previous_action_dict: Dict[str, np.array] = {}
     self.normalize = trainer_parameters.get("normalize", False)
     self.use_continuous_act = brain.vector_action_space_type == "continuous"
     if self.use_continuous_act:
         self.num_branches = self.brain.vector_action_space_size[0]
     self.model_path = trainer_parameters["model_path"]
     self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5)
     self.graph = tf.Graph()
     config = tf.ConfigProto()
     config.gpu_options.allow_growth = True
     # For multi-GPU training, set allow_soft_placement to True to allow
     # placing the operation into an alternative device automatically
     # to prevent from exceptions if the device doesn't suppport the operation
     # or the device does not exist
     config.allow_soft_placement = True
     self.sess = tf.Session(config=config, graph=self.graph)
     self.saver = None
     if self.use_recurrent:
         self.m_size = trainer_parameters["memory_size"]
         self.sequence_length = trainer_parameters["sequence_length"]
         if self.m_size == 0:
             raise UnityPolicyException(
                 "The memory size for brain {0} is 0 even "
                 "though the trainer uses recurrent.".format(brain.brain_name)
             )
         elif self.m_size % 4 != 0:
             raise UnityPolicyException(
                 "The memory size for brain {0} is {1} "
                 "but it must be divisible by 4.".format(
                     brain.brain_name, self.m_size
                 )
             )
예제 #11
0
    def __init__(
        self,
        seed: int,
        behavior_spec: BehaviorSpec,
        trainer_settings: TrainerSettings,
        tanh_squash: bool = False,
        reparameterize: bool = False,
        condition_sigma_on_obs: bool = True,
        create_tf_graph: bool = True,
    ):
        """
        Initialized the policy.
        :param seed: Random seed to use for TensorFlow.
        :param brain: The corresponding Brain for this policy.
        :param trainer_settings: The trainer parameters.
        """
        super().__init__(
            seed,
            behavior_spec,
            trainer_settings,
            tanh_squash,
            reparameterize,
            condition_sigma_on_obs,
        )
        if (
            self.behavior_spec.action_spec.continuous_size > 0
            and self.behavior_spec.action_spec.discrete_size > 0
        ):
            raise UnityPolicyException(
                "TensorFlow does not support mixed action spaces. Please run with the Torch framework."
            )
        # for ghost trainer save/load snapshots
        self.assign_phs: List[tf.Tensor] = []
        self.assign_ops: List[tf.Operation] = []
        self.update_dict: Dict[str, tf.Tensor] = {}
        self.inference_dict: Dict[str, tf.Tensor] = {}
        self.first_normalization_update: bool = False

        self.graph = tf.Graph()
        self.sess = tf.Session(
            config=tf_utils.generate_session_config(), graph=self.graph
        )
        self._initialize_tensorflow_references()
        self.grads = None
        self.update_batch: Optional[tf.Operation] = None
        self.trainable_variables: List[tf.Variable] = []
        self.rank = get_rank()
        if create_tf_graph:
            self.create_tf_graph()
예제 #12
0
    def __init__(
        self,
        seed: int,
        behavior_spec: BehaviorSpec,
        trainer_settings: TrainerSettings,
        model_path: str,
        load: bool = False,
        tanh_squash: bool = False,
        reparameterize: bool = False,
        condition_sigma_on_obs: bool = True,
        create_tf_graph: bool = True,
    ):
        """
        Initialized the policy.
        :param seed: Random seed to use for TensorFlow.
        :param brain: The corresponding Brain for this policy.
        :param trainer_settings: The trainer parameters.
        :param model_path: Where to load/save the model.
        :param load: If True, load model from model_path. Otherwise, create new model.
        """
        super().__init__(
            seed,
            behavior_spec,
            trainer_settings,
            model_path,
            load,
            tanh_squash,
            reparameterize,
            condition_sigma_on_obs,
        )
        # for ghost trainer save/load snapshots
        self.assign_phs: List[tf.Tensor] = []
        self.assign_ops: List[tf.Operation] = []
        self.update_dict: Dict[str, tf.Tensor] = {}
        self.inference_dict: Dict[str, tf.Tensor] = {}
        self.first_normalization_update: bool = False

        self.graph = tf.Graph()
        self.sess = tf.Session(config=tf_utils.generate_session_config(),
                               graph=self.graph)
        self.saver: Optional[tf.Operation] = None
        self._initialize_tensorflow_references()
        self.grads = None
        self.update_batch: Optional[tf.Operation] = None
        self.trainable_variables: List[tf.Variable] = []
        if create_tf_graph:
            self.create_tf_graph()
예제 #13
0
def test_sac_model_cc_vector():
    tf.reset_default_graph()
    with tf.Session() as sess:
        with tf.variable_scope("FakeGraphScope"):
            model = SACModel(
                make_brain_parameters(discrete_action=False, visual_inputs=0)
            )
            init = tf.global_variables_initializer()
            sess.run(init)

            run_list = [model.output, model.value, model.entropy, model.learning_rate]
            feed_dict = {
                model.batch_size: 2,
                model.sequence_length: 1,
                model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
            }
            sess.run(run_list, feed_dict=feed_dict)
def convert_frozen_to_onnx(settings: SerializationSettings,
                           frozen_graph_def: tf.GraphDef) -> Any:
    # This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py

    # Some constants in the graph need to be read by the inference system.
    # These aren't used by the model anywhere, so trying to make sure they propagate
    # through conversion and import is a losing battle. Instead, save them now,
    # so that we can add them back later.
    constant_values = {}
    for n in frozen_graph_def.node:
        if n.name in MODEL_CONSTANTS:
            val = n.attr["value"].tensor.int_val[0]
            constant_values[n.name] = val

    inputs = _get_input_node_names(frozen_graph_def)
    outputs = _get_output_node_names(frozen_graph_def)
    logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}")

    frozen_graph_def = tf_optimize(inputs,
                                   outputs,
                                   frozen_graph_def,
                                   fold_constant=True)

    with tf.Graph().as_default() as tf_graph:
        tf.import_graph_def(frozen_graph_def, name="")
    with tf.Session(graph=tf_graph):
        g = process_tf_graph(
            tf_graph,
            input_names=inputs,
            output_names=outputs,
            opset=settings.onnx_opset,
        )

    onnx_graph = optimizer.optimize_graph(g)
    model_proto = onnx_graph.make_model(settings.brain_name)

    # Save the constant values back the graph initializer.
    # This will ensure the importer gets them as global constants.
    constant_nodes = []
    for k, v in constant_values.items():
        constant_node = _make_onnx_node_for_constant(k, v)
        constant_nodes.append(constant_node)
    model_proto.graph.initializer.extend(constant_nodes)
    return model_proto
예제 #15
0
def test_dc_bc_model():
    tf.reset_default_graph()
    with tf.Session() as sess:
        with tf.variable_scope("FakeGraphScope"):
            model = BehavioralCloningModel(
                make_brain_parameters(discrete_action=True, visual_inputs=0))
            init = tf.global_variables_initializer()
            sess.run(init)

            run_list = [model.sample_action, model.action_probs]
            feed_dict = {
                model.batch_size: 2,
                model.dropout_rate: 1.0,
                model.sequence_length: 1,
                model.vector_in: np.array([[1, 2, 3, 1, 2, 3],
                                           [3, 4, 5, 3, 4, 5]]),
                model.action_masks: np.ones([2, 2]),
            }
            sess.run(run_list, feed_dict=feed_dict)
예제 #16
0
 def _dict_to_tensorboard(self, name: str, input_dict: Dict[str,
                                                            Any]) -> str:
     """
     Convert a dict to a Tensorboard-encoded string.
     :param name: The name of the text.
     :param input_dict: A dictionary that will be displayed in a table on Tensorboard.
     """
     try:
         with tf.Session(config=generate_session_config()) as sess:
             s_op = tf.summary.text(
                 name,
                 tf.convert_to_tensor(
                     ([[str(x), str(input_dict[x])] for x in input_dict])),
             )
             s = sess.run(s_op)
             return s
     except Exception:
         logger.warning("Could not write text summary for Tensorboard.")
         return ""
예제 #17
0
def test_visual_cc_bc_model():
    tf.reset_default_graph()
    with tf.Session() as sess:
        with tf.variable_scope("FakeGraphScope"):
            model = BehavioralCloningModel(
                make_brain_parameters(discrete_action=False, visual_inputs=2))
            init = tf.global_variables_initializer()
            sess.run(init)

            run_list = [model.sample_action, model.policy]
            feed_dict = {
                model.batch_size: 2,
                model.sequence_length: 1,
                model.vector_in: np.array([[1, 2, 3, 1, 2, 3],
                                           [3, 4, 5, 3, 4, 5]]),
                model.visual_in[0]: np.ones([2, 40, 30, 3], dtype=np.float32),
                model.visual_in[1]: np.ones([2, 40, 30, 3], dtype=np.float32),
            }
            sess.run(run_list, feed_dict=feed_dict)
예제 #18
0
 def __init__(self, seed, brain, trainer_parameters):
     """
     Initialized the policy.
     :param seed: Random seed to use for TensorFlow.
     :param brain: The corresponding Brain for this policy.
     :param trainer_parameters: The trainer parameters.
     """
     self.m_size = None
     self.model = None
     self.inference_dict = {}
     self.update_dict = {}
     self.sequence_length = 1
     self.seed = seed
     self.brain = brain
     self.use_recurrent = trainer_parameters["use_recurrent"]
     self.memory_dict: Dict[str, np.ndarray] = {}
     self.reward_signals: Dict[str, "RewardSignal"] = {}
     self.num_branches = len(self.brain.vector_action_space_size)
     self.previous_action_dict: Dict[str, np.array] = {}
     self.normalize = trainer_parameters.get("normalize", False)
     self.use_continuous_act = brain.vector_action_space_type == "continuous"
     if self.use_continuous_act:
         self.num_branches = self.brain.vector_action_space_size[0]
     self.model_path = trainer_parameters["model_path"]
     self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5)
     self.graph = tf.Graph()
     self.sess = tf.Session(config=tf_utils.generate_session_config(),
                            graph=self.graph)
     self.saver = None
     if self.use_recurrent:
         self.m_size = trainer_parameters["memory_size"]
         self.sequence_length = trainer_parameters["sequence_length"]
         if self.m_size == 0:
             raise UnityPolicyException(
                 "The memory size for brain {0} is 0 even "
                 "though the trainer uses recurrent.".format(
                     brain.brain_name))
         elif self.m_size % 4 != 0:
             raise UnityPolicyException(
                 "The memory size for brain {0} is {1} "
                 "but it must be divisible by 4.".format(
                     brain.brain_name, self.m_size))
예제 #19
0
 def write_tensorboard_text(self, key: str, input_dict: Dict[str, Any]) -> None:
     """
     Saves text to Tensorboard.
     Note: Only works on tensorflow r1.2 or above.
     :param key: The name of the text.
     :param input_dict: A dictionary that will be displayed in a table on Tensorboard.
     """
     try:
         with tf.Session(config=tf_utils.generate_session_config()) as sess:
             s_op = tf.summary.text(
                 key,
                 tf.convert_to_tensor(
                     ([[str(x), str(input_dict[x])] for x in input_dict])
                 ),
             )
             s = sess.run(s_op)
             self.stats_reporter.write_text(s, self.get_step)
     except Exception:
         LOGGER.info("Could not write text summary for Tensorboard.")
         pass
예제 #20
0
    def __init__(
        self,
        seed: int,
        brain: BrainParameters,
        trainer_settings: TrainerSettings,
        model_path: str,
        load: bool = False,
    ):
        """
        Initialized the policy.
        :param seed: Random seed to use for TensorFlow.
        :param brain: The corresponding Brain for this policy.
        :param trainer_settings: The trainer parameters.
        :param model_path: Where to load/save the model.
        :param load: If True, load model from model_path. Otherwise, create new model.
        """

        self.m_size = 0
        self.trainer_settings = trainer_settings
        self.network_settings: NetworkSettings = trainer_settings.network_settings
        # for ghost trainer save/load snapshots
        self.assign_phs: List[tf.Tensor] = []
        self.assign_ops: List[tf.Operation] = []

        self.inference_dict: Dict[str, tf.Tensor] = {}
        self.update_dict: Dict[str, tf.Tensor] = {}
        self.sequence_length = 1
        self.seed = seed
        self.brain = brain

        self.act_size = brain.vector_action_space_size
        self.vec_obs_size = brain.vector_observation_space_size
        self.vis_obs_size = brain.number_visual_observations

        self.use_recurrent = self.network_settings.memory is not None
        self.memory_dict: Dict[str, np.ndarray] = {}
        self.num_branches = len(self.brain.vector_action_space_size)
        self.previous_action_dict: Dict[str, np.array] = {}
        self.normalize = self.network_settings.normalize
        self.use_continuous_act = brain.vector_action_space_type == "continuous"
        if self.use_continuous_act:
            self.num_branches = self.brain.vector_action_space_size[0]
        self.model_path = model_path
        self.initialize_path = self.trainer_settings.init_path
        self.keep_checkpoints = self.trainer_settings.keep_checkpoints
        self.graph = tf.Graph()
        self.sess = tf.Session(config=tf_utils.generate_session_config(),
                               graph=self.graph)
        self.saver: Optional[tf.Operation] = None
        self.seed = seed
        if self.network_settings.memory is not None:
            self.m_size = self.network_settings.memory.memory_size
            self.sequence_length = self.network_settings.memory.sequence_length
            if self.m_size == 0:
                raise UnityPolicyException(
                    "The memory size for brain {0} is 0 even "
                    "though the trainer uses recurrent.".format(
                        brain.brain_name))
            elif self.m_size % 2 != 0:
                raise UnityPolicyException(
                    "The memory size for brain {0} is {1} "
                    "but it must be divisible by 2.".format(
                        brain.brain_name, self.m_size))
        self._initialize_tensorflow_references()
        self.load = load