Ejemplo n.º 1
0
def dropout(g, input, p, train):
    sym_help.assert_training_mode(train, "dropout")
    # in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op
    if not sym_help._training_mode:
        return input
    p = g.op("Constant", value_t=torch.tensor(p))
    r, _ = g.op("Dropout", input, p, outputs=2)
    return r
Ejemplo n.º 2
0
def batch_norm(g, input, weight, bias, running_mean, running_var, training,
               momentum, eps, cudnn_enabled):
    sym_help.assert_training_mode(training, "batch_norm")
    input_sizes = input.type().sizes()

    if weight is None or sym_help._is_none(weight):
        assert len(input_sizes) > 1
        weight_value = torch.tensor(
            [1.] * input_sizes[1]).type('torch.' + input.type().scalarType() +
                                        'Tensor')
        weight = g.op("Constant", value_t=weight_value)
    if bias is None or sym_help._is_none(bias):
        assert len(input_sizes) > 1
        bias_value = torch.tensor(
            [0.] * input_sizes[1]).type('torch.' + input.type().scalarType() +
                                        'Tensor')
        bias = g.op("Constant", value_t=bias_value)

    if not sym_help._training_mode:
        out = g.op("BatchNormalization",
                   input,
                   weight,
                   bias,
                   running_mean,
                   running_var,
                   epsilon_f=eps,
                   momentum_f=1 - momentum,
                   outputs=1)
        return out
    else:
        training_mode = g.op("Constant", value_t=torch.tensor(True))
        res, new_running_mean, new_running_var, saved_mean, saved_var = g.op(
            "BatchNormalization",
            input,
            weight,
            bias,
            running_mean,
            running_var,
            training_mode,
            epsilon_f=eps,
            momentum_f=1 - momentum,
            outputs=5)
        new_running_mean.setType(running_mean.type())
        new_running_var.setType(running_var.type())
        saved_mean.setDebugName("batch_norm_dead_output-" +
                                saved_mean.debugName())
        saved_var.setDebugName("batch_norm_dead_output-" +
                               saved_var.debugName())
        return res