Esempio n. 1
0
 def evaluate(self, model: WaveRNN, val_set: Dataset) -> float:
     model.eval()
     val_loss = 0
     device = next(model.parameters()).device
     for i, (x, y, m) in enumerate(val_set, 1):
         x, m, y = x.to(device), m.to(device), y.to(device)
         with torch.no_grad():
             y_hat = model(x, m)
             if model.mode == 'RAW':
                 y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
             elif model.mode == 'MOL':
                 y = y.float()
             y = y.unsqueeze(-1)
             loss = self.loss_func(y_hat, y)
             val_loss += loss.item()
     return val_loss / len(val_set)
Esempio n. 2
0
 def evaluate(self, model: WaveRNN, val_set: Dataset) -> float:
     model.eval()
     val_loss = 0
     device = next(model.parameters()).device
     for i, batch in enumerate(val_set, 1):
         batch = to_device(batch, device=device)
         x, y, m = batch['x'], batch['y'], batch['mel']
         with torch.no_grad():
             y_hat = model(x, m)
             if model.mode == 'RAW':
                 y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
             elif model.mode == 'MOL':
                 y = y.float()
             y = y.unsqueeze(-1)
             loss = self.loss_func(y_hat, y)
             val_loss += loss.item()
     return val_loss / len(val_set)
Esempio n. 3
0
    def generate_samples(self, model: WaveRNN,
                         session: VocSession) -> Tuple[float, list]:
        """
        Generates audio samples to cherry-pick models. To evaluate audio quality
        we calculate the l1 distance between mels of predictions and targets.
        """
        model.eval()
        mel_losses = []
        gen_wavs = []
        device = next(model.parameters()).device
        for i, sample in enumerate(session.val_set_samples, 1):
            m, x = sample['mel'], sample['x']
            if i > self.train_cfg['num_gen_samples']:
                break
            x = x[0].numpy()
            bits = 16 if self.dsp.voc_mode == 'MOL' else self.dsp.bits
            if self.dsp.mu_law and self.dsp.voc_mode != 'MOL':
                x = DSP.decode_mu_law(x, 2**bits, from_labels=True)
            else:
                x = DSP.label_2_float(x, bits)
            gen_wav = model.generate(mels=m,
                                     batched=self.train_cfg['gen_batched'],
                                     target=self.train_cfg['target'],
                                     overlap=self.train_cfg['overlap'],
                                     mu_law=self.dsp.mu_law,
                                     silent=True)

            gen_wavs.append(gen_wav)
            y_mel = self.dsp.wav_to_mel(x.squeeze(), normalize=False)
            y_mel = torch.tensor(y_mel).to(device)
            y_hat_mel = self.dsp.wav_to_mel(gen_wav, normalize=False)
            y_hat_mel = torch.tensor(y_hat_mel).to(device)
            loss = F.l1_loss(y_hat_mel, y_mel)
            mel_losses.append(loss.item())

            self.writer.add_audio(tag=f'Validation_Samples/target_{i}',
                                  snd_tensor=x,
                                  global_step=model.step,
                                  sample_rate=self.dsp.sample_rate)
            self.writer.add_audio(tag=f'Validation_Samples/generated_{i}',
                                  snd_tensor=gen_wav,
                                  global_step=model.step,
                                  sample_rate=self.dsp.sample_rate)

        return sum(mel_losses) / len(mel_losses), gen_wavs[0]
Esempio n. 4
0
    def generate_samples(self, model: WaveRNN,
                         session: VocSession) -> Tuple[float, list]:
        """
        Generates audio samples to cherry-pick models. To evaluate audio quality
        we calculate the l1 distance between mels of predictions and targets.
        """
        model.eval()
        mel_losses = []
        gen_wavs = []
        device = next(model.parameters()).device
        for i, (m, x) in enumerate(session.val_set_samples, 1):
            if i > hp.voc_gen_num_samples:
                break
            x = x[0].numpy()
            bits = 16 if hp.voc_mode == 'MOL' else hp.bits
            if hp.mu_law and hp.voc_mode != 'MOL':
                x = decode_mu_law(x, 2**bits, from_labels=True)
            else:
                x = label_2_float(x, bits)
            gen_wav = model.generate(mels=m,
                                     save_path=None,
                                     batched=hp.voc_gen_batched,
                                     target=hp.voc_target,
                                     overlap=hp.voc_overlap,
                                     mu_law=hp.mu_law,
                                     silent=True)

            gen_wavs.append(gen_wav)
            y_mel = raw_melspec(x.squeeze())
            y_mel = torch.tensor(y_mel).to(device)
            y_hat_mel = raw_melspec(gen_wav)
            y_hat_mel = torch.tensor(y_hat_mel).to(device)
            loss = F.l1_loss(y_hat_mel, y_mel)
            mel_losses.append(loss.item())

            self.writer.add_audio(tag=f'Validation_Samples/target_{i}',
                                  snd_tensor=x,
                                  global_step=model.step,
                                  sample_rate=hp.sample_rate)
            self.writer.add_audio(tag=f'Validation_Samples/generated_{i}',
                                  snd_tensor=gen_wav,
                                  global_step=model.step,
                                  sample_rate=hp.sample_rate)

        return sum(mel_losses) / len(mel_losses), gen_wavs[0]
Esempio n. 5
0
def main():

    # Parse Arguments
    parser = argparse.ArgumentParser(description='Train WaveRNN Vocoder')
    parser.add_argument('--lr',
                        '-l',
                        type=float,
                        help='[float] override hparams.py learning rate')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        help='[int] override hparams.py batch size')
    parser.add_argument('--force_train',
                        '-f',
                        action='store_true',
                        help='Forces the model to train past total steps')
    parser.add_argument('--gta',
                        '-g',
                        action='store_true',
                        help='train wavernn on GTA features')
    parser.add_argument(
        '--force_cpu',
        '-c',
        action='store_true',
        help='Forces CPU-only training, even when in CUDA capable environment')
    parser.add_argument('--hp_file',
                        metavar='FILE',
                        default='hparams.py',
                        help='The file to use for the hyperparameters')
    args = parser.parse_args()

    hp.configure(args.hp_file)  # load hparams from file
    if args.lr is None:
        args.lr = hp.voc_lr
    if args.batch_size is None:
        args.batch_size = hp.voc_batch_size

    paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

    batch_size = args.batch_size
    force_train = args.force_train
    train_gta = args.gta
    lr = args.lr

    if not args.force_cpu and torch.cuda.is_available():
        device = torch.device('cuda')
        if batch_size % torch.cuda.device_count() != 0:
            raise ValueError(
                '`batch_size` must be evenly divisible by n_gpus!')
    else:
        device = torch.device('cpu')
    print('Using device:', device)

    print('\nInitialising Model...\n')

    # Instantiate WaveRNN Model
    voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                        fc_dims=hp.voc_fc_dims,
                        bits=hp.bits,
                        pad=hp.voc_pad,
                        upsample_factors=hp.voc_upsample_factors,
                        feat_dims=hp.num_mels,
                        compute_dims=hp.voc_compute_dims,
                        res_out_dims=hp.voc_res_out_dims,
                        res_blocks=hp.voc_res_blocks,
                        hop_length=hp.hop_length,
                        sample_rate=hp.sample_rate,
                        mode=hp.voc_mode).to(device)

    # Check to make sure the hop length is correctly factorised
    assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

    optimizer = optim.Adam(voc_model.parameters())
    restore_checkpoint('voc',
                       paths,
                       voc_model,
                       optimizer,
                       create_if_missing=True)

    train_set, test_set = get_vocoder_datasets(paths.data, batch_size,
                                               train_gta)

    total_steps = 10_000_000 if force_train else hp.voc_total_steps

    simple_table([
        ('Remaining', str(
            (total_steps - voc_model.get_step()) // 1000) + 'k Steps'),
        ('Batch Size', batch_size), ('LR', lr),
        ('Sequence Len', hp.voc_seq_len), ('GTA Train', train_gta)
    ])

    loss_func = F.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss

    voc_train_loop(paths, voc_model, loss_func, optimizer, train_set, test_set,
                   lr, total_steps)

    print('Training Complete.')
    print(
        'To continue training increase voc_total_steps in hparams.py or use --force_train'
    )
Esempio n. 6
0
def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer,
                   train_set, test_set, lr, total_steps):
    # Use same device as model parameters
    device = next(model.parameters()).device

    for g in optimizer.param_groups:
        g['lr'] = lr

    total_iters = len(train_set)
    epochs = (total_steps - model.get_step()) // total_iters + 1

    for e in range(1, epochs + 1):

        start = time.time()
        running_loss = 0.

        for i, (x, y, m) in enumerate(train_set, 1):
            x, m, y = x.to(device), m.to(device), y.to(device)

            # Parallelize model onto GPUS using workaround due to python bug
            if device.type == 'cuda' and torch.cuda.device_count() > 1:
                y_hat = data_parallel_workaround(model, x, m)
            else:
                y_hat = model(x, m)

            if model.mode == 'RAW':
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)

            elif model.mode == 'MOL':
                y = y.float()

            y = y.unsqueeze(-1)

            loss = loss_func(y_hat, y)

            optimizer.zero_grad()
            loss.backward()
            if hp.voc_clip_grad_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hp.voc_clip_grad_norm)
                if np.isnan(grad_norm):
                    print('grad_norm was NaN!')
            optimizer.step()

            running_loss += loss.item()
            avg_loss = running_loss / i

            speed = i / (time.time() - start)

            step = model.get_step()
            k = step // 1000

            if step % hp.voc_checkpoint_every == 0:
                gen_testset(model, test_set, hp.voc_gen_at_checkpoint,
                            hp.voc_gen_batched, hp.voc_target, hp.voc_overlap,
                            paths.voc_output)
                ckpt_name = f'wave_step{k}K'
                save_checkpoint('voc',
                                paths,
                                model,
                                optimizer,
                                name=ckpt_name,
                                is_silent=True)

            msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
            stream(msg)

        # Must save latest optimizer state to ensure that resuming training
        # doesn't produce artifacts
        save_checkpoint('voc', paths, model, optimizer, is_silent=True)
        model.log(paths.voc_log, msg)
        print(' ')
Esempio n. 7
0
    def train_session(self, model: WaveRNN, optimizer: Optimizer,
                      session: VocSession, train_gta: bool) -> None:
        current_step = model.get_step()
        training_steps = session.max_step - current_step
        total_iters = len(session.train_set)
        epochs = training_steps // total_iters + 1
        simple_table([(f'Steps ', str(training_steps // 1000) + 'k'),
                      ('Batch Size', session.bs),
                      ('Learning Rate', session.lr),
                      ('Sequence Length', self.train_cfg['seq_len']),
                      ('GTA Training', train_gta)])
        for g in optimizer.param_groups:
            g['lr'] = session.lr

        loss_avg = Averager()
        duration_avg = Averager()
        device = next(
            model.parameters()).device  # use same device as model parameters

        for e in range(1, epochs + 1):
            for i, batch in enumerate(session.train_set, 1):
                start = time.time()
                model.train()
                batch = to_device(batch, device=device)
                x, y = batch['x'], batch['y']
                y_hat = model(x, batch['mel'])
                if model.mode == 'RAW':
                    y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
                elif model.mode == 'MOL':
                    y = batch['y'].float()
                y = y.unsqueeze(-1)

                loss = self.loss_func(y_hat, y)
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), self.train_cfg['clip_grad_norm'])
                optimizer.step()
                loss_avg.add(loss.item())
                step = model.get_step()
                k = step // 1000

                duration_avg.add(time.time() - start)
                speed = 1. / duration_avg.get()
                msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {loss_avg.get():#.4} ' \
                      f'| {speed:#.2} steps/s | Step: {k}k | '

                if step % self.train_cfg['gen_samples_every'] == 0:
                    stream(msg + 'generating samples...')
                    gen_result = self.generate_samples(model, session)
                    if gen_result is not None:
                        mel_loss, gen_wav = gen_result
                        self.writer.add_scalar('Loss/generated_mel_l1',
                                               mel_loss, model.get_step())
                        self.track_top_models(mel_loss, gen_wav, model)

                if step % self.train_cfg['checkpoint_every'] == 0:
                    save_checkpoint(model=model,
                                    optim=optimizer,
                                    config=self.config,
                                    path=self.paths.voc_checkpoints /
                                    f'wavernn_step{k}k.pt')

                self.writer.add_scalar('Loss/train', loss, model.get_step())
                self.writer.add_scalar('Params/batch_size', session.bs,
                                       model.get_step())
                self.writer.add_scalar('Params/learning_rate', session.lr,
                                       model.get_step())

                stream(msg)

            val_loss = self.evaluate(model, session.val_set)
            self.writer.add_scalar('Loss/val', val_loss, model.get_step())
            save_checkpoint(model=model,
                            optim=optimizer,
                            config=self.config,
                            path=self.paths.voc_checkpoints /
                            'latest_model.pt')

            loss_avg.reset()
            duration_avg.reset()
            print(' ')
Esempio n. 8
0
                        feat_dims=hp.num_mels,
                        compute_dims=hp.voc_compute_dims,
                        res_out_dims=hp.voc_res_out_dims,
                        res_blocks=hp.voc_res_blocks,
                        hop_length=hp.hop_length,
                        sample_rate=hp.sample_rate,
                        mode=hp.voc_mode).cuda()

    # Check to make sure the hop length is correctly factorised
    assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

    paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

    voc_model.restore(paths.voc_latest_weights)

    optimiser = optim.Adam(voc_model.parameters())

    train_set, test_set = get_vocoder_datasets(paths.data, batch_size,
                                               train_gta)

    total_steps = 10_000_000 if force_train else hp.voc_total_steps

    simple_table([
        ('Remaining', str(
            (total_steps - voc_model.get_step()) // 1000) + 'k Steps'),
        ('Batch Size', batch_size), ('LR', lr),
        ('Sequence Len', hp.voc_seq_len), ('GTA Train', train_gta)
    ])

    loss_func = F.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss
Esempio n. 9
0
def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer,
                   train_set, test_set, lr, total_steps):
    # Use same device as model parameters
    device = next(model.parameters()).device

    # set learning rate
    for g in optimizer.param_groups:
        g['lr'] = lr

    total_iters = len(train_set)
    epochs = (total_steps - model.get_step()) // total_iters + 1

    total_number_of_batches = len(train_set)

    writer = SummaryWriter("runs/{0}-{1}".format(
        model_name_prefix,
        datetime.now().strftime("%Y%m%d-%H%M%S")))
    scheduler = StepLR(optimizer, step_size=1, gamma=0.983)

    for e in range(EPOCH, epochs + 1):

        start = time.time()
        running_loss = 0.
        avg_loss = 0

        for i, (x, y, m) in enumerate(train_set, 1):
            x, m, y = x.to(device), m.to(device), y.to(device)

            # Parallelize model onto GPUS using workaround due to python bug
            if device.type == 'cuda' and torch.cuda.device_count() > 1:
                y_hat = data_parallel_workaround(model, x, m)
            else:
                y_hat = model(x, m)

            if model.mode == 'RAW':
                y_hat = y_hat.transpose(1, 2).unsqueeze(-1)

            elif model.mode == 'MOL':
                y = y.float()

            y = y.unsqueeze(-1)

            loss = loss_func(y_hat, y)

            optimizer.zero_grad()
            loss.backward()
            if hp.voc_clip_grad_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hp.voc_clip_grad_norm)
            optimizer.step()

            running_loss += loss.item()
            avg_loss = running_loss / i

            speed = i / (time.time() - start)

            step = model.get_step()
            k = step // 1000

            # Write to tensorboard per batch
            writer.add_scalar('Epoch loss', loss.item(),
                              e * total_number_of_batches + i)

            msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
            stream(msg)
        """
        ####################### Testing ############################
        torch.cuda.empty_cache()
        loss_test = 0
        for _, (x_test, y_test, m_test) in enumerate(test_set, 1):
            x_test, m_test, y_test = x_test.to(device), m_test.to(device), y_test.to(device)
            if device.type == 'cuda' and torch.cuda.device_count() > 1:
                raise RuntimeError("Unsupported")
            else:
                y_test_hat = model(x_test, m_test)

            if model.mode == 'RAW':
                y_test_hat = y_test_hat.transpose(1, 2).unsqueeze(-1)
            elif model.mode == 'MOL':
                y_test = y_test.float()

            y_test = y_test.unsqueeze(-1)

            loss_test += loss_func(y_test_hat, y_test).item()
        avg_loss_test = loss_test / len(test_set)
        msg = f'| Epoch: {e}/{epochs} | Test-Loss: {loss_test:.4f} | Test-AvgLoss: {avg_loss_test:.4f} | '
        stream("\n")
        stream(msg)

        writer.add_scalar('Test loss', loss_test, e)
        writer.add_scalar('Average test loss', avg_loss_test, e)
        ############################################################
        """

        # Write to tensorboard per epoch
        writer.add_scalar('Running loss', running_loss, e)
        writer.add_scalar('Average loss', avg_loss, e)

        # Must save latest optimizer state to ensure that resuming training
        # doesn't produce artifacts
        save_checkpoint('voc',
                        paths,
                        model,
                        optimizer,
                        name="{0}-epoch-{1}-loss-{2}".format(
                            model_name_prefix, e, avg_loss),
                        is_silent=True)
        model.log(paths.voc_log, msg)
        print(' ')
        scheduler.step()
        print('Epoch:', e, 'LR:', scheduler.get_lr())
Esempio n. 10
0
    # Instantiate WaveRNN Model
    voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                        fc_dims=hp.voc_fc_dims,
                        bits=hp.bits,
                        pad=hp.voc_pad,
                        upsample_factors=hp.voc_upsample_factors,
                        feat_dims=hp.pase_feats,
                        compute_dims=hp.voc_compute_dims,
                        res_out_dims=hp.voc_res_out_dims,
                        res_blocks=hp.voc_res_blocks,
                        hop_length=hp.hop_length,
                        sample_rate=hp.sample_rate,
                        mode=hp.voc_mode).to(device)

    print(voc_model)
    trainable_params = list(voc_model.parameters())

    paths = GEnhancementPaths(hp.voc_model_id)

    # Load pase model
    print('Building PASE...')
    if hp.pase_cfg is not None:
        pase = PASEInjector(hp.pase_cfg,
                            hp.pase_ckpt,
                            hp.pase_ft,
                            hp.num_mels,
                            hp.pase_feats,
                            paths.voc_checkpoints,
                            global_mode=hp.global_pase,
                            stft_cfg=hp.stft_cfg,
                            stft_ckpt=hp.stft_ckpt)
Esempio n. 11
0
                        compute_dims=hp.voc_compute_dims,
                        res_out_dims=hp.voc_res_out_dims,
                        res_blocks=hp.voc_res_blocks,
                        hop_length=hp.hop_length,
                        sample_rate=hp.sample_rate,
                        pad_val=hp.voc_pad_val,
                        mode=hp.voc_mode).cuda()

    # Check to make sure the hop length is correctly factorised
    assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

    paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

    voc_model.restore(paths.voc_latest_weights)

    optimizer = torch.optim.Adam(voc_model.parameters())

    train_set, test_set = get_vocoder_datasets(paths.data, batch_size, train_gta)

    total_steps = 10_000_000 if force_train else hp.voc_total_steps

    simple_table([('Remaining', str((total_steps - voc_model.get_step())//1000) + 'k Steps'),
                  ('Batch Size', batch_size),
                  ('Initial learning rate', init_lr),
                  ('Final learnging rate', final_lr),
                  ('Sequence Len', hp.voc_seq_len),
                  ('GTA Train', train_gta)])

    loss_func = torch.nn.functional.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss

    voc_train_loop(voc_model, loss_func, optimizer, train_set, test_set, init_lr, final_lr, total_steps)
Esempio n. 12
0
def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer,
                   train_set, test_set, init_lr, final_lr, total_steps):
    # Use same device as model parameters
    device = next(model.parameters()).device

    # for g in optimizer.param_groups: g['lr'] = lr

    total_iters = len(train_set)
    epochs = (total_steps - model.get_step()) // total_iters + 1

    for e in range(1, epochs + 1):

        adjust_learning_rate(optimizer, e, epochs, init_lr,
                             final_lr)  # 初始学习率与最终学习率-Begee
        start = time.time()
        running_loss = 0.

        for i, (x, y, m) in enumerate(train_set, 1):
            x, m, y = x.to(device), m.to(device), y.to(
                device)  # x/y: (Batch, sub_bands, T)

            #########################  MultiBand-WaveRNN   #########################
            if hp.voc_multiband:
                y0 = y[:, 0, :].squeeze(0).unsqueeze(
                    -1)  # y0/y1/y2/y3: (Batch, T, 1)
                y1 = y[:, 1, :].squeeze(0).unsqueeze(-1)
                y2 = y[:, 2, :].squeeze(0).unsqueeze(-1)
                y3 = y[:, 3, :].squeeze(0).unsqueeze(-1)

                y_hat = model(x, m)  # (Batch, T, num_classes, sub_bands)

                if model.mode == 'RAW':
                    y_hat0 = y_hat[:, :, :, 0].transpose(1, 2).unsqueeze(
                        -1)  # (Batch, num_classes, T, 1)
                    y_hat1 = y_hat[:, :, :, 1].transpose(1, 2).unsqueeze(-1)
                    y_hat2 = y_hat[:, :, :, 2].transpose(1, 2).unsqueeze(-1)
                    y_hat3 = y_hat[:, :, :, 3].transpose(1, 2).unsqueeze(-1)

                elif model.mode == 'MOL':
                    y0 = y0.float()
                    y1 = y1.float()
                    y2 = y2.float()
                    y3 = y3.float()

                loss = loss_func(y_hat0, y0) + loss_func(
                    y_hat1, y1) + loss_func(y_hat2, y2) + loss_func(
                        y_hat3, y3)

            #########################  MultiBand-WaveRNN   #########################

            optimizer.zero_grad()
            loss.backward()

            if hp.voc_clip_grad_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hp.voc_clip_grad_norm).cpu()
                if np.isnan(grad_norm):
                    print('grad_norm was NaN!')
            optimizer.step()

            running_loss += loss.item()
            avg_loss = running_loss / i

            speed = i / (time.time() - start)

            step = model.get_step()
            k = step // 1000

            if step % hp.voc_checkpoint_every == 0:
                gen_testset(model, test_set, hp.voc_gen_at_checkpoint,
                            hp.voc_gen_batched, hp.voc_target, hp.voc_overlap,
                            paths.voc_output)
                ckpt_name = f'wave_step{k}K'
                save_checkpoint('voc',
                                paths,
                                model,
                                optimizer,
                                name=ckpt_name,
                                is_silent=True)

            msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
            stream(msg)

        # Must save latest optimizer state to ensure that resuming training
        # doesn't produce artifacts
        save_checkpoint('voc', paths, model, optimizer, is_silent=True)
        model.log(paths.voc_log, msg)
        print(' ')