-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
134 lines (102 loc) · 4.38 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#%%
import torch.utils.data as data
import os, os.path
import re
from PIL import Image
import numpy as np
import torch
import json
from config import *
from alphabet import Alphabet
class LipsDataset(data.Dataset):
"""Lips custom Dataset"""
def __init__(self, frame_dir):
self.frame_dir = frame_dir
self.alphabet = Alphabet()
# self.words = [name for name in os.listdir(FRAME_DIR)]
# для сквозного прохода по папкам с видео
self.words = []
for root, dirs, files in os.walk(self.frame_dir):
if not dirs:
self.words.append(root)
# print('root: ', root)
# print('dirs: ', dirs)
# print('files: ', files)
# print(self.words)
self.count = 0
def __len__(self):
return len(self.words)
def __getitem__(self, index):
# загружаем все кадры для слова
curr_dir = self.words[index]
frames_list = [name for name in os.listdir(curr_dir) if not re.match(r'__', name)]
if len(frames_list) < COUNT_FRAMES:
#print(frames_list)
is_valid = False
else:
is_valid = True
frames = np.zeros((len(frames_list), 120, 120))
count = 0
for frame in frames_list:
frame = np.array(Image.open(os.path.join(curr_dir, frame)).convert(mode='L').getdata()).reshape((120, 120))
frames[count] = frame
count += 1
frames = torch.from_numpy(frames)
# разбиваем на батчи
if is_valid:
frames = make_batches(frames)
# загружаем субтитры
subs_path = [name for name in os.listdir(curr_dir) if re.match(r'__', name)][0]
with open(os.path.join(curr_dir, subs_path), 'r') as subs_file:
subs = str(json.loads(subs_file.read())['word']).lower()
characters = list()
characters.append(self.alphabet.ch2index('<sos>'))
for ch in subs:
if self.alphabet.ch2index(ch) is None:
is_valid = False
break
characters.append(self.alphabet.ch2index(ch))
characters.append(self.alphabet.ch2index('<eos>'))
targets = torch.LongTensor(characters)
#print('get_item - targets: ', targets)
return frames, targets, is_valid
def collate_fn(data):
frames, targets, is_valid = zip(*data)
# print('collate_fn - raw targets: ', targets)
#print('collate_fn - raw frames shape: ', frames[0].shape)
targets_lengths = [len(target) for target in targets]
# print('collate_fn - targets_lengths: ', targets_lengths)
batch_targets = torch.zeros(len(targets), max(targets_lengths)).long()
for i, target in enumerate(targets):
end = targets_lengths[i]
batch_targets[i, :end] = target[:end]
# print('collate_fn - batch_targets: ', batch_targets)
frames_lengths = [frame.shape[0] for frame in frames]
# print('collate_fn - frames_lengths: ', frames_lengths)
batch_frames = torch.zeros(len(frames), max(frames_lengths), COUNT_FRAMES, 120, 120).long()
for i, frame in enumerate(frames):
end = frames_lengths[i]
batch_frames[i, :end] = frame[:end]
# print('collate_fn - batch_frames: ', batch_frames.shape)
return batch_frames, batch_targets # batch_targets.shape = BATCH_SIZE*max_targets_length
# batch_frames.shape = BATCH_SIZE*max_frames_length*5*120*120
def get_loader(frame_dir):
lips_dataset = LipsDataset(frame_dir)
data_loader = torch.utils.data.DataLoader(dataset=lips_dataset, num_workers=12,
collate_fn=collate_fn, batch_size=BATCH_SIZE,drop_last=True)
# print(data_loader)
return data_loader
def get_loader_evaluate(frame_dir):
lips_dataset = LipsDataset(frame_dir)
data_loader = torch.utils.data.DataLoader(dataset=lips_dataset, num_workers=4)
# print(data_loader)
return data_loader
def make_batches(data_tensor, COUNT_FRAMES=COUNT_FRAMES):
new_size = data_tensor.shape[0] - COUNT_FRAMES + 1
# print('new size: ', new_size)
new_data_tensor = torch.FloatTensor(new_size, 5, 120, 120).zero_()
# print(new_data_tensor)
for i in range(new_size):
new_data_tensor[i] = data_tensor[i:i+5]
# print(new_data_tensor)
return new_data_tensor