示例#1
0
def train(ini_file):
    ''' Performs training according to .ini file

    :param ini_file: (String) the path of .ini file
    :return best_c_index: the best c-index
    '''
    # reads configuration from .ini file
    config = read_config(ini_file)
    # builds network|criterion|optimizer based on configuration
    model = Net(config['network']).to(device)
    criterion = Criterion(config['network'], device).to(device)
    optimizer = eval('optim.{}'.format(config['train']['optimizer']))(
        model.parameters(), lr=config['train']['learning_rate'])
    # constructs data loaders based on configuration
    train_dataset = MakeDataset(config['train']['h5_file'], is_train=True, device=device)
    test_dataset = MakeDataset(config['train']['h5_file'], is_train=False, device=device)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=train_dataset.__len__())
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=test_dataset.__len__())
    # training
    _best_acc = 0.70
    best_acc = 0.65
    best_ep = 0
    flag = 0
    _best_auc = 0
    best_auc = 0
    best_roc = None
    for epoch in range(1, config['train']['epochs'] + 1):
        # adjusts learning rate
        lr = adjust_learning_rate(optimizer, epoch,
                                  config['train']['learning_rate'],
                                  config['train']['lr_decay_rate'])
        # train step
        model.train()
        for X, y in train_loader:
            # makes predictions
            pred = model(X)
            train_loss = criterion(pred, y, model)
            train_FPR, train_TPR, train_ACC, train_roc, train_roc_auc, _, _, _, _ = Auc(pred, y)
            # updates parameters
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
        # valid step
        model.eval()
        for X, y in test_loader:
            # makes predictions
            with torch.no_grad():
                pred = model(X)
                # print(pred, y)
                valid_loss = criterion(pred, y, model)
                valid_FPR, valid_TPR, valid_ACC, valid_roc, valid_roc_auc, _, _, _, _ = Auc(pred, y)
                if valid_ACC > best_acc and train_ACC > _best_acc:
                    flag = 0
                    best_acc = valid_ACC
                    _best_acc = train_ACC
                    best_ep = epoch
                    best_auc = valid_roc_auc
                    _best_auc = train_roc_auc
                    best_roc = valid_roc
                    # saves the best model
                    torch.save({
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'epoch': epoch}, os.path.join(models_dir, ini_file.split('\\')[-1] + '.pth'))
                else:
                    flag += 1
                    if flag >= patience:
                        print('epoch: {}\t{:.8f}({:.8f})'.format(best_ep, _best_acc, best_acc))
                        if best_roc is not None:
                            plt.plot(best_roc[:, 0], best_roc[:, 1])
                            plt.title('ep:{}  AUC: {:.4f}({:.4f}) ACC: {:.4f}({:.4f})'.format(best_ep, _best_auc, best_auc, _best_acc, best_acc))
                            plt.show()
                        return best_acc, _best_acc
        # notes that, train loader and valid loader both have one batch!!!
        print('\rEpoch: {}\tLoss: {:.8f}({:.8f})\tACC: {:.8f}({:.8f})\tAUC: {}({})\tFPR: {:.8f}({:.8f})\tTPR: {:.8f}({:.8f})\tlr: {:g}\n'.format(
            epoch, train_loss.item(), valid_loss.item(), train_ACC, valid_ACC, train_roc_auc, valid_roc_auc, train_FPR, valid_FPR, train_TPR, valid_TPR, lr), end='', flush=False)
    return best_acc, _best_acc
示例#2
0
    def test(self, args):
        with open(args.test_list, 'r') as test_list_file:
            self.test_list = [line.strip() for line in test_list_file.readlines()]
        self.model_name = args.model_name
        self.model_file = args.model_file
        self.test_mixture_path = args.test_mixture_path
        self.prediction_path = args.prediction_path

        # create a network
        print('model', self.model_name)
        net = Net(device=self.device, L=self.frame_size, width=self.width)
        # net = torch.nn.DataParallel(net)
        net.to(self.device)
        print('Number of learnable parameters: %d' % numParams(net))
        print(net)
        # loss and optimizer
        criterion = mse_loss()
        net.eval()
        print('Load model from "%s"' % self.model_file)
        checkpoint = Checkpoint()
        checkpoint.load(self.model_file)
        net.load_state_dict(checkpoint.state_dict)
        with torch.no_grad():
            for i in range(len(self.test_list)):
                # read the mixture for resynthesis
                filename_input = self.test_list[i].split('/')[-1]
                start1 = timeit.default_timer()
                print('{}/{}, Started working on {}.'.format(i + 1, len(self.test_list), self.test_list[i]))
                print('')
                filename_mix = filename_input.replace('.samp', '_mix.dat')

                filename_s_ideal = filename_input.replace('.samp', '_s_ideal.dat')
                filename_s_est = filename_input.replace('.samp', '_s_est.dat')
                # print(filename_mix)
                # sys.exit()
                f_mix = h5py.File(os.path.join(self.test_mixture_path, filename_mix), 'r')
                f_s_ideal = h5py.File(os.path.join(self.prediction_path, filename_s_ideal), 'w')
                f_s_est = h5py.File(os.path.join(self.prediction_path, filename_s_est), 'w')
                # create a test dataset
                testSet = EvalDataset(os.path.join(self.test_mixture_path, self.test_list[i]),
                                      self.num_test_sentences)

                # create a data loader for test
                test_loader = DataLoader(testSet,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=2,
                                         collate_fn=EvalCollate())

                # print '\n[%d/%d] Predict on %s' % (i+1, len(self.test_list), self.test_list[i])

                accu_test_loss = 0.0
                accu_test_nframes = 0

                ttime = 0.
                mtime = 0.
                cnt = 0.
                for k, (mix_raw, cln_raw) in enumerate(test_loader):
                    start = timeit.default_timer()
                    est_s = self.eval_forward(mix_raw, net)
                    est_s = est_s[:mix_raw.size]
                    mix = f_mix[str(k)][:]

                    ideal_s = cln_raw

                    f_s_ideal.create_dataset(str(k), data=ideal_s.astype(np.float32), chunks=True)
                    f_s_est.create_dataset(str(k), data=est_s.astype(np.float32), chunks=True)
                    # compute eval_loss

                    test_loss = np.mean((est_s - ideal_s) ** 2)

                    accu_test_loss += test_loss
                    cnt += 1
                    end = timeit.default_timer()
                    curr_time = end - start
                    ttime += curr_time
                    mtime = ttime / cnt
                    mtime = (mtime * (k) + (end - start)) / (k + 1)
                    print('{}/{}, test_loss = {:.4f}, time/utterance = {:.4f}, '
                          'mtime/utternace = {:.4f}'.format(k + 1, self.num_test_sentences, test_loss, curr_time,
                                                            mtime))

                avg_test_loss = accu_test_loss / cnt
                # bar.update(k,test_loss=avg_test_loss)
                # bar.finish()
                end1 = timeit.default_timer()
                print('********** Finisehe working on {}. time taken = {:.4f} **********'.format(filename_input,
                                                                                                 end1 - start1))
                print('')
                f_mix.close()
                f_s_est.close()
                f_s_ideal.close()