예제 #1
0
        def train_one_update(step, epochs, tracing_on):
            # initialize replay buffer
            buffer = Buffer(
                batch_size,
                minibatch_size,
                MINIMAP_RES,
                MINIMAP_RES,
                env.action_spec()[0],
            )

            # initial observation
            timestep = env.reset()
            step_type, reward, _, obs = timestep[0]
            obs = preprocess(obs)

            ep_ret = []  # episode return (score)
            ep_rew = 0

            # fill in recorded trajectories
            while True:
                tf_obs = (
                    tf.constant(each_obs, shape=(1, *each_obs.shape))
                    for each_obs in obs
                )

                val, act_id, arg_spatial, arg_nonspatial, logp_a = actor_critic.step(
                    *tf_obs
                )

                sc2act_args = translateActionToSC2(
                    arg_spatial, arg_nonspatial, MINIMAP_RES, MINIMAP_RES
                )

                act_mask = get_mask(act_id.numpy().item(), actor_critic.action_spec)
                buffer.add(
                    *obs,
                    act_id.numpy().item(),
                    sc2act_args,
                    act_mask,
                    logp_a.numpy().item(),
                    val.numpy().item()
                )
                step_type, reward, _, obs = env.step(
                    [actions.FunctionCall(act_id.numpy().item(), sc2act_args)]
                )[0]
                # print("action:{}: {} reward {}".format(act_id.numpy().item(), sc2act_args, reward))
                buffer.add_rew(reward)
                obs = preprocess(obs)

                ep_rew += reward

                if step_type == step_type.LAST or buffer.is_full():
                    if step_type == step_type.LAST:
                        buffer.finalize(0)
                    else:
                        # trajectory is cut off, bootstrap last state with estimated value
                        tf_obs = (
                            tf.constant(each_obs, shape=(1, *each_obs.shape))
                            for each_obs in obs
                        )
                        val, _, _, _, _ = actor_critic.step(*tf_obs)
                        buffer.finalize(val)

                    ep_rew += reward
                    ep_ret.append(ep_rew)
                    ep_rew = 0

                    if buffer.is_full():
                        break

                    # respawn env
                    env.render(True)
                    timestep = env.reset()
                    _, _, _, obs = timestep[0]
                    obs = preprocess(obs)

            # train in minibatches
            buffer.post_process()

            mb_loss = []
            for ep in range(epochs):
                buffer.shuffle()

                for ind in range(batch_size // minibatch_size):
                    (
                        player,
                        available_act,
                        minimap,
                        # screen,
                        act_id,
                        act_args,
                        act_mask,
                        logp,
                        val,
                        ret,
                        adv,
                    ) = buffer.minibatch(ind)

                    assert ret.shape == val.shape
                    assert logp.shape == adv.shape
                    if tracing_on:
                        tf.summary.trace_on(graph=True, profiler=False)

                    mb_loss.append(
                        actor_critic.train_step(
                            tf.constant(step, dtype=tf.int64),
                            player,
                            available_act,
                            minimap,
                            # screen,
                            act_id,
                            act_args,
                            act_mask,
                            logp,
                            val,
                            ret,
                            adv,
                        )
                    )
                    step += 1

                    if tracing_on:
                        tracing_on = False
                        with train_summary_writer.as_default():
                            tf.summary.trace_export(name="train_step", step=0)

            batch_loss = np.mean(mb_loss)

            return (
                batch_loss,
                ep_ret,
                buffer.batch_ret,
                np.asarray(buffer.batch_vals, dtype=np.float32),
            )