import minerl
import coloredlogs
import logging
coloredlogs.install(logging.DEBUG)
from minerl.env.malmo import launch_instance_manager

if __name__ == "__main__":
    launch_instance_manager()
Beispiel #2
0
def main():
    malmo_base_port = FLAGS.malmo_base_port
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpus
    malmo.InstanceManager.configure_malmo_base_port(malmo_base_port)

    observation_space = CustomObservationSpace(
        pov_resolution=FLAGS.pov_resolution,
        pov_color_space=FLAGS.pov_color_space)

    action_space = CustomActionSpace(
        num_camera_actions=FLAGS.num_camera_actions,
        camera_max_angle=FLAGS.camera_max_angle)

    def combined_actor_critic_agent():
        return ResnetLSTMAgent(observation_space=observation_space,
                               action_space=action_space,
                               max_step_mul=FLAGS.max_step_mul,
                               core_hidden_size=FLAGS.lstm_hidden_size,
                               use_prev_actions=FLAGS.use_prev_actions,
                               action_embed_type=FLAGS.action_embed_type,
                               action_embed_size=FLAGS.action_embed_size)

    def separate_actor_critic_agent():
        return SeparateActorCriticWrapperAgent(
            actor=combined_actor_critic_agent(),
            critic=combined_actor_critic_agent())

    if FLAGS.separate_actor_critic:
        agent_fn = separate_actor_critic_agent
    else:
        agent_fn = combined_actor_critic_agent

    log_dir = FLAGS.logdir

    # Training Phase
    if EVALUATION_STAGE in ['all', 'training']:
        # only write out flags when training
        pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)
        FLAGS.append_flags_into_file(
            f'{log_dir}/flags_{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.cfg'
        )

        aicrowd_helper.training_start()
        try:
            train.main(log_dir=log_dir,
                       load_dir=FLAGS.loaddir,
                       observation_space=observation_space,
                       action_space=action_space,
                       max_step_mul=FLAGS.max_step_mul,
                       fixed_step_mul=FLAGS.fixed_step_mul,
                       step_mul=FLAGS.step_mul,
                       agent_fn=agent_fn,
                       seed=FLAGS.train_seed,
                       malmo_base_port=malmo_base_port)
            aicrowd_helper.training_end()
        except Exception as e:
            aicrowd_helper.training_error()
            print(traceback.format_exc())
            print(e)

    # Testing Phase
    if EVALUATION_STAGE in ['all', 'testing']:
        if EVALUATION_RUNNING_ON in ['local']:
            try:
                os.remove(EXITED_SIGNAL_PATH)
            except FileNotFoundError:
                pass
        aicrowd_helper.inference_start()
        try:
            test.main(log_dir=log_dir,
                      test_model=FLAGS.test_model,
                      observation_space=observation_space,
                      action_space=action_space,
                      fixed_step_mul=FLAGS.fixed_step_mul,
                      step_mul=FLAGS.step_mul,
                      agent_fn=agent_fn)
            aicrowd_helper.inference_end()
        except Exception as e:
            aicrowd_helper.inference_error()
            print(traceback.format_exc())
            print(e)
        if EVALUATION_RUNNING_ON in ['local']:
            from pathlib import Path
            Path(EXITED_SIGNAL_PATH).touch()

    # Launch instance manager
    if EVALUATION_STAGE in ['manager']:
        from minerl.env.malmo import launch_instance_manager
        launch_instance_manager()