Example #1
0
    def load_network(self, loaded_net=None):
        add_log = False
        if loaded_net is None:
            add_log = True
            if self.cfg.load.wandb_load_path is not None:
                self.cfg.load.network_chkpt_path = wandb.restore(
                    self.cfg.load.network_chkpt_path,
                    run_path=self.cfg.load.wandb_load_path,
                ).name
            loaded_net = torch.load(
                self.cfg.load.network_chkpt_path,
                map_location=torch.device(self.device),
            )
        loaded_clean_net = OrderedDict()  # remove unnecessary 'module.'
        for k, v in loaded_net.items():
            if k.startswith("module."):
                loaded_clean_net[k[7:]] = v
            else:
                loaded_clean_net[k] = v

        self.net.load_state_dict(loaded_clean_net,
                                 strict=self.cfg.load.strict_load)
        if is_logging_process() and add_log:
            self._logger.info("Checkpoint %s is loaded" %
                              self.cfg.load.network_chkpt_path)
Example #2
0
 def save_network(self, save_file=True):
     if is_logging_process():
         net = self.net.module if isinstance(self.net, DDP) else self.net
         state_dict = net.state_dict()
         for key, param in state_dict.items():
             state_dict[key] = param.to("cpu")
         if save_file:
             save_filename = "%s_%d.pt" % (self.cfg.name, self.step)
             save_path = osp.join(self.cfg.log.chkpt_dir, save_filename)
             torch.save(state_dict, save_path)
             if self.cfg.log.use_wandb:
                 wandb.save(save_path)
             if is_logging_process():
                 self._logger.info("Saved network checkpoint to: %s" %
                                   save_path)
         return state_dict
Example #3
0
 def save_training_state(self):
     if is_logging_process():
         save_filename = "%s_%d.state" % (self.cfg.name, self.step)
         save_path = osp.join(self.cfg.log.chkpt_dir, save_filename)
         net_state_dict = self.save_network(False)
         state = {
             "model": net_state_dict,
             "optimizer": self.optimizer.state_dict(),
             "step": self.step,
             "epoch": self.epoch,
         }
         torch.save(state, save_path)
         if self.cfg.log.use_wandb:
             wandb.save(save_path)
         if is_logging_process():
             self._logger.info("Saved training state to: %s" % save_path)
def train_model(cfg, model, train_loader, writer):
    logger = get_logger(cfg, os.path.basename(__file__))
    model.net.train()
    for model_input, model_target in train_loader:
        model.optimize_parameters(model_input, model_target)
        loss = model.log.loss_v
        model.step += 1

        if is_logging_process() and (loss > 1e8 or math.isnan(loss)):
            logger.error("Loss exploded to %.02f at step %d!" %
                         (loss, model.step))
            raise Exception("Loss exploded")

        if model.step % cfg.log.summary_interval == 0:
            if writer is not None:
                writer.logging_with_step(loss, model.step, "train_loss")
            if is_logging_process():
                logger.info("Train Loss %.04f at step %d" % (loss, model.step))
Example #5
0
def test_model(cfg, model, test_loader, writer):
    logger = get_logger(cfg, os.path.basename(__file__))
    model.net.eval()
    total_test_loss = 0
    test_loop_len = 0
    total_test_accuracy = 0
    total_test = 0
    with torch.no_grad():
        for model_input, target in test_loader:
            model.feed_data(input=model_input, GT=target)
            output = model.run_network()
            loss_v = model.loss_f(output, model.GT)
            _, predicted = torch.max(output.data, 1)
            total_v = torch.tensor(int(target.size(0))).to('cuda')

            accuracy_v = torch.tensor(
                float(
                    (predicted == target.to('cuda')).sum().item())).to('cuda')

            # print(f"loss_v: {type(loss_v)}")
            # print(f"predicted: {type(accuracy_v)}")
            # print(f"target: {target.device}")
            # print(f"accuracy_v: {accuracy_v.device}")

            if cfg.dist.gpus > 0:
                # Aggregate loss_v from all GPUs. loss_v is set as the sum of all GPUs' loss_v.
                torch.distributed.all_reduce(loss_v)
                loss_v /= torch.tensor(float(cfg.dist.gpus))

                # Aggregate accuracy_v from all GPUs. loss_v is set as the sum of all GPUs' accuracy_v.
                torch.distributed.all_reduce(accuracy_v)
                torch.distributed.all_reduce(total_v)

            total_test += total_v.to("cpu").item()
            total_test_loss += loss_v.to("cpu").item()
            total_test_accuracy += accuracy_v.to("cpu").item()

            test_loop_len += 1
        # print(f"total_v = {total_test}")
        # print(f"accuracy_v = {total_test_accuracy}")
        total_test_loss /= test_loop_len
        total_test_accuracy /= total_test

        if writer is not None:
            writer.logging_with_step(total_test_accuracy, model.step,
                                     "test_accuracy")
            writer.logging_with_step(total_test_loss, model.step, "test_loss")
            writer.logging_with_epoch(total_test_accuracy, model.step,
                                      model.epoch,
                                      "total_test_accuracy_per_epoch")
            writer.logging_with_epoch(total_test_loss, model.step, model.epoch,
                                      "test_loss_per_epoch")
        if is_logging_process():
            logger.info("Test Loss %.04f at step %d" %
                        (total_test_loss, model.step))
Example #6
0
    def load_training_state(self):
        if self.cfg.load.wandb_load_path is not None:
            self.cfg.load.resume_state_path = wandb.restore(
                self.cfg.load.resume_state_path,
                run_path=self.cfg.load.wandb_load_path,
            ).name
        resume_state = torch.load(
            self.cfg.load.resume_state_path,
            map_location=torch.device(self.device),
        )

        self.load_network(loaded_net=resume_state["model"])
        self.optimizer.load_state_dict(resume_state["optimizer"])
        self.step = resume_state["step"]
        self.epoch = resume_state["epoch"]
        if is_logging_process():
            self._logger.info("Resuming from training state: %s" %
                              self.cfg.load.resume_state_path)
def test_model(cfg, model, test_loader, writer):
    logger = get_logger(cfg, os.path.basename(__file__))
    model.net.eval()
    total_test_loss = 0
    test_loop_len = 0
    with torch.no_grad():
        for model_input, model_target in test_loader:
            output = model.inference(model_input)
            loss_v = model.loss_f(output, model_target.to(cfg.device))
            if cfg.dist.gpus > 0:
                # Aggregate loss_v from all GPUs. loss_v is set as the sum of all GPUs' loss_v.
                torch.distributed.all_reduce(loss_v)
                loss_v /= torch.tensor(float(cfg.dist.gpus))
            total_test_loss += loss_v.to("cpu").item()
            test_loop_len += 1

        total_test_loss /= test_loop_len

        if writer is not None:
            writer.logging_with_step(total_test_loss, model.step, "test_loss")
        if is_logging_process():
            logger.info("Test Loss %.04f at step %d" %
                        (total_test_loss, model.step))
Example #8
0
def train_loop(rank, cfg):
    logger = get_logger(cfg, os.path.basename(__file__))
    if cfg.device == "cuda" and cfg.dist.gpus != 0:
        cfg.device = rank
        # turn off background generator when distributed run is on
        cfg.data.use_background_generator = False
        setup(cfg, rank)
        torch.cuda.set_device(cfg.device)

    # setup writer
    if is_logging_process():
        # set log/checkpoint dir
        os.makedirs(cfg.log.chkpt_dir, exist_ok=True)
        # set writer (tensorboard / wandb)
        writer = Writer(cfg, "tensorboard")
        cfg_str = OmegaConf.to_yaml(cfg)
        logger.info("Config:\n" + cfg_str)
        if cfg.data.train_dir == "" or cfg.data.test_dir == "":
            logger.error("train or test data directory cannot be empty.")
            raise Exception("Please specify directories of data")
        logger.info("Set up train process")
        logger.info("BackgroundGenerator is turned off when Distributed running is on")

        # download MNIST dataset before making dataloader
        # TODO: This is example code. You should change this part as you need
        _ = torchvision.datasets.MNIST(
            root=hydra.utils.to_absolute_path("dataset/meta"),
            train=True,
            transform=torchvision.transforms.ToTensor(),
            download=True,
        )
        _ = torchvision.datasets.MNIST(
            root=hydra.utils.to_absolute_path("dataset/meta"),
            train=False,
            transform=torchvision.transforms.ToTensor(),
            download=True,
        )
    # Sync dist processes (because of download MNIST Dataset)
    if cfg.dist.gpus != 0:
        dist.barrier()

    # make dataloader
    if is_logging_process():
        logger.info("Making train dataloader...")
    train_loader = create_dataloader(cfg, DataloaderMode.train, rank)
    if is_logging_process():
        logger.info("Making test dataloader...")
    test_loader = create_dataloader(cfg, DataloaderMode.test, rank)

    # init Model
    net_arch = Net_arch(cfg)
    loss_f = torch.nn.CrossEntropyLoss()
    model = Model(cfg, net_arch, loss_f, rank)

    # load training state / network checkpoint
    if cfg.load.resume_state_path is not None:
        model.load_training_state()
    elif cfg.load.network_chkpt_path is not None:
        model.load_network()
    else:
        if is_logging_process():
            logger.info("Starting new training run.")

    try:
        if cfg.dist.gpus == 0 or cfg.data.divide_dataset_per_gpu:
            epoch_step = 1
        else:
            epoch_step = cfg.dist.gpus
        for model.epoch in itertools.count(model.epoch + 1, epoch_step):
            if model.epoch > cfg.num_epoch:
                break
            train_model(cfg, model, train_loader, writer)
            if model.epoch % cfg.log.chkpt_interval == 0:
                model.save_network()
                model.save_training_state()
            test_model(cfg, model, test_loader, writer)
        if is_logging_process():
            logger.info("End of Train")
    except Exception as e:
        if is_logging_process():
            logger.error(traceback.format_exc())
        else:
            traceback.print_exc()
    finally:
        if cfg.dist.gpus != 0:
            cleanup()