예제 #1
0
def main(unused_argv):
    init_timer = timer.Timer()
    init_timer.Start()
    if FLAGS.preload_gin_config:
        # Load default values from the original experiment, always the first one.
        with gin.unlock_config():
            gin.parse_config_file(FLAGS.preload_gin_config, skip_unknown=True)
        logging.info('Operative Gin configurations loaded from: %s',
                     FLAGS.preload_gin_config)
    gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)

    data_train, data_test, info = utils.get_dataset()
    input_shape = info.features['image'].shape
    num_classes = info.features['label'].num_classes
    logging.info('Input Shape: %s', input_shape)
    logging.info('train samples: %s', info.splits['train'].num_examples)
    logging.info('test samples: %s', info.splits['test'].num_examples)
    data_eval = data_train if FLAGS.eval_on_train else data_test
    pruning_params = utils.get_pruning_params(mode='constant')
    mask_load_dict = {-1: None, 0: FLAGS.ckpt_start, 1: FLAGS.ckpt_end}
    mask_path = mask_load_dict[FLAGS.load_mask_from]
    # Currently we interpolate only on the same sparse space.
    model_start = utils.get_network(pruning_params,
                                    input_shape,
                                    num_classes,
                                    mask_init_path=mask_path,
                                    weight_init_path=FLAGS.ckpt_start)
    model_start.summary()
    model_end = utils.get_network(pruning_params,
                                  input_shape,
                                  num_classes,
                                  mask_init_path=mask_path,
                                  weight_init_path=FLAGS.ckpt_end)
    model_end.summary()

    # Create a third network for interpolation.
    model_inter = utils.get_network(pruning_params,
                                    input_shape,
                                    num_classes,
                                    mask_init_path=mask_path,
                                    weight_init_path=FLAGS.ckpt_end)
    logging.info('Performance at init (model_start:')
    test_model(model_start, data_eval)
    logging.info('Performance at init (model_end:')
    test_model(model_end, data_eval)
    all_results = interpolate(model_start=model_start,
                              model_end=model_end,
                              model_inter=model_inter,
                              d_set=data_eval)

    tf.io.gfile.makedirs(FLAGS.logdir)
    results_path = os.path.join(FLAGS.logdir, 'all_results')
    with tf.io.gfile.GFile(results_path, 'wb') as f:
        np.save(f, all_results)
    logging.info('Total runtime: %.3f s', init_timer.GetDuration())
    logconfigfile_path = os.path.join(FLAGS.logdir, 'operative_config.gin')
    with tf.io.gfile.GFile(logconfigfile_path, 'w') as f:
        f.write('# Gin-Config:\n %s' % gin.config.operative_config_str())
예제 #2
0
파일: train.py 프로젝트: yaelandau22/rigl
def main(unused_argv):
  tf.random.set_seed(FLAGS.seed)
  init_timer = timer.Timer()
  init_timer.Start()

  if FLAGS.mode == 'hessian':
    # Load default values from the original experiment.
    FLAGS.preload_gin_config = os.path.join(FLAGS.logdir,
                                            'operative_config.gin')

  # Maybe preload a gin config.
  if FLAGS.preload_gin_config:
    config_path = FLAGS.preload_gin_config
    gin.parse_config_file(config_path)
    logging.info('Gin configuration pre-loaded from: %s', config_path)

  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)
  ds_train, ds_test, info = utils.get_dataset()
  input_shape = info.features['image'].shape
  num_classes = info.features['label'].num_classes
  logging.info('Input Shape: %s', input_shape)
  logging.info('train samples: %s', info.splits['train'].num_examples)
  logging.info('test samples: %s', info.splits['test'].num_examples)

  pruning_params = utils.get_pruning_params()
  model = utils.get_network(pruning_params, input_shape, num_classes)
  model.summary(print_fn=logging.info)
  if FLAGS.mode == 'train_eval':
    train_model(model, ds_train, ds_test, FLAGS.logdir)
  elif FLAGS.mode == 'hessian':
    test_model(model, ds_test)
    hessian(model, ds_train, FLAGS.logdir)
  logging.info('Total runtime: %.3f s', init_timer.GetDuration())

  logconfigfile_path = os.path.join(
      FLAGS.logdir,
      'hessian_' if FLAGS.mode == 'hessian' else '' + 'operative_config.gin')
  with tf.io.gfile.GFile(logconfigfile_path, 'w') as f:
    f.write('# Gin-Config:\n %s' % gin.config.operative_config_str())