Ejemplo n.º 1
0
    def __init__(self, matches_uri, images_uri):

        # ensure matches is a merged version with only one shard (otherwise we have mem problems)
        #shard_uris = py_pert.GetShardUris(matches_uri)
        #num_shards = len(shard_uris)
        #CHECK_EQ(num_shards, 1, 'expectred merged matches pert with 1 shard but got %d shards' % num_shards)
        self.match_reader = py_pert.StringTableReader()
        CHECK(self.match_reader.Open(matches_uri))

        # open images table

        self.image_reader = py_pert.StringTableReader()
        CHECK(self.image_reader.Open(images_uri))

        return
Ejemplo n.º 2
0
def test_proto_table():
    filename = "local:///home/ubuntu/Desktop/test_proto_table"

    person = test_pb2.Person()
    person.first_name = 'foo'
    person.last_name = 'bar'

    writer = pert.ProtoTableWriter()
    writer.Open(person, filename, 10)
    writer.Add('key1', person.SerializeToString())
    writer.Add('key2', person.SerializeToString())
    writer.Close()

    print "press a key to continue",
    f = raw_input()

    reader = pert.StringTableReader()
    CHECK(reader.Open(filename))

    person = test_pb2.Person()
    for k, v in reader:
        person.ParseFromString(v)
        print "key %s person %s" % (k, person)

    return
Ejemplo n.º 3
0
def LoadTide(tide_uri):
  objectid_to_object = {}
  LOG(INFO, 'starting to load tide dataset...')
  # load list of images that belong to each tide object    
  tide_reader = py_pert.StringTableReader()
  CHECK(tide_reader.Open(tide_uri))
  for index, (k, v) in enumerate(tide_reader):                                  
    tide_object = tide_pb2.Object()
    tide_object.ParseFromString(v)
    CHECK(tide_object.IsInitialized())
    objectid_to_object[tide_object.id] = tide_object
  return objectid_to_object
Ejemplo n.º 4
0
    def test_basic(self):
        index_uri = 'local://%s/index' % self.tmp_path
        cbirutil.CreateIndex(self.features_uri, index_uri)
        index = cbirutil.LoadIndex(index_uri)
        features_reader = py_pert.StringTableReader()
        CHECK(features_reader.Open(self.features_uri))

        feature_counts_table = cbirutil.LoadFeatureCountsTable(
            self.feature_counts_uri)

        k = 10
        for key, value in features_reader:
            image_id = iwutil.ParseUint64Key(key)
            query_features = iw_pb2.ImageFeatures()
            query_features.ParseFromString(value)
            print image_id
            print 'num_features: %d' % (len(query_features.descriptors))
            ok, neighbors = index.Search([image_id], k, query_features)
            CHECK(ok)
            CHECK_GT(len(neighbors.features), 0)
            scorer = cbirutil.CreateQueryScorer(feature_counts_table)

            ok, results = scorer.Run(neighbors)
            CHECK(ok)
            print results


#        ok, results = scorer.TestSmoothnessFilter(neighbors)
#        CHECK(ok)
#
#        fig = plt.figure()
#        for i, entry in enumerate(results.entries):
#          ax = fig.add_subplot(len(results.entries),1,i+1)
#          #ax.hist(entry.smoothness_scores, 10, cumulative = True)
#
#          hist, bins = np.histogram(entry.smoothness_scores, bins=10, range=(0,100), normed=False)
#          hist = np.cumsum(hist)
#          width = bins[1]-bins[0]
#          ax.bar(bins[:-1], hist, width=width)
#          ax.set_xlim(bins[0],bins[-1])
#          #ax.set_ylim(0, 100)
#          ax.set_title('query: %d candidate: %d' % (image_id, entry.image_id))
#
#
#      plt.show()

#ok, results = scorer.Run(cbirutil.GetCbirKeypoints(query_features), neighbors)
#CHECK(ok)
#print results

        return
Ejemplo n.º 5
0
def main():

    base_uri = 'local://home/ubuntu/Desktop/datasets/tide_v12/'
    tide_uri = '%s/objectid_to_object.pert' % (base_uri)

    dataset = tide.TideDataset(tide_uri)

    print tide

    pos_imageids = []

    imageid_to_objectname = {}

    for id, obj in dataset.objectid_to_object.iteritems():
        print obj.name
        pos_imageids.extend(obj.pos_image_ids)
        for image_id in obj.pos_image_ids:
            imageid_to_objectname[image_id] = obj.name

    # sort for efficient access to pert
    pos_imageids.sort()

    images_pert_uri = '%s/photoid_to_image.pert' % (base_uri)

    reader = py_pert.StringTableReader()
    CHECK(reader.Open(images_pert_uri))

    for image_id in pos_imageids:
        ok, data = reader.Find(py_base.Uint64ToKey(image_id))
        CHECK(ok)
        jpeg_image = iw_pb2.JpegImage()
        jpeg_image.ParseFromString(data)
        objectname = imageid_to_objectname[image_id]
        dirname = './extracted/%s' % (objectname)
        filename = '%s/%d.jpg' % (dirname, image_id)

        if not os.path.exists(dirname):
            os.makedirs(dirname)

        f = open(filename, 'wb')
        f.write(jpeg_image.data)

    return
Ejemplo n.º 6
0
def main():

    images_pert_uril = 'local:///media/ebs/4a4b34/tide_v13/photoid_to_image.pert'
    images_to_extract = [
        2071492, 2087400, 2112291, 2102113, 2080088, 2083122, 2107730
    ]

    reader = py_pert.StringTableReader()
    CHECK(reader.Open(images_pert_uril))

    for image_id in images_to_extract:
        ok, data = reader.Find(py_strings.Uint64ToKey(image_id))
        CHECK(ok)
        jpeg_image = iw_pb2.JpegImage()
        jpeg_image.ParseFromString(data)
        filename = '%d.jpg' % (image_id)
        f = open(filename, 'wb')
        f.write(jpeg_image.data)

    return
Ejemplo n.º 7
0
    def Run(self):
        print 'pid: %s' % os.getpid()
        print 'id(py_pert): %s' % id(py_pert)
        ok, scheme, path = py_pert.ParseUri(self.uri)
        print 'path: %s' % path
        print 'exists: %s' % py_pert.Exists(self.uri)
        if py_pert.Exists(self.uri):
            print 'num shards: %s' % py_pert.GetNumShards(self.uri)
            reader = py_pert.StringTableReader()
            print 'about to open reader'
            reader.Open(self.uri)
            print 'about to use reader'
            count = 0
            for k, v in reader:
                print k
                count += 1
                if count > 5:
                    break

        return True
Ejemplo n.º 8
0
def test_string_table():
    filename = "local:///home/ubuntu/Desktop/test_string_table"

    person = test_pb2.Person()
    person.first_name = 'foo'
    person.last_name = 'bar'

    writer = pert.StringTableWriter()
    writer.Open(filename, 1)
    writer.Add('key1', person.SerializeToString())
    writer.Add('key2', person.SerializeToString())
    writer.Close()

    reader = pert.StringTableReader()
    reader.Open(filename)

    for k, v in reader:
        my_person = test_pb2.Person()
        my_person.ParseFromString(v)
        print "key %s value %s" % (k, my_person)

    return
Ejemplo n.º 9
0
 def test_basic(self):
     index_uri = 'local://%s/index' % self.tmp_path
     cbirutil.CreateIndex(self.features_uri, index_uri)
     index = cbirutil.LoadIndex(index_uri)
     features_reader = py_pert.StringTableReader()
     CHECK(features_reader.Open(self.features_uri))
     feature_counts_table = cbirutil.LoadFeatureCountsTable(
         self.feature_counts_uri)
     k = 10
     for key, value in features_reader:
         image_id = iwutil.ParseUint64Key(key)
         query_features = iw_pb2.ImageFeatures()
         query_features.ParseFromString(value)
         print image_id
         print 'num_features: %d' % (len(query_features.descriptors))
         ok, neighbors = index.Search([image_id], k, query_features)
         CHECK(ok)
         CHECK_GT(len(neighbors.features), 0)
         scorer = CreateQueryScorer(feature_counts_table)
         ok, results = scorer.Run(neighbors)
         CHECK(ok)
         print results
     return
Ejemplo n.º 10
0
Archivo: util.py Proyecto: heathkh/iwct
 def Run(self):             
   bow_uri = self.GetInput('bow').GetUri()
   reader = py_pert.StringTableReader()    
   CHECK(reader.Open(bow_uri))    
   visual_vocab_size = self.cbir_bow_params.visual_vocab_size
   num_docs = reader.Entries()
   index = None
   if self.cbir_bow_params.implementation == 'inria':
     index = py_inria.InriaIndex()
   elif self.cbir_bow_params.implementation == 'ctis':
     index = py_ctis.CtisIndex()
     index.StartCreate(visual_vocab_size, num_docs)
   else:
     LOG(FATAL, 'unexpected')  
   
   #vv_uri = self.GetInput('visual_vocab').GetUri()
   temp_ivf_filepath = tempfile.mkdtemp()
       
   bag_of_words = bow_pb2.BagOfWords()
   progress = iwutil.MakeProgressBar(reader.Entries())
   for i, (key, value) in enumerate(reader):
     image_id = iwutil.KeyToUint64(key)
     bag_of_words.ParseFromString(value)
     index.Add(image_id, bag_of_words)
     progress.update(i)
   
   index.Save(temp_ivf_filepath)
   
   
   py_pert.Remove(self.index_base_uri)
   mr.CopyUri('local://' + temp_ivf_filepath , self.index_base_uri)    
   CHECK(py_pert.Exists(self.index_base_uri + '/index.ivf'))
   CHECK(py_pert.Exists(self.index_base_uri + '/index.ivfids'))
   
   shutil.rmtree(temp_ivf_filepath, ignore_errors=True)
   return True     
Ejemplo n.º 11
0
    def Run(self):
        itergraph_state = LoadObjectFromUri(
            self.GetInput('prev_state').GetUri())
        if True:  # hack to put this block in it's own scope to force release of memory resources
            reader = py_pert.StringTableReader()
            CHECK(reader.Open(self.GetInput('match_results').GetUri()))
            match_result = iw_pb2.GeometricMatchResult()
            num_entries = reader.Entries()
            if num_entries:
                pbar = iwutil.MakeProgressBar(num_entries)
                for i, (k, v) in enumerate(reader):
                    pbar.update(i)
                    match_result.ParseFromString(v)
                    success = False
                    for match in match_result.matches:
                        if match.nfa < -20:
                            success = True
                    if success:
                        itergraph_state.AddSuccesfulEdge(
                            match_result.image_a_id, match_result.image_b_id,
                            match_result.properties.score, self.phase)
                    else:
                        itergraph_state.AddFailedEdge(match_result.image_a_id,
                                                      match_result.image_b_id)
            print 'edges: %d' % (len(itergraph_state.edges))
        SaveObjectToUri(itergraph_state,
                        self.GetOutput('itergraph_state').GetUri())
        return


#def EvalImageGraph(itergraph_state_uri, tide_uri):
#  # compute edge stats
#  num_within_cluster, num_cross_cluster = CountCorrectIncorrectEdges(itergraph_state_uri, tide_uri)
#
#  # compute label prop performance
#  itergraph_state = LoadObjectFromUri(itergraph_state_uri)
#  eval_graph_uri = 'local://tmp/eval3_graph.pert'
#  itergraph_state.SaveAsEval3Graph(eval_graph_uri)
#
#  num_training_images = 100
#  num_trials = 4
#  evaluation = py_eval3.EvaluationRunner(eval_graph_uri, tide_uri, num_training_images, num_trials)
#  ok, result = evaluation.Run()
#  CHECK(ok)
#
#  return num_within_cluster, num_cross_cluster, result
#
#
#class EvalImageGraphFlow(core.Flow):
#  def __init__(self, base_uri, itergraph_state, tide):
#    super(EvalImageGraphFlow,self).__init__()
#    self.base_uri = base_uri
#    self.AddInput('itergraph_state', itergraph_state)
#    self.AddInput('tide', tide)
#
#    self.AddOutput('eval', core.FileResource(self, '%s/eval.txt' % (base_uri)))
#
#    return
#
#
#  def Run(self):
#    # compute edge stats
#    itergraph_state_uri = self.GetInput('itergraph_state').GetUri()
#    tide_uri = self.GetInput('tide').GetUri()
#
#    num_within_cluster, num_cross_cluster, result = EvalImageGraph(itergraph_state_uri, tide_uri)
#
#    lines = []
#    lines.append('num_within_cluster: %d' % (num_within_cluster))
#    lines.append('num_cross_cluster: %d' % (num_cross_cluster))
#    lines.append('fraction within cluster: %f' % (float(num_within_cluster)/ (num_within_cluster + num_cross_cluster)))
#    lines.append(str(result))
#    report = '\n'.join(lines)
#    print report
#
#    path = mr.UriToNfsPath(self.GetOutput('eval').GetUri())
#    f = open(path, 'w')
#    f.write(report)
#
#    return
#

#class BasicMatchBatchPlanningFlow(core.Flow):
#  def __init__(self, base_uri, candidates, max_batch_size, max_replication_factor, num_shards_features):
#    super(BasicMatchBatchPlanningFlow,self).__init__()
#
#    self.AddInput('candidates', candidates)
#    self.AddOutput('sorted_match_batches', core.PertResource(self, '%s/sorted_match_batches.pert' % (base_uri)))
#
#    self.max_batch_size = max_batch_size
#    self.max_replication_factor = max_replication_factor
#
#    self.num_shards_features = num_shards_features
#    self.match_groups = {}
#    self.num_replications = {}
#    return
#
#  def _GetNumReplications(self, image_id):
#    replications = 0
#    if image_id in self.num_replications:
#      replications = self.num_replications[image_id]
#    return replications
#
#  def _IncrementReplications(self, image_id):
#    if image_id not in self.num_replications:
#      self.num_replications[image_id] = 0
#    self.num_replications[image_id] += 1
#    return
#
#  def Run(self):
#    reader = py_pert.StringTableReader()
#    reader.Open(self.GetInput('candidates').GetUri())
#
#    self.match_groups = {}
#    num_selected_candidates = 0
#    prev_score = -1e6
#
#    widgets = [Percentage(), ' ', Bar(), ' ', ETA()]
#    pbar = ProgressBar(widgets=widgets, maxval=reader.Entries()).start()
#
#    num_edges_skipped_max_replication_constraint = 0
#    for i, (k, v) in enumerate(reader):
#      image_a_id, image_b_id = ParseUint64KeyPair(v)
#
#      # check precondition... pert is sorted by scores
#      score = iwutil.KeyToDouble(k)
#      CHECK_GE(score, prev_score)
#      prev_score = score
#
#      if image_a_id not in self.match_groups:
#        self.match_groups[image_a_id] = []
#
#      match_group_size = len(self.match_groups[image_a_id])
#
#      if match_group_size < self.max_batch_size:
#        # test max replication condition
#        num_replications = self._GetNumReplications(image_b_id)
#        if num_replications < self.max_replication_factor:
#          self._IncrementReplications(image_b_id)
#          self.match_groups[image_a_id].append(image_b_id)
#          num_selected_candidates += 1
#          pbar.update(num_selected_candidates)
#        else:
#          num_edges_skipped_max_replication_constraint += 1
#
#
#    print 'num_edges_skipped_max_replication_constraint: %d' % (num_edges_skipped_max_replication_constraint)
#
#    # write out the match plan (must be sorted by key for future join stage)
#    metadata_entries = []
#
#    for batch_id, (batch_primary_image, batch_image_ids) in enumerate(self.match_groups.iteritems()):
#
#      if not batch_image_ids:
#        continue
#
#      batch_name = iwutil.Uint64ToKey(batch_id)
#      match_batch_metadata = iw_pb2.MatchBatchMetadata()
#      match_batch_metadata.image_id = batch_primary_image
#      match_batch_metadata.batch_name = batch_name
#      match_batch_metadata.is_primary = True
#      metadata_entries.append( match_batch_metadata )
#
#      for image_id in batch_image_ids:
#        next_metadata = iw_pb2.MatchBatchMetadata()
#        next_metadata.image_id = image_id
#        next_metadata.batch_name = batch_name
#        next_metadata.is_primary = False
#        metadata_entries.append( next_metadata )
#
#    # image_id will be the key of output (since we are about to join by image_id), so need to sort by image_id
#    metadata_entries.sort(key= lambda m : iwutil.Uint64ToKey(m.image_id))
#    match_plan_writer = py_pert.ProtoTableWriter()
#    uri = self.GetOutput('sorted_match_batches').GetUri()
#    CHECK(match_plan_writer.Open(iw_pb2.MatchBatchMetadata(), uri, self.num_shards_features), 'failed to open %s' % (uri))  # to do join with features, must be sharded same way as features
#    for metadata in metadata_entries:
#      match_plan_writer.Add(iwutil.Uint64ToKey(metadata.image_id), metadata.SerializeToString())
#    match_plan_writer.Close()
#
#    return

#def EnsureParentPathExists(f):
#  d = os.path.dirname(f)
#  if not os.path.exists(d):
#    os.makedirs(d)
#  return

#def ParseUint64KeyPair(key):
#  CHECK_EQ(len(key), 16)
#  id_a = iwutil.KeyToUint64(key[0:8])
#  id_b = iwutil.KeyToUint64(key[8:16])
#  return id_a, id_b

#class TideObject(object):
#  def __init__(self):
#    self.id = None
#    self.name = None
#    self.image_ids = []
#    return
#
#  def LoadFromProto(self, id, tide_object_proto):
#    self.id = id
#    self.name = tide_object_proto.name
#    for photo in tide_object_proto.photos:
#      self.image_ids.append(photo.id)
#    return
#
#
#class TideDataset():
#  def __init__(self, tide_uri):
#    self.tideid_to_tideobject = {}
#    self.imageid_to_objectid = {}
#
#    # load list of images that belong to each tide object
#    tide_reader = py_pert.StringTableReader()
#    tide_reader.Open(tide_uri)
#    for index, (k, v) in enumerate(tide_reader):
#      tide_object = tide_pb2.Object()
#      tide_object.ParseFromString(v)
#      obj = TideObject()
#      obj.LoadFromProto(index, tide_object)
#      self.tideid_to_tideobject[obj.id] = obj
#
#    for tideid, tideobject in self.tideid_to_tideobject.iteritems():
#      object_id = tideobject.id
#      for image_id in tideobject.image_ids:
#        self.imageid_to_objectid[image_id] = object_id
#    return
#
#  def KnownImages(self, image_a_id, image_b_id):
#    return image_a_id in self.imageid_to_objectid and image_b_id in self.imageid_to_objectid
#
#  def EdgeWithinCluster(self, image_a_id, image_b_id):
#    CHECK(self.KnownImages(image_a_id, image_b_id))
#    object_a = self.imageid_to_objectid[image_a_id]
#    object_b = self.imageid_to_objectid[image_b_id]
#    return object_a == object_b
#

#
#def CountCorrectIncorrectEdges(itergraph_state_uri, tide_uri):
#  itergraph_state = LoadObjectFromUri(itergraph_state_uri)
#  tide = TideDataset(tide_uri)
#
#  num_within_cluster = 0
#  num_cross_cluster = 0
#
#  object_to_num_within_cluster = {}
#  for tide_id in tide.tideid_to_tideobject.iterkeys():
#    object_to_num_within_cluster[tide_id] = 0
#
#
#  for edge in itergraph_state.edges:
#    if tide.EdgeWithinCluster(edge.image_a_id, edge.image_b_id):
#      num_within_cluster += 1
#      object_to_num_within_cluster[tide.imageid_to_objectid[edge.image_a_id]] += 1
#    else:
#      num_cross_cluster += 1
#
#  for tide_id, num in object_to_num_within_cluster.iteritems():
#    print '%s: %d' % (tide.tideid_to_tideobject[tide_id].name, num)
#
#
#  return num_within_cluster, num_cross_cluster

#def GetPhaseBaseUri(base_uri,  phase):
#    return base_uri + '/phase%03d/' % (phase)
Ejemplo n.º 12
0
    def Run(self):
        LOG(
            INFO,
            'waiting to let running processes give up memory... I need a lot and may not get enough if we rush things...'
        )
        time.sleep(30)
        itergraph_state = LoadObjectFromUri(
            self.GetInput('prev_state').GetUri())
        reader = py_pert.StringTableReader()
        CHECK(reader.Open(self.GetInput('candidates').GetUri()))
        self.match_groups = {}
        num_selected_candidates = 0

        pbar = iwutil.MakeProgressBar(self.max_candidates_per_phase)
        num_edges_skipped_max_degree_constraint = 0
        num_edges_skipped_max_replication_constraint = 0
        prev_score = -float('inf')

        for ordering_key, candidate_pair_data in reader:
            image_a_id, image_b_id = iwutil.ParseUint64KeyPair(
                candidate_pair_data)
            if itergraph_state.PreviouslyAttempted(image_a_id, image_b_id):
                #print 'skipping previous attempted edge'
                continue
            # check precondition... candidates pert is sorted (increasing by rank or by negative cbir score)
            score = iwutil.KeyToDouble(ordering_key)
            CHECK_GE(score, prev_score)
            prev_score = score

            if image_a_id not in self.match_groups:
                self.match_groups[image_a_id] = []

            match_group_size = len(self.match_groups[image_a_id])

            if match_group_size < self.max_batch_size:
                # test vertex degree condition
                degree_a = itergraph_state.GetDegree(image_a_id)
                degree_b = itergraph_state.GetDegree(image_b_id)

                # version 1: skip candidate edge if either of the vertices has many edges
                #if degree_a < self.max_vertex_degree and degree_b < self.max_vertex_degree:

                # version 2: skip candidate edge only if both of the vertices have many edges
                if degree_a < self.max_vertex_degree or degree_b < self.max_vertex_degree:
                    # test max replication condition
                    num_replications = self._GetNumReplications(image_b_id)
                    if num_replications < self.max_replication_factor:
                        self._IncrementReplications(image_b_id)
                        self.match_groups[image_a_id].append(image_b_id)
                        num_selected_candidates += 1
                        pbar.update(num_selected_candidates)
                    else:
                        num_edges_skipped_max_replication_constraint += 1
                else:
                    num_edges_skipped_max_degree_constraint += 1

            if num_selected_candidates >= self.max_candidates_per_phase:
                break

        pbar.finish()

        print ''
        print ''
        print 'num_edges_skipped_max_replication_constraint: %d' % (
            num_edges_skipped_max_replication_constraint)
        print 'num_edges_skipped_max_degree_constraint: %d' % (
            num_edges_skipped_max_degree_constraint)
        print ''
        print ''

        # write out the match plan (must be sorted by key for future join stage)
        metadata_entries = []

        for batch_id, (batch_primary_image, batch_image_ids) in enumerate(
                self.match_groups.iteritems()):
            if len(batch_image_ids) == 0:
                continue
            batch_name = iwutil.Uint64ToKey(batch_id)
            CHECK(batch_name)
            CHECK(len(batch_name))
            match_batch_metadata = iw_pb2.MatchBatchMetadata()
            match_batch_metadata.image_id = batch_primary_image
            match_batch_metadata.batch_name = batch_name
            match_batch_metadata.is_primary = True
            metadata_entries.append(match_batch_metadata)

            for image_id in batch_image_ids:
                next_metadata = iw_pb2.MatchBatchMetadata()
                next_metadata.image_id = image_id
                next_metadata.batch_name = batch_name
                next_metadata.is_primary = False
                metadata_entries.append(next_metadata)

        # image_id will be the key of output, so need to sort by image_id
        metadata_entries.sort(key=lambda m: m.image_id)
        match_batches_uri = self.GetOutput('sorted_match_batches').GetUri()

        # TODO(heathkh): "closing" doesn't flush to disk... this is a bug!
        #    match_plan_writer = py_pert.ProtoTableWriter()
        #    num_shards_features = py_pert.GetNumShards(self.features.GetUri())
        #    CHECK(match_plan_writer.Open(iw_pb2.MatchBatchMetadata(), match_batches_uri, num_shards_features))
        #    for metadata in metadata_entries:
        #      CHECK(metadata.IsInitialized())
        #      key = iwutil.Uint64ToKey(metadata.image_id)
        #      CHECK(match_plan_writer.Add(key, metadata.SerializeToString()))
        #    match_plan_writer.Close()

        # TODO(kheath):   Work around for above bug is to run a MR stage to reshard
        tmp_match_batches_uri = self.GetOutput(
            'sorted_match_batches').GetUri() + '_to_be_sharded'
        match_plan_writer = py_pert.ProtoTableWriter()
        num_shards_features = py_pert.GetNumShards(self.features.GetUri())
        CHECK(
            match_plan_writer.Open(iw_pb2.MatchBatchMetadata(),
                                   tmp_match_batches_uri, 1))

        for metadata in metadata_entries:
            CHECK(metadata.IsInitialized())
            CHECK(
                match_plan_writer.Add(iwutil.Uint64ToKey(metadata.image_id),
                                      metadata.SerializeToString()))
        match_plan_writer.Close()

        # manually reshard
        pertedit_bin = 'pertedit'
        cmd = '%s --input %s --output %s --new_block_size_mb=10 --num_output_shards=%d' % (
            pertedit_bin, tmp_match_batches_uri, match_batches_uri,
            num_shards_features)
        print cmd
        CHECK_EQ(ExecuteCmd(cmd), 0)

        CHECK(py_pert.Exists(match_batches_uri))

        ok, fp = py_pert.GetShardSetFingerprint(match_batches_uri)
        CHECK(ok)
        CHECK(len(fp), 32)
        CHECK_NE(fp, 'd41d8cd98f00b204e9800998ecf8427e',
                 'invalid hash of empty string')

        return
Ejemplo n.º 13
0
def CheckInputPreconditions(input_images_uri):
    reader = py_pert.StringTableReader()
    CHECK(reader.Open(input_images_uri))
    CHECK(reader.IsSorted(), 'input images must be sorted')
    # TODO(heathkh): also need to check that there are no repeated keys... doing this locally is slow since we must read all the data
    return