예제 #1
0
def run_main():
    """
    Will Train a network and save all the batch gradients in a file (pickled dict with :
    key:value => conv1.weight : array [batch x chan x h x w]
    :return:
    """
    device = torch.device("cuda")
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=256,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=512,
                                              shuffle=True)

    model = Net(5).to(device)
    model.train()

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
    grad_func = GradientFunc()
    training_curve = []
    for epoch in range(1, 6):
        process(model, device, train_loader, optimizer, grad_func, n_steps=100)
        # save epoch gradients
        # save test results
        loss, accuracy = test(model, device, test_loader)
예제 #2
0
def train(ini_file):
    ''' Performs training according to .ini file

    :param ini_file: (String) the path of .ini file
    :return best_c_index: the best c-index
    '''
    # reads configuration from .ini file
    config = read_config(ini_file)
    # builds network|criterion|optimizer based on configuration
    model = Net(config['network']).to(device)
    criterion = Criterion(config['network'], device).to(device)
    optimizer = eval('optim.{}'.format(config['train']['optimizer']))(
        model.parameters(), lr=config['train']['learning_rate'])
    # constructs data loaders based on configuration
    train_dataset = MakeDataset(config['train']['h5_file'], is_train=True, device=device)
    test_dataset = MakeDataset(config['train']['h5_file'], is_train=False, device=device)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=train_dataset.__len__())
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=test_dataset.__len__())
    # training
    _best_acc = 0.70
    best_acc = 0.65
    best_ep = 0
    flag = 0
    _best_auc = 0
    best_auc = 0
    best_roc = None
    for epoch in range(1, config['train']['epochs'] + 1):
        # adjusts learning rate
        lr = adjust_learning_rate(optimizer, epoch,
                                  config['train']['learning_rate'],
                                  config['train']['lr_decay_rate'])
        # train step
        model.train()
        for X, y in train_loader:
            # makes predictions
            pred = model(X)
            train_loss = criterion(pred, y, model)
            train_FPR, train_TPR, train_ACC, train_roc, train_roc_auc, _, _, _, _ = Auc(pred, y)
            # updates parameters
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
        # valid step
        model.eval()
        for X, y in test_loader:
            # makes predictions
            with torch.no_grad():
                pred = model(X)
                # print(pred, y)
                valid_loss = criterion(pred, y, model)
                valid_FPR, valid_TPR, valid_ACC, valid_roc, valid_roc_auc, _, _, _, _ = Auc(pred, y)
                if valid_ACC > best_acc and train_ACC > _best_acc:
                    flag = 0
                    best_acc = valid_ACC
                    _best_acc = train_ACC
                    best_ep = epoch
                    best_auc = valid_roc_auc
                    _best_auc = train_roc_auc
                    best_roc = valid_roc
                    # saves the best model
                    torch.save({
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'epoch': epoch}, os.path.join(models_dir, ini_file.split('\\')[-1] + '.pth'))
                else:
                    flag += 1
                    if flag >= patience:
                        print('epoch: {}\t{:.8f}({:.8f})'.format(best_ep, _best_acc, best_acc))
                        if best_roc is not None:
                            plt.plot(best_roc[:, 0], best_roc[:, 1])
                            plt.title('ep:{}  AUC: {:.4f}({:.4f}) ACC: {:.4f}({:.4f})'.format(best_ep, _best_auc, best_auc, _best_acc, best_acc))
                            plt.show()
                        return best_acc, _best_acc
        # notes that, train loader and valid loader both have one batch!!!
        print('\rEpoch: {}\tLoss: {:.8f}({:.8f})\tACC: {:.8f}({:.8f})\tAUC: {}({})\tFPR: {:.8f}({:.8f})\tTPR: {:.8f}({:.8f})\tlr: {:g}\n'.format(
            epoch, train_loss.item(), valid_loss.item(), train_ACC, valid_ACC, train_roc_auc, valid_roc_auc, train_FPR, valid_FPR, train_TPR, valid_TPR, lr), end='', flush=False)
    return best_acc, _best_acc
예제 #3
0
    def train(self, args):
        with open(args.train_list, 'r') as train_list_file:
            self.train_list = [line.strip() for line in train_list_file.readlines()]
        self.eval_file = args.eval_file
        self.num_train_sentences = args.num_train_sentences
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.max_epoch = args.max_epoch
        self.model_path = args.model_path
        self.log_path = args.log_path
        self.fig_path = args.fig_path
        self.eval_plot_num = args.eval_plot_num
        self.eval_steps = args.eval_steps
        self.resume_model = args.resume_model
        self.wav_path = args.wav_path
        self.train_wav_path = args.train_wav_path
        self.tool_path = args.tool_path

        # create a training dataset and an evaluation dataset
        trainSet = TrainingDataset(self.train_list,
                                   frame_size=self.frame_size,
                                   frame_shift=self.frame_shift)
        evalSet = EvalDataset(self.eval_file,
                              self.num_test_sentences)
        # trainSet = evalSet
        # create data loaders for training and evaluation
        train_loader = DataLoader(trainSet,
                                  batch_size=self.batch_size,
                                  shuffle=True,
                                  num_workers=16,
                                  collate_fn=TrainCollate())

        eval_loader = DataLoader(evalSet,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=EvalCollate())

        # create a network
        print('model', self.model_name)
        net = Net(device=self.device, L=self.frame_size, width=self.width)
        # net = torch.nn.DataParallel(net)
        net.to(self.device)
        print('Number of learnable parameters: %d' % numParams(net))
        print(net)

        criterion = mse_loss()
        criterion1 = stftm_loss(device=self.device)
        optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
        self.lr_list = [0.0002] * 3 + [0.0001] * 6 + [0.00005] * 3 + [0.00001] * 3
        if self.resume_model:
            print('Resume model from "%s"' % self.resume_model)
            checkpoint = Checkpoint()
            checkpoint.load(self.resume_model)
            start_epoch = checkpoint.start_epoch
            start_iter = checkpoint.start_iter
            best_loss = checkpoint.best_loss
            net.load_state_dict(checkpoint.state_dict)
            optimizer.load_state_dict(checkpoint.optimizer)
        else:
            print('Training from scratch.')
            start_epoch = 0
            start_iter = 0
            best_loss = np.inf

        num_train_batches = self.num_train_sentences // self.batch_size
        total_train_batch = self.max_epoch * num_train_batches
        print('num_train_sentences', self.num_train_sentences)
        print('batches_per_epoch', num_train_batches)
        print('total_train_batch', total_train_batch)
        print('batch_size', self.batch_size)
        print('model_name', self.model_name)
        batch_timings = 0.
        counter = int(start_epoch * num_train_batches + start_iter)
        counter1 = 0
        print('counter', counter)
        ttime = 0.
        cnt = 0.
        iteration = 0
        print('best_loss', best_loss)
        for epoch in range(start_epoch, self.max_epoch):
            accu_train_loss = 0.0
            net.train()
            for param_group in optimizer.param_groups:
                param_group['lr'] = self.lr_list[epoch]

            start = timeit.default_timer()
            for i, (features, labels, nframes, feat_size, label_size, get_filename) in enumerate(
                    train_loader):  # features:torch.Size([4, 1, 250, 512])
                iteration += 1
                labels_cpu = labels
                i += start_iter
                features, labels = features.to(self.device), labels.to(self.device)  # torch.Size([4, 1, 250, 512])

                loss_mask = compLossMask(labels, nframes=nframes)

                # forward + backward + optimize
                optimizer.zero_grad()

                outputs = net(features)  # torch.Size([4, 1, 64256])

                feature_maker = Fbank(sample_rate=16000, n_fft=400, n_mels=40)
                loss_fbank = 0

                for t in range(len(get_filename)):
                    reader = h5py.File(get_filename[t], 'r')
                    feature_asr = reader['noisy_raw'][:]
                    label_asr = reader['clean_raw'][:]

                    feat_asr_size = int(feat_size[t][0].item())
                    label_asr_size = int(label_size[t][0].item())

                    output_asr = self.train_asr_forward(feature_asr, net)
                    est_output_asr = output_asr[:feat_asr_size]
                    ideal_labels_asr = label_asr

                    # 保存train的wav
                    est_path = os.path.join(self.train_wav_path, '{}_est.wav'.format(t + 1))
                    ideal_path = os.path.join(self.train_wav_path, '{}_ideal.wav'.format(t + 1))
                    sf.write(est_path, normalize_wav(est_output_asr)[0], self.srate)
                    sf.write(ideal_path, normalize_wav(ideal_labels_asr)[0], self.srate)

                    # read wav
                    est_sig = sb.dataio.dataio.read_audio(est_path).unsqueeze(axis=0).to(self.device)
                    ideal_sig = sb.dataio.dataio.read_audio(ideal_path).unsqueeze(axis=0).to(self.device)
                    est_sig_feats = feature_maker(est_sig)
                    ideal_sig_feats = feature_maker(ideal_sig)

                    # fbank_loss
                    loss_fbank += F.mse_loss(est_sig_feats, ideal_sig_feats, True)

                loss_fbank /= 100 * len(get_filename)
                # print(loss_fbank)
                # loss_fbank = 1 / (1 + math.exp(loss_fbank))

                outputs = outputs[:, :, :labels.shape[-1]]

                loss1 = criterion(outputs, labels, loss_mask, nframes)
                loss2 = criterion1(outputs, labels, loss_mask, nframes)
                # print(loss1)
                # print(loss2)

                # loss = 0.8 * loss1 + 0.2 * loss2
                loss = 0.4 * loss1 + 0.1 * loss2 + 0.5 * loss_fbank

                loss.backward()
                optimizer.step()
                # calculate losses
                running_loss = loss.data.item()
                accu_train_loss += running_loss

                # train-loss show
                summary.add_scalar('Train Loss', accu_train_loss, iteration)

                cnt += 1.
                counter += 1
                counter1 += 1

                del loss, loss_fbank, loss1, loss2, outputs, loss_mask, features, labels
                end = timeit.default_timer()
                curr_time = end - start
                ttime += curr_time
                mtime = ttime / counter1
                print(
                    'iter = {}/{}, epoch = {}/{}, loss = {:.5f}, time/batch = {:.5f}, mtime/batch = {:.5f}'.format(
                        i + 1,
                        num_train_batches, epoch + 1, self.max_epoch, running_loss, curr_time, mtime))
                start = timeit.default_timer()
                if (i + 1) % self.eval_steps == 0:
                    start = timeit.default_timer()

                    avg_train_loss = accu_train_loss / cnt

                    avg_eval_loss = self.validate(net, eval_loader, iteration)

                    net.train()

                    print('Epoch [%d/%d], Iter [%d/%d]  ( TrainLoss: %.4f | EvalLoss: %.4f )' % (
                        epoch + 1, self.max_epoch, i + 1, self.num_train_sentences // self.batch_size,
                        avg_train_loss,
                        avg_eval_loss))

                    is_best = True if avg_eval_loss < best_loss else False
                    best_loss = avg_eval_loss if is_best else best_loss

                    checkpoint = Checkpoint(epoch, i, avg_train_loss, avg_eval_loss, best_loss, net.state_dict(),
                                            optimizer.state_dict())

                    model_name = self.model_name + '_latest.model'
                    best_model = self.model_name + '_best.model'
                    checkpoint.save(is_best, os.path.join(self.model_path, model_name),
                                    os.path.join(self.model_path, best_model))

                    logging(self.log_path, self.model_name + '_loss_log.txt', checkpoint, self.eval_steps)
                    # metric_logging(self.log_path, self.model_name +'_metric_log.txt', epoch+1, [avg_st, avg_sn, avg_pe])
                    accu_train_loss = 0.0
                    cnt = 0.

                    net.train()
                if (i + 1) % num_train_batches == 0:
                    break

        avg_st, avg_sn, avg_pe = self.validate_with_metrics(net, eval_loader)
        net.train()
        print('#' * 50)
        print('')
        print('After {} epoch the performance on validation score is a s follows:'.format(epoch + 1))
        print('')
        print('STOI: {:.4f}'.format(avg_st))
        print('SNR: {:.4f}'.format(avg_sn))
        print('PESQ: {:.4f}'.format(avg_pe))
        for param_group in optimizer.param_groups:
            print('learning_rate', param_group['lr'])
        print('')
        print('#' * 50)
        checkpoint = Checkpoint(epoch, 0, None, None, best_loss, net.state_dict(), optimizer.state_dict())
        checkpoint.save(False, os.path.join(self.model_path, self.model_name + '-{}.model'.format(epoch + 1)),
                        os.path.join(self.model_path, best_model))
        metric_logging(self.log_path, self.model_name + '_metric_log.txt', epoch, [avg_st, avg_sn, avg_pe])
        start_iter = 0.
예제 #4
0
파일: main.py 프로젝트: wayneowen7/GSAPool
        f.write(str(test_acc))
        f.write('\r\n')


#training configuration
train_loader, val_loader, test_loader = data_builder(args)
model = Net(args).to(args.device)
optimizer = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             weight_decay=args.weight_decay)

#training steps
patience = 0
min_loss = args.min_loss
for epoch in range(args.epochs):
    model.train()
    for i, data in enumerate(train_loader):
        data = data.to(args.device)
        out = model(data)
        loss = F.nll_loss(out, data.y)
        print("Training loss:{}".format(loss.item()))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    val_acc, val_loss = test(model, val_loader)
    print("Validation loss:{}\taccuracy:{}".format(val_loss, val_acc))
    print("Epoch{}".format(epoch))
    if val_loss < min_loss:
        torch.save(model.state_dict(), 'latest.pth')
        print("Model saved at epoch{}".format(epoch))
        min_loss = val_loss
예제 #5
0
class Model:
    def __init__(self, path=current_folder, learning_rate=1e-3, batch_size=128):
        torch.manual_seed(12345)
        self.path = path
        self.batch_size = batch_size
        self.data_creator = DataCreator(self.batch_size)
        self.learning_rate = learning_rate
        try:
            self.net = torch.load(self.path + "/net.pth")
            print("--------------------------------\n"
                  "Models were loaded successfully! \n"
                  "--------------------------------")
        except:
            print("-----------------------\n"
                  "No models were loaded! \n"
                  "-----------------------")
            self.net = Net(input_dim=225, hidden_dim=450)
        self.net.cuda()

    def predict_signal(self, ticker):
        signals = ['SELL', 'BUY', 'HOLD']
        _, data = get_daily_data(ticker, compact=True)
        self.net.train(False)
        with torch.no_grad():
            input = torch.tensor(data.to_numpy()[-1]).float().cuda()
            output = F.softmax(self.net(input), dim=-1).cpu().numpy()
            signal_idx = np.argmax(output)
        return signals[int(signal_idx)], 100*output[signal_idx]

    def test(self):

        losses = []
        accuracies = []
        buy_accuracies = []
        sell_accuracies = []
        hold_accuracies = []
        data_loader = self.data_creator.provide_testing_stock()
        criterion = nn.CrossEntropyLoss()
        self.net.train(False)
        with torch.no_grad():
            for i, (batch_x, batch_y) in enumerate(data_loader):
                batch_x = batch_x.float().cuda()
                batch_y = batch_y.long().cuda()

                output = self.net(batch_x)
                loss = criterion(output, batch_y)

                output_metric = np.argmax(F.softmax(output, dim=1).cpu().numpy(), axis=1)
                batch_size = batch_y.size()[0]
                batch_y = batch_y.cpu().numpy()
                sell_mask_label = batch_y == 0
                sell_mask_output = output_metric == 0
                sell_accuracies.append(100*(sell_mask_label == sell_mask_output).sum()/batch_size)
                buy_mask_label = batch_y == 1
                buy_mask_output = output_metric == 1
                buy_accuracies.append(100*(buy_mask_label == buy_mask_output).sum()/batch_size)
                hold_mask_label = batch_y == 2
                hold_mask_output = output_metric == 2
                hold_accuracies.append(100*(hold_mask_label == hold_mask_output).sum()/batch_size)
                losses.append((loss.item()))
                accuracy = 100 * sum(1 if output_metric[k] == batch_y[k] else 0 for k in
                                     range(batch_size)) / batch_size
                accuracies.append(accuracy)
        print("Average loss: ", np.mean(losses))
        print("Average accuracy: ", np.mean(accuracies))
        print("Buy-Average accuracy: ", np.mean(buy_accuracies))
        print("Sell-Average accuracy: ", np.mean(sell_accuracies))
        print("Hold-Average accuracy: ", np.mean(hold_accuracies))

    def train(self, epochs):

        rocs_aucs = []
        baseline_rocs_aucs = []
        losses = []
        accuracies = []
        data_loader, class_weights = self.data_creator.provide_training_stock()
        criterion = nn.CrossEntropyLoss()
        optimiser = optim.AdamW(self.net.parameters(), lr=self.learning_rate, weight_decay=1e-5, amsgrad=True)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, patience=220, min_lr=1e-9)
        self.net.train(True)
        pbar = tqdm(total=epochs)

        # train the network
        for epoch in range(epochs):

            for i, (batch_x, batch_y) in enumerate(data_loader):
                batch_x = batch_x.float().cuda()
                batch_y = batch_y.long().cuda()

                self.net.zero_grad()
                output = self.net(batch_x)
                loss = criterion(output, batch_y)
                loss.backward()
                optimiser.step()

                scheduler.step(loss.item())

                # Print some loss stats
                if i % 2 == 0:
                    output_metric = F.softmax(output.detach().cpu(), dim=1).numpy()
                    random_metric = relabel_data(np.random.choice([0, 1, 2], size=(1, self.batch_size), p=[1/3, 1/3, 1/3]))
                    label_metric = relabel_data(batch_y.detach().cpu().numpy())
                    losses.append((loss.item()))
                    rocs_aucs.append(roc_auc_score(label_metric, output_metric, multi_class='ovo'))
                    baseline_rocs_aucs.append(roc_auc_score(label_metric, random_metric, multi_class='ovo'))
                    accuracy = 100 * sum(1 if np.argmax(output_metric[k]) == np.argmax(label_metric[k]) else 0 for k in
                                         range(self.batch_size)) / self.batch_size
                    accuracies.append(accuracy)
            pbar.update(1)
        pbar.close()
        fig, axs = plt.subplots(1, 3)
        axs[0].plot(np.convolve(losses, (1/25)*np.ones(25), mode='valid'))
        axs[1].plot(np.convolve(rocs_aucs, (1/25)*np.ones(25), mode='valid'))
        axs[1].plot(np.convolve(baseline_rocs_aucs, (1/25)*np.ones(25), mode='valid'))
        axs[1].legend(['Net', 'Baseline'])
        axs[2].plot(np.convolve(accuracies, (1/25)*np.ones(25), mode='valid'))
        plt.show()

    def save(self):
        torch.save(self.net, self.path + "/net.pth")
예제 #6
0
    def train(self, args):
        with open(args.train_list, 'r') as train_list_file:
            self.train_list = [
                line.strip() for line in train_list_file.readlines()
            ]
        self.eval_file = args.eval_file
        self.num_train_sentences = args.num_train_sentences
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.max_epoch = args.max_epoch
        self.model_path = args.model_path
        self.log_path = args.log_path
        self.fig_path = args.fig_path
        self.eval_plot_num = args.eval_plot_num
        self.eval_steps = args.eval_steps
        self.resume_model = args.resume_model
        self.wav_path = args.wav_path
        self.tool_path = args.tool_path

        # create a training dataset and an evaluation dataset
        trainSet = TrainingDataset(self.train_list,
                                   frame_size=self.frame_size,
                                   frame_shift=self.frame_shift)
        evalSet = EvalDataset(self.eval_file, self.num_test_sentences)
        #trainSet = evalSet
        # create data loaders for training and evaluation
        train_loader = DataLoader(trainSet,
                                  batch_size=self.batch_size,
                                  shuffle=True,
                                  num_workers=16,
                                  collate_fn=TrainCollate())
        eval_loader = DataLoader(evalSet,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=EvalCollate())

        # create a network
        print('model', self.model_name)
        net = Net(device=self.device, L=self.frame_size, width=self.width)
        #net = torch.nn.DataParallel(net)
        net.to(self.device)
        print('Number of learnable parameters: %d' % numParams(net))
        print(net)

        criterion = mse_loss()
        criterion1 = stftm_loss(device=self.device)
        optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
        self.lr_list = [0.0002] * 3 + [0.0001] * 6 + [0.00005] * 3 + [0.00001
                                                                      ] * 3
        if self.resume_model:
            print('Resume model from "%s"' % self.resume_model)
            checkpoint = Checkpoint()
            checkpoint.load(self.resume_model)
            start_epoch = checkpoint.start_epoch
            start_iter = checkpoint.start_iter
            best_loss = checkpoint.best_loss
            net.load_state_dict(checkpoint.state_dict)
            optimizer.load_state_dict(checkpoint.optimizer)
        else:
            print('Training from scratch.')
            start_epoch = 0
            start_iter = 0
            best_loss = np.inf

        num_train_batches = self.num_train_sentences // self.batch_size
        total_train_batch = self.max_epoch * num_train_batches
        print('num_train_sentences', self.num_train_sentences)
        print('batches_per_epoch', num_train_batches)
        print('total_train_batch', total_train_batch)
        print('batch_size', self.batch_size)
        print('model_name', self.model_name)
        batch_timings = 0.
        counter = int(start_epoch * num_train_batches + start_iter)
        counter1 = 0
        print('counter', counter)
        ttime = 0.
        cnt = 0.
        print('best_loss', best_loss)
        for epoch in range(start_epoch, self.max_epoch):
            accu_train_loss = 0.0
            net.train()
            for param_group in optimizer.param_groups:
                param_group['lr'] = self.lr_list[epoch]

            start = timeit.default_timer()
            for i, (features, labels, nframes) in enumerate(train_loader):
                i += start_iter
                features, labels = features.to(self.device), labels.to(
                    self.device)

                loss_mask = compLossMask(labels, nframes=nframes)

                # forward + backward + optimize
                optimizer.zero_grad()

                outputs = net(features)
                outputs = outputs[:, :, :labels.shape[-1]]

                loss1 = criterion(outputs, labels, loss_mask, nframes)
                loss2 = criterion1(outputs, labels, loss_mask, nframes)

                loss = 0.8 * loss1 + 0.2 * loss2
                loss.backward()
                optimizer.step()
                # calculate losses
                running_loss = loss.data.item()
                accu_train_loss += running_loss

                cnt += 1.
                counter += 1
                counter1 += 1

                del loss, loss1, loss2, outputs, loss_mask, features, labels
                end = timeit.default_timer()
                curr_time = end - start
                ttime += curr_time
                mtime = ttime / counter1
                print(
                    'iter = {}/{}, epoch = {}/{}, loss = {:.5f}, time/batch = {:.5f}, mtime/batch = {:.5f}'
                    .format(i + 1, num_train_batches, epoch + 1,
                            self.max_epoch, running_loss, curr_time, mtime))
                start = timeit.default_timer()
                if (i + 1) % self.eval_steps == 0:
                    start = timeit.default_timer()

                    avg_train_loss = accu_train_loss / cnt

                    avg_eval_loss = self.validate(net, eval_loader)

                    net.train()

                    print(
                        'Epoch [%d/%d], Iter [%d/%d]  ( TrainLoss: %.4f | EvalLoss: %.4f )'
                        % (epoch + 1, self.max_epoch, i + 1,
                           self.num_train_sentences // self.batch_size,
                           avg_train_loss, avg_eval_loss))

                    is_best = True if avg_eval_loss < best_loss else False
                    best_loss = avg_eval_loss if is_best else best_loss

                    checkpoint = Checkpoint(epoch, i, avg_train_loss,
                                            avg_eval_loss, best_loss,
                                            net.state_dict(),
                                            optimizer.state_dict())

                    model_name = self.model_name + '_latest.model'
                    best_model = self.model_name + '_best.model'
                    checkpoint.save(is_best,
                                    os.path.join(self.model_path, model_name),
                                    os.path.join(self.model_path, best_model))

                    logging(self.log_path, self.model_name + '_loss_log.txt',
                            checkpoint, self.eval_steps)
                    #metric_logging(self.log_path, self.model_name +'_metric_log.txt', epoch+1, [avg_st, avg_sn, avg_pe])
                    accu_train_loss = 0.0
                    cnt = 0.

                    net.train()
                if (i + 1) % num_train_batches == 0:
                    break

            avg_st, avg_sn, avg_pe = self.validate_with_metrics(
                net, eval_loader)
            net.train()
            print('#' * 50)
            print('')
            print(
                'After {} epoch the performance on validation score is a s follows:'
                .format(epoch + 1))
            print('')
            print('STOI: {:.4f}'.format(avg_st))
            print('SNR: {:.4f}'.format(avg_sn))
            print('PESQ: {:.4f}'.format(avg_pe))
            for param_group in optimizer.param_groups:
                print('learning_rate', param_group['lr'])
            print('')
            print('#' * 50)
            checkpoint = Checkpoint(epoch, 0, None, None, best_loss,
                                    net.state_dict(), optimizer.state_dict())
            checkpoint.save(
                False,
                os.path.join(self.model_path,
                             self.model_name + '-{}.model'.format(epoch + 1)),
                os.path.join(self.model_path, best_model))
            metric_logging(self.log_path, self.model_name + '_metric_log.txt',
                           epoch, [avg_st, avg_sn, avg_pe])
            start_iter = 0.