Ejemplo n.º 1
0
def run_rolodex(input_file, output_file):
	"""
	This program takes an input file of personal information in multiple formats
	It normalize every valid entry and dumps the ordered result into an output file
	@param string input_file	input file name with path from current dir
	@param string output_file	output file name with path from current dir

	"""
	persons = []
	error_indices = []
	normalizer = Normalizer()

	with open(input_file) as input_file:
		for line_number, line in enumerate(input_file, start=1):
			try:
				person = normalizer.normalize(line.rstrip())
				persons.append(person)
			except NormalizationException:
				error_indices.append(line_number)

	sorted_persons = sorted(persons, key=lambda person: person.__str__())

	output_dict = {
			"entries": sorted_persons,
			"errors": error_indices
		}
	logging.info("Completed, please check output file.")
	with open(output_file, 'w') as output_file:
		json.dump(output_dict, output_file, indent=2, sort_keys=True)
Ejemplo n.º 2
0
    def plot_shot_old(self,shot,save_fig=True,normalize=True,truth=None,prediction=None,P_thresh_opt=None,prediction_type='',extra_filename=''):
        if self.normalizer is None and normalize:
            if self.conf is not None:
                self.saved_conf['paths']['normalizer_path'] = self.conf['paths']['normalizer_path']
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        if(shot.previously_saved(self.shots_dir)):
            shot.restore(self.shots_dir)
            t_disrupt = shot.t_disrupt
            is_disruptive =  shot.is_disruptive
            if normalize:
                self.normalizer.apply(shot)

            use_signals = self.saved_conf['paths']['use_signals']
            f,axarr = plt.subplots(len(use_signals)+1,1,sharex=True,figsize=(13,13))#, squeeze=False)
            #plt.title(prediction_type)
            #all files must agree on T_warning due to output of truth vs. normalized shot ttd.
            assert(np.all(shot.ttd.flatten() == truth.flatten()))
            for i,sig in enumerate(use_signals):
                num_channels = sig.num_channels
                ax = axarr[i]
                sig_arr = shot.signals_dict[sig]
                if num_channels == 1:
                    ax.plot(sig_arr[:,0],label = sig.description)
                else:
                    ax.imshow(sig_arr[:,:].T, aspect='auto', label = sig.description + " (profile)")
                    ax.set_ylim([0,num_channels])
                ax.legend(loc='best',fontsize=8)
                plt.setp(ax.get_xticklabels(),visible=False)
                plt.setp(ax.get_yticklabels(),fontsize=7)
                f.subplots_adjust(hspace=0)
                #print(sig)
                #print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr)))
                ax = axarr[-1] 
            if self.pred_ttd:
                ax.semilogy((-truth+0.0001),label='ground truth')
                ax.plot(-prediction+0.0001,'g',label='neural net prediction')
                ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
            else:
                ax.plot((truth+0.001),label='ground truth')
                ax.plot(prediction,'g',label='neural net prediction')
                ax.axhline(P_thresh_opt,color='k',label='trigger threshold')
            #ax.set_ylim([1e-5,1.1e0])
            ax.set_ylim([-2,2])
            if len(truth)-self.T_max_warn >= 0:
                ax.axvline(len(truth)-self.T_max_warn,color='r',label='min warning time')
            ax.axvline(len(truth)-self.T_min_warn,color='r',label='max warning time')
            ax.set_xlabel('T [ms]')
            #ax.legend(loc = 'lower left',fontsize=10)
            plt.setp(ax.get_yticklabels(),fontsize=7)
            # ax.grid()           
            if save_fig:
                plt.savefig('sig_fig_{}{}.png'.format(shot.number,extra_filename),bbox_inches='tight')
               # np.savez('sig_{}{}.npz'.format(shot.number,extra_filename),shot=shot,T_min_warn=self.T_min_warn,T_max_warn=self.T_max_warn,prediction=prediction,truth=truth,use_signals=use_signals,P_thresh=P_thresh_opt)
            plt.close()
        else:
            print("Shot hasn't been processed")
Ejemplo n.º 3
0
    def save_shot(self, shot, P_thresh_opt=0, extra_filename=""):
        if self.normalizer is None:
            if self.conf is not None:
                self.saved_conf["paths"]["normalizer_path"] = self.conf[
                    "paths"]["normalizer_path"]
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        shot.restore(self.shots_dir)
        # t_disrupt = shot.t_disrupt
        # is_disruptive = shot.is_disruptive
        self.normalizer.apply(shot)

        pred, truth, is_disr = self.get_pred_truth_disr_by_shot(shot)
        use_signals = self.saved_conf["paths"]["use_signals"]
        np.savez(
            "sig_{}{}.npz".format(shot.number, extra_filename),
            shot=shot,
            T_min_warn=self.T_min_warn,
            T_max_warn=self.T_max_warn,
            prediction=pred,
            truth=truth,
            use_signals=use_signals,
            P_thresh=P_thresh_opt,
        )
Ejemplo n.º 4
0
def kmeans(request):
    k = 2
    global norm, champs, name_champ
    norm = Normalizer()
    workpath = os.path.dirname(os.path.abspath(__file__)) #Returns the Path your .py file is in
    datafile = os.path.join(workpath, 'dataset/spambase.data.txt')
    norm.load_csv(os.path.join(workpath, 'dataset/spambase.data.txt'))
    champs = []
    name_champ = []
    if request.method == 'GET':
        if(request.GET['nb'] == '3'):
            champs.append(int(request.GET['champs1']))
            champs.append(int(request.GET['champs2']))
            champs.append(int(request.GET['champs3']))
            name_champ.append(nomChamp[champs[0]])
            name_champ.append(nomChamp[champs[1]])
            name_champ.append(nomChamp[champs[2]])
        elif (request.GET['nb'] == '2'):
            champs.append(int(request.GET['champs1']))
            champs.append(int(request.GET['champs2']))
            name_champ.append(nomChamp[champs[0]])
            name_champ.append(nomChamp[champs[1]])

        global kMeanClusterer
        kMeanClusterer = KMeanClusterer(k, datafile, champs)
        kMeanClusterer.assignement()
        centroids = []
        clusters = []
        for i in range(k):
            centroids.append(kMeanClusterer.getCluster(i).getCentroid())
            #centroids.append(kMeanClusterer.getCluster(i).normalizeCentroid(0.0, 1.0, len(champs)))
        for i in range(k):
            clusters.append(kMeanClusterer.getCluster(i).getPoints())
            #clusters.append(norm.normalization(kMeanClusterer.getCluster(i).getPoints(), 0.0, 1.0, len(champs)))


        splitedData = norm.get_splitedData(champs)
        spams = splitedData[0]
        nospams = splitedData[1]

        html = render_to_string('kmeans.html', {'k': len(champs), 'centroids': centroids, 'clusters': clusters,
                                                'spams': spams, 'no_spams': nospams, 'nomChamps': name_champ})
        return HttpResponse(html)
    else:
        form = DocumentForm() # A empty, unbound form
        return redirect('index.html', {'form': form})
Ejemplo n.º 5
0
def get_stock_data():
    print "Loading data"
    # Import the data from a CSV file coming from Yahoo Finance.
    data_csv = pd.read_csv('all_stocks_5yr.csv')  # Change file path as needed

    # Define a function to remove any NaNs and Infs from an array
    def remove_nan_inf(a):
        s = a[~np.isnan(a).any(axis=1)]
        t = s[~np.isinf(s).any(axis=1)]
        return t

    # Define a Numpy array to hold the information for the NNs.
    n_features = 5
    data = np.empty((data_csv.shape[0], n_features))
    L = 1259  # Number of observations per subject (more info in README)

    print "Loading data"

    # Load the data and normalize it per subject (modify if necessary)
    for i in tqdm(range(int(float(data_csv.shape[0]) / L))):
        holder_array = np.empty(
            (data.shape[1], L))  # Array to temporarily hold information
        for info, j in zip(['open', 'close', 'high', 'low', 'volume'],
                           range(n_features)):
            for k in range(L):
                holder_array[j][k] = np.array(data_csv[info][i * L + k],
                                              dtype=np.float)

        holder_array = remove_nan_inf(holder_array.T)
        scaled_info = Normalizer().fit_transform(holder_array)
        for j in range(scaled_info.shape[0]):
            data[i * L + j] = scaled_info[j]

    # Set batch size, input shape and train/test prop.
    batch_size = 1
    train_test_prop = 0.8
    time_step = 1

    # Organize the data so that inputs are 3D and work with LSTMs.
    data_train = data[:int(data.shape[0] * train_test_prop)]
    data_test = data[int(data.shape[0] * train_test_prop):]
    x_train = data_train[:-1].reshape(data_train[:-1].shape[0], time_step,
                                      data_train[:-1].shape[1])
    y_train = data_train[1:][:, 0]
    x_test = data_test[:-1].reshape(data_test[:-1].shape[0], time_step,
                                    data_test[:-1].shape[1])
    y_test = data_test[1:][:, 0]

    print "Data successfully loaded"

    return (batch_size, x_train, x_test, y_train, y_test)
Ejemplo n.º 6
0
def train(train_env_id: str,
          eval_env_id: str,
          logdir: str,
          cfg: ExperimentConfig,
          save_path: str,
          pretrain_path: Optional[str] = None) -> DDPGAgent:
    pretrain = torch.load(os.path.join(pretrain_path)) \
               if pretrain_path is not None            \
               else None
    env = set_env_metadata(train_env_id, cfg)
    train_env = make_vec_env(train_env_id,
                             num_envs=cfg.episodes_per_cycle,
                             no_timeout=True,
                             seed=cfg.seed)
    eval_env = make_vec_env(eval_env_id,
                            num_envs=cfg.num_eval_envs,
                            no_timeout=True,
                            seed=cfg.seed + 100)
    replay = HERReplayBuffer(cfg=cfg)
    tf_logger = TensorboardLogger(logdir)
    actor = ActorNet(obs_dim=cfg.obs_dim,
                     goal_dim=cfg.goal_dim,
                     action_dim=cfg.action_dim,
                     action_range=cfg.action_range,
                     zero_last=(pretrain_path is not None))
    critic = CriticNet(obs_dim=cfg.obs_dim,
                       goal_dim=cfg.goal_dim,
                       action_dim=cfg.action_dim,
                       action_range=cfg.action_range)
    normalizer = Normalizer(cfg.obs_dim+cfg.goal_dim) \
                 if pretrain is None                  \
                 else pretrain.normalizer
    agent = DDPGAgent(cfg=cfg,
                      actor=actor,
                      critic=critic,
                      normalizer=normalizer,
                      reward_fn=env.compute_reward,
                      pretrain=getattr(pretrain, 'actor', None))
    engine = DDPGEngine(cfg=cfg,
                        agent=agent,
                        train_env=train_env,
                        eval_env=eval_env,
                        replay=replay,
                        tf_logger=tf_logger)
    engine.train()

    env.close()
    train_env.close()
    eval_env.close()
    torch.save(agent, os.path.join(save_path))
    return agent
Ejemplo n.º 7
0
    def generate(self, edgeCount, tfidf = False, window_size = 0, degree = False, closeness = False, groups= False):
        parser = XMLDataframeParser()
        text = parser.getText("./data/smokingRecords.xml")
        parser.addFeatureFromText(text, "HISTORY OF PRESENT ILLNESS :", "", True, True, "illness")
        df = parser.getDataframe()
        df_xml = parser.removeEmptyEntries(df, "illness")
        normalizer = Normalizer()
        if tfidf:
            if window_size == 0:
                vectorizer = TfidfVectorizer(tokenizer = lambda text: normalizer.normalize(text, True, False), ngram_range = (2, 2))
                mostFreq2Grams = self.get_first_n_words(vectorizer, df_xml.illness, edgeCount)
            else:
                vectorizer = TfidfVectorizer(analyzer = lambda text: self.custom_analyser(text, 2, int(window_size)))
                mostFreq2Grams = self.get_first_n_words(vectorizer, normalizer.normalizeArray(df_xml.illness, True, False), edgeCount)
        else:
            if window_size == 0:
                vectorizer = CountVectorizer(tokenizer = lambda text: normalizer.normalize(text, True, False), ngram_range = (2, 2))
                mostFreq2Grams = self.get_first_n_words(vectorizer, df_xml.illness, edgeCount)
            else:
                vectorizer = CountVectorizer(analyzer = lambda text: self.custom_analyser(text, 2, int(window_size)))
                mostFreq2Grams = self.get_first_n_words(vectorizer, normalizer.normalizeArray(df_xml.illness, True, False), edgeCount)
        df_graph = self.create_dataframe(mostFreq2Grams)
        GF = nx.from_pandas_edgelist(df_graph, 'Node1', 'Node2', ["Weight"])
        

        if degree:
            # calculate degree centrality
            degree_centrality = nx.degree_centrality(GF)
            nx.set_node_attributes(GF, degree_centrality, "degree_centrality")
            
        if closeness:
            # calculate closeness centrality    
            closeness_centrality = nx.closeness_centrality(GF) 
            nx.set_node_attributes(GF, closeness_centrality, "closeness_centrality")

        if groups:
            # calculate partitions
            partition = community.best_partition(GF)
            nx.set_node_attributes(GF, partition, "group")

        payload = json_graph.node_link_data(GF)
        return payload
Ejemplo n.º 8
0
    def __init__(self, params):
        print(datetime.now().strftime(params.time_format), 'Starting..')
        # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # verbose off(info, warning)
        random.seed(params.seed)
        np.random.seed(params.seed)
        tf.set_random_seed(params.seed)

        print("A GPU is{} available".format(
            "" if tf.test.is_gpu_available() else " NOT"))

        stm_dict = dict()
        stm_dict['params'] = params

        FLAGS.model_dir = './biobert_ner/pretrainedBERT/'
        FLAGS.bert_config_file = './biobert_ner/conf/bert_config.json'
        FLAGS.vocab_file = './biobert_ner/conf/vocab.txt'
        FLAGS.init_checkpoint = \
            './biobert_ner/pretrainedBERT/pubmed_pmc_470k/biobert_model.ckpt'

        FLAGS.ip = params.ip
        FLAGS.port = params.port

        FLAGS.gnormplus_home = params.gnormplus_home
        FLAGS.gnormplus_host = params.gnormplus_host
        FLAGS.gnormplus_port = params.gnormplus_port

        FLAGS.tmvar2_home = params.tmvar2_home
        FLAGS.tmvar2_host = params.tmvar2_host
        FLAGS.tmvar2_port = params.tmvar2_port

        # import pprint
        # pprint.PrettyPrinter().pprint(FLAGS.__flags)

        stm_dict['biobert'] = BioBERT(FLAGS)

        stm_dict['gnormplus_home'] = params.gnormplus_home
        stm_dict['gnormplus_host'] = params.gnormplus_host
        stm_dict['gnormplus_port'] = params.gnormplus_port

        stm_dict['tmvar2_home'] = params.tmvar2_home
        stm_dict['tmvar2_host'] = params.tmvar2_host
        stm_dict['tmvar2_port'] = params.tmvar2_port

        stm_dict['max_word_len'] = params.max_word_len
        stm_dict['ner_model'] = params.ner_model
        stm_dict['n_pmid_limit'] = params.n_pmid_limit
        stm_dict['time_format'] = params.time_format
        stm_dict['available_formats'] = params.available_formats

        if not os.path.exists('./output'):
            os.mkdir('output')
        else:
            # delete prev. version outputs
            delete_files('./output')

        #delete_files(os.path.join(params.gnormplus_home, 'input'))
        #delete_files(os.path.join(params.tmvar2_home, 'input'))
        #ceshi_pdf = '/home/yg/work/BioNLP/NER_RE_DB/pdf_data/z.R.scholar/data/rs993419_DNMT3B/A_new_and_a_reclassified_ICF_patient_without_mutations_in_DNMT3B_and_its_interacting_proteins_SUMO‐1_and_UBC9.pdf'
        ner = NER()
        ner.stm_dict = stm_dict
        ner.normalizer = Normalizer()
        input_db_path = '/home/yg/work/BioNLP/NER_RE_DB/pdf_data/z.R.scholar/instance/paper.sqlite'
        pdf_data_path = '/home/yg/work/BioNLP/NER_RE_DB/pdf_data/z.R.scholar/data'
        res = ner.run_pipeline(input_db_path, pdf_data_path)
        stm_dict['biobert'].close()
        return

        print("任务开始")
        res = ceshi_run(stm_dict)
        print('以下是本次测试的结果:')
        print(res)
        stm_dict['biobert'].close()
Ejemplo n.º 9
0
def main(p):
    start = time.time()

    # 选择文件名以'json.gz'结尾的记录
    file_name_list = filter(lambda x: x.endswith('json.gz'), os.listdir(p))

    # TODO 添加文件是否是24个的判断(glob模块)

    for file_name in file_name_list:
        with open(os.path.join(p, file_name), 'r') as f:
            raw_json_file = gzip.GzipFile(fileobj=f)

            record_cleaner = Cleaner()
            record_grouper = Grouper(db)
            record_normalizer = Normalizer(db)
            mongo_helper = MongoHelper(db)
            counter = ActorCounter()
            evaluater = Evaluater()

            # 数据清洗
            record_cleaner.set_dirty_data(raw_json_file)
            record_cleaner.clean()
            clean_record = record_cleaner.get_clean_data()
            log.log('clean record %s' % len(clean_record))
            # 数据处理

            # 分组
            record_grouper.set_records(clean_record)
            record_grouper.group()
            record_actor_exist = record_grouper.get_group_1()
            record_actor_new= record_grouper.get_group_2()
            log.log('record_actor_exist: %s' % len(record_actor_exist))
            log.log('record_actor_new: %s' % len(record_actor_new))


            # 处理记录的actor已存在的记录
            log.log('Begin processing actor-exist records...')
            # 只需要删掉记录的actor_attrs即可
            for record in record_actor_exist:
                del record['actor_attributes']
            log.log('Finished.')


            # 处理记录的actor不存在的记录
            record_normalizer.set_records(record_actor_new)
            record_normalizer.normalize()
            record_actor_new = record_normalizer.get_record_actor_new()
            new_actors = record_normalizer.get_new_actors()

            # 把本地的今日新增的Actor更新到数据库
            actors = new_actors.values()
            mongo_helper.insert_new_actors(actors)

            # 对新增的Actor, 改变Redis中相应的计数
            counter.count_actor_list(actors)

            # 计算每条记录的val
            evaluater.set_records(record_actor_exist)
            evaluater.evaluate()
            val_actor_exist = evaluater.get_val_cache()

            evaluater.set_records(record_actor_new)
            evaluater.evaluate()
            val_actor_new = evaluater.get_val_cache()

            # 将记录插入数据库
            mongo_helper.insert_new_reocrds(record_actor_new)
            mongo_helper.insert_new_reocrds(record_actor_exist)

            # 将今日用户新增的val更新到数据库
            mongo_helper.update_val(val_actor_new)
            mongo_helper.update_val(val_actor_exist)

            record_cleaner.free_mem()
            del record_cleaner
            del record_grouper
            del record_normalizer
            del mongo_helper
            del counter
            del evaluater

    # 生成CSV文件
    util.grcount2csv()

    end = time.time()
    log.log('total: %s s' % (end - start))
Ejemplo n.º 10
0
def index(request):
    # Handle file upload
    if request.method == 'POST':
        form = DocumentForm(request.POST, request.FILES)
        if form.is_valid():
            handle_uploaded_file(request.FILES['docfile'], request.FILES['docfile'].name)
            # Redirect to the document list after POST
            norm = Normalizer()
            data_save = []
            workpath = os.path.dirname(os.path.abspath(__file__)) #Returns the Path your .py file is in
            data = norm.load_csv(os.path.join(workpath, 'dataset/'+request.FILES['docfile'].name))

            for line in data:
                try:
                    data_save.append(line)
                except IndexError:
                    pass

            '''
            data_normalized = norm.normalization(data_save, 0.0, 1.0, 58)
            stats = norm.statistics(data_normalized, 58)
            '''
            normalizedData = norm.normalization()
            normSplitedData = norm.split(normalizedData)
            normNospams = normSplitedData[1]
            normSpams = normSplitedData[0]
            stats = norm.stats(normSpams, normNospams)

            spam = []
            for i in range(0, 58):
                line = []
                line.append(i)
                for j in range(0, 4):
                    line.append(stats[0][j][i])
                spam.append(line)

            no_spam = []
            for i in range(0, 58):
                line = []
                line.append(i)
                for j in range(0, 4):
                    line.append(stats[1][j][i])
                no_spam.append(line)

            new_stats = []
            for i in range(0, 58):
                line = []
                line.append(i)
                for j in range(0, 4):
                    line.append(stats[0][j][i])
                for j in range(0, 4):
                    line.append(stats[1][j][i])
                new_stats.append(line)

            global nomChamp
            nomChamp = ['word_freq_make', 'word_freq_address', 'word_freq_all',
                        'word_freq_3d', 'word_freq_our', 'word_freq_over', 'word_freq_remove',
                        'word_freq_internet', 'word_freq_order', 'word_freq_mail',
                        'word_freq_receive', 'word_freq_will', 'word_freq_people',
                        'word_freq_report', 'word_freq_addresses', 'word_freq_free',
                        'word_freq_business', 'word_freq_email', 'word_freq_you',
                        'word_freq_credit', 'word_freq_your', 'word_freq_font',
                        'word_freq_000', 'word_freq_money', 'word_freq_hp', 'word_freq_hpl',
                        'word_freq_george', 'word_freq_650', 'word_freq_lab',
                        'word_freq_labs', 'word_freq_telnet', 'word_freq_857',
                        'word_freq_data', 'word_freq_415', 'word_freq_85',
                        'word_freq_technology', 'word_freq_1999', 'word_freq_parts',
                        'word_freq_pm', 'word_freq_direct', 'word_freq_cs',
                        'word_freq_meeting', 'word_freq_original', 'word_freq_project',
                        'word_freq_re', 'word_freq_edu', 'word_freq_table',
                        'word_freq_conference', 'char_freq_semi', 'char_freq_lparen',
                        'char_freq_lbrack', 'char_freq_bang', 'char_freq_dollar',
                        'char_freq_hash', 'capital_run_length_average',
                        'capital_run_length_longest', 'capital_run_length_total',
                        'spam']

            stats_names = []
            stats_names = zip(nomChamp, new_stats)
            return render(request, 'stats.html', {'data': stats_names})
    else:
        form = DocumentForm() # A empty, unbound form
        return render(request, 'index.html', {'form': form})
Ejemplo n.º 11
0
def main(p):
    start = time.time()

    # 选择文件名以'json.gz'结尾的记录
    file_name_list = filter(lambda x: x.endswith('json.gz'), os.listdir(p))

    # TODO 添加文件是否是24个的判断(glob模块)

    for file_name in file_name_list:
        with open(os.path.join(p, file_name), 'r') as f:
            raw_json_file = gzip.GzipFile(fileobj=f)

            record_cleaner = Cleaner()
            record_grouper = Grouper(db)
            record_normalizer = Normalizer(db)
            mongo_helper = MongoHelper(db)
            counter = ActorCounter()
            evaluater = Evaluater()

            # 数据清洗
            record_cleaner.set_dirty_data(raw_json_file)
            record_cleaner.clean()
            clean_record = record_cleaner.get_clean_data()
            log.log('clean record %s' % len(clean_record))
            # 数据处理

            # 分组
            record_grouper.set_records(clean_record)
            record_grouper.group()
            record_actor_exist = record_grouper.get_group_1()
            record_actor_new = record_grouper.get_group_2()
            log.log('record_actor_exist: %s' % len(record_actor_exist))
            log.log('record_actor_new: %s' % len(record_actor_new))

            # 处理记录的actor已存在的记录
            log.log('Begin processing actor-exist records...')
            # 只需要删掉记录的actor_attrs即可
            for record in record_actor_exist:
                del record['actor_attributes']
            log.log('Finished.')

            # 处理记录的actor不存在的记录
            record_normalizer.set_records(record_actor_new)
            record_normalizer.normalize()
            record_actor_new = record_normalizer.get_record_actor_new()
            new_actors = record_normalizer.get_new_actors()

            # 把本地的今日新增的Actor更新到数据库
            actors = new_actors.values()
            mongo_helper.insert_new_actors(actors)

            # 对新增的Actor, 改变Redis中相应的计数
            counter.count_actor_list(actors)

            # 计算每条记录的val
            evaluater.set_records(record_actor_exist)
            evaluater.evaluate()
            val_actor_exist = evaluater.get_val_cache()

            evaluater.set_records(record_actor_new)
            evaluater.evaluate()
            val_actor_new = evaluater.get_val_cache()

            # 将记录插入数据库
            mongo_helper.insert_new_reocrds(record_actor_new)
            mongo_helper.insert_new_reocrds(record_actor_exist)

            # 将今日用户新增的val更新到数据库
            mongo_helper.update_val(val_actor_new)
            mongo_helper.update_val(val_actor_exist)

            record_cleaner.free_mem()
            del record_cleaner
            del record_grouper
            del record_normalizer
            del mongo_helper
            del counter
            del evaluater

    # 生成CSV文件
    util.grcount2csv()

    end = time.time()
    log.log('total: %s s' % (end - start))
Ejemplo n.º 12
0
    def plot_shot(self,shot,save_fig=True,normalize=True,truth=None,prediction=None,P_thresh_opt=None,prediction_type='',extra_filename=''):
        print('plotting shot,',shot,prediction_type,prediction.shape)
        if self.normalizer is None and normalize:
            if self.conf is not None:
                self.saved_conf['paths']['normalizer_path'] = self.conf['paths']['normalizer_path']
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)
    
        if(shot.previously_saved(self.shots_dir)):
            shot.restore(self.shots_dir)
            if shot.signals_dict is not None: #make sure shot was saved with data
                t_disrupt = shot.t_disrupt
                is_disruptive =  shot.is_disruptive
                if normalize:
                    self.normalizer.apply(shot)
    
                use_signals = self.saved_conf['paths']['use_signals']
                all_signals = self.saved_conf['paths']['all_signals']
                fontsize= 18
                lower_lim = 0 #len(pred)
                plt.close()
                colors = ['b','green','red','c','m','orange','k','y']
                lss = ["-","--"]
                #f,axarr = plt.subplots(len(use_signals)+1,1,sharex=True,figsize=(10,15))#, squeeze=False)
                f,axarr = plt.subplots(4+1,1,sharex=True,figsize=(18,18))#,squeeze=False)#, squeeze=False)
                #plt.title(prediction_type)
                #assert(np.all(shot.ttd.flatten() == truth.flatten()))
                xx = range((prediction.shape[0]))
                j=0 #list(reversed(range(len(pred))))
                j1=0
                p0=0
                for i,sig_target in enumerate(all_signals):
                    if sig_target.description== 'n1 finite frequency signals': 
#'Locked mode amplitude':
                       target_plot=shot.signals_dict[sig_target]##[:,0]
                       target_plot=target_plot[:,0]
                       print(target_plot.shape)
                    elif sig_target.description== 'Locked mode amplitude':
                       lm_plot=shot.signals_dict[sig_target]##[:,0]
                       lm_plot=lm_plot[:,0]
                for i,sig in enumerate(use_signals):
                    num_channels = sig.num_channels
                    sig_arr = shot.signals_dict[sig]
                    legend=[]
                    if num_channels == 1:
                        j=i//7
                        ax = axarr[j]
        #                 if j == 0:
                        ax.plot(xx,sig_arr[:,0],linewidth=2,color=colors[i%7],label=sig.description)#,linestyle=lss[j],color=colors[j])
        #                 else:
        #                     ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j],label = labels[sig])
                        #if np.min(sig_arr[:,0]) < -100000:
                        if j==0:
                          ax.set_ylim([-3,15])
                          ax.set_yticks([0,5,10])
                        else:
                          ax.set_ylim([-15,15])
                          ax.set_yticks([-10,-5,0,5,10])
        #                 ax.set_ylabel(labels[sig],size=fontsize)
                        ax.legend()
                    else:
                        j=-2-j1
                        j1+=1
                        ax = axarr[j]
                        ax.imshow(sig_arr[:,:].T, aspect='auto', label = sig.description,cmap="inferno" )
                        ax.set_ylim([0,num_channels])
                        ax.text(lower_lim+200, 45, sig.description, bbox={'facecolor': 'white', 'pad': 10},fontsize=fontsize)
                        ax.set_yticks([0,num_channels/2])
                        ax.set_yticklabels(["0","0.5"])
                        ax.set_ylabel("$\\rho$",size=fontsize)
                    ax.legend(loc="center left",labelspacing=0.1,bbox_to_anchor=(1,0.5),fontsize=fontsize,frameon=False)
                   # ax.axvline(len(truth)-self.T_min_warn,color='r')
                   # ax.axvline(p0,linestyle='--',color='darkgrey')
                    plt.setp(ax.get_xticklabels(),visible=False)
                    plt.setp(ax.get_yticklabels(),fontsize=fontsize)
                    f.subplots_adjust(hspace=0)
                    #print(sig)
                    #print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr)))
                ax = axarr[-1] 
                #         ax.semilogy((-truth+0.0001),label='ground truth')
                #         ax.plot(-prediction+0.0001,'g',label='neural net prediction')
                #         ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
        #         nn = np.min(pred)
        #        ax.plot(xx,truth,'g',label='target',linewidth=2)
        #         ax.axhline(0.4,linestyle="--",color='k',label='threshold')
                print('predictions shape:',prediction.shape)
                print('truth shape:',truth.shape)
                
     #           prediction=prediction[:,0]
                prediction=prediction#-1.5
                prediction[prediction<0]=0.0
                minii,maxii= np.amin(prediction),np.amax(prediction)
                lm_plot_max=np.amax(lm_plot)
                lm_plot=lm_plot/lm_plot_max*maxii
                truth_plot_max=np.amax(truth[:,1])
                truth_plot=truth[:,1]/truth_plot_max*maxii

                print('******************************************************')
                print('******************************************************')
                print('******************************************************')
                print('Truth_plot',truth_plot[:-10])
                print('lm_plot',lm_plot[:-10])
                print('******************************************************')
                target_plot_max=np.amax(target_plot)
                target_plot=target_plot/target_plot_max*maxii
                ax.plot(xx,truth_plot,'yellow',label='truth')
                ax.plot(xx,lm_plot,'pink',label='Locked mode amplitude')
                ax.plot(xx,target_plot,'cyan',label='n1rms')
                #ax.plot(xx,truth,'pink',label='target')
                ax.plot(xx,prediction[:,0],'blue',label='FRNN-U predicted n=1 mode ',linewidth=2)
                ax.plot(xx,prediction[:,1],'red',label='FRNN-U predicted locked mode ',linewidth=2)
                #ax.axhline(P_thresh_opt,linestyle="--",color='k',label='threshold',zorder=2)
                #ax.axvline(p0,linestyle='--',color='darkgrey')
    #            ax.set_ylim([np.amin(prediction,truth),np.amax(prediction,truth)])
                ax.set_ylim([0,maxii])
                print('predictions:',shot.number,prediction)
#                ax.set_ylim([np.min([prediction,target_plot,lm_plot]),np.max([prediction,target_plot,lm_plot])])
                #ax.set_yticks([0,1])
                #if p0>0:
                #  ax.scatter(xx[k],p,s=300,marker='*',color='r',zorder=3)
                
                # if len(truth)-T_max_warn >= 0:
                #     ax.axvline(len(truth)-T_max_warn,color='r')#,label='max warning time')
        #        ax.axvline(len(truth)-self.T_min_warn,color='r',linewidth=0.5)#,label='min warning time')
                ax.set_xlabel('T [ms]',size=fontsize)
                # ax.axvline(2400)
                ax.legend(loc="center left",labelspacing=0.1,bbox_to_anchor=(1,0.5),fontsize=fontsize+2,frameon=False)
                plt.setp(ax.get_yticklabels(),fontsize=fontsize)
                plt.setp(ax.get_xticklabels(),fontsize=fontsize)
                # plt.xlim(0,200)
                plt.xlim([lower_lim,len(truth)])
        #         plt.savefig("{}.png".format(num),dpi=200,bbox_inches="tight")
                if save_fig:
                    plt.savefig('sig_fig_{}{}.png'.format(shot.number,extra_filename),bbox_inches='tight')
                    #np.savez('sig_{}{}.npz'.format(shot.number,extra_filename),shot=shot,T_min_warn=self.T_min_warn,T_max_warn=self.T_max_warn,prediction=prediction,truth=truth,use_signals=use_signals,P_thresh=P_thresh_opt)
                #plt.show()
        else:
            print("Shot hasn't been processed")
Ejemplo n.º 13
0
 def _normalize(response):
     callback(
         Normalizer(self, data_type, response,
                    **self._normalizer_options).normalize())
Ejemplo n.º 14
0
from normalize import Normalizer
from recognize import LetterRecognizer
from model import LetterClassifier

# creating the recognizer
# it uses `tesserocr` to provide us with very rough letter boxes
# because those boxes are pretty rough, we have to normalize them first
# by rough prediction of letter box I mean a box with 2 or more letters in it(or maybe none at all)
rec = LetterRecognizer(image_path='./res/table_real.jpg')
letters_raw = rec.find_letters()

# creating normalizer(cuts all the letters into the small images and pads them to the same size)
normalizer = Normalizer(letter_sequence=letters_raw)
letters_norm = normalizer.normalized_letters()

# creating SGD classifier model
classifier = LetterClassifier(gamma=1e-4)
classifier.load(dump_filename='model.dmp')

# resulting in labels
labels = classifier.predict(letters_norm)
Ejemplo n.º 15
0
            self.convert_cont_spaces,
            self.strip
        ]
        _text = text
        for func in funcs:
            _text = func(_text)
        return _text


if __name__ == '__main__':
    import sys
    from normalize import Normalizer

    fd = open(sys.argv[1]) if len(sys.argv) >= 2 else sys.stdin
    ps = JaWikiPreprocess()
    norm = Normalizer()

    for _line in (_.strip() for _ in fd):
        for line in _line.split("。"):
            try:
                conv = ps.execute(line)
                if conv:
                    print(norm.normalize(conv + "。"))
            except KeyboardInterrupt:
                exit()
            except:
                traceback.print_exc()

    # for file in jawiki-latest*.txt; do python
    #       ~/Projects/cabocha/jawiki_preprocess.py $file >preprocess/$file.pre.txt;
    #   done
Ejemplo n.º 16
0
def build_dataset(slide_dir,
                  output_dir,
                  projects,
                  background=0.2,
                  size=255,
                  reject_rate=0.1,
                  ignore_repeat=False):
    proceed = None
    train_path = os.path.join(output_dir, "train.h5")
    val_path = os.path.join(output_dir, "val.h5")
    test_path = os.path.join(output_dir, "test.h5")

    if (os.path.isfile(train_path) and os.path.isfile(val_path)
            and os.path.isfile(test_path)):
        while not (proceed == "C" or proceed == "A" or proceed == "R"
                   or proceed == "Q"):
            print(
                """A dataset already exists in this directory. Do you want to \n
                    - Continue to build the datset [C] \n
                    - Reset the dataset [R] \n
                    - Quit [Q]
                """)
            proceed = input().upper()
            if proceed == "R":
                os.remove(train_path)
                os.remove(val_path)
                os.remove(test_path)

    if proceed == "C":
        train_data = load_set_data(train_path)
        val_data = load_set_data(val_path)
        test_data = load_set_data(test_path)

        train_h5 = h5py.File(train_path, 'a')
        val_h5 = h5py.File(val_path, 'a')
        test_h5 = h5py.File(test_path, 'a')
    elif proceed == "R" or proceed == None:
        if projects is None:
            raise ValueError("Missing list of projects to download.")
        data = get_projects_info(projects)

        train_h5 = h5py.File(train_path, 'a')
        val_h5 = h5py.File(val_path, 'a')
        test_h5 = h5py.File(test_path, 'a')

        all_cases = list(data['case to images'].keys())
        shuffle(all_cases)

        #split to train and val+test
        train_len = int(0.8 * len(all_cases))
        train_set = all_cases[:train_len]
        all_cases = all_cases[train_len:]

        #split val+test into val and test
        val_len = int(0.5 * len(all_cases))
        val_set = all_cases[:val_len]
        test_set = all_cases[val_len:]

        train_data = split_to_sets(train_set, data, train_path)
        val_data = split_to_sets(val_set, data, val_path)
        test_data = split_to_sets(test_set, data, test_path)

    if proceed != "Q":
        dataset = [(list(train_data["image to sample"].keys()), train_h5),
                   (list(val_data["image to sample"].keys()), val_h5),
                   (list(test_data["image to sample"].keys()), test_h5)]

        # train_images = ["TCGA-44-7671-01A-01-BS1.914604a2-de9c-404d-9fa5-23fbd0b76da3.svs"]
        # val_images = ["TCGA-FF-8041-01A-01-TS1.b8b69ce3-a325-4864-a5b0-43c450347bc9.svs"]
        # test_images = ["TCGA-G8-6326-01A-01-TS1.e0eb24da-6293-4ecb-8345-b70149c84d1e.svs"]

        # # # train_images = []
        # val_images = []
        # test_images = []

        # dataset = [
        #     (train_images, train_h5),
        #     (val_images, val_h5),
        #     (test_images, test_h5)
        # ]

        normalizer = Normalizer()
        for images, h5_file in dataset:
            image_h5_file = h5_file.require_group("images")

            for filename in images:
                if proceed != "C" or ".".join(
                        filename.split(".")[:-1]) not in image_h5_file:
                    download_image(filename, slide_dir)

                    Tile(slide_loc=os.path.join(slide_dir, filename),
                         set_hdf5_file=image_h5_file,
                         normalizer=normalizer,
                         background=background,
                         size=size,
                         reject_rate=reject_rate,
                         ignore_repeat=ignore_repeat)

            h5_file.close()

        normalizer.normalize_dir(output_dir)
Ejemplo n.º 17
0
def run_pipeline():

    #get training data
    training_data = pd.read_csv('worldbank-data/WDI_Data.csv')
    training_data.set_index(['Country Name', 'Indicator Name'], inplace=True)

    #convert to panel
    panel = training_data.to_panel()
    panel.drop(['Indicator Code', 'Country Code'], axis=0, inplace=True)
    panel = panel.swapaxes(0, 1)

    indicators_to_use = [
        'Agriculture, value added (% of GDP)',
        'Industry, value added (% of GDP)',
        'Services, etc., value added (% of GDP)',
        'Domestic credit provided by financial sector (% of GDP)',
        'GDP growth (annual %)', 'GDP (current US$)', 'Expense (% of GDP)',
        'Inflation, consumer prices (annual %)',
        'Inflation, GDP deflator (annual %)',
        'Total debt service (% of exports of goods, services and primary income)',
        'Current account balance (BoP, current US$)',
        'External balance on goods and services (% of GDP)',
        'Health expenditure, total (% of GDP)', 'Tax revenue (% of GDP)',
        'Gross capital formation (% of GDP)', 'Gross savings (% of GDP)',
        'Net investment in nonfinancial assets (% of GDP)',
        'Bank capital to assets ratio (%)',
        'Bank nonperforming loans to total gross loans (%)',
        'Broad money (% of GDP)',
        'Commercial bank branches (per 100,000 adults)',
        'Deposit interest rate (%)', 'Real interest rate (%)',
        'Risk premium on lending (lending rate minus treasury bill rate, %)',
        'Total reserves (includes gold, current US$)',
        'Unemployment, total (% of total labor force) (modeled ILO estimate)',
        'Interest rate spread (lending rate minus deposit rate, %)'
    ]
    print len(indicators_to_use), 'indicators used'
    panel = panel[:, :, indicators_to_use]

    target_variables = [
        'Agriculture, value added (% of GDP)',
        'Industry, value added (% of GDP)',
        'Services, etc., value added (% of GDP)', 'GDP growth (annual %)',
        'Inflation, GDP deflator (annual %)',
        'Gross capital formation (% of GDP)', 'Gross savings (% of GDP)',
        'Bank capital to assets ratio (%)',
        'Bank nonperforming loans to total gross loans (%)',
        'Deposit interest rate (%)', 'Real interest rate (%)',
        'Risk premium on lending (lending rate minus treasury bill rate, %)',
        'Unemployment, total (% of total labor force) (modeled ILO estimate)',
        'Interest rate spread (lending rate minus deposit rate, %)'
    ]
    #drop useless countries such as samoa, lesoto and so on.
    useful_countries = []
    for country in panel.axes[0]:
        if find_null_percentage(panel[country, :, :]) < 0.7:
            useful_countries.append(country)
    panel = panel.ix[useful_countries, :, :]

    normalizer = Normalizer(panel)
    normalized_panel = normalizer.normalize(panel)

    # #visualize normalization:
    # for indicator in normalized_panel.axes[2]:
    #     plot_hist(indicator, [panel, normalized_panel])

    # select train data
    years_to_validate = 1
    years_to_predict = 10
    years_train = generate_year_list(stop=2016 - years_to_validate)
    years_val = generate_year_list(start=2016 - years_to_validate + 1)
    years_predict = generate_year_list(start=2017,
                                       stop=2016 + years_to_predict)
    train_panel = normalized_panel[:, years_train, :].copy()

    # fill missing values:
    # either banal mean or median filling
    # or sampling with a generative bidirectional LSTM - see https://arxiv.org/abs/1306.1091

    generative_model = dense_generative_model(train_panel,
                                              hidden_layers=[120],
                                              epochs=100)
    sampled_filled_values = iterative_fill(generative_model,
                                           train_panel,
                                           normalizer,
                                           iterations=50,
                                           burn_in=10)
    train_panel.update(sampled_filled_values, overwrite=False)
    # or
    # train_panel.fillna(0, inplace=True)
    # or
    # train_panel = iterative_fill_bLSTM(train_panel)
    # or
    # filled_panel = fill_missing_bLSTM(train_panel, epochs=100)
    # train_panel.update(filled_panel, overwrite=False)
    # or
    # interpolate(train_panel)

    # create 1-step-ahead model
    epochs = 200
    hl = [100, 100]
    print "ARCHITECTURE:", hl
    print 'EPOCHS:', epochs
    X_train = train_panel[:, years_train, :][:, :-1, :]
    y_train = train_panel[:, years_train, :][:, 1:, :]
    model = dense_gradient_model(X_train,
                                 y_train,
                                 hidden_layers=hl,
                                 d=0.2,
                                 patience=50,
                                 epochs=epochs)

    # finally, predict
    for start, year in enumerate(years_val + years_predict):
        predictions = model.predict(train_panel[:,
                                                start + 1:, :].values)[:,
                                                                       -1, :]
        train_panel = train_panel.swapaxes(0, 1)
        new_year_df = pd.DataFrame(data=predictions,
                                   index=train_panel.axes[1],
                                   columns=y_train.axes[2])
        train_panel[year] = new_year_df
        train_panel = train_panel.swapaxes(0, 1)
    print "score:", rmse(
        normalized_panel[:, years_val, target_variables].values,
        train_panel[:, years_val, target_variables].values)

    #revert to original scale and distributions
    train_panel = normalizer.renormalize(train_panel)

    #convert to dataframe, and write relevant information to file
    target_countries = ['Bulgaria', 'Cyprus', 'Albania']
    train_panel = train_panel.swapaxes(0, 1)
    df = train_panel[:, target_countries,
                     target_variables].to_frame(filter_observations=False)
    df.to_csv('Predictions.csv')
Ejemplo n.º 18
0
    def __init__(self, params):
        print(datetime.now().strftime(params.time_format), 'Starting..')
        # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # verbose off(info, warning)
        random.seed(params.seed)
        np.random.seed(params.seed)
        tf.set_random_seed(params.seed)

        print("A GPU is{} available".format(
            "" if tf.test.is_gpu_available() else " NOT"))

        stm_dict = dict()
        stm_dict['params'] = params

        FLAGS.model_dir = './biobert_ner/pretrainedBERT/'
        FLAGS.bert_config_file = './biobert_ner/conf/bert_config.json'
        FLAGS.vocab_file = './biobert_ner/conf/vocab.txt'
        FLAGS.init_checkpoint = \
            './biobert_ner/pretrainedBERT/pubmed_pmc_470k/biobert_model.ckpt'

        FLAGS.ip = params.ip
        FLAGS.port = params.port

        FLAGS.gnormplus_home = params.gnormplus_home
        FLAGS.gnormplus_host = params.gnormplus_host
        FLAGS.gnormplus_port = params.gnormplus_port

        FLAGS.tmvar2_home = params.tmvar2_home
        FLAGS.tmvar2_host = params.tmvar2_host
        FLAGS.tmvar2_port = params.tmvar2_port

        # import pprint
        # pprint.PrettyPrinter().pprint(FLAGS.__flags)

        stm_dict['biobert'] = BioBERT(FLAGS)

        stm_dict['gnormplus_home'] = params.gnormplus_home
        stm_dict['gnormplus_host'] = params.gnormplus_host
        stm_dict['gnormplus_port'] = params.gnormplus_port

        stm_dict['tmvar2_home'] = params.tmvar2_home
        stm_dict['tmvar2_host'] = params.tmvar2_host
        stm_dict['tmvar2_port'] = params.tmvar2_port

        stm_dict['max_word_len'] = params.max_word_len
        stm_dict['ner_model'] = params.ner_model
        stm_dict['n_pmid_limit'] = params.n_pmid_limit
        stm_dict['time_format'] = params.time_format
        stm_dict['available_formats'] = params.available_formats

        if not os.path.exists('./output'):
            os.mkdir('output')
        else:
            # delete prev. version outputs
            delete_files('./output')

        delete_files(os.path.join(params.gnormplus_home, 'input'))
        delete_files(os.path.join(params.tmvar2_home, 'input'))

        print(datetime.now().strftime(params.time_format),
              'Starting server at http://{}:{}'.format(params.ip, params.port))

        # https://stackoverflow.com/a/18445168
        GetHandler.stm_dict = stm_dict
        GetHandler.normalizer = Normalizer()

        # https://docs.python.org/3.6/library/socketserver.html#asynchronous-mixins
        # https://stackoverflow.com/a/14089457
        server = ThreadedHTTPServer((params.ip, params.port), GetHandler)
        server.serve_forever()
Ejemplo n.º 19
0
custom_path = None
if only_predict:
    custom_path = sys.argv[1]
    print("predicting using path {}".format(custom_path))

#####################################################
#                   PREPROCESSING                   #
#####################################################
# TODO(KGF): check tuple unpack
(shot_list_train, shot_list_validate, shot_list_test) = guarantee_preprocessed(conf)

#####################################################
#                   NORMALIZATION                   #
#####################################################

nn = Normalizer(conf)
nn.train()
loader = Loader(conf, nn)
print("...done")
print(
    "Training on {} shots, testing on {} shots".format(
        len(shot_list_train), len(shot_list_test)
    )
)


#####################################################
#                    TRAINING                       #
#####################################################
# train(conf,shot_list_train,loader)
if not only_predict:
Ejemplo n.º 20
0
    def plot_shot(
        self,
        shot,
        save_fig=True,
        normalize=True,
        truth=None,
        prediction=None,
        P_thresh_opt=None,
        prediction_type="",
        extra_filename="",
    ):
        if self.normalizer is None and normalize:
            if self.conf is not None:
                self.saved_conf["paths"]["normalizer_path"] = self.conf[
                    "paths"]["normalizer_path"]
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        if shot.previously_saved(self.shots_dir):
            shot.restore(self.shots_dir)
            if shot.signals_dict is not None:  # make sure shot was saved with data
                # t_disrupt = shot.t_disrupt
                # is_disruptive = shot.is_disruptive
                if normalize:
                    self.normalizer.apply(shot)

                use_signals = self.saved_conf["paths"]["use_signals"]
                fontsize = 18
                lower_lim = 0  # len(pred)
                plt.close()
                colors = ["b", "green", "red", "c", "m", "orange", "k", "y"]
                # lss = ["-", "--"]
                f, axarr = plt.subplots(
                    4 + 1, 1, sharex=True,
                    figsize=(18, 15))  # ,squeeze=False)#, squeeze=False)
                plt.title(prediction_type)
                assert np.all(shot.ttd.flatten() == truth.flatten())
                xx = range(len(prediction))
                j = 0  # list(reversed(range(len(pred))))
                j1 = 0
                p0 = 0
                for k, p in enumerate(prediction):
                    if p > P_thresh_opt:
                        p0 = k
                        break

                for i, sig in enumerate(use_signals):
                    num_channels = sig.num_channels
                    sig_arr = shot.signals_dict[sig]
                    # legend = []
                    if num_channels == 1:
                        j = i // 7
                        ax = axarr[j]
                        #                 if j == 0:
                        ax.plot(
                            xx,
                            sig_arr[:, 0],
                            linewidth=2,
                            color=colors[i % 7],
                            label=sig.description,
                        )  # ,linestyle=lss[j],color=colors[j])
                        if np.min(sig_arr[:, 0]) < -100000:
                            ax.set_ylim([-6, 6])
                            ax.set_yticks([-5, 0, 5])
                        else:
                            ax.set_ylim([-2, 11])
                            ax.set_yticks([0, 5, 10])
                        #                 ax.set_ylabel(labels[sig],size=fontsize)
                        ax.legend()
                    else:
                        j = -2 - j1
                        j1 += 1
                        ax = axarr[j]
                        ax.imshow(
                            sig_arr[:, :].T,
                            aspect="auto",
                            label=sig.description,
                            cmap="inferno",
                        )
                        ax.set_ylim([0, num_channels])
                        ax.text(
                            lower_lim + 200,
                            45,
                            sig.description,
                            bbox={
                                "facecolor": "white",
                                "pad": 10
                            },
                            fontsize=fontsize,
                        )
                        ax.set_yticks([0, num_channels / 2])
                        ax.set_yticklabels(["0", "0.5"])
                        ax.set_ylabel("$\\rho$", size=fontsize)
                    ax.legend(
                        loc="center left",
                        labelspacing=0.1,
                        bbox_to_anchor=(1, 0.5),
                        fontsize=fontsize,
                        frameon=False,
                    )
                    ax.axvline(len(truth) - self.T_min_warn, color="r")
                    ax.axvline(p0, linestyle="--", color="darkgrey")
                    plt.setp(ax.get_xticklabels(), visible=False)
                    plt.setp(ax.get_yticklabels(), fontsize=fontsize)
                    f.subplots_adjust(hspace=0)
                    # print(sig)
                    # print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr)))
                ax = axarr[-1]
                #         ax.semilogy((-truth+0.0001),label='ground truth')
                #         ax.plot(-prediction+0.0001,'g',label='neural net prediction')
                #         ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
                #         nn = np.min(pred)
                #        ax.plot(xx,truth,'g',label='target',linewidth=2)
                #         ax.axhline(0.4,linestyle="--",color='k',label='threshold')

                ax.plot(
                    xx,
                    prediction,
                    "b",
                    label="RNN output--Disruption score",
                    linewidth=2,
                    zorder=1,
                )
                ax.axhline(P_thresh_opt,
                           linestyle="--",
                           color="k",
                           label="threshold",
                           zorder=2)
                ax.axvline(p0, linestyle="--", color="darkgrey")
                ax.set_ylim([min(prediction), max(prediction)])
                ax.set_yticks([0, 1])
                if p0 > 0:
                    ax.scatter(xx[k],
                               p,
                               s=300,
                               marker="*",
                               color="r",
                               zorder=3)

                ax.axvline(len(truth) - self.T_min_warn,
                           color="r",
                           linewidth=0.5)  # ,label='min warning time')
                ax.set_xlabel("T [ms]", size=fontsize)
                # ax.axvline(2400)
                ax.legend(
                    loc="center left",
                    labelspacing=0.1,
                    bbox_to_anchor=(1, 0.5),
                    fontsize=fontsize + 2,
                    frameon=False,
                )
                plt.setp(ax.get_yticklabels(), fontsize=fontsize)
                plt.setp(ax.get_xticklabels(), fontsize=fontsize)
                # plt.xlim(0,200)
                plt.xlim([lower_lim, len(truth)])
                #         plt.savefig("{}.png".format(num),dpi=200,bbox_inches="tight")
                if save_fig:
                    plt.savefig(
                        "sig_fig_{}{}.png".format(shot.number, extra_filename),
                        bbox_inches="tight",
                    )
                    np.savez(
                        "sig_{}{}.npz".format(shot.number, extra_filename),
                        shot=shot,
                        T_min_warn=self.T_min_warn,
                        T_max_warn=self.T_max_warn,
                        prediction=prediction,
                        truth=truth,
                        use_signals=use_signals,
                        P_thresh=P_thresh_opt,
                    )
                # plt.show()
        else:
            print("Shot hasn't been processed")
Ejemplo n.º 21
0
    def plot_shot(
        self,
        shot,
        save_fig=True,
        normalize=True,
        truth=None,
        prediction=None,
        P_thresh_opt=None,
        prediction_type="",
        extra_filename="",
    ):
        if self.normalizer is None and normalize:
            if self.conf is not None:
                self.saved_conf["paths"]["normalizer_path"] = self.conf[
                    "paths"]["normalizer_path"]
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        if shot.previously_saved(self.shots_dir):
            shot.restore(self.shots_dir)
            if shot.signals_dict is not None:
                # make sure shot was saved with data
                # t_disrupt = shot.t_disrupt
                # is_disruptive = shot.is_disruptive
                if normalize:
                    self.normalizer.apply(shot)

                use_signals = self.saved_conf["paths"]["use_signals"]
                fontsize = 15
                lower_lim = 0  # len(pred)
                plt.close()
                # colors = ["b", "k"]
                # lss = ["-", "--"]
                f, axarr = plt.subplots(len(use_signals) + 1,
                                        1,
                                        sharex=True,
                                        figsize=(10, 15))
                plt.title(prediction_type)
                assert np.all(shot.ttd.flatten() == truth.flatten())
                xx = range(len(prediction))  # list(reversed(range(len(pred))))
                for i, sig in enumerate(use_signals):
                    ax = axarr[i]
                    num_channels = sig.num_channels
                    sig_arr = shot.signals_dict[sig]
                    if num_channels == 1:
                        ax.plot(xx, sig_arr[:, 0], linewidth=2)
                        ax.plot([], linestyle="none", label=sig.description)
                        if np.min(sig_arr[:, 0]) < 0:
                            ax.set_ylim([-6, 6])
                            ax.set_yticks([-5, 0, 5])
                        ax.plot([], linestyle="none", label=sig.description)
                        if np.min(sig_arr[:, 0]) < 0:
                            ax.set_ylim([-6, 6])
                            ax.set_yticks([-5, 0, 5])
                        else:
                            ax.set_ylim([0, 8])
                            ax.set_yticks([0, 5])
                    else:
                        ax.imshow(
                            sig_arr[:, :].T,
                            aspect="auto",
                            label=sig.description,
                            cmap="inferno",
                        )
                        ax.set_ylim([0, num_channels])
                        ax.text(
                            lower_lim + 200,
                            45,
                            sig.description,
                            bbox={
                                "facecolor": "white",
                                "pad": 10
                            },
                            fontsize=fontsize - 5,
                        )
                        ax.set_yticks([0, num_channels / 2])
                        ax.set_yticklabels(["0", "0.5"])
                        ax.set_ylabel("$\\rho$", size=fontsize)
                    ax.legend(loc="best",
                              labelspacing=0.1,
                              fontsize=fontsize,
                              frameon=False)
                    ax.axvline(len(truth) - self.T_min_warn,
                               color="r",
                               linewidth=0.5)
                    plt.setp(ax.get_xticklabels(), visible=False)
                    plt.setp(ax.get_yticklabels(), fontsize=fontsize)
                    f.subplots_adjust(hspace=0)
                ax = axarr[-1]
                # ax.semilogy((-truth+0.0001),label='ground truth')
                # ax.plot(-prediction+0.0001,'g',label='neural net prediction')
                # ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
                # nn = np.min(pred)
                ax.plot(xx, truth, "g", label="target", linewidth=2)
                # ax.axhline(0.4,linestyle="--",color='k',label='threshold')
                ax.plot(xx, prediction, "b", label="RNN output", linewidth=2)
                ax.axhline(P_thresh_opt,
                           linestyle="--",
                           color="k",
                           label="threshold")
                ax.set_ylim([-2, 2])
                ax.set_yticks([-1, 0, 1])
                # if len(truth)-T_max_warn >= 0:
                # ax.axvline(len(truth)-T_max_warn,color='r')#,label='max
                # warning time')
                # ,label='min warning time')
                ax.axvline(len(truth) - self.T_min_warn,
                           color="r",
                           linewidth=0.5)
                ax.set_xlabel("T [ms]", size=fontsize)
                # ax.axvline(2400)
                ax.legend(
                    loc=(0.5, 0.7),
                    fontsize=fontsize - 5,
                    labelspacing=0.1,
                    frameon=False,
                )
                plt.setp(ax.get_yticklabels(), fontsize=fontsize)
                plt.setp(ax.get_xticklabels(), fontsize=fontsize)
                # plt.xlim(0,200)
                plt.xlim([lower_lim, len(truth)])
                #         plt.savefig("{}.png".format(num),dpi=200,bbox_inches="tight")
                if save_fig:
                    plt.savefig("sig_fig_{}{}.png".format(
                        shot.number, extra_filename)),
                    #        bbox_inches='tight')
                    np.savez(
                        "sig_{}{}.npz".format(shot.number, extra_filename),
                        shot=shot,
                        T_min_warn=self.T_min_warn,
                        T_max_warn=self.T_max_warn,
                        prediction=prediction,
                        truth=truth,
                        use_signals=use_signals,
                        P_thresh=P_thresh_opt,
                    )
                # plt.show()
        else:
            print("Shot hasn't been processed")