def setup(self, load=None, load_state_dict=None, save_every=None, train_dir=None, report_every=50): expect( not (load is not None and load_state_dict is not None), "`load` and `load_state_dict` cannot be passed simultaneously.") if load is not None: self.load(load) else: assert self.model is not None if load_state_dict is not None: self._load_state_dict(load_state_dict) self.logger.info("param size = %f M", utils.count_parameters(self.model) / 1.e6) self._parallelize() self.save_every = save_every self.train_dir = train_dir self.report_every = report_every expect( self.save_every is None or self.train_dir is not None, "when `save_every` is not None, make sure `train_dir` is not None") self._is_setup = True
def load(self, path): # load the model m_path = os.path.join(path, "model.pt") if os.path.isdir(path) else path if not os.path.exists(m_path): m_path = os.path.join(path, "model_state.pt") self._load_state_dict(m_path) else: self.model = torch.load(m_path, map_location=torch.device("cpu")) self.model.to(self.device) self._parallelize() log_strs = ["model from {}".format(m_path)] # init the optimzier/scheduler self.optimizer = self._init_optimizer() self.scheduler = self._init_scheduler(self.optimizer, self.optimizer_scheduler_cfg) o_path = os.path.join(path, "optimizer.pt") if os.path.isdir(path) else None if o_path and os.path.exists(o_path): checkpoint = torch.load(o_path, map_location=torch.device("cpu")) self.optimizer.load_state_dict(checkpoint["optimizer"]) log_strs.append("optimizer from {}".format(o_path)) self.last_epoch = checkpoint["epoch"] # load the optimizer/scheduler if self.scheduler is not None: s_path = os.path.join(path, "scheduler.pt") if os.path.isdir(path) else None if s_path and os.path.exists(s_path): self.scheduler.load_state_dict(torch.load(s_path, map_location=torch.device("cpu"))) log_strs.append("scheduler from {}".format(s_path)) self.logger.info("param size = %f M", utils.count_parameters(self.model)/1.e6) self.logger.info("Loaded checkpoint from %s: %s", path, ", ".join(log_strs)) self.logger.info("Last epoch: %d", self.last_epoch)
def load(self, path): del self.parallel_model # load the model m_path = os.path.join(path, "model.pt") if os.path.isdir(path) else path # load using cpu self.model = torch.load(m_path, map_location=torch.device("cpu")) # to device self.model.to(self.device) # maybe parallelize self._parallelize() log_strs = ["model from {}".format(m_path)] # init/load the optimizer o_path = os.path.join(path, "optimizer.pt") if os.path.isdir(path) else None if o_path and os.path.exists(o_path): # init according to the type of the saved optimizer, and then load checkpoint = torch.load(o_path, map_location=torch.device("cpu")) optimizer_state = checkpoint["optimizer"] if "t0" in optimizer_state["param_groups"][0]: # ASGD self.optimizer = torch.optim.ASGD(self.model.parameters(), lr=self.learning_rate, t0=0, lambd=0., weight_decay=self.weight_decay) else: # SGD self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) self.optimizer.load_state_dict(optimizer_state) self.last_epoch = self.epoch = checkpoint["epoch"] log_strs.append("optimizer from {}".format(o_path)) else: # just init a SGD optimizer self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) # init/load the scheduler self.scheduler = self._init_scheduler(self.optimizer, self.optimizer_scheduler_cfg) if self.optimizer_scheduler_cfg is not None: s_path = os.path.join(path, "scheduler.pt") if os.path.isdir(path) else None if s_path and os.path.exists(s_path): self.scheduler.load_state_dict(torch.load(s_path, map_location=torch.device("cpu"))) log_strs.append("scheduler from {}".format(s_path)) self.logger.info("param size = %f M", utils.count_parameters(self.model)/1.e6) self.logger.info("Loaded checkpoint from %s: %s", path, ", ".join(log_strs)) self.logger.info("Last epoch: %d", self.last_epoch)
def setup(self, load=None, load_state_dict=None, save_every=None, train_dir=None, report_every=50): assert load_state_dict is None, "Currently not supported and tested." if load is not None: self.load(load) else: self.logger.info("param size = %f M", utils.count_parameters(self.model)/1.e6) self._parallelize() self.save_every = save_every self.train_dir = train_dir self.report_every = report_every expect(self.save_every is None or self.train_dir is not None, "when `save_every` is not None, make sure `train_dir` is not None") self._is_setup = True
gpus=gpu_list, objective=objective, ) # check trainer support for data type expect(_data_type in trainer.supported_data_types()) # start training LOGGER.info("Start training.") # trainer.setup(load, load_state_dict, save_every, train_dir) if cfg["dataset_type"] == "cifar10": dummy_input = torch.rand([2, 3, 32, 32]).to(device) elif cfg["dataset_type"] == "imagenet": dummy_input = torch.rand([2, 3, 224, 224]).to(device) else: raise AssertionError("Dataset not supported") output = trainer.model.forward(dummy_input) dot = make_dot(output, params=dict(trainer.model.named_parameters())) dot.format = "pdf" dot.render("./test-torchviz") flops = trainer.model.total_flops / 1.0e6 bi_flops = trainer.model.bi_flops / 1.0e6 model_params = utils.count_parameters(trainer.model, count_binary=True) / 1.0e6 print("param size = {} M | bi-param {} M".format(model_params[0], model_params[1])) print("flops = {} M | bi-flops {} M".format(flops, bi_flops))