def test_horovod_allreduce_cpu_gpu_error(self): """Test that the allreduce raises an error if different ranks try to perform reduction on CPU and GPU.""" # Only do this test if there are GPUs available. if not torch.cuda.is_available(): return hvd.init() rank = hvd.rank() size = hvd.size() # This test does not apply if there is only one worker. if size == 1: return # Same rank, different dimension dims = [17] * 3 if rank % 2 == 0: tensor = torch.cuda.FloatTensor(*dims) else: tensor = torch.FloatTensor(*dims) try: hvd.allreduce(tensor) assert False, 'hvd.allreduce did not throw error' except torch.FatalError: pass
def test_horovod_allreduce_error(self): """Test that the allreduce raises an error if different ranks try to send tensors of different rank or dimension.""" hvd.init() rank = hvd.rank() size = hvd.size() # This test does not apply if there is only one worker. if size == 1: return # Same rank, different dimension torch.manual_seed(1234) dims = [17 + rank] * 3 tensor = torch.FloatTensor(*dims).random_(-100, 100) try: hvd.allreduce(tensor) assert False, 'hvd.allreduce did not throw error' except torch.FatalError: pass # Same number of elements, different rank torch.manual_seed(1234) if rank == 0: dims = [17, 23 * 57] else: dims = [17, 23, 57] tensor = torch.FloatTensor(*dims).random_(-100, 100) try: hvd.allreduce(tensor) assert False, 'hvd.allreduce did not throw error' except torch.FatalError: pass
def test_horovod_allreduce_average(self): """Test that the allreduce correctly sums 1D, 2D, 3D tensors.""" hvd.init() size = hvd.size() dtypes = [torch.IntTensor, torch.LongTensor, torch.FloatTensor, torch.DoubleTensor] if torch.cuda.is_available(): dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor, torch.cuda.DoubleTensor] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): torch.manual_seed(1234) tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100) tensor = tensor.type(dtype) averaged = hvd.allreduce(tensor, average=True) max_difference = averaged.data.sub(tensor).max() # Threshold for floating point equality depends on number of # ranks, since we're comparing against precise multiplication. if size <= 3 or dtype in [torch.IntTensor, torch.LongTensor, torch.cuda.IntTensor, torch.cuda.LongTensor]: threshold = 0 elif size < 10: threshold = 1e-4 elif size < 15: threshold = 5e-4 else: break assert max_difference <= threshold, 'hvd.allreduce produces incorrect results'
def test_horovod_allreduce_grad(self): """Test the correctness of the allreduce gradient.""" hvd.init() size = hvd.size() dtypes = [torch.IntTensor, torch.LongTensor, torch.FloatTensor, torch.DoubleTensor] if torch.cuda.is_available(): dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor, torch.cuda.DoubleTensor] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): torch.manual_seed(1234) tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100) tensor = tensor.type(dtype) tensor = torch.autograd.Variable(tensor, requires_grad=True) summed = hvd.allreduce(tensor, average=False) summed.backward(torch.ones([17] * dim)) grad_out = tensor.grad.data.numpy() expected = np.ones([17] * dim) * size err = np.linalg.norm(expected - grad_out) self.assertLess(err, 0.00000001, "gradient %s differs from expected %s, " "error: %s" % (grad_out, expected, str(err)))
def test_horovod_allreduce_type_error(self): """Test that the allreduce raises an error if different ranks try to send tensors of different type.""" hvd.init() rank = hvd.rank() size = hvd.size() # This test does not apply if there is only one worker. if size == 1: return # Same rank, different dimension dims = [17] * 3 if rank % 2 == 0: tensor = torch.IntTensor(*dims) else: tensor = torch.FloatTensor(*dims) try: hvd.allreduce(tensor) assert False, 'hvd.allreduce did not throw error' except torch.FatalError: pass
def train_once( sess, step, ops, names=None, gen_feed_dict_fn=None, deal_results_fn=None, interval_steps=100, eval_ops=None, eval_names=None, gen_eval_feed_dict_fn=None, deal_eval_results_fn=melt.print_results, valid_interval_steps=100, print_time=True, print_avg_loss=True, model_dir=None, log_dir=None, is_start=False, num_steps_per_epoch=None, metric_eval_fn=None, metric_eval_interval_steps=0, summary_excls=None, fixed_step=None, # for epoch only, incase you change batch size eval_loops=1, learning_rate=None, learning_rate_patience=None, learning_rate_decay_factor=None, num_epochs=None, model_path=None, use_horovod=False, ): use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ if use_horovod: if FLAGS.torch: import horovod.torch as hvd else: import horovod.tensorflow as hvd #is_start = False # force not to evaluate at first step #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step())) timer = gezi.Timer() if print_time: if not hasattr(train_once, 'timer'): train_once.timer = Timer() train_once.eval_timer = Timer() train_once.metric_eval_timer = Timer() melt.set_global('step', step) epoch = (fixed_step or step) / num_steps_per_epoch if num_steps_per_epoch else -1 if not num_epochs: epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else '' else: epoch_str = 'epoch:%.3f/%d' % ( epoch, num_epochs) if num_steps_per_epoch else '' melt.set_global('epoch', '%.2f' % (epoch)) info = IO() stop = False if eval_names is None: if names: eval_names = ['eval/' + x for x in names] if names: names = ['train/' + x for x in names] if eval_names: eval_names = ['eval/' + x for x in eval_names] is_eval_step = is_start or valid_interval_steps and step % valid_interval_steps == 0 summary_str = [] eval_str = '' if is_eval_step: # deal with summary if log_dir: if not hasattr(train_once, 'summary_op'): #melt.print_summary_ops() if summary_excls is None: train_once.summary_op = tf.summary.merge_all() else: summary_ops = [] for op in tf.get_collection(tf.GraphKeys.SUMMARIES): for summary_excl in summary_excls: if not summary_excl in op.name: summary_ops.append(op) print('filtered summary_ops:') for op in summary_ops: print(op) train_once.summary_op = tf.summary.merge(summary_ops) #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN) train_once.summary_writer = tf.summary.FileWriter( log_dir, sess.graph) tf.contrib.tensorboard.plugins.projector.visualize_embeddings( train_once.summary_writer, projector_config) # if eval ops then should have bee rank 0 if eval_ops: #if deal_eval_results_fn is None and eval_names is not None: # deal_eval_results_fn = lambda x: melt.print_results(x, eval_names) for i in range(eval_loops): eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn( ) #eval_feed_dict.update(feed_dict) # if use horovod let each rant use same sess.run! if not log_dir or train_once.summary_op is None or gezi.env_has( 'EVAL_NO_SUMMARY') or use_horovod: #if not log_dir or train_once.summary_op is None: eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict) else: eval_results = sess.run(eval_ops + [train_once.summary_op], feed_dict=eval_feed_dict) summary_str = eval_results[-1] eval_results = eval_results[:-1] eval_loss = gezi.get_singles(eval_results) #timer_.print() eval_stop = False if use_horovod: sess.run(hvd.allreduce(tf.constant(0))) #if not use_horovod or hvd.local_rank() == 0: # @TODO user print should also use logging as a must ? #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='') eval_names_ = melt.adjust_names(eval_loss, eval_names) #if not use_horovod or hvd.rank() == 0: # logging.info2('{} eval_step:{} eval_metrics:{}'.format(epoch_str, step, melt.parse_results(eval_loss, eval_names_))) eval_str = 'valid:{}'.format( melt.parse_results(eval_loss, eval_names_)) # if deal_eval_results_fn is not None: # eval_stop = deal_eval_results_fn(eval_results) assert len(eval_loss) > 0 if eval_stop is True: stop = True eval_names_ = melt.adjust_names(eval_loss, eval_names) if not use_horovod or hvd.rank() == 0: melt.set_global('eval_loss', melt.parse_results(eval_loss, eval_names_)) elif interval_steps != valid_interval_steps: #print() pass metric_evaluate = False # if metric_eval_fn is not None \ # and (is_start \ # or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \ # or (metric_eval_interval_steps \ # and step % metric_eval_interval_steps == 0)): # metric_evaluate = True if metric_eval_fn is not None \ and ((is_start or metric_eval_interval_steps \ and step % metric_eval_interval_steps == 0) or model_path): metric_evaluate = True if 'EVFIRST' in os.environ: if os.environ['EVFIRST'] == '0': if is_start: metric_evaluate = False else: if is_start: metric_evaluate = True if step == 0 or 'QUICK' in os.environ: metric_evaluate = False #print('------------1step', step, 'pre metric_evaluate', metric_evaluate, hvd.rank()) if metric_evaluate: # if use_horovod: # print('------------metric evaluate step', step, model_path, hvd.rank()) if not model_path or 'model_path' not in inspect.getargspec( metric_eval_fn).args: metric_eval_fn_ = metric_eval_fn else: metric_eval_fn_ = lambda: metric_eval_fn(model_path=model_path) try: l = metric_eval_fn_() if isinstance(l, tuple): num_returns = len(l) if num_returns == 2: evaluate_results, evaluate_names = l evaluate_summaries = None else: assert num_returns == 3, 'retrun 1,2,3 ok 4.. not ok' evaluate_results, evaluate_names, evaluate_summaries = l else: #return dict evaluate_results, evaluate_names = tuple(zip(*dict.items())) evaluate_summaries = None except Exception: logging.info('Do nothing for metric eval fn with exception:\n', traceback.format_exc()) if not use_horovod or hvd.rank() == 0: #logging.info2('{} valid_step:{} {}:{}'.format(epoch_str, step, 'valid_metrics' if model_path is None else 'epoch_valid_metrics', melt.parse_results(evaluate_results, evaluate_names))) logging.info2('{} valid_step:{} {}:{}'.format( epoch_str, step, 'valid_metrics', melt.parse_results(evaluate_results, evaluate_names))) if learning_rate is not None and (learning_rate_patience and learning_rate_patience > 0): assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1 valid_loss = evaluate_results[0] if not hasattr(train_once, 'min_valid_loss'): train_once.min_valid_loss = valid_loss train_once.deacy_steps = [] train_once.patience = 0 else: if valid_loss < train_once.min_valid_loss: train_once.min_valid_loss = valid_loss train_once.patience = 0 else: train_once.patience += 1 logging.info2('{} valid_step:{} patience:{}'.format( epoch_str, step, train_once.patience)) if learning_rate_patience and train_once.patience >= learning_rate_patience: lr_op = ops[1] lr = sess.run(lr_op) * learning_rate_decay_factor train_once.deacy_steps.append(step) logging.info2( '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}' .format(epoch_str, step, learning_rate_decay_factor, ','.join(map(str, train_once.deacy_steps)))) sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32))) train_once.patience = 0 train_once.min_valid_loss = valid_loss if ops is not None: #if deal_results_fn is None and names is not None: # deal_results_fn = lambda x: melt.print_results(x, names) feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn() # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar #print('---------------ops', ops) if eval_ops is not None or not log_dir or not hasattr( train_once, 'summary_op') or train_once.summary_op is None or use_horovod: feed_dict[K.learning_phase()] = 1 results = sess.run(ops, feed_dict=feed_dict) else: ## TODO why below ? #try: feed_dict[K.learning_phase()] = 1 results = sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) summary_str = results[-1] results = results[:-1] # except Exception: # logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail') # results = sess.run(ops, feed_dict=feed_dict) #print('------------results', results) # #--------trace debug # if step == 210: # run_metadata = tf.RunMetadata() # results = sess.run( # ops, # feed_dict=feed_dict, # options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), # run_metadata=run_metadata) # from tensorflow.python.client import timeline # trace = timeline.Timeline(step_stats=run_metadata.step_stats) # trace_file = open('timeline.ctf.json', 'w') # trace_file.write(trace.generate_chrome_trace_format()) #reults[0] assume to be train_op, results[1] to be learning_rate learning_rate = results[1] results = results[2:] #@TODO should support aver loss and other avg evaluations like test.. if print_avg_loss: if not hasattr(train_once, 'avg_loss'): train_once.avg_loss = AvgScore() #assume results[0] as train_op return, results[1] as loss loss = gezi.get_singles(results) train_once.avg_loss.add(loss) steps_per_second = None instances_per_second = None hours_per_epoch = None #step += 1 #if is_start or interval_steps and step % interval_steps == 0: interval_ok = not use_horovod or hvd.local_rank() == 0 if interval_steps and step % interval_steps == 0 and interval_ok: train_average_loss = train_once.avg_loss.avg_score() if print_time: duration = timer.elapsed() duration_str = 'duration:{:.2f} '.format(duration) melt.set_global('duration', '%.2f' % duration) #info.write(duration_str) elapsed = train_once.timer.elapsed() steps_per_second = interval_steps / elapsed batch_size = melt.batch_size() num_gpus = melt.num_gpus() instances_per_second = interval_steps * batch_size / elapsed gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format( num_gpus) if num_steps_per_epoch is None: epoch_time_info = '' else: hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600 epoch_time_info = '1epoch:[{:.2f}h]'.format( hours_per_epoch) # info.write('elapsed:[{:.2f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.6f}]'.format( # elapsed, batch_size, gpu_info, steps_per_second, instances_per_second, epoch_time_info, learning_rate)) info.write( 'elap:[{:.2f}] batch:[{}] {} lr:[{:.6f}]'.format( elapsed, batch_size, epoch_time_info, learning_rate)) if print_avg_loss: #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names))) names_ = melt.adjust_names(train_average_loss, names) #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_))) info.write(' train:{} '.format( melt.parse_results(train_average_loss, names_))) #info.write('train_avg_loss: {} '.format(train_average_loss)) info.write(eval_str) #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ') logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step, info.getvalue())) if deal_results_fn is not None: stop = deal_results_fn(results) summary_strs = gezi.to_list(summary_str) if metric_evaluate: if evaluate_summaries is not None: summary_strs += evaluate_summaries if step > 1: if is_eval_step: # deal with summary if log_dir: summary = tf.Summary() if eval_ops is None: if train_once.summary_op is not None: for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) else: for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) suffix = 'valid' if not eval_names else '' # loss/valid melt.add_summarys(summary, eval_results, eval_names_, suffix=suffix) if ops is not None: try: # loss/train_avg melt.add_summarys(summary, train_average_loss, names_, suffix='train_avg') except Exception: pass ##optimizer has done this also melt.add_summary(summary, learning_rate, 'learning_rate') melt.add_summary(summary, melt.batch_size(), 'batch_size', prefix='other') melt.add_summary(summary, melt.epoch(), 'epoch', prefix='other') if steps_per_second: melt.add_summary(summary, steps_per_second, 'steps_per_second', prefix='perf') if instances_per_second: melt.add_summary(summary, instances_per_second, 'instances_per_second', prefix='perf') if hours_per_epoch: melt.add_summary(summary, hours_per_epoch, 'hours_per_epoch', prefix='perf') if metric_evaluate: #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') prefix = 'step_eval' if model_path: prefix = 'eval' valid_interval_epochs = 1. try: valid_interval_epochs = FLAGS.valid_interval_epochs except Exception: pass if not hasattr(train_once, 'epoch_step'): train_once.epoch_step = 1 if melt.epoch( ) <= 1 else int( int(melt.epoch() * 10) / int(valid_interval_epochs * 10)) else: train_once.epoch_step += 1 step = train_once.epoch_step # eval/loss eval/auc .. melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) train_once.summary_writer.add_summary(summary, step) train_once.summary_writer.flush() return stop elif metric_evaluate and log_dir: summary = tf.Summary() for summary_str in summary_strs: train_once.summary_writer.add_summary(summary_str, step) #summary.ParseFromString(evaluate_summaries) summary_writer = train_once.summary_writer prefix = 'step_eval' if model_path: prefix = 'eval' if not hasattr(train_once, 'epoch_step'): ## TODO.. restart will get 1 again.. #epoch_step = tf.Variable(0, trainable=False, name='epoch_step') #epoch_step += 1 #train_once.epoch_step = sess.run(epoch_step) valid_interval_epochs = 1. try: valid_interval_epochs = FLAGS.valid_interval_epochs except Exception: pass train_once.epoch_step = 1 if melt.epoch() <= 1 else int( int(melt.epoch() * 10) / int(valid_interval_epochs * 10)) logging.info('train_once epoch start step is', train_once.epoch_step) else: #epoch_step += 1 train_once.epoch_step += 1 step = train_once.epoch_step #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) summary_writer.add_summary(summary, step) summary_writer.flush()
def metric_average(val, name): tensor = torch.tensor(val) avg_tensor = hvd.allreduce(tensor, name=name) return avg_tensor.item()
def barrier(self) -> None: # https://github.com/horovod/horovod/issues/159#issuecomment-424834603 # hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier") hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier")
def train_loop( run_id, dataset_dir, ckpt_run_dir, output_dir, validation_only=False, use_cuda=False, light_target=False, seed=42, ): """Train loop""" train_epochs = 10 math_mode = "fp16" rank = dist.get_rank() world_size = dist.get_world_size() # Dataset arguments train_global_batch_size = 2**17 # Global batch size max_bs = 2**13 # Max batch size for used hardware update_freq = int(max(1, train_global_batch_size // (max_bs * world_size))) max_tokens = int(train_global_batch_size // (world_size * update_freq)) max_source_positions, max_target_positions = 80, 80 seq_len_multiple = 2 left_pad = (True, False) lang = ("en", "de") # specific arch model_args = deepcopy(DEFAULT_TRANSFORMER_ARCH) model_args["max_source_positions"] = max_source_positions model_args["max_target_positions"] = max_target_positions model_args["share_all_embeddings"] = True model_args["dropout"] = 0.1 model_args["softmax_type"] = "fast_fill" lr = 1.976e-3 optimizer_args = { "lr": lr, "eps": 1e-9, "betas": (0.9, 0.98), } scheduler_args = { "base_lr": lr, "warmup_init_lr": 0.0, "warmup_steps": 1000 } loss_scaling_fp16 = { "init_scale": 2.0**7, "scale_factor": 2, "scale_window": 2000, } criterion_args = {"smoothing": 0.1, "fast_xentropy": True} # Horovod stuff use_horovod = (math_mode == "fp16") and dist.get_backend() == dist.Backend.MPI if use_horovod: hvd.init() logger.info("Using horovod rank={}".format(hvd.rank())) tensor = torch.tensor([1]) res = hvd.allreduce(tensor, op=hvd.Sum) assert res[0] == world_size # Load train and validation datasets train_set = WMT17Dataset( dataset_dir, download=True, train=True, shuffle=True, lang=lang, left_pad=left_pad, max_positions=(max_source_positions, max_target_positions), seq_len_multiple=seq_len_multiple, ) validation_set = WMT17Dataset( dataset_dir, download=False, test=True, shuffle=True, lang=lang, left_pad=left_pad, max_positions=(max_source_positions, max_target_positions), seq_len_multiple=seq_len_multiple, ) src_dict, trg_dict = train_set.src_dict, train_set.trg_dict train_batches = get_batches(train_set, max_tokens=max_tokens, bsz_mult=8, shuffle=True, seed=seed) val_batches = get_batches(validation_set, max_tokens=max_tokens, bsz_mult=8, shuffle=False) train_batches = equalize_batches(train_batches, world_size, seed=seed) # Partition by rank train_batches = partition_dataset_by_rank(train_batches, rank, world_size) val_batches = partition_dataset_by_rank(val_batches, rank, world_size) total_train_points = sum(len(b) for b in train_batches) validate_every = update_freq * round( len(train_batches) * 0.30 / update_freq) # Validate every 30% assert (validate_every % update_freq) == 0 logger.info("Using {} total train points, {} batches".format( total_train_points, len(train_batches))) train_loader = DataLoader( train_set, num_workers=1, pin_memory=False, collate_fn=train_set.collater, batch_sampler=train_batches, ) val_loader = DataLoader( validation_set, num_workers=1, pin_memory=False, collate_fn=validation_set.collater, batch_sampler=val_batches, ) model = TransformerModel(Arguments(model_args), src_dict, trg_dict) criterion = LabelSmoothing(padding_idx=src_dict.pad(), **criterion_args) if use_cuda: model = model.cuda() criterion = criterion.cuda() fp_optimizer, optimizer, model = build_optimizer( model, optimizer_args, math_mode=math_mode, scaling_args=loss_scaling_fp16, use_horovod=use_horovod, use_cuda=use_cuda, ) scheduler = SQRTTimeDecayLRWithWarmup(optimizer, **scheduler_args) metrics = [BLEUScore(use_raw=True)] checkpointer = Checkpointer(ckpt_run_dir=ckpt_run_dir, rank=rank, freq=CheckpointFreq.BEST) translator = SequenceGenerator( model, src_dict=deepcopy(src_dict), trg_dict=deepcopy(trg_dict), beam_size=4, stop_early=True, normalize_scores=True, len_penalty=0.6, sampling=False, sampling_topk=-1, minlen=1, ) if not validation_only: if light_target: goal = task4_time_to_bleu_goal(20) else: goal = task4_time_to_bleu_goal(25) num_batches_per_device_train = len(train_loader) tracker = Tracker(metrics, run_id, rank, goal=goal) dist.barrier() tracker.start() for epoch in range(0, train_epochs): if torch.cuda.is_available(): torch.cuda.empty_cache() model.train() tracker.train() iter_sample_size = 0 for batch_idx, sample in enumerate(train_loader): tracker.batch_start() sample = prepare_batch(sample, use_cuda=use_cuda) tracker.record_batch_load() is_last = batch_idx == len(train_loader) update = (batch_idx % update_freq) == update_freq - 1 init = (batch_idx % update_freq) == 0 # Clear gradients in the optimizer. if init: fp_optimizer.zero_grad() iter_sample_size = 0 tracker.record_batch_init() # Compute the output output = model(**sample["net_input"]) tracker.record_batch_fwd_pass() loss, sample_size = compute_loss(sample, output, criterion) loss_per_sample = loss.item() / sample_size iter_sample_size += sample_size tracker.record_batch_comp_loss() # Backprop fp_optimizer.backward_loss(loss) tracker.record_batch_backprop() if update or is_last: # Get batch size over all workers full_bs = get_full_batch_size(iter_sample_size, world_size=world_size, use_cuda=use_cuda) updated = opt_step( fp_optimizer, tracker, full_bs, update_freq, math_mode, world_size, ) if updated: scheduler.step() tracker.batch_end() record_train_batch_stats( batch_idx=batch_idx, loss=loss_per_sample, output=torch.Tensor([0]), metric_results={}, tracker=tracker, num_batches_per_device_train=num_batches_per_device_train, ) if (batch_idx + 1) % validate_every == 0: if torch.cuda.is_available(): torch.cuda.empty_cache() metric_values, loss = validation_round( val_loader, metrics, criterion, translator, tracker=tracker, use_cuda=use_cuda, ) record_validation_stats(metric_values, loss, tracker, rank) if tracker.goal_reached: break model.train() tracker.train() if torch.cuda.is_available(): torch.cuda.empty_cache() metric_values, loss = validation_round( val_loader, metrics, criterion, translator, tracker=tracker, use_cuda=use_cuda, ) is_best = record_validation_stats(metric_values, loss, tracker, rank) checkpointer.save( tracker, model, optimizer, scheduler, tracker.current_epoch, is_best, ) tracker.epoch_end() if tracker.goal_reached: print("Goal Reached!") time.sleep(10) return else: cecf = CheckpointsEvaluationControlFlow( ckpt_dir=ckpt_run_dir, rank=rank, world_size=world_size, checkpointer=checkpointer, model=model, epochs=train_epochs, loss_function=criterion, metrics=metrics, use_cuda=use_cuda, dtype="fp32", max_batch_per_epoch=None, ) train_stats = cecf.evaluate_by_epochs(train_loader) with open(os.path.join(output_dir, "train_stats.json"), "w") as f: json.dump(train_stats, f)
def metric_sum_hvd(val, name): tensor = torch.tensor(val) avg_tensor = hvd.allreduce(tensor, name=name, average=False) return avg_tensor.item()
def evaluate(args): # initialize Horovod library hvd.init() # Horovod limits CPU threads to be used per worker torch.set_num_threads(1) if hvd.local_rank() == 0 and not os.path.exists(args.dir): # create 16 random image, mask paris for evaluation print( f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(16): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ]) # create a evaluation data loader val_ds = Dataset(data=val_files, transform=val_transforms) # create a evaluation data sampler val_sampler = DistributedSampler(val_ds, shuffle=False, num_replicas=hvd.size(), rank=hvd.rank()) # when supported, use "forkserver" to spawn dataloader workers instead of "fork" to prevent # issues with Infiniband implementations that are not fork-safe multiprocessing_context = None if hasattr( mp, "_supports_context" ) and mp._supports_context and "forkserver" in mp.get_all_start_methods(): multiprocessing_context = "forkserver" # sliding window inference need to input 1 image in every iteration val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler, multiprocessing_context=multiprocessing_context, ) dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{hvd.local_rank()}") model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) if hvd.rank() == 0: # load model parameters for evaluation model.load_state_dict(torch.load("final_model.pth")) # Horovod broadcasts parameters hvd.broadcast_parameters(model.state_dict(), root_rank=0) model.eval() with torch.no_grad(): # define PyTorch Tensor to record metrics result at each GPU # the first value is `sum` of all dice metric, the second value is `count` of not_nan items metric = torch.zeros(2, dtype=torch.float, device=device) for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = dice_metric(y_pred=val_outputs, y=val_labels).squeeze() metric[0] += value * dice_metric.not_nans metric[1] += dice_metric.not_nans # synchronizes all processes and reduce results print( f"metric in rank {hvd.rank()}: sum={metric[0].item()}, count={metric[1].item()}" ) avg_metric = hvd.allreduce(metric, name="mean_dice") if hvd.rank() == 0: print( f"average metric: sum={avg_metric[0].item()}, count={avg_metric[1].item()}" ) print("evaluation metric:", (avg_metric[0] / avg_metric[1]).item())
def average_value(self, val, name): avg_tensor = hvd.allreduce(val, name=name) return avg_tensor
def update(self, val): import horovod.torch as hvd self.sum += hvd.allreduce(val.detach().cpu(), name=self.name) self.n += 1
def update(self, val): self.sum += hvd.allreduce(val.detach().cpu(), name=self.name) self.n += 1
def metric_average(val, name): tensor = torch.tensor(val) avg_tensor = hvd.allreduce(tensor, name=name) return avg_tensor.item()
def metric_sum(value): return hvd.allreduce(torch.tensor(value), op=hvd.Sum).item()
def metric_ave(value): return hvd.allreduce(torch.tensor(value)).item()
def update(self, val, delta_n=1): import horovod.torch as hvd val *= delta_n self.sum += hvd.allreduce(val.detach().cpu(), name=self.name) self.count += delta_n
def avg(self): import horovod.torch as hvd if not self.synced: self.sum = hvd.allreduce(self.sum, name=self.name) self.synced = True return self.sum / self.count
def metric_average(val, name): tensor = torch.FloatTensor([val]) avg_tensor = hvd.allreduce(tensor, name=name) return avg_tensor.data[0]
def run(i_run, options, train_data, valid_data, test_data, model, optimizer, handles, outfile): train_dataloader, train_idx = create_train_dataset( options, data_tensor_dict=train_data) valid_dataloader, valid_idx = create_valid_test_dataset( options, data_tensor_dict=valid_data) test_dataloader, test_idx = create_valid_test_dataset( options, data_tensor_dict=test_data) total_steps = get_train_step(len(train_data['idx']), 1, options.batchsize, hvd.size()) train_step = 0 best = {'val_acc': 0.0, 'epoch': 0} for g in optimizer.param_groups: g['lr'] = options.lr scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=0.1, patience=options.lr_patience, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=1e-2, eps=1e-08, verbose=True) model.reset_parameters() hvd.broadcast_optimizer_state(optimizer, root_rank=0) hvd.broadcast_parameters(model.state_dict(), root_rank=0) train_epoch = 0 for epoch in range(options.epochs): # train t0 = time.time() total_cla_loss, train_acc = train(train_dataloader, model, optimizer, total_steps, train_idx, handles.train_label_handle) t1 = time.time() #valid valid_acc = valid(valid_dataloader, model, len(valid_idx), valid_idx, handles) t2 = time.time() scheduler.step(valid_acc) if valid_acc < best['val_acc']: if hvd.rank() == 0: print( 'run=%02d, epoch=%03d, loss=%.4f, train_acc=%.2f%%, valid_acc=%.2f%%, train_time=%.2fs, valid_time=%.2fs' % (i_run, epoch, total_cla_loss, train_acc, valid_acc, t1 - t0, t2 - t1)) if epoch > best['epoch'] + options.stop_patience: break else: #test test_acc = test(test_dataloader, model, len(test_idx), test_idx, handles) t3 = time.time() if hvd.rank() == 0: print( 'run=%02d, epoch=%03d, loss=%.4f, train_acc=%.2f%%, valid_acc=%.2f%%, test_acc=%.2f%%, train_time=%.2fs, valid_time=%.2fs, test_time=%.2fs' % (i_run, epoch, total_cla_loss, train_acc, valid_acc, test_acc, t1 - t0, t2 - t1, t3 - t2)) best['val_acc'] = valid_acc best['loss'] = total_cla_loss best['test_acc'] = test_acc best['epoch'] = epoch best['train_acc'] = train_acc hvd.allreduce(torch.tensor(0)) if hvd.rank() == 0: print( '[BEST] run=%02d, epoch=%03d, loss=%.4f, train_acc=%.2f%%, valid_acc=%.2f%%, test_acc=%.2f%%' % (i_run, best['epoch'], best['loss'], best['train_acc'], best['val_acc'], best['test_acc'])) print( '[BEST] epoch=%03d, loss=%.4f, train_acc=%.2f%%, valid_acc=%.2f%%, test_acc=%.2f%%' % (best['epoch'], best['loss'], best['train_acc'], best['val_acc'], best['test_acc']), file=outfile) outfile.flush() return best