def load_network(self, loaded_net=None): add_log = False if loaded_net is None: add_log = True if self.cfg.load.wandb_load_path is not None: self.cfg.load.network_chkpt_path = wandb.restore( self.cfg.load.network_chkpt_path, run_path=self.cfg.load.wandb_load_path, ).name loaded_net = torch.load( self.cfg.load.network_chkpt_path, map_location=torch.device(self.device), ) loaded_clean_net = OrderedDict() # remove unnecessary 'module.' for k, v in loaded_net.items(): if k.startswith("module."): loaded_clean_net[k[7:]] = v else: loaded_clean_net[k] = v self.net.load_state_dict(loaded_clean_net, strict=self.cfg.load.strict_load) if is_logging_process() and add_log: self._logger.info("Checkpoint %s is loaded" % self.cfg.load.network_chkpt_path)
def save_network(self, save_file=True): if is_logging_process(): net = self.net.module if isinstance(self.net, DDP) else self.net state_dict = net.state_dict() for key, param in state_dict.items(): state_dict[key] = param.to("cpu") if save_file: save_filename = "%s_%d.pt" % (self.cfg.name, self.step) save_path = osp.join(self.cfg.log.chkpt_dir, save_filename) torch.save(state_dict, save_path) if self.cfg.log.use_wandb: wandb.save(save_path) if is_logging_process(): self._logger.info("Saved network checkpoint to: %s" % save_path) return state_dict
def save_training_state(self): if is_logging_process(): save_filename = "%s_%d.state" % (self.cfg.name, self.step) save_path = osp.join(self.cfg.log.chkpt_dir, save_filename) net_state_dict = self.save_network(False) state = { "model": net_state_dict, "optimizer": self.optimizer.state_dict(), "step": self.step, "epoch": self.epoch, } torch.save(state, save_path) if self.cfg.log.use_wandb: wandb.save(save_path) if is_logging_process(): self._logger.info("Saved training state to: %s" % save_path)
def train_model(cfg, model, train_loader, writer): logger = get_logger(cfg, os.path.basename(__file__)) model.net.train() for model_input, model_target in train_loader: model.optimize_parameters(model_input, model_target) loss = model.log.loss_v model.step += 1 if is_logging_process() and (loss > 1e8 or math.isnan(loss)): logger.error("Loss exploded to %.02f at step %d!" % (loss, model.step)) raise Exception("Loss exploded") if model.step % cfg.log.summary_interval == 0: if writer is not None: writer.logging_with_step(loss, model.step, "train_loss") if is_logging_process(): logger.info("Train Loss %.04f at step %d" % (loss, model.step))
def test_model(cfg, model, test_loader, writer): logger = get_logger(cfg, os.path.basename(__file__)) model.net.eval() total_test_loss = 0 test_loop_len = 0 total_test_accuracy = 0 total_test = 0 with torch.no_grad(): for model_input, target in test_loader: model.feed_data(input=model_input, GT=target) output = model.run_network() loss_v = model.loss_f(output, model.GT) _, predicted = torch.max(output.data, 1) total_v = torch.tensor(int(target.size(0))).to('cuda') accuracy_v = torch.tensor( float( (predicted == target.to('cuda')).sum().item())).to('cuda') # print(f"loss_v: {type(loss_v)}") # print(f"predicted: {type(accuracy_v)}") # print(f"target: {target.device}") # print(f"accuracy_v: {accuracy_v.device}") if cfg.dist.gpus > 0: # Aggregate loss_v from all GPUs. loss_v is set as the sum of all GPUs' loss_v. torch.distributed.all_reduce(loss_v) loss_v /= torch.tensor(float(cfg.dist.gpus)) # Aggregate accuracy_v from all GPUs. loss_v is set as the sum of all GPUs' accuracy_v. torch.distributed.all_reduce(accuracy_v) torch.distributed.all_reduce(total_v) total_test += total_v.to("cpu").item() total_test_loss += loss_v.to("cpu").item() total_test_accuracy += accuracy_v.to("cpu").item() test_loop_len += 1 # print(f"total_v = {total_test}") # print(f"accuracy_v = {total_test_accuracy}") total_test_loss /= test_loop_len total_test_accuracy /= total_test if writer is not None: writer.logging_with_step(total_test_accuracy, model.step, "test_accuracy") writer.logging_with_step(total_test_loss, model.step, "test_loss") writer.logging_with_epoch(total_test_accuracy, model.step, model.epoch, "total_test_accuracy_per_epoch") writer.logging_with_epoch(total_test_loss, model.step, model.epoch, "test_loss_per_epoch") if is_logging_process(): logger.info("Test Loss %.04f at step %d" % (total_test_loss, model.step))
def load_training_state(self): if self.cfg.load.wandb_load_path is not None: self.cfg.load.resume_state_path = wandb.restore( self.cfg.load.resume_state_path, run_path=self.cfg.load.wandb_load_path, ).name resume_state = torch.load( self.cfg.load.resume_state_path, map_location=torch.device(self.device), ) self.load_network(loaded_net=resume_state["model"]) self.optimizer.load_state_dict(resume_state["optimizer"]) self.step = resume_state["step"] self.epoch = resume_state["epoch"] if is_logging_process(): self._logger.info("Resuming from training state: %s" % self.cfg.load.resume_state_path)
def test_model(cfg, model, test_loader, writer): logger = get_logger(cfg, os.path.basename(__file__)) model.net.eval() total_test_loss = 0 test_loop_len = 0 with torch.no_grad(): for model_input, model_target in test_loader: output = model.inference(model_input) loss_v = model.loss_f(output, model_target.to(cfg.device)) if cfg.dist.gpus > 0: # Aggregate loss_v from all GPUs. loss_v is set as the sum of all GPUs' loss_v. torch.distributed.all_reduce(loss_v) loss_v /= torch.tensor(float(cfg.dist.gpus)) total_test_loss += loss_v.to("cpu").item() test_loop_len += 1 total_test_loss /= test_loop_len if writer is not None: writer.logging_with_step(total_test_loss, model.step, "test_loss") if is_logging_process(): logger.info("Test Loss %.04f at step %d" % (total_test_loss, model.step))
def train_loop(rank, cfg): logger = get_logger(cfg, os.path.basename(__file__)) if cfg.device == "cuda" and cfg.dist.gpus != 0: cfg.device = rank # turn off background generator when distributed run is on cfg.data.use_background_generator = False setup(cfg, rank) torch.cuda.set_device(cfg.device) # setup writer if is_logging_process(): # set log/checkpoint dir os.makedirs(cfg.log.chkpt_dir, exist_ok=True) # set writer (tensorboard / wandb) writer = Writer(cfg, "tensorboard") cfg_str = OmegaConf.to_yaml(cfg) logger.info("Config:\n" + cfg_str) if cfg.data.train_dir == "" or cfg.data.test_dir == "": logger.error("train or test data directory cannot be empty.") raise Exception("Please specify directories of data") logger.info("Set up train process") logger.info("BackgroundGenerator is turned off when Distributed running is on") # download MNIST dataset before making dataloader # TODO: This is example code. You should change this part as you need _ = torchvision.datasets.MNIST( root=hydra.utils.to_absolute_path("dataset/meta"), train=True, transform=torchvision.transforms.ToTensor(), download=True, ) _ = torchvision.datasets.MNIST( root=hydra.utils.to_absolute_path("dataset/meta"), train=False, transform=torchvision.transforms.ToTensor(), download=True, ) # Sync dist processes (because of download MNIST Dataset) if cfg.dist.gpus != 0: dist.barrier() # make dataloader if is_logging_process(): logger.info("Making train dataloader...") train_loader = create_dataloader(cfg, DataloaderMode.train, rank) if is_logging_process(): logger.info("Making test dataloader...") test_loader = create_dataloader(cfg, DataloaderMode.test, rank) # init Model net_arch = Net_arch(cfg) loss_f = torch.nn.CrossEntropyLoss() model = Model(cfg, net_arch, loss_f, rank) # load training state / network checkpoint if cfg.load.resume_state_path is not None: model.load_training_state() elif cfg.load.network_chkpt_path is not None: model.load_network() else: if is_logging_process(): logger.info("Starting new training run.") try: if cfg.dist.gpus == 0 or cfg.data.divide_dataset_per_gpu: epoch_step = 1 else: epoch_step = cfg.dist.gpus for model.epoch in itertools.count(model.epoch + 1, epoch_step): if model.epoch > cfg.num_epoch: break train_model(cfg, model, train_loader, writer) if model.epoch % cfg.log.chkpt_interval == 0: model.save_network() model.save_training_state() test_model(cfg, model, test_loader, writer) if is_logging_process(): logger.info("End of Train") except Exception as e: if is_logging_process(): logger.error(traceback.format_exc()) else: traceback.print_exc() finally: if cfg.dist.gpus != 0: cleanup()