Example #1
0
def main(args):

    # 4. Update the config with args, and make the agent name.
    if args.num_envs is None:
        import multiprocessing as mp
        args.num_envs = max(mp.cpu_count() - 1, 1)

    merge_args_into_config(args, config)

    if config.gamma < 1.:
        config.clip_target_range = (np.round(-(1 / (1 - config.gamma)), 2), 0.)
    if config.gamma == 1:
        config.clip_target_range = (np.round(-args.env_max_step - 5, 2), 0.)

    if args.sparse_reward_shaping:
        config.clip_target_range = (-np.inf, np.inf)

    config.agent_name = make_agent_name(config, [
        'env', 'alg', 'her', 'layers', 'seed', 'tb', 'ag_curiosity',
        'eexplore', 'first_visit_succ', 'dg_score_multiplier', 'alpha'
    ],
                                        prefix=args.prefix)

    # 5. Setup / add basic modules to the config
    config.update(
        dict(
            trainer=StandardTrain(),
            evaluation=EpisodicEval(),
            policy=ActorPolicy(),
            logger=Logger(),
            state_normalizer=Normalizer(MeanStdNormalizer()),
            replay=OnlineHERBuffer(),
        ))

    config.prioritized_mode = args.prioritized_mode
    if config.prioritized_mode == 'mep':
        config.prioritized_replay = EntropyPrioritizedOnlineHERBuffer()

    if not args.no_ag_kde:
        config.ag_kde = RawKernelDensity('ag',
                                         optimize_every=1,
                                         samples=10000,
                                         kernel=args.kde_kernel,
                                         bandwidth=args.bandwidth,
                                         log_entropy=True)
    if args.ag_curiosity is not None:
        config.dg_kde = RawKernelDensity('dg',
                                         optimize_every=500,
                                         samples=10000,
                                         kernel='tophat',
                                         bandwidth=0.2)
        config.ag_kde_tophat = RawKernelDensity('ag',
                                                optimize_every=100,
                                                samples=10000,
                                                kernel='tophat',
                                                bandwidth=0.2,
                                                tag='_tophat')
        if args.transition_to_dg:
            config.alpha_curiosity = CuriosityAlphaMixtureModule()
        if 'rnd' in args.ag_curiosity:
            config.ag_rnd = RandomNetworkDensity('ag')
        if 'flow' in args.ag_curiosity:
            config.ag_flow = FlowDensity('ag')

        use_qcutoff = not args.no_cutoff

        if args.ag_curiosity == 'minq':
            config.ag_curiosity = QAchievedGoalCuriosity(
                max_steps=args.env_max_step,
                num_sampled_ags=args.num_sampled_ags,
                use_qcutoff=use_qcutoff,
                keep_dg_percent=args.keep_dg_percent)
        elif args.ag_curiosity == 'randq':
            config.ag_curiosity = QAchievedGoalCuriosity(
                max_steps=args.env_max_step,
                randomize=True,
                num_sampled_ags=args.num_sampled_ags,
                use_qcutoff=use_qcutoff,
                keep_dg_percent=args.keep_dg_percent)
        elif args.ag_curiosity == 'minkde':
            config.ag_curiosity = DensityAchievedGoalCuriosity(
                max_steps=args.env_max_step,
                num_sampled_ags=args.num_sampled_ags,
                use_qcutoff=use_qcutoff,
                keep_dg_percent=args.keep_dg_percent)
        elif args.ag_curiosity == 'minrnd':
            config.ag_curiosity = DensityAchievedGoalCuriosity(
                'ag_rnd',
                max_steps=args.env_max_step,
                num_sampled_ags=args.num_sampled_ags,
                use_qcutoff=use_qcutoff,
                keep_dg_percent=args.keep_dg_percent)
        elif args.ag_curiosity == 'minflow':
            config.ag_curiosity = DensityAchievedGoalCuriosity(
                'ag_flow',
                max_steps=args.env_max_step,
                num_sampled_ags=args.num_sampled_ags,
                use_qcutoff=use_qcutoff,
                keep_dg_percent=args.keep_dg_percent)
        elif args.ag_curiosity == 'randkde':
            config.ag_curiosity = DensityAchievedGoalCuriosity(
                alpha=args.alpha,
                max_steps=args.env_max_step,
                randomize=True,
                num_sampled_ags=args.num_sampled_ags,
                use_qcutoff=use_qcutoff,
                keep_dg_percent=args.keep_dg_percent)
        elif args.ag_curiosity == 'randrnd':
            config.ag_curiosity = DensityAchievedGoalCuriosity(
                'ag_rnd',
                alpha=args.alpha,
                max_steps=args.env_max_step,
                num_sampled_ags=args.num_sampled_ags,
                use_qcutoff=use_qcutoff,
                keep_dg_percent=args.keep_dg_percent)
        elif args.ag_curiosity == 'randflow':
            config.ag_curiosity = DensityAchievedGoalCuriosity(
                'ag_flow',
                alpha=args.alpha,
                max_steps=args.env_max_step,
                num_sampled_ags=args.num_sampled_ags,
                use_qcutoff=use_qcutoff,
                keep_dg_percent=args.keep_dg_percent)
        elif args.ag_curiosity == 'goaldisc':
            config.success_predictor = GoalSuccessPredictor(
                batch_size=args.succ_bs,
                history_length=args.succ_hl,
                optimize_every=args.succ_oe)
            config.ag_curiosity = SuccessAchievedGoalCuriosity(
                max_steps=args.env_max_step,
                use_qcutoff=use_qcutoff,
                keep_dg_percent=args.keep_dg_percent)
        elif args.ag_curiosity == 'entropygainscore':
            config.bg_kde = RawKernelDensity('bg',
                                             optimize_every=args.env_max_step,
                                             samples=10000,
                                             kernel=args.kde_kernel,
                                             bandwidth=args.bandwidth,
                                             log_entropy=True)
            config.bgag_kde = RawJointKernelDensity(
                ['bg', 'ag'],
                optimize_every=args.env_max_step,
                samples=10000,
                kernel=args.kde_kernel,
                bandwidth=args.bandwidth,
                log_entropy=True)
            config.ag_curiosity = EntropyGainScoringGoalCuriosity(
                max_steps=args.env_max_step,
                use_qcutoff=use_qcutoff,
                keep_dg_percent=args.keep_dg_percent)
        else:
            raise NotImplementedError

    if args.noise_type.lower() == 'gaussian': noise_type = GaussianProcess
    if args.noise_type.lower() == 'ou': noise_type = OrnsteinUhlenbeckProcess
    config.action_noise = ContinuousActionNoise(noise_type,
                                                std=ConstantSchedule(
                                                    args.action_noise))

    if args.alg.lower() == 'ddpg':
        config.algorithm = DDPG()
    elif args.alg.lower() == 'td3':
        config.algorithm = TD3()
        config.target_network_update_freq *= 2
    elif args.alg.lower() == 'dqn':
        config.algorithm = DQN()
        config.policy = QValuePolicy()
        config.qvalue_lr = config.critic_lr
        config.qvalue_weight_decay = config.actor_weight_decay
        config.double_q = True
        config.random_action_prob = LinearSchedule(1.0, config.eexplore, 1e5)
    else:
        raise NotImplementedError

    # 6. Setup / add the environments and networks (which depend on the environment) to the config
    env, eval_env = make_env(args)
    if args.first_visit_done:
        env1, eval_env1 = env, eval_env
        env = lambda: FirstVisitDoneWrapper(env1())
        eval_env = lambda: FirstVisitDoneWrapper(eval_env1())
    if args.first_visit_succ:
        config.first_visit_succ = True

    config.train_env = EnvModule(env, num_envs=args.num_envs, seed=args.seed)
    config.eval_env = EnvModule(eval_env,
                                num_envs=args.num_eval_envs,
                                name='eval_env',
                                seed=args.seed + 1138)

    e = config.eval_env
    if args.alg.lower() == 'dqn':
        config.qvalue = PytorchModel(
            'qvalue', lambda: Critic(
                FCBody(e.state_dim + e.goal_dim, args.layers, nn.LayerNorm,
                       make_activ(config.activ)), e.action_dim))
    else:
        config.actor = PytorchModel(
            'actor', lambda: Actor(
                FCBody(e.state_dim + e.goal_dim, args.layers, nn.LayerNorm,
                       make_activ(config.activ)), e.action_dim, e.max_action))
        config.critic = PytorchModel(
            'critic', lambda: Critic(
                FCBody(e.state_dim + e.goal_dim + e.action_dim, args.layers, nn
                       .LayerNorm, make_activ(config.activ)), 1))
        if args.alg.lower() == 'td3':
            config.critic2 = PytorchModel(
                'critic2', lambda: Critic(
                    FCBody(e.state_dim + e.goal_dim + e.action_dim, args.
                           layers, nn.LayerNorm, make_activ(config.activ)), 1))

    if args.ag_curiosity == 'goaldisc':
        config.goal_discriminator = PytorchModel(
            'goal_discriminator', lambda: Critic(
                FCBody(e.state_dim + e.goal_dim, args.layers, nn.LayerNorm,
                       make_activ(config.activ)), 1))

    if args.reward_module == 'env':
        config.goal_reward = GoalEnvReward()
    elif args.reward_module == 'intrinsic':
        config.goal_reward = NeighborReward()
        config.neighbor_embedding_network = PytorchModel(
            'neighbor_embedding_network',
            lambda: FCBody(e.goal_dim, (256, 256)))
    else:
        raise ValueError('Unsupported reward module: {}'.format(
            args.reward_module))

    if config.eval_env.goal_env:
        if not (args.first_visit_done or args.first_visit_succ):
            config.never_done = True  # NOTE: This is important in the standard Goal environments, which are never done

    # 7. Make the agent and run the training loop.
    agent = mrl.config_to_agent(config)

    if args.visualize_trained_agent:
        print("Loading agent at epoch {}".format(0))
        agent.load('checkpoint')

        if args.intrinsic_visualization:
            agent.eval_mode()
            agent.train(10000, render=True, dont_optimize=True)

        else:
            agent.eval_mode()
            env = agent.eval_env

            for _ in range(10000):
                print("NEW EPISODE")
                state = env.reset()
                env.render()
                done = False
                while not done:
                    time.sleep(0.02)
                    action = agent.policy(state)
                    state, reward, done, info = env.step(action)
                    env.render()
                    print(reward[0])
    else:
        ag_buffer = agent.replay_buffer.buffer.BUFF.buffer_ag
        bg_buffer = agent.replay_buffer.buffer.BUFF.buffer_bg

        # EVALUATE
        res = np.mean(agent.eval(num_episodes=30).rewards)
        agent.logger.log_color('Initial test reward (30 eps):',
                               '{:.2f}'.format(res))

        for epoch in range(int(args.max_steps // args.epoch_len)):
            t = time.time()
            agent.train(num_steps=args.epoch_len)

            # VIZUALIZE GOALS
            if args.save_embeddings:
                sample_idxs = np.random.choice(len(ag_buffer),
                                               size=min(
                                                   len(ag_buffer),
                                                   args.epoch_len),
                                               replace=False)
                last_idxs = np.arange(max(0,
                                          len(ag_buffer) - args.epoch_len),
                                      len(ag_buffer))
                agent.logger.add_embedding('rand_ags',
                                           ag_buffer.get_batch(sample_idxs))
                agent.logger.add_embedding('last_ags',
                                           ag_buffer.get_batch(last_idxs))
                agent.logger.add_embedding('last_bgs',
                                           bg_buffer.get_batch(last_idxs))

            # EVALUATE
            res = np.mean(agent.eval(num_episodes=30).rewards)
            agent.logger.log_color('Test reward (30 eps):',
                                   '{:.2f}'.format(res))
            agent.logger.log_color('Epoch time:',
                                   '{:.2f}'.format(time.time() - t),
                                   color='yellow')

            print("Saving agent at epoch {}".format(epoch))
            agent.save('checkpoint')
Example #2
0
File: train.py Project: spitis/mrl
def main(args):

    # 4. Update the config with args, and make the agent name.
    if args.num_envs is None:
        import multiprocessing as mp
        args.num_envs = max(mp.cpu_count() - 1, 1)
    merge_args_into_config(args, config)

    torch.set_num_threads(min(8, args.num_envs))
    torch.set_num_interop_threads(min(8, args.num_envs))

    if config.gamma < 1.:
        config.clip_target_range = (np.round(-(1 / (1 - config.gamma)), 2), 0.)
    if config.gamma == 1:
        config.clip_target_range = (np.round(-args.env_max_step - 5, 2), 0.)
    if args.sparse_reward_shaping or 'sac' in args.alg.lower():
        config.clip_target_range = (-np.inf, np.inf)

    config.agent_name = make_agent_name(config, ['env', 'alg', 'tb', 'seed'],
                                        prefix=args.prefix)

    # 6. Setup environments & add them to config, so modules can refer to them if need be
    env, eval_env = make_env(args)
    if args.first_visit_done:
        env1, eval_env1 = env, eval_env
        env = lambda: FirstVisitDoneWrapper(env1(
        ))  # Terminates the training episode on "done"
        eval_env = lambda: FirstVisitDoneWrapper(eval_env1())
    if args.first_visit_succ:
        config.first_visit_succ = True  # Continues the training episode on "done", but counts it as if "done" (gamma = 0)
    if 'dictpush' in args.env.lower():
        config.modalities = ['gripper', 'object', 'relative']
        if 'reach' in args.env.lower():
            config.goal_modalities = ['gripper_goal', 'object_goal']
        else:
            config.goal_modalities = ['desired_goal']
        config.achieved_goal = GoalEnvAchieved()
    config.train_env = EnvModule(env,
                                 num_envs=args.num_envs,
                                 seed=args.seed,
                                 modalities=config.modalities,
                                 goal_modalities=config.goal_modalities)
    config.eval_env = EnvModule(eval_env,
                                num_envs=args.num_eval_envs,
                                name='eval_env',
                                seed=args.seed + 1138,
                                modalities=config.modalities,
                                goal_modalities=config.goal_modalities)

    # 7. Setup / add modules to the config

    # Base Modules
    config.update(
        dict(
            trainer=StandardTrain(),
            evaluation=EpisodicEval(),
            policy=ActorPolicy(),
            logger=Logger(),
            state_normalizer=Normalizer(MeanStdNormalizer()),
            replay=OnlineHERBuffer(),
        ))

    # Goal Selection Modules
    if args.ag_curiosity is not None:
        config.ag_kde = RawKernelDensity('ag',
                                         optimize_every=4,
                                         samples=2000,
                                         kernel=args.kde_kernel,
                                         bandwidth=args.bandwidth,
                                         log_entropy=True)
        config.dg_kde = RawKernelDensity('dg',
                                         optimize_every=500,
                                         samples=5000,
                                         kernel='tophat',
                                         bandwidth=0.2)
        config.ag_kde_tophat = RawKernelDensity('ag',
                                                optimize_every=100,
                                                samples=5000,
                                                kernel='tophat',
                                                bandwidth=0.2,
                                                tag='_tophat')
        if args.transition_to_dg:
            config.alpha_curiosity = CuriosityAlphaMixtureModule()

        use_qcutoff = not args.no_cutoff

        if args.ag_curiosity == 'minkde':
            config.ag_curiosity = DensityAchievedGoalCuriosity(
                max_steps=args.env_max_step,
                num_sampled_ags=args.num_sampled_ags,
                use_qcutoff=use_qcutoff,
                keep_dg_percent=args.keep_dg_percent)
        else:
            raise NotImplementedError

    # Action Noise Modules
    if args.noise_type.lower() == 'gaussian': noise_type = GaussianProcess
    if args.noise_type.lower() == 'ou': noise_type = OrnsteinUhlenbeckProcess
    config.action_noise = ContinuousActionNoise(noise_type,
                                                std=ConstantSchedule(
                                                    args.action_noise))

    # Algorithm Modules
    if args.alg.lower() == 'ddpg':
        config.algorithm = DDPG()
    elif args.alg.lower() == 'td3':
        config.algorithm = TD3()
        config.target_network_update_freq *= 2
    elif args.alg.lower() == 'sac':
        config.algorithm = SAC()
    elif args.alg.lower() == 'dqn':
        config.algorithm = DQN()
        config.policy = QValuePolicy()
        config.qvalue_lr = config.critic_lr
        config.qvalue_weight_decay = config.actor_weight_decay
        config.double_q = True
        config.random_action_prob = LinearSchedule(1.0, config.eexplore, 1e5)
    else:
        raise NotImplementedError

    # 7. Actor/Critic Networks
    e = config.eval_env
    if args.alg.lower() == 'dqn':
        config.qvalue = PytorchModel(
            'qvalue', lambda: Critic(
                FCBody(e.state_dim + e.goal_dim, args.layers, nn.Identity,
                       make_activ(config.activ)), e.action_dim))
    else:
        config.actor = PytorchModel(
            'actor', lambda: Actor(
                FCBody(e.state_dim + e.goal_dim, args.layers, nn.Identity,
                       make_activ(config.activ)), e.action_dim, e.max_action))
        config.critic = PytorchModel(
            'critic', lambda: Critic(
                FCBody(e.state_dim + e.goal_dim + e.action_dim, args.layers, nn
                       .Identity, make_activ(config.activ)), 1))
        if args.alg.lower() in ['td3', 'sac']:
            config.critic2 = PytorchModel(
                'critic2', lambda: Critic(
                    FCBody(e.state_dim + e.goal_dim + e.action_dim, args.
                           layers, nn.Identity, make_activ(config.activ)), 1))
        if args.alg.lower() == 'sac':
            del config.actor
            config.actor = PytorchModel(
                'actor', lambda: StochasticActor(FCBody(
                    e.state_dim + e.goal_dim, args.layers, nn.Identity,
                    make_activ(config.activ)),
                                                 e.action_dim,
                                                 e.max_action,
                                                 log_std_bounds=(-20, 2)))
            del config.policy
            config.policy = StochasticActorPolicy()

    # 8. Reward modules
    if args.reward_module == 'env':
        config.goal_reward = GoalEnvReward()
    elif args.reward_module == 'intrinsic':
        config.goal_reward = NeighborReward()
        config.neighbor_embedding_network = PytorchModel(
            'neighbor_embedding_network',
            lambda: FCBody(e.goal_dim, (256, 256)))
    else:
        raise ValueError('Unsupported reward module: {}'.format(
            args.reward_module))

    if config.eval_env.goal_env:
        if not (args.first_visit_done or args.first_visit_succ):
            config.never_done = True  # NOTE: This is important in the standard Goal environments, which are never done

    # 9. Make the agent
    agent = mrl.config_to_agent(config)

    if args.checkpoint_dir is not None:
        # If a checkpoint has been initialized load it.
        if os.path.exists(os.path.join(args.checkpoint_dir, 'INITIALIZED')):
            agent.load_from_checkpoint(args.checkpoint_dir)

    # 10.A Vizualize a trained agent
    if args.visualize_trained_agent:
        print("Loading agent at epoch {}".format(0))
        agent.load('checkpoint')

        if args.intrinsic_visualization:
            agent.eval_mode()
            agent.train(10000, render=True, dont_optimize=True)

        else:
            agent.eval_mode()
            env = agent.eval_env

            for _ in range(10000):
                print("NEW EPISODE")
                state = env.reset()
                env.render()
                done = False
                while not done:
                    time.sleep(0.02)
                    action = agent.policy(state)
                    state, reward, done, info = env.step(action)
                    env.render()
                    print(reward[0])

    # 10.B Or run the training loop
    else:
        ag_buffer = agent.replay_buffer.buffer.BUFF.buffer_ag
        bg_buffer = agent.replay_buffer.buffer.BUFF.buffer_bg

        # EVALUATE
        res = np.mean(agent.eval(num_episodes=30).rewards)
        agent.logger.log_color('Initial test reward (30 eps):',
                               '{:.2f}'.format(res))

        for epoch in range(int(args.max_steps // args.epoch_len)):
            t = time.time()
            agent.train(num_steps=args.epoch_len)

            # VIZUALIZE GOALS
            if args.save_embeddings:
                sample_idxs = np.random.choice(len(ag_buffer),
                                               size=min(
                                                   len(ag_buffer),
                                                   args.epoch_len),
                                               replace=False)
                last_idxs = np.arange(max(0,
                                          len(ag_buffer) - args.epoch_len),
                                      len(ag_buffer))
                agent.logger.add_embedding('rand_ags',
                                           ag_buffer.get_batch(sample_idxs))
                agent.logger.add_embedding('last_ags',
                                           ag_buffer.get_batch(last_idxs))
                agent.logger.add_embedding('last_bgs',
                                           bg_buffer.get_batch(last_idxs))

            # EVALUATE
            res = np.mean(agent.eval(num_episodes=30).rewards)
            agent.logger.log_color('Test reward (30 eps):',
                                   '{:.2f}'.format(res))
            agent.logger.log_color('Epoch time:',
                                   '{:.2f}'.format(time.time() - t),
                                   color='yellow')

            print("Saving agent at epoch {}".format(epoch))
            agent.save('checkpoint')

            # Also save to checkpoint if a checkpoint_dir is specified.
            if args.checkpoint_dir is not None:
                agent.save_checkpoint(args.checkpoint_dir)