예제 #1
0
def create_motif(
    dh,
    full_hparam_dict,
    motif_info,
    level,
    num_channels,
):
    if level == 1:
        return [
            lambda: conv2d_cell(num_channels, 1),  # 1x1 conv of C channels
            lambda: depthwise_conv2d_cell(3),  # 3x3 depthwise conv
            lambda: separable_conv2d_cell(num_channels, 3
                                          ),  # 3x3 sep conv of C channels
            lambda: max_pooling(3),  # 3x3 max pool
            lambda: average_pooling(3),  # 3x3 avg pool
            lambda: mo.identity()
        ][dh['operation'] - 1]()

    def substitution_fn(dh):
        num_nodes = motif_info[level]['num_nodes']
        ops = [[] for _ in range(motif_info[level]['num_nodes'])]
        ops[0].append(mo.identity())
        output_ops = []
        for out_node in range(num_nodes):
            for in_node in range(out_node):
                op_id = dh[ut.json_object_to_json_string({
                    "out_node_id": out_node,
                    "in_node_id": in_node,
                })]
                if op_id > 0:
                    if level == 2:
                        next_hparams = {'operation': op_id}
                    else:
                        next_hparams = full_hparam_dict[level - 1][op_id - 1]
                    ops[out_node].append(
                        create_motif(
                            next_hparams,
                            full_hparam_dict,
                            motif_info,
                            level - 1,
                            num_channels,
                        ))
                    ops[out_node][-1][0]['in'].connect(
                        output_ops[in_node][1]['out'])
            assert (len(ops[out_node]) > 0)
            concat_ins, concat_out = combine_with_concat(
                len(ops[out_node]), num_channels)
            for ix, (ins, outs) in enumerate(ops[out_node]):
                outs['out'].connect(concat_ins['in%d' % ix])
            output_ops.append((concat_ins, concat_out))
        output = output_ops[-1][1]
        if level == 2:
            conv_ins, conv_outs = conv2d_cell(num_channels, 1)
            output['out'].connect(conv_ins['in'])
            output = conv_outs
        return ops[0][0][0], output

    return mo.substitution_module('Motif_Level_%d' % level, substitution_fn,
                                  dh, ['in'], ['out'], None)
예제 #2
0
def cell_input_fn(filters):
    prev_input = mo.identity()
    cur_input = wrap_relu_batch_norm(conv2d(D([filters]), D([1])))
    transformed_prev_input = maybe_factorized_reduction(add_relu=True)
    transformed_prev_input[0]['In0'].connect(prev_input[1]['Out'])
    transformed_prev_input[0]['In1'].connect(cur_input[1]['Out'])
    return {
        'In0': prev_input[0]['In'],
        'In1': cur_input[0]['In']
    }, {
        'Out0': transformed_prev_input[1]['Out'],
        'Out1': cur_input[1]['Out']
    }
예제 #3
0
def combine_unused(num_ins):
    inputs = [mo.identity() for _ in range(num_ins)]
    factorized = [maybe_factorized_reduction() for _ in range(num_ins)]
    concat_ins, concat_outs = concat(num_ins)
    di = {}
    last_in, last_out = inputs[-1]
    for i in range(num_ins):
        f_in, f_out = inputs[i]
        factorized[i][0]['In0'].connect(f_out['Out'])
        factorized[i][0]['In1'].connect(last_out['Out'])
        di['In' + str(i)] = f_in['In']
        concat_ins['In' + str(i)].connect(factorized[i][1]['Out'])

    return di, concat_outs
예제 #4
0
def generate_search_space(num_nodes_per_cell, num_normal_cells,
                          num_reduction_cells, init_filters, stem_multiplier):
    global global_vars, hp_sharer
    global_vars = {}
    hp_sharer = hp.HyperparameterSharer()
    hp_sharer.register('drop_path_keep_prob',
                       lambda: D([.7], name='drop_path_keep_prob'))
    stem_in, stem_out = stem(int(init_filters * stem_multiplier))
    progress_in, progress_out = mo.identity()
    global_vars['progress'] = progress_out['Out']
    normal_cell_fn = create_cell_generator(num_nodes_per_cell, False)
    reduction_cell_fn = create_cell_generator(num_nodes_per_cell, True)

    total_cells = num_normal_cells + num_reduction_cells
    hasReduction = [False] * num_normal_cells
    for i in range(num_reduction_cells):
        hasReduction[int(
            float(i + 1) / (num_reduction_cells + 1) *
            num_normal_cells)] = True

    inputs = [stem_out, stem_out]
    filters = init_filters
    aux_loss_idx = int(
        float(num_reduction_cells) /
        (num_reduction_cells + 1) * num_normal_cells) - 1

    outs = {}
    cells_created = 0.0
    for i in range(num_normal_cells):
        if hasReduction[i]:
            filters *= 2
            connect_new_cell(
                reduction_cell_fn(filters, (cells_created + 1) / total_cells),
                inputs)
            cells_created += 1.0
        connect_new_cell(
            normal_cell_fn(filters, (cells_created + 1) / total_cells), inputs)
        cells_created += 1.0
        if i == aux_loss_idx:
            aux_in, aux_out = aux_logits()
            aux_in['In'].connect(inputs[-1]['Out'])
            outs['Out0'] = aux_out['Out']
    _, final_out = mo.siso_sequential([(None, inputs[-1]),
                                       relu(),
                                       global_pool2d(),
                                       dropout(D([1.0])),
                                       fc_layer(D([10]))])
    outs['Out1'] = final_out['Out']
    return {'In0': stem_in['In'], 'In1': progress_in['In']}, outs
예제 #5
0
    def substitution_fn(dh):
        print dh
        node_id_to_node_ids_used = {i: [i - 1] for i in range(1, num_nodes)}
        for name, v in dh.items():
            if v:
                d = ut.json_string_to_json_object(name)
                i = d["node_id"]
                node_ids_used = node_id_to_node_ids_used[i]
                j = d["in_node_id"]
                node_ids_used.append(j)
        for i in range(1, num_nodes):
            node_id_to_node_ids_used[i] = sorted(node_id_to_node_ids_used[i])
        print node_id_to_node_ids_used

        (inputs, outputs) = mo.identity()
        node_id_to_outputs = [outputs]
        in_inputs = inputs
        for i in range(1, num_nodes):
            node_ids_used = node_id_to_node_ids_used[i]
            num_edges = len(node_ids_used)

            outputs_lst = []
            for j in node_ids_used:
                inputs, outputs = submotif_fn()
                j_outputs = node_id_to_outputs[j]
                inputs["in"].connect(j_outputs["out"])
                outputs_lst.append(outputs)

            # if necessary, concatenate the results going into a node
            if num_edges > 1:
                c_inputs, c_outputs = combine_with_concat(num_edges)
                for idx, outputs in enumerate(outputs_lst):
                    c_inputs["in%d" % idx].connect(outputs["out"])
            else:
                c_outputs = outputs_lst[0]
            node_id_to_outputs.append(c_outputs)

        out_outputs = node_id_to_outputs[-1]
        return in_inputs, out_outputs
예제 #6
0
 def substitution_fn(dh):
     num_nodes = motif_info[level]['num_nodes']
     ops = [[] for _ in range(motif_info[level]['num_nodes'])]
     ops[0].append(mo.identity())
     output_ops = []
     for out_node in range(num_nodes):
         for in_node in range(out_node):
             op_id = dh[ut.json_object_to_json_string({
                 "out_node_id": out_node,
                 "in_node_id": in_node,
             })]
             if op_id > 0:
                 if level == 2:
                     next_hparams = {'operation': op_id}
                 else:
                     next_hparams = full_hparam_dict[level - 1][op_id - 1]
                 ops[out_node].append(
                     create_motif(
                         next_hparams,
                         full_hparam_dict,
                         motif_info,
                         level - 1,
                         num_channels,
                     ))
                 ops[out_node][-1][0]['in'].connect(
                     output_ops[in_node][1]['out'])
         assert (len(ops[out_node]) > 0)
         concat_ins, concat_out = combine_with_concat(
             len(ops[out_node]), num_channels)
         for ix, (ins, outs) in enumerate(ops[out_node]):
             outs['out'].connect(concat_ins['in%d' % ix])
         output_ops.append((concat_ins, concat_out))
     output = output_ops[-1][1]
     if level == 2:
         conv_ins, conv_outs = conv2d_cell(num_channels, 1)
         output['out'].connect(conv_ins['in'])
         output = conv_outs
     return ops[0][0][0], output
예제 #7
0
    def substitution_fn(dh):
        num_ins = [
            sum([dh['%d_%d' % (in_id, out_id)] for in_id in range(out_id)])
            for out_id in range(num_nodes + 2)
        ]
        num_outs = [
            sum([
                dh['%d_%d' % (in_id, out_id)]
                for out_id in range(in_id + 1, num_nodes + 2)
            ]) for in_id in range(num_nodes + 2)
        ]

        for i in range(1, num_nodes + 1):
            if num_outs[i] > 0 and num_ins[i] == 0:
                raise ValueError('Node exists with no path to input')
            if num_ins[i] > 0 and num_outs[i] == 0:
                raise ValueError('Node exists with no path to output')
        if num_ins[-1] == 0:
            raise ValueError('No path exists between input and output')

        if sum(num_ins) > 9:
            raise ValueError('More than 9 edges')

        if dh['%d_%d' % (0, num_nodes + 1)]:
            num_ins[-1] -= 1
        # Compute the number of channels that each vertex outputs
        int_channels = channels // num_ins[-1]
        correction = channels % num_ins[-1]

        vertex_channels = [0] * (num_nodes + 2)
        vertex_channels[-1] = channels
        # Distribute channels as evenly as possible among vertices connected
        # to output vertex
        for in_id in range(1, num_nodes + 1):
            if dh['%d_%d' % (in_id, num_nodes + 1)]:
                vertex_channels[in_id] = int_channels
                if correction > 0:
                    vertex_channels[in_id] += 1
                    correction -= 1
        # For all other vertices, the number of channels they output is the max
        # of the channels outputed by all vertices they flow into
        for in_id in range(num_nodes - 1, 0, -1):
            if not dh['%d_%d' % (in_id, num_nodes + 1)]:
                for out_id in range(in_id + 1, num_nodes + 1):
                    if dh['%d_%d' % (in_id, out_id)]:
                        vertex_channels[in_id] = max(vertex_channels[in_id],
                                                     vertex_channels[out_id])

        nodes = [mo.identity()]
        nodes += [
            node_fn(num_ins[i], i -
                    1, vertex_channels[i]) if num_outs[i] > 0 else None
            for i in range(1, num_nodes + 1)
        ]
        nodes.append(output_fn(num_ins[num_nodes + 1]))
        num_connected = [0] * (num_nodes + 2)
        # Project input vertex to correct dimensions if used
        for out_id in range(1, num_nodes + 1):
            if dh['%d_%d' % (0, out_id)]:
                proj_in, proj_out = input_fn(vertex_channels[out_id])
                nodes[0][1]['out'].connect(proj_in['in'])
                proj_out['out'].connect(
                    nodes[out_id][0]['in' + str(num_connected[out_id])])
                num_connected[out_id] += 1
        for in_id in range(1, num_nodes + 1):
            for out_id in range(in_id + 1, num_nodes + 2):
                if dh['%d_%d' % (in_id, out_id)]:
                    nodes[in_id][1]['out'].connect(
                        nodes[out_id][0]['in' + str(num_connected[out_id])])
                    num_connected[out_id] += 1
        if dh['%d_%d' % (0, num_nodes + 1)]:
            proj_in, proj_out = input_fn(channels)
            add_in, add_out = add(2)
            nodes[0][1]['out'].connect(proj_in['in'])
            proj_out['out'].connect(add_in['in0'])
            nodes[-1][1]['out'].connect(add_in['in1'])
            nodes[-1] = (add_in, add_out)

        return nodes[0][0], nodes[-1][1]