예제 #1
0
    def testTrainMaskingRewardMultipleEpisodesRewardOnFirst(self):
        # Test that train reacts correctly to experience when there are:
        #   * Multiple MDP episodes
        #   * Rewards on the tf.StepType.FIRST transitions
        #
        # F, L, M = ts.StepType.{FIRST, MID, LAST} in the chart below.
        #
        # Experience looks like this:
        # Trajectories: (F, L) -> (L, F) -> (F, L) -> (L, F)
        # observation : [1, 2]    [1, 2]    [1, 2]    [1, 2]
        # action      :   [0]       [1]       [2]       [3]
        # reward      :    3         0         4         0
        # ~is_boundary:    1         0         1         0
        # is_last     :    1         0         1         0
        # valid reward:   3*1       0*0       4*1       0*0
        #
        # The second & fourth action & reward should be masked out due to being on a
        # boundary (step_type=(L, F)) transition.
        #
        # The expected_loss is > 0.0 in this case, matching the expected_loss of the
        # testMaskingRewardMultipleEpisodesRewardOnFirst policy_gradient_loss test.
        agent = reinforce_agent.ReinforceAgent(
            self._time_step_spec,
            self._action_spec,
            actor_network=DummyActorNet(self._obs_spec,
                                        self._action_spec,
                                        unbounded_actions=True),
            optimizer=tf.compat.v1.train.AdamOptimizer(0.001),
            use_advantage_loss=False,
            normalize_returns=False,
        )

        step_type = tf.constant([
            ts.StepType.FIRST, ts.StepType.LAST, ts.StepType.FIRST,
            ts.StepType.LAST
        ])
        next_step_type = tf.constant([
            ts.StepType.LAST, ts.StepType.FIRST, ts.StepType.LAST,
            ts.StepType.FIRST
        ])
        reward = tf.constant([3, 0, 4, 0], dtype=tf.float32)
        discount = tf.constant([1, 0, 1, 0], dtype=tf.float32)
        observations = tf.constant([[1, 2], [1, 2], [1, 2], [1, 2]],
                                   dtype=tf.float32)
        actions = tf.constant([[0], [1], [2], [3]], dtype=tf.float32)

        experience = nest_utils.batch_nested_tensors(
            trajectory.Trajectory(step_type, observations, actions, (),
                                  next_step_type, reward, discount))

        # Rewards on the StepType.FIRST should be counted.
        expected_loss = 12.2091741562

        if tf.executing_eagerly():
            loss = lambda: agent.train(experience)
        else:
            loss = agent.train(experience)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        loss_info = self.evaluate(loss)
        self.assertAllClose(loss_info.loss, expected_loss)
예제 #2
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=1000,
        fc_layers=(100, ),
        use_tf_functions=True,
        # Params for collect
        collect_episodes_per_iteration=2,
        replay_buffer_capacity=2000,
        # Params for train
        learning_rate=1e-3,
        gradient_clipping=None,
        normalize_returns=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=100,
        # Params for checkpoints, summaries, and logging
        log_interval=100,
        summary_interval=100,
        summaries_flush_secs=1,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for Reinforce."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=fc_layers)

        global_step = tf.compat.v1.train.get_or_create_global_step()
        tf_agent = reinforce_agent.ReinforceAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            normalize_returns=normalize_returns,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        if use_tf_functions:
            # To speed up collect use TF function.
            collect_driver.run = common.function(collect_driver.run)
            # To speed up train use TF function.
            tf_agent.train = common.function(tf_agent.train)

        # Compute evaluation metrics.
        metrics = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        # TODO(b/126590894): Move this functionality into eager_compute_summaries
        if eval_metrics_callback is not None:
            eval_metrics_callback(metrics, global_step.numpy())

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            experience = replay_buffer.gather_all()
            total_loss = tf_agent.train(experience)
            replay_buffer.clear()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()
            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            if global_step_val % eval_interval == 0:
                metrics = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                # TODO(b/126590894): Move this functionality into
                # eager_compute_summaries.
                if eval_metrics_callback is not None:
                    eval_metrics_callback(metrics, global_step_val)
def experiment(
        # Environment hyperparameters
        v_n=2,
        v_k=2,
        v_seed=43,
        seed_rest=42,
        do_transform=True,
        time_limit=20,

        # Agent hyperparameters
        num_iterations=500,  # @param {type:"integer"}
        collect_episodes_per_iteration=5,  # @param {type:"integer"}
        replay_buffer_capacity=1000,  # @param {type:"integer"}
        fc_layer_params=(),
        learning_rate=1e-3,  # @param {type:"number"}
        log_interval=25,  # @param {type:"integer"}
        num_eval_episodes=10,  # @param {type:"integer"}
        eval_interval=10,  # @param {type:"integer"}

        # Decoder: a linear transformation from observations to features
    l2_decoder_coeff=1e-3,  # to prevent weight explosion from the l1 loss

        # Model of the environment
    l1coeff=1e-2,
        model_train_epochs=10,

        # Curiosity parameters
        alpha=1.0,
        curiosity_interval=20,
        sparsity=0.5,
        name="experiment"):

    tf.random.set_seed(seed_rest)

    f = open(name + ".out.%d.txt" % time.time(), "w")

    def print_exp(s):
        # print(str(s))
        f.write(str(s) + "\n")

    def savefig(n):
        plt.savefig(name + "." + str(n) + ".png", bbox_inches="tight")

    # two layers, will sync parameters between them

    decoder_layer = tf.keras.layers.Dense(
        v_n,
        input_shape=(v_k, ),
        activation=None,
        use_bias=False,
        kernel_initializer='random_normal',
        kernel_regularizer=tf.keras.regularizers.l2(l2_decoder_coeff))

    decoder_layer_agent = tf.keras.layers.Dense(
        v_n,
        input_shape=(v_k, ),
        activation=None,
        use_bias=False,
        kernel_initializer='random_normal',
        kernel_regularizer=tf.keras.regularizers.l2(l2_decoder_coeff))

    decoder = tf.keras.Sequential([decoder_layer])

    # In[8]:

    pruning_params = {
        'pruning_schedule': pruning_sched.ConstantSparsity(sparsity, 0),
        'block_size': (1, 1),
        'block_pooling_type': 'AVG'
    }

    env_model = tf.keras.Sequential([
        m_passthrough_action(decoder, v_k, v_n),
        tf.keras.layers.InputLayer(
            input_shape=(v_k + v_n, )),  # input: [state, one-hot action]
        prune.prune_low_magnitude(
            tf.keras.layers.Dense(
                v_k, kernel_regularizer=tf.keras.regularizers.l1(l1coeff)),
            **pruning_params)  # output: state
    ])

    env_model.compile('adam', 'mse')

    # Creating a curiosity-wrapped environment

    # In[10]:

    def get_env(add_curiosity_reward=True):
        """Return a copy of the environment."""
        env = VectorIncrementEnvironmentTFAgents(v_n=v_n,
                                                 v_k=v_k,
                                                 v_seed=v_seed,
                                                 do_transform=do_transform)
        env = wrappers.TimeLimit(env, time_limit)
        if add_curiosity_reward:
            env = CuriosityWrapper(env, env_model, alpha=alpha)
        env = tf_py_environment.TFPyEnvironment(env)
        return env

    # In[11]:

    train_env = get_env(add_curiosity_reward=True)
    eval_env = get_env(add_curiosity_reward=False)

    tf.random.set_seed(seed_rest)

    # In[12]:

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        train_env.observation_spec(),
        train_env.action_spec(),
        fc_layer_params=fc_layer_params,
        activation_fn=tf.keras.activations.relu,
        preprocessing_layers=decoder_layer_agent
        # for features: add preprocessing_layers=[...]
    )

    # In[13]:

    optimizer = tf.compat.v1.train.GradientDescentOptimizer(
        learning_rate=learning_rate)

    train_step_counter = tf.compat.v2.Variable(0)

    tf_agent = reinforce_agent.ReinforceAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        actor_network=actor_net,
        optimizer=optimizer,
        normalize_returns=True,
        train_step_counter=train_step_counter)
    tf_agent.initialize()

    # In[14]:

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    # In[15]:

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=replay_buffer_capacity)

    # In[16]:

    curiosity_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=1000000)  # should never overflow

    # In[17]:

    decoder_layer_agent = actor_net.layers[0].layers[
        0]  # taking the copied layer with actual weights

    # In[18]:

    # (Optional) Optimize by wrapping some of the code in a graph using TF function.
    tf_agent.train = common.function(tf_agent.train)

    # Reset the train step
    tf_agent.train_step_counter.assign(0)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return(eval_env, tf_agent.policy,
                                    num_eval_episodes)
    train_avg_return = compute_avg_return(eval_env, tf_agent.collect_policy,
                                          num_eval_episodes)
    returns = [avg_return]
    train_returns = [train_avg_return]
    curiosity_loss = []

    #global w_count
    global w_history
    w_history = []

    def print_weights():
        """Show weights."""
        global w_history
        w1 = decoder_layer_agent.get_weights()[0]
        w2 = env_model.layers[2].get_weights()
        print_exp("DECODER:" + str(w1))
        print_exp("MODEL:" + str(w2))
        w_history.append([w1, w2])
        #w_count += 1

    for _ in range(num_iterations):

        # Collect a few episodes using collect_policy and save to the replay buffer.
        collect_episode(train_env, tf_agent.collect_policy,
                        collect_episodes_per_iteration,
                        [replay_buffer, curiosity_replay_buffer])

        # Use data from the buffer and update the agent's network.
        experience = replay_buffer.gather_all()
        train_loss = tf_agent.train(experience)
        replay_buffer.clear()

        print_exp("Agent train step")
        print_weights()

        step = tf_agent.train_step_counter.numpy()

        if step % log_interval == 0:
            print_exp('step = {0}: loss = {1}'.format(step, train_loss.loss))

        if step % eval_interval == 0:
            avg_return = compute_avg_return(eval_env, tf_agent.policy,
                                            num_eval_episodes)
            train_avg_return = compute_avg_return(train_env,
                                                  tf_agent.collect_policy,
                                                  num_eval_episodes)
            print_exp(
                'step = {0}: Average Return = {1} Train curiosity Average return = {2}'
                .format(step, avg_return, train_avg_return))
            returns.append(avg_return)
            train_returns.append(train_avg_return)

        if step % curiosity_interval == 0:
            xs, ys = buffer_to_dataset(curiosity_replay_buffer, v_n)

            # setting weights from the agent to the model...
            decoder_layer.set_weights(decoder_layer_agent.get_weights())
            history = env_model.fit(
                xs,
                ys,
                epochs=model_train_epochs,
                verbose=0,
                callbacks=[pruning_callbacks.UpdatePruningStep()])

            # setting weights from the model to the agent...
            decoder_layer_agent.set_weights(decoder_layer.get_weights())

            #plt.title("Loss")
            #plt.plot(history.history['loss'])
            curiosity_loss += list(history.history['loss'])
            #plt.xlabel("Epochs")
            #plt.show()
            curiosity_replay_buffer.clear()

            print_exp("Model train step")
            print_weights()

    # In[19]:

    print_exp("Curiosity loss")
    print_exp(curiosity_loss)

    plt.title("Curiosity loss")
    plt.plot(curiosity_loss)
    savefig("curiosity-loss")
    #plt.show()

    # In[20]:

    steps = range(0, num_iterations + 1, eval_interval)
    fig = plt.figure()
    fig.patch.set_facecolor('lightgreen')
    plt.title("Returns with added curiosity reward in training")
    plt.plot(steps, returns, label="eval")
    plt.plot(steps, train_returns, label="train")
    plt.ylabel('Average Return')
    plt.legend()
    plt.xlabel('Step')
    savefig("return")
    #plt.show()

    # Evaluating the model

    # In[21]:

    curiosity_replay_buffer.clear()
    collect_episode(eval_env, tf_agent.collect_policy, 25,
                    [curiosity_replay_buffer])
    collect_episode(eval_env, tf_agent.policy, 25, [curiosity_replay_buffer])
    xs, ys = buffer_to_dataset(curiosity_replay_buffer, v_n)
    eval_loss = env_model.evaluate(xs, ys)
    print_exp(eval_loss)

    # In[22]:

    print_exp(env_model.weights)

    pickle.dump([curiosity_loss, returns, train_returns, eval_loss],
                open(name + ".history.pkl", "wb"))
    pickle.dump(w_history, open(name + ".weights.pkl", "wb"))
def train_eval(
        root_dir,
        num_iterations=int(1e09),
        actor_fc_layers=(100, ),
        value_net_fc_layers=(100, ),
        use_value_network=False,
        # Params for collect
        collect_episodes_per_iteration=30,
        replay_buffer_capacity=2000,
        # Params for train
        learning_rate=1e-3,
        gamma=0.9,
        gradient_clipping=None,
        normalize_returns=True,
        value_estimation_loss_coef=0.2,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=2000,
        policy_checkpoint_interval=1000,
        rb_checkpoint_interval=4000,
        log_interval=100,
        summary_interval=100,
        summaries_flush_secs=1,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for Reinforce."""

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        eval_py_env = load_env()
        eval_py_env2 = load_env()
        tf_env = tf_py_environment.TFPyEnvironment(load_env())

        # TODO(b/127870767): Handle distributions without gin.
        actor_net = masked_networks.MaskedActorDistributionNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers)

        if use_value_network:
            value_net = masked_networks.MaskedValueNetwork(
                tf_env.time_step_spec().observation,
                fc_layer_params=value_net_fc_layers)

        tf_agent = reinforce_agent.ReinforceAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            value_network=value_net if use_value_network else None,
            value_estimation_loss_coef=value_estimation_loss_coef,
            gamma=gamma,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            normalize_returns=normalize_returns,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)
        eval_py_policy_custom_return = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        experience = replay_buffer.gather_all()
        train_op = tf_agent.train(experience)
        clear_rb_op = replay_buffer.clear()

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            # TODO(b/126239733): Remove once Periodically can be saved.
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            # Compute evaluation metrics.
            global_step_call = sess.make_callable(global_step)
            global_step_val = global_step_call()
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            clear_rb_call = sess.make_callable(clear_rb_op)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                total_loss, _ = train_step_call()
                clear_rb_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
                    print(
                        'AVG RETURN:',
                        compute_avg_return(eval_py_env2,
                                           eval_py_policy_custom_return))
예제 #5
0
"""

actor_net = actor_distribution_network.ActorDistributionNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)
"""We also need an `optimizer` to train the network we just created, and a `train_step_counter` variable to keep track of how many times the network was updated."""

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.compat.v2.Variable(0)

tf_agent = reinforce_agent.ReinforceAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    actor_network=actor_net,
    optimizer=optimizer,
    normalize_returns=True,
    train_step_counter=train_step_counter)
tf_agent.initialize()
"""## Policies

In TF-Agents, policies represent the standard notion of policies in RL: given a `time_step` produce an action or a distribution over actions. The main method is `policy_step = policy.step(time_step)` where `policy_step` is a named tuple `PolicyStep(action, state, info)`.  The `policy_step.action` is the `action` to be applied to the environment, `state` represents the state for stateful (RNN) policies and `info` may contain auxiliary information such as log probabilities of the actions.

Agents contain two policies: the main policy that is used for evaluation/deployment (agent.policy) and another policy that is used for data collection (agent.collect_policy).
"""

eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy
"""## Metrics and Evaluation