Example #1
0
def load_widowx_real_robot():
    '''
    Loads sim datasets (uses rlkit from railrl-private)
    '''
    from rlkit.data_management.obs_dict_replay_buffer import ObsDictReplayBuffer
    from rlkit.misc.wx250_utils import add_data_to_buffer_real_robot, DummyEnv

    image_size = 64
    expl_env = DummyEnv(image_size=image_size)
    data_folder_path = '/nfs/kun1/users/albert/realrobot_datasets'
    data_file_path = data_folder_path + '/combined_2021-05-20_21_53_31.pkl'

    replay_buffer = ObsDictReplayBuffer(int(1E6),
                                        expl_env,
                                        observation_keys=['image'])
    add_data_to_buffer_real_robot(data_file_path,
                                  replay_buffer,
                                  validation_replay_buffer=None,
                                  validation_fraction=0.8)

    train = WidowXDataset(replay_buffer,
                          train=True,
                          normalize=False,
                          image_dims=(64, 64, 3))
    val = WidowXDataset(replay_buffer,
                        train=False,
                        normalize=False,
                        image_dims=(64, 64, 3))

    return train, val
Example #2
0
def experiment(variant):
    checkpoint_filepath = os.path.join(variant['checkpoint_dir'],
                                       'itr_{}.pkl'.format(
                                           variant['checkpoint_epoch']))
    checkpoint = torch.load(checkpoint_filepath)

    # the following does not work for Bullet envs yet
    # eval_env = checkpoint['evaluation/env']
    # expl_env = checkpoint['exploration/env']

    eval_env = roboverse.make(variant['env'], transpose_image=True)
    expl_env = eval_env

    policy = checkpoint['trainer/trainer'].policy
    eval_policy = checkpoint['evaluation/policy']
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )

    observation_key = 'image'
    online_buffer_size = 500 * 10 * variant['algorithm_kwargs'][
        'max_path_length']

    if variant['online_data_only']:
        replay_buffer = ObsDictReplayBuffer(online_buffer_size, expl_env,
                                            observation_key=observation_key)
    else:
        replay_buffer = load_data_from_npy_chaining(
            variant, expl_env, observation_key,
            extra_buffer_size=online_buffer_size)

    trainer_kwargs = variant['trainer_kwargs']
    assert trainer_kwargs['min_q_weight'] > 0.
    trainer = checkpoint['trainer/trainer']
    trainer.min_q_weight = trainer_kwargs['min_q_weight']

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        eval_both=False,
        batch_rl=False,
        **variant['algorithm_kwargs']
    )
    video_func = VideoSaveFunction(variant)
    algorithm.post_epoch_funcs.append(video_func)

    algorithm.to(ptu.device)
    algorithm.train()
Example #3
0
def load_data_from_npy_chaining(variant,
                                expl_env,
                                observation_key,
                                extra_buffer_size=100):
    with open(variant['prior_buffer'], 'rb') as f:
        data_prior = np.load(f, allow_pickle=True)
    with open(variant['task_buffer'], 'rb') as f:
        data_task = np.load(f, allow_pickle=True)

    buffer_size = get_buffer_size(data_prior)
    buffer_size += get_buffer_size(data_task)
    buffer_size += extra_buffer_size

    # TODO Clean this up
    if 'biased_sampling' in variant:
        if variant['biased_sampling']:
            bias_point = buffer_size - extra_buffer_size
            print('Setting bias point', bias_point)
            replay_buffer = ObsDictReplayBuffer(
                buffer_size,
                expl_env,
                observation_key=observation_key,
                biased_sampling=True,
                bias_point=bias_point,
                before_bias_point_probability=0.5,
            )
        else:
            replay_buffer = ObsDictReplayBuffer(
                buffer_size,
                expl_env,
                observation_key=observation_key,
            )
    else:
        replay_buffer = ObsDictReplayBuffer(
            buffer_size,
            expl_env,
            observation_key=observation_key,
        )

    add_data_to_buffer(data_prior, replay_buffer)
    top = replay_buffer._top
    print('Prior data loaded from npy file', top)
    replay_buffer._rewards[:top] = 0.0 * replay_buffer._rewards[:top]
    print('Zero-ed the rewards for prior data', top)

    add_data_to_buffer(data_task, replay_buffer)
    print('Task data loaded from npy file', replay_buffer._top)
    return replay_buffer
Example #4
0
def load_data_from_npy(variant,
                       expl_env,
                       observation_key,
                       extra_buffer_size=100):
    with open(variant['buffer'], 'rb') as f:
        data = np.load(f, allow_pickle=True)

    num_transitions = get_buffer_size(data)
    buffer_size = num_transitions + extra_buffer_size

    replay_buffer = ObsDictReplayBuffer(
        buffer_size,
        expl_env,
        observation_key=observation_key,
    )
    add_data_to_buffer(data, replay_buffer)
    print('Data loaded from npy file', replay_buffer._top)
    return replay_buffer
Example #5
0
def active_representation_learning_experiment(variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import ObsDictReplayBuffer
    from rlkit.torch.networks import ConcatMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.arl.active_representation_learning_algorithm import \
        ActiveRepresentationLearningAlgorithm
    from rlkit.torch.arl.representation_wrappers import RepresentationWrappedEnv
    from multiworld.core.image_env import ImageEnv
    from rlkit.samplers.data_collector import MdpPathCollector

    preprocess_rl_variant(variant)

    model_class = variant.get('model_class')
    model_kwargs = variant.get('model_kwargs')

    model = model_class(**model_kwargs)
    model.representation_size = 4
    model.imsize = 48
    variant["vae_path"] = model

    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    env = variant["env_class"](**variant['env_kwargs'])
    image_env = ImageEnv(
        env,
        variant.get('imsize'),
        init_camera=init_camera,
        transpose=True,
        normalize=True,
    )
    env = RepresentationWrappedEnv(
        image_env,
        model,
    )

    uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **variant['generate_uniform_dataset_kwargs'])
    else:
        uniform_dataset = None

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = env.observation_space.spaces[observation_key].low.size
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=hidden_sizes,
    )

    vae = env.vae

    replay_buffer = ObsDictReplayBuffer(env=env,
                                        **variant['replay_buffer_kwargs'])

    model_trainer_class = variant.get('model_trainer_class')
    model_trainer_kwargs = variant.get('model_trainer_kwargs')
    model_trainer = model_trainer_class(
        model,
        **model_trainer_kwargs,
    )
    # vae_trainer = ConvVAETrainer(
    #     env.vae,
    #     **variant['online_vae_trainer_kwargs']
    # )
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    max_path_length = variant['max_path_length']

    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['twin_sac_trainer_kwargs'])
    # trainer = HERTrainer(trainer)
    eval_path_collector = MdpPathCollector(
        env,
        MakeDeterministic(policy),
        # max_path_length,
        # observation_key=observation_key,
        # desired_goal_key=desired_goal_key,
    )
    expl_path_collector = MdpPathCollector(
        env,
        policy,
        # max_path_length,
        # observation_key=observation_key,
        # desired_goal_key=desired_goal_key,
    )

    algorithm = ActiveRepresentationLearningAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        model=model,
        model_trainer=model_trainer,
        uniform_dataset=uniform_dataset,
        max_path_length=max_path_length,
        **variant['algo_kwargs'])

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
Example #6
0
def experiment(variant):
    checkpoint_filepath = os.path.join(variant['checkpoint_dir'],
                                       'itr_{}.pkl'.format(
                                           variant['checkpoint_epoch']))
    checkpoint = torch.load(checkpoint_filepath)


    eval_env = roboverse.make(variant['env'], transpose_image=True)
    expl_env = eval_env

    action_dim = eval_env.action_space.low.size
    cnn_params = variant['cnn_params']
    cnn_params.update(
        input_width=48,
        input_height=48,
        input_channels=3,
        output_size=1,
        added_fc_input_size=action_dim,
    )
    qf1 = ConcatCNN(**cnn_params)
    qf2 = ConcatCNN(**cnn_params)
    target_qf1 = ConcatCNN(**cnn_params)
    target_qf2 = ConcatCNN(**cnn_params)

    policy = checkpoint['evaluation/policy']
    eval_policy = MakeDeterministic(policy)

    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )

    observation_key = 'image'
    online_buffer_size = 500 * 10 * variant['algorithm_kwargs'][
        'max_path_length']

    if variant['online_data_only']:
        replay_buffer = ObsDictReplayBuffer(online_buffer_size, expl_env,
                                            observation_key=observation_key)
    else:
        replay_buffer = load_data_from_npy_chaining(
            variant, expl_env, observation_key,
            extra_buffer_size=online_buffer_size)

    trainer_kwargs = variant['trainer_kwargs']
    trainer = SACTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **trainer_kwargs
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        eval_both=False,
        batch_rl=False,
        **variant['algorithm_kwargs']
    )
    video_func = VideoSaveFunction(variant)
    algorithm.post_epoch_funcs.append(video_func)

    algorithm.to(ptu.device)
    algorithm.train()
Example #7
0
def td3_experiment(variant):
    import gym
    import multiworld.envs.mujoco
    import multiworld.envs.pygame
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy
    from rlkit.exploration_strategies.gaussian_strategy import GaussianStrategy
    from rlkit.exploration_strategies.ou_strategy import OUStrategy
    from rlkit.torch.grill.launcher import get_state_experiment_video_save_function
    from rlkit.torch.her.her_td3 import HerTd3
    from rlkit.torch.td3.td3 import TD3
    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    from rlkit.data_management.obs_dict_replay_buffer import (
        ObsDictReplayBuffer)
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    from rlkit.samplers.data_collector.path_collector import ObsDictPathCollector

    if 'env_id' in variant:
        eval_env = gym.make(variant['env_id'])
        expl_env = gym.make(variant['env_id'])
    else:
        eval_env_kwargs = variant.get('eval_env_kwargs', variant['env_kwargs'])
        eval_env = variant['env_class'](**eval_env_kwargs)
        expl_env = variant['env_class'](**variant['env_kwargs'])

    observation_key = variant['observation_key']
    # desired_goal_key = variant['desired_goal_key']
    # variant['algo_kwargs']['her_kwargs']['observation_key'] = observation_key
    # variant['algo_kwargs']['her_kwargs']['desired_goal_key'] = desired_goal_key
    if variant.get('normalize', False):
        raise NotImplementedError()

    # achieved_goal_key = desired_goal_key.replace("desired", "achieved")

    replay_buffer = ObsDictReplayBuffer(
        env=eval_env,
        observation_key=observation_key,
        # desired_goal_key=desired_goal_key,
        # achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = eval_env.observation_space.spaces['observation'].low.size
    action_dim = eval_env.action_space.low.size
    goal_dim = eval_env.observation_space.spaces['desired_goal'].low.size
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=eval_env.action_space,
                        **variant['es_kwargs'])
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=eval_env.action_space,
            **variant['es_kwargs'],
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=eval_env.action_space,
            **variant['es_kwargs'],
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    qf1 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    trainer = TD3(policy=policy,
                  qf1=qf1,
                  qf2=qf2,
                  target_qf1=target_qf1,
                  target_qf2=target_qf2,
                  target_policy=target_policy,
                  **variant['trainer_kwargs'])
    observation_key = 'observation'
    desired_goal_key = 'desired_goal'
    eval_path_collector = ObsDictPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        # render=True,
        # desired_goal_key=desired_goal_key,
    )
    expl_path_collector = ObsDictPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        # render=True,
        # desired_goal_key=desired_goal_key,
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])

    # if variant.get("save_video", False):
    #     rollout_function = rf.create_rollout_function(
    #         rf.multitask_rollout,
    #         max_path_length=algorithm.max_path_length,
    #         observation_key=observation_key,
    #         desired_goal_key=algorithm.desired_goal_key,
    #     )
    #     video_func = get_state_experiment_video_save_function(
    #         rollout_function,
    #         env,
    #         policy,
    #         variant,
    #     )
    #     algorithm.post_epoch_funcs.append(video_func)
    algorithm.to(ptu.device)
    algorithm.train()
Example #8
0
import cv2

from rlkit.data_management.obs_dict_replay_buffer import ObsDictReplayBuffer
from rlkit.misc.wx250_utils import add_data_to_buffer_real_robot, DummyEnv

parser = argparse.ArgumentParser()
parser.add_argument('--vqvae', type=str, required=True)
parser.add_argument('--buffer', type=str, required=True)
parser.add_argument('--index-range', type=int, nargs=2, default=[0, 2])
args = parser.parse_args()

image_size = 64
expl_env = DummyEnv(image_size=image_size)

replay_buffer = ObsDictReplayBuffer(int(1E6),
                                    expl_env,
                                    observation_keys=['image'])
add_data_to_buffer_real_robot(args.buffer,
                              replay_buffer,
                              validation_replay_buffer=None,
                              validation_fraction=0.8)

network = torch.load(args.vqvae)
img = replay_buffer._obs['image'][np.arange(*args.index_range)].reshape(
    -1, 3, 64, 64)
img = torch.FloatTensor(img)
img = img.cuda()

z_e = network.encoder(img)
z_e = network.pre_quantization_conv(z_e)
embedding_loss, z_q, perplexity, _, min_encoding_indices = network.vector_quantization(
Example #9
0
def state_td3bc_experiment(variant):
    if variant.get('env_id', None):
        import gym
        import multiworld

        multiworld.register_all_envs()

        eval_env = gym.make(variant['env_id'])
        eval_env = MujocoGymToMultiEnv(eval_env)
        # eval_env = EncoderWrappedEnv(eval_env)

        expl_env = gym.make(variant['env_id'])
        expl_env = MujocoGymToMultiEnv(expl_env)
        # expl_env = EncoderWrappedEnv(expl_env)
    else:
        eval_env_kwargs = variant.get('eval_env_kwargs', variant['env_kwargs'])
        eval_env = variant['env_class'](**eval_env_kwargs)
        expl_env = variant['env_class'](**variant['env_kwargs'])

    observation_key = 'state_observation'
    desired_goal_key = 'state_desired_goal'
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    es_strat = variant.get('es', 'ou')
    if es_strat == 'ou':
        es = OUStrategy(
            action_space=expl_env.action_space,
            max_sigma=variant['exploration_noise'],
            min_sigma=variant['exploration_noise'],
        )
    elif es_strat == 'gauss_eps':
        es = GaussianAndEpislonStrategy(
            action_space=expl_env.action_space,
            max_sigma=variant['exploration_noise'],
            min_sigma=variant['exploration_noise'],  # constant sigma
            epsilon=0,
        )
    else:
        raise ValueError("invalid exploration strategy provided")
    obs_dim = expl_env.observation_space.spaces['observation'].low.size
    goal_dim = 0  # expl_env.observation_space.spaces['desired_goal'].low.size
    action_dim = expl_env.action_space.low.size
    qf1 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = ObsDictReplayBuffer(
        env=eval_env,
        observation_key=observation_key,
        # desired_goal_key=desired_goal_key,
        # achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    demo_train_buffer = ObsDictReplayBuffer(
        env=eval_env,
        observation_key=observation_key,
        # desired_goal_key=desired_goal_key,
        # achieved_goal_key=achieved_goal_key,
        max_size=variant['replay_buffer_kwargs']['max_size'])
    demo_test_buffer = ObsDictReplayBuffer(
        env=eval_env,
        observation_key=observation_key,
        # desired_goal_key=desired_goal_key,
        # achieved_goal_key=achieved_goal_key,
        max_size=variant['replay_buffer_kwargs']['max_size'],
    )
    if variant.get('td3_bc', True):
        td3_trainer = TD3BCTrainer(env=expl_env,
                                   policy=policy,
                                   qf1=qf1,
                                   qf2=qf2,
                                   replay_buffer=replay_buffer,
                                   demo_train_buffer=demo_train_buffer,
                                   demo_test_buffer=demo_test_buffer,
                                   target_qf1=target_qf1,
                                   target_qf2=target_qf2,
                                   target_policy=target_policy,
                                   **variant['td3_bc_trainer_kwargs'])
    else:
        td3_trainer = TD3(policy=policy,
                          qf1=qf1,
                          qf2=qf2,
                          target_qf1=target_qf1,
                          target_qf2=target_qf2,
                          target_policy=target_policy,
                          **variant['td3_trainer_kwargs'])
    trainer = td3_trainer  # HERTrainer(td3_trainer)
    eval_path_collector = ObsDictPathCollector(  # GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        # desired_goal_key=desired_goal_key,
    )
    expl_path_collector = ObsDictPathCollector(  # GoalConditionedPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        # desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])

    if variant.get("save_video", True):
        if variant.get("presampled_goals", None):
            variant['image_env_kwargs'][
                'presampled_goals'] = load_local_or_remote_file(
                    variant['presampled_goals']).item()
        image_eval_env = ImageEnv(eval_env, **variant["image_env_kwargs"])
        image_eval_path_collector = ObsDictPathCollector(  # GoalConditionedPathCollector(
            image_eval_env,
            policy,
            observation_key='state_observation',
            # desired_goal_key='state_desired_goal',
        )
        image_expl_env = ImageEnv(expl_env, **variant["image_env_kwargs"])
        image_expl_path_collector = ObsDictPathCollector(  # GoalConditionedPathCollector(
            image_expl_env,
            expl_policy,
            observation_key='state_observation',
            # desired_goal_key='state_desired_goal',
        )
        video_func = VideoSaveFunction(
            image_eval_env,
            variant,
            image_expl_path_collector,
            image_eval_path_collector,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if variant.get('load_demos', False):
        td3_trainer.load_demos()
    if variant.get('pretrain_policy', False):
        td3_trainer.pretrain_policy_with_bc()
    if variant.get('pretrain_rl', False):
        td3_trainer.pretrain_q_with_bc_data()
    algorithm.train()