def run(args):
    env, is_atari = make_env(args.env)
    option_critic = OptionCriticConv if is_atari else OptionCriticFeatures
    device = torch.device(
        'cuda' if torch.cuda.is_available() and args.cuda else 'cpu')

    option_critic = option_critic(in_features=env.observation_space.shape[0],
                                  num_actions=env.action_space.n,
                                  num_options=args.num_options,
                                  temperature=args.temp,
                                  eps_start=args.epsilon_start,
                                  eps_min=args.epsilon_min,
                                  eps_decay=args.epsilon_decay,
                                  eps_test=args.optimal_eps,
                                  device=device)
    # Create a prime network for more stable Q values
    option_critic_prime = deepcopy(option_critic)

    optim = torch.optim.RMSprop(option_critic.parameters(),
                                lr=args.learning_rate)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    env.seed(args.seed)

    buffer = ReplayBuffer(capacity=args.max_history, seed=args.seed)
    logger = Logger(
        logdir=args.logdir,
        run_name=
        f"{OptionCriticFeatures.__name__}-{args.env}-{args.exp}-{time.ctime()}"
    )

    steps = 0
    if args.switch_goal: print(f"Current goal {env.goal}")
    while steps < args.max_steps_total:

        rewards = 0
        option_lengths = {opt: [] for opt in range(args.num_options)}

        obs = env.reset()
        state = option_critic.get_state(to_tensor(obs))
        greedy_option = option_critic.greedy_option(state)
        current_option = 0

        # Goal switching experiment: run for 1k episodes in fourrooms, switch goals and run for another
        # 2k episodes. In option-critic, if the options have some meaning, only the policy-over-options
        # should be finedtuned (this is what we would hope).
        if args.switch_goal and logger.n_eps == 1000:
            torch.save(
                {
                    'model_params': option_critic.state_dict(),
                    'goal_state': env.goal
                }, 'models/option_critic_{args.seed}_1k')
            env.switch_goal()
            print(f"New goal {env.goal}")

        if args.switch_goal and logger.n_eps > 2000:
            torch.save(
                {
                    'model_params': option_critic.state_dict(),
                    'goal_state': env.goal
                }, 'models/option_critic_{args.seed}_2k')
            break

        done = False
        ep_steps = 0
        option_termination = True
        curr_op_len = 0
        while not done and ep_steps < args.max_steps_ep:
            epsilon = option_critic.epsilon

            if option_termination:
                option_lengths[current_option].append(curr_op_len)
                current_option = np.random.choice(
                    args.num_options
                ) if np.random.rand() < epsilon else greedy_option
                curr_op_len = 0

            action, logp, entropy = option_critic.get_action(
                state, current_option)

            next_obs, reward, done, _ = env.step(action)
            buffer.push(obs, current_option, reward, next_obs, done)

            old_state = state
            state = option_critic.get_state(to_tensor(next_obs))

            option_termination, greedy_option = option_critic.predict_option_termination(
                state, current_option)
            rewards += reward

            actor_loss, critic_loss = None, None
            if len(buffer) > args.batch_size:
                actor_loss = actor_loss_fn(obs, current_option, logp, entropy, \
                    reward, done, next_obs, option_critic, option_critic_prime, args)
                loss = actor_loss

                if steps % args.update_frequency == 0:
                    data_batch = buffer.sample(args.batch_size)
                    critic_loss = critic_loss_fn(option_critic,
                                                 option_critic_prime,
                                                 data_batch, args)
                    loss += critic_loss

                optim.zero_grad()
                loss.backward()
                optim.step()

                if steps % args.freeze_interval == 0:
                    option_critic_prime.load_state_dict(
                        option_critic.state_dict())

            # update global steps etc
            steps += 1
            ep_steps += 1
            curr_op_len += 1
            obs = next_obs

            logger.log_data(steps, actor_loss, critic_loss, entropy.item(),
                            epsilon)

        logger.log_episode(steps, rewards, option_lengths, ep_steps, epsilon)
Beispiel #2
0
def run(args):
    env = make_env(args.env)
    option_critic = OptionCriticFeatures
    device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')

    option_critic = option_critic(
        in_features=env.observation_space.shape[0],
        num_actions=env.action_space.n,
        num_options=args.num_options,
        temperature=args.temp,
        eps_start=args.epsilon_start,
        eps_min=args.epsilon_min,
        eps_decay=args.epsilon_decay,
        eps_test=args.optimal_eps,
        device=device,
        input_size = args.input_size,
        num_classes = args.num_classes,
        num_experts = args.num_experts, 
        hidden_size = args.hidden_size, 
        batch_size = args.batch_size,
        top_k = args.top_k
    )
    # Create a prime network for more stable Q values
    option_critic_prime = deepcopy(option_critic)

    optim = torch.optim.RMSprop(option_critic.parameters(), lr=args.learning_rate)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    env.seed(args.seed)

    buffer = ReplayBuffer(capacity=args.max_history, seed=args.seed)
    logger = Logger(logdir=args.logdir, run_name=f"{OptionCriticFeatures.__name__}-{args.env}-{args.exp}-{time.ctime()}")

    steps = 0 ; 
    print(f"Current goal {env.goal}")
    while steps < args.max_steps_total:

        rewards = 0 ; option_lengths = {opt:[] for opt in range(args.num_options)}

        obs   = env.reset()
        state = option_critic.get_state(to_tensor(obs))
        greedy_option  = option_critic.greedy_option(state)
        current_option = 0

        done = False ; ep_steps = 0 ; option_termination = True ; curr_op_len = 0
        while not done and ep_steps < args.max_steps_ep:
            epsilon = option_critic.epsilon

            if option_termination:
                option_lengths[current_option].append(curr_op_len)
                current_option = np.random.choice(args.num_options) if np.random.rand() < epsilon else greedy_option
                curr_op_len = 0
    
            action, logp, entropy = option_critic.get_action(state, current_option)

            next_obs, reward, done, _ = env.step(action)
            buffer.push(obs, current_option, reward, next_obs, done)

            old_state = state
            state = option_critic.get_state(to_tensor(next_obs))

            option_termination, greedy_option = option_critic.predict_option_termination(state, current_option)
            rewards += reward

            actor_loss, critic_loss = None, None
            if len(buffer) > args.batch_size:
                actor_loss = actor_loss_fn(obs, current_option, logp, entropy, \
                    reward, done, next_obs, option_critic, option_critic_prime, args)
                loss = actor_loss

                if steps % args.update_frequency == 0:
                    data_batch = buffer.sample(args.batch_size)
                    critic_loss = critic_loss_fn(option_critic, option_critic_prime, data_batch, args)
                    loss += critic_loss

                optim.zero_grad()
                loss.backward()
                optim.step()

                if steps % args.freeze_interval == 0:
                    option_critic_prime.load_state_dict(option_critic.state_dict())

            # update global steps etc
            steps += 1
            ep_steps += 1
            curr_op_len += 1
            obs = next_obs

            logger.log_data(steps, actor_loss, critic_loss, entropy.item(), epsilon)

        logger.log_episode(steps, rewards, option_lengths, ep_steps, epsilon)
class Runner(object):
    def __init__(self, args):
        # Set the random seed during training and deterministic cudnn

        self.args = args
        torch.manual_seed(self.args.seed)
        np.random.seed(self.args.seed)
        random.seed(self.args.seed)

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.autograd.set_detect_anomaly(True)
        device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')

        # Check if we are running evaluation alone
        self.evaling_checkpoint = args.eval_checkpoint != ""
        # Create Logger
        self.logger = Logger(logdir=args.logdir,
                             run_name=f"{OptionCriticFeatures.__name__}-{args.env}-{args.exp}-{time.ctime()}")

        # Load Env
        # self.env = gym.make('beaverkitchen-v0')
        self.env.set_args(self.env.Args(
            human_play=False,
            level=args.level,
            headless=args.render_train and not args.render_eval,
            render_slowly=False))

        self.env.seed(self.args.seed)

        self.num_t_steps = 0
        self.num_e_steps = 0
        self.epsilon = self.args.epsilon_start
        self.state = self.env.reset()
        self.num_states = len(self.state)
        self.num_actions = len(self.env.sample_action())

        # Create Model
        self.option_critic = OptionCriticFeatures(
            env_name=self.args.env,
            in_features=self.num_states,
            num_actions=self.num_actions,
            num_options=self.args.num_options,
            temperature=self.args.temp,
            eps_start=self.args.epsilon_start,
            eps_min=self.args.epsilon_min,
            eps_decay=self.args.epsilon_decay,
            eps_test=self.args.optimal_eps,
            device=device
        )

        # Create a prime network for more stable Q values
        self.option_critic_prime = deepcopy(self.option_critic)

        self.optim = torch.optim.RMSprop(self.option_critic.parameters(), lr=self.args.learning_rate)
        torch.nn.utils.clip_grad_norm(self.option_critic.parameters(), self.args.clip_value)

        self.replay_buffer = ReplayBuffer(capacity=self.args.max_history, seed=self.args.seed)

    def run(self):
        self.n_episodes = self.args.n_episodes if not self.args.eval_checkpoint else self.args.n_eval_episodes

        for ep in range(self.n_episodes):
            if not self.args.eval_checkpoint:
                train_return = self.run_episode(ep, train=True)

                timestamp = str(datetime.now())
                print("[{}] Episode: {}, Train Return: {}".format(timestamp, ep, train_return))
                self.logger.log_return("reward/train_return", train_return, ep)

            if ep % self.args.eval_freq == 0:
                # Output and plot eval episode results
                eval_returns = []
                for _ in range(self.args.n_eval_samples):
                    eval_return = self.run_episode(ep, train=False)
                    eval_returns.append(eval_return)

                eval_return = np.array(eval_returns).mean()

                timestamp = str(datetime.now())
                print("[{}] Episode: {}, Eval Return: {}".format(timestamp, ep, eval_return))
                self.logger.log_return("reward/eval_return", eval_return, ep)

            if ep % self.args.checkpoint_freq == 0:
                if not os.path.exists(self.args.modeldir):
                    os.makedirs(self.args.modeldir)

                model_dir = os.path.join(self.args.modeldir, "episode_{:05d}".format(ep))
                self.option_critic.save(model_dir)

        print("Done running...")

    def run_episode(self, ep, train=False):
        option_lengths = {opt: [] for opt in range(self.args.num_options)}
        obs = self.env.reset()
        state = self.option_critic.get_state(to_tensor(obs))

        greedy_option = self.option_critic.greedy_option(state)
        current_option = 0

        done = False
        ep_steps = 0
        option_termination = True
        curr_op_len = 0
        rewards = 0
        cum_reward = 0

        while not done and ep_steps < self.args.max_steps_ep:
            epsilon = self.epsilon if train else 0.0

            if train:
                self.num_t_steps += 1
            else:
                self.num_e_steps += 1

            if option_termination:
                option_lengths[current_option].append(curr_op_len)
                current_option = np.random.choice(
                    self.args.num_options) if np.random.rand() < epsilon else greedy_option
                curr_op_len = 0

            action_idx, logp, entropy = self.option_critic.get_action(state, current_option)

            action = np.zeros(self.num_actions)
            action[int(action_idx)] = 1.0
            next_obs, reward, done, info = self.env.step(action)
            rewards += reward

            if train:
                self.replay_buffer.push(obs, current_option, reward, next_obs, done)

                old_state = state
                state = self.option_critic.get_state(to_tensor(next_obs))

                option_termination, greedy_option = self.option_critic.predict_option_termination(state, current_option)

                # Render domain
                if (self.args.render_train and train) or (self.args.render_eval and not train):
                    self.env.render()

                actor_loss, critic_loss = None, None
                if len(self.replay_buffer) > self.args.batch_size:
                    actor_loss = actor_loss_fn(obs, current_option, logp, entropy, \
                                               reward, done, next_obs, self.option_critic, self.option_critic_prime,
                                               self.args)
                    loss = actor_loss

                    if ep % self.args.update_frequency == 0:
                        data_batch = self.replay_buffer.sample(self.args.batch_size)
                        critic_loss = critic_loss_fn(self.option_critic, self.option_critic_prime, data_batch,
                                                     self.args)
                        loss += critic_loss

                    self.optim.zero_grad()
                    torch.autograd.set_detect_anomaly(True)
                    loss.backward()
                    self.optim.step()

                    if ep % self.args.freeze_interval == 0:
                        self.option_critic_prime.load_state_dict(self.option_critic.state_dict())

                self.logger.log_return("train/cum_reward", cum_reward, self.num_t_steps)
                self.logger.log_data(self.num_t_steps, actor_loss, critic_loss, entropy.item(), self.epsilon)

            # update global steps etc
            ep_steps += 1
            curr_op_len += 1
            obs = next_obs
            cum_reward += (self.args.gamma ** ep_steps) * reward
            self.epsilon = max(self.args.epsilon_min, self.epsilon * self.args.epsilon_decay)

            self.logger.log_return("error/volume_error_{}".format("train" if train else "eval"),
                                   float(info['volume_error']), self.num_t_steps if train else self.num_e_steps)

        self.logger.log_episode(self.num_t_steps, rewards, option_lengths, ep_steps, self.epsilon)

        return rewards