示例#1
0
def miso_optional(fn, h_opt):
    def substitution_fn(opt):
        return fn() if opt else MISOIdentity().get_io()

    return mo.substitution_module("MISOOptional", {'opt': h_opt},
                                  substitution_fn, ['In0', 'In1'], ['Out'],
                                  scope=None)
def create_motif(
    dh,
    full_hparam_dict,
    motif_info,
    level,
    num_channels,
):
    if level == 1:
        return [
            lambda: conv2d_cell(num_channels, 1),  # 1x1 conv of C channels
            lambda: depthwise_conv2d_cell(3),  # 3x3 depthwise conv
            lambda: separable_conv2d_cell(num_channels, 3
                                          ),  # 3x3 sep conv of C channels
            lambda: max_pooling(3),  # 3x3 max pool
            lambda: average_pooling(3),  # 3x3 avg pool
            lambda: mo.identity()
        ][dh['operation'] - 1]()

    def substitution_fn(dh):
        num_nodes = motif_info[level]['num_nodes']
        ops = [[] for _ in range(motif_info[level]['num_nodes'])]
        ops[0].append(mo.identity())
        output_ops = []
        for out_node in range(num_nodes):
            for in_node in range(out_node):
                op_id = dh[ut.json_object_to_json_string({
                    "out_node_id": out_node,
                    "in_node_id": in_node,
                })]
                if op_id > 0:
                    if level == 2:
                        next_hparams = {'operation': op_id}
                    else:
                        next_hparams = full_hparam_dict[level - 1][op_id - 1]
                    ops[out_node].append(
                        create_motif(
                            next_hparams,
                            full_hparam_dict,
                            motif_info,
                            level - 1,
                            num_channels,
                        ))
                    ops[out_node][-1][0]['in'].connect(
                        output_ops[in_node][1]['out'])
            assert (len(ops[out_node]) > 0)
            concat_ins, concat_out = combine_with_concat(
                len(ops[out_node]), num_channels)
            for ix, (ins, outs) in enumerate(ops[out_node]):
                outs['out'].connect(concat_ins['in%d' % ix])
            output_ops.append((concat_ins, concat_out))
        output = output_ops[-1][1]
        if level == 2:
            conv_ins, conv_outs = conv2d_cell(num_channels, 1)
            output['out'].connect(conv_ins['in'])
            output = conv_outs
        return ops[0][0][0], output

    return mo.substitution_module('Motif_Level_%d' % level, substitution_fn,
                                  dh, ['in'], ['out'], None)
示例#3
0
def genetic_stage(input_fn, node_fn, output_fn, h_connections, num_nodes):

    def substitution_fn(dh):
        num_ins = [
            sum([dh['%d_%d' % (in_id, out_id)]
                 for in_id in range(1, out_id)])
            for out_id in range(1, num_nodes + 1)
        ]
        num_outs = [
            sum([
                dh['%d_%d' % (in_id, out_id)]
                for out_id in range(in_id + 1, num_nodes + 1)
            ])
            for in_id in range(1, num_nodes + 1)
        ]
        if sum(num_ins) == 0:
            return input_fn()

        nodes = [input_fn()]
        nodes += [
            node_fn(max(num_ins[i], 1))
            if num_ins[i] > 0 or num_outs[i] > 0 else None
            for i in range(num_nodes)
        ]
        nodes.append(
            output_fn(
                len([
                    i for i in range(len(num_outs))
                    if num_outs[i] == 0 and nodes[i + 1] is not None
                ])))

        num_connected = [0] * (num_nodes + 2)
        for in_id in range(1, num_nodes + 1):
            # Connect nodes with no input to original input
            if num_ins[in_id - 1] == 0 and nodes[in_id] is not None:
                nodes[0][1]['out'].connect(nodes[in_id][0]['in0'])
            # Connect nodes with no output to final output node
            if num_outs[in_id - 1] == 0 and nodes[in_id] is not None:
                nodes[in_id][1]['out'].connect(nodes[-1][0]['in%d' %
                                                            num_connected[-1]])
                num_connected[-1] += 1

            # Connect internal nodes
            for out_id in range(in_id + 1, num_nodes + 1):
                if dh['%d_%d' % (in_id, out_id)] == 1:
                    nodes[in_id][1]['out'].connect(
                        nodes[out_id][0]['in' + str(num_connected[out_id])])
                    num_connected[out_id] += 1
        return nodes[0][0], nodes[-1][1]

    i = 0
    name_to_hparam = {}
    for ix, hparam in enumerate(
            itertools.combinations(range(1, num_nodes + 1), 2)):
        name_to_hparam['%d_%d' % hparam] = h_connections[ix]
    return mo.substitution_module('GeneticStage',
                                  substitution_fn,
                                  name_to_hparam, ['in'], ['out'],
                                  scope=None)
示例#4
0
def miso_optional(fn, h_opt):
    def substitution_fn(dh):
        return fn() if dh["opt"] else MISOIdentity().get_io()

    return mo.substitution_module("MISOOptional",
                                  substitution_fn, {'opt': h_opt},
                                  ['in0', 'in1'], ['out'],
                                  scope=None)
示例#5
0
def motif(submotif_fn, num_nodes):
    assert num_nodes >= 1

    def substitution_fn(dh):
        print dh
        node_id_to_node_ids_used = {i: [i - 1] for i in range(1, num_nodes)}
        for name, v in dh.items():
            if v:
                d = ut.json_string_to_json_object(name)
                i = d["node_id"]
                node_ids_used = node_id_to_node_ids_used[i]
                j = d["in_node_id"]
                node_ids_used.append(j)
        for i in range(1, num_nodes):
            node_id_to_node_ids_used[i] = sorted(node_id_to_node_ids_used[i])
        print node_id_to_node_ids_used

        (inputs, outputs) = mo.identity()
        node_id_to_outputs = [outputs]
        in_inputs = inputs
        for i in range(1, num_nodes):
            node_ids_used = node_id_to_node_ids_used[i]
            num_edges = len(node_ids_used)

            outputs_lst = []
            for j in node_ids_used:
                inputs, outputs = submotif_fn()
                j_outputs = node_id_to_outputs[j]
                inputs["in"].connect(j_outputs["out"])
                outputs_lst.append(outputs)

            # if necessary, concatenate the results going into a node
            if num_edges > 1:
                c_inputs, c_outputs = combine_with_concat(num_edges)
                for idx, outputs in enumerate(outputs_lst):
                    c_inputs["in%d" % idx].connect(outputs["out"])
            else:
                c_outputs = outputs_lst[0]
            node_id_to_outputs.append(c_outputs)

        out_outputs = node_id_to_outputs[-1]
        return in_inputs, out_outputs

    name_to_hyperp = {
        ut.json_object_to_json_string({
            "node_id": i,
            "in_node_id": j
        }): D([0, 1]) for i in range(1, num_nodes) for j in range(i - 1)
    }
    return mo.substitution_module(
        "Motif", substitution_fn, name_to_hyperp, ["in"], ["out"], scope=None)
示例#6
0
def enas_space(h_num_layers,
               out_filters,
               fn_first,
               fn_repeats,
               input_names,
               output_names,
               weight_sharer,
               scope=None):
    def substitution_fn(dh):
        assert dh["num_layers"] > 0
        inputs, outputs = fn_first()
        temp_outputs = OrderedDict(outputs)
        for i in range(1, dh["num_layers"] + 1):
            inputs, temp_outputs = fn_repeats(inputs, temp_outputs, i,
                                              out_filters, weight_sharer)
        return inputs, OrderedDict(
            {'out': temp_outputs['out' + str(len(temp_outputs) - 1)]})

    return mo.substitution_module('ENASModule', substitution_fn,
                                  {'num_layers': h_num_layers}, input_names,
                                  output_names, scope)
示例#7
0
def cell(input_fn, node_fn, combine_fn, unused_combine_fn, num_nodes,
         hyperparameters):
    def substitution_fn(**dh):
        c_ins, c_outs = input_fn()
        nodes = [c_outs['Out0'], c_outs['Out1']]
        used_node = [False] * (num_nodes + 2)
        for i in range(num_nodes):
            # Get indices of hidden states to be combined
            idx0 = dh[str(i) + '_0']
            idx1 = dh[str(i) + '_1']

            # Transform hidden states
            h0 = node_fn(idx0, i, 0)
            h1 = node_fn(idx1, i, 1)
            h0[0]['In'].connect(nodes[idx0])
            h1[0]['In'].connect(nodes[idx1])
            used_node[idx0] = used_node[idx1] = True

            # Combine hidden states
            h = combine_fn()
            h[0]['In0'].connect(h0[1]['Out'])
            h[0]['In1'].connect(h1[1]['Out'])

            nodes.append(h[1]['Out'])

        ins, outs = unused_combine_fn(sum(not used for used in used_node))
        input_id = 0
        for ix, node in enumerate(nodes):
            if not used_node[ix]:
                ins['In' + str(input_id)].connect(node)
                input_id += 1
        return c_ins, outs

    name_to_hyperp = {
        '%d_%d' % (i // 2, i % 2): hyperparameters[i]
        for i in range(len(hyperparameters))
    }
    return mo.substitution_module('Cell', name_to_hyperp, substitution_fn,
                                  ['In0', 'In1'], ['Out'], None)
示例#8
0
def cell(input_fn, node_fn, output_fn, h_connections, num_nodes, channels):
    def substitution_fn(dh):
        num_ins = [
            sum([dh['%d_%d' % (in_id, out_id)] for in_id in range(out_id)])
            for out_id in range(num_nodes + 2)
        ]
        num_outs = [
            sum([
                dh['%d_%d' % (in_id, out_id)]
                for out_id in range(in_id + 1, num_nodes + 2)
            ]) for in_id in range(num_nodes + 2)
        ]

        for i in range(1, num_nodes + 1):
            if num_outs[i] > 0 and num_ins[i] == 0:
                raise ValueError('Node exists with no path to input')
            if num_ins[i] > 0 and num_outs[i] == 0:
                raise ValueError('Node exists with no path to output')
        if num_ins[-1] == 0:
            raise ValueError('No path exists between input and output')

        if sum(num_ins) > 9:
            raise ValueError('More than 9 edges')

        if dh['%d_%d' % (0, num_nodes + 1)]:
            num_ins[-1] -= 1
        # Compute the number of channels that each vertex outputs
        int_channels = channels // num_ins[-1]
        correction = channels % num_ins[-1]

        vertex_channels = [0] * (num_nodes + 2)
        vertex_channels[-1] = channels
        # Distribute channels as evenly as possible among vertices connected
        # to output vertex
        for in_id in range(1, num_nodes + 1):
            if dh['%d_%d' % (in_id, num_nodes + 1)]:
                vertex_channels[in_id] = int_channels
                if correction > 0:
                    vertex_channels[in_id] += 1
                    correction -= 1
        # For all other vertices, the number of channels they output is the max
        # of the channels outputed by all vertices they flow into
        for in_id in range(num_nodes - 1, 0, -1):
            if not dh['%d_%d' % (in_id, num_nodes + 1)]:
                for out_id in range(in_id + 1, num_nodes + 1):
                    if dh['%d_%d' % (in_id, out_id)]:
                        vertex_channels[in_id] = max(vertex_channels[in_id],
                                                     vertex_channels[out_id])

        nodes = [mo.identity()]
        nodes += [
            node_fn(num_ins[i], i -
                    1, vertex_channels[i]) if num_outs[i] > 0 else None
            for i in range(1, num_nodes + 1)
        ]
        nodes.append(output_fn(num_ins[num_nodes + 1]))
        num_connected = [0] * (num_nodes + 2)
        # Project input vertex to correct dimensions if used
        for out_id in range(1, num_nodes + 1):
            if dh['%d_%d' % (0, out_id)]:
                proj_in, proj_out = input_fn(vertex_channels[out_id])
                nodes[0][1]['out'].connect(proj_in['in'])
                proj_out['out'].connect(
                    nodes[out_id][0]['in' + str(num_connected[out_id])])
                num_connected[out_id] += 1
        for in_id in range(1, num_nodes + 1):
            for out_id in range(in_id + 1, num_nodes + 2):
                if dh['%d_%d' % (in_id, out_id)]:
                    nodes[in_id][1]['out'].connect(
                        nodes[out_id][0]['in' + str(num_connected[out_id])])
                    num_connected[out_id] += 1
        if dh['%d_%d' % (0, num_nodes + 1)]:
            proj_in, proj_out = input_fn(channels)
            add_in, add_out = add(2)
            nodes[0][1]['out'].connect(proj_in['in'])
            proj_out['out'].connect(add_in['in0'])
            nodes[-1][1]['out'].connect(add_in['in1'])
            nodes[-1] = (add_in, add_out)

        return nodes[0][0], nodes[-1][1]

    name_to_hparam = {}
    for ix, hparam in enumerate(itertools.combinations(range(num_nodes + 2),
                                                       2)):
        name_to_hparam['%d_%d' % hparam] = h_connections[ix]
    return mo.substitution_module('NasbenchCell',
                                  substitution_fn,
                                  name_to_hparam, ['in'], ['out'],
                                  scope=None)
示例#9
0
def conv_stage(filters, kernel_size, num_nodes):
    def substitution_fn(dh):
        print dh
        any_in_stage = any(v for v in itervalues(dh))

        if any_in_stage:
            node_id_to_lst = {}
            for name, v in iteritems(dh):
                d = ut.json_string_to_json_object(name)
                d['use'] = v
                i = d["node_id"]
                if i not in node_id_to_lst:
                    node_id_to_lst[i] = []
                node_id_to_lst[i].append(d)

            node_ids_using_any = set()
            node_ids_used = set()
            for i, lst in iteritems(node_id_to_lst):
                for d in lst:
                    if d['use']:
                        node_ids_using_any.add(d["node_id"])
                        node_ids_used.add(d["in_node_id"])

            node_ids_to_ignore = set([
                i for i in range(num_nodes)
                if i not in node_ids_using_any and i not in node_ids_used
            ])

            # creating the stage
            (in_inputs, in_outputs) = conv2d_cell(filters, kernel_size)
            node_id_to_outputs = {}
            for i in range(num_nodes):
                if i not in node_ids_to_ignore:
                    if i in node_ids_using_any:
                        in_ids = []
                        for d in node_id_to_lst[i]:
                            if d['use']:
                                in_ids.append(d['in_node_id'])
                        # collecting the inputs for the sum combiner.
                        num_inputs = len(in_ids)
                        if num_inputs > 1:
                            (s_inputs,
                             s_outputs) = combine_with_sum(num_inputs)
                            for idx, j in enumerate(in_ids):
                                j_outputs = node_id_to_outputs[j]
                                s_inputs["in%d" % idx].connect(
                                    j_outputs["out"])
                        else:
                            j = in_ids[0]
                            s_outputs = node_id_to_outputs[j]

                        (n_inputs,
                         n_outputs) = conv2d_cell(filters, kernel_size)
                        n_inputs["in"].connect(s_outputs["out"])
                    else:
                        (n_inputs,
                         n_outputs) = conv2d_cell(filters, kernel_size)
                        n_inputs["in"].connect(in_outputs['out'])
                    node_id_to_outputs[i] = n_outputs

            # final connection to the output node
            in_ids = []
            for i in range(num_nodes):
                if i not in node_ids_to_ignore and i not in node_ids_used:
                    in_ids.append(i)
            num_inputs = len(in_ids)
            if num_inputs > 1:
                (s_inputs, s_outputs) = combine_with_sum(num_inputs)
                for idx, j in enumerate(in_ids):
                    j_outputs = node_id_to_outputs[j]
                    s_inputs["in%d" % idx].connect(j_outputs["out"])
            else:
                j = in_ids[0]
                s_outputs = node_id_to_outputs[j]

            (out_inputs, out_outputs) = conv2d_cell(filters, kernel_size)
            out_inputs["in"].connect(s_outputs["out"])
            return (in_inputs, out_outputs)

        else:
            # all zeros encoded.
            return conv2d_cell(filters, kernel_size)

    name_to_hyperp = {
        ut.json_object_to_json_string({
            "node_id": i,
            "in_node_id": j
        }): D([0, 1])
        for i in range(num_nodes) for j in range(i)
    }
    return mo.substitution_module("ConvStage",
                                  substitution_fn,
                                  name_to_hyperp, ["in"], ["out"],
                                  scope=None)