def _dump_joined_items(self, matching_list): start_tm = time.time() prev_leader_idx = self._leader_restart_index + 1 for item in matching_list: (fe, li, fi) = item if self._enable_negative_example_generator and li > prev_leader_idx: for example in \ self._negative_example_generator.generate( fe, prev_leader_idx, li): builder = self._get_data_block_builder(True) assert builder is not None, "data block builder must be "\ "not None if before dummping" builder.append_item(example[0], example[1], example[2], None, True) if builder.check_data_block_full(): yield self._finish_data_block() prev_leader_idx = li + 1 builder = self._get_data_block_builder(True) assert builder is not None, "data block builder must be "\ "not None if before dummping" builder.append_item(fe, li, fi, None, True) if builder.check_data_block_full(): yield self._finish_data_block() metrics.emit_timer(name='attribution_joiner_dump_joined_items', value=int(time.time() - start_tm), tags=self._metrics_tags)
def _dump_joined_items(self, indexed_pairs): start_tm = time.time() for ip in indexed_pairs: if self._enable_negative_example_generator: for example in \ self._negative_example_generator.generate(ip.fe, ip.li): builder = self._get_data_block_builder(True) assert builder is not None, "data block builder must be "\ "not None if before dummping" # example: (li, fi, item) builder.append_item(example[0], example[1], example[2], None, True, 0) if builder.check_data_block_full(): yield self._finish_data_block() builder = self._get_data_block_builder(True) assert builder is not None, "data block builder must be "\ "not None if before dummping" builder.append_item(ip.fe, ip.li, ip.fi, None, True, joined=1) if builder.check_data_block_full(): yield self._finish_data_block() metrics.emit_timer(name='universal_joiner_dump_joined_items', value=int(time.time()-start_tm), tags=self._metrics_tags)
def _emit_dumper_metrics(self, file_index, dumped_index): dump_duration = time.time() - self._latest_dump_timestamp metrics.emit_timer(name='example_id_dump_duration', value=int(dump_duration), tags=self._metrics_tags) metrics.emit_store(name='example_dump_file_index', value=file_index, tags=self._metrics_tags) metrics.emit_store(name='example_id_dumped_index', value=dumped_index, tags=self._metrics_tags)
def _dump_data_blocks(self): while self.need_dump(): meta = self._get_next_data_block_meta() if meta is not None: start_tm = time.time() self._raw_data_visitor.active_visitor() self._dump_data_block_by_meta(meta) dump_duration = time.time() - start_tm metrics.emit_timer(name='data_block_dump_duration', value=int(dump_duration), tags=self._metrics_tags)
def _evit_stale_follower_cache(self): start_tm = time.time() reserved_items = self._evict_impl(self._follower_join_window, self._evict_if_useless) if len(reserved_items) < self._max_window_size: self._follower_join_window.reset(reserved_items, False) return reserved_items = self._evict_impl(reserved_items, self._evict_if_force) self._follower_join_window.reset(reserved_items, False) metrics.emit_timer(name='stream_joiner_evit_stale_follower_cache', value=int(time.time() - start_tm), tags=self._metrics_tags)
def _update_join_cache(self): start_tm = time.time() new_unjoined_example_ids = [] for example_id in self._leader_unjoined_example_ids: if example_id in self._follower_example_cache: self._joined_cache[example_id] = \ self._follower_example_cache[example_id] else: new_unjoined_example_ids.append(example_id) self._leader_unjoined_example_ids = new_unjoined_example_ids metrics.emit_timer(name='stream_joiner_update_join_cache', value=int(time.time() - start_tm), tags=self._metrics_tags)
def receive(self, iter_id, name): logging.debug('Data: Waiting to receive %s for iter %d.', name, iter_id) start_time = time.time() with self._condition: while (iter_id not in self._received_data) \ or (name not in self._received_data[iter_id]): self._condition.wait() data = self._received_data[iter_id][name] duration = time.time() - start_time metrics.emit_timer('receive_timer', duration) logging.debug('Data: received %s for iter %d after %f sec.', name, iter_id, duration) return tf.make_ndarray(data.tensor)
def _fill_follower_join_window(self, raw_data_finished): start_tm = time.time() start_pos = self._follower_join_window.size() follower_enough = self._fill_join_windows(self._follower_visitor, self._follower_join_window, self._follower_example_cache) end_pos = self._follower_join_window.size() eids = [(self._follower_join_window[idx][0], self._follower_join_window[idx][1].example_id) for idx in range(start_pos, end_pos)] self._joiner_stats.fill_follower_example_ids(eids) metrics.emit_timer(name='stream_joiner_fill_leader_join_window', value=int(time.time() - start_tm), tags=self._metrics_tags) return follower_enough or raw_data_finished
def _evit_stale_follower_cache(self): start_tm = time.time() tmp_sz = self._follower_join_window.size() reserved_items = self._evict_impl(self._follower_join_window, self._evict_if_useless) logging.debug("evict_if_useless %d to %d", tmp_sz, len(reserved_items)) if len(reserved_items) < self._max_window_size: self._follower_join_window.reset(reserved_items, False) return tmp_sz = len(reserved_items) reserved_items = self._evict_impl(reserved_items, self._evict_if_force) logging.debug("evict_if_force %d to %d", tmp_sz, len(reserved_items)) self._follower_join_window.reset(reserved_items, False) metrics.emit_timer(name='stream_joiner_evit_stale_follower_cache', value=int(time.time() - start_tm), tags=self._metrics_tags)
def _fill_follower_join_window(self, raw_data_finished): start_tm = time.time() idx = self._follower_join_window.size() filled_enough = self._fill_join_windows(self._follower_visitor, self._follower_join_window) eids = [] while idx < self._follower_join_window.size(): eids.append((self._follower_join_window[idx].index, self._follower_join_window[idx].item.example_id)) idx += 1 self._joiner_stats.fill_follower_example_ids(eids) metrics.emit_timer(name=\ 'universal_joiner_fill_follower_join_window', value=int(time.time()-start_tm), tags=self._metrics_tags) return filled_enough or raw_data_finished
def _fill_follower_join_window(self): start_tm = time.time() idx = self._follower_join_window.size() filled_new_example = self._fill_join_windows( self._follower_visitor, self._follower_join_window) eids = [] while idx < self._follower_join_window.size(): eids.append((self._follower_join_window[idx][0], self._follower_join_window[idx][1].example_id)) idx += 1 self._joiner_stats.fill_follower_example_ids(eids) metrics.emit_timer(name=\ 'attribution_joiner_fill_follower_join_window', value=int(time.time()-start_tm), tags=self._metrics_tags) return filled_new_example
def make_processor(self, next_index): input_finished = False with self._lock: if next_index is None: return if self._check_index_rollback(next_index): self._batch_queue = [] self._flying_item_count = 0 if len(self._batch_queue) > 0: end_batch = self._batch_queue[-1] next_index = end_batch.begin_index + len(end_batch) input_finished = self._input_finished assert next_index >= 0, "the next index should >= 0" end_batch = None batch_finished = False iter_round = 0 processed_index = None start_tm = time.time() for batch, batch_finished in self._make_inner_generator(next_index): if batch is not None: if len(batch) > 0: latency_mn = '{}.produce.latency'.format(self.name()) metrics.emit_timer(name=latency_mn, value=time.time() - start_tm, tags=self._get_metrics_tags()) store_mn = '{}.produce.index'.format(self.name()) metrics.emit_store(name=store_mn, value=batch.begin_index + len(batch) - 1, tags=self._get_metrics_tags()) self._append_next_item_batch(batch) yield batch start_tm = time.time() self._update_last_index(batch.begin_index + len(batch) - 1) iter_round += 1 processed_index = batch.begin_index + len(batch) - 1 if iter_round % 16 == 0: logging.info("%s process to index %d", self.name(), processed_index) if processed_index is not None: logging.info("%s process to index %d when round finished", self.name(), processed_index) if input_finished and batch_finished: self._set_process_finished()
def _dump_joined_items(self): start_tm = time.time() prev_leader_idx = 0 neg_samples = {} for (leader_idx, leader_item) in self._leader_join_window: if prev_leader_idx == 0: prev_leader_idx = leader_idx eid = leader_item.example_id if (eid not in self._follower_example_cache and eid not in self._joined_cache): if self._enable_negative_example_generator: neg_samples[leader_idx] = leader_item continue if eid not in self._joined_cache: self._joined_cache[eid] = self._follower_example_cache[eid] follower_example = self._joined_cache[eid] if (self._enable_negative_example_generator and leader_idx > prev_leader_idx): self._negative_example_generator.update(neg_samples) for example in self._negative_example_generator.generate( follower_example[1], prev_leader_idx, leader_idx): builder = self._get_data_block_builder(True) assert builder is not None, "data block builder must not " \ "be None before dumping" builder.append_item(example[0], example[1], example[2]) self._optional_stats.update_stats(example[0], kind='negative') if builder.check_data_block_full(): yield self._finish_data_block() neg_samples = {} builder = self._get_data_block_builder(True) assert builder is not None, "data block builder must not be "\ "None before dumping" follower_idx, item = self._joined_cache[eid] builder.append_item(item, leader_idx, follower_idx) self._optional_stats.update_stats(item, kind='joined') prev_leader_idx = leader_idx if builder.check_data_block_full(): yield self._finish_data_block() metrics.emit_timer(name='stream_joiner_dump_joined_items', value=int(time.time() - start_tm), tags=self._metrics_tags)
def _dump_joined_items(self): start_tm = time.time() for (li, le) in self._leader_join_window: eid = le.example_id if eid not in self._follower_example_cache and \ eid not in self._joined_cache: continue if eid not in self._joined_cache: self._joined_cache[eid] = \ self._follower_example_cache[eid] builder = self._get_data_block_builder(True) assert builder is not None, "data block builder must be "\ "not None if before dummping" fi, item = self._joined_cache[eid] builder.append_item(item, li, fi) if builder.check_data_block_full(): yield self._finish_data_block() metrics.emit_timer(name='stream_joiner_dump_joined_items', value=int(time.time() - start_tm), tags=self._metrics_tags)
def _fill_leader_join_window(self, sync_example_id_finished): if not self._fill_leader_enough: start_tm = time.time() start_pos = self._leader_join_window.size() if not self._fill_join_windows(self._leader_visitor, self._leader_join_window, None): self._fill_leader_enough = sync_example_id_finished else: self._fill_leader_enough = True if self._fill_leader_enough: self._leader_unjoined_example_ids = \ [item.example_id for _, item in self._leader_join_window] end_pos = self._leader_join_window.size() eids = [(self._leader_join_window[idx][0], self._leader_join_window[idx][1].example_id) for idx in range(start_pos, end_pos)] self._joiner_stats.fill_leader_example_ids(eids) metrics.emit_timer(name='stream_joiner_fill_leader_join_window', value=int(time.time() - start_tm), tags=self._metrics_tags) return self._fill_leader_enough
def _receive(self, iter_id, name): logging.debug('Data: Waiting to receive %s for iter %d.', name, iter_id) start_time = time.time() with self._condition: while (iter_id not in self._received_data) \ or (name not in self._received_data[iter_id]): if self._peer_next_iter_id > iter_id: msg = 'Peer committed without sending %s. ' \ 'Please check model code'%name logging.fatal(msg) raise RuntimeError(msg) if not self._condition.wait(10): logging.warning( 'Data: Still waiting to receive %s for iter %d...', name, iter_id) data = self._received_data[iter_id][name] duration = time.time() - start_time metrics.emit_timer('receive_timer', duration) logging.debug('Data: received %s for iter %d after %f sec.', name, iter_id, duration) return data
def _dump_joined_items(self): start_tm = time.time() prev_leader_idx = 0 neg_samples = {} for (li, le) in self._leader_join_window: if prev_leader_idx == 0: prev_leader_idx = li eid = le.example_id if eid not in self._follower_example_cache and \ eid not in self._joined_cache: if self._enable_negative_example_generator: neg_samples[li] = le continue if eid not in self._joined_cache: self._joined_cache[eid] = \ self._follower_example_cache[eid] fe = self._joined_cache[eid] if self._enable_negative_example_generator and li > prev_leader_idx: self._negative_example_generator.update(neg_samples) for example in \ self._negative_example_generator.generate( fe[1], prev_leader_idx, li): builder = self._get_data_block_builder(True) assert builder is not None, "data block builder must be "\ "not None if before dummping" builder.append_item(example[0], example[1], example[2]) if builder.check_data_block_full(): yield self._finish_data_block() neg_samples = {} builder = self._get_data_block_builder(True) assert builder is not None, "data block builder must be "\ "not None if before dummping" fi, item = self._joined_cache[eid] builder.append_item(item, li, fi) if builder.check_data_block_full(): yield self._finish_data_block() metrics.emit_timer(name='stream_joiner_dump_joined_items', value=int(time.time() - start_tm), tags=self._metrics_tags)
def evaluate(self, input_fn, checkpoint_path=None): if not tf.train.latest_checkpoint(checkpoint_path): raise ValueError( "Could not find trained model at %s" % checkpoint_path) with tf.Graph().as_default(): features, labels = self._get_features_and_labels_from_input_fn( input_fn, ModeKeys.EVAL) spec, model = self._get_model_spec(features, labels, ModeKeys.EVAL) # Track the average loss in default eval_metric_ops = spec.eval_metric_ops or {} if model_fn_lib.LOSS_METRIC_KEY not in eval_metric_ops: loss_metric = tf.metrics.mean(spec.loss) eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric # Create the real eval op update_ops, eval_dict = _extract_metric_update_ops(eval_metric_ops) update_ops.extend(model._train_ops) eval_op = tf.group(*update_ops) # Also track the global step if tf.GraphKeys.GLOBAL_STEP in eval_dict: raise ValueError( 'Metric with name `global_step` is not allowed, because ' 'Estimator already defines a default metric with the ' 'same name.') eval_dict[tf.GraphKeys.GLOBAL_STEP] = \ tf.train.get_or_create_global_step() # Prepare the session creator. scaffold = tf.train.Scaffold() session_creator = tf.train.ChiefSessionCreator( scaffold=scaffold, checkpoint_dir=checkpoint_path) # Prepare hooks all_hooks = list(spec.evaluation_hooks) or [] final_ops_hook = tf.train.FinalOpsHook(eval_dict) all_hooks.append(final_ops_hook) # Evaluate over dataset self._bridge.connect() try: with tf.train.MonitoredSession( session_creator=session_creator, hooks=all_hooks) as sess: if not self._restore_datablock(DATA_CHECKPOINT_INIT_VALUE): raise ValueError("Restore data checkpoint error") iter_id = 0 while not sess.should_stop(): self._bridge.start(iter_id) logging.debug('after bridge start.') start_time = time.time() sess.run(eval_op) end_time = time.time() metrics.emit_timer( name="iter_timer", value=end_time-start_time, tags={}) logging.debug('after session run.') self._bridge.commit() logging.debug('after bridge commit.') iter_id += 1 finally: self._bridge.terminate() # Print result logging.info('Metrics for iteration %d: %s', iter_id, _dict_to_str(final_ops_hook.final_ops_values)) return final_ops_hook.final_ops_values
def _client_daemon_fn(self): stop_event = threading.Event() generator = None channel = make_insecure_channel(self._remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) client = make_ready_client(channel, stop_event) lock = threading.Lock() resend_list = collections.deque() @metrics.timer(func_name="shutdown_fn", tags={}) def shutdown_fn(): with lock: while len(resend_list) > 0 or not self._transmit_queue.empty(): logging.debug( "Waiting for resend queue's being cleaned. " "Resend queue size: %d", len(resend_list)) lock.release() time.sleep(1) lock.acquire() stop_event.set() if generator is not None: generator.cancel() self._client_daemon_shutdown_fn = shutdown_fn while not stop_event.is_set(): try: def iterator(): with lock: resend_msgs = list(resend_list) for item in resend_msgs: logging.warning("Streaming resend message seq_num=%d", item.seq_num) metrics.emit_store(name="resend_msg_seq_num", value=int(item.seq_num), tags={}) yield item while True: item = self._transmit_queue.get() with lock: resend_list.append(item) logging.debug("Streaming send message seq_num=%d", item.seq_num) metrics.emit_store(name="send_msg_seq_num", value=int(item.seq_num), tags={}) yield item time_start = time.time() generator = client.StreamTransmit(iterator()) time_end = time.time() metrics.emit_timer(name="one_StreamTransmit_spend", value=int(time_end - time_start), tags={}) for response in generator: if response.status.code == common_pb.STATUS_SUCCESS: logging.debug( "Message with seq_num=%d is " "confirmed", response.next_seq_num - 1) elif response.status.code == \ common_pb.STATUS_MESSAGE_DUPLICATED: logging.debug( "Resent Message with seq_num=%d is " "confirmed", response.next_seq_num - 1) elif response.status.code == \ common_pb.STATUS_MESSAGE_MISSING: raise RuntimeError("Message with seq_num=%d is " "missing!" % (response.next_seq_num - 1)) else: raise RuntimeError("Trainsmit failed with %d" % response.status.code) with lock: while resend_list and \ resend_list[0].seq_num < response.next_seq_num: resend_list.popleft() min_seq_num_to_resend = resend_list[0].seq_num \ if resend_list else "NaN" logging.debug( "Resend queue size: %d, starting from seq_num=%s", len(resend_list), min_seq_num_to_resend) metrics.emit_store(name="sum_of_resend", value=int(len(resend_list)), tags={}) except Exception as e: # pylint: disable=broad-except if not stop_event.is_set(): logging.warning("Bridge streaming broken: %s.", repr(e)) finally: generator.cancel() channel.close() logging.warning( "Restarting streaming: resend queue size: %d, " "starting from seq_num=%s", len(resend_list), resend_list and resend_list[0].seq_num or "NaN") channel = make_insecure_channel(self._remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) client = make_ready_client(channel, stop_event) self._check_remote_heartbeat(client)
def train(self, input_fn, checkpoint_path=None, save_checkpoint_steps=None, save_checkpoint_secs=None): if self._cluster_spec is not None: device_fn = tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % self._worker_rank, merge_devices=True, cluster=self._cluster_spec) cluster_def = self._cluster_spec.as_cluster_def() local_address = self._cluster_spec.job_tasks('worker')[ self._worker_rank] server = tf.train.Server(tf.train.ClusterSpec( {'local': { 0: local_address }}), job_name='local', task_index=0) target = 'grpc://' + local_address else: device_fn = None cluster_def = None target = None config = tf.ConfigProto(cluster_def=cluster_def) config.inter_op_parallelism_threads = 4 config.intra_op_parallelism_threads = 4 config.experimental.share_session_state_in_clusterspec_propagation \ = True tf.config.set_soft_device_placement(False) with tf.Graph().as_default() as g: with tf.device(device_fn): features, labels = self._get_features_and_labels_from_input_fn( input_fn, ModeKeys.TRAIN) spec, _ = self._get_model_spec(features, labels, ModeKeys.TRAIN) # Explicitly add a Saver if not tf.get_collection(tf.GraphKeys.SAVERS): saver = tf.train.Saver( sharded=True, defer_build=True, save_relative_paths=True) # Must set for portability tf.add_to_collection(tf.GraphKeys.SAVERS, saver) listener = DataCheckpointSaverListener(self._trainer_master, self._application_id) saver_hook = tf.estimator.CheckpointSaverHook( checkpoint_path, save_secs=save_checkpoint_secs, save_steps=save_checkpoint_steps, listeners=[listener]) self._bridge.connect() try: with tf.train.MonitoredTrainingSession( master=target, config=config, is_chief=(self._worker_rank == 0), chief_only_hooks=[saver_hook], checkpoint_dir=checkpoint_path, save_checkpoint_steps=None, save_checkpoint_secs=None, hooks=spec.training_hooks) as sess: iter_id = 0 data_checkpoint_value = None if hasattr(saver_hook, "data_checkpoint"): data_checkpoint_value = saver_hook.data_checkpoint if not self._restore_datablock(data_checkpoint_value): raise ValueError("Restore data checkpoint error") while not sess.should_stop(): self._bridge.start(iter_id) logging.debug('after bridge start.') start_time = time.time() sess.run(spec.train_op, feed_dict={}) end_time = time.time() metrics.emit_timer( name="iter_timer", value=end_time-start_time, tags={}) logging.debug('after session run.') self._bridge.commit() logging.debug('after bridge commit.') iter_id += 1 finally: self._bridge.terminate() return self
def _update_latest_dump_timestamp(self): data_block_dump_duration = time.time() - self._latest_dump_timestamp metrics.emit_timer(name='data_block_dump_duration', value=int(data_block_dump_duration), tags=self._metrics_tags) self._latest_dump_timestamp = time.time()