def main(unused_argv): init_timer = timer.Timer() init_timer.Start() if FLAGS.preload_gin_config: # Load default values from the original experiment, always the first one. with gin.unlock_config(): gin.parse_config_file(FLAGS.preload_gin_config, skip_unknown=True) logging.info('Operative Gin configurations loaded from: %s', FLAGS.preload_gin_config) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) data_train, data_test, info = utils.get_dataset() input_shape = info.features['image'].shape num_classes = info.features['label'].num_classes logging.info('Input Shape: %s', input_shape) logging.info('train samples: %s', info.splits['train'].num_examples) logging.info('test samples: %s', info.splits['test'].num_examples) data_eval = data_train if FLAGS.eval_on_train else data_test pruning_params = utils.get_pruning_params(mode='constant') mask_load_dict = {-1: None, 0: FLAGS.ckpt_start, 1: FLAGS.ckpt_end} mask_path = mask_load_dict[FLAGS.load_mask_from] # Currently we interpolate only on the same sparse space. model_start = utils.get_network(pruning_params, input_shape, num_classes, mask_init_path=mask_path, weight_init_path=FLAGS.ckpt_start) model_start.summary() model_end = utils.get_network(pruning_params, input_shape, num_classes, mask_init_path=mask_path, weight_init_path=FLAGS.ckpt_end) model_end.summary() # Create a third network for interpolation. model_inter = utils.get_network(pruning_params, input_shape, num_classes, mask_init_path=mask_path, weight_init_path=FLAGS.ckpt_end) logging.info('Performance at init (model_start:') test_model(model_start, data_eval) logging.info('Performance at init (model_end:') test_model(model_end, data_eval) all_results = interpolate(model_start=model_start, model_end=model_end, model_inter=model_inter, d_set=data_eval) tf.io.gfile.makedirs(FLAGS.logdir) results_path = os.path.join(FLAGS.logdir, 'all_results') with tf.io.gfile.GFile(results_path, 'wb') as f: np.save(f, all_results) logging.info('Total runtime: %.3f s', init_timer.GetDuration()) logconfigfile_path = os.path.join(FLAGS.logdir, 'operative_config.gin') with tf.io.gfile.GFile(logconfigfile_path, 'w') as f: f.write('# Gin-Config:\n %s' % gin.config.operative_config_str())
def main(unused_argv): tf.random.set_seed(FLAGS.seed) init_timer = timer.Timer() init_timer.Start() if FLAGS.mode == 'hessian': # Load default values from the original experiment. FLAGS.preload_gin_config = os.path.join(FLAGS.logdir, 'operative_config.gin') # Maybe preload a gin config. if FLAGS.preload_gin_config: config_path = FLAGS.preload_gin_config gin.parse_config_file(config_path) logging.info('Gin configuration pre-loaded from: %s', config_path) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) ds_train, ds_test, info = utils.get_dataset() input_shape = info.features['image'].shape num_classes = info.features['label'].num_classes logging.info('Input Shape: %s', input_shape) logging.info('train samples: %s', info.splits['train'].num_examples) logging.info('test samples: %s', info.splits['test'].num_examples) pruning_params = utils.get_pruning_params() model = utils.get_network(pruning_params, input_shape, num_classes) model.summary(print_fn=logging.info) if FLAGS.mode == 'train_eval': train_model(model, ds_train, ds_test, FLAGS.logdir) elif FLAGS.mode == 'hessian': test_model(model, ds_test) hessian(model, ds_train, FLAGS.logdir) logging.info('Total runtime: %.3f s', init_timer.GetDuration()) logconfigfile_path = os.path.join( FLAGS.logdir, 'hessian_' if FLAGS.mode == 'hessian' else '' + 'operative_config.gin') with tf.io.gfile.GFile(logconfigfile_path, 'w') as f: f.write('# Gin-Config:\n %s' % gin.config.operative_config_str())