Example #1
0
def read_log(directory, name='test', tail=0):
  """Reads logged data about the performance of a lottery ticket experiment.

  Args:
    directory: The directory where the log data for a particular experiment
      is stored.
    name: Whether to retrieve data from the "test", "train", or "validate"
      logs.
    tail: If nonzero, returns only the last tail entries in each run.

  Returns:
    A dictionary with three keys.
    'iteration' is a numpy array of the iterations at which data was collected.
    'loss' is a numpy array of loss values at the corresponding iteration.
    'accuracy' is a numpy array of accuracy values at the corresponding
      iteration.
  """
  output = {
      'iteration': [],
      'loss': [],
      'accuracy': [],
  }

  with tf.gfile.GFile(paths.log(directory, name)) as fp:
    reader = csv.reader(fp)
    for row in reader:
      output['iteration'].append(float(row[1]))
      output['loss'].append(float(row[3]))
      output['accuracy'].append(float(row[5]))

  output['iteration'] = np.array(output['iteration'][-tail:])
  output['loss'] = np.array(output['loss'][-tail:])
  output['accuracy'] = np.array(output['accuracy'][-tail:])

  return output
def write_log(data, directory, name='test'):
    """Writes data about the performance of a lottery ticket experiment.

  Input data takes the same form as data returned by read_data. Writes a file
  in the format read by read_data.

  Args:
    data: The data to be written to the file. Takes the same form as the data
      returned by read_data.
    directory: The directory where the log data for a particular experiment is
      to be stored.
    name: What to call the data file itself.
  """
    with tf.io.gfile.GFile(paths.log(directory, name), 'wb') as fp:
        for loss, it, acc in zip(data['loss'], data['iteration'],
                                 data['accuracy']):
            fp.write(','.join(('iteration', str(it), 'loss', str(loss),
                               'accuracy', str(acc))))
            fp.write('\n')
Example #3
0
def train(sess, dataset, model, optimizer_fn, training_len, output_dir,
          **params):
    """Train a model on a dataset.

  Training continues until training_len iterations or epochs have taken place.

  Args:
    sess: A tensorflow session
    dataset: The dataset on which to train (a child of dataset_base.DatasetBase)
    model: The model to train (a child of model_base.ModelBase)
    optimizer_fn: A function that, when called, returns an instance of an
      optimizer object to be used to optimize the network.
    training_len: A tuple whose first value is the unit of measure
      ("epochs" or "iterations") and whose second value is the number of
      units for which the network should be trained.
    output_dir: The directory to which any output should be saved.
    **params: Other parameters.
      save_summaries is whether to save summary data.
      save_network is whether to save the network before and after training.
      test_interval is None if the test set should not be evaluated; otherwise,
        frequency (in iterations) at which the test set should be run.
      validate_interval is analogous to test_interval.

  Returns:
      A dictionary containing the weights before training and the weights after
      training.
  """
    # Create initial session parameters.
    #optimize = optimizer_fn().minimize(model.loss)
    D_solver = (tf.train.AdamOptimizer(
        learning_rate=model.lr, beta1=0.5).minimize(model.D_loss,
                                                    var_list=model.theta_D))
    G_solver = (tf.train.AdamOptimizer(
        learning_rate=model.lr, beta1=0.5).minimize(model.G_loss,
                                                    var_list=model.theta_G))
    sess.run(tf.global_variables_initializer())
    initial_weights = model.get_current_weights(sess)

    train_handle = dataset.get_train_handle(sess)
    test_handle = dataset.get_test_handle(sess)
    validate_handle = dataset.get_validate_handle(sess)

    # Optional operations to perform before training.
    if params.get('save_summaries', False):
        writer = tf.summary.FileWriter(paths.summaries(output_dir))
        D_train_file = tf.gfile.GFile(paths.log(output_dir, 'D_train'), 'w')
        G_train_file = tf.gfile.GFile(paths.log(output_dir, 'G_train'), 'w')
        test_file = tf.gfile.GFile(paths.log(output_dir, 'test'), 'w')
        validate_file = tf.gfile.GFile(paths.log(output_dir, 'validate'), 'w')

    if params.get('save_network', False):
        save_restore.save_network(paths.initial(output_dir), initial_weights)
        save_restore.save_network(paths.masks(output_dir), model.masks)

    # Helper functions to collect and record summaries.
    def record_summaries(iteration, records, fp):
        """Records summaries obtained from evaluating the network.

    Args:
      iteration: The current training iteration as an integer.
      records: A list of records to be written.
      fp: A file to which the records should be logged in an easier-to-parse
        format than the tensorflow summary files.
    """
        if params.get('save_summaries', False):
            log = ['iteration', str(iteration)]
            for record in records:
                # Log to tensorflow summaries for tensorboard.
                writer.add_summary(record, iteration)

                # Log to text file for convenience.
                summary_proto = tf.Summary()
                summary_proto.ParseFromString(record)
                value = summary_proto.value[0]
                log += [value.tag, str(value.simple_value)]
            fp.write(','.join(log) + '\n')

    def collect_test_summaries(iteration):
        if (params.get('save_summaries', False) and 'test_interval' in params
                and iteration % params['test_interval'] == 0):
            sess.run(dataset.test_initializer)
            records = sess.run(model.test_summaries,
                               {dataset.handle: test_handle})
            record_summaries(iteration, records, test_file)

    def collect_validate_summaries(iteration):
        if (params.get('save_summaries', False)
                and 'validate_interval' in params
                and iteration % params['validate_interval'] == 0):
            sess.run(dataset.validate_initializer)
            records = sess.run(model.validate_summaries,
                               {dataset.handle: validate_handle})
            record_summaries(iteration, records, validate_file)

    # Train for the specified number of epochs. This behavior is encapsulated
    # in a function so that it is possible to break out of multiple loops
    # simultaneously.
    def training_loop():
        """The main training loop encapsulated in a function."""
        iteration = 0
        epoch = 0
        while True:
            sess.run(dataset.train_initializer)
            epoch += 1

            # End training if we have passed the epoch limit.
            if training_len[0] == 'epochs' and epoch > training_len[1]:
                return

            # One training epoch.
            while True:
                try:
                    iteration += 1
                    if iteration == 12500:
                        import pdb
                        pdb.set_trace()

                    # End training if we have passed the iteration limit.
                    if training_len[
                            0] == 'iterations' and iteration > training_len[1]:
                        return

                    # Train.
                    #records = sess.run([optimize] + model.train_summaries,
                    #                   {dataset.handle: train_handle})[1:]
                    # TODO: make batch size less ridiculously designed
                    D_records = sess.run(
                        [D_solver] + model.D_train_summaries, {
                            dataset.handle: train_handle,
                            model.z: ModelWgan.sample_z(32, model.z_dim)
                        })[1:]
                    G_records = sess.run(
                        [G_solver] + model.G_train_summaries, {
                            dataset.handle: train_handle,
                            model.z: ModelWgan.sample_z(32, model.z_dim)
                        })[1:]

                    record_summaries(iteration, D_records, D_train_file)
                    record_summaries(iteration, G_records, G_train_file)


#
#           # Collect test and validation data if applicable.
#           collect_test_summaries(iteration)
#           collect_validate_summaries(iteration)

# End of epoch handling.
                except tf.errors.OutOfRangeError:
                    break

    # Run the training loop.
    training_loop()

    # Clean up.
    if params.get('save_summaries', False):
        D_train_file.close()
        G_train_file.close()
        test_file.close()
        validate_file.close()

    # Retrieve the final weights of the model.
    final_weights = model.get_current_weights(sess)
    if params.get('save_network', False):
        save_restore.save_network(paths.final(output_dir), final_weights)

    return initial_weights, final_weights
Example #4
0
exp_path = paths.experiment(constants.EXPERIMENT_PATH, sys.argv[1])
trial_nums = [
    int(re.findall('\d+', trial_dir)[0]) for trial_dir in os.listdir(exp_path)
]
print("Found {} trials".format(max(trial_nums)))
for trial in range(1, max(trial_nums) + 1):
    trial_path = paths.trial(exp_path, trial)
    if not os.path.isdir(trial_path):
        print("Warning: skipping trial {}, does not exist".format(trial))
        continue

    first_run_path = paths.run(trial_path, 0)
    first_run_train_acc = float(
        subprocess.check_output(
            ['tail', '-n', '1',
             paths.log(first_run_path, 'train')]).strip().split(',')[-1])
    avg_printer.do_print(trial, '\tFirst run train acc: {}',
                         [first_run_train_acc])
    first_run_test_acc = float(
        subprocess.check_output(
            ['tail', '-n', '1',
             paths.log(first_run_path, 'test')]).strip().split(',')[-1])
    avg_printer.do_print(trial, '\tFirst run test acc: {}',
                         [first_run_test_acc])

    runs = map(int, os.listdir(trial_path))
    second_last_run = sorted(runs)[-2] if len(runs) > 1 else runs[0]
    second_last_path = paths.run(trial_path, second_last_run)

    second_last_train_acc = float(
        subprocess.check_output(