Пример #1
0
def forward_preserve_state(module: DeviceAwareModule,
                           inputs: List[torch.Tensor]) -> torch.Tensor:
    """
    Perform forward pass on input module with given list of torch tensors. The function preserves the random state
    of the backend libraries to avoid reproducibility issues. Additionally, it temporarily sets the model in
    evaluation mode for inference and then restores its previous state.
    :param module: Callable torch module
    :param inputs: List of input torch tensors
    :return output: Output torch tensors
    """
    if not isinstance(inputs, list):
        raise RuntimeError("Inputs object has to be a list of torch tensors")

    if module.is_model_on_gpu():
        inputs = [input_tensor.cuda() for input_tensor in inputs]

    # collect the current state of the model
    is_train = module.training
    module_state = RandomStateSnapshot.snapshot_random_state()

    # set the model in evaluation mode and perform a forward pass
    module.eval()
    with torch.no_grad():
        output = module.forward(*inputs)
    if is_train:
        module.train()

    # restore the seed for torch and numpy
    module_state.restore_random_state()

    return output
    def __init__(self, model: DeviceAwareModule, model_config: ScalarModelBase,
                 epoch: int, pipeline_id: int) -> None:
        """
        :param model: Model recovered from the checkpoint.
        :param model_config: Model configuration information.
        :param epoch: Epoch of the checkpoint which was recovered.
        :param pipeline_id: ID for this pipeline (useful for ensembles).
        :return:
        """
        super().__init__(model_config)
        self.model = model
        self.epoch = epoch
        self.pipeline_id = pipeline_id

        # Switch model to evaluation mode (if not, results will be different from what we got during training,
        # because batchnorm operates differently).
        model.eval()