Esempio n. 1
0
 def forward(self,Y,h,c, outEncoder,teacher_force):# Y это кол-во символов умножить на 256
     if (np.random.rand()>teacher_force):
         seq_len=Y.shape[0]-1
         output_decoder= load_to_cuda(torch.autograd.Variable(torch.zeros(seq_len, h.shape[1], 48)))
         Y = self.embedding(Y)
         for  i in range(len(Y)-1): # -1 так как sos не учитывем в criterion
             h[0],c[0] = self.lstm1(Y[i],(h[0].clone(),c[0].clone()))
             h[1],c[1] = self.lstm2(h[0].clone(),(h[1].clone(),c[1].clone()))
             h[2],c[2] = self.lstm3(h[1].clone(),(h[2].clone(),c[2].clone()))
             h2=h[2].clone()
             context = self.attention(h2, outEncoder,BATCH_SIZE)
             context =  torch.bmm( context,outEncoder.view(outEncoder.shape[1],outEncoder.shape[0],-1) )
            # print("context",context.shape) # torch sueeze
             output_decoder[i] = self.MLP(torch.cat( (h2,torch.squeeze(context,1)) ,1 ))    
     else:
         seq_len=Y.shape[0]-1
         output_decoder= load_to_cuda(torch.autograd.Variable(torch.zeros(seq_len, h.shape[1], 48)))
         alphabet = Alphabet()
         Y_cur = self.embedding( load_to_cuda(Variable(torch.LongTensor([alphabet.ch2index('<sos>')]))) ).view(1,self.hidden_size)
         for  i in range(seq_len-1):
             Y_cur=Y_cur.expand(BATCH_SIZE,self.hidden_size)
             h[0],c[0] = self.lstm1(Y_cur,(h[0].clone(),c[0].clone()))
             h[1],c[1] = self.lstm2(h[0].clone(),(h[1].clone(),c[1].clone()))
             h[2],c[2] = self.lstm3(h[1].clone(),(h[2].clone(),c[2].clone()))
             h2 = h[2].clone()
             context = self.attention(h2, outEncoder,BATCH_SIZE)
             context = torch.bmm( context,outEncoder.view(outEncoder.shape[1],outEncoder.shape[0],-1) )
             output_decoder[i]  =  self.MLP(torch.cat( (h2,torch.squeeze(context,1)) ,1 ))
             argmax = torch.max(output_decoder[i][0],dim=0)
             Y_cur=self.embedding( Variable(load_to_cuda(torch.LongTensor([argmax[1][0].data[0]]))) ).view(1,self.hidden_size)
     return output_decoder 
Esempio n. 2
0
    def evaluate(self,h,c,outEncoder,max_len=-1): # sos в return быть не должно
     #   h = load_to_cuda(torch.squeeze(h0.clone(),0))
     #   c = load_to_cuda(torch.squeeze(c0.clone(),0))
        h = h.view(h.shape[1],h.shape[0],-1).clone()
        c = c.view(c.shape[1],c.shape[0],-1).clone()
        if max_len==-1:
            seq_len = 50# максимальная длина
        else:
            seq_len=max_len
        result = load_to_cuda(torch.FloatTensor(seq_len,1,48).zero_())
        if (len(outEncoder.shape))!=3:
            print("размерность encoderOut неправильная")
            return result, result[0], False
        alphabet = Alphabet()
        listArgmax=[]# буквы, которые выдал
        Y_cur = self.embedding( load_to_cuda(Variable(torch.LongTensor([alphabet.ch2index('<sos>')]))) ).view(1,self.hidden_size)
        for  i in range(seq_len-1):
            h[0],c[0] = self.lstm1(Y_cur,(h[0].clone(),c[0].clone()))
            h[1],c[1] = self.lstm2(h[0],(h[1].clone(),c[1].clone()))
            h[2],c[2] = self.lstm3(h[1].clone(),(h[2].clone(),c[2].clone()))
            context = self.attention(h[2].clone(), outEncoder.view(outEncoder.shape[1],outEncoder.shape[0],-1),1)
            context = torch.bmm(context,outEncoder)         
            char = self.MLP( torch.cat( (h[2].clone(),context.view(1,self.hidden_size)),1 ) )
            result[i] = char.data
            argmax = torch.max(result[i][0],dim=0)
            listArgmax.append(argmax[1][0])
            if argmax[1][0] == alphabet.ch2index('<eos>'):
               seq_len=i+1
               break
            Y_cur=self.embedding( Variable(load_to_cuda(torch.LongTensor([argmax[1][0]]))) ).view(1,self.hidden_size)

        word=get_word(torch.LongTensor(listArgmax))
 #       print("res:",word)
     #   with open('log2/result.txt', 'a') as f:
     #            f.write("res:"+word+'\n')
     #            print("res:",word)
        return result[:seq_len],word, True        
Esempio n. 3
0
class LipsDataset(data.Dataset):
    """Lips custom Dataset"""
    def __init__(self, frame_dir):
        self.frame_dir = frame_dir
        self.alphabet = Alphabet()
        # self.words = [name for name in os.listdir(FRAME_DIR)]

        # для сквозного прохода по папкам с видео
        self.words = []
        for root, dirs, files in os.walk(self.frame_dir):
            if not dirs:
                self.words.append(root)

            # print('root: ', root)
            # print('dirs: ', dirs)
            # print('files: ', files)

    # print(self.words)
        self.count = 0

    def __len__(self):
        return len(self.words)

    def __getitem__(self, index):

        # загружаем все кадры для слова
        curr_dir = self.words[index]
        frames_list = [
            name for name in os.listdir(curr_dir) if not re.match(r'__', name)
        ]
        if len(frames_list) < COUNT_FRAMES:
            #print(frames_list)

            is_valid = False
        else:
            is_valid = True

        frames = np.zeros((len(frames_list), 120, 120))
        count = 0
        for frame in frames_list:
            frame = np.array(
                Image.open(os.path.join(
                    curr_dir, frame)).convert(mode='L').getdata()).reshape(
                        (120, 120))
            frames[count] = frame
            count += 1
        frames = torch.from_numpy(frames)

        # разбиваем на батчи
        if is_valid:
            frames = make_batches(frames)

        # загружаем субтитры
        subs_path = [
            name for name in os.listdir(curr_dir) if re.match(r'__', name)
        ][0]
        with open(os.path.join(curr_dir, subs_path), 'r') as subs_file:
            subs = str(json.loads(subs_file.read())['word']).lower()
        characters = list()
        characters.append(self.alphabet.ch2index('<sos>'))
        for ch in subs:
            if self.alphabet.ch2index(ch) is None:
                is_valid = False
                break
            characters.append(self.alphabet.ch2index(ch))
        characters.append(self.alphabet.ch2index('<eos>'))

        targets = torch.LongTensor(characters)
        #print('get_item - targets: ', targets)
        return frames, targets, is_valid