示例#1
0
def run_imagenet_test():
  """ Runs the a test that trains a CNN to classify ImageNet data.
  Returns:
    A tuple containing the total elapsed time, and the average number of
    training iterations per second. """
  batch_size = 128
  # How many batches to have loaded into VRAM at once.
  load_batches = 5

  # Learning rate hyperparameters.
  learning_rate = 0.00001
  decay_steps = 10000
  decay_rate = 1
  momentum = 0.9
  weight_decay = 0.0005

  rho = 0.9
  epsilon = 1e-6

  # Where we save the network.
  save_file = "/home/theano/training_data/alexnet.pkl"
  synsets_save_file = "/home/theano/training_data/synsets.pkl"
  # Where we load the synsets to use from.
  synset_list = "/job_files/ilsvrc16_synsets.txt"
  # Where to load and save datasets.
  dataset_path = "/home/theano/training_data/ilsvrc16_dataset"
  # Where to cache image data.
  cache_path = "/home/theano/training_data/cache"
  # Where to save downloaded synset info.
  synset_dir = "/home/theano/training_data/synsets"

  data = data_loader.ImagenetLoader(batch_size, load_batches, cache_path,
                                    dataset_path, synset_dir, synset_list)
  if os.path.exists(synsets_save_file):
    data.load(synsets_save_file)
  train = data.get_train_set()
  test = data.get_test_set()
  cpu_labels = data.get_non_shared_test_set()

  if os.path.exists(save_file):
    # Load from the file.
    print "Theano: Loading network from file..."
    network = AlexNet.load(save_file, train, test, batch_size,
                           learning_rate=learning_rate)

  else:
    # Build new network.
    network = AlexNet(train, test, batch_size,
                      patch_separation=batch_size * load_batches)

    network.use_sgd_trainer(learning_rate, momentum=momentum,
                            weight_decay=weight_decay,
                            decay_rate=decay_rate,
                            decay_steps=decay_steps)
    #network.use_rmsprop_trainer(learning_rate, rho, epsilon,
    #                            decay_rate=decay_rate,
    #                            decay_steps=decay_steps)

  print "Theano: Starting ImageNet test..."

  accuracy = 0
  start_time = time.time()
  iterations = 0

  train_batch_index = 0
  test_batch_index = 0

  while iterations < 150000:
    logger.debug("Train index, size: %d, %d" % (train_batch_index,
                                                data.get_train_batch_size()))
    logger.debug("Test index, size: %d, %d" % (test_batch_index,
                                               data.get_test_batch_size()))

    # Swap in new data if we need to.
    if (train_batch_index + 1) * batch_size > data.get_train_batch_size():
      train_batch_index = 0
      logger.info("Getting train set.")
      train = data.get_train_set()
      logger.info("Got train set.")
    # Swap in new data if we need to.
    test_set_one_patch = data.get_test_batch_size() / 10
    if (test_batch_index + 1) * batch_size > test_set_one_patch:
      test_batch_index = 0
      logger.info("Getting test set.")
      test = data.get_test_set()
      cpu_labels = data.get_non_shared_test_set()[:]
      logger.info("Got test set.")

    if iterations % 100 == 0:
      # cpu_labels contains labels for every batch currently loaded in VRAM,
      # without duplicates for additional patches.
      label_index = test_batch_index * batch_size
      top_one, top_five = network.test(test_batch_index,
                                       cpu_labels[label_index:label_index + \
                                                              batch_size])
      logger.info("Step %d, testing top 1: %f, testing top 5: %f" % \
                  (iterations, top_one, top_five))

      test_batch_index += 1

    cost, rate, step = network.train(train_batch_index)
    logger.info("Training cost: %f, learning rate: %f, step: %d" % \
                (cost, rate, step))

    if iterations % 100 == 0:
      print "Saving network..."
      network.save(save_file)
      # Save synset data as well.
      data.save(synsets_save_file)

    iterations += 1
    train_batch_index += 1

  elapsed = time.time() - start_time
  speed = iterations / elapsed
  print("Theano: Ran %d training iterations. (%f iter/s)" % \
      (iterations, speed))
  print("Theano: Imagenet test completed in %f seconds." % (elapsed))

  data.exit_gracefully()

  return (elapsed, speed)