def ReplaceDeviceTag(parallel_desc_symbol, device_tag, builder=None): assert parallel_desc_symbol.device_tag != device_tag parallel_conf = placement_pb.ParallelConf() parallel_conf.device_tag = device_tag for device_name in parallel_desc_symbol.parallel_conf.device_name: parallel_conf.device_name.append(device_name) if builder is None: return symbol_util.ParallelDescSymbol(parallel_desc_symbol.symbol_id, parallel_conf, device_tag) else: return builder.GetParallelDescSymbol(parallel_conf)
def GetParallelDescSymbol(self, parallel_conf): device_tag = parallel_conf.device_tag serialized_parallel_conf = parallel_conf.SerializeToString() if symbol_storage.HasSymbol4SerializedParallelConf( serialized_parallel_conf): return symbol_storage.GetSymbol4SerializedParallelConf( serialized_parallel_conf) symbol_id = self._NewSymbolId4ParallelConf(parallel_conf) symbol = symbol_util.ParallelDescSymbol(symbol_id, parallel_conf, device_tag) symbol_storage.SetSymbol4Id(symbol_id, symbol) symbol_storage.SetSymbol4SerializedParallelConf( serialized_parallel_conf, symbol) return symbol
def RandomParallelIdPerMachine(parallel_desc_symbol, device_tag=None, builder=None): if device_tag is None: device_tag = parallel_desc_symbol.parallel_conf.device_tag assert device_tag is not None parallel_conf = placement_pb.ParallelConf() parallel_conf.device_tag = device_tag for machine_id, dev_ids in parallel_desc_symbol.machine_id2device_id_list.items( ): dev_id = dev_ids[random.randint(0, len(dev_ids) - 1)] parallel_conf.device_name.append("%s:%s" % (machine_id, dev_id)) if builder is None: return symbol_util.ParallelDescSymbol(parallel_desc_symbol.symbol_id, parallel_conf, device_tag) else: return builder.GetParallelDescSymbol(parallel_conf)