Example #1
0
    def load_checkpoint(
        self, config, checkpoint_path, eval=False
    ):  # pylint: disable=unused-argument, redefined-builtin
        """Load model checkpoint and set up internals.

        Args:
            config (Coqpi): model configuration.
            checkpoint_path (str): path to checkpoint file.
            eval (bool): whether to load model for evaluation.
        """
        state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
        self.load_state_dict(state["model"])
        # TODO: set r in run-time by taking it from the new config
        if "r" in state:
            # set r from the state (for compatibility with older checkpoints)
            self.decoder.set_r(state["r"])
        elif "config" in state:
            # set r from config used at training time (for inference)
            self.decoder.set_r(state["config"]["r"])
        else:
            # set r from the new config (for new-models)
            self.decoder.set_r(config.r)
        if eval:
            self.eval()
            print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")
            assert not self.training
 def load_checkpoint(self, config, checkpoint_path, eval=False):  # pylint: disable=unused-argument, redefined-builtin
     state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
     self.load_state_dict(state["model"])
     if eval:
         self.eval()
         assert not self.training
         self.remove_weight_norm()
Example #3
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global meta_data_train
    global meta_data_eval

    ap = AudioProcessor(**c.audio)
    model = setup_speaker_encoder_model(c)

    optimizer = RAdam(model.parameters(), lr=c.lr)

    # pylint: disable=redefined-outer-name
    meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False)

    data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True)

    if c.loss == "ge2e":
        criterion = GE2ELoss(loss_method="softmax")
    elif c.loss == "angleproto":
        criterion = AngleProtoLoss()
    elif c.loss == "softmaxproto":
        criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_speakers)
    else:
        raise Exception("The %s  not is a loss supported" % c.loss)

    if args.restore_path:
        checkpoint = load_fsspec(args.restore_path)
        try:
            model.load_state_dict(checkpoint["model"])

            if "criterion" in checkpoint:
                criterion.load_state_dict(checkpoint["criterion"])

        except (KeyError, RuntimeError):
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint["model"], c)
            model.load_state_dict(model_dict)
            del model_dict
        for group in optimizer.param_groups:
            group["lr"] = c.lr

        print(" > Model restored from step %d" % checkpoint["step"], flush=True)
        args.restore_step = checkpoint["step"]
    else:
        args.restore_step = 0

    if c.lr_decay:
        scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
    else:
        scheduler = None

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    if use_cuda:
        model = model.cuda()
        criterion.cuda()

    global_step = args.restore_step
    _, global_step = train(model, optimizer, scheduler, criterion, data_loader, global_step)
Example #4
0
 def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
     state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
     self.load_state_dict(state["model"])
     if use_cuda:
         self.cuda()
     if eval:
         self.eval()
         assert not self.training
Example #5
0
    def load_checkpoint(self,
                        config: Coqpit,
                        checkpoint_path: str,
                        eval: bool = False,
                        use_cuda: bool = False,
                        criterion=None):
        state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
        try:
            self.load_state_dict(state["model"])
        except (KeyError, RuntimeError) as error:
            # If eval raise the error
            if eval:
                raise error

            print(" > Partial model initialization.")
            model_dict = self.state_dict()
            model_dict = set_init_dict(model_dict, state["model"], c)
            self.load_state_dict(model_dict)
            del model_dict

        # load the criterion for restore_path
        if criterion is not None and "criterion" in state:
            try:
                criterion.load_state_dict(state["criterion"])
            except (KeyError, RuntimeError) as error:
                print(" > Criterion load ignored because of:", error)

        # instance and load the criterion for the encoder classifier in inference time
        if (eval and criterion is None and "criterion" in state and getattr(
                config, "map_classid_to_classname", None) is not None):
            criterion = self.get_criterion(
                config, len(config.map_classid_to_classname))
            criterion.load_state_dict(state["criterion"])

        if use_cuda:
            self.cuda()
            if criterion is not None:
                criterion = criterion.cuda()

        if eval:
            self.eval()
            assert not self.training

        if not eval:
            return criterion, state["step"]
        return criterion
Example #6
0
 def load_checkpoint(self, config, checkpoint_path, eval=False):  # pylint: disable=unused-argument, redefined-builtin
     state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
     self.load_state_dict(state["model"])
     if eval:
         self.eval()
         assert not self.training
         if self.config.model_params.use_weight_norm:
             self.remove_weight_norm()
         betas = np.linspace(
             config["test_noise_schedule"]["min_val"],
             config["test_noise_schedule"]["max_val"],
             config["test_noise_schedule"]["num_steps"],
         )
         self.compute_noise_level(betas)
     else:
         betas = np.linspace(
             config["train_noise_schedule"]["min_val"],
             config["train_noise_schedule"]["max_val"],
             config["train_noise_schedule"]["num_steps"],
         )
         self.compute_noise_level(betas)
Example #7
0
    def load_checkpoint(
            self,
            config: Coqpit,
            checkpoint_path: str,
            eval: bool = False,  # pylint: disable=unused-argument, redefined-builtin
    ) -> None:
        """Load a GAN checkpoint and initialize model parameters.

        Args:
            config (Coqpit): Model config.
            checkpoint_path (str): Checkpoint file path.
            eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
        """
        state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
        # band-aid for older than v0.0.15 GAN models
        if "model_disc" in state:
            self.model_g.load_checkpoint(config, checkpoint_path, eval)
        else:
            self.load_state_dict(state["model"])
            if eval:
                self.model_d = None
                if hasattr(self.model_g, "remove_weight_norm"):
                    self.model_g.remove_weight_norm()
                    type=str,
                    help="Path to config file of torch model.")
parser.add_argument(
    "--output_path",
    type=str,
    help="path to output file including file name to save TF model.")
args = parser.parse_args()

# load model config
config_path = args.config_path
c = load_config(config_path)
num_speakers = 0

# init torch model
model = setup_generator(c)
checkpoint = load_fsspec(args.torch_model_path,
                         map_location=torch.device("cpu"))
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
model.remove_weight_norm()
state_dict = model.state_dict()

# init tf model
model_tf = setup_tf_generator(c)

common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE"
# get tf_model graph by passing an input
# B x D x T
dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32)
mel_pred = model_tf(dummy_input, training=False)

# get tf variables
Example #9
0
def get_last_checkpoint(path: str) -> Tuple[str, str]:
    """Get latest checkpoint or/and best model in path.

    It is based on globbing for `*.pth.tar` and the RegEx
    `(checkpoint|best_model)_([0-9]+)`.

    Args:
        path: Path to files to be compared.

    Raises:
        ValueError: If no checkpoint or best_model files are found.

    Returns:
        Path to the last checkpoint
        Path to best checkpoint
    """
    fs = fsspec.get_mapper(path).fs
    file_names = fs.glob(os.path.join(path, "*.pth.tar"))
    scheme = urlparse(path).scheme
    if scheme:  # scheme is not preserved in fs.glob, add it back
        file_names = [scheme + "://" + file_name for file_name in file_names]
    last_models = {}
    last_model_nums = {}
    for key in ["checkpoint", "best_model"]:
        last_model_num = None
        last_model = None
        # pass all the checkpoint files and find
        # the one with the largest model number suffix.
        for file_name in file_names:
            match = re.search(f"{key}_([0-9]+)", file_name)
            if match is not None:
                model_num = int(match.groups()[0])
                if last_model_num is None or model_num > last_model_num:
                    last_model_num = model_num
                    last_model = file_name

        # if there is no checkpoint found above
        # find the checkpoint with the latest
        # modification date.
        key_file_names = [fn for fn in file_names if key in fn]
        if last_model is None and len(key_file_names) > 0:
            last_model = max(key_file_names, key=os.path.getctime)
            last_model_num = load_fsspec(last_model)["step"]

        if last_model is not None:
            last_models[key] = last_model
            last_model_nums[key] = last_model_num

    # check what models were found
    if not last_models:
        raise ValueError(f"No models found in continue path {path}!")
    if "checkpoint" not in last_models:  # no checkpoint just best model
        last_models["checkpoint"] = last_models["best_model"]
    elif "best_model" not in last_models:  # no best model
        # this shouldn't happen, but let's handle it just in case
        last_models["best_model"] = last_models["checkpoint"]
    # finally check if last best model is more recent than checkpoint
    elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
        last_models["checkpoint"] = last_models["best_model"]

    return last_models["checkpoint"], last_models["best_model"]