def train(self):
        try:
            with torch.autograd.profiler.emit_nvtx(
                    enabled=self.pyprof_enabled):
                for i in range(self.step + 1, self.final_steps + 1):
                    self.step = i
                    tprint(
                        "------------- TRAIN step : {} -------------".format(
                            i))

                    if self.nvprof_iter_start and i == self.nvprof_iter_start:
                        profiler.start()

                    with Nvtx("step #{}".format(self.step)):
                        loss, meta = self.do_step()

                    if self.nvprof_iter_end and i == self.nvprof_iter_end:
                        profiler.stop()

                    if self.lr_scheduler:
                        for param_group in self.optimizer.param_groups:
                            tprint("lr: {:06f}".format(param_group['lr']))
                        self.lr_scheduler.step(self.step)

                    if self.step % self.log_steps == 0:
                        self.log(loss, meta)

                    if self.ckpt_path and self.save_steps and i % self.save_steps == 0:
                        self.save()

            tprint("Training has been done.")
        except StopIteration:  # done by n_epochs
            tprint("Training has been done. (by n_epochs)")
        except KeyboardInterrupt:
            tprint("Training has been canceled.")
    def on_iteration_end(self):

        if self.step == 4:
            profiler.stop()
            logging.info(f"********************Stopping profiler at step: " +
                         str(self.step))

        if self.global_rank is None or self.global_rank == 0:
            step = self.step
            run_time = time.time() - self._last_iter_start
            logging.info(f"Step {self.step} time: {run_time} seconds")
Beispiel #3
0
def train(train_loader, model, scheduler, optimizer, epoch, args):
    global iteration
    print("{} epoch: \t start training....".format(epoch))
    start = time.time()
    total_loss = []
    model.train()
    model.module.is_training = True
    model.module.freeze_bn()
    optimizer.zero_grad()

    with torch.autograd.profiler.emit_nvtx():
        for idx, (images, annotations) in enumerate(train_loader):
            if iteration >= args.wramup:
                profiler.start()
            images = images.cuda().float()
            annotations = annotations.cuda()
            classification_loss, regression_loss = model([images, annotations])
            classification_loss = classification_loss.mean()
            regression_loss = regression_loss.mean()
            loss = classification_loss + regression_loss
            if bool(loss == 0):
                print('loss equal zero(0)')
                continue
            loss.backward()
            if (idx + 1) % args.grad_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                optimizer.step()
                optimizer.zero_grad()

            total_loss.append(loss.item())
            if(iteration % 2 == 0):
                print('{} iteration: training ...'.format(iteration))
                ans = {
                    'epoch': epoch,
                    'iteration': iteration,
                    'cls_loss': classification_loss.item(),
                    'reg_loss': regression_loss.item(),
                    'mean_loss': np.mean(total_loss)
                }
                for key, value in ans.items():
                    print('    {:15s}: {}'.format(str(key), value))
            iteration += 1
            if iteration > args.iter:
                break
        profiler.stop()
        scheduler.step(np.mean(total_loss))
        result = {
            'time': time.time() - start,
            'loss': np.mean(total_loss)
        }
        for key, value in result.items():
            print('    {:15s}: {}'.format(str(key), value))
Beispiel #4
0
    def iterate(self, src, tgt, update=True, training=True):
        """
        Performs one iteration of the training/validation.

        :param src: batch of examples from the source language
        :param tgt: batch of examples from the target language
        :param update: if True: optimizer does update of the weights
        :param training: if True: executes optimizer
        """
        pyprof2.init()
        src, src_length = src
        tgt, tgt_length = tgt
        src = src.to(self.device)
        tgt = tgt.to(self.device)
        src_length = src_length.to(self.device)

        num_toks = {}
        num_toks['tgt'] = int(sum(tgt_length - 1))
        num_toks['src'] = int(sum(src_length))

        with torch.autograd.profiler.emit_nvtx():
            profiler.start()

            if self.batch_first:
                output = self.model(src, src_length, tgt[:, :-1])
                tgt_labels = tgt[:, 1:]
                T, B = output.size(1), output.size(0)
            else:
                output = self.model(src, src_length, tgt[:-1])
                tgt_labels = tgt[1:]
                T, B = output.size(0), output.size(1)

            loss = self.criterion(output.view(T * B, -1),
                                  tgt_labels.contiguous().view(-1))

            loss_per_batch = loss.item()
            loss /= (B * self.iter_size)

            if training:
                self.fp_optimizer.step(loss, self.optimizer, self.scheduler,
                                       update)

            loss_per_token = loss_per_batch / num_toks['tgt']
            loss_per_sentence = loss_per_batch / B

            profiler.stop()

        print('You can stop now')
        exit()

        return loss_per_token, loss_per_sentence, num_toks
Beispiel #5
0
def main():
	args = parseArgs()

	pyprof2.init()
	pyprof2.wrap(fused_adam_cuda, 'adam')

	N = args.b
	C = 3
	H = d[args.m]['H']
	W = d[args.m]['W']
	opts = d[args.m]['opts']
	classes = 1000

	net = getattr(models, args.m)
	net = net(**opts).cuda().half()
	net.train()

	x = torch.rand(N, C, H, W).cuda().half()
	target = torch.empty(N, dtype=torch.long).random_(classes).cuda()

	criterion = nn.CrossEntropyLoss().cuda()
	if (args.o == "sgd"):
		optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)
	elif (args.o == "adam"):
		optimizer = FusedAdam(net.parameters())
		#optimizer = FP16_Optimizer(optimizer)
	else:
		assert False

	#Warm up without profiler
	for i in range(2):
		output = net(x)
		loss = criterion(output, target)
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

	with torch.autograd.profiler.emit_nvtx():
		profiler.start()
		output = net(x)
		loss = criterion(output, target)
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		profiler.stop()
Beispiel #6
0
def train_one_epoch(model, dataloader, optimizer, args, epoch):
    with torch.autograd.profiler.emit_nvtx():
        model.train()
        tloss = 0.
        tcnt = 0.
        st_time = time.time()
        i = 0
        with tqdm(dataloader, desc='Train Ep ' + str(epoch),
                  mininterval=60) as tq:
            for batch in tq:
                if i == iter_to_capture:
                    profiler.start()
                pred = model(batch)
                nll_loss = F.nll_loss(pred.view(-1, pred.shape[-1]),
                                      batch['tgt_text'].view(-1),
                                      ignore_index=0)
                loss = nll_loss
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), args.clip)
                optimizer.step()
                loss = loss.item()
                if loss != loss:
                    raise ValueError('NaN appear')
                tloss += loss * len(batch['tgt_text'])
                tcnt += len(batch['tgt_text'])
                tq.set_postfix({'loss': tloss / tcnt}, refresh=False)
                if i == iter_to_capture:
                    profiler.stop()
                i = i + 1
                if (i == 1):
                    break
        print('Train Ep ', str(epoch), 'AVG Loss ', tloss / tcnt, 'Steps ',
              tcnt, 'Time ',
              time.time() - st_time, 'GPU',
              torch.cuda.max_memory_cached() / 1024.0 / 1024.0 / 1024.0)
        torch.save(model, args.save_model + str(epoch % 100))
def train(model, loss_fn, optimizer, data_loader_train, data_loader_test,
          scaled_lr):
    """Train and evaluate the model

    Args:
        model (dlrm):
        loss_fn (torch.nn.Module): Loss function
        optimizer (torch.nn.optim):
        data_loader_train (torch.utils.data.DataLoader):
        data_loader_test (torch.utils.data.DataLoader):
    """
    model.train()
    prefetching_enabled = is_data_prefetching_enabled()
    base_device = FLAGS.base_device
    print_freq = FLAGS.print_freq
    steps_per_epoch = len(data_loader_train)

    checkpoint_writer = make_serial_checkpoint_writer(
        embedding_indices=range(len(get_categorical_feature_sizes(FLAGS))),
        config=FLAGS.flag_values_dict())

    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    if prefetching_enabled:
        data_stream = torch.cuda.Stream()

    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()

    timer.click()

    for epoch in range(FLAGS.epochs):
        input_pipeline = iter(data_loader_train)

        if prefetching_enabled:
            input_pipeline = prefetcher(input_pipeline, data_stream)

        for step, batch in enumerate(input_pipeline):
            global_step = steps_per_epoch * epoch + step
            numerical_features, categorical_features, click = batch

            utils.lr_step(optimizer,
                          num_warmup_iter=FLAGS.warmup_steps,
                          current_step=global_step + 1,
                          base_lr=scaled_lr,
                          warmup_factor=FLAGS.warmup_factor,
                          decay_steps=FLAGS.decay_steps,
                          decay_start_step=FLAGS.decay_start_step)

            if FLAGS.mode == 'prof-train' and global_step == FLAGS.benchmark_warmup_steps:
                profiler.start()

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                if FLAGS.mode == 'prof-train':
                    profiler.stop()
                print(
                    f"Reached max global steps of {FLAGS.max_steps}. Stopping."
                )
                break

            if prefetching_enabled:
                torch.cuda.synchronize()

            output = model(numerical_features,
                           categorical_features).squeeze().float()

            loss = loss_fn(output, click.squeeze())

            # Setting grad to None is faster than zero_grad()
            for param_group in optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = None

            if FLAGS.amp:
                loss *= FLAGS.loss_scale
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()

            if step % print_freq == 0 and step > 0:
                loss_value = loss.item()

                timer.click()

                if global_step < FLAGS.benchmark_warmup_steps:
                    metric_logger.update(loss=loss_value,
                                         lr=optimizer.param_groups[0]["lr"])
                else:
                    unscale_factor = FLAGS.loss_scale if FLAGS.amp else 1
                    metric_logger.update(
                        loss=loss_value / unscale_factor,
                        step_time=timer.measured / FLAGS.print_freq,
                        lr=optimizer.param_groups[0]["lr"] * unscale_factor)

                if global_step < FLAGS.benchmark_warmup_steps:
                    print(
                        f'Warming up, step [{global_step}/{FLAGS.benchmark_warmup_steps}]'
                    )
                    continue

                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.global_avg *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )

            if (global_step % test_freq == 0 and global_step > 0
                    and global_step / steps_per_epoch >= FLAGS.test_after):
                loss, auc, test_step_time = evaluate(model, loss_fn,
                                                     data_loader_test)
                print(
                    f"Epoch {epoch} step {step}. Test loss {loss:.5f}, auc {auc:.6f}"
                )

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)
                    maybe_save_checkpoint(checkpoint_writer, model,
                                          FLAGS.save_checkpoint_path)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    stop_time = time()
                    run_time_s = int(stop_time - start_time)
                    print(
                        f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        f"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
                        f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    return

    stop_time = time()
    run_time_s = int(stop_time - start_time)

    print(f"Finished training in {run_time_s}s.")

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    results = {
        'best_auc': best_auc,
        'best_epoch': best_epoch,
        'average_train_throughput': avg_throughput
    }

    if 'test_step_time' in locals():
        avg_test_throughput = FLAGS.test_batch_size / test_step_time
        results['average_test_throughput'] = avg_test_throughput

    dllogger.log(data=results, step=tuple())
 def on_train_end(self, trainer, pl_module):
     if self.profile:
         profiler.stop()
     self.log()
Beispiel #9
0
def main():
    args = parse_args()

    assert (torch.cuda.is_available())
    assert args.prediction_frequency % args.log_frequency == 0

    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    # set up distributed training
    multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1
    if multi_gpu:
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend='nccl', init_method='env://')
        world_size = dist.get_world_size()
        print_once(f'Distributed training with {world_size} GPUs\n')
    else:
        world_size = 1

    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    random.seed(args.seed + args.local_rank)

    init_log(args)

    cfg = config.load(args.model_config)
    config.apply_config_overrides(cfg, args)

    symbols = helpers.add_ctc_blank(cfg['labels'])

    assert args.grad_accumulation_steps >= 1
    assert args.batch_size % args.grad_accumulation_steps == 0
    batch_size = args.batch_size // args.grad_accumulation_steps

    print_once('Setting up datasets...')
    train_dataset_kw, train_features_kw = config.input(cfg, 'train')
    val_dataset_kw, val_features_kw = config.input(cfg, 'val')

    use_dali = args.dali_device in ('cpu', 'gpu')
    if use_dali:
        assert train_dataset_kw['ignore_offline_speed_perturbation'], \
            "DALI doesn't support offline speed perturbation"

        # pad_to_max_duration is not supported by DALI - have simple padders
        if train_features_kw['pad_to_max_duration']:
            train_feat_proc = BaseFeatures(
                pad_align=train_features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=train_features_kw['max_duration'],
                sample_rate=train_features_kw['sample_rate'],
                window_size=train_features_kw['window_size'],
                window_stride=train_features_kw['window_stride'])
            train_features_kw['pad_to_max_duration'] = False
        else:
            train_feat_proc = None

        if val_features_kw['pad_to_max_duration']:
            val_feat_proc = BaseFeatures(
                pad_align=val_features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=val_features_kw['max_duration'],
                sample_rate=val_features_kw['sample_rate'],
                window_size=val_features_kw['window_size'],
                window_stride=val_features_kw['window_stride'])
            val_features_kw['pad_to_max_duration'] = False
        else:
            val_feat_proc = None

        train_loader = DaliDataLoader(
            gpu_id=args.local_rank,
            dataset_path=args.dataset_dir,
            config_data=train_dataset_kw,
            config_features=train_features_kw,
            json_names=args.train_manifests,
            batch_size=batch_size,
            grad_accumulation_steps=args.grad_accumulation_steps,
            pipeline_type="train",
            device_type=args.dali_device,
            symbols=symbols)

        val_loader = DaliDataLoader(gpu_id=args.local_rank,
                                    dataset_path=args.dataset_dir,
                                    config_data=val_dataset_kw,
                                    config_features=val_features_kw,
                                    json_names=args.val_manifests,
                                    batch_size=batch_size,
                                    pipeline_type="val",
                                    device_type=args.dali_device,
                                    symbols=symbols)
    else:
        train_dataset_kw, train_features_kw = config.input(cfg, 'train')
        train_dataset = AudioDataset(args.dataset_dir, args.train_manifests,
                                     symbols, **train_dataset_kw)
        train_loader = get_data_loader(train_dataset,
                                       batch_size,
                                       multi_gpu=multi_gpu,
                                       shuffle=True,
                                       num_workers=4)
        train_feat_proc = FilterbankFeatures(**train_features_kw)

        val_dataset_kw, val_features_kw = config.input(cfg, 'val')
        val_dataset = AudioDataset(args.dataset_dir, args.val_manifests,
                                   symbols, **val_dataset_kw)
        val_loader = get_data_loader(val_dataset,
                                     batch_size,
                                     multi_gpu=multi_gpu,
                                     shuffle=False,
                                     num_workers=4,
                                     drop_last=False)
        val_feat_proc = FilterbankFeatures(**val_features_kw)

        dur = train_dataset.duration / 3600
        dur_f = train_dataset.duration_filtered / 3600
        nsampl = len(train_dataset)
        print_once(f'Training samples: {nsampl} ({dur:.1f}h, '
                   f'filtered {dur_f:.1f}h)')

    if train_feat_proc is not None:
        train_feat_proc.cuda()
    if val_feat_proc is not None:
        val_feat_proc.cuda()

    steps_per_epoch = len(train_loader) // args.grad_accumulation_steps

    # set up the model
    model = Jasper(encoder_kw=config.encoder(cfg),
                   decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
    model.cuda()
    ctc_loss = CTCLossNM(n_classes=len(symbols))
    greedy_decoder = GreedyCTCDecoder()

    print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n')

    # optimization
    kw = {'lr': args.lr, 'weight_decay': args.weight_decay}
    if args.optimizer == "novograd":
        optimizer = Novograd(model.parameters(), **kw)
    elif args.optimizer == "adamw":
        optimizer = AdamW(model.parameters(), **kw)
    else:
        raise ValueError(f'Invalid optimizer "{args.optimizer}"')

    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    adjust_lr = lambda step, epoch, optimizer: lr_policy(
        step,
        epoch,
        args.lr,
        optimizer,
        steps_per_epoch=steps_per_epoch,
        warmup_epochs=args.warmup_epochs,
        hold_epochs=args.hold_epochs,
        num_epochs=args.epochs,
        policy=args.lr_policy,
        min_lr=args.min_lr,
        exp_gamma=args.lr_exp_gamma)

    if args.ema > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if multi_gpu:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    if args.pyprof:
        pyprof.init(enable_function_stack=True)

    # load checkpoint
    meta = {'best_wer': 10**6, 'start_epoch': 0}
    checkpointer = Checkpointer(args.output_dir, 'Jasper',
                                args.keep_milestones)
    if args.resume:
        args.ckpt = checkpointer.last_checkpoint() or args.ckpt

    if args.ckpt is not None:
        checkpointer.load(args.ckpt, model, ema_model, optimizer, scaler, meta)

    start_epoch = meta['start_epoch']
    best_wer = meta['best_wer']
    epoch = 1
    step = start_epoch * steps_per_epoch + 1

    if args.pyprof:
        torch.autograd.profiler.emit_nvtx().__enter__()
        profiler.start()

    # training loop
    model.train()

    # pre-allocate
    if args.pre_allocate_range is not None:
        n_feats = train_features_kw['n_filt']
        pad_align = train_features_kw['pad_align']
        a, b = args.pre_allocate_range
        for n_frames in range(a, b + pad_align, pad_align):
            print_once(
                f'Pre-allocation ({batch_size}x{n_feats}x{n_frames})...')

            feat = torch.randn(batch_size, n_feats, n_frames, device='cuda')
            feat_lens = torch.ones(batch_size, device='cuda').fill_(n_frames)
            txt = torch.randint(high=len(symbols) - 1,
                                size=(batch_size, 100),
                                device='cuda')
            txt_lens = torch.ones(batch_size, device='cuda').fill_(100)
            with torch.cuda.amp.autocast(enabled=args.amp):
                log_probs, enc_lens = model(feat, feat_lens)
                del feat
                loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
            loss.backward()
            model.zero_grad()
    torch.cuda.empty_cache()

    bmark_stats = BenchmarkStats()

    for epoch in range(start_epoch + 1, args.epochs + 1):
        if multi_gpu and not use_dali:
            train_loader.sampler.set_epoch(epoch)

        epoch_utts = 0
        epoch_loss = 0
        accumulated_batches = 0
        epoch_start_time = time.time()
        epoch_eval_time = 0

        for batch in train_loader:

            if accumulated_batches == 0:
                step_loss = 0
                step_utts = 0
                step_start_time = time.time()

            if use_dali:
                # with DALI, the data is already on GPU
                feat, feat_lens, txt, txt_lens = batch
                if train_feat_proc is not None:
                    feat, feat_lens = train_feat_proc(feat, feat_lens)
            else:
                batch = [t.cuda(non_blocking=True) for t in batch]
                audio, audio_lens, txt, txt_lens = batch
                feat, feat_lens = train_feat_proc(audio, audio_lens)

            # Use context manager to prevent redundant accumulation of gradients
            if (multi_gpu and
                    accumulated_batches + 1 < args.grad_accumulation_steps):
                ctx = model.no_sync()
            else:
                ctx = empty_context()

            with ctx:
                with torch.cuda.amp.autocast(enabled=args.amp):
                    log_probs, enc_lens = model(feat, feat_lens)

                    loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
                    loss /= args.grad_accumulation_steps

                if multi_gpu:
                    reduced_loss = reduce_tensor(loss.data, world_size)
                else:
                    reduced_loss = loss

                if torch.isnan(reduced_loss).any():
                    print_once(f'WARNING: loss is NaN; skipping update')
                    continue
                else:
                    step_loss += reduced_loss.item()
                    step_utts += batch[0].size(0) * world_size
                    epoch_utts += batch[0].size(0) * world_size
                    accumulated_batches += 1

                    scaler.scale(loss).backward()

            if accumulated_batches % args.grad_accumulation_steps == 0:
                epoch_loss += step_loss
                scaler.step(optimizer)
                scaler.update()

                adjust_lr(step, epoch, optimizer)
                optimizer.zero_grad()

                apply_ema(model, ema_model, args.ema)

                if step % args.log_frequency == 0:
                    preds = greedy_decoder(log_probs)
                    wer, pred_utt, ref = greedy_wer(preds, txt, txt_lens,
                                                    symbols)

                    if step % args.prediction_frequency == 0:
                        print_once(f'  Decoded:   {pred_utt[:90]}')
                        print_once(f'  Reference: {ref[:90]}')

                    step_time = time.time() - step_start_time
                    log(
                        (epoch, step % steps_per_epoch
                         or steps_per_epoch, steps_per_epoch), step, 'train', {
                             'loss': step_loss,
                             'wer': 100.0 * wer,
                             'throughput': step_utts / step_time,
                             'took': step_time,
                             'lrate': optimizer.param_groups[0]['lr']
                         })

                step_start_time = time.time()

                if step % args.eval_frequency == 0:
                    tik = time.time()
                    wer = evaluate(epoch, step, val_loader, val_feat_proc,
                                   symbols, model, ema_model, ctc_loss,
                                   greedy_decoder, args.amp, use_dali)

                    if wer < best_wer and epoch >= args.save_best_from:
                        checkpointer.save(model,
                                          ema_model,
                                          optimizer,
                                          scaler,
                                          epoch,
                                          step,
                                          best_wer,
                                          is_best=True)
                        best_wer = wer
                    epoch_eval_time += time.time() - tik

                step += 1
                accumulated_batches = 0
                # end of step

            # DALI iterator need to be exhausted;
            # if not using DALI, simulate drop_last=True with grad accumulation
            if not use_dali and step > steps_per_epoch * epoch:
                break

        epoch_time = time.time() - epoch_start_time
        epoch_loss /= steps_per_epoch
        log(
            (epoch, ), None, 'train_avg', {
                'throughput': epoch_utts / epoch_time,
                'took': epoch_time,
                'loss': epoch_loss
            })
        bmark_stats.update(epoch_utts, epoch_time, epoch_loss)

        if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
            checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
                              best_wer)

        if 0 < args.epochs_this_job <= epoch - start_epoch:
            print_once(f'Finished after {args.epochs_this_job} epochs.')
            break
        # end of epoch

    if args.pyprof:
        profiler.stop()
        torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

    log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num))

    if epoch == args.epochs:
        evaluate(epoch, step, val_loader, val_feat_proc, symbols, model,
                 ema_model, ctc_loss, greedy_decoder, args.amp, use_dali)

        checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
                          best_wer)
    flush_log()
def train(args, trainer, epoch_itr):
    """Train the model for one epoch."""

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr()

    # update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    max_update = args.max_update or math.inf
    num_batches = len(epoch_itr)
    begin = time.time()

    # reset meters
    DLLogger.flush()
    trainer.get_throughput_meter().reset()

    for i, sample in enumerate(itr):
        # Profiling ---------
        if trainer.get_num_updates() == args.profiler_start_iter:
            profiler.start()
        # -------------------

        if i < num_batches - 1 and (i + 1) % update_freq > 0:
            # buffer updates according to --update-freq
            trainer.train_step(sample,
                               update_params=False,
                               last_step=(i == len(itr) - 1))
            continue
        else:
            trainer.train_step(sample,
                               update_params=True,
                               last_step=(i == len(itr) - 1))

        # Profiling ---------
        if trainer.get_num_updates(
        ) == args.profiler_start_iter + args.profiler_steps:
            profiler.stop()
        # -------------------

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_throughput_meter().reset()
            reset_perf_meters()

        if (i + 1) % args.log_interval == 0:
            DLLogger.flush()

        if trainer.get_num_updates() >= max_update:
            break

    print('Epoch time:', time.time() - begin)

    # Print epoch stats and reset training meters
    DLLogger.log(step=trainer.get_num_updates(),
                 data={'speed': trainer.get_throughput_meter().avg},
                 verbosity=0)
    DLLogger.flush()
Beispiel #11
0
def train(model, loss_fn, optimizer, data_loader_train, data_loader_test,
          scaled_lr):
    """Train and evaluate the model

    Args:
        model (dlrm):
        loss_fn (torch.nn.Module): Loss function
        optimizer (torch.nn.optim):
        data_loader_train (torch.utils.data.DataLoader):
        data_loader_test (torch.utils.data.DataLoader):
    """
    model.train()
    base_device = FLAGS.base_device
    print_freq = FLAGS.print_freq
    steps_per_epoch = len(data_loader_train)

    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=print_freq, fmt='{avg:.4f}'))
    metric_logger.add_meter(
        'step_time',
        utils.SmoothedValue(window_size=print_freq, fmt='{avg:.6f}'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()
    for epoch in range(FLAGS.epochs):

        batch_iter = iter(data_loader_train)
        for step in range(len(data_loader_train)):
            if step == 10:
                profiler.start()

            timer.click()

            global_step = steps_per_epoch * epoch + step

            numerical_features, categorical_features, click = next(batch_iter)

            categorical_features = categorical_features.to(base_device).to(
                torch.long)
            numerical_features = numerical_features.to(base_device)
            click = click.to(base_device).to(torch.float32)

            utils.lr_step(optimizer,
                          num_warmup_iter=FLAGS.warmup_steps,
                          current_step=global_step + 1,
                          base_lr=scaled_lr,
                          warmup_factor=FLAGS.warmup_factor,
                          decay_steps=FLAGS.decay_steps,
                          decay_start_step=FLAGS.decay_start_step)

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(
                    F"Reached max global steps of {FLAGS.max_steps}. Stopping."
                )
                break

            output = model(numerical_features,
                           categorical_features).squeeze().float()

            loss = loss_fn(output, click.squeeze())

            optimizer.zero_grad()
            if FLAGS.fp16:
                loss *= FLAGS.loss_scale
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()

            loss_value = loss.item()

            if timer.measured is None:
                # first iteration, no step time etc. to print
                continue

            if global_step < FLAGS.benchmark_warmup_steps:
                metric_logger.update(loss=loss_value,
                                     lr=optimizer.param_groups[0]["lr"])
            else:
                unscale_factor = FLAGS.loss_scale if FLAGS.fp16 else 1
                metric_logger.update(loss=loss_value / unscale_factor,
                                     step_time=timer.measured,
                                     lr=optimizer.param_groups[0]["lr"] *
                                     unscale_factor)

            if step % print_freq == 0 and step > 0:
                if global_step < FLAGS.benchmark_warmup_steps:
                    print(
                        F'Warming up, step [{global_step}/{FLAGS.benchmark_warmup_steps}]'
                    )
                    continue

                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.global_avg *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )

            if (
                    global_step + 1
            ) % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                loss, auc, test_step_time = evaluate(model, loss_fn,
                                                     data_loader_test)
                print(
                    F"Epoch {epoch} step {step}. Test loss {loss:.5f}, auc {auc:.6f}"
                )

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)
                    maybe_save_checkpoint(model, FLAGS.save_checkpoint_path)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    stop_time = time()
                    run_time_s = int(stop_time - start_time)
                    print(
                        F"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        F"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
                        F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    return

            if step == 10:
                profiler.stop()

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    results = {
        'best_auc': best_auc,
        'best_epoch': best_epoch,
        'average_train_throughput': avg_throughput
    }

    if 'test_step_time' in locals():
        avg_test_throughput = FLAGS.test_batch_size / test_step_time
        results['average_test_throughput'] = avg_test_throughput

    dllogger.log(data=results, step=tuple())
def train_epoch(epoch, args, model, optimizer, lr_scheduler, train_sampler,
                train_loader, v_psnr, v_ssim, v_ie, v_loss, block):
    # Average loss calculator.
    loss_values = utils.AverageMeter()

    # This will ensure the data is shuffled each epoch.
    if train_sampler is not None:
        train_sampler.set_epoch(epoch)

    # Get number of batches in one epoch.
    num_batches = len(train_loader) if args.train_n_batches < 0 \
        else args.train_n_batches

    with torch.autograd.profiler.emit_nvtx():
        global_index = 0
        for i, batch in enumerate(train_loader):
            if i == args.prof_iter:
                profiler.start()

            # Set global index.
            global_index = epoch * num_batches + i

            # Move one step.
            loss, outputs, _ = train_step(
                batch, model, optimizer, block, args,
                ((global_index + 1) % args.print_freq == 0))

            # Update the loss accumulator.
            loss_values.update(loss.data.item(), outputs.size(0))

            # Summary writer.
            if (global_index + 1) % args.print_freq == 0:

                # Reduce the loss.
                if args.world_size > 1:
                    t_loss_gpu = torch.Tensor([loss_values.val]).cuda()
                    torch.distributed.all_reduce(t_loss_gpu)
                    t_loss = t_loss_gpu.item() / args.world_size
                else:
                    t_loss = loss_values.val

                # Write to tensorboard.
                write_summary(global_index,
                              lr_scheduler.get_lr()[0], t_loss, v_loss, v_psnr,
                              v_ssim, v_ie, args)

                # And reset the loss accumulator.
                loss_values.reset()

                # Print some output.
                dict2print = {
                    'iter': global_index,
                    'epoch': str(epoch) + '/' + str(args.epochs),
                    'batch': str(i + 1) + '/' + str(num_batches)
                }
                str2print = ' '.join(key + " : " + str(dict2print[key])
                                     for key in dict2print)
                str2print += ' trainLoss:' + ' %1.3f' % t_loss
                str2print += ' valLoss' + ' %1.3f' % v_loss
                str2print += ' valPSNR' + ' %1.3f' % v_psnr
                str2print += ' lr:' + ' %1.6f' % (lr_scheduler.get_lr()[0])
                block.log(str2print)

            if i == args.prof_iter:
                profiler.stop()

            # Break the training loop if we have reached the maximum number of batches.
            if (i + 1) >= num_batches:
                break

        # Advance Learning rate.
        lr_scheduler.step()

        return global_index
def main():

    args = parser.parse_args()
    pyprof.nvtx.init()

    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Running on device {device}")

    corpus = Corpus(args.data)

    eval_batch_size = 10
    train_data = batchify(corpus.train, args.batch_size, device)
    val_data = batchify(corpus.valid, eval_batch_size, device)
    test_data = batchify(corpus.test, eval_batch_size, device)
    print(f"train_data.shape={train_data.shape}")
    print(f"val_data.shape={val_data.shape}")
    print(f"test_data.shape={test_data.shape}")

    ntokens = len(corpus.dictionary)
    print(f"ntokens={ntokens}")
    # model = model.TransformerModel(
    model = TransformerModel(
        ntokens,
        args.emsize,
        args.nhead,
        args.nhid,
        args.nlayers,
        args.bptt,
        args.dropout,
    ).cuda()
    # ).to(device)
    print(model)
    criterion = nn.CrossEntropyLoss()
    print(criterion)
    print(f"Using tokens={ntokens}, emsize={args.emsize}, nhid={args.emsize}")
    print(f"""ntokens={ntokens}, emsize={args.emsize}, 
    nhead={args.nhead}, nhid={args.nhid}, nlayers={args.nlayers}, 
    bpttt={args.bptt}, dropout={args.dropout}
    """)

    iter_to_capture = 1

    # Loop over epochs.
    lr = args.lr
    best_val_loss = None

    # At any point you can hit Ctrl + C to break out of training early.

    with torch.autograd.profiler.emit_nvtx():
        for epoch in range(1, args.epochs + 1):
            epoch_start_time = time.time()
            model.train()
            total_loss = 0.0
            start_time = time.time()
            ntokens = len(corpus.dictionary)
            for batch, i in enumerate(
                    range(0,
                          train_data.size(0) - 1, args.bptt)):
                data, targets = get_batch(train_data, i, args)
                # TODO: Use language modelling abstraction with torchtext
                model.zero_grad()
                if (epoch == 1) and (batch == iter_to_capture):
                    profiler.start()
                output = model(data)
                loss = criterion(output.view(-1, ntokens), targets)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
                for p in model.parameters():
                    p.data.add_(-lr, p.grad.data)
                # TODO: Use an optimizer
                if (epoch == 1) and (batch == iter_to_capture):
                    profiler.stop()
                total_loss += loss.item()
                if batch % args.log_interval == 0 and batch > 0:
                    cur_loss = total_loss / args.log_interval
                    elapsed = time.time() - start_time
                    print(
                        "| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | "
                        "loss {:5.2f} | ppl {:8.2f}".format(
                            epoch,
                            batch,
                            len(train_data) // args.bptt,
                            lr,
                            elapsed * 1000 / args.log_interval,
                            cur_loss,
                            math.exp(cur_loss),
                        ))
                    total_loss = 0
                    start_time = time.time()
            val_loss = evaluate(model, val_data, corpus, criterion, args)
            print("-" * 89)
            print(
                "| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | "
                "valid ppl {:8.2f}".format(
                    epoch,
                    (time.time() - epoch_start_time),
                    val_loss,
                    math.exp(val_loss),
                ))
            print("-" * 89)
            # Save the model if the validation loss is the best we've seen so far.
            if not best_val_loss or val_loss < best_val_loss:
                best_val_loss = val_loss
            else:
                # Anneal the learning rate if no improvement has been seen in the validation dataset.
                lr /= 4.0

    # Run on test data.
    test_loss = evaluate(model, test_data, corpus, criterion, args)
    print("=" * 89)
    print("| End of training | test loss {:5.2f} | test ppl {:8.2f}".format(
        test_loss, math.exp(test_loss)))
    print("=" * 89)
Beispiel #14
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
                                     allow_abbrev=False)
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    if args.p_arpabet > 0.0:
        cmudict.initialize(args.cmudict_path, keep_ambiguous=True)

    distributed_run = args.world_size > 1

    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)

    if args.local_rank == 0:
        if not os.path.exists(args.output):
            os.makedirs(args.output)

    log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
    tb_subsets = ['train', 'val']
    if args.ema_decay > 0.0:
        tb_subsets.append('val_ema')

    logger.init(log_fpath,
                args.output,
                enabled=(args.local_rank == 0),
                tb_subsets=tb_subsets)
    logger.parameters(vars(args), tb_subset='train')

    parser = models.parse_model_args('FastPitch', parser)
    args, unk_args = parser.parse_known_args()
    if len(unk_args) > 0:
        raise ValueError(f'Invalid options {unk_args}')

    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if distributed_run:
        init_distributed(args, args.world_size, args.local_rank)

    device = torch.device('cuda' if args.cuda else 'cpu')
    model_config = models.get_model_config('FastPitch', args)
    model = models.get_model('FastPitch', model_config, device)

    attention_kl_loss = AttentionBinarizationLoss()

    # Store pitch mean/std as params to translate from Hz during inference
    model.pitch_mean[0] = args.pitch_mean
    model.pitch_std[0] = args.pitch_std

    kw = dict(lr=args.learning_rate,
              betas=(0.9, 0.98),
              eps=1e-9,
              weight_decay=args.weight_decay)
    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), **kw)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), **kw)
    else:
        raise ValueError

    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    if args.ema_decay > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if distributed_run:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank,
                                        find_unused_parameters=True)

    if args.pyprof:
        pyprof.init(enable_function_stack=True)

    start_epoch = [1]
    start_iter = [0]

    assert args.checkpoint_path is None or args.resume is False, (
        "Specify a single checkpoint source")
    if args.checkpoint_path is not None:
        ch_fpath = args.checkpoint_path
    elif args.resume:
        ch_fpath = last_checkpoint(args.output)
    else:
        ch_fpath = None

    if ch_fpath is not None:
        load_checkpoint(args, model, ema_model, optimizer, scaler, start_epoch,
                        start_iter, model_config, ch_fpath)

    start_epoch = start_epoch[0]
    total_iter = start_iter[0]

    criterion = FastPitchLoss(
        dur_predictor_loss_scale=args.dur_predictor_loss_scale,
        pitch_predictor_loss_scale=args.pitch_predictor_loss_scale,
        attn_loss_scale=args.attn_loss_scale)

    collate_fn = TTSCollate()

    if args.local_rank == 0:
        prepare_tmp(args.pitch_online_dir)

    trainset = TTSDataset(audiopaths_and_text=args.training_files,
                          **vars(args))
    valset = TTSDataset(audiopaths_and_text=args.validation_files,
                        **vars(args))

    if distributed_run:
        train_sampler, shuffle = DistributedSampler(trainset), False
    else:
        train_sampler, shuffle = None, True

    # 4 workers are optimal on DGX-1 (from epoch 2 onwards)
    train_loader = DataLoader(trainset,
                              num_workers=4,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=True,
                              persistent_workers=True,
                              drop_last=True,
                              collate_fn=collate_fn)

    if args.ema_decay:
        mt_ema_params = init_multi_tensor_ema(model, ema_model)

    model.train()

    if args.pyprof:
        torch.autograd.profiler.emit_nvtx().__enter__()
        profiler.start()

    epoch_loss = []
    epoch_mel_loss = []
    epoch_num_frames = []
    epoch_frames_per_sec = []
    epoch_time = []

    torch.cuda.synchronize()
    for epoch in range(start_epoch, args.epochs + 1):
        epoch_start_time = time.perf_counter()

        epoch_loss += [0.0]
        epoch_mel_loss += [0.0]
        epoch_num_frames += [0]
        epoch_frames_per_sec += [0.0]

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        accumulated_steps = 0
        iter_loss = 0
        iter_num_frames = 0
        iter_meta = {}
        iter_start_time = None

        epoch_iter = 0
        num_iters = len(train_loader) // args.grad_accumulation
        for batch in train_loader:

            if accumulated_steps == 0:
                if epoch_iter == num_iters:
                    break
                total_iter += 1
                epoch_iter += 1
                if iter_start_time is None:
                    iter_start_time = time.perf_counter()

                adjust_learning_rate(total_iter, optimizer, args.learning_rate,
                                     args.warmup_steps)

                model.zero_grad(set_to_none=True)

            x, y, num_frames = batch_to_gpu(batch)

            with torch.cuda.amp.autocast(enabled=args.amp):
                y_pred = model(x)
                loss, meta = criterion(y_pred, y)

                if (args.kl_loss_start_epoch is not None
                        and epoch >= args.kl_loss_start_epoch):

                    if args.kl_loss_start_epoch == epoch and epoch_iter == 1:
                        print('Begin hard_attn loss')

                    _, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred
                    binarization_loss = attention_kl_loss(attn_hard, attn_soft)
                    kl_weight = min(
                        (epoch - args.kl_loss_start_epoch) /
                        args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight
                    meta['kl_loss'] = binarization_loss.clone().detach(
                    ) * kl_weight
                    loss += kl_weight * binarization_loss

                else:
                    meta['kl_loss'] = torch.zeros_like(loss)
                    kl_weight = 0
                    binarization_loss = 0

                loss /= args.grad_accumulation

            meta = {k: v / args.grad_accumulation for k, v in meta.items()}

            if args.amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, args.world_size).item()
                reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
                meta = {
                    k: reduce_tensor(v, args.world_size)
                    for k, v in meta.items()
                }
            else:
                reduced_loss = loss.item()
                reduced_num_frames = num_frames.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            accumulated_steps += 1
            iter_loss += reduced_loss
            iter_num_frames += reduced_num_frames
            iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}

            if accumulated_steps % args.grad_accumulation == 0:

                logger.log_grads_tb(total_iter, model)
                if args.amp:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.grad_clip_thresh)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.grad_clip_thresh)
                    optimizer.step()

                if args.ema_decay > 0.0:
                    apply_multi_tensor_ema(args.ema_decay, *mt_ema_params)

                iter_time = time.perf_counter() - iter_start_time
                iter_mel_loss = iter_meta['mel_loss'].item()
                iter_kl_loss = iter_meta['kl_loss'].item()
                epoch_frames_per_sec[-1] += iter_num_frames / iter_time
                epoch_loss[-1] += iter_loss
                epoch_num_frames[-1] += iter_num_frames
                epoch_mel_loss[-1] += iter_mel_loss

                logger.log(
                    (epoch, epoch_iter, num_iters),
                    tb_total_steps=total_iter,
                    subset='train',
                    data=OrderedDict([
                        ('loss', iter_loss), ('mel_loss', iter_mel_loss),
                        ('kl_loss', iter_kl_loss), ('kl_weight', kl_weight),
                        ('frames/s', iter_num_frames / iter_time),
                        ('took', iter_time),
                        ('lrate', optimizer.param_groups[0]['lr'])
                    ]),
                )

                accumulated_steps = 0
                iter_loss = 0
                iter_num_frames = 0
                iter_meta = {}
                iter_start_time = time.perf_counter()

        # Finished epoch
        epoch_loss[-1] /= epoch_iter
        epoch_mel_loss[-1] /= epoch_iter
        epoch_time += [time.perf_counter() - epoch_start_time]
        iter_start_time = None

        logger.log(
            (epoch, ),
            tb_total_steps=None,
            subset='train_avg',
            data=OrderedDict([('loss', epoch_loss[-1]),
                              ('mel_loss', epoch_mel_loss[-1]),
                              ('frames/s',
                               epoch_num_frames[-1] / epoch_time[-1]),
                              ('took', epoch_time[-1])]),
        )

        validate(model, epoch, total_iter, criterion, valset, args.batch_size,
                 collate_fn, distributed_run, batch_to_gpu)

        if args.ema_decay > 0:
            validate(ema_model,
                     epoch,
                     total_iter,
                     criterion,
                     valset,
                     args.batch_size,
                     collate_fn,
                     distributed_run,
                     batch_to_gpu,
                     ema=True)

        maybe_save_checkpoint(args, model, ema_model, optimizer, scaler, epoch,
                              total_iter, model_config)
        logger.flush()

    # Finished training
    if args.pyprof:
        profiler.stop()
        torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

    if len(epoch_loss) > 0:
        # Was trained - average the last 20 measurements
        last_ = lambda l: np.asarray(l[-20:])
        epoch_loss = last_(epoch_loss)
        epoch_mel_loss = last_(epoch_mel_loss)
        epoch_num_frames = last_(epoch_num_frames)
        epoch_time = last_(epoch_time)
        logger.log(
            (),
            tb_total_steps=None,
            subset='train_avg',
            data=OrderedDict([('loss', epoch_loss.mean()),
                              ('mel_loss', epoch_mel_loss.mean()),
                              ('frames/s',
                               epoch_num_frames.sum() / epoch_time.sum()),
                              ('took', epoch_time.mean())]),
        )

    validate(model, None, total_iter, criterion, valset, args.batch_size,
             collate_fn, distributed_run, batch_to_gpu)
def main():
    parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
                                     allow_abbrev=False)
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    distributed_run = args.world_size > 1

    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)

    if args.local_rank == 0:
        if not os.path.exists(args.output):
            os.makedirs(args.output)

    log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
    tb_subsets = ['train', 'val']
    if args.ema_decay > 0.0:
        tb_subsets.append('val_ema')

    logger.init(log_fpath, args.output, enabled=(args.local_rank == 0),
                tb_subsets=tb_subsets)
    logger.parameters(vars(args), tb_subset='train')

    parser = models.parse_model_args('FastPitch', parser)
    args, unk_args = parser.parse_known_args()
    if len(unk_args) > 0:
        raise ValueError(f'Invalid options {unk_args}')

    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if distributed_run:
        init_distributed(args, args.world_size, args.local_rank)

    device = torch.device('cuda' if args.cuda else 'cpu')
    model_config = models.get_model_config('FastPitch', args)
    model = models.get_model('FastPitch', model_config, device)

    # Store pitch mean/std as params to translate from Hz during inference
    with open(args.pitch_mean_std_file, 'r') as f:
        stats = json.load(f)
    model.pitch_mean[0] = stats['mean']
    model.pitch_std[0] = stats['std']

    kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9,
              weight_decay=args.weight_decay)
    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), **kw)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), **kw)
    else:
        raise ValueError

    if args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if args.ema_decay > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if distributed_run:
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank,
            find_unused_parameters=True)

    if args.pyprof:
        pyprof.init(enable_function_stack=True)

    start_epoch = [1]
    start_iter = [0]

    assert args.checkpoint_path is None or args.resume is False, (
        "Specify a single checkpoint source")
    if args.checkpoint_path is not None:
        ch_fpath = args.checkpoint_path
    elif args.resume:
        ch_fpath = last_checkpoint(args.output)
    else:
        ch_fpath = None

    if ch_fpath is not None:
        load_checkpoint(args.local_rank, model, ema_model, optimizer, start_epoch,
                        start_iter, model_config, args.amp, ch_fpath,
                        args.world_size)

    start_epoch = start_epoch[0]
    total_iter = start_iter[0]

    criterion = loss_functions.get_loss_function('FastPitch',
        dur_predictor_loss_scale=args.dur_predictor_loss_scale,
        pitch_predictor_loss_scale=args.pitch_predictor_loss_scale)

    collate_fn = data_functions.get_collate_function('FastPitch')
    trainset = data_functions.get_data_loader('FastPitch',
                                              audiopaths_and_text=args.training_files,
                                              **vars(args))
    valset = data_functions.get_data_loader('FastPitch',
                                            audiopaths_and_text=args.validation_files,
                                            **vars(args))
    if distributed_run:
        train_sampler, shuffle = DistributedSampler(trainset), False
    else:
        train_sampler, shuffle = None, True

    train_loader = DataLoader(trainset, num_workers=16, shuffle=shuffle,
                              sampler=train_sampler, batch_size=args.batch_size,
                              pin_memory=False, drop_last=True,
                              collate_fn=collate_fn)

    batch_to_gpu = data_functions.get_batch_to_gpu('FastPitch')

    if args.ema_decay:
        ema_model_weight_list, model_weight_list, overflow_buf_for_ema = init_multi_tensor_ema(model, ema_model)
    else:
        ema_model_weight_list, model_weight_list, overflow_buf_for_ema = None, None, None

    model.train()

    if args.pyprof:
        torch.autograd.profiler.emit_nvtx().__enter__()
        profiler.start()

    torch.cuda.synchronize()
    for epoch in range(start_epoch, args.epochs + 1):
        epoch_start_time = time.perf_counter()

        epoch_loss = 0.0
        epoch_mel_loss = 0.0
        epoch_num_frames = 0
        epoch_frames_per_sec = 0.0

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        accumulated_steps = 0
        iter_loss = 0
        iter_num_frames = 0
        iter_meta = {}

        epoch_iter = 0
        num_iters = len(train_loader) // args.gradient_accumulation_steps
        for batch in train_loader:

            if accumulated_steps == 0:
                if epoch_iter == num_iters:
                    break
                total_iter += 1
                epoch_iter += 1
                iter_start_time = time.perf_counter()

                adjust_learning_rate(total_iter, optimizer, args.learning_rate,
                                     args.warmup_steps)

                model.zero_grad()

            x, y, num_frames = batch_to_gpu(batch)
            y_pred = model(x, use_gt_durations=True)
            loss, meta = criterion(y_pred, y)

            loss /= args.gradient_accumulation_steps
            meta = {k: v / args.gradient_accumulation_steps
                    for k, v in meta.items()}

            if args.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, args.world_size).item()
                reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
                meta = {k: reduce_tensor(v, args.world_size) for k,v in meta.items()}
            else:
                reduced_loss = loss.item()
                reduced_num_frames = num_frames.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            accumulated_steps += 1
            iter_loss += reduced_loss
            iter_num_frames += reduced_num_frames
            iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}

            if accumulated_steps % args.gradient_accumulation_steps == 0:

                logger.log_grads_tb(total_iter, model)
                if args.amp:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.grad_clip_thresh)
                else:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.grad_clip_thresh)

                optimizer.step()
                apply_multi_tensor_ema(model_weight_list, ema_model_weight_list, args.ema_decay, overflow_buf_for_ema)

                iter_time = time.perf_counter() - iter_start_time
                iter_mel_loss = iter_meta['mel_loss'].item()
                epoch_frames_per_sec += iter_num_frames / iter_time
                epoch_loss += iter_loss
                epoch_num_frames += iter_num_frames
                epoch_mel_loss += iter_mel_loss

                logger.log((epoch, epoch_iter, num_iters),
                           tb_total_steps=total_iter,
                           subset='train',
                           data=OrderedDict([
                               ('loss', iter_loss),
                               ('mel_loss', iter_mel_loss),
                               ('frames/s', iter_num_frames / iter_time),
                               ('took', iter_time),
                               ('lrate', optimizer.param_groups[0]['lr'])]),
                )

                accumulated_steps = 0
                iter_loss = 0
                iter_num_frames = 0
                iter_meta = {}

        # Finished epoch
        epoch_time = time.perf_counter() - epoch_start_time

        logger.log((epoch,),
                   tb_total_steps=None,
                   subset='train_avg',
                   data=OrderedDict([
                       ('loss', epoch_loss / epoch_iter),
                       ('mel_loss', epoch_mel_loss / epoch_iter),
                       ('frames/s', epoch_num_frames / epoch_time),
                       ('took', epoch_time)]),
        )

        validate(model, epoch, total_iter, criterion, valset, args.batch_size,
                 collate_fn, distributed_run, batch_to_gpu,
                 use_gt_durations=True)

        if args.ema_decay > 0:
            validate(ema_model, epoch, total_iter, criterion, valset,
                     args.batch_size, collate_fn, distributed_run, batch_to_gpu,
                     use_gt_durations=True, ema=True)

        if (epoch > 0 and args.epochs_per_checkpoint > 0 and
            (epoch % args.epochs_per_checkpoint == 0) and args.local_rank == 0):

            checkpoint_path = os.path.join(
                args.output, f"FastPitch_checkpoint_{epoch}.pt")
            save_checkpoint(args.local_rank, model, ema_model, optimizer, epoch,
                            total_iter, model_config, args.amp, checkpoint_path)
        logger.flush()

    # Finished training
    if args.pyprof:
        profiler.stop()
        torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

    logger.log((),
               tb_total_steps=None,
               subset='train_avg',
               data=OrderedDict([
                   ('loss', epoch_loss / epoch_iter),
                   ('mel_loss', epoch_mel_loss / epoch_iter),
                   ('frames/s', epoch_num_frames / epoch_time),
                   ('took', epoch_time)]),
    )
    validate(model, None, total_iter, criterion, valset, args.batch_size,
             collate_fn, distributed_run, batch_to_gpu, use_gt_durations=True)

    if (epoch > 0 and args.epochs_per_checkpoint > 0 and
        (epoch % args.epochs_per_checkpoint != 0) and args.local_rank == 0):
        checkpoint_path = os.path.join(
            args.output, f"FastPitch_checkpoint_{epoch}.pt")
        save_checkpoint(args.local_rank, model, ema_model, optimizer, epoch,
                        total_iter, model_config, args.amp, checkpoint_path)
Beispiel #16
0
    #Start profiler
    profiler.start()

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            #self.conv1 = GINConv(dataset.num_node_features, dataset.num_classes)
            #self.conv2 = SAGEConv(16, dataset.num_classes)
            nn1 = Sequential(Linear(dataset.num_node_features, 16), ReLU(),
                             Linear(16, dataset.num_classes))
            #nn2 = Sequential(Linear)
            self.conv1 = GINConv(nn1)

        def forward(self, data):
            x, edge_index = data.x, data.edge_index

            x = self.conv1(x, edge_index)
            #x = F.relu(x)
            #x = F.dropout(x, training=self.training)
            #x = self.conv2(x, edge_index)
            #return x
            return F.log_softmax(x, dim=1)

    device = torch.device('cuda')
    model = Net().to(device)

    data = dataset[0].to(device)
    out = model(data)

    profiler.stop()
Beispiel #17
0
def main(args):
    exp_start_time = time.time()
    global best_prec1
    best_prec1 = 0

    args.distributed = False
    if "WORLD_SIZE" in os.environ:
        args.distributed = int(os.environ["WORLD_SIZE"]) > 1
        args.local_rank = int(os.environ["LOCAL_RANK"])

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        dist.init_process_group(backend="nccl", init_method="env://")
        args.world_size = torch.distributed.get_world_size()

    if args.amp and args.fp16:
        print("Please use only one of the --fp16/--amp flags")
        exit(1)

    if args.seed is not None:
        print("Using seed = {}".format(args.seed))
        torch.manual_seed(args.seed + args.local_rank)
        torch.cuda.manual_seed(args.seed + args.local_rank)
        np.random.seed(seed=args.seed + args.local_rank)
        random.seed(args.seed + args.local_rank)

        def _worker_init_fn(id):
            np.random.seed(seed=args.seed + args.local_rank + id)
            random.seed(args.seed + args.local_rank + id)

    else:

        def _worker_init_fn(id):
            pass

    if args.fp16:
        assert (torch.backends.cudnn.enabled
                ), "fp16 mode requires cudnn backend to be enabled."

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    if args.optimizer_batch_size < 0:
        batch_size_multiplier = 1
    else:
        tbs = args.world_size * args.batch_size
        if args.optimizer_batch_size % tbs != 0:
            print(
                "Warning: simulated batch size {} is not divisible by actual batch size {}"
                .format(args.optimizer_batch_size, tbs))
        batch_size_multiplier = int(args.optimizer_batch_size / tbs)
        print("BSM: {}".format(batch_size_multiplier))

    pretrained_weights = None
    if args.pretrained_weights:
        if os.path.isfile(args.pretrained_weights):
            print("=> loading pretrained weights from '{}'".format(
                args.pretrained_weights))
            pretrained_weights = torch.load(args.pretrained_weights)
        else:
            print("=> no pretrained weights found at '{}'".format(args.resume))

    start_epoch = 0
    # optionally resume from a checkpoint
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(
                args.resume,
                map_location=lambda storage, loc: storage.cuda(args.gpu))
            start_epoch = checkpoint["epoch"]
            best_prec1 = checkpoint["best_prec1"]
            model_state = checkpoint["state_dict"]
            optimizer_state = checkpoint["optimizer"]
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint["epoch"]))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            model_state = None
            optimizer_state = None
    else:
        model_state = None
        optimizer_state = None

    loss = nn.CrossEntropyLoss
    if args.mixup > 0.0:
        loss = lambda: NLLMultiLabelSmooth(args.label_smoothing)
    elif args.label_smoothing > 0.0:
        loss = lambda: LabelSmoothing(args.label_smoothing)

    memory_format = (torch.channels_last if args.memory_format == "nhwc" else
                     torch.contiguous_format)

    model_and_loss = ModelAndLoss(
        (args.arch, args.model_config, args.num_classes),
        loss,
        pretrained_weights=pretrained_weights,
        cuda=True,
        fp16=args.fp16,
        memory_format=memory_format,
    )
    if args.sync_batch_norm:
        model_and_loss = apex.parallel.convert_syncbn_model(model_and_loss)

    # Create data loaders and optimizers as needed
    if args.data_backend == "pytorch":
        get_train_loader = get_pytorch_train_loader
        get_val_loader = get_pytorch_val_loader
    elif args.data_backend == "dali-gpu":
        get_train_loader = get_dali_train_loader(dali_cpu=False)
        get_val_loader = get_dali_val_loader()
    elif args.data_backend == "dali-cpu":
        get_train_loader = get_dali_train_loader(dali_cpu=True)
        get_val_loader = get_dali_val_loader()
    elif args.data_backend == "syntetic":
        get_val_loader = get_syntetic_loader
        get_train_loader = get_syntetic_loader

    train_loader, train_loader_len = get_train_loader(
        args.data,
        args.batch_size,
        args.num_classes,
        args.mixup > 0.0,
        start_epoch=start_epoch,
        workers=args.workers,
        fp16=args.fp16,
        memory_format=memory_format,
    )
    if args.mixup != 0.0:
        train_loader = MixUpWrapper(args.mixup, train_loader)

    val_loader, val_loader_len = get_val_loader(
        args.data,
        args.batch_size,
        args.num_classes,
        False,
        workers=args.workers,
        fp16=args.fp16,
        memory_format=memory_format,
    )

    if not torch.distributed.is_initialized() or torch.distributed.get_rank(
    ) == 0:
        logger = log.Logger(
            args.print_freq,
            [
                dllogger.StdOutBackend(dllogger.Verbosity.DEFAULT,
                                       step_format=log.format_step),
                dllogger.JSONStreamBackend(
                    dllogger.Verbosity.VERBOSE,
                    os.path.join(args.workspace, args.raport_file),
                ),
            ],
            start_epoch=start_epoch - 1,
        )

    else:
        logger = log.Logger(args.print_freq, [], start_epoch=start_epoch - 1)

    logger.log_parameter(args.__dict__, verbosity=dllogger.Verbosity.DEFAULT)

    optimizer = get_optimizer(
        list(model_and_loss.model.named_parameters()),
        args.fp16,
        args.lr,
        args.momentum,
        args.weight_decay,
        nesterov=args.nesterov,
        bn_weight_decay=args.bn_weight_decay,
        state=optimizer_state,
        static_loss_scale=args.static_loss_scale,
        dynamic_loss_scale=args.dynamic_loss_scale,
    )

    if args.lr_schedule == "step":
        lr_policy = lr_step_policy(args.lr, [30, 60, 80],
                                   0.1,
                                   args.warmup,
                                   logger=logger)
    elif args.lr_schedule == "cosine":
        lr_policy = lr_cosine_policy(args.lr,
                                     args.warmup,
                                     args.epochs,
                                     logger=logger)
    elif args.lr_schedule == "linear":
        lr_policy = lr_linear_policy(args.lr,
                                     args.warmup,
                                     args.epochs,
                                     logger=logger)

    if args.amp:
        model_and_loss, optimizer = amp.initialize(
            model_and_loss,
            optimizer,
            opt_level="O1",
            loss_scale="dynamic"
            if args.dynamic_loss_scale else args.static_loss_scale,
        )

    if args.distributed:
        model_and_loss.distributed()

    model_and_loss.load_model_state(model_state)

    profiler.start()
    train_loop(
        model_and_loss,
        optimizer,
        lr_policy,
        train_loader,
        val_loader,
        args.fp16,
        logger,
        should_backup_checkpoint(args),
        use_amp=args.amp,
        batch_size_multiplier=batch_size_multiplier,
        start_epoch=start_epoch,
        end_epoch=(start_epoch +
                   args.run_epochs) if args.run_epochs != -1 else args.epochs,
        best_prec1=best_prec1,
        prof=args.prof,
        skip_training=args.evaluate,
        skip_validation=args.training_only,
        save_checkpoints=args.save_checkpoints and not args.evaluate,
        checkpoint_dir=args.workspace,
        checkpoint_filename=args.checkpoint_filename,
    )
    profiler.stop()
    exp_duration = time.time() - exp_start_time
    if not torch.distributed.is_initialized() or torch.distributed.get_rank(
    ) == 0:
        logger.end()
    print("Experiment ended")
    print(exp_duration)
Beispiel #18
0
def main():
    args = parse_args()
    init_distributed(args)

    if args.local_rank == 0:
        dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                                           filename=args.log_path),
                                dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)])
    else:
        dllogger.init(backends=[])

    dllogger.log(data=vars(args), step='PARAMETER')

    if args.seed is not None:
        torch.manual_seed(args.seed)

    print("Saving results to {}".format(args.checkpoint_dir))
    if not os.path.exists(args.checkpoint_dir) and args.checkpoint_dir != '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    # sync workers before timing
    if args.distributed:
        torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
    torch.cuda.synchronize()

    main_start_time = time.time()

    train_ratings = torch.load(args.data+'/train_ratings.pt', map_location=torch.device('cuda:{}'.format(args.local_rank)))
    test_ratings = torch.load(args.data+'/test_ratings.pt', map_location=torch.device('cuda:{}'.format(args.local_rank)))
    test_negs = torch.load(args.data+'/test_negatives.pt', map_location=torch.device('cuda:{}'.format(args.local_rank)))

    valid_negative = test_negs.shape[1]

    nb_maxs = torch.max(train_ratings, 0)[0]
    nb_users = nb_maxs[0].item() + 1
    nb_items = nb_maxs[1].item() + 1

    all_test_users = test_ratings.shape[0]

    test_users, test_items, dup_mask, real_indices = dataloading.create_test_data(test_ratings, test_negs, args)

    # make pytorch memory behavior more consistent later
    torch.cuda.empty_cache()

    # Create model
    model = NeuMF(nb_users, nb_items,
                  mf_dim=args.factors,
                  mlp_layer_sizes=args.layers,
                  dropout=args.dropout)

    optimizer = FusedAdam(model.parameters(), lr=args.learning_rate,
                          betas=(args.beta1, args.beta2), eps=args.eps)

    criterion = nn.BCEWithLogitsLoss(reduction='none') # use torch.mean() with dim later to avoid copy to host
    # Move model and loss to GPU
    model = model.cuda()
    criterion = criterion.cuda()

    local_batch = args.batch_size // args.world_size
    #traced_criterion = torch.jit.trace(criterion.forward,
    #                                   (torch.rand(local_batch,1),torch.rand(local_batch,1)))
    traced_criterion = criterion

    pyprof.init()
    #import importlib
    #pyprof.wrap(importlib.import_module(__name__), "traced_criterion")
    #pyprof.wrap(traced_criterion, "__call__")

    if args.opt_level == "O2":
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level,
                                          keep_batchnorm_fp32=False, loss_scale='dynamic')

    if args.distributed:
        model = DDP(model)

    print(model)
    print("{} parameters".format(utils.count_parameters(model)))

    if args.load_checkpoint_path:
        state_dict = torch.load(args.load_checkpoint_path)
        state_dict = {k.replace('module.', '') : v for k,v in state_dict.items()}
        model.load_state_dict(state_dict)

    if args.mode == 'test':
        start = time.time()
        hr, ndcg = val_epoch(model, test_users, test_items, dup_mask, real_indices, args.topk,
                             samples_per_user=valid_negative + 1,
                             num_user=all_test_users, distributed=args.distributed)
        val_time = time.time() - start
        eval_size = all_test_users * (valid_negative + 1)
        eval_throughput = eval_size / val_time

        dllogger.log(step=tuple(), data={'best_eval_throughput' : eval_throughput,
                                         'hr@10' : hr})
        return
    
    max_hr = 0
    best_epoch = 0
    train_throughputs, eval_throughputs = [], []

    with torch.autograd.profiler.emit_nvtx():

        for epoch in range(args.epochs):
        
            begin = time.time()
        
            epoch_users, epoch_items, epoch_label = dataloading.prepare_epoch_train_data(train_ratings, nb_items, args)
            num_batches = len(epoch_users)
            for i in range(num_batches // args.grads_accumulated):

                if i == 10:
                    profiler.start()

                for j in range(args.grads_accumulated):
                    batch_idx = (args.grads_accumulated * i) + j
                    user = epoch_users[batch_idx]
                    item = epoch_items[batch_idx]
                    label = epoch_label[batch_idx].view(-1,1)
        
                    outputs = model(user, item)
                    nvtx.range_push("layer:Loss")
                    loss = traced_criterion(outputs, label).float()
                    nvtx.range_pop()
                    nvtx.range_push("layer:Mean")
                    loss = torch.mean(loss.view(-1), 0)
                    nvtx.range_pop()
        
                    if args.opt_level == "O2":
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                nvtx.range_push("layer:Adam")
                optimizer.step()
                nvtx.range_pop()

                if i == 10:
                    profiler.stop()
        
                for p in model.parameters():
                    p.grad = None
        
            del epoch_users, epoch_items, epoch_label
            train_time = time.time() - begin
            begin = time.time()
        
            epoch_samples = len(train_ratings) * (args.negative_samples + 1)
            train_throughput = epoch_samples / train_time
            train_throughputs.append(train_throughput)
        
            hr, ndcg = val_epoch(model, test_users, test_items, dup_mask, real_indices, args.topk,
                                 samples_per_user=valid_negative + 1,
                                 num_user=all_test_users, epoch=epoch, distributed=args.distributed)
        
            val_time = time.time() - begin
        
        
            eval_size = all_test_users * (valid_negative + 1)
            eval_throughput = eval_size / val_time
            eval_throughputs.append(eval_throughput)
        
            dllogger.log(step=(epoch,),
                         data = {'train_throughput': train_throughput,
                                 'hr@10': hr,
                                 'train_epoch_time': train_time,
                                 'validation_epoch_time': val_time,
                                 'eval_throughput': eval_throughput})
        
            if hr > max_hr and args.local_rank == 0:
                max_hr = hr
                best_epoch = epoch
                save_checkpoint_path = os.path.join(args.checkpoint_dir, 'model.pth')
                print("New best hr! Saving the model to: ", save_checkpoint_path)
                torch.save(model.state_dict(), save_checkpoint_path)
                best_model_timestamp = time.time()
        
            if args.threshold is not None:
                if hr >= args.threshold:
                    print("Hit threshold of {}".format(args.threshold))
                    break

    if args.local_rank == 0:
        dllogger.log(data={'best_train_throughput': max(train_throughputs),
                           'best_eval_throughput': max(eval_throughputs),
                           'mean_train_throughput': np.mean(train_throughputs),
                           'mean_eval_throughput': np.mean(eval_throughputs),
                           'best_accuracy': max_hr,
                           'best_epoch': best_epoch,
                           'time_to_target': time.time() - main_start_time,
                           'time_to_best_model': best_model_timestamp - main_start_time},
                     step=tuple())