示例#1
0
    def load(self, f=None):
        if self.has_checkpoint():
            # override argument with existing checkpoint
            f = self.get_checkpoint_file()
        if not f:
            # no checkpoint could be found
            self.logger.info(
                "No checkpoint found. Initializing model from scratch")
            return {}
        self.logger.info("Loading checkpoint from {}".format(f))
        checkpoint = self._load_file(f)
        # self._load_model(checkpoint)
        if "model" in checkpoint:
            self.logger.info("Loading model from {}".format(f))
            load_state_dict(self.model, checkpoint.pop("model"))
        if "optimizer" in checkpoint and self.optimizer:
            self.logger.info("Loading optimizer from {}".format(f))
            self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
        if "scheduler" in checkpoint and self.scheduler:
            self.logger.info("Loading scheduler from {}".format(f))
            self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
        if 'epoch' in checkpoint:
            self.logger.info(checkpoint['epoch'])

        # return any further checkpoint data
        return checkpoint
def load_models(
    G: "nn.Module" = None,
    g_optimizer: "optim" = None,
    args: "tupperware" = None,
    tag: str = "latest",
    is_local_rank_0: bool = True,
) -> "Union[nn.Module, optim, int, int, int]":

    latest_path = args.ckpt_dir / args.save_filename_latest_G
    best_path = args.ckpt_dir / args.save_filename_G

    if tag == "latest":
        path = latest_path
        if not path.exists():
            path = best_path
            tag = "best"

    elif tag == "best":
        path = best_path
        if not path.exists():
            path = latest_path
            tag = "latest"

    # Defaults
    start_epoch = 0
    global_step = 0
    loss = 1e6

    if args.resume:
        if path.is_file():
            checkpoint = torch.load(path, map_location=torch.device("cpu"))

            if is_local_rank_0:
                logging.info(f"Loading checkpoint from {path} with tag {tag}.")
            load_state_dict(G, checkpoint["state_dict"])
            # G.load_state_dict(checkpoint["state_dict"])

            if not args.finetune:
                if g_optimizer and "optimizer" in checkpoint:
                    g_optimizer.load_state_dict(checkpoint["optimizer"])

                if "epoch" in checkpoint:
                    start_epoch = checkpoint["epoch"] - 1

                if "global_step" in checkpoint:
                    global_step = checkpoint["global_step"]

                if "loss" in checkpoint:
                    loss = checkpoint["loss"]

                    if is_local_rank_0:
                        logging.info(f"Model has loss of {loss}")

        else:
            if is_local_rank_0:
                logging.info(f"No checkpoint found  at {path} with tag {tag}.")

    return G, g_optimizer, global_step, start_epoch, loss
示例#3
0
    def init_weights(self, pretrained=None):
        if pretrained is not None:
            if isinstance(pretrained, str) and os.path.isfile(pretrained):
                logger.info(
                    '=> loading pretrained model {}'.format(pretrained))
                pretrained_state_dict = torch.load(pretrained)
            else:
                logger.info('=> loading pretrained model from web')
                pretrained_state_dict = pretrained

            logger.info('=> init deconv weights from normal distribution')
            for name, m in self.deconv_layers.named_modules():
                if isinstance(m, nn.ConvTranspose2d):
                    logger.info(
                        '=> init {}.weight as normal(0, 0.001)'.format(name))
                    logger.info('=> init {}.bias as 0'.format(name))
                    nn.init.normal_(m.weight, std=0.001)
                    if self.deconv_with_bias:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    logger.info('=> init {}.weight as 1'.format(name))
                    logger.info('=> init {}.bias as 0'.format(name))
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
            logger.info('=> init final conv weights from normal distribution')
            for m in self.final_layer.modules():
                if isinstance(m, nn.Conv2d):
                    # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                    logger.info(
                        '=> init {}.weight as normal(0, 0.001)'.format(name))
                    logger.info('=> init {}.bias as 0'.format(name))
                    nn.init.normal_(m.weight, std=0.001)
                    nn.init.constant_(m.bias, 0)
            #load_state_dict(self, pretrained_state_dict, prefix='resnet.')
            #load_state_dict(self, pretrained_state_dict, prefix='backbone.')
            load_state_dict(
                self,
                pretrained_state_dict,
                strict=False,
                ignored_layers=['final_layer.bias', 'final_layer.weight'],
                prefix=cfg.WEIGHTS_PREFIX,
                prefix_replace=cfg.WEIGHTS_PREFIX_REPLACE)
            #self.load_state_dict(pretrained_state_dict, strict=False)
        else:
            logger.info('=> init weights from normal distribution')
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                    nn.init.normal_(m.weight, std=0.001)
                    # nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.ConvTranspose2d):
                    nn.init.normal_(m.weight, std=0.001)
                    if self.deconv_with_bias:
                        nn.init.constant_(m.bias, 0)
 def _load_model(self, checkpoint, prefix="module.", prefix_replace=""):
     load_state_dict(self.model, checkpoint.pop("model"), prefix=prefix, prefix_replace=prefix_replace)
示例#5
0
 def load_model_only(self, f):
     checkpoint = self._load_file(f)
     if "model" in checkpoint:
         self.logger.info("Loading model from {}".format(f))
         load_state_dict(self.model, checkpoint.pop("model"))
示例#6
0
 def _load_model(self, checkpoint, no_head):
     load_state_dict(self.model, checkpoint.pop("model"), no_head)