Ejemplo n.º 1
0
    def learn(self):
        exp_replay = self._fill_replay_buffer()
        state = self.env.reset()
        opt = torch.optim.Adam(self.agent.parameters(), lr=1e-5)

        for step in range(self.total_steps):
            if not utils.is_enough_ram():
                print('less that 100 Mb RAM available, freezing')
                print('make sure everythin is ok and make KeyboardInterrupt to continue')
                try:
                    while True:
                        pass
                except KeyboardInterrupt:
                    pass

            self.agent.epsilon = utils.linear_decay(self.init_epsilon, self.final_epsilon, step, self.decay_steps)

            # play
            _, state = play_and_record(state, self.agent, self.env, exp_replay, self.timesteps_per_epoch)

            # train
            obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(self.batch_size)

            # loss
            loss = compute_td_loss(obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch,
                                   self.agent, self.target_network, self.device, gamma=0.99, check_shapes=False)

            loss.backward()
            grad_norm = nn.utils.clip_grad_norm_(self.agent.parameters(), self.max_grad_norm)
            opt.step()
            opt.zero_grad()

            if step % self.refresh_target_network_freq == 0:
                self.target_network.load_state_dict(self.agent.state_dict())

            if step % self.loss_freq == 0:
                td_loss = loss.data.cpu().item()
                self.writer.add_scalar('Train/Loss', td_loss, step)
                self.writer.add_scalar('Train/GradNorm', grad_norm, step)

            if step % self.eval_freq == 0:
                mean_reward, solved_games = evaluate(self.env_creator(seed=step), self.agent, n_games=15, greedy=True)
                self.writer.add_scalar('Train/MeanReward', mean_reward, step)
                self.writer.add_scalar('Train/SolvedGames', solved_games, step)

                initial_state_q_values = self.agent.get_qvalues([self.env_creator(seed=step).reset()])
                self.writer.add_scalar('Train/QValues', np.max(initial_state_q_values), step)
Ejemplo n.º 2
0
    def train(self, trainset, valset, model):
        loader = self._create_loader(trainset, training=True)
        model.set_roberta_grad(False)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.base_lr)

        base_steps = len(trainset) // self.cfg.train_batch_size * self.cfg.base_epochs
        fine_steps = len(trainset) // self.cfg.train_batch_size * self.cfg.fine_epochs
        warmup_steps = base_steps * self.cfg.warmup
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
            lambda step: linear_decay_with_warmup(step, base_steps, int(base_steps*self.cfg.warmup)))

        for epoch in range(self.cfg.base_epochs+self.cfg.fine_epochs):
            model.train()
            model.zero_grad()
            if epoch == self.cfg.base_epochs:
                print('Start training Roberta')
                model.set_roberta_grad(True)
                optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fine_lr)
                scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                    lambda step: linear_decay(step, fine_steps))
            
            train_loss = 0.
            train_acc = 0
            for i, batch in enumerate(tqdm(loader)):
                self.to_cuda(batch)
                labels = batch['labels']
                
                logits = model(batch)
                loss = self.criterion(logits, labels) / self.cfg.grad_accum
                loss.backward()
                if (i+1) % self.cfg.grad_accum == 0:
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                
                train_loss += loss.item()
                train_acc += (torch.argmax(logits, dim=1)==labels).float().sum().item()
                    
            train_loss /= len(loader)
            train_acc /= len(trainset)

            val_loss, val_acc = self.validate(valset, model)
            
            print(f'[Epoch {epoch+1}] train_loss: {train_loss:.4f} train_acc: {train_acc:.2%}',
                  f'val_loss: {val_loss:.4f}, val_acc: {val_acc:.2%}')
Ejemplo n.º 3
0
    def run(self):
        self.net.train()
        loss_history, reward_history, step_history = deque(maxlen=50), deque(
            maxlen=50), deque(maxlen=50)
        episode = episode_step = 0
        episode_included = False
        episode_reward = 0.0

        s = self.env.reset()
        for step in range(self.max_timesteps):
            self.net.epsilon = utils.linear_decay(self.initial_epsilon,
                                                  self.final_epsilon, step,
                                                  self.max_epsilon_decay_steps)

            qvalues = self.net.get_qvalues([s])

            if isinstance(self.net, LVAE):
                qvalues, _, _ = qvalues

            action = self.net.sample_actions(qvalues)[0]

            next_s, r, done, _ = self.env.step(action)
            episode_step += 1
            episode_reward += r

            self.exp_replay.add(s, action, r, next_s, done)

            if done:
                s = self.env.reset()
                episode += 1
                episode_included = True
                reward_history.append(episode_reward)
                step_history.append(episode_step)
                episode_step = episode_reward = 0
            else:
                s = next_s

            # don't train network until replay buffer hits minimum required size
            if len(self.exp_replay) > self.exp_replay_min_size:
                loss_history.append(self.train())
                if step % self.hard_update_frequency == 0:
                    self.target_net.load_state_dict(self.net.state_dict())

                # log info to cmd and tensorboard once training has started
                if episode % self.log_interval == 0 and episode_included:
                    info = {}
                    info['steps'] = step
                    info['episode'] = episode
                    info['episode_loss'] = np.mean(list(loss_history))
                    info['episode_reward'] = np.mean(list(reward_history))
                    info['episode_step'] = np.mean(list(step_history))
                    info['epsilon'] = self.net.epsilon
                    info['exp_replay_size'] = len(self.exp_replay)
                    info['lr'] = self.opt.param_groups[0]['lr']
                    utils.log(info, self.summary_writer, step=step)

                # save and evaulate model, log evaluated reward to cmd
                if episode % self.save_interval == 0 and episode_included:
                    info = {}
                    evaulated_reward = self.evaluate()
                    ckpt = self.save_model(step=step)

                    if evaulated_reward > self.best_evaluated_reward:
                        self.best_evaluated_reward = evaulated_reward
                        shutil.copyfile(
                            ckpt, os.path.join(self.model_dir,
                                               'model_best.pth'))

                    # log the network parameters to tensorboard
                    for name, param in self.net.named_parameters():
                        self.summary_writer.add_histogram(
                            name,
                            param.clone().cpu().data.numpy(), step)

                    info['eval/reward'] = evaulated_reward
                    utils.log(info, self.summary_writer, step=step)

            episode_included = False

            if self.max_episode is not None and episode >= self.max_episode:
                break

        print('Training completed...')
Ejemplo n.º 4
0
def train(env, agent, target_network, exp_replay, loss_func, device, lr=1e-4, 
          total_steps=3 * 10**6, verbose_steps=3 * 10 ** 5, batch_size=32,
          decay_steps=1 * 10**6, init_epsilon=1.0, final_epsilon=0.1, timesteps_per_epoch=1,
          max_grad_norm=50, loss_freq=50, refresh_target_network_freq=5000, eval_freq=5000):
    stop_evaluation = False
    
    mean_rw_history = []
    td_loss_history = []
    grad_norm_history = []
    initial_state_v_history = []

    opt = torch.optim.Adam(agent.parameters(), lr=lr)
    
    state = env.reset()
    for step in trange(total_steps + 1):
        agent.epsilon = utils.linear_decay(init_epsilon, final_epsilon, step, decay_steps)

        # play
        _, state = utils.play_and_record(state, agent, env, exp_replay, timesteps_per_epoch)

        # train
        states, actions, rewards, next_states, is_done = exp_replay.sample(batch_size)
        loss = loss_func(states, actions, rewards, next_states, is_done,
                         agent, target_network, device)


        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
        opt.step()
        opt.zero_grad()

        if step % loss_freq == 0:
            td_loss_history.append(loss.data.cpu().item())
            grad_norm_history.append(grad_norm)

        if step % refresh_target_network_freq == 0:
            # Load agent weights into target_network
            #target_network.parameters() = agent.parameters()
            target_network.load_state_dict(agent.state_dict())

        if step == verbose_steps:
            print("Stopping plotting to reduce training time.")
            stop_evaluation = True
        if (step % eval_freq == 0):
            # eval the agent
            mean_rw_history.append(utils.evaluate(
                make_env(seed=step), agent, n_games=3, greedy=True, t_max=1000)
            )
            initial_state_q_values = agent.get_qvalues(
                [make_env(seed=step).reset()]
            )
            initial_state_v_history.append(np.max(initial_state_q_values))
            if not stop_evaluation:
                clear_output(True)
                print("buffer size = %i, epsilon = %.5f" %
                    (len(exp_replay), agent.epsilon))

                plt.figure(figsize=[16, 9])
                plt.subplot(2, 2, 1)
                plt.title("Mean reward per episode")
                plt.plot(mean_rw_history)
                plt.grid()

                assert not np.isnan(td_loss_history[-1])
                plt.subplot(2, 2, 2)
                plt.title("TD loss history (smoothened)")
                plt.plot(utils.smoothen(td_loss_history))
                plt.grid()

                plt.subplot(2, 2, 3)
                plt.title("Initial state V")
                plt.plot(initial_state_v_history)
                plt.grid()

                plt.subplot(2, 2, 4)
                plt.title("Grad norm history (smoothened)")
                plt.plot(utils.smoothen(grad_norm_history))
                plt.grid()

                plt.show()

    return {'reward_history': mean_rw_history, 
            'td_loss_history': td_loss_history, 
            'grad_norm_history': grad_norm_history,
            'initial_state_v_history': initial_state_v_history}
Ejemplo n.º 5
0
opt = torch.optim.Adam(agent.parameters(), lr=learning_rate)

mean_rw_history = []
td_loss_history = []
grad_norm_history = []
initial_state_v_history = []

print("Starts training on {}".format(next(agent.parameters()).device))

# populate the buffer with 128 samples
init_size = 128
play_and_record(state, agent, env, exp_replay, init_size)

for step in range(total_steps):
    agent.epsilon = utils.linear_decay(init_epsilon, final_epsilon, step,
                                       decay_steps)

    # play for $T time steps and cache the exprs to the buffer
    _, state = play_and_record(state, agent, env, exp_replay, T)

    b_idx, obses_t, actions, rewards, obses_tp1, dones, weights = exp_replay.sample(
        batch_size)

    # td loss for each sample
    td_loss = compute_td_loss(states=obses_t,
                              actions=actions,
                              rewards=rewards,
                              next_states=obses_tp1,
                              is_done=dones,
                              agent=agent,
                              target_network=target_network,