def _reset_batch_norm(self): for layer in self.batch_norm_layers: # to get properly initialized moving mean and moving variance weights # we initialize a new batch norm layer from the config of the existing # layer, build that layer, retrieve its reinitialized moving mean and # moving var weights and then delete the layer bn_config = layer.get_config() new_batch_norm = BatchNormalization(**bn_config) new_batch_norm.build(layer.input_shape) new_moving_mean, new_moving_var = new_batch_norm.get_weights()[-2:] # get rid of the new_batch_norm layer del new_batch_norm # get the trained gamma and beta from the current batch norm layer trained_weights = layer.get_weights() new_weights = [] # get gamma if exists if bn_config['scale']: new_weights.append(trained_weights.pop(0)) # get beta if exists if bn_config['center']: new_weights.append(trained_weights.pop(0)) new_weights += [new_moving_mean, new_moving_var] # set weights to trained gamma and beta, reinitialized mean and variance layer.set_weights(new_weights)
def wider_bn(layer, start_dim, total_dim, n_add): """Get new layer with wider batch normalization for current layer Args: layer: the layer from which we get new layer with wider batch normalization start_dim: the started dimension total_dim: the total dimension n_add: the output shape Returns: The new layer with wider batch normalization """ weights = layer.get_weights() input_shape = list((None, ) * layer.input_spec.ndim) input_shape[-1] = get_int_tuple(layer.gamma.shape)[0] input_shape[-1] += n_add temp_layer = BatchNormalization() add_input_shape = list(input_shape) add_input_shape[-1] = n_add temp_layer.build(tuple(add_input_shape)) new_weights = temp_layer.get_weights() student_w = tuple() for weight, new_weight in zip(weights, new_weights): temp_w = weight.copy() temp_w = np.concatenate( (temp_w[:start_dim], new_weight, temp_w[start_dim:total_dim])) student_w += (temp_w, ) new_layer = BatchNormalization() new_layer.build(input_shape) new_layer.set_weights(student_w) return new_layer
def build(self, input_shape): if self.hidden_layers: for layer in self.hidden_layers: layer.build(input_shape) layer.built = True input_shape = layer.get_output_shape_for(input_shape) norm = BatchNormalization(mode=2) norm.build(input_shape) norm.built = True input_shape = norm.get_output_shape_for(input_shape) self._layers.append(layer) self._layers.append(norm) layer = self.output_layer layer.build(input_shape) layer.built = True self._layers.append(layer) super(MLPClassifierLayer, self).build(input_shape)
def _reset_batch_norm(self): for layer in self.batch_norm_layers: # to get properly initialized moving mean and moving variance weights # we initialize a new batch norm layer from the config of the existing # layer, build that layer, retrieve its reinitialized moving mean and # moving var weights and then delete the layer new_batch_norm = BatchNormalization(**layer.get_config()) new_batch_norm.build(layer.input_shape) _, _, new_moving_mean, new_moving_var = new_batch_norm.get_weights( ) # get rid of the new_batch_norm layer del new_batch_norm # get the trained gamma and beta from the current batch norm layer trained_gamma, trained_beta, _, _ = layer.get_weights() # set weights to trained gamma and beta, reinitialized mean and variance layer.set_weights( [trained_gamma, trained_beta, new_moving_mean, new_moving_var])