def data_setup(): """Sets up logging, random seeds and corpus""" # global variables # Set the random seed manually for reproducibility. random.seed(g.args.seed) np.random.seed(g.args.seed) torch.manual_seed(g.args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(g.args.seed) torch.cuda.set_device(g.args.local_rank) g.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ############################################################################### # Load data ############################################################################### g.corpus = get_lm_corpus(g.args.data, g.args.dataset, use_bpe=g.args.bpe) g.ntokens = len(g.corpus.vocab) g.va_iter, g.te_iter = [ g.corpus.get_dist_iterator(split, bsz=g.args.batch_size * 2, bptt=g.args.tgt_len, rank=util.get_global_rank(), max_rank=util.get_world_size(), device=g.device, ext_len=g.args.ext_len) for split in ('valid', 'test') ]
def test_optimize(): global log recv_bytes, transmit_bytes = util.network_bytes() device = 'cuda' fp16 = True dim = 2 ** 12 # multiple of 8, about 67MB matrix in fp32 model = SimpleNet(args.num_layers, dim) model = model.to(device) if fp16: model = model.half() bytes_per_number = 2 else: bytes_per_number = 4 gradient_size = args.num_layers * (dim * dim) * bytes_per_number size_mb = gradient_size / 1e6 dist.init_process_group(backend='nccl', init_method='env://', world_size=util.get_world_size()) model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) optimizer = optim.SGD(model.parameters(), lr=0.01) x = torch.eye(dim) x = x.to(device) if fp16: x = x.half() time_list = [] start_time = time.perf_counter() start_time0 = start_time for i in range(1): optimizer.zero_grad() output = model(x) def sqr(a): return a*a loss = sqr(output-x).sum() loss.backward() optimizer.step() torch.cuda.synchronize() elapsed_time_sec = (time.perf_counter() - start_time) start_time = time.perf_counter() elapsed_time_ms = elapsed_time_sec * 1000 time_list.append(elapsed_time_ms) rate = size_mb / elapsed_time_sec
def main_loop(): util.cancel_shutdown() losses = [] args = g.args if not args.local: g.logger.info( f'Distributed initializing process group with ' f'{args.dist_backend}, {args.dist_url}, {util.get_world_size()}') dist.init_process_group( backend=args.dist_backend, #init_method=args.dist_url, #world_size=util.get_world_size() ) assert (util.get_world_size() == dist.get_world_size()) g.logger.info( f"Distributed: success ({args.local_rank}/{dist.get_world_size()})" ) g.logger.info("creating new model") g.state = TrainState(args) g.state.model = MemTransformerLM(g.ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=g.tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=g.cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax, freeze_below=args.freeze_below) g.state.model.to(g.device) optimizer_setup(g.state) if args.checkpoint: if args.checkpoint_secondary: g.logger.info(f"restoring extra checkpoint") util.restore_from_checkpoint(g.state.model, g.state.optimizer, args.checkpoint_secondary, args.optim_state_dict) g.logger.info(f"Restoring model from {args.checkpoint}" + f" and optimizer from {args.optim_state_dict}" if args. optim_state_dict else "") util.restore_from_checkpoint(g.state.model, g.state.optimizer, args.checkpoint, args.optim_state_dict) else: g.state.model.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing g.state.model.word_emb.apply(weights_init) model: MemTransformerLM = g.state.model optimizer = g.state.optimizer if g.state.args.fp16: model = FP16_Module(model) optimizer = FP16_Optimizer( optimizer, static_loss_scale=g.state.args.static_loss_scale, dynamic_loss_scale=g.state.args.dynamic_loss_scale, dynamic_loss_args={'init_scale': 2**16}, verbose=False) # log model info # n_all_param = sum([p.nelement() for p in model.parameters()]) # log_tb('sizes/params', n_all_param) # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) # log_tb('sizes/non_emb_params', n_nonemb_param) # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param) # scheduler if args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, args.max_tokens // 1e6, eta_min=args.eta_min) elif args.scheduler == 'finder': g.state.scheduler: LRFinder = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3) else: assert args.scheduler == 'constant' g.state.scheduler = util.NoOp() # Setup distributed model if args.local: model = nn.DataParallel(model, dim=1) else: # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding. model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) # , find_unused_parameters=True) if util.get_global_rank() == 0: if not args.test: wandb.config.update(vars(args)) # wandb.watch(model) g.event_writer.add_text('args', str(args)) # TODO: replace with log_tb accumulated_loss = 0 # At any point you can hit Ctrl + C to break out of training early. try: for epoch in itertools.count(start=g.state.last_epoch): print(f"epoch -- {epoch}, token_count -- {g.state.token_count}") model.train() log_tb('sizes/batch_size', args.batch_size) log_tb('sizes/seq_size', args.tgt_len) if g.state.partial_epoch: # reuse previously loaded tr_iter and states assert g.state.tr_iter is not None assert g.state.mems is not None else: g.state.tr_iter = g.corpus.get_dist_iterator( 'train', rank=util.get_global_rank(), max_rank=util.get_world_size(), bsz=args.batch_size, bptt=args.tgt_len, device=g.device, ext_len=args.ext_len, skip_files=g.args.skip_files) g.state.mems = tuple() g.state.last_epoch = epoch log_start_time = time.time() tokens_per_epoch = 0 for batch, (data, target, seq_len) in enumerate(g.state.tr_iter): # assert seq_len == data.shape[0] # for i in range(1, data.shape[0]): # assert torch.all(torch.eq(data[i], target[i - 1])) # break # print(g.state.token_count, data) if g.state.train_step % args.eval_interval == 0: evaluate_and_log(model, g.va_iter, 'val_short-mem-1', generate_text=False, reset_mems_interval=1) evaluate_and_log(model, g.va_iter, 'val_short-mem-2', generate_text=False, reset_mems_interval=2) evaluate_and_log(model, g.va_iter, 'val_short-mem-3', generate_text=False, reset_mems_interval=3) evaluate_and_log(model, g.va_iter, 'val') if g.va_custom_iter: evaluate_and_log(g.state.model, g.va_custom_iter, g.args.valid_custom, generate_text=False) batch_total = torch.tensor(data.shape[1]).to(g.device) if args.local: # TODO(y): factor out (need way to see if dist was inited) batch_total = batch_total.sum() else: batch_total = util.dist_sum_tensor( batch_total) # global batch size batch_total = util.toscalar(batch_total) should_log = (g.state.train_step < args.verbose_log_steps) or \ (g.state.train_step + 1) % args.log_interval == 0 model.zero_grad() ret = model(data, target, *g.state.mems) loss, g.state.mems = ret[0], ret[1:] loss: torch.Tensor = loss.float().mean().type_as(loss) with timeit('backwards', noop=not should_log): if args.fp16: optimizer.backward(loss) else: loss.backward() loss0 = util.toscalar(loss) util.record('loss', loss0) util.record('params', torch.sum(util.flat_param(model)).item()) losses.append(loss0) accumulated_loss += loss0 if args.fp16: optimizer.clip_master_grads(args.clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) # step-wise learning rate annealing if hasattr(optimizer, 'overflow') and optimizer.overflow: g.logger.info("skipped iteration") else: if args.scheduler in ['cosine', 'constant', 'dev_perf']: # linear warmup stage if g.state.token_count < args.warmup_tokens: curr_lr = args.lr * float( g.state.token_count) / args.warmup_tokens optimizer.param_groups[0]['lr'] = curr_lr elif args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. g.state.scheduler.step(g.state.token_count // 1000 // 1000) else: g.state.scheduler.step(g.state.token_count) optimizer.step() g.state.train_step += 1 consumed_tokens = data.shape[0] * data.shape[1] world_size = int(os.environ.get("WORLD_SIZE", "8")) if world_size > 8: # correction factor for multiple machines consumed_tokens = consumed_tokens * (world_size // 8) tokens_per_epoch += consumed_tokens g.state.token_count += consumed_tokens g.token_count = g.state.token_count if g.state.token_count >= args.max_tokens: g.state.partial_epoch = True raise StopIteration # break out of parent train loop if should_log: elapsed_time = time.time() - log_start_time elapsed_steps = g.state.train_step - g.state.last_log_step # compute average loss over last logging interval cur_loss = accumulated_loss / elapsed_steps cur_loss_mean = util.dist_mean(cur_loss) log_str = f'| epoch {epoch:3d} step {g.state.train_step:>8d} ' \ f'| {batch:>6d} batches ' \ f'| lr {optimizer.param_groups[0]["lr"]:.3g} ' \ f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} ' \ f'| loss {cur_loss:5.2f}' if args.dataset in ['enwik8', 'text8']: log_str += f' | bpc {cur_loss / math.log(2):9.5f}' else: log_str += f' | ppl {math.exp(cur_loss):9.3f}' g.logger.info(log_str) log_tb('learning/epoch', epoch) log_tb('_loss', cur_loss_mean) # the most important thing log_tb('learning/loss', cur_loss_mean) log_tb('learning/ppl', math.exp(cur_loss_mean)) # currently step timings are not synchronized in multi-machine # case (see #4). Can add torch.distributed.barrier() to get # more accurate timings, but this may add slowness. log_tb('times/step', 1000 * elapsed_time / elapsed_steps) current_lr = optimizer.param_groups[0]['lr'] log_tb('learning/lr', current_lr) # 32 is the "canonical" batch size linear_scaling_factor = batch_total / 32 # TODO(y): merge logic from master log_tb('learning/base_lr', current_lr / linear_scaling_factor) if args.optim == 'lamb': log_lamb_rs(optimizer, g.event_writer, g.state.token_count) time_per_batch = elapsed_time / elapsed_steps time_per_sample = time_per_batch / args.batch_size time_per_token = time_per_sample / args.tgt_len log_tb('times/batches_per_sec', 1 / time_per_batch) log_tb('times/samples_per_sec', 1 / time_per_sample) log_tb('times/tokens_per_sec', 1 / time_per_token) if str(g.device) == 'cuda': log_tb("memory/allocated_gb", torch.cuda.memory_allocated() / 1e9) log_tb("memory/max_allocated_gb", torch.cuda.max_memory_allocated() / 1e9) log_tb("memory/cached_gb", torch.cuda.memory_cached() / 1e9) log_tb("memory/max_cached_gb", torch.cuda.max_memory_cached() / 1e9) accumulated_loss = 0 log_start_time = time.time() g.state.last_log_step = g.state.train_step if args.checkpoint_each_epoch: g.logger.info(f'Saving checkpoint for epoch {epoch}') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{epoch}') if tokens_per_epoch == 0: logging.info("Zero tokens in last epoch, breaking") break g.state.partial_epoch = False except KeyboardInterrupt: g.logger.info('-' * 100) g.logger.info('Exiting from training early') except StopIteration: pass return losses
def main(): global global_token_count, event_writer, train_step, train_loss, last_log_step, \ best_val_loss, epoch, model if args.local_rank > 0: pass # skip shutdown when rank is explicitly set + not zero rank else: os.system('shutdown -c') if not args.local: logger.info( f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}' ) dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=util.get_world_size()) assert (util.get_world_size() == dist.get_world_size()) logger.info( f"Distributed: success ({args.local_rank}/{dist.get_world_size()})" ) model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) # log model info n_all_param = sum([p.nelement() for p in model.parameters()]) log_tb('sizes/params', n_all_param) n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) log_tb('sizes/non_emb_params', n_nonemb_param) logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param) # optimizer if args.optim.lower() == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom) elif args.optim.lower() == 'lamb': optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd) else: assert args.optim.lower() == 'adam' optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) # scheduler if args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_tokens // 1e6, eta_min=args.eta_min) elif args.scheduler == 'finder': scheduler = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3) elif args.scheduler == 'constant': pass model.apply(weights_init) model.word_emb.apply( weights_init ) # ensure embedding init is not overridden by out_layer in case of weight sharing if args.checkpoint: if global_rank == 0: util.restore_from_checkpoint(model=model, checkpoint_fn=args.checkpoint) model = model.to(device) if args.fp16: model = FP16_Module(model) optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.static_loss_scale, dynamic_loss_scale=args.dynamic_loss_scale, dynamic_loss_args={'init_scale': 2**16}, verbose=False) if args.local: model = nn.DataParallel(model, dim=1) else: # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding. model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) #, find_unused_parameters=True) if global_rank == 0: event_writer = SummaryWriter(args.logdir) event_writer.add_text('args', str(args)) # test checkpoint writing if args.checkpoint_each_epoch: logger.info(f'Saving checkpoint for epoch {epoch}') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{0}') # Loop over epochs. train_step = 0 train_loss = 0 last_log_step = 0 best_val_loss = None va_iter, te_iter = [ corpus.get_dist_iterator(split, global_rank, max_rank, args.batch_size * 2, args.tgt_len, device=device, ext_len=args.ext_len) for split in ('valid', 'test') ] # At any point you can hit Ctrl + C to break out of training early. try: for epoch in itertools.count(start=1): train(va_iter, optimizer, scheduler) except KeyboardInterrupt: logger.info('-' * 100) logger.info('Exiting from training early') except StopIteration: pass # Eval one more time. evaluate_and_log(optimizer, va_iter, 'val', train_step=-1) # Load the best saved model. logger.info("Loading best checkpoint") model_file = os.path.join(args.logdir, 'model-best.pt') if os.path.exists(model_file): with open(model_file, 'rb') as model_f: with timeit('load'): if args.local: model = torch.load(model_f) else: model = torch.load(model_f, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) else: logger.warn('no model file, using current model for loss') # Run on test data. evaluate_and_log(optimizer, te_iter, 'test', -1)
type=int, help='how long to wait before shutting down on error') args = parser.parse_args() args.tied = not args.not_tied # global variables global_timeit_dict = OrderedDict() global_token_count = 0 event_writer = util.NoOp() epoch = 0 train_step = 0 local_rank = args.local_rank global_rank = util.get_global_rank() max_rank = util.get_world_size() class FileLogger: def __init__(self, output_dir: str, global_rank: int, local_rank: int): self.output_dir = output_dir if not os.path.exists(self.output_dir): os.makedirs(self.output_dir, exist_ok=True) self.logger = FileLogger.get_logger(output_dir, global_rank=global_rank, local_rank=local_rank) def exception(self, *args_, **kwargs): return self.logger.exception(*args_, **kwargs) @staticmethod
def test_optimize(): global log recv_bytes, transmit_bytes = util.network_bytes() device = 'cuda' dim = 2 ** 12 # multiple of 8, about 67MB matrix in fp32 model = SimpleNet(args.num_layers, dim) model = model.to(device) if fp16: model = model.half() bytes_per_number = 2 else: bytes_per_number = 4 gradient_size = args.num_layers * (dim * dim) * bytes_per_number size_mb = gradient_size / 1e6 log('initializing process group') dist.init_process_group(backend='nccl', init_method='env://', world_size=util.get_world_size()) log('calling DDP') model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) optimizer = optim.SGD(model.parameters(), lr=0.01) x = torch.eye(dim) x = x.to(device) if fp16: x = x.half() time_list = [] # force initialization of NCCL dist.all_reduce(torch.ones(()).cuda()) dist.barrier() log("Start timing") start_time = time.perf_counter() start_time0 = start_time for i in range(args.iters): optimizer.zero_grad() output = model(x) def sqr(a): return a * a loss = sqr(output - x).sum() loss.backward() optimizer.step() torch.cuda.synchronize() elapsed_time_sec = (time.perf_counter() - start_time) start_time = time.perf_counter() elapsed_time_ms = elapsed_time_sec * 1000 time_list.append(elapsed_time_ms) rate = size_mb / elapsed_time_sec log('%03d/%d added %d MBs in %.1f ms: %.2f MB/second %.1f' % ( i, args.iters, size_mb, elapsed_time_ms, rate, loss)) del time_list[0] # first measurement is off because of syncing min_time = np.min(time_list) median = np.median(time_list) log(f"min: {min_time:8.2f}, median: {median:8.2f}, mean: {np.mean(time_list):8.2f}") dist.barrier() elapsed_time = time.perf_counter() - start_time0 recv_bytes1, transmit_bytes1 = util.network_bytes() log(f"Received {(recv_bytes1 - recv_bytes) / 1e9:.1f}, transmitted {(transmit_bytes1 - transmit_bytes) / 1e9:.1f} " f"in {elapsed_time:.1f} seconds") log(f"predicted {gradient_size * args.iters / 1e9:.1f}") log(f"average observed bw: {(recv_bytes1 - recv_bytes) * 8 / elapsed_time / 1e9:.1f} Gbps") time_to_sync_buffer_sec = np.mean(time_list)/1000 effective_bw_gbps = gradient_size/time_to_sync_buffer_sec*8/1e9 log(f"average effective bw: {effective_bw_gbps} Gbps")
def main_loop(): util.cancel_shutdown() losses = [] args = g.args if not args.local: g.logger.info( f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}') dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=util.get_world_size()) assert (util.get_world_size() == dist.get_world_size()) g.logger.info(f"Distributed: success ({args.local_rank}/{dist.get_world_size()})") if args.load_state_fn: g.state = load_state(args.load_state_fn) g.logger.info(f"Restoring training from {args.load_state_fn}") else: g.logger.info("creating new model") g.state = TrainState(args) g.state.model = MemTransformerLM(g.ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=g.tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=g.cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) if args.checkpoint: util.restore_from_checkpoint(g.state.model, checkpoint_fn=args.checkpoint) else: g.state.model.apply(weights_init) g.state.model.word_emb.apply( weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing g.state.model.to(g.device) optimizer_setup(g.state) model: MemTransformerLM = g.state.model optimizer = g.state.optimizer # log model info # n_all_param = sum([p.nelement() for p in model.parameters()]) # log_tb('sizes/params', n_all_param) # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) # log_tb('sizes/non_emb_params', n_nonemb_param) # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param) # scheduler if not g.args.load_state_fn: if args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_tokens // 1e6, eta_min=args.eta_min) elif args.scheduler == 'finder': g.state.scheduler: LRFinder = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3) else: assert args.scheduler == 'constant' g.state.scheduler = util.NoOp() # Setup distributed model if args.local: model = nn.DataParallel(model, dim=1) else: # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding. model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) # , find_unused_parameters=True) if util.get_global_rank() == 0: if not args.test: wandb.config.update(vars(args)) # wandb.watch(model) g.event_writer.add_text('args', str(args)) # TODO: replace with log_tb accumulated_loss = 0 # At any point you can hit Ctrl + C to break out of training early. try: for epoch in itertools.count(start=g.state.last_epoch): print(f"epoch -- {epoch}, token_count -- {g.state.
def test_allreduce(): global log recv_bytes, transmit_bytes = util.network_bytes() device = 'cuda' dim = 2 ** 12 # multiple of 8, about 67MB matrix in fp32 if fp16: bytes_per_number = 2 else: bytes_per_number = 4 gradient_size = args.num_layers * (dim * dim) * bytes_per_number size_mb = gradient_size / 1e6 log('initializing process group') dist.init_process_group(backend='nccl', init_method='env://', world_size=util.get_world_size()) xs = [torch.ones((dim, dim)) for i in range(args.num_layers)] xs = [x.to(device) for x in xs] if fp16: xs = [x.half() for x in xs] time_list = [] # force initialization of NCCL dist.all_reduce(torch.ones(()).cuda()) dist.barrier() log("Start timing") start_time = time.perf_counter() start_time0 = start_time for i in range(args.iters): [dist.all_reduce(x, async_op=True) for x in xs] torch.cuda.synchronize() elapsed_time_sec = (time.perf_counter() - start_time) start_time = time.perf_counter() elapsed_time_ms = elapsed_time_sec * 1000 time_list.append(elapsed_time_ms) rate = size_mb / elapsed_time_sec # could do barrier, but didn't have effect on timing # dist.barrier() new_result = xs[0] log('%03d/%d added %d MBs in %.1f ms: %.2f MB/second %.1f' % ( i, args.iters, size_mb, elapsed_time_ms, rate, new_result[0,0])) del time_list[0] # first measurement is off because of syncing min_time = np.min(time_list) median = np.median(time_list) log(f"min: {min_time:8.2f}, median: {median:8.2f}, mean: {np.mean(time_list):8.2f}") dist.barrier() elapsed_time = time.perf_counter() - start_time0 recv_bytes1, transmit_bytes1 = util.network_bytes() log(f"Received {(recv_bytes1-recv_bytes)/1e9:.1f}, transmitted {(transmit_bytes1-transmit_bytes)/1e9:.1f} in {elapsed_time:.1f} seconds") log(f"predicted {gradient_size*args.iters/1e9:.1f}") log(f"average bw: {(recv_bytes1-recv_bytes)*8/elapsed_time/1e9:.1f} Gbps")