scaler.update() if scheduler is not None: scheduler.step() # increments global step and save data if needed be new_num_samples_treated = num_samples_treated + batch[0].shape[0] num_batches_treated += 1 if new_num_samples_treated > max_num_samples_to_train_on: state_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch, 'num_samples_treated': new_num_samples_treated, "scaler": scaler.state_dict(), 'num_batches_treated': num_batches_treated } torch.save( state_dict, os.path.join(checkpoint_dir, f'{args.net}-{new_num_samples_treated}.ckpt')) print("training ended") exit() if (num_samples_treated // args.samples_before_ckpt) != \ (new_num_samples_treated // args.samples_before_ckpt): state_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch,
class CustomMTSAC(MTSAC): def __init__( self, policy, qf1, qf2, replay_buffer, env_spec, sampler, train_task_sampler, *, num_tasks, gradient_steps_per_itr, task_update_frequency=1, max_episode_length_eval=None, fixed_alpha=None, target_entropy=None, initial_log_entropy=0., discount=0.99, buffer_batch_size=64, min_buffer_size=10000, target_update_tau=5e-3, policy_lr=3e-4, qf_lr=3e-4, reward_scale=1.0, optimizer=torch.optim.Adam, num_evaluation_episodes=5, # added fp16=False, log_per_task=False, share_train_eval_env=False): super().__init__( policy=policy, qf1=qf1, qf2=qf2, replay_buffer=replay_buffer, env_spec=env_spec, sampler=sampler, test_sampler=sampler, # not used, for compatibility train_task_sampler=train_task_sampler, num_tasks=num_tasks, gradient_steps_per_itr=gradient_steps_per_itr, max_episode_length_eval=max_episode_length_eval, fixed_alpha=fixed_alpha, target_entropy=target_entropy, initial_log_entropy=initial_log_entropy, discount=discount, buffer_batch_size=buffer_batch_size, min_buffer_size=min_buffer_size, target_update_tau=target_update_tau, policy_lr=policy_lr, qf_lr=qf_lr, reward_scale=reward_scale, optimizer=optimizer, steps_per_epoch=1, num_evaluation_episodes=num_evaluation_episodes, ) self._train_task_sampler = train_task_sampler self._task_update_frequency = task_update_frequency self._fp16 = fp16 self._log_per_task = log_per_task self._total_envsteps = 0 # scalers for fp16 # TODO: don't initialize gradscalers if not using fp16 # Also don't save and/or restore self._gs_qf1 = GradScaler() self._gs_qf2 = GradScaler() self._gs_policy = GradScaler() self._gs_alpha = GradScaler() # get updates for evaluation self.eval_env_updates = self.resample_environment(force_update=True) self.share_train_eval_env = share_train_eval_env if self.share_train_eval_env: logging.warn("WARNING: Sharing train and eval environments") # Fix bug with alpha with optimizer self._use_automatic_entropy_tuning = fixed_alpha is None if self._use_automatic_entropy_tuning: self._alpha_optimizer = optimizer([self._log_alpha], lr=self._policy_lr) def state_dict(self): return { # parameters "policy": self.policy.state_dict(), "qf1": self._qf1.state_dict(), "qf2": self._qf2.state_dict(), "target_qf1": self._target_qf1.state_dict(), "target_qf2": self._target_qf2.state_dict(), "log_alpha": self._log_alpha, # scalers "gs_qf1": self._gs_qf1.state_dict(), "gs_qf2": self._gs_qf2.state_dict(), "gs_policy": self._gs_policy.state_dict(), "gs_alpha": self._gs_alpha.state_dict(), # optimizers "policy_optimizer": self._policy_optimizer.state_dict(), "qf1_optimizer": self._qf1_optimizer.state_dict(), "qf2_optimizer": self._qf2_optimizer.state_dict(), "alpha_optimizer": self._alpha_optimizer.state_dict(), # other variables "replay_buffer": self.replay_buffer, "total_envsteps": self._total_envsteps, } def load_env_state(self, env_state): self.eval_env_updates = env_state def load_state(self, state): # parameters self.policy.load_state_dict(state["policy"]) self._qf1.load_state_dict(state["qf1"]) self._qf2.load_state_dict(state["qf2"]) self._target_qf1.load_state_dict(state["target_qf1"]) self._target_qf2.load_state_dict(state["target_qf2"]) self._log_alpha.data = state["log_alpha"] # scalers self._gs_qf1.load_state_dict(state["gs_qf1"]) self._gs_qf2.load_state_dict(state["gs_qf2"]) self._gs_policy.load_state_dict(state["gs_policy"]) self._gs_alpha.load_state_dict(state["gs_alpha"]) # optimizers self._policy_optimizer.load_state_dict(state["policy_optimizer"]) self._qf1_optimizer.load_state_dict(state["qf1_optimizer"]) self._qf2_optimizer.load_state_dict(state["qf2_optimizer"]) self._alpha_optimizer.load_state_dict(state["alpha_optimizer"]) # other variables self.replay_buffer = state["replay_buffer"] self._total_envsteps = state["total_envsteps"] def get_updated_policy(self, policy_hook=None): with torch.no_grad(): updated_policy = copy.deepcopy(self.policy) updated_policy.eval() # attach hooks if policy_hook: policy_hook(updated_policy) return updated_policy def update_buffer(self, trajectories): """Update Buffer""" self._total_envsteps += sum(trajectories.lengths) path_returns = [] for path in trajectories.to_list(): self.replay_buffer.add_path( dict(observation=path["observations"], action=path["actions"], reward=path["rewards"].reshape(-1, 1), next_observation=path["next_observations"], terminal=np.array([ step_type == StepType.TERMINAL for step_type in path["step_types"] ]).reshape(-1, 1))) path_returns.append(sum(path["rewards"])) self.episode_rewards.append(np.mean(path_returns)) def resample_environment(self, epoch=0, force_update=False): """ TODO: fix env update in sampler Intended behavior: if epoch % self._task_update_frequency == 0 or force_update: return self._train_task_sampler.sample(self._num_tasks) """ # TODO: remove first line to allow force update if epoch % self._task_update_frequency == 0 or force_update: return self._train_task_sampler.sample(self._num_tasks) def run_epoch(self, epoch, env_steps_per_epoch): """ Run one epoch, which is composed of one N sample collections and N training steps. Each training step in their turn is composed of M gradient steps of batch size B Total number of samples used by the algorithm in a epoch is given by N * M * B (steps * gradient_steps * batch size) Samples collected are only used to update the buffer, and there is no direct influence on number of gradient steps or batch size. Returns: float: The average return in last epoch cycle. """ t0 = time() env_updates = (self.eval_env_updates if self.share_train_eval_env else self.resample_environment(epoch)) new_trajectories = self._sampler.obtain_samples( num_samples=env_steps_per_epoch, agent_update=self.get_updated_policy(), env_updates=env_updates, ) self.update_buffer(new_trajectories) t1 = time() total_losses = self.run_step() time_to_collect_samples = t1 - t0 time_to_update_gradient = time() - t1 log_dict = self._log_statistics(*total_losses) # TODO: switch to logger.debug once logger is fixed logging.warn(f"Time to collect samples: {time_to_collect_samples:.2f}") logging.warn(f"Time to update gradient: {time_to_update_gradient:.2f}") return log_dict def run_step(self): """ Run one training step, which is composed of M gradient steps For M gradients steps: - sample a batch from buffer - perform one gradient step in all three networks (policy, qf1 and qf2) """ total_losses = [0, 0, 0] for _ in range(self._gradient_steps): if self.replay_buffer.n_transitions_stored >= self._min_buffer_size: samples = as_torch_dict( self.replay_buffer.sample_transitions( self._buffer_batch_size)) policy_loss, qf1_loss, qf2_loss = self.optimize_policy(samples) total_losses[0] += policy_loss total_losses[1] += qf1_loss total_losses[2] += qf2_loss self._update_targets() # Normalize losses by total of gradient updates total_losses = [loss / self._gradient_steps for loss in total_losses] return total_losses def _evaluate_policy(self, epoch, policy_hook=None): """Evaluate the performance of the policy via deterministic sampling. Statistics such as (average) discounted return and success rate are recorded. Args: epoch (int): The current training epoch. Returns: float: The average return across self._num_evaluation_episodes episodes """ t0 = time() # Collect episodes for evaluation eval_trajectories, policy_hook_data = self._sampler.obtain_exact_episodes( n_eps_per_worker=self._num_evaluation_episodes, agent_update=self.get_updated_policy(policy_hook=policy_hook), env_updates=self.eval_env_updates, ) # Log performance undiscounted_returns, log_dict = log_multitask_performance( epoch, batch=eval_trajectories, discount=self._discount, log_per_task=self._log_per_task) log_dict["average_return"] = np.mean(undiscounted_returns) logging.warn(f"Time to evaluate policy: {time()-t0:.2f}") return undiscounted_returns, log_dict, policy_hook_data def _log_statistics(self, policy_loss, qf1_loss, qf2_loss): """Record training statistics to dowel such as losses and returns. Args: policy_loss (torch.Tensor): loss from actor/policy network. qf1_loss (torch.Tensor): loss from 1st qf/critic network. qf2_loss (torch.Tensor): loss from 2nd qf/critic network. """ log_dict = {} with torch.no_grad(): log_dict["AlphaTemperature/mean"] = self._log_alpha.exp().mean( ).item() log_dict["Policy/Loss"] = policy_loss.item() log_dict["QF/{}".format("Qf1Loss")] = float(qf1_loss) log_dict["QF/{}".format("Qf2Loss")] = float(qf2_loss) log_dict[ "ReplayBuffer/buffer_size"] = self.replay_buffer.n_transitions_stored log_dict["Average/TrainAverageReturn"] = np.mean(self.episode_rewards) log_dict["TotalEnvSteps"] = self._total_envsteps return log_dict def _get_log_alpha(self, samples_data): """Return the value of log_alpha. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Raises: ValueError: If the number of tasks, num_tasks passed to this algorithm doesn't match the length of the task one-hot id in the observation vector. Returns: torch.Tensor: log_alpha. shape is (1, self.buffer_batch_size) """ obs = samples_data["observation"] log_alpha = self._log_alpha one_hots = obs[:, -self._num_tasks:] if (log_alpha.shape[0] != one_hots.shape[1] or one_hots.shape[1] != self._num_tasks or log_alpha.shape[0] != self._num_tasks): raise ValueError( "The number of tasks in the environment does " "not match self._num_tasks. Are you sure that you passed " "The correct number of tasks?") with autocast(enabled=self._fp16): return torch.mm(one_hots, log_alpha.unsqueeze(0).t()).squeeze() def _temperature_objective(self, log_pi, samples_data): """Compute the temperature/alpha coefficient loss. Args: log_pi(torch.Tensor): log probability of actions that are sampled from the replay buffer. Shape is (1, buffer_batch_size). samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: the temperature/alpha coefficient loss. """ alpha_loss = 0 with autocast(enabled=self._fp16): if self._use_automatic_entropy_tuning: alpha_loss = (-(self._get_log_alpha(samples_data)) * (log_pi.detach() + self._target_entropy)).mean() return alpha_loss def _actor_objective(self, samples_data, new_actions, log_pi_new_actions): """Compute the Policy/Actor loss. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. new_actions (torch.Tensor): Actions resampled from the policy based based on the Observations, obs, which were sampled from the replay buffer. Shape is (action_dim, buffer_batch_size). log_pi_new_actions (torch.Tensor): Log probability of the new actions on the TanhNormal distributions that they were sampled from. Shape is (1, buffer_batch_size). Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: loss from the Policy/Actor. """ obs = samples_data["observation"] with torch.no_grad(): alpha = self._get_log_alpha(samples_data).exp() with autocast(enabled=self._fp16): min_q_new_actions = torch.min(self._qf1(obs, new_actions), self._qf2(obs, new_actions)) policy_objective = ((alpha * log_pi_new_actions) - min_q_new_actions.flatten()).mean() return policy_objective def _critic_objective(self, samples_data): """Compute the Q-function/critic loss. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: loss from 1st q-function after optimization. torch.Tensor: loss from 2nd q-function after optimization. """ obs = samples_data["observation"] actions = samples_data["action"] rewards = samples_data["reward"].flatten() terminals = samples_data["terminal"].flatten() next_obs = samples_data["next_observation"] with torch.no_grad(): alpha = self._get_log_alpha(samples_data).exp() with autocast(enabled=self._fp16): q1_pred = self._qf1(obs, actions) q2_pred = self._qf2(obs, actions) new_next_actions_dist = self.policy(next_obs)[0] new_next_actions_pre_tanh, new_next_actions = ( new_next_actions_dist.rsample_with_pre_tanh_value()) new_log_pi = new_next_actions_dist.log_prob( value=new_next_actions, pre_tanh_value=new_next_actions_pre_tanh) target_q_values = torch.min( self._target_qf1(next_obs, new_next_actions), self._target_qf2(next_obs, new_next_actions)).flatten() - ( alpha * new_log_pi) with torch.no_grad(): q_target = rewards * self._reward_scale + ( 1. - terminals) * self._discount * target_q_values qf1_loss = F.mse_loss(q1_pred.flatten(), q_target) qf2_loss = F.mse_loss(q2_pred.flatten(), q_target) return qf1_loss, qf2_loss def optimize_policy(self, samples_data): """Optimize the policy q_functions, and temperature coefficient. Rezero model weights (if applicable) after each optimizer step. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: loss from actor/policy network after optimization. torch.Tensor: loss from 1st q-function after optimization. torch.Tensor: loss from 2nd q-function after optimization. """ if self._fp16: return self.optimize_policy_with_autocast(samples_data) obs = samples_data["observation"] qf1_loss, qf2_loss = self._critic_objective(samples_data) self._qf1_optimizer.zero_grad() qf1_loss.backward() self._qf1_optimizer.step() self._qf1.apply(rezero_weights) self._qf2_optimizer.zero_grad() qf2_loss.backward() self._qf2_optimizer.step() self._qf2.apply(rezero_weights) action_dists = self.policy(obs)[0] new_actions_pre_tanh, new_actions = ( action_dists.rsample_with_pre_tanh_value()) log_pi_new_actions = action_dists.log_prob( value=new_actions, pre_tanh_value=new_actions_pre_tanh) policy_loss = self._actor_objective(samples_data, new_actions, log_pi_new_actions) self._policy_optimizer.zero_grad() policy_loss.backward() self._policy_optimizer.step() self.policy.apply(rezero_weights) if self._use_automatic_entropy_tuning: alpha_loss = self._temperature_objective(log_pi_new_actions, samples_data) self._alpha_optimizer.zero_grad() alpha_loss.backward() self._alpha_optimizer.step() return policy_loss, qf1_loss, qf2_loss def optimize_policy_with_autocast(self, samples_data): """Optimize the policy q_functions, and temperature coefficient. Rezero model weights (if applicable) after each optimizer step. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: loss from actor/policy network after optimization. torch.Tensor: loss from 1st q-function after optimization. torch.Tensor: loss from 2nd q-function after optimization. """ obs = samples_data["observation"] qf1_loss, qf2_loss = self._critic_objective(samples_data) self._qf1_optimizer.zero_grad() self._gs_qf1.scale(qf1_loss).backward() self._gs_qf1.step(self._qf1_optimizer) self._gs_qf1.update() self._qf1.apply(rezero_weights) self._qf2_optimizer.zero_grad() self._gs_qf2.scale(qf2_loss).backward() self._gs_qf2.step(self._qf2_optimizer) self._gs_qf2.update() self._qf2.apply(rezero_weights) with autocast(): action_dists = self.policy(obs)[0] new_actions_pre_tanh, new_actions = ( action_dists.rsample_with_pre_tanh_value()) log_pi_new_actions = action_dists.log_prob( value=new_actions, pre_tanh_value=new_actions_pre_tanh) policy_loss = self._actor_objective(samples_data, new_actions, log_pi_new_actions) self._policy_optimizer.zero_grad() self._gs_policy.scale(policy_loss).backward() self._gs_policy.step(self._policy_optimizer) self._gs_policy.update() self.policy.apply(rezero_weights) if self._use_automatic_entropy_tuning: alpha_loss = self._temperature_objective(log_pi_new_actions, samples_data) self._alpha_optimizer.zero_grad() self._gs_alpha.scale(alpha_loss).backward() self._gs_alpha.step(self._alpha_optimizer) self._gs_alpha.update() return policy_loss, qf1_loss, qf2_loss def shutdown_worker(self): """Shutdown Plotter and Sampler workers.""" self._sampler.shutdown_worker()
class GenericTrainingManager: def __init__(self, params): self.type = None self.is_master = False self.params = params self.models = {} self.begin_time = None self.dataset = None self.paths = None self.latest_epoch = -1 self.latest_batch = 0 self.total_batch = 0 self.latest_train_metrics = dict() self.latest_valid_metrics = dict() self.phase = None self.max_mem_usage_by_epoch = list() self.scaler = None self.optimizer = None self.lr_scheduler = None self.best = None self.writer = None reset_optimizer = "reset_optimizer" in self.params["training_params"] and self.params["training_params"]["reset_optimizer"] self.init_hardware_config() self.init_paths() self.load_dataset() self.load_model(reset_optimizer) def init_paths(self): ## Create output folders output_path = os.path.join("outputs", self.params["training_params"]["output_folder"]) os.makedirs(output_path, exist_ok=True) checkpoints_path = os.path.join(output_path, "checkpoints") os.makedirs(checkpoints_path, exist_ok=True) results_path = os.path.join(output_path, "results") os.makedirs(results_path, exist_ok=True) self.paths = { "results": results_path, "checkpoints": checkpoints_path, "output_folder": output_path } def load_dataset(self): self.params["dataset_params"]["use_ddp"] = self.params["training_params"]["use_ddp"] self.params["dataset_params"]["batch_size"] = self.params["training_params"]["batch_size"] self.params["dataset_params"]["num_gpu"] = self.params["training_params"]["nb_gpu"] self.dataset = DatasetManager(self.params["dataset_params"]) if self.dataset.charset: self.params["model_params"]["vocab_size"] = len(self.dataset.charset) def init_hardware_config(self): # Debug mode if self.params["training_params"]["force_cpu"]: self.params["training_params"]["use_ddp"] = False self.params["training_params"]["use_amp"] = False # Manage Distributed Data Parallel & GPU usage self.manual_seed = 1111 if "manual_seed" not in self.params["training_params"].keys() else \ self.params["training_params"]["manual_seed"] self.ddp_config = { "master": self.params["training_params"]["use_ddp"] and self.params["training_params"]["ddp_rank"] == 0, "address": "localhost" if "ddp_addr" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_addr"], "port": "11111" if "ddp_port" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_port"], "backend": "nccl" if "ddp_backend" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_backend"], "rank": self.params["training_params"]["ddp_rank"], } self.is_master = self.ddp_config["master"] or not self.params["training_params"]["use_ddp"] if self.params["training_params"]["force_cpu"]: self.device = "cpu" else: if self.params["training_params"]["use_ddp"]: self.device = torch.device(self.ddp_config["rank"]) self.params["dataset_params"]["ddp_rank"] = self.ddp_config["rank"] self.launch_ddp() else: self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Print GPU info # global if (self.params["training_params"]["use_ddp"] and self.ddp_config["master"]) or not self.params["training_params"]["use_ddp"]: print("##################") print("Available GPUS: {}".format(self.params["training_params"]["nb_gpu"])) for i in range(self.params["training_params"]["nb_gpu"]): print("Rank {}: {} {}".format(i, torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i))) print("##################") # local print("Local GPU:") if self.device != "cpu": print("Rank {}: {} {}".format(self.params["training_params"]["ddp_rank"], torch.cuda.get_device_name(), torch.cuda.get_device_properties(self.device))) else: print("WORKING ON CPU !\n") print("##################") def load_model(self, reset_optimizer=False): self.params["model_params"]["use_amp"] = self.params["training_params"]["use_amp"] # Instanciate Model for model_name in self.params["model_params"]["models"].keys(): self.models[model_name] = self.params["model_params"]["models"][model_name](self.params["model_params"]) self.models[model_name].to(self.device) # To GPU or CPU # Instanciate optimizer self.reset_optimizer() if "lr_scheduler" in self.params["training_params"] and self.params["training_params"]["lr_scheduler"]: self.lr_scheduler = self.params["training_params"]["lr_scheduler"]["type"](self.optimizer, gamma=self.params["training_params"]["lr_scheduler"]["gamma"]) self.scaler = GradScaler(enabled=self.params["training_params"]["use_amp"]) # Load previous weights checkpoint = None if self.params["training_params"]["load_epoch"] in ("best", "last"): for filename in os.listdir(self.paths["checkpoints"]): # Continue training if self.params["training_params"]["load_epoch"] in filename: checkpoint_path = os.path.join(self.paths["checkpoints"], filename) checkpoint = torch.load(checkpoint_path) self.load_save_info(checkpoint) self.latest_epoch = checkpoint["epoch"] self.best = checkpoint["best"] self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) # Make model compatible with Distributed Data Parallel if used if self.params["training_params"]["use_ddp"]: for model_name in self.models.keys(): self.models[model_name] = DDP(self.models[model_name], [self.ddp_config["rank"]]) # Load model weights from past training for model_name in self.models.keys(): self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(model_name)]) # Load optimizer state from past training if not reset_optimizer: self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) # Load optimizer scheduler config from past training if used if "lr_scheduler" in self.params["training_params"] and self.params["training_params"]["lr_scheduler"] and "lr_scheduler_state_dict" in checkpoint.keys(): self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) break # Print the number of trained epoch so far with the model if self.is_master: print("LOADED EPOCH: {}\n".format(self.latest_epoch), flush=True) # New training if not checkpoint: # Weights initialization for model_name in self.models.keys(): self.models[model_name].apply(self.weights_init) # Handle transfer learning instructions if self.params["model_params"]["transfer_learning"]: # Iterates over models for model_name in self.params["model_params"]["transfer_learning"].keys(): state_dict_name, path, learnable, strict = self.params["model_params"]["transfer_learning"][model_name] # Loading pretrained weights file checkpoint = torch.load(path) try: # Load pretrained weights for model self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(state_dict_name)], strict=strict) print("transfered weights for {}".format(state_dict_name), flush=True) except RuntimeError as e: print(e, flush=True) # if error, try to load each parts of the model (useful if only few layers are different) for key in checkpoint["{}_state_dict".format(state_dict_name)].keys(): try: self.models[model_name].load_state_dict({key: checkpoint["{}_state_dict".format(state_dict_name)][key]}, strict=False) except RuntimeError as e: print(e, flush=True) # Set parameters no trainable if not learnable: self.set_model_learnable(self.models[model_name], False) # make the model compatible with Distributed Data Parallel if used if self.params["training_params"]["use_ddp"]: for model_name in self.models.keys(): self.models[model_name] = DDP(self.models[model_name], [self.ddp_config["rank"]]) return @staticmethod def set_model_learnable(model, learnable=True): for p in list(model.parameters()): p.requires_grad = learnable def save_model(self, epoch, name, keep_weights=False): """ Save model weights """ if not self.is_master: return to_del = [] for filename in os.listdir(self.paths["checkpoints"]): if name in filename: to_del.append(os.path.join(self.paths["checkpoints"], filename)) path = os.path.join(self.paths["checkpoints"], "{}_{}.pt".format(name, epoch)) content = { 'optimizer_state_dict': self.optimizer.state_dict(), 'epoch': epoch, "scaler_state_dict": self.scaler.state_dict(), 'best': self.best, } if self.lr_scheduler: content["lr_scheduler_state_dict"] = self.lr_scheduler.state_dict() content = self.add_save_info(content) for model_name in self.models.keys(): content["{}_state_dict".format(model_name)] = self.models[model_name].state_dict() torch.save(content, path) if not keep_weights: for path_to_del in to_del: if path_to_del != path: os.remove(path_to_del) def reset_optimizer(self): """ Reset optimizer learning rate """ parameters = list() for model_name in self.models.keys(): parameters += list(self.models[model_name].parameters()) self.optimizer = self.params["training_params"]["optimizer"]["class"]\ (parameters, **self.params["training_params"]["optimizer"]["args"]) @staticmethod def weights_init(m): """ Weights initialization for model training from scratch """ if isinstance(m, Conv2d) or isinstance(m, Linear): if m.weight is not None: kaiming_uniform_(m.weight, nonlinearity="relu") if m.bias is not None: zeros_(m.bias) elif isinstance(m, InstanceNorm2d): if m.weight is not None: ones_(m.weight) if m.bias is not None: zeros_(m.bias) def save_params(self): """ Output text file containing a summary of all hyperparameters chosen for the training """ def compute_nb_params(module): return sum([np.prod(p.size()) for p in list(module.parameters())]) def class_to_str_dict(my_dict): for key in my_dict.keys(): if callable(my_dict[key]): my_dict[key] = my_dict[key].__name__ elif isinstance(my_dict[key], np.ndarray): my_dict[key] = my_dict[key].tolist() elif isinstance(my_dict[key], dict): my_dict[key] = class_to_str_dict(my_dict[key]) return my_dict path = os.path.join(self.paths["results"], "params") if os.path.isfile(path): return params = copy.deepcopy(self.params) params = class_to_str_dict(params) total_params = 0 for model_name in self.models.keys(): current_params = compute_nb_params(self.models[model_name]) params["model_params"]["models"][model_name] = [params["model_params"]["models"][model_name], "{:,}".format(current_params)] total_params += current_params params["model_params"]["total_params"] = "{:,}".format(total_params) params["hardware"] = dict() if self.device != "cpu": for i in range(self.params["training_params"]["nb_gpu"]): params["hardware"][str(i)] = "{} {}".format(torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i)) else: params["hardware"]["0"] = "CPU" with open(path, 'w') as f: json.dump(params, f, indent=4) def update_memory_consumption(self): self.max_mem_usage_by_epoch.append(torch.cuda.max_memory_allocated()) torch.cuda.reset_max_memory_allocated() with open(os.path.join(self.paths["results"], "memory.txt"), 'a') as f: current = round(self.max_mem_usage_by_epoch[-1]/1e9, 2) max = round(np.max(self.max_mem_usage_by_epoch)/1e9, 2) min = round(np.min(self.max_mem_usage_by_epoch)/1e9, 2) median = round(np.median(self.max_mem_usage_by_epoch)/1e9, 2) mean = round(np.mean(self.max_mem_usage_by_epoch)/1e9, 2) f.write("E{} - Current: {} Go - Max: {} Go - Min: {} Go - Mean: {} Go - Median: {} Go\n".format( self.latest_epoch, current, max, min, mean, median)) @staticmethod def init_metrics(metrics_name): """ Initialization of the metrics specified in metrics_name """ metrics = { "nb_samples": 0, "weights": 0, "names": list(), "ids": list(), } for metric_name in metrics_name: if metric_name == "cer": metrics["nb_chars"] = 0 metrics[metric_name] = list() continue elif metric_name == "wer": metrics["nb_words"] = 0 elif metric_name in ["pred", "proba", "cer_force_len"]: metrics[metric_name] = list() continue elif metric_name == "diff_len": metrics[metric_name] = None continue metrics[metric_name] = 0 return metrics @staticmethod def update_metrics(metrics, batch_metrics): """ Add batch metrics to the metrics """ for key in batch_metrics.keys(): if key in ["diff_len", ]: if metrics[key] is None: metrics[key] = batch_metrics[key] else: metrics[key] = np.concatenate([metrics[key], batch_metrics[key]], axis=0) elif key in ["pred", ]: if len(metrics[key]) == 0: metrics[key] = batch_metrics[key] else: for i in range(len(metrics[key])): metrics[key][i] += batch_metrics[key][i] else: metrics[key] += batch_metrics[key] return metrics def get_display_values(self, metrics, metrics_name, num_batch): """ format metrics values for shell display purposes """ display_values = {} for metric_name in metrics_name: if metric_name in ["cer", "cer_force_len", ]: edit = np.sum(metrics[metric_name]) display_values[metric_name] = round(edit / metrics["nb_chars"], 4) elif metric_name == "wer": display_values[metric_name] = round(metrics[metric_name] / metrics["nb_words"], 4) elif metric_name in ["f_measure", "precision", "recall", "IoU", "mAP", "pp_f_measure", "pp_precision", "pp_recall", "pp_IoU", "pp_mAP"]: display_values[metric_name] = round(metrics[metric_name] / metrics["weights"], 4) elif metric_name in ["diff_len", ]: display_values[metric_name] = np.round(np.mean(np.abs(metrics[metric_name])), 3) elif metric_name in ["time", "pred", "probas", "nb_max_len", "worst_cer", ]: continue elif metric_name in ["loss", "loss_ctc", "loss_ce", "loss_ce_end", "loss_mse"]: display_values[metric_name] = round(metrics[metric_name] / self.latest_batch, 4) else: display_values[metric_name] = round(metrics[metric_name] / metrics["nb_samples"], 4) return display_values def backward_loss(self, loss, retain_graph=False): self.scaler.scale(loss).backward(retain_graph=retain_graph) def step_optimizer(self): self.scaler.step(self.optimizer) self.scaler.update() def train(self): # init tensorboard file and output param summary file if self.is_master: self.writer = SummaryWriter(self.paths["results"]) self.save_params() # init variables self.begin_time = time() focus_metric_name = self.params["training_params"]["focus_metric"] nb_epochs = self.params["training_params"]["max_nb_epochs"] interval_save_weights = self.params["training_params"]["interval_save_weights"] metrics_name = self.params["training_params"]["train_metrics"] display_values = None # perform epochs for num_epoch in range(self.latest_epoch+1, nb_epochs): self.phase = "train" # Check maximum training time stop condition if self.params["training_params"]["max_training_time"] and time() - self.begin_time > self.params["training_params"]["max_training_time"]: break # set models trainable for model_name in self.models.keys(): self.models[model_name].train() self.latest_epoch = num_epoch # init epoch metrics values metrics = self.init_metrics(metrics_name) t = tqdm(self.dataset.train_loader) t.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs)) # iterates over mini-batch data for ind_batch, batch_data in enumerate(t): self.latest_batch = ind_batch + 1 self.total_batch += 1 # train on batch data and compute metrics batch_metrics = self.train_batch(batch_data, metrics_name) batch_metrics["names"] = batch_data["names"] batch_metrics["ids"] = batch_data["ids"] # Merge metrics if Distributed Data Parallel is used if self.params["training_params"]["use_ddp"]: batch_metrics = self.merge_ddp_metrics(batch_metrics) # Update learning rate via scheduler if one is used if self.lr_scheduler and ind_batch % self.params["training_params"]["lr_scheduler"]["step_interval"] == 0: self.lr_scheduler.step() # Add batch metrics values to epoch metrics values metrics = self.update_metrics(metrics, batch_metrics) display_values = self.get_display_values(metrics, metrics_name, ind_batch) t.set_postfix(values=str(display_values)) # log metrics in tensorboard file if self.is_master: for key in display_values.keys(): self.writer.add_scalar('{}_{}'.format(self.params["dataset_params"]["train"]["name"], key), display_values[key], num_epoch) self.latest_train_metrics = display_values # evaluate and compute metrics for valid sets if self.params["training_params"]["eval_on_valid"] and num_epoch % self.params["training_params"]["eval_on_valid_interval"] == 0: for valid_set_name in self.dataset.valid_loaders.keys(): # evaluate set and compute metrics eval_values = self.evaluate(valid_set_name) self.latest_valid_metrics = eval_values # log valid metrics in tensorboard file if self.is_master: for key in eval_values.keys(): self.writer.add_scalar('{}_{}'.format(valid_set_name, key), eval_values[key], num_epoch) if valid_set_name == self.params["training_params"]["set_name_focus_metric"] and (self.best is None or \ (eval_values[focus_metric_name] < self.best and self.params["training_params"]["expected_metric_value"] == "low") or\ (eval_values[focus_metric_name] > self.best and self.params["training_params"]["expected_metric_value"] == "high")): self.save_model(epoch=num_epoch, name="best") self.best = eval_values[focus_metric_name] ## save model weights if self.is_master: self.save_model(epoch=num_epoch, name="last") self.update_memory_consumption() if interval_save_weights and num_epoch % interval_save_weights == 0: self.save_model(epoch=num_epoch, name="weigths", keep_weights=True) self.writer.flush() def evaluate(self, set_name, **kwargs): self.phase = "eval" loader = self.dataset.valid_loaders[set_name] # Set models in eval mode for model_name in self.models.keys(): self.models[model_name].eval() metrics_name = self.params["training_params"]["eval_metrics"] display_values = None # initialize epoch metrics metrics = self.init_metrics(metrics_name) t = tqdm(loader) t.set_description("Evaluation E{}".format(self.latest_epoch)) with torch.no_grad(): # iterate over batch data for ind_batch, batch_data in enumerate(t): self.latest_batch = ind_batch + 1 # eval batch data and compute metrics batch_metrics = self.evaluate_batch(batch_data, metrics_name) batch_metrics["names"] = batch_data["names"] batch_metrics["ids"] = batch_data["ids"] # merge metrics values if Distributed Data Parallel is used if self.params["training_params"]["use_ddp"]: batch_metrics = self.merge_ddp_metrics(batch_metrics) # add batch metrics to epoch metrics metrics = self.update_metrics(metrics, batch_metrics) display_values = self.get_display_values(metrics, metrics_name, ind_batch) t.set_postfix(values=str(display_values)) return display_values def predict(self, custom_name, sets_list, metrics_name, output=False): self.phase = "predict" metrics_name = metrics_name.copy() self.dataset.generate_test_loader(custom_name, sets_list) loader = self.dataset.test_loaders[custom_name] # Set models in eval mode for model_name in self.models.keys(): self.models[model_name].eval() pred_time_metric = False if "time" in metrics_name: metrics_name.remove("time") pred_time_metric = True # initialize epoch metrics metrics = self.init_metrics(metrics_name) t = tqdm(loader) t.set_description("Prediction") begin_time = time() with torch.no_grad(): for ind_batch, batch_data in enumerate(t): # iterates over batch data self.latest_batch = ind_batch + 1 # eval batch data and compute metrics batch_metrics = self.evaluate_batch(batch_data, metrics_name) batch_metrics["names"] = batch_data["names"] batch_metrics["ids"] = batch_data["ids"] # merge batch metrics if Distributed Data Parallel is used if self.params["training_params"]["use_ddp"]: batch_metrics = self.merge_ddp_metrics(batch_metrics) # add batch metrics to epoch metrics metrics = self.update_metrics(metrics, batch_metrics) display_values = self.get_display_values(metrics, metrics_name, ind_batch) t.set_postfix(values=str(display_values)) pred_time = time() - begin_time # add time metric values if requested if pred_time_metric: metrics["total_time"] = np.round(pred_time, 3) metrics["sample_time"] = np.round(pred_time / len(self.dataset.test_datasets[custom_name]), 4) # output metrics values if requested if output: for name in ["probas", ]: if name in metrics.keys(): path = os.path.join(self.paths["results"], "{}_{}_{}.txt".format(name, custom_name, self.latest_epoch)) info = "\n".join(metrics[name]) with open(path, "w") as f: f.write(info) del metrics[name] self.output(metrics, custom_name) def launch_ddp(self): """ Initialize Distributed Data Parallel system """ mp.set_start_method('fork', force=True) os.environ['MASTER_ADDR'] = self.ddp_config["address"] os.environ['MASTER_PORT'] = str(self.ddp_config["port"]) dist.init_process_group(self.ddp_config["backend"], rank=self.ddp_config["rank"], world_size=self.params["training_params"]["nb_gpu"]) torch.cuda.set_device(self.ddp_config["rank"]) random.seed(self.manual_seed) np.random.seed(self.manual_seed) torch.manual_seed(self.manual_seed) torch.cuda.manual_seed(self.manual_seed) def merge_ddp_metrics(self, metrics): """ Merge metrics when Distributed Data Parallel is used """ for metric_name in metrics.keys(): if metric_name in ["wer", "wer_force_len", "nb_samples", "nb_words", "nb_chars", "nb_max_len", "f_measure", "precision", "recall", "IoU", "mAP", "pp_f_measure", "pp_precision", "pp_recall", "pp_IoU", "pp_mAP"]: metrics[metric_name] = self.sum_ddp_metric(metrics[metric_name]) elif metric_name in ["loss", "loss_ce", "loss_ctc", "loss_ce_end"]: metrics[metric_name] = self.sum_ddp_metric(metrics[metric_name], average=True) elif metric_name in ["diff_len", "cer", "cer_force_len", "ids"]: metrics[metric_name] = self.cat_ddp_metric(metrics[metric_name]) return metrics def sum_ddp_metric(self, metric, average=False): """ Sum metrics for Distributed Data Parallel """ sum = torch.tensor(metric).to(self.device) dist.all_reduce(sum, op=dist.ReduceOp.SUM) if average: sum.true_divide(dist.get_world_size()) return sum.item() def cat_ddp_metric(self, metric): """ Concatenate metrics for Distributed Data Parallel """ tensor = torch.tensor(metric).unsqueeze(0).to(self.device) res = [torch.zeros(tensor.size()).long().to(self.device) for _ in range(dist.get_world_size())] dist.all_gather(res, tensor) return list(torch.cat(res, dim=0).flatten().cpu().numpy()) @staticmethod def cleanup(): dist.destroy_process_group() def train_batch(self, batch_data, metric_names): raise NotImplementedError def evaluate_batch(self, batch_data, metric_names): raise NotImplementedError def output_pred(self, pred, set_name): raise NotImplementedError def add_checkpoint_info(self, load_mode="last", **kwargs): for filename in os.listdir(self.paths["checkpoints"]): if load_mode in filename: checkpoint_path = os.path.join(self.paths["checkpoints"], filename) checkpoint = torch.load(checkpoint_path) for key in kwargs.keys(): checkpoint[key] = kwargs[key] torch.save(checkpoint, checkpoint_path) return self.save_model(self.latest_epoch, "last") def output(self, metrics, set_name): """ Output metrics in text file """ path = os.path.join(self.paths["results"], "predict_{}_{}.txt".format(set_name, self.latest_epoch)) with open(path, "w") as f: for metric_name in metrics.keys(): if metric_name in ["cer", "cer_force_len"]: edit = np.sum(metrics[metric_name]) value = round(edit / metrics["nb_chars"], 4) elif metric_name in ["wer", ]: value = round(metrics[metric_name] / metrics["nb_words"], 4) elif metric_name in ["loss_ce", ]: value = round(metrics[metric_name] / metrics["nb_samples"], 4) elif metric_name in ["total_time", "sample_time", "total_output_time", "sample_output_time"]: value = metrics[metric_name] elif metric_name in ["nb_samples", "nb_words", "nb_chars", "nb_max_len"]: value = metrics[metric_name] elif metric_name in ["diff_len", ]: f.write("{}: {}\n".format(metric_name, sorted(list(metrics[metric_name])))) f.write("{}-mean_abs: {}\n".format(metric_name, np.mean(np.abs(metrics[metric_name])))) continue elif metric_name in ["worst_cer", ]: m = metric_name.split("_")[-1] value = [[c, id] for c, id in zip(metrics[m], metrics["ids"])] value = sorted(value, key=lambda x: x[0], reverse=True) value = value[:50] else: continue f.write("{}: {}\n".format(metric_name, value)) def load_save_info(self, info_dict): """ Load curriculum info from saved model info """ if "curriculum_config" in info_dict.keys(): self.dataset.train_dataset.curriculum_config = info_dict["curriculum_config"] def add_save_info(self, info_dict): """ Add curriculum info to model info to be saved """ info_dict["curriculum_config"] = self.dataset.train_dataset.curriculum_config return info_dict
class Trainer: def __init__( self, name="default", results_dir="results", models_dir="models", base_dir="./", optimizer="adam", latent_dim=256, image_size=128, fmap_max=512, transparent=False, batch_size=4, gp_weight=10, gradient_accumulate_every=1, attn_res_layers=[], sle_spatial=False, disc_output_size=5, antialias=False, lr=2e-4, lr_mlp=1.0, ttur_mult=1.0, save_every=1000, evaluate_every=1000, trunc_psi=0.6, aug_prob=None, aug_types=["translation", "cutout"], dataset_aug_prob=0.0, calculate_fid_every=None, is_ddp=False, rank=0, world_size=1, log=False, amp=False, *args, **kwargs, ): self.GAN_params = [args, kwargs] self.GAN = None self.name = name base_dir = Path(base_dir) self.base_dir = base_dir self.results_dir = base_dir / results_dir self.models_dir = base_dir / models_dir self.config_path = self.models_dir / name / ".config.json" assert is_power_of_two( image_size ), "image size must be a power of 2 (64, 128, 256, 512, 1024)" assert all( map(is_power_of_two, attn_res_layers) ), "resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)" self.optimizer = optimizer self.latent_dim = latent_dim self.image_size = image_size self.fmap_max = fmap_max self.transparent = transparent self.aug_prob = aug_prob self.aug_types = aug_types self.lr = lr self.ttur_mult = ttur_mult self.batch_size = batch_size self.gradient_accumulate_every = gradient_accumulate_every self.gp_weight = gp_weight self.evaluate_every = evaluate_every self.save_every = save_every self.steps = 0 self.generator_top_k_gamma = 0.99 self.generator_top_k_frac = 0.5 self.attn_res_layers = attn_res_layers self.sle_spatial = sle_spatial self.disc_output_size = disc_output_size self.antialias = antialias self.d_loss = 0 self.g_loss = 0 self.last_gp_loss = None self.last_recon_loss = None self.last_fid = None self.init_folders() self.loader = None self.dataset_aug_prob = dataset_aug_prob self.calculate_fid_every = calculate_fid_every self.is_ddp = is_ddp self.is_main = rank == 0 self.rank = rank self.world_size = world_size self.syncbatchnorm = is_ddp self.amp = amp self.G_scaler = None self.D_scaler = None if self.amp: self.G_scaler = GradScaler() self.D_scaler = GradScaler() @property def image_extension(self): return "jpg" if not self.transparent else "png" @property def checkpoint_num(self): return floor(self.steps // self.save_every) def init_GAN(self): args, kwargs = self.GAN_params # set some global variables before instantiating GAN global norm_class global Blur norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d Blur = nn.Identity if not self.antialias else Blur # handle bugs when # switching from multi-gpu back to single gpu if self.syncbatchnorm and not self.is_ddp: import torch.distributed as dist os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" dist.init_process_group("nccl", rank=0, world_size=1) # instantiate GAN self.GAN = LightweightGAN( optimizer=self.optimizer, lr=self.lr, latent_dim=self.latent_dim, attn_res_layers=self.attn_res_layers, sle_spatial=self.sle_spatial, image_size=self.image_size, ttur_mult=self.ttur_mult, fmap_max=self.fmap_max, disc_output_size=self.disc_output_size, transparent=self.transparent, rank=self.rank, *args, **kwargs, ) if self.is_ddp: ddp_kwargs = { "device_ids": [self.rank], "output_device": self.rank, "find_unused_parameters": True, } self.G_ddp = DDP(self.GAN.G, **ddp_kwargs) self.D_ddp = DDP(self.GAN.D, **ddp_kwargs) self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs) def write_config(self): self.config_path.write_text(json.dumps(self.config())) def load_config(self): config = ( self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text()) ) self.image_size = config["image_size"] self.transparent = config["transparent"] self.syncbatchnorm = config["syncbatchnorm"] self.disc_output_size = config["disc_output_size"] self.attn_res_layers = config.pop("attn_res_layers", []) self.sle_spatial = config.pop("sle_spatial", False) self.optimizer = config.pop("optimizer", "adam") self.fmap_max = config.pop("fmap_max", 512) del self.GAN self.init_GAN() def config(self): return { "image_size": self.image_size, "transparent": self.transparent, "syncbatchnorm": self.syncbatchnorm, "disc_output_size": self.disc_output_size, "optimizer": self.optimizer, "attn_res_layers": self.attn_res_layers, "sle_spatial": self.sle_spatial, } def set_data_src(self, folder): self.dataset = ImageDataset( folder, self.image_size, transparent=self.transparent, aug_prob=self.dataset_aug_prob, ) sampler = ( DistributedSampler( self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True ) if self.is_ddp else None ) dataloader = DataLoader( self.dataset, num_workers=math.ceil(NUM_CORES / self.world_size), batch_size=math.ceil(self.batch_size / self.world_size), sampler=sampler, shuffle=not self.is_ddp, drop_last=True, pin_memory=True, ) self.loader = cycle(dataloader) # auto set augmentation prob for user if dataset is detected to be low num_samples = len(self.dataset) if not exists(self.aug_prob) and num_samples < 1e5: self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6) print( f"autosetting augmentation probability to {round(self.aug_prob * 100)}%" ) def train(self): assert exists( self.loader ), "You must first initialize the data source with `.set_data_src(<folder of images>)`" device = torch.device(f"cuda:{self.rank}") if not exists(self.GAN): self.init_GAN() self.GAN.train() total_disc_loss = torch.zeros([], device=device) total_gen_loss = torch.zeros([], device=device) batch_size = math.ceil(self.batch_size / self.world_size) # image_size = self.GAN.image_size latent_dim = self.GAN.latent_dim aug_prob = default(self.aug_prob, 0) aug_types = self.aug_types aug_kwargs = {"prob": aug_prob, "types": aug_types} G = self.GAN.G if not self.is_ddp else self.G_ddp # D = self.GAN.D if not self.is_ddp else self.D_ddp D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp apply_gradient_penalty = self.steps % 4 == 0 # amp related contexts and functions amp_context = autocast if self.amp else null_context def backward(amp, loss, scaler): if amp: return scaler.scale(loss).backward() loss.backward() def optimizer_step(amp, optimizer, scaler): if amp: scaler.step(optimizer) scaler.update() return optimizer.step() backward = partial(backward, self.amp) optimizer_step = partial(optimizer_step, self.amp) # train discriminator self.GAN.D_opt.zero_grad() for i in gradient_accumulate_contexts( self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, G] ): latents = torch.randn(batch_size, latent_dim).cuda(self.rank) image_batch = next(self.loader).cuda(self.rank) image_batch.requires_grad_() with amp_context(): generated_images = G(latents) fake_output, fake_output_32x32, _ = D_aug( generated_images.detach(), detach=True, **aug_kwargs ) real_output, real_output_32x32, real_aux_loss = D_aug( image_batch, calc_aux_loss=True, **aug_kwargs ) real_output_loss = real_output fake_output_loss = fake_output divergence = hinge_loss(real_output_loss, fake_output_loss) divergence_32x32 = hinge_loss(real_output_32x32, fake_output_32x32) disc_loss = divergence + divergence_32x32 aux_loss = real_aux_loss disc_loss = disc_loss + aux_loss if apply_gradient_penalty: outputs = [real_output, real_output_32x32] outputs = ( list(map(self.D_scaler.scale, outputs)) if self.amp else outputs ) scaled_gradients = torch_grad( outputs=outputs, inputs=image_batch, grad_outputs=list( map( lambda t: torch.ones(t.size(), device=image_batch.device), outputs, ) ), create_graph=True, retain_graph=True, only_inputs=True, )[0] inv_scale = (1.0 / self.D_scaler.get_scale()) if self.amp else 1.0 gradients = scaled_gradients * inv_scale with amp_context(): gradients = gradients.reshape(batch_size, -1) gp = self.gp_weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() if not torch.isnan(gp): disc_loss = disc_loss + gp self.last_gp_loss = gp.clone().detach().item() with amp_context(): disc_loss = disc_loss / self.gradient_accumulate_every disc_loss.register_hook(raise_if_nan) backward(disc_loss, self.D_scaler) total_disc_loss += divergence self.last_recon_loss = aux_loss.item() self.d_loss = float(total_disc_loss.item() / self.gradient_accumulate_every) optimizer_step(self.GAN.D_opt, self.D_scaler) # train generator self.GAN.G_opt.zero_grad() for i in gradient_accumulate_contexts( self.gradient_accumulate_every, self.is_ddp, ddps=[G, D_aug] ): latents = torch.randn(batch_size, latent_dim).cuda(self.rank) with amp_context(): generated_images = G(latents) fake_output, fake_output_32x32, _ = D_aug( generated_images, **aug_kwargs ) fake_output_loss = fake_output.mean(dim=1) + fake_output_32x32.mean( dim=1 ) epochs = ( self.steps * batch_size * self.gradient_accumulate_every ) / len(self.dataset) k_frac = max( self.generator_top_k_gamma ** epochs, self.generator_top_k_frac ) k = math.ceil(batch_size * k_frac) if k != batch_size: fake_output_loss, _ = fake_output_loss.topk(k=k, largest=False) loss = fake_output_loss.mean() gen_loss = loss gen_loss = gen_loss / self.gradient_accumulate_every gen_loss.register_hook(raise_if_nan) backward(gen_loss, self.G_scaler) total_gen_loss += loss self.g_loss = float(total_gen_loss.item() / self.gradient_accumulate_every) optimizer_step(self.GAN.G_opt, self.G_scaler) # calculate moving averages if self.is_main and self.steps % 10 == 0 and self.steps > 20000: self.GAN.EMA() if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2: self.GAN.reset_parameter_averaging() # save from NaN errors if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)): print( f"NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}" ) self.load(self.checkpoint_num) raise NanException del total_disc_loss del total_gen_loss # periodically save results if self.is_main: if self.steps % self.save_every == 0: self.save(self.checkpoint_num) if self.steps % self.evaluate_every == 0 or ( self.steps % 100 == 0 and self.steps < 20000 ): self.evaluate(floor(self.steps / self.evaluate_every)) if ( exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0 ): num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size) fid = self.calculate_fid(num_batches) self.last_fid = fid with open( str(self.results_dir / self.name / "fid_scores.txt"), "a" ) as f: f.write(f"{self.steps},{fid}\n") self.steps += 1 @torch.no_grad() def evaluate(self, num=0, num_image_tiles=8, trunc=1.0): self.GAN.eval() ext = self.image_extension num_rows = num_image_tiles latent_dim = self.GAN.latent_dim # image_size = self.GAN.image_size # latents and noise latents = torch.randn((num_rows ** 2, latent_dim)).cuda(self.rank) # regular generated_images = self.generate_truncated(self.GAN.G, latents) torchvision.utils.save_image( generated_images, str(self.results_dir / self.name / f"{str(num)}.{ext}"), nrow=num_rows, ) # moving averages generated_images = self.generate_truncated(self.GAN.GE, latents) torchvision.utils.save_image( generated_images, str(self.results_dir / self.name / f"{str(num)}-ema.{ext}"), nrow=num_rows, ) @torch.no_grad() def calculate_fid(self, num_batches): torch.cuda.empty_cache() real_path = str(self.results_dir / self.name / "fid_real") + "/" fake_path = str(self.results_dir / self.name / "fid_fake") + "/" # remove any existing files used for fid calculation and recreate directories rmtree(real_path, ignore_errors=True) rmtree(fake_path, ignore_errors=True) os.makedirs(real_path) os.makedirs(fake_path) for batch_num in tqdm( range(num_batches), desc="calculating FID - saving reals" ): real_batch = next(self.loader) for k in range(real_batch.size(0)): torchvision.utils.save_image( real_batch[k, :, :, :], real_path + "{}.png".format(k + batch_num * self.batch_size), ) # generate a bunch of fake images in results / name / fid_fake self.GAN.eval() ext = self.image_extension latent_dim = self.GAN.latent_dim # image_size = self.GAN.image_size for batch_num in tqdm( range(num_batches), desc="calculating FID - saving generated" ): # latents and noise latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank) # moving averages generated_images = self.generate_truncated(self.GAN.GE, latents) for j in range(generated_images.size(0)): torchvision.utils.save_image( generated_images[j, :, :, :], str( Path(fake_path) / f"{str(j + batch_num * self.batch_size)}-ema.{ext}" ), ) return fid_score.calculate_fid_given_paths( [real_path, fake_path], 256, True, 2048 ) @torch.no_grad() def generate_truncated(self, G, style, trunc_psi=0.75, num_image_tiles=8): generated_images = evaluate_in_chunks(self.batch_size, G, style) return generated_images.clamp_(0.0, 1.0) @torch.no_grad() def generate_interpolation( self, num=0, num_image_tiles=8, trunc=1.0, num_steps=100, save_frames=False ): self.GAN.eval() ext = self.image_extension num_rows = num_image_tiles latent_dim = self.GAN.latent_dim # image_size = self.GAN.image_size # latents and noise latents_low = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank) latents_high = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank) ratios = torch.linspace(0.0, 8.0, num_steps) frames = [] for ratio in tqdm(ratios): interp_latents = slerp(ratio, latents_low, latents_high) generated_images = self.generate_truncated(self.GAN.GE, interp_latents) images_grid = torchvision.utils.make_grid(generated_images, nrow=num_rows) pil_image = transforms.ToPILImage()(images_grid.cpu()) if self.transparent: background = Image.new("RGBA", pil_image.size, (255, 255, 255)) pil_image = Image.alpha_composite(background, pil_image) frames.append(pil_image) frames[0].save( str(self.results_dir / self.name / f"{str(num)}.gif"), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True, ) if save_frames: folder_path = self.results_dir / self.name / f"{str(num)}" folder_path.mkdir(parents=True, exist_ok=True) for ind, frame in enumerate(frames): frame.save(str(folder_path / f"{str(ind)}.{ext}")) def print_log(self): data = [ ("G", self.g_loss), ("D", self.d_loss), ("GP", self.last_gp_loss), ("SS", self.last_recon_loss), ("FID", self.last_fid), ] data = [d for d in data if exists(d[1])] log = " | ".join(map(lambda n: f"{n[0]}: {n[1]:.2f}", data)) print(log) def model_name(self, num): return str(self.models_dir / self.name / f"model_{num}.pt") def init_folders(self): (self.results_dir / self.name).mkdir(parents=True, exist_ok=True) (self.models_dir / self.name).mkdir(parents=True, exist_ok=True) def clear(self): rmtree(str(self.models_dir / self.name), True) rmtree(str(self.results_dir / self.name), True) rmtree(str(self.config_path), True) self.init_folders() def save(self, num): save_data = {"GAN": self.GAN.state_dict(), "version": __version__} if self.amp: save_data = { **save_data, "G_scaler": self.G_scaler.state_dict(), "D_scaler": self.D_scaler.state_dict(), } torch.save(save_data, self.model_name(num)) self.write_config() def load(self, num=-1): self.load_config() name = num if num == -1: file_paths = [ p for p in Path(self.models_dir / self.name).glob("model_*.pt") ] saved_nums = sorted(map(lambda x: int(x.stem.split("_")[1]), file_paths)) if len(saved_nums) == 0: return name = saved_nums[-1] print(f"continuing from previous epoch - {name}") self.steps = name * self.save_every load_data = torch.load(self.model_name(name)) if "version" in load_data and self.is_main: print(f"loading from version {load_data['version']}") try: self.GAN.load_state_dict(load_data["GAN"]) except Exception as e: print( "unable to load save model. please try downgrading the package to the version specified by the saved model" ) raise e if self.amp: if "G_scaler" in load_data: self.G_scaler.load_state_dict(load_data["G_scaler"]) if "D_scaler" in load_data: self.D_scaler.load_state_dict(load_data["D_scaler"])
def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu logger = get_logger(args.logging_file) logger.info("Use GPU: {} for training".format(args.gpu)) args.rank = args.rank * ngpus_per_node + gpu torch.distributed.init_process_group(backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=args.rank) epochs = args.epochs input_size = args.input_size resume_epoch = args.resume_epoch initializer = KaimingInitializer() zero_gamma = ZeroLastGamma() mix_precision_training = args.mix_precision_training is_first_rank = True if args.rank % ngpus_per_node == 0 else False batches_pre_epoch = args.num_training_samples // (args.batch_size * ngpus_per_node) lr = 0.1 * (args.batch_size * ngpus_per_node // 32) if args.lr == 0 else args.lr model = get_model(models, args.model) model.apply(initializer) if args.last_gamma: model.apply(zero_gamma) logger.info('Apply zero last gamma init.') if is_first_rank and args.model_info: summary(model, torch.rand((1, 3, input_size, input_size))) parameters = model.parameters() if not args.no_wd else no_decay_bias(model) if args.sgd_gc: logger.info('Use SGD_GC optimizer.') optimizer = SGD_GC(parameters, lr=lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) else: optimizer = optim.SGD(parameters, lr=lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = CosineWarmupLr(optimizer, batches_pre_epoch, epochs, base_lr=args.lr, warmup_epochs=args.warmup_epochs) # dropblock_scheduler = DropBlockScheduler(model, batches_pre_epoch, epochs) if args.lookahead: optimizer = Lookahead(optimizer) logger.info('Use lookahead optimizer.') torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.num_workers = int( (args.num_workers + ngpus_per_node - 1) / ngpus_per_node) if args.mix_precision_training and is_first_rank: logger.info('Train with FP16.') scaler = GradScaler(enabled=args.mix_precision_training) model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) Loss = nn.CrossEntropyLoss().cuda(args.gpu) if not args.label_smoothing else \ LabelSmoothingLoss(args.classes, smoothing=0.1).cuda(args.gpu) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if args.autoaugment: train_transform = transforms.Compose([ transforms.RandomResizedCrop(input_size), transforms.RandomHorizontalFlip(), ImageNetPolicy, transforms.ToTensor(), normalize, ]) else: train_transform = transforms.Compose([ transforms.RandomResizedCrop(input_size), # Cutout(), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.4, 0.4, 0.4), transforms.ToTensor(), normalize, ]) val_transform = transforms.Compose([ transforms.Resize(int(input_size / 0.875)), transforms.CenterCrop(input_size), transforms.ToTensor(), normalize, ]) train_set = ImageNet(args.data_path, split='train', transform=train_transform) val_set = ImageNet(args.data_path, split='val', transform=val_transform) train_sampler = DistributedSampler(train_set) train_loader = DataLoader(train_set, args.batch_size, False, pin_memory=True, num_workers=args.num_workers, drop_last=True, sampler=train_sampler) val_loader = DataLoader(val_set, args.batch_size, False, pin_memory=True, num_workers=args.num_workers, drop_last=False) if resume_epoch > 0: loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume_param, map_location=loc) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scaler.load_state_dict(checkpoint['scaler']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) print("Finish loading resume param.") torch.backends.cudnn.benchmark = True top1_acc = metric.Accuracy(name='Top1 Accuracy') top5_acc = metric.TopKAccuracy(top=5, name='Top5 Accuracy') loss_record = metric.NumericalCost(name='Loss') for epoch in range(resume_epoch, epochs): tic = time.time() train_sampler.set_epoch(epoch) if not args.mixup: train_one_epoch(model, train_loader, Loss, optimizer, epoch, lr_scheduler, logger, top1_acc, loss_record, scaler, args) else: train_one_epoch_mixup(model, train_loader, Loss, optimizer, epoch, lr_scheduler, logger, loss_record, scaler, args) train_speed = int(args.num_training_samples // (time.time() - tic)) if is_first_rank: logger.info( 'Finish one epoch speed: {} samples/s'.format(train_speed)) test(model, val_loader, Loss, epoch, logger, top1_acc, top5_acc, loss_record, args) if args.rank % ngpus_per_node == 0: checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scaler': scaler.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), } torch.save( checkpoint, '{}/{}_{}_{:.5}.pt'.format(args.save_dir, args.model, epoch, top1_acc.get()))
def run( cls, model: AbsESPnetModel, optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], train_iter_factory: AbsIterFactory, valid_iter_factory: AbsIterFactory, plot_attention_iter_factory: Optional[AbsIterFactory], trainer_options, distributed_option: DistributedOption, ) -> None: """Perform training. This method performs the main process of training.""" assert check_argument_types() # NOTE(kamo): Don't check the type more strictly as far trainer_options assert is_dataclass(trainer_options), type(trainer_options) assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers)) if isinstance(trainer_options.keep_nbest_models, int): keep_nbest_models = [trainer_options.keep_nbest_models] else: if len(trainer_options.keep_nbest_models) == 0: logging.warning("No keep_nbest_models is given. Change to [1]") trainer_options.keep_nbest_models = [1] keep_nbest_models = trainer_options.keep_nbest_models output_dir = Path(trainer_options.output_dir) reporter = Reporter() if trainer_options.use_amp: if V(torch.__version__) < V("1.6.0"): raise RuntimeError( "Require torch>=1.6.0 for Automatic Mixed Precision") if trainer_options.sharded_ddp: if fairscale is None: raise RuntimeError( "Requiring fairscale. Do 'pip install fairscale'") scaler = fairscale.optim.grad_scaler.ShardedGradScaler() else: scaler = GradScaler() else: scaler = None if trainer_options.resume and (output_dir / "checkpoint.pth").exists(): cls.resume( checkpoint=output_dir / "checkpoint.pth", model=model, optimizers=optimizers, schedulers=schedulers, reporter=reporter, scaler=scaler, ngpu=trainer_options.ngpu, ) start_epoch = reporter.get_epoch() + 1 if start_epoch == trainer_options.max_epoch + 1: logging.warning( f"The training has already reached at max_epoch: {start_epoch}" ) if distributed_option.distributed: if trainer_options.sharded_ddp: dp_model = fairscale.nn.data_parallel.ShardedDataParallel( module=model, sharded_optimizer=optimizers, ) else: dp_model = torch.nn.parallel.DistributedDataParallel( model, device_ids=( # Perform multi-Process with multi-GPUs [torch.cuda.current_device()] if distributed_option.ngpu == 1 # Perform single-Process with multi-GPUs else None), output_device=(torch.cuda.current_device() if distributed_option.ngpu == 1 else None), find_unused_parameters=trainer_options.unused_parameters, ) elif distributed_option.ngpu > 1: dp_model = torch.nn.parallel.DataParallel( model, device_ids=list(range(distributed_option.ngpu)), ) else: # NOTE(kamo): DataParallel also should work with ngpu=1, # but for debuggability it's better to keep this block. dp_model = model if trainer_options.use_tensorboard and ( not distributed_option.distributed or distributed_option.dist_rank == 0): from torch.utils.tensorboard import SummaryWriter train_summary_writer = SummaryWriter( str(output_dir / "tensorboard" / "train")) valid_summary_writer = SummaryWriter( str(output_dir / "tensorboard" / "valid")) else: train_summary_writer = None start_time = time.perf_counter() for iepoch in range(start_epoch, trainer_options.max_epoch + 1): if iepoch != start_epoch: logging.info( "{}/{}epoch started. Estimated time to finish: {}".format( iepoch, trainer_options.max_epoch, humanfriendly.format_timespan( (time.perf_counter() - start_time) / (iepoch - start_epoch) * (trainer_options.max_epoch - iepoch + 1)), )) else: logging.info( f"{iepoch}/{trainer_options.max_epoch}epoch started") set_all_random_seed(trainer_options.seed + iepoch) reporter.set_epoch(iepoch) # 1. Train and validation for one-epoch with reporter.observe("train") as sub_reporter: all_steps_are_invalid = cls.train_one_epoch( model=dp_model, optimizers=optimizers, schedulers=schedulers, iterator=train_iter_factory.build_iter(iepoch), reporter=sub_reporter, scaler=scaler, summary_writer=train_summary_writer, options=trainer_options, distributed_option=distributed_option, ) with reporter.observe("valid") as sub_reporter: cls.validate_one_epoch( model=dp_model, iterator=valid_iter_factory.build_iter(iepoch), reporter=sub_reporter, options=trainer_options, distributed_option=distributed_option, ) if not distributed_option.distributed or distributed_option.dist_rank == 0: # att_plot doesn't support distributed if plot_attention_iter_factory is not None: with reporter.observe("att_plot") as sub_reporter: cls.plot_attention( model=model, output_dir=output_dir / "att_ws", summary_writer=train_summary_writer, iterator=plot_attention_iter_factory.build_iter( iepoch), reporter=sub_reporter, options=trainer_options, ) # 2. LR Scheduler step for scheduler in schedulers: if isinstance(scheduler, AbsValEpochStepScheduler): scheduler.step( reporter.get_value( *trainer_options.val_scheduler_criterion)) elif isinstance(scheduler, AbsEpochStepScheduler): scheduler.step() if trainer_options.sharded_ddp: for optimizer in optimizers: if isinstance(optimizer, fairscale.optim.oss.OSS): optimizer.consolidate_state_dict() if not distributed_option.distributed or distributed_option.dist_rank == 0: # 3. Report the results logging.info(reporter.log_message()) if trainer_options.use_matplotlib: reporter.matplotlib_plot(output_dir / "images") if train_summary_writer is not None: reporter.tensorboard_add_scalar(train_summary_writer, key1="train") reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid") if trainer_options.use_wandb: reporter.wandb_log() # 4. Save/Update the checkpoint torch.save( { "model": model.state_dict(), "reporter": reporter.state_dict(), "optimizers": [o.state_dict() for o in optimizers], "schedulers": [ s.state_dict() if s is not None else None for s in schedulers ], "scaler": scaler.state_dict() if scaler is not None else None, }, output_dir / "checkpoint.pth", ) # 5. Save and log the model and update the link to the best model torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pth") # Creates a sym link latest.pth -> {iepoch}epoch.pth p = output_dir / "latest.pth" if p.is_symlink() or p.exists(): p.unlink() p.symlink_to(f"{iepoch}epoch.pth") _improved = [] for _phase, k, _mode in trainer_options.best_model_criterion: # e.g. _phase, k, _mode = "train", "loss", "min" if reporter.has(_phase, k): best_epoch = reporter.get_best_epoch(_phase, k, _mode) # Creates sym links if it's the best result if best_epoch == iepoch: p = output_dir / f"{_phase}.{k}.best.pth" if p.is_symlink() or p.exists(): p.unlink() p.symlink_to(f"{iepoch}epoch.pth") _improved.append(f"{_phase}.{k}") if len(_improved) == 0: logging.info("There are no improvements in this epoch") else: logging.info("The best model has been updated: " + ", ".join(_improved)) log_model = (trainer_options.wandb_model_log_interval > 0 and iepoch % trainer_options.wandb_model_log_interval == 0) if log_model and trainer_options.use_wandb: import wandb logging.info("Logging Model on this epoch :::::") artifact = wandb.Artifact( name=f"model_{wandb.run.id}", type="model", metadata={"improved": _improved}, ) artifact.add_file(str(output_dir / f"{iepoch}epoch.pth")) aliases = [ f"epoch-{iepoch}", "best" if best_epoch == iepoch else "", ] wandb.log_artifact(artifact, aliases=aliases) # 6. Remove the model files excluding n-best epoch and latest epoch _removed = [] # Get the union set of the n-best among multiple criterion nbests = set().union(*[ set( reporter.sort_epochs(ph, k, m) [:max(keep_nbest_models)]) for ph, k, m in trainer_options.best_model_criterion if reporter.has(ph, k) ]) # Generated n-best averaged model if (trainer_options.nbest_averaging_interval > 0 and iepoch % trainer_options.nbest_averaging_interval == 0): average_nbest_models( reporter=reporter, output_dir=output_dir, best_model_criterion=trainer_options. best_model_criterion, nbest=keep_nbest_models, suffix=f"till{iepoch}epoch", ) for e in range(1, iepoch): p = output_dir / f"{e}epoch.pth" if p.exists() and e not in nbests: p.unlink() _removed.append(str(p)) if len(_removed) != 0: logging.info("The model files were removed: " + ", ".join(_removed)) # 7. If any updating haven't happened, stops the training if all_steps_are_invalid: logging.warning( "The gradients at all steps are invalid in this epoch. " f"Something seems wrong. This training was stopped at {iepoch}epoch" ) break # 8. Check early stopping if trainer_options.patience is not None: if reporter.check_early_stopping( trainer_options.patience, *trainer_options.early_stopping_criterion): break else: logging.info( f"The training was finished at {trainer_options.max_epoch} epochs " ) # Generated n-best averaged model if not distributed_option.distributed or distributed_option.dist_rank == 0: average_nbest_models( reporter=reporter, output_dir=output_dir, best_model_criterion=trainer_options.best_model_criterion, nbest=keep_nbest_models, )
class Trainer(): def __init__( self, name = 'default', results_dir = 'results', models_dir = 'models', base_dir = './', optimizer = 'adam', num_workers = None, latent_dim = 256, image_size = 128, num_image_tiles = 8, fmap_max = 512, transparent = False, greyscale = False, batch_size = 4, gp_weight = 10, gradient_accumulate_every = 1, attn_res_layers = [], freq_chan_attn = False, disc_output_size = 5, dual_contrast_loss = False, antialias = False, lr = 2e-4, lr_mlp = 1., ttur_mult = 1., save_every = 1000, evaluate_every = 1000, aug_prob = None, aug_types = ['translation', 'cutout'], dataset_aug_prob = 0., calculate_fid_every = None, calculate_fid_num_images = 12800, clear_fid_cache = False, is_ddp = False, rank = 0, world_size = 1, log = False, amp = False, *args, **kwargs ): self.GAN_params = [args, kwargs] self.GAN = None self.name = name base_dir = Path(base_dir) self.base_dir = base_dir self.results_dir = base_dir / results_dir self.models_dir = base_dir / models_dir self.fid_dir = base_dir / 'fid' / name self.config_path = self.models_dir / name / '.config.json' assert is_power_of_two(image_size), 'image size must be a power of 2 (64, 128, 256, 512, 1024)' assert all(map(is_power_of_two, attn_res_layers)), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)' assert not (dual_contrast_loss and disc_output_size > 1), 'discriminator output size cannot be greater than 1 if using dual contrastive loss' self.image_size = image_size self.num_image_tiles = num_image_tiles self.latent_dim = latent_dim self.fmap_max = fmap_max self.transparent = transparent self.greyscale = greyscale assert (int(self.transparent) + int(self.greyscale)) < 2, 'you can only set either transparency or greyscale' self.aug_prob = aug_prob self.aug_types = aug_types self.lr = lr self.optimizer = optimizer self.num_workers = num_workers self.ttur_mult = ttur_mult self.batch_size = batch_size self.gradient_accumulate_every = gradient_accumulate_every self.gp_weight = gp_weight self.evaluate_every = evaluate_every self.save_every = save_every self.steps = 0 self.attn_res_layers = attn_res_layers self.freq_chan_attn = freq_chan_attn self.disc_output_size = disc_output_size self.antialias = antialias self.dual_contrast_loss = dual_contrast_loss self.d_loss = 0 self.g_loss = 0 self.last_gp_loss = None self.last_recon_loss = None self.last_fid = None self.init_folders() self.loader = None self.dataset_aug_prob = dataset_aug_prob self.calculate_fid_every = calculate_fid_every self.calculate_fid_num_images = calculate_fid_num_images self.clear_fid_cache = clear_fid_cache self.is_ddp = is_ddp self.is_main = rank == 0 self.rank = rank self.world_size = world_size self.syncbatchnorm = is_ddp self.amp = amp self.G_scaler = GradScaler(enabled = self.amp) self.D_scaler = GradScaler(enabled = self.amp) @property def image_extension(self): return 'jpg' if not self.transparent else 'png' @property def checkpoint_num(self): return floor(self.steps // self.save_every) def init_GAN(self): args, kwargs = self.GAN_params # set some global variables before instantiating GAN global norm_class global Blur norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d Blur = nn.Identity if not self.antialias else Blur # handle bugs when # switching from multi-gpu back to single gpu if self.syncbatchnorm and not self.is_ddp: import torch.distributed as dist os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=0, world_size=1) # instantiate GAN self.GAN = LightweightGAN( optimizer=self.optimizer, lr = self.lr, latent_dim = self.latent_dim, attn_res_layers = self.attn_res_layers, freq_chan_attn = self.freq_chan_attn, image_size = self.image_size, ttur_mult = self.ttur_mult, fmap_max = self.fmap_max, disc_output_size = self.disc_output_size, transparent = self.transparent, greyscale = self.greyscale, rank = self.rank, *args, **kwargs ) if self.is_ddp: ddp_kwargs = {'device_ids': [self.rank], 'output_device': self.rank, 'find_unused_parameters': True} self.G_ddp = DDP(self.GAN.G, **ddp_kwargs) self.D_ddp = DDP(self.GAN.D, **ddp_kwargs) self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs) def write_config(self): self.config_path.write_text(json.dumps(self.config())) def load_config(self): config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text()) self.image_size = config['image_size'] self.transparent = config['transparent'] self.syncbatchnorm = config['syncbatchnorm'] self.disc_output_size = config['disc_output_size'] self.greyscale = config.pop('greyscale', False) self.attn_res_layers = config.pop('attn_res_layers', []) self.freq_chan_attn = config.pop('freq_chan_attn', False) self.optimizer = config.pop('optimizer', 'adam') self.fmap_max = config.pop('fmap_max', 512) del self.GAN self.init_GAN() def config(self): return { 'image_size': self.image_size, 'transparent': self.transparent, 'greyscale': self.greyscale, 'syncbatchnorm': self.syncbatchnorm, 'disc_output_size': self.disc_output_size, 'optimizer': self.optimizer, 'attn_res_layers': self.attn_res_layers, 'freq_chan_attn': self.freq_chan_attn } def set_data_src(self, folder): num_workers = default(self.num_workers, math.ceil(NUM_CORES / self.world_size)) self.dataset = ImageDataset(folder, self.image_size, transparent = self.transparent, greyscale = self.greyscale, aug_prob = self.dataset_aug_prob) sampler = DistributedSampler(self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True) if self.is_ddp else None dataloader = DataLoader(self.dataset, num_workers = num_workers, batch_size = math.ceil(self.batch_size / self.world_size), sampler = sampler, shuffle = not self.is_ddp, drop_last = True, pin_memory = True) self.loader = cycle(dataloader) # auto set augmentation prob for user if dataset is detected to be low num_samples = len(self.dataset) if not exists(self.aug_prob) and num_samples < 1e5: self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6) print(f'autosetting augmentation probability to {round(self.aug_prob * 100)}%') def train(self): assert exists(self.loader), 'You must first initialize the data source with `.set_data_src(<folder of images>)`' device = torch.device(f'cuda:{self.rank}') if not exists(self.GAN): self.init_GAN() self.GAN.train() total_disc_loss = torch.zeros([], device=device) total_gen_loss = torch.zeros([], device=device) batch_size = math.ceil(self.batch_size / self.world_size) image_size = self.GAN.image_size latent_dim = self.GAN.latent_dim aug_prob = default(self.aug_prob, 0) aug_types = self.aug_types aug_kwargs = {'prob': aug_prob, 'types': aug_types} G = self.GAN.G if not self.is_ddp else self.G_ddp D = self.GAN.D if not self.is_ddp else self.D_ddp D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp apply_gradient_penalty = self.steps % 4 == 0 # amp related contexts and functions amp_context = autocast if self.amp else null_context # discriminator loss fn if self.dual_contrast_loss: D_loss_fn = dual_contrastive_loss else: D_loss_fn = hinge_loss # train discriminator self.GAN.D_opt.zero_grad() for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, G]): latents = torch.randn(batch_size, latent_dim).cuda(self.rank) image_batch = next(self.loader).cuda(self.rank) image_batch.requires_grad_() with amp_context(): with torch.no_grad(): generated_images = G(latents) fake_output, fake_output_32x32, _ = D_aug(generated_images, detach = True, **aug_kwargs) real_output, real_output_32x32, real_aux_loss = D_aug(image_batch, calc_aux_loss = True, **aug_kwargs) real_output_loss = real_output fake_output_loss = fake_output divergence = D_loss_fn(real_output_loss, fake_output_loss) divergence_32x32 = D_loss_fn(real_output_32x32, fake_output_32x32) disc_loss = divergence + divergence_32x32 aux_loss = real_aux_loss disc_loss = disc_loss + aux_loss if apply_gradient_penalty: outputs = [real_output, real_output_32x32] outputs = list(map(self.D_scaler.scale, outputs)) if self.amp else outputs scaled_gradients = torch_grad(outputs=outputs, inputs=image_batch, grad_outputs=list(map(lambda t: torch.ones(t.size(), device = image_batch.device), outputs)), create_graph=True, retain_graph=True, only_inputs=True)[0] inv_scale = safe_div(1., self.D_scaler.get_scale()) if self.amp else 1. if inv_scale != float('inf'): gradients = scaled_gradients * inv_scale with amp_context(): gradients = gradients.reshape(batch_size, -1) gp = self.gp_weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() if not torch.isnan(gp): disc_loss = disc_loss + gp self.last_gp_loss = gp.clone().detach().item() with amp_context(): disc_loss = disc_loss / self.gradient_accumulate_every disc_loss.register_hook(raise_if_nan) self.D_scaler.scale(disc_loss).backward() total_disc_loss += divergence self.last_recon_loss = aux_loss.item() self.d_loss = float(total_disc_loss.item() / self.gradient_accumulate_every) self.D_scaler.step(self.GAN.D_opt) self.D_scaler.update() # generator loss fn if self.dual_contrast_loss: G_loss_fn = dual_contrastive_loss G_requires_calc_real = True else: G_loss_fn = gen_hinge_loss G_requires_calc_real = False # train generator self.GAN.G_opt.zero_grad() for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[G, D_aug]): latents = torch.randn(batch_size, latent_dim).cuda(self.rank) if G_requires_calc_real: image_batch = next(self.loader).cuda(self.rank) image_batch.requires_grad_() with amp_context(): generated_images = G(latents) fake_output, fake_output_32x32, _ = D_aug(generated_images, **aug_kwargs) real_output, real_output_32x32, _ = D_aug(image_batch, **aug_kwargs) if G_requires_calc_real else (None, None, None) loss = G_loss_fn(fake_output, real_output) loss_32x32 = G_loss_fn(fake_output_32x32, real_output_32x32) gen_loss = loss + loss_32x32 gen_loss = gen_loss / self.gradient_accumulate_every gen_loss.register_hook(raise_if_nan) self.G_scaler.scale(gen_loss).backward() total_gen_loss += loss self.g_loss = float(total_gen_loss.item() / self.gradient_accumulate_every) self.G_scaler.step(self.GAN.G_opt) self.G_scaler.update() # calculate moving averages if self.is_main and self.steps % 10 == 0 and self.steps > 20000: self.GAN.EMA() if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2: self.GAN.reset_parameter_averaging() # save from NaN errors if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)): print(f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}') self.load(self.checkpoint_num) raise NanException del total_disc_loss del total_gen_loss # periodically save results if self.is_main: if self.steps % self.save_every == 0: self.save(self.checkpoint_num) if self.steps % self.evaluate_every == 0 or (self.steps % 100 == 0 and self.steps < 20000): self.evaluate(floor(self.steps / self.evaluate_every), num_image_tiles = self.num_image_tiles) if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0: num_batches = math.ceil(self.calculate_fid_num_images / self.batch_size) fid = self.calculate_fid(num_batches) self.last_fid = fid with open(str(self.results_dir / self.name / f'fid_scores.txt'), 'a') as f: f.write(f'{self.steps},{fid}\n') self.steps += 1 @torch.no_grad() def evaluate(self, num = 0, num_image_tiles = 4): self.GAN.eval() ext = self.image_extension num_rows = num_image_tiles latent_dim = self.GAN.latent_dim image_size = self.GAN.image_size # latents and noise latents = torch.randn((num_rows ** 2, latent_dim)).cuda(self.rank) # regular generated_images = self.generate_(self.GAN.G, latents) torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows) # moving averages generated_images = self.generate_(self.GAN.GE, latents) torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows) @torch.no_grad() def generate(self, num=0, num_image_tiles=4, checkpoint=None, types=['default', 'ema']): self.GAN.eval() latent_dim = self.GAN.latent_dim dir_name = self.name + str('-generated-') + str(checkpoint) dir_full = Path().absolute() / self.results_dir / dir_name ext = self.image_extension if not dir_full.exists(): os.mkdir(dir_full) # regular if 'default' in types: for i in tqdm(range(num_image_tiles), desc='Saving generated default images'): latents = torch.randn((1, latent_dim)).cuda(self.rank) generated_image = self.generate_(self.GAN.G, latents) path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}.{ext}') torchvision.utils.save_image(generated_image[0], path, nrow=1) # moving averages if 'ema' in types: for i in tqdm(range(num_image_tiles), desc='Saving generated EMA images'): latents = torch.randn((1, latent_dim)).cuda(self.rank) generated_image = self.generate_(self.GAN.GE, latents) path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}-ema.{ext}') torchvision.utils.save_image(generated_image[0], path, nrow=1) return dir_full @torch.no_grad() def show_progress(self, num_images=4, types=['default', 'ema']): checkpoints = self.get_checkpoints() assert exists(checkpoints), 'cannot find any checkpoints to create a training progress video for' dir_name = self.name + str('-progress') dir_full = Path().absolute() / self.results_dir / dir_name ext = self.image_extension latents = None zfill_length = math.ceil(math.log10(len(checkpoints))) if not dir_full.exists(): os.mkdir(dir_full) for checkpoint in tqdm(checkpoints, desc='Generating progress images'): self.load(checkpoint, print_version=False) self.GAN.eval() if checkpoint == 0: latents = torch.randn((num_images, self.GAN.latent_dim)).cuda(self.rank) # regular if 'default' in types: generated_image = self.generate_(self.GAN.G, latents) path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}.{ext}') torchvision.utils.save_image(generated_image, path, nrow=num_images) # moving averages if 'ema' in types: generated_image = self.generate_(self.GAN.GE, latents) path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}-ema.{ext}') torchvision.utils.save_image(generated_image, path, nrow=num_images) @torch.no_grad() def calculate_fid(self, num_batches): from pytorch_fid import fid_score torch.cuda.empty_cache() real_path = self.fid_dir / 'real' fake_path = self.fid_dir / 'fake' # remove any existing files used for fid calculation and recreate directories if not real_path.exists() or self.clear_fid_cache: rmtree(real_path, ignore_errors=True) os.makedirs(real_path) for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'): real_batch = next(self.loader) for k, image in enumerate(real_batch.unbind(0)): ind = k + batch_num * self.batch_size torchvision.utils.save_image(image, real_path / f'{ind}.png') # generate a bunch of fake images in results / name / fid_fake rmtree(fake_path, ignore_errors=True) os.makedirs(fake_path) self.GAN.eval() ext = self.image_extension latent_dim = self.GAN.latent_dim image_size = self.GAN.image_size for batch_num in tqdm(range(num_batches), desc='calculating FID - saving generated'): # latents and noise latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank) # moving averages generated_images = self.generate_(self.GAN.GE, latents) for j, image in enumerate(generated_images.unbind(0)): ind = j + batch_num * self.batch_size torchvision.utils.save_image(image, str(fake_path / f'{str(ind)}-ema.{ext}')) return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, latents.device, 2048) @torch.no_grad() def generate_(self, G, style, num_image_tiles = 8): generated_images = evaluate_in_chunks(self.batch_size, G, style) return generated_images.clamp_(0., 1.) @torch.no_grad() def generate_interpolation(self, num = 0, num_image_tiles = 8, num_steps = 100, save_frames = False): self.GAN.eval() ext = self.image_extension num_rows = num_image_tiles latent_dim = self.GAN.latent_dim image_size = self.GAN.image_size # latents and noise latents_low = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank) latents_high = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank) ratios = torch.linspace(0., 8., num_steps) frames = [] for ratio in tqdm(ratios): interp_latents = slerp(ratio, latents_low, latents_high) generated_images = self.generate_(self.GAN.GE, interp_latents) images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows) pil_image = transforms.ToPILImage()(images_grid.cpu()) if self.transparent: background = Image.new('RGBA', pil_image.size, (255, 255, 255)) pil_image = Image.alpha_composite(background, pil_image) frames.append(pil_image) frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True) if save_frames: folder_path = (self.results_dir / self.name / f'{str(num)}') folder_path.mkdir(parents=True, exist_ok=True) for ind, frame in enumerate(frames): frame.save(str(folder_path / f'{str(ind)}.{ext}')) def print_log(self): data = [ ('G', self.g_loss), ('D', self.d_loss), ('GP', self.last_gp_loss), ('SS', self.last_recon_loss), ('FID', self.last_fid) ] data = [d for d in data if exists(d[1])] log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data)) print(log) def model_name(self, num): return str(self.models_dir / self.name / f'model_{num}.pt') def init_folders(self): (self.results_dir / self.name).mkdir(parents=True, exist_ok=True) (self.models_dir / self.name).mkdir(parents=True, exist_ok=True) def clear(self): rmtree(str(self.models_dir / self.name), True) rmtree(str(self.results_dir / self.name), True) rmtree(str(self.fid_dir), True) rmtree(str(self.config_path), True) self.init_folders() def save(self, num): save_data = { 'GAN': self.GAN.state_dict(), 'version': __version__, 'G_scaler': self.G_scaler.state_dict(), 'D_scaler': self.D_scaler.state_dict() } torch.save(save_data, self.model_name(num)) self.write_config() def load(self, num=-1, print_version=True): self.load_config() name = num if num == -1: checkpoints = self.get_checkpoints() if not exists(checkpoints): return name = checkpoints[-1] print(f'continuing from previous epoch - {name}') self.steps = name * self.save_every load_data = torch.load(self.model_name(name)) if print_version and 'version' in load_data and self.is_main: print(f"loading from version {load_data['version']}") try: self.GAN.load_state_dict(load_data['GAN']) except Exception as e: print('unable to load save model. please try downgrading the package to the version specified by the saved model') raise e if 'G_scaler' in load_data: self.G_scaler.load_state_dict(load_data['G_scaler']) if 'D_scaler' in load_data: self.D_scaler.load_state_dict(load_data['D_scaler']) def get_checkpoints(self): file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')] saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths)) if len(saved_nums) == 0: return None return saved_nums
def distributed_worker(device, ngpus_per_node, args): torch.cuda.set_device(device) cudnn.benchmark = True print('%s: Use GPU: %d for training' % (time.ctime(), args.gpu_no[device])) rank = args.rank * ngpus_per_node + device batch_size = int(args.batch_size / ngpus_per_node) num_workers = int((args.num_workers + ngpus_per_node - 1) / ngpus_per_node) # init process for distributed training dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=rank) # load network network, optimizer, scheduler, loss_calculator = load_network(args, device) if device == 0: summary(network, input_size=(3, 512, 512)) # load dataset dataset = load_dataset(args) sampler = torch.utils.data.distributed.DistributedSampler(dataset) dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, sampler=sampler, collate_fn=dataset.collate_fn) # gradient scaler for automatic mixed precision scaler = GradScaler() if args.amp else None # training for epoch in range(args.start_epoch, args.end_epoch): sampler.set_epoch(epoch) # train one epoch train_step(dataloader, network, loss_calculator, optimizer, scheduler, scaler, epoch, device, args) # adjust learning rate scheduler.step() # save network if rank % ngpus_per_node == 0: torch.save( { 'epoch': epoch + 1, 'state_dict': network.module.state_dict() if hasattr(network, 'module') else network.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'scaler': scaler.state_dict() if scaler is not None else None, 'loss_log': loss_calculator.log }, os.path.join(args.save_path, 'check_point_%d.pth' % (epoch + 1))) return None
# if epoch%2==0: print(str(epoch) + ' ' + curr_train_stage + " loss: " + str(mean_volume_loss) + " time: " + str(mean_time)) if os.path.isfile(main_folder+'exit_file.txt'): torch.cuda.empty_cache() sys.exit(0) if epoch%25==0: torch.save({ 'epoch': epoch, 'args' : args, 'args_SLNet' : argsSLNet, 'statistics' : stats, 'model_state_dict': net_get_params(net).state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scaler_state_dict' : scaler.state_dict(), 'loss': mean_volume_loss}, save_folder + '/model_')#+str(epoch)) if epoch%50==0: torch.save({ 'epoch': epoch, 'args' : args, 'args_SLNet' : argsSLNet, 'statistics' : stats, 'model_state_dict': net_get_params(net).state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scaler_state_dict' : scaler.state_dict(), 'loss': mean_volume_loss}, save_folder + '/model_'+str(epoch))
def training(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #===================================# #==============Logging==============# #===================================# logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) handler = TqdmLoggingHandler() handler.setFormatter(logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S")) logger.addHandler(handler) logger.propagate = False #===================================# #============Data Load==============# #===================================# # 1) Dataloader setting write_log(logger, "Load data...") gc.disable() transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) dataset_dict = { 'train': torchvision.datasets.CIFAR10(root='./dataset/cifar10', train=True, download=False, transform=transform), 'valid': torchvision.datasets.CIFAR10(root='./dataset/cifar10', train=False, download=False, transform=transform) } dataloader_dict = { 'train': DataLoader(dataset_dict['train'], drop_last=True, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers), 'valid': DataLoader(dataset_dict['valid'], drop_last=False, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers) } gc.enable() write_log(logger, f"Total number of trainingsets iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}") #===================================# #===========Model setting===========# #===================================# # 1) Model initiating write_log(logger, "Instantiating models...") model = Vision_Transformer(n_classes=10, img_size=32, patch_size=16) model.train() model = model.to(device) # 2) Optimizer setting # optimizer = AdamW(model.parameters(), lr=args.lr, eps=1e-8) optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=1e-8) scheduler = shceduler_select(optimizer, dataloader_dict, args) scaler = GradScaler() # 2) Model resume start_epoch = 0 if args.resume: checkpoint = torch.load(os.path.join(args.model_path, 'checkpoint.pth.tar'), map_location='cpu') start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) model = model.train() model = model.to(device) del checkpoint #===================================# #=========Model Train Start=========# #===================================# best_val_acc = 0 write_log(logger, 'Train start!') for epoch in range(start_epoch, args.num_epochs): train_epoch(args, epoch, model, dataloader_dict['train'], optimizer, scheduler, scaler, logger, device) val_loss, val_acc = valid_epoch(args, model, dataloader_dict['valid'], device) val_loss /= len(dataloader_dict['valid']) val_acc /= len(dataloader_dict['valid']) write_log(logger, 'Validation Loss: %3.3f' % val_loss) write_log(logger, 'Validation Accuracy: %3.2f%%' % val_acc) if val_acc > best_val_acc: write_log(logger, 'Checkpoint saving...') torch.save({ 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'scaler': scaler.state_dict() }, f'checkpoint.pth.tar') best_val_acc = val_acc best_epoch = epoch else: else_log = f'Still {best_epoch} epoch accuracy({round(best_val_acc, 2)})% is better...' write_log(logger, else_log) # 3) print(f'Best Epoch: {best_epoch}') print(f'Best Accuracy: {round(best_val_acc, 2)}')
'L_attr': L_attr.item(), 'L_rec': L_rec.item() }, niter) writer.add_scalars('Train/Adversarial losses', { 'Generator': lossG.item(), 'Discriminator': lossD.item() }, niter) print( f'niter: {niter} (epoch: {epoch} {iteration}/{len(train_dataloader)})') print( f' lossD: {lossD.item()} lossG: {lossG.item()} batch_time: {batch_time}s' ) print( f' L_adv: {L_adv.item()} L_id: {L_id.item()} L_attr: {L_attr.item()} L_rec: {L_rec.item()}' ) if iteration % 1000 == 0: torch.save(G.state_dict(), './saved_models/G_latest.pth') torch.save(D.state_dict(), './saved_models/D_latest.pth') torch.save(opt_D.state_dict(), './saved_models/optG_latest.pth') torch.save(opt_D.state_dict(), './saved_models/optD_latest.pth') torch.save(scaler.state_dict(), './saved_models/scaler_latest.pth') with open('./saved_models/niter.pkl', 'wb') as f: pickle.dump(niter, f) if (niter + 1) % 10000 == 0: torch.save(G.state_dict(), f'./saved_models/G_iteration_{niter + 1}.pth') torch.save(D.state_dict(), f'./saved_models/D_iteration_{niter + 1}.pth') with open(f'./saved_models/niter_{niter + 1}.pkl', 'wb') as f: pickle.dump(niter, f)
class Learner(object): def __init__(self, model, optimizer, loss_func, name="", scheduler=None, device='cpu'): self.model = model self.optimizer = optimizer self.loss_func = loss_func self.scheduler = scheduler self.scaler = None self.device = device self.metric = None self.name = name self.log = {} self.eth = 0.99 self.do_autocast = False def init_amp(self, init_scale=65536.0, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True, do_autocast=True): self.do_autocast = do_autocast if GradScaler is not None: self.scaler = GradScaler(init_scale=init_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, enabled=True) def get_y(self, batch): # get Y from Batch, the default is batch[-1] but you can overwrite it return batch[-1] def get_inds(self, batch): # get Y from Batch, the default is batch[-1] but you can overwrite it return batch[-1] def get_x(self, batch): # get x from Batch, the default is batch[:-1] but you can overwrite it if isinstance(batch, (list, tuple)): return batch[:-1] else: return [batch] def run_model(self, model, batch): return model(*(x.to(self.device) for x in self.get_x(batch))) def calc_loss(self, y_pred, y_true): return self.loss_func(y_pred, y_true.to(self.device)) def one_cycle(self, batch, train=True, do_step=True): device = self.device self.preprocess_batch(batch, train) y_true = self.get_y(batch) if autocast is None: y_pred = self.run_model(self.model, batch) loss = self.calc_loss(y_pred, y_true) loss_item = 0 if np.isnan(loss.item()) else loss.item() else: with autocast(self.do_autocast): y_pred = self.run_model(self.model, batch) loss = self.calc_loss(y_pred, y_true) loss_item = 0 if np.isnan(loss.item()) else loss.item() if train: if self.scaler is not None: self.scaler.scale(loss).backward() else: loss.backward() if do_step: if self.scaler is not None: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() if self.scheduler is not None: self.scheduler.step() self.optimizer.zero_grad() if np.isnan(loss.item()): print('got loss = nan') loss_item = 0 if np.isnan(loss.item()) else loss.item() return loss_item if train else (loss_item, y_pred.to('cpu').detach()) def one_training_epoch(self, dl, accumulation_steps=1): device = self.device torch.cuda.empty_cache() avg_loss = 0. lossf = 0. self.model = self.model.train() self.model.zero_grad() tk0 = notebook.tqdm(dl) for i, batch in enumerate(tk0): do_step = (i + 1) % accumulation_steps == 0 loss_item = self.one_cycle(batch, train=True, do_step=do_step) e = min(self.eth, 1 - 1.0 / (i + 1.0)) lossf = e * lossf + (1 - e) * loss_item tk0.set_postfix(loss=lossf) avg_loss += loss_item / len(dl) tk0.disable = False tk0.set_postfix(loss=avg_loss) tk0.disable = True return avg_loss def agg_tta(self, y): return np.stack(y,0).mean(0) if not isinstance(y[0],tuple)\ else tuple(np.stack([yy[i] for yy in y],0).mean(0) for i in range(len(y[0]))) def preprocess_batch(self, batch, train=True): return (batch) def one_eval_epoch(self, dl, tta=1): device = self.device avg_loss = 0. avg_accuracy = 0. lossf = 0 self.model = self.model.eval() predss = [] with torch.no_grad(): for t in range(tta): pred_list = [] true_list = [] tk0 = notebook.tqdm(dl) for i, batch in enumerate(tk0): loss_item, y_pred = self.one_cycle(batch, train=False, do_step=False) pred_list.append(y_pred.to('cpu').numpy() if not isinstance(y_pred,tuple) else\ tuple(y.to('cpu').numpy() for y in y_pred)) y_batch = self.get_y(batch) true_list.append(y_batch.to('cpu').numpy() if not isinstance(y_batch,tuple) else\ tuple(y.to('cpu').numpy() for y in y_batch)) e = min(self.eth, 1 - 1.0 / (i + 1.0)) lossf = e * lossf + (1 - e) * loss_item tk0.set_postfix(loss=lossf) avg_loss += loss_item / len(dl) # y_true=np.concatenate(true_list,0) y_true=np.concatenate(true_list,0) if not isinstance(true_list[0],tuple) else\ tuple(np.concatenate([p[i] for p in true_list],0) for i in range(len(true_list[0]))) predss.append(np.concatenate(pred_list,0) if not isinstance(pred_list[0],tuple) else\ tuple(np.concatenate([p[i] for p in pred_list],0) for i in range(len(pred_list[0])))) preds = self.agg_tta(predss, 0) if tta > 1 else predss[0] m = dict() if self.metric is None else self.metric(preds, y_true) tk0.disable = False tk0.set_postfix(loss=avg_loss, **m) tk0.disable = True return avg_loss, m def send_log(self, **kwargs): log = {'model': self.name} log.update(kwargs) try: sandesh.send(log) except: print(log) def save2log(self, **kwargs): for key in kwargs.keys(): if key not in self.log: self.log[key] = [] self.log[key].append(kwargs[key]) def evaluate(self, ds, num_workers=8, tta=1, dl_args={'shuffle': False}): dl = D.DataLoader(ds, num_workers=num_workers, **dl_args) return self.one_eval_epoch(dl, tta=tta) def fit(self, num_epoches, train_ds, validate_ds=None, batch_size=None, lr=None, accumulation_steps=1, num_workers=8, send_log=True, eval_batch=None, reset_best=False, make_best=True, tta=1, train_dl_args={'shuffle': True}, val_dl_args={'shuffle': False}, save_checkpoint='best', path=''): if batch_size is not None: train_dl_args['batch_size'] = batch_size val_dl_args['batch_size'] = batch_size if eval_batch is not None: val_dl_args['batch_size'] = eval_batch tq = notebook.tqdm(range(num_epoches)) if lr is not None: self.set_lr(lr) if reset_best or not hasattr(self, 'best_metric'): self.best_model = None self.best_metric = np.inf for k, epoch in enumerate(tq): self.on_epoch_begin(epoch, train_ds=train_ds, validate_ds=validate_ds) dl = D.DataLoader(train_ds, num_workers=num_workers, **train_dl_args) if next(self.model.parameters()).device != torch.device('cpu'): torch.cuda.empty_cache() tavg_loss = self.one_training_epoch( dl, accumulation_steps=accumulation_steps) # dl=D.DataLoader(validate_ds, batch_size=batch_size if eval_batch is None else eval_batch, # num_workers=num_workers,**val_dl_args) if validate_ds is not None: avg_loss, metric = self.evaluate(validate_ds, num_workers=num_workers, dl_args=val_dl_args, tta=tta) else: avg_loss = tavg_loss metric = {} if send_log: self.send_log(epoch=epoch, tloss=tavg_loss, loss=avg_loss, **metric) self.save2log(epoch=epoch, tloss=tavg_loss, loss=avg_loss, **metric) m = avg_loss if 'metric' not in metric.keys() else metric['metric'] if save_checkpoint == 'last': self.save_checkpoint(path) if m < self.best_metric: self.best_metric = m self.best_model = copy.deepcopy(self.model.state_dict()) tq.set_postfix(best_metric=self.best_metric) if save_checkpoint == 'best': self.save_checkpoint(path) self.on_epoch_end(epoch) print('best metric:', self.best_metric) if make_best: self.model.load_state_dict(self.best_model) def save_model(self, path, name=None): name = self.name if name is None else name torch.save(self.model.state_dict(), f'{path}{name}') def load_model(self, path, name=None, map_location=None): name = self.name if name is None else name self.model.load_state_dict( torch.load(f'{path}{name}', map_location=map_location)) def save_checkpoint(self, path, name=None): name = self.name + '.chk' if name is None else name checkpoint = { 'model': self.model.state_dict(), 'best_model': self.best_model, 'best_metric': self.best_metric, 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'log': self.log } if self.scaler: checkpoint['scaler'] = self.scaler.state_dict() torch.save(checkpoint, f'{path}{name}') def load_checkpoint(self, path, name=None): name = self.name + '.chk' if name is None else name + '.chk' checkpoint = torch.load(f'{path}{name}') self.model.load_state_dict(checkpoint['model']) self.best_model = checkpoint['best_model'] self.best_metric = checkpoint['best_metric'] self.optimizer.load_state_dict(checkpoint['optimizer']) self.log = checkpoint['log'] if 'scaler' in checkpoint.keys(): self.scaler = GradScaler() self.scaler.load_state_dict(checkpoint['scaler']) else: self.scaler = None def set_lr(self, lr): for param_group in self.optimizer.param_groups: param_group['lr'] = lr def on_epoch_begin(self, *args, **kargs): pass def on_epoch_end(self, *args, **kargs): pass def predict(self, ds, batch_size=None, num_workers=8, dl_args={'shuffle': False}, return_inds=False, return_true=False, verbose=True, do_eval=True): device = self.device if batch_size is not None: dl_args['batch_size'] = batch_size dl = D.DataLoader(ds, num_workers=num_workers, **dl_args) pred_list = [] inds_list = [] true_list = [] if do_eval: self.model = self.model.eval() with torch.no_grad(): tk0 = notebook.tqdm(dl) if verbose else dl for i, batch in enumerate(tk0): if autocast is None: y_pred = self.run_model(self.model, batch) else: with autocast(self.scaler is not None): y_pred = self.run_model(self.model, batch) if return_inds: inds_list.append(self.get_inds(batch).to('cpu').numpy()) if return_true: yb = self.get_y(batch) true_list.append(yb.to('cpu').numpy() if not isinstance(yb,tuple) else\ tuple(y.to('cpu').numpy() for y in yb)) pred_list.append(y_pred.to('cpu').numpy() if not isinstance(y_pred,tuple) else\ tuple(y.to('cpu').numpy() for y in y_pred)) pred = np.concatenate(pred_list,0) if not isinstance(pred_list[0],tuple) else\ tuple(np.concatenate([p[i] for p in pred_list],0) for i in range(len(pred_list[0]))) out = () if return_inds: out = out + (np.concatenate(inds_list, 0), ) if return_true: rt=np.concatenate(true_list,0) if not isinstance(true_list[0],tuple) else\ tuple(np.concatenate([p[i] for p in true_list],0) for i in range(len(true_list[0]))) out = out + (rt, ) return pred if len(out) == 0 else (pred, ) + out
class BaseModel(nn.Module): """ BaseModel This is the BaseModel used by all classifiers in this package. This base class provides a basic loop for fitting on a dataset and some convenience functions for storing and loading chckpoints. Each classifier is expected to provide - `def forward(self, X)`: A forward method which predicts / applies the model to the given batch `X`. Since a BaseModel inherits from `nn.Module` please use `self.train` to distinguish between training and testing. - `def prepare_backward(self, data, target, weights = None)`: A method which computes the loss for calling backward as well as additional statistics, such as running accuracy. The arguments are: - `data`: The examples in this batch - `target`: This is the corresponding target batch - `weights`: This is the corresponding weights per example if required. The `prepare_backward` function should return a dictionary with three fields `prediction`, `backward` and `metrics`. The `prediction` field stores the individual predictions for the batch (in the same order). The `backward` field is used to perform the gradient step and `metrics` is used to store any metrics which should be written reported. Formally, the following `backward` call is used: backward = self.prepare_backward(data, target, weights) loss = backward["backward"].mean() loss.backward() Note that the prediction / loss / metrics should be given for each individual example in the batch. __Do not reduce / sum / mean the loss etc manually__. This happens automatically later on. An example would be: d = { # apply the model "prediction" : self(data), # compute the loss "backward" : self.loss_function(self(data), target), # Compute some metrics "metrics" : { "loss" : self.loss_function(self(data), target).detach(), "accuracy" : 100.0*(self(data).argmax(1) == target).type(self.get_float_type()) } } The bas class also supports storing and loading of checkpoints. To do so, the implementing class must take care of its parameters / object by overriding `restore_state` and `get_state`. Note that thse function __must__ call the respective functions frm the base class: def restore_state(self,checkpoint): # Restore base state super().restore_state(checkpoint) # Extract and parameters from the chechpoint dictionary self.my_param = checkpoint["my_param"] def get_state(self): # Get base state state = super().get_state() return { **state, "my_param":self.my_param } This class already expects a fair amount of parameters. Thus, it is best to use `args` and `kwargs` to pass parameters between c'tors. The following pattern is used as a best-practice to implement new classifier: class MyClass(Model): def __init__(self, my_param, *args, **kwargs): super().__init__(*args, **kwargs) self.my_param Attributes: optimizer (dict): Dictionary of optimizer and its parameters. This dictionary is expected to have at-least two entries - `method`: The actual optimizer to be used, e.g. `torch.optim.SGD` - `epochs`: The number of epochs used for optimization. If this is not provided, it will be set to 1 Any additional field will be passed to the optimizer object: optimizer_method = optimizer.pop("method") epochs = optimizer.pop("epochs", 1) the_optimizer = optimizer_method(model.parameters(), **optimizer) An example would be optimizer = { "method" : torch.optim.SGD, "lr" : 1e-2, "epochs" : 150 } scheduler (dict): Dictionary of learning rate scheduler and its parameters. This can be `None` if no scheduling is desired. Otherwise, its expected to contain a `method` field which is the scheduler. Any additional field will be used to create the this object scheduler_method = scheduler.pop("method") the_scheduler = scheduler_method(the_optimizer, **scheduler) An example would be scheduler = { "method" : torch.optim.lr_scheduler.StepLR, "step_size" : 25, "gamma": 0.5 } loss_function (function): The loss function which should be minimized. Technically this class does not make use of this function, but only stores it for sub-classes. base_estimator (function): The (base) neural network which should be trained. Technically this class does not make use of this field, but only stores it for sub-classes. training_file (str, optional): Filename used to store metrics during training. Is only used if `out_path` is not None. Defaults to "trainings.jsonl" seed (long): Random seed for involved in any randomization process verbose (bool): If `true`, prints the progress of each epoch including metrics via `tqdm` else disables it out_path (str, optional): Path to the folder where training metrics should be stored. If no path is given, nothing is stored. Defaults to `None` test_data (optional): Test data which can be used to compute statistics every `eval_every` epochs. It should be compatible with PyTorch `DataLoader`, e.g. this should be a `torch.utils.data.Dataset` or a numpy error / PyTorch tensor: test_loader = torch.utils.data.DataLoader(self.test_data, **self.loader_cfg) Defaults to `None`, which means no additional metrics are computed besides the one already obtained on the training data. loader (dict, optional): Dictionary of loader parameters which are passed to `torch.utils.data.DataLoader`: train_loader = torch.utils.data.DataLoader( data, **self.loader ) The loader is used for both, the training data and `test_data` if supplied. The loader can be `None` which defaults to: self.loader = {'num_workers': 1, 'pin_memory': True, 'batch_size':128} eval_every (int, optional): Evaluates metrics on the test_data every `eval_every` epochs, if `test_data` is provided. Defaults to 5. If this is `None` no additonal metrics are computed. store_every (int, optional): Stores a checkpoint of the model every `store_every` epochs. If this is `None` no checkpoints are stored device (str, optional): The device which is used to execute the model. Should be compatible to PyTorch's keywords. Defaults to "cuda" use_amp (bool): If `true` uses mixed precision provided by PyTorch, else not. """ def __init__(self, optimizer, scheduler, loss_function, base_estimator, training_file="training.jsonl", seed=None, verbose=True, out_path=None, test_data=None, eval_every=5, store_every=None, device="cuda", loader=None, use_amp=False, *args, **kwargs): super().__init__() if isinstance( base_estimator, types.LambdaType) and base_estimator.__name__ == "<lambda>": print( "Warning: base_estimator is a lambda function in Models.py - This is fine, unless you want to store checkpoints of your model. This will likely fail since unnamed functions cannot be pickled. Consider naming it." ) if optimizer is not None: optimizer_copy = copy.deepcopy(optimizer) self.optimizer_method = optimizer_copy.pop("method") if "epochs" in optimizer_copy: self.epochs = optimizer_copy.pop("epochs") else: self.epochs = 1 self.optimizer_cfg = optimizer_copy else: self.optimizer_cfg = None if scheduler is not None: scheduler_copy = copy.deepcopy(scheduler) self.scheduler_method = scheduler_copy.pop("method") self.scheduler_cfg = scheduler_copy else: self.scheduler_cfg = None if loader is not None: self.loader_cfg = loader else: self.loader_cfg = { 'num_workers': 1, 'pin_memory': True, 'batch_size': 128 } self.base_estimator = base_estimator self.loss_function = loss_function self.verbose = verbose self.out_path = out_path self.test_data = test_data self.seed = seed self.eval_every = eval_every self.store_every = store_every self.training_file = training_file self.cur_epoch = 0 self.resume_from_checkpoint = False self.device = device self.use_amp = use_amp if self.seed is not None: np.random.seed(self.seed) random.seed(self.seed) torch.manual_seed(self.seed) # if you are using GPU if self.device != "cpu": torch.cuda.manual_seed(self.seed) torch.cuda.manual_seed_all(self.seed) def get_float_type(self): if self.device == "cpu": return torch.FloatTensor else: return torch.cuda.FloatTensor def restore_state(self, checkpoint): self.optimizer_method = checkpoint["optimizer_method"] self.optimizer_cfg = checkpoint["optimizer_cfg"] self.scheduler_method = checkpoint["scheduler_method"] self.scheduler_cfg = checkpoint["scheduler_cfg"] self.loader_cfg = checkpoint["loader_cfg"] self.scheduler = checkpoint["scheduler"] self.base_estimator = checkpoint["base_estimator"] self.loss_function = checkpoint["loss_function"] self.verbose = checkpoint["verbose"] self.out_path = checkpoint["out_path"] self.test_data = checkpoint["test_data"] self.seed = checkpoint["seed"] self.eval_every = checkpoint["eval_every"] self.store_every = checkpoint["store_every"] self.training_file = checkpoint["training_file"] self.cur_epoch = checkpoint["cur_epoch"] self.epochs = checkpoint["epochs"] self.resume_from_checkpoint = True self.device = checkpoint["device"] self.use_amp = checkpoint["use_amp"] self.scaler = GradScaler(enabled=self.use_amp) self.scaler.load_state_dict(checkpoint['scaler_state_dict']) if self.seed is not None: np.random.seed(self.seed) random.seed(self.seed) torch.manual_seed(self.seed) # if you are using GPU if self.device != "cpu": torch.cuda.manual_seed(self.seed) torch.cuda.manual_seed_all(self.seed) self.load_state_dict(checkpoint['state_dict']) # Load the model to the correct device _before_ we init the optimizer # https://github.com/pytorch/pytorch/issues/2830 self.to(self.device) self.optimizer = self.optimizer_method(self.parameters(), **self.optimizer_cfg) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if self.scheduler_method is not None: self.scheduler = self.scheduler_method(self.optimizer, **self.scheduler_cfg) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) else: self.scheduler = None if self.loader_cfg is None: self.loader_cfg = { 'num_workers': 1, 'pin_memory': True, 'batch_size': 128 } def restore_checkoint(self, path): # https://github.com/pytorch/pytorch/issues/2830 checkpoint = torch.load(path, map_location=self.device) self.restore_state(checkpoint) def get_state(self): return { "optimizer_method": self.optimizer_method, "optimizer_cfg": self.optimizer_cfg, "loader_cfg": self.loader_cfg, "scheduler_method": self.scheduler_method, "scheduler_cfg": self.scheduler_cfg, "scheduler": self.scheduler, "base_estimator": self.base_estimator, "loss_function": self.loss_function, "verbose": self.verbose, "out_path": self.out_path, "test_data": self.test_data, "seed": self.seed, "device": self.device, "eval_every": self.eval_every, "store_every": self.store_every, "training_file": self.training_file, 'cur_epoch': self.cur_epoch, 'epochs': self.epochs, 'use_amp': self.use_amp, 'scaler_state_dict': self.scaler.state_dict() } def store_checkpoint(self): state = self.get_state() torch.save( state, os.path.join(self.out_path, 'model_{}.tar'.format(self.cur_epoch))) @abstractmethod def forward(self, X): pass @abstractmethod def prepare_backward(self, data, target, weights=None): pass def fit(self, data): if not self.resume_from_checkpoint: self.optimizer = self.optimizer_method(self.parameters(), **self.optimizer_cfg) if self.scheduler_method is not None: self.scheduler = self.scheduler_method(self.optimizer, **self.scheduler_cfg) else: self.scheduler = None self.scaler = GradScaler(enabled=self.use_amp) if self.out_path is not None: outfile = open(self.out_path + "/" + self.training_file, "w", 1) else: if self.out_path is not None: outfile = open(self.out_path + "/" + self.training_file, "a", 1) train_loader = torch.utils.data.DataLoader(data, shuffle=True, **self.loader_cfg) self.to(self.device) self.train() for epoch in range(self.cur_epoch, self.epochs): self.cur_epoch = epoch + 1 metrics = {} example_cnt = 0 with tqdm(total=len(train_loader.dataset), ncols=150, disable=not self.verbose) as pbar: self.batch_cnt = 0 for batch in train_loader: if len(batch) == 1: data = batch else: data = batch[0] data = data.to(self.device) data = Variable(data) if len(batch) > 1: target = batch[1] target = target.to(self.device) target = Variable(target) else: target = None if len(batch) > 2: weights = batch[2] weights = weights.to(self.device) weights = Variable(weights) else: weights = None example_cnt += data.shape[0] self.optimizer.zero_grad() # We assume that prepare_backward computes the appropriate loss and possible some statistics # the user wants to store / output. To do so, prepare_backward should return a dictionary with # three fields. An example is given below. Note that the prediction / loss / metrics should be # given for each individual example in the batch. # !!!! Do not reduce / sum / mean the loss etc manually !!!! # This is done afterwards in this code. # # d = { # "prediction" : self(data), # "backward" : self.loss_function(self(data), target), # "metrics" : # { # "loss" : self.loss_function(self(data), target), # "accuracy" : 100.0*(self(data).argmax(1) == target).type(self.get_float_type()) # } # } with autocast(enabled=self.use_amp): backward = self.prepare_backward(data, target, weights) loss = backward["backward"].mean() for key, val in backward["metrics"].items(): metrics[key] = metrics.get(key, 0) + val.sum().item() self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() mstr = "" for key, val in metrics.items(): mstr += "{} {:2.4f} ".format(key, val / example_cnt) pbar.update(data.shape[0]) desc = '[{}/{}] {}'.format(epoch, self.epochs - 1, mstr) pbar.set_description(desc) self.batch_cnt += 1 if self.scheduler is not None: self.scheduler.step() #torch.cuda.empty_cache() if self.out_path is not None: out_dict = {} mstr = "" for key, val in metrics.items(): out_dict["train_" + key] = val / example_cnt mstr += "{} {:2.4f} ".format(key, val / example_cnt) if self.store_every and self.store_every > 0 and ( epoch % self.store_every) == 0: self.store_checkpoint() if self.test_data and self.eval_every and self.eval_every > 0 and ( epoch % self.eval_every) == 0: # This is basically a different version of apply_in_batches but using the "new" prepare_backward interface # for evaluating the test data. Maybe we should refactor this at some point and / or apply_in_batches # is not really needed anymore as its own function? # TODO Check if refactoring might be interestring here self.eval() test_metrics = {} test_loader = torch.utils.data.DataLoader( self.test_data, **self.loader_cfg) for batch in test_loader: test_data = batch[0] test_target = batch[1] test_data, test_target = test_data.to( self.device), test_target.to(self.device) test_data, test_target = Variable( test_data), Variable(test_target) with torch.no_grad(): backward = self.prepare_backward( test_data, test_target) for key, val in backward["metrics"].items(): test_metrics[key] = test_metrics.get( key, 0) + val.sum().item() self.train() for key, val in test_metrics.items(): out_dict["test_" + key] = val / len(test_loader.dataset) mstr += "test {} {:2.4f} ".format( key, val / len(test_loader.dataset)) desc = '[{}/{}] {}'.format(epoch, self.epochs - 1, mstr) pbar.set_description(desc) out_dict["epoch"] = epoch out_file_content = json.dumps(out_dict, sort_keys=True) + "\n" outfile.write(out_file_content) if hasattr(train_loader.dataset, "end_of_epoch"): train_loader.dataset.end_of_epoch()
class TrainerLoop: def __init__( self, config: DictConfig, model: FlyModel, train_dataloader_fn: Callable, valid_dataloader_fn: Callable = None, test_dataloader_fn: Callable = None ): """ Args: config: FlyConfig dictionary model: must be FlyModel dataloader_fn: a Callable function which returns dataloaders """ assert isinstance(model, FlyModel) self.config = config self.model = model # For distributed self.rank = int(os.environ.get("RANK", 0)) self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) self.world_size = int(os.environ.get("WORLD_SIZE", 1)) self.distributed_training = (self.world_size > 1) if self.distributed_training and not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend='nccl', init_method='env://') assert torch.distributed.is_initialized() if self.distributed_training and not torch.distributed.is_initialized(): self.node_rank = os.environ.get("NODE_RANK", "N/A") logger.info( f"Initialized Rank:{torch.distributed.get_rank()} Locak-rank: {self.local_rank} on Node:{self.node_rank} Node-name:{socket.gethostname()}" ) logger.info("TrainerLoop is initializing!") # set cuda device if config.training.num_gpus_per_node > 0: torch.cuda.set_device(self.local_rank) self.device = torch.device("cuda", self.local_rank) else: self.device = torch.device("cpu") # Setup the dataloders self.train_dataloader = train_dataloader_fn() if train_dataloader_fn else None # only rank 0 can setup validation and test dataloder if self.rank == 0: self.validation_dataloader: Iterable = valid_dataloader_fn() if valid_dataloader_fn else None self.test_dataloader = test_dataloader_fn() if test_dataloader_fn else None # Setup callback handler self.callback_handler = CallbackHandler( config, trainer=self, callbacks=[], verbose=config.training.logging.level == "DEBUG" ) # constants self.fp16 = config.training.fp16 self.gradient_accumulation_batches = config.training.gradient_accumulation_batches self.setup_training_constants() # local variables self.global_batch_count = 0 self.global_step_count = 0 self.epochs_trained = 0 self.local_step_count = 0 # Configure optimizers self.optimizers, self.schedulers = self.model.configure_optimizers(self.total_num_update_steps) self.optimizers, self.schedulers = self.configure_optimizers() # Model is sent to GPU or CPU self.model = move_to_device(self.model, self.device) # Mixed-Precision if self.fp16: if self.config.training.num_gpus_per_node == 0: raise NotImplementedError("For mixed precision training, you need to use GPU!") self.configure_fp16() # Distributed Training if self.world_size > 1: self.configure_ddp() # Configure all callbacks self.configure_callbacks() self.callback_handler.fire_event(Events.INITIALIZE) # make sure the model has access to trainer info self.model.set_trainer(self) def setup_training_constants(self): self.total_num_update_steps = int(self.config.training.total_num.update_steps) self.total_num_batches = self.total_num_update_steps * int(self.gradient_accumulation_batches) self.total_num_epochs = int(self.config.training.total_num.epochs) # check if training in epoch or update_steps if self.total_num_update_steps < 0 and self.total_num_epochs < 0: raise NotImplementedError("config.training.total_num.updated_steps must be larger than 0") elif self.total_num_update_steps > 0 and self.total_num_epochs > 0: raise NotImplementedError( "Please only set either config.training.total_num.updated_steps or config.training.total_num.epochs greater than 0" ) elif self.total_num_update_steps > 0 and self.total_num_epochs < 0: self.training_in_epoch = False elif self.total_num_update_steps < 0 and self.total_num_epochs > 0: self.training_in_epoch = True # get the number of batches in the dataloader for one epoch try: self.epoch_num_batches = len(self.train_dataloader) except TypeError: logger.warning("Cannot determine the length of train_dataloader!") self.epoch_num_batches = None if self.training_in_epoch: if self.epoch_num_batches is not None: self.total_num_batches = self.epoch_num_batches * self.total_num_epochs self.total_num_update_steps = self.total_num_batches // self.gradient_accumulation_batches self.epoch_num_update_steps = self.epoch_num_batches // self.gradient_accumulation_batches else: # this is set to wait until the epoch finishes first self.total_num_update_steps = sys.maxsize def configure_optimizers(self): return self.model.configure_optimizers(self.total_num_update_steps) def configure_callbacks(self): # Resume callback runs for all ranks self.resume_callback = Resume(self.config) self.add_callback(self.resume_callback) # For logging and inference, use rank 0 by default if self.rank == 0: self.log_callback = TrainLogger(self.config) self.add_callback(self.log_callback) self.eval_callback = Evaluation(self.config) self.add_callback(self.eval_callback) if self.config.training.console: self.console_callback = Console(self.config) self.add_callback(self.console_callback) self.checkpoint_callback = Checkpoint(self.config) self.add_callback(self.checkpoint_callback) def configure_fp16(self): self.loss_scaler = GradScaler() def configure_ddp(self): """ Default distributed training uses reducer for simplicity. """ # Distributed training (should be after apex fp16 initialization) self.distributed_training = True self.reducer = Reducer(self.model) # for param in self.model.parameters(): # dist.broadcast(param.data, 0) # self.model = DistributedDataParallel(self.model, delay_allreduce=True) # trainer.model = torch.nn.parallel.DistributedDataParallel( # trainer.model, device_ids=[trainer.rank], output_device=trainer.rank, find_unused_parameters=True # ) def train(self): # Training begins self.callback_handler.fire_event(Events.TRAIN_BEGIN) while True: self.callback_handler.fire_event(Events.EPOCH_BEGIN) self.train_epoch() self.callback_handler.fire_event(Events.EPOCH_END) self.epochs_trained += 1 if self.training_in_epoch: if self.epochs_trained >= self.total_num_epochs: break else: if self.global_step_count < self.total_num_update_steps: continue else: break # Training ends self.callback_handler.fire_event(Events.TRAIN_END) def train_epoch(self): self.optimizer = self.optimizers[0] self.scheduler = self.schedulers[0] self.local_step_count = 0 if self.train_dataloader is None: return for batch in self.train_dataloader: self.callback_handler.fire_event(Events.BATCH_BEGIN) batch = move_to_device(batch, self.device) output = self.backward_batch(batch) # Update the model if (self.global_batch_count + 1) % self.gradient_accumulation_batches == 0: # Update the model with optimizer self.step_update(self.model, self.optimizer, self.scheduler) self.global_step_count += 1 self.local_step_count += 1 self.callback_handler.fire_event(Events.BATCH_END) if self.global_step_count >= self.total_num_update_steps: break self.global_batch_count += 1 def backward_batch(self, batch): self.model.train() with torch.cuda.amp.autocast(self.fp16): output = self.model(batch) # get the loss from output if hasattr(output, "loss"): loss = output.loss elif isinstance(output, dict): loss = output["loss"] if self.gradient_accumulation_batches > 1: loss = loss / self.gradient_accumulation_batches self.loss_backward(loss) return output def step_update(self, model, optimizer, scheduler=None): """ self.loss_scaler is defined in `configure_fp16` """ self.callback_handler.fire_event(Events.STEP_BEGIN) # collect gradient if self.distributed_training: self.reducer.reduce() gradient_clip = self.config.training.optimization.max_gradient_norm # Gradient Clipping if gradient_clip > 0: if self.fp16: self.loss_scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) # Update the model if self.fp16: self.loss_scaler.step(optimizer) self.loss_scaler.update() else: optimizer.step() # Step learning rate if scheduler: scheduler.step() # Gradient to zero optimizer.zero_grad() self.callback_handler.fire_event(Events.STEP_END) def loss_backward(self, loss): self.callback_handler.fire_event(Events.BACKWARD_BEGIN) # Loss backward if self.fp16: self.loss_scaler.scale(loss).backward() else: loss.backward() self.callback_handler.fire_event(Events.BACKWARD_END) def validate(self): # Start Validation self.model.eval() self.model.reset_evaluation_metrics() self.callback_handler.fire_event(Events.VALIDATE_BEGIN) # No gradient is needed for validation with torch.no_grad(): pbar = tqdm.tqdm(self.validation_dataloader) pbar.mininterval = 2.0 for batch in pbar: # send to cuda device batch = move_to_device(batch, self.device) self.model.predict(batch) self.callback_handler.fire_event(Events.VALIDATE_END) def test(self): # Start Testing self.model.eval() self.model.reset_evaluation_metrics() self.callback_handler.fire_event(Events.TEST_BEGIN) # No gradient is needed for test with torch.no_grad(): pbar = tqdm.tqdm(self.test_dataloader) pbar.mininterval = 2.0 for batch in pbar: # send to cuda device batch = move_to_device(batch, self.device) self.model.predict(batch) self.callback_handler.fire_event(Events.TEST_END) def set_model_state(self, model_state_dict): self.model.load_state_dict(model_state_dict) def get_model_state(self): return self.model.state_dict() def set_trainer_state(self, trainer_state_dict): self.epochs_trained = trainer_state_dict["epochs_trained"] self.global_step_count = trainer_state_dict["global_step_count"] self.local_step_count = trainer_state_dict["local_step_count"] # Resume the training state if self.config.training.resume.resume: # Scheduler States if self.config.training.resume.resume_scheduler: for idx, scheduler in enumerate(self.schedulers): try: scheduler.load_state_dict(trainer_state_dict["schedulers_state_dict"][idx]) except: if self.rank == 0: logger.warning(f"Cannot Load Scheduler {idx}'s State!") if self.config.training.resume.resume_optimizer: for idx, optimizer in enumerate(self.optimizers): try: optimizer.load_state_dict(trainer_state_dict["optimizers_state_dict"][idx]) except: if self.rank == 0: logger.warning(f"Cannot Load Optimizer {idx}'s State!") # save amp states if self.fp16: self.loss_scaler.load_state_dict(trainer_state_dict["amp_state_dict"]) # Random States if self.config.training.resume.resume_rng_state: torch.set_rng_state(trainer_state_dict["cpu_rng_state"]) trainer_state_dict["cuda_rng_state"] = trainer_state_dict["cuda_rng_state"][:torch.cuda.device_count()] torch.cuda.set_rng_state_all(trainer_state_dict["cuda_rng_state"]) # All Callbacks for callback in self.callback_handler.callbacks: try: callback.load_state_dict(trainer_state_dict[str(type(callback))]) except: logger.error(f"{type(callback)} seems not to exist in the checkpoint state!") def get_trainer_state(self): trainer_state_dict = { "epochs_trained": self.epochs_trained, "global_step_count": self.global_step_count, "local_step_count": self.local_step_count, "optimizers_state_dict": [optimizer.state_dict() for optimizer in self.optimizers], "schedulers_state_dict": [scheduler.state_dict() for scheduler in self.schedulers], "cpu_rng_state": torch.get_rng_state(), "cuda_rng_state": torch.cuda.get_rng_state_all(), } # save amp states if self.fp16: trainer_state_dict["amp_state_dict"] = self.loss_scaler.state_dict() # All Callbacks for callback in self.callback_handler.callbacks: trainer_state_dict[str(type(callback))] = callback.state_dict() return trainer_state_dict def add_callback(self, callback: Callback): self.callback_handler.add_callback(callback) # def get_lr(optimizer): # for param_group in optimizer.param_groups: # return param_group['lr'] # def get_log_variable(x): # if isinstance(x, torch.Tensor): # x = x.detach() # return x.item() # else: # raise NotImplementedError
class BaseTrainer: def __init__(self, dist, rank, config, resume, only_validation, model, loss_function, optimizer): self.color_tool = colorful self.color_tool.use_style("solarized") model = DistributedDataParallel(model.to(rank), device_ids=[rank]) self.model = model self.optimizer = optimizer self.loss_function = loss_function # DistributedDataParallel (DDP) self.rank = rank self.dist = dist # Automatic mixed precision (AMP) self.use_amp = config["meta"]["use_amp"] self.scaler = GradScaler(enabled=self.use_amp) # Acoustics self.acoustic_config = config["acoustics"] # Supported STFT n_fft = self.acoustic_config["n_fft"] hop_length = self.acoustic_config["hop_length"] win_length = self.acoustic_config["win_length"] self.torch_stft = partial(stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length) self.torch_istft = partial(istft, n_fft=n_fft, hop_length=hop_length, win_length=win_length) self.librosa_stft = partial(librosa.stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length) self.librosa_istft = partial(librosa.istft, hop_length=hop_length, win_length=win_length) # Trainer.train in the config self.train_config = config["trainer"]["train"] self.epochs = self.train_config["epochs"] self.save_checkpoint_interval = self.train_config[ "save_checkpoint_interval"] self.clip_grad_norm_value = self.train_config["clip_grad_norm_value"] assert self.save_checkpoint_interval >= 1, "Check the 'save_checkpoint_interval' parameter in the config. It should be large than one." # Trainer.validation in the config self.validation_config = config["trainer"]["validation"] self.validation_interval = self.validation_config[ "validation_interval"] self.save_max_metric_score = self.validation_config[ "save_max_metric_score"] assert self.validation_interval >= 1, "Check the 'validation_interval' parameter in the config. It should be large than one." # Trainer.visualization in the config self.visualization_config = config["trainer"]["visualization"] # In the 'train.py' file, if the 'resume' item is 'True', we will update the following args: self.start_epoch = 1 self.best_score = -np.inf if self.save_max_metric_score else np.inf self.save_dir = Path(config["meta"]["save_dir"]).expanduser().absolute( ) / config["meta"]["experiment_name"] self.checkpoints_dir = self.save_dir / "checkpoints" self.logs_dir = self.save_dir / "logs" if resume: self._resume_checkpoint() # Debug validation, which skips training self.only_validation = only_validation if config["meta"]["preloaded_model_path"]: self._preload_model(Path(config["preloaded_model_path"])) if self.rank == 0: prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume=resume) self.writer = SummaryWriter(self.logs_dir.as_posix(), max_queue=5, flush_secs=30) self.writer.add_text( tag="Configuration", text_string=f"<pre> \n{toml.dumps(config)} \n</pre>", global_step=1) print(self.color_tool.cyan("The configurations are as follows: ")) print(self.color_tool.cyan("=" * 40)) print(self.color_tool.cyan(toml.dumps(config)[:-1])) # except "\n" print(self.color_tool.cyan("=" * 40)) with open( (self.save_dir / f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(), "w") as handle: toml.dump(config, handle) self._print_networks([self.model]) def _preload_model(self, model_path): """ Preload model parameters (in "*.tar" format) at the start of experiment. Args: model_path (Path): The file path of the *.tar file """ model_path = model_path.expanduser().absolute() assert model_path.exists( ), f"The file {model_path.as_posix()} is not exist. please check path." model_checkpoint = torch.load(model_path.as_posix(), map_location="cpu") self.model.load_state_dict(model_checkpoint["model"], strict=False) self.model.to(self.rank) if self.rank == 0: print( f"Model preloaded successfully from {model_path.as_posix()}.") def _resume_checkpoint(self): """ Resume the experiment from the latest checkpoint. """ latest_model_path = self.checkpoints_dir.expanduser().absolute( ) / "latest_model.tar" assert latest_model_path.exists( ), f"{latest_model_path} does not exist, can not load latest checkpoint." # Make sure all processes (GPUs) do not start loading before the saving is finished. # see https://stackoverflow.com/questions/59760328/how-does-torch-distributed-barrier-work self.dist.barrier() # Load it on the CPU and later use .to(device) on the model # Maybe slightly slow than use map_location="cuda:<...>" # https://stackoverflow.com/questions/61642619/pytorch-distributed-data-parallel-confusion checkpoint = torch.load(latest_model_path.as_posix(), map_location="cpu") self.start_epoch = checkpoint["epoch"] + 1 self.best_score = checkpoint["best_score"] self.optimizer.load_state_dict(checkpoint["optimizer"]) self.scaler.load_state_dict(checkpoint["scaler"]) if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): self.model.module.load_state_dict(checkpoint["model"]) else: self.model.load_state_dict(checkpoint["model"]) # self.model.to(self.rank) if self.rank == 0: print( f"Model checkpoint loaded. Training will begin at {self.start_epoch} epoch." ) def _save_checkpoint(self, epoch, is_best_epoch=False): """ Save checkpoint to "<save_dir>/<config name>/checkpoints" directory, which consists of: - epoch - best metric score in historical epochs - optimizer parameters - model parameters Args: is_best_epoch (bool): In the current epoch, if the model get a best metric score (is_best_epoch=True), the checkpoint of model will be saved as "<save_dir>/checkpoints/best_model.tar". """ print(f"\t Saving {epoch} epoch model checkpoint...") state_dict = { "epoch": epoch, "best_score": self.best_score, "optimizer": self.optimizer.state_dict(), "scaler": self.scaler.state_dict() } if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): state_dict["model"] = self.model.module.state_dict() else: state_dict["model"] = self.model.state_dict() # Saved in "latest_model.tar" # Contains all checkpoint information, including the optimizer parameters, the model parameters, etc. # New checkpoint will overwrite the older one. torch.save(state_dict, (self.checkpoints_dir / "latest_model.tar").as_posix()) # "model_{epoch_number}.pth" # Contains only model. torch.save(state_dict["model"], (self.checkpoints_dir / f"model_{str(epoch).zfill(4)}.pth").as_posix()) # If the model get a best metric score (means "is_best_epoch=True") in the current epoch, # the model checkpoint will be saved as "best_model.tar" # The newer best-scored checkpoint will overwrite the older one. if is_best_epoch: print( self.color_tool.red( f"\t Found a best score in the {epoch} epoch, saving...")) torch.save(state_dict, (self.checkpoints_dir / "best_model.tar").as_posix()) def _is_best_epoch(self, score, save_max_metric_score=True): """ Check if the current model got the best metric score """ if save_max_metric_score and score >= self.best_score: self.best_score = score return True elif not save_max_metric_score and score <= self.best_score: self.best_score = score return True else: return False @staticmethod def _print_networks(models: list): print( f"This project contains {len(models)} models, the number of the parameters is: " ) params_of_all_networks = 0 for idx, model in enumerate(models, start=1): params_of_network = 0 for param in model.parameters(): params_of_network += param.numel() print(f"\tNetwork {idx}: {params_of_network / 1e6} million.") params_of_all_networks += params_of_network print( f"The amount of parameters in the project is {params_of_all_networks / 1e6} million." ) def _set_models_to_train_mode(self): self.model.train() def _set_models_to_eval_mode(self): self.model.eval() def spec_audio_visualization(self, noisy, enhanced, clean, name, epoch, mark=""): self.writer.add_audio(f"{mark}_Speech/{name}_Noisy", noisy, epoch, sample_rate=16000) self.writer.add_audio(f"{mark}_Speech/{name}_Enhanced", enhanced, epoch, sample_rate=16000) self.writer.add_audio(f"{mark}_Speech/{name}_Clean", clean, epoch, sample_rate=16000) # Visualize the spectrogram of noisy speech, clean speech, and enhanced speech noisy_mag, _ = librosa.magphase( self.librosa_stft(noisy, n_fft=320, hop_length=160, win_length=320)) enhanced_mag, _ = librosa.magphase( self.librosa_stft(enhanced, n_fft=320, hop_length=160, win_length=320)) clean_mag, _ = librosa.magphase( self.librosa_stft(clean, n_fft=320, hop_length=160, win_length=320)) fig, axes = plt.subplots(3, 1, figsize=(6, 6)) for k, mag in enumerate([noisy_mag, enhanced_mag, clean_mag]): axes[k].set_title(f"mean: {np.mean(mag):.3f}, " f"std: {np.std(mag):.3f}, " f"max: {np.max(mag):.3f}, " f"min: {np.min(mag):.3f}") librosa.display.specshow(librosa.amplitude_to_db(mag), cmap="magma", y_axis="linear", ax=axes[k], sr=16000) plt.tight_layout() self.writer.add_figure(f"{mark}_Spectrogram/{name}", fig, epoch) def metrics_visualization(self, noisy_list, clean_list, enhanced_list, metrics_list, epoch, num_workers=10, mark=""): """ Get metrics on validation dataset by paralleling. Notes: 1. You can register other metrics, but STOI and WB_PESQ metrics must be existence. These two metrics are used for checking if the current epoch is a "best epoch." 2. If you want to use a new metric, you must register it in "util.metrics" file. """ assert "STOI" in metrics_list and "WB_PESQ" in metrics_list, "'STOI' and 'WB_PESQ' must be existence." # Check if the metric is registered in "util.metrics" file. for i in metrics_list: assert i in metrics.REGISTERED_METRICS.keys( ), f"{i} is not registered, please check 'util.metrics' file." stoi_mean = 0.0 wb_pesq_mean = 0.0 for metric_name in metrics_list: score_on_noisy = Parallel(n_jobs=num_workers)( delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est) for ref, est in zip(clean_list, noisy_list)) score_on_enhanced = Parallel(n_jobs=num_workers)( delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est) for ref, est in zip(clean_list, enhanced_list)) # Add the mean value of the metric to tensorboard mean_score_on_noisy = np.mean(score_on_noisy) mean_score_on_enhanced = np.mean(score_on_enhanced) self.writer.add_scalars(f"{mark}_Validation/{metric_name}", { "Noisy": mean_score_on_noisy, "Enhanced": mean_score_on_enhanced }, epoch) if metric_name == "STOI": stoi_mean = mean_score_on_enhanced if metric_name == "WB_PESQ": wb_pesq_mean = transform_pesq_range(mean_score_on_enhanced) return (stoi_mean + wb_pesq_mean) / 2 def train(self): for epoch in range(self.start_epoch, self.epochs + 1): if self.rank == 0: print( self.color_tool.yellow( f"{'=' * 15} {epoch} epoch {'=' * 15}")) print("[0 seconds] Begin training...") # [debug validation] Only run validation (only use the first GPU (process)) # inference + calculating metrics + saving checkpoints if self.only_validation and self.rank == 0: self._set_models_to_eval_mode() metric_score = self._validation_epoch(epoch) if self._is_best_epoch( metric_score, save_max_metric_score=self.save_max_metric_score): self._save_checkpoint(epoch, is_best_epoch=True) # Skip the following regular training, saving checkpoints, and validation continue # Regular training timer = ExecutionTime() self._set_models_to_train_mode() self._train_epoch(epoch) # Regular save checkpoints if self.rank == 0 and self.save_checkpoint_interval != 0 and ( epoch % self.save_checkpoint_interval == 0): self._save_checkpoint(epoch) # Regular validation if self.rank == 0 and (epoch % self.validation_interval == 0): print( f"[{timer.duration()} seconds] Training has finished, validation is in progress..." ) self._set_models_to_eval_mode() metric_score = self._validation_epoch(epoch) if self._is_best_epoch( metric_score, save_max_metric_score=self.save_max_metric_score): self._save_checkpoint(epoch, is_best_epoch=True) print(f"[{timer.duration()} seconds] This epoch is finished.") def _train_epoch(self, epoch): raise NotImplementedError def _validation_epoch(self, epoch): raise NotImplementedError
class Fp16OptimizerHook(OptimizerHook): """FP16 optimizer hook (using PyTorch's implementation). If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, to take care of the optimization procedure. Args: loss_scale (float | str | dict): Scale factor configuration. If loss_scale is a float, static loss scaling will be used with the specified scale. If loss_scale is a string, it must be 'dynamic', then dynamic loss scaling will be used. It can also be a dict containing arguments of GradScalar. Defaults to 512. For Pytorch >= 1.6, mmcv uses official implementation of GradScaler. If you use a dict version of loss_scale to create GradScaler, please refer to: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler for the parameters. Examples: >>> loss_scale = dict( ... init_scale=65536.0, ... growth_factor=2.0, ... backoff_factor=0.5, ... growth_interval=2000 ... ) >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale) """ def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1, loss_scale=512., distributed=True): self.grad_clip = grad_clip self.coalesce = coalesce self.bucket_size_mb = bucket_size_mb self.distributed = distributed self._scale_update_param = None if loss_scale == 'dynamic': self.loss_scaler = GradScaler() elif isinstance(loss_scale, float): self._scale_update_param = loss_scale self.loss_scaler = GradScaler(init_scale=loss_scale) elif isinstance(loss_scale, dict): self.loss_scaler = GradScaler(**loss_scale) else: raise ValueError('loss_scale must be of type float, dict, or ' f'"dynamic", got {loss_scale}') def before_run(self, runner): """Preparing steps before Mixed Precision Training.""" # wrap model mode to fp16 wrap_fp16_model(runner.model) # resume from state dict if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']: scaler_state_dict = runner.meta['fp16']['loss_scaler'] self.loss_scaler.load_state_dict(scaler_state_dict) def copy_grads_to_fp32(self, fp16_net, fp32_weights): """Copy gradients from fp16 model to fp32 weight copy.""" for fp32_param, fp16_param in zip(fp32_weights, fp16_net.parameters()): if fp16_param.grad is not None: if fp32_param.grad is None: fp32_param.grad = fp32_param.data.new( fp32_param.size()) fp32_param.grad.copy_(fp16_param.grad) def copy_params_to_fp16(self, fp16_net, fp32_weights): """Copy updated params from fp32 weight copy to fp16 model.""" for fp16_param, fp32_param in zip(fp16_net.parameters(), fp32_weights): fp16_param.data.copy_(fp32_param.data) def after_train_iter(self, runner): """Backward optimization steps for Mixed Precision Training. For dynamic loss scaling, please refer to https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler. 1. Scale the loss by a scale factor. 2. Backward the loss to obtain the gradients. 3. Unscale the optimizer’s gradient tensors. 4. Call optimizer.step() and update scale factor. 5. Save loss_scaler state_dict for resume purpose. """ # clear grads of last iteration runner.model.zero_grad() runner.optimizer.zero_grad() self.loss_scaler.scale(runner.outputs['loss']).backward() self.loss_scaler.unscale_(runner.optimizer) # grad clip if self.grad_clip is not None: grad_norm = self.clip_grads(runner.model.parameters()) if grad_norm is not None: # Add grad norm to the logger runner.log_buffer.update({'grad_norm': float(grad_norm)}, runner.outputs['num_samples']) # backward and update scaler self.loss_scaler.step(runner.optimizer) self.loss_scaler.update(self._scale_update_param) # save state_dict of loss_scaler runner.meta.setdefault( 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
def main(args): # ensures that weight initializations are all the same torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) logging = utils.Logger(args.global_rank, args.save) writer = utils.Writer(args.global_rank, args.save) # Get data loaders. train_queue, valid_queue, num_classes = datasets.get_loaders(args) args.num_total_iter = len(train_queue) * args.epochs warmup_iters = len(train_queue) * args.warmup_epochs swa_start = len(train_queue) * (args.epochs - 1) arch_instance = utils.get_arch_cells(args.arch_instance) model = AutoEncoder(args, writer, arch_instance) model = model.cuda() logging.info('args = %s', args) logging.info('param size = %fM ', utils.count_parameters_in_M(model)) logging.info('groups per scale: %s, total_groups: %d', model.groups_per_scale, sum(model.groups_per_scale)) if args.fast_adamax: # Fast adamax has the same functionality as torch.optim.Adamax, except it is faster. cnn_optimizer = Adamax(model.parameters(), args.learning_rate, weight_decay=args.weight_decay, eps=1e-3) else: cnn_optimizer = torch.optim.Adamax(model.parameters(), args.learning_rate, weight_decay=args.weight_decay, eps=1e-3) cnn_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( cnn_optimizer, float(args.epochs - args.warmup_epochs - 1), eta_min=args.learning_rate_min) grad_scalar = GradScaler(2**10) num_output = utils.num_output(args.dataset) bpd_coeff = 1. / np.log(2.) / num_output # if load checkpoint_file = os.path.join(args.save, 'checkpoint.pt') if args.cont_training: logging.info('loading the model.') checkpoint = torch.load(checkpoint_file, map_location='cpu') init_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) model = model.cuda() cnn_optimizer.load_state_dict(checkpoint['optimizer']) grad_scalar.load_state_dict(checkpoint['grad_scalar']) cnn_scheduler.load_state_dict(checkpoint['scheduler']) global_step = checkpoint['global_step'] else: global_step, init_epoch = 0, 0 for epoch in range(init_epoch, args.epochs): # epochs cycle # update lrs. if args.distributed: train_queue.sampler.set_epoch(global_step + args.seed) valid_queue.sampler.set_epoch(0) if epoch > args.warmup_epochs: cnn_scheduler.step() # Logging. logging.info('epoch %d', epoch) # Training. train_nelbo, global_step = train(train_queue, model, cnn_optimizer, grad_scalar, global_step, warmup_iters, writer, logging) logging.info('train_nelbo %f', train_nelbo) writer.add_scalar('train/nelbo', train_nelbo, global_step) model.eval() # generate samples less frequently eval_freq = 1 if args.epochs <= 50 else 20 if epoch % eval_freq == 0 or epoch == (args.epochs - 1): with torch.no_grad(): num_samples = 16 n = int(np.floor(np.sqrt(num_samples))) for t in [0.7, 0.8, 0.9, 1.0]: logits = model.sample(num_samples, t) output = model.decoder_output(logits) output_img = output.mean if isinstance( output, torch.distributions.bernoulli.Bernoulli ) else output.sample(t) output_tiled = utils.tile_image(output_img, n) writer.add_image('generated_%0.1f' % t, output_tiled, global_step) valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=10, args=args, logging=logging) logging.info('valid_nelbo %f', valid_nelbo) logging.info('valid neg log p %f', valid_neg_log_p) logging.info('valid bpd elbo %f', valid_nelbo * bpd_coeff) logging.info('valid bpd log p %f', valid_neg_log_p * bpd_coeff) writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch) writer.add_scalar('val/nelbo', valid_nelbo, epoch) writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff, epoch) writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch) save_freq = int(np.ceil(args.epochs / 100)) if epoch % save_freq == 0 or epoch == (args.epochs - 1): if args.global_rank == 0: logging.info('saving the model.') torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': cnn_optimizer.state_dict(), 'global_step': global_step, 'args': args, 'arch_instance': arch_instance, 'scheduler': cnn_scheduler.state_dict(), 'grad_scalar': grad_scalar.state_dict() }, checkpoint_file) # Final validation valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=1000, args=args, logging=logging) logging.info('final valid nelbo %f', valid_nelbo) logging.info('final valid neg log p %f', valid_neg_log_p) writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch + 1) writer.add_scalar('val/nelbo', valid_nelbo, epoch + 1) writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff, epoch + 1) writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch + 1) writer.close()
class Trainer: def __init__( self, config: DictConfig, model: FlyModel, name: str = "Trainer1", *args, **kwargs, ): """ One trainer has one model Args: config: FlyConfig dictionary model: must be FlyModel dataloader_fn: a Callable function which returns dataloaders """ logger.info("TrainerLoop is initializing!") if not isinstance(model, FlyModel): logger.warn("model is not defined as FlyModel") self.config = config self.model = model self.trainer_name = name # class properties self.rank = None self.local_rank = None self.node_rank = None self.world_size = None self.distributed_training = None self.device = None self.gradient_accumulation_batches = None self.callback_handler = None self.optimizers = [] self.schedulers = [] self.init_distributed_environment() # make sure the model has access to trainer info self.model.set_trainer(self) self.callback_handler = CallbackHandler( config, trainer=self, callbacks=[], verbose=config.logging.level == "DEBUG") # Configure all callbacks self.configure_callbacks(config) self.callback_handler.fire_event(Events.INITIALIZE) def init_distributed_environment(self): # For distributed self.rank = int(os.environ.get("RANK", 0)) self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) self.world_size = int(os.environ.get("WORLD_SIZE", 1)) self.distributed_training = self.world_size > 1 # TODO: add error message when num_gpus is set, but distributed training is False here if self.distributed_training and not torch.distributed.is_initialized( ): torch.distributed.init_process_group(backend="nccl", init_method="env://") assert torch.distributed.is_initialized() if self.distributed_training and not torch.distributed.is_initialized( ): self.node_rank = os.environ.get("NODE_RANK", "N/A") logger.info( f"Initialized Rank:{torch.distributed.get_rank()} Locak-rank: {self.local_rank} on Node:{self.node_rank} Node-name:{socket.gethostname()}" ) torch.cuda.set_device(distributed.get_rank()) def init_device(self, config): # set cuda device if config.num_gpus_per_node > 0: torch.cuda.set_device(self.local_rank) self.device = torch.device("cuda", self.local_rank) else: self.device = torch.device("cpu") def init_fp16(self, config): if config.num_gpus_per_node == 0: raise NotImplementedError( "For mixed precision training, you need to use GPU!") self.loss_scaler = GradScaler() def init_training_constants(self, config): self.total_num_update_steps = int(config.total_num.update_steps) self.total_num_batches = self.total_num_update_steps * int( self.gradient_accumulation_batches) self.total_num_epochs = int(config.total_num.epochs) # check if training in epoch or update_steps if self.total_num_update_steps < 0 and self.total_num_epochs < 0: raise NotImplementedError( "config.total_num.updated_steps must be larger than 0") elif self.total_num_update_steps > 0 and self.total_num_epochs > 0: raise NotImplementedError( "Please only set either config.total_num.updated_steps or config.total_num.epochs greater than 0" ) elif self.total_num_update_steps > 0 and self.total_num_epochs < 0: self.training_in_epoch = False elif self.total_num_update_steps < 0 and self.total_num_epochs > 0: self.training_in_epoch = True # get the number of batches in the dataloader for one epoch try: self.epoch_num_batches = len(self.train_dataloader) except TypeError: logger.warning("Cannot determine the length of train_dataloader!") self.epoch_num_batches = None if self.training_in_epoch: if self.epoch_num_batches is not None: self.total_num_batches = self.epoch_num_batches * self.total_num_epochs self.total_num_update_steps = ( self.total_num_batches // self.gradient_accumulation_batches) self.epoch_num_update_steps = ( self.epoch_num_batches // self.gradient_accumulation_batches) else: # this is set to wait until the epoch finishes first self.total_num_update_steps = sys.maxsize def configure_optimizers(self, config, total_num_update_steps=None, optimizers=None, schedulers=None): if optimizers is not None and schedulers is not None: self.optimizers, self.schedulers = optimizers, schedulers elif total_num_update_steps is not None: self.optimizers, self.schedulers = self.model.configure_optimizers( config, total_num_update_steps) else: raise ValueError("Please provide the correct argument!") return self.optimizers, self.schedulers def configure_callbacks(self, config): # Resume callback runs for all ranks if config.resume.enabled: self.resume_callback = Resume(config) self.add_callback(self.resume_callback) self.log_callback = TrainLogger(config) self.add_callback(self.log_callback) self.eval_callback = Evaluation(config) self.add_callback(self.eval_callback) # For logging and inference, use rank 0 by default if self.rank == 0: if config.console: self.console_callback = Console(config) self.add_callback(self.console_callback) if config.checkpointing.enabled: self.checkpoint_callback = Checkpoint(config) self.add_callback(self.checkpoint_callback) def init_distributed_model(self, model): """ Default distributed training uses reducer for simplicity. """ logger.info("Reducer is intialized!") # Distributed training (should be after apex fp16 initialization) self.reducer = Reducer(model) # for param in self.model.parameters(): # dist.broadcast(param.data, 0) def train( self, config, train_dataloader, validation_dataloader=None, test_dataloader=None, configure_optimizers=True, stage_name: str = "Stage1", *args, **kwargs, ): self.config = config self.stage_name = stage_name # Model is sent to GPU or CPU self.init_device(config) # self.optimizers, self.schedulers = self.configure_optimizers() self.gradient_accumulation_batches = config.gradient_accumulation_batches self.max_gradient_norm = config.optimization.max_gradient_norm self.fp16 = config.fp16 self.model = move_to_device(self.model, self.device) self.model.device = self.device self.init_fp16(config) if self.distributed_training: self.init_distributed_model(self.model) self.total_num_update_steps = 0 self.total_num_batches = 0 self.total_num_epochs = 0 self.epoch_num_batches = 0 self.global_batch_count = 0 self.global_step_count = 0 self.epochs_trained = 0 self.local_step_count = 0 self.train_dataloader = train_dataloader self.validation_dataloader = validation_dataloader self.test_dataloader = test_dataloader self.init_training_constants(config) if configure_optimizers or len(self.optimizers) == 0: self.configure_optimizers(config, self.total_num_update_steps) # Training begins self.callback_handler.fire_event(Events.TRAIN_BEGIN) while True: self.callback_handler.fire_event(Events.EPOCH_BEGIN) self.train_epoch() self.callback_handler.fire_event(Events.EPOCH_END) self.epochs_trained += 1 if self.training_in_epoch: if self.epochs_trained >= self.total_num_epochs: break else: if self.global_step_count < self.total_num_update_steps: continue else: break # Training ends self.callback_handler.fire_event(Events.TRAIN_END) def train_epoch(self): self.optimizer = self.optimizers[0] self.scheduler = self.schedulers[0] self.local_step_count = 0 if self.train_dataloader is None: return for batch in self.train_dataloader: self.callback_handler.fire_event(Events.BATCH_BEGIN) batch = move_to_device(batch, self.device) output = self.backward_batch(batch) # Update the model if (self.global_batch_count + 1) % self.gradient_accumulation_batches == 0: # Update the model with optimizer self.step_update(self.model, self.optimizer, self.scheduler) self.global_step_count += 1 self.local_step_count += 1 self.callback_handler.fire_event(Events.BATCH_END) if self.global_step_count >= self.total_num_update_steps: break self.global_batch_count += 1 def backward_batch(self, batch): self.model.train() with torch.cuda.amp.autocast(self.fp16): output = self.model(batch) # get the loss from output if hasattr(output, "loss"): loss = output.loss elif isinstance(output, dict): loss = output["loss"] if self.gradient_accumulation_batches > 1: loss = loss / self.gradient_accumulation_batches self.loss_backward(loss) return output def step_update(self, model, optimizer, scheduler=None): """ self.loss_scaler is defined in `configure_fp16` """ self.callback_handler.fire_event(Events.STEP_BEGIN) # collect gradient if self.distributed_training: self.reducer.reduce() gradient_clip = self.max_gradient_norm # Gradient Clipping if gradient_clip > 0: if self.fp16: self.loss_scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) # Update the model if self.fp16: self.loss_scaler.step(optimizer) self.loss_scaler.update() else: optimizer.step() # Step learning rate if scheduler: scheduler.step() # Gradient to zero optimizer.zero_grad() self.callback_handler.fire_event(Events.STEP_END) def loss_backward(self, loss): self.callback_handler.fire_event(Events.BACKWARD_BEGIN) # Loss backward if self.fp16: self.loss_scaler.scale(loss).backward() else: loss.backward() self.callback_handler.fire_event(Events.BACKWARD_END) def validate(self, dataloader): # Start Validation self.model.reset_evaluation_metrics() self.callback_handler.fire_event(Events.VALIDATE_BEGIN) self.model.validation_loop(dataloader) self.callback_handler.fire_event(Events.VALIDATE_END) def test(self, dataloader): # Start Testing self.model.reset_evaluation_metrics() self.callback_handler.fire_event(Events.TEST_BEGIN) self.model.test_loop(dataloader) self.callback_handler.fire_event(Events.TEST_END) def set_model_state(self, model_state_dict): self.model.load_state_dict(model_state_dict) def get_model_state(self): return self.model.state_dict() def set_trainer_state(self, trainer_state_dict): self.trainer_name = trainer_state_dict["trainer_name"] self.trainer_stage = trainer_state_dict["trainer_stage"] self.epochs_trained = trainer_state_dict["epochs_trained"] self.global_step_count = trainer_state_dict["global_step_count"] self.local_step_count = trainer_state_dict["local_step_count"] # All Callbacks for callback in self.callback_handler.callbacks: try: callback.load_state_dict(trainer_state_dict[str( type(callback))]) except: logger.error( f"{type(callback)} seems not to exist in the checkpoint state!" ) # Resume the training state # Scheduler States for idx, scheduler in enumerate(self.schedulers): try: scheduler.load_state_dict( trainer_state_dict["schedulers_state_dict"][idx]) except: if self.rank == 0: logger.warning(f"Cannot Load Scheduler {idx}'s State!") for idx, optimizer in enumerate(self.optimizers): try: optimizer.load_state_dict( trainer_state_dict["optimizers_state_dict"][idx]) except: if self.rank == 0: logger.warning(f"Cannot Load Optimizer {idx}'s State!") # save amp states try: if self.fp16: self.loss_scaler.load_state_dict( trainer_state_dict["amp_state_dict"]) except: logger.warning(f"Cannot Load Loss Scaler State!") # Random States torch.set_rng_state(trainer_state_dict["cpu_rng_state"]) torch.cuda.set_rng_state_all( trainer_state_dict["cuda_rng_state"][:torch.cuda.device_count()]) def get_trainer_state(self): trainer_state_dict = { "trainer_name": self.trainer_name, "stage_name": self.stage_name, "epochs_trained": self.epochs_trained, "global_step_count": self.global_step_count, "local_step_count": self.local_step_count, "optimizers_state_dict": [optimizer.state_dict() for optimizer in self.optimizers], "schedulers_state_dict": [scheduler.state_dict() for scheduler in self.schedulers], "cpu_rng_state": torch.get_rng_state(), "cuda_rng_state": torch.cuda.get_rng_state_all(), } # save amp states if self.fp16: trainer_state_dict["amp_state_dict"] = self.loss_scaler.state_dict( ) # All Callbacks for callback in self.callback_handler.callbacks: trainer_state_dict[str(type(callback))] = callback.state_dict() return trainer_state_dict def add_callback(self, callback: Callback): self.callback_handler.add_callback(callback)
class ClassificationTask(ClassyTask): """Basic classification training task. This task encapsultates all of the components and steps needed to train a classifier using a :class:`classy_vision.trainer.ClassyTrainer`. Assumes a train / test phase per each epoch and that the datasets have the same API as the map-style Dataset class in `torch.utils.data.dataset <https://pytorch.org/docs/stable/data.html #torch.utils.data.Dataset>`_ (in particular, this task makes use of the len). If you are using an `IterableDataset <https://pytorch.org/docs/ stable/data.html#torch.utils.data.IterableDataset>`_ then a custom task may be appropriate. :var loss: Loss (see :class:`classy_vision.losses.ClassyLoss`) function used for computing the loss in each forward pass :var datasets: Mapping from a ``phase_type`` in ["train", "test'] to dataset used for training (or testing) :var meters: List of meters (see :class:`classy_vision.meters.ClassyMeter`) to calculate during training :var num_epochs: Number of epochs (passes over dataset) to train :var test_only: Used to only run the test phase :var base_model: Model to be trained, unwrapped in DDP or DP wrappers :var optimizer: Optimizer used in train step :var optimizer_schedulers: Dictionary. Key is the name of the optimizer option (e.g. lr), value is a ClassyParamScheduler :var checkpoint: Serializable dict which represents state in training :var phases: List of phase specific information, e.g. if phase is train / test. :var hooks: List of hooks to apply during training :var train: Phase type, if true it means we are training, false means testing :var distributed_model: Base model, but wrapped in DDP (DistributedDataParallel) :var phase_idx: Current phase id, first phase is 0, if task has not started training then returns -1 :var train_phase_idx: Only counts train phases :var num_updates: Number of total parameter updates applied to model by the optimizer :var data_iterator: Iterator which can be used to obtain batches :var losses: Loss curve :var perf_log: list of training speed measurements, to be logged :var clip_grad_norm: maximum gradient norm (default None) :var simulated_global_batchsize: batch size simulated via gradient accumulation :var optimizer_period: apply optimizer after this many steps; derived from simulated_global_batchsize, default 1. """ def __init__(self): """Constructs a ClassificationTask""" super().__init__() self.base_loss = None self.datasets = {} self.meters = [] self.num_epochs = 1 self.test_phase_period = 1 self.train_phases_per_epoch = 0 self.test_only = False self.base_model = None self.optimizer = None self.optimizer_schedulers = {} self.checkpoint_dict = None self.checkpoint_path = None self.phases = [] self.hooks = [] self.train = True self.distributed_model = None self.distributed_loss = None self.phase_idx = -1 self.train_phase_idx = -1 self.num_updates = 0 self.dataloader = None self.data_iterator = None self.losses = [] self.broadcast_buffers_mode: BroadcastBuffersMode = ( BroadcastBuffersMode.BEFORE_EVAL ) self.amp_args = None self.amp_type = None self.amp_grad_scaler = None self.mixup_transform = None self.perf_log = [] self.last_batch = None self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED self.find_unused_parameters = False self.use_gpu = torch.cuda.is_available() self.dataloader_mp_context = "spawn" self.bn_weight_decay = False self._train_only = True self.clip_grad_norm = None self.simulated_global_batchsize = None self.optimizer_period = 1 self.ddp_bucket_cap_mb = 25 self.use_sharded_ddp = False self.fp16_grad_compress = False def set_use_sharded_ddp(self, use_sharded_ddp: bool): self.use_sharded_ddp = use_sharded_ddp if self.use_sharded_ddp: logging.info("Using Sharded DDP") return self def set_use_gpu(self, use_gpu: bool): self.use_gpu = use_gpu assert ( not self.use_gpu or torch.cuda.is_available() ), "CUDA required to train on GPUs" return self def set_clip_grad_norm(self, clip_grad_norm: Optional[float]): """Sets maximum gradient norm. None means gradient clipping is disabled. Defaults to None.""" self.clip_grad_norm = clip_grad_norm if clip_grad_norm is None: logging.info("Disabled gradient norm clipping.") else: logging.info( f"Enabled gradient norm clipping with threshold: {clip_grad_norm}" ) return self def set_simulated_global_batchsize(self, simulated_global_batchsize: Optional[int]): """Sets a simulated batch size by gradient accumulation. Gradient accumulation adds up gradients from multiple minibatches and steps the optimizer every N train_steps, where N is optimizer_period. When enabled, the very last train_steps might end up not updating the model, depending on the number of total steps. None means gradient accumulation is disabled. Defaults to None.""" self.simulated_global_batchsize = simulated_global_batchsize return self def set_checkpoint(self, checkpoint_path: str): """Sets checkpoint on task. Args: checkpoint_path: The path to load the checkpoint from. Can be a file or a directory. See :func:`load_checkpoint` for more information. """ self.checkpoint_path = checkpoint_path return self def _set_checkpoint_dict(self, checkpoint_dict: Dict[str, Any]): """Sets the checkpoint dict in the task. Only used for testing. Args: checkpoint_dict: A serializable dict representing current task state """ self.checkpoint_dict = checkpoint_dict return self def set_num_epochs(self, num_epochs: Union[int, float]): """Set number of epochs to be run. Args: num_epochs: Number of epochs to run task """ self.num_epochs = num_epochs return self def set_test_phase_period(self, test_phase_period: int): """Set the period of test phase. Args: test_phase_period: The period of test phase """ self.test_phase_period = test_phase_period return self def set_dataset(self, dataset: ClassyDataset, phase_type: str): """Set dataset for phase type on task Args: dataset: ClassyDataset for returning samples. phase_type: str must be one of "train" or "test" """ assert phase_type in [ "train", "test", ], "phase_type must be in ['train', 'test']" self.datasets[phase_type] = dataset if phase_type == "train": self.train_phases_per_epoch = getattr(dataset, "phases_per_epoch", 1) else: self._train_only = False return self def set_dataloader_mp_context(self, dataloader_mp_context: Optional[str]): """Set the multiprocessing context used by the dataloader. The context can be either 'spawn', 'fork', 'forkserver' or None (uses the default context). See https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_context for more details.""" self.dataloader_mp_context = dataloader_mp_context return self def set_optimizer(self, optimizer: ClassyOptimizer): """Set optimizer for task Args: optimizer: optimizer for task """ self.optimizer = optimizer return self def set_loss(self, loss: ClassyLoss): """Set loss function for task Args: loss: loss for task """ self.base_loss = loss return self def set_meters(self, meters: List["ClassyMeter"]): """Set meters for task Args: meters: list of meters to compute during training """ self.meters = meters return self def set_distributed_options( self, broadcast_buffers_mode: BroadcastBuffersMode = BroadcastBuffersMode.BEFORE_EVAL, batch_norm_sync_mode: BatchNormSyncMode = BatchNormSyncMode.DISABLED, batch_norm_sync_group_size: int = 0, find_unused_parameters: bool = False, bucket_cap_mb: int = 25, fp16_grad_compress: bool = False, ): """Set distributed options. Args: broadcast_buffers_mode: Broadcast buffers mode. See :class:`BroadcastBuffersMode` for options. batch_norm_sync_mode: Batch normalization synchronization mode. See :class:`BatchNormSyncMode` for options. batch_norm_sync_group_size: Group size to use for synchronized batch norm. 0 means that the stats are synchronized across all replicas. For efficient synchronization, set it to the number of GPUs in a node ( usually 8). find_unused_parameters: See :class:`torch.nn.parallel.DistributedDataParallel` for information. bucket_cap_mb: See :class:`torch.nn.parallel.DistributedDataParallel` for information. Raises: RuntimeError: If batch_norm_sync_mode is `BatchNormSyncMode.APEX` and apex is not installed. """ self.broadcast_buffers_mode = broadcast_buffers_mode if batch_norm_sync_group_size > 0: if not batch_norm_sync_mode == BatchNormSyncMode.APEX: # this should ideally work with PyTorch Sync BN as well, but it # fails while initializing DDP for some reason. raise ValueError( "batch_norm_sync_group_size can be > 0 only when " "Apex Synchronized Batch Normalization is being used." ) self.batch_norm_sync_group_size = batch_norm_sync_group_size if batch_norm_sync_mode == BatchNormSyncMode.DISABLED: logging.info("Synchronized Batch Normalization is disabled") else: if batch_norm_sync_mode == BatchNormSyncMode.APEX and not apex_available: raise RuntimeError("apex is not installed") msg = f"Using Synchronized Batch Normalization using {batch_norm_sync_mode}" if self.batch_norm_sync_group_size > 0: msg += f" and group size {batch_norm_sync_group_size}" logging.info(msg) self.batch_norm_sync_mode = batch_norm_sync_mode if find_unused_parameters: logging.info("Enabling find_unused_parameters in DDP") self.find_unused_parameters = find_unused_parameters self.ddp_bucket_cap_mb = bucket_cap_mb if fp16_grad_compress: if get_torch_version() < [1, 8]: raise RuntimeError( "FP16 grad compression is only supported since PyTorch 1.8" ) logging.info("Enabling FP16 grad compression") self.fp16_grad_compress = fp16_grad_compress return self def set_hooks(self, hooks: List["ClassyHook"]): """Set hooks for task Args: hooks: List of hooks to apply during training """ from classy_vision.hooks import ClassyHook assert isinstance(hooks, list) assert all(isinstance(hook, ClassyHook) for hook in hooks) assert len({hook.name() for hook in hooks}) == len( hooks ), "Cannot have repeated hooks of the same class" # TODO (zyan3): we move checkpoint hook to the end of the list because some hooks # may change the state of the model, and we want to save changed state in the checkpoint. # This is temporary fix. non_checkpoint_hooks = [ hook for hook in hooks if not isinstance(hook, CheckpointHook) ] checkpoint_hooks = [hook for hook in hooks if isinstance(hook, CheckpointHook)] hooks = non_checkpoint_hooks + checkpoint_hooks self.hooks = hooks return self def set_model(self, model: ClassyModel): """Set model for task Args: model: Model to be trained """ self.base_model = model return self def set_test_only(self, test_only: bool): """Set test only flag Args: test_only: If true, only test phases will be run """ self.test_only = test_only return self def set_bn_weight_decay(self, bn_weight_decay: bool): assert type(bn_weight_decay) == bool self.bn_weight_decay = bn_weight_decay return self def set_amp_args(self, amp_args: Optional[Dict[str, Any]]): """Disable / enable apex.amp and set the automatic mixed precision parameters. apex.amp can be utilized for mixed / half precision training. Args: amp_args: Dictionary containing arguments to be passed to amp.initialize. Set to None to disable amp. To enable mixed precision training, pass amp_args={"opt_level": "O1"} here. See https://nvidia.github.io/apex/amp.html for more info. Raises: RuntimeError: If opt_level is not None and apex is not installed. Warning: apex needs to be installed to utilize this feature. """ self.amp_args = amp_args if amp_args is None: logging.info("AMP disabled") else: # Check that the requested AMP type is known try: self.amp_type = AmpType[self.amp_args["amp_type"].upper()] except KeyError: logging.info("AMP type not specified, defaulting to Apex") self.amp_type = AmpType.APEX # Check for CUDA availability, required for both Apex and Pytorch AMP if not torch.cuda.is_available(): raise RuntimeError( "AMP is required but CUDA is not supported, cannot enable AMP" ) # Check for Apex availability if self.amp_type == AmpType.APEX and not apex_available: raise RuntimeError( "Apex AMP is required but Apex is not installed, cannot enable AMP" ) if self.use_sharded_ddp: if self.amp_type == AmpType.APEX: raise RuntimeError( "ShardedDDP has been requested, which is incompatible with Apex AMP" ) if not fairscale_available: raise RuntimeError( "ShardedDDP has been requested, but fairscale is not installed in the current environment" ) # Set Torch AMP grad scaler, used to prevent gradient underflow elif self.amp_type == AmpType.PYTORCH: if self.use_sharded_ddp: logging.info("Using ShardedGradScaler to manage Pytorch AMP") self.amp_grad_scaler = ShardedGradScaler() else: self.amp_grad_scaler = TorchGradScaler() logging.info(f"AMP enabled with args {amp_args}") return self def set_mixup_transform(self, mixup_transform: Optional["MixupTransform"]): """Disable / enable mixup transform for data augmentation Args:: mixup_transform: a callable object which performs mixup data augmentation """ self.mixup_transform = mixup_transform if mixup_transform is None: logging.info("mixup disabled") else: logging.info("mixup enabled") return self def set_optimizer_schedulers(self, schedulers): self.optimizer_schedulers = schedulers return self @classmethod def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": """Instantiates a ClassificationTask from a configuration. Args: config: A configuration for a ClassificationTask. See :func:`__init__` for parameters expected in the config. Returns: A ClassificationTask instance. """ test_only = config.get("test_only", False) if not test_only: # TODO Make distinction between epochs and phases in optimizer clear train_phases_per_epoch = config["dataset"]["train"].get( "phases_per_epoch", 1 ) optimizer_config = config["optimizer"] optimizer_config["num_epochs"] = ( config["num_epochs"] * train_phases_per_epoch ) optimizer = build_optimizer(optimizer_config) param_schedulers = build_optimizer_schedulers(optimizer_config) datasets = {} phase_types = ["train", "test"] for phase_type in phase_types: if phase_type in config["dataset"]: datasets[phase_type] = build_dataset(config["dataset"][phase_type]) loss = build_loss(config["loss"]) amp_args = config.get("amp_args") meters = build_meters(config.get("meters", {})) model = build_model(config["model"]) mixup_transform = None if config.get("mixup") is not None: assert "alpha" in config["mixup"], "key alpha is missing in mixup dict" mixup_transform = MixupTransform( config["mixup"]["alpha"], config["mixup"].get("num_classes") ) # hooks config is optional hooks_config = config.get("hooks") hooks = [] if hooks_config is not None: hooks = build_hooks(hooks_config) distributed_config = config.get("distributed", {}) distributed_options = { "broadcast_buffers_mode": BroadcastBuffersMode[ distributed_config.get("broadcast_buffers", "before_eval").upper() ], "batch_norm_sync_mode": BatchNormSyncMode[ distributed_config.get("batch_norm_sync_mode", "disabled").upper() ], "batch_norm_sync_group_size": distributed_config.get( "batch_norm_sync_group_size", 0 ), "find_unused_parameters": distributed_config.get( "find_unused_parameters", False ), "bucket_cap_mb": distributed_config.get("bucket_cap_mb", 25), "fp16_grad_compress": distributed_config.get("fp16_grad_compress", False), } task = ( cls() .set_num_epochs(config["num_epochs"]) .set_test_phase_period(config.get("test_phase_period", 1)) .set_loss(loss) .set_test_only(test_only) .set_model(model) .set_meters(meters) .set_amp_args(amp_args) .set_mixup_transform(mixup_transform) .set_distributed_options(**distributed_options) .set_hooks(hooks) .set_bn_weight_decay(config.get("bn_weight_decay", False)) .set_clip_grad_norm(config.get("clip_grad_norm")) .set_simulated_global_batchsize(config.get("simulated_global_batchsize")) .set_use_sharded_ddp(config.get("use_sharded_ddp", False)) ) if not test_only: task.set_optimizer(optimizer) task.set_optimizer_schedulers(param_schedulers) use_gpu = config.get("use_gpu") if use_gpu is not None: task.set_use_gpu(use_gpu) for phase_type in datasets: task.set_dataset(datasets[phase_type], phase_type) # NOTE: this is a private member and only meant to be used for # logging/debugging purposes. See __repr__ implementation task._config = config return task @property def num_batches_per_phase(self): """Returns number of batches in current phase iterator""" return len(self.data_iterator) @property def model(self): """Returns model used in training (can be wrapped with DDP)""" return ( self.distributed_model if is_distributed_training_run() else self.base_model ) @property def loss(self): """Returns loss used in training (can be wrapped with DDP)""" return self.distributed_loss if self.distributed_loss else self.base_loss @property def phase_type(self): """Returns current phase type. String with value "train" or "test" """ return "train" if self.train else "test" @property def eval_phase_idx(self): """Returns current evaluation phase""" return self.phase_idx - self.train_phase_idx - 1 def get_total_training_phases(self): """ Returns the total number of "train" phases in the task """ num_training_phases = 0 for phase in self.phases: if phase["train"] is True: num_training_phases += 1 return num_training_phases def get_total_test_phases(self): """ Returns the total number of "test" phases in the task """ num_test_phases = 0 for phase in self.phases: if phase["train"] is False: num_test_phases += 1 return num_test_phases def _build_phases(self): """Returns list of phases from config. These phases will look like: { train: is this a train or test phase? optimizer: optimizer settings } - If this is a test only run, then only test phases will be generated - If this is a training run with both train and test datasets, then x phases = x train phases + x test phases, interleaved. If test_phase_period > 1, test phases are only added after test_phase_period train phases. The last phase is always a test phase. - If this is a training run with only a train dataset, then x phases = x train phases. """ if not self.test_only: phases = [ {"train": True} for _ in range(math.ceil(self.train_phases_per_epoch * self.num_epochs)) ] if self._train_only: return phases final_phases = [] for i, phase in enumerate(phases): final_phases.append(phase) if (i + 1) % self.test_phase_period == 0: final_phases.append({"train": False}) if final_phases[-1]["train"]: final_phases.append({"train": False}) return final_phases return [{"train": False} for _ in range(self.num_epochs)] def build_dataloader_from_dataset(self, dataset, **kwargs): """Builds a dataloader from the provided dataset Args: dataset: A ClassyDataset kwargs: Additional kwargs to pass during dataloader construction for derived classes """ return dataset.iterator( phase_type=self.phase_type, current_phase_id=self.train_phase_idx if self.train else 0, pin_memory=self.use_gpu and torch.cuda.device_count() > 1, multiprocessing_context=mp.get_context(self.dataloader_mp_context), **kwargs, ) def build_dataloaders_for_current_phase(self): """Builds dataloader(s) for the current phase. Deriving classes can override this method to support custom behavior, like supporting multiple dataloaders in parallel. """ self.dataloader = self.build_dataloader_from_dataset( self.datasets[self.phase_type] ) def prepare_optimizer(self, optimizer, model, loss=None): bn_params, other_params = split_batchnorm_params(model) if loss is not None: bn_params_loss, params_loss = split_batchnorm_params(loss) bn_params = bn_params + bn_params_loss other_params = other_params + params_loss bn_schedulers = self.optimizer_schedulers.copy() if not self.bn_weight_decay: bn_schedulers["weight_decay"] = 0 param_groups = [{"params": other_params, **self.optimizer_schedulers}] if len(bn_params) > 0: param_groups.append({"params": bn_params, **bn_schedulers}) self.optimizer.set_param_groups(param_groups) def prepare(self): """Prepares task for training, populates all derived attributes """ self.phases = self._build_phases() self.train = False if self.test_only else self.train if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH: self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(self.base_model) elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX: sync_bn_process_group = apex.parallel.create_syncbn_process_group( self.batch_norm_sync_group_size ) self.base_model = apex.parallel.convert_syncbn_model( self.base_model, process_group=sync_bn_process_group ) # move the model and loss to the right device if self.use_gpu: self.base_model, self.base_loss = copy_model_to_gpu( self.base_model, self.base_loss ) else: self.base_loss.cpu() self.base_model.cpu() if self.optimizer is not None: self.prepare_optimizer( optimizer=self.optimizer, model=self.base_model, loss=self.base_loss ) if self.amp_args is not None: if self.amp_type == AmpType.APEX: # Initialize apex.amp. This updates the model and the PyTorch optimizer ( # if training, which is wrapped by the ClassyOptimizer in self.optimizer). # Please note this must happen before loading the checkpoint, cause # there's amp state to be restored. if self.optimizer is None: self.base_model = apex.amp.initialize( self.base_model, optimizers=None, **self.amp_args ) else: self.base_model, self.optimizer.optimizer = apex.amp.initialize( self.base_model, self.optimizer.optimizer, **self.amp_args ) if self.simulated_global_batchsize is not None: if self.simulated_global_batchsize % self.get_global_batchsize() != 0: raise ValueError( f"Global batch size ({self.get_global_batchsize()}) must divide " f"simulated_global_batchsize ({self.simulated_global_batchsize})" ) else: self.simulated_global_batchsize = self.get_global_batchsize() self.optimizer_period = ( self.simulated_global_batchsize // self.get_global_batchsize() ) if self.optimizer_period > 1: logging.info( f"Using gradient accumulation with a period of {self.optimizer_period}" ) if self.checkpoint_path: self.checkpoint_dict = load_and_broadcast_checkpoint(self.checkpoint_path) classy_state_dict = ( None if self.checkpoint_dict is None else self.checkpoint_dict["classy_state_dict"] ) if classy_state_dict is not None: state_load_success = update_classy_state(self, classy_state_dict) assert ( state_load_success ), "Update classy state from checkpoint was unsuccessful." self.init_distributed_data_parallel_model() def init_distributed_data_parallel_model(self): """ Initialize `torch.nn.parallel.distributed.DistributedDataParallel <https://pytorch.org/ docs/stable/nn.html#distributeddataparallel>`_. Needed for distributed training. This is where a model should be wrapped by DDP. """ if not is_distributed_training_run(): return assert ( self.distributed_model is None ), "init_ddp_non_elastic must only be called once" broadcast_buffers = ( self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS ) if self.use_sharded_ddp: if not isinstance(self.optimizer, ZeRO): raise ValueError( "ShardedDataParallel engine should only be used in conjunction with ZeRO optimizer" ) from fairscale.nn.data_parallel import ShardedDataParallel # Replace the original DDP wrap by the shard-aware ShardedDDP self.distributed_model = ShardedDataParallel( module=self.base_model, sharded_optimizer=self.optimizer.optimizer, broadcast_buffers=broadcast_buffers, ) else: self.distributed_model = init_distributed_data_parallel_model( self.base_model, broadcast_buffers=broadcast_buffers, find_unused_parameters=self.find_unused_parameters, bucket_cap_mb=self.ddp_bucket_cap_mb, ) if self.fp16_grad_compress: from torch.distributed.algorithms import ddp_comm_hooks # FP16 hook is stateless and only takes a process group as the state. # We use the default process group so we set the state to None. process_group = None self.distributed_model.register_comm_hook( process_group, ddp_comm_hooks.default_hooks.fp16_compress_hook, ) if ( isinstance(self.base_loss, ClassyLoss) and self.base_loss.has_learned_parameters() ): logging.info("Initializing distributed loss") self.distributed_loss = init_distributed_data_parallel_model( self.base_loss, broadcast_buffers=broadcast_buffers, find_unused_parameters=self.find_unused_parameters, bucket_cap_mb=self.ddp_bucket_cap_mb, ) @property def where(self): """Returns the proportion of training that has completed. If in test only mode, returns proportion of testing completed Returned value is a float in the range [0, 1) """ current_step = self.num_updates / self.get_global_batchsize() num_phases = ( self.get_total_test_phases() if self.test_only else self.get_total_training_phases() ) if self.num_batches_per_phase <= 0: raise RuntimeError("No batches to read. Is the dataset empty?") num_steps = num_phases * self.num_batches_per_phase where = current_step / num_steps return where def get_classy_state(self, deep_copy: bool = False): """Returns serialiable state of task Args: deep_copy: If true, does a deep copy of state before returning. """ optimizer_state = {} if self.optimizer is not None: optimizer_state = self.optimizer.get_classy_state() classy_state_dict = { "train": self.train, "base_model": self.base_model.get_classy_state(), "meters": [meter.get_classy_state() for meter in self.meters], "optimizer": optimizer_state, "phase_idx": self.phase_idx, "train_phase_idx": self.train_phase_idx, "num_updates": self.num_updates, "losses": self.losses, "hooks": {hook.name(): hook.get_classy_state() for hook in self.hooks}, "loss": {}, } if "train" in self.datasets and self._is_checkpointable_dataset( self.datasets["train"] ): classy_state_dict["train_dataset_iterator"] = self.datasets[ "train" ].get_classy_state() if isinstance(self.base_loss, ClassyLoss): classy_state_dict["loss"] = self.base_loss.get_classy_state() if self.amp_args is not None: if self.amp_type == AmpType.APEX: classy_state_dict["amp"] = apex.amp.state_dict() elif self.amp_grad_scaler is not None: classy_state_dict["amp"] = self.amp_grad_scaler.state_dict() if deep_copy: classy_state_dict = copy.deepcopy(classy_state_dict) return classy_state_dict def set_classy_state(self, state): """Set task state Args: state: Dict containing state of a task """ self.train = False if self.test_only else state["train"] self.base_model.set_classy_state(state["base_model"]) if self.test_only: # if we're only testing, just need the state of the model to be updated return self.phase_idx = state["phase_idx"] self.num_updates = state["num_updates"] self.train_phase_idx = state["train_phase_idx"] self.losses = state["losses"] for meter, meter_state in zip(self.meters, state["meters"]): meter.set_classy_state(meter_state) if self.optimizer is not None: self.optimizer.set_classy_state(state["optimizer"]) if state.get("loss") and isinstance(self.base_loss, ClassyLoss): self.base_loss.set_classy_state(state["loss"]) if "amp" in state: if self.amp_type == AmpType.APEX: apex.amp.load_state_dict(state["amp"]) else: self.amp_grad_scaler.load_state_dict(state["amp"]) for hook in self.hooks: # we still want to be able to run when new hooks are added or old # hooks are removed if hook.name() in state["hooks"]: hook.set_classy_state(state["hooks"][hook.name()]) else: logging.warning(f"No state found for hook: {hook.name()}") if "train" in self.datasets and self._is_checkpointable_dataset( self.datasets["train"] ): self.datasets["train"].set_classy_state(state.get("train_dataset_iterator")) @staticmethod def _is_checkpointable_dataset(dataset): return hasattr(dataset, "get_classy_state") and hasattr( dataset, "set_classy_state" ) def eval_step(self): self.last_batch = None # Process next sample with Timer() as timer: sample = next(self.data_iterator) assert isinstance(sample, dict) and "input" in sample and "target" in sample, ( f"Returned sample [{sample}] is not a map with 'input' and" + "'target' keys" ) target = sample["target"] if self.use_gpu: sample = recursive_copy_to_gpu(sample, non_blocking=True) # Optional Pytorch AMP context torch_amp_context = ( torch.cuda.amp.autocast() if self.amp_type == AmpType.PYTORCH else contextlib.suppress() ) with torch.no_grad(), torch_amp_context: output = self.model(sample["input"]) local_loss = self.compute_loss(output, sample) loss = local_loss.detach().clone() self.check_inf_nan(loss) self.losses.append(loss.data.cpu().item()) self.update_meters(output, sample) # Move some data to the task so hooks get a chance to access it self.last_batch = LastBatchInfo( loss=loss, output=output, target=target, sample=sample, step_data={"sample_fetch_time": timer.elapsed_time}, ) def check_inf_nan(self, loss): if loss == float("inf") or loss == float("-inf") or loss != loss: raise FloatingPointError(f"Loss is infinity or NaN: {loss}") def _should_do_step(self): """Tells if we will be performing an optimizer step. Returns True always if there is no gradient accumulation. With gradient accumulation returns True only when the gradients will be synchronized and we will be performing an optimizer step. """ update_idx = self.num_updates // self.get_global_batchsize() return (update_idx % self.optimizer_period) == self.optimizer_period - 1 def train_step(self): """Train step to be executed in train loop.""" self.last_batch = None # Process next sample with Timer() as timer: sample = next(self.data_iterator) assert isinstance(sample, dict) and "input" in sample and "target" in sample, ( f"Returned sample [{sample}] is not a map with 'input' and" + "'target' keys" ) # Copy sample to GPU target = sample["target"] if self.use_gpu: sample = recursive_copy_to_gpu(sample, non_blocking=True) if self.mixup_transform is not None: sample = self.mixup_transform(sample) # Optional Pytorch AMP context torch_amp_context = ( torch.cuda.amp.autocast() if self.amp_type == AmpType.PYTORCH else contextlib.suppress() ) # only sync with DDP when we need to perform an optimizer step # an optimizer step can be skipped if gradient accumulation is enabled do_step = self._should_do_step() ctx_mgr_model = ( self.distributed_model.no_sync() if self.distributed_model is not None and not do_step else contextlib.suppress() ) ctx_mgr_loss = ( self.distributed_loss.no_sync() if self.distributed_loss is not None and not do_step else contextlib.suppress() ) with ctx_mgr_model, ctx_mgr_loss: # Forward pass with torch.enable_grad(), torch_amp_context: output = self.model(sample["input"]) local_loss = self.compute_loss(output, sample) loss = local_loss.detach().clone() self.losses.append(loss.data.cpu().item()) self.update_meters(output, sample) # Backwards pass + optimizer step self.run_optimizer(local_loss) self.num_updates += self.get_global_batchsize() # Move some data to the task so hooks get a chance to access it self.last_batch = LastBatchInfo( loss=loss, output=output, target=target, sample=sample, step_data={"sample_fetch_time": timer.elapsed_time}, ) def compute_loss(self, model_output, sample): return self.loss(model_output, sample["target"]) def run_optimizer(self, loss): """Runs backwards pass and update the optimizer""" self.check_inf_nan(loss) # Gradient accumulation logic. We always set optimizer_period, even # if gradient accumulation is disabled. Assumes all batches have the # same size update_idx = self.num_updates // self.get_global_batchsize() do_zero_grad = (update_idx % self.optimizer_period) == 0 do_step = self._should_do_step() if do_zero_grad: self.optimizer.zero_grad() if self.amp_type == AmpType.APEX: with apex.amp.scale_loss(loss, self.optimizer.optimizer) as scaled_loss: scaled_loss.backward() elif self.amp_type == AmpType.PYTORCH: self.amp_grad_scaler.scale(loss).backward() else: loss.backward() if do_step: # Handle gradient accumulation related gradient rescaling if self.optimizer_period != 1: self._rescale_gradients(1 / self.optimizer_period) # Clipping must happen after grad accumulation if self.clip_grad_norm is not None: self._clip_gradients(self.clip_grad_norm) if self.amp_type == AmpType.PYTORCH: # If using mixed precision, handle underflow-related scaling # See https://pytorch.org/docs/stable/amp.html#gradient-scaling # for context self.amp_grad_scaler.step(self.optimizer, where=self.where) self.amp_grad_scaler.update() else: self.optimizer.step(where=self.where) def _rescale_gradients(self, scale): for param in master_params(self.optimizer): if param.grad is not None: param.grad.data.mul_(scale) def _clip_gradients(self, max_norm): nn.utils.clip_grad_norm_(master_params(self.optimizer), max_norm) def update_meters(self, model_output, sample): target = sample["target"].detach().cpu() model_output = model_output.detach().cpu() # Update meters for meter in self.meters: meter.update(model_output, target, is_train=self.train) def synchronize_losses(self): """Average the losses across the different replicas""" # Average losses across nodes losses_tensor = torch.tensor(self.losses) synchronized_losses_tensor = all_reduce_mean(losses_tensor) self.losses = synchronized_losses_tensor.tolist() def advance_phase(self): """Performs bookkeeping / task updates between phases Increments phase idx, resets meters, resets loss history, resets counters, shuffles dataset, rebuilds iterators, and sets the train / test state for phase. """ logging.debug("Advancing phase") # Reset meters for next phase / epoch for meter in self.meters: meter.reset() # Reset loss history for next epoch self.losses = [] # Setup new phase self.phase_idx += 1 phase = self.phases[self.phase_idx] self.train = True if phase["train"] else False if self.train: self.train_phase_idx += 1 # Re-build dataloader & re-create iterator anytime membership changes. self.build_dataloaders_for_current_phase() self.create_data_iterators() # Set up pytorch module in train vs eval mode, update optimizer. self._set_model_train_mode() def done_training(self): """Stop condition for training""" return self.phase_idx + 1 >= len(self.phases) def create_data_iterators(self): """Creates data iterator(s) for the current phase.""" # Delete iterator explicitly so that all dataloader processes # are cleaned up. del self.data_iterator self.data_iterator = iter(self.dataloader) def _set_model_train_mode(self): """Set train mode for model""" phase = self.phases[self.phase_idx] self.base_model.train(phase["train"]) self.base_loss.train(phase["train"]) if ( self.broadcast_buffers_mode == BroadcastBuffersMode.BEFORE_EVAL and not self.train ): self._broadcast_buffers() def _broadcast_buffers(self): """Explicitly synchronize buffers across all devices.""" if self.distributed_model is None: return buffers = list(self.base_model.buffers()) if len(buffers) > 0: logging.info("Synchronizing buffers before evaluation.") for buffer in buffers: broadcast(buffer, 0, group=self.distributed_model.process_group) # TODO: Functions below should be better abstracted into the dataloader # abstraction def get_batchsize_per_replica(self): """Return local replica's batchsize for dataset (e.g. batchsize per GPU)""" return self.datasets[self.phase_type].get_batchsize_per_replica() def get_global_batchsize(self): """Return global batchsize across all trainers""" return self.datasets[self.phase_type].get_global_batchsize() def on_start(self): for hook in self.hooks: hook.on_start(self) def on_phase_start(self): self.phase_start_time_total = time.perf_counter() self.advance_phase() for hook in self.hooks: hook.on_phase_start(self) self.phase_start_time_train = time.perf_counter() def on_phase_end(self): self.log_phase_end("train") if self.train: self.optimizer.on_epoch(where=self.where) logging.debug("Syncing losses on phase end...") self.synchronize_losses() logging.debug("...losses synced") logging.debug("Syncing meters on phase end...") for meter in self.meters: meter.sync_state() logging.debug("...meters synced") barrier() for hook in self.hooks: hook.on_phase_end(self) self.perf_log = [] self.log_phase_end("total") if hasattr(self.datasets[self.phase_type], "on_phase_end"): self.datasets[self.phase_type].on_phase_end() def on_end(self): for hook in self.hooks: hook.on_end(self) def log_phase_end(self, tag): if not self.train: return start_time = ( self.phase_start_time_train if tag == "train" else self.phase_start_time_total ) phase_duration = time.perf_counter() - start_time im_per_sec = ( self.get_global_batchsize() * self.num_batches_per_phase ) / phase_duration self.perf_log.append( { "tag": tag, "phase_idx": self.train_phase_idx, "epoch_duration": phase_duration, "im_per_sec": im_per_sec, } ) def __repr__(self): if hasattr(self, "_config"): config = json.dumps(self._config, indent=4) return f"{super().__repr__()} initialized with config:\n{config}" return super().__repr__()
class DeepvacTrain(Deepvac): def __init__(self, deepvac_config): deepvac_config.is_forward_only = False super(DeepvacTrain, self).__init__(deepvac_config) self.initTrainParameters() self.initTrainContext() def setTrainContext(self): self.is_train = True self.is_val = False self.phase = 'TRAIN' self.dataset = self.train_dataset self.loader = self.train_loader self.batch_size = self.conf.train.batch_size self.net.train() def setValContext(self): self.is_train = False self.is_val = True self.phase = 'VAL' self.dataset = self.val_dataset self.loader = self.val_loader self.batch_size = self.conf.val.batch_size self.net.eval() def initTrainContext(self): self.scheduler = None self.initOutputDir() self.initSummaryWriter() self.initCriterion() self.initOptimizer() self.initScheduler() self.initCheckpoint() self.initTrainLoader() self.initValLoader() def initTrainParameters(self): self.dataset = None self.loader = None self.target = None self.epoch = 0 self.step = 0 self.iter = 0 # Creates a GradScaler once at the beginning of training. self.scaler = GradScaler() self.train_time = AverageMeter() self.load_data_time = AverageMeter() self.data_cpu2gpu_time = AverageMeter() self._mandatory_member_name = [ 'train_dataset', 'val_dataset', 'train_loader', 'val_loader', 'net', 'criterion', 'optimizer' ] def initOutputDir(self): if self.conf.output_dir != 'output' and self.conf.output_dir != './output': LOG.logW( "According deepvac standard, you should set config.output_dir to [output] rather than [{}]." .format(self.conf.output_dir)) self.output_dir = '{}/{}'.format(self.conf.output_dir, self.branch) LOG.logI('model save dir: {}'.format(self.output_dir)) #for DDP race condition os.makedirs(self.output_dir, exist_ok=True) def initSummaryWriter(self): event_dir = "{}/{}".format(self.conf.log_dir, self.branch) self.writer = SummaryWriter(event_dir) if not self.conf.tensorboard_port: return from tensorboard import program tensorboard = program.TensorBoard() self.conf.tensorboard_ip = '0.0.0.0' if self.conf.tensorboard_ip is None else self.conf.tensorboard_ip tensorboard.configure(argv=[ None, '--host', str(self.conf.tensorboard_ip), '--logdir', event_dir, "--port", str(self.conf.tensorboard_port) ]) try: url = tensorboard.launch() LOG.logI('Tensorboard at {} '.format(url)) except Exception as e: LOG.logE(e.msg) def initCriterion(self): self.criterion = torch.nn.CrossEntropyLoss() LOG.logW( "You should reimplement initCriterion() to initialize self.criterion, unless CrossEntropyLoss() is exactly what you need" ) def initCheckpoint(self): if not self.conf.checkpoint_suffix or self.conf.checkpoint_suffix == "": LOG.logI('Omit the checkpoint file since not specified...') return LOG.logI('Load checkpoint from {} folder'.format(self.output_dir)) self.net.load_state_dict( torch.load(self.output_dir + '/model__{}'.format(self.conf.checkpoint_suffix), map_location=self.device)) state_dict = torch.load( self.output_dir + '/checkpoint__{}'.format(self.conf.checkpoint_suffix), map_location=self.device) self.optimizer.load_state_dict(state_dict['optimizer']) if self.scheduler: self.scheduler.load_state_dict(state_dict['scheduler']) if self.conf.amp: LOG.logI( "Will load scaler from checkpoint since you enabled amp, make sure the checkpoint was saved with amp enabled." ) try: self.scaler.load_state_dict(state_dict["scaler"]) except: LOG.logI( "checkpoint was saved without amp enabled, so use fresh GradScaler instead." ) self.scaler = GradScaler() self.epoch = state_dict['epoch'] if self.conf.ema: self.ema.load_state_dict(state_dict['ema']) def initScheduler(self): if isinstance(self.conf.lr_step, list): self.scheduler = torch.optim.lr_scheduler.MultiStepLR( self.optimizer, self.conf.lr_step, self.conf.lr_factor) elif isinstance(self.conf.lr_step, Callable): self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=self.conf.lr_step) else: self.scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, self.conf.lr_step, self.conf.lr_factor) LOG.logW( "You should reimplement initScheduler() to initialize self.scheduler, unless lr_scheduler.StepLR() or lr_scheduler.MultiStepLR() is exactly what you need" ) def initTrainLoader(self): self.train_loader = None LOG.logE( "You must reimplement initTrainLoader() to initialize self.train_loader", exit=True) def initValLoader(self): self.val_loader = None LOG.logE( "You must reimplement initTrainLoader() to initialize self.val_loader", exit=True) def initOptimizer(self): self.initSgdOptimizer() LOG.logW( "You should reimplement initOptimizer() to initialize self.optimizer, unless SGD is exactly what you need" ) def initSgdOptimizer(self): self.optimizer = optim.SGD(self.net.parameters(), lr=self.conf.lr, momentum=self.conf.momentum, weight_decay=self.conf.weight_decay, nesterov=self.conf.nesterov) def initAdamOptimizer(self): self.optimizer = optim.Adam( self.net.parameters(), lr=self.conf.lr, betas=self.conf.betas if self.conf.betas else (0.9, 0.999), weight_decay=self.conf.weight_decay if self.conf.weight_decay else 0) for group in self.optimizer.param_groups: group.setdefault('initial_lr', group['lr']) def initRmspropOptimizer(self): self.optimizer = optim.RMSprop( self.net.parameters(), lr=self.conf.lr, momentum=self.conf.momentum, weight_decay=self.conf.weight_decay, # alpha=self.conf.rmsprop_alpha, # centered=self.conf.rmsprop_centered ) def addScalar(self, tag, value, step): self.writer.add_scalar(tag, value, step) def addImage(self, tag, image, step): self.writer.add_image(tag, image, step) @syszux_once def addGraph(self, input): try: self.writer.add_graph(self.net, input) except: LOG.logW( "Tensorboard addGraph failed. You network foward may have more than one parameters?" ) LOG.logW("Seems you need reimplement preIter function.") def earlyIter(self): self.feedSample() self.feedTarget() def feedSample(self): self.sample = self.sample.to(self.device) def feedTarget(self): self.target = self.target.to(self.device) def preIter(self): pass def postIter(self): pass def preEpoch(self): pass def postEpoch(self): pass def doForward(self): self.output = self.net(self.sample) def doLoss(self): self.loss = self.criterion(self.output, self.target) def doBackward(self): if self.conf.amp: self.scaler.scale(self.loss).backward() else: self.loss.backward() def doOptimize(self): if self.iter % self.conf.nominal_batch_factor != 0: return if self.conf.amp: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.optimizer.zero_grad() if self.conf.ema: self.updateEMA() def doLog(self): if self.step % self.conf.log_every != 0: return self.addScalar('{}/Loss'.format(self.phase), self.loss.item(), self.iter) self.addScalar('{}/LoadDataTime(secs/batch)'.format(self.phase), self.load_data_time.val, self.iter) self.addScalar('{}/DataCpu2GpuTime(secs/batch)'.format(self.phase), self.data_cpu2gpu_time.val, self.iter) self.addScalar('{}/TrainTime(secs/batch)'.format(self.phase), self.train_time.val, self.iter) LOG.logI('{}: [{}][{}/{}] [Loss:{} Lr:{}]'.format( self.phase, self.epoch, self.step, self.loader_len, self.loss.item(), self.optimizer.param_groups[0]['lr'])) def saveState(self, current_time): file_partial_name = '{}__acc_{}__epoch_{}__step_{}__lr_{}'.format( current_time, self.accuracy, self.epoch, self.step, self.optimizer.param_groups[0]['lr']) state_file = '{}/model__{}.pth'.format(self.output_dir, file_partial_name) checkpoint_file = '{}/checkpoint__{}.pth'.format( self.output_dir, file_partial_name) output_trace_file = '{}/trace__{}.pt'.format(self.output_dir, file_partial_name) output_script_file = '{}/script__{}.pt'.format(self.output_dir, file_partial_name) output_onnx_file = '{}/onnx__{}.onnx'.format(self.output_dir, file_partial_name) output_ncnn_file = '{}/ncnn__{}.bin'.format(self.output_dir, file_partial_name) output_coreml_file = '{}/coreml__{}.mlmodel'.format( self.output_dir, file_partial_name) #save state_dict net = self.ema if self.conf.ema else self.net torch.save(net.state_dict(), state_file) #save checkpoint torch.save( { 'optimizer': self.optimizer.state_dict(), 'epoch': self.epoch, 'scheduler': self.scheduler.state_dict() if self.scheduler else None, 'ema': self.ema.state_dict() if self.conf.ema else None, 'scaler': self.scaler.state_dict() if self.conf.amp else None }, checkpoint_file) self.exportTorchViaTrace(self.sample, output_trace_file) self.exportTorchViaScript(output_script_file) self.exportONNX(self.sample, output_onnx_file) self.exportNCNN(self.sample, output_ncnn_file) self.exportCoreML(self.sample, output_coreml_file) #tensorboard self.addScalar('{}/Accuracy'.format(self.phase), self.accuracy, self.iter) def processTrain(self): self.setTrainContext() self.step = 0 LOG.logI('Phase {} started...'.format(self.phase)) self.loader_len = len(self.loader) save_every = self.loader_len // self.conf.save_num save_list = list(range(0, self.loader_len + 1, save_every)) self.save_list = save_list[1:-1] LOG.logI('Model will be saved on step {} and the epoch end.'.format( self.save_list)) self.addScalar('{}/LR'.format(self.phase), self.optimizer.param_groups[0]['lr'], self.epoch) self.preEpoch() self.train_time.reset() self.load_data_time.reset() self.data_cpu2gpu_time.reset() iter_tick = time.time() for i, (sample, target) in enumerate(self.loader): self.load_data_time.update(time.time() - iter_tick) self.step = i self.target = target self.sample = sample self.preIter() feed_sample_tick = time.time() self.earlyIter() self.data_cpu2gpu_time.update(time.time() - feed_sample_tick) self.addGraph(self.sample) with autocast(enabled=self.conf.amp if self.conf.amp else False): self.doForward() self.doLoss() self.doBackward() self.doOptimize() self.doLog() self.postIter() self.iter += 1 self.train_time.update(time.time() - iter_tick) if self.step in self.save_list: self.processVal() self.setTrainContext() iter_tick = time.time() self.addScalar('{}/TrainTime(hours/epoch)'.format(self.phase), round(self.train_time.sum / 3600, 2), self.epoch) self.addScalar( '{}/AverageBatchTrainTime(secs/epoch)'.format(self.phase), self.train_time.avg, self.epoch) self.addScalar( '{}/AverageBatchLoadDataTime(secs/epoch)'.format(self.phase), self.load_data_time.avg, self.epoch) self.addScalar( '{}/AverageBatchDataCpu2GpuTime(secs/epoch)'.format(self.phase), self.data_cpu2gpu_time.avg, self.epoch) self.postEpoch() if self.scheduler: self.scheduler.step() def processVal(self, smoke=False): self.setValContext() LOG.logI('Phase {} started...'.format(self.phase)) with torch.no_grad(): self.preEpoch() for i, (sample, target) in enumerate(self.loader): self.target = target self.sample = sample self.preIter() self.earlyIter() self.doForward() self.doLoss() self.smokeTestForExport3rd() self.postIter() if smoke: return LOG.logI('{}: [{}][{}/{}]'.format(self.phase, self.epoch, i, len(self.loader))) self.postEpoch() self.saveState(self.getTime()) def processAccept(self): self.setValContext() def process(self): self.auditConfig() self.iter = 0 epoch_start = self.epoch if self.conf.ema: self.ema_updates = self.epoch * len( self.train_loader) // self.conf.nominal_batch_factor self.processVal(smoke=True) self.optimizer.zero_grad() for epoch in range(epoch_start, self.conf.epoch_num): self.epoch = epoch LOG.logI('Epoch {} started...'.format(self.epoch)) self.processTrain() self.processVal() self.processAccept() def __call__(self): self.process()
if i % args.logfreq == 0: niter = epoch*len(train_loader)+i tb_writer.add_scalar('Train/Loss', loss_reducer(runningLoss), niter) wandb.log({"Epoch":epoch, "TrainLoss":loss_reducer(runningLoss)})#, step=niter) # tensorboard_images(tb_writer, inp, out.detach(), gt, epoch, 'train') runningLoss = [] if args.finetune or (epoch % args.savefreq == 0): checkpoint = { 'epoch': epoch, 'iterations': (epoch+1)*len(train_loader), 'best_loss': best_loss, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'AMPScaler': scaler.state_dict() } torch.save(checkpoint, os.path.join(save_path, trainID+".pth.tar")) if args.modelid != 9: torch.onnx.export(model, images, trainID+".onnx", input_names=["LRCurrTP"], output_names=["SuperResolvedCurrTP"]) wandb.save(trainID+".onnx") tb_writer.add_scalar('Train/EpochLoss', loss_reducer(train_loss), epoch) wandb.log({"TrainEpochLoss":loss_reducer(train_loss)})#, step=epoch) #Validate if val_loader: model.eval() with torch.no_grad(): runningLoss = [] val_loss = []
def main(args): comm = MPI.COMM_WORLD world_size = comm.Get_size() rank = comm.Get_rank() os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(args.master_port) torch.cuda.set_device(rank) dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) device = torch.device("cuda") logger = None tb_logger = None if rank == 0: if not os.path.exists(args.save_path): os.mkdir(args.save_path) if not os.path.exists(args.tensorboard_log_dir): os.mkdir(args.tensorboard_log_dir) tb_logger = SummaryWriter( f"{args.tensorboard_log_dir}/{args.model_name}") logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) handler = TqdmLoggingHandler() handler.setFormatter(logging.Formatter(" %(asctime)s - %(message)s")) logger.addHandler(handler) logger.propagate = False write_log(logger, "Load data") def load_data(args): gc.disable() with open(f"{args.preprocessed_data_path}/hanja_korean_word2id.pkl", "rb") as f: data = pickle.load(f) hanja_word2id = data['hanja_word2id'] korean_word2id = data['korean_word2id'] with open(f"{args.preprocessed_data_path}/preprocessed_train.pkl", "rb") as f: data = pickle.load(f) train_hanja_indices = data['hanja_indices'] train_korean_indices = data['korean_indices'] train_additional_hanja_indices = data['additional_hanja_indices'] with open(f"{args.preprocessed_data_path}/preprocessed_valid.pkl", "rb") as f: data = pickle.load(f) valid_hanja_indices = data['hanja_indices'] valid_korean_indices = data['korean_indices'] valid_additional_hanja_indices = data['additional_hanja_indices'] gc.enable() write_log(logger, "Finished loading data!") return (hanja_word2id, korean_word2id, train_hanja_indices, train_korean_indices, train_additional_hanja_indices, valid_hanja_indices, valid_korean_indices, valid_additional_hanja_indices) # load data (hanja_word2id, korean_word2id, train_hanja_indices, train_korean_indices, train_additional_hanja_indices, valid_hanja_indices, valid_korean_indices, valid_additional_hanja_indices) = load_data(args) hanja_vocab_num = len(hanja_word2id) korean_vocab_num = len(korean_word2id) hk_dataset = HanjaKoreanDataset(train_hanja_indices, train_korean_indices, min_len=args.min_len, src_max_len=args.src_max_len, trg_max_len=args.trg_max_len) hk_sampler = DistributedSampler(hk_dataset, num_replicas=world_size, rank=rank) hk_loader = DataLoader(hk_dataset, drop_last=True, batch_size=args.hk_batch_size, sampler=hk_sampler, num_workers=args.num_workers, prefetch_factor=4, pin_memory=True) write_log(logger, f"hanja-korean: {len(hk_dataset)}, {len(hk_loader)}") h_dataset = HanjaDataset(train_hanja_indices, train_additional_hanja_indices, hanja_word2id, min_len=args.min_len, src_max_len=args.src_max_len) h_sampler = DistributedSampler(h_dataset, num_replicas=world_size, rank=rank) h_loader = DataLoader(h_dataset, drop_last=True, batch_size=args.h_batch_size, sampler=h_sampler, num_workers=args.num_workers, prefetch_factor=4, pin_memory=True) write_log(logger, f"hanja: {len(h_dataset)}, {len(h_loader)}") hk_valid_dataset = HanjaKoreanDataset(valid_hanja_indices, valid_korean_indices, min_len=args.min_len, src_max_len=args.src_max_len, trg_max_len=args.trg_max_len) hk_valid_sampler = DistributedSampler(hk_valid_dataset, num_replicas=world_size, rank=rank) hk_valid_loader = DataLoader(hk_valid_dataset, drop_last=True, batch_size=args.hk_batch_size, sampler=hk_valid_sampler) write_log( logger, f"hanja-korean-valid: {len(hk_valid_dataset)}, {len(hk_valid_loader)}") h_valid_dataset = HanjaDataset(valid_hanja_indices, valid_additional_hanja_indices, hanja_word2id, min_len=args.min_len, src_max_len=args.src_max_len) h_valid_sampler = DistributedSampler(h_valid_dataset, num_replicas=world_size, rank=rank) h_valid_loader = DataLoader(h_valid_dataset, drop_last=True, batch_size=args.h_batch_size, sampler=h_valid_sampler) write_log(logger, f"hanja: {len(h_valid_dataset)}, {len(h_valid_loader)}") del (train_hanja_indices, train_korean_indices, train_additional_hanja_indices, valid_hanja_indices, valid_korean_indices, valid_additional_hanja_indices) write_log(logger, "Build model") model = Transformer(hanja_vocab_num, korean_vocab_num, pad_idx=args.pad_idx, bos_idx=args.bos_idx, eos_idx=args.eos_idx, src_max_len=args.src_max_len, trg_max_len=args.trg_max_len, d_model=args.d_model, d_embedding=args.d_embedding, n_head=args.n_head, dropout=args.dropout, dim_feedforward=args.dim_feedforward, num_encoder_layer=args.num_encoder_layer, num_decoder_layer=args.num_decoder_layer, num_mask_layer=args.num_mask_layer).to(device) model = nn.parallel.DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True) for param in model.parameters(): dist.broadcast(param.data, 0) dist.barrier() write_log( logger, f"Total Parameters: {sum([p.nelement() for p in model.parameters()])}") no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = Ralamb(params=optimizer_grouped_parameters, lr=args.lr) total_iters = round( len(hk_loader) / args.num_grad_accumulate * args.epochs) scheduler = get_cosine_schedule_with_warmup( optimizer, round(total_iters * args.warmup_ratio), total_iters) scaler = GradScaler() start_epoch = 0 if args.resume: def load_states(): checkpoint = torch.load( f'{args.save_path}/{args.model_name}_ckpt.pt', map_location='cpu') start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) scaler.load_state_dict(checkpoint['scaler']) return start_epoch start_epoch = load_states() write_log(logger, f"Training start - Total iter: {total_iters}\n") iter_num = round(len(hk_loader) / args.num_grad_accumulate) global_step = start_epoch * iter_num hk_iter = iter(hk_loader) h_iter = iter(h_loader) model.train() tgt_mask = Transformer.generate_square_subsequent_mask( args.trg_max_len - 1, device) # validation validate(model, tgt_mask, h_valid_loader, hk_valid_loader, rank, logger, tb_logger, 0, device) for epoch in range(start_epoch + 1, args.epochs + 1): while True: start = time.time() finish_epoch = False trans_top5, trans_loss, mask_top5, mask_loss = 0.0, 0.0, 0.0, 0.0 if args.train_reconstruct: optimizer.zero_grad(set_to_none=True) for _ in range(args.num_grad_accumulate): try: src_sequences, trg_sequences = next(h_iter) except StopIteration: h_sampler.set_epoch(epoch) h_iter = iter(h_loader) src_sequences, trg_sequences = next(h_iter) trg_sequences = trg_sequences.to(device) src_sequences = src_sequences.to(device) non_pad = trg_sequences != args.pad_idx trg_sequences = trg_sequences[non_pad].contiguous().view( -1) with autocast(): predicted = model.module.reconstruct_predict( src_sequences, masked_position=non_pad) predicted = predicted.view(-1, predicted.size(-1)) loss = label_smoothing_loss( predicted, trg_sequences) / args.num_grad_accumulate scaler.scale(loss).backward() if global_step % args.print_freq == 0: mask_top5 += accuracy(predicted, trg_sequences, 5) / args.num_grad_accumulate mask_loss += loss.detach().item() for param in model.parameters(): if param.grad is not None: dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) param.grad.data = param.grad.data / world_size scaler.step(optimizer) scaler.update() if args.train_translate: optimizer.zero_grad(set_to_none=True) for _ in range(args.num_grad_accumulate): try: src_sequences, trg_sequences = next(hk_iter) except StopIteration: hk_sampler.set_epoch(epoch) hk_iter = iter(hk_loader) src_sequences, trg_sequences = next(hk_iter) finish_epoch = True trg_sequences = trg_sequences.to(device) trg_sequences_target = trg_sequences[:, 1:] src_sequences = src_sequences.to(device) non_pad = trg_sequences_target != args.pad_idx trg_sequences_target = trg_sequences_target[ non_pad].contiguous().view(-1) with autocast(): predicted = model(src_sequences, trg_sequences[:, :-1], tgt_mask, non_pad_position=non_pad) predicted = predicted.view(-1, predicted.size(-1)) loss = label_smoothing_loss( predicted, trg_sequences_target) / args.num_grad_accumulate scaler.scale(loss).backward() if global_step % args.print_freq == 0: trans_top5 += accuracy(predicted, trg_sequences_target, 5) / args.num_grad_accumulate trans_loss += loss.detach().item() for param in model.parameters(): if param.grad is not None: dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) param.grad.data = param.grad.data / world_size scaler.step(optimizer) scaler.update() scheduler.step() # Print status if global_step % args.print_freq == 0: if args.train_reconstruct: mask_top5 = torch.cuda.FloatTensor([mask_top5]) mask_loss = torch.cuda.FloatTensor([mask_loss]) dist.all_reduce(mask_top5, op=dist.ReduceOp.SUM) dist.all_reduce(mask_loss, op=dist.ReduceOp.SUM) mask_top5 = (mask_top5 / world_size).item() mask_loss = (mask_loss / world_size).item() if args.train_translate: trans_top5 = torch.cuda.FloatTensor([trans_top5]) trans_loss = torch.cuda.FloatTensor([trans_loss]) dist.all_reduce(trans_top5, op=dist.ReduceOp.SUM) dist.all_reduce(trans_loss, op=dist.ReduceOp.SUM) trans_top5 = (trans_top5 / world_size).item() trans_loss = (trans_loss / world_size).item() if rank == 0: batch_time = time.time() - start write_log( logger, f'[{global_step}/{total_iters}, {epoch}]\tIter time: {batch_time:.3f}\t' f'Trans loss: {trans_loss:.3f}\tMask_loss: {mask_loss:.3f}\t' f'Trans@5: {trans_top5:.3f}\tMask@5: {mask_top5:.3f}') tb_logger.add_scalar('loss/translate', trans_loss, global_step) tb_logger.add_scalar('loss/mask', mask_loss, global_step) tb_logger.add_scalar('top5/translate', trans_top5, global_step) tb_logger.add_scalar('top5/mask', mask_top5, global_step) tb_logger.add_scalar('batch/time', batch_time, global_step) tb_logger.add_scalar('batch/lr', optimizer.param_groups[0]['lr'], global_step) global_step += 1 if finish_epoch: break # validation validate(model, tgt_mask, h_valid_loader, hk_valid_loader, rank, logger, tb_logger, epoch, device) # save model if rank == 0: torch.save( { 'epoch': epoch, 'model': model.module.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'scaler': scaler.state_dict() }, f'{args.save_path}/{args.model_name}_ckpt.pt') write_log(logger, f"***** {epoch}th model updated! *****")