def __init__(self, args, front_sink_addr): super().__init__() self.port = args.port_out self.exit_flag = multiprocessing.Event() self.logger = set_logger(colored('SINK', 'green'), True) self.front_sink_addr = front_sink_addr self.verbose = True self.show_tokens_to_client = False self.is_ready = multiprocessing.Event()
def optimize(): logger = set_logger(colored('OPTIMIZER', 'magenta')) arg = get_run_args(get_optimizer_args_parser) temporary_file_name, config = optimize_graph(arg) optimized_graph_filepath = os.path.join(arg.model_dir, "optimized_graph.pbtxt") tf.gfile.Rename(temporary_file_name, optimized_graph_filepath, overwrite=True) logger.info(f"Serialized graph to {optimized_graph_filepath}")
def __init__(self, id, args, worker_address_list, sink_address, device_id): super().__init__() # 许海明 添加 self.args = args self.worker_id = id self.device_id = device_id self.logger = set_logger(colored('WORKER-%d' % self.worker_id, 'yellow'), args.verbose) self.daemon = True self.exit_flag = multiprocessing.Event() self.worker_address = worker_address_list self.num_concurrent_socket = len(self.worker_address) self.sink_address = sink_address self.verbose = True self.show_tokens_to_client = False self.is_ready = multiprocessing.Event()
def __init__(self, args): super().__init__() self.logger = set_logger(colored('VENTILATOR', 'magenta'), args.verbose) # self.model_dir = args.model_dir # self.max_seq_len = args.max_seq_len self.num_worker = args.num_worker self.max_batch_size = args.max_batch_size self.num_concurrent_socket = max(8, args.num_worker * 2) # optimize concurrency for multi-clients self.port = args.port self.args = args self.status_args = {k: (v if k != 'pooling_strategy' else v.value) for k, v in sorted(vars(args).items())} self.status_static = { 'python_version': sys.version, 'pyzmq_version': zmq.pyzmq_version(), 'zmq_version': zmq.zmq_version(), 'server_start_time': str(datetime.now()), } self.processes = [] self.logger.info('freeze, optimize and export graph, could take a while...') self.is_ready = threading.Event()
def gen(): logger = set_logger(colored('WORKER-%d' % self.worker_id, 'yellow'), self.verbose) poller = zmq.Poller() for sock in socks: poller.register(sock, zmq.POLLIN) logger.info('ready and listening!') self.is_ready.set() while not self.exit_flag.is_set(): events = dict(poller.poll()) for sock_idx, sock in enumerate(socks): if sock in events: client_id, raw_msg = sock.recv_multipart() msg = jsonapi.loads(raw_msg) # 这里是个list logger.info('new job\tsocket: %d\tsize: %d\tclient: %s' % (sock_idx, len(msg), client_id)) ## 这里对所有的句子进行分词 msg=[" ".join(jieba.lcut(sen.lower())) for sen in msg] # check if msg is a list of list, if yes consider the input is already tokenized yield { 'client_id': client_id, 'inputs': msg, }
def _run(self, sink_embed, sink_token, *receivers): # Windows does not support logger in MP environment, thus get a new logger # inside the process for better compatibility logger = set_logger(colored('WORKER-%d' % self.worker_id, 'yellow'), self.verbose) for sock, addr in zip(receivers, self.worker_address): sock.connect(addr) sink_embed.connect(self.sink_address) sink_token.connect(self.sink_address) # 下面是pytorch 代码 if self.args.max_tokens is None and self.args.max_sentences is None: self.args.max_sentences = 1 assert not self.args.sampling or self.args.nbest == self.args.beam, \ '--sampling requires --nbest to be equal to --beam' use_cuda = torch.cuda.is_available() and not self.args.cpu # Setup task, e.g., translation task = tasks.setup_task(self.args) # Load ensemble print('| loading model(s) from {}'.format(self.args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( self.args.path.split(':'), arg_overrides=eval(self.args.model_overrides), task=task, ) # Set dictionaries src_dict = task.source_dictionary tgt_dict = task.target_dictionary # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if self.args.no_beamable_mm else self.args.beam, need_attn=self.args.print_alignment, ) if self.args.fp16: model.half() if use_cuda: model.cuda() # Initialize generator generator = task.build_generator(self.args) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(self.args) bpe = encoders.build_bpe(self.args) def encode_fn(x): if tokenizer is not None: x = tokenizer.encode(x) if bpe is not None: x = bpe.encode(x) return x def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(self.args.replace_unk) max_positions = utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ) start_id = 0 for token in self.input_fn_builder(receivers, sink_token)(): inputs = token["inputs"] results = [] for batch in self.make_batches(inputs, self.args, task, max_positions, encode_fn): src_tokens = batch.src_tokens src_lengths = batch.src_lengths if use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() sample = { 'net_input': { 'src_tokens': src_tokens, 'src_lengths': src_lengths, }, } translations = task.inference_step(generator, models, sample) for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) results.append((start_id + id, src_tokens_i, hypos)) res = [] # sort output to match input order for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]): if src_dict is not None: src_str = src_dict.string(src_tokens, self.args.remove_bpe) res_each = [] for hypo in hypos[:min(len(hypos), self.args.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=self.args.remove_bpe, ) hypo_str = decode_fn(hypo_str) res_each.append(hypo_str) res.append(res_each) start_id += len(inputs) send_PAGE(src=sink_embed, dest=token['client_id'], res=res, req_id=ServerCmd.data_restokens) logger.info('job done\t client: %s' % (token['client_id']))
def _run(self, receiver, frontend, sender): receiver_addr = auto_bind(receiver) frontend.connect(self.front_sink_addr) sender.bind('tcp://*:%d' % self.port) pending_jobs = defaultdict(lambda: SinkJob()) # type: Dict[str, SinkJob] poller = zmq.Poller() poller.register(frontend, zmq.POLLIN) poller.register(receiver, zmq.POLLIN) # send worker receiver address back to frontend frontend.send(receiver_addr.encode('ascii')) # Windows does not support logger in MP environment, thus get a new logger # inside the process for better compability logger = set_logger(colored('SINK', 'green'), self.verbose) logger.info('ready') self.is_ready.set() while not self.exit_flag.is_set(): socks = dict(poller.poll()) if socks.get(receiver) == zmq.POLLIN: msg = receiver.recv_multipart() job_id = msg[0] # parsing job_id and partial_id job_info = job_id.split(b'@') job_id = job_info[0] partial_id = int(job_info[1]) if len(job_info) == 2 else 0 if msg[3] == ServerCmd.data_token: x = jsonapi.loads(msg[1]) pending_jobs[job_id].add_token(x, partial_id) elif msg[3] == ServerCmd.data_restokens: x = jsonapi.loads(msg[1]) pending_jobs[job_id].add_token(x, partial_id) else: logger.error('received a wrongly-formatted request (expected 4 frames, got %d)' % len(msg)) logger.error('\n'.join('field %d: %s' % (idx, k) for idx, k in enumerate(msg)), exc_info=True) logger.info('collect %s %s (E:%d/T:%d/A:%d)' % (msg[3], job_id, pending_jobs[job_id].progress_embeds, pending_jobs[job_id].progress_tokens, pending_jobs[job_id].checksum)) if socks.get(frontend) == zmq.POLLIN: client_addr, msg_type, msg_info, req_id = frontend.recv_multipart() if msg_type == ServerCmd.new_job: job_info = client_addr + b'#' + req_id # register a new job pending_jobs[job_info].checksum = int(msg_info) logger.info('job register\tsize: %d\tjob id: %s' % (int(msg_info), job_info)) if len(pending_jobs[job_info]._pending_embeds)>0 \ and pending_jobs[job_info].final_ndarray is None: pending_jobs[job_info].add_embed(None, 0) elif msg_type == ServerCmd.show_config: time.sleep(0.1) # dirty fix of slow-joiner: sleep so that client receiver can connect. logger.info('send config\tclient %s' % client_addr) sender.send_multipart([client_addr, msg_info, req_id]) # check if there are finished jobs, then send it back to workers finished = [(k, v) for k, v in pending_jobs.items() if v.is_done] for job_info, tmp in finished: client_addr, req_id = job_info.split(b'#') x_info = tmp.result # logger.info("req_id"+str(req_id)) # logger.info("client_addr"+str(client_addr)) # logger.info("x_info"+str(x_info)) # logger.info("xuhaing*************************1") sender.send_multipart([client_addr, x_info, b'', req_id]) # logger.info("xuhaing*************************2") logger.info('send back\tsize: %d\tjob id: %s' % (tmp.checksum, job_info)) # release the job tmp.clear() pending_jobs.pop(job_info)