Esempio n. 1
0
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
                 loss_scale, report_memory_flag):
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()

    # Update losses.
    for key in loss_dict:
        total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]

    # Logging.
    timers_to_log = []

    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)

    add_to_logging('forward')
    add_to_logging('backward')
    add_to_logging('allreduce')
    add_to_logging('optimizer')
    add_to_logging('batch generator')

    # Tensorboard values.
    if writer and torch.distributed.get_rank() == 0:
        writer.add_scalar('learning_rate', learning_rate, iteration)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
        if args.fp16:
            writer.add_scalar('loss_scale', loss_scale, iteration)
        normalizer = iteration % args.log_interval
        if normalizer == 0:
            normalizer = args.log_interval
        timers.write(timers_to_log, writer, iteration, normalizer=normalizer)

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('iteration_time',
                              elapsed_time / args.log_interval, iteration)
        log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
                                                       args.train_iters)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time * 1000.0 / args.log_interval)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
        for key in total_loss_dict:
            avg = total_loss_dict[key].item() / args.log_interval
            log_string += ' {}: {:.6E} |'.format(key, avg)
            total_loss_dict[key] = 0.0
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
        print_rank_0(log_string)
        if report_memory_flag:
            report_memory('after {} iterations'.format(iteration))
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag
Esempio n. 2
0
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
                 loss_scale, report_memory_flag, skipped_iter):
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()

    # Update losses.
    skipped_iters_key = 'skipped iterations'
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
    got_nan_key = 'got nan'

    got_nan = False
    for key in loss_dict:
        if not skipped_iter:
            total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
            got_nan = got_nan or is_nan

    total_loss_dict[got_nan_key] = total_loss_dict.get(
        got_nan_key, 0) + int(got_nan)

    # Logging.
    timers_to_log = []

    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
    add_to_logging('forward')
    add_to_logging('VocabParallelEmbedding forward reduce')
    add_to_logging('ColumnParallelLinear forward gather')
    add_to_logging('RowParallelLinear forward reduce')
    add_to_logging('backward')
    add_to_logging('backward-backward')
    add_to_logging('backward-allreduce')
    add_to_logging('backward-master-grad')
    add_to_logging('backward-clip-grad')
    add_to_logging('optimizer')
    add_to_logging('batch generator')
    add_to_logging('_reduce inside')
    add_to_logging('_gather inside')
    add_to_logging('CopyToModelParallelRegion BACKWARD _reduce')
    add_to_logging('ReduceFromModelParallelRegion SYMBOLIC _reduce')
    add_to_logging('ReduceFromModelParallelRegion FORWARD _reduce')
    add_to_logging('ScatterToModelParallelRegion BACKWARD _gather')
    add_to_logging('GatherFromModelParallelRegion SYMBOLIC _gather')
    add_to_logging('GatherFromModelParallelRegion FORWARD _gather')

    # Tensorboard values.
    if writer and torch.distributed.get_rank() == 0:
        writer.add_scalar('learning_rate', learning_rate, iteration)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
        if args.fp16:
            writer.add_scalar('loss_scale', loss_scale, iteration)
        normalizer = iteration % args.log_interval
        if normalizer == 0:
            normalizer = args.log_interval
        timers.write(timers_to_log, writer, iteration,
                     normalizer=normalizer)

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('iteration_time',
                              elapsed_time / args.log_interval, iteration)
        log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
                                                       args.train_iters)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time * 1000.0 / args.log_interval)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
        num_iterations = max(
            1, args.log_interval - total_loss_dict[skipped_iters_key])
        for key in total_loss_dict:
            if key not in [skipped_iters_key, got_nan_key]:
                avg = total_loss_dict[key].item() / float(num_iterations)
                log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = 0.0
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
        log_string += ' number of nan iterations: {:3d} |'.format(
            total_loss_dict[got_nan_key])
        total_loss_dict[skipped_iters_key] = 0
        total_loss_dict[got_nan_key] = 0
        print_rank_0(log_string)
        if report_memory_flag:
            report_memory('after {} iterations'.format(iteration))
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag
Esempio n. 3
0
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
                 loss_scale, report_memory_flag, skipped_iter, model=None):
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()

    # Update losses.
    skipped_iters_key = 'skipped iterations'
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
    got_nan_key = 'got nan'

    got_nan = False
    for key in loss_dict:
        if not skipped_iter:
            total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
            got_nan = got_nan or is_nan

    total_loss_dict[got_nan_key] = total_loss_dict.get(
        got_nan_key, 0) + int(got_nan)

    # Logging.
    timers_to_log = []

    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
    add_to_logging('forward')
    add_to_logging('backward')
    add_to_logging('backward-backward')
    add_to_logging('backward-allreduce')
    add_to_logging('backward-master-grad')
    add_to_logging('backward-clip-grad')
    add_to_logging('optimizer')
    add_to_logging('batch generator')

    # Tensorboard values.
    if writer and torch.distributed.get_rank() == 0:
        writer.add_scalar('tokens', args.tokens, iteration)
        writer.add_scalar('learning_rate', learning_rate, iteration)
        writer.add_scalar('learning_rate/vs tokens', learning_rate, args.tokens)
        if args.curriculum_learning:
            writer.add_scalar('seqlen',
                args.curriculum_seqlen, iteration)
            writer.add_scalar('seqlen/vs tokens',
                args.curriculum_seqlen, args.tokens)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
            writer.add_scalar(key + '/vs tokens', loss_dict[key], args.tokens)
        if args.fp16:
            writer.add_scalar('loss_scale', loss_scale, iteration)
        normalizer = iteration % args.log_interval
        if normalizer == 0:
            normalizer = args.log_interval
        timers.write(timers_to_log, writer, iteration,
                     normalizer=normalizer)

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('iteration_time',
                              elapsed_time / args.log_interval, iteration)
        log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
                                                       args.train_iters)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time * 1000.0 / args.log_interval)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
        num_iterations = max(
            1, args.log_interval - total_loss_dict[skipped_iters_key])
        for key in total_loss_dict:
            if key not in [skipped_iters_key, got_nan_key]:
                avg = total_loss_dict[key].item() / float(num_iterations)
                log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = 0.0
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
        log_string += ' number of nan iterations: {:3d} |'.format(
            total_loss_dict[got_nan_key])
        total_loss_dict[skipped_iters_key] = 0
        total_loss_dict[got_nan_key] = 0
        print_rank_0(log_string)
        if report_memory_flag:
            report_memory('after {} iterations'.format(iteration))
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)
        flops_calculator(model, args, elapsed_time)

    return report_memory_flag
Esempio n. 4
0
def training_log(neox_args, timers, loss_dict, total_loss_dict, learning_rate, iteration,
                 loss_scale, report_memory_flag, skipped_iter, model, optimizer, noise_scale_logger):
    """Log training information such as losses, timing, etc."""

    # Update losses.
    skipped_iters_key = 'skipped iterations'
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
    got_nan_key = 'got nan'

    got_nan = False
    for key in loss_dict:
        if not skipped_iter:
            total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
            got_nan = got_nan or is_nan

    total_loss_dict[got_nan_key] = total_loss_dict.get(
        got_nan_key, 0) + int(got_nan)

    # Logging.
    timers_to_log = []

    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)

    if not neox_args.is_pipe_parallel:
        add_to_logging('forward')
        add_to_logging('backward')
        add_to_logging('backward-backward')
        add_to_logging('backward-allreduce')
        add_to_logging('backward-master-grad')
        add_to_logging('backward-clip-grad')
        add_to_logging('optimizer')
        add_to_logging('batch generator')

        # Log timer info to tensorboard and wandb
        normalizer = iteration % neox_args.log_interval
        if normalizer == 0:
            normalizer = neox_args.log_interval
        if torch.distributed.get_rank() == 0:
            timers.write(names=timers_to_log, iteration=iteration, normalizer=normalizer)
    else:
        # with pipeline parallel, the megatron timers are overridden by the deepspeed ones.
        # Try to grab timer values from model engine. Only recently added to deeperspeed, so check that the engine
        # has that attribute first
        if hasattr(model, 'timer_values') and model.timer_values is not None:
            if model.wall_clock_breakdown() and model.global_steps % model.steps_per_print() == 0:
                timer_values = model.timer_values
                # deepspeed already logs to tensorboard / prints values, so just log to wandb
                if neox_args.use_wandb and torch.distributed.get_rank() == 0:
                    for key in timer_values:
                        tb_wandb_log(f"timers/{key}", timer_values[key], iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer)

    # write losses, lr, etc. every step
    tb_wandb_log('train/learning_rate', learning_rate, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer)
    for key in loss_dict:
        tb_wandb_log(f'train/{key.replace(" ", "_")}', loss_dict[key], iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer)
    if neox_args.fp16:
        tb_wandb_log(f'train/loss_scale', loss_scale, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer)

    # log gradient noise scale
    if neox_args.log_gradient_noise_scale:
        if noise_scale_logger.noise_scale is not None:
            tb_wandb_log(f'train/noise_scale', noise_scale_logger.noise_scale, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer)

    # (optional) Log optimizer states to wandb / tb every step
    if neox_args.log_optimizer_states:
        for k, v in optimizer.state_dict()['optimizer_state_dict']['state'].items():
            for ki, vi in v.items():  # step, module
                if ki != 'step':
                    opt_state_norm = torch.norm(vi) if hasattr(vi, 'dim') else vi
                    tb_wandb_log(f'optimizer_state_norms/{k}_{ki}', opt_state_norm, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer)

    # (optional) Log grad/param norms to wandb / tb every step
    if neox_args.log_grad_norm or neox_args.log_param_norm:
        if neox_args.log_grad_norm:
            model.store_gradients = True  # start storing gradients
        for i, (name, param) in enumerate(model.module.named_parameters()):
            if neox_args.log_grad_norm:
                if hasattr(model, 'stored_gradients') and model.stored_gradients is not None:
                    grad = model.stored_gradients[i]
                    if grad is not None:
                        tb_wandb_log(f'gradient_norms/{name}', torch.norm(grad), iteration,
                                     use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer, all_ranks=True)
            if neox_args.log_param_norm:
                tb_wandb_log(f'parameter_norms/{name}', torch.norm(param), iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer, all_ranks=True)

    if iteration % neox_args.log_interval == 0:
        # log other stuff every neox_args.log_interval iters
        elapsed_time = timers('interval time').elapsed()
        iteration_time = elapsed_time / neox_args.log_interval
        samples_per_sec = neox_args.train_batch_size / iteration_time
        log_string = ' samples/sec: {:.3f} |'.format(samples_per_sec)
        tb_wandb_log('runtime/samples_per_sec', samples_per_sec, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer)
        tb_wandb_log('runtime/iteration_time', iteration_time, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer)
        log_string += ' iteration {:8d}/{:8d} |'.format(iteration, neox_args.train_iters)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time * 1000.0 / neox_args.log_interval)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
        num_iterations = max(
            1, neox_args.log_interval - total_loss_dict[skipped_iters_key])

        # log tflop / gpu
        flops_per_s_per_gpu = get_flops(neox_args=neox_args, model=model, iter_time_s=iteration_time)
        log_string += f' approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |'
        tb_wandb_log('runtime/flops_per_sec_per_gpu', flops_per_s_per_gpu, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer)

        for key in total_loss_dict:
            if key not in [skipped_iters_key, got_nan_key]:
                v = total_loss_dict[key].item() if hasattr(total_loss_dict[key], 'item') else total_loss_dict[key]
                avg = v / float(num_iterations)
                log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = 0.0
        if neox_args.precision == "fp16":
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
        log_string += ' number of nan iterations: {:3d} |'.format(
            total_loss_dict[got_nan_key])
        total_loss_dict[skipped_iters_key] = 0
        total_loss_dict[got_nan_key] = 0
        print_rank_0(log_string)
        if report_memory_flag:
            report_memory('after {} iterations'.format(iteration))
            report_memory_flag = False

        timers.log(timers_to_log, normalizer=neox_args.log_interval)

    return report_memory_flag
Esempio n. 5
0
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
                 loss_scale, report_memory_flag, skipped_iter):
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()

    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
    skipped_iters_key = 'skipped iterations'
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
    # Update losses and set nan iterations
    got_nan = False
    for key in loss_dict:
        if not skipped_iter:
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
            got_nan = got_nan or is_nan
    total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key,
                                                         0) + int(got_nan)

    # Logging.
    timers_to_log = []

    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)

    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
    add_to_logging('forward-send-backward-recv')
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
    add_to_logging('backward-send-forward-recv')
    add_to_logging('backward-params-all-reduce')
    add_to_logging('backward-embedding-all-reduce')
    add_to_logging('optimizer-copy-to-main-grad')
    add_to_logging('optimizer-unscale-and-check-inf')
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
    add_to_logging('optimizer')
    add_to_logging('batch-generator')

    # Calculate batch size.
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

    # Tensorboard values.
    if writer and is_last_rank():
        writer.add_scalar('learning-rate', learning_rate, iteration)
        writer.add_scalar('learning-rate vs samples', learning_rate,
                          args.consumed_train_samples)
        writer.add_scalar('batch-size', batch_size, iteration)
        writer.add_scalar('batch-size vs samples', batch_size,
                          args.consumed_train_samples)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
                              args.consumed_train_samples)
        writer.add_scalar('loss-scale', loss_scale, iteration)
        writer.add_scalar('loss-scale vs samples', loss_scale,
                          args.consumed_train_samples)
        timers.write(timers_to_log,
                     writer,
                     iteration,
                     normalizer=total_iterations)

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
        elapsed_time_per_iteration = elapsed_time / total_iterations
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('iteration-time', elapsed_time_per_iteration,
                              iteration)
        log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
                                                       args.train_iters)
        log_string += ' consumed samples: {:12d} |'.format(
            args.consumed_train_samples)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time_per_iteration * 1000.0)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
        log_string += ' global batch size: {:5d} |'.format(batch_size)
        for key in total_loss_dict:
            if key not in [
                    advanced_iters_key, skipped_iters_key, nan_iters_key
            ]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
        log_string += ' number of nan iterations: {:3d} |'.format(
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
        total_loss_dict[skipped_iters_key] = 0
        total_loss_dict[nan_iters_key] = 0
        print_rank_last(log_string)
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag
Esempio n. 6
0
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
                 loss_scale, report_memory_flag, skipped_iter, model):
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()

    # Update losses.
    skipped_iters_key = 'skipped iterations'
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
    got_nan_key = 'got nan'

    got_nan = False
    for key in loss_dict:
        if not skipped_iter:
            total_loss_dict[key] = total_loss_dict.get(key,
                                                       0.) + loss_dict[key]
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
            got_nan = got_nan or is_nan

    total_loss_dict[got_nan_key] = total_loss_dict.get(got_nan_key,
                                                       0) + int(got_nan)

    # Logging.
    timers_to_log = []

    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)

    if args.pipe_parallel_size <= 0:
        add_to_logging('forward')
        add_to_logging('backward')
        add_to_logging('backward-backward')
        add_to_logging('backward-allreduce')
        add_to_logging('backward-master-grad')
        add_to_logging('backward-clip-grad')
        add_to_logging('optimizer')
        add_to_logging('batch generator')
    else:
        # with pipeline parallel, the megatron timers are overridden by the deepspeed ones.
        # Try to grab timer values from model engine. Only recently added to deeperspeed, so check that the engine
        # has that attribute first
        if hasattr(model, 'timer_values') and model.timer_values is not None:
            if model.wall_clock_breakdown(
            ) and model.global_steps % model.steps_per_print() == 0:
                timer_values = model.timer_values
                # deepspeed already logs to tensorboard / prints values, so just log to wandb
                if get_use_wandb() and torch.distributed.get_rank() == 0:
                    for key in timer_values:
                        wandb.log({key: timer_values[key]}, step=iteration)

    # Log timer info to tensorboard and wandb
    normalizer = iteration % args.log_interval
    if normalizer == 0:
        normalizer = args.log_interval
    if torch.distributed.get_rank() == 0:
        timers.write(names=timers_to_log,
                     iteration=iteration,
                     normalizer=normalizer)

    # wandb writer
    if get_use_wandb() and torch.distributed.get_rank() == 0:
        wandb.log({'learning_rate': learning_rate}, step=iteration)
        for key in loss_dict:
            wandb.log({key: loss_dict[key]}, step=iteration)
        if args.fp16:
            wandb.log({'loss_scale': loss_scale}, step=iteration)

    # Tensorboard values.
    if writer and torch.distributed.get_rank() == 0:
        writer.add_scalar('learning_rate', learning_rate, iteration)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
        if args.fp16:
            writer.add_scalar('loss_scale', loss_scale, iteration)

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
        iteration_time = elapsed_time / args.log_interval
        samples_per_sec = get_global_batch_size(args) / iteration_time
        log_string = ' samples/sec: {:.3f} |'.format(samples_per_sec)
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('samples/sec', samples_per_sec, iteration)
            writer.add_scalar('iteration_time', iteration_time, iteration)
        if get_use_wandb() and torch.distributed.get_rank() == 0:
            wandb.log({'samples/sec': samples_per_sec}, step=iteration)
            wandb.log({'iteration_time': iteration_time}, step=iteration)
        log_string += ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time * 1000.0 / args.log_interval)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
        num_iterations = max(
            1, args.log_interval - total_loss_dict[skipped_iters_key])

        # calculate tflop / gpu
        flops_per_s_per_gpu = get_flops(model, iteration_time)
        log_string += f' approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |'
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('flops/s/gpu', flops_per_s_per_gpu, iteration)
        if get_use_wandb() and torch.distributed.get_rank() == 0:
            wandb.log({'flops/s/gpu': flops_per_s_per_gpu}, step=iteration)

        for key in total_loss_dict:
            if key not in [skipped_iters_key, got_nan_key]:
                v = total_loss_dict[key].item() if hasattr(
                    total_loss_dict[key], 'item') else total_loss_dict[key]
                avg = v / float(num_iterations)
                log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = 0.0
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
        log_string += ' number of nan iterations: {:3d} |'.format(
            total_loss_dict[got_nan_key])
        total_loss_dict[skipped_iters_key] = 0
        total_loss_dict[got_nan_key] = 0
        print_rank_0(log_string)
        if report_memory_flag:
            report_memory('after {} iterations'.format(iteration))
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag