Ejemplo n.º 1
0
def evaluate(args):
    env = gym.make(args.env)
    env_params = get_env_params(env, args)
    env.close()

    agent = PPOAgent(args, env_params)
    agent.load_model(load_model_remark=args.load_model_remark)

    parent_conn, child_conn = Pipe()
    worker = AtariEnvironment(args.env,
                              1,
                              child_conn,
                              is_render=True,
                              max_episode_step=args.max_episode_step)
    worker.start()

    for i_episode in range(100):
        obs = worker.reset()
        while True:
            obs = np.expand_dims(obs, axis=0)
            action = agent.choose_action(obs / 255)

            parent_conn.send(action[0])
            obs_, r, done, info = parent_conn.recv()

            obs = obs_

            if done:
                break
Ejemplo n.º 2
0
def main():
    args = get_args()
    device = torch.device('cuda' if args.cuda else 'cpu')

    env = gym.make(args.env_name)

    input_size = env.observation_space.shape  # 4
    output_size = env.action_space.n  # 2

    if 'Breakout' in args.env_name:
        output_size -= 1

    env.close()

    is_render = True
    model_path = os.path.join(args.save_dir, args.env_name + '.model')
    if not os.path.exists(model_path):
        print("Model file not found")
        return
    num_worker = 1
    sticky_action = False

    model = CnnActorCriticNetwork(input_size, output_size, args.use_noisy_net)
    model = model.to(device)

    if args.cuda:
        model.load_state_dict(torch.load(model_path))
    else:
        model.load_state_dict(torch.load(model_path, map_location='cpu'))

    parent_conn, child_conn = Pipe()
    work = AtariEnvironment(
        args.env_name,
        is_render,
        0,
        child_conn,
        sticky_action=sticky_action,
        p=args.sticky_action_prob,
        max_episode_steps=args.max_episode_steps)
    work.start()

    # states = np.zeros([num_worker, 4, 84, 84])
    states = torch.zeros(num_worker, 4, 84, 84)

    while True:
        actions = get_action(model, device, torch.div(states, 255.))

        parent_conn.send(actions)

        next_states = []
        next_state, reward, done, real_done, log_reward = parent_conn.recv()
        next_states.append(next_state)
        states = torch.from_numpy(np.stack(next_states))
        states = states.type(torch.FloatTensor)
    def play(self):
        parent, child = Pipe()
        if flag.ENV == "MR":
            env = montezuma_revenge_env.MontezumaRevenge(0, child, 1, 0, 18000)
        env.start()
        self.current_observation = np.zeros((4, 84, 84))

        while True:
            observation_tensor = torch.from_numpy(
                np.expand_dims(self.current_observation, 0)).float().to(
                self.device)

            predicted_action, value1, value2 = self.model.step(
                observation_tensor / 255)
            parent.send(predicted_action[0])
            self.current_observation, rew, done = parent.recv()
Ejemplo n.º 4
0
class SC2Environment(environment.Environment):
    def __init__(self, env_args):
        super(SC2Environment, self).__init__()
        env = partial(make_sc2env, **env_args)
        self.conn, child_conn = Pipe()
        self.proc = Process(target=worker,
                            args=(child_conn, CloudpickleWrapper(env)))
        self.proc.start()
        self.reset()

    @staticmethod
    def get_action_size():
        return len(FUNCTIONS)

    def reset(self):
        self.conn.send([COMMAND_RESET, None])
        return [self.conn.recv()]

    def close(self):
        self.conn.send([COMMAND_TERMINATE, None])
        self.conn.close()
        self.proc.join()
        print("SC2 environment closed")

    def step(self, actions):
        self.conn.send([COMMAND_STEP, actions])
        obs = self.conn.recv()
        return [obs], obs.reward, obs.last()
Ejemplo n.º 5
0
class DummyServer(INeuralNetworkAPI, IFlightControl):
    def __init__(self, **kwargs):
        self.handler_conn, server_conn = Pipe()
        self.handler = HandlerProcess(server_conn=server_conn, **kwargs)
        self.handler.start()

    def forward(self, batch: TikTensor) -> None:
        pass

        # self.handler_conn.send(
        #     (
        #         "forward",
        #         {"keys": [a.id for a in batch], "data": torch.stack([torch.from_numpy(a.as_numpy()) for a in batch])},
        #     )
        # )

    def active_children(self):
        self.handler_conn.send(("active_children", {}))

    def listen(self, timeout: float = 10) -> Union[None, Tuple[str, dict]]:
        if self.handler_conn.poll(timeout=timeout):
            answer = self.handler_conn.recv()
            logger.debug("got answer: %s", answer)
            return answer
        else:
            return None

    def shutdown(self):
        self.handler_conn.send(SHUTDOWN)
        got_shutdown_answer = False
        while self.handler.is_alive():
            if self.handler_conn.poll(timeout=2):
                answer = self.handler_conn.recv()
                if answer == SHUTDOWN_ANSWER:
                    got_shutdown_answer = True

        assert got_shutdown_answer
Ejemplo n.º 6
0
def main(run_id=0, checkpoint=None, save_interval=1000):
    print({section: dict(config[section]) for section in config.sections()})

    train_method = default_config['TrainMethod']

    # Create environment
    env_id = default_config['EnvID']
    env_type = default_config['EnvType']

    if env_type == 'mario':
        print('Mario environment not fully implemented - thomaseh')
        raise NotImplementedError
        env = BinarySpaceToDiscreteSpaceEnv(
            gym_super_mario_bros.make(env_id), COMPLEX_MOVEMENT)
    elif env_type == 'atari':
        env = gym.make(env_id)
    else:
        raise NotImplementedError
    input_size = env.observation_space.shape  # 4
    output_size = env.action_space.n  # 2

    if 'Breakout' in env_id:
        output_size -= 1

    env.close()

    # Load configuration parameters
    is_load_model = checkpoint is not None
    is_render = False
    model_path = 'models/{}_{}_run{}_model'.format(env_id, train_method, run_id)
    if train_method == 'RND':
        predictor_path = 'models/{}_{}_run{}_pred'.format(env_id, train_method, run_id)
        target_path = 'models/{}_{}_run{}_target'.format(env_id, train_method, run_id)
    elif train_method == 'generative':
        predictor_path = 'models/{}_{}_run{}_vae'.format(env_id, train_method, run_id)
   

    writer = SummaryWriter(comment='_{}_{}_run{}'.format(env_id, train_method, run_id))

    use_cuda = default_config.getboolean('UseGPU')
    use_gae = default_config.getboolean('UseGAE')
    use_noisy_net = default_config.getboolean('UseNoisyNet')

    lam = float(default_config['Lambda'])
    num_worker = int(default_config['NumEnv'])

    num_step = int(default_config['NumStep'])
    num_rollouts = int(default_config['NumRollouts'])
    num_pretrain_rollouts = int(default_config['NumPretrainRollouts'])

    ppo_eps = float(default_config['PPOEps'])
    epoch = int(default_config['Epoch'])
    mini_batch = int(default_config['MiniBatch'])
    batch_size = int(num_step * num_worker / mini_batch)
    learning_rate = float(default_config['LearningRate'])
    entropy_coef = float(default_config['Entropy'])
    gamma = float(default_config['Gamma'])
    int_gamma = float(default_config['IntGamma'])
    clip_grad_norm = float(default_config['ClipGradNorm'])
    ext_coef = float(default_config['ExtCoef'])
    int_coef = float(default_config['IntCoef'])

    sticky_action = default_config.getboolean('StickyAction')
    action_prob = float(default_config['ActionProb'])
    life_done = default_config.getboolean('LifeDone')

    reward_rms = RunningMeanStd()
    obs_rms = RunningMeanStd(shape=(1, 1, 84, 84))
    pre_obs_norm_step = int(default_config['ObsNormStep'])
    discounted_reward = RewardForwardFilter(int_gamma)

    if train_method == 'RND':
        agent = RNDAgent
    elif train_method == 'generative':
        agent = GenerativeAgent
    else:
        raise NotImplementedError

    if default_config['EnvType'] == 'atari':
        env_type = AtariEnvironment
    elif default_config['EnvType'] == 'mario':
        env_type = MarioEnvironment
    else:
        raise NotImplementedError

    # Initialize agent
    agent = agent(
        input_size,
        output_size,
        num_worker,
        num_step,
        gamma,
        lam=lam,
        learning_rate=learning_rate,
        ent_coef=entropy_coef,
        clip_grad_norm=clip_grad_norm,
        epoch=epoch,
        batch_size=batch_size,
        ppo_eps=ppo_eps,
        use_cuda=use_cuda,
        use_gae=use_gae,
        use_noisy_net=use_noisy_net
    )

    # Load pre-existing model
    if is_load_model:
        print('load model...')
        if use_cuda:
            agent.model.load_state_dict(torch.load(model_path))
            if train_method == 'RND':
                agent.rnd.predictor.load_state_dict(torch.load(predictor_path))
                agent.rnd.target.load_state_dict(torch.load(target_path))
            elif train_method == 'generative':
                agent.vae.load_state_dict(torch.load(predictor_path))
        else:
            agent.model.load_state_dict(
                torch.load(model_path, map_location='cpu'))
            if train_method == 'RND':
                agent.rnd.predictor.load_state_dict(
                    torch.load(predictor_path, map_location='cpu'))
                agent.rnd.target.load_state_dict(
                    torch.load(target_path, map_location='cpu'))
            elif train_method == 'generative':
                agent.vae.load_state_dict(torch.load(predictor_path, map_location='cpu'))
        print('load finished!')

    # Create workers to run in environments
    works = []
    parent_conns = []
    child_conns = []
    for idx in range(num_worker):
        parent_conn, child_conn = Pipe()
        work = env_type(
            env_id, is_render, idx, child_conn, sticky_action=sticky_action,
            p=action_prob, life_done=life_done,
        )
        work.start()
        works.append(work)
        parent_conns.append(parent_conn)
        child_conns.append(child_conn)

    states = np.zeros([num_worker, 4, 84, 84], dtype='float32')

    sample_episode = 0
    sample_rall = 0
    sample_step = 0
    sample_env_idx = 0
    sample_i_rall = 0
    global_update = 0
    global_step = 0

    # Initialize observation normalizers
    print('Start to initialize observation normalization parameter...')
    next_obs = np.zeros([num_worker * num_step, 1, 84, 84])
    for step in range(num_step * pre_obs_norm_step):
        actions = np.random.randint(0, output_size, size=(num_worker,))

        for parent_conn, action in zip(parent_conns, actions):
            parent_conn.send(action)

        for idx, parent_conn in enumerate(parent_conns):
            s, r, d, rd, lr, _ = parent_conn.recv()
            next_obs[(step % num_step) * num_worker + idx, 0, :, :] = s[3, :, :]

        if (step % num_step) == num_step - 1:
            next_obs = np.stack(next_obs)
            obs_rms.update(next_obs)
            next_obs = np.zeros([num_worker * num_step, 1, 84, 84])
    print('End to initialize...')

    # Initialize stats dict
    stats = {
        'total_reward': [],
        'ep_length': [],
        'num_updates': [],
        'frames_seen': [],
    }

    # Main training loop
    while True:
        total_state = np.zeros([num_worker * num_step, 4, 84, 84], dtype='float32')
        total_next_obs = np.zeros([num_worker * num_step, 1, 84, 84])
        total_reward, total_done, total_next_state, total_action, \
            total_int_reward, total_ext_values, total_int_values, total_policy, \
            total_policy_np = [], [], [], [], [], [], [], [], []

        # Step 1. n-step rollout (collect data)
        for step in range(num_step):
            actions, value_ext, value_int, policy = agent.get_action(states/255.)

            for parent_conn, action in zip(parent_conns, actions):
                parent_conn.send(action)

            next_obs = np.zeros([num_worker, 1, 84, 84])
            next_states = np.zeros([num_worker, 4, 84, 84])
            rewards, dones, real_dones, log_rewards = [], [], [], []
            for idx, parent_conn in enumerate(parent_conns):
                s, r, d, rd, lr, stat = parent_conn.recv()
                next_states[idx] = s
                rewards.append(r)
                dones.append(d)
                real_dones.append(rd)
                log_rewards.append(lr)
                next_obs[idx, 0] = s[3, :, :]
                total_next_obs[idx * num_step + step, 0] = s[3, :, :]

                if rd:
                    stats['total_reward'].append(stat[0])
                    stats['ep_length'].append(stat[1])
                    stats['num_updates'].append(global_update)
                    stats['frames_seen'].append(global_step)

            rewards = np.hstack(rewards)
            dones = np.hstack(dones)
            real_dones = np.hstack(real_dones)

            # Compute total reward = intrinsic reward + external reward
            next_obs -= obs_rms.mean
            next_obs /= np.sqrt(obs_rms.var)
            next_obs.clip(-5, 5, out=next_obs)
            intrinsic_reward = agent.compute_intrinsic_reward(next_obs)
            intrinsic_reward = np.hstack(intrinsic_reward)
            sample_i_rall += intrinsic_reward[sample_env_idx]

            for idx, state in enumerate(states):
                total_state[idx * num_step + step] = state
            total_int_reward.append(intrinsic_reward)
            total_reward.append(rewards)
            total_done.append(dones)
            total_action.append(actions)
            total_ext_values.append(value_ext)
            total_int_values.append(value_int)
            total_policy.append(policy)
            total_policy_np.append(policy.cpu().numpy())

            states = next_states[:, :, :, :]

            sample_rall += log_rewards[sample_env_idx]

            sample_step += 1
            if real_dones[sample_env_idx]:
                sample_episode += 1
                writer.add_scalar('data/reward_per_epi', sample_rall, sample_episode)
                writer.add_scalar('data/reward_per_rollout', sample_rall, global_update)
                writer.add_scalar('data/step', sample_step, sample_episode)
                sample_rall = 0
                sample_step = 0
                sample_i_rall = 0

        # calculate last next value
        _, value_ext, value_int, _ = agent.get_action(np.float32(states) / 255.)
        total_ext_values.append(value_ext)
        total_int_values.append(value_int)
        # --------------------------------------------------

        total_reward = np.stack(total_reward).transpose().clip(-1, 1)
        total_action = np.stack(total_action).transpose().reshape([-1])
        total_done = np.stack(total_done).transpose()
        total_ext_values = np.stack(total_ext_values).transpose()
        total_int_values = np.stack(total_int_values).transpose()
        total_logging_policy = np.vstack(total_policy_np)

        # Step 2. calculate intrinsic reward
        # running mean intrinsic reward
        total_int_reward = np.stack(total_int_reward).transpose()
        total_reward_per_env = np.array([discounted_reward.update(reward_per_step) for reward_per_step in
                                         total_int_reward.T])
        mean, std, count = np.mean(total_reward_per_env), np.std(total_reward_per_env), len(total_reward_per_env)
        reward_rms.update_from_moments(mean, std ** 2, count)

        # normalize intrinsic reward
        total_int_reward /= np.sqrt(reward_rms.var)
        writer.add_scalar('data/int_reward_per_epi', np.sum(total_int_reward) / num_worker, sample_episode)
        writer.add_scalar('data/int_reward_per_rollout', np.sum(total_int_reward) / num_worker, global_update)
        # -------------------------------------------------------------------------------------------

        # logging Max action probability
        writer.add_scalar('data/max_prob', softmax(total_logging_policy).max(1).mean(), sample_episode)

        # Step 3. make target and advantage
        # extrinsic reward calculate
        ext_target, ext_adv = make_train_data(total_reward,
                                              total_done,
                                              total_ext_values,
                                              gamma,
                                              num_step,
                                              num_worker)

        # intrinsic reward calculate
        # None Episodic
        int_target, int_adv = make_train_data(total_int_reward,
                                              np.zeros_like(total_int_reward),
                                              total_int_values,
                                              int_gamma,
                                              num_step,
                                              num_worker)

        # add ext adv and int adv
        total_adv = int_adv * int_coef + ext_adv * ext_coef
        # -----------------------------------------------

        # Step 4. update obs normalize param
        obs_rms.update(total_next_obs)
        # -----------------------------------------------

        # Step 5. Training!
        total_state /= 255.
        total_next_obs -= obs_rms.mean
        total_next_obs /= np.sqrt(obs_rms.var)
        total_next_obs.clip(-5, 5, out=total_next_obs)

        agent.train_model(total_state, ext_target, int_target, total_action,
                          total_adv, total_next_obs, total_policy)

        global_step += (num_worker * num_step)
        global_update += 1
        if global_update % save_interval == 0:
            print('Saving model at global step={}, num rollouts={}.'.format(
                global_step, global_update))
            torch.save(agent.model.state_dict(), model_path + "_{}.pt".format(global_update))
            if train_method == 'RND':
                torch.save(agent.rnd.predictor.state_dict(), predictor_path + '_{}.pt'.format(global_update))
                torch.save(agent.rnd.target.state_dict(), target_path + '_{}.pt'.format(global_update))
            elif train_method == 'generative':
                torch.save(agent.vae.state_dict(), predictor_path + '_{}.pt'.format(global_update))

            # Save stats to pickle file
            with open('models/{}_{}_run{}_stats_{}.pkl'.format(env_id, train_method, run_id, global_update),'wb') as f:
                pickle.dump(stats, f)

        if global_update == num_rollouts + num_pretrain_rollouts:
            print('Finished Training.')
            break
def main():

    args = parse_arguments()

    train_method = args.train_method
    env_id = args.env_id
    env_type = args.env_type

    if env_type == 'atari':
        env = gym.make(env_id)
    else:
        raise NotImplementedError

    input_size = env.observation_space.shape  # 4
    output_size = env.action_space.n  # 2

    env.close()

    is_load_model = False
    is_render = False
    os.makedirs('models', exist_ok=True)
    model_path = 'models/{}.model'.format(env_id)
    predictor_path = 'models/{}.pred'.format(env_id)
    target_path = 'models/{}.target'.format(env_id)

    results_dir = os.path.join('outputs', args.env_id)
    os.makedirs(results_dir, exist_ok=True)
    logger = Logger(results_dir)
    writer = SummaryWriter(
        os.path.join(results_dir, 'tensorboard', args.env_id))

    use_cuda = args.use_gpu
    use_gae = args.use_gae
    use_noisy_net = args.use_noisynet
    lam = args.lam
    num_worker = args.num_env
    num_step = args.num_step
    ppo_eps = args.ppo_eps
    epoch = args.epoch
    mini_batch = args.minibatch
    batch_size = int(num_step * num_worker / mini_batch)
    learning_rate = args.learning_rate
    entropy_coef = args.entropy
    gamma = args.gamma
    int_gamma = args.int_gamma
    clip_grad_norm = args.clip_grad_norm
    ext_coef = args.ext_coef
    int_coef = args.int_coef
    sticky_action = args.sticky_action
    action_prob = args.action_prob
    life_done = args.life_done
    pre_obs_norm_step = args.obs_norm_step

    reward_rms = RunningMeanStd()
    obs_rms = RunningMeanStd(shape=(1, 1, 84, 84))
    discounted_reward = RewardForwardFilter(int_gamma)

    agent = RNDAgent

    if args.env_type == 'atari':
        env_type = AtariEnvironment
    else:
        raise NotImplementedError

    agent = agent(input_size,
                  output_size,
                  num_worker,
                  num_step,
                  gamma,
                  lam=lam,
                  learning_rate=learning_rate,
                  ent_coef=entropy_coef,
                  clip_grad_norm=clip_grad_norm,
                  epoch=epoch,
                  batch_size=batch_size,
                  ppo_eps=ppo_eps,
                  use_cuda=use_cuda,
                  use_gae=use_gae,
                  use_noisy_net=use_noisy_net)

    logger.info('Start to initialize workers')
    works = []
    parent_conns = []
    child_conns = []
    for idx in range(num_worker):
        parent_conn, child_conn = Pipe()
        work = env_type(env_id,
                        is_render,
                        idx,
                        child_conn,
                        sticky_action=sticky_action,
                        p=action_prob,
                        life_done=life_done)
        work.start()
        works.append(work)
        parent_conns.append(parent_conn)
        child_conns.append(child_conn)

    states = np.zeros([num_worker, 4, 84, 84])

    sample_episode = 0
    sample_rall = 0
    sample_step = 0
    sample_env_idx = 0
    sample_i_rall = 0
    global_update = 0
    global_step = 0

    # normalize obs
    logger.info('Start to initailize observation normalization parameter.....')
    next_obs = []
    for step in range(num_step * pre_obs_norm_step):
        actions = np.random.randint(0, output_size, size=(num_worker, ))

        for parent_conn, action in zip(parent_conns, actions):
            parent_conn.send(action)

        for parent_conn in parent_conns:
            s, r, d, rd, lr, nr = parent_conn.recv()
            next_obs.append(s[3, :, :].reshape([1, 84, 84]))

        if len(next_obs) % (num_step * num_worker) == 0:
            next_obs = np.stack(next_obs)
            obs_rms.update(next_obs)
            next_obs = []
    logger.info('End to initalize...')

    while True:
        logger.info('Iteration: {}'.format(global_update))
        #####################################################################################################
        total_state, total_reward, total_done, total_next_state, \
            total_action, total_int_reward, total_next_obs, total_ext_values, \
            total_int_values, total_policy, total_policy_np, total_num_rooms = \
            [], [], [], [], [], [], [], [], [], [], [], []
        #####################################################################################################
        global_step += (num_worker * num_step)
        global_update += 1

        # Step 1. n-step rollout
        for _ in range(num_step):
            actions, value_ext, value_int, policy = agent.get_action(
                np.float32(states) / 255.)

            for parent_conn, action in zip(parent_conns, actions):
                parent_conn.send(action)

            #################################################################################################
            next_states, rewards, dones, real_dones, log_rewards, next_obs, num_rooms = \
                [], [], [], [], [], [], []
            #################################################################################################
            for parent_conn in parent_conns:
                s, r, d, rd, lr, nr = parent_conn.recv()
                next_states.append(s)
                rewards.append(r)
                dones.append(d)
                real_dones.append(rd)
                log_rewards.append(lr)
                #############################################################################################
                num_rooms.append(nr)
                #############################################################################################
                next_obs.append(s[3, :, :].reshape([1, 84, 84]))

            next_states = np.stack(next_states)
            rewards = np.hstack(rewards)
            dones = np.hstack(dones)
            real_dones = np.hstack(real_dones)
            next_obs = np.stack(next_obs)
            #################################################################################################
            num_rooms = np.hstack(num_rooms)
            #################################################################################################

            # total reward = int reward + ext Reward
            intrinsic_reward = agent.compute_intrinsic_reward(
                ((next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5))
            intrinsic_reward = np.hstack(intrinsic_reward)
            sample_i_rall += intrinsic_reward[sample_env_idx]

            total_next_obs.append(next_obs)
            total_int_reward.append(intrinsic_reward)
            total_state.append(states)
            total_reward.append(rewards)
            total_done.append(dones)
            total_action.append(actions)
            total_ext_values.append(value_ext)
            total_int_values.append(value_int)
            total_policy.append(policy)
            total_policy_np.append(policy.cpu().numpy())
            #####################################################################################################
            total_num_rooms.append(num_rooms)
            #####################################################################################################

            states = next_states[:, :, :, :]

            sample_rall += log_rewards[sample_env_idx]

            sample_step += 1
            if real_dones[sample_env_idx]:
                sample_episode += 1
                writer.add_scalar('data/returns_vs_frames', sample_rall,
                                  global_step)
                writer.add_scalar('data/lengths_vs_frames', sample_step,
                                  global_step)
                writer.add_scalar('data/reward_per_epi', sample_rall,
                                  sample_episode)
                writer.add_scalar('data/reward_per_rollout', sample_rall,
                                  global_update)
                writer.add_scalar('data/step', sample_step, sample_episode)
                sample_rall = 0
                sample_step = 0
                sample_i_rall = 0

        # calculate last next value
        _, value_ext, value_int, _ = agent.get_action(
            np.float32(states) / 255.)
        total_ext_values.append(value_ext)
        total_int_values.append(value_int)
        # --------------------------------------------------

        total_state = np.stack(total_state).transpose([1, 0, 2, 3, 4]).reshape(
            [-1, 4, 84, 84])
        total_reward = np.stack(total_reward).transpose().clip(-1, 1)
        total_action = np.stack(total_action).transpose().reshape([-1])
        total_done = np.stack(total_done).transpose()
        total_next_obs = np.stack(total_next_obs).transpose(
            [1, 0, 2, 3, 4]).reshape([-1, 1, 84, 84])
        total_ext_values = np.stack(total_ext_values).transpose()
        total_int_values = np.stack(total_int_values).transpose()
        total_logging_policy = np.vstack(total_policy_np)
        #####################################################################################################
        total_num_rooms = np.stack(total_num_rooms).transpose().reshape(-1)
        total_done_cal = total_done.reshape(-1)
        if np.any(total_done_cal):
            avg_num_rooms = np.mean(total_num_rooms[total_done_cal])
        else:
            avg_num_rooms = 0
        #####################################################################################################

        # Step 2. calculate intrinsic reward
        # running mean intrinsic reward
        total_int_reward = np.stack(total_int_reward).transpose()
        total_reward_per_env = np.array([
            discounted_reward.update(reward_per_step)
            for reward_per_step in total_int_reward.T
        ])
        mean, std, count = np.mean(total_reward_per_env), np.std(
            total_reward_per_env), len(total_reward_per_env)
        reward_rms.update_from_moments(mean, std**2, count)

        # normalize intrinsic reward
        total_int_reward /= np.sqrt(reward_rms.var)
        writer.add_scalar('data/int_reward_per_epi',
                          np.sum(total_int_reward) / num_worker,
                          sample_episode)
        writer.add_scalar('data/int_reward_per_rollout',
                          np.sum(total_int_reward) / num_worker, global_update)
        #####################################################################################################
        writer.add_scalar('data/avg_num_rooms_per_iteration', avg_num_rooms,
                          global_update)
        writer.add_scalar('data/avg_num_rooms_per_step', avg_num_rooms,
                          global_step)
        #####################################################################################################
        # -------------------------------------------------------------------------------------------

        # logging Max action probability
        writer.add_scalar('data/max_prob',
                          softmax(total_logging_policy).max(1).mean(),
                          sample_episode)

        # Step 3. make target and advantage
        # extrinsic reward calculate
        ext_target, ext_adv = make_train_data(total_reward, total_done,
                                              total_ext_values, gamma,
                                              num_step, num_worker)

        # intrinsic reward calculate
        # None Episodic
        int_target, int_adv = make_train_data(total_int_reward,
                                              np.zeros_like(total_int_reward),
                                              total_int_values, int_gamma,
                                              num_step, num_worker)

        # add ext adv and int adv
        total_adv = int_adv * int_coef + ext_adv * ext_coef
        # -----------------------------------------------

        # Step 4. update obs normalize param
        obs_rms.update(total_next_obs)
        # -----------------------------------------------

        # Step 5. Training!
        agent.train_model(
            np.float32(total_state) / 255., ext_target, int_target,
            total_action, total_adv,
            ((total_next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(
                -5, 5), total_policy)

        if global_update % 1000 == 0:
            torch.save(agent.model.state_dict(),
                       'models/{}-{}.model'.format(env_id, global_update))
            logger.info('Now Global Step :{}'.format(global_step))
            torch.save(agent.model.state_dict(), model_path)
            torch.save(agent.rnd.predictor.state_dict(), predictor_path)
            torch.save(agent.rnd.target.state_dict(), target_path)
Ejemplo n.º 8
0
def main():
    print({section: dict(config[section]) for section in config.sections()})
    train_method = default_config['TrainMethod']
    env_id = default_config['EnvID']
    env_type = default_config['EnvType']

    if env_type == 'mario':
        env = BinarySpaceToDiscreteSpaceEnv(gym_super_mario_bros.make(env_id),
                                            COMPLEX_MOVEMENT)
    elif env_type == 'atari':
        env = gym.make(env_id)
    else:
        raise NotImplementedError
    input_size = env.observation_space.shape  # 4
    output_size = env.action_space.n  # 2

    if 'Breakout' in env_id:
        output_size -= 1

    env.close()

    is_load_model = False
    is_render = False
    model_path = 'models/{}.model'.format(env_id)
    icm_path = 'models/{}.icm'.format(env_id)

    writer = SummaryWriter()

    use_cuda = default_config.getboolean('UseGPU')
    use_gae = default_config.getboolean('UseGAE')
    use_noisy_net = default_config.getboolean('UseNoisyNet')

    lam = float(default_config['Lambda'])
    num_worker = 32

    num_step = 128

    ppo_eps = float(default_config['PPOEps'])
    epoch = int(default_config['Epoch'])
    mini_batch = int(default_config['MiniBatch'])
    batch_size = 256
    learning_rate = float(default_config['LearningRate'])
    entropy_coef = float(default_config['Entropy'])
    gamma = float(default_config['Gamma'])
    eta = float(default_config['ETA'])
    clip_grad_norm = float(default_config['ClipGradNorm'])

    reward_rms = RunningMeanStd()
    obs_rms = RunningMeanStd(shape=(1, 1, 84, 84))

    pre_obs_norm_step = int(default_config['ObsNormStep'])
    discounted_reward = RewardForwardFilter(gamma)

    agent = ICMAgent

    if default_config['EnvType'] == 'atari':
        env_type = AtariEnvironment
    elif default_config['EnvType'] == 'mario':
        env_type = MarioEnvironment
    else:
        raise NotImplementedError

    agent = agent(input_size,
                  output_size,
                  num_worker,
                  num_step,
                  gamma,
                  lam=lam,
                  learning_rate=learning_rate,
                  ent_coef=entropy_coef,
                  clip_grad_norm=clip_grad_norm,
                  epoch=epoch,
                  batch_size=batch_size,
                  ppo_eps=ppo_eps,
                  eta=eta,
                  use_cuda=use_cuda,
                  use_gae=use_gae,
                  use_noisy_net=use_noisy_net)

    if is_load_model:
        if use_cuda:
            agent.model.load_state_dict(torch.load(model_path))
        else:
            agent.model.load_state_dict(
                torch.load(model_path, map_location='cpu'))

    works = []
    parent_conns = []
    child_conns = []
    for idx in range(num_worker):
        parent_conn, child_conn = Pipe()
        work = env_type(env_id, is_render, idx, child_conn)
        work.start()
        works.append(work)
        parent_conns.append(parent_conn)
        child_conns.append(child_conn)

    states = np.zeros([num_worker, 4, 84, 84])

    sample_episode = 0
    sample_rall = 0
    sample_step = 0
    sample_env_idx = 0
    sample_i_rall = 0
    global_update = 0
    global_step = 0

    # normalize obs
    print('Start to initailize observation normalization parameter.....')
    next_obs = []
    steps = 0
    while steps < pre_obs_norm_step:
        steps += num_worker
        actions = np.random.randint(0, output_size, size=(num_worker, ))

        for parent_conn, action in zip(parent_conns, actions):
            parent_conn.send(action)

        for parent_conn in parent_conns:
            s, r, d, rd, lr = parent_conn.recv()
            next_obs.append(s[3, :, :].reshape([1, 84, 84]))

    next_obs = np.stack(next_obs)
    obs_rms.update(next_obs)
    print('End to initalize...')

    while True:
        total_state, total_reward, total_done, total_next_state, total_action, total_int_reward, total_next_obs, total_values, total_policy = \
            [], [], [], [], [], [], [], [], []
        global_step += (num_worker * num_step)
        global_update += 1

        # Step 1. n-step rollout
        for _ in range(num_step):
            actions, value, policy = agent.get_action(
                (np.float32(states) - obs_rms.mean) / np.sqrt(obs_rms.var))

            for parent_conn, action in zip(parent_conns, actions):
                parent_conn.send(action)

            next_states, rewards, dones, real_dones, log_rewards, next_obs = [], [], [], [], [], []
            for parent_conn in parent_conns:
                s, r, d, rd, lr = parent_conn.recv()
                next_states.append(s)
                rewards.append(r)
                dones.append(d)
                real_dones.append(rd)
                log_rewards.append(lr)

            next_states = np.stack(next_states)
            rewards = np.hstack(rewards)
            dones = np.hstack(dones)
            real_dones = np.hstack(real_dones)

            # total reward = int reward
            intrinsic_reward = agent.compute_intrinsic_reward(
                (states - obs_rms.mean) / np.sqrt(obs_rms.var),
                (next_states - obs_rms.mean) / np.sqrt(obs_rms.var), actions)
            intrinsic_reward = np.hstack(intrinsic_reward)
            sample_i_rall += intrinsic_reward[sample_env_idx]

            total_int_reward.append(intrinsic_reward)
            total_state.append(states)
            total_next_state.append(next_states)
            total_reward.append(rewards)
            total_done.append(dones)
            total_action.append(actions)
            total_values.append(value)
            total_policy.append(policy)

            states = next_states[:, :, :, :]

            sample_rall += log_rewards[sample_env_idx]

            sample_step += 1
            if real_dones[sample_env_idx]:
                sample_episode += 1
                writer.add_scalar('data/reward_per_epi', sample_rall,
                                  sample_episode)
                writer.add_scalar('data/reward_per_rollout', sample_rall,
                                  global_update)
                writer.add_scalar('data/step', sample_step, sample_episode)
                sample_rall = 0
                sample_step = 0
                sample_i_rall = 0

        # calculate last next value
        _, value, _ = agent.get_action(
            (np.float32(states) - obs_rms.mean) / np.sqrt(obs_rms.var))
        total_values.append(value)
        # --------------------------------------------------

        total_state = np.stack(total_state).transpose([1, 0, 2, 3, 4]).reshape(
            [-1, 4, 84, 84])
        total_next_state = np.stack(total_next_state).transpose(
            [1, 0, 2, 3, 4]).reshape([-1, 4, 84, 84])
        total_action = np.stack(total_action).transpose().reshape([-1])
        total_reward = np.stack(total_reward).transpose()
        total_done = np.stack(total_done).transpose()
        total_values = np.stack(total_values).transpose()
        total_logging_policy = np.vstack(total_policy)

        # Step 2. calculate intrinsic reward
        # running mean intrinsic reward
        total_int_reward = np.stack(total_int_reward).transpose()
        total_reward_per_env = np.array([
            discounted_reward.update(reward_per_step)
            for reward_per_step in total_int_reward.T
        ])
        mean, std, count = np.mean(total_reward_per_env), np.std(
            total_reward_per_env), len(total_reward_per_env)
        reward_rms.update_from_moments(mean, std**2, count)

        # normalize intrinsic reward
        total_int_reward /= np.sqrt(reward_rms.var)
        writer.add_scalar('data/int_reward_per_epi',
                          np.sum(total_int_reward) / num_worker,
                          sample_episode)
        writer.add_scalar('data/int_reward_per_rollout',
                          np.sum(total_int_reward) / num_worker, global_update)
        # -------------------------------------------------------------------------------------------

        # logging Max action probability
        writer.add_scalar('data/max_prob',
                          softmax(total_logging_policy).max(1).mean(),
                          sample_episode)

        # Step 3. make target and advantage
        target, adv = make_train_data_icm(total_int_reward,
                                          np.zeros_like(total_int_reward),
                                          total_values, gamma, num_step,
                                          num_worker)

        adv = (adv - np.mean(adv)) / (np.std(adv) + 1e-8)
        # -----------------------------------------------

        # Step 5. Training!
        print('training')
        agent.train_model(
            (np.float32(total_state) - obs_rms.mean) / np.sqrt(obs_rms.var),
            (np.float32(total_next_state) - obs_rms.mean) /
            np.sqrt(obs_rms.var), target, total_action, adv, total_policy)

        if global_step % (num_worker * num_step * 100) == 0:
            print('Now Global Step :{}'.format(global_step))
            torch.save(agent.model.state_dict(), model_path)
            torch.save(agent.icm.state_dict(), icm_path)
Ejemplo n.º 9
0
def main():
    env = gym.make(args.env_name)
    env.seed(500)
    torch.manual_seed(500)

    img_shape = env.observation_space.shape
    num_actions = env.action_space.n - 1
    print('image size:', img_shape)
    print('action size:', num_actions)

    net = FuN(num_actions, args, device)
    optimizer = optim.RMSprop(net.parameters(), lr=0.00025, eps=0.01)
    writer = SummaryWriter('logs')

    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)

    workers = []
    parent_conns = []
    child_conns = []

    for i in range(args.num_envs):
        parent_conn, child_conn = Pipe()
        worker = EnvWorker(args.env_name, args.render, child_conn)
        worker.start()
        workers.append(worker)
        parent_conns.append(parent_conn)
        child_conns.append(child_conn)

    net.to(device)
    net.train()

    global_steps = 0
    score = np.zeros(args.num_envs)
    count = 0
    grad_norm = 0

    histories = torch.zeros([args.num_envs, 3, 84, 84]).to(device)

    m_hx = torch.zeros(args.num_envs, num_actions * 16).to(device)
    m_cx = torch.zeros(args.num_envs, num_actions * 16).to(device)
    m_lstm = (m_hx, m_cx)

    w_hx = torch.zeros(args.num_envs, num_actions * 16).to(device)
    w_cx = torch.zeros(args.num_envs, num_actions * 16).to(device)
    w_lstm = (w_hx, w_cx)

    goals_horizon = torch.zeros(args.num_envs, args.horizon + 1,
                                num_actions * 16).to(device)

    while True:
        count += 1
        memory = Memory()
        global_steps += (args.num_envs * args.num_step)

        # gather samples from the environment
        for i in range(args.num_step):
            # TODO: think about net output
            net_output = net(histories.to(device), m_lstm, w_lstm,
                             goals_horizon)
            policies, goal, goals_horizon, m_lstm, w_lstm, m_value, w_value_ext, w_value_int, m_state = net_output

            actions = get_action(policies, num_actions)

            # send action to each worker environment and get state information
            next_histories, rewards, masks, dones = [], [], [], []

            for i, (parent_conn,
                    action) in enumerate(zip(parent_conns, actions)):
                parent_conn.send(action)
                next_history, reward, dead, done = parent_conn.recv()
                next_histories.append(next_history)
                rewards.append(reward)
                masks.append(1 - dead)
                dones.append(done)

                if dead:
                    m_hx_mask = torch.ones(args.num_envs,
                                           num_actions * 16).to(device)
                    m_hx_mask[i, :] = m_hx_mask[i, :] * 0
                    m_cx_mask = torch.ones(args.num_envs,
                                           num_actions * 16).to(device)
                    m_cx_mask[i, :] = m_cx_mask[i, :] * 0
                    m_hx, m_cx = m_lstm
                    m_hx = m_hx * m_hx_mask
                    m_cx = m_cx * m_cx_mask
                    m_lstm = (m_hx, m_cx)

                    w_hx_mask = torch.ones(args.num_envs,
                                           num_actions * 16).to(device)
                    w_hx_mask[i, :] = w_hx_mask[i, :] * 0
                    w_cx_mask = torch.ones(args.num_envs,
                                           num_actions * 16).to(device)
                    w_cx_mask[i, :] = w_cx_mask[i, :] * 0
                    w_hx, w_cx = w_lstm
                    w_hx = w_hx * w_hx_mask
                    w_cx = w_cx * w_cx_mask
                    w_lstm = (w_hx, w_cx)

                    goal_init = torch.zeros(args.horizon + 1,
                                            num_actions * 16).to(device)
                    goals_horizon[i] = goal_init

            score += rewards[0]

            # if agent in first environment dies, print and log score
            for i in range(args.num_envs):
                if dones[i]:
                    entropy = -policies * torch.log(policies + 1e-5)
                    entropy = entropy.mean().data.cpu()
                    print(
                        'global steps {} | score: {} | entropy: {:.4f} | grad norm: {:.3f} '
                        .format(global_steps, score[i], entropy, grad_norm))
                    if i == 0:
                        writer.add_scalar('log/score', score[i], global_steps)
                    score[i] = 0

            next_histories = torch.Tensor(next_histories).to(device)
            rewards = np.hstack(rewards)
            masks = np.hstack(masks)
            memory.push(histories, next_histories, actions, rewards, masks,
                        goal, policies, m_lstm, w_lstm, m_value, w_value_ext,
                        w_value_int, m_state)
            histories = next_histories

        # Train every args.num_step
        if (global_steps % args.num_step) == 0:  # Need to fix logic
            transitions = memory.sample()
            loss, grad_norm = train_model(net, optimizer, transitions, args)
            m_hx, m_cx = m_lstm
            m_lstm = (m_hx.detach(), m_cx.detach())
            w_hx, w_cx = w_lstm
            w_lstm = (w_hx.detach(), w_cx.detach())
            goals_horizon = goals_horizon.detach()
            # avg_loss.append(loss.cpu().data)

        if count % args.save_interval == 0:
            ckpt_path = args.save_path + 'model.pt'
            torch.save(net.state_dict(), ckpt_path)
Ejemplo n.º 10
0
def main():
    args = get_args()
    device = torch.device('cuda' if args.cuda else 'cpu')

    env = gym.make(args.env_name)

    input_size = env.observation_space.shape  # 4
    output_size = env.action_space.n  # 2

    if 'Breakout' in args.env_name:
        output_size -= 1

    env.close()

    is_render = False
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    model_path = os.path.join(args.save_dir, args.env_name + '.model')
    predictor_path = os.path.join(args.save_dir, args.env_name + '.pred')
    target_path = os.path.join(args.save_dir, args.env_name + '.target')

    writer = SummaryWriter(log_dir=args.log_dir)

    reward_rms = RunningMeanStd()
    obs_rms = RunningMeanStd(shape=(1, 1, 84, 84))
    discounted_reward = RewardForwardFilter(args.ext_gamma)

    model = CnnActorCriticNetwork(input_size, output_size, args.use_noisy_net)
    rnd = RNDModel(input_size, output_size)
    model = model.to(device)
    rnd = rnd.to(device)
    optimizer = optim.Adam(list(model.parameters()) +
                           list(rnd.predictor.parameters()),
                           lr=args.lr)

    if args.load_model:
        if args.cuda:
            model.load_state_dict(torch.load(model_path))
        else:
            model.load_state_dict(torch.load(model_path, map_location='cpu'))

    works = []
    parent_conns = []
    child_conns = []
    for idx in range(args.num_worker):
        parent_conn, child_conn = Pipe()
        work = AtariEnvironment(args.env_name,
                                is_render,
                                idx,
                                child_conn,
                                sticky_action=args.sticky_action,
                                p=args.sticky_action_prob,
                                max_episode_steps=args.max_episode_steps)
        work.start()
        works.append(work)
        parent_conns.append(parent_conn)
        child_conns.append(child_conn)

    states = np.zeros([args.num_worker, 4, 84, 84])

    sample_env_index = 0  # Sample Environment index to log
    sample_episode = 0
    sample_rall = 0
    sample_step = 0
    sample_i_rall = 0
    global_update = 0
    global_step = 0

    # normalize observation
    print('Initializes observation normalization...')
    next_obs = []
    for step in range(args.num_step * args.pre_obs_norm_steps):
        actions = np.random.randint(0, output_size, size=(args.num_worker, ))

        for parent_conn, action in zip(parent_conns, actions):
            parent_conn.send(action)

        for parent_conn in parent_conns:
            next_state, reward, done, realdone, log_reward = parent_conn.recv()
            next_obs.append(next_state[3, :, :].reshape([1, 84, 84]))

        if len(next_obs) % (args.num_step * args.num_worker) == 0:
            next_obs = np.stack(next_obs)
            obs_rms.update(next_obs)
            next_obs = []

    print('Training...')
    while True:
        total_state, total_reward, total_done, total_next_state, total_action, total_int_reward, total_next_obs, total_ext_values, total_int_values, total_action_probs = [], [], [], [], [], [], [], [], [], []
        global_step += (args.num_worker * args.num_step)
        global_update += 1

        # Step 1. n-step rollout
        for _ in range(args.num_step):
            actions, value_ext, value_int, action_probs = get_action(
                model, device,
                np.float32(states) / 255.)

            for parent_conn, action in zip(parent_conns, actions):
                parent_conn.send(action)

            next_states, rewards, dones, real_dones, log_rewards, next_obs = [], [], [], [], [], []
            for parent_conn in parent_conns:
                next_state, reward, done, real_done, log_reward = parent_conn.recv(
                )
                next_states.append(next_state)
                rewards.append(reward)
                dones.append(done)
                real_dones.append(real_done)
                log_rewards.append(log_reward)
                next_obs.append(next_state[3, :, :].reshape([1, 84, 84]))

            next_states = np.stack(next_states)
            rewards = np.hstack(rewards)
            dones = np.hstack(dones)
            real_dones = np.hstack(real_dones)
            next_obs = np.stack(next_obs)

            # total reward = int reward + ext Reward
            intrinsic_reward = compute_intrinsic_reward(
                rnd, device,
                ((next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5))
            intrinsic_reward = np.hstack(intrinsic_reward)
            sample_i_rall += intrinsic_reward[sample_env_index]

            total_next_obs.append(next_obs)
            total_int_reward.append(intrinsic_reward)
            total_state.append(states)
            total_reward.append(rewards)
            total_done.append(dones)
            total_action.append(actions)
            total_ext_values.append(value_ext)
            total_int_values.append(value_int)
            total_action_probs.append(action_probs)

            states = next_states[:, :, :, :]

            sample_rall += log_rewards[sample_env_index]

            sample_step += 1
            if real_dones[sample_env_index]:
                sample_episode += 1
                writer.add_scalar('data/reward_per_epi', sample_rall,
                                  sample_episode)
                writer.add_scalar('data/reward_per_rollout', sample_rall,
                                  global_update)
                writer.add_scalar('data/step', sample_step, sample_episode)
                sample_rall = 0
                sample_step = 0
                sample_i_rall = 0

        # calculate last next value
        _, value_ext, value_int, _ = get_action(model, device,
                                                np.float32(states) / 255.)
        total_ext_values.append(value_ext)
        total_int_values.append(value_int)
        # --------------------------------------------------

        total_state = np.stack(total_state).transpose([1, 0, 2, 3, 4]).reshape(
            [-1, 4, 84, 84])
        total_reward = np.stack(total_reward).transpose().clip(-1, 1)
        total_action = np.stack(total_action).transpose().reshape([-1])
        total_done = np.stack(total_done).transpose()
        total_next_obs = np.stack(total_next_obs).transpose(
            [1, 0, 2, 3, 4]).reshape([-1, 1, 84, 84])
        total_ext_values = np.stack(total_ext_values).transpose()
        total_int_values = np.stack(total_int_values).transpose()
        total_logging_action_probs = np.vstack(total_action_probs)

        # Step 2. calculate intrinsic reward
        # running mean intrinsic reward
        total_int_reward = np.stack(total_int_reward).transpose()
        total_reward_per_env = np.array([
            discounted_reward.update(reward_per_step)
            for reward_per_step in total_int_reward.T
        ])
        mean, std, count = np.mean(total_reward_per_env), np.std(
            total_reward_per_env), len(total_reward_per_env)
        reward_rms.update_from_moments(mean, std**2, count)

        # normalize intrinsic reward
        total_int_reward /= np.sqrt(reward_rms.var)
        writer.add_scalar('data/int_reward_per_epi',
                          np.sum(total_int_reward) / args.num_worker,
                          sample_episode)
        writer.add_scalar('data/int_reward_per_rollout',
                          np.sum(total_int_reward) / args.num_worker,
                          global_update)
        # -------------------------------------------------------------------------------------------

        # logging Max action probability
        writer.add_scalar('data/max_prob',
                          total_logging_action_probs.max(1).mean(),
                          sample_episode)

        # Step 3. make target and advantage
        # extrinsic reward calculate
        ext_target, ext_adv = make_train_data(total_reward, total_done,
                                              total_ext_values, args.ext_gamma,
                                              args.gae_lambda, args.num_step,
                                              args.num_worker, args.use_gae)

        # intrinsic reward calculate
        # None Episodic
        int_target, int_adv = make_train_data(total_int_reward,
                                              np.zeros_like(total_int_reward),
                                              total_int_values, args.int_gamma,
                                              args.gae_lambda, args.num_step,
                                              args.num_worker, args.use_gae)

        # add ext adv and int adv
        total_adv = int_adv * args.int_coef + ext_adv * args.ext_coef
        # -----------------------------------------------

        # Step 4. update obs normalize param
        obs_rms.update(total_next_obs)
        # -----------------------------------------------

        # Step 5. Training!
        train_model(args, device, output_size, model, rnd, optimizer,
                    np.float32(total_state) / 255., ext_target, int_target,
                    total_action, total_adv,
                    ((total_next_obs - obs_rms.mean) /
                     np.sqrt(obs_rms.var)).clip(-5, 5), total_action_probs)

        if global_step % (args.num_worker * args.num_step *
                          args.save_interval) == 0:
            print('Now Global Step :{}'.format(global_step))
            torch.save(model.state_dict(), model_path)
            torch.save(rnd.predictor.state_dict(), predictor_path)
            torch.save(rnd.target.state_dict(), target_path)
Ejemplo n.º 11
0
def train(args):

    torch.multiprocessing.set_start_method('forkserver')

    num_envs = args.num_envs
    num_workers = args.num_workers
    total_envs = num_workers * num_envs
    game_name = args.env_name
    max_train_steps = args.max_train_steps
    n_steps = args.n_steps
    init_lr = args.lr
    gamma = args.gamma
    clip_grad_norm = args.clip_grad_norm
    num_action = gym.make(game_name).action_space.n
    image_size = 84
    n_stack = 4

    model = paac_ff(min_act=num_action).cuda()

    x = Variable(torch.zeros(total_envs, n_stack, image_size, image_size),
                 volatile=True).cuda()
    xs = [
        Variable(torch.zeros(total_envs, n_stack, image_size,
                             image_size)).cuda() for i in range(n_steps)
    ]

    share_reward = [
        Variable(torch.zeros(total_envs)).cuda() for _ in range(n_steps)
    ]
    share_mask = [
        Variable(torch.zeros(total_envs)).cuda() for _ in range(n_steps)
    ]
    constant_one = torch.ones(total_envs).cuda()

    optimizer = optim.Adam(model.parameters(), lr=init_lr)

    workers = []
    parent_conns = []
    child_conns = []
    for i in range(num_workers):
        parent_conn, child_conn = Pipe()
        w = worker(i, num_envs, game_name, n_stack, child_conn, args)
        w.start()
        workers.append(w)
        parent_conns.append(parent_conn)
        child_conns.append(child_conn)

    new_s = np.zeros((total_envs, n_stack, image_size, image_size))

    for global_step in range(1, max_train_steps + 1):

        cache_v_series = []
        entropies = []
        sampled_log_probs = []

        for step in range(n_steps):

            xs[step].data.copy_(torch.from_numpy(new_s))
            v, pi = model(xs[step])
            cache_v_series.append(v)

            sampling_action = pi.data.multinomial(1)

            log_pi = (pi + 1e-12).log()
            entropy = -(log_pi * pi).sum(1)
            sampled_log_prob = log_pi.gather(
                1, Variable(sampling_action)).squeeze()
            sampled_log_probs.append(sampled_log_prob)
            entropies.append(entropy)

            send_action = sampling_action.squeeze().cpu().numpy()
            send_action = np.split(send_action, num_workers)

            # send action and then get state
            for parent_conn, action in zip(parent_conns, send_action):
                parent_conn.send(action)

            batch_s, batch_r, batch_mask = [], [], []
            for parent_conn in parent_conns:
                s, r, mask = parent_conn.recv()
                batch_s.append(s)
                batch_r.append(r)
                batch_mask.append(mask)

            new_s = np.vstack(batch_s)
            r = np.hstack(batch_r).clip(-1, 1)  # clip reward
            mask = np.hstack(batch_mask)

            share_reward[step].data.copy_(torch.from_numpy(r))
            share_mask[step].data.copy_(torch.from_numpy(mask))

        x.data.copy_(torch.from_numpy(new_s))
        v, _ = model(x)  # v is volatile
        R = Variable(v.data.clone())
        v_loss = 0.0
        policy_loss = 0.0
        entropy_loss = 0.0

        for i in reversed(range(n_steps)):

            R = share_reward[i] + 0.99 * share_mask[i] * R
            advantage = R - cache_v_series[i]
            v_loss += advantage.pow(2).mul(0.5).mean()

            policy_loss -= sampled_log_probs[i].mul(advantage.detach()).mean()
            entropy_loss -= entropies[i].mean()

        total_loss = policy_loss + entropy_loss.mul(0.02) + v_loss * 0.5
        total_loss = total_loss.mul(1 / (n_steps))

        # adjust learning rate
        new_lr = init_lr - (global_step / max_train_steps) * init_lr
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lr

        optimizer.zero_grad()
        total_loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(),
                                                  clip_grad_norm)

        optimizer.step()

        if global_step % 10000 == 0:
            torch.save(model.state_dict(), './model/model_%s.pth' % game_name)

    for parent_conn in parent_conns:
        parent_conn.send(None)

    for w in workers:
        w.join()
Ejemplo n.º 12
0
class Worker(Process):
    def __init__(self, worker_id, args):
        super().__init__()

        self.id = worker_id
        self.args = args
        # for master use, for worker use
        self.pipe_master, self.pipe_worker = Pipe()
        self.exit_event = Event()

        # determine n_e
        q, r = divmod(args.n_e, args.n_w)

        if r:
            print('Warning: n_e % n_w != 0')

        if worker_id == args.n_w - 1:
            self.n_e = n_e = q + r
        else:
            self.n_e = n_e = q

        print('Worker', self.id, '] n_e = %d' % n_e)

        self.env_start = worker_id * q
        self.env_slice = slice(self.env_start, self.env_start + n_e)
        self.env_range = range(self.env_start, self.env_start + n_e)
        self.envs = None

        self.start()

    def make_environments(self):
        envs = []

        for _ in range(self.n_e):
            envs.append(gym.make(self.args.env, hack='train'))

        return envs

    def put_shared_tensors(self, actions, obs, rewards, terminals):
        assert (actions.is_shared() and obs.is_shared()
                and rewards.is_shared() and terminals.is_shared())

        self.pipe_master.send((actions, obs, rewards, terminals))

    def get_shared_tensors(self):
        actions, obs, rewards, terminals = self.pipe_worker.recv()
        assert (actions.is_shared() and obs.is_shared()
                and rewards.is_shared() and terminals.is_shared())
        return actions, obs, rewards, terminals

    def set_step_done(self):
        self.pipe_worker.send_bytes(b'1')

    def wait_step_done(self):
        self.pipe_master.recv_bytes(1)

    def set_action_done(self):
        self.pipe_master.send_bytes(b'1')

    def wait_action_done(self):
        self.pipe_worker.recv_bytes(1)

    def run(self):
        preprocess = PAACNet.preprocess

        envs = self.envs = self.make_environments()
        env_start = self.env_start
        t_max = self.args.t_max
        t = 0
        dones = [False] * self.args.n_e

        # get shared tensor
        actions, obs, rewards, terminals = self.get_shared_tensors()

        for i, env in enumerate(envs, start=env_start):
            obs[i] = preprocess(env.reset())

        self.set_step_done()

        while not self.exit_event.is_set():
            self.wait_action_done()

            for i, env in enumerate(envs, start=env_start):
                if not dones[i]:
                    ob, reward, done, info = env.step(actions[i])
                else:
                    ob, reward, done, info = env.reset(), 0, False, None

                obs[i] = preprocess(ob)
                rewards[t, i] = reward
                terminals[t, i] = dones[i] = done

            self.set_step_done()

            t += 1

            if t == t_max:
                t = 0
Ejemplo n.º 13
0
class OnlineVaeOffpolicyAlgorithm(TorchBatchRLAlgorithm):
    def __init__(self,
                 vae,
                 vae_trainer,
                 *base_args,
                 vae_save_period=1,
                 vae_training_schedule=vae_schedules.never_train,
                 oracle_data=False,
                 parallel_vae_train=True,
                 vae_min_num_steps_before_training=0,
                 uniform_dataset=None,
                 dataset_path=None,
                 rl_offpolicy_num_training_steps=0,
                 **base_kwargs):
        super().__init__(*base_args, **base_kwargs)
        assert isinstance(self.replay_buffer, OnlineVaeRelabelingBuffer)
        self.vae = vae
        self.vae_trainer = vae_trainer
        self.vae_trainer.model = self.vae
        self.vae_save_period = vae_save_period
        self.vae_training_schedule = vae_training_schedule
        self.oracle_data = oracle_data

        self.parallel_vae_train = parallel_vae_train
        self.vae_min_num_steps_before_training = vae_min_num_steps_before_training
        self.uniform_dataset = uniform_dataset

        self._vae_training_process = None
        self._update_subprocess_vae_thread = None
        self._vae_conn_pipe = None

        self.dataset_path = dataset_path
        if self.dataset_path:
            self.load_dataset(dataset_path)

        # train Q and policy rl_offpolicy_num_training_steps times
        self.rl_offpolicy_num_training_steps = rl_offpolicy_num_training_steps

    def pretrain(self):
        for _ in range(self.rl_offpolicy_num_training_steps):
            train_data = self.replay_buffer.random_batch(self.batch_size)
            self.trainer.train(train_data)

    def load_dataset(self, dataset_path):
        dataset = load_local_or_remote_file(dataset_path)
        dataset = dataset.item()

        observations = dataset['observations']
        actions = dataset['actions']

        # dataset['observations'].shape # (2000, 50, 6912)
        # dataset['actions'].shape # (2000, 50, 2)
        # dataset['env'].shape # (2000, 6912)
        N, H, imlength = observations.shape

        self.vae.eval()
        for n in range(N):
            x0 = ptu.from_numpy(dataset['env'][n:n + 1, :] / 255.0)
            x = ptu.from_numpy(observations[n, :, :] / 255.0)
            latents = self.vae.encode(x, x0, distrib=False)

            r1, r2 = self.vae.latent_sizes
            conditioning = latents[0, r1:]
            goal = torch.cat(
                [ptu.randn(self.vae.latent_sizes[0]), conditioning])
            goal = ptu.get_numpy(goal)  # latents[-1, :]

            latents = ptu.get_numpy(latents)
            latent_delta = latents - goal
            distances = np.zeros((H - 1, 1))
            for i in range(H - 1):
                distances[i, 0] = np.linalg.norm(latent_delta[i + 1, :])

            terminals = np.zeros((H - 1, 1))
            # terminals[-1, 0] = 1
            path = dict(
                observations=[],
                actions=actions[n, :H - 1, :],
                next_observations=[],
                rewards=-distances,
                terminals=terminals,
            )

            for t in range(H - 1):
                # reward = -np.linalg.norm(latent_delta[i, :])

                obs = dict(
                    latent_observation=latents[t, :],
                    latent_achieved_goal=latents[t, :],
                    latent_desired_goal=goal,
                )
                next_obs = dict(
                    latent_observation=latents[t + 1, :],
                    latent_achieved_goal=latents[t + 1, :],
                    latent_desired_goal=goal,
                )

                path['observations'].append(obs)
                path['next_observations'].append(next_obs)

            # import ipdb; ipdb.set_trace()
            self.replay_buffer.add_path(path)

    def _end_epoch(self):
        timer.start_timer('vae training')
        self._train_vae(self.epoch)
        timer.stop_timer('vae training')
        super()._end_epoch()

    def _get_diagnostics(self):
        vae_log = self._get_vae_diagnostics().copy()
        vae_log.update(super()._get_diagnostics())
        return vae_log

    def to(self, device):
        self.vae.to(device)
        super().to(device)

    """
    VAE-specific Code
    """

    def _train_vae(self, epoch):
        if self.parallel_vae_train and self._vae_training_process is None:
            self.init_vae_training_subprocess()
        should_train, amount_to_train = self.vae_training_schedule(epoch)
        rl_start_epoch = int(self.min_num_steps_before_training /
                             (self.num_expl_steps_per_train_loop *
                              self.num_train_loops_per_epoch))
        if should_train:  # or epoch <= (rl_start_epoch - 1):
            if self.parallel_vae_train:
                assert self._vae_training_process.is_alive()
                # Make sure the last vae update has finished before starting
                # another one
                if self._update_subprocess_vae_thread is not None:
                    self._update_subprocess_vae_thread.join()
                self._update_subprocess_vae_thread = Thread(
                    target=OnlineVaeAlgorithm.
                    update_vae_in_training_subprocess,
                    args=(self, epoch, ptu.device))
                self._update_subprocess_vae_thread.start()
                self._vae_conn_pipe.send((amount_to_train, epoch))
            else:
                _train_vae(self.vae_trainer, epoch, self.replay_buffer,
                           amount_to_train)
                self.replay_buffer.refresh_latents(epoch)
                _test_vae(
                    self.vae_trainer,
                    epoch,
                    self.replay_buffer,
                    vae_save_period=self.vae_save_period,
                    uniform_dataset=self.uniform_dataset,
                )

    def _get_vae_diagnostics(self):
        return add_prefix(
            self.vae_trainer.get_diagnostics(),
            prefix='vae_trainer/',
        )

    def _cleanup(self):
        if self.parallel_vae_train:
            self._vae_conn_pipe.close()
            self._vae_training_process.terminate()

    def init_vae_training_subprocess(self):
        assert isinstance(self.replay_buffer, SharedObsDictRelabelingBuffer)

        self._vae_conn_pipe, process_pipe = Pipe()
        self._vae_training_process = Process(
            target=subprocess_train_vae_loop,
            args=(
                process_pipe,
                self.vae,
                self.vae.state_dict(),
                self.replay_buffer,
                self.replay_buffer.get_mp_info(),
                ptu.device,
            ))
        self._vae_training_process.start()
        self._vae_conn_pipe.send(self.vae_trainer)

    def update_vae_in_training_subprocess(self, epoch, device):
        self.vae.__setstate__(self._vae_conn_pipe.recv())
        self.vae.to(device)
        _test_vae(
            self.vae_trainer,
            epoch,
            self.replay_buffer,
            vae_save_period=self.vae_save_period,
            uniform_dataset=self.uniform_dataset,
        )
Ejemplo n.º 14
0
def main():
    if 'NAME' in os.environ.keys():
        NAME = os.environ['NAME']
    else:
        raise ValueError('set NAME via env variable')

    try:
        env_settings = json.load(open(default_config['CarIntersectConfigPath'], 'r'))
    except:
        env_settings = yaml.load(open(default_config['CarIntersectConfigPath'], 'r'))

    if 'home-test' not in NAME:
        wandb.init(
            project='CarRacing_RND',
            reinit=True,
            name=f'rnd_{NAME}',
            config={'env_config': env_settings, 'agent_config': default_config},
        )

    # print({section: dict(config[section]) for section in config.sections()})
    train_method = default_config['TrainMethod']

    env_id = default_config['EnvID']
    # env_type = default_config['EnvType']

    # if env_type == 'mario':
    #     env = BinarySpaceToDiscreteSpaceEnv(gym_super_mario_bros.make(env_id), COMPLEX_MOVEMENT)
    # elif env_type == 'atari':
    #     env = gym.make(env_id)
    # else:
    #     raise NotImplementedError

    seed = np.random.randint(0, 2 ** 16 - 1)

    print(f'use name : {NAME}')
    print(f"use env config : {default_config['CarIntersectConfigPath']}")
    print(f'use seed : {seed}')
    print(f"use device : {os.environ['DEVICE']}")

    os.chdir('..')
    env = makeCarIntersect(env_settings)
    eval_env = create_eval_env(makeCarIntersect(env_settings))

    # input_size = env.observation_space.shape  # 4
    input_size = env.observation_space.shape
    assert isinstance(env.action_space, gym.spaces.Box)
    action_size = env.action_space.shape[0]  # 2

    env.close()

    is_load_model = True
    is_render = False
    # model_path = 'models/{}.model'.format(NAME)
    # predictor_path = 'models/{}.pred'.format(NAME)
    # target_path = 'models/{}.target'.format(NAME)

    # writer = SummaryWriter()

    use_cuda = default_config.getboolean('UseGPU')
    use_gae = default_config.getboolean('UseGAE')
    use_noisy_net = default_config.getboolean('UseNoisyNet')

    lam = float(default_config['Lambda'])
    num_worker = int(default_config['NumEnv'])

    num_step = int(default_config['NumStep'])

    ppo_eps = float(default_config['PPOEps'])
    epoch = int(default_config['Epoch'])
    mini_batch = int(default_config['MiniBatch'])
    batch_size = int(num_step * num_worker / mini_batch)
    learning_rate = float(default_config['LearningRate'])
    entropy_coef = float(default_config['Entropy'])
    gamma = float(default_config['Gamma'])
    int_gamma = float(default_config['IntGamma'])
    clip_grad_norm = float(default_config['ClipGradNorm'])
    ext_coef = float(default_config['ExtCoef'])
    int_coef = float(default_config['IntCoef'])

    sticky_action = default_config.getboolean('StickyAction')
    action_prob = float(default_config['ActionProb'])
    life_done = default_config.getboolean('LifeDone')

    reward_rms = RunningMeanStd()
    obs_rms = RunningMeanStd(shape=(1, 1, 84, 84))
    pre_obs_norm_step = int(default_config['ObsNormStep'])
    discounted_reward = RewardForwardFilter(int_gamma)

    agent = RNDAgent(
        input_size,
        action_size,
        num_worker,
        num_step,
        gamma,
        lam=lam,
        learning_rate=learning_rate,
        ent_coef=entropy_coef,
        clip_grad_norm=clip_grad_norm,
        epoch=epoch,
        batch_size=batch_size,
        ppo_eps=ppo_eps,
        use_cuda=use_cuda,
        use_gae=use_gae,
        use_noisy_net=use_noisy_net,
        device=os.environ['DEVICE'],
    )

    # if is_load_model:
    #     print('load model...')
    #     if use_cuda:
    #         agent.model.load_state_dict(torch.load(model_path))
    #         agent.rnd.predictor.load_state_dict(torch.load(predictor_path))
    #         agent.rnd.target.load_state_dict(torch.load(target_path))
    #     else:
    #         agent.model.load_state_dict(torch.load(model_path, map_location='cpu'))
    #         agent.rnd.predictor.load_state_dict(torch.load(predictor_path, map_location='cpu'))
    #         agent.rnd.target.load_state_dict(torch.load(target_path, map_location='cpu'))
    #     print('load finished!')

    works = []
    parent_conns = []
    child_conns = []
    for idx in range(num_worker):
        parent_conn, child_conn = Pipe()
        work = AtariEnvironment(env_id, is_render, idx, child_conn, sticky_action=sticky_action, p=action_prob,
                        life_done=life_done, settings=env_settings)
        work.start()
        works.append(work)
        parent_conns.append(parent_conn)
        child_conns.append(child_conn)

    os.chdir('rnd_continues')

    states = np.zeros([num_worker, 4, 84, 84])

    sample_episode = 0
    sample_rall = 0
    sample_step = 0
    sample_env_idx = 0
    sample_i_rall = 0
    global_update = 0
    global_step = 0

    logger = Logger(None, use_console=True, use_wandb=True, log_interval=1)

    print('Test evaluater:')
    evaluate_and_log(
        eval_env=eval_env,
        action_get_method=lambda eval_state: agent.get_action(
            np.tile(np.float32(eval_state), (1, 4, 1, 1)) / 255.
        )[0][0].cpu().numpy(),
        logger=logger,
        log_animation=False,
        exp_class='RND',
        exp_name=NAME,
        debug=True,
    )
    print('end evaluater test.')

    # normalize obs
    print('Start to initailize observation normalization parameter.....')

    # print('ALERT! pass section')
    # assert 'home-test' in NAME
    next_obs = []
    for step in range(num_step * pre_obs_norm_step):
        actions = np.random.uniform(-1, 1, size=(num_worker, action_size))

        for parent_conn, action in zip(parent_conns, actions):
            parent_conn.send(action)

        for parent_conn in parent_conns:
            s, r, d, rd, lr = parent_conn.recv()
            next_obs.append(s[3, :, :].reshape([1, 84, 84]))

        if len(next_obs) % (num_step * num_worker) == 0:
            next_obs = np.stack(next_obs)
            obs_rms.update(next_obs)
            next_obs = []
    print('End to initalize...')

    while True:
        total_state, total_reward, total_done, total_next_state, total_action, total_int_reward, total_next_obs, total_ext_values, total_int_values, total_policy_log_prob, total_policy_log_prob_np = \
            [], [], [], [], [], [], [], [], [], [], []

        # Step 1. n-step rollout
        for _ in range(num_step):
            global_step += num_worker
            # actions, value_ext, value_int, policy = agent.get_action(np.float32(states) / 255.)
            actions, value_ext, value_int, policy_log_prob = agent.get_action(np.float32(states) / 255.)

            for parent_conn, action in zip(parent_conns, actions):
                parent_conn.send(action.cpu().numpy())

            next_states, rewards, dones, real_dones, log_rewards, next_obs = [], [], [], [], [], []
            for parent_conn in parent_conns:
                s, r, d, rd, lr = parent_conn.recv()
                next_states.append(s)
                rewards.append(r)
                dones.append(d)
                real_dones.append(rd)
                log_rewards.append(lr)
                next_obs.append(s[3, :, :].reshape([1, 84, 84]))

            next_states = np.stack(next_states)
            rewards = np.hstack(rewards)
            dones = np.hstack(dones)
            real_dones = np.hstack(real_dones)
            next_obs = np.stack(next_obs)

            # total reward = int reward + ext Reward
            intrinsic_reward = agent.compute_intrinsic_reward(
                ((next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5))
            intrinsic_reward = np.hstack(intrinsic_reward)
            sample_i_rall += intrinsic_reward[sample_env_idx]

            total_next_obs.append(next_obs)
            total_int_reward.append(intrinsic_reward)
            total_state.append(states)
            total_reward.append(rewards)
            total_done.append(dones)
            total_action.append(actions.cpu().numpy())
            total_ext_values.append(value_ext)
            total_int_values.append(value_int)

            # total_policy.append(policy)
            # total_policy_np.append(policy.cpu().numpy())

            total_policy_log_prob.extend(policy_log_prob.cpu().numpy())

            states = next_states[:, :, :, :]

            sample_rall += log_rewards[sample_env_idx]

            sample_step += 1
            if real_dones[sample_env_idx]:
                sample_episode += 1
                # writer.add_scalar('data/reward_per_epi', sample_rall, sample_episode)
                # writer.add_scalar('data/reward_per_rollout', sample_rall, global_update)
                # writer.add_scalar('data/step', sample_step, sample_episode)
                logger.log_it({
                    'reward_per_episode': sample_rall,
                    'intrinsic_reward': sample_i_rall,
                    'episode_steps': sample_step,
                    'global_step_cnt': global_step,
                    'updates_cnt': global_update,
                })
                logger.publish_logs(step=global_step)
                sample_rall = 0
                sample_step = 0
                sample_i_rall = 0

        # calculate last next value
        _, value_ext, value_int, _ = agent.get_action(np.float32(states) / 255.)
        total_ext_values.append(value_ext)
        total_int_values.append(value_int)
        # --------------------------------------------------

        total_state = np.stack(total_state).transpose([1, 0, 2, 3, 4]).reshape([-1, 4, 84, 84])
        total_reward = np.stack(total_reward).transpose().clip(-1, 1)

        # total_action = np.stack(total_action).transpose().reshape([-1, action_size])
        total_action = np.array(total_action).reshape((-1, action_size))
        # total_log_prob_old = np.array(total_policy_log_prob).reshape((-1))

        total_done = np.stack(total_done).transpose()
        total_next_obs = np.stack(total_next_obs).transpose([1, 0, 2, 3, 4]).reshape([-1, 1, 84, 84])
        total_ext_values = np.stack(total_ext_values).transpose()
        total_int_values = np.stack(total_int_values).transpose()
        # total_logging_policy = np.vstack(total_policy_np)

        # Step 2. calculate intrinsic reward
        # running mean intrinsic reward
        total_int_reward = np.stack(total_int_reward).transpose()
        total_reward_per_env = np.array([discounted_reward.update(reward_per_step) for reward_per_step in
                                         total_int_reward.T])
        mean, std, count = np.mean(total_reward_per_env), np.std(total_reward_per_env), len(total_reward_per_env)
        reward_rms.update_from_moments(mean, std ** 2, count)

        # normalize intrinsic reward
        total_int_reward /= np.sqrt(reward_rms.var)
        # writer.add_scalar('data/int_reward_per_epi', np.sum(total_int_reward) / num_worker, sample_episode)
        # writer.add_scalar('data/int_reward_per_rollout', np.sum(total_int_reward) / num_worker, global_update)
        # -------------------------------------------------------------------------------------------

        # logging Max action probability
        # writer.add_scalar('data/max_prob', softmax(total_logging_policy).max(1).mean(), sample_episode)

        # Step 3. make target and advantage
        # extrinsic reward calculate
        ext_target, ext_adv = make_train_data(total_reward,
                                              total_done,
                                              total_ext_values,
                                              gamma,
                                              num_step,
                                              num_worker)

        # intrinsic reward calculate
        # None Episodic
        int_target, int_adv = make_train_data(total_int_reward,
                                              np.zeros_like(total_int_reward),
                                              total_int_values,
                                              int_gamma,
                                              num_step,
                                              num_worker)

        # add ext adv and int adv
        total_adv = int_adv * int_coef + ext_adv * ext_coef
        # -----------------------------------------------

        # Step 4. update obs normalize param
        obs_rms.update(total_next_obs)
        # -----------------------------------------------

        global_update += 1
        # Step 5. Training!
        agent.train_model(np.float32(total_state) / 255., ext_target, int_target, total_action,
                          total_adv, ((total_next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5),
                          total_policy_log_prob)

        # if global_step % (num_worker * num_step * 100) == 0:
        #     print('Now Global Step :{}'.format(global_step))
        #     torch.save(agent.model.state_dict(), model_path)
        #     torch.save(agent.rnd.predictor.state_dict(), predictor_path)
        #     torch.save(agent.rnd.target.state_dict(), target_path)

        if global_update % 100 == 0:
            evaluate_and_log(
                eval_env=eval_env,
                action_get_method=lambda eval_state: agent.get_action(
                    np.tile(np.float32(eval_state), (1, 4, 1, 1)) / 255.
                )[0][0].cpu().numpy(),
                logger=logger,
                log_animation=True,
                exp_class='RND',
                exp_name=NAME,
            )
            logger.publish_logs(step=global_step)
Ejemplo n.º 15
0
    def meta_fit(self, meta_dataset_generator):

        catchable_sigs = set(signal.Signals) - {signal.SIGKILL, signal.SIGSTOP}
        for sig in catchable_sigs:
            signal.signal(
                sig,
                receive_signal)  # Substitute handler of choice for `print`
        LOGGER.debug('My PID: %s' % os.getpid())

        self.timer.begin('main training')
        mp.set_start_method('spawn', force=True)

        # >>> BUG: OS Error: Too many opened files
        # >>> SOLVED: by `ulimit -HSn 4096`
        # Now, we change all the queues to pipe
        self.timer.begin('build data pipeline')
        # every 10 epoch will produce one valid
        train_data_reservoir = [
            queue.Queue(32 * 10) for i in range(len(self.devices))
        ]
        valid_data_reservoir = [
            queue.Queue(200) for i in range(len(self.devices))
        ]
        meta_valid_reservoir = [
            queue.Queue(self.eval_tasks) for i in range(self.total_exp)
        ]
        train_recv, valid_recv = [], []
        train_send, valid_send = [], []
        for i in range(len(self.devices)):
            recv, send = Pipe(True)
            # activate the first handshake
            recv.send(True)
            train_recv.append(recv)
            train_send.append(send)
            recv, send = Pipe(True)
            # activate the first handshake
            recv.send(True)
            valid_recv.append(recv)
            valid_send.append(send)

        def apply_device_to_hp(hp, device):
            hp['device'] = 'cuda:{}'.format(device)
            return hp

        self.timer.end('build data pipeline')

        self.timer.begin('build main proc pipeline')
        clsnum = get_base_class_number(meta_dataset_generator)
        LOGGER.info('base class number detected', clsnum)
        procs = [
            mp.Process(target=run_exp,
                       args=(self.modules[i].MyMetaLearner,
                             apply_device_to_hp(self.hp[i], dev),
                             train_recv[i], valid_recv[i], clsnum))
            for i, dev in enumerate(self.devices)
        ]
        for p in procs:
            p.daemon = True
            p.start()

        self.timer.end('build main proc pipeline')
        LOGGER.info('build data',
                    self.timer.query_time_by_name('build data pipeline'),
                    'build proc',
                    self.timer.query_time_by_name('build main proc pipeline'))
        label_meta_valid = []

        data_generation = True

        self.timer.begin('prepare dataset')
        meta_train_dataset = meta_dataset_generator.meta_train_pipeline.batch(
            1)
        meta_train_generator = cycle(iter(meta_train_dataset))
        meta_valid_dataset = meta_dataset_generator.meta_valid_pipeline.batch(
            1)
        meta_valid_generator = cycle(iter(meta_valid_dataset))
        self.timer.end('prepare dataset')
        LOGGER.info('prepare dataset',
                    self.timer.query_time_by_name('prepare dataset'))

        valid_ens_data_load_number = 0

        def generate_data():
            # manage data globally
            while data_generation:
                for i in range(32 * 10):
                    # load train
                    if not data_generation:
                        break
                    data_train = process_task_batch(next(meta_train_generator),
                                                    device=torch.device('cpu'),
                                                    with_origin_label=True)
                    for dr in train_data_reservoir:
                        try:
                            dr.put_nowait(data_train)
                        except:
                            pass
                    time.sleep(0.0001)

                for i in range(200):
                    # load valid
                    if not data_generation:
                        break
                    data_valid = process_task_batch(next(meta_valid_generator),
                                                    device=torch.device('cpu'),
                                                    with_origin_label=False)
                    for dr in valid_data_reservoir:
                        try:
                            dr.put_nowait(data_valid)
                        except:
                            pass
                    if random.random() < 0.1:
                        for dr in meta_valid_reservoir:
                            try:
                                if dr.qsize() < self.eval_tasks:
                                    valid_ens_data_load_number += 1
                                    dr.put_nowait([
                                        data_valid[0][0], data_valid[0][1],
                                        data_valid[1][0]
                                    ])
                                    label_meta_valid.extend(
                                        data_valid[1][1].tolist())
                            except:
                                pass
                    time.sleep(0.0001)

        def put_data_train_passive(i):
            while data_generation:
                try:
                    if train_send[i].recv():
                        supp, quer = train_data_reservoir[i].get()
                        data = self.modules[i].process_data(
                            supp, quer, True, self.hp[i])
                        train_send[i].send(data)
                    else:
                        return
                except:
                    pass

        def put_data_valid_passive(i):
            while data_generation:
                try:
                    if valid_send[i].recv():
                        supp, quer = valid_data_reservoir[i].get()
                        data = self.modules[i].process_data(
                            supp, quer, False, self.hp[i])
                        valid_send[i].send(data)
                    else:
                        return
                except:
                    pass

        thread_pool = [threading.Thread(target=generate_data)] + \
            [threading.Thread(target=put_data_train_passive, args=(i,)) for i in range(self.total_exp)] + \
            [threading.Thread(target=put_data_valid_passive, args=(i,)) for i in range(self.total_exp)]

        for th in thread_pool:
            th.daemon = True
            th.start()

        try:
            # we leave about 20 min for decoding of test
            for p in procs:
                p.join(max(self.timer.time_left() - 60 * 10, 0.1))

            self.timer.begin('clear env')
            # terminate proc that is out-of-time
            LOGGER.info('Main meta-train is done',
                        '' if self.timer.time_left() > 60 else 'time out exit')
            LOGGER.info('time left', self.timer.time_left(), 's')
            for p in procs:
                if p.is_alive():
                    p.terminate()

            data_generation = False
            # in case there are blocking
            for q in train_data_reservoir + valid_data_reservoir:
                if q.empty():
                    q.put(False)
            for s in train_recv + valid_recv:
                s.send(False)
            for s in train_send + train_recv + valid_send + valid_recv:
                s.close()
            for p in thread_pool:
                p.join()
            self.timer.end('clear env')
            LOGGER.info('clear env',
                        self.timer.query_time_by_name('clear env'))

            self.timer.end('main training')
        except Exception:
            LOGGER.info('error occured in main process')
            traceback.print_exc()

        LOGGER.info(
            'spawn total {} meta valid tasks. main training time {}'.format(
                valid_ens_data_load_number,
                self.timer.query_time_by_name('main training')))

        self.timer.begin('load learner')

        self.meta_learners = [None] * self.total_exp

        def load_model(args):
            module, hp, i = args
            self.meta_learners[i] = module.load_model(hp)

        pool = [
            threading.Thread(target=load_model,
                             args=((self.modules[i], self.hp[i], i), ))
            for i in range(self.total_exp)
        ]
        for p in pool:
            p.daemon = True
            p.start()
        for p in pool:
            p.join()

        self.timer.end('load learner')
        LOGGER.info('load learner done, time spent',
                    self.timer.query_time_by_name('load learner'))

        if not isinstance(self.ensemble, int):
            # instead of just weighted sum, we plan to use stacking
            procs = []
            reses = [None] * len(self.meta_learners)

            self.timer.begin('validation')

            recv_list, sent_list = [], []
            for i in range(self.total_exp):
                r, s = Pipe(True)
                r.send(True)
                recv_list.append(r)
                sent_list.append(s)

            pool = mp.Pool(self.total_exp)
            procs = pool.starmap_async(
                predict, [(self.meta_learners[i], recv_list[i],
                           self.eval_tasks, self.hp[i]['device'], {
                               'time_fired': time.time(),
                               'taskid': i
                           }) for i in range(self.total_exp)])

            # start sub thread to pass data
            def pass_meta_data(i):
                for _ in range(self.eval_tasks):
                    if sent_list[i].recv():
                        # LOGGER.info(i, 'fire data signal get')
                        sent_list[i].send(meta_valid_reservoir[i].get())
                        # LOGGER.info(i, 'data is sent')

            threads = [
                threading.Thread(target=pass_meta_data, args=(i, ))
                for i in range(self.total_exp)
            ]
            for t in threads:
                t.daemon = True
                t.start()

            for _ in range(self.eval_tasks - valid_ens_data_load_number):
                data_valid = next(meta_valid_generator)
                data_valid = process_task_batch(data_valid,
                                                device=torch.device('cpu'),
                                                with_origin_label=False)
                label_meta_valid.extend(data_valid[1][1].tolist())
                for dr in meta_valid_reservoir:
                    dr.put(
                        [data_valid[0][0], data_valid[0][1], data_valid[1][0]])
                # LOGGER.info('put data!')
            # LOGGER.info('all data done!')

            # now we can receive data
            for t in threads:
                t.join()
            reses = [sent_list[i].recv()['res'] for i in range(self.total_exp)]
            # every res in reses is a np.array of shape (eval_task * WAY * QUERY) * WAY
            ENS_VALID_TASK = 50
            ENS_VALID_ELEMENT = ENS_VALID_TASK * 5 * 19
            reses_test_list = [
                deepcopy(res[-ENS_VALID_ELEMENT:]) for res in reses
            ]

            self.timer.end('validation')
            LOGGER.info('valid data predict done',
                        self.timer.query_time_by_name('validation'))

            weight = [1.] * len(self.meta_learners)
            labels = np.array(label_meta_valid, dtype=np.int)  # 19000
            acc_o = ((np.array(weight)[:, None, None] / sum(weight) *
                      np.array(reses)).sum(axis=0).argmax(
                          axis=1) == labels).mean()
            reses = np.array(reses, dtype=np.float).transpose((1, 0, 2))
            reses_test = reses[-ENS_VALID_ELEMENT:].reshape(
                ENS_VALID_ELEMENT, -1)
            reses = reses[:-ENS_VALID_ELEMENT]
            reses = reses.reshape(len(reses), -1)
            labels_test = labels[-ENS_VALID_ELEMENT:]
            labels = labels[:-ENS_VALID_ELEMENT]
            LOGGER.info('voting result', acc_o)

            self.timer.begin('ensemble')

            # mp.set_start_method('fork', True)
            result = pool.map(
                ensemble_on_data,
                [
                    (GBMEnsembler(), reses, labels, 'gbm'),
                    (GLMEnsembler(), reses, labels, 'glm'),
                    (NBEnsembler(), reses, labels, 'nb'),
                    (RFEnsembler(), reses, labels, 'rf'
                     )  # too over-fit on simple dataset
                ])

            # test the ensemble model
            def acc(logit, label):
                return (logit.argmax(axis=1) == label).mean()

            res_test = [x[0]._predict(reses_test) for x in result]
            acc_test = [acc(r, labels_test) for r in res_test]
            acc_single_test = [
                acc(np.array(r), labels_test) for r in reses_test_list
            ]
            LOGGER.info('ensemble test', 'gbm', 'glm', 'nb', 'rf', acc_test)
            LOGGER.info('single test', acc_single_test)

            if max(acc_test) > max(acc_single_test):
                LOGGER.info("will use ensemble model")
                #idx_acc_max = np.argmax([x[1] for x in result])
                idx_acc_max = np.argmax(acc_test)
                self.timer.end('ensemble')
                print('best ensembler', ['gbm', 'glm', 'nb',
                                         'rf'][idx_acc_max], 'acc',
                      acc_test[idx_acc_max])
                print('ensemble done, time cost',
                      self.timer.query_time_by_name('ensemble'))

                # currently we use mean of output as ensemble
                return MyLearner(self.meta_learners,
                                 result[idx_acc_max][0],
                                 timers=self.timer)
            else:
                LOGGER.info("will use single model")
                idx_acc_max = np.argmax(acc_single_test)
                self.timer.end('ensemble')
                print('best single model id', idx_acc_max)
                print('ensemble done, time cost',
                      self.timer.query_time_by_name('ensemble'))

                # return only the best meta learners
                return MyLearner([self.meta_learners[idx_acc_max]], 0,
                                 self.timer)
        return MyLearner([self.meta_learners[self.ensemble]],
                         0,
                         timers=self.timer)
Ejemplo n.º 16
0
    def meta_fit(self, meta_dataset_generator):

        with tf.device('/cpu:0'):
            LOGGER.debug('My PID: %s' % os.getpid())

            self.timer.begin('main training')
            mp.set_start_method('spawn', force=True)
            
            self.timer.begin('build data pipeline')

            # these reservoirs are used to send data to sub-process
            train_data_process_reservoir = [queue.Queue(self.train_cache_size) for i in range(len(self.devices))]
            valid_data_process_reservoir = [queue.Queue(self.valid_cache_size) for i in range(len(self.devices))]
            
            meta_valid_reservoir = [queue.Queue(self.eval_tasks) for i in range(self.total_exp)]

            # these reserviors are used to only store the extracted data
            train_data_extract_reservoir = [queue.Queue(self.train_cache_size) for i in range(len(self.devices))]
            valid_data_extract_reservoir = [queue.Queue(self.valid_cache_size) for i in range(len(self.devices))]

            if self.fix_valid:
                valid_data_cache = [[] for _ in range(len(self.devices))]
                valid_data_pointer = [0 for _ in range(len(self.devices))]
            
            train_recv, valid_recv = [], []
            train_send, valid_send = [], []
            for i in range(len(self.devices)):
                recv, send = Pipe(True)
                # activate the first handshake
                recv.send(True)
                train_recv.append(recv)
                train_send.append(send)
                recv, send = Pipe(True)
                # activate the first handshake
                recv.send(True)
                valid_recv.append(recv)
                valid_send.append(send)

            def apply_device_to_hp(hp, device):
                hp['device'] = 'cuda:{}'.format(device)
                return hp
            
            self.timer.end('build data pipeline')

            self.timer.begin('build main proc pipeline')
            clsnum = get_base_class_number(meta_dataset_generator)
            LOGGER.info('base class number detected', clsnum)
            procs = [mp.Process(
                target=run_exp,
                args=(
                    self.modules[i].MyMetaLearner,
                    apply_device_to_hp(self.hp[i], dev),
                    train_recv[i], valid_recv[i],
                    clsnum, 
                    self.modules[i].process_data if self.process_protocol != 'process-in-main' else None
                )
            ) for i, dev in enumerate(self.devices)]

            for p in procs: p.daemon = True; p.start()

            self.timer.end('build main proc pipeline')
            LOGGER.info('build data', self.timer.query_time_by_name('build data pipeline'), 'build proc', self.timer.query_time_by_name('build main proc pipeline'))
            label_meta_valid = []

            data_generation = True

            self.timer.begin('prepare dataset')
            meta_train_dataset = meta_dataset_generator.meta_train_pipeline.batch(1)
            meta_train_generator = iter(meta_train_dataset)
            meta_valid_dataset = meta_dataset_generator.meta_valid_pipeline.batch(1)
            meta_valid_generator = iter(meta_valid_dataset)
            self.timer.end('prepare dataset')
            LOGGER.info('prepare dataset', self.timer.query_time_by_name('prepare dataset'))

            global valid_ens_data_load_number
            valid_ens_data_load_number = 0

            def train_pipe_fill():
                while data_generation:
                    data_train = process_task_batch(next(meta_train_generator), device=torch.device('cpu'), with_origin_label=True)
                    for dr in train_data_extract_reservoir:
                        try: dr.put_nowait(data_train)
                        except: pass
                    time.sleep(0.001)
            
            def valid_pipe_fill():
                global valid_ens_data_load_number
                while data_generation:
                    data_valid = process_task_batch(next(meta_valid_generator), device=torch.device('cpu'), with_origin_label=False)
                    for dr in valid_data_extract_reservoir:
                        try: dr.put_nowait(data_valid)
                        except: pass
                        if random.random() < 0.1 and valid_ens_data_load_number < self.eval_tasks:
                            # fill the meta-valid
                            valid_ens_data_load_number += 1
                            label_meta_valid.extend(data_valid[1][1].tolist())
                            for dr in meta_valid_reservoir:
                                    dr.put([data_valid[0][0], data_valid[0][1], data_valid[1][0]])
                    time.sleep(0.001)

            def put_data_train_passive(i):
                while data_generation:
                    try:
                        if train_send[i].recv(): train_send[i].send(train_data_process_reservoir[i].get())
                        else: return
                    except: pass

            def put_data_valid_passive(i):
                while data_generation:
                    try:
                        if valid_send[i].recv():
                            if self.fix_valid:
                                if len(valid_data_cache[i]) == self.hp[i]['eval_tasks']:
                                    # retrieve the ith element
                                    data = valid_data_cache[i][valid_data_pointer[i]]
                                    valid_data_pointer[i] = (valid_data_pointer[i] + 1) % self.hp[i]['eval_tasks']
                                    valid_send[i].send(data)
                                else:
                                    # fill the cache
                                    data = valid_data_process_reservoir[i].get()
                                    valid_data_cache[i].append(data)
                                    valid_send[i].send(data)
                            else:
                                valid_send[i].send(valid_data_process_reservoir[i].get())

                        else: return
                    except: pass
            
            def process_data(i, train=True):
                while data_generation:
                    extract_ = train_data_extract_reservoir[i] if train else valid_data_extract_reservoir[i]
                    process_ = train_data_process_reservoir[i] if train else valid_data_process_reservoir[i]
                    data = extract_.get()
                    if data == False: break
                    if self.process_protocol == 'process-in-main':
                        data = self.modules[i].process_data(data[0], data[1], train, apply_device_to_hp(self.hp[i], self.devices[i]))
                    process_.put(data)
            
            thread_pool = [threading.Thread(target=train_pipe_fill), threading.Thread(target=valid_pipe_fill)] + \
                [threading.Thread(target=put_data_train_passive, args=(i,)) for i in range(self.total_exp)] + \
                [threading.Thread(target=put_data_valid_passive, args=(i,)) for i in range(self.total_exp)] + \
                [threading.Thread(target=process_data, args=(i, train)) for i, train in itertools.product(range(self.total_exp), [True, False])]
            
            for th in thread_pool: th.daemon = True; th.start()

            try:
                # we leave about 20 min for decoding of test
                for p in procs: p.join(max(self.timer.time_left() - 60 * 20, 0.1))
            
                self.timer.begin('clear env')
                # terminate proc that is out-of-time
                LOGGER.info('Main meta-train is done', '' if self.timer.time_left() > 60 else 'time out exit')
                LOGGER.info('time left', self.timer.time_left(), 's')
                for p in procs:
                    if p.is_alive():
                        p.terminate()
                
                LOGGER.info('all process terminated')

                data_generation = False
                
                LOGGER.info('send necessary messages in case of block')
                # solve the pipe block
                try:
                    for s in train_recv + valid_recv: s.send(False)
                    for s in train_send + train_recv + valid_send + valid_recv: s.close()
                except:
                    LOGGER.error('wired, it should not fire any errors, but it just did')
                
                # solve the block of extract reservoir
                for q in train_data_extract_reservoir + valid_data_extract_reservoir:
                    if q.empty():
                        q.put(False)

                for q in train_data_process_reservoir + valid_data_process_reservoir:
                    if q.full():
                        q.get()
                    elif q.empty():
                        q.put(False)

                LOGGER.info('wait for all data thread')
                for p in thread_pool: p.join()
                LOGGER.info('wait for sub process to exit')
                for p in procs: p.join()
                self.timer.end('clear env')
                LOGGER.info('clear env', self.timer.query_time_by_name('clear env'))
                
                self.timer.end('main training')
            except Exception:
                LOGGER.info('error occured in main process')
                traceback.print_exc()

            LOGGER.info('spawn total {} meta valid tasks. main training time {}'.format(valid_ens_data_load_number, self.timer.query_time_by_name('main training')))
            
            self.timer.begin('load learner')

            self.meta_learners = [None] * self.total_exp

            def load_model(args):
                module, hp, i = args
                self.meta_learners[i] = module.load_model(hp)

            pool = [threading.Thread(target=load_model, args=((self.modules[i], self.hp[i], i), )) for i in range(self.total_exp)]
            for p in pool: p.daemon=True; p.start()
            for p in pool: p.join()

            self.timer.end('load learner')
            LOGGER.info('load learner done, time spent', self.timer.query_time_by_name('load learner'))
            
            if not isinstance(self.ensemble, int):
                # auto-ensemble by exhaustive search
                procs = []
                reses = [None] * len(self.meta_learners)
                
                self.timer.begin('validation')
                
                recv_list, sent_list = [], []
                for i in range(self.total_exp):
                    r, s = Pipe(True)
                    r.send(True)
                    recv_list.append(r)
                    sent_list.append(s)

                processes = [mp.Process(target=predict, args=(
                    self.meta_learners[i],
                    recv_list[i],
                    self.eval_tasks,
                    self.hp[i]['device'],
                    {
                        'time_fired': time.time(),
                        'taskid': i
                    }
                )) for i in range(self.total_exp)]

                for p in processes: p.daemon = True; p.start()
                
                # start sub thread to pass data
                def pass_meta_data(i):
                    for _ in range(self.eval_tasks):
                        if sent_list[i].recv():
                            sent_list[i].send(meta_valid_reservoir[i].get())
                
                threads = [threading.Thread(target=pass_meta_data, args=(i, )) for i in range(self.total_exp)]
                for t in threads: t.daemon = True; t.start()
                
                for _ in range(self.eval_tasks - valid_ens_data_load_number):
                    data_valid = next(meta_valid_generator)
                    data_valid = process_task_batch(data_valid, device=torch.device('cpu'), with_origin_label=False)
                    label_meta_valid.extend(data_valid[1][1].tolist())
                    for dr in meta_valid_reservoir:
                        dr.put([data_valid[0][0], data_valid[0][1], data_valid[1][0]])
                    # LOGGER.info('put data!')
                LOGGER.info('all data done!')
                LOGGER.info(len(label_meta_valid))
                
                # now we can receive data
                for t in threads: t.join()
                reses = [sent_list[i].recv()['res'] for i in range(self.total_exp)]
                for send in sent_list:
                    send.send(True)
                # for p in processes: p.join()
                # every res in reses is a np.array of shape (eval_task * WAY * QUERY) * WAY
                ENS_VALID_TASK = 100
                ENS_VALID_ELEMENT = ENS_VALID_TASK * 5 * 19
                reses_test_list = [deepcopy(res[-ENS_VALID_ELEMENT:]) for res in reses]

                self.timer.end('validation')
                LOGGER.info('valid data predict done', self.timer.query_time_by_name('validation'))
                
                weight = [1.] * len(self.meta_learners)
                labels = np.array(label_meta_valid, dtype=np.int)                            # 19000
                acc_o = ((np.array(weight)[:,None, None] / sum(weight) * np.array(reses)).sum(axis=0).argmax(axis=1) == labels).astype(np.float).mean()
                reses = np.array(reses, dtype=np.float).transpose((1, 0, 2))
                reses_test = reses[-ENS_VALID_ELEMENT:].reshape(ENS_VALID_ELEMENT, -1)
                reses = reses[:-ENS_VALID_ELEMENT]
                reses = reses.reshape(len(reses), -1)
                labels_test = labels[-ENS_VALID_ELEMENT:]
                labels = labels[:-ENS_VALID_ELEMENT]
                LOGGER.info('voting result', acc_o)

                self.timer.begin('ensemble')

                # mp.set_start_method('fork', True)
                pool = mp.Pool(3)
                result = pool.map(ensemble_on_data, [
                    # (GBMEnsembler(), reses, labels, 'gbm'), # currently, gbm has some problems when save/load
                    (GLMEnsembler(), reses, labels, 'glm'),
                    (NBEnsembler(), reses, labels, 'nb'),
                    (RFEnsembler(), reses, labels, 'rf') # too over-fit on simple dataset
                ])

                # test the ensemble model
                def acc(logit, label):
                    return (logit.argmax(axis=1) == label).mean()
                res_test = [x[0]._predict(reses_test) for x in result]
                acc_test = [acc(r, labels_test) for r in res_test]
                acc_single_test = [acc(np.array(r), labels_test) for r in reses_test_list]
                LOGGER.info('ensemble test', 'glm', 'nb', 'rf', acc_test)
                LOGGER.info('single test', acc_single_test)

                if max(acc_test) > max(acc_single_test):
                    LOGGER.info("will use ensemble model")
                    #idx_acc_max = np.argmax([x[1] for x in result])
                    idx_acc_max = np.argmax(acc_test)
                    self.timer.end('ensemble')
                    print('best ensembler', ['glm', 'nb', 'rf'][idx_acc_max], 'acc', acc_test[idx_acc_max])
                    print('ensemble done, time cost', self.timer.query_time_by_name('ensemble'))

                    return MyLearner(self.meta_learners, result[idx_acc_max][0], timers=self.timer)
                else:
                    LOGGER.info("will use single model")
                    idx_acc_max = np.argmax(acc_single_test)
                    self.timer.end('ensemble')
                    print('best single model id', idx_acc_max)
                    print('ensemble done, time cost', self.timer.query_time_by_name('ensemble'))

                    # return only the best meta learners
                    return MyLearner([self.meta_learners[idx_acc_max]], 0, self.timer)
            return MyLearner([self.meta_learners[self.ensemble]], 0, timers=self.timer)