Example #1
0
 def _get_evaluation_data(self):
     jvalue = callZooFunc("float", "createMiniBatchRDDFromTFDatasetEval",
                          self.rdd.map(lambda x: x[0]), self.init_op_name,
                          self.table_init_op, self.output_names,
                          self.output_types, self.shard_index_op_name)
     rdd = jvalue.value().toJavaRDD()
     return rdd
Example #2
0
 def _get_training_data(self):
     jvalue = callZooFunc("float", "createTFDataFeatureSet",
                          self.rdd.map(lambda x: x[0]), self.init_op_name,
                          self.table_init_op, self.output_names,
                          self.output_types, self.shard_index_op_name,
                          self.inter_threads, self.intra_threads)
     return FeatureSet(jvalue=jvalue)
Example #3
0
 def _get_validation_data(self):
     if self.validation_dataset is not None:
         jvalue = callZooFunc("float", "createTFDataFeatureSet",
                              self.val_rdd.map(lambda x: x[0]), self.init_op_name,
                              self.table_init_op, self.output_names,
                              self.output_types, self.shard_index_op_name)
         return FeatureSet(jvalue=jvalue)
     return None
Example #4
0
    def _get_evaluation_data(self):

        feature_length = len(nest.flatten(self.tensor_structure[0]))
        jvalue = callZooFunc("float", "createMiniBatchRDDFromTFDatasetEval",
                             self.rdd.map(lambda x: x[0]), self.init_op_name, self.table_init_op,
                             self.output_names,
                             self.output_types, self.shard_index_op_name, feature_length)
        rdd = jvalue.value().toJavaRDD()
        return rdd
Example #5
0
 def _get_prediction_data(self):
     assert not self.drop_remainder, \
         "sanity check: drop_remainder should be false in this case," \
         " otherwise please report a bug"
     jvalue = callZooFunc("float", "createMiniBatchRDDFromTFDataset",
                          self.rdd.map(lambda x: x[0]), self.init_op_name, self.table_init_op,
                          self.output_names, self.output_types, self.shard_index_op_name)
     rdd = jvalue.value().toJavaRDD()
     return rdd