Esempio n. 1
0
File: train.py Progetto: geneing/TTS
def main(args): #pylint: disable=redefined-outer-name
    # Audio processor
    ap = AudioProcessor(**c.audio)

    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])
    num_chars = len(phonemes) if c.use_phonemes else len(symbols)

    if c.use_speaker_embedding:
        speakers = get_speakers(c.data_path, c.meta_file_train, c.dataset)
        if args.restore_path:
            prev_out_path = os.path.dirname(args.restore_path)
            speaker_mapping = load_speaker_mapping(prev_out_path)
            assert all([speaker in speaker_mapping
                        for speaker in speakers]), "As of now you, you cannot " \
                                                   "introduce new speakers to " \
                                                   "a previously trained model."
        else:
            speaker_mapping = {name: i
                               for i, name in enumerate(speakers)}
        save_speaker_mapping(OUT_PATH, speaker_mapping)
        num_speakers = len(speaker_mapping)
        print("Training with {} speakers: {}".format(num_speakers,
                                                     ", ".join(speakers)))
    else:
        num_speakers = 0

    model = setup_model(num_chars, num_speakers, c)

    print(" | > Num output units : {}".format(ap.num_freq), flush=True)

    #optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
    optimizer = Ranger(model.parameters(), lr=c.lr, weight_decay=c.wd)
    optimizer_gst = Ranger(model.textgst.parameters(), lr=c.lr, weight_decay=c.wd) if c.text_gst else None

    if c.stopnet and c.separate_stopnet:
        optimizer_st = Ranger(model.decoder.stopnet.parameters(), lr=c.lr)
    else:
        optimizer_st = None

    if c.loss_masking:
        criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"] else MSELossMasked()
    else:
        criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"] else nn.MSELoss()
    criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None
    criterion_gst = nn.L1Loss() if c.text_gst else None

    if args.restore_path:
        checkpoint = torch.load(args.restore_path)
        try:
            # TODO: fix optimizer init, model.cuda() needs to be called before
            # optimizer restore
            # optimizer.load_state_dict(checkpoint['optimizer'])
            if c.reinit_layers:
                raise RuntimeError
            model.load_state_dict(checkpoint['model'])
        except:
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint, c)
            model.load_state_dict(model_dict)
            del model_dict
        for group in optimizer.param_groups:
            group['lr'] = c.lr
        print(
            " > Model restored from step %d" % checkpoint['step'], flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model = model.cuda()
        criterion.cuda()
        if criterion_st:
            criterion_st.cuda()

    # DISTRUBUTED
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)

    if c.lr_decay:
        scheduler = NoamLR(
            optimizer,
            warmup_steps=c.warmup_steps,
            last_epoch=args.restore_step - 1)
    else:
        scheduler = None

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        # set gradual training
        if c.gradual_training is not None:
            r, c.batch_size = gradual_training_scheduler(global_step, c)
            c.r = r
            model.decoder.set_r(r)
        print(" > Number of outputs per iteration:", model.decoder.r)

        train_loss, global_step = train(model, criterion, criterion_st,
                                        optimizer, optimizer_st, scheduler,
                                        ap, global_step, epoch, criterion_gst=criterion_gst, optimizer_gst=optimizer_gst)
        
        if epoch % 5 == 0:
            val_loss = evaluate(model, criterion, criterion_st, criterion_gst, ap, global_step, epoch)
            print(
                " | > Training Loss: {:.5f}   Validation Loss: {:.5f}".format(
                    train_loss, val_loss),
                flush=True)
            target_loss = train_loss
            if c.run_eval:
                target_loss = val_loss
            best_loss = save_best_model(model, optimizer, optimizer_st, optimizer_gst, target_loss, best_loss,
                                        OUT_PATH, global_step, epoch)
Esempio n. 2
0
def main(args):
    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])
    num_chars = len(phonemes) if c.use_phonemes else len(symbols)
    model = setup_model(num_chars, c, args.use_half)

    print(" | > Num output units : {}".format(ap.num_freq), flush=True)

    if args.use_half:
        print(' | > Use half mode')

    optimizer_eps = 1e-08 if not args.use_half else 1e-04
    optimizer = optim.Adam(model.parameters(),
                           lr=c.lr,
                           weight_decay=0,
                           eps=optimizer_eps)
    # optimizer = optim.SGD(model.parameters(), lr=c.lr, weight_decay=0)
    if c.stopnet and c.separate_stopnet:
        optimizer_st = optim.Adam(model.decoder.stopnet.parameters(),
                                  lr=c.lr,
                                  weight_decay=0,
                                  eps=optimizer_eps)
        # optimizer_st = optim.SGD(model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
    else:
        optimizer_st = None

    if c.loss_masking:
        criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked(
        )
    else:
        criterion = nn.L1Loss() if c.model == "Tacotron" else nn.MSELoss()
    criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None

    if args.restore_path:
        checkpoint = torch.load(args.restore_path)
        try:
            # TODO: fix optimizer init, model.cuda() needs to be called before
            # optimizer restore
            # optimizer.load_state_dict(checkpoint['optimizer'])
            if len(c.reinit_layers) > 0:
                raise RuntimeError
            model.load_state_dict(checkpoint['model'])
        except:
            print(" > Partial model initialization.")
            partial_init_flag = True
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint, c)
            model.load_state_dict(model_dict)
            del model_dict
        for group in optimizer.param_groups:
            group['lr'] = c.lr
        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        start_epoch = checkpoint['epoch']
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    # use half mode
    if args.use_half:
        model.half()
        for layer in model.modules():
            if isinstance(layer, torch.nn.BatchNorm1d):
                layer.float()

    if use_cuda:
        model = model.cuda()
        criterion.cuda()
        if criterion_st: criterion_st.cuda()
        if args.restore_path:
            # print(checkpoint['optimizer'])
            # print('---opt', optimizer)
            optimizer.load_state_dict(checkpoint['optimizer'])

    # DISTRUBUTED
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)

    # reset lr
    if args.reset_lr:
        for group in optimizer.param_groups:
            group['initial_lr'] = c.lr

    if c.lr_decay:
        scheduler = NoamLR(
            optimizer,
            warmup_steps=c.warmup_steps,
            last_epoch=args.restore_step - 1,
            use_half=args.use_half,
        )
    else:
        scheduler = None

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    for epoch in range(0, c.epochs):
        train_loss, current_step = train(model, criterion, criterion_st,
                                         optimizer, optimizer_st, scheduler,
                                         ap, epoch, args.use_half)
        if c.run_eval:
            val_loss = evaluate(model, criterion, criterion_st, ap,
                                current_step, epoch, args.use_half)
            print(
                " | > Training Loss: {:.5f}   Validation Loss: {:.5f}".format(
                    train_loss, val_loss),
                flush=True)
            target_loss = val_loss
        else:
            print(" | > Training Loss: {:.5f}".format(train_loss), flush=True)
            target_loss = train_loss
        best_loss = save_best_model(model, optimizer, target_loss, best_loss,
                                    OUT_PATH, current_step, epoch)
Esempio n. 3
0
def main(args):
    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])
    num_chars = len(phonemes) if c.use_phonemes else len(symbols)
    model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm)

    print(" | > Num output units : {}".format(ap.num_freq), flush=True)

    optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
    optimizer_st = optim.Adam(
        model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)

    criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
    criterion_st = nn.BCEWithLogitsLoss()

    if args.restore_path:
        checkpoint = torch.load(args.restore_path)
        try:
            # TODO: fix optimizer init, model.cuda() needs to be called before
            # optimizer restore
            # optimizer.load_state_dict(checkpoint['optimizer'])
            if len(c.reinit_layers) > 0:
                raise RuntimeError
            model.load_state_dict(checkpoint['model'])
        except:
            print(" > Partial model initialization.")
            partial_init_flag = True
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint, c)
            model.load_state_dict(model_dict)
            del model_dict
        if use_cuda:
            model = model.cuda()
            criterion.cuda()
            criterion_st.cuda()
        for group in optimizer.param_groups:
            group['lr'] = c.lr
        print(
            " > Model restored from step %d" % checkpoint['step'], flush=True)
        start_epoch = checkpoint['epoch']
        # best_loss = checkpoint['postnet_loss']
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0
        if use_cuda:
            model = model.cuda()
            criterion.cuda()
            criterion_st.cuda()

    # DISTRUBUTED
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)

    if c.lr_decay:
        scheduler = NoamLR(
            optimizer,
            warmup_steps=c.warmup_steps,
            last_epoch=args.restore_step - 1)
    else:
        scheduler = None

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    for epoch in range(0, c.epochs):
        train_loss, current_step = train(model, criterion, criterion_st,
                                         optimizer, optimizer_st, scheduler,
                                         ap, epoch)
        val_loss = evaluate(model, criterion, criterion_st, ap, current_step, epoch)
        print(
            " | > Training Loss: {:.5f}   Validation Loss: {:.5f}".format(
                train_loss, val_loss),
            flush=True)
        target_loss = train_loss
        if c.run_eval:
            target_loss = val_loss
        best_loss = save_best_model(model, optimizer, target_loss, best_loss,
                                    OUT_PATH, current_step, epoch)
Esempio n. 4
0
File: train.py Progetto: wurde/TTS
def main(args):
    model = Tacotron(c.embedding_size, ap.num_freq, ap.num_mels, c.r)
    print(" | > Num output units : {}".format(ap.num_freq), flush=True)

    optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
    optimizer_st = optim.Adam(
        model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)

    criterion = L1LossMasked()
    criterion_st = nn.BCELoss()

    if args.restore_path:
        checkpoint = torch.load(args.restore_path)
        try:
            model.load_state_dict(checkpoint['model'])
        except:
            model_dict = model.state_dict()
            # Partial initialization: if there is a mismatch with new and old layer, it is skipped.
            # 1. filter out unnecessary keys
            pretrained_dict = {
                k: v
                for k, v in checkpoint['model'].items() if k in model_dict
            }
            # 2. overwrite entries in the existing state dict
            model_dict.update(pretrained_dict)
            # 3. load the new state dict
            model.load_state_dict(model_dict)
        if use_cuda:
            model = model.cuda()
            criterion.cuda()
            criterion_st.cuda()
        optimizer.load_state_dict(checkpoint['optimizer'])
        print(
            " > Model restored from step %d" % checkpoint['step'], flush=True)
        start_epoch = checkpoint['epoch']
        best_loss = checkpoint['linear_loss']
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0
        print("\n > Starting a new training", flush=True)
        if use_cuda:
            model = model.cuda()
            criterion.cuda()
            criterion_st.cuda()

    if c.lr_decay:
        scheduler = NoamLR(
            optimizer,
            warmup_steps=c.warmup_steps,
            last_epoch=args.restore_step - 1)
    else:
        scheduler = None

    num_params = count_parameters(model)
    print(" | > Model has {} parameters".format(num_params), flush=True)

    if not os.path.exists(CHECKPOINT_PATH):
        os.mkdir(CHECKPOINT_PATH)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    for epoch in range(0, c.epochs):
        train_loss, current_step = train(model, criterion, criterion_st,
                                         optimizer, optimizer_st,
                                         scheduler, ap, epoch)
        val_loss = evaluate(model, criterion, criterion_st, ap,
                            current_step)
        print(
            " | > Train Loss: {:.5f}   Validation Loss: {:.5f}".format(
                train_loss, val_loss),
            flush=True)
        best_loss = save_best_model(model, optimizer, train_loss, best_loss,
                                    OUT_PATH, current_step, epoch)
Esempio n. 5
0
def main(args):
    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])
    num_chars = len(phonemes) if c.use_phonemes else len(symbols)
    model = Tacotron(num_chars=num_chars,
                     embedding_dim=c.embedding_size,
                     linear_dim=ap.num_freq,
                     mel_dim=ap.num_mels,
                     r=c.r,
                     memory_size=c.memory_size)

    optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
    optimizer_st = optim.Adam(model.decoder.stopnet.parameters(),
                              lr=c.lr,
                              weight_decay=0)

    criterion = L1LossMasked()
    criterion_st = nn.BCELoss()

    if args.restore_path:
        checkpoint = torch.load(args.restore_path)
        try:
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        except:
            print(" > Partial model initialization.")
            partial_init_flag = True
            model_dict = model.state_dict()
            # Partial initialization: if there is a mismatch with new and old layer, it is skipped.
            # 1. filter out unnecessary keys
            pretrained_dict = {
                k: v
                for k, v in checkpoint['model'].items() if k in model_dict
            }
            # 2. filter out different size layers
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items()
                if v.numel() == model_dict[k].numel()
            }
            # 3. overwrite entries in the existing state dict
            model_dict.update(pretrained_dict)
            # 4. load the new state dict
            model.load_state_dict(model_dict)
            print(" | > {} / {} layers are initialized".format(
                len(pretrained_dict), len(model_dict)))
        if use_cuda:
            model = model.cuda()
            criterion.cuda()
            criterion_st.cuda()
        for group in optimizer.param_groups:
            group['lr'] = c.lr
        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        start_epoch = checkpoint['epoch']
        best_loss = checkpoint['linear_loss']
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0
        if use_cuda:
            model = model.cuda()
            criterion.cuda()
            criterion_st.cuda()

    # DISTRUBUTED
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)

    if c.lr_decay:
        scheduler = NoamLR(optimizer,
                           warmup_steps=c.warmup_steps,
                           last_epoch=args.restore_step - 1)
    else:
        scheduler = None

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    for epoch in range(0, c.epochs):
        train_loss, current_step = train(model, criterion, criterion_st,
                                         optimizer, optimizer_st, scheduler,
                                         ap, epoch)
        val_loss = evaluate(model, criterion, criterion_st, ap, current_step,
                            epoch)
        print(" | > Training Loss: {:.5f}   Validation Loss: {:.5f}".format(
            train_loss, val_loss),
              flush=True)
        target_loss = train_loss
        if c.run_eval:
            target_loss = val_loss
        best_loss = save_best_model(model, optimizer, target_loss, best_loss,
                                    OUT_PATH, current_step, epoch)
Esempio n. 6
0
def train(args,
          log_dir,
          checkpoint_path,
          trainloader,
          testloader,
          tensorboard,
          c,
          model_name,
          ap,
          cuda=True):
    padding_with_max_lenght = c.dataset['padding_with_max_lenght']
    if (model_name == 'conv2d'):
        model = conv2d(c)
    #elif(model_name == 'voicesplit'):
    else:
        raise Exception(" The model '" + model_name + "' is not suported")

    if c.train_config['optimizer'] == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=c.train_config['learning_rate'],
            weight_decay=c.train_config['weight_decay'])
    else:
        raise Exception("The %s  not is a optimizer supported" %
                        c.train['optimizer'])

    step = 0
    if checkpoint_path is not None:
        print("Continue training from checkpoint: %s" % checkpoint_path)
        try:
            if c.train_config['reinit_layers']:
                raise RuntimeError
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            model.load_state_dict(checkpoint['model'])
            if cuda:
                model = model.cuda()
        except:
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint, c)
            model.load_state_dict(model_dict)
            del model_dict
        try:
            optimizer.load_state_dict(checkpoint['optimizer'])
        except:
            print(
                " > Optimizer state is not loaded from checkpoint path, you see this mybe you change the optimizer"
            )

        step = checkpoint['step']
    else:
        print("Starting new training run")
        step = 0

    if c.train_config['lr_decay']:
        scheduler = NoamLR(optimizer,
                           warmup_steps=c.train_config['warmup_steps'],
                           last_epoch=step - 1)
    else:
        scheduler = None
    # convert model from cuda
    if cuda:
        model = model.cuda()

    # define loss function
    criterion = nn.BCELoss()
    eval_criterion = nn.BCELoss(reduction='sum')

    best_loss = float('inf')

    model.train()
    for epoch in range(c.train_config['epochs']):
        for feature, target in trainloader:
            if cuda:
                feature = feature.cuda()
                target = target.cuda()

            output = model(feature)

            # Calculate loss
            # adjust target dim
            if not padding_with_max_lenght:
                target = target[:, :output.shape[1], :]
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # update lr decay scheme
            if scheduler:
                scheduler.step()
            step += 1

            loss = loss.item()
            if loss > 1e8 or math.isnan(loss):
                print("Loss exploded to %.02f at step %d!" % (loss, step))
                break

            # write loss to tensorboard
            if step % c.train_config['summary_interval'] == 0:
                tensorboard.log_training(loss, step)
                print("Write summary at step %d" % step, ' Loss: ', loss)

            # save checkpoint file  and evaluate and save sample to tensorboard
            if step % c.train_config['checkpoint_interval'] == 0:
                save_path = os.path.join(log_dir, 'checkpoint_%d.pt' % step)
                torch.save(
                    {
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'step': step,
                        'config_str': str(c),
                    }, save_path)
                print("Saved checkpoint to: %s" % save_path)
                # run validation and save best checkpoint
                val_loss = validation(eval_criterion,
                                      ap,
                                      model,
                                      c,
                                      testloader,
                                      tensorboard,
                                      step,
                                      cuda=cuda)
                best_loss = save_best_checkpoint(log_dir, model, optimizer, c,
                                                 step, val_loss, best_loss)

        print('=================================================')
        print("Epoch %d End !" % epoch)
        print('=================================================')
        # run validation and save best checkpoint at end epoch
        val_loss = validation(eval_criterion,
                              ap,
                              model,
                              c,
                              testloader,
                              tensorboard,
                              step,
                              cuda=cuda)
        best_loss = save_best_checkpoint(log_dir, model, optimizer, c, step,
                                         val_loss, best_loss)
Esempio n. 7
0
def train(args,
          log_dir,
          checkpoint_path,
          trainloader,
          testloader,
          tensorboard,
          c,
          model_name,
          ap,
          cuda=True,
          model_params=None):
    loss1_weight = c.train_config['loss1_weight']
    use_mixup = False if 'mixup' not in c.model else c.model['mixup']
    if use_mixup:
        mixup_alpha = 1 if 'mixup_alpha' not in c.model else c.model[
            'mixup_alpha']
        mixup_augmenter = Mixup(mixup_alpha=mixup_alpha)
        print("Enable Mixup with alpha:", mixup_alpha)

    model = return_model(c, model_params)

    if c.train_config['optimizer'] == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=c.train_config['learning_rate'],
            weight_decay=c.train_config['weight_decay'])
    elif c.train_config['optimizer'] == 'adamw':
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=c.train_config['learning_rate'],
            weight_decay=c.train_config['weight_decay'])
    elif c.train_config['optimizer'] == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=c.train_config['learning_rate'],
                          weight_decay=c.train_config['weight_decay'])
    else:
        raise Exception("The %s  not is a optimizer supported" %
                        c.train['optimizer'])

    step = 0
    if checkpoint_path is not None:
        print("Continue training from checkpoint: %s" % checkpoint_path)
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            model.load_state_dict(checkpoint['model'])
        except:
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint, c)
            model.load_state_dict(model_dict)
            del model_dict
        step = 0
    else:
        print("Starting new training run")
        step = 0

    if c.train_config['lr_decay']:
        scheduler = NoamLR(optimizer,
                           warmup_steps=c.train_config['warmup_steps'],
                           last_epoch=step - 1)
    else:
        scheduler = None
    # convert model from cuda
    if cuda:
        model = model.cuda()

    # define loss function
    if use_mixup:
        criterion = Clip_BCE()
    else:
        criterion = nn.BCELoss()
    eval_criterion = nn.BCELoss(reduction='sum')

    best_loss = float('inf')

    # early stop definitions
    early_epochs = 0

    model.train()
    for epoch in range(c.train_config['epochs']):
        for feature, target in trainloader:

            if cuda:
                feature = feature.cuda()
                target = target.cuda()

            if use_mixup:
                batch_len = len(feature)
                if (batch_len % 2) != 0:
                    batch_len -= 1
                    feature = feature[:batch_len]
                    target = target[:batch_len]

                mixup_lambda = torch.FloatTensor(
                    mixup_augmenter.get_lambda(batch_len)).to(feature.device)
                output = model(feature[:batch_len], mixup_lambda)
                target = do_mixup(target, mixup_lambda)
            else:
                output = model(feature)
            # Calculate loss
            if c.dataset['class_balancer_batch'] and not use_mixup:
                idxs = (target == c.dataset['control_class'])
                loss_control = criterion(output[idxs], target[idxs])
                idxs = (target == c.dataset['patient_class'])
                loss_patient = criterion(output[idxs], target[idxs])
                loss = (loss_control + loss_patient) / 2
            else:
                loss = criterion(output, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # update lr decay scheme
            if scheduler:
                scheduler.step()
            step += 1

            loss = loss.item()
            if loss > 1e8 or math.isnan(loss):
                print("Loss exploded to %.02f at step %d!" % (loss, step))
                break

            # write loss to tensorboard
            if step % c.train_config['summary_interval'] == 0:
                tensorboard.log_training(loss, step)
                if c.dataset['class_balancer_batch'] and not use_mixup:
                    print("Write summary at step %d" % step, ' Loss: ', loss,
                          'Loss control:', loss_control.item(),
                          'Loss patient:', loss_patient.item())
                else:
                    print("Write summary at step %d" % step, ' Loss: ', loss)

            # save checkpoint file  and evaluate and save sample to tensorboard
            if step % c.train_config['checkpoint_interval'] == 0:
                save_path = os.path.join(log_dir, 'checkpoint_%d.pt' % step)
                torch.save(
                    {
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'step': step,
                        'config_str': str(c),
                    }, save_path)
                print("Saved checkpoint to: %s" % save_path)
                # run validation and save best checkpoint
                val_loss = validation(eval_criterion,
                                      ap,
                                      model,
                                      c,
                                      testloader,
                                      tensorboard,
                                      step,
                                      cuda=cuda,
                                      loss1_weight=loss1_weight)
                best_loss, _ = save_best_checkpoint(
                    log_dir, model, optimizer, c, step, val_loss, best_loss,
                    early_epochs
                    if c.train_config['early_stop_epochs'] != 0 else None)

        print('=================================================')
        print("Epoch %d End !" % epoch)
        print('=================================================')
        # run validation and save best checkpoint at end epoch
        val_loss = validation(eval_criterion,
                              ap,
                              model,
                              c,
                              testloader,
                              tensorboard,
                              step,
                              cuda=cuda,
                              loss1_weight=loss1_weight)
        best_loss, early_epochs = save_best_checkpoint(
            log_dir, model, optimizer, c, step, val_loss, best_loss,
            early_epochs if c.train_config['early_stop_epochs'] != 0 else None)
        if c.train_config['early_stop_epochs'] != 0:
            if early_epochs is not None:
                if early_epochs >= c.train_config['early_stop_epochs']:
                    break  # stop train
    return best_loss