Example #1
0
    def __init__(self):
        with open(os.path.join(os.path.dirname(__file__), 'Config.json'),
                  'r') as json_file:
            config = json.load(json_file)
        self._shots = dict()

        self._shots['train'] = config['shots']['train']
        self._shots['test'] = config['shots']['test']
        self._shots['val'] = config['shots']['test']

        self._tags = config['diagnosis']['tags']
        # 截取信号的起始时间(ms)
        self._start_time = config['diagnosis']['start_time']
        # 重采样的信号频率(kHz)
        self._sample_rate = config['diagnosis']['sample_rate']
        self._normalized = config['normalized']
        self._directory = config['directory']
        if not os.path.exists(self._directory):
            os.makedirs(self._directory)

        # 复制配置文件到输出路径
        with open(os.path.join(self._directory, 'Config.json'),
                  'w') as json_file:
            json.dump(config, json_file, indent=4)
        exit()
        ddb = Query()
        if self._normalized:
            self._normalize_param = ddb.get_normalize_parm(self._tags)
 def __init__(self, normalized=False):
     config = configparser.ConfigParser()
     config.read(
         os.path.join(os.path.dirname(__file__), 'DataSetConfig.ini'))
     self._tags = ast.literal_eval(config['Diagnosis']['tags'])
     self._sample_rate = int(config['Diagnosis']['sample_rate'])
     self._frame_size = int(config['Diagnosis']['frame_size'])
     self._step = int(config['Diagnosis']['step'])
     self._npy_path = config['path']['npy']
     if not os.path.exists(self._npy_path):
         os.makedirs(self._npy_path)
     ddb = Query()
     self._normalized = normalized
     if normalized:
         self._normalize_param = ddb.get_normalize_parm(self._tags)
     my_query = {'IsValidShot': True, 'IsDisrupt': False}
     self._shots = ddb.query(my_query)
     my_query = {
         'IsValidShot': True,
         'IsDisrupt': True,
         'CqTime': {
             "$gte": 0.05
         },
         'IpFlat': {
             '$gte': 110
         }
     }
     self._shots += ddb.query(my_query)
    def run(self):
        data_reader = Reader()
        ddb = Query()
        for shot in self._shots:
            # if shot < 1065500 or shot > 1065599:
            #     continue
            print(shot)
            try:
                tags = ddb.tag(shot)
                if tags['IsDisrupt']:
                    t1 = tags['CqTime']
                else:
                    t1 = tags['RampDownTime']
                new_dig_length = int((t1 * 1000 - 50) * self._sample_rate)
                data = data_reader.read_many(shot, self._tags)
                digs = []
                for tag, (dig, time) in data.items():
                    dig = dig[(0.05 <= time) & (time <= t1)]
                    if self._normalized:
                        dig = (dig - self._normalize_param[tag]['min']) / \
                              (self._normalize_param[tag]['max'] - self._normalize_param[tag]['min'])
                    digs.append(signal.resample(dig, new_dig_length))

                digs = np.array(digs)
                y_ = y(new_dig_length, self._sample_rate, tags['IsDisrupt'])
                index = 0
                path = os.path.join(self._npy_path, '{}'.format(shot))
                if not os.path.exists(path):
                    os.makedirs(path)
                while index + self._frame_size <= new_dig_length:
                    frame = digs[:, index:index + self._frame_size]
                    y_frame = y_[index:index + self._frame_size]
                    np.save(
                        os.path.join(
                            path, 'x_{}.npy'.format(int(index / self._step))),
                        frame)
                    np.save(
                        os.path.join(
                            path, 'y_{}.npy'.format(int(index / self._step))),
                        y_frame)
                    index += self._step
                if index + self._frame_size - new_dig_length < self._frame_size / 2:
                    frame = digs[:, new_dig_length -
                                 self._frame_size:new_dig_length]
                    y_frame = y_[new_dig_length -
                                 self._frame_size:new_dig_length]
                    np.save(
                        os.path.join(
                            path, 'x_{}.npy'.format(int(index / self._step))),
                        frame)
                    np.save(
                        os.path.join(
                            path, 'y_{}.npy'.format(int(index / self._step))),
                        y_frame)
            except Exception as e:
                print(e)
                traceback.print_exc()
Example #4
0
    def get(self):
        """
        加载npy数据到tf.data.DataSet
        :return: training set, test set
        """
        train_test_shots = list()
        with open(os.path.join('log', 'ShotsUsed4Training.txt'), 'r') as f:
            for shot in f.readlines():
                train_test_shots.append(int(shot))
        train_test_shots.sort(reverse=False)
        ddb = Query()
        shots = list()

        with open(os.path.join('log', 'ShotsInDataset.txt'), 'w') as f:
            my_query = {
                'IsValidShot': True,
                'IsDisrupt': True,
                'CqTime': {
                    "$gte": 0.15
                },
                'IpFlat': {
                    '$gte': 110
                }
            }
            for shot in ddb.query(my_query):
                if os.path.exists(
                        os.path.join(self.npy_path, '{}'.format(shot))):
                    shots.append(shot)
            shots.sort(reverse=False)

            for shot in shots:
                if shot in train_test_shots:
                    print('{} 1 d'.format(shot), file=f)
                else:
                    print('{} 0 d'.format(shot), file=f)

            shots.clear()
            my_query = {
                'IsValidShot': True,
                'IsDisrupt': False,
                'IpFlat': {
                    '$gte': 110
                }
            }
            for shot in ddb.query(my_query):
                if os.path.exists(
                        os.path.join(self.npy_path, '{}'.format(shot))):
                    shots.append(shot)

            shots.sort(reverse=False)
            for shot in shots:
                print('{} 0 u'.format(shot), file=f)
    def save_full_npy(self, shots):
        data_reader = Reader()
        ddb = Query()
        path = os.path.join(self._npy_path, 'full')
        if not os.path.exists(path):
            os.makedirs(path)
        print('####Start generate val DataSet####')
        for shot in shots:
            try:
                print(shot)

                tags = ddb.tag(shot)
                if tags['IsDisrupt']:
                    t1 = tags['CqTime']
                else:
                    t1 = tags['RampDownTime']
                new_dig_length = int((t1 * 1000 - 50) * self._sample_rate)
                data = data_reader.read_many(shot, self._tags)
                digs = []
                for tag, (dig, time) in data.items():
                    dig = dig[(0.05 <= time) & (time <= t1)]
                    if self._normalized:
                        dig = (dig - self._normalize_param[tag]['min']) / \
                              (self._normalize_param[tag]['max'] - self._normalize_param[tag]['min'])
                    digs.append(signal.resample(dig, new_dig_length))

                digs = np.array(digs)
                y_ = y(new_dig_length, self._sample_rate, tags['IsDisrupt'])
                index = 0
                x = list()
                labels = list()
                while index + self._frame_size <= new_dig_length:
                    frame = digs[:, index:index + self._frame_size]
                    y_frame = y_[index:index + self._frame_size]
                    # index += self.frame_size
                    x.append(frame)
                    labels.append(y_frame[-1])
                    index += self._step
                x = np.array(x)
                labels = np.array(labels)
                np.save(os.path.join(path, 'x_{}.npy'.format(shot)), x)
                np.save(os.path.join(path, 'y_{}.npy'.format(shot)), labels)
            except Exception as e:
                print(e)
                traceback.print_exc()
Example #6
0
    def generate(self):
        data_reader = Reader()
        ddb = Query()
        for categories, shots in self._shots.items():
            if not os.path.exists(os.path.join(self._directory, categories)):
                os.makedirs(os.path.join(self._directory, categories))
            for shot in shots:
                print(shot)
                try:
                    tags = ddb.tag(shot)
                    if tags['IsDisrupt']:
                        t1 = tags['CqTime']
                    else:
                        t1 = tags['RampDownTime']
                    new_dig_length = int(
                        (t1 * 1000 - self._start_time) * self._sample_rate)
                    data = data_reader.read_many(shot, self._tags)
                    digs = []
                    for tag, (dig, time) in data.items():
                        dig = dig[(self._start_time / 1000 <= time)
                                  & (time <= t1)]
                        # 归一化
                        if self._normalized:
                            dig = (dig - self._normalize_param[tag]['min']) / \
                                  (self._normalize_param[tag]['max'] - self._normalize_param[tag]['min'])
                        # 重采样
                        digs.append(signal.resample(dig, new_dig_length))

                    digs = np.array(digs)

                    f = h5py.File(
                        os.path.join(self._directory, categories,
                                     '{}.hdf5'.format(shot)))
                    dataset = f.create_dataset('diagnosis', data=digs)
                    for key, value in tags.items():
                        dataset.attrs.create(key, value)
                    f.close()

                except Exception as e:
                    print(e)
                    traceback.print_exc()
Example #7
0
    def plot_much(self,
                  Taglist=None,
                  Shotlist=None,
                  Savepath=None,
                  ShowDownTime=False,
                  ShowIpFlat=False,
                  xline=None,
                  yline=None):
        if Savepath:
            root_path = Savepath
            if not os.path.exists(root_path):
                raise ValueError(
                    'No such saving path, you need to create one! ')
        else:
            root_path = os.getcwd() + os.sep + "plot"
            print(root_path)
            if not os.path.exists(root_path):
                os.makedirs(root_path)
        for tag in Taglist:
            tag_name = tag[1:]
            file_path = root_path + os.sep + tag_name
            if not os.path.exists(file_path):
                os.makedirs(file_path)

        reader = Reader(root_path=self.hdf5path)
        db = Query()
        for tag in Taglist:
            tag_name = tag[1:]
            file_path = root_path + os.sep + tag_name
            n = 1
            for shot in Shotlist:
                print("Shot:{}".format(shot) + " Tag:{}  ".format(tag_name) +
                      "No.{}".format(n))
                n += 1
                try:
                    shot_info = db.tag(int(shot))
                    data = reader.read_one(int(shot), tag)
                    plt.figure((str(shot) + tag_name))
                    plt.plot(data[1], data[0], 'g')
                    if ShowDownTime:
                        if shot_info["IsValidShot"]:
                            if shot_info["IsDisrupt"]:
                                plt.axvline(round(shot_info["CqTime"], 3),
                                            c='r')
                            else:
                                plt.axvline(round(shot_info["RampDownTime"],
                                                  3),
                                            c='r')
                    if ShowIpFlat:
                        if tag == r"\ip":
                            if shot_info["IsValidShot"]:
                                plt.axhline(round(shot_info["IpFlat"], 3),
                                            c='k')
                    if xline:
                        if not isNum(xline):
                            raise ValueError('xline needs to be number ')
                        plt.axvline(round(xline, 3))
                    if yline:
                        if not isNum(yline):
                            raise ValueError('yline needs to be number ')
                        plt.axhline(round(yline, 3))
                    path = file_path + os.sep + r"{}.png".format(shot)
                    plt.savefig(path)
                    plt.close()
                except Exception as err:
                    print("Shot:{}".format(shot) +
                          " Tag:{}  ".format(tag_name) + "No data")
                    plt.close()
                    pass
Example #8
0
import os
import numpy as np
from DDB.Service import Query

# IsValidShot
db = Query()
valid_query = {
    'IsValidShot': True,
    'IpFlat': {
        '$gt': 100
    },
    "$or": [{
        'CqTime': {
            '$gt': 0.2
        }
    }, {
        'RampDownTime': {
            '$gt': 0.2
        }
    }]
}
break_query = {
    'IsValidShot': True,
    'IsDisrupt': True,
    'IpFlat': {
        '$gt': 100
    },
    'CqTime': {
        '$gt': 0.2
    }
}
Example #9
0
import DDB
from pymongo import MongoClient
from DDB.Service import Query
from DDB.Data import Reader
import random
from scipy import signal

ddb = Query()
my_query = {'IsValidShot': True, 'IsDisrupt': False}
shots = ddb.query(my_query)
my_query = {
    'IsValidShot': True,
    'IsDisrupt': True,
    'CqTime': {
        "$gte": 0.05
    },
    'IpFlat': {
        '$gte': 110
    }
}
shots += ddb.query(my_query)
config = DDB.get_config()
database = config['output']
client = MongoClient(database['host'], int(database['port']))
db = client[database['database']]
param = db[database['collection'] + '归一化参数']

# tags = [r'\Bt', r'\Ihfp', r'\Ivfp', r'\MA_POL_CA01T', r'\MA_POL_CA02T', r'\MA_POL_CA03T', r'\MA_POL_CA05T', r'\MA_POL_CA06T', r'\MA_POL_CA07T', r'\MA_POL_CA19T', r'\MA_POL_CA20T', r'\MA_POL_CA21T', r'\MA_POL_CA22T', r'\MA_POL_CA23T', r'\MA_POL_CA24T', r'\axuv_ca_01', r'\ip', r'\sxr_cb_024', r'\sxr_cc_049', r'\vs_c3_aa001', r'\vs_ha_aa001']
# reader = Reader(root_path='/nas/hdf5_new')
# result = dict()
# for tag in tags:
Example #10
0
    def load(self):
        """
        加载npy数据到tf.data.DataSet
        :return: training set, test set
        """
        examples_und = list()
        examples_dis = list()
        labels_und = list()
        labels_dis = list()
        ddb = Query()
        shots = list()

        # 不使用非破裂炮进行训练
        # my_query = {'IsValidShot': True, 'IsDisrupt': False}
        # for shot in ddb.query(my_query):
        #     if os.path.exists(os.path.join(self.npy_path, '{}'.format(shot))):
        #         shots.append(shot)
        #         if len(shots) >= self.shots/2:
        #             break

        my_query = {
            'IsValidShot': True,
            'IsDisrupt': True,
            'CqTime': {
                "$gte": 0.15
            },
            'IpFlat': {
                '$gte': 110
            }
        }
        for shot in ddb.query(my_query):
            if os.path.exists(os.path.join(self._npy_path, '{}'.format(shot))):
                shots.append(shot)
                if len(shots) >= self._shots_len:
                    break
        # shots = np.random.choice(shots, self.shots_len)
        if not os.path.exists('log'):
            os.mkdir('log')
        if os.path.exists(os.path.join('log', 'ShotsUsed4Training.txt')):
            path = os.path.join(
                'log', 'ShotsUsed4Training_{}.txt'.format(
                    time.strftime('%Y%m%d_%H%M%S',
                                  time.localtime(time.time()))))
            os.rename(os.path.join('log', 'ShotsUsed4Training.txt'), path)
        with open(os.path.join('log', 'ShotsUsed4Training.txt'), 'w') as f:
            for shot in shots:
                print(shot, file=f)

        for shot in shots:
            file_names = [
                i for i in os.listdir(
                    os.path.join(self._npy_path, '{}'.format(shot)))
                if 'x' in i
            ]
            for file in file_names:
                x = np.load(
                    os.path.join(self._npy_path, '{}'.format(shot), file))
                y = np.load(
                    os.path.join(self._npy_path, '{}'.format(shot),
                                 file.replace('x', 'y')))
                if y[-1] > 0:
                    examples_dis.append(x)
                    labels_dis.append(y[-1])
                else:
                    examples_und.append(x)
                    labels_und.append(y[-1])
        len_und = len(labels_und)
        len_dis = len(labels_dis)
        print('Length un_disruption: ', len_und, '\nLength disruption: ',
              len_dis)
        # --------------------------------------------------------------------------------------
        # 均衡策略1:扩大disruption, un_disruption不变
        # --------------------------------------------------------------------------------------
        # dataset_und = tf.data.Dataset.from_tensor_slices((examples_und, labels_und))
        # dataset_dis = tf.data.Dataset.from_tensor_slices((examples_dis, labels_dis))
        #
        # split_point_und = (int(len(labels_und)*self.train), int(len(labels_und)*self.test))
        # split_point_dis = (int(len(labels_dis)*self.train), int(len(labels_dis)*self.test))
        #
        # train_dataset_und = dataset_und.take(split_point_und[0])
        # test_dataset_und = dataset_und.skip(split_point_und[0]).take(split_point_und[1])
        # train_dataset_dis = dataset_dis.take(split_point_dis[0])
        # test_dataset_dis = dataset_dis.skip(split_point_dis[0]).take(split_point_dis[1])
        #
        # train_dataset_dis = train_dataset_dis.repeat(int(split_point_und[0]/split_point_dis[0]))
        #
        # train_dataset = train_dataset_und.concatenate(train_dataset_dis)
        # test_dataset = test_dataset_und.concatenate(test_dataset_dis)
        # --------------------------------------------------------------------------------------
        # 均衡策略2:disruption扩大2倍, 随机抽取un_disruption, 比例为und/dis = 6/4
        # --------------------------------------------------------------------------------------
        dataset_und = tf.data.Dataset.from_tensor_slices(
            (examples_und, labels_und))
        dataset_dis = tf.data.Dataset.from_tensor_slices(
            (examples_dis, labels_dis))
        dataset_und = dataset_und.shuffle(buffer_size=len_und).take(3 *
                                                                    len_dis)
        dataset_dis = dataset_dis.repeat(2)
        dataset = dataset_und.concatenate(dataset_dis)
        dataset = dataset.shuffle(5 * len_dis)
        train_dataset = dataset.take(int(5 * len_dis * self._train_per))
        test_dataset = dataset.skip(int(5 * len_dis * self._train_per)).take(
            int(5 * len_dis * self._test_per))

        return train_dataset, test_dataset
    def run(self, data):
        result = {
            'IsLockedMode': False,  # 是否发生锁模
            'LockedModeTime': 0.0  # 锁模时间
            # 'IsUnLockedMode': False,  # 是否解锁
            # 'UnLockedModeTime': 0.0  # 解锁时间
        }

        shot = data['shot']
        db = Query()
        tag = db.tag(shot)

        if tag['IsRampUpDisrupt'] is True:
            return result
        if tag['IsDisrupt'] is True:
            end_t = tag['CqTime']
        else:
            end_t = tag['RampDownTime']

        # mirnov信号
        mirnov_t = data['\MA_POL_CA01T']
        if mirnov_t.shape[0] == 2 and mirnov_t.shape[1] != 0:
            mirnov = mirnov_t[0]
            time = mirnov_t[1]
        else:
            result['NoData'] = True
            return result

        # 截取信号
        start = np.where(time >= 0.05)[0][0]
        if time[-1] < end_t:
            end = len(time)
        else:
            end = np.where(time >= end_t)[0][0]

        mirnov = mirnov[start:end]
        time = time[start:end]
        # 信号分析
        # 采样率
        sampling_rate = 250000
        # FFT采样点数
        frame_size = 1024

        # 低通滤波
        # scipy.signal.butter(N, Wn, btype='low', analog=False, output='ba')
        # N:滤波器的阶数
        # Wn:归一化截止频率.计算公式Wn = 2 * 截止频率 / 采样频率
        # 滤除50kHz以上的成分,截止频率50000
        wn = 2 * 10000 / sampling_rate

        # plt.figure(0)
        # plt.subplot(211)
        # plt.title(u'滤波前')
        # plt.plot(time, mirnov)
        [b, a] = signal.butter(3, wn, 'low')
        mirnov = signal.filtfilt(b, a, mirnov)

        # plt.subplot(212)
        # plt.title(u'滤波后')
        # plt.plot(time, mirnov)
        # plt.xlabel('time/s')

        # plt.figure(1, figsize=(19.20, 10.80))
        # plt.subplot(311)
        # plt.plot(time, mirnov)

        # FFT
        i = 0
        max_sqe = []
        max_sqe_t = []
        window = np.hanning(frame_size)

        while i < len(mirnov) - frame_size:
            frames = mirnov[i:i + frame_size]
            frames *= window
            mirnov_fft = np.fft.rfft(frames) / frame_size
            freqs = np.linspace(0, sampling_rate / 2, int(frame_size / 2 + 1))
            mirnov_fft = 20 * np.log10(
                np.clip(np.abs(mirnov_fft), 1e-20, 1e100))

            # plt.figure()
            # plt.subplot(211)
            # plt.plot(time[i:i+frameSize], mirnov[i:i + frameSize])
            # plt.subplot(212)
            # plt.plot(freqs, mirnov_fft)
            # plt.show()

            max_sqe.append(freqs[np.argmax(mirnov_fft)])
            max_sqe_t.append(time[int(i + frame_size / 2)])
            i += 100

        # 中值滤波,平滑处理
        max_sqe = signal.medfilt(max_sqe, 31)

        # plt.subplot(312)
        # plt.scatter(max_sqe_t, max_sqe)

        exsad1_t = data['\exsad1']
        if exsad1_t.shape[0] == 2 and exsad1_t.shape[1] != 0:
            exsad1 = exsad1_t[0]
            exsad1_time = exsad1_t[1]
        else:
            result['NoData'] = True
            return result

        exsad7_t = data['\exsad7']
        if exsad7_t.shape[0] == 2 and exsad7_t.shape[1] != 0:
            exsad7 = exsad7_t[0]
            exsad7_time = exsad7_t[1]
        else:
            result['NoData'] = True
            return result

        exsad_resample = []
        for time in max_sqe_t:
            index1 = np.where(exsad1_time >= time)[0][0]
            index2 = np.where(exsad7_time >= time)[0][0]
            exsad_resample.append(exsad1[index1] * 100 / 2.35 -
                                  exsad7[index2] * 100 / 1.79)

        exsad_resample = signal.medfilt(exsad_resample, 31)
        for i in range(int(len(max_sqe_t) / 2), len(max_sqe_t)):
            if max_sqe[i] < 1000 and np.fabs(exsad_resample[i]) > 10:
                result['IsLockedMode'] = True
                result['LockedModeTime'] = max_sqe_t[i]

        # plt.subplot(313)
        # plt.plot(max_sqe_t, exsad_resample)
        # # plt.scatter(max_sqe_t, exsad7_resample)
        # plt.xlabel('time/s')
        # plt.show()
        # shot = data['shot']
        # plt.savefig('image\\mirnov' + os.sep + '{}.png'.format(shot), dpi=300)
        # plt.close(1)

        return result
TrainBreak = np.load(root_path + os.sep + r"TrainData" + os.sep +
                     r"TrainBreak.npy")
TrainNormal = list(TrainNormal)
TrainBreak = list(TrainBreak)
print(len(TrainNormal))
print(len(TrainBreak))
TrainShot = TrainNormal + TrainBreak

save_path = root_path + os.sep + r"ReduceSampling"
if not os.path.exists(save_path):
    os.makedirs(save_path)

n = 1
mistake = []
reader = Reader()
db = Query()
for shot in TrainShot:
    print("Shot:{}  ".format(shot) + "No.{}".format(n))
    n += 1
    try:
        shot_info = db.tag(int(shot))
        file = h5py.File(save_path + os.sep + r"{}.hdf5".format(shot))
        if not shot_info["IsDisrupt"]:
            DownTime = shot_info["RampDownTime"]
            for shottag in all_tags:
                dataset = reader.read_one(int(shot), shottag)
                data = dataset[0]
                time = dataset[1]
                data = data[time <= DownTime]
                time = time[time <= DownTime]
                data = data[time > 0.2]