def train(model, args, device, writer):
    print('preparing data...')
    dataloader = make_dataloader(
        args.tr_clean_list,
        args.tr_noise_list,
        args.tr_rir_list,
        batch_size=args.batch_size,
        repeate=1,
        segement_length=8,
        sample_rate=args.sample_rate,
        num_workers=args.num_threads,
    )

    print_freq = 100
    num_batch = len(dataloader)
    params = model.get_params(args.weight_decay)
    optimizer = optim.Adam(params, lr=args.learn_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     factor=0.5,
                                                     patience=1,
                                                     verbose=True)

    if args.retrain:
        start_epoch, step = reload_model(model, optimizer, args.exp_dir,
                                         args.use_cuda)
    else:
        start_epoch, step = 0, 0
    print('---------PRERUN-----------')
    lr = get_learning_rate(optimizer)
    print('(Initialization)')
    val_loss, val_sisnr = validation(model, args, lr, -1, device)
    writer.add_scalar('Loss/Train', val_loss, step)
    writer.add_scalar('Loss/Cross-Validation', val_loss, step)

    writer.add_scalar('SISNR/Train', -val_sisnr, step)
    writer.add_scalar('SISNR/Cross-Validation', -val_sisnr, step)

    warmup_epoch = 6
    warmup_lr = args.learn_rate / (4 * warmup_epoch)

    for epoch in range(start_epoch, args.max_epoch):
        torch.manual_seed(args.seed + epoch)
        if args.use_cuda:
            torch.cuda.manual_seed(args.seed + epoch)
        model.train()
        loss_total = 0.0
        loss_print = 0.0

        sisnr_total = 0.0
        sisnr_print = 0.0
        '''
        if epoch == 0 and warmup_epoch > 0:
            print('Use warmup stragery, and the lr is set to {:.5f}'.format(warmup_lr))
            setup_lr(optimizer, warmup_lr)
            warmup_lr *= 4*(epoch+1)
        elif epoch == warmup_epoch:
            print('The warmup was end, and the lr is set to {:.5f}'.format(args.learn_rate))
            setup_lr(optimizer, args.learn_rate)
        '''

        stime = time.time()
        lr = get_learning_rate(optimizer)
        for idx, data in enumerate(dataloader):
            torch.cuda.empty_cache()
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            model.zero_grad()
            est_spec, est_wav = data_parallel(model, (inputs, ))
            '''
            if epoch > 8:
                gth_spec, gth_wav = data_parallel(model, (labels,))
            else:
                gth_spec = data_parallel(model.stft, (labels))[0]
            '''
            #gth_spec = data_parallel(model.stft, (labels))
            #loss = model.loss(est_spec, gth_spec, loss_mode='MSE')
            #loss.backward()
            sisnr = model.loss(est_wav, labels, loss_mode='SI-SNR')
            sisnr.backward()
            nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            optimizer.step()

            step += 1

            #loss_total += loss.data.cpu()
            #loss_print += loss.data.cpu()

            sisnr_total += sisnr.data.cpu()
            sisnr_print += sisnr.data.cpu()

            loss_total = sisnr_total
            loss_print = sisnr_print
            del est_wav, est_spec
            if (idx + 1) % 3000 == 0:
                save_checkpoint(model, optimizer, -1, step, args.exp_dir)
            if (idx + 1) % print_freq == 0:
                eplashed = time.time() - stime
                speed_avg = eplashed / (idx + 1)
                loss_print_avg = loss_print / print_freq
                sisnr_print_avg = sisnr_print / print_freq
                print('Epoch {:3d}/{:3d} | batches {:5d}/{:5d} | lr {:1.4e} |'
                      '{:2.3f}s/batches | loss {:2.6f} |'
                      'SI-SNR {:2.4f} '.format(
                          epoch,
                          args.max_epoch,
                          idx + 1,
                          num_batch,
                          lr,
                          speed_avg,
                          loss_print_avg,
                          -sisnr_print_avg,
                      ))
                sys.stdout.flush()
                writer.add_scalar('Loss/Train', loss_print_avg, step)
                writer.add_scalar('SISNR/Train', -sisnr_print_avg, step)
                loss_print = 0.0
                sisnr_print = 0.0
        eplashed = time.time() - stime
        loss_total_avg = loss_total / num_batch
        sisnr_total_avg = sisnr_total / num_batch
        print('Training AVG.LOSS |'
              ' Epoch {:3d}/{:3d} | lr {:1.4e} |'
              ' {:2.3f}s/batch | time {:3.2f}mins |'
              ' loss {:2.6f} |'
              ' SISNR {:2.4f}|'.format(epoch + 1, args.max_epoch, lr,
                                       eplashed / num_batch, eplashed / 60.0,
                                       loss_total_avg.item(),
                                       -sisnr_total_avg.item()))
        val_loss, val_sisnr = validation(model, args, lr, epoch, device)
        writer.add_scalar('Loss/Cross-Validation', val_loss, step)
        writer.add_scalar('SISNR/Cross-Validation', -val_sisnr, step)
        writer.add_scalar('learn_rate', lr, step)
        if val_loss > scheduler.best:
            print('Rejected !!! The best is {:2.6f}'.format(scheduler.best))
        else:
            save_checkpoint(model,
                            optimizer,
                            epoch + 1,
                            step,
                            args.exp_dir,
                            mode='best_model')
        scheduler.step(val_loss)
        sys.stdout.flush()
        stime = time.time()
def main(args):
    if args.name == 'sbr':
        experiment_dir = Path(
            os.path.join(BASE_DIR, 'save/sbr/' + get_run_name()))
    else:
        experiment_dir = Path(
            os.path.join(BASE_DIR,
                         'save/double_{}/'.format(args.name) + get_run_name()))

    experiment_dir.mkdir(parents=True, exist_ok=True)

    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)

    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    # writer = SummaryWriter(log_dir)

    config_f = open(os.path.join(log_dir, 'config.json'), 'w')
    json.dump(vars(args), config_f)
    config_f.close()

    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    num_class = DATASETS[args.dataset]['n_classes']

    if args.name == 'sbr':
        net_p = ngvnn.Net_Prev(pretraining=args.pretrain,
                               num_views=args.num_views)
        net_whole = ngvnn.Net_Whole(pretraining=args.pretrain)
    elif args.name == 'pointnet':
        import importlib
        model = importlib.import_module(args.model)
        net_p = model.get_encoder_model(num_class, normal_channel=False)
        net_whole = model.get_encoder_model(num_class, normal_channel=False)
    elif args.name == 'ngvnn':
        net_p = ngvnn.Net_Prev(pretraining=args.pretrain,
                               num_views=args.num_views)
        net_whole = ngvnn.Net_Prev(pretraining=args.pretrain,
                                   num_views=args.num_views)
    else:
        NotImplementedError

    net_cls = ngvnn.Net_Classifier(nclasses=num_class)

    crt_cls = nn.CrossEntropyLoss().cuda()
    if args.triplet_type == 'tcl':
        center_embed = 512
        crt_tpl = custom_loss.TripletCenterLoss(margin=args.margin,
                                                center_embed=center_embed,
                                                num_classes=num_class).cuda()
        optim_centers = torch.optim.SGD(crt_tpl.parameters(), lr=0.1)
    else:
        from dataset.TripletSampler import HardestNegativeTripletSelector
        anchor_index = args.n_classes * args.n_samples
        crt_tpl = custom_loss.OnlineTripletLoss(
            args.margin,
            HardestNegativeTripletSelector(args.margin, False, anchor_index))
    criterion = [crt_cls, crt_tpl, args.w1, args.w2]
    # Load from checkpoint
    start_epoch = best_top1 = 0
    if args.resume:
        checkpoint = torch.load(args.resume)
        net_p.load_state_dict(checkpoint['net_p'])
        net_whole.load_state_dict(checkpoint['net_whole'])
        net_cls.load_state_dict(checkpoint['net_cls'])
        start_epoch = checkpoint['epoch']
        best_top1 = checkpoint['best_prec']

    net_whole = nn.DataParallel(net_whole).cuda()
    net_cls = nn.DataParallel(net_cls).cuda()
    net_p = nn.DataParallel(net_p).cuda()

    optim_shape = optim.SGD([{
        'params': net_p.parameters()
    }, {
        'params': net_cls.parameters()
    }],
                            lr=0.001,
                            momentum=0.9,
                            weight_decay=args.weight_decay)
    if args.name in ['sbr', 'ngvnn', 'pointnet']:
        base_param_ids = set(map(id, net_whole.module.features.parameters()))
        new_params = [
            p for p in net_whole.parameters() if id(p) not in base_param_ids
        ]
        param_groups = [{
            'params': net_whole.module.features.parameters(),
            'lr_mult': 0.1
        }, {
            'params': new_params,
            'lr_mult': 1.0
        }]

        optim_sketch = optim.SGD(param_groups,
                                 lr=0.001,
                                 momentum=0.9,
                                 weight_decay=args.weight_decay)
    else:
        optim_sketch = optim.SGD([{
            'params': net_p.parameters()
        }, {
            'params': net_cls.parameters()
        }],
                                 lr=0.001,
                                 momentum=0.9,
                                 weight_decay=args.weight_decay)

    if args.triplet_type == 'tcl':
        optimizer = (optim_sketch, optim_shape, optim_centers)
    else:
        optimizer = (optim_sketch, optim_shape)
    model = (net_whole, net_p, net_cls)

    # Schedule learning rate
    def adjust_lr(epoch, optimizer):
        step_size = 800 if args.pk_flag else 80  # 40
        lr = args.lr * (0.1**(epoch // step_size))
        for g in optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    train_shape_loader, train_sketch_loader, test_shape_loader, test_sketch_loader = get_dataloader(
        args)
    # Start training
    top1 = 0.0
    best_epoch = -1
    best_metric = None
    '''TRANING'''
    logger.info('Start training...')

    for epoch in range(start_epoch, args.epoch):
        # cls acc top1
        train_top1 = train(train_sketch_loader, train_shape_loader, model,
                           criterion, optimizer, epoch, args, logger)
        if train_top1 > 0.1:
            print("Test:")
            cur_metric = validate(test_sketch_loader, test_shape_loader, model,
                                  criterion, args, logger)
            top1 = cur_metric[3]  # mAP_feat_norm

        is_best = top1 > best_top1
        if is_best:
            best_epoch = epoch + 1
            best_metric = cur_metric
        best_top1 = max(top1, best_top1)

        # path_checkpoint = '{0}/model_latest.pth'.format(checkpoints_dir)
        # misc.save_checkpoint(checkpoint, path_checkpoint)

        if is_best:  # save checkpoint
            logger.info('Save model...')
            savepath = str(checkpoints_dir) + '/best_model.pth'
            checkpoint = {}
            checkpoint['epoch'] = epoch + 1
            checkpoint['current_prec'] = top1
            checkpoint['best_prec'] = best_top1
            checkpoint['net_p'] = net_p.module.state_dict()
            checkpoint['net_whole'] = net_whole.module.state_dict()
            checkpoint['net_cls'] = net_cls.module.state_dict()

            # torch.save(checkpoint, savepath)
            # path_checkpoint = '{0}/best_model.pth'.format(checkpoints_dir)
            misc.save_checkpoint(checkpoint, savepath)  #path_checkpoint)

        log_string(
            '\n * Finished epoch {:3d}  top1: {:5.3%}  best: {:5.3%}{} @epoch {}\n'
            .format(epoch, top1, best_top1, ' *' if is_best else '',
                    best_epoch), logger)

    logger.info('End of training...')

    log_string('Best metric {}'.format(best_metric), logger)

    return experiment_dir
Пример #3
0
def train(model, args, device, writer):
    print('preparing data...')
    dataloader, dataset = make_loader(
        args.tr_list,
        args.batch_size,
        num_workers=args.num_threads,
            )
    print_freq = 100
    num_batch = len(dataloader)
    params = model.get_params(args.weight_decay)
    optimizer = optim.Adam(params, lr=args.learn_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', factor=0.5, patience=1, verbose=True)
    
    if args.retrain:
        start_epoch, step = reload_model(model, optimizer, args.exp_dir,
                                         args.use_cuda)
    else:
        start_epoch, step = 0, 0
    print('---------PRERUN-----------')
    lr = get_learning_rate(optimizer)
    print('(Initialization)')
    val_loss, val_sisnr = validation(model, args, lr, -1, device)
    writer.add_scalar('Loss/Train', val_loss, step)
    writer.add_scalar('Loss/Cross-Validation', val_loss, step)
    
    writer.add_scalar('SiSNR/Train', -val_sisnr, step)
    writer.add_scalar('SiSNR/Cross-Validation', -val_sisnr, step)

    for epoch in range(start_epoch, args.max_epoch):
        torch.manual_seed(args.seed + epoch)
        if args.use_cuda:
            torch.cuda.manual_seed(args.seed + epoch)
        model.train()
        sisnr_total = 0.0
        sisnr_print = 0.0
        mix_loss_total = 0.0 
        mix_loss_print = 0.0 
        amp_loss_total = 0.0 
        amp_loss_print = 0.0
        phase_loss_total = 0.0
        phase_loss_print = 0.0

        stime = time.time()
        lr = get_learning_rate(optimizer)
        for idx, data in enumerate(dataloader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            model.zero_grad()
            outputs, wav = data_parallel(model, (inputs,))
            loss = model.loss(outputs, labels, mode='Mix')
            loss[0].backward()
            nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            optimizer.step()
            step += 1
            sisnr = model.loss(wav, labels, mode='SiSNR')
            
            mix_loss_total += loss[0].data.cpu()
            mix_loss_print += loss[0].data.cpu()
            
            amp_loss_total += loss[1].data.cpu()
            amp_loss_print += loss[1].data.cpu()
            
            phase_loss_total += loss[2].data.cpu()
            phase_loss_print += loss[2].data.cpu()
            
            sisnr_print += sisnr.data.cpu()
            sisnr_total += sisnr.data.cpu()

            del outputs, labels, inputs, loss, wav
            if (idx+1) % 1000 == 0:
                save_checkpoint(model, optimizer, -1, step, args.exp_dir)
            if (idx + 1) % print_freq == 0:
                eplashed = time.time() - stime
                speed_avg = eplashed / (idx+1)
                mix_loss_print_avg = mix_loss_print / print_freq
                amp_loss_print_avg = amp_loss_print / print_freq
                phase_loss_print_avg = phase_loss_print / print_freq
                sisnr_print_avg = sisnr_print / print_freq
                print('Epoch {:3d}/{:3d} | batches {:5d}/{:5d} | lr {:1.4e} |'
                      '{:2.3f}s/batches '
                      '| Mixloss {:2.4f}'
                      '| AMPloss {:2.4f}'
                      '| Phaseloss {:2.4f}'
                      '| SiSNR {:2.4f}'
                      .format(
                          epoch, args.max_epoch, idx + 1, num_batch, lr,
                          speed_avg, 
                          mix_loss_print_avg,
                          amp_loss_print_avg,
                          phase_loss_print_avg,
                          -sisnr_print_avg
                    ))
                sys.stdout.flush()
                writer.add_scalar('SiSNR/Train', -sisnr_print_avg, step)
                writer.add_scalar('Loss/Train', mix_loss_print_avg, step)
                mix_loss_print = 0. 
                amp_loss_print = 0.
                phase_loss_print = 0. 
                sisnr_print = 0.

        eplashed = time.time() - stime
        mix_loss_total_avg = mix_loss_total / num_batch
        sisnr_total_avg = sisnr_total / num_batch
        print(
            'Training AVG.LOSS |'
            ' Epoch {:3d}/{:3d} | lr {:1.4e} |'
            ' {:2.3f}s/batch | time {:3.2f}mins |'
            ' Mixloss {:2.4f}'
            ' SiSNR {:2.4f}'
            .format(
                                    epoch + 1, args.max_epoch,
                                    lr,
                                    eplashed/num_batch,
                                    eplashed/60.0,
                                    mix_loss_total_avg,
                                    -sisnr_total_avg
                ))
        val_loss, val_sisnr = validation(model, args, lr, epoch, device)
        writer.add_scalar('Loss/Cross-Validation', val_loss, step)
        writer.add_scalar('SiSNR/Cross-Validation', -val_sisnr, step)
        writer.add_scalar('learn_rate', lr, step) 
        if val_loss > scheduler.best:
            print('Rejected !!! The best is {:2.6f}'.format(scheduler.best))
        else:
            save_checkpoint(model, optimizer, epoch + 1, step, args.exp_dir, mode='best_model')
        scheduler.step(val_loss)
        sys.stdout.flush()
        stime = time.time()
Пример #4
0
def train(model, args, device, writer):
    print('preparing data...')
    dataloader, dataset = make_loader(
        args.tr_list,
        args.batch_size,
        num_workers=args.num_threads,
        processer=Processer(
            win_len=args.win_len,
            win_inc=args.win_inc,
            left_context=args.left_context,
            right_context=args.right_context,
            fft_len=args.fft_len,
            window_type=args.win_type))
    print_freq = 100
    num_batch = len(dataloader)
    params = model.get_params(args.weight_decay)
    optimizer = optim.Adam(params, lr=args.learn_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', factor=0.5, patience=1, verbose=True)
    
    if args.retrain:
        start_epoch, step = reload_model(model, optimizer, args.exp_dir,
                                         args.use_cuda)
    else:
        start_epoch, step = 0, 0
    print('---------PRERUN-----------')
    lr = get_learning_rate(optimizer)
    print('(Initialization)')
    val_loss = validation(model, args, lr, -1, device)
    writer.add_scalar('Loss/Train', val_loss, step)
    writer.add_scalar('Loss/Cross-Validation', val_loss, step)

    for epoch in range(start_epoch, args.max_epoch):
        torch.manual_seed(args.seed + epoch)
        if args.use_cuda:
            torch.cuda.manual_seed(args.seed + epoch)
        model.train()
        loss_total = 0.0
        loss_print = 0.0
        stime = time.time()
        lr = get_learning_rate(optimizer)
        for idx, data in enumerate(dataloader):
            inputs, labels, lengths = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            lengths = lengths
            
            model.zero_grad()
            outputs, _ = data_parallel(model, (inputs, lengths))
            
            loss = model.loss(outputs, labels, lengths)
            
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            optimizer.step()
            step += 1
            loss_total += loss.data.cpu()
            loss_print += loss.data.cpu()
            
            del lengths, outputs, labels, inputs, loss, _
            if (idx+1) % 3000 == 0:
                save_checkpoint(model, optimizer, epoch + 1, step, args.exp_dir)
            if (idx + 1) % print_freq == 0:
                eplashed = time.time() - stime
                speed_avg = eplashed / (idx+1)
                loss_print_avg = loss_print / print_freq
                print('Epoch {:3d}/{:3d} | batches {:5d}/{:5d} | lr {:1.4e} |'
                      '{:2.3f}s/batches | loss {:2.6f}'.format(
                          epoch, args.max_epoch, idx + 1, num_batch, lr,
                          speed_avg, loss_print_avg))
                sys.stdout.flush()
                writer.add_scalar('Loss/Train', loss_print_avg, step)
                loss_print = 0.0
        eplashed = time.time() - stime
        loss_total_avg = loss_total / num_batch
        print(
            'Training AVG.LOSS |'
            ' Epoch {:3d}/{:3d} | lr {:1.4e} |'
            ' {:2.3f}s/batch | time {:3.2f}mins |'
            ' loss {:2.6f}'.format(
                                    epoch + 1,
                                    args.max_epoch,
                                    lr,
                                    eplashed/num_batch,
                                    eplashed/60.0,
                                    loss_total_avg.item()))
        val_loss = validation(model, args, lr, epoch, device)
        writer.add_scalar('Loss/Cross-Validation', val_loss, step)
        writer.add_scalar('learn_rate', lr, step) 
        if val_loss > scheduler.best:
            print('Rejected !!! The best is {:2.6f}'.format(scheduler.best))
        else:
            save_checkpoint(model, optimizer, epoch + 1, step, args.exp_dir)
        scheduler.step(val_loss)
        sys.stdout.flush()
        stime = time.time()
Пример #5
0
def train(model, args, device, writer):
    print('preparing data...')
    dataloader, dataset = make_loader(args.tr_list,
                                      args.batch_size,
                                      num_workers=args.num_threads,
                                      segment_length=1.5,
                                      sample_rate=args.sample_rate,
                                      processer=Processer(
                                          sample_rate=args.sample_rate, ))
    print_freq = 100
    num_batch = len(dataloader)
    params = model.get_params(args.weight_decay)
    optimizer = optim.Adam(params, lr=args.learn_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     factor=0.5,
                                                     patience=1,
                                                     verbose=True)

    if args.retrain:
        start_epoch, step = reload_model(model, optimizer, args.exp_dir,
                                         args.use_cuda)
    else:
        start_epoch, step = 0, 0
    print('---------PRERUN-----------')
    lr = get_learning_rate(optimizer)
    print('(Initialization)')
    val_loss, val_sisnr = 30, 30.  #validation(model, args, lr, -1, device)
    writer.add_scalar('Loss/Train', val_loss, step)
    writer.add_scalar('Loss/Cross-Validation', val_loss, step)

    writer.add_scalar('SISNR/Train', -val_sisnr, step)
    writer.add_scalar('SISNR/Cross-Validation', -val_sisnr, step)

    for epoch in range(start_epoch, args.max_epoch):
        torch.manual_seed(args.seed + epoch)
        if args.use_cuda:
            torch.cuda.manual_seed(args.seed + epoch)
        model.train()

        all_moniter = Moniter('All', num_batch, print_freq)
        sdr_moniter = Moniter('SDR', num_batch, print_freq)
        speaker_moniter = Moniter('Speaker', num_batch, print_freq)
        reg_moniter = Moniter('Reg', num_batch, print_freq)

        stime = time.time()
        lr = get_learning_rate(optimizer)
        for idx, data in enumerate(dataloader):
            inputs, labels, spkid = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            spkid = spkid.to(device)
            model.zero_grad()
            est_wav, speaker_loss, reg_loss = data_parallel(
                model, (inputs, spkid))
            speaker_loss = torch.mean(speaker_loss)
            reg_loss = torch.mean(reg_loss)

            sdr = model.loss(est_wav, labels, loss_mode='SDR')
            all = sdr + 2 * speaker_loss + 0.3 * reg_loss
            all.backward()
            nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            optimizer.step()
            all_moniter(all)
            sdr_moniter(sdr)
            speaker_moniter(speaker_loss)
            reg_moniter(reg_loss)

            step += 1

            if (idx + 1) % 3000 == 0:
                save_checkpoint(model, optimizer, -1, step, args.exp_dir)
                #val_loss, val_sisnr= validation(model, args, lr, epoch, device)
                #scheduler.step(val_loss)
                #lr = get_learning_rate(optimizer)
            if (idx + 1) % print_freq == 0:
                eplashed = time.time() - stime
                speed_avg = eplashed / (idx + 1)
                log_str = 'Epoch {:3d}/{:3d} | batches {:5d}/{:5d} | lr {:1.4e} |'\
                      '{:2.3f}s/batches |' \
                      ' {:s} |'\
                      ' {:s} |'\
                      ' {:s} |'\
                      ' {:s} |'\
                      ''.format(
                          epoch, args.max_epoch, idx + 1, num_batch, lr,
                          speed_avg,
                            all_moniter.recent(),
                            sdr_moniter.recent(),
                            speaker_moniter.recent(),
                            reg_moniter.recent()
                          )
                writer.add_scalar('Loss/Train', all_moniter.rec_float(), step)
                writer.add_scalar('SISNR/Train', -sdr_moniter.rec_float(),
                                  step)
                print(log_str)
                all_moniter.reset()
                sdr_moniter.reset()
                speaker_moniter.reset()
                reg_moniter.reset()
                sys.stdout.flush()

        eplashed = time.time() - stime
        log_str = 'Training AVG.LOSS |' \
            ' Epoch {:3d}/{:3d} | lr {:1.4e} |' \
            ' {:2.3f}s/batch | time {:3.2f}mins |' \
            ' {:s} |'\
            ' {:s} |'\
            ' {:s} |'\
            ' {:s} |'\
            ''.format(
                                    epoch + 1,
                                    args.max_epoch,
                                    lr,
                                    eplashed/num_batch,
                                    eplashed/60.0,
                            all_moniter.average(),
                            sdr_moniter.average(),
                            speaker_moniter.average(),
                            reg_moniter.average()
                        )
        print(log_str)
        val_loss, val_sisnr = validation(model, args, lr, epoch, device)
        writer.add_scalar('Loss/Cross-Validation', val_loss, step)
        writer.add_scalar('SISNR/Cross-Validation', -val_sisnr, step)
        writer.add_scalar('learn_rate', lr, step)
        if val_loss > scheduler.best:
            print('Rejected !!! The best is {:2.6f}'.format(scheduler.best))
        else:
            save_checkpoint(model,
                            optimizer,
                            epoch + 1,
                            step,
                            args.exp_dir,
                            mode='best_model')
        scheduler.step(val_loss)
        sys.stdout.flush()
        stime = time.time()