def _dump_joined_items(self, matching_list):
        start_tm = time.time()
        prev_leader_idx = self._leader_restart_index + 1
        for item in matching_list:
            (fe, li, fi) = item
            if self._enable_negative_example_generator and li > prev_leader_idx:
                for example in \
                    self._negative_example_generator.generate(
                        fe, prev_leader_idx, li):

                    builder = self._get_data_block_builder(True)
                    assert builder is not None, "data block builder must be "\
                                                "not None if before dummping"
                    builder.append_item(example[0], example[1], example[2],
                                        None, True)
                    if builder.check_data_block_full():
                        yield self._finish_data_block()
            prev_leader_idx = li + 1

            builder = self._get_data_block_builder(True)
            assert builder is not None, "data block builder must be "\
                                        "not None if before dummping"
            builder.append_item(fe, li, fi, None, True)
            if builder.check_data_block_full():
                yield self._finish_data_block()
        metrics.emit_timer(name='attribution_joiner_dump_joined_items',
                           value=int(time.time() - start_tm),
                           tags=self._metrics_tags)
    def _dump_joined_items(self, indexed_pairs):
        start_tm = time.time()
        for ip in indexed_pairs:
            if self._enable_negative_example_generator:
                for example in \
                    self._negative_example_generator.generate(ip.fe, ip.li):
                    builder = self._get_data_block_builder(True)
                    assert builder is not None, "data block builder must be "\
                                                "not None if before dummping"
                    # example:  (li, fi, item)
                    builder.append_item(example[0], example[1],
                                        example[2], None, True, 0)
                    if builder.check_data_block_full():
                        yield self._finish_data_block()

            builder = self._get_data_block_builder(True)
            assert builder is not None, "data block builder must be "\
                                        "not None if before dummping"
            builder.append_item(ip.fe, ip.li, ip.fi, None, True,
                                joined=1)
            if builder.check_data_block_full():
                yield self._finish_data_block()
        metrics.emit_timer(name='universal_joiner_dump_joined_items',
                           value=int(time.time()-start_tm),
                           tags=self._metrics_tags)
 def _emit_dumper_metrics(self, file_index, dumped_index):
     dump_duration = time.time() - self._latest_dump_timestamp
     metrics.emit_timer(name='example_id_dump_duration',
                        value=int(dump_duration),
                        tags=self._metrics_tags)
     metrics.emit_store(name='example_dump_file_index',
                        value=file_index,
                        tags=self._metrics_tags)
     metrics.emit_store(name='example_id_dumped_index',
                        value=dumped_index,
                        tags=self._metrics_tags)
Example #4
0
 def _dump_data_blocks(self):
     while self.need_dump():
         meta = self._get_next_data_block_meta()
         if meta is not None:
             start_tm = time.time()
             self._raw_data_visitor.active_visitor()
             self._dump_data_block_by_meta(meta)
             dump_duration = time.time() - start_tm
             metrics.emit_timer(name='data_block_dump_duration',
                                value=int(dump_duration),
                                tags=self._metrics_tags)
Example #5
0
 def _evit_stale_follower_cache(self):
     start_tm = time.time()
     reserved_items = self._evict_impl(self._follower_join_window,
                                       self._evict_if_useless)
     if len(reserved_items) < self._max_window_size:
         self._follower_join_window.reset(reserved_items, False)
         return
     reserved_items = self._evict_impl(reserved_items, self._evict_if_force)
     self._follower_join_window.reset(reserved_items, False)
     metrics.emit_timer(name='stream_joiner_evit_stale_follower_cache',
                        value=int(time.time() - start_tm),
                        tags=self._metrics_tags)
Example #6
0
 def _update_join_cache(self):
     start_tm = time.time()
     new_unjoined_example_ids = []
     for example_id in self._leader_unjoined_example_ids:
         if example_id in self._follower_example_cache:
             self._joined_cache[example_id] = \
                     self._follower_example_cache[example_id]
         else:
             new_unjoined_example_ids.append(example_id)
     self._leader_unjoined_example_ids = new_unjoined_example_ids
     metrics.emit_timer(name='stream_joiner_update_join_cache',
                        value=int(time.time() - start_tm),
                        tags=self._metrics_tags)
Example #7
0
 def receive(self, iter_id, name):
     logging.debug('Data: Waiting to receive %s for iter %d.', name,
                   iter_id)
     start_time = time.time()
     with self._condition:
         while (iter_id not in self._received_data) \
                 or (name not in self._received_data[iter_id]):
             self._condition.wait()
         data = self._received_data[iter_id][name]
     duration = time.time() - start_time
     metrics.emit_timer('receive_timer', duration)
     logging.debug('Data: received %s for iter %d after %f sec.', name,
                   iter_id, duration)
     return tf.make_ndarray(data.tensor)
Example #8
0
 def _fill_follower_join_window(self, raw_data_finished):
     start_tm = time.time()
     start_pos = self._follower_join_window.size()
     follower_enough = self._fill_join_windows(self._follower_visitor,
                                               self._follower_join_window,
                                               self._follower_example_cache)
     end_pos = self._follower_join_window.size()
     eids = [(self._follower_join_window[idx][0],
              self._follower_join_window[idx][1].example_id)
             for idx in range(start_pos, end_pos)]
     self._joiner_stats.fill_follower_example_ids(eids)
     metrics.emit_timer(name='stream_joiner_fill_leader_join_window',
                        value=int(time.time() - start_tm),
                        tags=self._metrics_tags)
     return follower_enough or raw_data_finished
Example #9
0
 def _evit_stale_follower_cache(self):
     start_tm = time.time()
     tmp_sz = self._follower_join_window.size()
     reserved_items = self._evict_impl(self._follower_join_window,
                                       self._evict_if_useless)
     logging.debug("evict_if_useless %d to %d", tmp_sz, len(reserved_items))
     if len(reserved_items) < self._max_window_size:
         self._follower_join_window.reset(reserved_items, False)
         return
     tmp_sz = len(reserved_items)
     reserved_items = self._evict_impl(reserved_items, self._evict_if_force)
     logging.debug("evict_if_force %d to %d", tmp_sz, len(reserved_items))
     self._follower_join_window.reset(reserved_items, False)
     metrics.emit_timer(name='stream_joiner_evit_stale_follower_cache',
                        value=int(time.time() - start_tm),
                        tags=self._metrics_tags)
    def _fill_follower_join_window(self, raw_data_finished):
        start_tm = time.time()
        idx = self._follower_join_window.size()
        filled_enough = self._fill_join_windows(self._follower_visitor,
                                      self._follower_join_window)
        eids = []
        while idx < self._follower_join_window.size():
            eids.append((self._follower_join_window[idx].index,
                 self._follower_join_window[idx].item.example_id))
            idx += 1

        self._joiner_stats.fill_follower_example_ids(eids)
        metrics.emit_timer(name=\
                           'universal_joiner_fill_follower_join_window',
                           value=int(time.time()-start_tm),
                           tags=self._metrics_tags)
        return filled_enough or raw_data_finished
Example #11
0
    def _fill_follower_join_window(self):
        start_tm = time.time()
        idx = self._follower_join_window.size()
        filled_new_example = self._fill_join_windows(
            self._follower_visitor, self._follower_join_window)
        eids = []
        while idx < self._follower_join_window.size():
            eids.append((self._follower_join_window[idx][0],
                         self._follower_join_window[idx][1].example_id))
            idx += 1

        self._joiner_stats.fill_follower_example_ids(eids)
        metrics.emit_timer(name=\
                           'attribution_joiner_fill_follower_join_window',
                           value=int(time.time()-start_tm),
                           tags=self._metrics_tags)
        return filled_new_example
Example #12
0
 def make_processor(self, next_index):
     input_finished = False
     with self._lock:
         if next_index is None:
             return
         if self._check_index_rollback(next_index):
             self._batch_queue = []
             self._flying_item_count = 0
         if len(self._batch_queue) > 0:
             end_batch = self._batch_queue[-1]
             next_index = end_batch.begin_index + len(end_batch)
         input_finished = self._input_finished
     assert next_index >= 0, "the next index should >= 0"
     end_batch = None
     batch_finished = False
     iter_round = 0
     processed_index = None
     start_tm = time.time()
     for batch, batch_finished in self._make_inner_generator(next_index):
         if batch is not None:
             if len(batch) > 0:
                 latency_mn = '{}.produce.latency'.format(self.name())
                 metrics.emit_timer(name=latency_mn,
                                    value=time.time() - start_tm,
                                    tags=self._get_metrics_tags())
                 store_mn = '{}.produce.index'.format(self.name())
                 metrics.emit_store(name=store_mn,
                                    value=batch.begin_index + len(batch) -
                                    1,
                                    tags=self._get_metrics_tags())
                 self._append_next_item_batch(batch)
                 yield batch
                 start_tm = time.time()
             self._update_last_index(batch.begin_index + len(batch) - 1)
             iter_round += 1
             processed_index = batch.begin_index + len(batch) - 1
             if iter_round % 16 == 0:
                 logging.info("%s process to index %d", self.name(),
                              processed_index)
     if processed_index is not None:
         logging.info("%s process to index %d when round finished",
                      self.name(), processed_index)
     if input_finished and batch_finished:
         self._set_process_finished()
Example #13
0
 def _dump_joined_items(self):
     start_tm = time.time()
     prev_leader_idx = 0
     neg_samples = {}
     for (leader_idx, leader_item) in self._leader_join_window:
         if prev_leader_idx == 0:
             prev_leader_idx = leader_idx
         eid = leader_item.example_id
         if (eid not in self._follower_example_cache
                 and eid not in self._joined_cache):
             if self._enable_negative_example_generator:
                 neg_samples[leader_idx] = leader_item
             continue
         if eid not in self._joined_cache:
             self._joined_cache[eid] = self._follower_example_cache[eid]
         follower_example = self._joined_cache[eid]
         if (self._enable_negative_example_generator
                 and leader_idx > prev_leader_idx):
             self._negative_example_generator.update(neg_samples)
             for example in self._negative_example_generator.generate(
                     follower_example[1], prev_leader_idx, leader_idx):
                 builder = self._get_data_block_builder(True)
                 assert builder is not None, "data block builder must not " \
                                             "be None before dumping"
                 builder.append_item(example[0], example[1], example[2])
                 self._optional_stats.update_stats(example[0],
                                                   kind='negative')
                 if builder.check_data_block_full():
                     yield self._finish_data_block()
             neg_samples = {}
         builder = self._get_data_block_builder(True)
         assert builder is not None, "data block builder must not be "\
                                     "None before dumping"
         follower_idx, item = self._joined_cache[eid]
         builder.append_item(item, leader_idx, follower_idx)
         self._optional_stats.update_stats(item, kind='joined')
         prev_leader_idx = leader_idx
         if builder.check_data_block_full():
             yield self._finish_data_block()
     metrics.emit_timer(name='stream_joiner_dump_joined_items',
                        value=int(time.time() - start_tm),
                        tags=self._metrics_tags)
Example #14
0
 def _dump_joined_items(self):
     start_tm = time.time()
     for (li, le) in self._leader_join_window:
         eid = le.example_id
         if eid not in self._follower_example_cache and \
                 eid not in self._joined_cache:
             continue
         if eid not in self._joined_cache:
             self._joined_cache[eid] = \
                     self._follower_example_cache[eid]
         builder = self._get_data_block_builder(True)
         assert builder is not None, "data block builder must be "\
                                     "not None if before dummping"
         fi, item = self._joined_cache[eid]
         builder.append_item(item, li, fi)
         if builder.check_data_block_full():
             yield self._finish_data_block()
     metrics.emit_timer(name='stream_joiner_dump_joined_items',
                        value=int(time.time() - start_tm),
                        tags=self._metrics_tags)
Example #15
0
 def _fill_leader_join_window(self, sync_example_id_finished):
     if not self._fill_leader_enough:
         start_tm = time.time()
         start_pos = self._leader_join_window.size()
         if not self._fill_join_windows(self._leader_visitor,
                                        self._leader_join_window, None):
             self._fill_leader_enough = sync_example_id_finished
         else:
             self._fill_leader_enough = True
         if self._fill_leader_enough:
             self._leader_unjoined_example_ids = \
                 [item.example_id for _, item in self._leader_join_window]
         end_pos = self._leader_join_window.size()
         eids = [(self._leader_join_window[idx][0],
                  self._leader_join_window[idx][1].example_id)
                 for idx in range(start_pos, end_pos)]
         self._joiner_stats.fill_leader_example_ids(eids)
         metrics.emit_timer(name='stream_joiner_fill_leader_join_window',
                            value=int(time.time() - start_tm),
                            tags=self._metrics_tags)
     return self._fill_leader_enough
Example #16
0
 def _receive(self, iter_id, name):
     logging.debug('Data: Waiting to receive %s for iter %d.', name,
                   iter_id)
     start_time = time.time()
     with self._condition:
         while (iter_id not in self._received_data) \
                 or (name not in self._received_data[iter_id]):
             if self._peer_next_iter_id > iter_id:
                 msg = 'Peer committed without sending %s. ' \
                     'Please check model code'%name
                 logging.fatal(msg)
                 raise RuntimeError(msg)
             if not self._condition.wait(10):
                 logging.warning(
                     'Data: Still waiting to receive %s for iter %d...',
                     name, iter_id)
         data = self._received_data[iter_id][name]
     duration = time.time() - start_time
     metrics.emit_timer('receive_timer', duration)
     logging.debug('Data: received %s for iter %d after %f sec.', name,
                   iter_id, duration)
     return data
Example #17
0
 def _dump_joined_items(self):
     start_tm = time.time()
     prev_leader_idx = 0
     neg_samples = {}
     for (li, le) in self._leader_join_window:
         if prev_leader_idx == 0:
             prev_leader_idx = li
         eid = le.example_id
         if eid not in self._follower_example_cache and \
                 eid not in self._joined_cache:
             if self._enable_negative_example_generator:
                 neg_samples[li] = le
             continue
         if eid not in self._joined_cache:
             self._joined_cache[eid] = \
                     self._follower_example_cache[eid]
         fe = self._joined_cache[eid]
         if self._enable_negative_example_generator and li > prev_leader_idx:
             self._negative_example_generator.update(neg_samples)
             for example in \
                 self._negative_example_generator.generate(
                     fe[1], prev_leader_idx, li):
                 builder = self._get_data_block_builder(True)
                 assert builder is not None, "data block builder must be "\
                                             "not None if before dummping"
                 builder.append_item(example[0], example[1], example[2])
                 if builder.check_data_block_full():
                     yield self._finish_data_block()
             neg_samples = {}
         builder = self._get_data_block_builder(True)
         assert builder is not None, "data block builder must be "\
                                     "not None if before dummping"
         fi, item = self._joined_cache[eid]
         builder.append_item(item, li, fi)
         if builder.check_data_block_full():
             yield self._finish_data_block()
     metrics.emit_timer(name='stream_joiner_dump_joined_items',
                        value=int(time.time() - start_tm),
                        tags=self._metrics_tags)
Example #18
0
    def evaluate(self,
                 input_fn,
                 checkpoint_path=None):
        if not tf.train.latest_checkpoint(checkpoint_path):
            raise ValueError(
                "Could not find trained model at %s" % checkpoint_path)

        with tf.Graph().as_default():
            features, labels = self._get_features_and_labels_from_input_fn(
                input_fn, ModeKeys.EVAL)
            spec, model = self._get_model_spec(features, labels, ModeKeys.EVAL)

            # Track the average loss in default
            eval_metric_ops = spec.eval_metric_ops or {}
            if model_fn_lib.LOSS_METRIC_KEY not in eval_metric_ops:
                loss_metric = tf.metrics.mean(spec.loss)
                eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric

            # Create the real eval op
            update_ops, eval_dict = _extract_metric_update_ops(eval_metric_ops)
            update_ops.extend(model._train_ops)
            eval_op = tf.group(*update_ops)

            # Also track the global step
            if tf.GraphKeys.GLOBAL_STEP in eval_dict:
                raise ValueError(
                    'Metric with name `global_step` is not allowed, because '
                    'Estimator already defines a default metric with the '
                    'same name.')
            eval_dict[tf.GraphKeys.GLOBAL_STEP] = \
                tf.train.get_or_create_global_step()

            # Prepare the session creator.
            scaffold = tf.train.Scaffold()
            session_creator = tf.train.ChiefSessionCreator(
                scaffold=scaffold,
                checkpoint_dir=checkpoint_path)

            # Prepare hooks
            all_hooks = list(spec.evaluation_hooks) or []
            final_ops_hook = tf.train.FinalOpsHook(eval_dict)
            all_hooks.append(final_ops_hook)

            # Evaluate over dataset
            self._bridge.connect()
            try:
                with tf.train.MonitoredSession(
                    session_creator=session_creator, hooks=all_hooks) as sess:
                    if not self._restore_datablock(DATA_CHECKPOINT_INIT_VALUE):
                        raise ValueError("Restore data checkpoint error")
                    iter_id = 0
                    while not sess.should_stop():
                        self._bridge.start(iter_id)
                        logging.debug('after bridge start.')
                        start_time = time.time()
                        sess.run(eval_op)
                        end_time = time.time()
                        metrics.emit_timer(
                            name="iter_timer",
                            value=end_time-start_time,
                            tags={})
                        logging.debug('after session run.')
                        self._bridge.commit()
                        logging.debug('after bridge commit.')
                        iter_id += 1
            finally:
                self._bridge.terminate()

            # Print result
            logging.info('Metrics for iteration %d: %s',
                iter_id, _dict_to_str(final_ops_hook.final_ops_values))
            return final_ops_hook.final_ops_values
Example #19
0
    def _client_daemon_fn(self):
        stop_event = threading.Event()
        generator = None
        channel = make_insecure_channel(self._remote_address,
                                        ChannelType.REMOTE,
                                        options=self._grpc_options,
                                        compression=self._compression)
        client = make_ready_client(channel, stop_event)

        lock = threading.Lock()
        resend_list = collections.deque()

        @metrics.timer(func_name="shutdown_fn", tags={})
        def shutdown_fn():
            with lock:
                while len(resend_list) > 0 or not self._transmit_queue.empty():
                    logging.debug(
                        "Waiting for resend queue's being cleaned. "
                        "Resend queue size: %d", len(resend_list))
                    lock.release()
                    time.sleep(1)
                    lock.acquire()

            stop_event.set()
            if generator is not None:
                generator.cancel()

        self._client_daemon_shutdown_fn = shutdown_fn

        while not stop_event.is_set():
            try:

                def iterator():
                    with lock:
                        resend_msgs = list(resend_list)
                    for item in resend_msgs:
                        logging.warning("Streaming resend message seq_num=%d",
                                        item.seq_num)
                        metrics.emit_store(name="resend_msg_seq_num",
                                           value=int(item.seq_num),
                                           tags={})
                        yield item
                    while True:
                        item = self._transmit_queue.get()
                        with lock:
                            resend_list.append(item)
                        logging.debug("Streaming send message seq_num=%d",
                                      item.seq_num)
                        metrics.emit_store(name="send_msg_seq_num",
                                           value=int(item.seq_num),
                                           tags={})
                        yield item

                time_start = time.time()
                generator = client.StreamTransmit(iterator())
                time_end = time.time()
                metrics.emit_timer(name="one_StreamTransmit_spend",
                                   value=int(time_end - time_start),
                                   tags={})
                for response in generator:
                    if response.status.code == common_pb.STATUS_SUCCESS:
                        logging.debug(
                            "Message with seq_num=%d is "
                            "confirmed", response.next_seq_num - 1)
                    elif response.status.code == \
                        common_pb.STATUS_MESSAGE_DUPLICATED:
                        logging.debug(
                            "Resent Message with seq_num=%d is "
                            "confirmed", response.next_seq_num - 1)
                    elif response.status.code == \
                        common_pb.STATUS_MESSAGE_MISSING:
                        raise RuntimeError("Message with seq_num=%d is "
                                           "missing!" %
                                           (response.next_seq_num - 1))
                    else:
                        raise RuntimeError("Trainsmit failed with %d" %
                                           response.status.code)
                    with lock:
                        while resend_list and \
                                resend_list[0].seq_num < response.next_seq_num:
                            resend_list.popleft()
                        min_seq_num_to_resend = resend_list[0].seq_num \
                            if resend_list else "NaN"
                        logging.debug(
                            "Resend queue size: %d, starting from seq_num=%s",
                            len(resend_list), min_seq_num_to_resend)
                metrics.emit_store(name="sum_of_resend",
                                   value=int(len(resend_list)),
                                   tags={})
            except Exception as e:  # pylint: disable=broad-except
                if not stop_event.is_set():
                    logging.warning("Bridge streaming broken: %s.", repr(e))
            finally:
                generator.cancel()
                channel.close()
                logging.warning(
                    "Restarting streaming: resend queue size: %d, "
                    "starting from seq_num=%s", len(resend_list),
                    resend_list and resend_list[0].seq_num or "NaN")
                channel = make_insecure_channel(self._remote_address,
                                                ChannelType.REMOTE,
                                                options=self._grpc_options,
                                                compression=self._compression)
                client = make_ready_client(channel, stop_event)
                self._check_remote_heartbeat(client)
Example #20
0
    def train(self,
              input_fn,
              checkpoint_path=None,
              save_checkpoint_steps=None,
              save_checkpoint_secs=None):
        if self._cluster_spec is not None:
            device_fn = tf.train.replica_device_setter(
                worker_device="/job:worker/task:%d" % self._worker_rank,
                merge_devices=True,
                cluster=self._cluster_spec)
            cluster_def = self._cluster_spec.as_cluster_def()
            local_address = self._cluster_spec.job_tasks('worker')[
                self._worker_rank]
            server = tf.train.Server(tf.train.ClusterSpec(
                {'local': {
                    0: local_address
                }}),
                job_name='local',
                task_index=0)
            target = 'grpc://' + local_address
        else:
            device_fn = None
            cluster_def = None
            target = None

        config = tf.ConfigProto(cluster_def=cluster_def)
        config.inter_op_parallelism_threads = 4
        config.intra_op_parallelism_threads = 4
        config.experimental.share_session_state_in_clusterspec_propagation \
            = True
        tf.config.set_soft_device_placement(False)

        with tf.Graph().as_default() as g:
            with tf.device(device_fn):
                features, labels = self._get_features_and_labels_from_input_fn(
                    input_fn, ModeKeys.TRAIN)
                spec, _ = self._get_model_spec(features, labels, ModeKeys.TRAIN)

            # Explicitly add a Saver
            if not tf.get_collection(tf.GraphKeys.SAVERS):
                saver = tf.train.Saver(
                    sharded=True,
                    defer_build=True,
                    save_relative_paths=True)  # Must set for portability
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)

            listener = DataCheckpointSaverListener(self._trainer_master,
                                                   self._application_id)
            saver_hook = tf.estimator.CheckpointSaverHook(
                checkpoint_path, save_secs=save_checkpoint_secs,
                save_steps=save_checkpoint_steps, listeners=[listener])
            self._bridge.connect()

            try:
                with tf.train.MonitoredTrainingSession(
                    master=target,
                    config=config,
                    is_chief=(self._worker_rank == 0),
                    chief_only_hooks=[saver_hook],
                    checkpoint_dir=checkpoint_path,
                    save_checkpoint_steps=None,
                    save_checkpoint_secs=None,
                    hooks=spec.training_hooks) as sess:
                    iter_id = 0

                    data_checkpoint_value = None
                    if hasattr(saver_hook, "data_checkpoint"):
                        data_checkpoint_value = saver_hook.data_checkpoint
                    if not self._restore_datablock(data_checkpoint_value):
                        raise ValueError("Restore data checkpoint error")

                    while not sess.should_stop():
                        self._bridge.start(iter_id)
                        logging.debug('after bridge start.')
                        start_time = time.time()
                        sess.run(spec.train_op, feed_dict={})
                        end_time = time.time()
                        metrics.emit_timer(
                            name="iter_timer",
                            value=end_time-start_time,
                            tags={})
                        logging.debug('after session run.')
                        self._bridge.commit()
                        logging.debug('after bridge commit.')
                        iter_id += 1
            finally:
                self._bridge.terminate()

        return self
Example #21
0
 def _update_latest_dump_timestamp(self):
     data_block_dump_duration = time.time() - self._latest_dump_timestamp
     metrics.emit_timer(name='data_block_dump_duration',
                        value=int(data_block_dump_duration),
                        tags=self._metrics_tags)
     self._latest_dump_timestamp = time.time()