Пример #1
0
    def setup(self,
              load=None,
              load_state_dict=None,
              save_every=None,
              train_dir=None,
              report_every=50):
        expect(
            not (load is not None and load_state_dict is not None),
            "`load` and `load_state_dict` cannot be passed simultaneously.")
        if load is not None:
            self.load(load)
        else:
            assert self.model is not None
            if load_state_dict is not None:
                self._load_state_dict(load_state_dict)

            self.logger.info("param size = %f M",
                             utils.count_parameters(self.model) / 1.e6)
            self._parallelize()

        self.save_every = save_every
        self.train_dir = train_dir
        self.report_every = report_every

        expect(
            self.save_every is None or self.train_dir is not None,
            "when `save_every` is not None, make sure `train_dir` is not None")

        self._is_setup = True
Пример #2
0
    def load(self, path):
        # load the model
        m_path = os.path.join(path, "model.pt") if os.path.isdir(path) else path
        if not os.path.exists(m_path):
            m_path = os.path.join(path, "model_state.pt")
            self._load_state_dict(m_path)
        else:
            self.model = torch.load(m_path, map_location=torch.device("cpu"))
        self.model.to(self.device)
        self._parallelize()
        log_strs = ["model from {}".format(m_path)]

        # init the optimzier/scheduler
        self.optimizer = self._init_optimizer()
        self.scheduler = self._init_scheduler(self.optimizer, self.optimizer_scheduler_cfg)

        o_path = os.path.join(path, "optimizer.pt") if os.path.isdir(path) else None
        if o_path and os.path.exists(o_path):
            checkpoint = torch.load(o_path, map_location=torch.device("cpu"))
            self.optimizer.load_state_dict(checkpoint["optimizer"])
            log_strs.append("optimizer from {}".format(o_path))
            self.last_epoch = checkpoint["epoch"]

        # load the optimizer/scheduler
        if self.scheduler is not None:
            s_path = os.path.join(path, "scheduler.pt") if os.path.isdir(path) else None
            if s_path and os.path.exists(s_path):
                self.scheduler.load_state_dict(torch.load(s_path, map_location=torch.device("cpu")))
                log_strs.append("scheduler from {}".format(s_path))

        self.logger.info("param size = %f M",
                         utils.count_parameters(self.model)/1.e6)
        self.logger.info("Loaded checkpoint from %s: %s", path, ", ".join(log_strs))
        self.logger.info("Last epoch: %d", self.last_epoch)
Пример #3
0
    def load(self, path):
        del self.parallel_model
        # load the model
        m_path = os.path.join(path, "model.pt") if os.path.isdir(path) else path
        # load using cpu
        self.model = torch.load(m_path, map_location=torch.device("cpu"))
        # to device
        self.model.to(self.device)
        # maybe parallelize
        self._parallelize()
        log_strs = ["model from {}".format(m_path)]

        # init/load the optimizer
        o_path = os.path.join(path, "optimizer.pt") if os.path.isdir(path) else None
        if o_path and os.path.exists(o_path):
            # init according to the type of the saved optimizer, and then load
            checkpoint = torch.load(o_path, map_location=torch.device("cpu"))
            optimizer_state = checkpoint["optimizer"]
            if "t0" in optimizer_state["param_groups"][0]:
                # ASGD
                self.optimizer = torch.optim.ASGD(self.model.parameters(), lr=self.learning_rate,
                                                  t0=0, lambd=0., weight_decay=self.weight_decay)
            else:
                # SGD
                self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate,
                                                 weight_decay=self.weight_decay)
            self.optimizer.load_state_dict(optimizer_state)
            self.last_epoch = self.epoch = checkpoint["epoch"]
            log_strs.append("optimizer from {}".format(o_path))
        else:
            # just init a SGD optimizer
            self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate,
                                             weight_decay=self.weight_decay)

        # init/load the scheduler
        self.scheduler = self._init_scheduler(self.optimizer, self.optimizer_scheduler_cfg)
        if self.optimizer_scheduler_cfg is not None:
            s_path = os.path.join(path, "scheduler.pt") if os.path.isdir(path) else None
            if s_path and os.path.exists(s_path):
                self.scheduler.load_state_dict(torch.load(s_path, map_location=torch.device("cpu")))
                log_strs.append("scheduler from {}".format(s_path))

        self.logger.info("param size = %f M",
                         utils.count_parameters(self.model)/1.e6)
        self.logger.info("Loaded checkpoint from %s: %s", path, ", ".join(log_strs))
        self.logger.info("Last epoch: %d", self.last_epoch)
Пример #4
0
    def setup(self, load=None, load_state_dict=None,
              save_every=None, train_dir=None, report_every=50):
        assert load_state_dict is None, "Currently not supported and tested."
        if load is not None:
            self.load(load)
        else:
            self.logger.info("param size = %f M",
                             utils.count_parameters(self.model)/1.e6)
            self._parallelize()

        self.save_every = save_every
        self.train_dir = train_dir
        self.report_every = report_every

        expect(self.save_every is None or self.train_dir is not None,
               "when `save_every` is not None, make sure `train_dir` is not None")

        self._is_setup = True
Пример #5
0
    gpus=gpu_list,
    objective=objective,
)
# check trainer support for data type
expect(_data_type in trainer.supported_data_types())

# start training
LOGGER.info("Start training.")
# trainer.setup(load, load_state_dict, save_every, train_dir)

if cfg["dataset_type"] == "cifar10":
    dummy_input = torch.rand([2, 3, 32, 32]).to(device)
elif cfg["dataset_type"] == "imagenet":
    dummy_input = torch.rand([2, 3, 224, 224]).to(device)
else:
    raise AssertionError("Dataset not supported")

output = trainer.model.forward(dummy_input)
dot = make_dot(output, params=dict(trainer.model.named_parameters()))

dot.format = "pdf"
dot.render("./test-torchviz")

flops = trainer.model.total_flops / 1.0e6
bi_flops = trainer.model.bi_flops / 1.0e6

model_params = utils.count_parameters(trainer.model, count_binary=True) / 1.0e6
print("param size = {} M | bi-param {} M".format(model_params[0],
                                                 model_params[1]))
print("flops = {} M | bi-flops {} M".format(flops, bi_flops))