def miso_optional(fn, h_opt): def substitution_fn(opt): return fn() if opt else MISOIdentity().get_io() return mo.substitution_module("MISOOptional", {'opt': h_opt}, substitution_fn, ['In0', 'In1'], ['Out'], scope=None)
def create_motif( dh, full_hparam_dict, motif_info, level, num_channels, ): if level == 1: return [ lambda: conv2d_cell(num_channels, 1), # 1x1 conv of C channels lambda: depthwise_conv2d_cell(3), # 3x3 depthwise conv lambda: separable_conv2d_cell(num_channels, 3 ), # 3x3 sep conv of C channels lambda: max_pooling(3), # 3x3 max pool lambda: average_pooling(3), # 3x3 avg pool lambda: mo.identity() ][dh['operation'] - 1]() def substitution_fn(dh): num_nodes = motif_info[level]['num_nodes'] ops = [[] for _ in range(motif_info[level]['num_nodes'])] ops[0].append(mo.identity()) output_ops = [] for out_node in range(num_nodes): for in_node in range(out_node): op_id = dh[ut.json_object_to_json_string({ "out_node_id": out_node, "in_node_id": in_node, })] if op_id > 0: if level == 2: next_hparams = {'operation': op_id} else: next_hparams = full_hparam_dict[level - 1][op_id - 1] ops[out_node].append( create_motif( next_hparams, full_hparam_dict, motif_info, level - 1, num_channels, )) ops[out_node][-1][0]['in'].connect( output_ops[in_node][1]['out']) assert (len(ops[out_node]) > 0) concat_ins, concat_out = combine_with_concat( len(ops[out_node]), num_channels) for ix, (ins, outs) in enumerate(ops[out_node]): outs['out'].connect(concat_ins['in%d' % ix]) output_ops.append((concat_ins, concat_out)) output = output_ops[-1][1] if level == 2: conv_ins, conv_outs = conv2d_cell(num_channels, 1) output['out'].connect(conv_ins['in']) output = conv_outs return ops[0][0][0], output return mo.substitution_module('Motif_Level_%d' % level, substitution_fn, dh, ['in'], ['out'], None)
def genetic_stage(input_fn, node_fn, output_fn, h_connections, num_nodes): def substitution_fn(dh): num_ins = [ sum([dh['%d_%d' % (in_id, out_id)] for in_id in range(1, out_id)]) for out_id in range(1, num_nodes + 1) ] num_outs = [ sum([ dh['%d_%d' % (in_id, out_id)] for out_id in range(in_id + 1, num_nodes + 1) ]) for in_id in range(1, num_nodes + 1) ] if sum(num_ins) == 0: return input_fn() nodes = [input_fn()] nodes += [ node_fn(max(num_ins[i], 1)) if num_ins[i] > 0 or num_outs[i] > 0 else None for i in range(num_nodes) ] nodes.append( output_fn( len([ i for i in range(len(num_outs)) if num_outs[i] == 0 and nodes[i + 1] is not None ]))) num_connected = [0] * (num_nodes + 2) for in_id in range(1, num_nodes + 1): # Connect nodes with no input to original input if num_ins[in_id - 1] == 0 and nodes[in_id] is not None: nodes[0][1]['out'].connect(nodes[in_id][0]['in0']) # Connect nodes with no output to final output node if num_outs[in_id - 1] == 0 and nodes[in_id] is not None: nodes[in_id][1]['out'].connect(nodes[-1][0]['in%d' % num_connected[-1]]) num_connected[-1] += 1 # Connect internal nodes for out_id in range(in_id + 1, num_nodes + 1): if dh['%d_%d' % (in_id, out_id)] == 1: nodes[in_id][1]['out'].connect( nodes[out_id][0]['in' + str(num_connected[out_id])]) num_connected[out_id] += 1 return nodes[0][0], nodes[-1][1] i = 0 name_to_hparam = {} for ix, hparam in enumerate( itertools.combinations(range(1, num_nodes + 1), 2)): name_to_hparam['%d_%d' % hparam] = h_connections[ix] return mo.substitution_module('GeneticStage', substitution_fn, name_to_hparam, ['in'], ['out'], scope=None)
def miso_optional(fn, h_opt): def substitution_fn(dh): return fn() if dh["opt"] else MISOIdentity().get_io() return mo.substitution_module("MISOOptional", substitution_fn, {'opt': h_opt}, ['in0', 'in1'], ['out'], scope=None)
def motif(submotif_fn, num_nodes): assert num_nodes >= 1 def substitution_fn(dh): print dh node_id_to_node_ids_used = {i: [i - 1] for i in range(1, num_nodes)} for name, v in dh.items(): if v: d = ut.json_string_to_json_object(name) i = d["node_id"] node_ids_used = node_id_to_node_ids_used[i] j = d["in_node_id"] node_ids_used.append(j) for i in range(1, num_nodes): node_id_to_node_ids_used[i] = sorted(node_id_to_node_ids_used[i]) print node_id_to_node_ids_used (inputs, outputs) = mo.identity() node_id_to_outputs = [outputs] in_inputs = inputs for i in range(1, num_nodes): node_ids_used = node_id_to_node_ids_used[i] num_edges = len(node_ids_used) outputs_lst = [] for j in node_ids_used: inputs, outputs = submotif_fn() j_outputs = node_id_to_outputs[j] inputs["in"].connect(j_outputs["out"]) outputs_lst.append(outputs) # if necessary, concatenate the results going into a node if num_edges > 1: c_inputs, c_outputs = combine_with_concat(num_edges) for idx, outputs in enumerate(outputs_lst): c_inputs["in%d" % idx].connect(outputs["out"]) else: c_outputs = outputs_lst[0] node_id_to_outputs.append(c_outputs) out_outputs = node_id_to_outputs[-1] return in_inputs, out_outputs name_to_hyperp = { ut.json_object_to_json_string({ "node_id": i, "in_node_id": j }): D([0, 1]) for i in range(1, num_nodes) for j in range(i - 1) } return mo.substitution_module( "Motif", substitution_fn, name_to_hyperp, ["in"], ["out"], scope=None)
def enas_space(h_num_layers, out_filters, fn_first, fn_repeats, input_names, output_names, weight_sharer, scope=None): def substitution_fn(dh): assert dh["num_layers"] > 0 inputs, outputs = fn_first() temp_outputs = OrderedDict(outputs) for i in range(1, dh["num_layers"] + 1): inputs, temp_outputs = fn_repeats(inputs, temp_outputs, i, out_filters, weight_sharer) return inputs, OrderedDict( {'out': temp_outputs['out' + str(len(temp_outputs) - 1)]}) return mo.substitution_module('ENASModule', substitution_fn, {'num_layers': h_num_layers}, input_names, output_names, scope)
def cell(input_fn, node_fn, combine_fn, unused_combine_fn, num_nodes, hyperparameters): def substitution_fn(**dh): c_ins, c_outs = input_fn() nodes = [c_outs['Out0'], c_outs['Out1']] used_node = [False] * (num_nodes + 2) for i in range(num_nodes): # Get indices of hidden states to be combined idx0 = dh[str(i) + '_0'] idx1 = dh[str(i) + '_1'] # Transform hidden states h0 = node_fn(idx0, i, 0) h1 = node_fn(idx1, i, 1) h0[0]['In'].connect(nodes[idx0]) h1[0]['In'].connect(nodes[idx1]) used_node[idx0] = used_node[idx1] = True # Combine hidden states h = combine_fn() h[0]['In0'].connect(h0[1]['Out']) h[0]['In1'].connect(h1[1]['Out']) nodes.append(h[1]['Out']) ins, outs = unused_combine_fn(sum(not used for used in used_node)) input_id = 0 for ix, node in enumerate(nodes): if not used_node[ix]: ins['In' + str(input_id)].connect(node) input_id += 1 return c_ins, outs name_to_hyperp = { '%d_%d' % (i // 2, i % 2): hyperparameters[i] for i in range(len(hyperparameters)) } return mo.substitution_module('Cell', name_to_hyperp, substitution_fn, ['In0', 'In1'], ['Out'], None)
def cell(input_fn, node_fn, output_fn, h_connections, num_nodes, channels): def substitution_fn(dh): num_ins = [ sum([dh['%d_%d' % (in_id, out_id)] for in_id in range(out_id)]) for out_id in range(num_nodes + 2) ] num_outs = [ sum([ dh['%d_%d' % (in_id, out_id)] for out_id in range(in_id + 1, num_nodes + 2) ]) for in_id in range(num_nodes + 2) ] for i in range(1, num_nodes + 1): if num_outs[i] > 0 and num_ins[i] == 0: raise ValueError('Node exists with no path to input') if num_ins[i] > 0 and num_outs[i] == 0: raise ValueError('Node exists with no path to output') if num_ins[-1] == 0: raise ValueError('No path exists between input and output') if sum(num_ins) > 9: raise ValueError('More than 9 edges') if dh['%d_%d' % (0, num_nodes + 1)]: num_ins[-1] -= 1 # Compute the number of channels that each vertex outputs int_channels = channels // num_ins[-1] correction = channels % num_ins[-1] vertex_channels = [0] * (num_nodes + 2) vertex_channels[-1] = channels # Distribute channels as evenly as possible among vertices connected # to output vertex for in_id in range(1, num_nodes + 1): if dh['%d_%d' % (in_id, num_nodes + 1)]: vertex_channels[in_id] = int_channels if correction > 0: vertex_channels[in_id] += 1 correction -= 1 # For all other vertices, the number of channels they output is the max # of the channels outputed by all vertices they flow into for in_id in range(num_nodes - 1, 0, -1): if not dh['%d_%d' % (in_id, num_nodes + 1)]: for out_id in range(in_id + 1, num_nodes + 1): if dh['%d_%d' % (in_id, out_id)]: vertex_channels[in_id] = max(vertex_channels[in_id], vertex_channels[out_id]) nodes = [mo.identity()] nodes += [ node_fn(num_ins[i], i - 1, vertex_channels[i]) if num_outs[i] > 0 else None for i in range(1, num_nodes + 1) ] nodes.append(output_fn(num_ins[num_nodes + 1])) num_connected = [0] * (num_nodes + 2) # Project input vertex to correct dimensions if used for out_id in range(1, num_nodes + 1): if dh['%d_%d' % (0, out_id)]: proj_in, proj_out = input_fn(vertex_channels[out_id]) nodes[0][1]['out'].connect(proj_in['in']) proj_out['out'].connect( nodes[out_id][0]['in' + str(num_connected[out_id])]) num_connected[out_id] += 1 for in_id in range(1, num_nodes + 1): for out_id in range(in_id + 1, num_nodes + 2): if dh['%d_%d' % (in_id, out_id)]: nodes[in_id][1]['out'].connect( nodes[out_id][0]['in' + str(num_connected[out_id])]) num_connected[out_id] += 1 if dh['%d_%d' % (0, num_nodes + 1)]: proj_in, proj_out = input_fn(channels) add_in, add_out = add(2) nodes[0][1]['out'].connect(proj_in['in']) proj_out['out'].connect(add_in['in0']) nodes[-1][1]['out'].connect(add_in['in1']) nodes[-1] = (add_in, add_out) return nodes[0][0], nodes[-1][1] name_to_hparam = {} for ix, hparam in enumerate(itertools.combinations(range(num_nodes + 2), 2)): name_to_hparam['%d_%d' % hparam] = h_connections[ix] return mo.substitution_module('NasbenchCell', substitution_fn, name_to_hparam, ['in'], ['out'], scope=None)
def conv_stage(filters, kernel_size, num_nodes): def substitution_fn(dh): print dh any_in_stage = any(v for v in itervalues(dh)) if any_in_stage: node_id_to_lst = {} for name, v in iteritems(dh): d = ut.json_string_to_json_object(name) d['use'] = v i = d["node_id"] if i not in node_id_to_lst: node_id_to_lst[i] = [] node_id_to_lst[i].append(d) node_ids_using_any = set() node_ids_used = set() for i, lst in iteritems(node_id_to_lst): for d in lst: if d['use']: node_ids_using_any.add(d["node_id"]) node_ids_used.add(d["in_node_id"]) node_ids_to_ignore = set([ i for i in range(num_nodes) if i not in node_ids_using_any and i not in node_ids_used ]) # creating the stage (in_inputs, in_outputs) = conv2d_cell(filters, kernel_size) node_id_to_outputs = {} for i in range(num_nodes): if i not in node_ids_to_ignore: if i in node_ids_using_any: in_ids = [] for d in node_id_to_lst[i]: if d['use']: in_ids.append(d['in_node_id']) # collecting the inputs for the sum combiner. num_inputs = len(in_ids) if num_inputs > 1: (s_inputs, s_outputs) = combine_with_sum(num_inputs) for idx, j in enumerate(in_ids): j_outputs = node_id_to_outputs[j] s_inputs["in%d" % idx].connect( j_outputs["out"]) else: j = in_ids[0] s_outputs = node_id_to_outputs[j] (n_inputs, n_outputs) = conv2d_cell(filters, kernel_size) n_inputs["in"].connect(s_outputs["out"]) else: (n_inputs, n_outputs) = conv2d_cell(filters, kernel_size) n_inputs["in"].connect(in_outputs['out']) node_id_to_outputs[i] = n_outputs # final connection to the output node in_ids = [] for i in range(num_nodes): if i not in node_ids_to_ignore and i not in node_ids_used: in_ids.append(i) num_inputs = len(in_ids) if num_inputs > 1: (s_inputs, s_outputs) = combine_with_sum(num_inputs) for idx, j in enumerate(in_ids): j_outputs = node_id_to_outputs[j] s_inputs["in%d" % idx].connect(j_outputs["out"]) else: j = in_ids[0] s_outputs = node_id_to_outputs[j] (out_inputs, out_outputs) = conv2d_cell(filters, kernel_size) out_inputs["in"].connect(s_outputs["out"]) return (in_inputs, out_outputs) else: # all zeros encoded. return conv2d_cell(filters, kernel_size) name_to_hyperp = { ut.json_object_to_json_string({ "node_id": i, "in_node_id": j }): D([0, 1]) for i in range(num_nodes) for j in range(i) } return mo.substitution_module("ConvStage", substitution_fn, name_to_hyperp, ["in"], ["out"], scope=None)