Esempio n. 1
0
def _create_estimator(spec, config, model_dir,
                      num_train_images, num_sample_images=None):
  """Creates the TPUEstimator object."""
  # Estimator will save a checkpoint at the end of every train() call. Disable
  # automatic checkpoints by setting the time interval between checkpoints to
  # a very large value.
  run_config = tf.contrib.tpu.RunConfig(
      model_dir=model_dir,
      keep_checkpoint_max=3,    # Keeps ckpt at start, halfway, and end
      save_checkpoints_secs=2**30,
      tpu_config=tf.contrib.tpu.TPUConfig(
          iterations_per_loop=config['tpu_iterations_per_loop'],
          num_shards=config['tpu_num_shards']))

  # This is a hack to allow PREDICT on a fixed batch on TPU. By replicating the
  # batch by the number of shards, this ensures each TPU core operates on the
  # entire fixed batch.
  if num_sample_images and config['use_tpu']:
    num_sample_images *= config['tpu_num_shards']

  estimator = tf.contrib.tpu.TPUEstimator(
      use_tpu=config['use_tpu'],
      model_fn=model_builder.build_model_fn(
          spec, config, num_train_images),
      config=run_config,
      train_batch_size=config['batch_size'],
      eval_batch_size=config['batch_size'],
      predict_batch_size=num_sample_images)

  return estimator
Esempio n. 2
0
def keytotuple(key):
    cur_network_data = nb.get_metrics_from_hash(key)
    #print(cur_network_data[0])
    #print(cur_network_data[0].keys())
    model = model_spec.ModelSpec(cur_network_data[0]['module_adjacency'],
                                 cur_network_data[0]['module_operations'])
    model_fn = model_builder.build_model_fn(model, cfg, 60000)
    if os.path.exists('empty'):
        shutil.rmtree('empty')
    run_cfg = tf.contrib.tpu.RunConfig(
        model_dir='empty',
        keep_checkpoint_max=3,  # Keeps ckpt at start, halfway, and end
        save_checkpoints_secs=2**30)
    #tpu_config=tf.contrib.tpu.TPUConfig(
    #    iterations_per_loop=cfg['tpu_iterations_per_loop'],
    #    num_shards=cfg['tpu_num_shards']))
    #estimator = tf.contrib.tpu.TPUEstimator(model_fn, config=run_cfg,
    #                                       train_batch_size=cfg['batch_size'],
    #                                       eval_batch_size=cfg['batch_size'],
    #                                       predict_batch_size=cfg['batch_size'],
    #                                       use_tpu=False)#, params=cfg)
    estimator = tf.estimator.Estimator(model_fn, config=run_cfg, params=cfg)
    print(estimator)
    #dummy_input = np.zeros((1, 224, 224, 3))
    #dummy_label = np.zeros((1, 100))
    #dummy_label[0] = 1
    input_train = cifar.CIFARInput('train', cfg)
    print(cfg['batch_size'])

    #input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": dummy_input}, y=dummy_label, shuffle=True)
    #estimator.train(input_fn)
    #estimator.train(input_fn=input_train.input_fn,
    #                max_steps=1)
    #print(tf.get_default_graph().as_graph_def())

    with tf.Graph().as_default() as g:
        features = tf.placeholder(tf.float32, [cfg['batch_size'], 32, 32, 3])
        labels = tf.placeholder(tf.int32, [cfg['batch_size']])
        _ = model_fn(features,
                     labels,
                     mode=tf.estimator.ModeKeys.TRAIN,
                     params=cfg)
        with tf.Session() as sess:
            run_meta = tf.RunMetadata()
            opts = tf.profiler.ProfileOptionBuilder.float_operation()
            flops = tf.profiler.profile(g,
                                        run_meta=run_meta,
                                        cmd='op',
                                        options=opts)
            n_flops = flops.total_float_ops
            print(n_flops)
            #print(sess.graph.as_graph_def())

    training_time_sum = 0.0
    acc_sum = 0.0
    params = cur_network_data[0]['trainable_parameters']
    count = 0
    for item in cur_network_data[1][108]:
        count += 1
        training_time_sum += item['final_training_time']
        acc_sum += item['final_test_accuracy']
    training_time = training_time_sum / count
    acc = acc_sum / count

    return (params, training_time, acc, n_flops)
Esempio n. 3
0
def prepare_kd_dataset(spec, config, model_path, dataset_files,
                       new_dataset_path, trainset_part_percentage):

    if config['dataset'] == 'mnist':
        _dummy_imput_fn = _dummy_imput_fn_mnist
    elif config['cifar']:
        _dummy_imput_fn = _dummy_imput_fn_cifar
    else:
        raise Exception('Unsupported config[\'dataset\'] = {}. '
                        'Supported only cifar of mnist'.format(
                            config['dataset']))

    Path(new_dataset_path).mkdir(parents=True, exist_ok=True)
    for filename in dataset_files:
        raw_dataset = tf.data.TFRecordDataset([filename])
        params = {'file': filename, 'use_KD': False}
        estimator = tf.contrib.tpu.TPUEstimator(
            use_tpu=False,
            model_fn=model_builder.build_model_fn(spec, config, None),
            config=tf.contrib.tpu.RunConfig(model_dir=model_path),
            params=params,
            train_batch_size=config['batch_size'],
            eval_batch_size=config['batch_size'],
            predict_batch_size=100)

        est_preds = estimator.predict(input_fn=_dummy_imput_fn,
                                      yield_single_examples=False)
        all_pred_logits_aug = []
        for preds in est_preds:
            all_pred_logits_aug.append(preds['logits'])
        if len(all_pred_logits_aug) == 0:
            logging.error(filename, ": all_pred_logits_aug is empty")
        all_pred_logits_aug = np.vstack(all_pred_logits_aug)

        filename = Path(filename)
        name_postfix = '_KD'
        if trainset_part_percentage != 100:
            name_postfix += '_' + str(trainset_part_percentage)
        out_file = filename.with_name(filename.stem + name_postfix)
        out_file = out_file.with_suffix(".tfrecords")
        out_file = Path(new_dataset_path, out_file.name)
        filename = str(filename)
        with tf.io.TFRecordWriter(str(out_file)) as record_writer:
            for i, raw_record in enumerate(raw_dataset):
                if i >= 10000 * (trainset_part_percentage / 100.0):
                    break

                example = tf.train.Example()
                example.ParseFromString(raw_record.numpy())
                img = example.features.feature['image']
                label = example.features.feature['label']
                if len(all_pred_logits_aug) <= i:
                    logging.error(
                        "all_pred_logits_aug is not anought ({}); i={}".format(
                            len(all_pred_logits_aug), i))
                    continue
                preds = all_pred_logits_aug[i]
                new_label = np.hstack((label.int64_list.value[0], preds))
                feat_preds = tf.train.Feature(float_list=tf.train.FloatList(
                    value=new_label))
                example = tf.train.Example(features=tf.train.Features(
                    feature={
                        'image': img,
                        'label': feat_preds
                    }))
                record_writer.write(example.SerializeToString())
            logging.info("{} stored".format(out_file))