def forward(self,Y,h,c, outEncoder,teacher_force):# Y это кол-во символов умножить на 256 if (np.random.rand()>teacher_force): seq_len=Y.shape[0]-1 output_decoder= load_to_cuda(torch.autograd.Variable(torch.zeros(seq_len, h.shape[1], 48))) Y = self.embedding(Y) for i in range(len(Y)-1): # -1 так как sos не учитывем в criterion h[0],c[0] = self.lstm1(Y[i],(h[0].clone(),c[0].clone())) h[1],c[1] = self.lstm2(h[0].clone(),(h[1].clone(),c[1].clone())) h[2],c[2] = self.lstm3(h[1].clone(),(h[2].clone(),c[2].clone())) h2=h[2].clone() context = self.attention(h2, outEncoder,BATCH_SIZE) context = torch.bmm( context,outEncoder.view(outEncoder.shape[1],outEncoder.shape[0],-1) ) # print("context",context.shape) # torch sueeze output_decoder[i] = self.MLP(torch.cat( (h2,torch.squeeze(context,1)) ,1 )) else: seq_len=Y.shape[0]-1 output_decoder= load_to_cuda(torch.autograd.Variable(torch.zeros(seq_len, h.shape[1], 48))) alphabet = Alphabet() Y_cur = self.embedding( load_to_cuda(Variable(torch.LongTensor([alphabet.ch2index('<sos>')]))) ).view(1,self.hidden_size) for i in range(seq_len-1): Y_cur=Y_cur.expand(BATCH_SIZE,self.hidden_size) h[0],c[0] = self.lstm1(Y_cur,(h[0].clone(),c[0].clone())) h[1],c[1] = self.lstm2(h[0].clone(),(h[1].clone(),c[1].clone())) h[2],c[2] = self.lstm3(h[1].clone(),(h[2].clone(),c[2].clone())) h2 = h[2].clone() context = self.attention(h2, outEncoder,BATCH_SIZE) context = torch.bmm( context,outEncoder.view(outEncoder.shape[1],outEncoder.shape[0],-1) ) output_decoder[i] = self.MLP(torch.cat( (h2,torch.squeeze(context,1)) ,1 )) argmax = torch.max(output_decoder[i][0],dim=0) Y_cur=self.embedding( Variable(load_to_cuda(torch.LongTensor([argmax[1][0].data[0]]))) ).view(1,self.hidden_size) return output_decoder
def evaluate(self,h,c,outEncoder,max_len=-1): # sos в return быть не должно # h = load_to_cuda(torch.squeeze(h0.clone(),0)) # c = load_to_cuda(torch.squeeze(c0.clone(),0)) h = h.view(h.shape[1],h.shape[0],-1).clone() c = c.view(c.shape[1],c.shape[0],-1).clone() if max_len==-1: seq_len = 50# максимальная длина else: seq_len=max_len result = load_to_cuda(torch.FloatTensor(seq_len,1,48).zero_()) if (len(outEncoder.shape))!=3: print("размерность encoderOut неправильная") return result, result[0], False alphabet = Alphabet() listArgmax=[]# буквы, которые выдал Y_cur = self.embedding( load_to_cuda(Variable(torch.LongTensor([alphabet.ch2index('<sos>')]))) ).view(1,self.hidden_size) for i in range(seq_len-1): h[0],c[0] = self.lstm1(Y_cur,(h[0].clone(),c[0].clone())) h[1],c[1] = self.lstm2(h[0],(h[1].clone(),c[1].clone())) h[2],c[2] = self.lstm3(h[1].clone(),(h[2].clone(),c[2].clone())) context = self.attention(h[2].clone(), outEncoder.view(outEncoder.shape[1],outEncoder.shape[0],-1),1) context = torch.bmm(context,outEncoder) char = self.MLP( torch.cat( (h[2].clone(),context.view(1,self.hidden_size)),1 ) ) result[i] = char.data argmax = torch.max(result[i][0],dim=0) listArgmax.append(argmax[1][0]) if argmax[1][0] == alphabet.ch2index('<eos>'): seq_len=i+1 break Y_cur=self.embedding( Variable(load_to_cuda(torch.LongTensor([argmax[1][0]]))) ).view(1,self.hidden_size) word=get_word(torch.LongTensor(listArgmax)) # print("res:",word) # with open('log2/result.txt', 'a') as f: # f.write("res:"+word+'\n') # print("res:",word) return result[:seq_len],word, True
class LipsDataset(data.Dataset): """Lips custom Dataset""" def __init__(self, frame_dir): self.frame_dir = frame_dir self.alphabet = Alphabet() # self.words = [name for name in os.listdir(FRAME_DIR)] # для сквозного прохода по папкам с видео self.words = [] for root, dirs, files in os.walk(self.frame_dir): if not dirs: self.words.append(root) # print('root: ', root) # print('dirs: ', dirs) # print('files: ', files) # print(self.words) self.count = 0 def __len__(self): return len(self.words) def __getitem__(self, index): # загружаем все кадры для слова curr_dir = self.words[index] frames_list = [ name for name in os.listdir(curr_dir) if not re.match(r'__', name) ] if len(frames_list) < COUNT_FRAMES: #print(frames_list) is_valid = False else: is_valid = True frames = np.zeros((len(frames_list), 120, 120)) count = 0 for frame in frames_list: frame = np.array( Image.open(os.path.join( curr_dir, frame)).convert(mode='L').getdata()).reshape( (120, 120)) frames[count] = frame count += 1 frames = torch.from_numpy(frames) # разбиваем на батчи if is_valid: frames = make_batches(frames) # загружаем субтитры subs_path = [ name for name in os.listdir(curr_dir) if re.match(r'__', name) ][0] with open(os.path.join(curr_dir, subs_path), 'r') as subs_file: subs = str(json.loads(subs_file.read())['word']).lower() characters = list() characters.append(self.alphabet.ch2index('<sos>')) for ch in subs: if self.alphabet.ch2index(ch) is None: is_valid = False break characters.append(self.alphabet.ch2index(ch)) characters.append(self.alphabet.ch2index('<eos>')) targets = torch.LongTensor(characters) #print('get_item - targets: ', targets) return frames, targets, is_valid