Пример #1
0
def experiment(variant):
    expl_env = gym.make('carla-lane-dict-v0')

    eval_env = expl_env
    num_channels, img_width, img_height = eval_env.image_shape
    num_channels = 3

    action_dim = int(np.prod(eval_env.action_space.shape))
    cnn_params = variant['cnn_params']
    cnn_params.update(
        input_width=img_width,
        input_height=img_height,
        input_channels=num_channels,
        added_fc_input_size=0,
        output_conv_channels=True,
        output_size=None,
    )

    qf_cnn = CNN(**cnn_params)
    qf_obs_processor = nn.Sequential(
        qf_cnn,
        Flatten(),
    )

    qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    qf_kwargs['obs_processor'] = qf_obs_processor
    qf_kwargs['output_size'] = 1
    qf_kwargs['input_size'] = (
            action_dim + qf_cnn.conv_output_flat_size
    )
    qf1 = MlpQfWithObsProcessor(**qf_kwargs)
    qf2 = MlpQfWithObsProcessor(**qf_kwargs)

    target_qf_cnn = CNN(**cnn_params)
    target_qf_obs_processor = nn.Sequential(
        target_qf_cnn,
        Flatten(),
    )

    target_qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    target_qf_kwargs['obs_processor'] = target_qf_obs_processor
    target_qf_kwargs['output_size'] = 1
    target_qf_kwargs['input_size'] = (
            action_dim + target_qf_cnn.conv_output_flat_size
    )

    target_qf1 = MlpQfWithObsProcessor(**target_qf_kwargs)
    target_qf2 = MlpQfWithObsProcessor(**target_qf_kwargs)

    action_dim = int(np.prod(eval_env.action_space.shape))
    policy_cnn = CNN(**cnn_params)
    policy_obs_processor = nn.Sequential(
        policy_cnn,
        Flatten(),
    )
    policy = TanhGaussianPolicyAdapter(
        policy_obs_processor,
        policy_cnn.conv_output_flat_size,
        action_dim,
        **variant['policy_kwargs']
    )

    cnn_vae_params = variant['cnn_vae_params']
    cnn_vae_params['conv_args'].update(
        input_width=img_width,
        input_height=img_height,
        input_channels=num_channels,
    )
    vae_policy = ConvVAEPolicy(
        representation_size=cnn_vae_params['representation_size'],
        architecture=cnn_vae_params,
        action_dim=action_dim,
        input_channels=3,
        imsize=img_width,
    )

    observation_key = 'image'
    eval_path_collector = CustomObsDictPathCollector(
        eval_env,
        observation_key=observation_key,
        **variant['eval_path_collector_kwargs']
    )

    vae_eval_path_collector = CustomObsDictPathCollector(
        eval_env,
        # eval_policy,
        observation_key=observation_key,
        **variant['eval_path_collector_kwargs']
    )

    #with open(variant['buffer'], 'rb') as f:
    #    replay_buffer = pickle.load(f)
    observation_key = 'image'
    replay_buffer = ObsDictReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
        observation_key=observation_key,
    )
    load_hdf5(expl_env, replay_buffer)


    trainer = BEARTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        vae=vae_policy,
        **variant['trainer_kwargs']
    )

    expl_path_collector = ObsDictPathCollector(
        expl_env,
        policy,
        observation_key=observation_key,
        **variant['expl_path_collector_kwargs']
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        vae_evaluation_data_collector=vae_eval_path_collector,
        replay_buffer=replay_buffer,
        q_learning_alg=True,
        batch_rl=variant['batch_rl'],
        **variant['algo_kwargs']
    )

    video_func = VideoSaveFunctionBullet(variant)
    # dump_buffer_func = BufferSaveFunction(variant)

    algorithm.post_train_funcs.append(video_func)
    # algorithm.post_train_funcs.append(dump_buffer_func)

    algorithm.to(ptu.device)
    algorithm.train()
Пример #2
0
def _pointmass_fixed_goal_experiment(vae_latent_size,
                                     replay_buffer_size,
                                     cnn_kwargs,
                                     vae_kwargs,
                                     policy_kwargs,
                                     qf_kwargs,
                                     e2e_trainer_kwargs,
                                     sac_trainer_kwargs,
                                     algorithm_kwargs,
                                     eval_path_collector_kwargs=None,
                                     expl_path_collector_kwargs=None,
                                     **kwargs):
    if expl_path_collector_kwargs is None:
        expl_path_collector_kwargs = {}
    if eval_path_collector_kwargs is None:
        eval_path_collector_kwargs = {}
    from multiworld.core.image_env import ImageEnv
    from multiworld.envs.pygame.point2d import Point2DEnv
    from multiworld.core.flat_goal_env import FlatGoalEnv
    env = Point2DEnv(
        images_are_rgb=True,
        render_onscreen=False,
        show_goal=False,
        ball_radius=2,
        render_size=48,
        fixed_goal=(0, 0),
    )
    env = ImageEnv(env, imsize=env.render_size, transpose=True, normalize=True)
    env = FlatGoalEnv(env)  #, append_goal_to_obs=True)
    input_width, input_height = env.image_shape

    action_dim = int(np.prod(env.action_space.shape))
    vae = ConvVAE(
        representation_size=vae_latent_size,
        input_channels=3,
        imsize=input_width,
        decoder_output_activation=nn.Sigmoid(),
        # decoder_distribution='gaussian_identity_variance',
        **vae_kwargs)
    encoder = Vae2Encoder(vae)

    def make_cnn():
        return networks.CNN(input_width=input_width,
                            input_height=input_height,
                            input_channels=3,
                            output_conv_channels=True,
                            output_size=None,
                            **cnn_kwargs)

    def make_qf():
        return networks.MlpQfWithObsProcessor(obs_processor=nn.Sequential(
            encoder,
            networks.Flatten(),
        ),
                                              output_size=1,
                                              input_size=action_dim +
                                              vae_latent_size,
                                              **qf_kwargs)

    qf1 = make_qf()
    qf2 = make_qf()
    target_qf1 = make_qf()
    target_qf2 = make_qf()
    action_dim = int(np.prod(env.action_space.shape))
    policy_cnn = make_cnn()
    policy = TanhGaussianPolicyAdapter(
        nn.Sequential(policy_cnn, networks.Flatten()),
        policy_cnn.conv_output_flat_size, action_dim, **policy_kwargs)
    eval_env = expl_env = env

    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(eval_env, eval_policy,
                                           **eval_path_collector_kwargs)
    replay_buffer = EnvReplayBuffer(
        replay_buffer_size,
        expl_env,
    )
    vae_trainer = VAETrainer(vae)
    sac_trainer = SACTrainer(env=eval_env,
                             policy=policy,
                             qf1=qf1,
                             qf2=qf2,
                             target_qf1=target_qf1,
                             target_qf2=target_qf2,
                             **sac_trainer_kwargs)
    trainer = End2EndSACTrainer(
        sac_trainer=sac_trainer,
        vae_trainer=vae_trainer,
        **e2e_trainer_kwargs,
    )
    expl_path_collector = MdpPathCollector(expl_env, policy,
                                           **expl_path_collector_kwargs)
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **algorithm_kwargs)
    algorithm.to(ptu.device)
    algorithm.train()
Пример #3
0
def experiment(variant):
    import multiworld.envs.pygame
    env = gym.make('Point2DEnv-ImageFixedGoal-v0')
    input_width, input_height = env.image_shape

    action_dim = int(np.prod(env.action_space.shape))
    cnn_params = variant['cnn_params']
    cnn_params.update(
        input_width=input_width,
        input_height=input_height,
        input_channels=3,
        output_conv_channels=True,
        output_size=None,
    )
    if variant['shared_qf_conv']:
        qf_cnn = PretrainedCNN(**cnn_params)
        qf1 = MlpQfWithObsProcessor(
            obs_processor=nn.Sequential(qf_cnn, Flatten()),
            output_size=1,
            input_size=action_dim + qf_cnn.conv_output_flat_size,
            **variant['qf_kwargs'])
        qf2 = MlpQfWithObsProcessor(
            obs_processor=nn.Sequential(qf_cnn, Flatten()),
            output_size=1,
            input_size=action_dim + qf_cnn.conv_output_flat_size,
            **variant['qf_kwargs'])
        target_qf_cnn = PretrainedCNN(**cnn_params)
        target_qf1 = MlpQfWithObsProcessor(
            obs_processor=nn.Sequential(target_qf_cnn, Flatten()),
            output_size=1,
            input_size=action_dim + target_qf_cnn.conv_output_flat_size,
            **variant['qf_kwargs'])
        target_qf2 = MlpQfWithObsProcessor(
            obs_processor=nn.Sequential(target_qf_cnn, Flatten()),
            output_size=1,
            input_size=action_dim + target_qf_cnn.conv_output_flat_size,
            **variant['qf_kwargs'])
    else:
        qf1_cnn = PretrainedCNN(**cnn_params)
        cnn_output_dim = qf1_cnn.conv_output_flat_size
        qf1 = MlpQfWithObsProcessor(obs_processor=nn.Sequential(
            qf1_cnn, Flatten()),
                                    output_size=1,
                                    input_size=action_dim + cnn_output_dim,
                                    **variant['qf_kwargs'])
        qf2 = MlpQfWithObsProcessor(obs_processor=nn.Sequential(
            PretrainedCNN(**cnn_params), Flatten()),
                                    output_size=1,
                                    input_size=action_dim + cnn_output_dim,
                                    **variant['qf_kwargs'])
        target_qf1 = MlpQfWithObsProcessor(
            obs_processor=nn.Sequential(PretrainedCNN(**cnn_params),
                                        Flatten()),
            output_size=1,
            input_size=action_dim + cnn_output_dim,
            **variant['qf_kwargs'])
        target_qf2 = MlpQfWithObsProcessor(
            obs_processor=nn.Sequential(PretrainedCNN(**cnn_params),
                                        Flatten()),
            output_size=1,
            input_size=action_dim + cnn_output_dim,
            **variant['qf_kwargs'])
    action_dim = int(np.prod(env.action_space.shape))
    policy_cnn = PretrainedCNN(**cnn_params)
    policy = TanhGaussianPolicyAdapter(nn.Sequential(policy_cnn, Flatten()),
                                       policy_cnn.conv_output_flat_size,
                                       action_dim, **variant['policy_kwargs'])
    eval_env = expl_env = env

    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env, eval_policy, **variant['eval_path_collector_kwargs'])
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['trainer_kwargs'])
    if variant['collection_mode'] == 'batch':
        expl_path_collector = MdpPathCollector(
            expl_env, policy, **variant['expl_path_collector_kwargs'])
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    elif variant['collection_mode'] == 'online':
        expl_path_collector = MdpStepCollector(
            expl_env, policy, **variant['expl_path_collector_kwargs'])
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Пример #4
0
def experiment(variant):
    #expl_env = carla_env.CarlaObsDictEnv(args=variant['env_args'])
    import gym
    import d4rl.carla
    expl_env = gym.make('carla-lane-dict-v0')

    eval_env = expl_env
    #num_channels, img_width, img_height = eval_env._wrapped_env.image_shape
    num_channels, img_width, img_height = eval_env.image_shape
    # num_channels = 3
    action_dim = int(np.prod(eval_env.action_space.shape))
    # obs_dim = 11

    cnn_params = variant['cnn_params']
    cnn_params.update(
        input_width=img_width,
        input_height=img_height,
        input_channels=num_channels,
        added_fc_input_size=0,
        output_conv_channels=True,
        output_size=None,
    )

    qf_cnn = CNN(**cnn_params)
    qf_obs_processor = nn.Sequential(
        qf_cnn,
        Flatten(),
    )

    qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    qf_kwargs['obs_processor'] = qf_obs_processor
    qf_kwargs['output_size'] = 1
    qf_kwargs['input_size'] = (action_dim + qf_cnn.conv_output_flat_size)
    qf1 = MlpQfWithObsProcessor(**qf_kwargs)
    qf2 = MlpQfWithObsProcessor(**qf_kwargs)

    target_qf_cnn = CNN(**cnn_params)
    target_qf_obs_processor = nn.Sequential(
        target_qf_cnn,
        Flatten(),
    )

    target_qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    target_qf_kwargs['obs_processor'] = target_qf_obs_processor
    target_qf_kwargs['output_size'] = 1
    target_qf_kwargs['input_size'] = (action_dim +
                                      target_qf_cnn.conv_output_flat_size)

    target_qf1 = MlpQfWithObsProcessor(**target_qf_kwargs)
    target_qf2 = MlpQfWithObsProcessor(**target_qf_kwargs)

    action_dim = int(np.prod(eval_env.action_space.shape))
    policy_cnn = CNN(**cnn_params)
    policy_obs_processor = nn.Sequential(
        policy_cnn,
        Flatten(),
    )
    policy = TanhGaussianPolicyAdapter(policy_obs_processor,
                                       policy_cnn.conv_output_flat_size,
                                       action_dim, **variant['policy_kwargs'])

    eval_policy = MakeDeterministic(policy)
    observation_key = 'image'

    eval_path_collector = ObsDictPathCollector(
        eval_env,
        eval_policy,
        observation_key=observation_key,
        **variant['eval_path_collector_kwargs'])

    expl_path_collector = CustomObsDictPathCollector(
        expl_env,
        observation_key=observation_key,
    )

    observation_key = 'image'
    replay_buffer = ObsDictReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
        observation_key=observation_key,
    )
    load_hdf5(expl_env, replay_buffer)
    #load_buffer(buffer_path=variant['buffer'], replay_buffer=replay_buffer)
    # import ipdb; ipdb.set_trace()

    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         behavior_policy=None,
                         **variant['trainer_kwargs'])
    variant['algo_kwargs']['max_path_length'] = expl_env._max_episode_steps
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        eval_both=True,
        batch_rl=True,
        **variant['algorithm_kwargs'])

    video_func = VideoSaveFunctionBullet(variant)
    algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):

    expl_env = roboverse.make(variant['env'],
                              gui=False,
                              randomize=variant['randomize_env'],
                              observation_mode=variant['obs'],
                              reward_type='shaped',
                              transpose_image=True)

    if variant['obs'] == 'pixels_debug':
        robot_state_dims = 11
    elif variant['obs'] == 'pixels':
        robot_state_dims = 4
    else:
        raise NotImplementedError

    expl_env = FlatEnv(expl_env,
                       use_robot_state=variant['use_robot_state'],
                       robot_state_dims=robot_state_dims)
    eval_env = expl_env

    img_width, img_height = eval_env.image_shape
    num_channels = 3

    action_dim = int(np.prod(eval_env.action_space.shape))
    cnn_params = variant['cnn_params']
    cnn_params.update(
        input_width=img_width,
        input_height=img_height,
        input_channels=num_channels,
    )
    if variant['use_robot_state']:
        cnn_params.update(
            added_fc_input_size=robot_state_dims,
            output_conv_channels=False,
            hidden_sizes=[400, 400],
            output_size=200,
        )
    else:
        cnn_params.update(
            added_fc_input_size=0,
            output_conv_channels=True,
            output_size=None,
        )
    qf_cnn = CNN(**cnn_params)

    if variant['use_robot_state']:
        qf_obs_processor = qf_cnn
    else:
        qf_obs_processor = nn.Sequential(
            qf_cnn,
            Flatten(),
        )

    qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    qf_kwargs['obs_processor'] = qf_obs_processor
    qf_kwargs['output_size'] = 1

    if variant['use_robot_state']:
        qf_kwargs['input_size'] = (action_dim + qf_cnn.output_size)
    else:
        qf_kwargs['input_size'] = (action_dim + qf_cnn.conv_output_flat_size)

    qf1 = MlpQfWithObsProcessor(**qf_kwargs)
    qf2 = MlpQfWithObsProcessor(**qf_kwargs)

    target_qf_cnn = CNN(**cnn_params)
    if variant['use_robot_state']:
        target_qf_obs_processor = target_qf_cnn
    else:
        target_qf_obs_processor = nn.Sequential(
            target_qf_cnn,
            Flatten(),
        )

    target_qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    target_qf_kwargs['obs_processor'] = target_qf_obs_processor
    target_qf_kwargs['output_size'] = 1

    if variant['use_robot_state']:
        target_qf_kwargs['input_size'] = (action_dim +
                                          target_qf_cnn.output_size)
    else:
        target_qf_kwargs['input_size'] = (action_dim +
                                          target_qf_cnn.conv_output_flat_size)

    target_qf1 = MlpQfWithObsProcessor(**target_qf_kwargs)
    target_qf2 = MlpQfWithObsProcessor(**target_qf_kwargs)

    action_dim = int(np.prod(eval_env.action_space.shape))
    policy_cnn = CNN(**cnn_params)
    if variant['use_robot_state']:
        policy_obs_processor = policy_cnn
    else:
        policy_obs_processor = nn.Sequential(
            policy_cnn,
            Flatten(),
        )

    if variant['use_robot_state']:
        policy = TanhGaussianPolicyAdapter(policy_obs_processor,
                                           policy_cnn.output_size, action_dim,
                                           **variant['policy_kwargs'])
    else:
        policy = TanhGaussianPolicyAdapter(policy_obs_processor,
                                           policy_cnn.conv_output_flat_size,
                                           action_dim,
                                           **variant['policy_kwargs'])

    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env, eval_policy, **variant['eval_path_collector_kwargs'])
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['trainer_kwargs'])
    if variant['collection_mode'] == 'batch':
        expl_path_collector = MdpPathCollector(
            expl_env, policy, **variant['expl_path_collector_kwargs'])
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    elif variant['collection_mode'] == 'online':
        expl_path_collector = MdpStepCollector(
            expl_env, policy, **variant['expl_path_collector_kwargs'])
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    else:
        raise NotImplementedError

    video_func = VideoSaveFunctionBullet(variant)
    algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):

    expl_env = FlatEnv(PointmassBaseEnv(observation_mode='pixels',
                                        transpose_image=True),
                       use_robot_state=False)

    eval_env = expl_env

    img_width, img_height = (48, 48)
    num_channels = 3

    action_dim = int(np.prod(eval_env.action_space.shape))
    cnn_params = variant['cnn_params']
    cnn_params.update(
        input_width=img_width,
        input_height=img_height,
        input_channels=num_channels,
        added_fc_input_size=4,
        output_conv_channels=True,
        output_size=None,
    )

    qf_cnn = CNN(**cnn_params)
    qf_obs_processor = nn.Sequential(
        qf_cnn,
        Flatten(),
    )

    qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    qf_kwargs['obs_processor'] = qf_obs_processor
    qf_kwargs['output_size'] = 1
    qf_kwargs['input_size'] = (action_dim + qf_cnn.conv_output_flat_size)
    qf1 = MlpQfWithObsProcessor(**qf_kwargs)
    qf2 = MlpQfWithObsProcessor(**qf_kwargs)

    target_qf_cnn = CNN(**cnn_params)
    target_qf_obs_processor = nn.Sequential(
        target_qf_cnn,
        Flatten(),
    )
    target_qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    target_qf_kwargs['obs_processor'] = target_qf_obs_processor
    target_qf_kwargs['output_size'] = 1
    target_qf_kwargs['input_size'] = (action_dim +
                                      target_qf_cnn.conv_output_flat_size)
    target_qf1 = MlpQfWithObsProcessor(**target_qf_kwargs)
    target_qf2 = MlpQfWithObsProcessor(**target_qf_kwargs)

    action_dim = int(np.prod(eval_env.action_space.shape))
    policy_cnn = CNN(**cnn_params)
    policy_obs_processor = nn.Sequential(
        policy_cnn,
        Flatten(),
    )
    policy = TanhGaussianPolicyAdapter(policy_obs_processor,
                                       policy_cnn.conv_output_flat_size,
                                       action_dim, **variant['policy_kwargs'])

    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env, eval_policy, **variant['eval_path_collector_kwargs'])
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['trainer_kwargs'])
    if variant['collection_mode'] == 'batch':
        expl_path_collector = MdpPathCollector(
            expl_env, policy, **variant['expl_path_collector_kwargs'])
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    elif variant['collection_mode'] == 'online':
        expl_path_collector = MdpStepCollector(
            expl_env, policy, **variant['expl_path_collector_kwargs'])
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):
    # from softlearning.environments.gym import register_image_reach
    # register_image_reach()
    # env = gym.envs.make(
    #     'Pusher2d-ImageReach-v0',
    # )
    from softlearning.environments.gym.mujoco.image_pusher_2d import (
        ImageForkReacher2dEnv)

    env_kwargs = {
        'image_shape': (32, 32, 3),
        'arm_goal_distance_cost_coeff': 1.0,
        'arm_object_distance_cost_coeff': 0.0,
    }

    eval_env = ImageForkReacher2dEnv(**env_kwargs)
    expl_env = ImageForkReacher2dEnv(**env_kwargs)

    input_width, input_height, input_channels = eval_env.image_shape
    image_dim = input_width * input_height * input_channels

    action_dim = int(np.prod(eval_env.action_space.shape))
    cnn_params = variant['cnn_params']
    cnn_params.update(
        input_width=input_width,
        input_height=input_height,
        input_channels=input_channels,
        added_fc_input_size=4,
        output_conv_channels=True,
        output_size=None,
    )
    non_image_dim = int(np.prod(eval_env.observation_space.shape)) - image_dim
    if variant['shared_qf_conv']:
        qf_cnn = CNN(**cnn_params)
        qf_obs_processor = nn.Sequential(
            Split(qf_cnn, identity, image_dim),
            FlattenEach(),
            Concat(),
        )

        qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
        qf_kwargs['obs_processor'] = qf_obs_processor
        qf_kwargs['output_size'] = 1
        qf_kwargs['input_size'] = (action_dim + qf_cnn.conv_output_flat_size +
                                   non_image_dim)
        qf1 = MlpQfWithObsProcessor(**qf_kwargs)
        qf2 = MlpQfWithObsProcessor(**qf_kwargs)

        target_qf_cnn = CNN(**cnn_params)
        target_qf_obs_processor = nn.Sequential(
            Split(target_qf_cnn, identity, image_dim),
            FlattenEach(),
            Concat(),
        )
        target_qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
        target_qf_kwargs['obs_processor'] = target_qf_obs_processor
        target_qf_kwargs['output_size'] = 1
        target_qf_kwargs['input_size'] = (action_dim +
                                          target_qf_cnn.conv_output_flat_size +
                                          non_image_dim)
        target_qf1 = MlpQfWithObsProcessor(**target_qf_kwargs)
        target_qf2 = MlpQfWithObsProcessor(**target_qf_kwargs)
    else:
        qf1_cnn = CNN(**cnn_params)
        cnn_output_dim = qf1_cnn.conv_output_flat_size
        qf1 = MlpQfWithObsProcessor(obs_processor=qf1_cnn,
                                    output_size=1,
                                    input_size=action_dim + cnn_output_dim,
                                    **variant['qf_kwargs'])
        qf2 = MlpQfWithObsProcessor(obs_processor=CNN(**cnn_params),
                                    output_size=1,
                                    input_size=action_dim + cnn_output_dim,
                                    **variant['qf_kwargs'])
        target_qf1 = MlpQfWithObsProcessor(obs_processor=CNN(**cnn_params),
                                           output_size=1,
                                           input_size=action_dim +
                                           cnn_output_dim,
                                           **variant['qf_kwargs'])
        target_qf2 = MlpQfWithObsProcessor(obs_processor=CNN(**cnn_params),
                                           output_size=1,
                                           input_size=action_dim +
                                           cnn_output_dim,
                                           **variant['qf_kwargs'])
    action_dim = int(np.prod(eval_env.action_space.shape))
    policy_cnn = CNN(**cnn_params)
    policy_obs_processor = nn.Sequential(
        Split(policy_cnn, identity, image_dim),
        FlattenEach(),
        Concat(),
    )
    policy = TanhGaussianPolicyAdapter(
        policy_obs_processor, policy_cnn.conv_output_flat_size + non_image_dim,
        action_dim, **variant['policy_kwargs'])

    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env, eval_policy, **variant['eval_path_collector_kwargs'])
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['trainer_kwargs'])
    if variant['collection_mode'] == 'batch':
        expl_path_collector = MdpPathCollector(
            expl_env, policy, **variant['expl_path_collector_kwargs'])
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    elif variant['collection_mode'] == 'online':
        expl_path_collector = MdpStepCollector(
            expl_env, policy, **variant['expl_path_collector_kwargs'])
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):

    ptu.set_gpu_mode(True, 0)

    from softlearning.environments.gym import register_image_reach
    register_image_reach()
    env = gym.make('Pusher2d-ImageReach-v0', arm_goal_distance_cost_coeff=1.0, arm_object_distance_cost_coeff=0.0)

    #import ipdb; ipdb.set_trace()
    input_width, input_height = env.image_shape

    action_dim = int(np.prod(env.action_space.shape))
    cnn_params = variant['cnn_params']
    cnn_params.update(
        input_width=input_width,
        input_height=input_height,
        input_channels=3,
        added_fc_input_size=4,
        output_conv_channels=True,
        output_size=None,
    )
    if variant['shared_qf_conv']:
        qf_cnn = CNN(**cnn_params)
        qf1 = MlpQfWithObsProcessor(
            obs_processor=qf_cnn,
            output_size=1,
            input_size=action_dim+qf_cnn.conv_output_flat_size,
            **variant['qf_kwargs']
        )
        qf2 = MlpQfWithObsProcessor(
            obs_processor=qf_cnn,
            output_size=1,
            input_size=action_dim+qf_cnn.conv_output_flat_size,
            **variant['qf_kwargs']
        )
        target_qf_cnn = CNN(**cnn_params)
        target_qf1 = MlpQfWithObsProcessor(
            obs_processor=target_qf_cnn,
            output_size=1,
            input_size=action_dim+qf_cnn.conv_output_flat_size,
            **variant['qf_kwargs']
        )
        target_qf2 = MlpQfWithObsProcessor(
            obs_processor=target_qf_cnn,
            output_size=1,
            input_size=action_dim+qf_cnn.conv_output_flat_size,
            **variant['qf_kwargs']
        )
    else:
        qf1_cnn = CNN(**cnn_params)
        cnn_output_dim = qf1_cnn.conv_output_flat_size
        qf1 = MlpQfWithObsProcessor(
            obs_processor=qf1_cnn,
            output_size=1,
            input_size=action_dim+cnn_output_dim,
            **variant['qf_kwargs']
        )
        qf2 = MlpQfWithObsProcessor(
            obs_processor=CNN(**cnn_params),
            output_size=1,
            input_size=action_dim+cnn_output_dim,
            **variant['qf_kwargs']
        )
        target_qf1 = MlpQfWithObsProcessor(
            obs_processor=CNN(**cnn_params),
            output_size=1,
            input_size=action_dim+cnn_output_dim,
            **variant['qf_kwargs']
        )
        target_qf2 = MlpQfWithObsProcessor(
            obs_processor=CNN(**cnn_params),
            output_size=1,
            input_size=action_dim+cnn_output_dim,
            **variant['qf_kwargs']
        )
    action_dim = int(np.prod(env.action_space.shape))
    policy_cnn = CNN(**cnn_params)
    policy = TanhGaussianPolicyAdapter(
        policy_cnn,
        policy_cnn.conv_output_flat_size,
        action_dim,
    )
    eval_env = expl_env = env

    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
        **variant['eval_path_collector_kwargs']
    )
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = SACTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **variant['trainer_kwargs']
    )
    if variant['collection_mode'] == 'batch':
        expl_path_collector = MdpPathCollector(
            expl_env,
            policy,
            **variant['expl_path_collector_kwargs']
        )
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs']
        )
    elif variant['collection_mode'] == 'online':
        expl_path_collector = MdpStepCollector(
            expl_env,
            policy,
            **variant['expl_path_collector_kwargs']
        )
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs']
        )
    elif variant['collection_mode'] == 'parallel':
        expl_path_collector = MdpPathCollector(
            expl_env,
            policy,
            **variant['expl_path_collector_kwargs']
        )
        algorithm = TorchParallelRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs']
        )
    algorithm.to(ptu.device)
    algorithm.train()
Пример #9
0
def experiment(variant):
    from multiworld.envs.mujoco import register_goal_example_envs
    register_goal_example_envs()

    eval_env = gym.make('Image48SawyerPushForwardEnv-v0')
    expl_env = gym.make('Image48SawyerPushForwardEnv-v0')
    # Hack for now
    eval_env.wrapped_env.transpose = True
    expl_env.wrapped_env.transpose = True
    # More hacks, use a dense reward instead
    eval_env.wrapped_env.wrapped_env.reward_type = 'puck_distance'
    expl_env.wrapped_env.wrapped_env.reward_type = 'puck_distance'

    img_width, img_height = eval_env.image_shape
    num_channels = 3

    action_dim = int(np.prod(eval_env.action_space.shape))
    cnn_params = variant['cnn_params']
    cnn_params.update(
        input_width=img_width,
        input_height=img_height,
        input_channels=num_channels,
        added_fc_input_size=4,
        output_conv_channels=True,
        output_size=None,
    )

    qf_cnn = CNN(**cnn_params)
    qf_obs_processor = nn.Sequential(
        qf_cnn,
        Flatten(),
    )

    qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    qf_kwargs['obs_processor'] = qf_obs_processor
    qf_kwargs['output_size'] = 1
    qf_kwargs['input_size'] = (action_dim + qf_cnn.conv_output_flat_size)
    qf1 = MlpQfWithObsProcessor(**qf_kwargs)
    qf2 = MlpQfWithObsProcessor(**qf_kwargs)

    target_qf_cnn = CNN(**cnn_params)
    target_qf_obs_processor = nn.Sequential(
        target_qf_cnn,
        Flatten(),
    )
    target_qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    target_qf_kwargs['obs_processor'] = target_qf_obs_processor
    target_qf_kwargs['output_size'] = 1
    target_qf_kwargs['input_size'] = (action_dim +
                                      target_qf_cnn.conv_output_flat_size)
    target_qf1 = MlpQfWithObsProcessor(**target_qf_kwargs)
    target_qf2 = MlpQfWithObsProcessor(**target_qf_kwargs)

    action_dim = int(np.prod(eval_env.action_space.shape))
    policy_cnn = CNN(**cnn_params)
    policy_obs_processor = nn.Sequential(
        policy_cnn,
        Flatten(),
    )
    policy = TanhGaussianPolicyAdapter(policy_obs_processor,
                                       policy_cnn.conv_output_flat_size,
                                       action_dim, **variant['policy_kwargs'])

    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env, eval_policy, **variant['eval_path_collector_kwargs'])
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['trainer_kwargs'])
    if variant['collection_mode'] == 'batch':
        expl_path_collector = MdpPathCollector(
            expl_env, policy, **variant['expl_path_collector_kwargs'])
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    elif variant['collection_mode'] == 'online':
        expl_path_collector = MdpStepCollector(
            expl_env, policy, **variant['expl_path_collector_kwargs'])
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):
    import multiworld
    multiworld.register_all_envs()
    env = gym.make('Image48SawyerReachXYEnv-v1')
    observation_key = 'image_proprio_observation'
    input_width, input_height = env.image_shape

    action_dim = int(np.prod(env.action_space.shape))
    cnn_params = variant['cnn_params']
    cnn_params.update(
        input_width=input_width,
        input_height=input_height,
        input_channels=3,
        added_fc_input_size=3,
        output_conv_channels=True,
        output_size=None,
    )
    if variant['shared_qf_conv']:
        qf_cnn = CNN(**cnn_params)
        qf1 = MlpQfWithObsProcessor(
            obs_processor=nn.Sequential(qf_cnn, Flatten()),
            output_size=1,
            input_size=action_dim + qf_cnn.conv_output_flat_size,
            **variant['qf_kwargs'])
        qf2 = MlpQfWithObsProcessor(
            obs_processor=nn.Sequential(qf_cnn, Flatten()),
            output_size=1,
            input_size=action_dim + qf_cnn.conv_output_flat_size,
            **variant['qf_kwargs'])
        target_qf_cnn = CNN(**cnn_params)
        target_qf1 = MlpQfWithObsProcessor(
            obs_processor=nn.Sequential(target_qf_cnn, Flatten()),
            output_size=1,
            input_size=action_dim + target_qf_cnn.conv_output_flat_size,
            **variant['qf_kwargs'])
        target_qf2 = MlpQfWithObsProcessor(
            obs_processor=nn.Sequential(target_qf_cnn, Flatten()),
            output_size=1,
            input_size=action_dim + target_qf_cnn.conv_output_flat_size,
            **variant['qf_kwargs'])
    else:
        qf1_cnn = CNN(**cnn_params)
        cnn_output_dim = qf1_cnn.conv_output_flat_size
        qf1 = MlpQfWithObsProcessor(obs_processor=nn.Sequential(
            qf1_cnn, Flatten()),
                                    output_size=1,
                                    input_size=action_dim + cnn_output_dim,
                                    **variant['qf_kwargs'])
        qf2 = MlpQfWithObsProcessor(obs_processor=nn.Sequential(
            CNN(**cnn_params), Flatten()),
                                    output_size=1,
                                    input_size=action_dim + cnn_output_dim,
                                    **variant['qf_kwargs'])
        target_qf1 = MlpQfWithObsProcessor(
            obs_processor=nn.Sequential(CNN(**cnn_params), Flatten()),
            output_size=1,
            input_size=action_dim + cnn_output_dim,
            **variant['qf_kwargs'])
        target_qf2 = MlpQfWithObsProcessor(
            obs_processor=nn.Sequential(CNN(**cnn_params), Flatten()),
            output_size=1,
            input_size=action_dim + cnn_output_dim,
            **variant['qf_kwargs'])
    policy_cnn = CNN(**cnn_params)
    policy = TanhGaussianPolicyAdapter(nn.Sequential(policy_cnn, Flatten()),
                                       policy_cnn.conv_output_flat_size,
                                       action_dim, **variant['policy_kwargs'])
    eval_env = expl_env = env

    eval_policy = MakeDeterministic(policy)
    eval_path_collector = ObsDictPathCollector(
        eval_env,
        eval_policy,
        observation_key=observation_key,
        **variant['eval_path_collector_kwargs'])
    replay_buffer = ObsDictReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
        observation_key=observation_key,
        **variant['replay_buffer_kwargs'],
    )
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['trainer_kwargs'])
    if variant['collection_mode'] == 'batch':
        expl_path_collector = ObsDictPathCollector(
            expl_env,
            policy,
            observation_key=observation_key,
            **variant['expl_path_collector_kwargs'])
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    elif variant['collection_mode'] == 'online':
        expl_path_collector = ObsDictStepCollector(
            expl_env,
            policy,
            observation_key=observation_key,
            **variant['expl_path_collector_kwargs'])
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Пример #11
0
def experiment(variant):
    expl_env = roboverse.make(variant['env'],
                              gui=False,
                              randomize=True,
                              observation_mode=variant['obs'],
                              reward_type='shaped',
                              transpose_image=True)
    eval_env = expl_env
    img_width, img_height = eval_env.image_shape
    num_channels = 3
    action_dim = int(np.prod(eval_env.action_space.shape))
    # obs_dim = 11

    cnn_params = variant['cnn_params']
    cnn_params.update(
        input_width=img_width,
        input_height=img_height,
        input_channels=num_channels,
        added_fc_input_size=0,
        output_conv_channels=True,
        output_size=None,
    )

    qf_cnn = CNN(**cnn_params)
    qf_obs_processor = nn.Sequential(
        qf_cnn,
        Flatten(),
    )

    qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    qf_kwargs['obs_processor'] = qf_obs_processor
    qf_kwargs['output_size'] = 1
    qf_kwargs['input_size'] = (action_dim + qf_cnn.conv_output_flat_size)
    qf1 = MlpQfWithObsProcessor(**qf_kwargs)
    qf2 = MlpQfWithObsProcessor(**qf_kwargs)

    target_qf_cnn = CNN(**cnn_params)
    target_qf_obs_processor = nn.Sequential(
        target_qf_cnn,
        Flatten(),
    )

    target_qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    target_qf_kwargs['obs_processor'] = target_qf_obs_processor
    target_qf_kwargs['output_size'] = 1
    target_qf_kwargs['input_size'] = (action_dim +
                                      target_qf_cnn.conv_output_flat_size)

    target_qf1 = MlpQfWithObsProcessor(**target_qf_kwargs)
    target_qf2 = MlpQfWithObsProcessor(**target_qf_kwargs)

    action_dim = int(np.prod(eval_env.action_space.shape))
    policy_cnn = CNN(**cnn_params)
    policy_obs_processor = nn.Sequential(
        policy_cnn,
        Flatten(),
    )
    policy = TanhGaussianPolicyAdapter(policy_obs_processor,
                                       policy_cnn.conv_output_flat_size,
                                       action_dim, **variant['policy_kwargs'])

    eval_policy = MakeDeterministic(policy)
    observation_key = 'image'

    eval_path_collector = ObsDictPathCollector(
        eval_env,
        eval_policy,
        observation_key=observation_key,
        **variant['eval_path_collector_kwargs'])

    expl_path_collector = CustomObsDictPathCollector(
        expl_env,
        observation_key=observation_key,
    )

    with open(variant['buffer'], 'rb') as f:
        replay_buffer = pickle.load(f)

    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         behavior_policy=None,
                         **variant['trainer_kwargs'])
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        eval_both=True,
        batch_rl=variant['load_buffer'],
        **variant['algorithm_kwargs'])

    video_func = VideoSaveFunctionBullet(variant)
    algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):

    env_params = dict(
        block_random=0.3,
        camera_random=0,
        simple_observations=False,
        continuous=True,
        remove_height_hack=True,
        render_mode="DIRECT",
        # render_mode="GUI",
        num_objects=5,
        max_num_training_models=900,
        target=False,
        test=False,
    )
    expl_env = FlatEnv(KukaGraspingProceduralEnv(**env_params))
    eval_env = expl_env
    img_width, img_height = eval_env.image_shape
    num_channels = 3

    action_dim = int(np.prod(eval_env.action_space.shape))
    cnn_params = variant['cnn_params']
    cnn_params.update(
        input_width=img_width,
        input_height=img_height,
        input_channels=num_channels,
        added_fc_input_size=0,
        output_conv_channels=True,
        output_size=None,
    )

    qf_cnn = CNN(**cnn_params)
    qf_obs_processor = nn.Sequential(
        qf_cnn,
        Flatten(),
    )

    qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    qf_kwargs['obs_processor'] = qf_obs_processor
    qf_kwargs['output_size'] = 1
    qf_kwargs['input_size'] = (action_dim + qf_cnn.conv_output_flat_size)
    qf1 = MlpQfWithObsProcessor(**qf_kwargs)
    qf2 = MlpQfWithObsProcessor(**qf_kwargs)

    target_qf_cnn = CNN(**cnn_params)
    target_qf_obs_processor = nn.Sequential(
        target_qf_cnn,
        Flatten(),
    )

    target_qf_kwargs = copy.deepcopy(variant['qf_kwargs'])
    target_qf_kwargs['obs_processor'] = target_qf_obs_processor
    target_qf_kwargs['output_size'] = 1
    target_qf_kwargs['input_size'] = (action_dim +
                                      target_qf_cnn.conv_output_flat_size)

    target_qf1 = MlpQfWithObsProcessor(**target_qf_kwargs)
    target_qf2 = MlpQfWithObsProcessor(**target_qf_kwargs)

    action_dim = int(np.prod(eval_env.action_space.shape))
    policy_cnn = CNN(**cnn_params)
    policy_obs_processor = nn.Sequential(
        policy_cnn,
        Flatten(),
    )
    policy = TanhGaussianPolicyAdapter(policy_obs_processor,
                                       policy_cnn.conv_output_flat_size,
                                       action_dim, **variant['policy_kwargs'])

    observation_key = 'image'
    eval_policy = MakeDeterministic(policy)
    eval_path_collector = ObsDictPathCollector(
        eval_env,
        eval_policy,
        observation_key=observation_key,
        **variant['eval_path_collector_kwargs'])
    replay_buffer = ObsDictReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
        observation_key=observation_key,
    )

    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['trainer_kwargs'])
    if variant['collection_mode'] == 'batch':
        expl_path_collector = ObsDictPathCollector(
            expl_env,
            policy,
            observation_key=observation_key,
            **variant['expl_path_collector_kwargs'])
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    elif variant['collection_mode'] == 'online':
        expl_path_collector = ObsDictStepCollector(
            expl_env,
            policy,
            observation_key=observation_key,
            **variant['expl_path_collector_kwargs'])
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs'])
    else:
        raise NotImplementedError

    video_func = VideoSaveFunctionBullet(variant)
    algorithm.post_train_funcs.append(video_func)

    # dump_buffer_func = BufferSaveFunction(variant)
    # algorithm.post_train_funcs.append(dump_buffer_func)

    algorithm.to(ptu.device)
    algorithm.train()