コード例 #1
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        self.model.to(self.device)
        self.test_mode = self.config.get("test_mode", False)
        if self.test_mode:
            self.model.eval()
        self.submodules = ["model"]
        self.no_restore_keys = retrieve(self.config,
                                        'no_restore_keys',
                                        default='').split(',')

        self.learning_rate = get_learning_rate(self.config)
        self.logger.info("learning_rate: {}".format(self.learning_rate))
        if "loss" in self.config:
            self.loss = get_obj_from_str(self.config["loss"])(self.config)
            self.loss.to(self.device)
            self.submodules.append("loss")
            self.optimizer = torch.optim.Adam(itertools.chain(
                self.model.parameters(), self.loss.parameters()),
                                              lr=self.learning_rate,
                                              betas=(0.5, 0.9))
            self.submodules.append("optimizer")

        self.num_steps = retrieve(self.config, "num_steps", default=0)
        self.decay_start = retrieve(self.config,
                                    "decay_start",
                                    default=self.num_steps)
コード例 #2
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.learning_rate = get_learning_rate(self.config)
     self.loss_lr_factor = retrieve(self.config, "loss_lr_factor", default=1.0)
     self.loss_lr = self.loss_lr_factor * self.learning_rate
     self.logger.info("learning_rate: {}".format(self.learning_rate))
     self.logger.info("loss learning_rate: {}".format(self.loss_lr))
     if "loss" in self.config:
         self.loss = get_obj_from_str(self.config["loss"])(self.config)
         self.loss.to(self.device)
         self.submodules.append("loss")
         self.optimizer = torch.optim.Adam([
             {"params": self.model.parameters()},
             {"params": self.loss.parameters(),
              "lr": self.loss_lr}], lr=self.learning_rate, betas=(0.5, 0.9))
         self.submodules.append("optimizer")
コード例 #3
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.test_mode = self.config.get("test_mode", False)
        self.submodules = ["model"]

        if "pretrained_key" in self.config:
            self.model = get_obj_from_str(self.config["pretrained_model"]).from_pretrained(
                self.config["pretrained_key"])
            self.logger.info("Loaded pretrained model from key {}".format(self.config["pretrained_key"]))
            self.logger.info("Warning: This overrides any model specified as 'model' in the edflow config.")
            self.model.eval()

        if self.config.get("test_mode", False):
            # in eval mode
            self.model.eval()
        self.model.to(self.device)
        self.do_not_restore_keys = retrieve(self.config, 'no_restore_keys', default='').split(',')
コード例 #4
0
def _get_state(config):
    Dataset = get_obj_from_str(config["dataset"])
    dataset = Dataset(config)
    return dataset