Exemplo n.º 1
0
    def start_training(self, itt_no_improve, batch_size, target_sample_rate):
        epoch = 1
        left_itt = itt_no_improve
        dio = DatasetIO()
        self._render_devset()
        sys.stdout.write("\n")
        self.vocoder.store('data/models/rnn_vocoder')
        while left_itt > 0:
            sys.stdout.write("Starting epoch " + str(epoch) + "\n")
            sys.stdout.write("Shuffling training data\n")
            from random import shuffle
            shuffle(self.trainset.files)
            file_index = 1
            total_loss = 0
            for file in self.trainset.files:
                sys.stdout.write("\t" + str(file_index) + "/" +
                                 str(len(self.trainset.files)) +
                                 " processing file " + file)
                sys.stdout.flush()
                wav_file = file + ".orig.wav"
                mgc_file = file + ".mgc.npy"
                mgc = np.load(mgc_file)
                file_index += 1
                data, sample_rate = dio.read_wave(wav_file)
                if self.use_ulaw:
                    [wave_disc, ulaw_cont] = dio.ulaw_encode(data)
                else:
                    wave_disc = dio.b16_enc(data)
                import time
                start = time.time()
                loss = self.vocoder.learn(wave_disc, mgc, batch_size)
                total_loss += loss
                stop = time.time()
                sys.stdout.write(' avg loss=' + str(loss) +
                                 " execution time=" + str(stop - start))
                sys.stdout.write('\n')
                sys.stdout.flush()
                if file_index % 50 == 0:
                    self.synth_devset(batch_size, target_sample_rate)
                    self.vocoder.store('data/models/rnn_vocoder')

            self.synth_devset(batch_size, target_sample_rate)
            self.vocoder.store('data/models/rnn_vocoder')

            epoch += 1
Exemplo n.º 2
0
    def start_training(self,
                       itt_no_improve,
                       batch_size,
                       target_sample_rate,
                       params=None):
        epoch = 1
        left_itt = itt_no_improve
        dio = DatasetIO()
        self._render_devset()
        sys.stdout.write("\n")

        if self.vocoder.sparse:
            print("Setting sparsity at: " + str(params.sparsity_step) + "%")
            sparsity = params.sparsity_step
            self.vocoder.rnnFine.set_sparsity(float(sparsity) / 100)
            self.vocoder.rnnCoarse.set_sparsity(float(sparsity) / 100)

        if self.vocoder.sparse:
            self.vocoder.store('data/models/rnn_vocoder_sparse')
        else:
            self.vocoder.store('data/models/rnn_vocoder')

        num_files = 0

        while left_itt > 0:
            sys.stdout.write("Starting epoch " + str(epoch) + "\n")
            sys.stdout.write("Shuffling training data\n")
            from random import shuffle
            shuffle(self.trainset.files)
            file_index = 1
            total_loss = 0
            for file in self.trainset.files:
                num_files += 1

                if num_files == params.sparsity_increase:
                    sparsity += params.sparsity_step
                    num_files = 0
                    if sparsity <= params.sparsity_target:
                        print("Setting sparsity at " + str(sparsity) + "%")
                        self.vocoder.rnnFine.set_sparsity(
                            float(sparsity) / 100)
                        self.vocoder.rnnCoarse.set_sparsity(
                            float(sparsity) / 100)
                    else:
                        sparsity = params.sparsity_target

                sys.stdout.write("\t" + str(file_index) + "/" +
                                 str(len(self.trainset.files)) +
                                 " processing file " + file)
                sys.stdout.flush()
                wav_file = file + ".orig.wav"
                mgc_file = file + ".mgc.npy"
                mgc = np.load(mgc_file)
                file_index += 1
                data, sample_rate = dio.read_wave(wav_file)
                if self.use_ulaw:
                    [wave_disc, ulaw_cont] = dio.ulaw_encode(data)
                else:
                    wave_disc = dio.b16_enc(data)
                import time
                start = time.time()
                loss = self.vocoder.learn(wave_disc, mgc, batch_size)
                total_loss += loss
                stop = time.time()
                sys.stdout.write(' avg loss=' + str(loss) +
                                 " execution time=" + str(stop - start))
                sys.stdout.write('\n')
                sys.stdout.flush()
                if file_index % 50 == 0:
                    self.synth_devset(batch_size, target_sample_rate)
                    if self.vocoder.sparse:
                        self.vocoder.store('data/models/rnn_vocoder_sparse')
                    else:
                        self.vocoder.store('data/models/rnn_vocoder')

            self.synth_devset(batch_size, target_sample_rate)
            if self.vocoder.sparse:
                self.vocoder.store('data/models/rnn_vocoder_sparse')
            else:
                self.vocoder.store('data/models/rnn_vocoder')

            epoch += 1