class Trainer:
    def __init__(
        self,
        config: TrainConfig,
        model: ConveRTModel,
        train_dataloader: DataLoader,
        eval_dataloader: DataLoader,
        logger: Logger,
    ):
        """
        모델을 학습시키기 위한 로직을 관리하는 클래스입니다.
        """
        self.config = config
        self.model = model
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader
        self.logger = logger

        self.device, self.list_ids = self._prepare_device(config.n_gpu, logger)
        self.model.to(self.device)
        # if len(self.list_ids) > 1:
        #     self.model = nn.DataParallel(self.model, device_ids=self.list_ids)
        self.optimizer = Adam(model.parameters(), lr=config.learning_rate)
        self.criterion = (nn.NLLLoss() if self.config.label_smoothing_value
                          == 0.0 else nn.KLDivLoss(reduction="batchmean"))
        self.criterion.to(self.device)

        self.steps_per_epoch = len(train_dataloader)
        self.total_steps = self.steps_per_epoch * config.epoch

    def train(self):
        self.logger.info("========= Start of train config ========")
        self.logger.info(f"device                : {self.device}")
        self.logger.info(
            f"dataset length/ train : {len(self.train_dataloader.dataset)}")
        self.logger.info(
            f"dataset length/ test  : {len(self.eval_dataloader.dataset)}")
        self.logger.info(f"max sequence length   : {self.config.max_seq_len}")
        self.logger.info(
            f"train batch size      : {self.config.train_batch_size}")
        self.logger.info(
            f"label smoothing value : {self.config.label_smoothing_value}")
        self.logger.info(
            f"learning rate         : {self.config.learning_rate}")
        self.logger.info(f"dropout prob          : {self.config.dropout_prob}")
        self.logger.info(f"total epoch           : {self.config.epoch}")
        self.logger.info(f"steps per epoch       : {self.steps_per_epoch}")
        self.logger.info(f"total steps           : {self.total_steps}")
        self.logger.info("========= End of train config ========")
        global_step = 0
        for epoch in range(1, self.config.epoch + 1):
            loss_sum = 0.0
            for data in self.train_dataloader:
                self.model.train()
                batch_size = data[0].size()[0]
                global_step += 1

                self.optimizer.zero_grad()

                query = data[0].to(self.device)
                context = data[1].to(self.device)
                reply = data[2].to(self.device)

                outputs_q, outputs_c, outputs_qc = self.model.forward(
                    query, context, reply)

                if self.config.label_smoothing_value == 0.0:
                    target_labels = torch.arange(batch_size).to(self.device)
                else:
                    target_labels = get_smoothing_labels(
                        batch_size,
                        self.config.label_smoothing_value).to(self.device)

                loss_q = self.criterion(outputs_q, target_labels)
                loss_c = self.criterion(outputs_c, target_labels)
                loss_qc = self.criterion(outputs_qc, target_labels)
                total_loss = loss_q + loss_c + loss_qc

                total_loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
                self.optimizer.step()

                loss_sum += total_loss.item()
                if global_step % self.config.train_log_interval == 0:
                    mean_loss = loss_sum / self.config.train_log_interval
                    self.logger.info(
                        f"Epoch {epoch} Step {global_step} Loss {mean_loss:.4f}"
                    )
                    loss_sum = 0.0
                if global_step % self.config.val_log_interval == 0:
                    self._validate(global_step)
                if global_step % self.config.save_interval == 0:
                    self._save_model(self.model, global_step)

    def _validate(self, global_step):
        self.model.eval()
        correct_top_k = [0] * 5
        total_instance_num = 0
        with torch.no_grad():
            for data in self.eval_dataloader:
                batch_size = data[0].size()[0]
                total_instance_num += batch_size

                query = data[0].to(self.device)
                context = data[1].to(self.device)
                candidates = data[2].to(self.device)

                outputs = self.model.validate_forward(query, context,
                                                      candidates)

                # (cls_num, k)
                _, arg_top_ks = torch.topk(outputs, k=5)
                # (k,cls_num)
                correct_tensor = arg_top_ks.transpose(0, 1).eq(0)

                for k in range(5):
                    correct_top_k[k] += int(torch.sum(correct_tensor[:k + 1]))

            acc_at_1 = float(correct_top_k[0]) / total_instance_num
            acc_at_5 = float(correct_top_k[4]) / total_instance_num
            self.logger.info(
                f"[Validation] Hits@1  {acc_at_1:.4f} Hits@5 {acc_at_5:.4f}")

    def _prepare_device(self, n_gpu_use: int,
                        logger: Logger) -> Tuple[torch.device, List[int]]:
        """
        setup GPU device if available, move model into configured device
        """
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            logger.warn("Warning: There's no GPU available on this machine,"
                        "training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            logger.warn(
                "Warning: The number of GPU's configured to use is {}, but only {} are available "
                "on this machine.".format(n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        device = torch.device("cuda:0" if n_gpu_use > 0 else "cpu")
        list_ids = list(range(n_gpu_use))
        return device, list_ids

    def _save_model(self, model: nn.Module, step: int):
        """모델을 지정된 경로에 저장하는 함수입니다."""
        if isinstance(model, nn.DataParallel):
            torch.save(
                model.module.state_dict(),
                f"{self.config.save_model_file_prefix}_step_{step}.pth")
        else:
            torch.save(
                model.state_dict(),
                f"{self.config.save_model_file_prefix}_step_{step}.pth")
Esempio n. 2
0
class Trainer:
    def __init__(self, config, data_loader):
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.num_epoch = config.num_epoch
        self.epoch = config.epoch
        self.image_size = config.image_size
        self.data_loader = data_loader
        self.checkpoint_dir = config.checkpoint_dir
        self.batch_size = config.batch_size
        self.sample_dir = config.sample_dir
        self.nf = config.nf
        self.scale_factor = config.scale_factor

        if config.is_perceptual_oriented:
            self.lr = config.p_lr
            self.content_loss_factor = config.p_content_loss_factor
            self.perceptual_loss_factor = config.p_perceptual_loss_factor
            self.adversarial_loss_factor = config.p_adversarial_loss_factor
            self.decay_iter = config.p_decay_iter
        else:
            self.lr = config.g_lr
            self.content_loss_factor = config.g_content_loss_factor
            self.perceptual_loss_factor = config.g_perceptual_loss_factor
            self.adversarial_loss_factor = config.g_adversarial_loss_factor
            self.decay_iter = config.g_decay_iter

        self.build_model()
        self.optimizer_generator = Adam(self.generator.parameters(),
                                        lr=self.lr,
                                        betas=(config.b1, config.b2),
                                        weight_decay=config.weight_decay)
        self.optimizer_discriminator = Adam(self.discriminator.parameters(),
                                            lr=self.lr,
                                            betas=(config.b1, config.b2),
                                            weight_decay=config.weight_decay)

        self.lr_scheduler_generator = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer_generator, self.decay_iter)
        self.lr_scheduler_discriminator = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer_discriminator, self.decay_iter)

    def train(self):
        total_step = len(self.data_loader)
        adversarial_criterion = nn.BCEWithLogitsLoss().to(self.device)
        content_criterion = nn.L1Loss().to(self.device)
        perception_criterion = PerceptualLoss().to(self.device)
        self.generator.train()
        self.discriminator.train()

        for epoch in range(self.epoch, self.num_epoch):
            if not os.path.exists(os.path.join(self.sample_dir, str(epoch))):
                os.makedirs(os.path.join(self.sample_dir, str(epoch)))

            for step, image in enumerate(self.data_loader):
                low_resolution = image['lr'].to(self.device)
                high_resolution = image['hr'].to(self.device)

                real_labels = torch.ones(
                    (high_resolution.size(0), 1)).to(self.device)
                fake_labels = torch.zeros(
                    (high_resolution.size(0), 1)).to(self.device)

                ##########################
                #   training generator   #
                ##########################
                self.optimizer_generator.zero_grad()
                fake_high_resolution = self.generator(low_resolution)

                score_real = self.discriminator(high_resolution)
                score_fake = self.discriminator(fake_high_resolution)
                discriminator_rf = score_real - score_fake.mean()
                discriminator_fr = score_fake - score_real.mean()

                adversarial_loss_rf = adversarial_criterion(
                    discriminator_rf, fake_labels)
                adversarial_loss_fr = adversarial_criterion(
                    discriminator_fr, real_labels)
                adversarial_loss = (adversarial_loss_fr +
                                    adversarial_loss_rf) / 2

                perceptual_loss = perception_criterion(high_resolution,
                                                       fake_high_resolution)
                content_loss = content_criterion(fake_high_resolution,
                                                 high_resolution)

                generator_loss = adversarial_loss * self.adversarial_loss_factor + \
                                 perceptual_loss * self.perceptual_loss_factor + \
                                 content_loss * self.content_loss_factor

                generator_loss.backward()
                self.optimizer_generator.step()

                ##########################
                # training discriminator #
                ##########################

                self.optimizer_discriminator.zero_grad()

                score_real = self.discriminator(high_resolution)
                score_fake = self.discriminator(fake_high_resolution.detach())
                discriminator_rf = score_real - score_fake.mean()
                discriminator_fr = score_fake - score_real.mean()

                adversarial_loss_rf = adversarial_criterion(
                    discriminator_rf, real_labels)
                adversarial_loss_fr = adversarial_criterion(
                    discriminator_fr, fake_labels)
                discriminator_loss = (adversarial_loss_fr +
                                      adversarial_loss_rf) / 2

                discriminator_loss.backward()
                self.optimizer_discriminator.step()

                self.lr_scheduler_generator.step()
                self.lr_scheduler_discriminator.step()
                if step % 1000 == 0:
                    print(
                        f"[Epoch {epoch}/{self.num_epoch}] [Batch {step}/{total_step}] "
                        f"[D loss {discriminator_loss.item():.4f}] [G loss {generator_loss.item():.4f}] "
                        f"[adversarial loss {adversarial_loss.item() * self.adversarial_loss_factor:.4f}]"
                        f"[perceptual loss {perceptual_loss.item() * self.perceptual_loss_factor:.4f}]"
                        f"[content loss {content_loss.item() * self.content_loss_factor:.4f}]"
                        f"")
                    if step % 5000 == 0:
                        result = torch.cat(
                            (high_resolution, fake_high_resolution), 2)
                        save_image(
                            result,
                            os.path.join(self.sample_dir, str(epoch),
                                         f"SR_{step}.png"))

            torch.save(
                self.generator.state_dict(),
                os.path.join(self.checkpoint_dir, f"generator_{epoch}.pth"))
            torch.save(
                self.discriminator.state_dict(),
                os.path.join(self.checkpoint_dir,
                             f"discriminator_{epoch}.pth"))

    def build_model(self):
        self.generator = ESRGAN(3, 3, 64,
                                scale_factor=self.scale_factor).to(self.device)
        self.discriminator = Discriminator().to(self.device)
        self.load_model()

    def load_model(self):
        print(f"[*] Load model from {self.checkpoint_dir}")
        if not os.path.exists(self.checkpoint_dir):
            self.makedirs = os.makedirs(self.checkpoint_dir)

        if not os.listdir(self.checkpoint_dir):
            print(f"[!] No checkpoint in {self.checkpoint_dir}")
            return

        generator = glob(
            os.path.join(self.checkpoint_dir,
                         f'generator_{self.epoch - 1}.pth'))
        discriminator = glob(
            os.path.join(self.checkpoint_dir,
                         f'discriminator_{self.epoch - 1}.pth'))

        if not generator:
            print(f"[!] No checkpoint in epoch {self.epoch - 1}")
            return

        self.generator.load_state_dict(torch.load(generator[0]))
        self.discriminator.load_state_dict(torch.load(discriminator[0]))
class GoalConditionedSAC(Agent):
    def __init__(self,
                 algo_params,
                 env,
                 transition_tuple=None,
                 path=None,
                 seed=-1):
        # environment
        self.env = env
        self.env.seed(seed)
        obs = self.env.reset()
        algo_params.update({
            'state_dim': obs['observation'].shape[0],
            'goal_dim': obs['desired_goal'].shape[0],
            'action_dim': self.env.action_space.shape[0],
            'action_max': self.env.action_space.high,
            'action_scaling': self.env.action_space.high[0],
            'init_input_means': None,
            'init_input_vars': None
        })
        # training args
        self.training_epochs = algo_params['training_epochs']
        self.training_cycles = algo_params['training_cycles']
        self.training_episodes = algo_params['training_episodes']
        self.testing_gap = algo_params['testing_gap']
        self.testing_episodes = algo_params['testing_episodes']
        self.saving_gap = algo_params['saving_gap']

        super(GoalConditionedSAC,
              self).__init__(algo_params,
                             transition_tuple=transition_tuple,
                             goal_conditioned=True,
                             path=path,
                             seed=seed)
        # torch
        self.network_dict.update({
            'actor':
            StochasticActor(self.state_dim + self.goal_dim,
                            self.action_dim,
                            log_std_min=-6,
                            log_std_max=1,
                            action_scaling=self.action_scaling).to(
                                self.device),
            'critic_1':
            Critic(self.state_dim + self.goal_dim + self.action_dim,
                   1).to(self.device),
            'critic_1_target':
            Critic(self.state_dim + self.goal_dim + self.action_dim,
                   1).to(self.device),
            'critic_2':
            Critic(self.state_dim + self.goal_dim + self.action_dim,
                   1).to(self.device),
            'critic_2_target':
            Critic(self.state_dim + self.goal_dim + self.action_dim,
                   1).to(self.device),
            'alpha':
            algo_params['alpha'],
            'log_alpha':
            T.tensor(np.log(algo_params['alpha']),
                     requires_grad=True,
                     device=self.device),
        })
        self.network_keys_to_save = ['actor', 'critic_1_target']
        self.actor_optimizer = Adam(self.network_dict['actor'].parameters(),
                                    lr=self.actor_learning_rate)
        self.critic_1_optimizer = Adam(
            self.network_dict['critic_1'].parameters(),
            lr=self.critic_learning_rate)
        self.critic_2_optimizer = Adam(
            self.network_dict['critic_2'].parameters(),
            lr=self.critic_learning_rate)
        self._soft_update(self.network_dict['critic_1'],
                          self.network_dict['critic_1_target'],
                          tau=1)
        self._soft_update(self.network_dict['critic_2'],
                          self.network_dict['critic_2_target'],
                          tau=1)
        self.target_entropy = -self.action_dim
        self.alpha_optimizer = Adam([self.network_dict['log_alpha']],
                                    lr=self.actor_learning_rate)
        # training args
        self.clip_value = algo_params['clip_value']
        self.actor_update_interval = algo_params['actor_update_interval']
        self.critic_target_update_interval = algo_params[
            'critic_target_update_interval']
        # statistic dict
        self.statistic_dict.update({
            'cycle_return': [],
            'cycle_success_rate': [],
            'epoch_test_return': [],
            'epoch_test_success_rate': [],
            'alpha': [],
            'policy_entropy': [],
        })

    def run(self, test=False, render=False, load_network_ep=None, sleep=0):
        # training setup uses a hierarchy of Epoch, Cycle and Episode
        #   following the HER paper: https://papers.nips.cc/paper/2017/hash/453fadbd8a1a3af50a9df4df899537b5-Abstract.html
        if test:
            if load_network_ep is not None:
                print("Loading network parameters...")
                self._load_network(ep=load_network_ep)
            print("Start testing...")
        else:
            print("Start training...")

        for epo in range(self.training_epochs):
            for cyc in range(self.training_cycles):
                cycle_return = 0
                cycle_success = 0
                for ep in range(self.training_episodes):
                    ep_return = self._interact(render, test, sleep=sleep)
                    cycle_return += ep_return
                    if ep_return > -50:
                        cycle_success += 1

                self.statistic_dict['cycle_return'].append(
                    cycle_return / self.training_episodes)
                self.statistic_dict['cycle_success_rate'].append(
                    cycle_success / self.training_episodes)
                print(
                    "Epoch %i" % epo, "Cycle %i" % cyc, "avg. return %0.1f" %
                    (cycle_return / self.training_episodes),
                    "success rate %0.1f" %
                    (cycle_success / self.training_episodes))

            if (epo % self.testing_gap == 0) and (epo != 0) and (not test):
                test_return = 0
                test_success = 0
                for test_ep in range(self.testing_episodes):
                    ep_test_return = self._interact(render, test=True)
                    test_return += ep_test_return
                    if ep_test_return > -50:
                        test_success += 1
                self.statistic_dict['epoch_test_return'].append(
                    test_return / self.testing_episodes)
                self.statistic_dict['epoch_test_success_rate'].append(
                    test_success / self.testing_episodes)
                print(
                    "Epoch %i" % epo, "test avg. return %0.1f" %
                    (test_return / self.testing_episodes))

            if (epo % self.saving_gap == 0) and (epo != 0) and (not test):
                self._save_network(ep=epo)

        if not test:
            print("Finished training")
            print("Saving statistics...")
            self._plot_statistics(x_labels={
                'critic_loss':
                'Optimization epoch (per ' + str(self.optimizer_steps) +
                ' steps)',
                'actor_loss':
                'Optimization epoch (per ' + str(self.optimizer_steps) +
                ' steps)',
                'alpha':
                'Optimization epoch (per ' + str(self.optimizer_steps) +
                ' steps)',
                'policy_entropy':
                'Optimization epoch (per ' + str(self.optimizer_steps) +
                ' steps)'
            },
                                  save_to_file=True)
        else:
            print("Finished testing")

    def _interact(self, render=False, test=False, sleep=0):
        done = False
        obs = self.env.reset()
        ep_return = 0
        new_episode = True
        # start a new episode
        while not done:
            if render:
                self.env.render()
            action = self._select_action(obs, test=test)
            new_obs, reward, done, info = self.env.step(action)
            time.sleep(sleep)
            ep_return += reward
            if not test:
                self._remember(obs['observation'],
                               obs['desired_goal'],
                               action,
                               new_obs['observation'],
                               new_obs['achieved_goal'],
                               reward,
                               1 - int(done),
                               new_episode=new_episode)
                if self.observation_normalization:
                    self.normalizer.store_history(
                        np.concatenate(
                            (new_obs['observation'], new_obs['achieved_goal']),
                            axis=0))
            obs = new_obs
            new_episode = False

        if not test:
            self.normalizer.update_mean()
            self._learn()
        return ep_return

    def _select_action(self, obs, test=False):
        inputs = np.concatenate((obs['observation'], obs['desired_goal']),
                                axis=0)
        inputs = self.normalizer(inputs)
        inputs = T.as_tensor(inputs, dtype=T.float).to(self.device)
        return self.network_dict['actor'].get_action(
            inputs, mean_pi=test).detach().cpu().numpy()

    def _learn(self, steps=None):
        if self.hindsight:
            self.buffer.modify_episodes()
        self.buffer.store_episodes()
        if len(self.buffer) < self.batch_size:
            return
        if steps is None:
            steps = self.optimizer_steps

        critic_losses = T.zeros(1, device=self.device)
        actor_losses = T.zeros(1, device=self.device)
        alphas = T.zeros(1, device=self.device)
        policy_entropies = T.zeros(1, device=self.device)
        for i in range(steps):
            if self.prioritised:
                batch, weights, inds = self.buffer.sample(self.batch_size)
                weights = T.as_tensor(weights, device=self.device).view(
                    self.batch_size, 1)
            else:
                batch = self.buffer.sample(self.batch_size)
                weights = T.ones(size=(self.batch_size, 1), device=self.device)
                inds = None

            actor_inputs = np.concatenate((batch.state, batch.desired_goal),
                                          axis=1)
            actor_inputs = self.normalizer(actor_inputs)
            actor_inputs = T.as_tensor(actor_inputs,
                                       dtype=T.float32,
                                       device=self.device)
            actions = T.as_tensor(batch.action,
                                  dtype=T.float32,
                                  device=self.device)
            critic_inputs = T.cat((actor_inputs, actions), dim=1)
            actor_inputs_ = np.concatenate(
                (batch.next_state, batch.desired_goal), axis=1)
            actor_inputs_ = self.normalizer(actor_inputs_)
            actor_inputs_ = T.as_tensor(actor_inputs_,
                                        dtype=T.float32,
                                        device=self.device)
            rewards = T.as_tensor(batch.reward,
                                  dtype=T.float32,
                                  device=self.device).unsqueeze(1)
            done = T.as_tensor(batch.done, dtype=T.float32,
                               device=self.device).unsqueeze(1)

            if self.discard_time_limit:
                done = done * 0 + 1

            with T.no_grad():
                actions_, log_probs_ = self.network_dict['actor'].get_action(
                    actor_inputs_, probs=True)
                critic_inputs_ = T.cat((actor_inputs_, actions_), dim=1)
                value_1_ = self.network_dict['critic_1_target'](critic_inputs_)
                value_2_ = self.network_dict['critic_2_target'](critic_inputs_)
                value_ = T.min(value_1_, value_2_) - (
                    self.network_dict['alpha'] * log_probs_)
                value_target = rewards + done * self.gamma * value_
                value_target = T.clamp(value_target, -self.clip_value, 0.0)

            self.critic_1_optimizer.zero_grad()
            value_estimate_1 = self.network_dict['critic_1'](critic_inputs)
            critic_loss_1 = F.mse_loss(value_estimate_1,
                                       value_target.detach(),
                                       reduction='none')
            (critic_loss_1 * weights).mean().backward()
            self.critic_1_optimizer.step()

            if self.prioritised:
                assert inds is not None
                self.buffer.update_priority(
                    inds, np.abs(critic_loss_1.cpu().detach().numpy()))

            self.critic_2_optimizer.zero_grad()
            value_estimate_2 = self.network_dict['critic_2'](critic_inputs)
            critic_loss_2 = F.mse_loss(value_estimate_2,
                                       value_target.detach(),
                                       reduction='none')
            (critic_loss_2 * weights).mean().backward()
            self.critic_2_optimizer.step()

            critic_losses += critic_loss_1.detach().mean()

            if self.optim_step_count % self.critic_target_update_interval == 0:
                self._soft_update(self.network_dict['critic_1'],
                                  self.network_dict['critic_1_target'])
                self._soft_update(self.network_dict['critic_2'],
                                  self.network_dict['critic_2_target'])

            if self.optim_step_count % self.actor_update_interval == 0:
                self.actor_optimizer.zero_grad()
                new_actions, new_log_probs, entropy = self.network_dict[
                    'actor'].get_action(actor_inputs, probs=True, entropy=True)
                critic_eval_inputs = T.cat((actor_inputs, new_actions),
                                           dim=1).to(self.device)
                new_values = T.min(
                    self.network_dict['critic_1'](critic_eval_inputs),
                    self.network_dict['critic_2'](critic_eval_inputs))
                actor_loss = (self.network_dict['alpha'] * new_log_probs -
                              new_values).mean()
                actor_loss.backward()
                self.actor_optimizer.step()

                self.alpha_optimizer.zero_grad()
                alpha_loss = (
                    self.network_dict['log_alpha'] *
                    (-new_log_probs - self.target_entropy).detach()).mean()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                self.network_dict['alpha'] = self.network_dict[
                    'log_alpha'].exp()

                actor_losses += actor_loss.detach()
                alphas += self.network_dict['alpha'].detach()
                policy_entropies += entropy.detach().mean()

            self.optim_step_count += 1

        self.statistic_dict['critic_loss'].append(critic_losses / steps)
        self.statistic_dict['actor_loss'].append(actor_losses / steps)
        self.statistic_dict['alpha'].append(alphas / steps)
        self.statistic_dict['policy_entropy'].append(policy_entropies / steps)