예제 #1
0
파일: tfkeras.py 프로젝트: ypeleg/keras-swa
    def _reset_batch_norm(self):

        for layer in self.batch_norm_layers:

            # to get properly initialized moving mean and moving variance weights
            # we initialize a new batch norm layer from the config of the existing
            # layer, build that layer, retrieve its reinitialized moving mean and
            # moving var weights and then delete the layer
            bn_config = layer.get_config()
            new_batch_norm = BatchNormalization(**bn_config)
            new_batch_norm.build(layer.input_shape)
            new_moving_mean, new_moving_var = new_batch_norm.get_weights()[-2:]
            # get rid of the new_batch_norm layer
            del new_batch_norm
            # get the trained gamma and beta from the current batch norm layer
            trained_weights = layer.get_weights()
            new_weights = []
            # get gamma if exists
            if bn_config['scale']:
                new_weights.append(trained_weights.pop(0))
            # get beta if exists
            if bn_config['center']:
                new_weights.append(trained_weights.pop(0))
            new_weights += [new_moving_mean, new_moving_var]
            # set weights to trained gamma and beta, reinitialized mean and variance
            layer.set_weights(new_weights)
예제 #2
0
    def _reset_batch_norm(self):

        for layer in self.batch_norm_layers:

            # to get properly initialized moving mean and moving variance weights
            # we initialize a new batch norm layer from the config of the existing
            # layer, build that layer, retrieve its reinitialized moving mean and
            # moving var weights and then delete the layer
            new_batch_norm = BatchNormalization(**layer.get_config())
            new_batch_norm.build(layer.input_shape)
            _, _, new_moving_mean, new_moving_var = new_batch_norm.get_weights(
            )

            # get rid of the new_batch_norm layer
            del new_batch_norm

            # get the trained gamma and beta from the current batch norm layer
            trained_gamma, trained_beta, _, _ = layer.get_weights()

            # set weights to trained gamma and beta, reinitialized mean and variance
            layer.set_weights(
                [trained_gamma, trained_beta, new_moving_mean, new_moving_var])