def load(self, f=None): if self.has_checkpoint(): # override argument with existing checkpoint f = self.get_checkpoint_file() if not f: # no checkpoint could be found self.logger.info( "No checkpoint found. Initializing model from scratch") return {} self.logger.info("Loading checkpoint from {}".format(f)) checkpoint = self._load_file(f) # self._load_model(checkpoint) if "model" in checkpoint: self.logger.info("Loading model from {}".format(f)) load_state_dict(self.model, checkpoint.pop("model")) if "optimizer" in checkpoint and self.optimizer: self.logger.info("Loading optimizer from {}".format(f)) self.optimizer.load_state_dict(checkpoint.pop("optimizer")) if "scheduler" in checkpoint and self.scheduler: self.logger.info("Loading scheduler from {}".format(f)) self.scheduler.load_state_dict(checkpoint.pop("scheduler")) if 'epoch' in checkpoint: self.logger.info(checkpoint['epoch']) # return any further checkpoint data return checkpoint
def load_models( G: "nn.Module" = None, g_optimizer: "optim" = None, args: "tupperware" = None, tag: str = "latest", is_local_rank_0: bool = True, ) -> "Union[nn.Module, optim, int, int, int]": latest_path = args.ckpt_dir / args.save_filename_latest_G best_path = args.ckpt_dir / args.save_filename_G if tag == "latest": path = latest_path if not path.exists(): path = best_path tag = "best" elif tag == "best": path = best_path if not path.exists(): path = latest_path tag = "latest" # Defaults start_epoch = 0 global_step = 0 loss = 1e6 if args.resume: if path.is_file(): checkpoint = torch.load(path, map_location=torch.device("cpu")) if is_local_rank_0: logging.info(f"Loading checkpoint from {path} with tag {tag}.") load_state_dict(G, checkpoint["state_dict"]) # G.load_state_dict(checkpoint["state_dict"]) if not args.finetune: if g_optimizer and "optimizer" in checkpoint: g_optimizer.load_state_dict(checkpoint["optimizer"]) if "epoch" in checkpoint: start_epoch = checkpoint["epoch"] - 1 if "global_step" in checkpoint: global_step = checkpoint["global_step"] if "loss" in checkpoint: loss = checkpoint["loss"] if is_local_rank_0: logging.info(f"Model has loss of {loss}") else: if is_local_rank_0: logging.info(f"No checkpoint found at {path} with tag {tag}.") return G, g_optimizer, global_step, start_epoch, loss
def init_weights(self, pretrained=None): if pretrained is not None: if isinstance(pretrained, str) and os.path.isfile(pretrained): logger.info( '=> loading pretrained model {}'.format(pretrained)) pretrained_state_dict = torch.load(pretrained) else: logger.info('=> loading pretrained model from web') pretrained_state_dict = pretrained logger.info('=> init deconv weights from normal distribution') for name, m in self.deconv_layers.named_modules(): if isinstance(m, nn.ConvTranspose2d): logger.info( '=> init {}.weight as normal(0, 0.001)'.format(name)) logger.info('=> init {}.bias as 0'.format(name)) nn.init.normal_(m.weight, std=0.001) if self.deconv_with_bias: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): logger.info('=> init {}.weight as 1'.format(name)) logger.info('=> init {}.bias as 0'.format(name)) nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) logger.info('=> init final conv weights from normal distribution') for m in self.final_layer.modules(): if isinstance(m, nn.Conv2d): # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') logger.info( '=> init {}.weight as normal(0, 0.001)'.format(name)) logger.info('=> init {}.bias as 0'.format(name)) nn.init.normal_(m.weight, std=0.001) nn.init.constant_(m.bias, 0) #load_state_dict(self, pretrained_state_dict, prefix='resnet.') #load_state_dict(self, pretrained_state_dict, prefix='backbone.') load_state_dict( self, pretrained_state_dict, strict=False, ignored_layers=['final_layer.bias', 'final_layer.weight'], prefix=cfg.WEIGHTS_PREFIX, prefix_replace=cfg.WEIGHTS_PREFIX_REPLACE) #self.load_state_dict(pretrained_state_dict, strict=False) else: logger.info('=> init weights from normal distribution') for m in self.modules(): if isinstance(m, nn.Conv2d): # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.normal_(m.weight, std=0.001) # nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.ConvTranspose2d): nn.init.normal_(m.weight, std=0.001) if self.deconv_with_bias: nn.init.constant_(m.bias, 0)
def _load_model(self, checkpoint, prefix="module.", prefix_replace=""): load_state_dict(self.model, checkpoint.pop("model"), prefix=prefix, prefix_replace=prefix_replace)
def load_model_only(self, f): checkpoint = self._load_file(f) if "model" in checkpoint: self.logger.info("Loading model from {}".format(f)) load_state_dict(self.model, checkpoint.pop("model"))
def _load_model(self, checkpoint, no_head): load_state_dict(self.model, checkpoint.pop("model"), no_head)