Ejemplo n.º 1
0
def load_checkpoint_from_http(
    model,
    filename,
    map_location=None,
):
    checkpointer = Checkpointer(model)
    checkpoint = load_from_http(filename, map_location=map_location)
    
    checkpointer.logger.info("[Checkpointer] Loading from {} ...".format(filename))
    incompatible = checkpointer._load_model(checkpoint={"model": checkpoint})
    
    # handle some existing subclasses that returns None
    if incompatible is not None:
        checkpointer._log_incompatible_keys(incompatible)
Ejemplo n.º 2
0
    def test_load_reused_params(self) -> None:
        class Model(nn.Module):
            def __init__(self, has_y: bool) -> None:
                super().__init__()
                self.x = nn.Linear(10, 10)
                if has_y:
                    self.y = self.x

        model = Model(has_y=False)
        model.x.bias.data.fill_(5.0)  # pyre-ignore
        data = {"model": model.state_dict()}
        new_model = Model(has_y=True)
        chkpt = Checkpointer(new_model)
        chkpt.logger = logger = MagicMock()
        incompatible = chkpt._load_model(data)
        chkpt._log_incompatible_keys(incompatible)
        self.assertTrue(
            torch.allclose(new_model.y.bias - 5.0,
                           torch.zeros_like(new_model.y.bias)))
        logger.info.assert_not_called()