def main(argv): del argv # Unused. line_format = ('--masks={masks} --output_dir={output_dir}') name = FLAGS.experiment for trial in range(1, 21): for level in range(0, 31): for run in range(1, 11): masks = paths.masks(constants.run(trial, level)) output = constants.run(trial, level, name, run) result = line_format.format(masks=masks, output_dir=output) if FLAGS.experiment in ('reuse', 'reuse_sign'): result += (' --initialization_distribution=' + constants.initialization(level)) if FLAGS.experiment == 'reuse_sign': presets = paths.initial(constants.run(trial, level)) result += ' --same_sign={}'.format(presets) print(result)
def train(sess, dataset, model, optimizer_fn, training_len, output_dir, **params): """Train a model on a dataset. Training continues until training_len iterations or epochs have taken place. Args: sess: A tensorflow session dataset: The dataset on which to train (a child of dataset_base.DatasetBase) model: The model to train (a child of model_base.ModelBase) optimizer_fn: A function that, when called, returns an instance of an optimizer object to be used to optimize the network. training_len: A tuple whose first value is the unit of measure ("epochs" or "iterations") and whose second value is the number of units for which the network should be trained. output_dir: The directory to which any output should be saved. **params: Other parameters. save_summaries is whether to save summary data. save_network is whether to save the network before and after training. test_interval is None if the test set should not be evaluated; otherwise, frequency (in iterations) at which the test set should be run. validate_interval is analogous to test_interval. Returns: A dictionary containing the weights before training and the weights after training. """ # Create initial session parameters. #optimize = optimizer_fn().minimize(model.loss) D_solver = (tf.train.AdamOptimizer( learning_rate=model.lr, beta1=0.5).minimize(model.D_loss, var_list=model.theta_D)) G_solver = (tf.train.AdamOptimizer( learning_rate=model.lr, beta1=0.5).minimize(model.G_loss, var_list=model.theta_G)) sess.run(tf.global_variables_initializer()) initial_weights = model.get_current_weights(sess) train_handle = dataset.get_train_handle(sess) test_handle = dataset.get_test_handle(sess) validate_handle = dataset.get_validate_handle(sess) # Optional operations to perform before training. if params.get('save_summaries', False): writer = tf.summary.FileWriter(paths.summaries(output_dir)) D_train_file = tf.gfile.GFile(paths.log(output_dir, 'D_train'), 'w') G_train_file = tf.gfile.GFile(paths.log(output_dir, 'G_train'), 'w') test_file = tf.gfile.GFile(paths.log(output_dir, 'test'), 'w') validate_file = tf.gfile.GFile(paths.log(output_dir, 'validate'), 'w') if params.get('save_network', False): save_restore.save_network(paths.initial(output_dir), initial_weights) save_restore.save_network(paths.masks(output_dir), model.masks) # Helper functions to collect and record summaries. def record_summaries(iteration, records, fp): """Records summaries obtained from evaluating the network. Args: iteration: The current training iteration as an integer. records: A list of records to be written. fp: A file to which the records should be logged in an easier-to-parse format than the tensorflow summary files. """ if params.get('save_summaries', False): log = ['iteration', str(iteration)] for record in records: # Log to tensorflow summaries for tensorboard. writer.add_summary(record, iteration) # Log to text file for convenience. summary_proto = tf.Summary() summary_proto.ParseFromString(record) value = summary_proto.value[0] log += [value.tag, str(value.simple_value)] fp.write(','.join(log) + '\n') def collect_test_summaries(iteration): if (params.get('save_summaries', False) and 'test_interval' in params and iteration % params['test_interval'] == 0): sess.run(dataset.test_initializer) records = sess.run(model.test_summaries, {dataset.handle: test_handle}) record_summaries(iteration, records, test_file) def collect_validate_summaries(iteration): if (params.get('save_summaries', False) and 'validate_interval' in params and iteration % params['validate_interval'] == 0): sess.run(dataset.validate_initializer) records = sess.run(model.validate_summaries, {dataset.handle: validate_handle}) record_summaries(iteration, records, validate_file) # Train for the specified number of epochs. This behavior is encapsulated # in a function so that it is possible to break out of multiple loops # simultaneously. def training_loop(): """The main training loop encapsulated in a function.""" iteration = 0 epoch = 0 while True: sess.run(dataset.train_initializer) epoch += 1 # End training if we have passed the epoch limit. if training_len[0] == 'epochs' and epoch > training_len[1]: return # One training epoch. while True: try: iteration += 1 if iteration == 12500: import pdb pdb.set_trace() # End training if we have passed the iteration limit. if training_len[ 0] == 'iterations' and iteration > training_len[1]: return # Train. #records = sess.run([optimize] + model.train_summaries, # {dataset.handle: train_handle})[1:] # TODO: make batch size less ridiculously designed D_records = sess.run( [D_solver] + model.D_train_summaries, { dataset.handle: train_handle, model.z: ModelWgan.sample_z(32, model.z_dim) })[1:] G_records = sess.run( [G_solver] + model.G_train_summaries, { dataset.handle: train_handle, model.z: ModelWgan.sample_z(32, model.z_dim) })[1:] record_summaries(iteration, D_records, D_train_file) record_summaries(iteration, G_records, G_train_file) # # # Collect test and validation data if applicable. # collect_test_summaries(iteration) # collect_validate_summaries(iteration) # End of epoch handling. except tf.errors.OutOfRangeError: break # Run the training loop. training_loop() # Clean up. if params.get('save_summaries', False): D_train_file.close() G_train_file.close() test_file.close() validate_file.close() # Retrieve the final weights of the model. final_weights = model.get_current_weights(sess) if params.get('save_network', False): save_restore.save_network(paths.final(output_dir), final_weights) return initial_weights, final_weights
avg_printer.do_print(trial, '\tTest: {}', [second_last_test_acc]) avg_printer.do_print( trial, '\t{:^20s}{:^20s}{:^20s}{:^20s}'.format('Nonempty', 'Rows', 'Columns', 'Weights'), ()) if second_last_run == 0: # First runs don't have masks, so we have to use the initial weights to get the shapes of the layers weights_dir = paths.initial(second_last_path) for mask_name in sorted(os.listdir(weights_dir)): mask = np.load(os.path.join(weights_dir, mask_name)) avg_printer.do_print( trial, '\t{:^20s}'.format(mask_name) + '{:>10.2f}/{:<10.2f}{:>10.2f}/{:<10.2f}{:>10.2f}/{:<10.2f}', [ mask.shape[0], mask.shape[0], mask.shape[1], mask.shape[1], mask.shape[0] * mask.shape[1], mask.shape[0] * mask.shape[1] ]) else: masks_dir = paths.masks(second_last_path) for mask_name in sorted(os.listdir(masks_dir)): mask = np.load(os.path.join(masks_dir, mask_name)) avg_printer.do_print( trial, '\t{:^20s}'.format(mask_name) + '{:>10.2f}/{:<10.2f}{:>10.2f}/{:<10.2f}{:>10.2f}/{:<10.2f}', [ np.sum(np.sum(mask, axis=1) > 0), mask.shape[0], np.sum(np.sum(mask, axis=0) > 0), mask.shape[1], np.sum(mask), mask.shape[0] * mask.shape[1] ]) avg_printer.flush()