コード例 #1
0
class SampleDataset(Dataset):
    def __init__(self, file_pth: str, train=True, random_state=None):
        super(Dataset, self).__init__()

        self.sp = Spectrogram()

        self.file_pth = file_pth
        self.train = train
        self.random_state = None

        self.train = self.read_data()

    def __getitem__(self, index):

        # datasets[0][index][0]
        stage_layer, degc_layer = self.feature_layer(self.train[index][2])
        img = self.sp.spec_array(self.train[index][0])
        img = np.concatenate((img, stage_layer), axis=0)
        img = np.concatenate((img, degc_layer), axis=0)
        img = torch.tensor(img, dtype=torch.float32)
        target = self.train[index][1]
        target = torch.tensor(target)
        return img, target

    def __len__(self):
        return len(self.train)

    def feature_layer(self, features):

        stage = int(re.findall("\d+", features[0])[0])
        degc = float(features[1])

        stage_layer = np.full((224, 224), stage, dtype=np.int8)
        degc_layer = np.full((224, 224), degc, dtype=np.float)

        stage_layer = stage_layer[np.newaxis, :, :]
        degc_layer = degc_layer[np.newaxis, :, :]

        return stage_layer, degc_layer

    def read_data(self):
        datas = []
        #self.file = "./sample_data.txt"
        with open(self.file_pth, "r") as f:
            header = f.readline()
            while 1:
                line = f.readline()
                if not line:
                    break
                tmp = line.strip().split('\t')
                freq = list(map(float, tmp[4:]))
                features = tmp[2:4]
                # freq = list(map(float, tmp[1:]))
                # print(freq[-1])
                label = int(tmp[0])

                datas.append([freq, label, features])

        return datas
コード例 #2
0
ファイル: main.py プロジェクト: wsy8029/nvh
class TestDataset(Dataset):
    
    def __init__(self, file_pth: str, train=True, random_state=None):
        super(Dataset, self).__init__()
        
        self.sp = Spectrogram()
        
        self.file_pth = file_pth
        self.train = train
        self.random_state = None
        
        self.data = self.read_data()
        datasets = train_test_split(self.data)
        if train:
            self.train = datasets[0]
        else:
            # actually this is test dataset
            self.train = datasets[1]
    
    def __getitem__(self, index):
        
        # datasets[0][index][0]
        img = self.sp.spec_array(self.train[index][0])
        img = torch.tensor(img, dtype=torch.float32)
        target = self.train[index][1]
        target = torch.tensor(target)
        return img, target
    
    def __len__(self):
        return len(self.train)
    
    def read_data(self):
        datas = []
        #self.file = "./sample_data.txt"
        
        with open(self.file_pth, "r") as f:
            header = f.readline()
            while 1:
                line = f.readline()
                if not line:
                    break
                tmp = line.strip().split('\t')
                # freq = list(map(float, tmp[4:]))
                freq = list(map(float, tmp[1:]))
                # print(freq[-1])
                label = int(tmp[0])
                
                datas.append([freq,label])
                
        return datas