def download(self, item):
     downloads = self.settings['downloads']
     if '' == downloads:
         xbmcgui.Dialog().ok(self.provider.name, xbmcutil.__lang__(30009))
         return
     stream = self.resolve(item['url'])
     if stream:
         if not 'headers' in stream.keys():
             stream['headers'] = {}
         xbmcutil.reportUsage(self.addon_id, self.addon_id + '/download')
         # clean up \ and /
         name = item['title'].replace('/', '_').replace('\\', '_')
         if not stream['subs'] == '':
             util.save_to_file(stream['subs'],
                               os.path.join(downloads, name + '.srt'),
                               stream['headers'])
         dot = name.find('.')
         if dot <= 0:
             # name does not contain extension, append some
             name += '.mp4'
         xbmcutil.download(self.addon,
                           name,
                           self.provider._url(stream['url']),
                           os.path.join(downloads, name),
                           headers=stream['headers'])
Пример #2
0
def get_address(latitude, longitude, googlemaps_api_key, force=True, no_cache=False):
    latlng = '{0},{1}'.format(latitude, longitude)
    if latlng in addresses:
        return addresses[latlng]['results'][0]['formatted_address']
    if no_cache:
        try:                    
            url = "https://maps.googleapis.com/maps/api/geocode/json?latlng={lat},{lng}&key={api_key}".format(
                lat=latitude,
                lng=longitude,
                api_key=googlemaps_api_key
            )    
            log_google.info('Buscando endereco {0},{1} (Google Maps API)...'.format(latitude, longitude))   
            raw_data = requests.get(url)
            data = json.loads(raw_data.text)                
            log_google.info('Endereco obtido.')   
            if 'results' in data:
                formatted_address = data['results'][0]['formatted_address']
                addresses[latlng] = data
                if len(addresses) % 25 == 0:
                    save_to_file('address.json', addresses)
                return formatted_address
        except:
            traceback.print_exc()
    if force:
        return '[{0},{1}]'.format(latitude, longitude)
    else:
        return ''
Пример #3
0
 def save(self):
     self.json_obj[jkey_quetions] = list(self.json_obj[jkey_quetions])
     self.json_obj[jkey_images] = list(self.json_obj[jkey_images])
     json_str = json.dumps(self.json_obj,
                           ensure_ascii=False,
                           indent=4,
                           sort_keys=True)
     util.luni(json_str)
     util.save_to_file(self.fpath, json_str.encode('utf-8'))
Пример #4
0
def main():

    for n, fig in enumerate(FIGURES):

        depthmap = os.path.join(BGFOLDER, 'b_{}.png'.format(fig))
        pattern = os.path.join(FGFOLDER, 'f_{}.png'.format(fig))
        outfile = os.path.join(OUTFOLDER, '{}.png'.format(fig))
        color = COLORMAP[n]

        util.save_to_file(make_stereogram(depthmap, pattern, color), outfile)
Пример #5
0
 def download(self,url,name):
     downloads = self.settings['downloads']
     if '' == downloads:
         xbmcgui.Dialog().ok(self.provider.name,xbmcutil.__lang__(30009))
         return
     stream = self.resolve(url)
     if stream:
         xbmcutil.reportUsage(self.addon_id,self.addon_id+'/download')
         if not stream['subs'] == '':
             util.save_to_file(stream['subs'],os.path.join(downloads,name+'.srt'))
         xbmcutil.download(self.addon,name,self.provider._url(stream['url']),os.path.join(downloads,name))
Пример #6
0
def get_processed_tweets(inputfilepath, outputfilepath = None, format = "JSON"):

    # Read raw data
    if format == "xml":
        tweets = readxml(inputfilepath)
    else:
        tweets = readjson(inputfilepath)

    if outputfilepath:
        util.save_to_file(tweets, outputfilepath)

    return tweets
Пример #7
0
 def ask_for_captcha(self,params):
     img = os.path.join(unicode(xbmc.translatePath(self.addon.getAddonInfo('profile'))),'captcha.png')
     util.save_to_file(params['img'],img)
     cd = CaptchaDialog('captcha-dialog.xml',xbmcutil.__addon__.getAddonInfo('path'),'default','0')
     cd.image = img
     xbmc.sleep(3000)
     cd.doModal()
     del cd
     kb = xbmc.Keyboard('',self.addon.getLocalizedString(200),False)
     kb.doModal()
     if kb.isConfirmed():
         print 'got code '+kb.getText()
         return kb.getText()
 def ask_for_captcha(self, params):
     img = os.path.join(unicode(xbmc.translatePath(
         self.addon.getAddonInfo('profile'))), 'captcha.png')
     util.save_to_file(params['img'], img)
     cd = CaptchaDialog('captcha-dialog.xml',
                        xbmcutil.__addon__.getAddonInfo('path'), 'default', '0')
     cd.image = img
     xbmc.sleep(3000)
     cd.doModal()
     del cd
     kb = xbmc.Keyboard('', self.addon.getLocalizedString(200), False)
     kb.doModal()
     if kb.isConfirmed():
         print 'got code ' + kb.getText()
         return kb.getText()
Пример #9
0
def build_word_map():
    print("Building word map...")
    with open("trees/train.txt", "r") as f:
        trees = [ParentedTree.fromstring(line.lower()) for line in f]

    print("Counting words...")
    words = defaultdict(int)
    for tree in trees:
        for token in tree.leaves():
            words[token] += 1

    word_map = dict(zip(words.keys(), range(len(words))))
    word_map[UNK] = len(words)  # Add unknown as word
    util.save_to_file(word_map, WORD_MAP_FILENAME)
    return word_map
 def solve_captcha(self,params):
     snd = os.path.join(unicode(xbmc.translatePath(self.addon.getAddonInfo('profile'))),'sound.wav')
     util.save_to_file(params['snd'], snd)
     try:
         sndfile = open(snd, 'rb').read()
         url = 'http://m217-io.appspot.com/ulozto'
         headers = {'Content-Type': 'audio/wav'}
         req = urllib2.Request(url, sndfile, headers)
         response = urllib2.urlopen(req)
         data = response.read()
         response.close()
     except urllib2.HTTPError:
         traceback.print_exc()
         data = ''
     if not data:
         return self.ask_for_captcha(params)
     return data
Пример #11
0
 def download(self,item):
     downloads = self.settings['downloads']
     if '' == downloads:
         xbmcgui.Dialog().ok(self.provider.name,xbmcutil.__lang__(30009))
         return
     stream = self.resolve(item['url'])
     if stream:
         if not 'headers' in stream.keys():
             stream['headers'] = {}
         xbmcutil.reportUsage(self.addon_id,self.addon_id+'/download')
         # clean up \ and /
         name = item['title'].replace('/','_').replace('\\','_')
         if not stream['subs'] == '':
             util.save_to_file(stream['subs'],os.path.join(downloads,name+'.srt'))
         dot = name.find('.')
         if dot <= 0:
             # name does not contain extension, append some
             name+='.mp4'
         xbmcutil.download(self.addon,name,self.provider._url(stream['url']),os.path.join(downloads,name),headers=stream['headers'])
Пример #12
0
def main(infiles, outfile, outformat, outdir, clobber, headers, recursive,
         minimise):
    if not infiles:
        click.secho("No input files specified. Aborting.", fg="red")
        sys.exit(1)

    # Get/set writemode (x => create or ask, w => overwrite)
    writemode = "w" if clobber else "x"

    # Get the full list of input files
    if recursive:
        _infiles = list(filter(is_ibw, infiles))
        inpaths = filter(os.path.isdir, infiles)
        infiles = _infiles + recurse_subdirs(inpaths)
    else:
        infiles = util.flatten(map(list_ibw, infiles))

    # Check for errors
    if len(infiles) is 0:
        click.secho("No .ibw files found", fg="red")
        sys.exit(1)
    if len(infiles) > 1 and outfile:
        click.secho("Output filename cannot be " +
                    "specified for multiple input files",
                    fg="red")
        sys.exit(1)

    # Iterate through input files and do action
    if outformat == "dump":
        for infile in infiles:
            extractors.ibw2stdout(infile)  # prints to stdout
    else:
        with click.progressbar(infiles, width=0) as bar:
            for infile in bar:
                outpath = get_outpath(infile, outfile, outformat, outdir)
                data = extractors.ibw2dict(infile)
                util.save_to_file(data,
                                  outpath,
                                  mode=writemode,
                                  csv_headers=headers,
                                  json_mini=minimise)
Пример #13
0
def set_apache_vhost_directive(
    vhost,
    server_name,
    directive,
    new_value,
    config_files=get_available_apache_config_files()):
    """
    Update/add the directive in vhost config

    :param vhost: string representing a vhost
    :param server_name: The name of the virtual host (in case of an IP-base VHost this should be "None")
    :param directive: the name of the directive to be written or added
    :param new_value: the value of the directive to be written or added
    """
    vhost_config = get_vhost_config(config_files, vhost, server_name)
    if vhost_config is None:
        raise Exception('No VHost "{}" found'.format(vhost))

    new_line_content = '\t\t{} {}\n'.format(directive, new_value)
    directive_elements = vhost_config.xpath(directive)
    if len(directive_elements) > 1:
        config_files = set([
            d.xpath('ancestor::ConfigFile/Path')[0].text
            for d in directive_elements
        ])
        raise Exception(
            'Expected at most 1 occurrence of directive "{}" in VHost "{}", found {} occurrences in configuration file(s) {}.'
            .format(directive, vhost, len(directive_elements),
                    ','.join(config_files)))

    config_path = vhost_config.xpath('ancestor::ConfigFile/Path')[0].text
    config_content = open(config_path, 'r').readlines()
    if len(directive_elements) == 0:
        config_content.insert(
            int(vhost_config.find('EndLine').text) - 1, new_line_content)
    else:
        directive_element = directive_elements[0]
        config_content[int(directive_element.find('StartLine').text) -
                       1] = new_line_content
    util.save_to_file("".join(config_content), config_path, backup=True)
Пример #14
0
def get_url(url, google_api_key, force=True):
    global urls
    if url in urls:
        return urls[url]

    if google_api_key and force:
        log_google.info('usando Google API Shortner ')
        try:                    
            url_api = "https://www.googleapis.com/urlshortener/v1/url?key={key}".format(key=google_api_key)
            post = {'longUrl': url}
            raw_data = requests.post(url_api, json=post)
            data = json.loads(raw_data.text) 
            log_google.info(data)              
            urls[url] = data['id']
            save_to_file('urls.json', urls)
            return data['id']
        except:
            traceback.print_exc()
    if force:
        return url
    else:
        return ''
Пример #15
0
 def solve_captcha(self, params):
     snd = os.path.join(
         unicode(xbmc.translatePath(self.addon.getAddonInfo('profile'))),
         'sound.wav')
     util.save_to_file(params['snd'], snd)
     sndfile = open(snd, 'rb').read()
     url = 'http://m217-io.appspot.com/ulozto'
     headers = {'Content-Type': 'audio/wav'}
     req = Request(url, sndfile, headers)
     response = None
     try:
         response = urlopen(req,
                            timeout=int(
                                __settings__("ulozto_captcha_timeout")))
         data = response.read()
     except (HTTPError, socket.timeout):
         traceback.print_exc()
         data = ''
     finally:
         response and response.close()
     if not data:
         return self.ask_for_captcha(params)
     return data
Пример #16
0
def createNewTask():
    """Create a new task

       1. Interact with user to grab the task content
       2. Schedule the task
       3. Put the task record in file.
    """
    #TODO: Now the task content is not allowed to contain any coma
    # since I'm using coma as delimeter in schdule_task().
    # This need to be imporved.
    task_content = raw_input('Please input task:')
    options = {
        1: '2 Hours later',
        2: 'Tomorrow Evening',
        3: 'Tomorrow Morning',
        4: 'Tomorrow Afternoon'
    }

    print 'Please select task scheduled time:'
    for i, option in options.items():
        print '{}. {}'.format(i, option)
    while True:
        user_option = int(raw_input('Your Option:'))
        #TODO: Rerange schedule time from 1-4
        if not 0 <= user_option < 4:
            print 'Invalid Option'
            continue
        break

    task_id, schedule_time = scheduleTask(user_option, task_content)

    #TODO: Create an LaterTask object and put it in the queue.
    print('[{}] will be scheduled at {}'.format(task_content, schedule_time))
    task_list = load_from_file(TASK_LIST)
    newTask = LaterTask(task_id, schedule_time, task_content)
    task_list.append(newTask)
    save_to_file(task_list, TASK_LIST)
Пример #17
0
 def download(self,url,name):
     downloads = self.settings['downloads']
     if '' == downloads:
         xbmcgui.Dialog().ok(self.provider.name,xbmcutil.__lang__(30009))
         return
     stream = self.resolve(url)
     if stream:
         xbmcutil.reportUsage(self.addon_id,self.addon_id+'/download')
         if not stream['subs'] == '':
             util.save_to_file(stream['subs'],os.path.join(downloads,name+'.srt'))
         if stream['url'].find('munkvideo') > 0:
             # we have to handle this download a special way
             filename = xbmc.makeLegalFilename(os.path.join(downloads,name+'.mp4'))
             icon = os.path.join(__addon__.getAddonInfo('path'),'icon.png')
             output = open(filename,'wb')
             try:
                 req = urllib2.Request(stream['url'],headers={'Referer':'me'}) # that special way
                 response = urllib2.urlopen(req)
                 data = response.read(8192)
                 xbmc.executebuiltin('XBMC.Notification(%s,%s,3000,%s)' % (xbmc.getLocalizedString(13413).encode('utf-8'),filename,icon))
                 while len(data) > 0:
                     output.write(data)
                     data = response.read(8192)
                 response.close()
                 output.close()
                 if xbmc.Player().isPlaying():
                     xbmc.executebuiltin('XBMC.Notification(%s,%s,8000,%s)' % (xbmc.getLocalizedString(20177),filename,icon))
                 else:
                     xbmcgui.Dialog().ok(xbmc.getLocalizedString(20177),filename)
             except:
                 traceback.print_exc()
                 xbmc.executebuiltin('XBMC.Notification(%s,%s,5000,%s)' % (xbmc.getLocalizedString(257),filename,icon))
                 xbmcgui.Dialog().ok(filename,xbmc.getLocalizedString(257))
                 output.close()
         else:
             xbmcutil.download(self.addon,name,self.provider._url(stream['url']),os.path.join(downloads,name))
Пример #18
0
    def fetch(self,
              start_page=1,
              end_page=1000,
              verbose=False,
              save_page_to_file=False,
              save_path=DIR_PATH + 'lib/pickled/'):
        count = 0
        inserted = 0
        for i in range(start_page, end_page + 1):
            page = self.kb.get('page/%d/' % i)
            if len(page) == 0:
                if verbose: print('Found empty page... stopping fetch')
                break
            if save_page_to_file:
                util.save_to_file(page, save_path + '%d.p' % i)

            commit = 0
            for event in page:
                if self.kbdb.insert_kill(event):
                    commit += 1
                count += 1
            if verbose:
                print('Finished page %d | committed %d new killmails to db' %
                      (i, commit))
def install_apache_ssl_cert(pem_cert_key_path, site, restart_apache=False):
    vhost = site['VHost']
    Logger.info(
        'Installing SSL certificate for virtual host at {VHost}'.format(
            **site))

    server_name = site['ServerName']
    ssl_cert_path = apache_util.get_apache_ssl_cert_path(vhost, server_name)
    ssl_key_path = apache_util.get_apache_ssl_key_path(vhost, server_name)

    certs = util.parse_certs(pem_cert_key_path, Logger)
    if not certs:
        raise Exception(
            "No X.509 certs found in {} received by KeyTalk client".format(
                pem_cert_key_path))
    keys = util.parse_keys(pem_cert_key_path, Logger)
    if not keys:
        raise Exception(
            "No X.509 keys found in {} received by KeyTalk client".format(
                pem_cert_key_path))
    cas = util.parse_cas(Logger)

    if util.same_file(ssl_cert_path, ssl_key_path):
        Logger.debug("Saving SSL certificate with key and {} CAs to {}".format(
            len(cas), ssl_cert_path))
        util.save_to_file('\n'.join(certs + keys + cas), ssl_cert_path)
    else:
        Logger.debug(
            "Saving SSL certificates (serial: {}) and {} CAs to {}".format(
                OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
                                                certs[0]).get_serial_number(),
                len(cas), ssl_cert_path))
        util.save_to_file('\n'.join(certs + cas), ssl_cert_path)
        Logger.debug("Saving SSL key to " + ssl_key_path)
        util.save_to_file('\n'.join(keys), ssl_key_path)

    # ask Apache to gracefully reload key material
    if restart_apache:
        reload_apache()
Пример #20
0
def filter_tweets(tweets,
                  outputfilepath=None,
                  art=False,
                  frequency=False,
                  terms_to_remove=None):
    """
    :param tweets: list of Entity.Tweet, required

    :param outputfilepath: string, optional, default: None
        Location path for saving the filtered tweets

    :param art: boolean, optional, default: False
        Remove articles, pronouns and prepositions

    :param frequency: boolean, optional, default: False,
        Remove less used words

    :param terms_to_remove: list of string, optional, default: None
        List of terms to remove from each tweet

    :return tweets: list of tweets cleaned and filtered
    """

    if terms_to_remove:  # Remove search terms
        for i in range(len(tweets)):
            tweets[i].text = remove_search_terms(tweets[i].text.lower(),
                                                 terms_to_remove)

    # Clean tweets
    for i in range(len(tweets)):
        tweets[i].text = clean(tweets[i].text.lower())

    # Remove duplicated tweets
    seen = set()
    deduped = []
    for i in range(len(tweets)):
        if tweets[i].text not in seen:
            seen.add(tweets[i].text)
            deduped.append(tweets[i])
    tweets = deduped

    if art:  # Remove articulos, pronombres y preposiciones
        for i in range(len(tweets)):
            tweets[i].text = " ".join(removerArtProPre(tweets[i].text))

    if frequency:  # Remove less used words
        # TODO review effectiveness
        total = 0
        d = defaultdict(int)
        for tw in tweets:
            for word in tw.text:
                d[word] += 1
                total += 1

        freq_ord = [(word, count) for word, count in sorted(
            d.items(), key=lambda k_v: (k_v[1], k_v[0]), reverse=True)]

        wordsFiltered = []
        # freqFiltered = []

        for i in range(len(freq_ord)):
            if freq_ord[i][1] / total < MINIMUM_FREQUENCY:
                wordsFiltered = [x[0] for x in freq_ord[:i]]
                #  Only necessary if it's required to know the frequency of each word
                # freqFiltered = [x[1] for x in freq_ord[:i]]
                break

        for i in range(len(tweets)):
            tweets[i].text = ' '.join(
                [x for x in tweets[i].text if x in wordsFiltered])

    # Remove empty tweets
    super_cleaned = []
    for tweet in tweets:
        if len(tweet.text) != 0:
            super_cleaned.append(tweet)

    tweets = super_cleaned

    if not tweets:
        raise Exception('There is no remaining tweet after filtering')

    if outputfilepath:
        util.save_to_file(tweets, outputfilepath)

    return tweets
Пример #21
0
 def __to_file(self):
     util.save_to_file(self.data, self.file)
Пример #22
0
def save_csv_from_ratings(csv, user_name, out_file_name=CSV_RATINGS):
    log("Create CSV file {} for user {}.".format(out_file_name, user_name))
    ratings = get_ratings_from_csv_file(csv, user_name)
    save_to_file(ratings, out_file_name)
    return [i[0] for i in ratings]
Пример #23
0
def save_csv_from_watchlist(csv, user_name, out_file_name=CSV_WATCHLIST):
    log("Create watchlist file {} for user {}.".format(out_file_name,
                                                       user_name))
    watchlist = get_watchlist_from_csv_file(csv, user_name)
    save_to_file(watchlist, out_file_name)
    return [i[0] for i in watchlist]
Пример #24
0
    # X_train, y_train = RandomOverSampler('minority').fit_sample(X_train, y_train)
    # print(X_train.shape, y_train.shape)

    # X_train, y_train = SMOTE().fit_sample(X_train, y_train)

    # fs_model = feature_selection(X_train, y_train)
    # X_train = fs_model.transform(X_train)
    # X_test = fs_model.transform(X_test)

    if train_flag:
        models, names = get_models()
        estimators = train_predict(models, names, X_train, y_train, X_test,
                                   y_test)
        for estimator, name in zip(estimators, names):
            util.save_to_file(predicted[['user_id', 'brand_id']],
                              estimator.predict(predicted.drop(cols, axis=1)),
                              '_'.join(['1452983', '2cii', name]) + '.txt')

    else:
        model_params = {
            'LogisticRegression': (
                LogisticRegression(penalty='l2'),
                {
                    # 'penalty':['l2', 'l1'], l2没用
                    # 'C': np.arange(0.2, 0.5, 0.1)
                    'C': np.arange(0.2, 0.5, 0.1)
                }),
            'DecisionTreeClassifier': (DecisionTreeClassifier(), {
                'class_weight': ['balanced', None]
            }),
            'KNeighborsClassifier': (KNeighborsClassifier(), {
Пример #25
0
def train(models):
    train, test = get_data(params.ci_train_file_name, params.ci_test_file_name, train_month=[2, 3, 4],
                                test_month=[5])
    print(test['target'].value_counts())
    print(train['target'].value_counts())
    # train = util.get_undersample_data2(train)
    train.drop('user_id', inplace=True, axis=1)
    (X_train, y_train), (X_test, y_test) = util.get_X_y(train), util.get_X_y(test)

    print(test['target'].value_counts())
    print(train['target'].value_counts())

    results = []
    names = []
    elapsed = []
    auc = []
    precision = []
    corrects, errors=[], []
    corrects_value_counts, errors_value_counts=[],[]

    scoring = {'recall': 'recall', 'accuracy': 'accuracy'}  # , 'auc': 'roc_auc'
    for name, model in models:
        kfold = KFold(n_splits=3, random_state=78)
        print('=' * 30)
        try:
            cv_results = cross_validate(model, X_train, y_train, cv=kfold, scoring=scoring)

            for score in scoring:
                msg = "%s: %s mean:%f std:(%f)" % (
                name, score, cv_results['test_' + score].mean(), cv_results['test_' + score].std())
                print(msg)

            start_time = time.time()

            model.fit(X_train, y_train)
            test_pred = model.predict(X_test.loc[:, X_test.columns != 'user_id'])
            test_prec = precision_score(y_test, test_pred, average='micro')

            test_auc = roc_auc_score(y_test, test_pred)
            print('test precision: %f roc_auc: %f' % (test_prec, test_auc))

            elapsed_time = time.time() - start_time

            results.append(cv_results)
            auc.append(test_auc)
            precision.append(test_prec)
            elapsed.append(elapsed_time)
            names.append(name)

            correct = np.where(np.array(y_test['target'].tolist()) == test_pred)[0]
            error = np.where(np.array(y_test['target'].tolist()) != test_pred)[0]

            corrects.append(len(correct))
            errors.append(len(error))
            corrects_value_counts.append(y_test['target'].iloc[correct].value_counts().to_dict())
            errors_value_counts.append(y_test['target'].iloc[error].value_counts().to_dict())

            print('correct: {} errors: {}'.format(len(correct), len(error)))
            print('predict correct value counts:')
            print(y_test['target'].iloc[correct].value_counts())
            print('predict error value counts:')
            print(y_test['target'].iloc[error].value_counts())


            util.save_to_file(X_test[['user_id']], test_pred,
                               '_'.join(['1452983', '2ci', name]) + '.txt')

        except Exception as ex:
            print(ex)

    statics = pd.DataFrame([names, auc, precision, elapsed, corrects, errors, corrects_value_counts, errors_value_counts]).T
    statics.columns=['name', 'auc', 'precision', 'time', 'correct', 'error', 'corrects_value_counts', 'errors_value_counts']
    statics.sort_values(by=['auc'], ascending=False, inplace=True)
    statics.to_csv(params.output + '_'.join(['1452983', '2ci', 'statics']) + '.txt', index=False)
    print('end')
Пример #26
0
import util

from collections import defaultdict
import pandas as pd
from time import time


def time_detection(algos, file):
    times = defaultdict(list)
    for algo in algos:
        for _ in range(100):
            objects, duration = algo.object_detection_api(f'./../src/imgs/{file}')
            times[algo.__name__].append(duration)
        else:
            print(f'{algo.__name__} done')
    return times


if __name__ == '__main__':
    # util.get_system_info()
    files = ['one.jpg', 'two.jpg', 'three.webp', 'four.webp', 'five.png', 'six.jpg']
    algos = [FasterRCNN(), Hog(), Yolo()]

    for file in files:
        times = time_detection(algos, file)
        for name, data in times.items():
            df = pd.DataFrame(data, columns=['Data'])
            util.calc_stats(df, name)
        # util.show_boxplot(times, file.split('.')[0].capitalize())
        util.save_to_file(times, file.split('.')[0])
Пример #27
0
        X_train, y_train)
    print(X_train.shape, y_train.shape)

    fs_model = feature_selection(X_train, y_train)
    print(X_train.shape)
    X_train = fs_model.transform(X_train)
    X_test = fs_model.transform(X_test)
    X_predict = fs_model.transform(predicted.drop(cols, axis=1))
    print(X_train.shape)

    if train_flag:
        models, names = get_models()
        estimators = train_predict(models, names, X_train, y_train, X_test,
                                   y_test)
        for estimator, name in zip(estimators, names):
            util.save_to_file(predicted[cols], estimator.predict(X_predict),
                              '_'.join(['1452983', '2b', name]) + '.txt')

    else:
        model_params = {
            'LogisticRegression': (
                LogisticRegression(penalty='l1'),
                {
                    # 'penalty': ['l2', 'l1'], l2貌似没用
                    'C': np.arange(0.4, 1, 0.2)  #l2:score不变
                }),
            'DecisionTreeClassifier': (DecisionTreeClassifier(), {
                'class_weight': ['balanced', None]
            })
        }
        train, _ = get_data(params.ci_train_file_name)
        tuning(train,