def train_one_update(step, epochs, tracing_on): # initialize replay buffer buffer = Buffer( batch_size, minibatch_size, MINIMAP_RES, MINIMAP_RES, env.action_spec()[0], ) # initial observation timestep = env.reset() step_type, reward, _, obs = timestep[0] obs = preprocess(obs) ep_ret = [] # episode return (score) ep_rew = 0 # fill in recorded trajectories while True: tf_obs = ( tf.constant(each_obs, shape=(1, *each_obs.shape)) for each_obs in obs ) val, act_id, arg_spatial, arg_nonspatial, logp_a = actor_critic.step( *tf_obs ) sc2act_args = translateActionToSC2( arg_spatial, arg_nonspatial, MINIMAP_RES, MINIMAP_RES ) act_mask = get_mask(act_id.numpy().item(), actor_critic.action_spec) buffer.add( *obs, act_id.numpy().item(), sc2act_args, act_mask, logp_a.numpy().item(), val.numpy().item() ) step_type, reward, _, obs = env.step( [actions.FunctionCall(act_id.numpy().item(), sc2act_args)] )[0] # print("action:{}: {} reward {}".format(act_id.numpy().item(), sc2act_args, reward)) buffer.add_rew(reward) obs = preprocess(obs) ep_rew += reward if step_type == step_type.LAST or buffer.is_full(): if step_type == step_type.LAST: buffer.finalize(0) else: # trajectory is cut off, bootstrap last state with estimated value tf_obs = ( tf.constant(each_obs, shape=(1, *each_obs.shape)) for each_obs in obs ) val, _, _, _, _ = actor_critic.step(*tf_obs) buffer.finalize(val) ep_rew += reward ep_ret.append(ep_rew) ep_rew = 0 if buffer.is_full(): break # respawn env env.render(True) timestep = env.reset() _, _, _, obs = timestep[0] obs = preprocess(obs) # train in minibatches buffer.post_process() mb_loss = [] for ep in range(epochs): buffer.shuffle() for ind in range(batch_size // minibatch_size): ( player, available_act, minimap, # screen, act_id, act_args, act_mask, logp, val, ret, adv, ) = buffer.minibatch(ind) assert ret.shape == val.shape assert logp.shape == adv.shape if tracing_on: tf.summary.trace_on(graph=True, profiler=False) mb_loss.append( actor_critic.train_step( tf.constant(step, dtype=tf.int64), player, available_act, minimap, # screen, act_id, act_args, act_mask, logp, val, ret, adv, ) ) step += 1 if tracing_on: tracing_on = False with train_summary_writer.as_default(): tf.summary.trace_export(name="train_step", step=0) batch_loss = np.mean(mb_loss) return ( batch_loss, ep_ret, buffer.batch_ret, np.asarray(buffer.batch_vals, dtype=np.float32), )