Пример #1
0
def batch_norm(g, input, weight, bias, running_mean, running_var, training,
               momentum, eps, cudnn_enabled):

    if torch.is_autocast_enabled() and \
            not args_have_same_dtype([input, weight, bias, running_mean, running_var]) and \
            sym_help._export_onnx_opset_version < 15:
        return sym_help._onnx_opset_unsupported_detailed(
            "BatchNormalization", 14, 15,
            "All input tensors must have the same `dtype`."
            " Turn off Autocast or export using opset version 15.")

    sym_help.check_training_mode(training, "batch_norm")
    weight, bias, running_mean, running_var = sym_help._batchnorm_helper(
        g, input, weight, bias, running_mean, running_var)
    out = g.op("BatchNormalization",
               input,
               weight,
               bias,
               running_mean,
               running_var,
               epsilon_f=eps,
               momentum_f=1 - momentum,
               training_mode_i=0 if not training else 1,
               outputs=1 if not training else 3)
    if not training:
        return out
    else:
        res, new_running_mean, new_running_var = out
        new_running_mean.setType(running_mean.type())
        new_running_var.setType(running_var.type())
        return res
Пример #2
0
def dropout(g, input, p, train):
    symbolic_helper.check_training_mode(train, "dropout")
    # if train is False, dropout is no-op
    if not train:
        return input
    p = g.op("Constant", value_t=torch.tensor(p))
    t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool))
    r, _ = g.op("Dropout", input, p, t, outputs=2)
    return r
Пример #3
0
 def test_check_training_mode_warns_when(
     self,
     op_train_mode: int,
     export_mode: torch.onnx.TrainingMode,
 ):
     with self.assertWarnsRegex(
         UserWarning, f"ONNX export mode is set to {export_mode}"
     ):
         GLOBALS.training_mode = export_mode
         symbolic_helper.check_training_mode(op_train_mode, "testop")
Пример #4
0
def dropout(g, input, p, train):
    sym_help.check_training_mode(train, "dropout")
    # in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op
    if not train:
        return input
    warnings.warn("Dropout is a training op and should not be exported in inference mode. "
                  "For inference, make sure to call eval() on the model and to export it with param training=False.")
    p = g.op("Constant", value_t=torch.tensor(p))
    t = g.op("Constant", value_t=torch.tensor(True))
    r, _ = g.op("Dropout", input, p, t, outputs=2)
    return r
Пример #5
0
def _dropout_returns_masked_input_and_mask(
        g, input: torch._C.Value, p: float,
        train: bool) -> Tuple[torch._C.Value, Optional[torch._C.Value]]:
    symbolic_helper.check_training_mode(train, "dropout")
    # In eval mode, dropout is non-op. That is, if the node's
    # train param is set to False, dropout just returns its inputs.
    if not train:
        return input, None
    p = g.op("Constant", value_t=torch.tensor(p))
    t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool))
    r, mask = g.op("Dropout", input, p, t, outputs=2)
    return r, mask
Пример #6
0
def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled):
    sym_help.check_training_mode(training, "batch_norm")
    weight, bias, running_mean, running_var = sym_help._batchnorm_helper(g, input, weight, bias, running_mean, running_var)
    out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var,
               epsilon_f=eps,
               momentum_f=1 - momentum,
               training_mode_i=0 if not training else 1,
               outputs=1 if not training else 3)
    if not training:
        return out
    else:
        res, new_running_mean, new_running_var = out
        new_running_mean.setType(running_mean.type())
        new_running_var.setType(running_var.type())
        return res
Пример #7
0
 def test_check_training_mode_does_not_warn_when(
     self, op_train_mode: int, export_mode: torch.onnx.TrainingMode
 ):
     GLOBALS.training_mode = export_mode
     self.assertNotWarn(
         lambda: symbolic_helper.check_training_mode(op_train_mode, "testop")
     )