Ejemplo n.º 1
0
def _get_nncf_graph_from_sequential(model: tf.keras.Model) -> NNCFGraph:
    nncf_graph = NNCFGraph()
    producer_layer = None
    model_config = model.get_config()
    for layer in model_config['layers']:
        layer_name = layer['config']['name']
        layer_type = _get_layer_type(layer)
        layer_dtype = _get_layer_dtype(layer)
        data_format = layer['config'].get('data_format')
        attrs = dict(type=layer_type,
                     dtype=layer_dtype,
                     data_format=data_format,
                     in_ports=[0],
                     out_ports=[0],
                     is_shared=False)
        if layer_type in GENERAL_CONV_LAYERS:
            module_attributes = _get_module_attributes(
                model.get_layer(layer_name), attrs)
            attrs.update({NNCFGraph.MODULE_ATTRIBUTES: module_attributes})

        nncf_graph.add_node(layer_name, **attrs)
        if producer_layer is not None:
            input_shape = _prepare_shape(
                model.get_layer(layer_name).input_shape)
            attr = {
                NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR: input_shape[0],
                NNCFGraph.IN_PORT_NAME_EDGE_ATTR: 0
            }
            nncf_graph.add_edge(producer_layer, layer_name, **attr)
        producer_layer = layer_name

    return nncf_graph
Ejemplo n.º 2
0
def load_pretrained_layers(config: dict, my_model: tf.keras.Model):
    if "pretrained_layers" in config["model"]["hyper_params"]:
        pretrained_layers = config["model"]["hyper_params"][
            "pretrained_layers"]
        for pretrained_layer in pretrained_layers:
            logger.info(
                f"Load pretrained layer {pretrained_layer['layer_name']} into {pretrained_layer['target_layer_name']}"
            )
            # https://github.com/tensorflow/tensorflow/issues/32348
            pretrained_model: tf.keras.Model \
                = tf.keras.models.load_model(pretrained_layer["model_path"], compile=False)

            if (pretrained_layer['target_layer_name'] == "decoder") and \
                    isinstance(my_model.get_layer("decoder"), Decoder):
                logger.debug(
                    f"Load a Transformer Encoder into Transformer Decoder")
                pretrained_encoder: Encoder = pretrained_model.get_layer(
                    "encoder")
                decoder: Decoder = my_model.get_layer("decoder")

                # Load weights from encoder like XLM paper
                # See https://github.com/facebookresearch/XLM/blob/master/src/model/transformer.py

                decoder.embedding.set_weights(
                    pretrained_encoder.embedding.get_weights())
                for i in range(len(decoder.dec_layers)):
                    decoder.dec_layers[i].mha1.set_weights(
                        pretrained_encoder.enc_layers[i].mha.get_weights())
                    decoder.dec_layers[i].mha2.set_weights(
                        pretrained_encoder.enc_layers[i].mha.get_weights())
                    decoder.dec_layers[i].ffn.set_weights(
                        pretrained_encoder.enc_layers[i].ffn.get_weights())
                    decoder.dec_layers[i].layernorm1.set_weights(
                        pretrained_encoder.enc_layers[i].layernorm1.
                        get_weights())
                    decoder.dec_layers[i].layernorm2.set_weights(
                        pretrained_encoder.enc_layers[i].layernorm1.
                        get_weights())
                    decoder.dec_layers[i].layernorm3.set_weights(
                        pretrained_encoder.enc_layers[i].layernorm2.
                        get_weights())

            else:
                for l in pretrained_layer['layer_name'].split("/"):
                    pretrained_model = pretrained_model.get_layer(l)

                weights = pretrained_model.get_weights()
                my_model.get_layer(
                    pretrained_layer['target_layer_name']).set_weights(weights)

    return my_model
Ejemplo n.º 3
0
def freeze_layers_before(model: tf.keras.Model, layer_name: str):
    """Freezes layers of a Keras `model` before a given `layer_name` (excluded)."""

    freeze_before = model.get_layer(layer_name)
    index_freeze_before = model.layers.index(freeze_before)
    for layer in model.layers[:index_freeze_before]:
        layer.trainable = False
Ejemplo n.º 4
0
def transfer_ema_weights(source_model: tf.keras.Model,
                         target_model: tf.keras.Model,
                         source_ema: Optional[
                             tf.train.ExponentialMovingAverage] = None,
                         layer_name_prefix: str = '') -> None:
    """Transfer weights from source_model to target_model. If source_ema is specified, then the exponential moving
    average in source_ema of each variable in source_model will be transferred to source_target.

    Args:
        source_model: the source to transfer weights from
        target_model: the target to transfer weights to
        source_ema: optional exponential moving average of source_model.variables
        layer_name_prefix: only layers, which names start with layer_name_prefix, are transferred
    """
    for source_layer in source_model.layers:
        source_vars = source_layer.variables
        if source_layer.name.startswith(layer_name_prefix) and source_vars:
            try:
                target_layer = target_model.get_layer(name=source_layer.name)
            except ValueError:
                continue
            for source_var, target_var in zip(source_vars,
                                              target_layer.variables):
                if source_ema is not None:
                    transfer_var = source_ema.average(source_var)
                else:
                    transfer_var = source_var
                target_var.assign(transfer_var)
Ejemplo n.º 5
0
def _prepare_raw_nodes(model: tf.keras.Model) -> Dict:
    model_config = model.get_config()
    raw_nodes = Dict()
    for layer in model_config['layers']:
        layer_name = layer['name']
        layer_type = _get_layer_type(layer)
        layer_dtype = _get_layer_dtype(layer)
        data_format = layer['config'].get('data_format')
        model_layer = model.get_layer(layer_name)

        if layer['inbound_nodes']:
            is_shared = len(layer['inbound_nodes']) > 1
            for i, inbound_node in enumerate(layer['inbound_nodes']):
                input_shape = _prepare_shape(
                    model_layer.inbound_nodes[i].input_shapes)
                instance = raw_nodes[layer_name][i]
                instance['type'] = layer_type
                instance['dtype'] = layer_dtype
                instance['data_format'] = data_format
                instance['is_shared'] = is_shared
                instance['input_shape'] = input_shape
                instance['in_ports'] = list(range(len(inbound_node)))
                if not instance['out_ports']:
                    instance['out_ports'] = set()
                if layer_type in GENERAL_CONV_LAYERS:
                    module_attributes = _get_module_attributes(
                        model_layer, instance)
                    instance.update(
                        {NNCFGraph.MODULE_ATTRIBUTES: module_attributes})
                for parent_name, parent_instance_index, parent_out_ports, _ in inbound_node:
                    parent_instance = raw_nodes[parent_name][
                        parent_instance_index]
                    if parent_instance['out_ports']:
                        parent_instance['out_ports'].add(parent_out_ports)
                    else:
                        parent_instance['out_ports'] = {parent_out_ports}
        else:
            instance = raw_nodes[layer_name][0]
            instance['type'] = layer_type
            instance['dtype'] = layer_dtype
            instance['data_format'] = data_format
            instance['is_shared'] = False
            instance['in_ports'] = []
            instance['input_shape'] = _prepare_shape(model_layer.input_shape)
            if layer_type in GENERAL_CONV_LAYERS:
                module_attributes = _get_module_attributes(
                    model_layer, instance)
                instance.update(
                    {NNCFGraph.MODULE_ATTRIBUTES: module_attributes})

    outputs = model_config['output_layers']
    raw_nodes = _process_outputs(outputs, raw_nodes)

    for instance_dict in raw_nodes.values():
        for instance in instance_dict.values():
            instance['out_ports'] = sorted(list(instance['out_ports']))

    return raw_nodes
Ejemplo n.º 6
0
def load_darknet_weights(model: tf.keras.Model, weights_file):
    """
    A function which takes raw darknet weights file and converts it to tensorflow Model using tf.train.Checkpoint api

    :param model: tf.keras.Model with the YOLOv3 architecture
    :param weights_file: pretrained Darknet weights file
    :return: returns void with weights converted to tensorflow checkpoint format
    """
    wf = open(weights_file, 'rb')
    major, minor, revision, seen, _ = np.fromfile(wf, dtype=np.int32, count=5)

    layers = YOLOV3_LAYER_LIST

    for layer_name in layers:
        sub_model = model.get_layer(layer_name)
        for i, layer in enumerate(sub_model.layers):
            if not layer.name.startswith('conv2d'):
                continue
            batch_norm = None
            if i + 1 < len(sub_model.layers) and \
                    sub_model.layers[i + 1].name.startswith('batch_norm'):
                batch_norm = sub_model.layers[i + 1]

            logging.info("{}/{} {}".format(sub_model.name, layer.name,
                                           'bn' if batch_norm else 'bias'))

            filters = layer.filters
            size = layer.kernel_size[0]
            in_dim = layer.input_shape[-1]

            if batch_norm is None:
                conv_bias = np.fromfile(wf, dtype=np.float32, count=filters)
            else:
                # darknet [beta, gamma, mean, variance]
                bn_weights = np.fromfile(wf,
                                         dtype=np.float32,
                                         count=4 * filters)
                # tf [gamma, beta, mean, variance]
                bn_weights = bn_weights.reshape((4, filters))[[1, 0, 2, 3]]

            # darknet shape (out_dim, in_dim, height, width)
            conv_shape = (filters, in_dim, size, size)
            conv_weights = np.fromfile(wf,
                                       dtype=np.float32,
                                       count=np.product(conv_shape))
            # tf shape (height, width, in_dim, out_dim)
            conv_weights = conv_weights.reshape(conv_shape).transpose(
                [2, 3, 1, 0])

            if batch_norm is None:
                layer.set_weights([conv_weights, conv_bias])
            else:
                layer.set_weights([conv_weights])
                batch_norm.set_weights(bn_weights)

    assert len(wf.read()) == 0, 'failed to read all data'
    wf.close()
Ejemplo n.º 7
0
 def __init__(self,
              input_shape,
              bgnet_input: tf.keras.Model,
              output_channels=15):
     self.inputs = tf.keras.Input(shape=input_shape, name='bgnet_input')
     self.img_input = bgnet_input.layers[0]
     self.bgnet_input: tf.keras.Model = bgnet_input(self.inputs)
     self.bgnet_output_layer = bgnet_input.get_layer('output')
     self.output_channels = output_channels
     self.outputs = None
     self.model = self._build_model()
Ejemplo n.º 8
0
def get_last_conv_ancestor(model: tf.keras.Model,
                           layer: tf.keras.layers.Layer):
    prev_name = layer.input.name.split('/')[0]
    try:
        layer = model.get_layer(name=prev_name)
    except ValueError:
        return None
    if type(layer) == tf.keras.layers.Conv2D:
        return layer
    else:
        return get_last_conv_ancestor(model, layer)
Ejemplo n.º 9
0
def transfer_weights(source_model: tf.keras.Model,
                     target_model: tf.keras.Model,
                     is_cloned: bool = False,
                     layer_name_prefix: str = '',
                     beta: float = 0.0) -> None:
    """Linear beta-interpolation of weights from source_model to target_model.

    Can be used to maintain a shadow exponential moving average of source_model. Only weights of layers with the same
    name in both models and both starting with 'layer_name_prefix' are transferred.

    If target_model and source_model are clones and share the exact same topology a significantly faster implementation
    is used. If is_cloned is False, this function assumes source_model is a topological sub-network of target_model; in
    that case missing layers in either target_model or source_model are silently ignored.

    Args:
        source_model: the source to transfer weights from
        target_model: the target to transfer weights to
        is_cloned: whether or not source and target are exact clones (significantly speeds up computation)
        layer_name_prefix: only layers, which names start with layer_name_prefix, are transferred
        beta: value for linear interpolation; must be within [0.0, 1.0)

    Raises:
        ValueError: if beta exceeds interval [0.0, 1.0)
    """
    if not 0.0 <= beta < 1.0:
        raise ValueError(
            f'beta must be within [0.0, 1.0) but received beta={beta} instead')

    if is_cloned:  # same exact layer order and topology in both models
        for source_layer, target_layer in zip(source_model.layers,
                                              target_model.layers):
            if source_layer.name == target_layer.name and source_layer.name.startswith(
                    layer_name_prefix):
                for source_var, target_var in zip(source_layer.variables,
                                                  target_layer.variables):
                    delta_value = (1 - beta) * (target_var - source_var)
                    target_var.assign_sub(delta_value)
    else:  # iterate source_model.layers and transfer to target_layer, if target_layer exists
        for source_layer in source_model.layers:
            source_vars = source_layer.variables
            if source_layer.name.startswith(layer_name_prefix) and source_vars:
                try:
                    target_layer = target_model.get_layer(
                        name=source_layer.name)
                except ValueError:
                    continue
                for source_var, target_var in zip(source_vars,
                                                  target_layer.variables):
                    delta_value = (1 - beta) * (target_var - source_var)
                    target_var.assign_sub(delta_value)
Ejemplo n.º 10
0
def freeze_unfreeze_layers(model: tf.keras.Model,
                           layers: List[str] = [],
                           freeze_mode: bool = True):
    '''freeze unfreeze layers for a given model
        Args:
            model: The model to freeze unfreeze its layers
            layers: a list of layers to be frozen unfrozen
                    empty list means the operation will be
                    applied to all layers of the given model
            freeze_mode: True to freeze the layers,
                         False to unfreeze the layers
    '''
    trainable = not (freeze_mode)
    if len(layers) == 0:
        for layer in model.layers:
            layer.trainable = trainable
        return
    for layer in layers:
        model.get_layer(layer).trainable = trainable
    def _make_gradcam_heatmap(self,
                              img_array: np.ndarray,
                              model: tf.keras.Model,
                              layer_name: str,
                              pred_index: Optional[int] = None) -> np.ndarray:
        # function is taken from https://keras.io/examples/vision/grad_cam/

        # First, we create a model that maps the input image to the activations
        # of the last conv layer as well as the output predictions
        grad_model = tf.keras.models.Model(
            [model.inputs], [model.get_layer(layer_name).output, model.output])

        # Then, we compute the gradient of the top predicted class for our input image
        # with respect to the activations of the last conv layer
        with tf.GradientTape() as tape:
            last_conv_layer_output, preds = grad_model(img_array)
            if pred_index is None:
                pred_index = tf.argmax(preds[0])
            class_channel = preds[:, pred_index]

        # This is the gradient of the output neuron (top predicted or chosen)
        # with regard to the output feature map of the last conv layer
        grads = tape.gradient(class_channel, last_conv_layer_output)

        # This is a vector where each entry is the mean intensity of the gradient
        # over a specific feature map channel
        pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

        # We multiply each channel in the feature map array
        # by "how important this channel is" with regard to the top predicted class
        # then sum all the channels to obtain the heatmap class activation
        last_conv_layer_output = last_conv_layer_output[0]
        heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
        heatmap = tf.squeeze(heatmap)

        # For visualization purpose, we will also normalize the heatmap between 0 & 1
        heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
        return heatmap.numpy()
Ejemplo n.º 12
0
 def _set_feature_layes(base_model: tf.keras.Model, feature_specs: List[FeatureSpec]):
     outputs = [base_model.get_layer(spec.layer_name).output for spec in feature_specs]
     return tf.keras.Model(inputs=base_model.input, outputs=outputs)
Ejemplo n.º 13
0
    def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLayout:
        """
        Computes necessary model transformations (pruning mask insertions) to enable pruning.

        :param model: The original uncompressed model.
        :return: The instance of the `TransformationLayout` class containing
            a list of pruning mask insertions.
        """
        converter = TFModelConverterFactory.create(model)
        self._graph = converter.convert()
        groups_of_nodes_to_prune = self._pruning_node_selector.create_pruning_groups(self._graph)

        transformations = TFTransformationLayout()
        shared_layers = set()

        self._pruned_layer_groups_info = Clusterization[PrunedLayerInfo](lambda x: x.layer_name)

        for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()):
            group_minfos = []
            for node in group.elements:
                layer_name = get_layer_identifier(node)
                layer = model.get_layer(layer_name)
                group_minfos.append(PrunedLayerInfo(node.node_name, layer_name, node.node_id,
                                                    is_prunable_depthwise_conv(node)))

                # Add output_mask to elements to run mask_propagation
                # and detect spec_nodes that will be pruned.
                # It should be done for all elements of shared layer.
                node.data['output_mask'] = TFNNCFTensor(tf.ones(node.layer_attributes.out_channels))
                if layer_name in shared_layers:
                    continue
                if node.is_shared():
                    shared_layers.add(layer_name)
                # Check that we need to prune weights in this op
                assert self._is_pruned_layer(layer)
                nncf_logger.info('Adding Weight Pruner in: %s', layer_name)

                _, layer_info = converter.get_layer_info_for_node(node.node_name)
                for weight_def in node.metatype.weight_definitions:
                    transformations.register(
                        self._get_insertion_command_binary_mask(
                            layer_info.layer_name, weight_def.weight_attr_name)
                    )
                if node.metatype.bias_attr_name is not None and \
                        getattr(layer, node.metatype.bias_attr_name) is not None:
                    transformations.register(
                        self._get_insertion_command_binary_mask(
                            layer_info.layer_name, node.metatype.bias_attr_name)
                    )

            cluster = Cluster[PrunedLayerInfo](i, group_minfos, [n.node_id for n in group.elements])
            self._pruned_layer_groups_info.add_cluster(cluster)

        # Propagating masks across the graph to detect spec_nodes that will be pruned
        mask_propagator = MaskPropagationAlgorithm(self._graph, TF_PRUNING_OPERATOR_METATYPES,
                                                   TFNNCFPruningTensorProcessor)
        mask_propagator.mask_propagation()

        # Add masks for all spec modules, because prunable batchnorm layers can be determined
        # at the moment of mask propagation
        types_spec_layers = [TFBatchNormalizationLayerMetatype] \
            if self._prune_batch_norms else []

        spec_nodes = self._graph.get_nodes_by_metatypes(types_spec_layers)
        for spec_node in spec_nodes:
            layer_name = get_layer_identifier(spec_node)
            layer = model.get_layer(layer_name)
            if spec_node.data['output_mask'] is None:
                # Skip elements that will not be pruned
                continue
            if layer_name in shared_layers:
                continue
            if spec_node.is_shared():
                shared_layers.add(layer_name)
            nncf_logger.info('Adding Weight Pruner in: %s', layer_name)

            _, layer_info = converter.get_layer_info_for_node(spec_node.node_name)
            for weight_def in spec_node.metatype.weight_definitions:
                if spec_node.metatype is TFBatchNormalizationLayerMetatype \
                        and not layer.scale and weight_def.weight_attr_name == 'gamma':
                    nncf_logger.debug('Fused gamma parameter encountered in BatchNormalization layer. '
                                      'Do not add mask to it.')
                    continue

                transformations.register(
                    self._get_insertion_command_binary_mask(
                        layer_info.layer_name, weight_def.weight_attr_name)
                )
            transformations.register(
                self._get_insertion_command_binary_mask(
                    layer_info.layer_name, spec_node.metatype.bias_attr_name)
            )
        return transformations
Ejemplo n.º 14
0
def save_embedding_layer_weights(model: tf.keras.Model) -> None:
    embedding_layer = model.get_layer(EMBEDDING_MODEL_NAME)
    embedding_layer.save(TRAINED_EMBEDDING_MODEL_PATH)
    print("Saved model to", TRAINED_EMBEDDING_MODEL_PATH)
Ejemplo n.º 15
0
    def get_transformation_layout(
            self, model: tf.keras.Model) -> TFTransformationLayout:
        """
        Computes necessary model transformations (pruning mask insertions) to enable pruning.

        :param model: The original uncompressed model.
        :return: The instance of the `TransformationLayout` class containing
            a list of pruning mask insertions.
        """
        self._graph = convert_keras_model_to_nncf_graph(model)
        groups_of_nodes_to_prune = self._pruning_node_selector.create_pruning_groups(
            self._graph)

        transformations = TFTransformationLayout()
        shared_layers = set()

        self._pruned_layer_groups_info = Clusterization('layer_name')

        for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()):
            group_minfos = []
            for node in group.nodes:
                layer_name = get_layer_identifier(node)
                layer = model.get_layer(layer_name)

                # Add output_mask to nodes to run mask_propagation
                # and detect spec_nodes that will be pruned.
                # It should be done for all nodes of shared layer.
                node.data['output_mask'] = tf.ones(
                    node.module_attributes.out_channels)
                if layer_name in shared_layers:
                    continue
                if is_shared(node):
                    shared_layers.add(layer_name)
                # Check that we need to prune weights in this op
                assert self._is_pruned_layer(layer)
                nncf_logger.info('Adding Weight Pruner in: %s', layer_name)
                for attr_name_key in [WEIGHT_ATTR_NAME, BIAS_ATTR_NAME]:
                    attr_name = LAYERS_WITH_WEIGHTS[
                        node.node_type][attr_name_key]
                    if getattr(layer, attr_name) is not None:
                        transformations.register(
                            self._get_insertion_command_binary_mask(
                                layer_name, attr_name))
                group_minfos.append(PrunedLayerInfo(layer_name, node.node_id))

            cluster = NodesCluster(i, group_minfos,
                                   [n.node_id for n in group.nodes])
            self._pruned_layer_groups_info.add_cluster(cluster)

        # Propagating masks across the graph to detect spec_nodes that will be pruned
        mask_propagator = MaskPropagationAlgorithm(
            self._graph, TF_PRUNING_OPERATOR_METATYPES)
        mask_propagator.mask_propagation()

        # Add masks for all spec modules, because prunable batchnorm layers can be determines
        # at the moment of mask propagation
        types_spec_layers = list(SPECIAL_LAYERS_WITH_WEIGHTS)
        if not self._prune_batch_norms:
            types_spec_layers.remove('BatchNormalization')

        spec_nodes = self._graph.get_nodes_by_types(types_spec_layers)
        for spec_node in spec_nodes:
            layer_name = get_layer_identifier(spec_node)
            if spec_node.data['output_mask'] is None:
                # Skip nodes that will not be pruned
                continue
            if layer_name in shared_layers:
                continue
            if is_shared(spec_node):
                shared_layers.add(layer_name)
            nncf_logger.info('Adding Weight Pruner in: %s', layer_name)
            for attr_name_key in [WEIGHT_ATTR_NAME, BIAS_ATTR_NAME]:
                attr_name = SPECIAL_LAYERS_WITH_WEIGHTS[
                    spec_node.node_type][attr_name_key]
                transformations.register(
                    self._get_insertion_command_binary_mask(
                        layer_name, attr_name))
        return transformations
Ejemplo n.º 16
0
 def gradcam(cls, x, model: tf.keras.Model, layer_name: str):
     y = model.predict(x)
     layer = model.get_layer(name=layer_name)
Ejemplo n.º 17
0
def load_model_weights_from_checkpoint(model: tf.keras.Model,
                                       config: Dict[str, str],
                                       checkpoint_file: str,
                                       training: bool = False) -> NoReturn:
    '''Load trained official model from checkpoint.
    Args:
        model (tf.keras.Model): Built keras model.
        config (object) : Loaded configuration file.
        checkpoint_file (str): The path to the checkpoint files, should end with '.ckpt'.
        training (bool): If training, the whole model will be returned.
                     Otherwise, the MLM and NSP parts will be ignored.
    '''
    loader = checkpoint_loader(checkpoint_file)

    model.get_layer(name='Embedding-Token').set_weights([
        loader('bert/embeddings/word_embeddings'),
    ])
    model.get_layer(name='Embedding-Position').set_weights([
        loader('bert/embeddings/position_embeddings')
        [:config['max_position_embeddings'], :],
    ])
    model.get_layer(name='Embedding-Segment').set_weights([
        loader('bert/embeddings/token_type_embeddings'),
    ])
    model.get_layer(name='Embedding-Norm').set_weights([
        loader('bert/embeddings/LayerNorm/gamma'),
        loader('bert/embeddings/LayerNorm/beta'),
    ])
    for i in range(config['num_hidden_layers']):
        model.get_layer(
            name='Encoder-%d-MultiHeadSelfAttention' % (i + 1)).set_weights([
                loader('bert/encoder/layer_%d/attention/self/query/kernel' % i),
                loader('bert/encoder/layer_%d/attention/self/query/bias' % i),
                loader('bert/encoder/layer_%d/attention/self/key/kernel' % i),
                loader('bert/encoder/layer_%d/attention/self/key/bias' % i),
                loader('bert/encoder/layer_%d/attention/self/value/kernel' % i),
                loader('bert/encoder/layer_%d/attention/self/value/bias' % i),
                loader('bert/encoder/layer_%d/attention/output/dense/kernel' %
                       i),
                loader('bert/encoder/layer_%d/attention/output/dense/bias' % i),
            ])
        model.get_layer(
            name='Encoder-%d-MultiHeadSelfAttention-Norm' %
            (i + 1)).set_weights([
                loader(
                    'bert/encoder/layer_%d/attention/output/LayerNorm/gamma' %
                    i),
                loader('bert/encoder/layer_%d/attention/output/LayerNorm/beta' %
                       i),
            ])
        model.get_layer(
            name='Encoder-%d-MultiHeadSelfAttention-Norm' %
            (i + 1)).set_weights([
                loader(
                    'bert/encoder/layer_%d/attention/output/LayerNorm/gamma' %
                    i),
                loader('bert/encoder/layer_%d/attention/output/LayerNorm/beta' %
                       i),
            ])
        model.get_layer(name='Encoder-%d-FeedForward' % (i + 1)).set_weights([
            loader('bert/encoder/layer_%d/intermediate/dense/kernel' % i),
            loader('bert/encoder/layer_%d/intermediate/dense/bias' % i),
            loader('bert/encoder/layer_%d/output/dense/kernel' % i),
            loader('bert/encoder/layer_%d/output/dense/bias' % i),
        ])
        model.get_layer(
            name='Encoder-%d-FeedForward-Norm' % (i + 1)).set_weights([
                loader('bert/encoder/layer_%d/output/LayerNorm/gamma' % i),
                loader('bert/encoder/layer_%d/output/LayerNorm/beta' % i),
            ])
    if training:
        model.get_layer(name='MLM-Dense').set_weights([
            loader('cls/predictions/transform/dense/kernel'),
            loader('cls/predictions/transform/dense/bias'),
        ])
        model.get_layer(name='MLM-Norm').set_weights([
            loader('cls/predictions/transform/LayerNorm/gamma'),
            loader('cls/predictions/transform/LayerNorm/beta'),
        ])
        model.get_layer(name='MLM-Sim').set_weights([
            loader('cls/predictions/output_bias'),
        ])
        model.get_layer(name='NSP-Dense').set_weights([
            loader('bert/pooler/dense/kernel'),
            loader('bert/pooler/dense/bias'),
        ])
        model.get_layer(name='NSP').set_weights([
            np.transpose(loader('cls/seq_relationship/output_weights')),
            loader('cls/seq_relationship/output_bias'),
        ])