Esempio n. 1
0
def generate_video(args):

    total_time = args.video_length * 100
    exp_path = os.path.join(DATA_DIR, "EXP_{:04d}".format(args.expID))
    if not os.path.exists(exp_path):
        raise FileNotFoundError('checkpoint does not exist')
    print('*** folder fetched: {} ***'.format(exp_path))
    os.makedirs(VIDEO_DIR, exist_ok=True)

    # Retrieve MuJoCo XML files for visualizing ========================================
    env_names = []
    args.graphs = dict()
    # existing envs
    if not args.custom_xml:
        for morphology in args.morphologies:
            env_names += [
                name[:-4] for name in os.listdir(XML_DIR)
                if '.xml' in name and morphology in name
            ]
        for name in env_names:
            args.graphs[name] = utils.getGraphStructure(
                os.path.join(XML_DIR, '{}.xml'.format(name)))
    # custom envs
    else:
        if os.path.isfile(args.custom_xml):
            assert '.xml' in os.path.basename(
                args.custom_xml), 'No XML file found.'
            name = os.path.basename(args.custom_xml)
            env_names.append(name[:-4])  # truncate the .xml suffix
            args.graphs[name[:-4]] = utils.getGraphStructure(args.custom_xml)
        elif os.path.isdir(args.custom_xml):
            for name in os.listdir(args.custom_xml):
                if '.xml' in name:
                    env_names.append(name[:-4])
                    args.graphs[name[:-4]] = utils.getGraphStructure(
                        os.path.join(args.custom_xml, name))
    env_names.sort()

    # Set up env and policy ================================================
    args.limb_obs_size, args.max_action = utils.registerEnvs(
        env_names, args.max_episode_steps, args.custom_xml)
    # determine the maximum number of children in all the envs
    if args.max_children is None:
        args.max_children = utils.findMaxChildren(env_names, args.graphs)
    # setup agent policy
    policy = TD3.TD3(args)

    try:
        cp.load_model_only(exp_path, policy)
    except:
        raise Exception(
            'policy loading failed; check policy params (hint 1: max_children must be the same as the trained policy; hint 2: did the trained policy use torchfold (consider pass --disable_fold)?'
        )

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # visualize ===========================================================
    for env_name in env_names:
        # create env
        env = utils.makeEnvWrapper(env_name, seed=args.seed,
                                   obs_max_len=None)()
        policy.change_morphology(args.graphs[env_name])

        # create unique temp frame dir
        count = 0
        frame_dir = os.path.join(
            VIDEO_DIR, "frames_{}_{}_{}".format(args.expID, env_name, count))
        while os.path.exists(frame_dir):
            count += 1
            frame_dir = "{}/frames_{}_{}_{}".format(VIDEO_DIR, args.expID,
                                                    env_name, count)
        os.makedirs(frame_dir)
        # create video name without overwriting previously generated videos
        count = 0
        video_name = "%04d_%s_%d" % (args.expID, env_name, count)
        while os.path.exists("{}/{}.mp4".format(VIDEO_DIR, video_name)):
            count += 1
            video_name = "%04d_%s_%d" % (args.expID, env_name, count)

        # init env vars
        done = True
        print("-" * 50)
        time_step_counter = 0
        printProgressBar(0, total_time)

        while time_step_counter < total_time:
            printProgressBar(time_step_counter + 1,
                             total_time,
                             prefix=env_name)
            if done:
                obs = env.reset()
                done = False
                episode_reward = 0
            action = policy.select_action(np.array(obs))
            # perform action in the environment
            new_obs, reward, done, _ = env.step(action)
            episode_reward += reward
            # draw image of current frame
            image_data = env.sim.render(VIDEO_RESOLUATION[0],
                                        VIDEO_RESOLUATION[1],
                                        camera_name="track")
            img = Image.fromarray(image_data, "RGB")
            draw = ImageDraw.Draw(img)
            font = ImageFont.truetype('./misc/sans-serif.ttf', 24)
            draw.text((200, 10),
                      "Instant Reward: " + str(reward), (255, 0, 0),
                      font=font)
            draw.text((200, 35),
                      "Episode Reward: " + str(episode_reward), (255, 0, 0),
                      font=font)
            img.save(
                os.path.join(frame_dir, "frame-%.10d.png" % time_step_counter))

            obs = new_obs
            time_step_counter += 1

        # redirect output so output does not show on window
        FNULL = open(os.devnull, 'w')
        # create video
        subprocess.call([
            'ffmpeg', '-framerate', '50', '-y', '-i',
            os.path.join(frame_dir, 'frame-%010d.png'), '-r', '30', '-pix_fmt',
            'yuv420p',
            os.path.join(VIDEO_DIR, '{}.mp4'.format(video_name))
        ],
                        stdout=FNULL,
                        stderr=subprocess.STDOUT)
        subprocess.call(['rm', '-rf', frame_dir])
Esempio n. 2
0
def train(args):

    # Set up directories ===========================================================
    os.makedirs(DATA_DIR, exist_ok=True)
    os.makedirs(BUFFER_DIR, exist_ok=True)
    exp_name = "EXP_%04d" % (args.expID)
    exp_path = os.path.join(DATA_DIR, exp_name)
    rb_path = os.path.join(BUFFER_DIR, exp_name)
    os.makedirs(exp_path, exist_ok=True)
    os.makedirs(rb_path, exist_ok=True)
    # save arguments
    with open(os.path.join(exp_path, 'args.txt'), 'w+') as f:
        json.dump(args.__dict__, f, indent=2)

    # Retrieve MuJoCo XML files for training ========================================
    agent_name = args.agent_name
    envs_train_names = [agent_name]
    args.graphs = dict()
    # existing envs
    if not args.custom_xml:
        args.graphs[agent_name] = utils.getGraphStructure(
            os.path.join(XML_DIR, '{}.xml'.format(agent_name)))
    # custom envs

    num_envs_train = len(envs_train_names)
    print("#" * 50 + '\ntraining envs: {}\n'.format(envs_train_names) +
          "#" * 50)

    # Set up training env and policy ================================================
    args.limb_obs_size, args.max_action = utils.registerEnvs(
        envs_train_names, args.max_episode_steps, args.custom_xml)
    max_num_limbs = max(
        [len(args.graphs[env_name]) for env_name in envs_train_names])
    # create vectorized training env
    obs_max_len = max(
        [len(args.graphs[env_name])
         for env_name in envs_train_names]) * args.limb_obs_size
    envs_train = [
        utils.makeEnvWrapper(name, obs_max_len, args.seed)
        for name in envs_train_names
    ]
    # envs_train = SubprocVecEnv(envs_train)  # vectorized env
    # set random seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    # determine the maximum number of children in all the training envs
    if args.max_children is None:
        args.max_children = utils.findMaxChildren(envs_train_names,
                                                  args.graphs)
    # setup agent policy
    policy = TD3.LifeLongTD3(args)

    # Create new training instance or load previous checkpoint ========================
    if cp.has_checkpoint(exp_path, rb_path):
        print("*** loading checkpoint from {} ***".format(exp_path))
        total_timesteps, episode_num, replay_buffer, num_samples, loaded_path = cp.load_checkpoint(
            exp_path, rb_path, policy, args)
        print("*** checkpoint loaded from {} ***".format(loaded_path))
    else:
        print("*** training from scratch ***")
        # init training vars
        total_timesteps = 0
        episode_num = 0
        num_samples = 0
        # different replay buffer for each env; avoid using too much memory if there are too many envs

    # Initialize training variables ================================================
    writer = SummaryWriter("%s/%s/" % (DATA_DIR, exp_name))
    s = time.time()
    # TODO: may have to change the following codes into the loop
    timesteps_since_saving = 0
    this_training_timesteps = 0
    episode_timesteps = 0
    episode_reward = 0
    episode_reward_buffer = 0
    done = True

    # Start training ===========================================================
    for env_handle, env_name in zip(envs_train, envs_train_names):
        env = env_handle()
        obs = env.reset()
        replay_buffer = utils.ReplayBuffer(max_size=args.rb_max)
        policy.change_morphology(args.graphs[env_name])
        policy.graph = args.graphs[env_name]
        task_timesteps = 0
        done = False
        episode_timesteps = 0
        episode_reward = 0
        episode_reward_buffer = 0
        while task_timesteps < args.max_timesteps:
            # train and log after one episode for each env
            if done:
                # log updates and train policy
                if this_training_timesteps != 0:
                    policy.train(replay_buffer,
                                 episode_timesteps,
                                 args.batch_size,
                                 args.discount,
                                 args.tau,
                                 args.policy_noise,
                                 args.noise_clip,
                                 args.policy_freq,
                                 graphs=args.graphs,
                                 env_name=env_name)
                    # add to tensorboard display

                    writer.add_scalar('{}_episode_reward'.format(env_name),
                                      episode_reward, task_timesteps)
                    writer.add_scalar('{}_episode_len'.format(env_name),
                                      episode_timesteps, task_timesteps)
                    # print to console
                    print(
                        "-" * 50 +
                        "\nExpID: {}, FPS: {:.2f}, TotalT: {}, EpisodeNum: {}, SampleNum: {}, ReplayBSize: {}"
                        .format(args.expID, this_training_timesteps /
                                (time.time() -
                                 s), total_timesteps, episode_num, num_samples,
                                len(replay_buffer.storage)))
                    print("{} === EpisodeT: {}, Reward: {:.2f}".format(
                        env_name, episode_timesteps, episode_reward))
                    this_training_timesteps = 0
                    s = time.time()

                # save model and replay buffers
                if timesteps_since_saving >= args.save_freq:
                    print("!!!!!")
                    timesteps_since_saving = 0
                    model_saved_path = cp.save_model(exp_path, policy,
                                                     total_timesteps,
                                                     episode_num, num_samples,
                                                     {env_name: replay_buffer},
                                                     envs_train_names, args)
                    print("*** model saved to {} ***".format(model_saved_path))
                    rb_saved_path = cp.save_replay_buffer(
                        rb_path, {env_name: replay_buffer})
                    print("*** replay buffers saved to {} ***".format(
                        rb_saved_path))

                # reset training variables
                obs = env.reset()
                done = False
                episode_reward = 0
                episode_timesteps = 0
                episode_num += 1
                # create reward buffer to store reward for one sub-env when it is not done
                episode_reward_buffer = 0

            # start sampling ===========================================================
            # sample action randomly for sometime and then according to the policy
            if task_timesteps < args.start_timesteps:
                action = np.random.uniform(low=env.action_space.low[0],
                                           high=env.action_space.high[0],
                                           size=max_num_limbs)
            else:
                # remove 0 padding of obs before feeding into the policy (trick for vectorized env)
                obs = np.array(obs[:args.limb_obs_size *
                                   len(args.graphs[env_name])])
                policy_action = policy.select_action(obs)
                if args.expl_noise != 0:
                    policy_action = (policy_action + np.random.normal(
                        0, args.expl_noise, size=policy_action.size)).clip(
                            env.action_space.low[0], env.action_space.high[0])
                # add 0-padding to ensure that size is the same for all envs
                action = np.append(
                    policy_action,
                    np.array([
                        0 for i in range(max_num_limbs - policy_action.size)
                    ]))

            # perform action in the environment
            new_obs, reward, done, _ = env.step(action)

            # record if each env has ever been 'done'

            # add the instant reward to the cumulative buffer
            # if any sub-env is done at the momoent, set the episode reward list to be the value in the buffer
            episode_reward_buffer += reward
            if done and episode_reward == 0:
                episode_reward = episode_reward_buffer
                episode_reward_buffer = 0
            writer.add_scalar('{}_instant_reward'.format(env_name), reward,
                              task_timesteps)
            done_bool = float(done)
            if episode_timesteps + 1 == args.max_episode_steps:
                done_bool = 0
                done = True
            # remove 0 padding before storing in the replay buffer (trick for vectorized env)
            num_limbs = len(args.graphs[env_name])
            obs = np.array(obs[:args.limb_obs_size * num_limbs])
            new_obs = np.array(new_obs[:args.limb_obs_size * num_limbs])
            action = np.array(action[:num_limbs])
            # insert transition in the replay buffer
            replay_buffer.add((obs, new_obs, action, reward, done_bool))
            num_samples += 1
            # do not increment episode_timesteps if the sub-env has been 'done'
            if not done:
                episode_timesteps += 1
                total_timesteps += 1
                task_timesteps += 1
                this_training_timesteps += 1
                timesteps_since_saving += 1

            obs = new_obs
        policy.next_task()

    # save checkpoint after training ===========================================================
    model_saved_path = cp.save_model(exp_path, policy, total_timesteps,
                                     episode_num, num_samples,
                                     {envs_train_names[-1]: replay_buffer},
                                     envs_train_names, args)
    print("*** training finished and model saved to {} ***".format(
        model_saved_path))
Esempio n. 3
0
from skeleton_graph_vae.graph_dataset import GraphDataset
import envs
import gym
import utils
from graph_vae import wrappers
from graph_vae import graph_dataset
import numpy as np

ant = gym.make("envs:ant-v0")

ant = wrappers.ModularEnvWrapper(ant)
xml = ant.xml
g = utils.getGraphStructure(xml)

ds = GraphDataset("data/ant.memory", g, 27)

s = ant.reset()

print(s.shape)
Esempio n. 4
0
def train(_run):
    # Set up directories ===========================================================
    os.makedirs(DATA_DIR, exist_ok=True)
    os.makedirs(BUFFER_DIR, exist_ok=True)
    exp_name = args.expID
    exp_path = os.path.join(DATA_DIR, exp_name)
    rb_path = os.path.join(BUFFER_DIR, exp_name)
    os.makedirs(exp_path, exist_ok=True)
    os.makedirs(rb_path, exist_ok=True)
    # save arguments
    with open(os.path.join(exp_path, "args.txt"), "w+") as f:
        json.dump(args.__dict__, f, indent=2)

    # Retrieve MuJoCo XML files for training ========================================
    envs_train_names = []
    args.graphs = dict()
    # existing envs
    if not args.custom_xml:
        for morphology in args.morphologies:
            envs_train_names += [
                name[:-4] for name in os.listdir(XML_DIR)
                if ".xml" in name and morphology in name
            ]
        for name in envs_train_names:
            args.graphs[name] = utils.getGraphStructure(
                os.path.join(XML_DIR, "{}.xml".format(name)),
                args.observation_graph_type,
            )
    # custom envs
    else:
        if os.path.isfile(args.custom_xml):
            assert ".xml" in os.path.basename(
                args.custom_xml), "No XML file found."
            name = os.path.basename(args.custom_xml)
            envs_train_names.append(name[:-4])  # truncate the .xml suffix
            args.graphs[name[:-4]] = utils.getGraphStructure(
                args.custom_xml, args.observation_graph_type)
        elif os.path.isdir(args.custom_xml):
            for name in os.listdir(args.custom_xml):
                if ".xml" in name:
                    envs_train_names.append(name[:-4])
                    args.graphs[name[:-4]] = utils.getGraphStructure(
                        os.path.join(args.custom_xml, name),
                        args.observation_graph_type)

    envs_train_names.sort()
    num_envs_train = len(envs_train_names)
    print("#" * 50 + "\ntraining envs: {}\n".format(envs_train_names) +
          "#" * 50)

    # Set up training env and policy ================================================
    args.limb_obs_size, args.max_action = utils.registerEnvs(
        envs_train_names, args.max_episode_steps, args.custom_xml)
    max_num_limbs = max(
        [len(args.graphs[env_name]) for env_name in envs_train_names])
    # create vectorized training env
    obs_max_len = (
        max([len(args.graphs[env_name])
             for env_name in envs_train_names]) * args.limb_obs_size)
    envs_train = [
        utils.makeEnvWrapper(name, obs_max_len, args.seed)
        for name in envs_train_names
    ]

    envs_train = SubprocVecEnv(envs_train)  # vectorized env
    # set random seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    # determine the maximum number of children in all the training envs
    if args.max_children is None:
        args.max_children = utils.findMaxChildren(envs_train_names,
                                                  args.graphs)

    args.max_num_limbs = max_num_limbs
    # setup agent policy
    policy = TD3.TD3(args)

    # Create new training instance or load previous checkpoint ========================
    if cp.has_checkpoint(exp_path, rb_path):
        print("*** loading checkpoint from {} ***".format(exp_path))
        (
            total_timesteps,
            episode_num,
            replay_buffer,
            num_samples,
            loaded_path,
        ) = cp.load_checkpoint(exp_path, rb_path, policy, args)
        print("*** checkpoint loaded from {} ***".format(loaded_path))
    else:
        print("*** training from scratch ***")
        # init training vars
        total_timesteps = 0
        episode_num = 0
        num_samples = 0
        # different replay buffer for each env; avoid using too much memory if there are too many envs
        replay_buffer = dict()
        if num_envs_train > args.rb_max // 1e6:
            for name in envs_train_names:
                replay_buffer[name] = utils.ReplayBuffer(
                    max_size=args.rb_max // num_envs_train)
        else:
            for name in envs_train_names:
                replay_buffer[name] = utils.ReplayBuffer()

    # Initialize training variables ================================================
    writer = SummaryWriter("%s/%s/" % (DATA_DIR, exp_name))
    s = time.time()
    timesteps_since_saving = 0
    timesteps_since_saving_model_only = 0
    this_training_timesteps = 0
    collect_done = True
    episode_timesteps_list = [0 for i in range(num_envs_train)]
    done_list = [True for i in range(num_envs_train)]

    # Start training ===========================================================
    model_savings_so_far = 0
    while total_timesteps < args.max_timesteps:

        # train and log after one episode for each env
        if collect_done:
            # log updates and train policy
            if this_training_timesteps != 0:
                policy.train(
                    replay_buffer,
                    episode_timesteps_list,
                    args.batch_size,
                    args.discount,
                    args.tau,
                    args.policy_noise,
                    args.noise_clip,
                    args.policy_freq,
                    graphs=args.graphs,
                    envs_train_names=envs_train_names[:num_envs_train],
                )
                # add to tensorboard display
                for i in range(num_envs_train):
                    writer.add_scalar(
                        "{}_episode_reward".format(envs_train_names[i]),
                        episode_reward_list[i],
                        total_timesteps,
                    )
                    writer.add_scalar(
                        "{}_episode_len".format(envs_train_names[i]),
                        episode_timesteps_list[i],
                        total_timesteps,
                    )
                    if not args.debug:
                        ex.log_scalar(
                            f"{envs_train_names[i]}_episode_reward",
                            float(episode_reward_list[i]),
                            total_timesteps,
                        )
                        ex.log_scalar(
                            f"{envs_train_names[i]}_episode_len",
                            float(episode_timesteps_list[i]),
                            total_timesteps,
                        )
                if not args.debug:
                    ex.log_scalar(
                        "total_timesteps",
                        float(total_timesteps),
                        total_timesteps,
                    )
                # print to console
                print(
                    "-" * 50 +
                    "\nExpID: {}, FPS: {:.2f}, TotalT: {}, EpisodeNum: {}, SampleNum: {}, ReplayBSize: {}"
                    .format(
                        args.expID,
                        this_training_timesteps / (time.time() - s),
                        total_timesteps,
                        episode_num,
                        num_samples,
                        sum([
                            len(replay_buffer[name].storage)
                            for name in envs_train_names
                        ]),
                    ))
                for i in range(len(envs_train_names)):
                    print("{} === EpisodeT: {}, Reward: {:.2f}".format(
                        envs_train_names[i],
                        episode_timesteps_list[i],
                        episode_reward_list[i],
                    ))

            # save model and replay buffers
            if timesteps_since_saving >= args.save_freq:
                timesteps_since_saving = 0
                model_saved_path = cp.save_model(
                    exp_path,
                    policy,
                    total_timesteps,
                    episode_num,
                    num_samples,
                    replay_buffer,
                    envs_train_names,
                    args,
                    model_name=f"model_{model_savings_so_far}.pyth",
                )
                model_savings_so_far += 1
                print("*** model saved to {} ***".format(model_saved_path))
                if args.save_buffer:
                    rb_saved_path = cp.save_replay_buffer(
                        rb_path, replay_buffer)
                    print("*** replay buffers saved to {} ***".format(
                        rb_saved_path))

            # reset training variables
            obs_list = envs_train.reset()
            done_list = [False for i in range(num_envs_train)]
            episode_reward_list = [0 for i in range(num_envs_train)]
            episode_timesteps_list = [0 for i in range(num_envs_train)]
            episode_num += num_envs_train
            # create reward buffer to store reward for one sub-env when it is not done
            episode_reward_list_buffer = [0 for i in range(num_envs_train)]

        # start sampling ===========================================================
        # sample action randomly for sometime and then according to the policy
        if total_timesteps < args.start_timesteps * num_envs_train:
            action_list = [
                np.random.uniform(
                    low=envs_train.action_space.low[0],
                    high=envs_train.action_space.high[0],
                    size=max_num_limbs,
                ) for i in range(num_envs_train)
            ]
        else:
            action_list = []
            for i in range(num_envs_train):
                # dynamically change the graph structure of the modular policy
                policy.change_morphology(args.graphs[envs_train_names[i]])
                # remove 0 padding of obs before feeding into the policy (trick for vectorized env)
                obs = np.array(
                    obs_list[i][:args.limb_obs_size *
                                len(args.graphs[envs_train_names[i]])])
                policy_action = policy.select_action(obs)
                if args.expl_noise != 0:
                    policy_action = (policy_action + np.random.normal(
                        0, args.expl_noise, size=policy_action.size)).clip(
                            envs_train.action_space.low[0],
                            envs_train.action_space.high[0])
                # add 0-padding to ensure that size is the same for all envs
                policy_action = np.append(
                    policy_action,
                    np.array([
                        0 for i in range(max_num_limbs - policy_action.size)
                    ]),
                )
                action_list.append(policy_action)

        # perform action in the environment
        new_obs_list, reward_list, curr_done_list, _ = envs_train.step(
            action_list)

        # record if each env has ever been 'done'
        done_list = [
            done_list[i] or curr_done_list[i] for i in range(num_envs_train)
        ]

        for i in range(num_envs_train):
            # add the instant reward to the cumulative buffer
            # if any sub-env is done at the momoent, set the episode reward list to be the value in the buffer
            episode_reward_list_buffer[i] += reward_list[i]
            if curr_done_list[i] and episode_reward_list[i] == 0:
                episode_reward_list[i] = episode_reward_list_buffer[i]
                episode_reward_list_buffer[i] = 0
            done_bool = float(curr_done_list[i])
            if episode_timesteps_list[i] + 1 == args.max_episode_steps:
                done_bool = 0
                done_list[i] = True
            # remove 0 padding before storing in the replay buffer (trick for vectorized env)
            num_limbs = len(args.graphs[envs_train_names[i]])
            obs = np.array(obs_list[i][:args.limb_obs_size * num_limbs])
            new_obs = np.array(new_obs_list[i][:args.limb_obs_size *
                                               num_limbs])
            action = np.array(action_list[i][:num_limbs])
            # insert transition in the replay buffer
            replay_buffer[envs_train_names[i]].add(
                (obs, new_obs, action, reward_list[i], done_bool))
            num_samples += 1
            # do not increment episode_timesteps if the sub-env has been 'done'
            if not done_list[i]:
                episode_timesteps_list[i] += 1
                total_timesteps += 1
                this_training_timesteps += 1
                timesteps_since_saving += 1
                timesteps_since_saving_model_only += 1

        obs_list = new_obs_list
        collect_done = all(done_list)

    # save checkpoint after training ===========================================================
    model_saved_path = cp.save_model(
        exp_path,
        policy,
        total_timesteps,
        episode_num,
        num_samples,
        replay_buffer,
        envs_train_names,
        args,
    )
    print("*** training finished and model saved to {} ***".format(
        model_saved_path))
def train(args):

    # Set up directories ===========================================================
    os.makedirs(DATA_DIR, exist_ok=True)
    os.makedirs(BUFFER_DIR, exist_ok=True)
    exp_name = "EXP_%04d" % (args.expID)
    exp_path = os.path.join(DATA_DIR, exp_name)
    rb_path = os.path.join(BUFFER_DIR, exp_name)
    os.makedirs(exp_path, exist_ok=True)
    os.makedirs(rb_path, exist_ok=True)
    # save arguments
    with open(os.path.join(exp_path, 'args.txt'), 'w+') as f:
        json.dump(args.__dict__, f, indent=2)

    # Retrieve MuJoCo XML files for training ========================================
    envs_train_names = []
    args.graphs = dict()
    # existing envs
    if not args.custom_xml:
        if args.predefined_order:
            envs_train_names = [
                'walker_7_flipped', 'walker_7_main', 'walker_6_flipped',
                'walker_5_main', 'walker_4_main', 'walker_5_flipped',
                'walker_3_main', 'walker_6_main', 'walker_3_flipped',
                'walker_4_flipped'
            ]
        else:
            for morphology in args.morphologies:
                envs_train_names += [
                    name[:-4] for name in os.listdir(XML_DIR)
                    if '.xml' in name and morphology in name
                ]
            envs_train_names.sort()

        total_num_envs = len(envs_train_names)
        train_envs = envs_train_names[:int(args.train_ratio * total_num_envs)]
        test_envs = envs_train_names[int(args.train_ratio * total_num_envs):]
        envs_train_names = train_envs
        for name in envs_train_names:
            args.graphs[name] = utils.getGraphStructure(
                os.path.join(XML_DIR, '{}.xml'.format(name)))
    # custom envs
    else:
        if os.path.isfile(args.custom_xml):
            assert '.xml' in os.path.basename(
                args.custom_xml), 'No XML file found.'
            name = os.path.basename(args.custom_xml)
            envs_train_names.append(name[:-4])  # truncate the .xml suffix
            args.graphs[name[:-4]] = utils.getGraphStructure(args.custom_xml)
        elif os.path.isdir(args.custom_xml):
            for name in os.listdir(args.custom_xml):
                if '.xml' in name:
                    envs_train_names.append(name[:-4])
                    args.graphs[name[:-4]] = utils.getGraphStructure(
                        os.path.join(args.custom_xml, name))
    # envs_train_names.sort()
    num_envs_train = len(envs_train_names)
    print("#" * 50 + '\ntraining envs: {}\n'.format(envs_train_names) +
          "#" * 50)

    # Set up training env and policy ================================================
    args.limb_obs_size, args.max_action = utils.registerEnvs(
        envs_train_names, args.max_episode_steps, args.custom_xml)
    max_num_limbs = max(
        [len(args.graphs[env_name]) for env_name in envs_train_names])
    # create vectorized training env
    obs_max_len = max(
        [len(args.graphs[env_name])
         for env_name in envs_train_names]) * args.limb_obs_size
    # envs_train = SubprocVecEnv(envs_train)  # vectorized env
    # set random seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    # determine the maximum number of children in all the training envs
    if args.max_children is None:
        args.max_children = utils.findMaxChildren(envs_train_names,
                                                  args.graphs)
    # setup agent policy
    policy = TD3.LifeLongTD3(args)

    # Create new training instance or load previous checkpoint ========================
    if cp.has_checkpoint(exp_path, rb_path):
        print("*** loading checkpoint from {} ***".format(exp_path))
        total_timesteps, episode_num, replay_buffer, num_samples, loaded_path = cp.load_checkpoint(
            exp_path, rb_path, policy, args)
        print("*** checkpoint loaded from {} ***".format(loaded_path))
    else:
        print("*** training from scratch ***")
        # init training vars
        total_timesteps = 0
        episode_num = 0
        num_samples = 0
        # different replay buffer for each env; avoid using too much memory if there are too many envs

    # Initialize training variables ================================================
    writer = SummaryWriter("%s/%s/" % (DATA_DIR, exp_name))
    s = time.time()
    # TODO: may have to change the following codes into the loop
    timesteps_since_saving = 0
    this_training_timesteps = 0

    # Start training ===========================================================
    for env_name in envs_train_names:
        print("new env: {}".format(env_name))
        envs_train = []
        for i in range(args.num_parallel):
            envs_train.append(
                utils.makeEnvWrapperParallel(env_name, obs_max_len, args.seed))
        envs_train = SubprocVecEnv(envs_train, in_series=1)
        replay_buffer = utils.ReplayBuffer(max_size=args.rb_max)
        policy.change_morphology(args.graphs[env_name])
        policy.graph = args.graphs[env_name]
        task_timesteps = 0
        collect_done = True
        obs_list = envs_train.reset()
        done_list = [False for i in range(args.num_parallel)]
        episode_reward_list = [0 for i in range(args.num_parallel)]
        episode_timesteps_list = [0 for i in range(args.num_parallel)]
        # create reward buffer to store reward for one sub-env when it is not done
        episode_reward_list_buffer = [0 for i in range(args.num_parallel)]
        mean_reward = 0
        this_training_timesteps = 0
        while task_timesteps < args.max_timesteps:
            # train and log after one episode for each env
            if collect_done:
                # log updates and train policy
                if this_training_timesteps != 0:
                    mean_reward = np.array(episode_reward_list).mean()
                    if mean_reward < args.success_thres:
                        policy.train(replay_buffer,
                                     np.array(episode_timesteps_list).sum(),
                                     args.batch_size,
                                     args.discount,
                                     args.tau,
                                     args.policy_noise,
                                     args.noise_clip,
                                     args.policy_freq,
                                     graphs=args.graphs,
                                     env_name=env_name)
                    # add to tensorboard display
                    writer.add_scalar(
                        '{}_episode_mean_reward'.format(env_name), mean_reward,
                        task_timesteps)
                    # for i in range(args.num_parallel):
                    #     writer.add_scalar('{}_episode_reward_proc{}'.format(env_name,i), episode_reward_list[i], task_timesteps)
                    #     writer.add_scalar('{}_episode_len_proc{}'.format(env_name,i), episode_timesteps_list[i], task_timesteps)
                    # print to console
                    print(
                        "-" * 50 +
                        "\nExpID: {}, FPS: {:.2f}, TotalT: {}, taskT:{},  EpisodeNum: {}, SampleNum: {}, ReplayBSize: {}"
                        .format(args.expID, this_training_timesteps /
                                (time.time() - s), total_timesteps,
                                task_timesteps, episode_num, num_samples,
                                sum([len(replay_buffer.storage)])))
                    for i in range(args.num_parallel):
                        print("{} process {} === EpisodeT: {}, Reward: {:.2f}".
                              format(env_name, i, episode_timesteps_list[i],
                                     episode_reward_list[i]))
                    print("mean reward {}".format(mean_reward))
                    this_training_timesteps = 0
                    s = time.time()

                # save model and replay buffers
                if timesteps_since_saving >= args.save_freq:
                    timesteps_since_saving = 0
                    model_saved_path = cp.save_model_lifelong(
                        exp_path, policy, total_timesteps, episode_num,
                        num_samples, replay_buffer, env_name, args)
                    print("*** model saved to {} ***".format(model_saved_path))
                    rb_saved_path = cp.save_replay_buffer_lifelong(
                        rb_path, replay_buffer, env_name)
                    print("*** replay buffers saved to {} ***".format(
                        rb_saved_path))

                # reset training variables

                obs_list = envs_train.reset()
                done_list = [False for i in range(args.num_parallel)]
                episode_reward_list = [0 for i in range(args.num_parallel)]
                episode_timesteps_list = [0 for i in range(args.num_parallel)]
                episode_num += args.num_parallel
                # create reward buffer to store reward for one sub-env when it is not done
                episode_reward_list_buffer = [
                    0 for i in range(args.num_parallel)
                ]
                if mean_reward > args.success_thres:
                    print(
                        "satisfied training requirement, change to the next task"
                    )
                    break
            # start sampling ===========================================================
            # sample action randomly for sometime and then according to the policy
            if task_timesteps < args.start_timesteps:
                action_list = [
                    np.random.uniform(low=envs_train.action_space.low[0],
                                      high=envs_train.action_space.high[0],
                                      size=max_num_limbs)
                    for i in range(args.num_parallel)
                ]
            else:
                action_list = []
                for i in range(args.num_parallel):
                    # remove 0 padding of obs before feeding into the policy (trick for vectorized env)
                    obs = np.array(obs_list[i][:args.limb_obs_size *
                                               len(args.graphs[env_name])])
                    policy_action = policy.select_action(obs)
                    if args.expl_noise != 0:
                        policy_action = (policy_action + np.random.normal(
                            0, args.expl_noise, size=policy_action.size)).clip(
                                envs_train.action_space.low[0],
                                envs_train.action_space.high[0])
                    # add 0-padding to ensure that size is the same for all envs
                    policy_action = np.append(
                        policy_action,
                        np.array([
                            0
                            for i in range(max_num_limbs - policy_action.size)
                        ]))
                    action_list.append(policy_action)

            # perform action in the environment
            new_obs_list, reward_list, curr_done_list, _ = envs_train.step(
                action_list)

            # record if each env has ever been 'done'
            done_list = [
                done_list[i] or curr_done_list[i]
                for i in range(args.num_parallel)
            ]

            # record if each env has ever been 'done'

            # add the instant reward to the cumulative buffer
            for i in range(args.num_parallel):
                # add the instant reward to the cumulative buffer
                # if any sub-env is done at the momoent, set the episode reward list to be the value in the buffer
                episode_reward_list_buffer[i] += reward_list[i]
                if curr_done_list[i] and episode_reward_list[i] == 0:
                    episode_reward_list[i] = episode_reward_list_buffer[i]
                    episode_reward_list_buffer[i] = 0
                writer.add_scalar(
                    '{}_instant_reward_pro{}'.format(env_name, i),
                    reward_list[i], task_timesteps)
                done_bool = float(curr_done_list[i])
                if episode_timesteps_list[i] + 1 == args.max_episode_steps:
                    done_bool = 0
                    done_list[i] = True
                # remove 0 padding before storing in the replay buffer (trick for vectorized env)
                num_limbs = len(args.graphs[env_name])
                obs = np.array(obs_list[i][:args.limb_obs_size * num_limbs])
                new_obs = np.array(new_obs_list[i][:args.limb_obs_size *
                                                   num_limbs])
                action = np.array(action_list[i][:num_limbs])
                # insert transition in the replay buffer
                replay_buffer.add(
                    (obs, new_obs, action, reward_list[i], done_bool))
                num_samples += 1
                # do not increment episode_timesteps if the sub-env has been 'done'
                if not done_list[i]:
                    episode_timesteps_list[i] += 1
                    task_timesteps += 1
                    total_timesteps += 1
                    this_training_timesteps += 1
                    timesteps_since_saving += 1

            obs_list = new_obs_list
            collect_done = all(done_list)

        model_saved_path = cp.save_model_lifelong(exp_path, policy,
                                                  total_timesteps, episode_num,
                                                  num_samples, replay_buffer,
                                                  env_name, args)
        policy.next_task()

    # save checkpoint after training ===========================================================
    model_saved_path = cp.save_model_lifelong(exp_path, policy,
                                              total_timesteps, episode_num,
                                              num_samples, replay_buffer,
                                              env_name, args)
    print("*** training finished and model saved to {} ***".format(
        model_saved_path))