def main():
    parser = argparse.ArgumentParser('Parse configuration file')
    parser.add_argument('--policy', type=str, default='orca')
    parser.add_argument('--env_config',
                        type=str,
                        default='../configs/env.config')
    parser.add_argument('--policy_config',
                        type=str,
                        default='../configs/policy.config')
    parser.add_argument('--imitate_config',
                        type=str,
                        default='../configs/demonstrate.config')
    parser.add_argument('--num_frames', type=int, default=1)
    parser.add_argument('--output_dir',
                        type=str,
                        default='../data/output/temp')
    parser.add_argument('--memory_dir',
                        type=str,
                        default='../data/memory/temp')
    parser.add_argument('--human_type',
                        type=str,
                        default='normal',
                        choices=['aggressive', 'normal', 'cautious'])
    parser.add_argument('--expert_file',
                        type=str,
                        default='../data/expert/rl_model.pth')
    parser.add_argument('--freeze', default=False, action='store_true')
    parser.add_argument('--resume', default=False, action='store_true')
    parser.add_argument('--gpu', default=False, action='store_true')
    parser.add_argument('--debug', default=False, action='store_true')
    args = parser.parse_args()

    # configure paths
    make_new_dir = True
    if os.path.exists(args.output_dir):
        key = input(
            'Output directory already exists! Overwrite the folder? (y/n)')
        if key == 'y' and not args.resume:
            shutil.rmtree(args.output_dir)
        else:
            make_new_dir = False
            args.env_config = os.path.join(args.output_dir,
                                           os.path.basename(args.env_config))
            args.policy_config = os.path.join(
                args.output_dir, os.path.basename(args.policy_config))
            args.imitate_config = os.path.join(
                args.output_dir, os.path.basename(args.imitate_config))
    if make_new_dir:
        os.makedirs(args.output_dir)
        shutil.copy(args.env_config, args.output_dir)
        shutil.copy(args.policy_config, args.output_dir)
        shutil.copy(args.imitate_config, args.output_dir)
    log_file = os.path.join(args.output_dir, 'output.log')

    # configure logging
    mode = 'a' if args.resume else 'w'
    file_handler = logging.FileHandler(log_file, mode=mode)
    stdout_handler = logging.StreamHandler(sys.stdout)
    level = logging.INFO if not args.debug else logging.DEBUG
    logging.basicConfig(level=level,
                        handlers=[stdout_handler, file_handler],
                        format='%(asctime)s, %(levelname)s: %(message)s',
                        datefmt="%Y-%m-%d %H:%M:%S")
    repo = git.Repo(search_parent_directories=True)
    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and args.gpu else "cpu")
    logging.info(' =========== Demonstrate %s ============ ', args.policy)
    logging.info('Current git head hash code: %s', (repo.head.object.hexsha))
    logging.info('Using device: %s', device)

    # configure environment
    env_config = configparser.RawConfigParser()
    env_config.read(args.env_config)
    #env = gym.make('CrowdSim-v0')
    env = CrowdSim()
    env.configure(env_config)
    env_mods = None
    if args.human_type == 'aggressive':
        env_mods = {'safety_space': 0.03, 'time_horizon': 1}
    elif args.human_type == 'cautious':
        env_mods = {'safety_space': 0.3, 'time_horizon': 8}
    modify_env_params(env, mods=env_mods)
    robot = Robot(env_config, 'robot')
    env.set_robot(robot)

    # multi-frame env
    if args.num_frames > 1:
        logging.info("stack %d frames", args.num_frames)
        env = FrameStack(env, args.num_frames)

    # configure policy
    policy = policy_factory[args.policy]()
    if not policy.trainable:
        parser.error('Policy has to be trainable')
    if args.policy_config is None:
        parser.error(
            'Policy config has to be specified for a trainable network')
    policy_config = configparser.RawConfigParser()
    policy_config.read(args.policy_config)

    # read training parameters
    if args.imitate_config is None:
        parser.error(
            'Train config has to be specified for a trainable network')
    imitate_config = configparser.RawConfigParser()
    imitate_config.read(args.imitate_config)

    # configure demonstrateion and explorer
    memory = ReplayMemory(
        10000000)  # sufficiently large to store expert demonstration

    if not os.path.exists(args.memory_dir):
        os.makedirs(args.memory_dir)
    if robot.visible:
        demonstration_file = os.path.join(args.memory_dir,
                                          'data_imit_visible.pt')
    else:
        demonstration_file = os.path.join(args.memory_dir,
                                          'data_imit_invisible.pt')

    il_episodes = imitate_config.getint('demonstration', 'il_episodes')
    expert_policy = imitate_config.get('demonstration', 'il_policy')

    expert_policy = policy_factory[expert_policy]()
    expert_policy.configure(policy_config)
    expert_policy.set_device(device)
    expert_policy.set_env(env)
    expert_policy.multiagent_training = policy.multiagent_training

    counter = imitate_config.getint('demonstration', 'counter_start')
    env.set_counter(counter)

    if robot.visible:
        expert_policy.safety_space = imitate_config.getfloat(
            'demonstration', 'safety_space')
    else:
        expert_policy.safety_space = imitate_config.getfloat(
            'demonstration', 'safety_space')
    robot.set_policy(expert_policy)

    noise_explore = imitate_config.getfloat('demonstration', 'noise_explore')

    gamma = policy_config.getfloat('rl', 'gamma')
    explorer = Explorer(env,
                        robot,
                        device,
                        args.num_frames,
                        memory,
                        gamma=gamma,
                        target_policy=policy)

    if expert_policy.trainable:
        if args.expert_file is None:
            logging.warning(
                'Trainable policy is NOT specified with a model weights directory'
            )
        else:
            try:
                expert_policy.get_model().load_state_dict(
                    torch.load(args.expert_file))
                logging.info('Loaded policy from %s', args.expert_file)
            except Exception as e:
                logging.warning(e)
                expert_policy.get_model().load_state_dict(torch.load(
                    args.expert_file),
                                                          strict=False)
                logging.info('Loaded policy partially from %s',
                             args.expert_file)
        robot.policy.set_epsilon(-1.0)

    # data collection
    robot.policy.multiagent_training = True
    human_obsv = explorer.run_k_episodes(il_episodes,
                                         'train',
                                         update_memory=True,
                                         imitation_learning=True,
                                         noise_explore=noise_explore)

    torch.save(memory.memory, demonstration_file)
    logging.info('Save memory to %s', demonstration_file)

    obs_file = os.path.join(args.memory_dir, 'data_traj.pt')
    torch.save(human_obsv, obs_file)
    logging.info('Save #%d episodes of obsv to %s', len(human_obsv), obs_file)
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser('Parse configuration file')
    parser.add_argument('--env_config', type=str, default='configs/env.config')
    parser.add_argument('--policy', type=str, default='cadrl')
    parser.add_argument('--policy_config', type=str, default='configs/policy.config')
    parser.add_argument('--train_config', type=str, default='configs/train.config')
    parser.add_argument('--output_dir', type=str, default='data/output')
    parser.add_argument('--weights', type=str)
    parser.add_argument('--resume', default=True, action='store_true')
    parser.add_argument('--gpu', default=False, action='store_true')
    parser.add_argument('--debug', default=False, action='store_true')
    args = parser.parse_args()

    # configure paths
    make_new_dir = True
    if os.path.exists(args.output_dir):
        key = input('Output directory already exists! Overwrite the folder? (y/n)')
        if key == 'y' and not args.resume:
            shutil.rmtree(args.output_dir)
        else:
            make_new_dir = False
            args.env_config = os.path.join(args.output_dir, os.path.basename(args.env_config))
            args.policy_config = os.path.join(args.output_dir, os.path.basename(args.policy_config))
            args.train_config = os.path.join(args.output_dir, os.path.basename(args.train_config))
    if make_new_dir:
        os.makedirs(args.output_dir)
        shutil.copy(args.env_config, args.output_dir)
        shutil.copy(args.policy_config, args.output_dir)
        shutil.copy(args.train_config, args.output_dir)
    log_file = os.path.join(args.output_dir, 'output.log')
    il_weight_file = os.path.join(args.output_dir, 'il_model.pth')
    rl_weight_file = os.path.join(args.output_dir, 'rl_model.pth')

    # configure logging
    mode = 'a' if args.resume else 'w'
    file_handler = logging.FileHandler(log_file, mode=mode)
    stdout_handler = logging.StreamHandler(sys.stdout)
    level = logging.INFO if not args.debug else logging.DEBUG
    logging.basicConfig(level=level, handlers=[stdout_handler, file_handler],
                        format='%(asctime)s, %(levelname)s: %(message)s', datefmt="%Y-%m-%d %H:%M:%S")
    repo = git.Repo(search_parent_directories=True)
    logging.info('Current git head hash code: %s'.format(repo.head.object.hexsha))
    device = torch.device("cuda:0" if torch.cuda.is_available() and args.gpu else "cpu")
    logging.info('Using device: %s', device)

    # configure policy
    policy = policy_factory[args.policy]()
    if not policy.trainable:
        parser.error('Policy has to be trainable')
    if args.policy_config is None:
        parser.error('Policy config has to be specified for a trainable network')
    policy_config = configparser.RawConfigParser()
    policy_config.read(args.policy_config)
    policy.configure(policy_config)
    policy.set_device(device)

    # configure environment
    env_config = configparser.RawConfigParser()
    env_config.read(args.env_config)
    env = gym.make('CrowdSim-v0')
    env.configure(env_config)
    robot = Robot(env_config, 'robot')
    env.set_robot(robot)

    # read training parameters
    if args.train_config is None:
        parser.error('Train config has to be specified for a trainable network')
    train_config = configparser.RawConfigParser()
    train_config.read(args.train_config)
    rl_learning_rate = train_config.getfloat('train', 'rl_learning_rate')
    train_batches = train_config.getint('train', 'train_batches')
    train_episodes = train_config.getint('train', 'train_episodes')
    sample_episodes = train_config.getint('train', 'sample_episodes')
    target_update_interval = train_config.getint('train', 'target_update_interval')
    evaluation_interval = train_config.getint('train', 'evaluation_interval')
    capacity = train_config.getint('train', 'capacity')
    epsilon_start = train_config.getfloat('train', 'epsilon_start')
    epsilon_end = train_config.getfloat('train', 'epsilon_end')
    epsilon_decay = train_config.getfloat('train', 'epsilon_decay')
    checkpoint_interval = train_config.getint('train', 'checkpoint_interval')

    # configure trainer and explorer
    memory = ReplayMemory(capacity)
    model = policy.get_model()
    batch_size = train_config.getint('trainer', 'batch_size')
    trainer = Trainer(model, memory, device, batch_size)
    explorer = Explorer(env, robot, device, memory, policy.gamma, target_policy=policy)

    # imitation learning
    if args.resume:
        if not os.path.exists(rl_weight_file):
            logging.error('RL weights does not exist')
        model.load_state_dict(torch.load(rl_weight_file))
        rl_weight_file = os.path.join(args.output_dir, 'resumed_rl_model.pth')
        logging.info('Load reinforcement learning trained weights. Resume training')
    elif os.path.exists(il_weight_file):
        model.load_state_dict(torch.load(il_weight_file))
        logging.info('Load imitation learning trained weights.')
    else:
        il_episodes = train_config.getint('imitation_learning', 'il_episodes')
        il_policy = train_config.get('imitation_learning', 'il_policy')
        il_epochs = train_config.getint('imitation_learning', 'il_epochs')
        il_learning_rate = train_config.getfloat('imitation_learning', 'il_learning_rate')
        trainer.set_learning_rate(il_learning_rate)
        if robot.visible:
            safety_space = 0
        else:
            safety_space = train_config.getfloat('imitation_learning', 'safety_space')
        il_policy = policy_factory[il_policy]()
        il_policy.multiagent_training = policy.multiagent_training
        il_policy.safety_space = safety_space
        robot.set_policy(il_policy)
        explorer.run_k_episodes(il_episodes, 'train', update_memory=True, imitation_learning=True)
        trainer.optimize_epoch(il_epochs)
        torch.save(model.state_dict(), il_weight_file)
        logging.info('Finish imitation learning. Weights saved.')
        logging.info('Experience set size: %d/%d', len(memory), memory.capacity)
    explorer.update_target_model(model)

    # reinforcement learning
    policy.set_env(env)
    robot.set_policy(policy)
    robot.print_info()
    trainer.set_learning_rate(rl_learning_rate)
    # fill the memory pool with some RL experience
    if args.resume:
        robot.policy.set_epsilon(epsilon_end)
        explorer.run_k_episodes(100, 'train', update_memory=True, episode=0)
        logging.info('Experience set size: %d/%d', len(memory), memory.capacity)
    episode = 0
    while episode < train_episodes:
        if args.resume:
            epsilon = epsilon_end
        else:
            if episode < epsilon_decay:
                epsilon = epsilon_start + (epsilon_end - epsilon_start) / epsilon_decay * episode
            else:
                epsilon = epsilon_end
        robot.policy.set_epsilon(epsilon)

        # evaluate the model
        if episode % evaluation_interval == 0:
            explorer.run_k_episodes(env.case_size['val'], 'val', episode=episode)

        # sample k episodes into memory and optimize over the generated memory
        explorer.run_k_episodes(sample_episodes, 'train', update_memory=True, episode=episode)
        trainer.optimize_batch(train_batches)
        episode += 1

        if episode % target_update_interval == 0:
            explorer.update_target_model(model)

        if episode != 0 and episode % checkpoint_interval == 0:
            torch.save(model.state_dict(), rl_weight_file)

    # final test
    explorer.run_k_episodes(env.case_size['test'], 'test', episode=episode)
Ejemplo n.º 3
0
def main(args):
    set_random_seeds(args.randomseed)
    # configure paths
    make_new_dir = True
    if os.path.exists(args.output_dir):
        if args.overwrite:
            shutil.rmtree(args.output_dir)
        else:
            key = input(
                'Output directory already exists! Overwrite the folder? (y/n)')
            if key == 'y' and not args.resume:
                shutil.rmtree(args.output_dir)
            else:
                make_new_dir = False
    if make_new_dir:
        os.makedirs(args.output_dir)
        shutil.copy(args.config, os.path.join(args.output_dir, 'config.py'))

        # # insert the arguments from command line to the config file
        # with open(os.path.join(args.output_dir, 'config.py'), 'r') as fo:
        #     config_text = fo.read()
        # search_pairs = {r"gcn.X_dim = \d*": "gcn.X_dim = {}".format(args.X_dim),
        #                 r"gcn.num_layer = \d": "gcn.num_layer = {}".format(args.layers),
        #                 r"gcn.similarity_function = '\w*'": "gcn.similarity_function = '{}'".format(args.sim_func),
        #                 r"gcn.layerwise_graph = \w*": "gcn.layerwise_graph = {}".format(args.layerwise_graph),
        #                 r"gcn.skip_connection = \w*": "gcn.skip_connection = {}".format(args.skip_connection)}
        #
        # for find, replace in search_pairs.items():
        #     config_text = re.sub(find, replace, config_text)
        #
        # with open(os.path.join(args.output_dir, 'config.py'), 'w') as fo:
        #     fo.write(config_text)

    args.config = os.path.join(args.output_dir, 'config.py')
    log_file = os.path.join(args.output_dir, 'output.log')
    il_weight_file = os.path.join(args.output_dir, 'il_model.pth')
    rl_weight_file = os.path.join(args.output_dir, 'rl_model.pth')

    spec = importlib.util.spec_from_file_location('config', args.config)
    if spec is None:
        parser.error('Config file not found.')
    config = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(config)

    # configure logging
    mode = 'a' if args.resume else 'w'
    file_handler = logging.FileHandler(log_file, mode=mode)
    stdout_handler = logging.StreamHandler(sys.stdout)
    level = logging.INFO if not args.debug else logging.DEBUG
    logging.basicConfig(level=level,
                        handlers=[stdout_handler, file_handler],
                        format='%(asctime)s, %(levelname)s: %(message)s',
                        datefmt="%Y-%m-%d %H:%M:%S")
    repo = git.Repo(search_parent_directories=True)
    logging.info('Current git head hash code: {}'.format(
        repo.head.object.hexsha))
    logging.info('Current config content is :{}'.format(config))
    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and args.gpu else "cpu")
    logging.info('Using device: %s', device)
    writer = SummaryWriter(log_dir=args.output_dir)

    # configure policy
    policy_config = config.PolicyConfig()
    policy = policy_factory[policy_config.name]()
    if not policy.trainable:
        parser.error('Policy has to be trainable')
    policy.configure(policy_config)
    policy.set_device(device)

    # configure environment
    env_config = config.EnvConfig(args.debug)
    env = gym.make('CrowdSim-v0')
    env.configure(env_config)
    robot = Robot(env_config, 'robot')
    robot.time_step = env.time_step
    env.set_robot(robot)

    # read training parameters
    train_config = config.TrainConfig(args.debug)
    rl_learning_rate = train_config.train.rl_learning_rate
    train_batches = train_config.train.train_batches
    train_episodes = train_config.train.train_episodes
    sample_episodes = train_config.train.sample_episodes
    target_update_interval = train_config.train.target_update_interval
    evaluation_interval = train_config.train.evaluation_interval
    capacity = train_config.train.capacity
    epsilon_start = train_config.train.epsilon_start
    epsilon_end = train_config.train.epsilon_end
    epsilon_decay = train_config.train.epsilon_decay
    checkpoint_interval = train_config.train.checkpoint_interval

    # configure trainer and explorer
    memory = ReplayMemory(capacity)
    model = policy.get_model()
    batch_size = train_config.trainer.batch_size
    optimizer = train_config.trainer.optimizer
    if policy_config.name == 'model_predictive_rl':
        trainer = MPRLTrainer(
            model,
            policy.state_predictor,
            memory,
            device,
            policy,
            writer,
            batch_size,
            optimizer,
            env.human_num,
            reduce_sp_update_frequency=train_config.train.
            reduce_sp_update_frequency,
            freeze_state_predictor=train_config.train.freeze_state_predictor,
            detach_state_predictor=train_config.train.detach_state_predictor,
            share_graph_model=policy_config.model_predictive_rl.
            share_graph_model)
    else:
        trainer = VNRLTrainer(model, memory, device, policy, batch_size,
                              optimizer, writer)
    explorer = Explorer(env,
                        robot,
                        device,
                        writer,
                        memory,
                        policy.gamma,
                        target_policy=policy)

    # imitation learning
    if args.resume:
        if not os.path.exists(rl_weight_file):
            logging.error('RL weights does not exist')
        model.load_state_dict(torch.load(rl_weight_file))
        rl_weight_file = os.path.join(args.output_dir, 'resumed_rl_model.pth')
        logging.info(
            'Load reinforcement learning trained weights. Resume training')
    elif os.path.exists(il_weight_file):
        model.load_state_dict(torch.load(il_weight_file))
        logging.info('Load imitation learning trained weights.')
    else:
        il_episodes = train_config.imitation_learning.il_episodes
        il_policy = train_config.imitation_learning.il_policy
        il_epochs = train_config.imitation_learning.il_epochs
        il_learning_rate = train_config.imitation_learning.il_learning_rate
        trainer.set_learning_rate(il_learning_rate)
        if robot.visible:
            safety_space = 0
        else:
            safety_space = train_config.imitation_learning.safety_space
        il_policy = policy_factory[il_policy]()
        il_policy.multiagent_training = policy.multiagent_training
        il_policy.safety_space = safety_space
        robot.set_policy(il_policy)
        explorer.run_k_episodes(il_episodes,
                                'train',
                                update_memory=True,
                                imitation_learning=True)
        trainer.optimize_epoch(il_epochs)
        policy.save_model(il_weight_file)
        logging.info('Finish imitation learning. Weights saved.')
        logging.info('Experience set size: %d/%d', len(memory),
                     memory.capacity)

    trainer.update_target_model(model)

    # reinforcement learning
    policy.set_env(env)
    robot.set_policy(policy)
    robot.print_info()
    trainer.set_learning_rate(rl_learning_rate)
    # fill the memory pool with some RL experience
    if args.resume:
        robot.policy.set_epsilon(epsilon_end)
        explorer.run_k_episodes(100, 'train', update_memory=True, episode=0)
        logging.info('Experience set size: %d/%d', len(memory),
                     memory.capacity)
    episode = 0
    best_val_reward = -1
    best_val_model = None
    # evaluate the model after imitation learning

    if episode % evaluation_interval == 0:
        logging.info(
            'Evaluate the model instantly after imitation learning on the validation cases'
        )
        explorer.run_k_episodes(env.case_size['val'], 'val', episode=episode)
        explorer.log('val', episode // evaluation_interval)

        if args.test_after_every_eval:
            explorer.run_k_episodes(env.case_size['test'],
                                    'test',
                                    episode=episode,
                                    print_failure=True)
            explorer.log('test', episode // evaluation_interval)

    episode = 0
    while episode < train_episodes:
        if args.resume:
            epsilon = epsilon_end
        else:
            if episode < epsilon_decay:
                epsilon = epsilon_start + (
                    epsilon_end - epsilon_start) / epsilon_decay * episode
            else:
                epsilon = epsilon_end
        robot.policy.set_epsilon(epsilon)

        # sample k episodes into memory and optimize over the generated memory
        explorer.run_k_episodes(sample_episodes,
                                'train',
                                update_memory=True,
                                episode=episode)
        explorer.log('train', episode)

        trainer.optimize_batch(train_batches, episode)
        episode += 1

        if episode % target_update_interval == 0:
            trainer.update_target_model(model)
        # evaluate the model
        if episode % evaluation_interval == 0:
            _, _, _, reward, _ = explorer.run_k_episodes(env.case_size['val'],
                                                         'val',
                                                         episode=episode)
            explorer.log('val', episode // evaluation_interval)

            if episode % checkpoint_interval == 0 and reward > best_val_reward:
                best_val_reward = reward
                best_val_model = copy.deepcopy(policy.get_state_dict())
        # test after every evaluation to check how the generalization performance evolves
            if args.test_after_every_eval:
                explorer.run_k_episodes(env.case_size['test'],
                                        'test',
                                        episode=episode,
                                        print_failure=True)
                explorer.log('test', episode // evaluation_interval)

        if episode != 0 and episode % checkpoint_interval == 0:
            current_checkpoint = episode // checkpoint_interval - 1
            save_every_checkpoint_rl_weight_file = rl_weight_file.split(
                '.')[0] + '_' + str(current_checkpoint) + '.pth'
            policy.save_model(save_every_checkpoint_rl_weight_file)

    # # test with the best val model
    if best_val_model is not None:
        policy.load_state_dict(best_val_model)
        torch.save(best_val_model, os.path.join(args.output_dir,
                                                'best_val.pth'))
        logging.info('Save the best val model with the reward: {}'.format(
            best_val_reward))
    explorer.run_k_episodes(env.case_size['test'],
                            'test',
                            episode=episode,
                            print_failure=True)
Ejemplo n.º 4
0
def main():

    parser = argparse.ArgumentParser('Parse configuration file')
    parser.add_argument('--env_config', type=str, default='configs/env.config')
    parser.add_argument('--policy_config',
                        type=str,
                        default='configs/policy.config')
    parser.add_argument('--train_config',
                        type=str,
                        default='configs/train.config')
    parser.add_argument('--output_dir', type=str, default='data')
    parser.add_argument('--resume', default=False, action='store_true')
    parser.add_argument('--gpu', default=False, action='store_true')
    parser.add_argument('--debug', default=False, action='store_true')
    parser.add_argument('--gpu_num', default="0", type=str)
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num

    ########## CONFIGURE PATHS ##########

    make_new_dir = True
    outputs = os.listdir(args.output_dir)

    # Find the 'output' files in output directory
    output_nums = [
        name[name.find("output") + 6:] for name in outputs if 'output' in name
        and os.path.isdir(os.path.join(args.output_dir, name))
    ]
    num = 0
    if output_nums != []:
        num = max([int(num) if num != '' else 0 for num in output_nums])
        if num == 0:
            num = ''
        key = input('Continue from last output ? (y/n)')
        if key == 'y' and not args.resume:
            args.output_dir = os.path.join(args.output_dir,
                                           "output" + str(num))
            args.env_config = os.path.join(args.output_dir,
                                           os.path.basename(args.env_config))
            args.policy_config = os.path.join(
                args.output_dir, os.path.basename(args.policy_config))
            args.train_config = os.path.join(
                args.output_dir, os.path.basename(args.train_config))
        else:
            key = input(
                'Overwrite last output directory instead of new ? (y/n)')
            if key == 'y' and not args.resume:
                args.output_dir = os.path.join(args.output_dir,
                                               'output' + str(num))
                shutil.rmtree(args.output_dir)
                os.makedirs(args.output_dir)
            else:
                num = num + 1 if num else 1
                args.output_dir = os.path.join(args.output_dir,
                                               'output' + str(num))
                os.makedirs(args.output_dir)
            shutil.copy(args.env_config, args.output_dir)
            shutil.copy(args.policy_config, args.output_dir)
            shutil.copy(args.train_config, args.output_dir)
    else:
        args.output_dir = os.path.join(args.output_dir, 'output')
        os.makedirs(args.output_dir)
        shutil.copy(args.env_config, args.output_dir)
        shutil.copy(args.policy_config, args.output_dir)
        shutil.copy(args.train_config, args.output_dir)

    log_file = os.path.join(args.output_dir, 'output.log')
    il_weight_file = os.path.join(args.output_dir, 'il_model.pth')
    rl_weight_file = os.path.join(args.output_dir, 'rl_model.pth')

    ########## CONFIGURE LOGGING ##########

    mode = 'a' if args.resume else 'w'
    logger = logging.getLogger('train_sgan')
    logger.propogate = False
    level = logging.INFO if not args.debug else logging.DEBUG
    logger.setLevel(level)
    file_handler = logging.FileHandler(log_file, mode=mode)
    file_handler.setLevel(level)
    stdout_handler = logging.StreamHandler(sys.stdout)
    stdout_handler.setLevel(level)
    formatter = logging.Formatter('%(asctime)s, %(levelname)s: %(message)s')
    stdout_handler.setFormatter(formatter)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.addHandler(stdout_handler)

    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and args.gpu else "cpu")

    ########## CONFIGURE POLICY ##########

    policy = TLSGAN(device, logger)
    if args.policy_config is None:
        parser.error(
            'Policy config has to be specified for a trainable network')
    policy_config = configparser.RawConfigParser()
    policy_config.read(args.policy_config)

    policy.configure(policy_config)
    policy.set_device(device)

    ########## CONFIGURE ENVIRONMENT ##########

    env_config = configparser.RawConfigParser()
    env_config.read(args.env_config)
    env = gym.make('CrowdSim-v0')
    env.configure(env_config)
    robot = Robot(env_config, 'robot')
    env.set_robot(robot)

    ########## READ RL PARAMETERS ##########
    if args.train_config is None:
        parser.error(
            'Train config has to be specified for a trainable network')
    train_config = configparser.RawConfigParser()
    train_config.read(args.train_config)
    rl_learning_rate = train_config.getfloat('train', 'rl_learning_rate')
    train_batches = train_config.getint('train', 'train_batches')
    train_episodes = train_config.getint('train', 'train_episodes')
    sample_episodes = train_config.getint('train', 'sample_episodes')
    target_update_interval = train_config.getint('train',
                                                 'target_update_interval')
    evaluation_interval = train_config.getint('train', 'evaluation_interval')
    capacity = train_config.getint('train', 'capacity')
    epsilon_start = train_config.getfloat('train', 'epsilon_start')
    epsilon_end = train_config.getfloat('train', 'epsilon_end')
    epsilon_decay = train_config.getfloat('train', 'epsilon_decay')
    checkpoint_interval = train_config.getint('train', 'checkpoint_interval')

    ########## CONFIGURE MEMORY, TRAINER, EXPLORER ##########
    train_memory = ReplayMemory(capacity)
    val_memory = ReplayMemory(capacity)
    model = policy.get_model()
    batch_size = train_config.getint('trainer', 'batch_size')
    trainer = Trainer(model, train_memory, val_memory, device, batch_size,
                      args.output_dir, logger)
    explorer = Explorer(env, robot, policy.policy_learning, device,
                        train_memory, val_memory, policy.gamma, logger, policy,
                        policy.obs_len)

    logger.info('Using device: %s', device)

    ########## IMITATION LEARNING ##########
    if args.resume:
        if not os.path.exists(rl_weight_file):
            logger.error('RL weights does not exist')
        model.load_state_dict(torch.load(rl_weight_file, map_location=device))
        rl_weight_file = os.path.join(args.output_dir, 'resumed_rl_model.pth')
        logger.info(
            'Load reinforcement learning trained weights. Resume training')
    elif os.path.exists(il_weight_file):
        model.load_state_dict(torch.load(il_weight_file, map_location=device),
                              strict=False)

        logger.info('Load imitation learning trained weights.')
        if robot.visible:
            safety_space = 0
        else:
            safety_space = train_config.getfloat('imitation_learning',
                                                 'safety_space')
        il_policy = train_config.get('imitation_learning', 'il_policy')
        il_policy = policy_factory[il_policy]()
        il_policy.multiagent_training = policy.multiagent_training
        il_policy.safety_space = safety_space
        robot.set_test_policy(il_policy)
        robot.set_policy(policy)
        policy.set_env(env)
        robot.policy.set_epsilon(epsilon_end)

        explorer.run_k_episodes(100,
                                0,
                                'val',
                                update_memory=True,
                                imitation_learning=True,
                                with_render=False)
        explorer.run_k_episodes(100,
                                0,
                                'val',
                                update_memory=True,
                                imitation_learning=True,
                                with_render=False)
        explorer.run_k_episodes(100,
                                0,
                                'val',
                                update_memory=True,
                                imitation_learning=True,
                                with_render=False)
        explorer.run_k_episodes(100,
                                0,
                                'val',
                                update_memory=True,
                                imitation_learning=True,
                                with_render=False)
        explorer.run_k_episodes(100,
                                0,
                                'val',
                                update_memory=True,
                                imitation_learning=True,
                                with_render=False)

        logger.info('Experience training set size: %d/%d', len(train_memory),
                    train_memory.capacity)
    else:
        train_episodes = train_config.getint('imitation_learning',
                                             'train_episodes')
        validation_episodes = train_config.getint('imitation_learning',
                                                  'validation_episodes')
        logger.info(
            'Starting imitation learning on %d training episodes and %d validation episodes',
            train_episodes, validation_episodes)

        il_policy = train_config.get('imitation_learning', 'il_policy')
        il_epochs = train_config.getint('imitation_learning', 'il_epochs')
        il_learning_rate = train_config.getfloat('imitation_learning',
                                                 'il_learning_rate')
        trainer.set_learning_rate(il_learning_rate)
        if robot.visible:
            safety_space = 0
        else:
            safety_space = train_config.getfloat('imitation_learning',
                                                 'safety_space')
        il_policy = policy_factory[il_policy]()
        il_policy.multiagent_training = policy.multiagent_training
        il_policy.safety_space = safety_space
        robot.set_test_policy(il_policy)
        robot.set_policy(il_policy)
        explorer.run_k_episodes(train_episodes,
                                validation_episodes,
                                'train',
                                update_memory=True,
                                imitation_learning=True,
                                with_render=False)
        trainer.optimize_epoch_il(il_epochs, robot, with_validation=True)
        torch.save(model.state_dict(), il_weight_file)
        logger.info('Finish imitation learning. Weights saved.')
        logger.info('Experience train_set size: %d/%d', len(train_memory),
                    train_memory.capacity)
        logger.info('Experience validation_set size: %d/%d', len(val_memory),
                    val_memory.capacity)

    ########## REINFORCEMENT LEARNING ##########
    raise NotImplementedError

    policy.set_env(env)
    robot.set_policy(policy)

    robot.print_info()
    trainer.set_learning_rate(rl_learning_rate)

    explorer.update_target_model(model)
    robot.policy.set_epsilon(epsilon_end)

    rl_memory = ReplayMemory(capacity)
    explorer.train_memory = rl_memory
    trainer.train_memory = rl_memory

    episode = 0
    while episode < train_episodes:
        if args.resume:
            epsilon = epsilon_end
        else:
            if episode < epsilon_decay:
                epsilon = epsilon_start + (
                    epsilon_end - epsilon_start) / epsilon_decay * episode
            else:
                epsilon = epsilon_end
        robot.policy.set_epsilon(epsilon)

        # evaluate the model
        if episode % evaluation_interval == 0:
            explorer.run_k_episodes(env.case_size['val'],
                                    0,
                                    'val',
                                    episode=episode,
                                    with_render=False)

        # sample k episodes into memory and optimize over the generated memory
        explorer.run_k_episodes(sample_episodes,
                                0,
                                'train',
                                update_memory=True,
                                episode=episode)
        trainer.optimize_batch_rl(train_batches)
        rl_memory.clear()

        episode += 1

        if episode % target_update_interval == 0:
            explorer.update_target_model(model)

        if episode != 0 and episode % checkpoint_interval == 0:
            torch.save(model.state_dict(), rl_weight_file)

    # final test
    explorer.run_k_episodes(env.case_size['test'], 0, 'test', episode=episode)
Ejemplo n.º 5
0
    policy.model.load_state_dict(torch.load(weights_path))

    # domain/env config (initially same as used in training)
    env_config = configparser.RawConfigParser()
    env_config.read('configs/env.config')
    if args.policy == 'orca':
        env = CrowdSim()
    else:
        env = CrowdSim_SF()
    env.configure(env_config)
    robot = Robot(env_config, 'robot')
    robot.set_policy(policy)
    env.set_robot(robot)

    # explorer, memory is not used as we will only use the explorer in test mode
    memory = ReplayMemory(100000)
    explorer = Explorer(env,
                        robot,
                        device,
                        memory,
                        policy.gamma,
                        target_policy=policy)

    policy.set_env(env)
    robot.set_policy(policy)

    mods = {
        'neighbor_dist': args.neighd,
        'max_neighbors': args.maxn,
        'time_horizon': args.timeh,
        'max_speed': 1
Ejemplo n.º 6
0
def main(args):
    set_random_seeds(args.randomseed)
    # configure paths
    make_new_dir = True
    if os.path.exists(args.output_dir):
        if args.overwrite:
            shutil.rmtree(args.output_dir)
        else:
            shutil.rmtree(args.output_dir)
    if make_new_dir:
        os.makedirs(args.output_dir)
        shutil.copy(args.config, os.path.join(args.output_dir, 'config.py'))

        # # insert the arguments from command line to the config file
        # with open(os.path.join(args.output_dir, 'config.py'), 'r') as fo:
        #     config_text = fo.read()
        # search_pairs = {r"gcn.X_dim = \d*": "gcn.X_dim = {}".format(args.X_dim),
        #                 r"gcn.num_layer = \d": "gcn.num_layer = {}".format(args.layers),
        #                 r"gcn.similarity_function = '\w*'": "gcn.similarity_function = '{}'".format(args.sim_func),
        #                 r"gcn.layerwise_graph = \w*": "gcn.layerwise_graph = {}".format(args.layerwise_graph),
        #                 r"gcn.skip_connection = \w*": "gcn.skip_connection = {}".format(args.skip_connection)}
        #
        # for find, replace in search_pairs.items():
        #     config_text = re.sub(find, replace, config_text)
        #
        # with open(os.path.join(args.output_dir, 'config.py'), 'w') as fo:
        #     fo.write(config_text)

    args.config = os.path.join(args.output_dir, 'config.py')
    log_file = os.path.join(args.output_dir, 'output.log')
    in_weight_file = os.path.join(args.output_dir, 'in_model.pth')
    il_weight_file = os.path.join(args.output_dir, 'il_model.pth')
    rl_weight_file = os.path.join(args.output_dir, 'rl_model.pth')

    spec = importlib.util.spec_from_file_location('config', args.config)
    if spec is None:
        parser.error('Config file not found.')
    config = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(config)

    # configure logging
    mode = 'a' if args.resume else 'w'
    file_handler = logging.FileHandler(log_file, mode=mode)
    stdout_handler = logging.StreamHandler(sys.stdout)
    level = logging.INFO if not args.debug else logging.DEBUG
    logging.basicConfig(level=level,
                        handlers=[stdout_handler, file_handler],
                        format='%(asctime)s, %(levelname)s: %(message)s',
                        datefmt="%Y-%m-%d %H:%M:%S")
    repo = git.Repo(search_parent_directories=True)
    logging.info('Current git head hash code: {}'.format(
        repo.head.object.hexsha))
    logging.info('Current random seed: {}'.format(sys_args.randomseed))
    logging.info('Current safe_weight: {}'.format(sys_args.safe_weight))
    logging.info('Current goal_weight: {}'.format(sys_args.goal_weight))
    logging.info('Current re_collision: {}'.format(sys_args.re_collision))
    logging.info('Current re_arrival: {}'.format(sys_args.re_arrival))
    logging.info('Current config content is :{}'.format(config))
    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and args.gpu else "cpu")
    logging.info('Using device: %s', device)
    writer = SummaryWriter(log_dir=args.output_dir)

    # configure policy
    policy_config = config.PolicyConfig()
    policy = policy_factory[policy_config.name]()
    if not policy.trainable:
        parser.error('Policy has to be trainable')
    policy.configure(policy_config, device)
    policy.set_device(device)

    # configure environment
    env_config = config.EnvConfig(args.debug)
    env_config.reward.collision_penalty = args.re_collision
    env_config.reward.success_reward = args.re_arrival
    env_config.reward.goal_factor = args.goal_weight
    env_config.reward.discomfort_penalty_factor = args.safe_weight
    env_config.sim.human_num = args.human_num

    env = gym.make('CrowdSim-v0')
    env.configure(env_config)
    robot = Robot(env_config, 'robot')
    robot.time_step = env.time_step
    env.set_robot(robot)

    reward_estimator = Reward_Estimator()
    reward_estimator.configure(env_config)
    policy.reward_estimator = reward_estimator
    # for continous action
    action_dim = env.action_space.shape[0]
    max_action = env.action_space.high
    min_action = env.action_space.low

    # read training parameters
    train_config = config.TrainConfig(args.debug)
    rl_learning_rate = train_config.train.rl_learning_rate
    train_batches = train_config.train.train_batches
    train_episodes = train_config.train.train_episodes
    sample_episodes = train_config.train.sample_episodes
    target_update_interval = train_config.train.target_update_interval
    evaluation_interval = train_config.train.evaluation_interval
    capacity = train_config.train.capacity
    epsilon_start = train_config.train.epsilon_start
    epsilon_end = train_config.train.epsilon_end
    epsilon_decay = train_config.train.epsilon_decay
    checkpoint_interval = train_config.train.checkpoint_interval

    # configure trainer and explorer
    memory = ReplayMemory(capacity)
    model = policy.get_model()
    batch_size = train_config.trainer.batch_size
    optimizer = train_config.trainer.optimizer
    print(policy_config.name)
    if policy_config.name == 'model_predictive_rl':
        trainer = MPRLTrainer(
            model,
            policy.state_predictor,
            memory,
            device,
            policy,
            writer,
            batch_size,
            optimizer,
            env.human_num,
            reduce_sp_update_frequency=train_config.train.
            reduce_sp_update_frequency,
            freeze_state_predictor=train_config.train.freeze_state_predictor,
            detach_state_predictor=train_config.train.detach_state_predictor,
            share_graph_model=policy_config.model_predictive_rl.
            share_graph_model)
    elif policy_config.name == 'tree_search_rl':
        trainer = TSRLTrainer(
            model,
            policy.state_predictor,
            memory,
            device,
            policy,
            writer,
            batch_size,
            optimizer,
            env.human_num,
            reduce_sp_update_frequency=train_config.train.
            reduce_sp_update_frequency,
            freeze_state_predictor=train_config.train.freeze_state_predictor,
            detach_state_predictor=train_config.train.detach_state_predictor,
            share_graph_model=policy_config.model_predictive_rl.
            share_graph_model)

    elif policy_config.name == 'gat_predictive_rl':
        trainer = MPRLTrainer(
            model,
            policy.state_predictor,
            memory,
            device,
            policy,
            writer,
            batch_size,
            optimizer,
            env.human_num,
            reduce_sp_update_frequency=train_config.train.
            reduce_sp_update_frequency,
            freeze_state_predictor=train_config.train.freeze_state_predictor,
            detach_state_predictor=train_config.train.detach_state_predictor,
            share_graph_model=policy_config.model_predictive_rl.
            share_graph_model)
    elif policy_config.name == 'td3_rl':
        policy.set_action(action_dim, max_action, min_action)
        trainer = TD3RLTrainer(
            policy.actor,
            policy.critic,
            policy.state_predictor,
            memory,
            device,
            policy,
            writer,
            batch_size,
            optimizer,
            env.human_num,
            reduce_sp_update_frequency=train_config.train.
            reduce_sp_update_frequency,
            freeze_state_predictor=train_config.train.freeze_state_predictor,
            detach_state_predictor=train_config.train.detach_state_predictor,
            share_graph_model=policy_config.model_predictive_rl.
            share_graph_model)
    else:
        trainer = VNRLTrainer(model, memory, device, policy, batch_size,
                              optimizer, writer)
    explorer = Explorer(env,
                        robot,
                        device,
                        writer,
                        memory,
                        policy.gamma,
                        target_policy=policy)
    policy.save_model(in_weight_file)
    # imitation learning
    # if args.resume:
    #     if not os.path.exists(rl_weight_file):
    #         logging.error('RL weights does not exist')
    #     policy.load_state_dict(torch.load(rl_weight_file))
    #     model = policy.get_model()
    #     rl_weight_file = os.path.join(args.output_dir, 'resumed_rl_model.pth')
    #     logging.info('Load reinforcement learning trained weights. Resume training')
    # elif os.path.exists(il_weight_file):
    #     policy.load_state_dict(torch.load(rl_weight_file))
    #     model = policy.get_model()
    #     logging.info('Load imitation learning trained weights.')
    # else:
    #     il_episodes = train_config.imitation_learning.il_episodes
    #     il_policy = train_config.imitation_learning.il_policy
    #     il_epochs = train_config.imitation_learning.il_epochs
    #     il_learning_rate = train_config.imitation_learning.il_learning_rate
    #     trainer.set_learning_rate(il_learning_rate)
    #     if robot.visible:
    #         safety_space = 0
    #     else:
    #         safety_space = train_config.imitation_learning.safety_space
    #     il_policy = policy_factory[il_policy]()
    #     il_policy.set_common_parameters(policy_config)
    #     il_policy.multiagent_training = policy.multiagent_training
    #     il_policy.safety_space = safety_space
    #     robot.set_policy(il_policy)
    #     explorer.run_k_episodes(il_episodes, 'train', update_memory=True, imitation_learning=True)
    #     trainer.optimize_epoch(il_epochs)
    #     policy.save_model(il_weight_file)
    #     logging.info('Finish imitation learning. Weights saved.')
    #     logging.info('Experience set size: %d/%d', len(memory), memory.capacity)

    trainer.update_target_model(model)

    # reinforcement learning
    policy.set_env(env)
    robot.set_policy(policy)
    robot.print_info()
    trainer.set_rl_learning_rate(rl_learning_rate)
    # fill the memory pool with some RL experience
    if args.resume:
        robot.policy.set_epsilon(epsilon_end)
        explorer.run_k_episodes(100, 'train', update_memory=True, episode=0)
        logging.info('Experience set size: %d/%d', len(memory),
                     memory.capacity)
    episode = 0
    best_val_reward = -1
    best_val_return = -1
    best_val_model = None
    # evaluate the model after imitation learning

    if episode % evaluation_interval == 0:
        logging.info(
            'Evaluate the model instantly after imitation learning on the validation cases'
        )
        explorer.run_k_episodes(env.case_size['val'], 'val', episode=episode)
        explorer.log('val', episode // evaluation_interval)

        if args.test_after_every_eval:
            explorer.run_k_episodes(env.case_size['test'],
                                    'test',
                                    episode=episode,
                                    print_failure=True)
            explorer.log('test', episode // evaluation_interval)

    episode = 0
    reward_rec = []
    return_rec = []
    discom_tim_rec = []
    nav_time_rec = []
    total_time_rec = []
    reward_in_last_interval = 0
    return_in_last_interval = 0
    nav_time__in_last_interval = 0
    discom_time_in_last_interval = 0
    total_time_in_last_interval = 0
    eps_count = 0
    fw = open(sys_args.output_dir + '/data.txt', 'w')
    print("%f %f %f %f %f" % (0, 0, 0, 0, 0), file=fw)
    while episode < train_episodes:
        if args.resume:
            epsilon = epsilon_end
        else:
            if episode < epsilon_decay:
                epsilon = epsilon_start + (
                    epsilon_end - epsilon_start) / epsilon_decay * episode
            else:
                epsilon = epsilon_end
        robot.policy.set_epsilon(epsilon)

        # sample k episodes into memory and optimize over the generated memory
        _, _, nav_time, sum_reward, ave_return, discom_time, total_time = \
            explorer.run_k_episodes(sample_episodes, 'train', update_memory=True, episode=episode)
        eps_count = eps_count + 1
        reward_in_last_interval = reward_in_last_interval + sum_reward
        return_in_last_interval = return_in_last_interval + ave_return
        nav_time__in_last_interval = nav_time__in_last_interval + nav_time
        discom_time_in_last_interval = discom_time_in_last_interval + discom_time
        total_time_in_last_interval = total_time_in_last_interval + total_time
        interval = 100
        if eps_count % interval == 0:
            reward_rec.append(reward_in_last_interval / 100.0)
            return_rec.append(return_in_last_interval / 100.0)
            discom_tim_rec.append(discom_time_in_last_interval / 100.0)
            nav_time_rec.append(nav_time__in_last_interval / 100.0)
            total_time_rec.append(total_time_in_last_interval / 100.0)
            logging.info(
                'Train in episode %d reward in last 100 episodes %f %f %f %f %f',
                eps_count, reward_rec[-1], return_rec[-1], discom_tim_rec[-1],
                nav_time_rec[-1], total_time_rec[-1])
            print("%f %f %f %f %f" %
                  (reward_rec[-1], return_rec[-1], discom_tim_rec[-1],
                   nav_time_rec[-1], total_time_rec[-1]),
                  file=fw)
            reward_in_last_interval = 0
            return_in_last_interval = 0
            nav_time__in_last_interval = 0
            discom_time_in_last_interval = 0
            total_time_in_last_interval = 0
            min_reward = (np.min(reward_rec) // 5.0) * 5.0
            max_reward = (np.max(reward_rec) // 5.0 + 1) * 5.0
            pos = np.array(range(1, len(reward_rec) + 1)) * interval
            plt.plot(pos,
                     reward_rec,
                     color='r',
                     marker='.',
                     linestyle='dashed')
            plt.axis([0, eps_count, min_reward, max_reward])
            savefig(args.output_dir + "/reward_record.jpg")
        explorer.log('train', episode)

        trainer.optimize_batch(train_batches, episode)
        episode += 1

        if episode % target_update_interval == 0:
            trainer.update_target_model(model)
        # evaluate the model
        if episode % evaluation_interval == 0:
            _, _, _, reward, average_return, _, _ = explorer.run_k_episodes(
                env.case_size['val'], 'val', episode=episode)
            explorer.log('val', episode // evaluation_interval)

            if episode % checkpoint_interval == 0 and average_return > best_val_return:
                best_val_return = average_return
                best_val_model = copy.deepcopy(policy.get_state_dict())
        # test after every evaluation to check how the generalization performance evolves
            if args.test_after_every_eval:
                explorer.run_k_episodes(env.case_size['test'],
                                        'test',
                                        episode=episode,
                                        print_failure=True)
                explorer.log('test', episode // evaluation_interval)

        if episode != 0 and episode % checkpoint_interval == 0:
            current_checkpoint = episode // checkpoint_interval - 1
            save_every_checkpoint_rl_weight_file = rl_weight_file.split(
                '.')[0] + '_' + str(current_checkpoint) + '.pth'
            policy.save_model(save_every_checkpoint_rl_weight_file)

    # # test with the best val model
    fw.close()
    if best_val_model is not None:
        policy.load_state_dict(best_val_model)
        torch.save(best_val_model, os.path.join(args.output_dir,
                                                'best_val.pth'))
        logging.info('Save the best val model with the return: {}'.format(
            best_val_return))
    explorer.run_k_episodes(env.case_size['test'],
                            'test',
                            episode=episode,
                            print_failure=True)
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser("Parse configuration file")
    parser.add_argument("--env_config", type=str, default="configs/env.config")
    parser.add_argument("--policy", type=str, default="cadrl")
    parser.add_argument("--policy_config",
                        type=str,
                        default="configs/policy.config")
    # parser.add_argument('--train_config', type=str, default='configs/train-test.config')
    parser.add_argument("--train_config",
                        type=str,
                        default="configs/train.config")
    parser.add_argument("--output_dir",
                        type=str,
                        default="data2/actenvcarl_alltf")
    parser.add_argument("--weights", type=str)
    parser.add_argument("--resume", default=False, action="store_true")
    parser.add_argument("--gpu", default=True, action="store_true")
    parser.add_argument("--debug", default=True, action="store_true")
    parser.add_argument("--phase", type=str, default="train")
    parser.add_argument("--test_policy_flag", type=str, default="1")
    parser.add_argument("--optimizer", type=str, default="SGD")
    parser.add_argument("--multi_process", type=str, default="average")

    # reward parameters
    parser.add_argument("--agent_timestep", type=float, default=0.4)
    parser.add_argument("--human_timestep", type=float, default=0.5)
    parser.add_argument("--reward_increment", type=float, default=4.0)
    parser.add_argument("--position_variance", type=float, default=4.0)
    parser.add_argument("--direction_variance", type=float, default=4.0)

    # visable or not
    parser.add_argument("--visible", default=False, action="store_true")

    # act step cnt
    parser.add_argument("--act_steps", type=int, default=1)
    parser.add_argument("--act_fixed", default=False, action="store_true")

    args = parser.parse_args()

    agent_timestep = args.agent_timestep
    human_timestep = args.human_timestep
    reward_increment = args.reward_increment
    position_variance = args.position_variance
    direction_variance = args.direction_variance

    # agent_visable = args.visable

    optimizer_type = args.optimizer

    # configure paths
    make_new_dir = True
    if os.path.exists(args.output_dir):
        key = input(
            "Output directory already exists! Overwrite the folder? (y/n)")
        if key == "y" and not args.resume:
            shutil.rmtree(args.output_dir)
        else:
            make_new_dir = False
            args.env_config = os.path.join(args.output_dir,
                                           os.path.basename(args.env_config))
            args.policy_config = os.path.join(
                args.output_dir, os.path.basename(args.policy_config))
            args.train_config = os.path.join(
                args.output_dir, os.path.basename(args.train_config))
            shutil.copy("utils/trainer.py", args.output_dir)
            shutil.copy("policy/" + args.policy + ".py", args.output_dir)

    if make_new_dir:
        os.makedirs(args.output_dir)
        shutil.copy(args.env_config, args.output_dir)
        shutil.copy(args.policy_config, args.output_dir)
        shutil.copy(args.train_config, args.output_dir)
        shutil.copy("utils/trainer.py", args.output_dir)
        shutil.copy("policy/" + args.policy + ".py", args.output_dir)
    log_file = os.path.join(args.output_dir, "output.log")
    il_weight_file = os.path.join(args.output_dir, "il_model.pth")
    rl_weight_file = os.path.join(args.output_dir, "rl_model.pth")

    # configure logging
    mode = "a" if args.resume else "w"
    file_handler = logging.FileHandler(log_file, mode=mode)
    stdout_handler = logging.StreamHandler(sys.stdout)
    level = logging.INFO if not args.debug else logging.DEBUG
    logging.basicConfig(
        level=level,
        handlers=[stdout_handler, file_handler],
        format="%(asctime)s, %(levelname)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    # repo = git.Repo(search_parent_directories=True)
    # logging.info('Current git head hash code: %s'.format(repo.head.object.hexsha))
    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and args.gpu else "cpu")
    logging.info("Using device: %s", device)

    # configure policy
    policy = policy_factory[args.policy]()
    # policy.set_phase(args.phase)
    print("policy type: ", type(policy))

    if not policy.trainable:
        parser.error("Policy has to be trainable")
    if args.policy_config is None:
        parser.error(
            "Policy config has to be specified for a trainable network")
    policy_config = configparser.RawConfigParser()
    policy_config.read(args.policy_config)

    policy_config.set("actenvcarl", "test_policy_flag", args.test_policy_flag)
    policy_config.set("actenvcarl", "multi_process", args.multi_process)
    policy_config.set("actenvcarl", "act_steps", args.act_steps)
    policy_config.set("actenvcarl", "act_fixed", args.act_fixed)

    policy.configure(policy_config)
    policy.set_device(device)

    # configure environment
    env_config = configparser.RawConfigParser()
    env_config.read(args.env_config)
    env_config.set("reward", "agent_timestep", agent_timestep)
    env_config.set("reward", "human_timestep", human_timestep)
    env_config.set("reward", "reward_increment", reward_increment)
    env_config.set("reward", "position_variance", position_variance)
    env_config.set("reward", "direction_variance", direction_variance)
    env = gym.make("CrowdSim-v0")
    env.configure(env_config)

    # 只是为了获取部分的环境配置信息,如半径速度啥的
    robot = Robot(env_config, "robot")
    env.set_robot(robot)

    # read training parameters
    if args.train_config is None:
        parser.error(
            "Train config has to be specified for a trainable network")

    train_config = configparser.RawConfigParser()
    train_config.read(args.train_config)
    rl_learning_rate = train_config.getfloat("train", "rl_learning_rate")
    train_batches = train_config.getint("train", "train_batches")
    train_episodes = train_config.getint("train", "train_episodes")
    sample_episodes = train_config.getint("train", "sample_episodes")
    target_update_interval = train_config.getint("train",
                                                 "target_update_interval")
    evaluation_interval = train_config.getint("train", "evaluation_interval")
    capacity = train_config.getint("train", "capacity")
    epsilon_start = train_config.getfloat("train", "epsilon_start")
    epsilon_end = train_config.getfloat("train", "epsilon_end")
    epsilon_decay = train_config.getfloat("train", "epsilon_decay")
    checkpoint_interval = train_config.getint("train", "checkpoint_interval")

    # configure trainer and explorer
    memory = ReplayMemory(capacity)

    model = policy.get_model()

    batch_size = train_config.getint("trainer", "batch_size")

    trainer = Trainer(model, memory, device, batch_size)

    explorer = Explorer(env,
                        robot,
                        device,
                        memory,
                        policy.gamma,
                        target_policy=policy)

    # imitation learning
    if args.resume:
        if not os.path.exists(rl_weight_file):
            logging.error("RL weights does not exist")
        model.load_state_dict(torch.load(rl_weight_file))
        rl_weight_file = os.path.join(args.output_dir, "resumed_rl_model.pth")
        logging.info(
            "Load reinforcement learning trained weights. Resume training")
    elif os.path.exists(il_weight_file):
        model.load_state_dict(torch.load(il_weight_file))
        logging.info("Load imitation learning trained weights.")
    else:
        il_episodes = train_config.getint("imitation_learning", "il_episodes")
        il_policy = train_config.get("imitation_learning", "il_policy")
        il_epochs = train_config.getint("imitation_learning", "il_epochs")
        il_learning_rate = train_config.getfloat("imitation_learning",
                                                 "il_learning_rate")
        trainer.set_learning_rate(il_learning_rate, optimizer_type)
        if robot.visible:
            safety_space = 0
        else:
            safety_space = train_config.getfloat("imitation_learning",
                                                 "safety_space")
        il_policy = policy_factory[il_policy]()
        il_policy.multiagent_training = policy.multiagent_training
        il_policy.safety_space = safety_space

        print("il_policy: ", type(il_policy))
        robot.set_policy(il_policy)
        explorer.run_k_episodes(200,
                                "train",
                                update_memory=True,
                                imitation_learning=True)

        trainer.optimize_epoch(il_epochs)

        torch.save(model.state_dict(), il_weight_file)
        logging.info("Finish imitation learning. Weights saved.")
        logging.info("Experience set size: %d/%d", len(memory),
                     memory.capacity)

    explorer.update_target_model(model)

    # reinforcement learning
    policy.set_env(env)
    robot.set_policy(policy)
    robot.print_info()
    trainer.set_learning_rate(rl_learning_rate, optimizer_type)
    # fill the memory pool with some RL experienceo
    if args.resume:
        robot.policy.set_epsilon(epsilon_end)
        explorer.run_k_episodes(100, "train", update_memory=True, episode=0)
        logging.info("Experience set size: %d/%d", len(memory),
                     memory.capacity)
    episode = 0
    while episode < train_episodes:
        if args.resume:
            epsilon = epsilon_end
        else:
            if episode < epsilon_decay:
                epsilon = epsilon_start + (
                    epsilon_end - epsilon_start) / epsilon_decay * episode
            else:
                epsilon = epsilon_end
        robot.policy.set_epsilon(epsilon)

        # evaluate the model
        if episode % evaluation_interval == 0:
            explorer.run_k_episodes(env.case_size["val"],
                                    "val",
                                    episode=episode)

        # sample k episodes into memory and optimize over the generated memory
        explorer.run_k_episodes(sample_episodes,
                                "train",
                                update_memory=True,
                                episode=episode)
        trainer.optimize_batch(train_batches)
        episode += 1

        if episode % target_update_interval == 0:
            explorer.update_target_model(model)

        if episode != 0 and episode % checkpoint_interval == 0:
            # rl_weight_file_name = 'rl_model_' + str(episode)  + '.pth'
            rl_weight_file_name = "rl_model_{:d}.pth".format(episode)
            rl_weight_file = os.path.join(args.output_dir, rl_weight_file_name)
            torch.save(model.state_dict(),
                       rl_weight_file,
                       _use_new_zipfile_serialization=False)

    # final test
    explorer.run_k_episodes(env.case_size["test"], "test", episode=episode)