Beispiel #1
0
    def _run_one_epoch(self, epoch, cross_valid=False):
        start = time.time()
        total_loss = 0

        data_loader = self.tr_loader if not cross_valid else self.cv_loader

        with tqdm(total=len(data_loader.dataset), unit="iter") as t:
            for i, (data) in enumerate(data_loader):
                self.batches += 1

                t.set_description(f"Epoch {epoch}") if not cross_valid else t.set_description("Validation")

                padded_mixture, mixture_lengths, padded_source = data
                if self.use_cuda:
                    padded_mixture = padded_mixture.cuda()
                    mixture_lengths = mixture_lengths.cuda()
                    padded_source = padded_source.cuda()
                estimate_source = self.model(padded_mixture)
                loss, max_snr, estimate_source, reorder_estimate_source = cal_loss(padded_source, estimate_source, mixture_lengths)

                if not cross_valid:
                    loss.backward()

                    if self.batches % self.batch_per_step == 0:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_norm)
                        self.optimizer.step()
                        self.optimizer.zero_grad()

                total_loss += loss.item()

                postfix = f"Iter {i + 1} | Steps {self.batches // self.batch_per_step} | Average Loss {(total_loss / (i + 1)):.{3}} | Loss {loss.item():.{6}}"
                t.set_postfix_str(postfix)
                t.update(padded_mixture.size(0))

        return total_loss / (i + 1)
    def _run_one_epoch(self, epoch, cross_valid=False):
        start = time.time()
        total_loss = 0

        data_loader = self.tr_loader if not cross_valid else self.cv_loader

        for i, data in enumerate(data_loader):
            mix_batch, src_batch, pad_mask = data

            if self.use_cuda:
                mix_batch = mix_batch.cuda()
                src_batch = src_batch.cuda()
                pad_mask = pad_mask.cuda()
            estimate_source = self.model(mix_batch)
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(src_batch, estimate_source, pad_mask)
            if not cross_valid:
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.max_norm)
                self.optimizer.step()

            total_loss += loss.item()

            if i % self.print_freq == 0:
                print('Epoch {0:3d} | Iter {1:5d} | Average Loss {2:3.3f} | '
                      'Current Loss {3:3.6f} | {4:5.1f} ms/batch'.format(
                          epoch + 1, i + 1, total_loss / (i + 1), loss.item(),
                          1000 * (time.time() - start) / (i + 1)),
                      flush=True)

        return total_loss / (i + 1)
Beispiel #3
0
    def _run_one_epoch(self, epoch, cross_valid=False):
        start = time.time()
        total_loss = 0

        data_loader = self.tr_loader if not cross_valid else self.cv_loader

        # visualizing loss using visdom
        if self.visdom_epoch and not cross_valid:
            vis_opts_epoch = dict(title=self.visdom_id + " epoch " +
                                  str(epoch),
                                  ylabel='Loss',
                                  xlabel='Epoch')
            vis_window_epoch = None
            vis_iters = torch.arange(1, len(data_loader) + 1)
            vis_iters_loss = torch.Tensor(len(data_loader))

        for i, (data) in enumerate(data_loader):
            padded_mixture, mixture_lengths, padded_source = data
            if self.use_cuda:
                padded_mixture = padded_mixture.cuda()
                mixture_lengths = mixture_lengths.cuda()
                padded_source = padded_source.cuda()
            estimate_source = self.model(padded_mixture)
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)
            if not cross_valid:
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.max_norm)
                self.optimizer.step()

            total_loss += loss.item()

            if i % self.print_freq == 0:
                print('Epoch {0} | Iter {1} | Average Loss {2:.3f} | '
                      'Current Loss {3:.6f} | {4:.1f} ms/batch'.format(
                          epoch + 1, i + 1, total_loss / (i + 1), loss.item(),
                          1000 * (time.time() - start) / (i + 1)),
                      flush=True)

            # visualizing loss using visdom
            if self.visdom_epoch and not cross_valid:
                vis_iters_loss[i] = loss.item()
                if i % self.print_freq == 0:
                    x_axis = vis_iters[:i + 1]
                    y_axis = vis_iters_loss[:i + 1]
                    if vis_window_epoch is None:
                        vis_window_epoch = self.vis.line(X=x_axis,
                                                         Y=y_axis,
                                                         opts=vis_opts_epoch)
                    else:
                        self.vis.line(X=x_axis,
                                      Y=y_axis,
                                      win=vis_window_epoch,
                                      update='replace')

        return total_loss / (i + 1)
Beispiel #4
0
def evaluate(args):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0

    # Load model
    model = ConvTasNet.load_model(args.model_path)
    print(model)
    model.eval()
    #if args.use_cuda:
    if True:
        model.cuda()

    # Load data
    dataset = AudioDataset(args.data_dir,
                           args.batch_size,
                           sample_rate=args.sample_rate,
                           segment=-1)
    data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2)

    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data
            #if args.use_cuda:
            if True:
                padded_mixture = padded_mixture.cuda()
                mixture_lengths = mixture_lengths.cuda()
                padded_source = padded_source.cuda()
            # Forward
            estimate_source = model(padded_mixture)  # [B, C, T]
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)
            # Remove padding and flat
            mixture = remove_pad(padded_mixture, mixture_lengths)
            source = remove_pad(padded_source, mixture_lengths)
            # NOTE: use reorder estimate source
            estimate_source = remove_pad(reorder_estimate_source,
                                         mixture_lengths)
            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                print("Utt", total_cnt + 1)
                # Compute SDRi
                if args.cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    print("\tSDRi={0:.2f}".format(avg_SDRi))
                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
                total_SISNRi += avg_SISNRi
                total_cnt += 1
    if args.cal_sdr:
        print("Average SDR improvement: {0:.2f}".format(total_SDRi /
                                                        total_cnt))
    print("Average SISNR improvement: {0:.2f}".format(total_SISNRi /
                                                      total_cnt))
    def validation_step(self, batch, batch_nb):
        padded_mixture, mixture_lengths, padded_source = \
            batch['audio_input'], batch['lengths'], batch['audio_targets']

        estimate_source = self.forward(padded_mixture)

        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(padded_source, estimate_source, mixture_lengths)

        return {'val_loss': loss.item()}
    def _run_one_epoch(self, epoch, cross_valid=False):
        start = time.time()
        total_loss = 0

        data_loader = self.tr_loader if not cross_valid else self.cv_loader
        print('data_loader.len {}'.format(len(data_loader)))
        for i, (data) in enumerate(data_loader):
            padded_mixture_, mixture_lengths_, padded_source_ = data
            seg_idx = numpy.random.randint(0, padded_mixture_.shape[0], self.batch_size)
            padded_mixture = padded_mixture_[seg_idx, :]
            mixture_lengths = mixture_lengths_[seg_idx]
            padded_source = padded_source_[seg_idx, :]
            # print('seg_idx {}'.format(seg_idx))
            # print('padded_mixture_ {}'.format(padded_mixture_))
            # print('padded_source_ {}'.format(padded_source_))
            # print('padded_mixture {}'.format(padded_mixture))
            # print('padded_source {}'.format(padded_source))
            # print('mixture_lengths {}'.format(mixture_lengths))
            # print('padded_mixture.shape {}'.format(padded_mixture.shape))
            if self.use_cuda:
                # padded_mixture = padded_mixture.cuda()
                # mixture_lengths = mixture_lengths.cuda()
                # padded_source = padded_source.cuda()
                padded_mixture = padded_mixture.to(device)
                mixture_lengths = mixture_lengths.to(device)
                padded_source = padded_source.to(device)
            #print('padded_mixture.shape {}'.format(padded_mixture.shape))
            estimate_source = self.model(padded_mixture)
            #print('estimate_source.shape {}'.format(estimate_source.shape))
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)
            if not cross_valid:
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.max_norm)
                self.optimizer.step()

            total_loss += loss.item()

            if i % self.print_freq == 0:
                print('Epoch {0} | Iter {1} | Average Loss {2:.3f} | '
                      'Current Loss {3:.6f} | {4:.1f} ms/batch'.format(
                    epoch + 1, i + 1, total_loss / (i + 1),
                    loss.item(), 1000 * (time.time() - start) / (i + 1)),
                    flush=True)

        return total_loss / (i + 1)
    def training_step(self, batch, batch_nb):
        padded_mixture, mixture_lengths, padded_source = \
            batch['audio_input'], batch['lengths'], batch['audio_targets']
        estimate_source = self.forward(padded_mixture)

        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(padded_source, estimate_source, mixture_lengths)

        # ?
        # torch.nn.utils.clip_grad_norm_(self.model.parameters(),
        #                                self.max_norm)

        # ?
        # total_loss += loss.item()

        return {'loss': loss, 'log': {'train_loss': loss}}
    def _run_one_epoch(self, epoch, cross_valid=False):
        start = time.time()
        total_loss = 0

        data_loader = self.tr_loader if not cross_valid else self.cv_loader

        for i, (data) in enumerate(data_loader):

            padded_mixture, mixture_lengths, padded_source = data

            if self.use_cuda:
                padded_mixture = padded_mixture.cuda()
                mixture_lengths = mixture_lengths.cuda()
                padded_source = padded_source.cuda()

            x = torch.rand(2, 4, 32000)
            none_mic = torch.zeros(1).type(x.type())
            estimate_source = self.model(padded_mixture, none_mic.long())

            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)

            if not cross_valid:
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.max_norm)
                self.optimizer.step()

            total_loss += loss.item()

            if i % self.print_freq == 0:
                #optim_state = self.optimizer.state_dict()
                #print('Learning rate adjusted to: {lr:.6f}'.format(lr=optim_state['param_groups'][0]['lr']))
                print('Epoch {0:3d} | Iter {1:5d} | Average Loss {2:3.3f} | '
                      'Current Loss {3:3.6f} | {4:5.1f} ms/batch'.format(
                          epoch + 1, i + 1, total_loss / (i + 1), loss.item(),
                          1000 * (time.time() - start) / (i + 1)),
                      flush=True)


        del padded_mixture, mixture_lengths, padded_source, \
            loss, max_snr, estimate_source, reorder_estimate_source

        if self.use_cuda: torch.cuda.empty_cache()

        return total_loss / (i + 1)
def evaluate(args):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0

    # Load model
    model = DPTNet(args.N, args.C, args.L, args.H, args.K, args.B)

    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()

    # model.load_state_dict(torch.load(args.model_path, map_location='cpu'))

    model_info = torch.load(args.model_path)

    state_dict = OrderedDict()
    for k, v in model_info['model_state_dict'].items():
        name = k.replace("module.", "")  # remove 'module.'
        state_dict[name] = v
    model.load_state_dict(state_dict)

    print(model)

    # Load data
    dataset = AudioDataset(args.data_dir,
                           args.batch_size,
                           sample_rate=args.sample_rate,
                           segment=-1)
    data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2)

    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data
            if args.use_cuda:
                padded_mixture = padded_mixture.cuda()
                mixture_lengths = mixture_lengths.cuda()
                padded_source = padded_source.cuda()
            # Forward
            estimate_source = model(padded_mixture)  # [B, C, T]
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)
            # Remove padding and flat
            mixture = remove_pad(padded_mixture, mixture_lengths)
            source = remove_pad(padded_source, mixture_lengths)
            # NOTE: use reorder estimate source
            estimate_source = remove_pad(reorder_estimate_source,
                                         mixture_lengths)
            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                print("Utt", total_cnt + 1)
                # Compute SDRi
                if args.cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    print("\tSDRi={0:.2f}".format(avg_SDRi))
                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
                total_SISNRi += avg_SISNRi
                total_cnt += 1
    if args.cal_sdr:
        print("Average SDR improvement: {0:.2f}".format(total_SDRi /
                                                        total_cnt))
    print("Average SISNR improvement: {0:.2f}".format(total_SISNRi /
                                                      total_cnt))
def evaluate(args):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0

    # Load model

    model = FaSNet_TAC(enc_dim=args.enc_dim,
                       feature_dim=args.feature_dim,
                       hidden_dim=args.hidden_dim,
                       layer=args.layer,
                       segment_size=args.segment_size,
                       nspk=args.nspk,
                       win_len=args.win_len,
                       context_len=args.context_len,
                       sr=args.sample_rate)

    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()

    # model.load_state_dict(torch.load(args.model_path, map_location='cpu'))

    model_info = torch.load(args.model_path)
    try:
        model.load_state_dict(model_info['model_state_dict'])
    except KeyError:
        state_dict = OrderedDict()
        for k, v in model_info['model_state_dict'].items():
            name = k.replace("module.", "")  # remove 'module.'
            state_dict[name] = v
        model.load_state_dict(state_dict)

    print(model)
    model.eval()

    # Load data
    dataset = AudioDataset('test',
                           batch_size=1,
                           sample_rate=args.sample_rate,
                           nmic=args.mic)
    data_loader = EvalAudioDataLoader(dataset, batch_size=1, num_workers=8)

    sisnr_array = []
    sdr_array = []
    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data

            if args.use_cuda:
                padded_mixture = padded_mixture.cuda()
                mixture_lengths = mixture_lengths.cuda()
                padded_source = padded_source.cuda()

            x = torch.rand(2, 6, 32000)
            none_mic = torch.zeros(1).type(x.type())
            # Forward
            estimate_source = model(padded_mixture,
                                    none_mic.long())  # [M, C, T]


            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)

            M, _, T = padded_mixture.shape
            mixture_ref = torch.chunk(padded_mixture, args.mic,
                                      dim=1)[0]  #[M, ch, T] -> [M, 1, T]
            mixture_ref = mixture_ref.view(M, T)  #[M, 1, T] -> [M, T]

            mixture = remove_pad(mixture_ref, mixture_lengths)
            source = remove_pad(padded_source, mixture_lengths)
            estimate_source = remove_pad(reorder_estimate_source,
                                         mixture_lengths)

            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                print("Utt", total_cnt + 1)
                # Compute SDRi
                if args.cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    sdr_array.append(avg_SDRi)
                    print("\tSDRi={0:.2f}".format(avg_SDRi))
                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
                total_SISNRi += avg_SISNRi
                sisnr_array.append(avg_SISNRi)
                total_cnt += 1
    if args.cal_sdr:
        print("Average SDR improvement: {0:.2f}".format(total_SDRi /
                                                        total_cnt))

    np.save('sisnr.npy', np.array(sisnr_array))
    np.save('sdr.npy', np.array(sdr_array))
    print("Average SISNR improvement: {0:.2f}".format(total_SISNRi /
                                                      total_cnt))
def evaluate(args):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0
    avg_SISNRiPitNum = 0
    length = torch.ones(1)
    length = length.int()
    numberEsti =[]
    # Load model
    model = ConvTasNet.load_model(args.model_path)
 #   print(model)
    model.eval()
    if args.use_cuda:
        model.cuda(0)

    # Load data
    dataset = AudioDataset(args.data_dir, args.batch_size,
                           sample_rate=args.sample_rate, segment=2)
    data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2)

    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            print(i)
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data
            if args.use_cuda:
                padded_mixture = padded_mixture.cuda(0)
                mixture_lengths = mixture_lengths.cuda(0)
            # Forward
            estimate_source ,s_embed  = model(padded_mixture)  # [B, C, T],#[B,N,K,E] 
          #  print(estimate_source.shape)
           # embid = (model.separator.network[2][7])(padded_mixture)
          #  print(embid)
            '''
            embeddings = s_embed[0].data.cpu().numpy()
            embedding = (embeddings.reshape((1,-1,20)))[0]
            number = sourceNumEsti2(embedding)
            numberEsti.append(number)
            '''
           # print(estimate_source)
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)
            # Remove padding and flat
            mixture = remove_pad(padded_mixture, mixture_lengths)
            source = remove_pad(padded_source, mixture_lengths)
           # print(max_snr.item())
            # NOTE: use reorder estimate source
            estimate_source = remove_pad(reorder_estimate_source,
                                         mixture_lengths)
           # print((estimate_source[0].shape))
            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                print("Utt", total_cnt + 1)
                # Compute SDRi
                if args.cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    print("\tSDRi={0:.2f}".format(avg_SDRi))
                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                #avg_SISNRiPit,a,b = cal_si_snr_with_pit(torch.from_numpy(src_ref), torch.from_numpy(src_est),length)
                print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
                total_SISNRi += (avg_SISNRi)
                #total_SNRiPitNum += avg_SISNRiPit.numpy()
                total_cnt += 1
            
    if args.cal_sdr:
        print("Average SDR improvement: {0:.2f}".format(total_SDRi / total_cnt))
    print("Average SISNR improvement: {0:.2f}".format(total_SISNRi / total_cnt))
    print("speaker:2 ./ClustertrainTFSE1New/final_paper_2_3_2chobatch6.pth.tar")
   
    return numberEsti
Beispiel #12
0
    def _run_one_epoch(self, epoch, cross_valid=False):
        start = time.time()
        total_loss = 0

        data_loader = self.tr_loader if not cross_valid else self.cv_loader

        # visualizing loss using visdom
        if self.visdom_epoch and not cross_valid:
            vis_opts_epoch = dict(title=self.visdom_id + " epoch " + str(epoch),
                                  ylabel='Loss', xlabel='Epoch')
            vis_window_epoch = None
            vis_iters = torch.arange(1, len(data_loader) + 1)
            vis_iters_loss = torch.Tensor(len(data_loader))

        for i, (data) in enumerate(data_loader):
            padded_mixture, mixture_lengths, padded_source= data #  ,c_t 
            if self.use_cuda:
                padded_mixture = padded_mixture.cuda(0)
             #   print(padded_mixture)
                mixture_lengths = mixture_lengths.cuda(0)
              #  print(mixture_lengths)
                padded_source = padded_source.cuda(0)
              #  print(padded_source)
            estimate_source,s_embed = self.model(padded_mixture) # ,c_t 
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)
            '''
            s_embed = s_embed.view(s_embed.shape[0],-1,20)
            s_embed_cov = torch.bmm(s_embed.permute(0, 2, 1),s_embed)/s_embed.shape[1] #[b,20,20]
            eig_vec = (torch.zeros(s_embed.shape[0])).cuda()
            for j in range(s_embed_cov.shape[0]):
                e,v = torch.symeig(s_embed_cov[j],eigenvectors=True) #[b,20]
                e_norm = e/torch.norm(torch.abs(e))
                eig_vec[j] = e_norm[18]-e_norm[17]
            loss2 = torch.mean(eig_vec)
            '''
           # print(loss2)
            loss = loss #-loss2*0.001
            if not cross_valid:
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.max_norm)
                self.optimizer.step()

            total_loss += loss.item()

            if i % self.print_freq == 0:
                print('Epoch {0} | Iter {1} | Average Loss {2:.3f} | '
                      'Current Loss {3:.6f} | {4:.1f} ms/batch'.format(
                          epoch + 1, i + 1, total_loss / (i + 1),
                          loss.item(), 1000 * (time.time() - start) / (i + 1)),
                      flush=True)

            # visualizing loss using visdom
            if self.visdom_epoch and not cross_valid:
                vis_iters_loss[i] = loss.item()
                if i % self.print_freq == 0:
                    x_axis = vis_iters[:i+1]
                    y_axis = vis_iters_loss[:i+1]
                    if vis_window_epoch is None:
                        vis_window_epoch = self.vis.line(X=x_axis, Y=y_axis,
                                                         opts=vis_opts_epoch)
                    else:
                        self.vis.line(X=x_axis, Y=y_axis, win=vis_window_epoch,
                                      update='replace')

        return total_loss / (i + 1)
    def _run_one_epoch(self, epoch, cross_valid=False):
        start = time.time()
        total_loss = 0

        data_loader = self.tr_loader if not cross_valid else self.cv_loader

        # visualizing loss using visdom
        if self.visdom_epoch and not cross_valid:
            vis_opts_epoch = dict(title=self.visdom_id + " epoch " +
                                  str(epoch),
                                  ylabel='Loss',
                                  xlabel='Epoch')
            vis_window_epoch = None
            vis_iters = torch.arange(1, len(data_loader) + 1)
            vis_iters_loss = torch.Tensor(len(data_loader))

        for i, (data) in enumerate(data_loader):
            padded_mixture, mixture_lengths, padded_source = data

            if self.use_cuda:
                padded_mixture = padded_mixture.cuda()
                mixture_lengths = mixture_lengths.cuda()
                padded_source = padded_source.cuda()

            estimate_source, est_mask_f_all = self.model(padded_mixture)

            loss, max_snr, estimate_source, reorder_estimate_source = cal_loss(
                padded_source, estimate_source, mixture_lengths)
            #
            # import librosa
            # import numpy as np
            # print('estimate_source ' + str(estimate_source.shape))
            # print('padded_source ' + str(padded_source.shape))
            #
            #
            # wav = padded_mixture[0,:].detach().float().cpu().numpy()
            # wav = wav / np.max(wav)
            # librosa.output.write_wav('/private/home/eliyan/mix.wav', y=wav, sr=8000)
            #
            #
            # wav = estimate_source[0,0,:].detach().float().cpu().numpy()
            # wav = wav / np.max(wav)
            # librosa.output.write_wav('/private/home/eliyan/gt.wav', y=wav, sr=8000)
            #
            # wav2 = padded_source[0,0,:].detach().cpu().numpy()
            # librosa.output.write_wav('/private/home/eliyan/pred.wav', y=wav2, sr=8000)
            #
            # wav = estimate_source[0,1,:].detach().float().cpu().numpy()
            # wav = wav / np.max(wav)
            # librosa.output.write_wav('/private/home/eliyan/gt1.wav', y=wav, sr=8000)
            #
            # wav2 = padded_source[0,1,:].detach().cpu().numpy()
            # librosa.output.write_wav('/private/home/eliyan/pred1.wav', y=wav2, sr=8000)
            # exit()

            loss_ss_f_all = 0.0
            iik = 0
            for ii in range(len(est_mask_f_all)):
                if (ii + 1) % self.args.loss_every == 0:

                    loss_ss_f_ii = cal_loss(padded_source, est_mask_f_all[ii],
                                            mixture_lengths)
                    loss_ss_f_all += loss_ss_f_ii[0]
                    iik += 1

            loss += loss_ss_f_all
            loss /= (iik + 1)

            if not cross_valid:
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.max_norm)
                self.optimizer.step()

            total_loss += loss.item()

            if i % self.print_freq == 0:
                print('Epoch {0} | Iter {1} | Average Loss {2:.3f} | '
                      'Current Loss {3:.6f} | {4:.1f} ms/batch'.format(
                          epoch + 1, i + 1, total_loss / (i + 1), loss.item(),
                          1000 * (time.time() - start) / (i + 1)),
                      flush=True)

            # visualizing loss using visdom
            if self.visdom_epoch and not cross_valid:
                vis_iters_loss[i] = loss.item()
                if i % self.print_freq == 0:
                    x_axis = vis_iters[:i + 1]
                    y_axis = vis_iters_loss[:i + 1]
                    if vis_window_epoch is None:
                        vis_window_epoch = self.vis.line(X=x_axis,
                                                         Y=y_axis,
                                                         opts=vis_opts_epoch)
                    else:
                        self.vis.line(X=x_axis,
                                      Y=y_axis,
                                      win=vis_window_epoch,
                                      update='replace')

        return total_loss / (i + 1)