示例#1
0
 def from_json_loads(json_data):
     net_info = CellNetworkInfo()
     if isinstance(json_data, list):
         net_info['master'] = LayerInfoList.from_json_loads(json_data)
         net_info.cell_names.append('master')
     else:
         for key in ['master', 'normal', 'reduction']:
             jd = json_data.get(key, None)
             if jd:
                 net_info[key] = LayerInfoList.from_json_loads(jd)
                 net_info.cell_names.append(key)
     return net_info
示例#2
0
def darts_rnn_base_cell_info(next_id=0,
                             input_ids=[0, 1],
                             end_merge=LayerTypes.MERGE_WITH_CAT):
    """
    See implementation of PetridishRNNCell to see an example of
    how this list of info is used.
    The first two info are x_and_h and init_layer, which is a
    projected x_and_h multiplied with gate.
    The rest of layers use specified operation to morph the layers.

    This is DARTS from the paper writing
    """
    LT = LayerTypes
    l_info = LayerInfoList([
        LayerInfo(input_ids[0]),  # next_id + 0
        LayerInfo(input_ids[1]),  # next_id + 1
        LayerInfo(next_id + 2,
                  inputs=[input_ids[1]],
                  operations=[LT.FC_RELU_MUL_GATE, LT.MERGE_WITH_SUM]),
        LayerInfo(next_id + 3,
                  inputs=[next_id + 2],
                  operations=[LT.FC_RELU_MUL_GATE, LT.MERGE_WITH_SUM]),
        LayerInfo(next_id + 4,
                  inputs=[next_id + 3],
                  operations=[LT.FC_TANH_MUL_GATE, LT.MERGE_WITH_SUM]),
        LayerInfo(next_id + 5,
                  inputs=[next_id + 4],
                  operations=[LT.FC_RELU_MUL_GATE, LT.MERGE_WITH_SUM]),
        LayerInfo(next_id + 6,
                  inputs=[next_id + 5],
                  operations=[LT.FC_RELU_MUL_GATE, LT.MERGE_WITH_SUM]),
        LayerInfo(next_id + 7,
                  inputs=[next_id + 2],
                  operations=[LT.FC_IDEN_MUL_GATE, LT.MERGE_WITH_SUM]),
        LayerInfo(next_id + 8,
                  inputs=[next_id + 6],
                  operations=[LT.FC_RELU_MUL_GATE, LT.MERGE_WITH_SUM]),
        LayerInfo(next_id + 9,
                  inputs=[next_id + 2],
                  operations=[LT.FC_RELU_MUL_GATE, LT.MERGE_WITH_SUM]),
        LayerInfo(next_id + 10,
                  inputs=[
                      next_id + 1,
                      next_id + 2,
                      next_id + 3,
                      next_id + 4,
                      next_id + 5,
                      next_id + 6,
                      next_id + 7,
                      next_id + 8,
                      next_id + 9,
                  ],
                  operations=[
                      LT.IDENTITY, LT.IDENTITY, LT.IDENTITY, LT.IDENTITY,
                      LT.IDENTITY, LT.IDENTITY, LT.IDENTITY, LT.IDENTITY,
                      LT.IDENTITY, LT.MERGE_WITH_AVG
                  ])
    ])
    return l_info
示例#3
0
def fully_connected_resnet_cell_info(next_id=0,
                                     input_ids=[0, 1],
                                     end_merge=LayerTypes.MERGE_WITH_CAT):
    LT = LayerTypes
    l_info = LayerInfoList([
        LayerInfo(input_ids[0]),
        LayerInfo(input_ids[1]),
        LayerInfo(next_id + 2,
                  inputs=[input_ids[1]],
                  operations=[LT.FC_SGMD_MUL_GATE, LT.MERGE_WITH_SUM])
    ])
    return ensure_end_merge(l_info, end_merge)
示例#4
0
def nasnata_reduction_cell_info(next_id=0,
                                input_ids=[0, 1],
                                end_merge=LayerTypes.MERGE_WITH_CAT):
    LT = LayerTypes
    l_info = LayerInfoList()
    l_info.extend([
        LayerInfo(input_ids[0]),
        LayerInfo(input_ids[1]),  # most recent layer
        LayerInfo(next_id + 2,
                  inputs=[input_ids[1], input_ids[0]],
                  operations=[
                      LT.SEPARABLE_CONV_5_2, LT.SEPARABLE_CONV_7_2,
                      LT.MERGE_WITH_SUM
                  ]),
        LayerInfo(next_id + 3,
                  inputs=[input_ids[1], input_ids[0]],
                  operations=[
                      LT.MAXPOOL_3x3, LT.SEPARABLE_CONV_7_2, LT.MERGE_WITH_SUM
                  ]),
        LayerInfo(next_id + 4,
                  inputs=[input_ids[1], input_ids[0]],
                  operations=[
                      LT.AVGPOOL_3x3, LT.SEPARABLE_CONV_5_2, LT.MERGE_WITH_SUM
                  ]),
        LayerInfo(next_id + 5,
                  inputs=[next_id + 3, next_id + 2],
                  operations=[LT.IDENTITY, LT.AVGPOOL_3x3, LT.MERGE_WITH_SUM]),
        LayerInfo(next_id + 6,
                  inputs=[next_id + 2, input_ids[1]],
                  operations=[
                      LT.SEPARABLE_CONV_3_2, LT.MAXPOOL_3x3, LT.MERGE_WITH_SUM
                  ]),
    ])
    l_info.append(cat_unused(l_info, next_id + 7, end_merge))
    return l_info
示例#5
0
def separable_resnet_cell_info(next_id=0,
                               input_ids=[0, 1],
                               end_merge=LayerTypes.MERGE_WITH_CAT):
    LT = LayerTypes
    l_info = LayerInfoList([
        LayerInfo(input_ids[0]),
        LayerInfo(input_ids[1]),
        LayerInfo(
            next_id + 2,
            inputs=[input_ids[1], input_ids[1]],
            operations=[LT.SEPARABLE_CONV_3_2, LT.IDENTITY, LT.MERGE_WITH_SUM])
    ])
    return ensure_end_merge(l_info, end_merge)
示例#6
0
def basic_resnet_cell_info(next_id=0,
                           input_ids=[0, 1],
                           end_merge=LayerTypes.MERGE_WITH_CAT):
    LT = LayerTypes
    l_info = LayerInfoList([
        LayerInfo(input_ids[0]),
        LayerInfo(input_ids[1]),
        LayerInfo(next_id + 2,
                  inputs=[input_ids[1]],
                  operations=[LT.CONV_3, LT.MERGE_WITH_NOTHING]),
        LayerInfo(next_id + 3,
                  inputs=[next_id + 2, input_ids[1]],
                  operations=[LT.CONV_3, LT.IDENTITY, LT.MERGE_WITH_SUM])
    ])
    return ensure_end_merge(l_info, end_merge)
示例#7
0
    def default_master(n_normal_inputs=2,
                       n_reduction_inputs=2,
                       num_cells=18,
                       num_reduction_layers=2,
                       num_init_reductions=0,
                       skip_reduction_layer_input=0,
                       use_aux_head=1):
        reduction_layers = CellNetworkInfo.calc_reduction_layers(
            num_cells, num_reduction_layers, num_init_reductions)
        master = LayerInfoList()
        layer_id = 0
        n_inputs = n_normal_inputs if num_init_reductions == 0 else n_reduction_inputs
        for _ in range(n_inputs):
            master.append(LayerInfo(layer_id=layer_id))
            layer_id += 1

        # true_num_cells counts cells from the first non-input with 0-based index
        true_num_cells = num_cells + num_init_reductions + num_reduction_layers
        for ci in range(true_num_cells):
            info = LayerInfo(layer_id)
            if ci in reduction_layers:
                info.inputs = list(
                    range(layer_id - n_reduction_inputs, layer_id))
                n_in = len(info.inputs)
                info.operations = [LayerTypes.IDENTITY] * n_in + ['reduction']
                info.down_sampling = 1
            else:
                if (skip_reduction_layer_input and ci - 1 in reduction_layers
                        and ci > num_init_reductions):
                    # imagenet : do not take the input of regular reduction as skip connection.
                    info.inputs = (list(
                        range(layer_id - n_normal_inputs - 1, layer_id - 2)) +
                                   [layer_id - 1])
                else:
                    info.inputs = list(
                        range(layer_id - n_normal_inputs, layer_id))
                n_in = len(info.inputs)
                info.operations = [LayerTypes.IDENTITY] * n_in + ['normal']
            master.append(info)
            layer_id += 1

        # aux_weight at the last cell before the last reduction
        if use_aux_head and len(reduction_layers) > 0:
            master[reduction_layers[-1] - 1 + n_inputs].aux_weight = 0.4
        master[-1].aux_weight = 1.0
        return master