def loss( self, step, player, # home_away_race, # upgrades, available_act, minimap, # screen, act_id, act_args, act_mask, old_logp, old_v, ret, adv, ): out = self.call(player, available_act, minimap, step=step) # new pi(a|s) logp = self.logp_a(act_id, act_args, act_mask, available_act, out) delta_pi = tf.exp(logp - old_logp) pg_loss_1 = delta_pi * adv pg_loss_2 = ( tf.clip_by_value(delta_pi, 1 - self.clip_ratio, 1 + self.clip_ratio) * adv ) # expection grad log pg_loss = -tf.reduce_mean(tf.minimum(pg_loss_1, pg_loss_2)) if self.clip_value > 0: v_clip = old_v + tf.clip_by_value( out["value"] - old_v, -self.clip_value, self.clip_value ) v_clip_loss = tf.square(v_clip - ret) v_loss = tf.square(out["value"] - ret) v_loss = tf.reduce_mean(tf.maximum(v_clip_loss, v_loss)) else: v_loss = tf.reduce_mean(tf.square(out["value"] - ret)) approx_entropy = tf.reduce_mean( compute_over_actions(entropy, out, available_act, act_mask), name="entropy", ) tf.debugging.check_numerics(approx_entropy, "bad entropy") approx_kl = tf.reduce_mean(tf.square(old_logp - logp), name="kl") clip_frac = tf.reduce_mean( tf.cast(tf.greater(tf.abs(delta_pi - 1.0), self.clip_ratio), tf.float32), name="clip_frac", ) with train_summary_writer.as_default(): tf.summary.scalar("loss/pg_loss", pg_loss, step) tf.summary.scalar("loss/v_loss", v_loss, step) tf.summary.scalar("loss/approx_entropy", approx_entropy, step) tf.summary.scalar("stat/approx_kl", approx_kl, step) tf.summary.scalar("stat/clip_frac", clip_frac, step) return pg_loss + self.v_coef * v_loss - self.entropy_coef * approx_entropy
def train_step( self, step, player, available_act, minimap, # screen, act_id, act_args, act_mask, old_logp, old_v, ret, adv, ): with tf.GradientTape() as tape: ls = self.loss( step, player, available_act, minimap, # screen, act_id, act_args, act_mask, old_logp, old_v, ret, adv, ) grad = tape.gradient(ls, self.trainable_variables) with train_summary_writer.as_default(): norm_tmp = [tf.norm(g) for g in grad] tf.summary.scalar("batch/gradient_norm", tf.reduce_mean(norm_tmp), step) tf.summary.scalar("batch/gradient_norm_max", tf.reduce_max(norm_tmp), step) for g in grad: tf.debugging.check_numerics(g, "Bad grad {}".format(g)) # clip grad (https://arxiv.org/pdf/1211.5063.pdf) if self.max_grad_norm > 0.0: grad, _ = tf.clip_by_global_norm(grad, self.max_grad_norm) self.optimizer.apply_gradients(zip(grad, self.trainable_variables)) return ls
def train( env_name, batch_size, minibatch_size, updates, epochs, hparam, hp_summary_writer, save_model=False, load_path=None, ): """ Main learning function Args: batch_size: size of the buffer, may have multiple trajecties inside minibatch_size: one batch is seperated into several minibatches. Each has this size. epochs: in one epoch, buffer is fully filled, and trained multiple times with minibatches. """ actor_critic = Actor_Critic(hparam) if load_path is not None: print("Loading model ...") load_path = osp.expanduser(load_path) ckpt = tf.train.Checkpoint(model=actor_critic) manager = tf.train.CheckpointManager(ckpt, load_path, max_to_keep=5) ckpt.restore(manager.latest_checkpoint) # set env with SC2EnvWrapper( map_name=env_name, players=[sc2_env.Agent(sc2_env.Race.random)], agent_interface_format=sc2_env.parse_agent_interface_format( feature_minimap=MINIMAP_RES, feature_screen=MINIMAP_RES ), step_mul=FLAGS.step_mul, game_steps_per_episode=FLAGS.game_steps_per_episode, disable_fog=FLAGS.disable_fog, ) as env: actor_critic.set_act_spec(env.action_spec()[0]) # assume one agent def train_one_update(step, epochs, tracing_on): # initialize replay buffer buffer = Buffer( batch_size, minibatch_size, MINIMAP_RES, MINIMAP_RES, env.action_spec()[0], ) # initial observation timestep = env.reset() step_type, reward, _, obs = timestep[0] obs = preprocess(obs) ep_ret = [] # episode return (score) ep_rew = 0 # fill in recorded trajectories while True: tf_obs = ( tf.constant(each_obs, shape=(1, *each_obs.shape)) for each_obs in obs ) val, act_id, arg_spatial, arg_nonspatial, logp_a = actor_critic.step( *tf_obs ) sc2act_args = translateActionToSC2( arg_spatial, arg_nonspatial, MINIMAP_RES, MINIMAP_RES ) act_mask = get_mask(act_id.numpy().item(), actor_critic.action_spec) buffer.add( *obs, act_id.numpy().item(), sc2act_args, act_mask, logp_a.numpy().item(), val.numpy().item() ) step_type, reward, _, obs = env.step( [actions.FunctionCall(act_id.numpy().item(), sc2act_args)] )[0] # print("action:{}: {} reward {}".format(act_id.numpy().item(), sc2act_args, reward)) buffer.add_rew(reward) obs = preprocess(obs) ep_rew += reward if step_type == step_type.LAST or buffer.is_full(): if step_type == step_type.LAST: buffer.finalize(0) else: # trajectory is cut off, bootstrap last state with estimated value tf_obs = ( tf.constant(each_obs, shape=(1, *each_obs.shape)) for each_obs in obs ) val, _, _, _, _ = actor_critic.step(*tf_obs) buffer.finalize(val) ep_rew += reward ep_ret.append(ep_rew) ep_rew = 0 if buffer.is_full(): break # respawn env env.render(True) timestep = env.reset() _, _, _, obs = timestep[0] obs = preprocess(obs) # train in minibatches buffer.post_process() mb_loss = [] for ep in range(epochs): buffer.shuffle() for ind in range(batch_size // minibatch_size): ( player, available_act, minimap, # screen, act_id, act_args, act_mask, logp, val, ret, adv, ) = buffer.minibatch(ind) assert ret.shape == val.shape assert logp.shape == adv.shape if tracing_on: tf.summary.trace_on(graph=True, profiler=False) mb_loss.append( actor_critic.train_step( tf.constant(step, dtype=tf.int64), player, available_act, minimap, # screen, act_id, act_args, act_mask, logp, val, ret, adv, ) ) step += 1 if tracing_on: tracing_on = False with train_summary_writer.as_default(): tf.summary.trace_export(name="train_step", step=0) batch_loss = np.mean(mb_loss) return ( batch_loss, ep_ret, buffer.batch_ret, np.asarray(buffer.batch_vals, dtype=np.float32), ) num_train_per_update = epochs * (batch_size // minibatch_size) for i in range(updates): if i == 0: tracing_on = True else: tracing_on = False batch_loss, cumulative_rew, batch_ret, batch_vals = train_one_update( i * num_train_per_update, epochs, tracing_on ) ev = explained_variance(batch_vals, batch_ret) with train_summary_writer.as_default(): tf.summary.scalar( "batch/cumulative_rewards", np.mean(cumulative_rew), step=i ) tf.summary.scalar("batch/ev", ev, step=i) tf.summary.scalar("loss/batch_loss", batch_loss, step=i) with hp_summary_writer.as_default(): tf.summary.scalar("rewards", np.mean(cumulative_rew), step=i) print("----------------------------") print( "epoch {0:2d} loss {1:.3f} batch_ret {2:.3f}".format( i, batch_loss, np.mean(cumulative_rew) ) ) print("----------------------------") # save model if save_model and i % 15 == 0: print("saving model ...") save_path = osp.expanduser(saved_model_dir) ckpt = tf.train.Checkpoint(model=actor_critic) manager = tf.train.CheckpointManager(ckpt, save_path, max_to_keep=3) manager.save()
def call( self, player, available_act, minimap, # screen, step=None, ): """ Embedding of inputs """ """ Scalar features These are embedding of scalar features """ player = tf.stop_gradient(player) available_act = tf.stop_gradient(available_act) minimap = tf.stop_gradient(minimap) embed_player = self.embed_player(tf.stop_gradient(tf.math.log(player + 1.0))) # embed_race = self.embed_race( # tf.reshape(tf.one_hot(home_away_race, depth=4), shape=[-1, 8]) # ) # embed_upgrades = self.embed_upgrads(upgrades) embed_available_act = self.embed_available_act(available_act) scalar_out = tf.concat([embed_player, embed_available_act], axis=-1) """ Map features These are embedding of map features """ def one_hot_map(obs, screen_on=False): assert len(obs.shape) == 4 if screen_on: Features = features.SCREEN_FEATURES else: Features = features.MINIMAP_FEATURES out = [] for ind, feature in enumerate(Features): if feature.type is features.FeatureType.CATEGORICAL: one_hot = tf.one_hot(obs[:, :, :, ind], depth=feature.scale) else: # features.FeatureType.SCALAR # FIXME: different screen feature has different scaling one_hot = ( tf.cast(obs[:, :, :, ind : ind + 1], dtype=tf.float32) / 255.0 ) out.append(one_hot) out = tf.concat(out, axis=-1) return out one_hot_minimap = tf.stop_gradient(one_hot_map(minimap)) embed_minimap = self.embed_minimap(one_hot_minimap) embed_minimap = self.embed_minimap_2(embed_minimap) # one_hot_screen = tf.stop_gradient(one_hot_map(screen, screen_on=True)) # embed_screen = self.embed_screen(one_hot_screen) # embed_screen = self.embed_screen_2(embed_screen) if step is not None: with train_summary_writer.as_default(): tf.summary.image( "embed_minimap", tf.transpose(embed_minimap[2:3, :, :, :], (3, 1, 2, 0)), step=step, max_outputs=5, ) # tf.summary.image( # "embed_screen", # tf.transpose(embed_screen[2:3, :, :, :], (3, 1, 2, 0)), # step=step, # max_outputs=5, # ) # TODO: entities feature """ State representation """ # core scalar_out_2d = tf.tile( tf.expand_dims(tf.expand_dims(scalar_out, 1), 2), [1, embed_minimap.shape[1], embed_minimap.shape[2], 1], ) core_out = tf.concat([scalar_out_2d, embed_minimap], axis=-1, name="core") core_out_flat = self.flat(core_out) core_out_flat = self.core_fc(core_out_flat) """ Decision output """ # value value_out = self.value(core_out_flat) # action id action_id_out = self.action_id_layer(core_out_flat) action_id_out = self.action_id_gate(action_id_out, embed_available_act) # delay # delay_out = self.delay_logits(core_out_flat) # queued queued_out = self.queued_logits(core_out_flat) # selected units select_point_out = self.select_point_logits(core_out_flat) select_add_out = self.select_add_logits(core_out_flat) select_worker_out = self.select_worker_logits(core_out_flat) # target unit # target_unit_out = self.target_unit_logits(core_out_flat) # target location target_location_out = self.target_location_logits(core_out) target_location_out = self.target_location_flat(target_location_out) out = { "value": value_out, "action_id": action_id_out, "queued": queued_out, "select_point_act": select_point_out, "select_add": select_add_out, "select_worker": select_worker_out, "target_location": target_location_out, } return out
def train_one_update(step, epochs, tracing_on): # initialize replay buffer buffer = Buffer( batch_size, minibatch_size, MINIMAP_RES, MINIMAP_RES, env.action_spec()[0], ) # initial observation timestep = env.reset() step_type, reward, _, obs = timestep[0] obs = preprocess(obs) ep_ret = [] # episode return (score) ep_rew = 0 # fill in recorded trajectories while True: tf_obs = ( tf.constant(each_obs, shape=(1, *each_obs.shape)) for each_obs in obs ) val, act_id, arg_spatial, arg_nonspatial, logp_a = actor_critic.step( *tf_obs ) sc2act_args = translateActionToSC2( arg_spatial, arg_nonspatial, MINIMAP_RES, MINIMAP_RES ) act_mask = get_mask(act_id.numpy().item(), actor_critic.action_spec) buffer.add( *obs, act_id.numpy().item(), sc2act_args, act_mask, logp_a.numpy().item(), val.numpy().item() ) step_type, reward, _, obs = env.step( [actions.FunctionCall(act_id.numpy().item(), sc2act_args)] )[0] # print("action:{}: {} reward {}".format(act_id.numpy().item(), sc2act_args, reward)) buffer.add_rew(reward) obs = preprocess(obs) ep_rew += reward if step_type == step_type.LAST or buffer.is_full(): if step_type == step_type.LAST: buffer.finalize(0) else: # trajectory is cut off, bootstrap last state with estimated value tf_obs = ( tf.constant(each_obs, shape=(1, *each_obs.shape)) for each_obs in obs ) val, _, _, _, _ = actor_critic.step(*tf_obs) buffer.finalize(val) ep_rew += reward ep_ret.append(ep_rew) ep_rew = 0 if buffer.is_full(): break # respawn env env.render(True) timestep = env.reset() _, _, _, obs = timestep[0] obs = preprocess(obs) # train in minibatches buffer.post_process() mb_loss = [] for ep in range(epochs): buffer.shuffle() for ind in range(batch_size // minibatch_size): ( player, available_act, minimap, # screen, act_id, act_args, act_mask, logp, val, ret, adv, ) = buffer.minibatch(ind) assert ret.shape == val.shape assert logp.shape == adv.shape if tracing_on: tf.summary.trace_on(graph=True, profiler=False) mb_loss.append( actor_critic.train_step( tf.constant(step, dtype=tf.int64), player, available_act, minimap, # screen, act_id, act_args, act_mask, logp, val, ret, adv, ) ) step += 1 if tracing_on: tracing_on = False with train_summary_writer.as_default(): tf.summary.trace_export(name="train_step", step=0) batch_loss = np.mean(mb_loss) return ( batch_loss, ep_ret, buffer.batch_ret, np.asarray(buffer.batch_vals, dtype=np.float32), )