示例#1
0
def main(unused_argv):
    init_timer = timer.Timer()
    init_timer.Start()
    if FLAGS.preload_gin_config:
        # Load default values from the original experiment, always the first one.
        with gin.unlock_config():
            gin.parse_config_file(FLAGS.preload_gin_config, skip_unknown=True)
        logging.info('Operative Gin configurations loaded from: %s',
                     FLAGS.preload_gin_config)
    gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)

    data_train, data_test, info = utils.get_dataset()
    input_shape = info.features['image'].shape
    num_classes = info.features['label'].num_classes
    logging.info('Input Shape: %s', input_shape)
    logging.info('train samples: %s', info.splits['train'].num_examples)
    logging.info('test samples: %s', info.splits['test'].num_examples)
    data_eval = data_train if FLAGS.eval_on_train else data_test
    pruning_params = utils.get_pruning_params(mode='constant')
    mask_load_dict = {-1: None, 0: FLAGS.ckpt_start, 1: FLAGS.ckpt_end}
    mask_path = mask_load_dict[FLAGS.load_mask_from]
    # Currently we interpolate only on the same sparse space.
    model_start = utils.get_network(pruning_params,
                                    input_shape,
                                    num_classes,
                                    mask_init_path=mask_path,
                                    weight_init_path=FLAGS.ckpt_start)
    model_start.summary()
    model_end = utils.get_network(pruning_params,
                                  input_shape,
                                  num_classes,
                                  mask_init_path=mask_path,
                                  weight_init_path=FLAGS.ckpt_end)
    model_end.summary()

    # Create a third network for interpolation.
    model_inter = utils.get_network(pruning_params,
                                    input_shape,
                                    num_classes,
                                    mask_init_path=mask_path,
                                    weight_init_path=FLAGS.ckpt_end)
    logging.info('Performance at init (model_start:')
    test_model(model_start, data_eval)
    logging.info('Performance at init (model_end:')
    test_model(model_end, data_eval)
    all_results = interpolate(model_start=model_start,
                              model_end=model_end,
                              model_inter=model_inter,
                              d_set=data_eval)

    tf.io.gfile.makedirs(FLAGS.logdir)
    results_path = os.path.join(FLAGS.logdir, 'all_results')
    with tf.io.gfile.GFile(results_path, 'wb') as f:
        np.save(f, all_results)
    logging.info('Total runtime: %.3f s', init_timer.GetDuration())
    logconfigfile_path = os.path.join(FLAGS.logdir, 'operative_config.gin')
    with tf.io.gfile.GFile(logconfigfile_path, 'w') as f:
        f.write('# Gin-Config:\n %s' % gin.config.operative_config_str())
示例#2
0
文件: train.py 项目: yaelandau22/rigl
def main(unused_argv):
  tf.random.set_seed(FLAGS.seed)
  init_timer = timer.Timer()
  init_timer.Start()

  if FLAGS.mode == 'hessian':
    # Load default values from the original experiment.
    FLAGS.preload_gin_config = os.path.join(FLAGS.logdir,
                                            'operative_config.gin')

  # Maybe preload a gin config.
  if FLAGS.preload_gin_config:
    config_path = FLAGS.preload_gin_config
    gin.parse_config_file(config_path)
    logging.info('Gin configuration pre-loaded from: %s', config_path)

  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)
  ds_train, ds_test, info = utils.get_dataset()
  input_shape = info.features['image'].shape
  num_classes = info.features['label'].num_classes
  logging.info('Input Shape: %s', input_shape)
  logging.info('train samples: %s', info.splits['train'].num_examples)
  logging.info('test samples: %s', info.splits['test'].num_examples)

  pruning_params = utils.get_pruning_params()
  model = utils.get_network(pruning_params, input_shape, num_classes)
  model.summary(print_fn=logging.info)
  if FLAGS.mode == 'train_eval':
    train_model(model, ds_train, ds_test, FLAGS.logdir)
  elif FLAGS.mode == 'hessian':
    test_model(model, ds_test)
    hessian(model, ds_train, FLAGS.logdir)
  logging.info('Total runtime: %.3f s', init_timer.GetDuration())

  logconfigfile_path = os.path.join(
      FLAGS.logdir,
      'hessian_' if FLAGS.mode == 'hessian' else '' + 'operative_config.gin')
  with tf.io.gfile.GFile(logconfigfile_path, 'w') as f:
    f.write('# Gin-Config:\n %s' % gin.config.operative_config_str())
示例#3
0
文件: train.py 项目: yaelandau22/rigl
def sparse_hessian_calculator(model,
                              data,
                              rows_at_once,
                              eigvals_path,
                              overwrite,
                              is_dense_spectrum=False):
  """Calculates the Hessian of the model parameters. Biases are dense."""
  # Read all data at once
  x_batch, y_batch = list(data.batch(100000))[0]

  if tf.io.gfile.exists(eigvals_path) and overwrite:
    logging.info('Deleting existing Eigvals: %s', eigvals_path)
    tf.io.gfile.rmtree(eigvals_path)
  if tf.io.gfile.exists(eigvals_path):
    with tf.io.gfile.GFile(eigvals_path, 'rb') as f:
      eigvals = np.load(f)
    logging.info('Eigvals exists, skipping :%s', eigvals_path)
    return eigvals

  # First lets create lists that indicate the valid dimension of each variable.
  # If we want to calculate sparse spectrum, then we have to omit masked
  # dimensions. Biases are dense, therefore have masks of 1's.
  masks = []
  variables = []
  layer_group_indices = []
  for l in model.layers:
    if isinstance(l, utils.PRUNING_WRAPPER):
      # TODO following the outcome of b/148083099, update following.
      # Add the weight, mask and the valid dimensions.
      weight = l.weights[0]
      variables.append(weight)

      mask = l.weights[2]
      masks.append(mask)
      logging.info(mask.shape)

      if is_dense_spectrum:
        n_params = tf.size(mask)
        layer_group_indices.append(tf.range(n_params))
      else:
        fmask = tf.reshape(mask, [-1])
        indices = tf.where(tf.equal(fmask, 1))[:, 0]
        layer_group_indices.append(indices)
      # Add the bias mask of ones and all of its dimensions.
      bias = l.weights[1]
      variables.append(bias)
      masks.append(tf.ones_like(bias))
      layer_group_indices.append(tf.range(tf.size(bias)))
    else:
      # For now we assume all parameterized layers are wrapped with
      # PruneLowMagnitude.
      assert not l.trainable_variables
  result_all = []
  init_timer = timer.Timer()
  init_timer.Start()
  n_total = 0
  logging.info('Calculating Hessian...')
  for i, inds in enumerate(layer_group_indices):
    n_split = np.ceil(tf.size(inds).numpy() / rows_at_once)
    logging.info('Nsplit: %d', n_split)
    for c_slice in np.array_split(inds.numpy(), n_split):
      res = get_rows(model, variables, masks, i, c_slice, x_batch, y_batch,
                     is_dense_spectrum)
      result_all.append(res.numpy())
      n_total += res.shape[0]
      target_n = float(res.shape[1])
    logging.info('%.3f %% ..', (n_total / target_n))
  # We convert in numpy so that it is on cpu automatically and we don't get OOM.
  c_hessian = np.concatenate(result_all, 0)
  logging.info('Total runtime for hessian: %.3f s', init_timer.GetDuration())
  init_timer.Start()
  eigens = jax.jit(eigh, backend='cpu')(c_hessian)
  eigvals = np.asarray(eigens[0])
  with tf.io.gfile.GFile(eigvals_path, 'wb') as f:
    np.save(f, eigvals)
  logging.info('EigVals saved: %s', eigvals_path)
  logging.info('Total runtime for eigvals: %.3f s', init_timer.GetDuration())
  return eigvals