def download(location=locations.DATASET_LOCATION): d = {} (d['x_train'], d['y_train']), (d['x_test'], d['y_test']) = dataset.load_data() save_restore.save_network(location, d)
def train(sess, dataset, model, optimizer_fn, training_len, output_dir, **params): """Train a model on a dataset. Training continues until training_len iterations or epochs have taken place. Args: sess: A tensorflow session dataset: The dataset on which to train (a child of dataset_base.DatasetBase) model: The model to train (a child of model_base.ModelBase) optimizer_fn: A function that, when called, returns an instance of an optimizer object to be used to optimize the network. training_len: A tuple whose first value is the unit of measure ("epochs" or "iterations") and whose second value is the number of units for which the network should be trained. output_dir: The directory to which any output should be saved. **params: Other parameters. save_summaries is whether to save summary data. save_network is whether to save the network before and after training. test_interval is None if the test set should not be evaluated; otherwise, frequency (in iterations) at which the test set should be run. validate_interval is analogous to test_interval. Returns: A dictionary containing the weights before training and the weights after training. """ # Create initial session parameters. #optimize = optimizer_fn().minimize(model.loss) D_solver = (tf.train.AdamOptimizer( learning_rate=model.lr, beta1=0.5).minimize(model.D_loss, var_list=model.theta_D)) G_solver = (tf.train.AdamOptimizer( learning_rate=model.lr, beta1=0.5).minimize(model.G_loss, var_list=model.theta_G)) sess.run(tf.global_variables_initializer()) initial_weights = model.get_current_weights(sess) train_handle = dataset.get_train_handle(sess) test_handle = dataset.get_test_handle(sess) validate_handle = dataset.get_validate_handle(sess) # Optional operations to perform before training. if params.get('save_summaries', False): writer = tf.summary.FileWriter(paths.summaries(output_dir)) D_train_file = tf.gfile.GFile(paths.log(output_dir, 'D_train'), 'w') G_train_file = tf.gfile.GFile(paths.log(output_dir, 'G_train'), 'w') test_file = tf.gfile.GFile(paths.log(output_dir, 'test'), 'w') validate_file = tf.gfile.GFile(paths.log(output_dir, 'validate'), 'w') if params.get('save_network', False): save_restore.save_network(paths.initial(output_dir), initial_weights) save_restore.save_network(paths.masks(output_dir), model.masks) # Helper functions to collect and record summaries. def record_summaries(iteration, records, fp): """Records summaries obtained from evaluating the network. Args: iteration: The current training iteration as an integer. records: A list of records to be written. fp: A file to which the records should be logged in an easier-to-parse format than the tensorflow summary files. """ if params.get('save_summaries', False): log = ['iteration', str(iteration)] for record in records: # Log to tensorflow summaries for tensorboard. writer.add_summary(record, iteration) # Log to text file for convenience. summary_proto = tf.Summary() summary_proto.ParseFromString(record) value = summary_proto.value[0] log += [value.tag, str(value.simple_value)] fp.write(','.join(log) + '\n') def collect_test_summaries(iteration): if (params.get('save_summaries', False) and 'test_interval' in params and iteration % params['test_interval'] == 0): sess.run(dataset.test_initializer) records = sess.run(model.test_summaries, {dataset.handle: test_handle}) record_summaries(iteration, records, test_file) def collect_validate_summaries(iteration): if (params.get('save_summaries', False) and 'validate_interval' in params and iteration % params['validate_interval'] == 0): sess.run(dataset.validate_initializer) records = sess.run(model.validate_summaries, {dataset.handle: validate_handle}) record_summaries(iteration, records, validate_file) # Train for the specified number of epochs. This behavior is encapsulated # in a function so that it is possible to break out of multiple loops # simultaneously. def training_loop(): """The main training loop encapsulated in a function.""" iteration = 0 epoch = 0 while True: sess.run(dataset.train_initializer) epoch += 1 # End training if we have passed the epoch limit. if training_len[0] == 'epochs' and epoch > training_len[1]: return # One training epoch. while True: try: iteration += 1 if iteration == 12500: import pdb pdb.set_trace() # End training if we have passed the iteration limit. if training_len[ 0] == 'iterations' and iteration > training_len[1]: return # Train. #records = sess.run([optimize] + model.train_summaries, # {dataset.handle: train_handle})[1:] # TODO: make batch size less ridiculously designed D_records = sess.run( [D_solver] + model.D_train_summaries, { dataset.handle: train_handle, model.z: ModelWgan.sample_z(32, model.z_dim) })[1:] G_records = sess.run( [G_solver] + model.G_train_summaries, { dataset.handle: train_handle, model.z: ModelWgan.sample_z(32, model.z_dim) })[1:] record_summaries(iteration, D_records, D_train_file) record_summaries(iteration, G_records, G_train_file) # # # Collect test and validation data if applicable. # collect_test_summaries(iteration) # collect_validate_summaries(iteration) # End of epoch handling. except tf.errors.OutOfRangeError: break # Run the training loop. training_loop() # Clean up. if params.get('save_summaries', False): D_train_file.close() G_train_file.close() test_file.close() validate_file.close() # Retrieve the final weights of the model. final_weights = model.get_current_weights(sess) if params.get('save_network', False): save_restore.save_network(paths.final(output_dir), final_weights) return initial_weights, final_weights
def download(location=locations.MNIST_LOCATION): d = {} (d['x_train'], d['y_train']), (d['x_test'], d['y_test']) = mnist.load_data() save_restore.save_network(location, d)