Exemplo n.º 1
0
    def _get_distance(self, targets, y_hats):
        """
        Provides total character distance between targets & y_hats

        Args:
            targets (torch.Tensor): set of ground truth
            y_hats (torch.Tensor): predicted y values (y_hat) by the model

        Returns: total_dist, total_length
            - **total_dist**: total distance between targets & y_hats
            - **total_length**: total length of targets sequence
        """
        total_dist = 0
        total_length = 0

        for (target, y_hat) in zip(targets, y_hats):
            s1 = label_to_string(target, self.id2char, self.eos_id)
            s2 = label_to_string(y_hat, self.id2char, self.eos_id)

            dist, length = self.calc_error_rate(s1, s2)

            total_dist += dist
            total_length += length

        return total_dist, total_length
Exemplo n.º 2
0
    def search(self, model: nn.Module, queue: Queue, device: str, print_every: int) -> float:
        cer = 0
        total_sent_num = 0
        timestep = 0

        model.eval()

        with torch.no_grad():
            while True:
                inputs, targets, input_lengths, target_lengths = queue.get()
                if inputs.shape[0] == 0:
                    break

                inputs = inputs.to(device)
                targets = targets.to(device)

                output = model(inputs, input_lengths, teacher_forcing_ratio=0.0, return_decode_dict=False)
                logit = torch.stack(output, dim=1).to(device)
                pred = logit.max(-1)[1]

                for idx in range(targets.size(0)):
                    self.target_list.append(label_to_string(targets[idx], id2char, EOS_token))
                    self.predict_list.append(label_to_string(pred[idx].cpu().detach().numpy(), id2char, EOS_token))

                cer = self.metric(targets[:, 1:], pred)
                total_sent_num += targets.size(0)

                if timestep % print_every == 0:
                    logger.info('cer: {:.2f}'.format(cer))

                timestep += 1

        return cer
Exemplo n.º 3
0
    def search(self, model, queue, device, print_every):
        cer = 0
        total_sent_num = 0
        timestep = 0

        model.eval()

        with torch.no_grad():
            while True:
                inputs, scripts, input_lengths, target_lengths = queue.get()
                if inputs.shape[0] == 0:
                    break

                inputs = inputs.to(device)
                scripts = scripts.to(device)
                targets = scripts[:, 1:]

                output, _ = model(inputs,
                                  input_lengths,
                                  teacher_forcing_ratio=0.0,
                                  language_model=self.language_model)

                logit = torch.stack(output, dim=1).to(device)
                hypothesis = logit.max(-1)[1]

                for idx in range(targets.size(0)):
                    self.target_list.append(
                        label_to_string(scripts[idx], id2char, EOS_token))
                    self.hypothesis_list.append(
                        label_to_string(hypothesis[idx].cpu().detach().numpy(),
                                        id2char, EOS_token))

                cer = self.metric(targets, hypothesis)
                total_sent_num += scripts.size(0)

                if timestep % print_every == 0:
                    logger.info('cer: {:.2f}'.format(cer))

                timestep += 1

        return cer
Exemplo n.º 4
0
def index():
    global show_graph

    # If hit play button
    if request.method == 'POST':
        if os.path.isfile(AUDIO_TO_PLAY_PATH):
            os.remove(AUDIO_TO_PLAY_PATH)

        file = request.files['file']
        uploaded_file_path = UPLOAD_FOLDER + file.filename
        is_valid, extension = allowed_file(file.filename)  # check condition

        if is_valid:
            filename = secure_filename(file.filename)
            file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))

            # Convert format
            if extension.lower() == 'pcm':
                pcm2wav(uploaded_file_path, AUDIO_TO_PLAY_PATH)
            elif extension.lower() == 'wav':
                convert2pcm(uploaded_file_path, AUDIO_TO_PLAY_PATH)

            # Extract feature & Inference by model
            spectrogram = parse_audio('./uploaded_audio/%s' % filename)
            output = model(spectrogram.unsqueeze(0), torch.IntTensor([len(spectrogram)]), teacher_forcing_ratio=0.0)[0]
            logit = torch.stack(output, dim=1).to(DEVICE)
            y_hat = logit.max(-1)[1]
            prediction = str(label_to_string(y_hat, id2char, EOS_token)[0])

            os.remove(uploaded_file_path)

            return render_template('uploaded.html',
                                   audio_path='.%s' % AUDIO_TO_PLAY_PATH,
                                   prediction=prediction)
    # Root page
    return render_template('homepage.html')
Exemplo n.º 5
0
def index():
    global show_graph

    # If hit play button
    if request.method == 'POST':
        if os.path.isfile(AUDIO_TO_PLAY_PATH):
            os.remove(AUDIO_TO_PLAY_PATH)

        file = request.files['file']
        uploaded_file_path = UPLOAD_FOLDER + file.filename
        is_valid, extension = allowed_file(file.filename)  # check condition

        if is_valid:
            filename = secure_filename(file.filename)
            file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))

            # Convert format
            if extension.lower() == 'pcm':
                pcm2wav(uploaded_file_path, AUDIO_TO_PLAY_PATH)
            elif extension.lower() == 'wav':
                convert2pcm(uploaded_file_path, AUDIO_TO_PLAY_PATH)

            # Extract feature & Inference by model
            spectrogram = parse_audio('./uploaded_audio/%s' % filename)
            output = model(spectrogram.unsqueeze(0),
                           torch.IntTensor([len(spectrogram)]),
                           teacher_forcing_ratio=0.0)[0]
            logit = torch.stack(output, dim=1).to(DEVICE)
            y_hat = logit.max(-1)[1]
            prediction = label_to_string(y_hat, id2char, EOS_token)
            os.remove(uploaded_file_path)

            # Determine destination device & command
            order = milestone(prediction[0])
            print(order)

            if order is not None:
                # Launch socket to light
                if order[0] == str(LIGHT):
                    try:  # If not connected
                        light_socket.connect(light_addr)
                        light_socket.send(order[1:].encode())
                    except:  # If already connected
                        light_socket.send(order[1:].encode())

                # Launch socket to airconditioner
                elif order[0] == str(AIRCONDITIONER):
                    try:  # If not connected
                        air_socket.connect(fine_dust_addr)
                        air_socket.send(order[3:].encode())
                    except:  # If already connected
                        air_socket.send(order[3:].encode())

                    # Receive ack (acknowledge)
                    ack = air_socket.recv(65535).decode()
                    print(ack)

                # Show graph
                elif order[0] == str(GRAPH):
                    if order[3:] == str(SHOW):
                        show_graph = True
                    elif order[3:] == str(OFF):
                        show_graph = False

            # Play page
            if show_graph:
                return render_template('display_graph.html',
                                       audio_path='.%s' % AUDIO_TO_PLAY_PATH,
                                       prediction=str(prediction[0]))
            else:
                return render_template('uploaded.html',
                                       audio_path='.%s' % AUDIO_TO_PLAY_PATH,
                                       prediction=str(prediction[0]))
    # Root page
    return render_template('homepage.html')
Exemplo n.º 6
0
        window_type='hamming').transpose(0, 1).numpy()
    feature_vector -= feature_vector.mean()
    feature_vector = Tensor(feature_vector).transpose(0, 1)

    return feature_vector


parser = argparse.ArgumentParser(description='Run Pretrain')
parser.add_argument('--model_path', type=str, default='../pretrain/model.pt')
parser.add_argument('--audio_path',
                    type=str,
                    default='../pretrain/sample_audio.pcm')
parser.add_argument('--device', type=str, default='cuda')
opt = parser.parse_args()

feature_vector = parse_audio(opt.audio_path, del_silence=True)
input_length = torch.IntTensor([len(feature_vector)])

model = load_test_model(opt, opt.device)
model.eval()

output = model(inputs=feature_vector.unsqueeze(0),
               input_lengths=input_length,
               teacher_forcing_ratio=0.0,
               return_decode_dict=False)
logit = torch.stack(output, dim=1).to(opt.device)
pred = logit.max(-1)[1]

sentence = label_to_string(pred.cpu().detach().numpy(), id2char, EOS_token)
print(sentence)