예제 #1
0
def emit_assignments(o, env):
    for i, r in enumerate(o['returns']):
        t = RETURN_MAP[r['type'] if not value_is_tensor_type(r) else 'at::Tensor']
        assignment = CT(t).substitute(env, offset=i, output=get_output(o, i))
        check_size_assignment = ASSIGN_CHECK_SIZE_TEMPLATE.substitute(env, offset=i, assignment=assignment)

        env['assignments'].append(check_size_assignment)
예제 #2
0
파일: gen_op.py 프로젝트: Guokr1991/pytorch
# template for each potential operator.
# each operator has an integer 'key' associated with it, and
# a lambda that defines the operator
# non-tensor attributes are created in ${initialization}
# and then saved as arguments to the lambda
# Inputs/Outputs are read inside the lambda
#
# each implementation is defined in a separate method annotated with
# C10_NOINLINE to avoid inlining into the ATenOp constructor, which would
# trigger pathological compile times.
IMPLEMENTATION_TEMPLATE = CT("""\
C10_NOINLINE void implementation_${key}() { // ${name}
    ${initialization}
    run_op = [=] {
        at::AutoNonVariableTypeMode guard;
        ${statements}
        auto the_result = ${invocation};
        ${assignments}
        return true;
    };
}
""")

CASE_TEMPLATE = CT("""\
case ${key}: // ${name}
  implementation_${key}();
  break;
""")

ASSIGN_CHECK_SIZE_TEMPLATE = CT("""\
  if(OutputSize() > ${offset}) {${assignment}}
""")