def gen_workload_list(selection): # type: (str) -> Iterable[Tuple[WTL, RunConfig]] """Select workloads based on commandline""" if not selection: blacklist = ['speech', 'seq2seq', 'mnistlg', 'mnistsf', 'mnistcv'] names = ( (v, bs) for k, v in WTL.known_workloads.items() for bs in v.available_batch_sizes() if k not in blacklist ) else: names = [] for cname in unique((cname for cname in selection.split(',')), stable=True): if '_' not in cname: raise UsageError(f"Not a canonical name: {cname}") name, bs = cname.split('_', 1) bs = try_with_default(int, bs, ValueError)(bs) names.append((WTL.from_name(name), bs)) # Find all available batch_num with JCT and mem data return ( (wtl, RunConfig(bs, bn, None)) for wtl, bs in names for bn in wtl.available_batch_nums(bs) )
def find_geometry(w, field): """ :type w: Workload :type field: str """ if w.geometry[field] is not None: return w.geometry[field] # check for another bn for bn in w.wtl.available_batch_nums(w.batch_size): g = WTL.from_name(w.name).geometry(RunConfig(w.batch_size, bn, None), w.executor) if g[field] is not None: w.geometry[field] = g[field] return g[field] return None