def generate_type_extension_backend(backend, declarations): env = {} env['Type'] = "{}Type".format(backend) env['Backend'] = backend env['DeviceType'] = backend env['is_extension_backend'] = True env['TypeID'] = 'TypeID::' + backend top_env['type_ids'].append(backend + ',') declarations, definitions = function_wrapper.create_extension_backend( env, declarations) env['type_method_declarations'] = declarations env['type_method_definitions'] = definitions fm = file_manager fm.write(env['Type'] + ".cpp", TYPE_EXTENSION_BACKEND_CPP, env) fm.write(env['Type'] + ".h", TYPE_EXTENSION_BACKEND_H, env) for scalar_name, _, _, _, _ in scalar_types: type_register = TYPE_REGISTER.substitute(backend=env['Backend'], scalar_type=scalar_name, type_name=env['Type']) top_env['cpu_type_registrations'].append(type_register) extension_backend_register_switch = EXTENSION_BACKEND_REGISTER_SWITCH.substitute( env) top_env['extension_backend_register_switches'].append( extension_backend_register_switch) top_env['extension_backend_headers'].append('#include <ATen/{}.h>'.format( env['Type'])) top_env['cpu_type_headers'].append('#include "ATen/{}.h"'.format( env['Type'])) return env
def generate_type_extension_backend(backend, declarations): env = {} env['Type'] = "{}Type".format(backend) env['Backend'] = backend env['DeviceType'] = backend_to_devicetype(backend) env['TypeID'] = 'TypeID::' + backend top_env['type_ids'].append(backend + ',') declarations, definitions, registrations = function_wrapper.create_extension_backend( env, declarations) env['type_method_declarations'] = declarations env['type_method_definitions'] = definitions env['function_registrations'] = registrations top_env['cpu_type_headers'].append('#include "ATen/{}.h"'.format( env['Type'])) file_manager.write(env['Type'] + ".cpp", TYPE_EXTENSION_CPP, env) file_manager.write(env['Type'] + ".h", TYPE_EXTENSION_H, env) extension_backend_register_switch = EXTENSION_BACKEND_REGISTER_SWITCH.substitute( env) top_env['extension_backend_register_switches'].append( extension_backend_register_switch) top_env['extension_backend_headers'].append('#include <ATen/{}.h>'.format( env['Type']))
def generate_type_extension_backend(backend, declarations): env = {} env['Type'] = "{}Type".format(backend) env['Backend'] = backend env['DeviceType'] = backend_to_devicetype(backend) env['TypeID'] = 'TypeID::' + backend top_env['type_ids'].append(backend + ',') declarations, definitions = function_wrapper.create_extension_backend( env, declarations) env['type_method_declarations'] = declarations env['type_method_definitions'] = definitions type_register = TYPE_REGISTER.substitute(backend=env['Backend'], type_name=env['Type']) top_env['cpu_type_headers'].append('#include "ATen/{}.h"'.format(env['Type'])) top_env['cpu_type_registrations'].append(type_register) file_manager.write(env['Type'] + ".cpp", TYPE_EXTENSION_CPP, env) file_manager.write(env['Type'] + ".h", TYPE_EXTENSION_H, env) extension_backend_register_switch = EXTENSION_BACKEND_REGISTER_SWITCH.substitute(env) top_env['extension_backend_register_switches'].append(extension_backend_register_switch) top_env['extension_backend_headers'].append( '#include <ATen/{}.h>'.format(env['Type']))