def _test_frequency_with_engine(device, workers, lower_bound_factor=0.8, every=1): artificial_time = 1.0 / workers # seconds total_tokens = 400 // workers batch_size = 128 // workers estimated_wps = batch_size * workers / artificial_time def update_fn(engine, batch): time.sleep(artificial_time) return {"ntokens": len(batch)} engine = Engine(update_fn) wps_metric = Frequency(output_transform=lambda x: x["ntokens"], device=device) event = Events.ITERATION_COMPLETED(every=every) wps_metric.attach(engine, "wps", event_name=event) @engine.on(event) def assert_wps(e): wps = e.state.metrics["wps"] assert estimated_wps * lower_bound_factor < wps <= estimated_wps, "{}: {} < {} < {}".format( e.state.iteration, estimated_wps * lower_bound_factor, wps, estimated_wps) data = [[i] * batch_size for i in range(0, total_tokens, batch_size)] engine.run(data, max_epochs=1)
def _test_frequency_with_engine(workers=None, lower_bound_factor=0.8, every=1): if workers is None: workers = idist.get_world_size() artificial_time = 1.0 / workers # seconds total_tokens = 400 // workers batch_size = 128 // workers estimated_wps = batch_size * workers / artificial_time def update_fn(engine, batch): time.sleep(artificial_time) return {"ntokens": len(batch)} engine = Engine(update_fn) wps_metric = Frequency(output_transform=lambda x: x["ntokens"]) event = Events.ITERATION_COMPLETED(every=every) wps_metric.attach(engine, "wps", event_name=event) @engine.on(event) def assert_wps(e): wps = e.state.metrics["wps"] # Skip iterations 2, 3, 4 if backend is Horovod on CUDA, # wps is abnormally low for these iterations # otherwise, other values of wps are OK if idist.model_name() == "horovod-dist" and e.state.iteration in (2, 3, 4): return assert estimated_wps * lower_bound_factor < wps <= estimated_wps, "{}: {} < {} < {}".format( e.state.iteration, estimated_wps * lower_bound_factor, wps, estimated_wps ) data = [[i] * batch_size for i in range(0, total_tokens, batch_size)] max_epochs = 1 if idist.model_name() != "horovod-dist" else 2 engine.run(data, max_epochs=2)
def test_nondistributed_average(): artificial_time = 1 # seconds num_tokens = 100 average_upper_bound = num_tokens / artificial_time average_lower_bound = average_upper_bound * 0.9 freq_metric = Frequency() freq_metric.reset() time.sleep(artificial_time) freq_metric.update(num_tokens) average = freq_metric.compute() assert average_lower_bound < average < average_upper_bound
def _test_frequency_with_engine(device, workers): artificial_time = 0.1 / workers # seconds total_tokens = 1200 // workers batch_size = 128 // workers estimated_wps = batch_size * workers / artificial_time def update_fn(engine, batch): time.sleep(artificial_time) return {"ntokens": len(batch)} engine = Engine(update_fn) wps_metric = Frequency(output_transform=lambda x: x["ntokens"], device=device) wps_metric.attach(engine, 'wps') @engine.on(Events.ITERATION_COMPLETED) def assert_wps(e): wps = e.state.metrics['wps'] assert estimated_wps * 0.85 < wps < estimated_wps, \ "{}: {} < {} < {}".format(e.state.iteration, estimated_wps * 0.85, wps, estimated_wps) data = [[i] * batch_size for i in range(0, total_tokens, batch_size)] engine.run(data, max_epochs=1)