def main(unused_argv): init_timer = timer.Timer() init_timer.Start() if FLAGS.preload_gin_config: # Load default values from the original experiment, always the first one. with gin.unlock_config(): gin.parse_config_file(FLAGS.preload_gin_config, skip_unknown=True) logging.info('Operative Gin configurations loaded from: %s', FLAGS.preload_gin_config) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) data_train, data_test, info = utils.get_dataset() input_shape = info.features['image'].shape num_classes = info.features['label'].num_classes logging.info('Input Shape: %s', input_shape) logging.info('train samples: %s', info.splits['train'].num_examples) logging.info('test samples: %s', info.splits['test'].num_examples) data_eval = data_train if FLAGS.eval_on_train else data_test pruning_params = utils.get_pruning_params(mode='constant') mask_load_dict = {-1: None, 0: FLAGS.ckpt_start, 1: FLAGS.ckpt_end} mask_path = mask_load_dict[FLAGS.load_mask_from] # Currently we interpolate only on the same sparse space. model_start = utils.get_network(pruning_params, input_shape, num_classes, mask_init_path=mask_path, weight_init_path=FLAGS.ckpt_start) model_start.summary() model_end = utils.get_network(pruning_params, input_shape, num_classes, mask_init_path=mask_path, weight_init_path=FLAGS.ckpt_end) model_end.summary() # Create a third network for interpolation. model_inter = utils.get_network(pruning_params, input_shape, num_classes, mask_init_path=mask_path, weight_init_path=FLAGS.ckpt_end) logging.info('Performance at init (model_start:') test_model(model_start, data_eval) logging.info('Performance at init (model_end:') test_model(model_end, data_eval) all_results = interpolate(model_start=model_start, model_end=model_end, model_inter=model_inter, d_set=data_eval) tf.io.gfile.makedirs(FLAGS.logdir) results_path = os.path.join(FLAGS.logdir, 'all_results') with tf.io.gfile.GFile(results_path, 'wb') as f: np.save(f, all_results) logging.info('Total runtime: %.3f s', init_timer.GetDuration()) logconfigfile_path = os.path.join(FLAGS.logdir, 'operative_config.gin') with tf.io.gfile.GFile(logconfigfile_path, 'w') as f: f.write('# Gin-Config:\n %s' % gin.config.operative_config_str())
def main(unused_argv): tf.random.set_seed(FLAGS.seed) init_timer = timer.Timer() init_timer.Start() if FLAGS.mode == 'hessian': # Load default values from the original experiment. FLAGS.preload_gin_config = os.path.join(FLAGS.logdir, 'operative_config.gin') # Maybe preload a gin config. if FLAGS.preload_gin_config: config_path = FLAGS.preload_gin_config gin.parse_config_file(config_path) logging.info('Gin configuration pre-loaded from: %s', config_path) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) ds_train, ds_test, info = utils.get_dataset() input_shape = info.features['image'].shape num_classes = info.features['label'].num_classes logging.info('Input Shape: %s', input_shape) logging.info('train samples: %s', info.splits['train'].num_examples) logging.info('test samples: %s', info.splits['test'].num_examples) pruning_params = utils.get_pruning_params() model = utils.get_network(pruning_params, input_shape, num_classes) model.summary(print_fn=logging.info) if FLAGS.mode == 'train_eval': train_model(model, ds_train, ds_test, FLAGS.logdir) elif FLAGS.mode == 'hessian': test_model(model, ds_test) hessian(model, ds_train, FLAGS.logdir) logging.info('Total runtime: %.3f s', init_timer.GetDuration()) logconfigfile_path = os.path.join( FLAGS.logdir, 'hessian_' if FLAGS.mode == 'hessian' else '' + 'operative_config.gin') with tf.io.gfile.GFile(logconfigfile_path, 'w') as f: f.write('# Gin-Config:\n %s' % gin.config.operative_config_str())
def sparse_hessian_calculator(model, data, rows_at_once, eigvals_path, overwrite, is_dense_spectrum=False): """Calculates the Hessian of the model parameters. Biases are dense.""" # Read all data at once x_batch, y_batch = list(data.batch(100000))[0] if tf.io.gfile.exists(eigvals_path) and overwrite: logging.info('Deleting existing Eigvals: %s', eigvals_path) tf.io.gfile.rmtree(eigvals_path) if tf.io.gfile.exists(eigvals_path): with tf.io.gfile.GFile(eigvals_path, 'rb') as f: eigvals = np.load(f) logging.info('Eigvals exists, skipping :%s', eigvals_path) return eigvals # First lets create lists that indicate the valid dimension of each variable. # If we want to calculate sparse spectrum, then we have to omit masked # dimensions. Biases are dense, therefore have masks of 1's. masks = [] variables = [] layer_group_indices = [] for l in model.layers: if isinstance(l, utils.PRUNING_WRAPPER): # TODO following the outcome of b/148083099, update following. # Add the weight, mask and the valid dimensions. weight = l.weights[0] variables.append(weight) mask = l.weights[2] masks.append(mask) logging.info(mask.shape) if is_dense_spectrum: n_params = tf.size(mask) layer_group_indices.append(tf.range(n_params)) else: fmask = tf.reshape(mask, [-1]) indices = tf.where(tf.equal(fmask, 1))[:, 0] layer_group_indices.append(indices) # Add the bias mask of ones and all of its dimensions. bias = l.weights[1] variables.append(bias) masks.append(tf.ones_like(bias)) layer_group_indices.append(tf.range(tf.size(bias))) else: # For now we assume all parameterized layers are wrapped with # PruneLowMagnitude. assert not l.trainable_variables result_all = [] init_timer = timer.Timer() init_timer.Start() n_total = 0 logging.info('Calculating Hessian...') for i, inds in enumerate(layer_group_indices): n_split = np.ceil(tf.size(inds).numpy() / rows_at_once) logging.info('Nsplit: %d', n_split) for c_slice in np.array_split(inds.numpy(), n_split): res = get_rows(model, variables, masks, i, c_slice, x_batch, y_batch, is_dense_spectrum) result_all.append(res.numpy()) n_total += res.shape[0] target_n = float(res.shape[1]) logging.info('%.3f %% ..', (n_total / target_n)) # We convert in numpy so that it is on cpu automatically and we don't get OOM. c_hessian = np.concatenate(result_all, 0) logging.info('Total runtime for hessian: %.3f s', init_timer.GetDuration()) init_timer.Start() eigens = jax.jit(eigh, backend='cpu')(c_hessian) eigvals = np.asarray(eigens[0]) with tf.io.gfile.GFile(eigvals_path, 'wb') as f: np.save(f, eigvals) logging.info('EigVals saved: %s', eigvals_path) logging.info('Total runtime for eigvals: %.3f s', init_timer.GetDuration()) return eigvals