def expand_data_generator_by_sampling_rect(X, y, config):
    new_data = []
    new_label = []
    for segment, index in zip(X, range(len(X))):
        expanded_iq = Sample_rectangle_from_spectrogram(iq_mat=segment,
                                                        config=config)
        new_data.extend(expanded_iq)
        new_label_list = [y[index]] * len(expanded_iq)
        new_label.extend(new_label_list)

    return new_data, new_label
Esempio n. 2
0
def test_model(model, sub_path, SRC_DIR,config,BEST_RESULT_DIR):
    os.chdir(SRC_DIR)
    test_dataloader = DataSetParser(stable_mode=False, read_test_only=True, config=config)
    X_test = test_dataloader.get_dataset_test_allsnr()
    if config.with_rect_augmentation or config.with_preprocess_rect_augmentation:
        X_augmented_test = expand_test_by_sampling_rect(data=X_test,config=config)
    # swap axes for sequential data
    elif bool(re.search('LSTM',config.exp_name,re.IGNORECASE)) or bool(re.search('tcn',config.exp_name,re.IGNORECASE)):
        X_test = X_test.swapaxes(1, 2)
    else:
        X_test = np.expand_dims(X_test, axis=-1)

    result_list = []
    segment_list = []
    result_list_temp = []
    submission = pd.DataFrame()

    # Creating DataFrame with the probability prediction for each segment
    if config.snr_type == 'all':
        segment_list = test_dataloader.test_data[1]['segment_id']
        if config.with_rect_augmentation or config.with_preprocess_rect_augmentation:
            for sampled_list_x,test_index in zip(X_augmented_test,range(len(X_augmented_test))):
                sample_result_list = []
                prev_gap_doppler_burst = config.rect_augment_gap_doppler_burst_from_edge
                while not sampled_list_x:
                    '''
                    That means that we didn't manged to sample rectangle with the current doppler burst gap
                    '''
                    config.rect_augment_gap_doppler_burst_from_edge -= 1
                    print('Reducing the doppler burst gap for test_index sample {} '.format(test_index))
                    sampled_list_x = Sample_rectangle_from_spectrogram(X_test[test_index],config)

                config.rect_augment_gap_doppler_burst_from_edge = prev_gap_doppler_burst
                print('Sampled {} rectangles for test_index sample {} '.format(len(sampled_list_x), test_index))
                sampled_list_x = np.array(sampled_list_x)
                x = np.expand_dims(sampled_list_x,axis=-1)
                sample_result_list.extend(model.predict(x,batch_size=x.shape[0]).flatten().tolist())
                # result_list.append(np.mean(sample_result_list))
                result_list_temp.append(np.mean(sample_result_list))
        else:
            # result_list = model.predict(X_test).flatten().tolist()
            result_list_temp = model.predict(X_test).flatten().tolist()

    elif config.snr_type == 'low':
        if config.with_rect_augmentation or config.with_preprocess_rect_augmentation:
            for sampled_list_x, snr_type, segment_id,test_index in zip(X_augmented_test, test_dataloader.test_data[1]['snr_type'],
                                                                      test_dataloader.test_data[1]['segment_id'],range(len(X_augmented_test))):
                if snr_type == 'LowSNR':
                    sample_result_list = []
                    prev_gap_doppler_burst = config.rect_augment_gap_doppler_burst_from_edge
                    while not sampled_list_x:
                        '''
                        That means that we didn't manged to sample rectangle with the current doppler burst gap
                        '''
                        config.rect_augment_gap_doppler_burst_from_edge -= 1
                        print('Reducing the doppler burst gap for test_index sample {} '.format(test_index))
                        sampled_list_x = Sample_rectangle_from_spectrogram(X_test[test_index], config)

                    config.rect_augment_gap_doppler_burst_from_edge = prev_gap_doppler_burst
                    print('Sampled {} rectangles for test_index sample {} '.format(len(sampled_list_x), test_index))
                    sampled_list_x = np.array(sampled_list_x)
                    x = np.expand_dims(sampled_list_x, axis=-1)
                    sample_result_list.extend(model.predict(x, batch_size=x.shape[0]).flatten().tolist())
                    # result_list.append(np.mean(sample_result_list))
                    result_list_temp.append(np.mean(sample_result_list))
        else:
            low_snr_list = []
            for x,snr_type,segment_id in zip(X_test,test_dataloader.test_data[1]['snr_type'],test_dataloader.test_data[1]['segment_id']):
                if snr_type == 'LowSNR':
                    low_snr_list.append(x)
                    segment_list.append(segment_id)
            sampled_list_x = np.array(low_snr_list)
            x = np.expand_dims(sampled_list_x, axis=-1)
            # result_list = model.predict(x, batch_size=x.shape[0]).flatten().tolist()
            result_list_temp = model.predict(x, batch_size=x.shape[0]).flatten().tolist()
    else:
        # High SNR run
        if config.with_rect_augmentation or config.with_preprocess_rect_augmentation:
            for sampled_list_x, snr_type, segment_id,test_index in zip(X_augmented_test, test_dataloader.test_data[1]['snr_type'],
                                                           test_dataloader.test_data[1]['segment_id'],range(len(X_augmented_test))):
                if snr_type == 'HighSNR':
                    sample_result_list = []
                    prev_gap_doppler_burst = config.rect_augment_gap_doppler_burst_from_edge
                    while not sampled_list_x:
                        '''
                        That means that we didn't manged to sample rectangle with the current doppler burst gap
                        '''
                        config.rect_augment_gap_doppler_burst_from_edge -= 1
                        print('Reducing the doppler burst gap for test_index sample {} '.format(test_index))
                        sampled_list_x = Sample_rectangle_from_spectrogram(X_test[test_index], config)

                    config.rect_augment_gap_doppler_burst_from_edge = prev_gap_doppler_burst
                    print('Sampled {} rectangles for test_index sample {} '.format(len(sampled_list_x), test_index))
                    sampled_list_x = np.array(sampled_list_x)
                    x = np.expand_dims(sampled_list_x, axis=-1)
                    sample_result_list.extend(model.predict(x, batch_size=x.shape[0]).flatten().tolist())
                    # result_list.append(np.mean(sample_result_list))
                    result_list_temp.append(np.mean(sample_result_list))
        else:
            high_snr_list = []
            for x,snr_type,segment_id in zip(X_test,test_dataloader.test_data[1]['snr_type'],test_dataloader.test_data[1]['segment_id']):
                if snr_type == 'HighSNR':
                    high_snr_list.append(x)
                    segment_list.append(segment_id)
            sampled_list_x = np.array(high_snr_list)
            x = np.expand_dims(sampled_list_x, axis=-1)
            # result_list = model.predict(x, batch_size=x.shape[0]).flatten().tolist()
            result_list_temp = model.predict(x, batch_size=x.shape[0]).flatten().tolist()


    if config.learn_background:
        result_list_temp = np.array(result_list_temp).reshape((-1, 3))
        if config.background_implicit_inference:
            y_pred_2 = np.array([[y[0] , y[1] + y[2]] for y in result_list_temp])
        else:
            y_pred_2 = np.array([[y[0] / (1 - y[2]), y[1] / (1 - y[2])] for y in result_list_temp])
        y_pred_2 = np.array([y / (y[0] + y[1]) if y[0] + y[1] > 1 else y for y in y_pred_2]) # numeric correction
        result_list = [y[0] if y[0] > y[1] else 1 - y[1] for y in y_pred_2]
    else:
        result_list = result_list_temp

    submission['segment_id'] = segment_list
    submission['prediction'] = result_list
    submission['prediction'] = submission['prediction'].astype('float')
    # Save submission
    submission.to_csv(sub_path, index=False)