Exemple #1
0
def run_main():
    """
    Will Train a network and save all the batch gradients in a file (pickled dict with :
    key:value => conv1.weight : array [batch x chan x h x w]
    :return:
    """
    gradient_save_path = "gradients/test"
    os.makedirs(gradient_save_path, exist_ok=True)

    device = torch.device("cuda")
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=256,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=256,
                                              shuffle=True)

    model = Net(5).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    # save initial weights
    torch.save(model.state_dict(),
               os.path.join(gradient_save_path, "weights.pth"))
    training_curve = []
    for epoch in range(1, 6):
        gradients = record_grad(model,
                                device,
                                train_loader,
                                optimizer,
                                n_steps=300)
        # save epoch gradients
        pickle.dump(
            gradients,
            open(
                os.path.join(gradient_save_path, "epoch_{}.pkl".format(epoch)),
                "bw"))
        # save test results
        loss, accuracy = test(model, device, test_loader)
        training_curve.append([loss, accuracy])
    np.savetxt(os.path.join(gradient_save_path, "train_original.txt"),
               np.array(training_curve))
Exemple #2
0
def train(ini_file):
    ''' Performs training according to .ini file

    :param ini_file: (String) the path of .ini file
    :return best_c_index: the best c-index
    '''
    # reads configuration from .ini file
    config = read_config(ini_file)
    # builds network|criterion|optimizer based on configuration
    model = Net(config['network']).to(device)
    criterion = Criterion(config['network'], device).to(device)
    optimizer = eval('optim.{}'.format(config['train']['optimizer']))(
        model.parameters(), lr=config['train']['learning_rate'])
    # constructs data loaders based on configuration
    train_dataset = MakeDataset(config['train']['h5_file'], is_train=True, device=device)
    test_dataset = MakeDataset(config['train']['h5_file'], is_train=False, device=device)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=train_dataset.__len__())
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=test_dataset.__len__())
    # training
    _best_acc = 0.70
    best_acc = 0.65
    best_ep = 0
    flag = 0
    _best_auc = 0
    best_auc = 0
    best_roc = None
    for epoch in range(1, config['train']['epochs'] + 1):
        # adjusts learning rate
        lr = adjust_learning_rate(optimizer, epoch,
                                  config['train']['learning_rate'],
                                  config['train']['lr_decay_rate'])
        # train step
        model.train()
        for X, y in train_loader:
            # makes predictions
            pred = model(X)
            train_loss = criterion(pred, y, model)
            train_FPR, train_TPR, train_ACC, train_roc, train_roc_auc, _, _, _, _ = Auc(pred, y)
            # updates parameters
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
        # valid step
        model.eval()
        for X, y in test_loader:
            # makes predictions
            with torch.no_grad():
                pred = model(X)
                # print(pred, y)
                valid_loss = criterion(pred, y, model)
                valid_FPR, valid_TPR, valid_ACC, valid_roc, valid_roc_auc, _, _, _, _ = Auc(pred, y)
                if valid_ACC > best_acc and train_ACC > _best_acc:
                    flag = 0
                    best_acc = valid_ACC
                    _best_acc = train_ACC
                    best_ep = epoch
                    best_auc = valid_roc_auc
                    _best_auc = train_roc_auc
                    best_roc = valid_roc
                    # saves the best model
                    torch.save({
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'epoch': epoch}, os.path.join(models_dir, ini_file.split('\\')[-1] + '.pth'))
                else:
                    flag += 1
                    if flag >= patience:
                        print('epoch: {}\t{:.8f}({:.8f})'.format(best_ep, _best_acc, best_acc))
                        if best_roc is not None:
                            plt.plot(best_roc[:, 0], best_roc[:, 1])
                            plt.title('ep:{}  AUC: {:.4f}({:.4f}) ACC: {:.4f}({:.4f})'.format(best_ep, _best_auc, best_auc, _best_acc, best_acc))
                            plt.show()
                        return best_acc, _best_acc
        # notes that, train loader and valid loader both have one batch!!!
        print('\rEpoch: {}\tLoss: {:.8f}({:.8f})\tACC: {:.8f}({:.8f})\tAUC: {}({})\tFPR: {:.8f}({:.8f})\tTPR: {:.8f}({:.8f})\tlr: {:g}\n'.format(
            epoch, train_loss.item(), valid_loss.item(), train_ACC, valid_ACC, train_roc_auc, valid_roc_auc, train_FPR, valid_FPR, train_TPR, valid_TPR, lr), end='', flush=False)
    return best_acc, _best_acc
    def train(self, args):
        with open(args.train_list, 'r') as train_list_file:
            self.train_list = [line.strip() for line in train_list_file.readlines()]
        self.eval_file = args.eval_file
        self.num_train_sentences = args.num_train_sentences
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.max_epoch = args.max_epoch
        self.model_path = args.model_path
        self.log_path = args.log_path
        self.fig_path = args.fig_path
        self.eval_plot_num = args.eval_plot_num
        self.eval_steps = args.eval_steps
        self.resume_model = args.resume_model
        self.wav_path = args.wav_path
        self.train_wav_path = args.train_wav_path
        self.tool_path = args.tool_path

        # create a training dataset and an evaluation dataset
        trainSet = TrainingDataset(self.train_list,
                                   frame_size=self.frame_size,
                                   frame_shift=self.frame_shift)
        evalSet = EvalDataset(self.eval_file,
                              self.num_test_sentences)
        # trainSet = evalSet
        # create data loaders for training and evaluation
        train_loader = DataLoader(trainSet,
                                  batch_size=self.batch_size,
                                  shuffle=True,
                                  num_workers=16,
                                  collate_fn=TrainCollate())

        eval_loader = DataLoader(evalSet,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=EvalCollate())

        # create a network
        print('model', self.model_name)
        net = Net(device=self.device, L=self.frame_size, width=self.width)
        # net = torch.nn.DataParallel(net)
        net.to(self.device)
        print('Number of learnable parameters: %d' % numParams(net))
        print(net)

        criterion = mse_loss()
        criterion1 = stftm_loss(device=self.device)
        optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
        self.lr_list = [0.0002] * 3 + [0.0001] * 6 + [0.00005] * 3 + [0.00001] * 3
        if self.resume_model:
            print('Resume model from "%s"' % self.resume_model)
            checkpoint = Checkpoint()
            checkpoint.load(self.resume_model)
            start_epoch = checkpoint.start_epoch
            start_iter = checkpoint.start_iter
            best_loss = checkpoint.best_loss
            net.load_state_dict(checkpoint.state_dict)
            optimizer.load_state_dict(checkpoint.optimizer)
        else:
            print('Training from scratch.')
            start_epoch = 0
            start_iter = 0
            best_loss = np.inf

        num_train_batches = self.num_train_sentences // self.batch_size
        total_train_batch = self.max_epoch * num_train_batches
        print('num_train_sentences', self.num_train_sentences)
        print('batches_per_epoch', num_train_batches)
        print('total_train_batch', total_train_batch)
        print('batch_size', self.batch_size)
        print('model_name', self.model_name)
        batch_timings = 0.
        counter = int(start_epoch * num_train_batches + start_iter)
        counter1 = 0
        print('counter', counter)
        ttime = 0.
        cnt = 0.
        iteration = 0
        print('best_loss', best_loss)
        for epoch in range(start_epoch, self.max_epoch):
            accu_train_loss = 0.0
            net.train()
            for param_group in optimizer.param_groups:
                param_group['lr'] = self.lr_list[epoch]

            start = timeit.default_timer()
            for i, (features, labels, nframes, feat_size, label_size, get_filename) in enumerate(
                    train_loader):  # features:torch.Size([4, 1, 250, 512])
                iteration += 1
                labels_cpu = labels
                i += start_iter
                features, labels = features.to(self.device), labels.to(self.device)  # torch.Size([4, 1, 250, 512])

                loss_mask = compLossMask(labels, nframes=nframes)

                # forward + backward + optimize
                optimizer.zero_grad()

                outputs = net(features)  # torch.Size([4, 1, 64256])

                feature_maker = Fbank(sample_rate=16000, n_fft=400, n_mels=40)
                loss_fbank = 0

                for t in range(len(get_filename)):
                    reader = h5py.File(get_filename[t], 'r')
                    feature_asr = reader['noisy_raw'][:]
                    label_asr = reader['clean_raw'][:]

                    feat_asr_size = int(feat_size[t][0].item())
                    label_asr_size = int(label_size[t][0].item())

                    output_asr = self.train_asr_forward(feature_asr, net)
                    est_output_asr = output_asr[:feat_asr_size]
                    ideal_labels_asr = label_asr

                    # 保存train的wav
                    est_path = os.path.join(self.train_wav_path, '{}_est.wav'.format(t + 1))
                    ideal_path = os.path.join(self.train_wav_path, '{}_ideal.wav'.format(t + 1))
                    sf.write(est_path, normalize_wav(est_output_asr)[0], self.srate)
                    sf.write(ideal_path, normalize_wav(ideal_labels_asr)[0], self.srate)

                    # read wav
                    est_sig = sb.dataio.dataio.read_audio(est_path).unsqueeze(axis=0).to(self.device)
                    ideal_sig = sb.dataio.dataio.read_audio(ideal_path).unsqueeze(axis=0).to(self.device)
                    est_sig_feats = feature_maker(est_sig)
                    ideal_sig_feats = feature_maker(ideal_sig)

                    # fbank_loss
                    loss_fbank += F.mse_loss(est_sig_feats, ideal_sig_feats, True)

                loss_fbank /= 100 * len(get_filename)
                # print(loss_fbank)
                # loss_fbank = 1 / (1 + math.exp(loss_fbank))

                outputs = outputs[:, :, :labels.shape[-1]]

                loss1 = criterion(outputs, labels, loss_mask, nframes)
                loss2 = criterion1(outputs, labels, loss_mask, nframes)
                # print(loss1)
                # print(loss2)

                # loss = 0.8 * loss1 + 0.2 * loss2
                loss = 0.4 * loss1 + 0.1 * loss2 + 0.5 * loss_fbank

                loss.backward()
                optimizer.step()
                # calculate losses
                running_loss = loss.data.item()
                accu_train_loss += running_loss

                # train-loss show
                summary.add_scalar('Train Loss', accu_train_loss, iteration)

                cnt += 1.
                counter += 1
                counter1 += 1

                del loss, loss_fbank, loss1, loss2, outputs, loss_mask, features, labels
                end = timeit.default_timer()
                curr_time = end - start
                ttime += curr_time
                mtime = ttime / counter1
                print(
                    'iter = {}/{}, epoch = {}/{}, loss = {:.5f}, time/batch = {:.5f}, mtime/batch = {:.5f}'.format(
                        i + 1,
                        num_train_batches, epoch + 1, self.max_epoch, running_loss, curr_time, mtime))
                start = timeit.default_timer()
                if (i + 1) % self.eval_steps == 0:
                    start = timeit.default_timer()

                    avg_train_loss = accu_train_loss / cnt

                    avg_eval_loss = self.validate(net, eval_loader, iteration)

                    net.train()

                    print('Epoch [%d/%d], Iter [%d/%d]  ( TrainLoss: %.4f | EvalLoss: %.4f )' % (
                        epoch + 1, self.max_epoch, i + 1, self.num_train_sentences // self.batch_size,
                        avg_train_loss,
                        avg_eval_loss))

                    is_best = True if avg_eval_loss < best_loss else False
                    best_loss = avg_eval_loss if is_best else best_loss

                    checkpoint = Checkpoint(epoch, i, avg_train_loss, avg_eval_loss, best_loss, net.state_dict(),
                                            optimizer.state_dict())

                    model_name = self.model_name + '_latest.model'
                    best_model = self.model_name + '_best.model'
                    checkpoint.save(is_best, os.path.join(self.model_path, model_name),
                                    os.path.join(self.model_path, best_model))

                    logging(self.log_path, self.model_name + '_loss_log.txt', checkpoint, self.eval_steps)
                    # metric_logging(self.log_path, self.model_name +'_metric_log.txt', epoch+1, [avg_st, avg_sn, avg_pe])
                    accu_train_loss = 0.0
                    cnt = 0.

                    net.train()
                if (i + 1) % num_train_batches == 0:
                    break

        avg_st, avg_sn, avg_pe = self.validate_with_metrics(net, eval_loader)
        net.train()
        print('#' * 50)
        print('')
        print('After {} epoch the performance on validation score is a s follows:'.format(epoch + 1))
        print('')
        print('STOI: {:.4f}'.format(avg_st))
        print('SNR: {:.4f}'.format(avg_sn))
        print('PESQ: {:.4f}'.format(avg_pe))
        for param_group in optimizer.param_groups:
            print('learning_rate', param_group['lr'])
        print('')
        print('#' * 50)
        checkpoint = Checkpoint(epoch, 0, None, None, best_loss, net.state_dict(), optimizer.state_dict())
        checkpoint.save(False, os.path.join(self.model_path, self.model_name + '-{}.model'.format(epoch + 1)),
                        os.path.join(self.model_path, best_model))
        metric_logging(self.log_path, self.model_name + '_metric_log.txt', epoch, [avg_st, avg_sn, avg_pe])
        start_iter = 0.
Exemple #4
0
                      )

            _, y_pred = torch.max(output.data, 1)
            total += y_test.size(0)
            correct += (y_pred == y_test).sum().item()
            compute_confusion_matrix(cm, y_test, y_pred)

        accuracy = 100 * correct / total
        print(f"Accuracy: {accuracy}%")
        print(f"Confusion matrix:\n {cm}")

        # writer.add_scalar("test_accuracy", accuracy, epoch * batch_idx + batch_idx)
        # writer.add_scalar("on_epoch_test_loss", total_loss, epoch * batch_idx + batch_idx)

        return accuracy, cm


if __name__ == '__main__':
    best_accuracy = 0.0
    for epoch in range(0, 100):
        train(epoch)
        epoch_accuracy, epoch_cm = test()
        if best_accuracy < epoch_accuracy:
            best_accuracy = epoch_accuracy
            torch.save(net.state_dict(), f"./weights/weights_epoch{epoch}_accuracy{best_accuracy}.pth")
            plt.matshow(epoch_cm)
            plt.colorbar()
            plt.savefig(f"./images/test/cm_epoch{epoch}_accuracy{best_accuracy}.png")
            # plt.show()
    test()
Exemple #5
0
min_loss = args.min_loss
for epoch in range(args.epochs):
    model.train()
    for i, data in enumerate(train_loader):
        data = data.to(args.device)
        out = model(data)
        loss = F.nll_loss(out, data.y)
        print("Training loss:{}".format(loss.item()))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    val_acc, val_loss = test(model, val_loader)
    print("Validation loss:{}\taccuracy:{}".format(val_loss, val_acc))
    print("Epoch{}".format(epoch))
    if val_loss < min_loss:
        torch.save(model.state_dict(), 'latest.pth')
        print("Model saved at epoch{}".format(epoch))
        min_loss = val_loss
        patience = 0
    else:
        patience += 1
    if patience > args.patience:
        break

#test step
model = Net(args).to(args.device)
model.load_state_dict(torch.load('latest.pth'))
test_acc, test_loss = test(model, test_loader)
print("Test accuarcy:{}".format(test_acc))
save_result(test_acc, args.save_path)
Exemple #6
0
    def train(self, args):
        with open(args.train_list, 'r') as train_list_file:
            self.train_list = [
                line.strip() for line in train_list_file.readlines()
            ]
        self.eval_file = args.eval_file
        self.num_train_sentences = args.num_train_sentences
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.max_epoch = args.max_epoch
        self.model_path = args.model_path
        self.log_path = args.log_path
        self.fig_path = args.fig_path
        self.eval_plot_num = args.eval_plot_num
        self.eval_steps = args.eval_steps
        self.resume_model = args.resume_model
        self.wav_path = args.wav_path
        self.tool_path = args.tool_path

        # create a training dataset and an evaluation dataset
        trainSet = TrainingDataset(self.train_list,
                                   frame_size=self.frame_size,
                                   frame_shift=self.frame_shift)
        evalSet = EvalDataset(self.eval_file, self.num_test_sentences)
        #trainSet = evalSet
        # create data loaders for training and evaluation
        train_loader = DataLoader(trainSet,
                                  batch_size=self.batch_size,
                                  shuffle=True,
                                  num_workers=16,
                                  collate_fn=TrainCollate())
        eval_loader = DataLoader(evalSet,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=EvalCollate())

        # create a network
        print('model', self.model_name)
        net = Net(device=self.device, L=self.frame_size, width=self.width)
        #net = torch.nn.DataParallel(net)
        net.to(self.device)
        print('Number of learnable parameters: %d' % numParams(net))
        print(net)

        criterion = mse_loss()
        criterion1 = stftm_loss(device=self.device)
        optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
        self.lr_list = [0.0002] * 3 + [0.0001] * 6 + [0.00005] * 3 + [0.00001
                                                                      ] * 3
        if self.resume_model:
            print('Resume model from "%s"' % self.resume_model)
            checkpoint = Checkpoint()
            checkpoint.load(self.resume_model)
            start_epoch = checkpoint.start_epoch
            start_iter = checkpoint.start_iter
            best_loss = checkpoint.best_loss
            net.load_state_dict(checkpoint.state_dict)
            optimizer.load_state_dict(checkpoint.optimizer)
        else:
            print('Training from scratch.')
            start_epoch = 0
            start_iter = 0
            best_loss = np.inf

        num_train_batches = self.num_train_sentences // self.batch_size
        total_train_batch = self.max_epoch * num_train_batches
        print('num_train_sentences', self.num_train_sentences)
        print('batches_per_epoch', num_train_batches)
        print('total_train_batch', total_train_batch)
        print('batch_size', self.batch_size)
        print('model_name', self.model_name)
        batch_timings = 0.
        counter = int(start_epoch * num_train_batches + start_iter)
        counter1 = 0
        print('counter', counter)
        ttime = 0.
        cnt = 0.
        print('best_loss', best_loss)
        for epoch in range(start_epoch, self.max_epoch):
            accu_train_loss = 0.0
            net.train()
            for param_group in optimizer.param_groups:
                param_group['lr'] = self.lr_list[epoch]

            start = timeit.default_timer()
            for i, (features, labels, nframes) in enumerate(train_loader):
                i += start_iter
                features, labels = features.to(self.device), labels.to(
                    self.device)

                loss_mask = compLossMask(labels, nframes=nframes)

                # forward + backward + optimize
                optimizer.zero_grad()

                outputs = net(features)
                outputs = outputs[:, :, :labels.shape[-1]]

                loss1 = criterion(outputs, labels, loss_mask, nframes)
                loss2 = criterion1(outputs, labels, loss_mask, nframes)

                loss = 0.8 * loss1 + 0.2 * loss2
                loss.backward()
                optimizer.step()
                # calculate losses
                running_loss = loss.data.item()
                accu_train_loss += running_loss

                cnt += 1.
                counter += 1
                counter1 += 1

                del loss, loss1, loss2, outputs, loss_mask, features, labels
                end = timeit.default_timer()
                curr_time = end - start
                ttime += curr_time
                mtime = ttime / counter1
                print(
                    'iter = {}/{}, epoch = {}/{}, loss = {:.5f}, time/batch = {:.5f}, mtime/batch = {:.5f}'
                    .format(i + 1, num_train_batches, epoch + 1,
                            self.max_epoch, running_loss, curr_time, mtime))
                start = timeit.default_timer()
                if (i + 1) % self.eval_steps == 0:
                    start = timeit.default_timer()

                    avg_train_loss = accu_train_loss / cnt

                    avg_eval_loss = self.validate(net, eval_loader)

                    net.train()

                    print(
                        'Epoch [%d/%d], Iter [%d/%d]  ( TrainLoss: %.4f | EvalLoss: %.4f )'
                        % (epoch + 1, self.max_epoch, i + 1,
                           self.num_train_sentences // self.batch_size,
                           avg_train_loss, avg_eval_loss))

                    is_best = True if avg_eval_loss < best_loss else False
                    best_loss = avg_eval_loss if is_best else best_loss

                    checkpoint = Checkpoint(epoch, i, avg_train_loss,
                                            avg_eval_loss, best_loss,
                                            net.state_dict(),
                                            optimizer.state_dict())

                    model_name = self.model_name + '_latest.model'
                    best_model = self.model_name + '_best.model'
                    checkpoint.save(is_best,
                                    os.path.join(self.model_path, model_name),
                                    os.path.join(self.model_path, best_model))

                    logging(self.log_path, self.model_name + '_loss_log.txt',
                            checkpoint, self.eval_steps)
                    #metric_logging(self.log_path, self.model_name +'_metric_log.txt', epoch+1, [avg_st, avg_sn, avg_pe])
                    accu_train_loss = 0.0
                    cnt = 0.

                    net.train()
                if (i + 1) % num_train_batches == 0:
                    break

            avg_st, avg_sn, avg_pe = self.validate_with_metrics(
                net, eval_loader)
            net.train()
            print('#' * 50)
            print('')
            print(
                'After {} epoch the performance on validation score is a s follows:'
                .format(epoch + 1))
            print('')
            print('STOI: {:.4f}'.format(avg_st))
            print('SNR: {:.4f}'.format(avg_sn))
            print('PESQ: {:.4f}'.format(avg_pe))
            for param_group in optimizer.param_groups:
                print('learning_rate', param_group['lr'])
            print('')
            print('#' * 50)
            checkpoint = Checkpoint(epoch, 0, None, None, best_loss,
                                    net.state_dict(), optimizer.state_dict())
            checkpoint.save(
                False,
                os.path.join(self.model_path,
                             self.model_name + '-{}.model'.format(epoch + 1)),
                os.path.join(self.model_path, best_model))
            metric_logging(self.log_path, self.model_name + '_metric_log.txt',
                           epoch, [avg_st, avg_sn, avg_pe])
            start_iter = 0.