def _create_function(ctx, funtion_proto, batch_size): # todo: arrange weight name for NNC if funtion_proto.type == "Reshape": # if batch_size = -1, something wrong? reshape_shape = (batch_size,) + \ tuple(funtion_proto.reshape_param.shape.dim) function_instance = F.Reshape(ctx, shape=reshape_shape) elif funtion_proto.type == "RepeatStart": raise NotImplementedError("Repeat not supported.") function_instance = F.Identity(ctx) elif funtion_proto.type == "RepeatEnd": raise NotImplementedError("Repeat not supported.") function_instance = F.Identity(ctx) elif funtion_proto.type == "RecurrentOutput": raise NotImplementedError("Recurrent not supported.") function_instance = F.Stack(ctx, axis=funtion_proto.recurrent_param.axis) elif funtion_proto.type == "RecurrentInput": raise NotImplementedError("Recurrent not supported.") function_instance = F.Split(ctx, axis=funtion_proto.recurrent_param.axis) elif funtion_proto.type == "Delay": raise NotImplementedError("Recurrent not supported.") function_instance = F.Identity(ctx) else: function_instance = _create_function_instance(ctx, funtion_proto) return function_instance
def _create_function(ctx, inputs, function_proto, batch_size): # todo: arrange weight name for NNC if function_proto.type == "Reshape": # if batch_size = -1, something wrong? reshape_shape = resolve_reshape_params( inputs, function_proto, batch_size) function_instance = F.Reshape( ctx, shape=reshape_shape, inplace=function_proto.reshape_param.inplace) elif function_proto.type == 'Broadcast': shape = resolve_broadcast_params(inputs, function_proto, batch_size) function_instance = F.Broadcast(ctx, shape=shape) elif function_proto.type == "RepeatStart": raise NotImplementedError("Repeat not supported.") function_instance = F.Identity(ctx) elif function_proto.type == "RepeatEnd": raise NotImplementedError("Repeat not supported.") function_instance = F.Identity(ctx) elif function_proto.type == "RecurrentOutput": raise NotImplementedError("Recurrent not supported.") function_instance = F.Stack( ctx, axis=function_proto.recurrent_param.axis) elif function_proto.type == "RecurrentInput": raise NotImplementedError("Recurrent not supported.") function_instance = F.Split( ctx, axis=function_proto.recurrent_param.axis) elif function_proto.type == "Delay": raise NotImplementedError("Recurrent not supported.") function_instance = F.Identity(ctx) else: function_instance = _create_function_instance(ctx, function_proto) return function_instance
def _create_function(ctx, network, f, variable_index): variable_index_name = ''.join([ '_' + f.repeat_id[index] + '[' + str(i) + ']' for index, i in enumerate(variable_index) ]) variable_index_low_level_name = ''.join([ '_' + f.repeat_id[index] + '[' + str(i) + ']' for index, i in enumerate(variable_index[:-1]) ]) function_name = f.name + variable_index_name if f.type == "RepeatStart": # RepeatStart takes input variable and t-1 variable assert (len(f.input) == 2) if variable_index[-1] == 0: # Input variable if t == 0 input_variable_names = [ f.input[0] if f.input[0] in network.variables else f.input[0] + variable_index_low_level_name ] else: # t-1 variable if t > 0 input_variable_names = [ f.input[1] + variable_index_low_level_name + '_' + f.repeat_param.repeat_id + '[' + str(variable_index[-1] - 1) + ']' ] elif f.type == "RepeatEnd": assert (len(f.input) == 1) input_variable_names = [ f.input[0] + variable_index_name + '_' + f.repeat_param.repeat_id + '[' + str(f.repeat_param.times - 1) + ']' ] elif f.type == "RecurrentInput": if variable_index[-1] > 0: # Create single split function for single RecurrentInput return None, None, None function_name = f.name + variable_index_low_level_name variable_index_name = variable_index_low_level_name input_variable_names = [ v_name if v_name in network.variables else v_name + variable_index_low_level_name for v_name in f.input ] elif f.type == "RecurrentOutput": assert (len(f.input) == 1) input_variable_names = [ f.input[0] + variable_index_name + '_' + f.recurrent_param.repeat_id + '[' + str(v_index) + ']' for v_index in range(f.recurrent_param.length) ] elif f.type == "Delay": assert (len(f.input) == 2 ) # Delay takes t-1 variable and initial value if variable_index[-1] == 0: # Initial value if t == 0 input_variable_names = [ f.input[1] if f.input[1] in network.variables else f.input[1] + variable_index_low_level_name ] else: # t-1 variable if t > 0 input_variable_names = [ f.input[0] + variable_index_low_level_name + '_' + f.recurrent_param.repeat_id + '[' + str(variable_index[-1] - 1) + ']' ] else: v_names = [] for v_name in f.input: for index, i in enumerate(variable_index): v_name = v_name.replace('{' + f.repeat_id[index] + '}', '[' + str(i) + ']') v_names.append(v_name) input_variable_names = [ v_name if v_name in network.variables else v_name + variable_index_name if v_name + variable_index_name in network.variables else v_name + variable_index_low_level_name for v_name in v_names ] inputs = [network.variables[v_name] for v_name in input_variable_names] if f.type == "RecurrentInput": assert (len(inputs) == 1) assert (len(f.output) == 1) output_variable_names = [ f.output[0] + variable_index_low_level_name + '_' + f.recurrent_param.repeat_id + '[' + str(v_index) + ']' for v_index in range(inputs[0].shape[f.recurrent_param.axis]) ] else: output_variable_names = [ v_name + variable_index_name if v_name + variable_index_name in network.variables else v_name for v_name in f.output ] outputs = [network.variables[v_name] for v_name in output_variable_names] persistent = True if f.type == "Reshape": shape = resolve_reshape_params(inputs, f, network.batch_size) function_instance = F.Reshape(ctx, shape=shape, inplace=True) elif f.type == "RepeatStart": function_instance = F.Identity(ctx) persistent = False elif f.type == "RepeatEnd": function_instance = F.Identity(ctx) persistent = False elif f.type == "RecurrentOutput": function_instance = F.Stack(ctx, axis=f.recurrent_param.axis) elif f.type == "RecurrentInput": function_instance = F.Split(ctx, axis=f.recurrent_param.axis) elif f.type == "Delay": function_instance = F.Identity(ctx) persistent = False elif f.type == "Broadcast": shape = resolve_broadcast_params(inputs, f, network.batch_size) function_instance = F.Broadcast(ctx, shape) else: function_instance = _create_function_instance(ctx, f) # Prepare link structure class Function: pass function = Function() function.name = function_name function.function_instance = function_instance function.inputs = list(inputs) function.outputs = list(outputs) function.persistent = persistent return function, input_variable_names, output_variable_names
def _create_function(ctx, network, f, variable_index): variable_index_name = ''.join( ['_' + f.repeat_id[index] + '[' + str(i) + ']' for index, i in enumerate(variable_index)]) variable_index_low_level_name = ''.join( ['_' + f.repeat_id[index] + '[' + str(i) + ']' for index, i in enumerate(variable_index[:-1])]) function_name = f.name + variable_index_name if f.type == "RepeatStart": # RepeatStart takes input variable and t-1 variable assert(len(f.input) == 2) if variable_index[-1] == 0: # Input variable if t == 0 input_variable_names = [f.input[0] if f.input[ 0] in network.variables else f.input[0] + variable_index_low_level_name] else: # t-1 variable if t > 0 input_variable_names = [f.input[1] + variable_index_low_level_name + '_' + f.repeat_param.repeat_id + '[' + str(variable_index[-1] - 1) + ']'] elif f.type == "RepeatEnd": assert(len(f.input) == 1) input_variable_names = [f.input[0] + variable_index_name + '_' + f.repeat_param.repeat_id + '[' + str(f.repeat_param.times - 1) + ']'] elif f.type == "RecurrentInput": if variable_index[-1] > 0: # Create single split function for single RecurrentInput return None, None, None function_name = f.name + variable_index_low_level_name variable_index_name = variable_index_low_level_name input_variable_names = [v_name if v_name in network.variables else v_name + variable_index_low_level_name for v_name in f.input] elif f.type == "RecurrentOutput": assert(len(f.input) == 1) input_variable_names = [f.input[0] + variable_index_name + '_' + f.recurrent_param.repeat_id + '[' + str(v_index) + ']' for v_index in range(f.recurrent_param.length)] elif f.type == "Delay": assert(len(f.input) == 2) # Delay takes t-1 variable and initial value if variable_index[-1] == 0: # Initial value if t == 0 input_variable_names = [f.input[1] if f.input[ 1] in network.variables else f.input[1] + variable_index_low_level_name] else: # t-1 variable if t > 0 input_variable_names = [f.input[0] + variable_index_low_level_name + '_' + f.recurrent_param.repeat_id + '[' + str(variable_index[-1] - 1) + ']'] else: input_variable_names = [v_name if v_name in network.variables else v_name + variable_index_name if v_name + variable_index_name in network.variables else v_name + variable_index_low_level_name for v_name in f.input] inputs = [network.variables[v_name] for v_name in input_variable_names] if f.type == "RecurrentInput": assert(len(inputs) == 1) assert(len(f.output) == 1) output_variable_names = [f.output[0] + variable_index_low_level_name + '_' + f.recurrent_param.repeat_id + '[' + str(v_index) + ']' for v_index in range(inputs[0].shape[f.recurrent_param.axis])] else: output_variable_names = [v_name + variable_index_name if v_name + variable_index_name in network.variables else v_name for v_name in f.output] outputs = [network.variables[v_name] for v_name in output_variable_names] if f.type == "Reshape": reshape_shape = (network.batch_size,) + \ tuple(f.reshape_param.shape.dim) function_instance = F.Reshape(ctx, shape=reshape_shape) elif f.type == "RepeatStart": function_instance = F.Identity(ctx) elif f.type == "RepeatEnd": function_instance = F.Identity(ctx) elif f.type == "RecurrentOutput": function_instance = F.Stack(ctx, axis=f.recurrent_param.axis) elif f.type == "RecurrentInput": function_instance = F.Split(ctx, axis=f.recurrent_param.axis) elif f.type == "Delay": function_instance = F.Identity(ctx) else: function_instance = _create_function_instance(ctx, f) # Prepare link structure class Function: pass function = Function() function.name = function_name function.function_instance = function_instance function.inputs = list(inputs) function.outputs = list(outputs) return function, input_variable_names, output_variable_names