Пример #1
0
            scaler.update()

            if scheduler is not None:
                scheduler.step()

            # increments global step and save data if needed be
            new_num_samples_treated = num_samples_treated + batch[0].shape[0]
            num_batches_treated += 1

            if new_num_samples_treated > max_num_samples_to_train_on:
                state_dict = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'epoch': epoch,
                    'num_samples_treated': new_num_samples_treated,
                    "scaler": scaler.state_dict(),
                    'num_batches_treated': num_batches_treated
                }
                torch.save(
                    state_dict,
                    os.path.join(checkpoint_dir,
                                 f'{args.net}-{new_num_samples_treated}.ckpt'))
                print("training ended")
                exit()

            if (num_samples_treated // args.samples_before_ckpt) != \
                    (new_num_samples_treated // args.samples_before_ckpt):
                state_dict = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'epoch': epoch,
Пример #2
0
class CustomMTSAC(MTSAC):
    def __init__(
            self,
            policy,
            qf1,
            qf2,
            replay_buffer,
            env_spec,
            sampler,
            train_task_sampler,
            *,
            num_tasks,
            gradient_steps_per_itr,
            task_update_frequency=1,
            max_episode_length_eval=None,
            fixed_alpha=None,
            target_entropy=None,
            initial_log_entropy=0.,
            discount=0.99,
            buffer_batch_size=64,
            min_buffer_size=10000,
            target_update_tau=5e-3,
            policy_lr=3e-4,
            qf_lr=3e-4,
            reward_scale=1.0,
            optimizer=torch.optim.Adam,
            num_evaluation_episodes=5,
            # added
            fp16=False,
            log_per_task=False,
            share_train_eval_env=False):

        super().__init__(
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            replay_buffer=replay_buffer,
            env_spec=env_spec,
            sampler=sampler,
            test_sampler=sampler,  # not used, for compatibility
            train_task_sampler=train_task_sampler,
            num_tasks=num_tasks,
            gradient_steps_per_itr=gradient_steps_per_itr,
            max_episode_length_eval=max_episode_length_eval,
            fixed_alpha=fixed_alpha,
            target_entropy=target_entropy,
            initial_log_entropy=initial_log_entropy,
            discount=discount,
            buffer_batch_size=buffer_batch_size,
            min_buffer_size=min_buffer_size,
            target_update_tau=target_update_tau,
            policy_lr=policy_lr,
            qf_lr=qf_lr,
            reward_scale=reward_scale,
            optimizer=optimizer,
            steps_per_epoch=1,
            num_evaluation_episodes=num_evaluation_episodes,
        )
        self._train_task_sampler = train_task_sampler
        self._task_update_frequency = task_update_frequency
        self._fp16 = fp16
        self._log_per_task = log_per_task
        self._total_envsteps = 0

        # scalers for fp16
        # TODO: don't initialize gradscalers if not using fp16
        # Also don't save and/or restore
        self._gs_qf1 = GradScaler()
        self._gs_qf2 = GradScaler()
        self._gs_policy = GradScaler()
        self._gs_alpha = GradScaler()

        # get updates for evaluation
        self.eval_env_updates = self.resample_environment(force_update=True)
        self.share_train_eval_env = share_train_eval_env
        if self.share_train_eval_env:
            logging.warn("WARNING: Sharing train and eval environments")

        # Fix bug with alpha with optimizer
        self._use_automatic_entropy_tuning = fixed_alpha is None
        if self._use_automatic_entropy_tuning:
            self._alpha_optimizer = optimizer([self._log_alpha],
                                              lr=self._policy_lr)

    def state_dict(self):
        return {
            # parameters
            "policy": self.policy.state_dict(),
            "qf1": self._qf1.state_dict(),
            "qf2": self._qf2.state_dict(),
            "target_qf1": self._target_qf1.state_dict(),
            "target_qf2": self._target_qf2.state_dict(),
            "log_alpha": self._log_alpha,

            # scalers
            "gs_qf1": self._gs_qf1.state_dict(),
            "gs_qf2": self._gs_qf2.state_dict(),
            "gs_policy": self._gs_policy.state_dict(),
            "gs_alpha": self._gs_alpha.state_dict(),

            # optimizers
            "policy_optimizer": self._policy_optimizer.state_dict(),
            "qf1_optimizer": self._qf1_optimizer.state_dict(),
            "qf2_optimizer": self._qf2_optimizer.state_dict(),
            "alpha_optimizer": self._alpha_optimizer.state_dict(),

            # other variables
            "replay_buffer": self.replay_buffer,
            "total_envsteps": self._total_envsteps,
        }

    def load_env_state(self, env_state):
        self.eval_env_updates = env_state

    def load_state(self, state):
        # parameters
        self.policy.load_state_dict(state["policy"])
        self._qf1.load_state_dict(state["qf1"])
        self._qf2.load_state_dict(state["qf2"])
        self._target_qf1.load_state_dict(state["target_qf1"])
        self._target_qf2.load_state_dict(state["target_qf2"])
        self._log_alpha.data = state["log_alpha"]

        # scalers
        self._gs_qf1.load_state_dict(state["gs_qf1"])
        self._gs_qf2.load_state_dict(state["gs_qf2"])
        self._gs_policy.load_state_dict(state["gs_policy"])
        self._gs_alpha.load_state_dict(state["gs_alpha"])

        # optimizers
        self._policy_optimizer.load_state_dict(state["policy_optimizer"])
        self._qf1_optimizer.load_state_dict(state["qf1_optimizer"])
        self._qf2_optimizer.load_state_dict(state["qf2_optimizer"])
        self._alpha_optimizer.load_state_dict(state["alpha_optimizer"])

        # other variables
        self.replay_buffer = state["replay_buffer"]
        self._total_envsteps = state["total_envsteps"]

    def get_updated_policy(self, policy_hook=None):
        with torch.no_grad():
            updated_policy = copy.deepcopy(self.policy)
        updated_policy.eval()
        # attach hooks
        if policy_hook:
            policy_hook(updated_policy)

        return updated_policy

    def update_buffer(self, trajectories):
        """Update Buffer"""

        self._total_envsteps += sum(trajectories.lengths)
        path_returns = []
        for path in trajectories.to_list():
            self.replay_buffer.add_path(
                dict(observation=path["observations"],
                     action=path["actions"],
                     reward=path["rewards"].reshape(-1, 1),
                     next_observation=path["next_observations"],
                     terminal=np.array([
                         step_type == StepType.TERMINAL
                         for step_type in path["step_types"]
                     ]).reshape(-1, 1)))
            path_returns.append(sum(path["rewards"]))

        self.episode_rewards.append(np.mean(path_returns))

    def resample_environment(self, epoch=0, force_update=False):
        """
        TODO: fix env update in sampler

        Intended behavior:
        if epoch % self._task_update_frequency == 0 or force_update:
            return self._train_task_sampler.sample(self._num_tasks)
        """
        # TODO: remove first line to allow force update
        if epoch % self._task_update_frequency == 0 or force_update:
            return self._train_task_sampler.sample(self._num_tasks)

    def run_epoch(self, epoch, env_steps_per_epoch):
        """
        Run one epoch, which is composed of one N sample collections and N training
        steps. Each training step in their turn is composed of M gradient steps of
        batch size B

        Total number of samples used by the algorithm in a epoch is given by N * M * B
        (steps * gradient_steps * batch size)

        Samples collected are only used to update the buffer, and there is no direct
        influence on number of gradient steps or batch size.

        Returns:
            float: The average return in last epoch cycle.

        """
        t0 = time()

        env_updates = (self.eval_env_updates if self.share_train_eval_env else
                       self.resample_environment(epoch))

        new_trajectories = self._sampler.obtain_samples(
            num_samples=env_steps_per_epoch,
            agent_update=self.get_updated_policy(),
            env_updates=env_updates,
        )
        self.update_buffer(new_trajectories)
        t1 = time()
        total_losses = self.run_step()
        time_to_collect_samples = t1 - t0
        time_to_update_gradient = time() - t1

        log_dict = self._log_statistics(*total_losses)

        # TODO: switch to logger.debug once logger is fixed
        logging.warn(f"Time to collect samples: {time_to_collect_samples:.2f}")
        logging.warn(f"Time to update gradient: {time_to_update_gradient:.2f}")

        return log_dict

    def run_step(self):
        """
        Run one training step, which is composed of M gradient steps

        For M gradients steps:
        - sample a batch from buffer
        - perform one gradient step in all three networks (policy, qf1 and qf2)
        """

        total_losses = [0, 0, 0]
        for _ in range(self._gradient_steps):
            if self.replay_buffer.n_transitions_stored >= self._min_buffer_size:
                samples = as_torch_dict(
                    self.replay_buffer.sample_transitions(
                        self._buffer_batch_size))
                policy_loss, qf1_loss, qf2_loss = self.optimize_policy(samples)
                total_losses[0] += policy_loss
                total_losses[1] += qf1_loss
                total_losses[2] += qf2_loss
                self._update_targets()

        # Normalize losses by total of gradient updates
        total_losses = [loss / self._gradient_steps for loss in total_losses]

        return total_losses

    def _evaluate_policy(self, epoch, policy_hook=None):
        """Evaluate the performance of the policy via deterministic sampling.

            Statistics such as (average) discounted return and success rate are
            recorded.

        Args:
            epoch (int): The current training epoch.

        Returns:
            float: The average return across self._num_evaluation_episodes
                episodes

        """
        t0 = time()

        # Collect episodes for evaluation
        eval_trajectories, policy_hook_data = self._sampler.obtain_exact_episodes(
            n_eps_per_worker=self._num_evaluation_episodes,
            agent_update=self.get_updated_policy(policy_hook=policy_hook),
            env_updates=self.eval_env_updates,
        )

        # Log performance
        undiscounted_returns, log_dict = log_multitask_performance(
            epoch,
            batch=eval_trajectories,
            discount=self._discount,
            log_per_task=self._log_per_task)
        log_dict["average_return"] = np.mean(undiscounted_returns)

        logging.warn(f"Time to evaluate policy: {time()-t0:.2f}")

        return undiscounted_returns, log_dict, policy_hook_data

    def _log_statistics(self, policy_loss, qf1_loss, qf2_loss):
        """Record training statistics to dowel such as losses and returns.

        Args:
            policy_loss (torch.Tensor): loss from actor/policy network.
            qf1_loss (torch.Tensor): loss from 1st qf/critic network.
            qf2_loss (torch.Tensor): loss from 2nd qf/critic network.

        """
        log_dict = {}
        with torch.no_grad():
            log_dict["AlphaTemperature/mean"] = self._log_alpha.exp().mean(
            ).item()
        log_dict["Policy/Loss"] = policy_loss.item()
        log_dict["QF/{}".format("Qf1Loss")] = float(qf1_loss)
        log_dict["QF/{}".format("Qf2Loss")] = float(qf2_loss)
        log_dict[
            "ReplayBuffer/buffer_size"] = self.replay_buffer.n_transitions_stored
        log_dict["Average/TrainAverageReturn"] = np.mean(self.episode_rewards)
        log_dict["TotalEnvSteps"] = self._total_envsteps

        return log_dict

    def _get_log_alpha(self, samples_data):
        """Return the value of log_alpha.
        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.
        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`
        Raises:
            ValueError: If the number of tasks, num_tasks passed to
                this algorithm doesn't match the length of the task
                one-hot id in the observation vector.
        Returns:
            torch.Tensor: log_alpha. shape is (1, self.buffer_batch_size)
        """
        obs = samples_data["observation"]
        log_alpha = self._log_alpha
        one_hots = obs[:, -self._num_tasks:]

        if (log_alpha.shape[0] != one_hots.shape[1]
                or one_hots.shape[1] != self._num_tasks
                or log_alpha.shape[0] != self._num_tasks):
            raise ValueError(
                "The number of tasks in the environment does "
                "not match self._num_tasks. Are you sure that you passed "
                "The correct number of tasks?")

        with autocast(enabled=self._fp16):
            return torch.mm(one_hots, log_alpha.unsqueeze(0).t()).squeeze()

    def _temperature_objective(self, log_pi, samples_data):
        """Compute the temperature/alpha coefficient loss.
        Args:
            log_pi(torch.Tensor): log probability of actions that are sampled
                from the replay buffer. Shape is (1, buffer_batch_size).
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.
        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`
        Returns:
            torch.Tensor: the temperature/alpha coefficient loss.
        """
        alpha_loss = 0

        with autocast(enabled=self._fp16):
            if self._use_automatic_entropy_tuning:
                alpha_loss = (-(self._get_log_alpha(samples_data)) *
                              (log_pi.detach() + self._target_entropy)).mean()

            return alpha_loss

    def _actor_objective(self, samples_data, new_actions, log_pi_new_actions):
        """Compute the Policy/Actor loss.
        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.
            new_actions (torch.Tensor): Actions resampled from the policy based
                based on the Observations, obs, which were sampled from the
                replay buffer. Shape is (action_dim, buffer_batch_size).
            log_pi_new_actions (torch.Tensor): Log probability of the new
                actions on the TanhNormal distributions that they were sampled
                from. Shape is (1, buffer_batch_size).
        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`
        Returns:
            torch.Tensor: loss from the Policy/Actor.
        """
        obs = samples_data["observation"]

        with torch.no_grad():
            alpha = self._get_log_alpha(samples_data).exp()

        with autocast(enabled=self._fp16):
            min_q_new_actions = torch.min(self._qf1(obs, new_actions),
                                          self._qf2(obs, new_actions))

            policy_objective = ((alpha * log_pi_new_actions) -
                                min_q_new_actions.flatten()).mean()

            return policy_objective

    def _critic_objective(self, samples_data):
        """Compute the Q-function/critic loss.
        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.
        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`
        Returns:
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.
        """
        obs = samples_data["observation"]
        actions = samples_data["action"]
        rewards = samples_data["reward"].flatten()
        terminals = samples_data["terminal"].flatten()
        next_obs = samples_data["next_observation"]

        with torch.no_grad():
            alpha = self._get_log_alpha(samples_data).exp()

        with autocast(enabled=self._fp16):
            q1_pred = self._qf1(obs, actions)
            q2_pred = self._qf2(obs, actions)

            new_next_actions_dist = self.policy(next_obs)[0]
            new_next_actions_pre_tanh, new_next_actions = (
                new_next_actions_dist.rsample_with_pre_tanh_value())
            new_log_pi = new_next_actions_dist.log_prob(
                value=new_next_actions,
                pre_tanh_value=new_next_actions_pre_tanh)

            target_q_values = torch.min(
                self._target_qf1(next_obs, new_next_actions),
                self._target_qf2(next_obs, new_next_actions)).flatten() - (
                    alpha * new_log_pi)

            with torch.no_grad():
                q_target = rewards * self._reward_scale + (
                    1. - terminals) * self._discount * target_q_values

            qf1_loss = F.mse_loss(q1_pred.flatten(), q_target)
            qf2_loss = F.mse_loss(q2_pred.flatten(), q_target)

            return qf1_loss, qf2_loss

    def optimize_policy(self, samples_data):
        """Optimize the policy q_functions, and temperature coefficient. Rezero
        model weights (if applicable) after each optimizer step.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        if self._fp16:
            return self.optimize_policy_with_autocast(samples_data)

        obs = samples_data["observation"]
        qf1_loss, qf2_loss = self._critic_objective(samples_data)

        self._qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self._qf1_optimizer.step()
        self._qf1.apply(rezero_weights)

        self._qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self._qf2_optimizer.step()
        self._qf2.apply(rezero_weights)

        action_dists = self.policy(obs)[0]
        new_actions_pre_tanh, new_actions = (
            action_dists.rsample_with_pre_tanh_value())
        log_pi_new_actions = action_dists.log_prob(
            value=new_actions, pre_tanh_value=new_actions_pre_tanh)

        policy_loss = self._actor_objective(samples_data, new_actions,
                                            log_pi_new_actions)
        self._policy_optimizer.zero_grad()
        policy_loss.backward()

        self._policy_optimizer.step()
        self.policy.apply(rezero_weights)

        if self._use_automatic_entropy_tuning:
            alpha_loss = self._temperature_objective(log_pi_new_actions,
                                                     samples_data)
            self._alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self._alpha_optimizer.step()

        return policy_loss, qf1_loss, qf2_loss

    def optimize_policy_with_autocast(self, samples_data):
        """Optimize the policy q_functions, and temperature coefficient. Rezero
        model weights (if applicable) after each optimizer step.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        obs = samples_data["observation"]

        qf1_loss, qf2_loss = self._critic_objective(samples_data)

        self._qf1_optimizer.zero_grad()
        self._gs_qf1.scale(qf1_loss).backward()
        self._gs_qf1.step(self._qf1_optimizer)
        self._gs_qf1.update()
        self._qf1.apply(rezero_weights)

        self._qf2_optimizer.zero_grad()
        self._gs_qf2.scale(qf2_loss).backward()
        self._gs_qf2.step(self._qf2_optimizer)
        self._gs_qf2.update()
        self._qf2.apply(rezero_weights)

        with autocast():
            action_dists = self.policy(obs)[0]
            new_actions_pre_tanh, new_actions = (
                action_dists.rsample_with_pre_tanh_value())
            log_pi_new_actions = action_dists.log_prob(
                value=new_actions, pre_tanh_value=new_actions_pre_tanh)

        policy_loss = self._actor_objective(samples_data, new_actions,
                                            log_pi_new_actions)

        self._policy_optimizer.zero_grad()
        self._gs_policy.scale(policy_loss).backward()
        self._gs_policy.step(self._policy_optimizer)
        self._gs_policy.update()
        self.policy.apply(rezero_weights)

        if self._use_automatic_entropy_tuning:
            alpha_loss = self._temperature_objective(log_pi_new_actions,
                                                     samples_data)

            self._alpha_optimizer.zero_grad()
            self._gs_alpha.scale(alpha_loss).backward()
            self._gs_alpha.step(self._alpha_optimizer)
            self._gs_alpha.update()

        return policy_loss, qf1_loss, qf2_loss

    def shutdown_worker(self):
        """Shutdown Plotter and Sampler workers."""
        self._sampler.shutdown_worker()
Пример #3
0
class GenericTrainingManager:

    def __init__(self, params):
        self.type = None
        self.is_master = False
        self.params = params
        self.models = {}
        self.begin_time = None
        self.dataset = None
        self.paths = None
        self.latest_epoch = -1
        self.latest_batch = 0
        self.total_batch = 0
        self.latest_train_metrics = dict()
        self.latest_valid_metrics = dict()
        self.phase = None
        self.max_mem_usage_by_epoch = list()

        self.scaler = None
        self.optimizer = None
        self.lr_scheduler = None
        self.best = None
        self.writer = None

        reset_optimizer = "reset_optimizer" in self.params["training_params"] and self.params["training_params"]["reset_optimizer"]

        self.init_hardware_config()
        self.init_paths()
        self.load_dataset()
        self.load_model(reset_optimizer)

    def init_paths(self):
        ## Create output folders
        output_path = os.path.join("outputs", self.params["training_params"]["output_folder"])
        os.makedirs(output_path, exist_ok=True)
        checkpoints_path = os.path.join(output_path, "checkpoints")
        os.makedirs(checkpoints_path, exist_ok=True)
        results_path = os.path.join(output_path, "results")
        os.makedirs(results_path, exist_ok=True)

        self.paths = {
            "results": results_path,
            "checkpoints": checkpoints_path,
            "output_folder": output_path
        }

    def load_dataset(self):
        self.params["dataset_params"]["use_ddp"] = self.params["training_params"]["use_ddp"]
        self.params["dataset_params"]["batch_size"] = self.params["training_params"]["batch_size"]
        self.params["dataset_params"]["num_gpu"] = self.params["training_params"]["nb_gpu"]
        self.dataset = DatasetManager(self.params["dataset_params"])
        if self.dataset.charset:
            self.params["model_params"]["vocab_size"] = len(self.dataset.charset)

    def init_hardware_config(self):
        # Debug mode
        if self.params["training_params"]["force_cpu"]:
            self.params["training_params"]["use_ddp"] = False
            self.params["training_params"]["use_amp"] = False
        # Manage Distributed Data Parallel & GPU usage
        self.manual_seed = 1111 if "manual_seed" not in self.params["training_params"].keys() else \
        self.params["training_params"]["manual_seed"]
        self.ddp_config = {
            "master": self.params["training_params"]["use_ddp"] and self.params["training_params"]["ddp_rank"] == 0,
            "address": "localhost" if "ddp_addr" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_addr"],
            "port": "11111" if "ddp_port" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_port"],
            "backend": "nccl" if "ddp_backend" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_backend"],
            "rank": self.params["training_params"]["ddp_rank"],
        }
        self.is_master = self.ddp_config["master"] or not self.params["training_params"]["use_ddp"]
        if self.params["training_params"]["force_cpu"]:
            self.device = "cpu"
        else:
            if self.params["training_params"]["use_ddp"]:
                self.device = torch.device(self.ddp_config["rank"])
                self.params["dataset_params"]["ddp_rank"] = self.ddp_config["rank"]
                self.launch_ddp()
            else:
                self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # Print GPU info
        # global
        if (self.params["training_params"]["use_ddp"] and self.ddp_config["master"]) or not self.params["training_params"]["use_ddp"]:
            print("##################")
            print("Available GPUS: {}".format(self.params["training_params"]["nb_gpu"]))
            for i in range(self.params["training_params"]["nb_gpu"]):
                print("Rank {}: {} {}".format(i, torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i)))
            print("##################")
        # local
        print("Local GPU:")
        if self.device != "cpu":
            print("Rank {}: {} {}".format(self.params["training_params"]["ddp_rank"], torch.cuda.get_device_name(), torch.cuda.get_device_properties(self.device)))
        else:
            print("WORKING ON CPU !\n")
        print("##################")

    def load_model(self, reset_optimizer=False):
        self.params["model_params"]["use_amp"] = self.params["training_params"]["use_amp"]
        # Instanciate Model
        for model_name in self.params["model_params"]["models"].keys():
            self.models[model_name] = self.params["model_params"]["models"][model_name](self.params["model_params"])
            self.models[model_name].to(self.device)  # To GPU or CPU

        # Instanciate optimizer
        self.reset_optimizer()
        if "lr_scheduler" in self.params["training_params"] and self.params["training_params"]["lr_scheduler"]:
            self.lr_scheduler = self.params["training_params"]["lr_scheduler"]["type"](self.optimizer, gamma=self.params["training_params"]["lr_scheduler"]["gamma"])

        self.scaler = GradScaler(enabled=self.params["training_params"]["use_amp"])

        # Load previous weights
        checkpoint = None
        if self.params["training_params"]["load_epoch"] in ("best", "last"):
            for filename in os.listdir(self.paths["checkpoints"]):
                # Continue training
                if self.params["training_params"]["load_epoch"] in filename:
                    checkpoint_path = os.path.join(self.paths["checkpoints"], filename)
                    checkpoint = torch.load(checkpoint_path)
                    self.load_save_info(checkpoint)
                    self.latest_epoch = checkpoint["epoch"]
                    self.best = checkpoint["best"]
                    self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
                    # Make model compatible with Distributed Data Parallel if used
                    if self.params["training_params"]["use_ddp"]:
                        for model_name in self.models.keys():
                            self.models[model_name] = DDP(self.models[model_name], [self.ddp_config["rank"]])
                    # Load model weights from past training
                    for model_name in self.models.keys():
                        self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(model_name)])
                    # Load optimizer state from past training
                    if not reset_optimizer:
                        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
                    # Load optimizer scheduler config from past training if used
                    if "lr_scheduler" in self.params["training_params"] and self.params["training_params"]["lr_scheduler"] and "lr_scheduler_state_dict" in checkpoint.keys():
                        self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
                    break

        # Print the number of trained epoch so far with the model
        if self.is_master:
            print("LOADED EPOCH: {}\n".format(self.latest_epoch), flush=True)

        # New training
        if not checkpoint:
            # Weights initialization
            for model_name in self.models.keys():
                self.models[model_name].apply(self.weights_init)
            # Handle transfer learning instructions
            if self.params["model_params"]["transfer_learning"]:
                # Iterates over models
                for model_name in self.params["model_params"]["transfer_learning"].keys():
                    state_dict_name, path, learnable, strict = self.params["model_params"]["transfer_learning"][model_name]
                    # Loading pretrained weights file
                    checkpoint = torch.load(path)
                    try:
                        # Load pretrained weights for model
                        self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(state_dict_name)], strict=strict)
                        print("transfered weights for {}".format(state_dict_name), flush=True)
                    except RuntimeError as e:
                        print(e, flush=True)
                        # if error, try to load each parts of the model (useful if only few layers are different)
                        for key in checkpoint["{}_state_dict".format(state_dict_name)].keys():
                            try:
                                self.models[model_name].load_state_dict({key: checkpoint["{}_state_dict".format(state_dict_name)][key]}, strict=False)
                            except RuntimeError as e:
                                print(e, flush=True)
                    # Set parameters no trainable
                    if not learnable:
                        self.set_model_learnable(self.models[model_name], False)

            # make the model compatible with Distributed Data Parallel if used
            if self.params["training_params"]["use_ddp"]:
                for model_name in self.models.keys():
                    self.models[model_name] = DDP(self.models[model_name], [self.ddp_config["rank"]])
            return

    @staticmethod
    def set_model_learnable(model, learnable=True):
        for p in list(model.parameters()):
            p.requires_grad = learnable

    def save_model(self, epoch, name, keep_weights=False):
        """
        Save model weights
        """
        if not self.is_master:
            return
        to_del = []
        for filename in os.listdir(self.paths["checkpoints"]):
            if name in filename:
                to_del.append(os.path.join(self.paths["checkpoints"], filename))
        path = os.path.join(self.paths["checkpoints"], "{}_{}.pt".format(name, epoch))
        content = {
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epoch': epoch,
            "scaler_state_dict": self.scaler.state_dict(),
            'best': self.best,
        }
        if self.lr_scheduler:
            content["lr_scheduler_state_dict"] = self.lr_scheduler.state_dict()
        content = self.add_save_info(content)
        for model_name in self.models.keys():
            content["{}_state_dict".format(model_name)] = self.models[model_name].state_dict()
        torch.save(content, path)
        if not keep_weights:
            for path_to_del in to_del:
                if path_to_del != path:
                    os.remove(path_to_del)

    def reset_optimizer(self):
        """
        Reset optimizer learning rate
        """
        parameters = list()
        for model_name in self.models.keys():
            parameters += list(self.models[model_name].parameters())
        self.optimizer = self.params["training_params"]["optimizer"]["class"]\
            (parameters, **self.params["training_params"]["optimizer"]["args"])


    @staticmethod
    def weights_init(m):
        """
        Weights initialization for model training from scratch
        """
        if isinstance(m, Conv2d) or isinstance(m, Linear):
            if m.weight is not None:
                kaiming_uniform_(m.weight, nonlinearity="relu")
            if m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, InstanceNorm2d):
            if m.weight is not None:
                ones_(m.weight)
            if m.bias is not None:
                zeros_(m.bias)

    def save_params(self):
        """
        Output text file containing a summary of all hyperparameters chosen for the training
        """
        def compute_nb_params(module):
            return sum([np.prod(p.size()) for p in list(module.parameters())])

        def class_to_str_dict(my_dict):
            for key in my_dict.keys():
                if callable(my_dict[key]):
                    my_dict[key] = my_dict[key].__name__
                elif isinstance(my_dict[key], np.ndarray):
                    my_dict[key] = my_dict[key].tolist()
                elif isinstance(my_dict[key], dict):
                    my_dict[key] = class_to_str_dict(my_dict[key])
            return my_dict

        path = os.path.join(self.paths["results"], "params")
        if os.path.isfile(path):
            return
        params = copy.deepcopy(self.params)
        params = class_to_str_dict(params)
        total_params = 0
        for model_name in self.models.keys():
            current_params = compute_nb_params(self.models[model_name])
            params["model_params"]["models"][model_name] = [params["model_params"]["models"][model_name], "{:,}".format(current_params)]
            total_params += current_params
        params["model_params"]["total_params"] = "{:,}".format(total_params)

        params["hardware"] = dict()
        if self.device != "cpu":
            for i in range(self.params["training_params"]["nb_gpu"]):
                params["hardware"][str(i)] = "{} {}".format(torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i))
        else:
            params["hardware"]["0"] = "CPU"
        with open(path, 'w') as f:
            json.dump(params, f, indent=4)

    def update_memory_consumption(self):
        self.max_mem_usage_by_epoch.append(torch.cuda.max_memory_allocated())
        torch.cuda.reset_max_memory_allocated()
        with open(os.path.join(self.paths["results"], "memory.txt"), 'a') as f:
            current = round(self.max_mem_usage_by_epoch[-1]/1e9, 2)
            max = round(np.max(self.max_mem_usage_by_epoch)/1e9, 2)
            min = round(np.min(self.max_mem_usage_by_epoch)/1e9, 2)
            median = round(np.median(self.max_mem_usage_by_epoch)/1e9, 2)
            mean = round(np.mean(self.max_mem_usage_by_epoch)/1e9, 2)
            f.write("E{} - Current: {} Go - Max: {} Go - Min: {} Go - Mean: {} Go - Median: {} Go\n".format(
                self.latest_epoch, current, max, min, mean, median))

    @staticmethod
    def init_metrics(metrics_name):
        """
        Initialization of the metrics specified in metrics_name
        """
        metrics = {
            "nb_samples": 0,
            "weights": 0,
            "names": list(),
            "ids": list(),
        }
        for metric_name in metrics_name:
            if metric_name == "cer":
                metrics["nb_chars"] = 0
                metrics[metric_name] = list()
                continue
            elif metric_name == "wer":
                metrics["nb_words"] = 0
            elif metric_name in ["pred", "proba", "cer_force_len"]:
                metrics[metric_name] = list()
                continue
            elif metric_name == "diff_len":
                metrics[metric_name] = None
                continue
            metrics[metric_name] = 0
        return metrics

    @staticmethod
    def update_metrics(metrics, batch_metrics):
        """
        Add batch metrics to the metrics
        """
        for key in batch_metrics.keys():
            if key in ["diff_len", ]:
                if metrics[key] is None:
                    metrics[key] = batch_metrics[key]
                else:
                    metrics[key] = np.concatenate([metrics[key], batch_metrics[key]], axis=0)
            elif key in ["pred", ]:
                if len(metrics[key]) == 0:
                    metrics[key] = batch_metrics[key]
                else:
                    for i in range(len(metrics[key])):
                        metrics[key][i] += batch_metrics[key][i]
            else:
                metrics[key] += batch_metrics[key]
        return metrics

    def get_display_values(self, metrics, metrics_name, num_batch):
        """
        format metrics values for shell display purposes
        """
        display_values = {}
        for metric_name in metrics_name:
            if metric_name in ["cer", "cer_force_len", ]:
                edit = np.sum(metrics[metric_name])
                display_values[metric_name] = round(edit / metrics["nb_chars"], 4)
            elif metric_name == "wer":
                display_values[metric_name] = round(metrics[metric_name] / metrics["nb_words"], 4)
            elif metric_name in ["f_measure", "precision", "recall", "IoU", "mAP", "pp_f_measure", "pp_precision", "pp_recall", "pp_IoU", "pp_mAP"]:
                display_values[metric_name] = round(metrics[metric_name] / metrics["weights"], 4)
            elif metric_name in ["diff_len", ]:
                display_values[metric_name] = np.round(np.mean(np.abs(metrics[metric_name])), 3)
            elif metric_name in ["time", "pred", "probas", "nb_max_len", "worst_cer", ]:
                continue
            elif metric_name in ["loss", "loss_ctc", "loss_ce", "loss_ce_end", "loss_mse"]:
                display_values[metric_name] = round(metrics[metric_name] / self.latest_batch, 4)
            else:
                display_values[metric_name] = round(metrics[metric_name] / metrics["nb_samples"], 4)
        return display_values

    def backward_loss(self, loss, retain_graph=False):
        self.scaler.scale(loss).backward(retain_graph=retain_graph)

    def step_optimizer(self):
        self.scaler.step(self.optimizer)
        self.scaler.update()

    def train(self):
        # init tensorboard file and output param summary file
        if self.is_master:
            self.writer = SummaryWriter(self.paths["results"])
            self.save_params()
        # init variables
        self.begin_time = time()
        focus_metric_name = self.params["training_params"]["focus_metric"]
        nb_epochs = self.params["training_params"]["max_nb_epochs"]
        interval_save_weights = self.params["training_params"]["interval_save_weights"]
        metrics_name = self.params["training_params"]["train_metrics"]
        display_values = None
        # perform epochs
        for num_epoch in range(self.latest_epoch+1, nb_epochs):
            self.phase = "train"
            # Check maximum training time stop condition
            if self.params["training_params"]["max_training_time"] and time() - self.begin_time > self.params["training_params"]["max_training_time"]:
                break
            # set models trainable
            for model_name in self.models.keys():
                self.models[model_name].train()
            self.latest_epoch = num_epoch
            # init epoch metrics values
            metrics = self.init_metrics(metrics_name)
            t = tqdm(self.dataset.train_loader)
            t.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs))
            # iterates over mini-batch data
            for ind_batch, batch_data in enumerate(t):
                self.latest_batch = ind_batch + 1
                self.total_batch += 1
                # train on batch data and compute metrics
                batch_metrics = self.train_batch(batch_data, metrics_name)
                batch_metrics["names"] = batch_data["names"]
                batch_metrics["ids"] = batch_data["ids"]
                # Merge metrics if Distributed Data Parallel is used
                if self.params["training_params"]["use_ddp"]:
                    batch_metrics = self.merge_ddp_metrics(batch_metrics)
                # Update learning rate via scheduler if one is used
                if self.lr_scheduler and ind_batch % self.params["training_params"]["lr_scheduler"]["step_interval"] == 0:
                    self.lr_scheduler.step()
                # Add batch metrics values to epoch metrics values
                metrics = self.update_metrics(metrics, batch_metrics)
                display_values = self.get_display_values(metrics, metrics_name, ind_batch)
                t.set_postfix(values=str(display_values))
            # log metrics in tensorboard file
            if self.is_master:
                for key in display_values.keys():
                    self.writer.add_scalar('{}_{}'.format(self.params["dataset_params"]["train"]["name"], key), display_values[key], num_epoch)
            self.latest_train_metrics = display_values

            # evaluate and compute metrics for valid sets
            if self.params["training_params"]["eval_on_valid"] and num_epoch % self.params["training_params"]["eval_on_valid_interval"] == 0:
                for valid_set_name in self.dataset.valid_loaders.keys():
                    # evaluate set and compute metrics
                    eval_values = self.evaluate(valid_set_name)
                    self.latest_valid_metrics = eval_values
                    # log valid metrics in tensorboard file
                    if self.is_master:
                        for key in eval_values.keys():
                            self.writer.add_scalar('{}_{}'.format(valid_set_name, key), eval_values[key], num_epoch)
                        if valid_set_name == self.params["training_params"]["set_name_focus_metric"] and (self.best is None or \
                                (eval_values[focus_metric_name] < self.best and self.params["training_params"]["expected_metric_value"] == "low") or\
                                (eval_values[focus_metric_name] > self.best and self.params["training_params"]["expected_metric_value"] == "high")):
                            self.save_model(epoch=num_epoch, name="best")
                            self.best = eval_values[focus_metric_name]

            ## save model weights
            if self.is_master:
                self.save_model(epoch=num_epoch, name="last")
                self.update_memory_consumption()
                if interval_save_weights and num_epoch % interval_save_weights == 0:
                    self.save_model(epoch=num_epoch, name="weigths", keep_weights=True)
                self.writer.flush()

    def evaluate(self, set_name, **kwargs):
        self.phase = "eval"
        loader = self.dataset.valid_loaders[set_name]
        # Set models in eval mode
        for model_name in self.models.keys():
            self.models[model_name].eval()
        metrics_name = self.params["training_params"]["eval_metrics"]
        display_values = None
        # initialize epoch metrics
        metrics = self.init_metrics(metrics_name)
        t = tqdm(loader)
        t.set_description("Evaluation E{}".format(self.latest_epoch))
        with torch.no_grad():
            # iterate over batch data
            for ind_batch, batch_data in enumerate(t):
                self.latest_batch = ind_batch + 1
                # eval batch data and compute metrics
                batch_metrics = self.evaluate_batch(batch_data, metrics_name)
                batch_metrics["names"] = batch_data["names"]
                batch_metrics["ids"] = batch_data["ids"]
                # merge metrics values if Distributed Data Parallel is used
                if self.params["training_params"]["use_ddp"]:
                    batch_metrics = self.merge_ddp_metrics(batch_metrics)
                # add batch metrics to epoch metrics
                metrics = self.update_metrics(metrics, batch_metrics)
                display_values = self.get_display_values(metrics, metrics_name, ind_batch)
                t.set_postfix(values=str(display_values))
        return display_values

    def predict(self, custom_name, sets_list, metrics_name, output=False):
        self.phase = "predict"
        metrics_name = metrics_name.copy()
        self.dataset.generate_test_loader(custom_name, sets_list)
        loader = self.dataset.test_loaders[custom_name]
        # Set models in eval mode
        for model_name in self.models.keys():
            self.models[model_name].eval()
        pred_time_metric = False
        if "time" in metrics_name:
            metrics_name.remove("time")
            pred_time_metric = True
        # initialize epoch metrics
        metrics = self.init_metrics(metrics_name)
        t = tqdm(loader)
        t.set_description("Prediction")
        begin_time = time()
        with torch.no_grad():
            for ind_batch, batch_data in enumerate(t):
                # iterates over batch data
                self.latest_batch = ind_batch + 1
                # eval batch data and compute metrics
                batch_metrics = self.evaluate_batch(batch_data, metrics_name)
                batch_metrics["names"] = batch_data["names"]
                batch_metrics["ids"] = batch_data["ids"]
                # merge batch metrics if Distributed Data Parallel is used
                if self.params["training_params"]["use_ddp"]:
                    batch_metrics = self.merge_ddp_metrics(batch_metrics)
                # add batch metrics to epoch metrics
                metrics = self.update_metrics(metrics, batch_metrics)
                display_values = self.get_display_values(metrics, metrics_name, ind_batch)
                t.set_postfix(values=str(display_values))
        pred_time = time() - begin_time
        # add time metric values if requested
        if pred_time_metric:
            metrics["total_time"] = np.round(pred_time, 3)
            metrics["sample_time"] = np.round(pred_time / len(self.dataset.test_datasets[custom_name]), 4)
        # output metrics values if requested
        if output:
            for name in ["probas", ]:
                if name in metrics.keys():
                    path = os.path.join(self.paths["results"], "{}_{}_{}.txt".format(name, custom_name, self.latest_epoch))
                    info = "\n".join(metrics[name])
                    with open(path, "w") as f:
                        f.write(info)
                    del metrics[name]
            self.output(metrics, custom_name)

    def launch_ddp(self):
        """
        Initialize Distributed Data Parallel system
        """
        mp.set_start_method('fork', force=True)
        os.environ['MASTER_ADDR'] = self.ddp_config["address"]
        os.environ['MASTER_PORT'] = str(self.ddp_config["port"])
        dist.init_process_group(self.ddp_config["backend"], rank=self.ddp_config["rank"], world_size=self.params["training_params"]["nb_gpu"])
        torch.cuda.set_device(self.ddp_config["rank"])
        random.seed(self.manual_seed)
        np.random.seed(self.manual_seed)
        torch.manual_seed(self.manual_seed)
        torch.cuda.manual_seed(self.manual_seed)

    def merge_ddp_metrics(self, metrics):
        """
        Merge metrics when Distributed Data Parallel is used
        """
        for metric_name in metrics.keys():
            if metric_name in ["wer", "wer_force_len", "nb_samples", "nb_words", "nb_chars", "nb_max_len",
                               "f_measure", "precision", "recall", "IoU", "mAP", "pp_f_measure", "pp_precision", "pp_recall", "pp_IoU", "pp_mAP"]:
                metrics[metric_name] = self.sum_ddp_metric(metrics[metric_name])
            elif metric_name in ["loss", "loss_ce", "loss_ctc", "loss_ce_end"]:
                metrics[metric_name] = self.sum_ddp_metric(metrics[metric_name], average=True)
            elif metric_name in ["diff_len", "cer", "cer_force_len", "ids"]:
                metrics[metric_name] = self.cat_ddp_metric(metrics[metric_name])
        return metrics

    def sum_ddp_metric(self, metric, average=False):
        """
        Sum metrics for Distributed Data Parallel
        """
        sum = torch.tensor(metric).to(self.device)
        dist.all_reduce(sum, op=dist.ReduceOp.SUM)
        if average:
            sum.true_divide(dist.get_world_size())
        return sum.item()

    def cat_ddp_metric(self, metric):
        """
        Concatenate metrics for Distributed Data Parallel
        """
        tensor = torch.tensor(metric).unsqueeze(0).to(self.device)
        res = [torch.zeros(tensor.size()).long().to(self.device) for _ in range(dist.get_world_size())]
        dist.all_gather(res, tensor)
        return list(torch.cat(res, dim=0).flatten().cpu().numpy())

    @staticmethod
    def cleanup():
        dist.destroy_process_group()

    def train_batch(self, batch_data, metric_names):
        raise NotImplementedError

    def evaluate_batch(self, batch_data, metric_names):
        raise NotImplementedError

    def output_pred(self, pred, set_name):
        raise NotImplementedError

    def add_checkpoint_info(self, load_mode="last", **kwargs):
        for filename in os.listdir(self.paths["checkpoints"]):
            if load_mode in filename:
                checkpoint_path = os.path.join(self.paths["checkpoints"], filename)
                checkpoint = torch.load(checkpoint_path)
                for key in kwargs.keys():
                    checkpoint[key] = kwargs[key]
                torch.save(checkpoint, checkpoint_path)
            return
        self.save_model(self.latest_epoch, "last")

    def output(self, metrics, set_name):
        """
        Output metrics in text file
        """
        path = os.path.join(self.paths["results"], "predict_{}_{}.txt".format(set_name, self.latest_epoch))
        with open(path, "w") as f:
            for metric_name in metrics.keys():
                if metric_name in ["cer", "cer_force_len"]:
                    edit = np.sum(metrics[metric_name])
                    value = round(edit / metrics["nb_chars"], 4)
                elif metric_name in ["wer", ]:
                    value = round(metrics[metric_name] / metrics["nb_words"], 4)
                elif metric_name in ["loss_ce", ]:
                    value = round(metrics[metric_name] / metrics["nb_samples"], 4)
                elif metric_name in ["total_time", "sample_time", "total_output_time", "sample_output_time"]:
                    value = metrics[metric_name]
                elif metric_name in ["nb_samples", "nb_words", "nb_chars", "nb_max_len"]:
                    value = metrics[metric_name]
                elif metric_name in ["diff_len", ]:
                    f.write("{}: {}\n".format(metric_name, sorted(list(metrics[metric_name]))))
                    f.write("{}-mean_abs: {}\n".format(metric_name, np.mean(np.abs(metrics[metric_name]))))
                    continue
                elif metric_name in ["worst_cer", ]:
                    m = metric_name.split("_")[-1]
                    value = [[c, id] for c, id in zip(metrics[m], metrics["ids"])]
                    value = sorted(value, key=lambda x: x[0], reverse=True)
                    value = value[:50]
                else:
                    continue
                f.write("{}: {}\n".format(metric_name, value))

    def load_save_info(self, info_dict):
        """
        Load curriculum info from saved model info
        """
        if "curriculum_config" in info_dict.keys():
            self.dataset.train_dataset.curriculum_config = info_dict["curriculum_config"]

    def add_save_info(self, info_dict):
        """
        Add curriculum info to model info to be saved
        """
        info_dict["curriculum_config"] = self.dataset.train_dataset.curriculum_config
        return info_dict
Пример #4
0
class Trainer:
    def __init__(
        self,
        name="default",
        results_dir="results",
        models_dir="models",
        base_dir="./",
        optimizer="adam",
        latent_dim=256,
        image_size=128,
        fmap_max=512,
        transparent=False,
        batch_size=4,
        gp_weight=10,
        gradient_accumulate_every=1,
        attn_res_layers=[],
        sle_spatial=False,
        disc_output_size=5,
        antialias=False,
        lr=2e-4,
        lr_mlp=1.0,
        ttur_mult=1.0,
        save_every=1000,
        evaluate_every=1000,
        trunc_psi=0.6,
        aug_prob=None,
        aug_types=["translation", "cutout"],
        dataset_aug_prob=0.0,
        calculate_fid_every=None,
        is_ddp=False,
        rank=0,
        world_size=1,
        log=False,
        amp=False,
        *args,
        **kwargs,
    ):
        self.GAN_params = [args, kwargs]
        self.GAN = None

        self.name = name

        base_dir = Path(base_dir)
        self.base_dir = base_dir
        self.results_dir = base_dir / results_dir
        self.models_dir = base_dir / models_dir
        self.config_path = self.models_dir / name / ".config.json"

        assert is_power_of_two(
            image_size
        ), "image size must be a power of 2 (64, 128, 256, 512, 1024)"
        assert all(
            map(is_power_of_two, attn_res_layers)
        ), "resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)"

        self.optimizer = optimizer
        self.latent_dim = latent_dim
        self.image_size = image_size
        self.fmap_max = fmap_max
        self.transparent = transparent

        self.aug_prob = aug_prob
        self.aug_types = aug_types

        self.lr = lr
        self.ttur_mult = ttur_mult
        self.batch_size = batch_size
        self.gradient_accumulate_every = gradient_accumulate_every

        self.gp_weight = gp_weight

        self.evaluate_every = evaluate_every
        self.save_every = save_every
        self.steps = 0

        self.generator_top_k_gamma = 0.99
        self.generator_top_k_frac = 0.5

        self.attn_res_layers = attn_res_layers
        self.sle_spatial = sle_spatial
        self.disc_output_size = disc_output_size
        self.antialias = antialias

        self.d_loss = 0
        self.g_loss = 0
        self.last_gp_loss = None
        self.last_recon_loss = None
        self.last_fid = None

        self.init_folders()

        self.loader = None
        self.dataset_aug_prob = dataset_aug_prob

        self.calculate_fid_every = calculate_fid_every

        self.is_ddp = is_ddp
        self.is_main = rank == 0
        self.rank = rank
        self.world_size = world_size

        self.syncbatchnorm = is_ddp

        self.amp = amp
        self.G_scaler = None
        self.D_scaler = None
        if self.amp:
            self.G_scaler = GradScaler()
            self.D_scaler = GradScaler()

    @property
    def image_extension(self):
        return "jpg" if not self.transparent else "png"

    @property
    def checkpoint_num(self):
        return floor(self.steps // self.save_every)

    def init_GAN(self):
        args, kwargs = self.GAN_params

        # set some global variables before instantiating GAN

        global norm_class
        global Blur

        norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d
        Blur = nn.Identity if not self.antialias else Blur

        # handle bugs when
        # switching from multi-gpu back to single gpu

        if self.syncbatchnorm and not self.is_ddp:
            import torch.distributed as dist

            os.environ["MASTER_ADDR"] = "localhost"
            os.environ["MASTER_PORT"] = "12355"
            dist.init_process_group("nccl", rank=0, world_size=1)

        # instantiate GAN

        self.GAN = LightweightGAN(
            optimizer=self.optimizer,
            lr=self.lr,
            latent_dim=self.latent_dim,
            attn_res_layers=self.attn_res_layers,
            sle_spatial=self.sle_spatial,
            image_size=self.image_size,
            ttur_mult=self.ttur_mult,
            fmap_max=self.fmap_max,
            disc_output_size=self.disc_output_size,
            transparent=self.transparent,
            rank=self.rank,
            *args,
            **kwargs,
        )

        if self.is_ddp:
            ddp_kwargs = {
                "device_ids": [self.rank],
                "output_device": self.rank,
                "find_unused_parameters": True,
            }

            self.G_ddp = DDP(self.GAN.G, **ddp_kwargs)
            self.D_ddp = DDP(self.GAN.D, **ddp_kwargs)
            self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs)

    def write_config(self):
        self.config_path.write_text(json.dumps(self.config()))

    def load_config(self):
        config = (
            self.config()
            if not self.config_path.exists()
            else json.loads(self.config_path.read_text())
        )
        self.image_size = config["image_size"]
        self.transparent = config["transparent"]
        self.syncbatchnorm = config["syncbatchnorm"]
        self.disc_output_size = config["disc_output_size"]
        self.attn_res_layers = config.pop("attn_res_layers", [])
        self.sle_spatial = config.pop("sle_spatial", False)
        self.optimizer = config.pop("optimizer", "adam")
        self.fmap_max = config.pop("fmap_max", 512)
        del self.GAN
        self.init_GAN()

    def config(self):
        return {
            "image_size": self.image_size,
            "transparent": self.transparent,
            "syncbatchnorm": self.syncbatchnorm,
            "disc_output_size": self.disc_output_size,
            "optimizer": self.optimizer,
            "attn_res_layers": self.attn_res_layers,
            "sle_spatial": self.sle_spatial,
        }

    def set_data_src(self, folder):
        self.dataset = ImageDataset(
            folder,
            self.image_size,
            transparent=self.transparent,
            aug_prob=self.dataset_aug_prob,
        )
        sampler = (
            DistributedSampler(
                self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True
            )
            if self.is_ddp
            else None
        )
        dataloader = DataLoader(
            self.dataset,
            num_workers=math.ceil(NUM_CORES / self.world_size),
            batch_size=math.ceil(self.batch_size / self.world_size),
            sampler=sampler,
            shuffle=not self.is_ddp,
            drop_last=True,
            pin_memory=True,
        )
        self.loader = cycle(dataloader)

        # auto set augmentation prob for user if dataset is detected to be low
        num_samples = len(self.dataset)
        if not exists(self.aug_prob) and num_samples < 1e5:
            self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
            print(
                f"autosetting augmentation probability to {round(self.aug_prob * 100)}%"
            )

    def train(self):
        assert exists(
            self.loader
        ), "You must first initialize the data source with `.set_data_src(<folder of images>)`"
        device = torch.device(f"cuda:{self.rank}")

        if not exists(self.GAN):
            self.init_GAN()

        self.GAN.train()
        total_disc_loss = torch.zeros([], device=device)
        total_gen_loss = torch.zeros([], device=device)

        batch_size = math.ceil(self.batch_size / self.world_size)

        # image_size = self.GAN.image_size
        latent_dim = self.GAN.latent_dim

        aug_prob = default(self.aug_prob, 0)
        aug_types = self.aug_types
        aug_kwargs = {"prob": aug_prob, "types": aug_types}

        G = self.GAN.G if not self.is_ddp else self.G_ddp
        # D = self.GAN.D if not self.is_ddp else self.D_ddp
        D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp

        apply_gradient_penalty = self.steps % 4 == 0

        # amp related contexts and functions

        amp_context = autocast if self.amp else null_context

        def backward(amp, loss, scaler):
            if amp:
                return scaler.scale(loss).backward()
            loss.backward()

        def optimizer_step(amp, optimizer, scaler):
            if amp:
                scaler.step(optimizer)
                scaler.update()
                return
            optimizer.step()

        backward = partial(backward, self.amp)
        optimizer_step = partial(optimizer_step, self.amp)

        # train discriminator
        self.GAN.D_opt.zero_grad()
        for i in gradient_accumulate_contexts(
            self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, G]
        ):
            latents = torch.randn(batch_size, latent_dim).cuda(self.rank)
            image_batch = next(self.loader).cuda(self.rank)
            image_batch.requires_grad_()

            with amp_context():
                generated_images = G(latents)
                fake_output, fake_output_32x32, _ = D_aug(
                    generated_images.detach(), detach=True, **aug_kwargs
                )

                real_output, real_output_32x32, real_aux_loss = D_aug(
                    image_batch, calc_aux_loss=True, **aug_kwargs
                )

                real_output_loss = real_output
                fake_output_loss = fake_output

                divergence = hinge_loss(real_output_loss, fake_output_loss)
                divergence_32x32 = hinge_loss(real_output_32x32, fake_output_32x32)
                disc_loss = divergence + divergence_32x32

                aux_loss = real_aux_loss
                disc_loss = disc_loss + aux_loss

            if apply_gradient_penalty:
                outputs = [real_output, real_output_32x32]
                outputs = (
                    list(map(self.D_scaler.scale, outputs)) if self.amp else outputs
                )

                scaled_gradients = torch_grad(
                    outputs=outputs,
                    inputs=image_batch,
                    grad_outputs=list(
                        map(
                            lambda t: torch.ones(t.size(), device=image_batch.device),
                            outputs,
                        )
                    ),
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True,
                )[0]

                inv_scale = (1.0 / self.D_scaler.get_scale()) if self.amp else 1.0
                gradients = scaled_gradients * inv_scale

                with amp_context():
                    gradients = gradients.reshape(batch_size, -1)
                    gp = self.gp_weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()

                    if not torch.isnan(gp):
                        disc_loss = disc_loss + gp
                        self.last_gp_loss = gp.clone().detach().item()

            with amp_context():
                disc_loss = disc_loss / self.gradient_accumulate_every

            disc_loss.register_hook(raise_if_nan)
            backward(disc_loss, self.D_scaler)
            total_disc_loss += divergence

        self.last_recon_loss = aux_loss.item()
        self.d_loss = float(total_disc_loss.item() / self.gradient_accumulate_every)
        optimizer_step(self.GAN.D_opt, self.D_scaler)

        # train generator

        self.GAN.G_opt.zero_grad()

        for i in gradient_accumulate_contexts(
            self.gradient_accumulate_every, self.is_ddp, ddps=[G, D_aug]
        ):
            latents = torch.randn(batch_size, latent_dim).cuda(self.rank)

            with amp_context():
                generated_images = G(latents)
                fake_output, fake_output_32x32, _ = D_aug(
                    generated_images, **aug_kwargs
                )
                fake_output_loss = fake_output.mean(dim=1) + fake_output_32x32.mean(
                    dim=1
                )

                epochs = (
                    self.steps * batch_size * self.gradient_accumulate_every
                ) / len(self.dataset)
                k_frac = max(
                    self.generator_top_k_gamma ** epochs, self.generator_top_k_frac
                )
                k = math.ceil(batch_size * k_frac)

                if k != batch_size:
                    fake_output_loss, _ = fake_output_loss.topk(k=k, largest=False)

                loss = fake_output_loss.mean()
                gen_loss = loss

                gen_loss = gen_loss / self.gradient_accumulate_every
            gen_loss.register_hook(raise_if_nan)
            backward(gen_loss, self.G_scaler)
            total_gen_loss += loss

        self.g_loss = float(total_gen_loss.item() / self.gradient_accumulate_every)
        optimizer_step(self.GAN.G_opt, self.G_scaler)

        # calculate moving averages

        if self.is_main and self.steps % 10 == 0 and self.steps > 20000:
            self.GAN.EMA()

        if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2:
            self.GAN.reset_parameter_averaging()

        # save from NaN errors

        if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
            print(
                f"NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}"
            )
            self.load(self.checkpoint_num)
            raise NanException

        del total_disc_loss
        del total_gen_loss

        # periodically save results

        if self.is_main:
            if self.steps % self.save_every == 0:
                self.save(self.checkpoint_num)

            if self.steps % self.evaluate_every == 0 or (
                self.steps % 100 == 0 and self.steps < 20000
            ):
                self.evaluate(floor(self.steps / self.evaluate_every))

            if (
                exists(self.calculate_fid_every)
                and self.steps % self.calculate_fid_every == 0
                and self.steps != 0
            ):
                num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size)
                fid = self.calculate_fid(num_batches)
                self.last_fid = fid

                with open(
                    str(self.results_dir / self.name / "fid_scores.txt"), "a"
                ) as f:
                    f.write(f"{self.steps},{fid}\n")

        self.steps += 1

    @torch.no_grad()
    def evaluate(self, num=0, num_image_tiles=8, trunc=1.0):
        self.GAN.eval()

        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.latent_dim
        # image_size = self.GAN.image_size

        # latents and noise

        latents = torch.randn((num_rows ** 2, latent_dim)).cuda(self.rank)

        # regular

        generated_images = self.generate_truncated(self.GAN.G, latents)
        torchvision.utils.save_image(
            generated_images,
            str(self.results_dir / self.name / f"{str(num)}.{ext}"),
            nrow=num_rows,
        )

        # moving averages

        generated_images = self.generate_truncated(self.GAN.GE, latents)
        torchvision.utils.save_image(
            generated_images,
            str(self.results_dir / self.name / f"{str(num)}-ema.{ext}"),
            nrow=num_rows,
        )

    @torch.no_grad()
    def calculate_fid(self, num_batches):
        torch.cuda.empty_cache()

        real_path = str(self.results_dir / self.name / "fid_real") + "/"
        fake_path = str(self.results_dir / self.name / "fid_fake") + "/"

        # remove any existing files used for fid calculation and recreate directories
        rmtree(real_path, ignore_errors=True)
        rmtree(fake_path, ignore_errors=True)
        os.makedirs(real_path)
        os.makedirs(fake_path)

        for batch_num in tqdm(
            range(num_batches), desc="calculating FID - saving reals"
        ):
            real_batch = next(self.loader)
            for k in range(real_batch.size(0)):
                torchvision.utils.save_image(
                    real_batch[k, :, :, :],
                    real_path + "{}.png".format(k + batch_num * self.batch_size),
                )

        # generate a bunch of fake images in results / name / fid_fake
        self.GAN.eval()
        ext = self.image_extension

        latent_dim = self.GAN.latent_dim
        # image_size = self.GAN.image_size

        for batch_num in tqdm(
            range(num_batches), desc="calculating FID - saving generated"
        ):
            # latents and noise
            latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank)

            # moving averages
            generated_images = self.generate_truncated(self.GAN.GE, latents)

            for j in range(generated_images.size(0)):
                torchvision.utils.save_image(
                    generated_images[j, :, :, :],
                    str(
                        Path(fake_path)
                        / f"{str(j + batch_num * self.batch_size)}-ema.{ext}"
                    ),
                )

        return fid_score.calculate_fid_given_paths(
            [real_path, fake_path], 256, True, 2048
        )

    @torch.no_grad()
    def generate_truncated(self, G, style, trunc_psi=0.75, num_image_tiles=8):
        generated_images = evaluate_in_chunks(self.batch_size, G, style)
        return generated_images.clamp_(0.0, 1.0)

    @torch.no_grad()
    def generate_interpolation(
        self, num=0, num_image_tiles=8, trunc=1.0, num_steps=100, save_frames=False
    ):
        self.GAN.eval()
        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.latent_dim
        # image_size = self.GAN.image_size

        # latents and noise

        latents_low = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)
        latents_high = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)

        ratios = torch.linspace(0.0, 8.0, num_steps)

        frames = []
        for ratio in tqdm(ratios):
            interp_latents = slerp(ratio, latents_low, latents_high)
            generated_images = self.generate_truncated(self.GAN.GE, interp_latents)
            images_grid = torchvision.utils.make_grid(generated_images, nrow=num_rows)
            pil_image = transforms.ToPILImage()(images_grid.cpu())

            if self.transparent:
                background = Image.new("RGBA", pil_image.size, (255, 255, 255))
                pil_image = Image.alpha_composite(background, pil_image)

            frames.append(pil_image)

        frames[0].save(
            str(self.results_dir / self.name / f"{str(num)}.gif"),
            save_all=True,
            append_images=frames[1:],
            duration=80,
            loop=0,
            optimize=True,
        )

        if save_frames:
            folder_path = self.results_dir / self.name / f"{str(num)}"
            folder_path.mkdir(parents=True, exist_ok=True)
            for ind, frame in enumerate(frames):
                frame.save(str(folder_path / f"{str(ind)}.{ext}"))

    def print_log(self):
        data = [
            ("G", self.g_loss),
            ("D", self.d_loss),
            ("GP", self.last_gp_loss),
            ("SS", self.last_recon_loss),
            ("FID", self.last_fid),
        ]

        data = [d for d in data if exists(d[1])]
        log = " | ".join(map(lambda n: f"{n[0]}: {n[1]:.2f}", data))
        print(log)

    def model_name(self, num):
        return str(self.models_dir / self.name / f"model_{num}.pt")

    def init_folders(self):
        (self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
        (self.models_dir / self.name).mkdir(parents=True, exist_ok=True)

    def clear(self):
        rmtree(str(self.models_dir / self.name), True)
        rmtree(str(self.results_dir / self.name), True)
        rmtree(str(self.config_path), True)
        self.init_folders()

    def save(self, num):
        save_data = {"GAN": self.GAN.state_dict(), "version": __version__}

        if self.amp:
            save_data = {
                **save_data,
                "G_scaler": self.G_scaler.state_dict(),
                "D_scaler": self.D_scaler.state_dict(),
            }

        torch.save(save_data, self.model_name(num))
        self.write_config()

    def load(self, num=-1):
        self.load_config()

        name = num
        if num == -1:
            file_paths = [
                p for p in Path(self.models_dir / self.name).glob("model_*.pt")
            ]
            saved_nums = sorted(map(lambda x: int(x.stem.split("_")[1]), file_paths))
            if len(saved_nums) == 0:
                return
            name = saved_nums[-1]
            print(f"continuing from previous epoch - {name}")

        self.steps = name * self.save_every

        load_data = torch.load(self.model_name(name))

        if "version" in load_data and self.is_main:
            print(f"loading from version {load_data['version']}")

        try:
            self.GAN.load_state_dict(load_data["GAN"])
        except Exception as e:
            print(
                "unable to load save model. please try downgrading the package to the version specified by the saved model"
            )
            raise e

        if self.amp:
            if "G_scaler" in load_data:
                self.G_scaler.load_state_dict(load_data["G_scaler"])
            if "D_scaler" in load_data:
                self.D_scaler.load_state_dict(load_data["D_scaler"])
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    logger = get_logger(args.logging_file)
    logger.info("Use GPU: {} for training".format(args.gpu))

    args.rank = args.rank * ngpus_per_node + gpu
    torch.distributed.init_process_group(backend="nccl",
                                         init_method=args.dist_url,
                                         world_size=args.world_size,
                                         rank=args.rank)

    epochs = args.epochs
    input_size = args.input_size
    resume_epoch = args.resume_epoch
    initializer = KaimingInitializer()
    zero_gamma = ZeroLastGamma()
    mix_precision_training = args.mix_precision_training
    is_first_rank = True if args.rank % ngpus_per_node == 0 else False

    batches_pre_epoch = args.num_training_samples // (args.batch_size *
                                                      ngpus_per_node)
    lr = 0.1 * (args.batch_size * ngpus_per_node //
                32) if args.lr == 0 else args.lr

    model = get_model(models, args.model)

    model.apply(initializer)
    if args.last_gamma:
        model.apply(zero_gamma)
        logger.info('Apply zero last gamma init.')

    if is_first_rank and args.model_info:
        summary(model, torch.rand((1, 3, input_size, input_size)))

    parameters = model.parameters() if not args.no_wd else no_decay_bias(model)
    if args.sgd_gc:
        logger.info('Use SGD_GC optimizer.')
        optimizer = SGD_GC(parameters,
                           lr=lr,
                           momentum=args.momentum,
                           weight_decay=args.wd,
                           nesterov=True)
    else:
        optimizer = optim.SGD(parameters,
                              lr=lr,
                              momentum=args.momentum,
                              weight_decay=args.wd,
                              nesterov=True)

    lr_scheduler = CosineWarmupLr(optimizer,
                                  batches_pre_epoch,
                                  epochs,
                                  base_lr=args.lr,
                                  warmup_epochs=args.warmup_epochs)

    # dropblock_scheduler = DropBlockScheduler(model, batches_pre_epoch, epochs)

    if args.lookahead:
        optimizer = Lookahead(optimizer)
        logger.info('Use lookahead optimizer.')

    torch.cuda.set_device(args.gpu)
    model.cuda(args.gpu)
    args.num_workers = int(
        (args.num_workers + ngpus_per_node - 1) / ngpus_per_node)

    if args.mix_precision_training and is_first_rank:
        logger.info('Train with FP16.')

    scaler = GradScaler(enabled=args.mix_precision_training)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

    Loss = nn.CrossEntropyLoss().cuda(args.gpu) if not args.label_smoothing else \
        LabelSmoothingLoss(args.classes, smoothing=0.1).cuda(args.gpu)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    if args.autoaugment:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            ImageNetPolicy,
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            # Cutout(),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4),
            transforms.ToTensor(),
            normalize,
        ])

    val_transform = transforms.Compose([
        transforms.Resize(int(input_size / 0.875)),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        normalize,
    ])

    train_set = ImageNet(args.data_path,
                         split='train',
                         transform=train_transform)
    val_set = ImageNet(args.data_path, split='val', transform=val_transform)

    train_sampler = DistributedSampler(train_set)
    train_loader = DataLoader(train_set,
                              args.batch_size,
                              False,
                              pin_memory=True,
                              num_workers=args.num_workers,
                              drop_last=True,
                              sampler=train_sampler)
    val_loader = DataLoader(val_set,
                            args.batch_size,
                            False,
                            pin_memory=True,
                            num_workers=args.num_workers,
                            drop_last=False)

    if resume_epoch > 0:
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.resume_param, map_location=loc)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scaler.load_state_dict(checkpoint['scaler'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        print("Finish loading resume param.")

    torch.backends.cudnn.benchmark = True

    top1_acc = metric.Accuracy(name='Top1 Accuracy')
    top5_acc = metric.TopKAccuracy(top=5, name='Top5 Accuracy')
    loss_record = metric.NumericalCost(name='Loss')

    for epoch in range(resume_epoch, epochs):
        tic = time.time()
        train_sampler.set_epoch(epoch)
        if not args.mixup:
            train_one_epoch(model, train_loader, Loss, optimizer, epoch,
                            lr_scheduler, logger, top1_acc, loss_record,
                            scaler, args)
        else:
            train_one_epoch_mixup(model, train_loader, Loss, optimizer, epoch,
                                  lr_scheduler, logger, loss_record, scaler,
                                  args)
        train_speed = int(args.num_training_samples // (time.time() - tic))
        if is_first_rank:
            logger.info(
                'Finish one epoch speed: {} samples/s'.format(train_speed))
        test(model, val_loader, Loss, epoch, logger, top1_acc, top5_acc,
             loss_record, args)

        if args.rank % ngpus_per_node == 0:
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scaler': scaler.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
            }
            torch.save(
                checkpoint, '{}/{}_{}_{:.5}.pt'.format(args.save_dir,
                                                       args.model, epoch,
                                                       top1_acc.get()))
Пример #6
0
    def run(
        cls,
        model: AbsESPnetModel,
        optimizers: Sequence[torch.optim.Optimizer],
        schedulers: Sequence[Optional[AbsScheduler]],
        train_iter_factory: AbsIterFactory,
        valid_iter_factory: AbsIterFactory,
        plot_attention_iter_factory: Optional[AbsIterFactory],
        trainer_options,
        distributed_option: DistributedOption,
    ) -> None:
        """Perform training. This method performs the main process of training."""
        assert check_argument_types()
        # NOTE(kamo): Don't check the type more strictly as far trainer_options
        assert is_dataclass(trainer_options), type(trainer_options)
        assert len(optimizers) == len(schedulers), (len(optimizers),
                                                    len(schedulers))

        if isinstance(trainer_options.keep_nbest_models, int):
            keep_nbest_models = [trainer_options.keep_nbest_models]
        else:
            if len(trainer_options.keep_nbest_models) == 0:
                logging.warning("No keep_nbest_models is given. Change to [1]")
                trainer_options.keep_nbest_models = [1]
            keep_nbest_models = trainer_options.keep_nbest_models

        output_dir = Path(trainer_options.output_dir)
        reporter = Reporter()
        if trainer_options.use_amp:
            if V(torch.__version__) < V("1.6.0"):
                raise RuntimeError(
                    "Require torch>=1.6.0 for  Automatic Mixed Precision")
            if trainer_options.sharded_ddp:
                if fairscale is None:
                    raise RuntimeError(
                        "Requiring fairscale. Do 'pip install fairscale'")
                scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
            else:
                scaler = GradScaler()
        else:
            scaler = None

        if trainer_options.resume and (output_dir / "checkpoint.pth").exists():
            cls.resume(
                checkpoint=output_dir / "checkpoint.pth",
                model=model,
                optimizers=optimizers,
                schedulers=schedulers,
                reporter=reporter,
                scaler=scaler,
                ngpu=trainer_options.ngpu,
            )

        start_epoch = reporter.get_epoch() + 1
        if start_epoch == trainer_options.max_epoch + 1:
            logging.warning(
                f"The training has already reached at max_epoch: {start_epoch}"
            )

        if distributed_option.distributed:
            if trainer_options.sharded_ddp:
                dp_model = fairscale.nn.data_parallel.ShardedDataParallel(
                    module=model,
                    sharded_optimizer=optimizers,
                )
            else:
                dp_model = torch.nn.parallel.DistributedDataParallel(
                    model,
                    device_ids=(
                        # Perform multi-Process with multi-GPUs
                        [torch.cuda.current_device()]
                        if distributed_option.ngpu == 1
                        # Perform single-Process with multi-GPUs
                        else None),
                    output_device=(torch.cuda.current_device()
                                   if distributed_option.ngpu == 1 else None),
                    find_unused_parameters=trainer_options.unused_parameters,
                )
        elif distributed_option.ngpu > 1:
            dp_model = torch.nn.parallel.DataParallel(
                model,
                device_ids=list(range(distributed_option.ngpu)),
            )
        else:
            # NOTE(kamo): DataParallel also should work with ngpu=1,
            # but for debuggability it's better to keep this block.
            dp_model = model

        if trainer_options.use_tensorboard and (
                not distributed_option.distributed
                or distributed_option.dist_rank == 0):
            from torch.utils.tensorboard import SummaryWriter

            train_summary_writer = SummaryWriter(
                str(output_dir / "tensorboard" / "train"))
            valid_summary_writer = SummaryWriter(
                str(output_dir / "tensorboard" / "valid"))
        else:
            train_summary_writer = None

        start_time = time.perf_counter()
        for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
            if iepoch != start_epoch:
                logging.info(
                    "{}/{}epoch started. Estimated time to finish: {}".format(
                        iepoch,
                        trainer_options.max_epoch,
                        humanfriendly.format_timespan(
                            (time.perf_counter() - start_time) /
                            (iepoch - start_epoch) *
                            (trainer_options.max_epoch - iepoch + 1)),
                    ))
            else:
                logging.info(
                    f"{iepoch}/{trainer_options.max_epoch}epoch started")
            set_all_random_seed(trainer_options.seed + iepoch)

            reporter.set_epoch(iepoch)
            # 1. Train and validation for one-epoch
            with reporter.observe("train") as sub_reporter:
                all_steps_are_invalid = cls.train_one_epoch(
                    model=dp_model,
                    optimizers=optimizers,
                    schedulers=schedulers,
                    iterator=train_iter_factory.build_iter(iepoch),
                    reporter=sub_reporter,
                    scaler=scaler,
                    summary_writer=train_summary_writer,
                    options=trainer_options,
                    distributed_option=distributed_option,
                )

            with reporter.observe("valid") as sub_reporter:
                cls.validate_one_epoch(
                    model=dp_model,
                    iterator=valid_iter_factory.build_iter(iepoch),
                    reporter=sub_reporter,
                    options=trainer_options,
                    distributed_option=distributed_option,
                )
            if not distributed_option.distributed or distributed_option.dist_rank == 0:
                # att_plot doesn't support distributed
                if plot_attention_iter_factory is not None:
                    with reporter.observe("att_plot") as sub_reporter:
                        cls.plot_attention(
                            model=model,
                            output_dir=output_dir / "att_ws",
                            summary_writer=train_summary_writer,
                            iterator=plot_attention_iter_factory.build_iter(
                                iepoch),
                            reporter=sub_reporter,
                            options=trainer_options,
                        )

            # 2. LR Scheduler step
            for scheduler in schedulers:
                if isinstance(scheduler, AbsValEpochStepScheduler):
                    scheduler.step(
                        reporter.get_value(
                            *trainer_options.val_scheduler_criterion))
                elif isinstance(scheduler, AbsEpochStepScheduler):
                    scheduler.step()
            if trainer_options.sharded_ddp:
                for optimizer in optimizers:
                    if isinstance(optimizer, fairscale.optim.oss.OSS):
                        optimizer.consolidate_state_dict()

            if not distributed_option.distributed or distributed_option.dist_rank == 0:
                # 3. Report the results
                logging.info(reporter.log_message())
                if trainer_options.use_matplotlib:
                    reporter.matplotlib_plot(output_dir / "images")
                if train_summary_writer is not None:
                    reporter.tensorboard_add_scalar(train_summary_writer,
                                                    key1="train")
                    reporter.tensorboard_add_scalar(valid_summary_writer,
                                                    key1="valid")
                if trainer_options.use_wandb:
                    reporter.wandb_log()

                # 4. Save/Update the checkpoint
                torch.save(
                    {
                        "model":
                        model.state_dict(),
                        "reporter":
                        reporter.state_dict(),
                        "optimizers": [o.state_dict() for o in optimizers],
                        "schedulers": [
                            s.state_dict() if s is not None else None
                            for s in schedulers
                        ],
                        "scaler":
                        scaler.state_dict() if scaler is not None else None,
                    },
                    output_dir / "checkpoint.pth",
                )

                # 5. Save and log the model and update the link to the best model
                torch.save(model.state_dict(),
                           output_dir / f"{iepoch}epoch.pth")

                # Creates a sym link latest.pth -> {iepoch}epoch.pth
                p = output_dir / "latest.pth"
                if p.is_symlink() or p.exists():
                    p.unlink()
                p.symlink_to(f"{iepoch}epoch.pth")

                _improved = []
                for _phase, k, _mode in trainer_options.best_model_criterion:
                    # e.g. _phase, k, _mode = "train", "loss", "min"
                    if reporter.has(_phase, k):
                        best_epoch = reporter.get_best_epoch(_phase, k, _mode)
                        # Creates sym links if it's the best result
                        if best_epoch == iepoch:
                            p = output_dir / f"{_phase}.{k}.best.pth"
                            if p.is_symlink() or p.exists():
                                p.unlink()
                            p.symlink_to(f"{iepoch}epoch.pth")
                            _improved.append(f"{_phase}.{k}")
                if len(_improved) == 0:
                    logging.info("There are no improvements in this epoch")
                else:
                    logging.info("The best model has been updated: " +
                                 ", ".join(_improved))

                log_model = (trainer_options.wandb_model_log_interval > 0
                             and iepoch %
                             trainer_options.wandb_model_log_interval == 0)
                if log_model and trainer_options.use_wandb:
                    import wandb

                    logging.info("Logging Model on this epoch :::::")
                    artifact = wandb.Artifact(
                        name=f"model_{wandb.run.id}",
                        type="model",
                        metadata={"improved": _improved},
                    )
                    artifact.add_file(str(output_dir / f"{iepoch}epoch.pth"))
                    aliases = [
                        f"epoch-{iepoch}",
                        "best" if best_epoch == iepoch else "",
                    ]
                    wandb.log_artifact(artifact, aliases=aliases)

                # 6. Remove the model files excluding n-best epoch and latest epoch
                _removed = []
                # Get the union set of the n-best among multiple criterion
                nbests = set().union(*[
                    set(
                        reporter.sort_epochs(ph, k, m)
                        [:max(keep_nbest_models)])
                    for ph, k, m in trainer_options.best_model_criterion
                    if reporter.has(ph, k)
                ])

                # Generated n-best averaged model
                if (trainer_options.nbest_averaging_interval > 0
                        and iepoch % trainer_options.nbest_averaging_interval
                        == 0):
                    average_nbest_models(
                        reporter=reporter,
                        output_dir=output_dir,
                        best_model_criterion=trainer_options.
                        best_model_criterion,
                        nbest=keep_nbest_models,
                        suffix=f"till{iepoch}epoch",
                    )

                for e in range(1, iepoch):
                    p = output_dir / f"{e}epoch.pth"
                    if p.exists() and e not in nbests:
                        p.unlink()
                        _removed.append(str(p))
                if len(_removed) != 0:
                    logging.info("The model files were removed: " +
                                 ", ".join(_removed))

            # 7. If any updating haven't happened, stops the training
            if all_steps_are_invalid:
                logging.warning(
                    "The gradients at all steps are invalid in this epoch. "
                    f"Something seems wrong. This training was stopped at {iepoch}epoch"
                )
                break

            # 8. Check early stopping
            if trainer_options.patience is not None:
                if reporter.check_early_stopping(
                        trainer_options.patience,
                        *trainer_options.early_stopping_criterion):
                    break

        else:
            logging.info(
                f"The training was finished at {trainer_options.max_epoch} epochs "
            )

        # Generated n-best averaged model
        if not distributed_option.distributed or distributed_option.dist_rank == 0:
            average_nbest_models(
                reporter=reporter,
                output_dir=output_dir,
                best_model_criterion=trainer_options.best_model_criterion,
                nbest=keep_nbest_models,
            )
Пример #7
0
class Trainer():
    def __init__(
        self,
        name = 'default',
        results_dir = 'results',
        models_dir = 'models',
        base_dir = './',
        optimizer = 'adam',
        num_workers = None,
        latent_dim = 256,
        image_size = 128,
        num_image_tiles = 8,
        fmap_max = 512,
        transparent = False,
        greyscale = False,
        batch_size = 4,
        gp_weight = 10,
        gradient_accumulate_every = 1,
        attn_res_layers = [],
        freq_chan_attn = False,
        disc_output_size = 5,
        dual_contrast_loss = False,
        antialias = False,
        lr = 2e-4,
        lr_mlp = 1.,
        ttur_mult = 1.,
        save_every = 1000,
        evaluate_every = 1000,
        aug_prob = None,
        aug_types = ['translation', 'cutout'],
        dataset_aug_prob = 0.,
        calculate_fid_every = None,
        calculate_fid_num_images = 12800,
        clear_fid_cache = False,
        is_ddp = False,
        rank = 0,
        world_size = 1,
        log = False,
        amp = False,
        *args,
        **kwargs
    ):
        self.GAN_params = [args, kwargs]
        self.GAN = None

        self.name = name

        base_dir = Path(base_dir)
        self.base_dir = base_dir
        self.results_dir = base_dir / results_dir
        self.models_dir = base_dir / models_dir
        self.fid_dir = base_dir / 'fid' / name

        self.config_path = self.models_dir / name / '.config.json'

        assert is_power_of_two(image_size), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
        assert all(map(is_power_of_two, attn_res_layers)), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)'

        assert not (dual_contrast_loss and disc_output_size > 1), 'discriminator output size cannot be greater than 1 if using dual contrastive loss'

        self.image_size = image_size
        self.num_image_tiles = num_image_tiles

        self.latent_dim = latent_dim
        self.fmap_max = fmap_max
        self.transparent = transparent
        self.greyscale = greyscale

        assert (int(self.transparent) + int(self.greyscale)) < 2, 'you can only set either transparency or greyscale'

        self.aug_prob = aug_prob
        self.aug_types = aug_types

        self.lr = lr
        self.optimizer = optimizer
        self.num_workers = num_workers
        self.ttur_mult = ttur_mult
        self.batch_size = batch_size
        self.gradient_accumulate_every = gradient_accumulate_every

        self.gp_weight = gp_weight

        self.evaluate_every = evaluate_every
        self.save_every = save_every
        self.steps = 0

        self.attn_res_layers = attn_res_layers
        self.freq_chan_attn = freq_chan_attn

        self.disc_output_size = disc_output_size
        self.antialias = antialias

        self.dual_contrast_loss = dual_contrast_loss

        self.d_loss = 0
        self.g_loss = 0
        self.last_gp_loss = None
        self.last_recon_loss = None
        self.last_fid = None

        self.init_folders()

        self.loader = None
        self.dataset_aug_prob = dataset_aug_prob

        self.calculate_fid_every = calculate_fid_every
        self.calculate_fid_num_images = calculate_fid_num_images
        self.clear_fid_cache = clear_fid_cache

        self.is_ddp = is_ddp
        self.is_main = rank == 0
        self.rank = rank
        self.world_size = world_size

        self.syncbatchnorm = is_ddp

        self.amp = amp
        self.G_scaler = GradScaler(enabled = self.amp)
        self.D_scaler = GradScaler(enabled = self.amp)

    @property
    def image_extension(self):
        return 'jpg' if not self.transparent else 'png'

    @property
    def checkpoint_num(self):
        return floor(self.steps // self.save_every)
        
    def init_GAN(self):
        args, kwargs = self.GAN_params

        # set some global variables before instantiating GAN

        global norm_class
        global Blur

        norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d
        Blur = nn.Identity if not self.antialias else Blur

        # handle bugs when
        # switching from multi-gpu back to single gpu

        if self.syncbatchnorm and not self.is_ddp:
            import torch.distributed as dist
            os.environ['MASTER_ADDR'] = 'localhost'
            os.environ['MASTER_PORT'] = '12355'
            dist.init_process_group('nccl', rank=0, world_size=1)

        # instantiate GAN

        self.GAN = LightweightGAN(
            optimizer=self.optimizer,
            lr = self.lr,
            latent_dim = self.latent_dim,
            attn_res_layers = self.attn_res_layers,
            freq_chan_attn = self.freq_chan_attn,
            image_size = self.image_size,
            ttur_mult = self.ttur_mult,
            fmap_max = self.fmap_max,
            disc_output_size = self.disc_output_size,
            transparent = self.transparent,
            greyscale = self.greyscale,
            rank = self.rank,
            *args,
            **kwargs
        )

        if self.is_ddp:
            ddp_kwargs = {'device_ids': [self.rank], 'output_device': self.rank, 'find_unused_parameters': True}

            self.G_ddp = DDP(self.GAN.G, **ddp_kwargs)
            self.D_ddp = DDP(self.GAN.D, **ddp_kwargs)
            self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs)

    def write_config(self):
        self.config_path.write_text(json.dumps(self.config()))

    def load_config(self):
        config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text())
        self.image_size = config['image_size']
        self.transparent = config['transparent']
        self.syncbatchnorm = config['syncbatchnorm']
        self.disc_output_size = config['disc_output_size']
        self.greyscale = config.pop('greyscale', False)
        self.attn_res_layers = config.pop('attn_res_layers', [])
        self.freq_chan_attn = config.pop('freq_chan_attn', False)
        self.optimizer = config.pop('optimizer', 'adam')
        self.fmap_max = config.pop('fmap_max', 512)
        del self.GAN
        self.init_GAN()

    def config(self):
        return {
            'image_size': self.image_size,
            'transparent': self.transparent,
            'greyscale': self.greyscale,
            'syncbatchnorm': self.syncbatchnorm,
            'disc_output_size': self.disc_output_size,
            'optimizer': self.optimizer,
            'attn_res_layers': self.attn_res_layers,
            'freq_chan_attn': self.freq_chan_attn
        }

    def set_data_src(self, folder):
        num_workers = default(self.num_workers, math.ceil(NUM_CORES / self.world_size))
        self.dataset = ImageDataset(folder, self.image_size, transparent = self.transparent, greyscale = self.greyscale, aug_prob = self.dataset_aug_prob)
        sampler = DistributedSampler(self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True) if self.is_ddp else None
        dataloader = DataLoader(self.dataset, num_workers = num_workers, batch_size = math.ceil(self.batch_size / self.world_size), sampler = sampler, shuffle = not self.is_ddp, drop_last = True, pin_memory = True)
        self.loader = cycle(dataloader)

        # auto set augmentation prob for user if dataset is detected to be low
        num_samples = len(self.dataset)
        if not exists(self.aug_prob) and num_samples < 1e5:
            self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
            print(f'autosetting augmentation probability to {round(self.aug_prob * 100)}%')

    def train(self):
        assert exists(self.loader), 'You must first initialize the data source with `.set_data_src(<folder of images>)`'
        device = torch.device(f'cuda:{self.rank}')

        if not exists(self.GAN):
            self.init_GAN()

        self.GAN.train()
        total_disc_loss = torch.zeros([], device=device)
        total_gen_loss = torch.zeros([], device=device)

        batch_size = math.ceil(self.batch_size / self.world_size)

        image_size = self.GAN.image_size
        latent_dim = self.GAN.latent_dim

        aug_prob   = default(self.aug_prob, 0)
        aug_types  = self.aug_types
        aug_kwargs = {'prob': aug_prob, 'types': aug_types}

        G = self.GAN.G if not self.is_ddp else self.G_ddp
        D = self.GAN.D if not self.is_ddp else self.D_ddp
        D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp

        apply_gradient_penalty = self.steps % 4 == 0

        # amp related contexts and functions

        amp_context = autocast if self.amp else null_context

        # discriminator loss fn

        if self.dual_contrast_loss:
            D_loss_fn = dual_contrastive_loss
        else:
            D_loss_fn = hinge_loss

        # train discriminator

        self.GAN.D_opt.zero_grad()
        for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, G]):
            latents = torch.randn(batch_size, latent_dim).cuda(self.rank)
            image_batch = next(self.loader).cuda(self.rank)
            image_batch.requires_grad_()

            with amp_context():
                with torch.no_grad():
                    generated_images = G(latents)

                fake_output, fake_output_32x32, _ = D_aug(generated_images, detach = True, **aug_kwargs)

                real_output, real_output_32x32, real_aux_loss = D_aug(image_batch,  calc_aux_loss = True, **aug_kwargs)

                real_output_loss = real_output
                fake_output_loss = fake_output

                divergence = D_loss_fn(real_output_loss, fake_output_loss)
                divergence_32x32 = D_loss_fn(real_output_32x32, fake_output_32x32)
                disc_loss = divergence + divergence_32x32

                aux_loss = real_aux_loss
                disc_loss = disc_loss + aux_loss

            if apply_gradient_penalty:
                outputs = [real_output, real_output_32x32]
                outputs = list(map(self.D_scaler.scale, outputs)) if self.amp else outputs

                scaled_gradients = torch_grad(outputs=outputs, inputs=image_batch,
                                       grad_outputs=list(map(lambda t: torch.ones(t.size(), device = image_batch.device), outputs)),
                                       create_graph=True, retain_graph=True, only_inputs=True)[0]

                inv_scale = safe_div(1., self.D_scaler.get_scale()) if self.amp else 1.

                if inv_scale != float('inf'):
                    gradients = scaled_gradients * inv_scale

                    with amp_context():
                        gradients = gradients.reshape(batch_size, -1)
                        gp =  self.gp_weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()

                        if not torch.isnan(gp):
                            disc_loss = disc_loss + gp
                            self.last_gp_loss = gp.clone().detach().item()

            with amp_context():
                disc_loss = disc_loss / self.gradient_accumulate_every

            disc_loss.register_hook(raise_if_nan)
            self.D_scaler.scale(disc_loss).backward()
            total_disc_loss += divergence

        self.last_recon_loss = aux_loss.item()
        self.d_loss = float(total_disc_loss.item() / self.gradient_accumulate_every)
        self.D_scaler.step(self.GAN.D_opt)
        self.D_scaler.update()

        # generator loss fn

        if self.dual_contrast_loss:
            G_loss_fn = dual_contrastive_loss
            G_requires_calc_real = True
        else:
            G_loss_fn = gen_hinge_loss
            G_requires_calc_real = False

        # train generator

        self.GAN.G_opt.zero_grad()

        for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[G, D_aug]):
            latents = torch.randn(batch_size, latent_dim).cuda(self.rank)

            if G_requires_calc_real:
                image_batch = next(self.loader).cuda(self.rank)
                image_batch.requires_grad_()

            with amp_context():
                generated_images = G(latents)

                fake_output, fake_output_32x32, _ = D_aug(generated_images, **aug_kwargs)
                real_output, real_output_32x32, _ = D_aug(image_batch, **aug_kwargs) if G_requires_calc_real else (None, None, None)

                loss = G_loss_fn(fake_output, real_output)
                loss_32x32 = G_loss_fn(fake_output_32x32, real_output_32x32)

                gen_loss = loss + loss_32x32

                gen_loss = gen_loss / self.gradient_accumulate_every

            gen_loss.register_hook(raise_if_nan)
            self.G_scaler.scale(gen_loss).backward()
            total_gen_loss += loss 

        self.g_loss = float(total_gen_loss.item() / self.gradient_accumulate_every)
        self.G_scaler.step(self.GAN.G_opt)
        self.G_scaler.update()

        # calculate moving averages

        if self.is_main and self.steps % 10 == 0 and self.steps > 20000:
            self.GAN.EMA()

        if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2:
            self.GAN.reset_parameter_averaging()

        # save from NaN errors

        if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
            print(f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}')
            self.load(self.checkpoint_num)
            raise NanException

        del total_disc_loss
        del total_gen_loss

        # periodically save results

        if self.is_main:
            if self.steps % self.save_every == 0:
                self.save(self.checkpoint_num)

            if self.steps % self.evaluate_every == 0 or (self.steps % 100 == 0 and self.steps < 20000):
                self.evaluate(floor(self.steps / self.evaluate_every), num_image_tiles = self.num_image_tiles)

            if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0:
                num_batches = math.ceil(self.calculate_fid_num_images / self.batch_size)
                fid = self.calculate_fid(num_batches)
                self.last_fid = fid

                with open(str(self.results_dir / self.name / f'fid_scores.txt'), 'a') as f:
                    f.write(f'{self.steps},{fid}\n')

        self.steps += 1

    @torch.no_grad()
    def evaluate(self, num = 0, num_image_tiles = 4):
        self.GAN.eval()

        ext = self.image_extension
        num_rows = num_image_tiles
    
        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        # latents and noise

        latents = torch.randn((num_rows ** 2, latent_dim)).cuda(self.rank)

        # regular

        generated_images = self.generate_(self.GAN.G, latents)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)
        
        # moving averages

        generated_images = self.generate_(self.GAN.GE, latents)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows)

    @torch.no_grad()
    def generate(self, num=0, num_image_tiles=4, checkpoint=None, types=['default', 'ema']):
        self.GAN.eval()

        latent_dim = self.GAN.latent_dim
        dir_name = self.name + str('-generated-') + str(checkpoint)
        dir_full = Path().absolute() / self.results_dir / dir_name
        ext = self.image_extension

        if not dir_full.exists():
            os.mkdir(dir_full)

        # regular
        if 'default' in types:
            for i in tqdm(range(num_image_tiles), desc='Saving generated default images'):
                latents = torch.randn((1, latent_dim)).cuda(self.rank)
                generated_image = self.generate_(self.GAN.G, latents)
                path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}.{ext}')
                torchvision.utils.save_image(generated_image[0], path, nrow=1)

        # moving averages
        if 'ema' in types:
            for i in tqdm(range(num_image_tiles), desc='Saving generated EMA images'):
                latents = torch.randn((1, latent_dim)).cuda(self.rank)
                generated_image = self.generate_(self.GAN.GE, latents)
                path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}-ema.{ext}')
                torchvision.utils.save_image(generated_image[0], path, nrow=1)

        return dir_full

    @torch.no_grad()
    def show_progress(self, num_images=4, types=['default', 'ema']):
        checkpoints = self.get_checkpoints()
        assert exists(checkpoints), 'cannot find any checkpoints to create a training progress video for'

        dir_name = self.name + str('-progress')
        dir_full = Path().absolute() / self.results_dir / dir_name
        ext = self.image_extension
        latents = None

        zfill_length = math.ceil(math.log10(len(checkpoints)))

        if not dir_full.exists():
            os.mkdir(dir_full)

        for checkpoint in tqdm(checkpoints, desc='Generating progress images'):
            self.load(checkpoint, print_version=False)
            self.GAN.eval()

            if checkpoint == 0:
                latents = torch.randn((num_images, self.GAN.latent_dim)).cuda(self.rank)

            # regular
            if 'default' in types:
                generated_image = self.generate_(self.GAN.G, latents)
                path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}.{ext}')
                torchvision.utils.save_image(generated_image, path, nrow=num_images)

            # moving averages
            if 'ema' in types:
                generated_image = self.generate_(self.GAN.GE, latents)
                path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}-ema.{ext}')
                torchvision.utils.save_image(generated_image, path, nrow=num_images)

    @torch.no_grad()
    def calculate_fid(self, num_batches):
        from pytorch_fid import fid_score
        torch.cuda.empty_cache()

        real_path = self.fid_dir / 'real'
        fake_path = self.fid_dir / 'fake'

        # remove any existing files used for fid calculation and recreate directories
        if not real_path.exists() or self.clear_fid_cache:
            rmtree(real_path, ignore_errors=True)
            os.makedirs(real_path)

            for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'):
                real_batch = next(self.loader)
                for k, image in enumerate(real_batch.unbind(0)):
                    ind = k + batch_num * self.batch_size
                    torchvision.utils.save_image(image, real_path / f'{ind}.png')

        # generate a bunch of fake images in results / name / fid_fake

        rmtree(fake_path, ignore_errors=True)
        os.makedirs(fake_path)

        self.GAN.eval()
        ext = self.image_extension

        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        for batch_num in tqdm(range(num_batches), desc='calculating FID - saving generated'):
            # latents and noise
            latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank)

            # moving averages
            generated_images = self.generate_(self.GAN.GE, latents)

            for j, image in enumerate(generated_images.unbind(0)):
                ind = j + batch_num * self.batch_size
                torchvision.utils.save_image(image, str(fake_path / f'{str(ind)}-ema.{ext}'))

        return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, latents.device, 2048)

    @torch.no_grad()
    def generate_(self, G, style, num_image_tiles = 8):
        generated_images = evaluate_in_chunks(self.batch_size, G, style)
        return generated_images.clamp_(0., 1.)

    @torch.no_grad()
    def generate_interpolation(self, num = 0, num_image_tiles = 8, num_steps = 100, save_frames = False):
        self.GAN.eval()
        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        # latents and noise

        latents_low = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)
        latents_high = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)

        ratios = torch.linspace(0., 8., num_steps)

        frames = []
        for ratio in tqdm(ratios):
            interp_latents = slerp(ratio, latents_low, latents_high)
            generated_images = self.generate_(self.GAN.GE, interp_latents)
            images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows)
            pil_image = transforms.ToPILImage()(images_grid.cpu())
            
            if self.transparent:
                background = Image.new('RGBA', pil_image.size, (255, 255, 255))
                pil_image = Image.alpha_composite(background, pil_image)
                
            frames.append(pil_image)

        frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True)

        if save_frames:
            folder_path = (self.results_dir / self.name / f'{str(num)}')
            folder_path.mkdir(parents=True, exist_ok=True)
            for ind, frame in enumerate(frames):
                frame.save(str(folder_path / f'{str(ind)}.{ext}'))

    def print_log(self):
        data = [
            ('G', self.g_loss),
            ('D', self.d_loss),
            ('GP', self.last_gp_loss),
            ('SS', self.last_recon_loss),
            ('FID', self.last_fid)
        ]

        data = [d for d in data if exists(d[1])]
        log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data))
        print(log)

    def model_name(self, num):
        return str(self.models_dir / self.name / f'model_{num}.pt')

    def init_folders(self):
        (self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
        (self.models_dir / self.name).mkdir(parents=True, exist_ok=True)

    def clear(self):
        rmtree(str(self.models_dir / self.name), True)
        rmtree(str(self.results_dir / self.name), True)
        rmtree(str(self.fid_dir), True)
        rmtree(str(self.config_path), True)
        self.init_folders()

    def save(self, num):
        save_data = {
            'GAN': self.GAN.state_dict(),
            'version': __version__,
            'G_scaler': self.G_scaler.state_dict(),
            'D_scaler': self.D_scaler.state_dict()
        }

        torch.save(save_data, self.model_name(num))
        self.write_config()

    def load(self, num=-1, print_version=True):
        self.load_config()

        name = num
        if num == -1:
            checkpoints = self.get_checkpoints()

            if not exists(checkpoints):
                return

            name = checkpoints[-1]
            print(f'continuing from previous epoch - {name}')

        self.steps = name * self.save_every

        load_data = torch.load(self.model_name(name))

        if print_version and 'version' in load_data and self.is_main:
            print(f"loading from version {load_data['version']}")

        try:
            self.GAN.load_state_dict(load_data['GAN'])
        except Exception as e:
            print('unable to load save model. please try downgrading the package to the version specified by the saved model')
            raise e

        if 'G_scaler' in load_data:
            self.G_scaler.load_state_dict(load_data['G_scaler'])
        if 'D_scaler' in load_data:
            self.D_scaler.load_state_dict(load_data['D_scaler'])

    def get_checkpoints(self):
        file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')]
        saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths))

        if len(saved_nums) == 0:
            return None

        return saved_nums
Пример #8
0
def distributed_worker(device, ngpus_per_node, args):
    torch.cuda.set_device(device)
    cudnn.benchmark = True
    print('%s: Use GPU: %d for training' % (time.ctime(), args.gpu_no[device]))

    rank = args.rank * ngpus_per_node + device
    batch_size = int(args.batch_size / ngpus_per_node)
    num_workers = int((args.num_workers + ngpus_per_node - 1) / ngpus_per_node)

    # init process for distributed training
    dist.init_process_group(backend=args.dist_backend,
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=rank)

    # load network
    network, optimizer, scheduler, loss_calculator = load_network(args, device)
    if device == 0:
        summary(network, input_size=(3, 512, 512))

    # load dataset
    dataset = load_dataset(args)
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=batch_size,
                                             num_workers=num_workers,
                                             pin_memory=True,
                                             sampler=sampler,
                                             collate_fn=dataset.collate_fn)

    # gradient scaler for automatic mixed precision
    scaler = GradScaler() if args.amp else None

    # training
    for epoch in range(args.start_epoch, args.end_epoch):
        sampler.set_epoch(epoch)

        # train one epoch
        train_step(dataloader, network, loss_calculator, optimizer, scheduler,
                   scaler, epoch, device, args)

        # adjust learning rate
        scheduler.step()

        # save network
        if rank % ngpus_per_node == 0:
            torch.save(
                {
                    'epoch':
                    epoch + 1,
                    'state_dict':
                    network.module.state_dict()
                    if hasattr(network, 'module') else network.state_dict(),
                    'optimizer':
                    optimizer.state_dict(),
                    'scheduler':
                    scheduler.state_dict(),
                    'scaler':
                    scaler.state_dict() if scaler is not None else None,
                    'loss_log':
                    loss_calculator.log
                },
                os.path.join(args.save_path,
                             'check_point_%d.pth' % (epoch + 1)))

    return None
Пример #9
0
        # if epoch%2==0:
        print(str(epoch) + ' ' + curr_train_stage + " loss: " + str(mean_volume_loss) + " time: " + str(mean_time))
        if os.path.isfile(main_folder+'exit_file.txt'):
            torch.cuda.empty_cache()
            sys.exit(0)

        if epoch%25==0:
            torch.save({
            'epoch': epoch,
            'args' : args,
            'args_SLNet' : argsSLNet,
            'statistics' : stats,
            'model_state_dict': net_get_params(net).state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict' : scaler.state_dict(),
            'loss': mean_volume_loss},
            save_folder + '/model_')#+str(epoch))
        if epoch%50==0:
            torch.save({
            'epoch': epoch,
            'args' : args,
            'args_SLNet' : argsSLNet,
            'statistics' : stats,
            'model_state_dict': net_get_params(net).state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict' : scaler.state_dict(),
            'loss': mean_volume_loss},
            save_folder + '/model_'+str(epoch))

    
Пример #10
0
def training(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #===================================#
    #==============Logging==============#
    #===================================#

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    handler = TqdmLoggingHandler()
    handler.setFormatter(logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S"))
    logger.addHandler(handler)
    logger.propagate = False

    #===================================#
    #============Data Load==============#
    #===================================#

    # 1) Dataloader setting
    write_log(logger, "Load data...")
    gc.disable()
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    dataset_dict = {
        'train': torchvision.datasets.CIFAR10(root='./dataset/cifar10', 
            train=True, download=False, transform=transform),
        'valid': torchvision.datasets.CIFAR10(root='./dataset/cifar10', 
            train=False, download=False, transform=transform)
    }
    dataloader_dict = {
        'train': DataLoader(dataset_dict['train'], drop_last=True,
                            batch_size=args.batch_size, shuffle=True, pin_memory=True,
                            num_workers=args.num_workers),
        'valid': DataLoader(dataset_dict['valid'], drop_last=False,
                            batch_size=args.batch_size, shuffle=False, pin_memory=True,
                            num_workers=args.num_workers)
    }
    gc.enable()
    write_log(logger, f"Total number of trainingsets  iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}")

    #===================================#
    #===========Model setting===========#
    #===================================#

    # 1) Model initiating
    write_log(logger, "Instantiating models...")
    model = Vision_Transformer(n_classes=10, img_size=32, patch_size=16)
    model.train()
    model = model.to(device)

    # 2) Optimizer setting
    # optimizer = AdamW(model.parameters(), lr=args.lr, eps=1e-8)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=1e-8)
    scheduler = shceduler_select(optimizer, dataloader_dict, args)
    scaler = GradScaler()

    # 2) Model resume
    start_epoch = 0
    if args.resume:
        checkpoint = torch.load(os.path.join(args.model_path, 'checkpoint.pth.tar'), map_location='cpu')
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        model = model.train()
        model = model.to(device)
        del checkpoint

    #===================================#
    #=========Model Train Start=========#
    #===================================#

    best_val_acc = 0

    write_log(logger, 'Train start!')

    for epoch in range(start_epoch, args.num_epochs):

        train_epoch(args, epoch, model, dataloader_dict['train'], optimizer, scheduler, scaler, logger, device)
        val_loss, val_acc = valid_epoch(args, model, dataloader_dict['valid'], device)

        val_loss /= len(dataloader_dict['valid'])
        val_acc /= len(dataloader_dict['valid'])
        write_log(logger, 'Validation Loss: %3.3f' % val_loss)
        write_log(logger, 'Validation Accuracy: %3.2f%%' % val_acc)
        if val_acc > best_val_acc:
            write_log(logger, 'Checkpoint saving...')
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'scaler': scaler.state_dict()
            }, f'checkpoint.pth.tar')
            best_val_acc = val_acc
            best_epoch = epoch
        else:
            else_log = f'Still {best_epoch} epoch accuracy({round(best_val_acc, 2)})% is better...'
            write_log(logger, else_log)

    # 3)
    print(f'Best Epoch: {best_epoch}')
    print(f'Best Accuracy: {round(best_val_acc, 2)}')
Пример #11
0
                'L_attr': L_attr.item(),
                'L_rec': L_rec.item()
            }, niter)
        writer.add_scalars('Train/Adversarial losses', {
            'Generator': lossG.item(),
            'Discriminator': lossD.item()
        }, niter)
    print(
        f'niter: {niter} (epoch: {epoch} {iteration}/{len(train_dataloader)})')
    print(
        f'    lossD: {lossD.item()} lossG: {lossG.item()} batch_time: {batch_time}s'
    )
    print(
        f'    L_adv: {L_adv.item()} L_id: {L_id.item()} L_attr: {L_attr.item()} L_rec: {L_rec.item()}'
    )
    if iteration % 1000 == 0:
        torch.save(G.state_dict(), './saved_models/G_latest.pth')
        torch.save(D.state_dict(), './saved_models/D_latest.pth')
        torch.save(opt_D.state_dict(), './saved_models/optG_latest.pth')
        torch.save(opt_D.state_dict(), './saved_models/optD_latest.pth')
        torch.save(scaler.state_dict(), './saved_models/scaler_latest.pth')
        with open('./saved_models/niter.pkl', 'wb') as f:
            pickle.dump(niter, f)
    if (niter + 1) % 10000 == 0:
        torch.save(G.state_dict(),
                   f'./saved_models/G_iteration_{niter + 1}.pth')
        torch.save(D.state_dict(),
                   f'./saved_models/D_iteration_{niter + 1}.pth')
        with open(f'./saved_models/niter_{niter + 1}.pkl', 'wb') as f:
            pickle.dump(niter, f)
Пример #12
0
class Learner(object):
    def __init__(self,
                 model,
                 optimizer,
                 loss_func,
                 name="",
                 scheduler=None,
                 device='cpu'):
        self.model = model
        self.optimizer = optimizer
        self.loss_func = loss_func
        self.scheduler = scheduler
        self.scaler = None
        self.device = device
        self.metric = None
        self.name = name
        self.log = {}
        self.eth = 0.99
        self.do_autocast = False

    def init_amp(self,
                 init_scale=65536.0,
                 growth_factor=2.0,
                 backoff_factor=0.5,
                 growth_interval=2000,
                 enabled=True,
                 do_autocast=True):
        self.do_autocast = do_autocast
        if GradScaler is not None:
            self.scaler = GradScaler(init_scale=init_scale,
                                     growth_factor=growth_factor,
                                     backoff_factor=backoff_factor,
                                     growth_interval=growth_interval,
                                     enabled=True)

    def get_y(self, batch):
        # get Y from Batch, the default is batch[-1] but you can overwrite it
        return batch[-1]

    def get_inds(self, batch):
        # get Y from Batch, the default is batch[-1] but you can overwrite it
        return batch[-1]

    def get_x(self, batch):
        # get x from Batch, the default is batch[:-1] but you can overwrite it
        if isinstance(batch, (list, tuple)):
            return batch[:-1]
        else:
            return [batch]

    def run_model(self, model, batch):
        return model(*(x.to(self.device) for x in self.get_x(batch)))

    def calc_loss(self, y_pred, y_true):
        return self.loss_func(y_pred, y_true.to(self.device))

    def one_cycle(self, batch, train=True, do_step=True):
        device = self.device
        self.preprocess_batch(batch, train)
        y_true = self.get_y(batch)
        if autocast is None:
            y_pred = self.run_model(self.model, batch)
            loss = self.calc_loss(y_pred, y_true)
            loss_item = 0 if np.isnan(loss.item()) else loss.item()
        else:
            with autocast(self.do_autocast):
                y_pred = self.run_model(self.model, batch)
                loss = self.calc_loss(y_pred, y_true)
                loss_item = 0 if np.isnan(loss.item()) else loss.item()
        if train:
            if self.scaler is not None:
                self.scaler.scale(loss).backward()
            else:
                loss.backward()
            if do_step:
                if self.scaler is not None:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    self.optimizer.step()
                if self.scheduler is not None:
                    self.scheduler.step()
                self.optimizer.zero_grad()
            if np.isnan(loss.item()):
                print('got loss = nan')
            loss_item = 0 if np.isnan(loss.item()) else loss.item()
        return loss_item if train else (loss_item, y_pred.to('cpu').detach())

    def one_training_epoch(self, dl, accumulation_steps=1):
        device = self.device
        torch.cuda.empty_cache()
        avg_loss = 0.
        lossf = 0.
        self.model = self.model.train()
        self.model.zero_grad()
        tk0 = notebook.tqdm(dl)
        for i, batch in enumerate(tk0):
            do_step = (i + 1) % accumulation_steps == 0
            loss_item = self.one_cycle(batch, train=True, do_step=do_step)
            e = min(self.eth, 1 - 1.0 / (i + 1.0))
            lossf = e * lossf + (1 - e) * loss_item
            tk0.set_postfix(loss=lossf)
            avg_loss += loss_item / len(dl)
        tk0.disable = False
        tk0.set_postfix(loss=avg_loss)
        tk0.disable = True
        return avg_loss

    def agg_tta(self, y):
        return np.stack(y,0).mean(0) if not isinstance(y[0],tuple)\
               else tuple(np.stack([yy[i] for yy in y],0).mean(0) for i in range(len(y[0])))

    def preprocess_batch(self, batch, train=True):
        return (batch)

    def one_eval_epoch(self, dl, tta=1):
        device = self.device
        avg_loss = 0.
        avg_accuracy = 0.
        lossf = 0
        self.model = self.model.eval()
        predss = []
        with torch.no_grad():
            for t in range(tta):
                pred_list = []
                true_list = []
                tk0 = notebook.tqdm(dl)
                for i, batch in enumerate(tk0):
                    loss_item, y_pred = self.one_cycle(batch,
                                                       train=False,
                                                       do_step=False)
                    pred_list.append(y_pred.to('cpu').numpy() if not isinstance(y_pred,tuple) else\
                        tuple(y.to('cpu').numpy() for y in y_pred))
                    y_batch = self.get_y(batch)
                    true_list.append(y_batch.to('cpu').numpy() if not isinstance(y_batch,tuple) else\
                        tuple(y.to('cpu').numpy() for y in y_batch))
                    e = min(self.eth, 1 - 1.0 / (i + 1.0))
                    lossf = e * lossf + (1 - e) * loss_item
                    tk0.set_postfix(loss=lossf)
                    avg_loss += loss_item / len(dl)
#                 y_true=np.concatenate(true_list,0)
                y_true=np.concatenate(true_list,0) if not isinstance(true_list[0],tuple) else\
                    tuple(np.concatenate([p[i] for p in true_list],0) for i in range(len(true_list[0])))
                predss.append(np.concatenate(pred_list,0) if not isinstance(pred_list[0],tuple) else\
                    tuple(np.concatenate([p[i] for p in pred_list],0) for i in range(len(pred_list[0]))))

            preds = self.agg_tta(predss, 0) if tta > 1 else predss[0]
            m = dict() if self.metric is None else self.metric(preds, y_true)
        tk0.disable = False
        tk0.set_postfix(loss=avg_loss, **m)
        tk0.disable = True
        return avg_loss, m

    def send_log(self, **kwargs):
        log = {'model': self.name}
        log.update(kwargs)
        try:
            sandesh.send(log)
        except:
            print(log)

    def save2log(self, **kwargs):
        for key in kwargs.keys():
            if key not in self.log:
                self.log[key] = []
            self.log[key].append(kwargs[key])

    def evaluate(self, ds, num_workers=8, tta=1, dl_args={'shuffle': False}):
        dl = D.DataLoader(ds, num_workers=num_workers, **dl_args)
        return self.one_eval_epoch(dl, tta=tta)

    def fit(self,
            num_epoches,
            train_ds,
            validate_ds=None,
            batch_size=None,
            lr=None,
            accumulation_steps=1,
            num_workers=8,
            send_log=True,
            eval_batch=None,
            reset_best=False,
            make_best=True,
            tta=1,
            train_dl_args={'shuffle': True},
            val_dl_args={'shuffle': False},
            save_checkpoint='best',
            path=''):
        if batch_size is not None:
            train_dl_args['batch_size'] = batch_size
            val_dl_args['batch_size'] = batch_size
        if eval_batch is not None:
            val_dl_args['batch_size'] = eval_batch

        tq = notebook.tqdm(range(num_epoches))
        if lr is not None:
            self.set_lr(lr)
        if reset_best or not hasattr(self, 'best_metric'):
            self.best_model = None
            self.best_metric = np.inf
        for k, epoch in enumerate(tq):
            self.on_epoch_begin(epoch,
                                train_ds=train_ds,
                                validate_ds=validate_ds)
            dl = D.DataLoader(train_ds,
                              num_workers=num_workers,
                              **train_dl_args)
            if next(self.model.parameters()).device != torch.device('cpu'):
                torch.cuda.empty_cache()
            tavg_loss = self.one_training_epoch(
                dl, accumulation_steps=accumulation_steps)
            #             dl=D.DataLoader(validate_ds, batch_size=batch_size if eval_batch is None else eval_batch,
            #                              num_workers=num_workers,**val_dl_args)
            if validate_ds is not None:
                avg_loss, metric = self.evaluate(validate_ds,
                                                 num_workers=num_workers,
                                                 dl_args=val_dl_args,
                                                 tta=tta)
            else:
                avg_loss = tavg_loss
                metric = {}
            if send_log:
                self.send_log(epoch=epoch,
                              tloss=tavg_loss,
                              loss=avg_loss,
                              **metric)
            self.save2log(epoch=epoch,
                          tloss=tavg_loss,
                          loss=avg_loss,
                          **metric)
            m = avg_loss if 'metric' not in metric.keys() else metric['metric']
            if save_checkpoint == 'last':
                self.save_checkpoint(path)
            if m < self.best_metric:
                self.best_metric = m
                self.best_model = copy.deepcopy(self.model.state_dict())
                tq.set_postfix(best_metric=self.best_metric)
                if save_checkpoint == 'best':
                    self.save_checkpoint(path)
            self.on_epoch_end(epoch)

        print('best metric:', self.best_metric)
        if make_best:
            self.model.load_state_dict(self.best_model)

    def save_model(self, path, name=None):
        name = self.name if name is None else name
        torch.save(self.model.state_dict(), f'{path}{name}')

    def load_model(self, path, name=None, map_location=None):
        name = self.name if name is None else name
        self.model.load_state_dict(
            torch.load(f'{path}{name}', map_location=map_location))

    def save_checkpoint(self, path, name=None):
        name = self.name + '.chk' if name is None else name
        checkpoint = {
            'model': self.model.state_dict(),
            'best_model': self.best_model,
            'best_metric': self.best_metric,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'log': self.log
        }
        if self.scaler:
            checkpoint['scaler'] = self.scaler.state_dict()
        torch.save(checkpoint, f'{path}{name}')

    def load_checkpoint(self, path, name=None):
        name = self.name + '.chk' if name is None else name + '.chk'
        checkpoint = torch.load(f'{path}{name}')
        self.model.load_state_dict(checkpoint['model'])
        self.best_model = checkpoint['best_model']
        self.best_metric = checkpoint['best_metric']
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.log = checkpoint['log']
        if 'scaler' in checkpoint.keys():
            self.scaler = GradScaler()
            self.scaler.load_state_dict(checkpoint['scaler'])
        else:
            self.scaler = None

    def set_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def on_epoch_begin(self, *args, **kargs):
        pass

    def on_epoch_end(self, *args, **kargs):
        pass

    def predict(self,
                ds,
                batch_size=None,
                num_workers=8,
                dl_args={'shuffle': False},
                return_inds=False,
                return_true=False,
                verbose=True,
                do_eval=True):
        device = self.device
        if batch_size is not None:
            dl_args['batch_size'] = batch_size
        dl = D.DataLoader(ds, num_workers=num_workers, **dl_args)
        pred_list = []
        inds_list = []
        true_list = []
        if do_eval:
            self.model = self.model.eval()
        with torch.no_grad():
            tk0 = notebook.tqdm(dl) if verbose else dl
            for i, batch in enumerate(tk0):
                if autocast is None:
                    y_pred = self.run_model(self.model, batch)
                else:
                    with autocast(self.scaler is not None):
                        y_pred = self.run_model(self.model, batch)
                if return_inds:
                    inds_list.append(self.get_inds(batch).to('cpu').numpy())
                if return_true:
                    yb = self.get_y(batch)
                    true_list.append(yb.to('cpu').numpy() if not isinstance(yb,tuple) else\
                                 tuple(y.to('cpu').numpy() for y in yb))
                pred_list.append(y_pred.to('cpu').numpy() if not isinstance(y_pred,tuple) else\
                                 tuple(y.to('cpu').numpy() for y in y_pred))
        pred = np.concatenate(pred_list,0) if not isinstance(pred_list[0],tuple) else\
                tuple(np.concatenate([p[i] for p in pred_list],0) for i in range(len(pred_list[0])))
        out = ()
        if return_inds:
            out = out + (np.concatenate(inds_list, 0), )
        if return_true:
            rt=np.concatenate(true_list,0) if not isinstance(true_list[0],tuple) else\
                    tuple(np.concatenate([p[i] for p in true_list],0) for i in range(len(true_list[0])))
            out = out + (rt, )

        return pred if len(out) == 0 else (pred, ) + out
Пример #13
0
class BaseModel(nn.Module):
    """ BaseModel

    This is the BaseModel used by all classifiers in this package. This base class provides a basic loop for fitting on a dataset and some convenience functions for storing and loading chckpoints. Each classifier is expected to provide 

    - `def forward(self, X)`: A forward method which predicts / applies the model to the given batch `X`. Since a BaseModel inherits from `nn.Module` please use `self.train` to distinguish between training and testing. 

    - `def prepare_backward(self, data, target, weights = None)`: A method which computes the loss for calling backward as well as additional statistics, such as running accuracy. The arguments are:

        - `data`: The examples in this batch
        - `target`: This is the corresponding target batch 
        - `weights`: This is the corresponding weights per example if required. 
    
        The `prepare_backward` function should return a dictionary with three fields `prediction`, `backward` and `metrics`. The `prediction` field stores the individual predictions for the batch (in the same order). The `backward` field is used to perform the gradient step and `metrics` is used to store any metrics which should be written reported. Formally, the following `backward` call is used:
        
            backward = self.prepare_backward(data, target, weights)
            loss = backward["backward"].mean()
            loss.backward()
        
        Note that the prediction / loss / metrics should be given for each individual example in the batch. __Do not reduce / sum / mean the loss etc manually__. This happens automatically later on. An example would be:

            d = {
                # apply the model
                "prediction" : self(data),  

                # compute the loss
                "backward" : self.loss_function(self(data), target), 

                # Compute some metrics
                "metrics" :
                {
                    "loss" : self.loss_function(self(data), target).detach(), 
                    "accuracy" : 100.0*(self(data).argmax(1) == target).type(self.get_float_type()) 
                } 
            }

    The bas class also supports storing and loading of checkpoints. To do so, the implementing class must take care of its parameters / object by overriding  `restore_state` and `get_state`. Note that thse function __must__ call the respective functions frm the base class:

        def restore_state(self,checkpoint):
            # Restore base state
            super().restore_state(checkpoint)

            # Extract and parameters from the chechpoint dictionary
            self.my_param = checkpoint["my_param"]

        def get_state(self):
            # Get base state
            state = super().get_state()
            
            return {
                **state,
                "my_param":self.my_param
            } 
        
    This class already expects a fair amount of parameters. Thus, it is best to use `args` and `kwargs` to pass parameters between c'tors. The following pattern is used as a best-practice to implement new classifier:

        class MyClass(Model): 
            def __init__(self, my_param, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.my_param

    Attributes:
        optimizer (dict): Dictionary of optimizer and its parameters. This dictionary is expected to have at-least two entries

            - `method`: The actual optimizer to be used, e.g. `torch.optim.SGD`
            - `epochs`: The number of epochs used for optimization. If this is not provided, it will be set to 1
            
            Any additional field will be passed to the optimizer object: 

                optimizer_method = optimizer.pop("method")
                epochs = optimizer.pop("epochs", 1)
                the_optimizer = optimizer_method(model.parameters(), **optimizer)
            
            An example would be

                optimizer = {
                    "method" : torch.optim.SGD,
                    "lr" : 1e-2,
                    "epochs" : 150
                }

        scheduler (dict): Dictionary of learning rate scheduler and its parameters. This can be `None` if no scheduling is desired. Otherwise, its expected to contain a `method` field which is the scheduler. Any additional field will be used to create the this object

                scheduler_method = scheduler.pop("method")
                the_scheduler = scheduler_method(the_optimizer, **scheduler)

            An example would be

                scheduler = {
                    "method" : torch.optim.lr_scheduler.StepLR,
                    "step_size" : 25,
                    "gamma": 0.5
                }

        loss_function (function): The loss function which should be minimized. Technically this class does not make use of this function, but only stores it for sub-classes.
        
        base_estimator (function): The (base) neural network which should be trained. Technically this class does not make use of this field, but only stores it for sub-classes.
        
        training_file (str, optional): Filename used to store metrics during training. Is only used if `out_path` is not None. Defaults to "trainings.jsonl"
        
        seed (long): Random seed for involved in any randomization process
        
        verbose (bool): If `true`, prints the progress of each epoch including metrics via `tqdm` else disables it 
        
        out_path (str, optional): Path to the folder where training metrics should be stored. If no path is given, nothing is stored. Defaults to `None`
        
        test_data (optional): Test data which can be used to compute statistics every `eval_every` epochs. It should be compatible with PyTorch `DataLoader`, e.g. this should be a `torch.utils.data.Dataset` or a numpy error / PyTorch tensor:
        
                test_loader = torch.utils.data.DataLoader(self.test_data, **self.loader_cfg)
        
            Defaults to `None`, which means no additional metrics are computed besides the one already obtained on the training data.
        
        loader (dict, optional): Dictionary of loader parameters which are passed to `torch.utils.data.DataLoader`:
            
                train_loader = torch.utils.data.DataLoader(
                    data,
                    **self.loader
                ) 
        
            The loader is used for both, the training data and `test_data` if supplied. The loader can be `None` which defaults to:
        
                self.loader = {'num_workers': 1, 'pin_memory': True, 'batch_size':128} 
        
        eval_every (int, optional): Evaluates metrics on the test_data every `eval_every` epochs, if `test_data` is provided. Defaults to 5. If this is `None` no additonal metrics are computed.
        
        store_every (int, optional): Stores a checkpoint of the model every `store_every` epochs. If this is `None` no checkpoints are stored
        
        device (str, optional): The device which is used to execute the model. Should be compatible to PyTorch's keywords. Defaults to "cuda"
        
        use_amp (bool): If `true` uses mixed precision provided by PyTorch, else not.
    """
    def __init__(self,
                 optimizer,
                 scheduler,
                 loss_function,
                 base_estimator,
                 training_file="training.jsonl",
                 seed=None,
                 verbose=True,
                 out_path=None,
                 test_data=None,
                 eval_every=5,
                 store_every=None,
                 device="cuda",
                 loader=None,
                 use_amp=False,
                 *args,
                 **kwargs):
        super().__init__()

        if isinstance(
                base_estimator,
                types.LambdaType) and base_estimator.__name__ == "<lambda>":
            print(
                "Warning: base_estimator is a lambda function in Models.py - This is fine, unless you want to store checkpoints of your model. This will likely fail since unnamed functions cannot be pickled. Consider naming it."
            )

        if optimizer is not None:
            optimizer_copy = copy.deepcopy(optimizer)
            self.optimizer_method = optimizer_copy.pop("method")
            if "epochs" in optimizer_copy:
                self.epochs = optimizer_copy.pop("epochs")
            else:
                self.epochs = 1

            self.optimizer_cfg = optimizer_copy
        else:
            self.optimizer_cfg = None

        if scheduler is not None:
            scheduler_copy = copy.deepcopy(scheduler)
            self.scheduler_method = scheduler_copy.pop("method")
            self.scheduler_cfg = scheduler_copy
        else:
            self.scheduler_cfg = None

        if loader is not None:
            self.loader_cfg = loader
        else:
            self.loader_cfg = {
                'num_workers': 1,
                'pin_memory': True,
                'batch_size': 128
            }

        self.base_estimator = base_estimator
        self.loss_function = loss_function
        self.verbose = verbose
        self.out_path = out_path
        self.test_data = test_data
        self.seed = seed
        self.eval_every = eval_every
        self.store_every = store_every
        self.training_file = training_file
        self.cur_epoch = 0
        self.resume_from_checkpoint = False
        self.device = device
        self.use_amp = use_amp

        if self.seed is not None:
            np.random.seed(self.seed)
            random.seed(self.seed)
            torch.manual_seed(self.seed)
            # if you are using GPU
            if self.device != "cpu":
                torch.cuda.manual_seed(self.seed)
                torch.cuda.manual_seed_all(self.seed)

    def get_float_type(self):
        if self.device == "cpu":
            return torch.FloatTensor
        else:
            return torch.cuda.FloatTensor

    def restore_state(self, checkpoint):
        self.optimizer_method = checkpoint["optimizer_method"]
        self.optimizer_cfg = checkpoint["optimizer_cfg"]
        self.scheduler_method = checkpoint["scheduler_method"]
        self.scheduler_cfg = checkpoint["scheduler_cfg"]
        self.loader_cfg = checkpoint["loader_cfg"]
        self.scheduler = checkpoint["scheduler"]
        self.base_estimator = checkpoint["base_estimator"]
        self.loss_function = checkpoint["loss_function"]
        self.verbose = checkpoint["verbose"]
        self.out_path = checkpoint["out_path"]
        self.test_data = checkpoint["test_data"]
        self.seed = checkpoint["seed"]
        self.eval_every = checkpoint["eval_every"]
        self.store_every = checkpoint["store_every"]
        self.training_file = checkpoint["training_file"]
        self.cur_epoch = checkpoint["cur_epoch"]
        self.epochs = checkpoint["epochs"]
        self.resume_from_checkpoint = True
        self.device = checkpoint["device"]
        self.use_amp = checkpoint["use_amp"]
        self.scaler = GradScaler(enabled=self.use_amp)

        self.scaler.load_state_dict(checkpoint['scaler_state_dict'])

        if self.seed is not None:
            np.random.seed(self.seed)
            random.seed(self.seed)
            torch.manual_seed(self.seed)
            # if you are using GPU
            if self.device != "cpu":
                torch.cuda.manual_seed(self.seed)
                torch.cuda.manual_seed_all(self.seed)

        self.load_state_dict(checkpoint['state_dict'])

        # Load the model to the correct device _before_ we init the optimizer
        # https://github.com/pytorch/pytorch/issues/2830
        self.to(self.device)

        self.optimizer = self.optimizer_method(self.parameters(),
                                               **self.optimizer_cfg)
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        if self.scheduler_method is not None:
            self.scheduler = self.scheduler_method(self.optimizer,
                                                   **self.scheduler_cfg)
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        else:
            self.scheduler = None

        if self.loader_cfg is None:
            self.loader_cfg = {
                'num_workers': 1,
                'pin_memory': True,
                'batch_size': 128
            }

    def restore_checkoint(self, path):
        # https://github.com/pytorch/pytorch/issues/2830
        checkpoint = torch.load(path, map_location=self.device)
        self.restore_state(checkpoint)

    def get_state(self):
        return {
            "optimizer_method": self.optimizer_method,
            "optimizer_cfg": self.optimizer_cfg,
            "loader_cfg": self.loader_cfg,
            "scheduler_method": self.scheduler_method,
            "scheduler_cfg": self.scheduler_cfg,
            "scheduler": self.scheduler,
            "base_estimator": self.base_estimator,
            "loss_function": self.loss_function,
            "verbose": self.verbose,
            "out_path": self.out_path,
            "test_data": self.test_data,
            "seed": self.seed,
            "device": self.device,
            "eval_every": self.eval_every,
            "store_every": self.store_every,
            "training_file": self.training_file,
            'cur_epoch': self.cur_epoch,
            'epochs': self.epochs,
            'use_amp': self.use_amp,
            'scaler_state_dict': self.scaler.state_dict()
        }

    def store_checkpoint(self):
        state = self.get_state()
        torch.save(
            state,
            os.path.join(self.out_path, 'model_{}.tar'.format(self.cur_epoch)))

    @abstractmethod
    def forward(self, X):
        pass

    @abstractmethod
    def prepare_backward(self, data, target, weights=None):
        pass

    def fit(self, data):
        if not self.resume_from_checkpoint:
            self.optimizer = self.optimizer_method(self.parameters(),
                                                   **self.optimizer_cfg)

            if self.scheduler_method is not None:
                self.scheduler = self.scheduler_method(self.optimizer,
                                                       **self.scheduler_cfg)
            else:
                self.scheduler = None

            self.scaler = GradScaler(enabled=self.use_amp)

            if self.out_path is not None:
                outfile = open(self.out_path + "/" + self.training_file, "w",
                               1)
        else:
            if self.out_path is not None:
                outfile = open(self.out_path + "/" + self.training_file, "a",
                               1)

        train_loader = torch.utils.data.DataLoader(data,
                                                   shuffle=True,
                                                   **self.loader_cfg)

        self.to(self.device)

        self.train()
        for epoch in range(self.cur_epoch, self.epochs):
            self.cur_epoch = epoch + 1
            metrics = {}
            example_cnt = 0

            with tqdm(total=len(train_loader.dataset),
                      ncols=150,
                      disable=not self.verbose) as pbar:
                self.batch_cnt = 0
                for batch in train_loader:
                    if len(batch) == 1:
                        data = batch
                    else:
                        data = batch[0]

                    data = data.to(self.device)
                    data = Variable(data)

                    if len(batch) > 1:
                        target = batch[1]
                        target = target.to(self.device)
                        target = Variable(target)
                    else:
                        target = None

                    if len(batch) > 2:
                        weights = batch[2]
                        weights = weights.to(self.device)
                        weights = Variable(weights)
                    else:
                        weights = None

                    example_cnt += data.shape[0]

                    self.optimizer.zero_grad()

                    # We assume that prepare_backward computes the appropriate loss and possible some statistics
                    # the user wants to store / output. To do so, prepare_backward should return a dictionary with
                    # three fields. An example is given below. Note that the prediction / loss / metrics should be
                    # given for each individual example in the batch.
                    #    !!!! Do not reduce / sum / mean the loss etc manually !!!!
                    # This is done afterwards in this code.
                    #
                    # d = {
                    #     "prediction" : self(data),
                    #     "backward" : self.loss_function(self(data), target),
                    #     "metrics" :
                    #     {
                    #         "loss" : self.loss_function(self(data), target),
                    #         "accuracy" : 100.0*(self(data).argmax(1) == target).type(self.get_float_type())
                    #     }
                    # }
                    with autocast(enabled=self.use_amp):
                        backward = self.prepare_backward(data, target, weights)
                        loss = backward["backward"].mean()

                    for key, val in backward["metrics"].items():
                        metrics[key] = metrics.get(key, 0) + val.sum().item()

                    self.scaler.scale(loss).backward()
                    self.scaler.step(self.optimizer)
                    self.scaler.update()

                    mstr = ""
                    for key, val in metrics.items():
                        mstr += "{} {:2.4f} ".format(key, val / example_cnt)

                    pbar.update(data.shape[0])
                    desc = '[{}/{}] {}'.format(epoch, self.epochs - 1, mstr)
                    pbar.set_description(desc)
                    self.batch_cnt += 1

                if self.scheduler is not None:
                    self.scheduler.step()

                #torch.cuda.empty_cache()

                if self.out_path is not None:
                    out_dict = {}

                    mstr = ""
                    for key, val in metrics.items():
                        out_dict["train_" + key] = val / example_cnt
                        mstr += "{} {:2.4f} ".format(key, val / example_cnt)

                    if self.store_every and self.store_every > 0 and (
                            epoch % self.store_every) == 0:
                        self.store_checkpoint()

                    if self.test_data and self.eval_every and self.eval_every > 0 and (
                            epoch % self.eval_every) == 0:
                        # This is basically a different version of apply_in_batches but using the "new" prepare_backward interface
                        # for evaluating the test data. Maybe we should refactor this at some point and / or apply_in_batches
                        # is not really needed anymore as its own function?
                        # TODO Check if refactoring might be interestring here
                        self.eval()

                        test_metrics = {}
                        test_loader = torch.utils.data.DataLoader(
                            self.test_data, **self.loader_cfg)

                        for batch in test_loader:
                            test_data = batch[0]
                            test_target = batch[1]
                            test_data, test_target = test_data.to(
                                self.device), test_target.to(self.device)
                            test_data, test_target = Variable(
                                test_data), Variable(test_target)
                            with torch.no_grad():
                                backward = self.prepare_backward(
                                    test_data, test_target)

                            for key, val in backward["metrics"].items():
                                test_metrics[key] = test_metrics.get(
                                    key, 0) + val.sum().item()

                        self.train()
                        for key, val in test_metrics.items():
                            out_dict["test_" +
                                     key] = val / len(test_loader.dataset)
                            mstr += "test {} {:2.4f} ".format(
                                key, val / len(test_loader.dataset))

                    desc = '[{}/{}] {}'.format(epoch, self.epochs - 1, mstr)
                    pbar.set_description(desc)

                    out_dict["epoch"] = epoch
                    out_file_content = json.dumps(out_dict,
                                                  sort_keys=True) + "\n"
                    outfile.write(out_file_content)

            if hasattr(train_loader.dataset, "end_of_epoch"):
                train_loader.dataset.end_of_epoch()
Пример #14
0
class TrainerLoop:
    def __init__(
        self,
        config: DictConfig,
        model: FlyModel,
        train_dataloader_fn: Callable,
        valid_dataloader_fn: Callable = None,
        test_dataloader_fn: Callable = None
    ):
        """
        Args:
            config: FlyConfig dictionary
            model: must be FlyModel
            dataloader_fn: a Callable function which returns dataloaders
        """
        assert isinstance(model, FlyModel)
        self.config = config
        self.model = model

        # For distributed
        self.rank = int(os.environ.get("RANK", 0))
        self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.distributed_training = (self.world_size > 1)

        if self.distributed_training and not torch.distributed.is_initialized():
            torch.distributed.init_process_group(backend='nccl', init_method='env://')
            assert torch.distributed.is_initialized()

        if self.distributed_training and not torch.distributed.is_initialized():
            self.node_rank = os.environ.get("NODE_RANK", "N/A")
            logger.info(
                f"Initialized Rank:{torch.distributed.get_rank()} Locak-rank: {self.local_rank} on Node:{self.node_rank} Node-name:{socket.gethostname()}"
            )

        logger.info("TrainerLoop is initializing!")

        # set cuda device
        if config.training.num_gpus_per_node > 0:
            torch.cuda.set_device(self.local_rank)
            self.device = torch.device("cuda", self.local_rank)
        else:
            self.device = torch.device("cpu")

        # Setup the dataloders
        self.train_dataloader = train_dataloader_fn() if train_dataloader_fn else None
        # only rank 0 can setup validation and test dataloder
        if self.rank == 0:
            self.validation_dataloader: Iterable = valid_dataloader_fn() if valid_dataloader_fn else None
            self.test_dataloader = test_dataloader_fn() if test_dataloader_fn else None

        # Setup callback handler
        self.callback_handler = CallbackHandler(
            config, trainer=self, callbacks=[], verbose=config.training.logging.level == "DEBUG"
        )

        # constants
        self.fp16 = config.training.fp16
        self.gradient_accumulation_batches = config.training.gradient_accumulation_batches

        self.setup_training_constants()

        # local variables
        self.global_batch_count = 0
        self.global_step_count = 0
        self.epochs_trained = 0
        self.local_step_count = 0

        # Configure optimizers
        self.optimizers, self.schedulers = self.model.configure_optimizers(self.total_num_update_steps)
        self.optimizers, self.schedulers = self.configure_optimizers()

        # Model is sent to GPU or CPU
        self.model = move_to_device(self.model, self.device)

        # Mixed-Precision
        if self.fp16:
            if self.config.training.num_gpus_per_node == 0:
                raise NotImplementedError("For mixed precision training, you need to use GPU!")
            self.configure_fp16()

        # Distributed Training
        if self.world_size > 1:
            self.configure_ddp()

        # Configure all callbacks
        self.configure_callbacks()
        self.callback_handler.fire_event(Events.INITIALIZE)

        # make sure the model has access to trainer info
        self.model.set_trainer(self)

    def setup_training_constants(self):
        self.total_num_update_steps = int(self.config.training.total_num.update_steps)
        self.total_num_batches = self.total_num_update_steps * int(self.gradient_accumulation_batches)
        self.total_num_epochs = int(self.config.training.total_num.epochs)

        # check if training in epoch or update_steps
        if self.total_num_update_steps < 0 and self.total_num_epochs < 0:
            raise NotImplementedError("config.training.total_num.updated_steps must be larger than 0")
        elif self.total_num_update_steps > 0 and self.total_num_epochs > 0:
            raise NotImplementedError(
                "Please only set either config.training.total_num.updated_steps or config.training.total_num.epochs greater than 0"
            )
        elif self.total_num_update_steps > 0 and self.total_num_epochs < 0:
            self.training_in_epoch = False
        elif self.total_num_update_steps < 0 and self.total_num_epochs > 0:
            self.training_in_epoch = True

        # get the number of batches in the dataloader for one epoch
        try:
            self.epoch_num_batches = len(self.train_dataloader)
        except TypeError:
            logger.warning("Cannot determine the length of train_dataloader!")
            self.epoch_num_batches = None

        if self.training_in_epoch:
            if self.epoch_num_batches is not None:
                self.total_num_batches = self.epoch_num_batches * self.total_num_epochs
                self.total_num_update_steps = self.total_num_batches // self.gradient_accumulation_batches
                self.epoch_num_update_steps = self.epoch_num_batches // self.gradient_accumulation_batches
            else:
                # this is set to wait until the epoch finishes first
                self.total_num_update_steps = sys.maxsize

    def configure_optimizers(self):
        return self.model.configure_optimizers(self.total_num_update_steps)

    def configure_callbacks(self):
        # Resume callback runs for all ranks
        self.resume_callback = Resume(self.config)
        self.add_callback(self.resume_callback)

        # For logging and inference, use rank 0 by default
        if self.rank == 0:
            self.log_callback = TrainLogger(self.config)
            self.add_callback(self.log_callback)

            self.eval_callback = Evaluation(self.config)
            self.add_callback(self.eval_callback)

            if self.config.training.console:
                self.console_callback = Console(self.config)
                self.add_callback(self.console_callback)

            self.checkpoint_callback = Checkpoint(self.config)
            self.add_callback(self.checkpoint_callback)

    def configure_fp16(self):
        self.loss_scaler = GradScaler()

    def configure_ddp(self):
        """
        Default distributed training uses reducer for simplicity. 
        """
        # Distributed training (should be after apex fp16 initialization)
        self.distributed_training = True
        self.reducer = Reducer(self.model)
        # for param in self.model.parameters():
        #     dist.broadcast(param.data, 0)

        # self.model = DistributedDataParallel(self.model, delay_allreduce=True)
        # trainer.model = torch.nn.parallel.DistributedDataParallel(
        #     trainer.model, device_ids=[trainer.rank], output_device=trainer.rank, find_unused_parameters=True
        # )

    def train(self):
        # Training begins
        self.callback_handler.fire_event(Events.TRAIN_BEGIN)

        while True:
            self.callback_handler.fire_event(Events.EPOCH_BEGIN)
            self.train_epoch()
            self.callback_handler.fire_event(Events.EPOCH_END)
            self.epochs_trained += 1

            if self.training_in_epoch:
                if self.epochs_trained >= self.total_num_epochs:
                    break
            else:
                if self.global_step_count < self.total_num_update_steps:
                    continue
                else:
                    break

        # Training ends
        self.callback_handler.fire_event(Events.TRAIN_END)

    def train_epoch(self):
        self.optimizer = self.optimizers[0]
        self.scheduler = self.schedulers[0]

        self.local_step_count = 0

        if self.train_dataloader is None:
            return

        for batch in self.train_dataloader:
            self.callback_handler.fire_event(Events.BATCH_BEGIN)

            batch = move_to_device(batch, self.device)
            output = self.backward_batch(batch)

            # Update the model
            if (self.global_batch_count + 1) % self.gradient_accumulation_batches == 0:
                # Update the model with optimizer
                self.step_update(self.model, self.optimizer, self.scheduler)
                self.global_step_count += 1
                self.local_step_count += 1

            self.callback_handler.fire_event(Events.BATCH_END)

            if self.global_step_count >= self.total_num_update_steps:
                break

            self.global_batch_count += 1

    def backward_batch(self, batch):
        self.model.train()
        with torch.cuda.amp.autocast(self.fp16):
            output = self.model(batch)

        # get the loss from output
        if hasattr(output, "loss"):
            loss = output.loss
        elif isinstance(output, dict):
            loss = output["loss"]

        if self.gradient_accumulation_batches > 1:
            loss = loss / self.gradient_accumulation_batches

        self.loss_backward(loss)
        return output

    def step_update(self, model, optimizer, scheduler=None):
        """
            self.loss_scaler is defined in `configure_fp16`
        """
        self.callback_handler.fire_event(Events.STEP_BEGIN)
        # collect gradient
        if self.distributed_training:
            self.reducer.reduce()

        gradient_clip = self.config.training.optimization.max_gradient_norm
        # Gradient Clipping
        if gradient_clip > 0:
            if self.fp16:
                self.loss_scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
        # Update the model
        if self.fp16:
            self.loss_scaler.step(optimizer)
            self.loss_scaler.update()
        else:
            optimizer.step()
        # Step learning rate
        if scheduler:
            scheduler.step()
        # Gradient to zero
        optimizer.zero_grad()
        self.callback_handler.fire_event(Events.STEP_END)

    def loss_backward(self, loss):
        self.callback_handler.fire_event(Events.BACKWARD_BEGIN)
        # Loss backward
        if self.fp16:
            self.loss_scaler.scale(loss).backward()
        else:
            loss.backward()
        self.callback_handler.fire_event(Events.BACKWARD_END)

    def validate(self):
        # Start Validation
        self.model.eval()
        self.model.reset_evaluation_metrics()
        self.callback_handler.fire_event(Events.VALIDATE_BEGIN)
        # No gradient is needed for validation
        with torch.no_grad():
            pbar = tqdm.tqdm(self.validation_dataloader)
            pbar.mininterval = 2.0
            for batch in pbar:
                # send to cuda device
                batch = move_to_device(batch, self.device)
                self.model.predict(batch)

        self.callback_handler.fire_event(Events.VALIDATE_END)

    def test(self):
        # Start Testing
        self.model.eval()
        self.model.reset_evaluation_metrics()
        self.callback_handler.fire_event(Events.TEST_BEGIN)
        # No gradient is needed for test
        with torch.no_grad():
            pbar = tqdm.tqdm(self.test_dataloader)
            pbar.mininterval = 2.0
            for batch in pbar:
                # send to cuda device
                batch = move_to_device(batch, self.device)
                self.model.predict(batch)

        self.callback_handler.fire_event(Events.TEST_END)

    def set_model_state(self, model_state_dict):
        self.model.load_state_dict(model_state_dict)

    def get_model_state(self):
        return self.model.state_dict()

    def set_trainer_state(self, trainer_state_dict):
        self.epochs_trained = trainer_state_dict["epochs_trained"]
        self.global_step_count = trainer_state_dict["global_step_count"]
        self.local_step_count = trainer_state_dict["local_step_count"]

        # Resume the training state
        if self.config.training.resume.resume:
            # Scheduler States
            if self.config.training.resume.resume_scheduler:
                for idx, scheduler in enumerate(self.schedulers):
                    try:
                        scheduler.load_state_dict(trainer_state_dict["schedulers_state_dict"][idx])
                    except:
                        if self.rank == 0:
                            logger.warning(f"Cannot Load Scheduler {idx}'s State!")

            if self.config.training.resume.resume_optimizer:
                for idx, optimizer in enumerate(self.optimizers):
                    try:
                        optimizer.load_state_dict(trainer_state_dict["optimizers_state_dict"][idx])
                    except:
                        if self.rank == 0:
                            logger.warning(f"Cannot Load Optimizer {idx}'s State!")

            # save amp states
            if self.fp16:
                self.loss_scaler.load_state_dict(trainer_state_dict["amp_state_dict"])

            # Random States
            if self.config.training.resume.resume_rng_state:
                torch.set_rng_state(trainer_state_dict["cpu_rng_state"])
                trainer_state_dict["cuda_rng_state"] = trainer_state_dict["cuda_rng_state"][:torch.cuda.device_count()]
                torch.cuda.set_rng_state_all(trainer_state_dict["cuda_rng_state"])

            # All Callbacks
            for callback in self.callback_handler.callbacks:
                try:
                    callback.load_state_dict(trainer_state_dict[str(type(callback))])
                except:
                    logger.error(f"{type(callback)} seems not to exist in the checkpoint state!")

    def get_trainer_state(self):
        trainer_state_dict = {
            "epochs_trained": self.epochs_trained,
            "global_step_count": self.global_step_count,
            "local_step_count": self.local_step_count,
            "optimizers_state_dict": [optimizer.state_dict() for optimizer in self.optimizers],
            "schedulers_state_dict": [scheduler.state_dict() for scheduler in self.schedulers],
            "cpu_rng_state": torch.get_rng_state(),
            "cuda_rng_state": torch.cuda.get_rng_state_all(),
        }

        # save amp states
        if self.fp16:
            trainer_state_dict["amp_state_dict"] = self.loss_scaler.state_dict()

        # All Callbacks
        for callback in self.callback_handler.callbacks:
            trainer_state_dict[str(type(callback))] = callback.state_dict()

        return trainer_state_dict

    def add_callback(self, callback: Callback):
        self.callback_handler.add_callback(callback)


# def get_lr(optimizer):
#     for param_group in optimizer.param_groups:
#         return param_group['lr']

# def get_log_variable(x):
#     if isinstance(x, torch.Tensor):
#         x = x.detach()
#         return x.item()
#     else:
#         raise NotImplementedError
Пример #15
0
class BaseTrainer:
    def __init__(self, dist, rank, config, resume, only_validation, model,
                 loss_function, optimizer):
        self.color_tool = colorful
        self.color_tool.use_style("solarized")

        model = DistributedDataParallel(model.to(rank), device_ids=[rank])
        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_function

        # DistributedDataParallel (DDP)
        self.rank = rank
        self.dist = dist

        # Automatic mixed precision (AMP)
        self.use_amp = config["meta"]["use_amp"]
        self.scaler = GradScaler(enabled=self.use_amp)

        # Acoustics
        self.acoustic_config = config["acoustics"]

        # Supported STFT
        n_fft = self.acoustic_config["n_fft"]
        hop_length = self.acoustic_config["hop_length"]
        win_length = self.acoustic_config["win_length"]

        self.torch_stft = partial(stft,
                                  n_fft=n_fft,
                                  hop_length=hop_length,
                                  win_length=win_length)
        self.torch_istft = partial(istft,
                                   n_fft=n_fft,
                                   hop_length=hop_length,
                                   win_length=win_length)
        self.librosa_stft = partial(librosa.stft,
                                    n_fft=n_fft,
                                    hop_length=hop_length,
                                    win_length=win_length)
        self.librosa_istft = partial(librosa.istft,
                                     hop_length=hop_length,
                                     win_length=win_length)

        # Trainer.train in the config
        self.train_config = config["trainer"]["train"]
        self.epochs = self.train_config["epochs"]
        self.save_checkpoint_interval = self.train_config[
            "save_checkpoint_interval"]
        self.clip_grad_norm_value = self.train_config["clip_grad_norm_value"]
        assert self.save_checkpoint_interval >= 1, "Check the 'save_checkpoint_interval' parameter in the config. It should be large than one."

        # Trainer.validation in the config
        self.validation_config = config["trainer"]["validation"]
        self.validation_interval = self.validation_config[
            "validation_interval"]
        self.save_max_metric_score = self.validation_config[
            "save_max_metric_score"]
        assert self.validation_interval >= 1, "Check the 'validation_interval' parameter in the config. It should be large than one."

        # Trainer.visualization in the config
        self.visualization_config = config["trainer"]["visualization"]

        # In the 'train.py' file, if the 'resume' item is 'True', we will update the following args:
        self.start_epoch = 1
        self.best_score = -np.inf if self.save_max_metric_score else np.inf
        self.save_dir = Path(config["meta"]["save_dir"]).expanduser().absolute(
        ) / config["meta"]["experiment_name"]
        self.checkpoints_dir = self.save_dir / "checkpoints"
        self.logs_dir = self.save_dir / "logs"

        if resume:
            self._resume_checkpoint()

        # Debug validation, which skips training
        self.only_validation = only_validation

        if config["meta"]["preloaded_model_path"]:
            self._preload_model(Path(config["preloaded_model_path"]))

        if self.rank == 0:
            prepare_empty_dir([self.checkpoints_dir, self.logs_dir],
                              resume=resume)

            self.writer = SummaryWriter(self.logs_dir.as_posix(),
                                        max_queue=5,
                                        flush_secs=30)
            self.writer.add_text(
                tag="Configuration",
                text_string=f"<pre>  \n{toml.dumps(config)}  \n</pre>",
                global_step=1)

            print(self.color_tool.cyan("The configurations are as follows: "))
            print(self.color_tool.cyan("=" * 40))
            print(self.color_tool.cyan(toml.dumps(config)[:-1]))  # except "\n"
            print(self.color_tool.cyan("=" * 40))

            with open(
                (self.save_dir /
                 f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(),
                    "w") as handle:
                toml.dump(config, handle)

            self._print_networks([self.model])

    def _preload_model(self, model_path):
        """
        Preload model parameters (in "*.tar" format) at the start of experiment.

        Args:
            model_path (Path): The file path of the *.tar file
        """
        model_path = model_path.expanduser().absolute()
        assert model_path.exists(
        ), f"The file {model_path.as_posix()} is not exist. please check path."

        model_checkpoint = torch.load(model_path.as_posix(),
                                      map_location="cpu")
        self.model.load_state_dict(model_checkpoint["model"], strict=False)
        self.model.to(self.rank)

        if self.rank == 0:
            print(
                f"Model preloaded successfully from {model_path.as_posix()}.")

    def _resume_checkpoint(self):
        """
        Resume the experiment from the latest checkpoint.
        """
        latest_model_path = self.checkpoints_dir.expanduser().absolute(
        ) / "latest_model.tar"
        assert latest_model_path.exists(
        ), f"{latest_model_path} does not exist, can not load latest checkpoint."

        # Make sure all processes (GPUs) do not start loading before the saving is finished.
        # see https://stackoverflow.com/questions/59760328/how-does-torch-distributed-barrier-work
        self.dist.barrier()

        # Load it on the CPU and later use .to(device) on the model
        # Maybe slightly slow than use map_location="cuda:<...>"
        # https://stackoverflow.com/questions/61642619/pytorch-distributed-data-parallel-confusion
        checkpoint = torch.load(latest_model_path.as_posix(),
                                map_location="cpu")

        self.start_epoch = checkpoint["epoch"] + 1
        self.best_score = checkpoint["best_score"]
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        self.scaler.load_state_dict(checkpoint["scaler"])

        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            self.model.module.load_state_dict(checkpoint["model"])
        else:
            self.model.load_state_dict(checkpoint["model"])

        # self.model.to(self.rank)

        if self.rank == 0:
            print(
                f"Model checkpoint loaded. Training will begin at {self.start_epoch} epoch."
            )

    def _save_checkpoint(self, epoch, is_best_epoch=False):
        """
        Save checkpoint to "<save_dir>/<config name>/checkpoints" directory, which consists of:
            - epoch
            - best metric score in historical epochs
            - optimizer parameters
            - model parameters

        Args:
            is_best_epoch (bool): In the current epoch, if the model get a best metric score (is_best_epoch=True),
                                the checkpoint of model will be saved as "<save_dir>/checkpoints/best_model.tar".
        """
        print(f"\t Saving {epoch} epoch model checkpoint...")

        state_dict = {
            "epoch": epoch,
            "best_score": self.best_score,
            "optimizer": self.optimizer.state_dict(),
            "scaler": self.scaler.state_dict()
        }

        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            state_dict["model"] = self.model.module.state_dict()
        else:
            state_dict["model"] = self.model.state_dict()

        # Saved in "latest_model.tar"
        # Contains all checkpoint information, including the optimizer parameters, the model parameters, etc.
        # New checkpoint will overwrite the older one.
        torch.save(state_dict,
                   (self.checkpoints_dir / "latest_model.tar").as_posix())

        # "model_{epoch_number}.pth"
        # Contains only model.
        torch.save(state_dict["model"],
                   (self.checkpoints_dir /
                    f"model_{str(epoch).zfill(4)}.pth").as_posix())

        # If the model get a best metric score (means "is_best_epoch=True") in the current epoch,
        # the model checkpoint will be saved as "best_model.tar"
        # The newer best-scored checkpoint will overwrite the older one.
        if is_best_epoch:
            print(
                self.color_tool.red(
                    f"\t Found a best score in the {epoch} epoch, saving..."))
            torch.save(state_dict,
                       (self.checkpoints_dir / "best_model.tar").as_posix())

    def _is_best_epoch(self, score, save_max_metric_score=True):
        """
        Check if the current model got the best metric score
        """
        if save_max_metric_score and score >= self.best_score:
            self.best_score = score
            return True
        elif not save_max_metric_score and score <= self.best_score:
            self.best_score = score
            return True
        else:
            return False

    @staticmethod
    def _print_networks(models: list):
        print(
            f"This project contains {len(models)} models, the number of the parameters is: "
        )

        params_of_all_networks = 0
        for idx, model in enumerate(models, start=1):
            params_of_network = 0
            for param in model.parameters():
                params_of_network += param.numel()

            print(f"\tNetwork {idx}: {params_of_network / 1e6} million.")
            params_of_all_networks += params_of_network

        print(
            f"The amount of parameters in the project is {params_of_all_networks / 1e6} million."
        )

    def _set_models_to_train_mode(self):
        self.model.train()

    def _set_models_to_eval_mode(self):
        self.model.eval()

    def spec_audio_visualization(self,
                                 noisy,
                                 enhanced,
                                 clean,
                                 name,
                                 epoch,
                                 mark=""):
        self.writer.add_audio(f"{mark}_Speech/{name}_Noisy",
                              noisy,
                              epoch,
                              sample_rate=16000)
        self.writer.add_audio(f"{mark}_Speech/{name}_Enhanced",
                              enhanced,
                              epoch,
                              sample_rate=16000)
        self.writer.add_audio(f"{mark}_Speech/{name}_Clean",
                              clean,
                              epoch,
                              sample_rate=16000)

        # Visualize the spectrogram of noisy speech, clean speech, and enhanced speech
        noisy_mag, _ = librosa.magphase(
            self.librosa_stft(noisy, n_fft=320, hop_length=160,
                              win_length=320))
        enhanced_mag, _ = librosa.magphase(
            self.librosa_stft(enhanced,
                              n_fft=320,
                              hop_length=160,
                              win_length=320))
        clean_mag, _ = librosa.magphase(
            self.librosa_stft(clean, n_fft=320, hop_length=160,
                              win_length=320))
        fig, axes = plt.subplots(3, 1, figsize=(6, 6))
        for k, mag in enumerate([noisy_mag, enhanced_mag, clean_mag]):
            axes[k].set_title(f"mean: {np.mean(mag):.3f}, "
                              f"std: {np.std(mag):.3f}, "
                              f"max: {np.max(mag):.3f}, "
                              f"min: {np.min(mag):.3f}")
            librosa.display.specshow(librosa.amplitude_to_db(mag),
                                     cmap="magma",
                                     y_axis="linear",
                                     ax=axes[k],
                                     sr=16000)
        plt.tight_layout()
        self.writer.add_figure(f"{mark}_Spectrogram/{name}", fig, epoch)

    def metrics_visualization(self,
                              noisy_list,
                              clean_list,
                              enhanced_list,
                              metrics_list,
                              epoch,
                              num_workers=10,
                              mark=""):
        """
        Get metrics on validation dataset by paralleling.

        Notes:
            1. You can register other metrics, but STOI and WB_PESQ metrics must be existence. These two metrics are
             used for checking if the current epoch is a "best epoch."
            2. If you want to use a new metric, you must register it in "util.metrics" file.
        """
        assert "STOI" in metrics_list and "WB_PESQ" in metrics_list, "'STOI' and 'WB_PESQ' must be existence."

        # Check if the metric is registered in "util.metrics" file.
        for i in metrics_list:
            assert i in metrics.REGISTERED_METRICS.keys(
            ), f"{i} is not registered, please check 'util.metrics' file."

        stoi_mean = 0.0
        wb_pesq_mean = 0.0
        for metric_name in metrics_list:
            score_on_noisy = Parallel(n_jobs=num_workers)(
                delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est)
                for ref, est in zip(clean_list, noisy_list))
            score_on_enhanced = Parallel(n_jobs=num_workers)(
                delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est)
                for ref, est in zip(clean_list, enhanced_list))

            # Add the mean value of the metric to tensorboard
            mean_score_on_noisy = np.mean(score_on_noisy)
            mean_score_on_enhanced = np.mean(score_on_enhanced)
            self.writer.add_scalars(f"{mark}_Validation/{metric_name}", {
                "Noisy": mean_score_on_noisy,
                "Enhanced": mean_score_on_enhanced
            }, epoch)

            if metric_name == "STOI":
                stoi_mean = mean_score_on_enhanced

            if metric_name == "WB_PESQ":
                wb_pesq_mean = transform_pesq_range(mean_score_on_enhanced)

        return (stoi_mean + wb_pesq_mean) / 2

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            if self.rank == 0:
                print(
                    self.color_tool.yellow(
                        f"{'=' * 15} {epoch} epoch {'=' * 15}"))
                print("[0 seconds] Begin training...")

            # [debug validation] Only run validation (only use the first GPU (process))
            # inference + calculating metrics + saving checkpoints
            if self.only_validation and self.rank == 0:
                self._set_models_to_eval_mode()
                metric_score = self._validation_epoch(epoch)

                if self._is_best_epoch(
                        metric_score,
                        save_max_metric_score=self.save_max_metric_score):
                    self._save_checkpoint(epoch, is_best_epoch=True)

                # Skip the following regular training, saving checkpoints, and validation
                continue

            # Regular training
            timer = ExecutionTime()
            self._set_models_to_train_mode()
            self._train_epoch(epoch)

            #  Regular save checkpoints
            if self.rank == 0 and self.save_checkpoint_interval != 0 and (
                    epoch % self.save_checkpoint_interval == 0):
                self._save_checkpoint(epoch)

            # Regular validation
            if self.rank == 0 and (epoch % self.validation_interval == 0):
                print(
                    f"[{timer.duration()} seconds] Training has finished, validation is in progress..."
                )

                self._set_models_to_eval_mode()
                metric_score = self._validation_epoch(epoch)

                if self._is_best_epoch(
                        metric_score,
                        save_max_metric_score=self.save_max_metric_score):
                    self._save_checkpoint(epoch, is_best_epoch=True)

            print(f"[{timer.duration()} seconds] This epoch is finished.")

    def _train_epoch(self, epoch):
        raise NotImplementedError

    def _validation_epoch(self, epoch):
        raise NotImplementedError
Пример #16
0
    class Fp16OptimizerHook(OptimizerHook):
        """FP16 optimizer hook (using PyTorch's implementation).

        If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
        to take care of the optimization procedure.

        Args:
            loss_scale (float | str | dict): Scale factor configuration.
                If loss_scale is a float, static loss scaling will be used with
                the specified scale. If loss_scale is a string, it must be
                'dynamic', then dynamic loss scaling will be used.
                It can also be a dict containing arguments of GradScalar.
                Defaults to 512. For Pytorch >= 1.6, mmcv uses official
                implementation of GradScaler. If you use a dict version of
                loss_scale to create GradScaler, please refer to:
                https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
                for the parameters.

        Examples:
            >>> loss_scale = dict(
            ...     init_scale=65536.0,
            ...     growth_factor=2.0,
            ...     backoff_factor=0.5,
            ...     growth_interval=2000
            ... )
            >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale)
        """

        def __init__(self,
                     grad_clip=None,
                     coalesce=True,
                     bucket_size_mb=-1,
                     loss_scale=512.,
                     distributed=True):
            self.grad_clip = grad_clip
            self.coalesce = coalesce
            self.bucket_size_mb = bucket_size_mb
            self.distributed = distributed
            self._scale_update_param = None
            if loss_scale == 'dynamic':
                self.loss_scaler = GradScaler()
            elif isinstance(loss_scale, float):
                self._scale_update_param = loss_scale
                self.loss_scaler = GradScaler(init_scale=loss_scale)
            elif isinstance(loss_scale, dict):
                self.loss_scaler = GradScaler(**loss_scale)
            else:
                raise ValueError('loss_scale must be of type float, dict, or '
                                 f'"dynamic", got {loss_scale}')

        def before_run(self, runner):
            """Preparing steps before Mixed Precision Training."""
            # wrap model mode to fp16
            wrap_fp16_model(runner.model)
            # resume from state dict
            if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
                scaler_state_dict = runner.meta['fp16']['loss_scaler']
                self.loss_scaler.load_state_dict(scaler_state_dict)

        def copy_grads_to_fp32(self, fp16_net, fp32_weights):
            """Copy gradients from fp16 model to fp32 weight copy."""
            for fp32_param, fp16_param in zip(fp32_weights,
                                              fp16_net.parameters()):
                if fp16_param.grad is not None:
                    if fp32_param.grad is None:
                        fp32_param.grad = fp32_param.data.new(
                            fp32_param.size())
                    fp32_param.grad.copy_(fp16_param.grad)

        def copy_params_to_fp16(self, fp16_net, fp32_weights):
            """Copy updated params from fp32 weight copy to fp16 model."""
            for fp16_param, fp32_param in zip(fp16_net.parameters(),
                                              fp32_weights):
                fp16_param.data.copy_(fp32_param.data)

        def after_train_iter(self, runner):
            """Backward optimization steps for Mixed Precision Training. For
            dynamic loss scaling, please refer to
            https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.

            1. Scale the loss by a scale factor.
            2. Backward the loss to obtain the gradients.
            3. Unscale the optimizer’s gradient tensors.
            4. Call optimizer.step() and update scale factor.
            5. Save loss_scaler state_dict for resume purpose.
            """
            # clear grads of last iteration
            runner.model.zero_grad()
            runner.optimizer.zero_grad()

            self.loss_scaler.scale(runner.outputs['loss']).backward()
            self.loss_scaler.unscale_(runner.optimizer)
            # grad clip
            if self.grad_clip is not None:
                grad_norm = self.clip_grads(runner.model.parameters())
                if grad_norm is not None:
                    # Add grad norm to the logger
                    runner.log_buffer.update({'grad_norm': float(grad_norm)},
                                             runner.outputs['num_samples'])
            # backward and update scaler
            self.loss_scaler.step(runner.optimizer)
            self.loss_scaler.update(self._scale_update_param)

            # save state_dict of loss_scaler
            runner.meta.setdefault(
                'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
Пример #17
0
def main(args):
    # ensures that weight initializations are all the same
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    logging = utils.Logger(args.global_rank, args.save)
    writer = utils.Writer(args.global_rank, args.save)

    # Get data loaders.
    train_queue, valid_queue, num_classes = datasets.get_loaders(args)
    args.num_total_iter = len(train_queue) * args.epochs
    warmup_iters = len(train_queue) * args.warmup_epochs
    swa_start = len(train_queue) * (args.epochs - 1)

    arch_instance = utils.get_arch_cells(args.arch_instance)

    model = AutoEncoder(args, writer, arch_instance)
    model = model.cuda()

    logging.info('args = %s', args)
    logging.info('param size = %fM ', utils.count_parameters_in_M(model))
    logging.info('groups per scale: %s, total_groups: %d',
                 model.groups_per_scale, sum(model.groups_per_scale))

    if args.fast_adamax:
        # Fast adamax has the same functionality as torch.optim.Adamax, except it is faster.
        cnn_optimizer = Adamax(model.parameters(),
                               args.learning_rate,
                               weight_decay=args.weight_decay,
                               eps=1e-3)
    else:
        cnn_optimizer = torch.optim.Adamax(model.parameters(),
                                           args.learning_rate,
                                           weight_decay=args.weight_decay,
                                           eps=1e-3)

    cnn_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        cnn_optimizer,
        float(args.epochs - args.warmup_epochs - 1),
        eta_min=args.learning_rate_min)
    grad_scalar = GradScaler(2**10)

    num_output = utils.num_output(args.dataset)
    bpd_coeff = 1. / np.log(2.) / num_output

    # if load
    checkpoint_file = os.path.join(args.save, 'checkpoint.pt')
    if args.cont_training:
        logging.info('loading the model.')
        checkpoint = torch.load(checkpoint_file, map_location='cpu')
        init_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        model = model.cuda()
        cnn_optimizer.load_state_dict(checkpoint['optimizer'])
        grad_scalar.load_state_dict(checkpoint['grad_scalar'])
        cnn_scheduler.load_state_dict(checkpoint['scheduler'])
        global_step = checkpoint['global_step']
    else:
        global_step, init_epoch = 0, 0

    for epoch in range(init_epoch, args.epochs):  # epochs cycle
        # update lrs.
        if args.distributed:
            train_queue.sampler.set_epoch(global_step + args.seed)
            valid_queue.sampler.set_epoch(0)

        if epoch > args.warmup_epochs:
            cnn_scheduler.step()

        # Logging.
        logging.info('epoch %d', epoch)

        # Training.
        train_nelbo, global_step = train(train_queue, model, cnn_optimizer,
                                         grad_scalar, global_step,
                                         warmup_iters, writer, logging)
        logging.info('train_nelbo %f', train_nelbo)
        writer.add_scalar('train/nelbo', train_nelbo, global_step)

        model.eval()
        # generate samples less frequently
        eval_freq = 1 if args.epochs <= 50 else 20
        if epoch % eval_freq == 0 or epoch == (args.epochs - 1):
            with torch.no_grad():
                num_samples = 16
                n = int(np.floor(np.sqrt(num_samples)))
                for t in [0.7, 0.8, 0.9, 1.0]:
                    logits = model.sample(num_samples, t)
                    output = model.decoder_output(logits)
                    output_img = output.mean if isinstance(
                        output, torch.distributions.bernoulli.Bernoulli
                    ) else output.sample(t)
                    output_tiled = utils.tile_image(output_img, n)
                    writer.add_image('generated_%0.1f' % t, output_tiled,
                                     global_step)

            valid_neg_log_p, valid_nelbo = test(valid_queue,
                                                model,
                                                num_samples=10,
                                                args=args,
                                                logging=logging)
            logging.info('valid_nelbo %f', valid_nelbo)
            logging.info('valid neg log p %f', valid_neg_log_p)
            logging.info('valid bpd elbo %f', valid_nelbo * bpd_coeff)
            logging.info('valid bpd log p %f', valid_neg_log_p * bpd_coeff)
            writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch)
            writer.add_scalar('val/nelbo', valid_nelbo, epoch)
            writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff,
                              epoch)
            writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch)

        save_freq = int(np.ceil(args.epochs / 100))
        if epoch % save_freq == 0 or epoch == (args.epochs - 1):
            if args.global_rank == 0:
                logging.info('saving the model.')
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': cnn_optimizer.state_dict(),
                        'global_step': global_step,
                        'args': args,
                        'arch_instance': arch_instance,
                        'scheduler': cnn_scheduler.state_dict(),
                        'grad_scalar': grad_scalar.state_dict()
                    }, checkpoint_file)

    # Final validation
    valid_neg_log_p, valid_nelbo = test(valid_queue,
                                        model,
                                        num_samples=1000,
                                        args=args,
                                        logging=logging)
    logging.info('final valid nelbo %f', valid_nelbo)
    logging.info('final valid neg log p %f', valid_neg_log_p)
    writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch + 1)
    writer.add_scalar('val/nelbo', valid_nelbo, epoch + 1)
    writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff, epoch + 1)
    writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch + 1)
    writer.close()
Пример #18
0
class Trainer:
    def __init__(
        self,
        config: DictConfig,
        model: FlyModel,
        name: str = "Trainer1",
        *args,
        **kwargs,
    ):
        """
        One trainer has one model
        Args:
            config: FlyConfig dictionary
            model: must be FlyModel
            dataloader_fn: a Callable function which returns dataloaders
        """
        logger.info("TrainerLoop is initializing!")
        if not isinstance(model, FlyModel):
            logger.warn("model is not defined as FlyModel")
        self.config = config
        self.model = model
        self.trainer_name = name

        # class properties
        self.rank = None
        self.local_rank = None
        self.node_rank = None
        self.world_size = None
        self.distributed_training = None
        self.device = None
        self.gradient_accumulation_batches = None
        self.callback_handler = None
        self.optimizers = []
        self.schedulers = []

        self.init_distributed_environment()

        # make sure the model has access to trainer info
        self.model.set_trainer(self)

        self.callback_handler = CallbackHandler(
            config,
            trainer=self,
            callbacks=[],
            verbose=config.logging.level == "DEBUG")

        # Configure all callbacks
        self.configure_callbacks(config)
        self.callback_handler.fire_event(Events.INITIALIZE)

    def init_distributed_environment(self):
        # For distributed
        self.rank = int(os.environ.get("RANK", 0))
        self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.distributed_training = self.world_size > 1

        # TODO: add error message when num_gpus is set, but distributed training is False here

        if self.distributed_training and not torch.distributed.is_initialized(
        ):
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")
            assert torch.distributed.is_initialized()

        if self.distributed_training and not torch.distributed.is_initialized(
        ):
            self.node_rank = os.environ.get("NODE_RANK", "N/A")
            logger.info(
                f"Initialized Rank:{torch.distributed.get_rank()} Locak-rank: {self.local_rank} on Node:{self.node_rank} Node-name:{socket.gethostname()}"
            )

        torch.cuda.set_device(distributed.get_rank())

    def init_device(self, config):
        # set cuda device
        if config.num_gpus_per_node > 0:
            torch.cuda.set_device(self.local_rank)
            self.device = torch.device("cuda", self.local_rank)
        else:
            self.device = torch.device("cpu")

    def init_fp16(self, config):
        if config.num_gpus_per_node == 0:
            raise NotImplementedError(
                "For mixed precision training, you need to use GPU!")
        self.loss_scaler = GradScaler()

    def init_training_constants(self, config):
        self.total_num_update_steps = int(config.total_num.update_steps)
        self.total_num_batches = self.total_num_update_steps * int(
            self.gradient_accumulation_batches)
        self.total_num_epochs = int(config.total_num.epochs)

        # check if training in epoch or update_steps
        if self.total_num_update_steps < 0 and self.total_num_epochs < 0:
            raise NotImplementedError(
                "config.total_num.updated_steps must be larger than 0")
        elif self.total_num_update_steps > 0 and self.total_num_epochs > 0:
            raise NotImplementedError(
                "Please only set either config.total_num.updated_steps or config.total_num.epochs greater than 0"
            )
        elif self.total_num_update_steps > 0 and self.total_num_epochs < 0:
            self.training_in_epoch = False
        elif self.total_num_update_steps < 0 and self.total_num_epochs > 0:
            self.training_in_epoch = True

        # get the number of batches in the dataloader for one epoch
        try:
            self.epoch_num_batches = len(self.train_dataloader)
        except TypeError:
            logger.warning("Cannot determine the length of train_dataloader!")
            self.epoch_num_batches = None

        if self.training_in_epoch:
            if self.epoch_num_batches is not None:
                self.total_num_batches = self.epoch_num_batches * self.total_num_epochs
                self.total_num_update_steps = (
                    self.total_num_batches //
                    self.gradient_accumulation_batches)
                self.epoch_num_update_steps = (
                    self.epoch_num_batches //
                    self.gradient_accumulation_batches)
            else:
                # this is set to wait until the epoch finishes first
                self.total_num_update_steps = sys.maxsize

    def configure_optimizers(self,
                             config,
                             total_num_update_steps=None,
                             optimizers=None,
                             schedulers=None):
        if optimizers is not None and schedulers is not None:
            self.optimizers, self.schedulers = optimizers, schedulers
        elif total_num_update_steps is not None:
            self.optimizers, self.schedulers = self.model.configure_optimizers(
                config, total_num_update_steps)
        else:
            raise ValueError("Please provide the correct argument!")
        return self.optimizers, self.schedulers

    def configure_callbacks(self, config):
        # Resume callback runs for all ranks
        if config.resume.enabled:
            self.resume_callback = Resume(config)
            self.add_callback(self.resume_callback)

        self.log_callback = TrainLogger(config)
        self.add_callback(self.log_callback)

        self.eval_callback = Evaluation(config)
        self.add_callback(self.eval_callback)

        # For logging and inference, use rank 0 by default
        if self.rank == 0:
            if config.console:
                self.console_callback = Console(config)
                self.add_callback(self.console_callback)

            if config.checkpointing.enabled:
                self.checkpoint_callback = Checkpoint(config)
                self.add_callback(self.checkpoint_callback)

    def init_distributed_model(self, model):
        """
        Default distributed training uses reducer for simplicity. 
        """
        logger.info("Reducer is intialized!")
        # Distributed training (should be after apex fp16 initialization)
        self.reducer = Reducer(model)
        # for param in self.model.parameters():
        #     dist.broadcast(param.data, 0)

    def train(
        self,
        config,
        train_dataloader,
        validation_dataloader=None,
        test_dataloader=None,
        configure_optimizers=True,
        stage_name: str = "Stage1",
        *args,
        **kwargs,
    ):
        self.config = config
        self.stage_name = stage_name

        # Model is sent to GPU or CPU
        self.init_device(config)
        # self.optimizers, self.schedulers = self.configure_optimizers()

        self.gradient_accumulation_batches = config.gradient_accumulation_batches
        self.max_gradient_norm = config.optimization.max_gradient_norm
        self.fp16 = config.fp16
        self.model = move_to_device(self.model, self.device)
        self.model.device = self.device
        self.init_fp16(config)

        if self.distributed_training:
            self.init_distributed_model(self.model)

        self.total_num_update_steps = 0
        self.total_num_batches = 0
        self.total_num_epochs = 0
        self.epoch_num_batches = 0
        self.global_batch_count = 0
        self.global_step_count = 0
        self.epochs_trained = 0
        self.local_step_count = 0

        self.train_dataloader = train_dataloader
        self.validation_dataloader = validation_dataloader
        self.test_dataloader = test_dataloader

        self.init_training_constants(config)

        if configure_optimizers or len(self.optimizers) == 0:
            self.configure_optimizers(config, self.total_num_update_steps)

        # Training begins
        self.callback_handler.fire_event(Events.TRAIN_BEGIN)

        while True:
            self.callback_handler.fire_event(Events.EPOCH_BEGIN)
            self.train_epoch()
            self.callback_handler.fire_event(Events.EPOCH_END)
            self.epochs_trained += 1

            if self.training_in_epoch:
                if self.epochs_trained >= self.total_num_epochs:
                    break
            else:
                if self.global_step_count < self.total_num_update_steps:
                    continue
                else:
                    break

        # Training ends
        self.callback_handler.fire_event(Events.TRAIN_END)

    def train_epoch(self):
        self.optimizer = self.optimizers[0]
        self.scheduler = self.schedulers[0]

        self.local_step_count = 0

        if self.train_dataloader is None:
            return

        for batch in self.train_dataloader:
            self.callback_handler.fire_event(Events.BATCH_BEGIN)

            batch = move_to_device(batch, self.device)
            output = self.backward_batch(batch)

            # Update the model
            if (self.global_batch_count +
                    1) % self.gradient_accumulation_batches == 0:
                # Update the model with optimizer
                self.step_update(self.model, self.optimizer, self.scheduler)
                self.global_step_count += 1
                self.local_step_count += 1

            self.callback_handler.fire_event(Events.BATCH_END)

            if self.global_step_count >= self.total_num_update_steps:
                break

            self.global_batch_count += 1

    def backward_batch(self, batch):
        self.model.train()
        with torch.cuda.amp.autocast(self.fp16):
            output = self.model(batch)

        # get the loss from output
        if hasattr(output, "loss"):
            loss = output.loss
        elif isinstance(output, dict):
            loss = output["loss"]

        if self.gradient_accumulation_batches > 1:
            loss = loss / self.gradient_accumulation_batches

        self.loss_backward(loss)
        return output

    def step_update(self, model, optimizer, scheduler=None):
        """
            self.loss_scaler is defined in `configure_fp16`
        """
        self.callback_handler.fire_event(Events.STEP_BEGIN)
        # collect gradient
        if self.distributed_training:
            self.reducer.reduce()

        gradient_clip = self.max_gradient_norm
        # Gradient Clipping
        if gradient_clip > 0:
            if self.fp16:
                self.loss_scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               gradient_clip)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               gradient_clip)
        # Update the model
        if self.fp16:
            self.loss_scaler.step(optimizer)
            self.loss_scaler.update()
        else:
            optimizer.step()
        # Step learning rate
        if scheduler:
            scheduler.step()
        # Gradient to zero
        optimizer.zero_grad()
        self.callback_handler.fire_event(Events.STEP_END)

    def loss_backward(self, loss):
        self.callback_handler.fire_event(Events.BACKWARD_BEGIN)
        # Loss backward
        if self.fp16:
            self.loss_scaler.scale(loss).backward()
        else:
            loss.backward()
        self.callback_handler.fire_event(Events.BACKWARD_END)

    def validate(self, dataloader):
        # Start Validation
        self.model.reset_evaluation_metrics()
        self.callback_handler.fire_event(Events.VALIDATE_BEGIN)
        self.model.validation_loop(dataloader)
        self.callback_handler.fire_event(Events.VALIDATE_END)

    def test(self, dataloader):
        # Start Testing
        self.model.reset_evaluation_metrics()
        self.callback_handler.fire_event(Events.TEST_BEGIN)
        self.model.test_loop(dataloader)
        self.callback_handler.fire_event(Events.TEST_END)

    def set_model_state(self, model_state_dict):
        self.model.load_state_dict(model_state_dict)

    def get_model_state(self):
        return self.model.state_dict()

    def set_trainer_state(self, trainer_state_dict):
        self.trainer_name = trainer_state_dict["trainer_name"]
        self.trainer_stage = trainer_state_dict["trainer_stage"]
        self.epochs_trained = trainer_state_dict["epochs_trained"]
        self.global_step_count = trainer_state_dict["global_step_count"]
        self.local_step_count = trainer_state_dict["local_step_count"]

        # All Callbacks
        for callback in self.callback_handler.callbacks:
            try:
                callback.load_state_dict(trainer_state_dict[str(
                    type(callback))])
            except:
                logger.error(
                    f"{type(callback)} seems not to exist in the checkpoint state!"
                )

        # Resume the training state
        # Scheduler States
        for idx, scheduler in enumerate(self.schedulers):
            try:
                scheduler.load_state_dict(
                    trainer_state_dict["schedulers_state_dict"][idx])
            except:
                if self.rank == 0:
                    logger.warning(f"Cannot Load Scheduler {idx}'s State!")

        for idx, optimizer in enumerate(self.optimizers):
            try:
                optimizer.load_state_dict(
                    trainer_state_dict["optimizers_state_dict"][idx])
            except:
                if self.rank == 0:
                    logger.warning(f"Cannot Load Optimizer {idx}'s State!")

        # save amp states
        try:
            if self.fp16:
                self.loss_scaler.load_state_dict(
                    trainer_state_dict["amp_state_dict"])
        except:
            logger.warning(f"Cannot Load Loss Scaler State!")

        # Random States
        torch.set_rng_state(trainer_state_dict["cpu_rng_state"])
        torch.cuda.set_rng_state_all(
            trainer_state_dict["cuda_rng_state"][:torch.cuda.device_count()])

    def get_trainer_state(self):
        trainer_state_dict = {
            "trainer_name":
            self.trainer_name,
            "stage_name":
            self.stage_name,
            "epochs_trained":
            self.epochs_trained,
            "global_step_count":
            self.global_step_count,
            "local_step_count":
            self.local_step_count,
            "optimizers_state_dict":
            [optimizer.state_dict() for optimizer in self.optimizers],
            "schedulers_state_dict":
            [scheduler.state_dict() for scheduler in self.schedulers],
            "cpu_rng_state":
            torch.get_rng_state(),
            "cuda_rng_state":
            torch.cuda.get_rng_state_all(),
        }

        # save amp states
        if self.fp16:
            trainer_state_dict["amp_state_dict"] = self.loss_scaler.state_dict(
            )

        # All Callbacks
        for callback in self.callback_handler.callbacks:
            trainer_state_dict[str(type(callback))] = callback.state_dict()

        return trainer_state_dict

    def add_callback(self, callback: Callback):
        self.callback_handler.add_callback(callback)
Пример #19
0
class ClassificationTask(ClassyTask):
    """Basic classification training task.

    This task encapsultates all of the components and steps needed to
    train a classifier using a :class:`classy_vision.trainer.ClassyTrainer`.

    Assumes a train / test phase per each epoch and that the datasets
    have the same API as the map-style Dataset class in
    `torch.utils.data.dataset <https://pytorch.org/docs/stable/data.html
    #torch.utils.data.Dataset>`_ (in particular, this task makes use of
    the len).  If you are using an `IterableDataset <https://pytorch.org/docs/
    stable/data.html#torch.utils.data.IterableDataset>`_ then a custom task
    may be appropriate.


    :var loss: Loss (see :class:`classy_vision.losses.ClassyLoss`) function used
        for computing the loss in each forward pass
    :var datasets: Mapping from a ``phase_type`` in ["train", "test']
        to dataset used for training (or testing)
    :var meters: List of meters (see :class:`classy_vision.meters.ClassyMeter`)
        to calculate during training
    :var num_epochs: Number of epochs (passes over dataset) to train
    :var test_only: Used to only run the test phase
    :var base_model: Model to be trained, unwrapped in DDP or DP wrappers
    :var optimizer: Optimizer used in train step
    :var optimizer_schedulers: Dictionary. Key is the name of the optimizer
        option (e.g. lr), value is a ClassyParamScheduler
    :var checkpoint: Serializable dict which represents state in training
    :var phases: List of phase specific information, e.g. if phase is
        train / test.
    :var hooks: List of hooks to apply during training
    :var train: Phase type, if true it means we are training,
        false means testing
    :var distributed_model: Base model, but wrapped in DDP (DistributedDataParallel)
    :var phase_idx: Current phase id, first phase is 0, if task has not started
        training then returns -1
    :var train_phase_idx: Only counts train phases
    :var num_updates: Number of total parameter updates applied to model
        by the optimizer
    :var data_iterator: Iterator which can be used to obtain batches
    :var losses: Loss curve
    :var perf_log: list of training speed measurements, to be logged
    :var clip_grad_norm: maximum gradient norm (default None)
    :var simulated_global_batchsize: batch size simulated via gradient accumulation
    :var optimizer_period: apply optimizer after this many steps; derived from
        simulated_global_batchsize, default 1.
    """

    def __init__(self):
        """Constructs a ClassificationTask"""
        super().__init__()

        self.base_loss = None
        self.datasets = {}
        self.meters = []
        self.num_epochs = 1
        self.test_phase_period = 1
        self.train_phases_per_epoch = 0
        self.test_only = False
        self.base_model = None
        self.optimizer = None
        self.optimizer_schedulers = {}
        self.checkpoint_dict = None
        self.checkpoint_path = None
        self.phases = []
        self.hooks = []
        self.train = True
        self.distributed_model = None
        self.distributed_loss = None
        self.phase_idx = -1
        self.train_phase_idx = -1
        self.num_updates = 0
        self.dataloader = None
        self.data_iterator = None
        self.losses = []
        self.broadcast_buffers_mode: BroadcastBuffersMode = (
            BroadcastBuffersMode.BEFORE_EVAL
        )
        self.amp_args = None
        self.amp_type = None
        self.amp_grad_scaler = None
        self.mixup_transform = None
        self.perf_log = []
        self.last_batch = None
        self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED
        self.find_unused_parameters = False
        self.use_gpu = torch.cuda.is_available()
        self.dataloader_mp_context = "spawn"
        self.bn_weight_decay = False
        self._train_only = True
        self.clip_grad_norm = None
        self.simulated_global_batchsize = None
        self.optimizer_period = 1
        self.ddp_bucket_cap_mb = 25
        self.use_sharded_ddp = False
        self.fp16_grad_compress = False

    def set_use_sharded_ddp(self, use_sharded_ddp: bool):
        self.use_sharded_ddp = use_sharded_ddp
        if self.use_sharded_ddp:
            logging.info("Using Sharded DDP")
        return self

    def set_use_gpu(self, use_gpu: bool):
        self.use_gpu = use_gpu

        assert (
            not self.use_gpu or torch.cuda.is_available()
        ), "CUDA required to train on GPUs"

        return self

    def set_clip_grad_norm(self, clip_grad_norm: Optional[float]):
        """Sets maximum gradient norm.

        None means gradient clipping is disabled. Defaults to None."""
        self.clip_grad_norm = clip_grad_norm
        if clip_grad_norm is None:
            logging.info("Disabled gradient norm clipping.")
        else:
            logging.info(
                f"Enabled gradient norm clipping with threshold: {clip_grad_norm}"
            )
        return self

    def set_simulated_global_batchsize(self, simulated_global_batchsize: Optional[int]):
        """Sets a simulated batch size by gradient accumulation.

        Gradient accumulation adds up gradients from multiple minibatches and
        steps the optimizer every N train_steps, where N is optimizer_period.
        When enabled, the very last train_steps might end up not updating the
        model, depending on the number of total steps. None means gradient
        accumulation is disabled. Defaults to None."""
        self.simulated_global_batchsize = simulated_global_batchsize
        return self

    def set_checkpoint(self, checkpoint_path: str):
        """Sets checkpoint on task.

        Args:
            checkpoint_path: The path to load the checkpoint from. Can be a file or a
            directory. See :func:`load_checkpoint` for more information.
        """
        self.checkpoint_path = checkpoint_path
        return self

    def _set_checkpoint_dict(self, checkpoint_dict: Dict[str, Any]):
        """Sets the checkpoint dict in the task. Only used for testing.

        Args:
            checkpoint_dict: A serializable dict representing current task state
        """
        self.checkpoint_dict = checkpoint_dict
        return self

    def set_num_epochs(self, num_epochs: Union[int, float]):
        """Set number of epochs to be run.

        Args:
           num_epochs: Number of epochs to run task
        """
        self.num_epochs = num_epochs
        return self

    def set_test_phase_period(self, test_phase_period: int):
        """Set the period of test phase.

        Args:
            test_phase_period: The period of test phase
        """
        self.test_phase_period = test_phase_period
        return self

    def set_dataset(self, dataset: ClassyDataset, phase_type: str):
        """Set dataset for phase type on task

        Args:
            dataset: ClassyDataset for returning samples.
            phase_type: str must be one of "train" or "test"
        """
        assert phase_type in [
            "train",
            "test",
        ], "phase_type must be in ['train', 'test']"
        self.datasets[phase_type] = dataset
        if phase_type == "train":
            self.train_phases_per_epoch = getattr(dataset, "phases_per_epoch", 1)
        else:
            self._train_only = False
        return self

    def set_dataloader_mp_context(self, dataloader_mp_context: Optional[str]):
        """Set the multiprocessing context used by the dataloader.

        The context can be either 'spawn', 'fork', 'forkserver' or None (uses the
        default context). See
        https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_context
        for more details."""

        self.dataloader_mp_context = dataloader_mp_context
        return self

    def set_optimizer(self, optimizer: ClassyOptimizer):
        """Set optimizer for task

        Args:
            optimizer: optimizer for task
        """
        self.optimizer = optimizer
        return self

    def set_loss(self, loss: ClassyLoss):
        """Set loss function for task

        Args:
            loss: loss for task
        """
        self.base_loss = loss
        return self

    def set_meters(self, meters: List["ClassyMeter"]):
        """Set meters for task

        Args:
            meters: list of meters to compute during training
        """
        self.meters = meters
        return self

    def set_distributed_options(
        self,
        broadcast_buffers_mode: BroadcastBuffersMode = BroadcastBuffersMode.BEFORE_EVAL,
        batch_norm_sync_mode: BatchNormSyncMode = BatchNormSyncMode.DISABLED,
        batch_norm_sync_group_size: int = 0,
        find_unused_parameters: bool = False,
        bucket_cap_mb: int = 25,
        fp16_grad_compress: bool = False,
    ):
        """Set distributed options.

        Args:
            broadcast_buffers_mode: Broadcast buffers mode. See
                :class:`BroadcastBuffersMode` for options.
            batch_norm_sync_mode: Batch normalization synchronization mode. See
                :class:`BatchNormSyncMode` for options.
            batch_norm_sync_group_size: Group size to use for synchronized batch norm.
                0 means that the stats are synchronized across all replicas. For
                efficient synchronization, set it to the number of GPUs in a node (
                usually 8).
            find_unused_parameters: See
                :class:`torch.nn.parallel.DistributedDataParallel` for information.
            bucket_cap_mb: See
                :class:`torch.nn.parallel.DistributedDataParallel` for information.
        Raises:
            RuntimeError: If batch_norm_sync_mode is `BatchNormSyncMode.APEX` and apex
                is not installed.
        """
        self.broadcast_buffers_mode = broadcast_buffers_mode

        if batch_norm_sync_group_size > 0:
            if not batch_norm_sync_mode == BatchNormSyncMode.APEX:
                # this should ideally work with PyTorch Sync BN as well, but it
                # fails while initializing DDP for some reason.
                raise ValueError(
                    "batch_norm_sync_group_size can be > 0 only when "
                    "Apex Synchronized Batch Normalization is being used."
                )
        self.batch_norm_sync_group_size = batch_norm_sync_group_size

        if batch_norm_sync_mode == BatchNormSyncMode.DISABLED:
            logging.info("Synchronized Batch Normalization is disabled")
        else:
            if batch_norm_sync_mode == BatchNormSyncMode.APEX and not apex_available:
                raise RuntimeError("apex is not installed")
            msg = f"Using Synchronized Batch Normalization using {batch_norm_sync_mode}"
            if self.batch_norm_sync_group_size > 0:
                msg += f" and group size {batch_norm_sync_group_size}"
            logging.info(msg)
        self.batch_norm_sync_mode = batch_norm_sync_mode

        if find_unused_parameters:
            logging.info("Enabling find_unused_parameters in DDP")

        self.find_unused_parameters = find_unused_parameters
        self.ddp_bucket_cap_mb = bucket_cap_mb

        if fp16_grad_compress:
            if get_torch_version() < [1, 8]:
                raise RuntimeError(
                    "FP16 grad compression is only supported since PyTorch 1.8"
                )
            logging.info("Enabling FP16 grad compression")
        self.fp16_grad_compress = fp16_grad_compress

        return self

    def set_hooks(self, hooks: List["ClassyHook"]):
        """Set hooks for task

        Args:
            hooks: List of hooks to apply during training
        """
        from classy_vision.hooks import ClassyHook

        assert isinstance(hooks, list)
        assert all(isinstance(hook, ClassyHook) for hook in hooks)
        assert len({hook.name() for hook in hooks}) == len(
            hooks
        ), "Cannot have repeated hooks of the same class"
        # TODO (zyan3): we move checkpoint hook to the end of the list because some hooks
        # may change the state of the model, and we want to save changed state in the checkpoint.
        # This is temporary fix.
        non_checkpoint_hooks = [
            hook for hook in hooks if not isinstance(hook, CheckpointHook)
        ]
        checkpoint_hooks = [hook for hook in hooks if isinstance(hook, CheckpointHook)]
        hooks = non_checkpoint_hooks + checkpoint_hooks
        self.hooks = hooks
        return self

    def set_model(self, model: ClassyModel):
        """Set model for task

        Args:
            model: Model to be trained
        """
        self.base_model = model
        return self

    def set_test_only(self, test_only: bool):
        """Set test only flag

        Args:
            test_only: If true, only test phases will be run
        """
        self.test_only = test_only
        return self

    def set_bn_weight_decay(self, bn_weight_decay: bool):
        assert type(bn_weight_decay) == bool

        self.bn_weight_decay = bn_weight_decay
        return self

    def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
        """Disable / enable apex.amp and set the automatic mixed precision parameters.

        apex.amp can be utilized for mixed / half precision training.

        Args:
            amp_args: Dictionary containing arguments to be passed to
            amp.initialize. Set to None to disable amp.  To enable mixed
            precision training, pass amp_args={"opt_level": "O1"} here.
            See https://nvidia.github.io/apex/amp.html for more info.

        Raises:
            RuntimeError: If opt_level is not None and apex is not installed.

        Warning: apex needs to be installed to utilize this feature.
        """
        self.amp_args = amp_args

        if amp_args is None:
            logging.info("AMP disabled")
        else:
            # Check that the requested AMP type is known
            try:
                self.amp_type = AmpType[self.amp_args["amp_type"].upper()]
            except KeyError:
                logging.info("AMP type not specified, defaulting to Apex")
                self.amp_type = AmpType.APEX

            # Check for CUDA availability, required for both Apex and Pytorch AMP
            if not torch.cuda.is_available():
                raise RuntimeError(
                    "AMP is required but CUDA is not supported, cannot enable AMP"
                )

            # Check for Apex availability
            if self.amp_type == AmpType.APEX and not apex_available:
                raise RuntimeError(
                    "Apex AMP is required but Apex is not installed, cannot enable AMP"
                )

            if self.use_sharded_ddp:
                if self.amp_type == AmpType.APEX:
                    raise RuntimeError(
                        "ShardedDDP has been requested, which is incompatible with Apex AMP"
                    )

                if not fairscale_available:
                    raise RuntimeError(
                        "ShardedDDP has been requested, but fairscale is not installed in the current environment"
                    )

            # Set Torch AMP grad scaler, used to prevent gradient underflow
            elif self.amp_type == AmpType.PYTORCH:

                if self.use_sharded_ddp:
                    logging.info("Using ShardedGradScaler to manage Pytorch AMP")
                    self.amp_grad_scaler = ShardedGradScaler()
                else:
                    self.amp_grad_scaler = TorchGradScaler()

            logging.info(f"AMP enabled with args {amp_args}")
        return self

    def set_mixup_transform(self, mixup_transform: Optional["MixupTransform"]):
        """Disable / enable mixup transform for data augmentation

        Args::
            mixup_transform: a callable object which performs mixup data augmentation
        """
        self.mixup_transform = mixup_transform
        if mixup_transform is None:
            logging.info("mixup disabled")
        else:
            logging.info("mixup enabled")
        return self

    def set_optimizer_schedulers(self, schedulers):
        self.optimizer_schedulers = schedulers
        return self

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
        """Instantiates a ClassificationTask from a configuration.

        Args:
            config: A configuration for a ClassificationTask.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A ClassificationTask instance.
        """
        test_only = config.get("test_only", False)
        if not test_only:
            # TODO Make distinction between epochs and phases in optimizer clear
            train_phases_per_epoch = config["dataset"]["train"].get(
                "phases_per_epoch", 1
            )

            optimizer_config = config["optimizer"]
            optimizer_config["num_epochs"] = (
                config["num_epochs"] * train_phases_per_epoch
            )
            optimizer = build_optimizer(optimizer_config)
            param_schedulers = build_optimizer_schedulers(optimizer_config)

        datasets = {}
        phase_types = ["train", "test"]
        for phase_type in phase_types:
            if phase_type in config["dataset"]:
                datasets[phase_type] = build_dataset(config["dataset"][phase_type])
        loss = build_loss(config["loss"])
        amp_args = config.get("amp_args")
        meters = build_meters(config.get("meters", {}))
        model = build_model(config["model"])

        mixup_transform = None
        if config.get("mixup") is not None:
            assert "alpha" in config["mixup"], "key alpha is missing in mixup dict"
            mixup_transform = MixupTransform(
                config["mixup"]["alpha"], config["mixup"].get("num_classes")
            )

        # hooks config is optional
        hooks_config = config.get("hooks")
        hooks = []
        if hooks_config is not None:
            hooks = build_hooks(hooks_config)

        distributed_config = config.get("distributed", {})
        distributed_options = {
            "broadcast_buffers_mode": BroadcastBuffersMode[
                distributed_config.get("broadcast_buffers", "before_eval").upper()
            ],
            "batch_norm_sync_mode": BatchNormSyncMode[
                distributed_config.get("batch_norm_sync_mode", "disabled").upper()
            ],
            "batch_norm_sync_group_size": distributed_config.get(
                "batch_norm_sync_group_size", 0
            ),
            "find_unused_parameters": distributed_config.get(
                "find_unused_parameters", False
            ),
            "bucket_cap_mb": distributed_config.get("bucket_cap_mb", 25),
            "fp16_grad_compress": distributed_config.get("fp16_grad_compress", False),
        }

        task = (
            cls()
            .set_num_epochs(config["num_epochs"])
            .set_test_phase_period(config.get("test_phase_period", 1))
            .set_loss(loss)
            .set_test_only(test_only)
            .set_model(model)
            .set_meters(meters)
            .set_amp_args(amp_args)
            .set_mixup_transform(mixup_transform)
            .set_distributed_options(**distributed_options)
            .set_hooks(hooks)
            .set_bn_weight_decay(config.get("bn_weight_decay", False))
            .set_clip_grad_norm(config.get("clip_grad_norm"))
            .set_simulated_global_batchsize(config.get("simulated_global_batchsize"))
            .set_use_sharded_ddp(config.get("use_sharded_ddp", False))
        )

        if not test_only:
            task.set_optimizer(optimizer)
            task.set_optimizer_schedulers(param_schedulers)

        use_gpu = config.get("use_gpu")
        if use_gpu is not None:
            task.set_use_gpu(use_gpu)

        for phase_type in datasets:
            task.set_dataset(datasets[phase_type], phase_type)

        # NOTE: this is a private member and only meant to be used for
        # logging/debugging purposes. See __repr__ implementation
        task._config = config

        return task

    @property
    def num_batches_per_phase(self):
        """Returns number of batches in current phase iterator"""
        return len(self.data_iterator)

    @property
    def model(self):
        """Returns model used in training (can be wrapped with DDP)"""
        return (
            self.distributed_model if is_distributed_training_run() else self.base_model
        )

    @property
    def loss(self):
        """Returns loss used in training (can be wrapped with DDP)"""
        return self.distributed_loss if self.distributed_loss else self.base_loss

    @property
    def phase_type(self):
        """Returns current phase type. String with value "train" or "test" """
        return "train" if self.train else "test"

    @property
    def eval_phase_idx(self):
        """Returns current evaluation phase"""
        return self.phase_idx - self.train_phase_idx - 1

    def get_total_training_phases(self):
        """
        Returns the total number of "train" phases in the task
        """
        num_training_phases = 0
        for phase in self.phases:
            if phase["train"] is True:
                num_training_phases += 1
        return num_training_phases

    def get_total_test_phases(self):
        """
        Returns the total number of "test" phases in the task
        """
        num_test_phases = 0
        for phase in self.phases:
            if phase["train"] is False:
                num_test_phases += 1
        return num_test_phases

    def _build_phases(self):
        """Returns list of phases from config.

        These phases will look like:
        {
          train: is this a train or test phase?
          optimizer: optimizer settings
        }

        - If this is a test only run, then only test phases will be
        generated
        - If this is a training run with both train and test datasets, then x phases =
          x train phases + x test phases, interleaved. If test_phase_period > 1, test
          phases are only added after test_phase_period train phases. The last phase is
          always a test phase.
        - If this is a training run with only a train dataset, then x phases = x train
          phases.
        """
        if not self.test_only:
            phases = [
                {"train": True}
                for _ in range(math.ceil(self.train_phases_per_epoch * self.num_epochs))
            ]

            if self._train_only:
                return phases

            final_phases = []
            for i, phase in enumerate(phases):
                final_phases.append(phase)
                if (i + 1) % self.test_phase_period == 0:
                    final_phases.append({"train": False})
            if final_phases[-1]["train"]:
                final_phases.append({"train": False})
            return final_phases

        return [{"train": False} for _ in range(self.num_epochs)]

    def build_dataloader_from_dataset(self, dataset, **kwargs):
        """Builds a dataloader from the provided dataset

        Args:
            dataset: A ClassyDataset
            kwargs: Additional kwargs to pass during dataloader construction for
                derived classes
        """
        return dataset.iterator(
            phase_type=self.phase_type,
            current_phase_id=self.train_phase_idx if self.train else 0,
            pin_memory=self.use_gpu and torch.cuda.device_count() > 1,
            multiprocessing_context=mp.get_context(self.dataloader_mp_context),
            **kwargs,
        )

    def build_dataloaders_for_current_phase(self):
        """Builds dataloader(s) for the current phase.

        Deriving classes can override this method to support custom behavior, like
        supporting multiple dataloaders in parallel.
        """
        self.dataloader = self.build_dataloader_from_dataset(
            self.datasets[self.phase_type]
        )

    def prepare_optimizer(self, optimizer, model, loss=None):
        bn_params, other_params = split_batchnorm_params(model)
        if loss is not None:
            bn_params_loss, params_loss = split_batchnorm_params(loss)
            bn_params = bn_params + bn_params_loss
            other_params = other_params + params_loss

        bn_schedulers = self.optimizer_schedulers.copy()
        if not self.bn_weight_decay:
            bn_schedulers["weight_decay"] = 0

        param_groups = [{"params": other_params, **self.optimizer_schedulers}]
        if len(bn_params) > 0:
            param_groups.append({"params": bn_params, **bn_schedulers})
        self.optimizer.set_param_groups(param_groups)

    def prepare(self):
        """Prepares task for training, populates all derived attributes """

        self.phases = self._build_phases()
        self.train = False if self.test_only else self.train

        if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
            self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(self.base_model)
        elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
            sync_bn_process_group = apex.parallel.create_syncbn_process_group(
                self.batch_norm_sync_group_size
            )
            self.base_model = apex.parallel.convert_syncbn_model(
                self.base_model, process_group=sync_bn_process_group
            )

        # move the model and loss to the right device
        if self.use_gpu:
            self.base_model, self.base_loss = copy_model_to_gpu(
                self.base_model, self.base_loss
            )
        else:
            self.base_loss.cpu()
            self.base_model.cpu()

        if self.optimizer is not None:
            self.prepare_optimizer(
                optimizer=self.optimizer, model=self.base_model, loss=self.base_loss
            )

        if self.amp_args is not None:
            if self.amp_type == AmpType.APEX:
                # Initialize apex.amp. This updates the model and the PyTorch optimizer (
                # if training, which is wrapped by the ClassyOptimizer in self.optimizer).
                # Please note this must happen before loading the checkpoint, cause
                # there's amp state to be restored.
                if self.optimizer is None:
                    self.base_model = apex.amp.initialize(
                        self.base_model, optimizers=None, **self.amp_args
                    )
                else:
                    self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                        self.base_model, self.optimizer.optimizer, **self.amp_args
                    )

        if self.simulated_global_batchsize is not None:
            if self.simulated_global_batchsize % self.get_global_batchsize() != 0:
                raise ValueError(
                    f"Global batch size ({self.get_global_batchsize()}) must divide "
                    f"simulated_global_batchsize ({self.simulated_global_batchsize})"
                )
        else:
            self.simulated_global_batchsize = self.get_global_batchsize()

        self.optimizer_period = (
            self.simulated_global_batchsize // self.get_global_batchsize()
        )
        if self.optimizer_period > 1:
            logging.info(
                f"Using gradient accumulation with a period of {self.optimizer_period}"
            )

        if self.checkpoint_path:
            self.checkpoint_dict = load_and_broadcast_checkpoint(self.checkpoint_path)

        classy_state_dict = (
            None
            if self.checkpoint_dict is None
            else self.checkpoint_dict["classy_state_dict"]
        )

        if classy_state_dict is not None:
            state_load_success = update_classy_state(self, classy_state_dict)
            assert (
                state_load_success
            ), "Update classy state from checkpoint was unsuccessful."

        self.init_distributed_data_parallel_model()

    def init_distributed_data_parallel_model(self):
        """
        Initialize
        `torch.nn.parallel.distributed.DistributedDataParallel <https://pytorch.org/
        docs/stable/nn.html#distributeddataparallel>`_.

        Needed for distributed training. This is where a model should be wrapped by DDP.
        """
        if not is_distributed_training_run():
            return
        assert (
            self.distributed_model is None
        ), "init_ddp_non_elastic must only be called once"

        broadcast_buffers = (
            self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS
        )

        if self.use_sharded_ddp:
            if not isinstance(self.optimizer, ZeRO):
                raise ValueError(
                    "ShardedDataParallel engine should only be used in conjunction with ZeRO optimizer"
                )
            from fairscale.nn.data_parallel import ShardedDataParallel

            # Replace the original DDP wrap by the shard-aware ShardedDDP
            self.distributed_model = ShardedDataParallel(
                module=self.base_model,
                sharded_optimizer=self.optimizer.optimizer,
                broadcast_buffers=broadcast_buffers,
            )
        else:
            self.distributed_model = init_distributed_data_parallel_model(
                self.base_model,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
                bucket_cap_mb=self.ddp_bucket_cap_mb,
            )
            if self.fp16_grad_compress:

                from torch.distributed.algorithms import ddp_comm_hooks

                # FP16 hook is stateless and only takes a process group as the state.
                # We use the default process group so we set the state to None.
                process_group = None
                self.distributed_model.register_comm_hook(
                    process_group,
                    ddp_comm_hooks.default_hooks.fp16_compress_hook,
                )
        if (
            isinstance(self.base_loss, ClassyLoss)
            and self.base_loss.has_learned_parameters()
        ):
            logging.info("Initializing distributed loss")
            self.distributed_loss = init_distributed_data_parallel_model(
                self.base_loss,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
                bucket_cap_mb=self.ddp_bucket_cap_mb,
            )

    @property
    def where(self):
        """Returns the proportion of training that has completed. If in test
        only mode, returns proportion of testing completed

        Returned value is a float in the range [0, 1)
        """
        current_step = self.num_updates / self.get_global_batchsize()
        num_phases = (
            self.get_total_test_phases()
            if self.test_only
            else self.get_total_training_phases()
        )

        if self.num_batches_per_phase <= 0:
            raise RuntimeError("No batches to read. Is the dataset empty?")

        num_steps = num_phases * self.num_batches_per_phase
        where = current_step / num_steps

        return where

    def get_classy_state(self, deep_copy: bool = False):
        """Returns serialiable state of task

        Args:
            deep_copy: If true, does a deep copy of state before returning.
        """
        optimizer_state = {}
        if self.optimizer is not None:
            optimizer_state = self.optimizer.get_classy_state()

        classy_state_dict = {
            "train": self.train,
            "base_model": self.base_model.get_classy_state(),
            "meters": [meter.get_classy_state() for meter in self.meters],
            "optimizer": optimizer_state,
            "phase_idx": self.phase_idx,
            "train_phase_idx": self.train_phase_idx,
            "num_updates": self.num_updates,
            "losses": self.losses,
            "hooks": {hook.name(): hook.get_classy_state() for hook in self.hooks},
            "loss": {},
        }
        if "train" in self.datasets and self._is_checkpointable_dataset(
            self.datasets["train"]
        ):
            classy_state_dict["train_dataset_iterator"] = self.datasets[
                "train"
            ].get_classy_state()

        if isinstance(self.base_loss, ClassyLoss):
            classy_state_dict["loss"] = self.base_loss.get_classy_state()
        if self.amp_args is not None:
            if self.amp_type == AmpType.APEX:
                classy_state_dict["amp"] = apex.amp.state_dict()

            elif self.amp_grad_scaler is not None:
                classy_state_dict["amp"] = self.amp_grad_scaler.state_dict()

        if deep_copy:
            classy_state_dict = copy.deepcopy(classy_state_dict)
        return classy_state_dict

    def set_classy_state(self, state):
        """Set task state

        Args:
            state: Dict containing state of a task
        """
        self.train = False if self.test_only else state["train"]
        self.base_model.set_classy_state(state["base_model"])

        if self.test_only:
            # if we're only testing, just need the state of the model to be updated
            return

        self.phase_idx = state["phase_idx"]
        self.num_updates = state["num_updates"]
        self.train_phase_idx = state["train_phase_idx"]
        self.losses = state["losses"]
        for meter, meter_state in zip(self.meters, state["meters"]):
            meter.set_classy_state(meter_state)

        if self.optimizer is not None:
            self.optimizer.set_classy_state(state["optimizer"])
        if state.get("loss") and isinstance(self.base_loss, ClassyLoss):
            self.base_loss.set_classy_state(state["loss"])

        if "amp" in state:
            if self.amp_type == AmpType.APEX:
                apex.amp.load_state_dict(state["amp"])
            else:
                self.amp_grad_scaler.load_state_dict(state["amp"])

        for hook in self.hooks:
            # we still want to be able to run when new hooks are added or old
            # hooks are removed
            if hook.name() in state["hooks"]:
                hook.set_classy_state(state["hooks"][hook.name()])
            else:
                logging.warning(f"No state found for hook: {hook.name()}")

        if "train" in self.datasets and self._is_checkpointable_dataset(
            self.datasets["train"]
        ):
            self.datasets["train"].set_classy_state(state.get("train_dataset_iterator"))

    @staticmethod
    def _is_checkpointable_dataset(dataset):
        return hasattr(dataset, "get_classy_state") and hasattr(
            dataset, "set_classy_state"
        )

    def eval_step(self):
        self.last_batch = None

        # Process next sample
        with Timer() as timer:
            sample = next(self.data_iterator)

        assert isinstance(sample, dict) and "input" in sample and "target" in sample, (
            f"Returned sample [{sample}] is not a map with 'input' and"
            + "'target' keys"
        )

        target = sample["target"]
        if self.use_gpu:
            sample = recursive_copy_to_gpu(sample, non_blocking=True)

        # Optional Pytorch AMP context
        torch_amp_context = (
            torch.cuda.amp.autocast()
            if self.amp_type == AmpType.PYTORCH
            else contextlib.suppress()
        )

        with torch.no_grad(), torch_amp_context:
            output = self.model(sample["input"])

            local_loss = self.compute_loss(output, sample)

            loss = local_loss.detach().clone()

            self.check_inf_nan(loss)

            self.losses.append(loss.data.cpu().item())

            self.update_meters(output, sample)

        # Move some data to the task so hooks get a chance to access it
        self.last_batch = LastBatchInfo(
            loss=loss,
            output=output,
            target=target,
            sample=sample,
            step_data={"sample_fetch_time": timer.elapsed_time},
        )

    def check_inf_nan(self, loss):
        if loss == float("inf") or loss == float("-inf") or loss != loss:
            raise FloatingPointError(f"Loss is infinity or NaN: {loss}")

    def _should_do_step(self):
        """Tells if we will be performing an optimizer step.

        Returns True always if there is no gradient accumulation. With gradient
        accumulation returns True only when the gradients will be synchronized and we
        will be performing an optimizer step.
        """
        update_idx = self.num_updates // self.get_global_batchsize()
        return (update_idx % self.optimizer_period) == self.optimizer_period - 1

    def train_step(self):
        """Train step to be executed in train loop."""

        self.last_batch = None

        # Process next sample
        with Timer() as timer:
            sample = next(self.data_iterator)

        assert isinstance(sample, dict) and "input" in sample and "target" in sample, (
            f"Returned sample [{sample}] is not a map with 'input' and"
            + "'target' keys"
        )

        # Copy sample to GPU
        target = sample["target"]
        if self.use_gpu:
            sample = recursive_copy_to_gpu(sample, non_blocking=True)

        if self.mixup_transform is not None:
            sample = self.mixup_transform(sample)

        # Optional Pytorch AMP context
        torch_amp_context = (
            torch.cuda.amp.autocast()
            if self.amp_type == AmpType.PYTORCH
            else contextlib.suppress()
        )

        # only sync with DDP when we need to perform an optimizer step
        # an optimizer step can be skipped if gradient accumulation is enabled
        do_step = self._should_do_step()
        ctx_mgr_model = (
            self.distributed_model.no_sync()
            if self.distributed_model is not None and not do_step
            else contextlib.suppress()
        )
        ctx_mgr_loss = (
            self.distributed_loss.no_sync()
            if self.distributed_loss is not None and not do_step
            else contextlib.suppress()
        )

        with ctx_mgr_model, ctx_mgr_loss:
            # Forward pass
            with torch.enable_grad(), torch_amp_context:
                output = self.model(sample["input"])

                local_loss = self.compute_loss(output, sample)
                loss = local_loss.detach().clone()
                self.losses.append(loss.data.cpu().item())

                self.update_meters(output, sample)

            # Backwards pass + optimizer step
            self.run_optimizer(local_loss)

        self.num_updates += self.get_global_batchsize()

        # Move some data to the task so hooks get a chance to access it
        self.last_batch = LastBatchInfo(
            loss=loss,
            output=output,
            target=target,
            sample=sample,
            step_data={"sample_fetch_time": timer.elapsed_time},
        )

    def compute_loss(self, model_output, sample):
        return self.loss(model_output, sample["target"])

    def run_optimizer(self, loss):
        """Runs backwards pass and update the optimizer"""

        self.check_inf_nan(loss)

        # Gradient accumulation logic. We always set optimizer_period, even
        # if gradient accumulation is disabled. Assumes all batches have the
        # same size
        update_idx = self.num_updates // self.get_global_batchsize()
        do_zero_grad = (update_idx % self.optimizer_period) == 0
        do_step = self._should_do_step()

        if do_zero_grad:
            self.optimizer.zero_grad()

        if self.amp_type == AmpType.APEX:
            with apex.amp.scale_loss(loss, self.optimizer.optimizer) as scaled_loss:
                scaled_loss.backward()
        elif self.amp_type == AmpType.PYTORCH:
            self.amp_grad_scaler.scale(loss).backward()
        else:
            loss.backward()

        if do_step:
            # Handle gradient accumulation related gradient rescaling
            if self.optimizer_period != 1:
                self._rescale_gradients(1 / self.optimizer_period)

            # Clipping must happen after grad accumulation
            if self.clip_grad_norm is not None:
                self._clip_gradients(self.clip_grad_norm)

            if self.amp_type == AmpType.PYTORCH:
                # If using mixed precision, handle underflow-related scaling
                # See https://pytorch.org/docs/stable/amp.html#gradient-scaling
                # for context
                self.amp_grad_scaler.step(self.optimizer, where=self.where)
                self.amp_grad_scaler.update()
            else:
                self.optimizer.step(where=self.where)

    def _rescale_gradients(self, scale):
        for param in master_params(self.optimizer):
            if param.grad is not None:
                param.grad.data.mul_(scale)

    def _clip_gradients(self, max_norm):
        nn.utils.clip_grad_norm_(master_params(self.optimizer), max_norm)

    def update_meters(self, model_output, sample):
        target = sample["target"].detach().cpu()
        model_output = model_output.detach().cpu()

        # Update meters
        for meter in self.meters:
            meter.update(model_output, target, is_train=self.train)

    def synchronize_losses(self):
        """Average the losses across the different replicas"""

        # Average losses across nodes
        losses_tensor = torch.tensor(self.losses)
        synchronized_losses_tensor = all_reduce_mean(losses_tensor)
        self.losses = synchronized_losses_tensor.tolist()

    def advance_phase(self):
        """Performs bookkeeping / task updates between phases

        Increments phase idx, resets meters, resets loss history,
        resets counters, shuffles dataset, rebuilds iterators, and
        sets the train / test state for phase.
        """
        logging.debug("Advancing phase")
        # Reset meters for next phase / epoch
        for meter in self.meters:
            meter.reset()

        # Reset loss history for next epoch
        self.losses = []

        # Setup new phase
        self.phase_idx += 1
        phase = self.phases[self.phase_idx]
        self.train = True if phase["train"] else False
        if self.train:
            self.train_phase_idx += 1

        # Re-build dataloader & re-create iterator anytime membership changes.
        self.build_dataloaders_for_current_phase()
        self.create_data_iterators()
        # Set up pytorch module in train vs eval mode, update optimizer.
        self._set_model_train_mode()

    def done_training(self):
        """Stop condition for training"""
        return self.phase_idx + 1 >= len(self.phases)

    def create_data_iterators(self):
        """Creates data iterator(s) for the current phase."""
        # Delete iterator explicitly so that all dataloader processes
        # are cleaned up.
        del self.data_iterator
        self.data_iterator = iter(self.dataloader)

    def _set_model_train_mode(self):
        """Set train mode for model"""
        phase = self.phases[self.phase_idx]
        self.base_model.train(phase["train"])
        self.base_loss.train(phase["train"])

        if (
            self.broadcast_buffers_mode == BroadcastBuffersMode.BEFORE_EVAL
            and not self.train
        ):
            self._broadcast_buffers()

    def _broadcast_buffers(self):
        """Explicitly synchronize buffers across all devices."""
        if self.distributed_model is None:
            return
        buffers = list(self.base_model.buffers())
        if len(buffers) > 0:
            logging.info("Synchronizing buffers before evaluation.")
            for buffer in buffers:
                broadcast(buffer, 0, group=self.distributed_model.process_group)

    # TODO: Functions below should be better abstracted into the dataloader
    # abstraction
    def get_batchsize_per_replica(self):
        """Return local replica's batchsize for dataset (e.g. batchsize per GPU)"""
        return self.datasets[self.phase_type].get_batchsize_per_replica()

    def get_global_batchsize(self):
        """Return global batchsize across all trainers"""
        return self.datasets[self.phase_type].get_global_batchsize()

    def on_start(self):
        for hook in self.hooks:
            hook.on_start(self)

    def on_phase_start(self):
        self.phase_start_time_total = time.perf_counter()

        self.advance_phase()

        for hook in self.hooks:
            hook.on_phase_start(self)

        self.phase_start_time_train = time.perf_counter()

    def on_phase_end(self):
        self.log_phase_end("train")

        if self.train:
            self.optimizer.on_epoch(where=self.where)

        logging.debug("Syncing losses on phase end...")
        self.synchronize_losses()
        logging.debug("...losses synced")

        logging.debug("Syncing meters on phase end...")
        for meter in self.meters:
            meter.sync_state()
        logging.debug("...meters synced")
        barrier()

        for hook in self.hooks:
            hook.on_phase_end(self)
        self.perf_log = []

        self.log_phase_end("total")

        if hasattr(self.datasets[self.phase_type], "on_phase_end"):
            self.datasets[self.phase_type].on_phase_end()

    def on_end(self):
        for hook in self.hooks:
            hook.on_end(self)

    def log_phase_end(self, tag):
        if not self.train:
            return

        start_time = (
            self.phase_start_time_train
            if tag == "train"
            else self.phase_start_time_total
        )
        phase_duration = time.perf_counter() - start_time
        im_per_sec = (
            self.get_global_batchsize() * self.num_batches_per_phase
        ) / phase_duration
        self.perf_log.append(
            {
                "tag": tag,
                "phase_idx": self.train_phase_idx,
                "epoch_duration": phase_duration,
                "im_per_sec": im_per_sec,
            }
        )

    def __repr__(self):
        if hasattr(self, "_config"):
            config = json.dumps(self._config, indent=4)
            return f"{super().__repr__()} initialized with config:\n{config}"

        return super().__repr__()
Пример #20
0
class DeepvacTrain(Deepvac):
    def __init__(self, deepvac_config):
        deepvac_config.is_forward_only = False
        super(DeepvacTrain, self).__init__(deepvac_config)
        self.initTrainParameters()
        self.initTrainContext()

    def setTrainContext(self):
        self.is_train = True
        self.is_val = False
        self.phase = 'TRAIN'
        self.dataset = self.train_dataset
        self.loader = self.train_loader
        self.batch_size = self.conf.train.batch_size
        self.net.train()

    def setValContext(self):
        self.is_train = False
        self.is_val = True
        self.phase = 'VAL'
        self.dataset = self.val_dataset
        self.loader = self.val_loader
        self.batch_size = self.conf.val.batch_size
        self.net.eval()

    def initTrainContext(self):
        self.scheduler = None
        self.initOutputDir()
        self.initSummaryWriter()
        self.initCriterion()
        self.initOptimizer()
        self.initScheduler()
        self.initCheckpoint()
        self.initTrainLoader()
        self.initValLoader()

    def initTrainParameters(self):
        self.dataset = None
        self.loader = None
        self.target = None
        self.epoch = 0
        self.step = 0
        self.iter = 0
        # Creates a GradScaler once at the beginning of training.
        self.scaler = GradScaler()
        self.train_time = AverageMeter()
        self.load_data_time = AverageMeter()
        self.data_cpu2gpu_time = AverageMeter()
        self._mandatory_member_name = [
            'train_dataset', 'val_dataset', 'train_loader', 'val_loader',
            'net', 'criterion', 'optimizer'
        ]

    def initOutputDir(self):
        if self.conf.output_dir != 'output' and self.conf.output_dir != './output':
            LOG.logW(
                "According deepvac standard, you should set config.output_dir to [output] rather than [{}]."
                .format(self.conf.output_dir))

        self.output_dir = '{}/{}'.format(self.conf.output_dir, self.branch)
        LOG.logI('model save dir: {}'.format(self.output_dir))
        #for DDP race condition
        os.makedirs(self.output_dir, exist_ok=True)

    def initSummaryWriter(self):
        event_dir = "{}/{}".format(self.conf.log_dir, self.branch)
        self.writer = SummaryWriter(event_dir)
        if not self.conf.tensorboard_port:
            return
        from tensorboard import program
        tensorboard = program.TensorBoard()
        self.conf.tensorboard_ip = '0.0.0.0' if self.conf.tensorboard_ip is None else self.conf.tensorboard_ip
        tensorboard.configure(argv=[
            None, '--host',
            str(self.conf.tensorboard_ip), '--logdir', event_dir, "--port",
            str(self.conf.tensorboard_port)
        ])
        try:
            url = tensorboard.launch()
            LOG.logI('Tensorboard at {} '.format(url))
        except Exception as e:
            LOG.logE(e.msg)

    def initCriterion(self):
        self.criterion = torch.nn.CrossEntropyLoss()
        LOG.logW(
            "You should reimplement initCriterion() to initialize self.criterion, unless CrossEntropyLoss() is exactly what you need"
        )

    def initCheckpoint(self):
        if not self.conf.checkpoint_suffix or self.conf.checkpoint_suffix == "":
            LOG.logI('Omit the checkpoint file since not specified...')
            return
        LOG.logI('Load checkpoint from {} folder'.format(self.output_dir))
        self.net.load_state_dict(
            torch.load(self.output_dir +
                       '/model__{}'.format(self.conf.checkpoint_suffix),
                       map_location=self.device))
        state_dict = torch.load(
            self.output_dir +
            '/checkpoint__{}'.format(self.conf.checkpoint_suffix),
            map_location=self.device)
        self.optimizer.load_state_dict(state_dict['optimizer'])
        if self.scheduler:
            self.scheduler.load_state_dict(state_dict['scheduler'])
        if self.conf.amp:
            LOG.logI(
                "Will load scaler from checkpoint since you enabled amp, make sure the checkpoint was saved with amp enabled."
            )
            try:
                self.scaler.load_state_dict(state_dict["scaler"])
            except:
                LOG.logI(
                    "checkpoint was saved without amp enabled, so use fresh GradScaler instead."
                )
                self.scaler = GradScaler()

        self.epoch = state_dict['epoch']
        if self.conf.ema:
            self.ema.load_state_dict(state_dict['ema'])

    def initScheduler(self):
        if isinstance(self.conf.lr_step, list):
            self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
                self.optimizer, self.conf.lr_step, self.conf.lr_factor)
        elif isinstance(self.conf.lr_step, Callable):
            self.scheduler = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer, lr_lambda=self.conf.lr_step)
        else:
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer, self.conf.lr_step, self.conf.lr_factor)
        LOG.logW(
            "You should reimplement initScheduler() to initialize self.scheduler, unless lr_scheduler.StepLR() or lr_scheduler.MultiStepLR() is exactly what you need"
        )

    def initTrainLoader(self):
        self.train_loader = None
        LOG.logE(
            "You must reimplement initTrainLoader() to initialize self.train_loader",
            exit=True)

    def initValLoader(self):
        self.val_loader = None
        LOG.logE(
            "You must reimplement initTrainLoader() to initialize self.val_loader",
            exit=True)

    def initOptimizer(self):
        self.initSgdOptimizer()
        LOG.logW(
            "You should reimplement initOptimizer() to initialize self.optimizer, unless SGD is exactly what you need"
        )

    def initSgdOptimizer(self):
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=self.conf.lr,
                                   momentum=self.conf.momentum,
                                   weight_decay=self.conf.weight_decay,
                                   nesterov=self.conf.nesterov)

    def initAdamOptimizer(self):
        self.optimizer = optim.Adam(
            self.net.parameters(),
            lr=self.conf.lr,
            betas=self.conf.betas if self.conf.betas else (0.9, 0.999),
            weight_decay=self.conf.weight_decay
            if self.conf.weight_decay else 0)
        for group in self.optimizer.param_groups:
            group.setdefault('initial_lr', group['lr'])

    def initRmspropOptimizer(self):
        self.optimizer = optim.RMSprop(
            self.net.parameters(),
            lr=self.conf.lr,
            momentum=self.conf.momentum,
            weight_decay=self.conf.weight_decay,
            # alpha=self.conf.rmsprop_alpha,
            # centered=self.conf.rmsprop_centered
        )

    def addScalar(self, tag, value, step):
        self.writer.add_scalar(tag, value, step)

    def addImage(self, tag, image, step):
        self.writer.add_image(tag, image, step)

    @syszux_once
    def addGraph(self, input):
        try:
            self.writer.add_graph(self.net, input)
        except:
            LOG.logW(
                "Tensorboard addGraph failed. You network foward may have more than one parameters?"
            )
            LOG.logW("Seems you need reimplement preIter function.")

    def earlyIter(self):
        self.feedSample()
        self.feedTarget()

    def feedSample(self):
        self.sample = self.sample.to(self.device)

    def feedTarget(self):
        self.target = self.target.to(self.device)

    def preIter(self):
        pass

    def postIter(self):
        pass

    def preEpoch(self):
        pass

    def postEpoch(self):
        pass

    def doForward(self):
        self.output = self.net(self.sample)

    def doLoss(self):
        self.loss = self.criterion(self.output, self.target)

    def doBackward(self):
        if self.conf.amp:
            self.scaler.scale(self.loss).backward()
        else:
            self.loss.backward()

    def doOptimize(self):
        if self.iter % self.conf.nominal_batch_factor != 0:
            return
        if self.conf.amp:
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            self.optimizer.step()
        self.optimizer.zero_grad()
        if self.conf.ema:
            self.updateEMA()

    def doLog(self):
        if self.step % self.conf.log_every != 0:
            return
        self.addScalar('{}/Loss'.format(self.phase), self.loss.item(),
                       self.iter)
        self.addScalar('{}/LoadDataTime(secs/batch)'.format(self.phase),
                       self.load_data_time.val, self.iter)
        self.addScalar('{}/DataCpu2GpuTime(secs/batch)'.format(self.phase),
                       self.data_cpu2gpu_time.val, self.iter)
        self.addScalar('{}/TrainTime(secs/batch)'.format(self.phase),
                       self.train_time.val, self.iter)
        LOG.logI('{}: [{}][{}/{}] [Loss:{}  Lr:{}]'.format(
            self.phase, self.epoch, self.step, self.loader_len,
            self.loss.item(), self.optimizer.param_groups[0]['lr']))

    def saveState(self, current_time):
        file_partial_name = '{}__acc_{}__epoch_{}__step_{}__lr_{}'.format(
            current_time, self.accuracy, self.epoch, self.step,
            self.optimizer.param_groups[0]['lr'])
        state_file = '{}/model__{}.pth'.format(self.output_dir,
                                               file_partial_name)
        checkpoint_file = '{}/checkpoint__{}.pth'.format(
            self.output_dir, file_partial_name)
        output_trace_file = '{}/trace__{}.pt'.format(self.output_dir,
                                                     file_partial_name)
        output_script_file = '{}/script__{}.pt'.format(self.output_dir,
                                                       file_partial_name)
        output_onnx_file = '{}/onnx__{}.onnx'.format(self.output_dir,
                                                     file_partial_name)
        output_ncnn_file = '{}/ncnn__{}.bin'.format(self.output_dir,
                                                    file_partial_name)
        output_coreml_file = '{}/coreml__{}.mlmodel'.format(
            self.output_dir, file_partial_name)
        #save state_dict
        net = self.ema if self.conf.ema else self.net
        torch.save(net.state_dict(), state_file)
        #save checkpoint
        torch.save(
            {
                'optimizer': self.optimizer.state_dict(),
                'epoch': self.epoch,
                'scheduler':
                self.scheduler.state_dict() if self.scheduler else None,
                'ema': self.ema.state_dict() if self.conf.ema else None,
                'scaler': self.scaler.state_dict() if self.conf.amp else None
            }, checkpoint_file)

        self.exportTorchViaTrace(self.sample, output_trace_file)
        self.exportTorchViaScript(output_script_file)
        self.exportONNX(self.sample, output_onnx_file)
        self.exportNCNN(self.sample, output_ncnn_file)
        self.exportCoreML(self.sample, output_coreml_file)
        #tensorboard
        self.addScalar('{}/Accuracy'.format(self.phase), self.accuracy,
                       self.iter)

    def processTrain(self):
        self.setTrainContext()
        self.step = 0
        LOG.logI('Phase {} started...'.format(self.phase))
        self.loader_len = len(self.loader)
        save_every = self.loader_len // self.conf.save_num
        save_list = list(range(0, self.loader_len + 1, save_every))
        self.save_list = save_list[1:-1]
        LOG.logI('Model will be saved on step {} and the epoch end.'.format(
            self.save_list))
        self.addScalar('{}/LR'.format(self.phase),
                       self.optimizer.param_groups[0]['lr'], self.epoch)
        self.preEpoch()
        self.train_time.reset()
        self.load_data_time.reset()
        self.data_cpu2gpu_time.reset()

        iter_tick = time.time()
        for i, (sample, target) in enumerate(self.loader):
            self.load_data_time.update(time.time() - iter_tick)
            self.step = i
            self.target = target
            self.sample = sample
            self.preIter()
            feed_sample_tick = time.time()
            self.earlyIter()
            self.data_cpu2gpu_time.update(time.time() - feed_sample_tick)
            self.addGraph(self.sample)
            with autocast(enabled=self.conf.amp if self.conf.amp else False):
                self.doForward()
                self.doLoss()
            self.doBackward()
            self.doOptimize()
            self.doLog()
            self.postIter()
            self.iter += 1
            self.train_time.update(time.time() - iter_tick)
            if self.step in self.save_list:
                self.processVal()
                self.setTrainContext()
            iter_tick = time.time()

        self.addScalar('{}/TrainTime(hours/epoch)'.format(self.phase),
                       round(self.train_time.sum / 3600, 2), self.epoch)
        self.addScalar(
            '{}/AverageBatchTrainTime(secs/epoch)'.format(self.phase),
            self.train_time.avg, self.epoch)
        self.addScalar(
            '{}/AverageBatchLoadDataTime(secs/epoch)'.format(self.phase),
            self.load_data_time.avg, self.epoch)
        self.addScalar(
            '{}/AverageBatchDataCpu2GpuTime(secs/epoch)'.format(self.phase),
            self.data_cpu2gpu_time.avg, self.epoch)

        self.postEpoch()
        if self.scheduler:
            self.scheduler.step()

    def processVal(self, smoke=False):
        self.setValContext()
        LOG.logI('Phase {} started...'.format(self.phase))
        with torch.no_grad():
            self.preEpoch()
            for i, (sample, target) in enumerate(self.loader):
                self.target = target
                self.sample = sample
                self.preIter()
                self.earlyIter()
                self.doForward()
                self.doLoss()
                self.smokeTestForExport3rd()
                self.postIter()
                if smoke:
                    return
                LOG.logI('{}: [{}][{}/{}]'.format(self.phase, self.epoch, i,
                                                  len(self.loader)))
            self.postEpoch()
        self.saveState(self.getTime())

    def processAccept(self):
        self.setValContext()

    def process(self):
        self.auditConfig()
        self.iter = 0
        epoch_start = self.epoch
        if self.conf.ema:
            self.ema_updates = self.epoch * len(
                self.train_loader) // self.conf.nominal_batch_factor
        self.processVal(smoke=True)
        self.optimizer.zero_grad()
        for epoch in range(epoch_start, self.conf.epoch_num):
            self.epoch = epoch
            LOG.logI('Epoch {} started...'.format(self.epoch))
            self.processTrain()
            self.processVal()
            self.processAccept()

    def __call__(self):
        self.process()
Пример #21
0
                if i % args.logfreq == 0:
                    niter = epoch*len(train_loader)+i
                    tb_writer.add_scalar('Train/Loss', loss_reducer(runningLoss), niter)
                    wandb.log({"Epoch":epoch, "TrainLoss":loss_reducer(runningLoss)})#, step=niter)
                    # tensorboard_images(tb_writer, inp, out.detach(), gt, epoch, 'train')
                    runningLoss = []
            
            if args.finetune or (epoch % args.savefreq == 0):              
                checkpoint = {
                    'epoch': epoch,
                    'iterations': (epoch+1)*len(train_loader),
                    'best_loss': best_loss,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'AMPScaler': scaler.state_dict()         
                }
                torch.save(checkpoint, os.path.join(save_path, trainID+".pth.tar"))
                if args.modelid != 9:
                    torch.onnx.export(model, images, trainID+".onnx", input_names=["LRCurrTP"], output_names=["SuperResolvedCurrTP"])
                    wandb.save(trainID+".onnx")

            tb_writer.add_scalar('Train/EpochLoss', loss_reducer(train_loss), epoch)
            wandb.log({"TrainEpochLoss":loss_reducer(train_loss)})#, step=epoch)

            #Validate
            if val_loader:
                model.eval()
                with torch.no_grad():
                    runningLoss = []
                    val_loss = []
Пример #22
0
def main(args):
    comm = MPI.COMM_WORLD
    world_size = comm.Get_size()
    rank = comm.Get_rank()
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = str(args.master_port)
    torch.cuda.set_device(rank)
    dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
    device = torch.device("cuda")

    logger = None
    tb_logger = None
    if rank == 0:
        if not os.path.exists(args.save_path):
            os.mkdir(args.save_path)
        if not os.path.exists(args.tensorboard_log_dir):
            os.mkdir(args.tensorboard_log_dir)
        tb_logger = SummaryWriter(
            f"{args.tensorboard_log_dir}/{args.model_name}")

        logger = logging.getLogger(__name__)
        logger.setLevel(logging.DEBUG)
        handler = TqdmLoggingHandler()
        handler.setFormatter(logging.Formatter(" %(asctime)s - %(message)s"))
        logger.addHandler(handler)
        logger.propagate = False

    write_log(logger, "Load data")

    def load_data(args):
        gc.disable()
        with open(f"{args.preprocessed_data_path}/hanja_korean_word2id.pkl",
                  "rb") as f:
            data = pickle.load(f)
            hanja_word2id = data['hanja_word2id']
            korean_word2id = data['korean_word2id']

        with open(f"{args.preprocessed_data_path}/preprocessed_train.pkl",
                  "rb") as f:
            data = pickle.load(f)
            train_hanja_indices = data['hanja_indices']
            train_korean_indices = data['korean_indices']
            train_additional_hanja_indices = data['additional_hanja_indices']

        with open(f"{args.preprocessed_data_path}/preprocessed_valid.pkl",
                  "rb") as f:
            data = pickle.load(f)
            valid_hanja_indices = data['hanja_indices']
            valid_korean_indices = data['korean_indices']
            valid_additional_hanja_indices = data['additional_hanja_indices']

        gc.enable()
        write_log(logger, "Finished loading data!")
        return (hanja_word2id, korean_word2id, train_hanja_indices,
                train_korean_indices, train_additional_hanja_indices,
                valid_hanja_indices, valid_korean_indices,
                valid_additional_hanja_indices)

    # load data
    (hanja_word2id, korean_word2id, train_hanja_indices, train_korean_indices,
     train_additional_hanja_indices, valid_hanja_indices, valid_korean_indices,
     valid_additional_hanja_indices) = load_data(args)
    hanja_vocab_num = len(hanja_word2id)
    korean_vocab_num = len(korean_word2id)

    hk_dataset = HanjaKoreanDataset(train_hanja_indices,
                                    train_korean_indices,
                                    min_len=args.min_len,
                                    src_max_len=args.src_max_len,
                                    trg_max_len=args.trg_max_len)
    hk_sampler = DistributedSampler(hk_dataset,
                                    num_replicas=world_size,
                                    rank=rank)
    hk_loader = DataLoader(hk_dataset,
                           drop_last=True,
                           batch_size=args.hk_batch_size,
                           sampler=hk_sampler,
                           num_workers=args.num_workers,
                           prefetch_factor=4,
                           pin_memory=True)
    write_log(logger, f"hanja-korean: {len(hk_dataset)}, {len(hk_loader)}")

    h_dataset = HanjaDataset(train_hanja_indices,
                             train_additional_hanja_indices,
                             hanja_word2id,
                             min_len=args.min_len,
                             src_max_len=args.src_max_len)
    h_sampler = DistributedSampler(h_dataset,
                                   num_replicas=world_size,
                                   rank=rank)
    h_loader = DataLoader(h_dataset,
                          drop_last=True,
                          batch_size=args.h_batch_size,
                          sampler=h_sampler,
                          num_workers=args.num_workers,
                          prefetch_factor=4,
                          pin_memory=True)
    write_log(logger, f"hanja: {len(h_dataset)}, {len(h_loader)}")

    hk_valid_dataset = HanjaKoreanDataset(valid_hanja_indices,
                                          valid_korean_indices,
                                          min_len=args.min_len,
                                          src_max_len=args.src_max_len,
                                          trg_max_len=args.trg_max_len)
    hk_valid_sampler = DistributedSampler(hk_valid_dataset,
                                          num_replicas=world_size,
                                          rank=rank)
    hk_valid_loader = DataLoader(hk_valid_dataset,
                                 drop_last=True,
                                 batch_size=args.hk_batch_size,
                                 sampler=hk_valid_sampler)
    write_log(
        logger,
        f"hanja-korean-valid: {len(hk_valid_dataset)}, {len(hk_valid_loader)}")

    h_valid_dataset = HanjaDataset(valid_hanja_indices,
                                   valid_additional_hanja_indices,
                                   hanja_word2id,
                                   min_len=args.min_len,
                                   src_max_len=args.src_max_len)
    h_valid_sampler = DistributedSampler(h_valid_dataset,
                                         num_replicas=world_size,
                                         rank=rank)
    h_valid_loader = DataLoader(h_valid_dataset,
                                drop_last=True,
                                batch_size=args.h_batch_size,
                                sampler=h_valid_sampler)
    write_log(logger, f"hanja: {len(h_valid_dataset)}, {len(h_valid_loader)}")

    del (train_hanja_indices, train_korean_indices,
         train_additional_hanja_indices, valid_hanja_indices,
         valid_korean_indices, valid_additional_hanja_indices)

    write_log(logger, "Build model")
    model = Transformer(hanja_vocab_num,
                        korean_vocab_num,
                        pad_idx=args.pad_idx,
                        bos_idx=args.bos_idx,
                        eos_idx=args.eos_idx,
                        src_max_len=args.src_max_len,
                        trg_max_len=args.trg_max_len,
                        d_model=args.d_model,
                        d_embedding=args.d_embedding,
                        n_head=args.n_head,
                        dropout=args.dropout,
                        dim_feedforward=args.dim_feedforward,
                        num_encoder_layer=args.num_encoder_layer,
                        num_decoder_layer=args.num_decoder_layer,
                        num_mask_layer=args.num_mask_layer).to(device)
    model = nn.parallel.DistributedDataParallel(model,
                                                device_ids=[device],
                                                find_unused_parameters=True)
    for param in model.parameters():
        dist.broadcast(param.data, 0)

    dist.barrier()
    write_log(
        logger,
        f"Total Parameters: {sum([p.nelement() for p in model.parameters()])}")

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = Ralamb(params=optimizer_grouped_parameters, lr=args.lr)

    total_iters = round(
        len(hk_loader) / args.num_grad_accumulate * args.epochs)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, round(total_iters * args.warmup_ratio), total_iters)
    scaler = GradScaler()

    start_epoch = 0
    if args.resume:

        def load_states():
            checkpoint = torch.load(
                f'{args.save_path}/{args.model_name}_ckpt.pt',
                map_location='cpu')
            start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            scaler.load_state_dict(checkpoint['scaler'])
            return start_epoch

        start_epoch = load_states()

    write_log(logger, f"Training start - Total iter: {total_iters}\n")
    iter_num = round(len(hk_loader) / args.num_grad_accumulate)
    global_step = start_epoch * iter_num
    hk_iter = iter(hk_loader)
    h_iter = iter(h_loader)
    model.train()
    tgt_mask = Transformer.generate_square_subsequent_mask(
        args.trg_max_len - 1, device)

    # validation
    validate(model, tgt_mask, h_valid_loader, hk_valid_loader, rank, logger,
             tb_logger, 0, device)

    for epoch in range(start_epoch + 1, args.epochs + 1):
        while True:
            start = time.time()
            finish_epoch = False
            trans_top5, trans_loss, mask_top5, mask_loss = 0.0, 0.0, 0.0, 0.0

            if args.train_reconstruct:
                optimizer.zero_grad(set_to_none=True)
                for _ in range(args.num_grad_accumulate):
                    try:
                        src_sequences, trg_sequences = next(h_iter)
                    except StopIteration:
                        h_sampler.set_epoch(epoch)
                        h_iter = iter(h_loader)
                        src_sequences, trg_sequences = next(h_iter)

                    trg_sequences = trg_sequences.to(device)
                    src_sequences = src_sequences.to(device)
                    non_pad = trg_sequences != args.pad_idx
                    trg_sequences = trg_sequences[non_pad].contiguous().view(
                        -1)

                    with autocast():
                        predicted = model.module.reconstruct_predict(
                            src_sequences, masked_position=non_pad)
                        predicted = predicted.view(-1, predicted.size(-1))
                        loss = label_smoothing_loss(
                            predicted,
                            trg_sequences) / args.num_grad_accumulate

                    scaler.scale(loss).backward()

                    if global_step % args.print_freq == 0:
                        mask_top5 += accuracy(predicted, trg_sequences,
                                              5) / args.num_grad_accumulate
                        mask_loss += loss.detach().item()

                for param in model.parameters():
                    if param.grad is not None:
                        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
                        param.grad.data = param.grad.data / world_size

                scaler.step(optimizer)
                scaler.update()

            if args.train_translate:
                optimizer.zero_grad(set_to_none=True)
                for _ in range(args.num_grad_accumulate):
                    try:
                        src_sequences, trg_sequences = next(hk_iter)
                    except StopIteration:
                        hk_sampler.set_epoch(epoch)
                        hk_iter = iter(hk_loader)
                        src_sequences, trg_sequences = next(hk_iter)
                        finish_epoch = True

                    trg_sequences = trg_sequences.to(device)
                    trg_sequences_target = trg_sequences[:, 1:]
                    src_sequences = src_sequences.to(device)
                    non_pad = trg_sequences_target != args.pad_idx
                    trg_sequences_target = trg_sequences_target[
                        non_pad].contiguous().view(-1)

                    with autocast():
                        predicted = model(src_sequences,
                                          trg_sequences[:, :-1],
                                          tgt_mask,
                                          non_pad_position=non_pad)
                        predicted = predicted.view(-1, predicted.size(-1))
                        loss = label_smoothing_loss(
                            predicted,
                            trg_sequences_target) / args.num_grad_accumulate

                    scaler.scale(loss).backward()

                    if global_step % args.print_freq == 0:
                        trans_top5 += accuracy(predicted, trg_sequences_target,
                                               5) / args.num_grad_accumulate
                        trans_loss += loss.detach().item()

                for param in model.parameters():
                    if param.grad is not None:
                        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
                        param.grad.data = param.grad.data / world_size

                scaler.step(optimizer)
                scaler.update()

            scheduler.step()

            # Print status
            if global_step % args.print_freq == 0:
                if args.train_reconstruct:
                    mask_top5 = torch.cuda.FloatTensor([mask_top5])
                    mask_loss = torch.cuda.FloatTensor([mask_loss])
                    dist.all_reduce(mask_top5, op=dist.ReduceOp.SUM)
                    dist.all_reduce(mask_loss, op=dist.ReduceOp.SUM)
                    mask_top5 = (mask_top5 / world_size).item()
                    mask_loss = (mask_loss / world_size).item()

                if args.train_translate:
                    trans_top5 = torch.cuda.FloatTensor([trans_top5])
                    trans_loss = torch.cuda.FloatTensor([trans_loss])
                    dist.all_reduce(trans_top5, op=dist.ReduceOp.SUM)
                    dist.all_reduce(trans_loss, op=dist.ReduceOp.SUM)
                    trans_top5 = (trans_top5 / world_size).item()
                    trans_loss = (trans_loss / world_size).item()

                if rank == 0:
                    batch_time = time.time() - start
                    write_log(
                        logger,
                        f'[{global_step}/{total_iters}, {epoch}]\tIter time: {batch_time:.3f}\t'
                        f'Trans loss: {trans_loss:.3f}\tMask_loss: {mask_loss:.3f}\t'
                        f'Trans@5: {trans_top5:.3f}\tMask@5: {mask_top5:.3f}')

                    tb_logger.add_scalar('loss/translate', trans_loss,
                                         global_step)
                    tb_logger.add_scalar('loss/mask', mask_loss, global_step)
                    tb_logger.add_scalar('top5/translate', trans_top5,
                                         global_step)
                    tb_logger.add_scalar('top5/mask', mask_top5, global_step)
                    tb_logger.add_scalar('batch/time', batch_time, global_step)
                    tb_logger.add_scalar('batch/lr',
                                         optimizer.param_groups[0]['lr'],
                                         global_step)

            global_step += 1
            if finish_epoch:
                break

        # validation
        validate(model, tgt_mask, h_valid_loader, hk_valid_loader, rank,
                 logger, tb_logger, epoch, device)
        # save model
        if rank == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'model': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'scaler': scaler.state_dict()
                }, f'{args.save_path}/{args.model_name}_ckpt.pt')
            write_log(logger, f"***** {epoch}th model updated! *****")