Example #1
0
def main(config):
    teacher = Teacher(config['teacher_path']).to(config['device'])

    net = Network(**config['model_args']).to(config['device'])
    data_train, data_val = get_dataset(config['source'])(**config['data_args'])

    optim = torch.optim.Adam(net.parameters(), **config['optimizer_args'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optim,
        milestones=[mult * config['max_epoch'] for mult in [0.5, 0.75]],
        gamma=0.5)

    wandb.init(project='task-distillation-07',
               config=config,
               id=config['run_name'],
               resume='auto')
    wandb.save(str(Path(wandb.run.dir) / '*.t7'))

    if wandb.run.resumed:
        resume_project(net, optim, scheduler, config)
    else:
        wandb.run.summary['step'] = 0
        wandb.run.summary['epoch'] = 0
        wandb.run.summary['best_epoch'] = 0

    resume_epoch = max(wandb.run.summary['epoch'],
                       wandb.run.summary['best_epoch'])

    for epoch in tqdm.tqdm(range(resume_epoch + 1, config['max_epoch'] + 1),
                           desc='epoch',
                           position=0):
        wandb.run.summary['epoch'] = epoch

        checkpoint_project(net, optim, scheduler, config)

        loss_train = train_or_eval(teacher, net, data_train, optim, True,
                                   config)
        print(loss_train)

        with torch.no_grad():
            loss_val = train_or_eval(teacher, net, data_val, None, False,
                                     config)
        wandb.log({'train/loss_epoch': loss_train, 'val/loss_epoch': loss_val})

        if loss_val < wandb.run.summary.get('best_val_loss', np.inf):
            wandb.run.summary['best_val_loss'] = loss_val
            wandb.run.summary['best_epoch'] = epoch
            torch.save(net.state_dict(), Path(wandb.run.dir) / 'model_best.t7')

        if epoch % 10 == 0:
            torch.save(net.state_dict(),
                       Path(wandb.run.dir) / ('model_%03d.t7' % epoch))
Example #2
0
class DQNAgent:

    def __init__(
        self,
        env: UnityEnvironment,
        memory_size: int,
        batch_size: int,
        target_update: int,
        epsilon_decay: float = 1 / 2000,
        max_epsilon: float = 1.0,
        min_epsilon: float = 0.1,
        gamma: float = 0.99,
        ):
        self.brain_name = env.brain_names[0]
        self.brain = env.brains[self.brain_name]
        env_info = env.reset(train_mode=True)[self.brain_name]
        self.env = env
        action_size = self.brain.vector_action_space_size
        state = env_info.vector_observations[0]
        state_size = len(state)
        
        self.obs_dim = state_size
        self.action_dim = 1

        self.memory = ReplayBuffer(self.obs_dim, self.action_dim, memory_size, batch_size)


        self.batch_size = batch_size
        self.target_update = target_update
        self.epsilon_decay = epsilon_decay
        self.max_epsilon = max_epsilon
        self.min_epsilon = min_epsilon
        self.gamma = gamma
        self.epsilon = max_epsilon

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        
        self.dqn = Network(self.obs_dim, self.action_dim)
        self.dqn_target = Network(self.obs_dim, self.action_dim)
        self.dqn_target.load_state_dict(self.dqn.state_dict())
        self.dqn_target.eval()

        self.optimizer = optim.Adam(self.dqn.parameters(), lr=5e-5)

        self.transition = list()

        self.is_test = False

    def select_action(self, state: np.ndarray) -> np.int64:
        """ Select an action given input """
        if self.epsilon > np.random.random():
            selected_action = np.random.random_integers(0, self.action_dim-1)
        else:
            selected_action = self.dqn(
                torch.FloatTensor(state).to(self.device)
            )
            selected_action = np.argmax(selected_action.detach().cpu().numpy())

        
        if not self.is_test:
            self.transition = [state, selected_action]
        
        return selected_action

    def step(self, action: np.int64) -> Tuple[np.ndarray, np.float64, bool]:
        "Take an action and return environment response"
        env_info = self.env.step(action)[self.brain_name]
        next_state = env_info.vector_observations[0]   
        reward = env_info.rewards[0]                   
        done = env_info.local_done[0]
    
        if not self.is_test:
            self.transition += [reward, next_state, done]
            self.memory.store(*self.transition)

        return next_state, reward, done

    def update_model(self) -> torch.Tensor:
        """ Update model by gradient descent"""
        samples = self.memory.sample_batch()
        loss = self._compute_dqn_loss(samples)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def train(self, num_episode: int, max_iteration: int=1000, plotting_interval: int=400):
        """  train the agent """
        self.is_test = False

        env_info = self.env.reset(train_mode=True)[self.brain_name]
        state = env_info.vector_observations[0]

        update_cnt = 0
        epsilons = []
        losses = []
        avg_losses= []
        scores = []
        avg_scores = []

        for episode in range(num_episode):
            env_info = self.env.reset(train_mode=True)[self.brain_name]
            state = env_info.vector_observations[0]
            score = 0
            for iter in range(max_iteration):
                action = self.select_action(state)
                next_state, reward, done = self.step(action)
                state = next_state
                score += reward
                if done:
                    break

                if len(self.memory) > self.batch_size:
                    loss = self.update_model()
                    losses.append(loss)
                    update_cnt += 1

            avg_losses.append(np.mean(losses))
            losses = []
            self.epsilon = max(
                self.min_epsilon, self.epsilon - (
                    self.max_epsilon - self.min_epsilon
                ) * self.epsilon_decay
            )
            epsilons.append(self.epsilon)
            
            if update_cnt % self.target_update == 0:
                self._target_hard_update()
            scores.append(score)
            epsilons.append(self.epsilon)

            if episode >= 100:
                avg_scores.append(np.mean(scores[-100:]))
            self._plot(episode, scores, avg_scores, avg_losses, epsilons)
        torch.save(self.dqn.state_dict(), "model_weight/dqn.pt")



    def test(self):
        """ Test agent """
        self.is_test = True
        env_info = self.env.reset(train_mode=False)[self.brain_name]
        state = env_info.vector_observations[0]
        done = False
        score = 0

        while not done:
            action = self.select_action(state)
            next_state, reward, done = self.step(action)

            state = next_state
            score += reward

        print("score: ", score)
        self.env.close()


    def _compute_dqn_loss(self, samples: Dict[str, np.ndarray], gamma: float=0.99) -> torch.Tensor:
        """ Compute and return DQN loss"""
        gamma = self.gamma
        device = self.device
        state = torch.FloatTensor(samples["obs"]).to(device)
        next_state = torch.FloatTensor(samples["next_obs"]).to(device)
        action = torch.LongTensor(samples["acts"]).reshape(-1, 1).to(device)
        reward = torch.FloatTensor(samples["rews"]).reshape(-1, 1).to(device)
        done = torch.FloatTensor(samples["done"]).reshape(-1, 1).to(device)
        
        curr_q_value = self.dqn(state).gather(1, action)
            
        next_q_value = self.dqn_target(next_state).max(dim=1, keepdim=True)[0].detach()
        mask = 1 - done
        target = (reward + gamma * next_q_value * mask).to(device)
        loss = F.smooth_l1_loss(curr_q_value, target)

        return loss


    def _target_hard_update(self):
        """ update target network """
        self.dqn_target.load_state_dict(self.dqn.state_dict())

    def _plot(
        self,
        episode :int,
        scores: List[float],
        avg_scores: List[float],
        losses: List[float],
        epsilons: List[float]
    ):
        """ Plot the training process"""
        plt.figure(figsize=(20, 5))
        plt.subplot(141)
        if len(avg_scores) > 0:
            plt.title("Average reward per 100 episodes. Score: %s" % (avg_scores[-1]))
        else:
            plt.title("Average reward over 100 episodes.")
        plt.plot([100 + i for i in range(len(avg_scores))], avg_scores)
        plt.subplot(142)
        plt.title("episode %s. Score: %s" % (episode, np.mean(scores[-10:])))
        plt.plot(scores)
        plt.subplot(143)
        plt.title('Loss')
        plt.plot(losses)
        plt.subplot(144)
        plt.title('epsilons')
        plt.plot(epsilons)
        plt.savefig('plots/dqn_result.png')
Example #3
0
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    assert args.dataset in ["psdb_train",
                            "psdb_test"], "Unknown dataset: %s" % args.dataset
    dataset = PSDB(args.dataset)
    dataloader = DataLoader(dataset, batch_size=1, sampler=PSSampler(dataset))
    logging.info("Loaded dataset: %s" % args.dataset)

    # Initialize model
    net = Network()
    if args.weights:
        state_dict = torch.load(args.weights)
        net.load_state_dict(
            {k: v
             for k, v in state_dict.items() if k in net.state_dict()})
        logging.info("Loaded pretrained model from: %s" % args.weights)

    # Initialize optimizer
    lr = cfg.TRAIN.LEARNING_RATE
    weight_decay = cfg.TRAIN.WEIGHT_DECAY
    params = []
    for k, v in net.named_parameters():
        if v.requires_grad:
            if "BN" in k:
                params += [{"params": [v], "lr": lr, "weight_decay": 0}]
            elif "bias" in k:
                params += [{"params": [v], "lr": 2 * lr, "weight_decay": 0}]
            else:
                params += [{
                    "params": [v],