def _test_input_fn(self, data_dir): """ Test training dataset. :return: None """ batch_size = self.batch_size d = per_record_input_fn(data_dir, self.test_metadata_file, self.num_shards, self.shard_index, batch_size, constants.TFRECORD) d_iter = tf.compat.v1.data.make_one_shot_iterator(d) item = d_iter.get_next() i = 0 sparse_tensors = self.generate_sparse_tensors(weight_indices, weight_values, 5, batch_size, self.shard_index) with self.session() as sess: sparse_tensors_val = sess.run(sparse_tensors) try: while True: features, response = sess.run(item) self.assertAllEqual(features['weight'].values, sparse_tensors_val[i].values) self.assertAllEqual(features['weight'].indices, sparse_tensors_val[i].indices) self.assertAllEqual(features['f1'], f1[i*batch_size:(i+1)*batch_size] + self.shard_index) self.assertAllEqual(response['response'], labels[i*batch_size:(i+1)*batch_size] + self.shard_index) i += 1 except tf.errors.OutOfRangeError: pass self.assertEqual(i, num_records // self.batch_size)
def _get_num_iterations(self, input_files, metadata_file): """ Get the number of samples each worker assigned. This works for tfrecord only. :param input_files: a list of TFRecord files. :param metadata_file: the metadata associated with the TFRecord files. :return: number of iterations """ start_time = time.time() assert (self.data_format == constants.TFRECORD) # reset the default graph, so it has been called before the main graph is built. tf1.reset_default_graph() num_iterations = 0 dataset = per_record_input_fn(input_files, metadata_file, 1, 0, self.batch_size, self.data_format, build_features=False) data_iterator = tf1.data.make_initializable_iterator(dataset) next_item = data_iterator.get_next() with tf1.device('device:CPU:0'), tf1.Session() as sess: sess.run(data_iterator.initializer) while True: try: sess.run(next_item) num_iterations += 1 except tf.errors.OutOfRangeError: break end_time = time.time() logging( f'It took {end_time - start_time} seconds to count {num_iterations} batches ' f'with batch size {self.batch_size}.') return num_iterations
def predict(self, output_dir, input_data_path, metadata_file, checkpoint_path, execution_context, schema_params): # Overwrite predict method from parent class. logging("Kicking off fixed effect LR predict") task_index = execution_context[constants.TASK_INDEX] num_workers = execution_context[constants.NUM_WORKERS] # Prediction uses local server self.server = tf1.train.Server.create_local_server() # Define the graph here, keep session open to let scipy L-BFGS solver repeatly call _compute_loss_and_gradients # Inference is conducted in local mode. with tf1.variable_scope('worker{}'.format(task_index)), tf1.device('device:CPU:0'): dataset = per_record_input_fn(input_data_path, metadata_file, num_workers, task_index, self.batch_size, self.data_format) x_placeholder = tf1.placeholder(tf1.float64, shape=[None]) data_diter = tf1.data.make_one_shot_iterator(dataset) assigned_files = self._get_assigned_files(input_data_path, num_workers, task_index) data_num_iterations = self._get_num_iterations(assigned_files) sample_ids_op, labels_op, weights_op, scores_op, scores_and_offsets_op = self._inference_model_fn( data_diter, x_placeholder, data_num_iterations, schema_params) init_variables_op = tf1.global_variables_initializer() session_creator = tf1.train.ChiefSessionCreator(master=self.server.target) tf_session = tf1.train.MonitoredSession(session_creator=session_creator) tf_session.run(init_variables_op) predict_ops = (sample_ids_op, labels_op, weights_op, scores_op, scores_and_offsets_op) model_coefficients = self._load_model() self._run_inference(model_coefficients, tf_session, x_placeholder, predict_ops, task_index, schema_params, output_dir) logging("Snooze before closing the session") snooze_after_tf_session_closure(tf_session, self.delayed_exit_in_seconds) logging("Closed the session")
def train(self, training_data_dir, validation_data_dir, metadata_file, checkpoint_path, execution_context, schema_params): """ Overwrite train method from parent class. """ logging("Kicking off fixed effect LR LBFGS training") task_index = execution_context[constants.TASK_INDEX] num_workers = execution_context[constants.NUM_WORKERS] is_chief = execution_context[constants.IS_CHIEF] self._create_server(execution_context) assigned_train_files = self._get_assigned_files(training_data_dir, num_workers, task_index) if self.copy_to_local: train_input_dir = self.local_training_input_dir actual_train_files = copy_files(assigned_train_files, train_input_dir) # After copy the worker's shard to local, we don't shard the local files any more. train_num_shards = 1 train_shard_index = 0 else: train_input_dir = self.training_data_dir actual_train_files = assigned_train_files train_num_shards = num_workers train_shard_index = task_index # Define the graph here, keep session open to let scipy L-BFGS solver repeatly call _compute_loss_and_gradients with tf1.variable_scope('worker{}'.format(task_index)), \ tf1.device('job:worker/task:{}/device:CPU:0'.format(task_index)): # Define ops for training train_dataset = per_record_input_fn(train_input_dir, metadata_file, train_num_shards, train_shard_index, self.batch_size, self.data_format) train_diter = tf1.data.make_initializable_iterator(train_dataset) init_train_dataset_op = train_diter.initializer train_x_placeholder = tf1.placeholder(tf1.float64, shape=[None]) train_num_iterations = self._get_num_iterations(actual_train_files) value_op, gradients_op = self._train_model_fn(train_diter, train_x_placeholder, num_workers, self.num_features, self.global_num_samples, train_num_iterations, schema_params) train_ops = (init_train_dataset_op, value_op, gradients_op) # Define ops for inference valid_dataset = per_record_input_fn(validation_data_dir, metadata_file, num_workers, task_index, self.batch_size, self.data_format) inference_x_placeholder = tf1.placeholder(tf1.float64, shape=[None]) inference_train_data_diter = tf1.data.make_one_shot_iterator(train_dataset) train_sample_ids_op, train_labels_op, train_weights_op, train_prediction_score_op, \ train_prediction_score_per_coordinate_op = self._inference_model_fn( inference_train_data_diter, inference_x_placeholder, train_num_iterations, schema_params) inference_validation_data_diter = tf1.data.make_one_shot_iterator(valid_dataset) assigned_validation_files = self._get_assigned_files(validation_data_dir, num_workers, task_index) validation_data_num_iterations = self._get_num_iterations(assigned_validation_files) valid_sample_ids_op, valid_labels_op, valid_weights_op, valid_prediction_score_op, \ valid_prediction_score_per_coordinate_op = self._inference_model_fn( inference_validation_data_diter, inference_x_placeholder, validation_data_num_iterations, schema_params) if num_workers > 1: all_reduce_sync_op = collective_ops.all_reduce( tf1.constant(0.0, tf1.float64), num_workers, FixedEffectLRModelLBFGS.TF_ALL_REDUCE_GROUP_KEY, 0, merge_op='Add', final_op='Id') init_variables_op = tf1.global_variables_initializer() session_creator = tf1.train.ChiefSessionCreator(master=self.server.target) tf_session = tf1.train.MonitoredSession(session_creator=session_creator) tf_session.run(init_variables_op) # load existing model if available logging("Try to load initial model coefficients...") prev_model = self._load_model(catch_exception=True) if prev_model is None or len(prev_model) != self.num_features + 1: logging("No initial model found, use all zeros instead.") x0 = np.zeros(self.num_features + 1) else: logging("Found a previous model, loaded as the initial point for training") x0 = prev_model # Run all reduce warm up logging("All-reduce-warmup starts...") if num_workers > 1: start_time = time.time() tf_session.run([all_reduce_sync_op]) logging("All-reduce-warmup --- {} seconds ---".format(time.time() - start_time)) # Start training logging("Training starts...") start_time = time.time() self.model_coefficients, f_min, info = fmin_l_bfgs_b( func=self._compute_loss_and_gradients, x0=x0, approx_grad=False, m=self.num_correction_pairs, # number of variable metrics corrections. default is 10. factr=self.factor, # control precision, smaller the better. maxiter=self.max_iteration, args=(tf_session, train_x_placeholder, train_ops, task_index), disp=0) logging("Training --- {} seconds ---".format(time.time() - start_time)) logging("\n------------------------------\nf_min: {}\nnum of funcalls: {}\ntask msg:" "{}\n------------------------------".format(f_min, info['funcalls'], info['task'])) logging("Inference training data starts...") inference_training_data_ops = (train_sample_ids_op, train_labels_op, train_weights_op, train_prediction_score_op, train_prediction_score_per_coordinate_op) self._run_inference(self.model_coefficients, tf_session, inference_x_placeholder, inference_training_data_ops, task_index, schema_params, self.training_output_dir) logging("Inference validation data starts...") inference_validation_data_ops = (valid_sample_ids_op, valid_labels_op, valid_weights_op, valid_prediction_score_op, valid_prediction_score_per_coordinate_op) self._run_inference(self.model_coefficients, tf_session, inference_x_placeholder, inference_validation_data_ops, task_index, schema_params, self.validation_output_dir) # Final sync up and then reliably terminate all workers if (num_workers > 1): tf_session.run([all_reduce_sync_op]) snooze_after_tf_session_closure(tf_session, self.delayed_exit_in_seconds) if is_chief: self._save_model() # remove the cached training input files if self.copy_to_local: tf1.gfile.DeleteRecursively(self.local_training_input_dir)