Ejemplo n.º 1
0
 def _test_input_fn(self, data_dir):
     """
     Test training dataset.
     :return: None
     """
     batch_size = self.batch_size
     d = per_record_input_fn(data_dir,
                             self.test_metadata_file,
                             self.num_shards,
                             self.shard_index,
                             batch_size,
                             constants.TFRECORD)
     d_iter = tf.compat.v1.data.make_one_shot_iterator(d)
     item = d_iter.get_next()
     i = 0
     sparse_tensors = self.generate_sparse_tensors(weight_indices,
                                                   weight_values, 5, batch_size, self.shard_index)
     with self.session() as sess:
         sparse_tensors_val = sess.run(sparse_tensors)
         try:
             while True:
                 features, response = sess.run(item)
                 self.assertAllEqual(features['weight'].values, sparse_tensors_val[i].values)
                 self.assertAllEqual(features['weight'].indices, sparse_tensors_val[i].indices)
                 self.assertAllEqual(features['f1'], f1[i*batch_size:(i+1)*batch_size]
                                     + self.shard_index)
                 self.assertAllEqual(response['response'], labels[i*batch_size:(i+1)*batch_size]
                                     + self.shard_index)
                 i += 1
         except tf.errors.OutOfRangeError:
             pass
     self.assertEqual(i, num_records // self.batch_size)
Ejemplo n.º 2
0
 def _get_num_iterations(self, input_files, metadata_file):
     """ Get the number of samples each worker assigned.
         This works for tfrecord only.
     :param input_files: a list of TFRecord files.
     :param metadata_file: the metadata associated with the TFRecord files.
     :return: number of iterations
     """
     start_time = time.time()
     assert (self.data_format == constants.TFRECORD)
     # reset the default graph, so it has been called before the main graph is built.
     tf1.reset_default_graph()
     num_iterations = 0
     dataset = per_record_input_fn(input_files,
                                   metadata_file,
                                   1,
                                   0,
                                   self.batch_size,
                                   self.data_format,
                                   build_features=False)
     data_iterator = tf1.data.make_initializable_iterator(dataset)
     next_item = data_iterator.get_next()
     with tf1.device('device:CPU:0'), tf1.Session() as sess:
         sess.run(data_iterator.initializer)
         while True:
             try:
                 sess.run(next_item)
                 num_iterations += 1
             except tf.errors.OutOfRangeError:
                 break
     end_time = time.time()
     logging(
         f'It took {end_time - start_time} seconds to count {num_iterations} batches '
         f'with batch size {self.batch_size}.')
     return num_iterations
    def predict(self,
                output_dir,
                input_data_path,
                metadata_file,
                checkpoint_path,
                execution_context,
                schema_params):
        # Overwrite predict method from parent class.
        logging("Kicking off fixed effect LR predict")

        task_index = execution_context[constants.TASK_INDEX]
        num_workers = execution_context[constants.NUM_WORKERS]
        # Prediction uses local server
        self.server = tf1.train.Server.create_local_server()

        # Define the graph here, keep session open to let scipy L-BFGS solver repeatly call _compute_loss_and_gradients
        # Inference is conducted in local mode.
        with tf1.variable_scope('worker{}'.format(task_index)), tf1.device('device:CPU:0'):
            dataset = per_record_input_fn(input_data_path,
                                          metadata_file,
                                          num_workers,
                                          task_index,
                                          self.batch_size,
                                          self.data_format)
            x_placeholder = tf1.placeholder(tf1.float64, shape=[None])

            data_diter = tf1.data.make_one_shot_iterator(dataset)
            assigned_files = self._get_assigned_files(input_data_path, num_workers, task_index)
            data_num_iterations = self._get_num_iterations(assigned_files)
            sample_ids_op, labels_op, weights_op, scores_op, scores_and_offsets_op = self._inference_model_fn(
                data_diter,
                x_placeholder,
                data_num_iterations,
                schema_params)
            init_variables_op = tf1.global_variables_initializer()

        session_creator = tf1.train.ChiefSessionCreator(master=self.server.target)
        tf_session = tf1.train.MonitoredSession(session_creator=session_creator)
        tf_session.run(init_variables_op)

        predict_ops = (sample_ids_op, labels_op, weights_op, scores_op, scores_and_offsets_op)
        model_coefficients = self._load_model()
        self._run_inference(model_coefficients,
                            tf_session,
                            x_placeholder,
                            predict_ops,
                            task_index,
                            schema_params,
                            output_dir)
        logging("Snooze before closing the session")
        snooze_after_tf_session_closure(tf_session, self.delayed_exit_in_seconds)
        logging("Closed the session")
    def train(self, training_data_dir, validation_data_dir, metadata_file, checkpoint_path,
              execution_context, schema_params):
        """ Overwrite train method from parent class. """
        logging("Kicking off fixed effect LR LBFGS training")

        task_index = execution_context[constants.TASK_INDEX]
        num_workers = execution_context[constants.NUM_WORKERS]
        is_chief = execution_context[constants.IS_CHIEF]
        self._create_server(execution_context)

        assigned_train_files = self._get_assigned_files(training_data_dir, num_workers, task_index)
        if self.copy_to_local:
            train_input_dir = self.local_training_input_dir
            actual_train_files = copy_files(assigned_train_files, train_input_dir)
            # After copy the worker's shard to local, we don't shard the local files any more.
            train_num_shards = 1
            train_shard_index = 0
        else:
            train_input_dir = self.training_data_dir
            actual_train_files = assigned_train_files
            train_num_shards = num_workers
            train_shard_index = task_index

        # Define the graph here, keep session open to let scipy L-BFGS solver repeatly call _compute_loss_and_gradients
        with tf1.variable_scope('worker{}'.format(task_index)), \
                tf1.device('job:worker/task:{}/device:CPU:0'.format(task_index)):

            # Define ops for training
            train_dataset = per_record_input_fn(train_input_dir,
                                                metadata_file,
                                                train_num_shards,
                                                train_shard_index,
                                                self.batch_size,
                                                self.data_format)
            train_diter = tf1.data.make_initializable_iterator(train_dataset)
            init_train_dataset_op = train_diter.initializer
            train_x_placeholder = tf1.placeholder(tf1.float64, shape=[None])
            train_num_iterations = self._get_num_iterations(actual_train_files)
            value_op, gradients_op = self._train_model_fn(train_diter,
                                                          train_x_placeholder,
                                                          num_workers,
                                                          self.num_features,
                                                          self.global_num_samples,
                                                          train_num_iterations,
                                                          schema_params)
            train_ops = (init_train_dataset_op, value_op, gradients_op)

            # Define ops for inference
            valid_dataset = per_record_input_fn(validation_data_dir,
                                                metadata_file,
                                                num_workers,
                                                task_index,
                                                self.batch_size,
                                                self.data_format)
            inference_x_placeholder = tf1.placeholder(tf1.float64, shape=[None])

            inference_train_data_diter = tf1.data.make_one_shot_iterator(train_dataset)
            train_sample_ids_op, train_labels_op, train_weights_op, train_prediction_score_op, \
                train_prediction_score_per_coordinate_op = self._inference_model_fn(
                    inference_train_data_diter,
                    inference_x_placeholder,
                    train_num_iterations,
                    schema_params)

            inference_validation_data_diter = tf1.data.make_one_shot_iterator(valid_dataset)
            assigned_validation_files = self._get_assigned_files(validation_data_dir, num_workers, task_index)
            validation_data_num_iterations = self._get_num_iterations(assigned_validation_files)
            valid_sample_ids_op, valid_labels_op, valid_weights_op, valid_prediction_score_op, \
                valid_prediction_score_per_coordinate_op = self._inference_model_fn(
                    inference_validation_data_diter,
                    inference_x_placeholder,
                    validation_data_num_iterations,
                    schema_params)

            if num_workers > 1:
                all_reduce_sync_op = collective_ops.all_reduce(
                    tf1.constant(0.0, tf1.float64),
                    num_workers,
                    FixedEffectLRModelLBFGS.TF_ALL_REDUCE_GROUP_KEY,
                    0,
                    merge_op='Add',
                    final_op='Id')

            init_variables_op = tf1.global_variables_initializer()

        session_creator = tf1.train.ChiefSessionCreator(master=self.server.target)
        tf_session = tf1.train.MonitoredSession(session_creator=session_creator)
        tf_session.run(init_variables_op)

        # load existing model if available
        logging("Try to load initial model coefficients...")
        prev_model = self._load_model(catch_exception=True)
        if prev_model is None or len(prev_model) != self.num_features + 1:
            logging("No initial model found, use all zeros instead.")
            x0 = np.zeros(self.num_features + 1)
        else:
            logging("Found a previous model,  loaded as the initial point for training")
            x0 = prev_model

        # Run all reduce warm up
        logging("All-reduce-warmup starts...")
        if num_workers > 1:
            start_time = time.time()
            tf_session.run([all_reduce_sync_op])
            logging("All-reduce-warmup --- {} seconds ---".format(time.time() - start_time))

        # Start training
        logging("Training starts...")
        start_time = time.time()
        self.model_coefficients, f_min, info = fmin_l_bfgs_b(
            func=self._compute_loss_and_gradients,
            x0=x0,
            approx_grad=False,
            m=self.num_correction_pairs,  # number of variable metrics corrections. default is 10.
            factr=self.factor,            # control precision, smaller the better.
            maxiter=self.max_iteration,
            args=(tf_session, train_x_placeholder, train_ops, task_index),
            disp=0)
        logging("Training --- {} seconds ---".format(time.time() - start_time))
        logging("\n------------------------------\nf_min: {}\nnum of funcalls: {}\ntask msg:"
                "{}\n------------------------------".format(f_min, info['funcalls'], info['task']))

        logging("Inference training data starts...")
        inference_training_data_ops = (train_sample_ids_op, train_labels_op, train_weights_op,
                                       train_prediction_score_op, train_prediction_score_per_coordinate_op)
        self._run_inference(self.model_coefficients,
                            tf_session,
                            inference_x_placeholder,
                            inference_training_data_ops,
                            task_index,
                            schema_params,
                            self.training_output_dir)

        logging("Inference validation data starts...")
        inference_validation_data_ops = (valid_sample_ids_op, valid_labels_op, valid_weights_op,
                                         valid_prediction_score_op, valid_prediction_score_per_coordinate_op)
        self._run_inference(self.model_coefficients,
                            tf_session,
                            inference_x_placeholder,
                            inference_validation_data_ops,
                            task_index,
                            schema_params,
                            self.validation_output_dir)

        # Final sync up and then reliably terminate all workers
        if (num_workers > 1):
            tf_session.run([all_reduce_sync_op])

        snooze_after_tf_session_closure(tf_session, self.delayed_exit_in_seconds)

        if is_chief:
            self._save_model()

        # remove the cached training input files
        if self.copy_to_local:
            tf1.gfile.DeleteRecursively(self.local_training_input_dir)