Ejemplo n.º 1
0
def _make_recovery_model_include_top(recovery_model:Layer,default_shape=None,input_shape=None, include_top=True, classes=1000, freeze_features=False):
    size_change=False
    if default_shape is None:
        if recovery_model.built:
            default_shape=tuple(recovery_model._input_shape.dims[1:] if isinstance(recovery_model._input_shape,TensorShape) else recovery_model._input_shape)
        else:
            default_shape=(3,224,224) if get_backend() == 'pytorch' else (224,224,3)
    if input_shape is not None and input_shape !=default_shape:
        size_change=True
        dims = list(input_shape)
        dims.insert(0, None)

        if isinstance(recovery_model.signature, Signature):
            recovery_model._input_shape = TensorShape(dims)
            recovery_model.signature.inputs.value_list[0].shape = TensorShape(dims)
            recovery_model.signature.inputs.value_list[0].object_type=ObjectType.rgb

    if freeze_features:
        recovery_model.trainable=False
        idx=-1
        while (len(recovery_model[idx]._parameters) == 0 or isinstance(recovery_model[idx], Dense))and len(recovery_model[idx].output_shape) >= 2:
            layer=recovery_model[idx]
            if layer.output_shape.rank>2:
                break
            if  len(recovery_model[idx]._parameters) >0:
                recovery_model[idx].trainable=True
            idx-=1

    if not include_top:
        while  len(recovery_model[-1]._parameters)==0 or isinstance(recovery_model[-1],Dense) and len(recovery_model[-1].output_shape)>=2:
            layer = recovery_model[-1]
            if layer.output_shape.rank > 2:
                break
            recovery_model.remove_at(-1)
        recovery_model.class_names = []
    elif size_change:
        new_layers=[]
        dims = list(input_shape)
        dims.insert(0, None)
        shp=TensorShape(dims)

        while len(recovery_model[-1]._parameters) == 0 or isinstance(recovery_model[-1], Dense) and len(recovery_model[-1].output_shape) >= 2:
            layer = recovery_model[-1]
            if layer.output_shape.rank > 2:
                break

            new_layer=copy.deepcopy(layer)
            if isinstance(layer,Dense) :
                if  layer.num_filters==1000 and classes != 1000:
                    new_layer=Dense((classes))
                    recovery_model.class_names = []
                else:
                    num_filters=new_layer.num_filters
                    new_layer=Dense((num_filters))
            new_layers.insert(0,new_layer)
            recovery_model.remove_at(-1)
        out=recovery_model(to_tensor(shp.get_dummy_tensor()))
        recovery_model[-1].output_shape=tensor_to_shape(out,need_exclude_batch_axis=True)
        fc_seq=0
        for ly in new_layers:
            if isinstance(ly, Dense):
                recovery_model.add_module('fc' if fc_seq==0 else 'fc{0}'.format(fc_seq),ly)
                fc_seq += 1
            else:
                recovery_model.add_module(ly.name, ly)

        if isinstance(recovery_model.signature, Signature):
            recovery_model.output_shape = TensorShape([None, classes])
            recovery_model.signature.outputs.value_list[0].shape = TensorShape([None, classes])
            recovery_model.signature.outputs.value_list[0].object_type = ObjectType.classification_label

    else:
        #include_top=True
        if classes != 1000:
            while  len(recovery_model[-1]._parameters)==0 or isinstance(recovery_model[-1],Dense) and len(recovery_model[-1].output_shape)>=2:
                m=recovery_model[-1]
                if isinstance(m,Dense):
                    recovery_model[-1]=Dense((classes))
                    recovery_model.add_module('softmax',SoftMax())
                    break
                else:
                    recovery_model.remove_at(-1)
            if isinstance(recovery_model.signature, Signature):
                recovery_model.output_shape= TensorShape([None,classes])
                recovery_model.signature.outputs.value_list[0].shape = TensorShape([None,classes])
                recovery_model.signature.outputs.value_list[0].object_type = ObjectType.classification_label
            recovery_model.class_names = []
    return recovery_model
Ejemplo n.º 2
0
def _make_recovery_model_include_top(recovery_model: Layer,
                                     default_shape=None,
                                     input_shape=None,
                                     include_top=True,
                                     classes=1000,
                                     freeze_features=True):
    size_change = False
    if default_shape is None:
        if recovery_model.built:
            default_shape = tuple(
                recovery_model._input_shape.
                dims[1:] if isinstance(recovery_model._input_shape, TensorShape
                                       ) else recovery_model._input_shape)
        else:
            default_shape = (3, 224,
                             224) if get_backend() == 'pytorch' else (224, 224,
                                                                      3)
    if input_shape is not None and input_shape != default_shape:
        size_change = True

    if freeze_features:
        recovery_model.trainable = False
        idx = -1
        is_last_dense = True
        while (len(recovery_model[idx]._parameters) == 0
               or isinstance(recovery_model[idx], Dense)) and len(
                   recovery_model[idx].output_shape) >= 2:
            layer = recovery_model[idx]
            if layer.output_shape.rank > 2:
                break
            elif len(recovery_model[idx]._parameters) > 0:
                if not include_top:
                    recovery_model.remove_at(idx)
                    idx += 1
                elif size_change or (is_last_dense and classes != 1000 and
                                     isinstance(recovery_model[idx], Dense)):
                    if hasattr(
                            recovery_model[idx], 'num_filters'
                    ) and recovery_model[idx].num_filters != classes:
                        recovery_model[idx].num_filters = classes
                    recovery_model[idx]._built = False
                    recovery_model[idx]._parameters.clear()

                else:
                    recovery_model[idx].trainable = True
            else:
                if not include_top:
                    recovery_model.remove_at(idx)
                    idx += 1
            idx -= 1

    dims = list(default_shape)
    dims.insert(0, None)
    new_tensorshape = TensorShape(dims)
    if size_change:
        dims = list(input_shape)
        dims.insert(0, None)
        new_tensorshape = TensorShape(dims)
        for module in recovery_model.modules():
            module._input_shape = None
            module._output_shape = None
    recovery_model.to(get_device())
    out = recovery_model(
        to_tensor(new_tensorshape.get_dummy_tensor(), device=get_device()))

    if isinstance(recovery_model.signature, Signature):
        recovery_model.signature.inputs.value_list[0].shape = TensorShape(dims)
        recovery_model.signature.inputs.value_list[
            0].object_type = ObjectType.rgb
    recovery_model.to(get_device())
    return recovery_model