コード例 #1
0
def evaluate_model(
    model: ModelInterface,
    batch: List[object],
    label: List[int],
    show_stats: bool = False
) -> Tuple[float, float, List[int]]:
    with torch.no_grad():
        pred = model.predict(batch)
        pred_label = pred.argmax(dim=1)
        index = pred_label.to(torch.bool)

        if show_stats:
            stats = torch.cat([
                pred[index].to('cpu'),
                torch.tensor(label, dtype=torch.float).reshape(-1, 1)[index]
            ], dim=1)
            log.debug(stats)

    roc_auc, prc_auc = util.evaluate_auc(label, pred[:, 1])
    return roc_auc, prc_auc, pred_label.tolist()
コード例 #2
0
ファイル: gui.py プロジェクト: danenigma/speaker_recoginition
class Main(QMainWindow):
    CONV_INTERVAL = 0.4
    CONV_DURATION = 1.5
    CONV_FILTER_DURATION = CONV_DURATION
    FS = 8000
    TEST_DURATION = 3

    def __init__(self, parent=None):
        QWidget.__init__(self, parent)
        uic.loadUi("edytor.ui", self)
        self.statusBar()

        self.timer = QTimer(self)
        self.timer.timeout.connect(self.timer_callback)

        self.noiseButton.clicked.connect(self.noise_clicked)
        self.recording_noise = False
        self.loadNoise.clicked.connect(self.load_noise)

        self.enrollRecord.clicked.connect(self.start_enroll_record)
        self.stopEnrollRecord.clicked.connect(self.stop_enroll_record)
        self.enrollFile.clicked.connect(self.enroll_file)
        self.enroll.clicked.connect(self.do_enroll)
        self.startTrain.clicked.connect(self.start_train)
        self.dumpBtn.clicked.connect(self.dump)
        self.loadBtn.clicked.connect(self.load)

        self.recoRecord.clicked.connect(self.start_reco_record)
        self.stopRecoRecord.clicked.connect(self.stop_reco_record)
#        self.newReco.clicked.connect(self.new_reco)
        self.recoFile.clicked.connect(self.reco_file)
        self.recoInputFiles.clicked.connect(self.reco_files)

        #UI.init
        self.userdata =[]
        self.loadUsers()
        self.Userchooser.currentIndexChanged.connect(self.showUserInfo)
        self.ClearInfo.clicked.connect(self.clearUserInfo)
        self.UpdateInfo.clicked.connect(self.updateUserInfo)
        self.UploadImage.clicked.connect(self.upload_avatar)
        #movie test
        self.movie = QMovie(u"image/recording.gif")
        self.movie.start()
        self.movie.stop()
        self.Animation.setMovie(self.movie)
        self.Animation_2.setMovie(self.movie)
        self.Animation_3.setMovie(self.movie)

        self.aladingpic = QPixmap(u"image/a_hello.png")
        self.Alading.setPixmap(self.aladingpic)
        self.Alading_conv.setPixmap(self.aladingpic)

        #default user image setting
        self.avatarname = "image/nouser.jpg"
        self.defaultimage = QPixmap(self.avatarname)
        self.Userimage.setPixmap(self.defaultimage)
        self.recoUserImage.setPixmap(self.defaultimage)
        self.convUserImage.setPixmap(self.defaultimage)
        self.load_avatar('avatar/')

        # Graph Window init
        self.graphwindow = GraphWindow()
        self.newname = ""
        self.lastname = ""
        self.Graph_button.clicked.connect(self.graphwindow.show)
        self.convRecord.clicked.connect(self.start_conv_record)
        self.convStop.clicked.connect(self.stop_conv)

        self.backend = ModelInterface()

        # debug
        QShortcut(QKeySequence("Ctrl+P"), self, self.printDebug)

        #init
        try:
            fs, signal = read_wav("bg.wav")
            self.backend.init_noise(fs, signal)
        except:
            pass


    ############ RECORD
    def start_record(self):
        self.pyaudio = pyaudio.PyAudio()
        self.status("Recording...")
        self.movie.start()
        self.Alading.setPixmap(QPixmap(u"image/a_thinking.png"))


        self.recordData = []
        self.stream = self.pyaudio.open(format=FORMAT, channels=1, rate=Main.FS,
                        input=True, frames_per_buffer=1)
        self.stopped = False
        self.reco_th = RecorderThread(self)
        self.reco_th.start()

        self.timer.start(1000)
        self.record_time = 0
        self.update_all_timer()

    def add_record_data(self, i):
        self.recordData.append(i)
        return self.stopped

    def timer_callback(self):
        self.record_time += 1
        self.status("Recording..." + time_str(self.record_time))
        self.update_all_timer()

    def stop_record(self):
        self.movie.stop()
        self.stopped = True
        self.reco_th.wait()
        self.timer.stop()
        self.stream.stop_stream()
        self.stream.close()
        self.pyaudio.terminate()
        self.status("Record stopeed")

    ############## conversation
    def start_conv_record(self):
        self.conv_result_list = []
        self.start_record()
        self.conv_now_pos = 0
        self.conv_timer = QTimer(self)
        self.conv_timer.timeout.connect(self.do_conversation)
        self.conv_timer.start(Main.CONV_INTERVAL * 1000)
        #reset
        self.graphwindow.wid.reset()

    def stop_conv(self):
        self.stop_record()
        self.conv_timer.stop()

    def do_conversation(self):
        interval_len = int(Main.CONV_INTERVAL * Main.FS)
        segment_len = int(Main.CONV_DURATION * Main.FS)
        self.conv_now_pos += interval_len
        to_filter = self.recordData[max([self.conv_now_pos - segment_len, 0]):
                                   self.conv_now_pos]
        signal = np.array(to_filter, dtype=NPDtype)
        label = None
        try:
            signal = self.backend.filter(Main.FS, signal)
            if len(signal) > 50:
                label = self.backend.predict(Main.FS, signal, True)
        except Exception as e:
            print traceback.format_exc()
            print str(e)

        global last_label_to_show
        label_to_show = label
        if label and self.conv_result_list:
            last_label = self.conv_result_list[-1]
            if last_label and last_label != label:
                label_to_show = last_label_to_show
        self.conv_result_list.append(label)

        print label_to_show, "label to show"
        last_label_to_show = label_to_show

        #ADD FOR GRAPH
        if label_to_show is None:
            label_to_show = 'Nobody'
        if len(NAMELIST) and NAMELIST[-1] != label_to_show:
            NAMELIST.append(label_to_show)
        self.convUsername.setText(label_to_show)
        self.Alading_conv.setPixmap(QPixmap(u"image/a_result.png"))
        self.convUserImage.setPixmap(self.get_avatar(label_to_show))


    ###### RECOGNIZE
    def start_reco_record(self):
        self.Alading.setPixmap(QPixmap(u"image/a_hello"))
        self.recoRecordData = np.array((), dtype=NPDtype)
        self.start_record()

    def stop_reco_record(self):
        self.stop_record()
        signal = np.array(self.recordData, dtype=NPDtype)
        self.reco_remove_update(Main.FS, signal)

    def reco_do_predict(self, fs, signal):
        label = self.backend.predict(fs, signal)
        if not label:
            label = "Nobody"
        print label
        self.recoUsername.setText(label)
        self.Alading.setPixmap(QPixmap(u"image/a_result.png"))
        self.recoUserImage.setPixmap(self.get_avatar(label))

        # TODO To Delete
        write_wav('reco.wav', fs, signal)

    def reco_remove_update(self, fs, signal):
        new_signal = self.backend.filter(fs, signal)
        print "After removed: {0} -> {1}".format(len(signal), len(new_signal))
        self.recoRecordData = np.concatenate((self.recoRecordData, new_signal))
        real_len = float(len(self.recoRecordData)) / Main.FS / Main.TEST_DURATION * 100
        if real_len > 100:
            real_len = 100

        self.reco_do_predict(fs, self.recoRecordData)


    def reco_file(self):
        fname = QFileDialog.getOpenFileName(self, "Open Wav File", "", "Files (*.wav)")
        print 'reco_file'
        if not fname:
            return
        self.status(fname)

        fs, signal = read_wav(fname)
        self.reco_do_predict(fs, signal)

    def reco_files(self):
        fnames = QFileDialog.getOpenFileNames(self, "Select Wav Files", "", "Files (*.wav)")
        print 'reco_files'
        for f in fnames:
            fs, sig = read_wav(f)
            newsig = self.backend.filter(fs, sig)
            label = self.backend.predict(fs, newsig)
            print f, label

    ########## ENROLL
    def start_enroll_record(self):
        self.enrollWav = None
        self.enrollFileName.setText("")
        self.start_record()

    def enroll_file(self):
        fname = QFileDialog.getOpenFileName(self, "Open Wav File", "", "Files (*.wav)")
        if not fname:
            return
        self.status(fname)
        self.enrollFileName.setText(fname)
        fs, signal = read_wav(fname)
        signal = monophonic(signal)
        self.enrollWav = (fs, signal)

    def stop_enroll_record(self):
        self.stop_record()
        print self.recordData[:300]
        signal = np.array(self.recordData, dtype=NPDtype)
        self.enrollWav = (Main.FS, signal)

        # TODO To Delete
        write_wav('enroll.wav', *self.enrollWav)

    def do_enroll(self):
        name = self.Username.text().trimmed()
        if not name:
            self.warn("Please Input Your Name")
            return
#        self.addUserInfo()
        new_signal = self.backend.filter(*self.enrollWav)
        print "After removed: {0} -> {1}".format(len(self.enrollWav[1]), len(new_signal))
        print "Enroll: {:.4f} seconds".format(float(len(new_signal)) / Main.FS)
        if len(new_signal) == 0:
            print "Error! Input is silent! Please enroll again"
            return
        self.backend.enroll(name, Main.FS, new_signal)

    def start_train(self):
        self.status("Training...")
        self.backend.train()
        self.status("Training Done.")

    ####### UI related
    def getWidget(self, splash):
        t = QtCore.QElapsedTimer()
        t.start()
        while (t.elapsed() < 800):
            str = QtCore.QString("times = ") + QtCore.QString.number(t.elapsed())
            splash.showMessage(str)
            QtCore.QCoreApplication.processEvents()

    def upload_avatar(self):
        fname = QFileDialog.getOpenFileName(self, "Open JPG File", "", "File (*.jpg)")
        if not fname:
            return
        self.avatarname = fname
        self.Userimage.setPixmap(QPixmap(fname))

    def loadUsers(self):
        with open("avatar/metainfo.txt") as db:
            for line in db:
                tmp = line.split()
                self.userdata.append(tmp)
                self.Userchooser.addItem(tmp[0])

    def showUserInfo(self):
        for user in self.userdata:
            if self.userdata.index(user) == self.Userchooser.currentIndex() - 1:
                self.Username.setText(user[0])
                self.Userage.setValue(int(user[1]))
                if user[2] == 'F':
                    self.Usersex.setCurrentIndex(1)
                else:
                    self.Usersex.setCurrentIndex(0)
                self.Userimage.setPixmap(self.get_avatar(user[0]))

    def updateUserInfo(self):
        userindex = self.Userchooser.currentIndex() - 1
        u = self.serdata[userindex]
        u[0] = unicode(self.Username.displayText())
        u[1] = self.Userage.value()
        if self.Usersex.currentIndex():
            u[2] = 'F'
        else:
            u[2] = 'M'
        with open("avatar/metainfo.txt","w") as db:
            for user in self.userdata:
                for i in range(3):
                    db.write(str(user[i]) + " ")
                db.write("\n")

    def writeuserdata(self):
        with open("avatar/metainfo.txt","w") as db:
            for user in self.userdata:
                for i in range (0,4):
                    db.write(str(user[i]) + " ")
                db.write("\n")

    def clearUserInfo(self):
        self.Username.setText("")
        self.Userage.setValue(0)
        self.Usersex.setCurrentIndex(0)
        self.Userimage.setPixmap(self.defaultimage)

    def addUserInfo(self):
        for user in self.userdata:
            if user[0] == unicode(self.Username.displayText()):
                return
        newuser = []
        newuser.append(unicode(self.Username.displayText()))
        newuser.append(self.Userage.value())
        if self.Usersex.currentIndex():
            newuser.append('F')
        else:
            newuser.append('M')
        if self.avatarname:
            shutil.copy(self.avatarname, 'avatar/' + user[0] + '.jpg')
        self.userdata.append(newuser)
        self.writeuserdata()
        self.Userchooser.addItem(unicode(self.Username.displayText()))


    ############# UTILS
    def warn(self, s):
        QMessageBox.warning(self, "Warning", s)

    def status(self, s=""):
        self.statusBar().showMessage(s)

    def update_all_timer(self):
        s = time_str(self.record_time)
        self.enrollTime.setText(s)
        self.recoTime.setText(s)
        self.convTime.setText(s)

    def dump(self):
        fname = QFileDialog.getSaveFileName(self, "Save Data to:", "", "")
        if fname:
            try:
                self.backend.dump(fname)
            except Exception as e:
                self.warn(str(e))
            else:
                self.status("Dumped to file: " + fname)

    def load(self):
        fname = QFileDialog.getOpenFileName(self, "Open Data File:", "", "")
        if fname:
            try:
                self.backend = ModelInterface.load(fname)
            except Exception as e:
                self.warn(str(e))
            else:
                self.status("Loaded from file: " + fname)

    def noise_clicked(self):
        self.recording_noise = not self.recording_noise
        if self.recording_noise:
            self.noiseButton.setText('Stop Recording Noise')
            self.start_record()
        else:
            self.noiseButton.setText('Recording Background Noise')
            self.stop_record()
            signal = np.array(self.recordData, dtype=NPDtype)
            wavfile.write("bg.wav", Main.FS, signal)
            self.backend.init_noise(Main.FS, signal)

    def load_noise(self):
        fname = QFileDialog.getOpenFileName(self, "Open Data File:", "", "Wav File  (*.wav)")
        if fname:
            fs, signal = read_wav(fname)
            self.backend.init_noise(fs, signal)

    def load_avatar(self, dirname):
        self.avatars = {}
        for f in glob.glob(dirname + '/*.jpg'):
            name = os.path.basename(f).split('.')[0]
            print f, name
            self.avatars[name] = QPixmap(f)

    def get_avatar(self, username):
        p = self.avatars.get(str(username), None)
        if p:
            return p
        else:
            return self.defaultimage

    def printDebug(self):
        for name, feat in self.backend.features.iteritems():
            print name, len(feat)
        print "GMMs",
        print len(self.backend.gmmset.gmms)
    '''
コード例 #3
0
class Main(QMainWindow):
    CONV_INTERVAL = 0.4
    CONV_DURATION = 1.5
    CONV_FILTER_DURATION = CONV_DURATION
    FS = 8000
    TEST_DURATION = 3

    def __init__(self, parent=None):
        QWidget.__init__(self, parent)
        uic.loadUi("edytor.ui", self)
        self.statusBar()

        self.timer = QTimer(self)
        self.timer.timeout.connect(self.timer_callback)

        self.noiseButton.clicked.connect(self.noise_clicked)
        self.recording_noise = False
        self.loadNoise.clicked.connect(self.load_noise)

        self.enrollRecord.clicked.connect(self.start_enroll_record)
        self.stopEnrollRecord.clicked.connect(self.stop_enroll_record)
        self.enrollFile.clicked.connect(self.enroll_file)
        self.enroll.clicked.connect(self.do_enroll)
        self.startTrain.clicked.connect(self.start_train)
        self.dumpBtn.clicked.connect(self.dump)
        self.loadBtn.clicked.connect(self.load)

        self.recoRecord.clicked.connect(self.start_reco_record)
        self.stopRecoRecord.clicked.connect(self.stop_reco_record)
        #        self.newReco.clicked.connect(self.new_reco)
        self.recoFile.clicked.connect(self.reco_file)
        self.recoInputFiles.clicked.connect(self.reco_files)

        #UI.init
        self.userdata = []
        self.loadUsers()
        self.Userchooser.currentIndexChanged.connect(self.showUserInfo)
        self.ClearInfo.clicked.connect(self.clearUserInfo)
        self.UpdateInfo.clicked.connect(self.updateUserInfo)
        self.UploadImage.clicked.connect(self.upload_avatar)
        #movie test
        self.movie = QMovie(u"image/recording.gif")
        self.movie.start()
        self.movie.stop()
        self.Animation.setMovie(self.movie)
        self.Animation_2.setMovie(self.movie)
        self.Animation_3.setMovie(self.movie)

        self.aladingpic = QPixmap(u"image/a_hello.png")
        self.Alading.setPixmap(self.aladingpic)
        self.Alading_conv.setPixmap(self.aladingpic)

        #default user image setting
        self.avatarname = "image/nouser.jpg"
        self.defaultimage = QPixmap(self.avatarname)
        self.Userimage.setPixmap(self.defaultimage)
        self.recoUserImage.setPixmap(self.defaultimage)
        self.convUserImage.setPixmap(self.defaultimage)
        self.load_avatar('avatar/')

        # Graph Window init
        self.graphwindow = GraphWindow()
        self.newname = ""
        self.lastname = ""
        self.Graph_button.clicked.connect(self.graphwindow.show)
        self.convRecord.clicked.connect(self.start_conv_record)
        self.convStop.clicked.connect(self.stop_conv)

        self.backend = ModelInterface()

        # debug
        QShortcut(QKeySequence("Ctrl+P"), self, self.printDebug)

        #init
        try:
            fs, signal = wavfile.read("bg.wav")
            self.backend.init_noise(fs, signal)
        except:
            pass

    ############ RECORD
    def start_record(self):
        self.pyaudio = pyaudio.PyAudio()
        self.status("Recording...")
        self.movie.start()
        self.Alading.setPixmap(QPixmap(u"image/a_thinking.png"))

        self.recordData = []
        self.stream = self.pyaudio.open(format=FORMAT,
                                        channels=1,
                                        rate=Main.FS,
                                        input=True,
                                        frames_per_buffer=1)
        self.stopped = False
        self.reco_th = RecorderThread(self)
        self.reco_th.start()

        self.timer.start(1000)
        self.record_time = 0
        self.update_all_timer()

    def add_record_data(self, i):
        self.recordData.append(i)
        return self.stopped

    def timer_callback(self):
        self.record_time += 1
        self.status("Recording..." + time_str(self.record_time))
        self.update_all_timer()

    def stop_record(self):
        self.movie.stop()
        self.stopped = True
        self.reco_th.wait()
        self.timer.stop()
        self.stream.stop_stream()
        self.stream.close()
        self.pyaudio.terminate()
        self.status("Record stopeed")

    ############## conversation
    def start_conv_record(self):
        self.conv_result_list = []
        self.start_record()
        self.conv_now_pos = 0
        self.conv_timer = QTimer(self)
        self.conv_timer.timeout.connect(self.do_conversation)
        self.conv_timer.start(Main.CONV_INTERVAL * 1000)
        #reset
        self.graphwindow.wid.reset()

    def stop_conv(self):
        self.stop_record()
        self.conv_timer.stop()

    def do_conversation(self):
        interval_len = int(Main.CONV_INTERVAL * Main.FS)
        segment_len = int(Main.CONV_DURATION * Main.FS)
        self.conv_now_pos += interval_len
        to_filter = self.recordData[max([self.conv_now_pos -
                                         segment_len, 0]):self.conv_now_pos]
        signal = np.array(to_filter, dtype=NPDtype)
        label = None
        try:
            signal = self.backend.filter(Main.FS, signal)
            if len(signal) > 50:
                label = self.backend.predict(Main.FS, signal, True)
        except Exception as e:
            print traceback.format_exc()
            print str(e)

        global last_label_to_show
        label_to_show = label
        if label and self.conv_result_list:
            last_label = self.conv_result_list[-1]
            if last_label and last_label != label:
                label_to_show = last_label_to_show
        self.conv_result_list.append(label)

        print label_to_show, "label to show"
        last_label_to_show = label_to_show

        #ADD FOR GRAPH
        if label_to_show is None:
            label_to_show = 'Nobody'
        if len(NAMELIST) and NAMELIST[-1] != label_to_show:
            NAMELIST.append(label_to_show)
        self.convUsername.setText(label_to_show)
        self.Alading_conv.setPixmap(QPixmap(u"image/a_result.png"))
        self.convUserImage.setPixmap(self.get_avatar(label_to_show))

    ###### RECOGNIZE
    def start_reco_record(self):
        self.Alading.setPixmap(QPixmap(u"image/a_hello"))
        self.recoRecordData = np.array((), dtype=NPDtype)
        self.start_record()

    def stop_reco_record(self):
        self.stop_record()
        signal = np.array(self.recordData, dtype=NPDtype)
        self.reco_remove_update(Main.FS, signal)

    def reco_do_predict(self, fs, signal):
        label = self.backend.predict(fs, signal)
        if not label:
            label = "Nobody"
        print label
        self.recoUsername.setText(label)
        self.Alading.setPixmap(QPixmap(u"image/a_result.png"))
        self.recoUserImage.setPixmap(self.get_avatar(label))

        # TODO To Delete
        write_wav('reco.wav', fs, signal)

    def reco_remove_update(self, fs, signal):
        new_signal = self.backend.filter(fs, signal)
        print "After removed: {0} -> {1}".format(len(signal), len(new_signal))
        self.recoRecordData = np.concatenate((self.recoRecordData, new_signal))
        real_len = float(len(
            self.recoRecordData)) / Main.FS / Main.TEST_DURATION * 100
        if real_len > 100:
            real_len = 100

        self.reco_do_predict(fs, self.recoRecordData)

    def reco_file(self):
        fname = QFileDialog.getOpenFileName(self, "Open Wav File", "",
                                            "Files (*.wav)")
        print 'reco_file'
        if not fname:
            return
        self.status(fname)
        fs, signal = wavfile.read(fname)
        self.reco_do_predict(fs, signal)

    def reco_files(self):
        fnames = QFileDialog.getOpenFileNames(self, "Select Wav Files", "",
                                              "Files (*.wav)")
        print 'reco_files'
        for f in fnames:
            fs, sig = wavfile.read(f)
            newsig = self.backend.filter(fs, sig)
            label = self.backend.predict(fs, newsig)
            print f, label

    ########## ENROLL
    def start_enroll_record(self):
        self.enrollWav = None
        self.enrollFileName.setText("")
        self.start_record()

    def enroll_file(self):
        fname = QFileDialog.getOpenFileName(self, "Open Wav File", "",
                                            "Files (*.wav)")
        if not fname:
            return
        self.status(fname)
        self.enrollFileName.setText(fname)
        fs, signal = wavfile.read(fname)
        signal = monophonic(signal)
        self.enrollWav = (fs, signal)

    def stop_enroll_record(self):
        self.stop_record()
        print self.recordData[:300]
        signal = np.array(self.recordData, dtype=NPDtype)
        self.enrollWav = (Main.FS, signal)

        # TODO To Delete
        write_wav('enroll.wav', *self.enrollWav)

    def do_enroll(self):
        name = self.Username.text().trimmed()
        if not name:
            self.warn("Please Input Your Name")
            return


#        self.addUserInfo()
        new_signal = self.backend.filter(*self.enrollWav)
        print "After removed: {0} -> {1}".format(len(self.enrollWav[1]),
                                                 len(new_signal))
        print "Enroll: {:.4f} seconds".format(float(len(new_signal)) / Main.FS)
        self.backend.enroll(name, Main.FS, new_signal)

    def start_train(self):
        self.status("Training...")
        self.backend.train()
        self.status("Training Done.")

    ####### UI related
    def getWidget(self, splash):
        t = QtCore.QElapsedTimer()
        t.start()
        while (t.elapsed() < 800):
            str = QtCore.QString("times = ") + QtCore.QString.number(
                t.elapsed())
            splash.showMessage(str)
            QtCore.QCoreApplication.processEvents()

    def upload_avatar(self):
        fname = QFileDialog.getOpenFileName(self, "Open JPG File", "",
                                            "File (*.jpg)")
        if not fname:
            return
        self.avatarname = fname
        self.Userimage.setPixmap(QPixmap(fname))

    def loadUsers(self):
        with open("avatar/metainfo.txt") as db:
            for line in db:
                tmp = line.split()
                self.userdata.append(tmp)
                self.Userchooser.addItem(tmp[0])

    def showUserInfo(self):
        for user in self.userdata:
            if self.userdata.index(
                    user) == self.Userchooser.currentIndex() - 1:
                self.Username.setText(user[0])
                self.Userage.setValue(int(user[1]))
                if user[2] == 'F':
                    self.Usersex.setCurrentIndex(1)
                else:
                    self.Usersex.setCurrentIndex(0)
                self.Userimage.setPixmap(self.get_avatar(user[0]))

    def updateUserInfo(self):
        userindex = self.Userchooser.currentIndex() - 1
        u = self.serdata[userindex]
        u[0] = unicode(self.Username.displayText())
        u[1] = self.Userage.value()
        if self.Usersex.currentIndex():
            u[2] = 'F'
        else:
            u[2] = 'M'
        with open("avatar/metainfo.txt", "w") as db:
            for user in self.userdata:
                for i in range(3):
                    db.write(str(user[i]) + " ")
                db.write("\n")

    def writeuserdata(self):
        with open("avatar/metainfo.txt", "w") as db:
            for user in self.userdata:
                for i in range(0, 4):
                    db.write(str(user[i]) + " ")
                db.write("\n")

    def clearUserInfo(self):
        self.Username.setText("")
        self.Userage.setValue(0)
        self.Usersex.setCurrentIndex(0)
        self.Userimage.setPixmap(self.defaultimage)

    def addUserInfo(self):
        for user in self.userdata:
            if user[0] == unicode(self.Username.displayText()):
                return
        newuser = []
        newuser.append(unicode(self.Username.displayText()))
        newuser.append(self.Userage.value())
        if self.Usersex.currentIndex():
            newuser.append('F')
        else:
            newuser.append('M')
        if self.avatarname:
            shutil.copy(self.avatarname, 'avatar/' + user[0] + '.jpg')
        self.userdata.append(newuser)
        self.writeuserdata()
        self.Userchooser.addItem(unicode(self.Username.displayText()))

    ############# UTILS
    def warn(self, s):
        QMessageBox.warning(self, "Warning", s)

    def status(self, s=""):
        self.statusBar().showMessage(s)

    def update_all_timer(self):
        s = time_str(self.record_time)
        self.enrollTime.setText(s)
        self.recoTime.setText(s)
        self.convTime.setText(s)

    def dump(self):
        fname = QFileDialog.getSaveFileName(self, "Save Data to:", "", "")
        if fname:
            try:
                self.backend.dump(fname)
            except Exception as e:
                self.warn(str(e))
            else:
                self.status("Dumped to file: " + fname)

    def load(self):
        fname = QFileDialog.getOpenFileName(self, "Open Data File:", "", "")
        if fname:
            try:
                self.backend = ModelInterface.load(fname)
            except Exception as e:
                self.warn(str(e))
            else:
                self.status("Loaded from file: " + fname)

    def noise_clicked(self):
        self.recording_noise = not self.recording_noise
        if self.recording_noise:
            self.noiseButton.setText('Stop Recording Noise')
            self.start_record()
        else:
            self.noiseButton.setText('Recording Background Noise')
            self.stop_record()
            signal = np.array(self.recordData, dtype=NPDtype)
            wavfile.write("bg.wav", Main.FS, signal)
            self.backend.init_noise(Main.FS, signal)

    def load_noise(self):
        fname = QFileDialog.getOpenFileName(self, "Open Data File:", "",
                                            "Wav File  (*.wav)")
        if fname:
            fs, signal = wavfile.read(fname)
            self.backend.init_noise(fs, signal)

    def load_avatar(self, dirname):
        self.avatars = {}
        for f in glob.glob(dirname + '/*.jpg'):
            name = os.path.basename(f).split('.')[0]
            print f, name
            self.avatars[name] = QPixmap(f)

    def get_avatar(self, username):
        p = self.avatars.get(str(username), None)
        if p:
            return p
        else:
            return self.defaultimage

    def printDebug(self):
        for name, feat in self.backend.features.iteritems():
            print name, len(feat)
        print "GMMs",
        print len(self.backend.gmmset.gmms)

    '''
コード例 #4
0
ファイル: main.py プロジェクト: EZlzh/property-prediction
def train(directory: Text,
          model_name: Text,
          batch_size: int,
          learning_rate: float,
          epsilon: float,
          cuda: bool,
          train_with_test: bool,
          min_iteration: int,
          max_iteration: int,
          ndrop: Optional[float] = None,
          **kwargs) -> None:
    # filter out options that are not set in command line
    kwargs = util.dict_filter(kwargs, lambda k, v: v is not None)

    data_folder = Path(directory)
    assert data_folder.is_dir(), 'Invalid data folder'

    dev = require_device(cuda)
    for fold in sorted(data_folder.iterdir()):
        log.info(f'Processing "{fold}"...')

        # model & optimizer
        model_type = models.select(model_name)  # see models/__init__.py
        model = ModelInterface(model_type, dev, **kwargs)
        optimizer = torch.optim.Adam(params=model.inst.parameters(),
                                     lr=learning_rate)

        # load the fold
        raw = [
            util.load_csv(fold / name)
            for name in ['train.csv', 'test.csv', 'dev.csv']
        ]

        # let the model parse these molecules
        data = []
        for i in range(len(raw)):
            buf = []
            for smiles, activity in raw[i].items():
                obj = model.process(smiles)
                buf.append(Item(obj, activity))
            data.append(buf)
        log.debug(f'atom_map: {model.atom_map}')

        test_batch, _test_label = util.separate_items(data[1])
        test_label = torch.tensor(_test_label)

        # training phase
        train_data = data[0] + data[1] if train_with_test else data[0]

        # set up to randomly drop negative samples
        # see util.RandomIterator for details
        drop_prob = ndrop if ndrop is not None else 0
        drop_fn = lambda x: drop_prob if x.activity == 0 else 0
        data_ptr = util.RandomIterator(
            train_data, drop_fn=drop_fn if ndrop is not None else None)

        countdown = min_iteration
        min_loss = 1e99  # track history minimal loss
        sum_loss, batch_cnt = 0.0, 0
        for _ in range(max_iteration):
            # generate batch
            batch, _label = util.separate_items(data_ptr.iterate(batch_size))
            label = torch.tensor(_label)

            # train a mini-batch
            batch_loss = train_step(model, optimizer, batch, label)
            sum_loss += batch_loss
            batch_cnt += 1
            # log.debug(f'{batch_loss}, {sum_loss}')

            # convergence test
            if data_ptr.is_cycled():
                loss = sum_loss / batch_cnt
                pred = model.predict(test_batch)
                log.debug(
                    f'{util.stat_string(_test_label, pred)}. loss={loss},min={min_loss}'
                )

                if countdown <= 0 and abs(min_loss - loss) < epsilon:
                    log.debug('Converged.')
                    break

                countdown -= 1
                min_loss = min(min_loss, loss)
                sum_loss, batch_cnt = 0.0, 0

        # model evaluation on `dev.csv`
        roc_auc, prc_auc = evaluate_model(model, data[2])
        log.info(f'ROC-AUC: {roc_auc}')
        log.info(f'PRC-AUC: {prc_auc}')
コード例 #5
0
ファイル: main.py プロジェクト: EZlzh/property-prediction
def evaluate_model(model: ModelInterface,
                   data: List[Item]) -> Tuple[float, float]:
    batch, label = util.separate_items(data)
    pred = model.predict(batch)
    log.debug(f'final: {util.stat_string(label, pred)}')
    return util.evaluate_auc(label, pred)
コード例 #6
0
            f"Hello {name}. Please input your voice {args.num_samples} times")
        with tempfile.TemporaryDirectory() as tempdir:
            i = 1
            while i <= args.num_samples:
                with sr.Microphone() as source:
                    audio = r.listen(source)
                # Generate random filename
                filename = os.path.join(
                    tempdir, name + "_" + str(uuid.uuid1()) + ".wav")
                with open(filename, "wb") as file:
                    file.write(audio.get_wav_data(convert_rate=16000))
                # enroll a file
                fs, signal = read_wav(filename)
                model.enroll(name, fs, signal)
                logger.info("wav file %s has been enrolled" % (filename))
                i += 1

        model.train()
        model.dump(args.model_path)
    else:
        model = ModelInterface.load(args.model_path)
        print("Please input your voice: ")
        with tempfile.TemporaryDirectory() as tempdir:
            with sr.Microphone() as source:
                audio = r.listen(source)
            filename = os.path.join(tempdir, str(uuid.uuid1()) + ".wav")
            with open(filename, "wb") as file:
                file.write(audio.get_wav_data(convert_rate=16000))
            fs, signal = read_wav(filename)
            pred, _ = model.predict(fs, signal)
            print(f"Your name is {pred}!")