コード例 #1
0
ファイル: ensemble_train.py プロジェクト: kyangt/dcase
    def load_data_engin(self, train_addr, valid_addr):
        torch.cuda.empty_cache()
        self.train_addr = train_addr
        self.valid_addr = valid_addr

        self.train = Data_Engin(method=self.method,
                                mono=self.mono,
                                address=self.train_addr,
                                spectra_type=self.spectra_type,
                                device=self.device,
                                batch_size=self.batch_size,
                                fs=self.fs,
                                n_fft=self.n_fft,
                                n_mels=self.n_mels,
                                win_len=self.win_len,
                                hop_len=self.hop_len)

        self.valid = Data_Engin(method=self.method,
                                mono=self.mono,
                                address=self.valid_addr,
                                spectra_type=self.spectra_type,
                                device=self.device,
                                batch_size=self.batch_size,
                                fs=self.fs,
                                n_fft=self.n_fft,
                                n_mels=self.n_mels,
                                win_len=self.win_len,
                                hop_len=self.hop_len)
コード例 #2
0
    acc = dict()
    target = dict()
    prediction = dict()

    for model in models:
        torch.cuda.empty_cache()

        idx = str(model['idx'])
        print('idx:', idx)

        test = Data_Engin(
            method=model['method'],
            mono=model['mono'],
            address='./dataset/dcase/evaluation_setup/modify_evaluate.csv',
            spectra_type=model['spectra_type'],
            device=device,
            batch_size=model['batch_size'],
            fs=model['fs'],
            n_fft=model['n_fft'],
            n_mels=model['n_mels'])

        network = load_model(model['network_address'],
                             network_type=model['network'],
                             no_class=no_class)

        acc[idx], target[idx], prediction[idx] = infer(network=network,
                                                       valid_data_engine=test)

    for item in prediction.values():
        print(item[:10])
コード例 #3
0
    # classes = ['silence', 'clapping', 'laughing', 'scream-shout', 'conversation', 'happy', 'angry']
    # classes = ['airport', 'shopping_mall', 'metro_station', 'street_pedestrian', 'public_square', 'street_traffic', 'tram', 'bus', 'metro', 'park']
    classes = [
        'airport', 'bus', 'metro', 'metro_station', 'park', 'public_square',
        'shopping_mall', 'street_pedestrian', 'street_traffic', 'tram'
    ]

    no_class = len(classes)

    test = Data_Engin(
        method='post',
        mono='diff',
        address='./dataset/dcase/evaluation_setup/modify_evaluate.csv',
        spectra_type='mel_spectrum',
        device=device,
        batch_size=16,
        fs=48000,
        n_fft=2048,
        n_mels=128,
        win_len=2048,
        hop_len=204)

    # model_a = VGG_M(no_class=no_class)
    # model_b = DCASE_PAST(no_class=no_class)

    # network = ENSEMBLE(model_a=model_a, model_b=model_b, no_class=no_class)

    network = VGG_M2(no_class=no_class)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")