예제 #1
0
class CardiogramDataset(Dataset):
    def __init__(self, list_files, labels):
        self.helper = DataHelper()
        self.list_files = list_files  # list_files is the filename of training sample
        self.labels = labels

    def __getitem__(self, index):
        features = self.helper.get_features_from_txt(self.list_files[index])
        features = torch.FloatTensor(features)
        # print("index {} label {}".format(index, self.labels[index]))
        label = torch.FloatTensor([self.labels[index]])
        # print("float label {}".format(label))
        return features, label

    def __len__(self):
        return len(self.list_files)
예제 #2
0
    if type(m)==nn.Linear or type(m)==nn.Conv2d:
        torch.nn.init.xavier_normal_(m.weight)

device = torch.device('cuda:1')
net.apply(init_weights)
net.to(device)

def get_files():
    """
    :return:
    """
    train_path = Constant.TRAIN_DATA_PATH
    files = [f for f in os.listdir(train_path) if os.path.isfile(os.path.join(train_path, f))]
    return files

helper = DataHelper()
files = helper.files
num_labels = helper.get_num_label_i(0) ## indicator patient has the 0th arrythmia or not
# for file, num_label_0 in zip(files, num_labels):
#     print(file+"\t"+str(num_label_0))



datas, labels = k_fold(5, files, num_labels)
validation_data = datas[0]
validation_label = labels[0]

train_data = []
train_labels = []
for i in range(1, 5):
    train_data += datas[i]
예제 #3
0
class LSTM_RESHAPE(nn.Module):
    def forward(self, x):
        return x.view((-1, x.shape[2], x.shape[1]))


class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.Sequential(LSTM_RESHAPE(), )

    def forward(self, x):
        return x


helper = DataHelper()
train_iter, test_iter = helper.get_train_and_validation_iter()
for X, y in train_iter:
    x = X
    break

# print(x.shape)
# lstm_reshape = LSTM_RESHAPE()
# x = lstm_reshape(x)
# print(x.shape)
#
#
# fcn = FCN()
# x = fcn(x)
# print("x.shape: {}".format(x.shape))
예제 #4
0
 def __init__(self, list_files, labels):
     self.helper = DataHelper()
     self.list_files = list_files  # list_files is the filename of training sample
     self.labels = labels
예제 #5
0
        for j in range(fold_num):

            r = random.randint(0, len(list_files) - 1)
            tmp_datas.append(list_files[r])
            tmp_labels.append(labels[r])
            del list_files[r]
            del labels[r]
        k_datas.append(tmp_datas)
        k_labels.append(tmp_labels)

    k_datas.append(list_files)
    k_labels.append(labels)
    return k_datas, k_labels


helper = DataHelper()
files = helper.files
num_labels = helper.num_labels

datas, labels = k_fold(5, files, num_labels)
validation_data = datas[0]
validation_label = labels[0]

train_data = []
train_labels = []
for i in range(1, 5):
    train_data += datas[i]
    train_labels += labels[i]

dataset = CardiogramDataset(train_data, train_labels)
# print(len(dataset))