Exemple #1
0
def train_reinforce(args):
    '''
    Parse arguments and construct objects for training reinforce model, with no baseine
    '''
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    token_tables = op.build_token_tables()

    # initialize tensorboard for logging output
    from os import path
    train_logger = None
    if args.log_dir is not None:
        train_logger = tb.SummaryWriter(path.join(args.log_dir, 'train'),
                                        flush_secs=1)

    # Load Models
    policy = RobustFill(string_size=len(op.CHARACTER),
                        string_embedding_size=args.embedding_size,
                        decoder_inp_size=args.embedding_size,
                        hidden_size=args.hidden_size,
                        program_size=len(token_tables.op_token_table),
                        device=device)
    value = ValueNetwork(args.embedding_size, args.hidden_size).to(device)
    if args.continue_training_policy:
        policy.load_state_dict(
            torch.load(path.join(path.dirname(path.abspath(__file__)),
                                 args.checkpoint_filename),
                       map_location=device))
    elif args.continue_training:
        policy.load_state_dict(
            torch.load(path.join(path.dirname(path.abspath(__file__)),
                                 args.checkpoint_filename),
                       map_location=device))
        value.load_state_dict(
            torch.load(path.join(path.dirname(path.abspath(__file__)),
                                 args.val_checkpoint_filename),
                       map_location=device))
    policy = policy.to(device)
    value = value.to(device)
    # Initialize Optimizer
    if (args.optimizer == 'sgd'):
        pol_opt = optim.SGD(policy.parameters(), lr=args.lr)
        val_opt = optim.SGD(value.parameters(), lr=args.lr)
    else:
        pol_opt = optim.Adam(policy.parameters(), lr=args.lr)
        val_opt = optim.Adam(value.parameters(), lr=args.lr)

    # Load Environment
    env = RobustFillEnv()
    train_reinforce_(
        args,
        policy=policy,
        value=value,
        pol_opt=pol_opt,
        value_opt=val_opt,
        env=env,
        train_logger=train_logger,
        checkpoint_filename=args.checkpoint_filename,
        checkpoint_step_size=args.checkpoint_step_size,
        checkpoint_print_tensors=args.print_tensors,
    )
class DecoupledWorker(mp.Process):
    def __init__(self, id, env, gamma, global_value_network,
                 global_policy_network, global_value_optimizer,
                 global_policy_optimizer, global_episode, GLOBAL_MAX_EPISODE):
        super(DecoupledWorker, self).__init__()
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.name = "w%i" % id

        self.env = env
        self.env.seed(id)
        self.obs_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n

        self.gamma = gamma
        self.local_value_network = ValueNetwork(self.obs_dim, 1)
        self.local_policy_network = PolicyNetwork(self.obs_dim,
                                                  self.action_dim)

        self.global_value_network = global_value_network
        self.global_policy_network = global_policy_network
        self.global_episode = global_episode
        self.global_value_optimizer = global_value_optimizer
        self.global_policy_optimizer = global_policy_optimizer
        self.GLOBAL_MAX_EPISODE = GLOBAL_MAX_EPISODE

        # sync local networks with global networks
        self.sync_with_global()

    def get_action(self, state):
        state = torch.FloatTensor(state).to(self.device)
        logits = self.local_policy_network.forward(state)
        dist = F.softmax(logits, dim=0)
        probs = Categorical(dist)

        return probs.sample().cpu().detach().item()

    def compute_loss(self, trajectory):
        states = torch.FloatTensor([sars[0]
                                    for sars in trajectory]).to(self.device)
        actions = torch.LongTensor([sars[1] for sars in trajectory
                                    ]).view(-1, 1).to(self.device)
        rewards = torch.FloatTensor([sars[2]
                                     for sars in trajectory]).to(self.device)
        next_states = torch.FloatTensor([sars[3] for sars in trajectory
                                         ]).to(self.device)
        dones = torch.FloatTensor([sars[4] for sars in trajectory
                                   ]).view(-1, 1).to(self.device)

        # compute value target
        discounted_rewards = [torch.sum(torch.FloatTensor([self.gamma**i for i in range(rewards[j:].size(0))])\
             * rewards[j:]) for j in range(rewards.size(0))]  # sorry, not the most readable code.
        value_targets = rewards.view(
            -1, 1) + torch.FloatTensor(discounted_rewards).view(-1, 1).to(
                self.device)

        # compute value loss
        values = self.local_value_network.forward(states)
        value_loss = F.mse_loss(values, value_targets.detach())

        # compute policy loss with entropy bonus
        logits = self.local_policy_network.forward(states)
        dists = F.softmax(logits, dim=1)
        probs = Categorical(dists)

        # compute entropy bonus
        entropy = []
        for dist in dists:
            entropy.append(-torch.sum(dist.mean() * torch.log(dist)))
        entropy = torch.stack(entropy).sum()

        advantage = value_targets - values
        policy_loss = -probs.log_prob(actions.view(actions.size(0))).view(
            -1, 1) * advantage.detach()
        policy_loss = policy_loss.mean() - 0.001 * entropy

        return value_loss, policy_loss

    def update_global(self, trajectory):
        value_loss, policy_loss = self.compute_loss(trajectory)

        self.global_value_optimizer.zero_grad()
        value_loss.backward()
        # propagate local gradients to global parameters
        for local_params, global_params in zip(
                self.local_value_network.parameters(),
                self.global_value_network.parameters()):
            global_params._grad = local_params._grad
        self.global_value_optimizer.step()

        self.global_policy_optimizer.zero_grad()
        policy_loss.backward()
        # propagate local gradients to global parameters
        for local_params, global_params in zip(
                self.local_policy_network.parameters(),
                self.global_policy_network.parameters()):
            global_params._grad = local_params._grad
            #print(global_params._grad)
        self.global_policy_optimizer.step()

    def sync_with_global(self):
        self.local_value_network.load_state_dict(
            self.global_value_network.state_dict())
        self.local_policy_network.load_state_dict(
            self.global_policy_network.state_dict())

    def run(self):
        state = self.env.reset()
        trajectory = []  # [[s, a, r, s', done], [], ...]
        episode_reward = 0

        while self.global_episode.value < self.GLOBAL_MAX_EPISODE:
            action = self.get_action(state)
            next_state, reward, done, _ = self.env.step(action)
            trajectory.append([state, action, reward, next_state, done])
            episode_reward += reward

            if done:
                with self.global_episode.get_lock():
                    self.global_episode.value += 1
                print(self.name + " | episode: " +
                      str(self.global_episode.value) + " " +
                      str(episode_reward))

                self.update_global(trajectory)
                self.sync_with_global()

                trajectory = []
                episode_reward = 0
                state = self.env.reset()
            else:
                state = next_state