Beispiel #1
0
def experiment(args, variant):
    #eval_env = gym.make('FetchReach-v1')
    #expl_env = gym.make('FetchReach-v1')

    core_env = env.DeepBuilderEnv(args.session_name, args.act_dim,
                                  args.box_dim, args.max_num_boxes,
                                  args.height_field_dim)
    eval_env = stuff.NormalizedActions(core_env)
    expl_env = stuff.NormalizedActions(core_env)
    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    resumed = args.resume == 1

    if resumed:
        variant, params = doc.load_rklit_file(args.session_name)
        variant['algorithm_kwargs']['min_num_steps_before_training'] = 0

    M = variant['layer_size']

    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    ) if not resumed else params['trainer/qf1']

    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    ) if not resumed else params['trainer/qf2']

    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    ) if not resumed else params['trainer/target_qf1']

    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    ) if not resumed else params['trainer/target_qf2']

    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=[M, M],
    ) if not resumed else params['trainer/policy']

    eval_policy = MakeDeterministic(
        policy) if not resumed else params['evaluation/policy']

    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )

    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )

    replay_buffer_expl = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )

    replay_buffer_eval = EnvReplayBuffer(
        int(variant['replay_buffer_size'] *
            (float(args.num_plays_eval) / float(args.num_plays_expl))),
        eval_env,
    )

    if resumed:
        replay_buffer_expl._actions = params['replay_buffer_expl/actions']
        replay_buffer_expl._env_infos = params['replay_buffer_expl/env_infos']
        replay_buffer_expl._next_obs = params['replay_buffer_expl/next_obs']
        replay_buffer_expl._observations = params[
            'replay_buffer_expl/observations']
        replay_buffer_expl._rewards = params['replay_buffer_expl/rewards']
        replay_buffer_expl._size = params['replay_buffer_expl/size']
        replay_buffer_expl._terminals = params['replay_buffer_expl/terminals']
        replay_buffer_expl._top = params['replay_buffer_expl/top']

        replay_buffer_eval._actions = params['replay_buffer_eval/actions']
        replay_buffer_eval._env_infos = params['replay_buffer_eval/env_infos']
        replay_buffer_eval._next_obs = params['replay_buffer_eval/next_obs']
        replay_buffer_eval._observations = params[
            'replay_buffer_eval/observations']
        replay_buffer_eval._rewards = params['replay_buffer_eval/rewards']
        replay_buffer_eval._size = params['replay_buffer_eval/size']
        replay_buffer_eval._terminals = params['replay_buffer_eval/terminals']
        replay_buffer_eval._top = params['replay_buffer_eval/top']

    elif args.replay_add_sess_name != '':
        _, other_params = doc.load_rklit_file(args.replay_add_sess_name)
        num_samples = int(args.replay_add_num_samples)
        replay_buffer_expl._size = 0
        replay_buffer_expl._top = 0
        print("Loading " + str(num_samples) + " batch samples from session " +
              args.replay_add_sess_name)
        zeroes = []
        offset = 0
        for i in range(num_samples):
            act = other_params['replay_buffer_expl/actions'][i]
            obs = other_params['replay_buffer_expl/observations'][i]
            if act.min() == 0.0 and act.max() == 0.0 and obs.min(
            ) == 0.0 and obs.max() == 0.0:
                zeroes.append(i)
                continue

            replay_buffer_expl._actions[offset] = copy.deepcopy(act.tolist())
            replay_buffer_expl._next_obs[offset] = copy.deepcopy(
                other_params['replay_buffer_expl/next_obs'][i].tolist())
            replay_buffer_expl._observations[offset] = copy.deepcopy(
                obs.tolist())
            replay_buffer_expl._rewards[offset] = copy.deepcopy(
                other_params['replay_buffer_expl/rewards'][i].tolist())
            replay_buffer_expl._terminals[offset] = copy.deepcopy(
                other_params['replay_buffer_expl/terminals'][i].tolist())
            replay_buffer_expl._size += 1
            replay_buffer_expl._top += 1
            offset += 1

        print(
            "Detected and ignored " + str(len(zeroes)) +
            " zero samples in replay buffer. Total num samples loaded into replay buffer: "
            + str(replay_buffer_expl._size))
        other_params = {}

    trainer = SACTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **variant['trainer_kwargs'],
        starting_train_steps=0 if not resumed else
        (params['replay_buffer_expl/top'] *
         variant['algorithm_kwargs']['num_trains_per_train_loop']),
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer_eval=replay_buffer_eval,
        replay_buffer_expl=replay_buffer_expl,
        **variant['algorithm_kwargs'])

    algorithm.to(ptu.device)
    algorithm.train()
def experiment(args, variant):
    core_env = env.DeepBuilderEnv(args.session_name, args.act_dim, args.box_dim, args.max_num_boxes, args.height_field_dim)
    eval_env = stuff.NormalizedActions(core_env)
    expl_env = stuff.NormalizedActions(core_env)
    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    M = variant['layer_size']

    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )

    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )

    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )

    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )

    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=[M, M],
    )

    eval_policy = MakeDeterministic(policy)
    
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )

    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )
    replay_buffer = EnvReplayBuffer(
        222726+21,
        expl_env,
    )

    replay_buffer_eval = EnvReplayBuffer(21, eval_env)

    if args.replay_add_sess_name_1 != '':
        _, other_params = doc.load_rklit_file(args.replay_add_sess_name_1)
        num_samples = int(args.replay_add_num_samples_1)
        replay_buffer._size = 0
        replay_buffer._top = 0
        offset = 0
        print("Loading "+str(num_samples)+" batch samples from session " + args.replay_add_sess_name_1)
        for i in range(num_samples):
            act = other_params['replay_buffer_expl/actions'][i]
            obs = other_params['replay_buffer_expl/observations'][i]
            if not (act.min()== 0.0 and act.max() == 0.0 and obs.min() == 0.0 and obs.max() == 0.0):            
                replay_buffer._actions[i] = act
                replay_buffer._next_obs[i] = other_params['replay_buffer_expl/next_obs'][i]
                replay_buffer._observations[i] = obs
                replay_buffer._rewards[i] = other_params['replay_buffer_expl/rewards'][i]
                replay_buffer._terminals[i] = other_params['replay_buffer_expl/terminals'][i]
                replay_buffer._size += 1
                replay_buffer._top += 1
                offset+=1

        if args.replay_add_sess_name_2 != '':
            _, other_params = doc.load_rklit_file(args.replay_add_sess_name_2)
            num_samples = int(args.replay_add_num_samples_2)
            print("Loading "+str(num_samples)+" batch samples from session " + args.replay_add_sess_name_2)
            for i in range(21021, num_samples):      
                act = other_params['replay_buffer_expl/actions'][i]
                obs = other_params['replay_buffer_expl/observations'][i]   
                if not (act.min()== 0.0 and act.max() == 0.0 and obs.min() == 0.0 and obs.max() == 0.0):    
                    replay_buffer._actions[offset] = act
                    replay_buffer._next_obs[offset] = other_params['replay_buffer_expl/next_obs'][i]
                    replay_buffer._observations[offset] = obs
                    replay_buffer._rewards[offset] = other_params['replay_buffer_expl/rewards'][i]
                    replay_buffer._terminals[offset] = other_params['replay_buffer_expl/terminals'][i]
                    replay_buffer._size += 1
                    replay_buffer._top += 1
                    offset+=1

        '''
        if args.replay_add_sess_name_3 != args.replay_add_sess_name_2:
            #_, other_params = doc.load_rklit_file(args.replay_add_sess_name_3)
            num_samples = int(args.replay_add_num_samples_3)
            print("Loading "+str(num_samples)+" batch samples from session " + args.replay_add_sess_name_3)
            for i in range(num_samples):     
                act = other_params['replay_buffer_eval/actions'][i]
                obs = other_params['replay_buffer_eval/observations'][i]   
                if not (act.min()== 0.0 and act.max() == 0.0 and obs.min() == 0.0 and obs.max() == 0.0):          
                    replay_buffer._actions[offset] = act
                    replay_buffer._next_obs[offset] = other_params['replay_buffer_eval/next_obs'][i]
                    replay_buffer._observations[offset] = obs
                    replay_buffer._rewards[offset] = other_params['replay_buffer_eval/rewards'][i]
                    replay_buffer._terminals[offset] = other_params['replay_buffer_eval/terminals'][i]
                    replay_buffer._size += 1
                    replay_buffer._top += 1
                    offset+=1
        '''
        del other_params

        print("Detected and removed "+str(replay_buffer._max_replay_buffer_size - replay_buffer._size)+" zero samples. Final size of replay buffer: " + str(replay_buffer._size))


    trainer = SACTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **variant['trainer_kwargs']
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer_expl=replay_buffer,
        replay_buffer_eval=replay_buffer_eval,
        **variant['algorithm_kwargs']
    )

    algorithm.to(ptu.device)
    algorithm.train()