예제 #1
0
파일: video.py 프로젝트: tyarkoni/featureX
    def _extract(self, stim):
        verify_dependencies(['cv2'])
        flows = []
        onsets = []
        durations = []
        for i, f in enumerate(stim):

            frame = f.data
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

            if i == 0:
                last_frame = frame

            flow = cv2.calcOpticalFlowFarneback(last_frame, frame, None,
                                                self.pyr_scale, self.levels,
                                                self.winsize, self.iterations,
                                                self.poly_n, self.poly_sigma,
                                                self.flags)
            flow = np.sqrt((flow ** 2).sum(2))

            if self.show:
                cv2.imshow('frame', flow.astype('int8'))
                cv2.waitKey(1)

            last_frame = frame
            flows.append(flow.sum())
            onsets.append(f.onset)
            durations.append(f.duration)

        return ExtractorResult(flows, stim, self, features=['total_flow'],
                               onsets=onsets, durations=durations)
예제 #2
0
파일: audio.py 프로젝트: tyarkoni/featureX
 def __init__(self, feature=None, hop_length=512, **librosa_kwargs):
     verify_dependencies(['librosa'])
     if feature:
         self._feature = feature
     self.hop_length = hop_length
     self.librosa_kwargs = librosa_kwargs
     super(LibrosaFeatureExtractor, self).__init__()
예제 #3
0
    def __init__(self, api_key=None, model='general-v1.3', min_value=None,
                 max_concepts=None, select_concepts=None, rate_limit=None,
                 batch_size=None):
        verify_dependencies(['clarifai_client'])
        if api_key is None:
            try:
                api_key = os.environ['CLARIFAI_API_KEY']
            except KeyError:
                raise ValueError("A valid Clarifai API API_KEY "
                                 "must be passed the first time a Clarifai "
                                 "extractor is initialized.")

        self.api_key = api_key
        try:
            self.api = clarifai_client.ClarifaiApp(api_key=api_key)
            self.model = self.api.models.get(model)
        except clarifai_client.ApiError as e:
            logging.warn(str(e))
            self.api = None
            self.model = None
        self.model_name = model
        self.min_value = min_value
        self.max_concepts = max_concepts
        self.select_concepts = select_concepts
        if select_concepts:
            select_concepts = listify(select_concepts)
            self.select_concepts = [clarifai_client.Concept(concept_name=n)
                                    for n in select_concepts]
        super(ClarifaiAPIExtractor, self).__init__(rate_limit=rate_limit)
예제 #4
0
파일: google.py 프로젝트: tyarkoni/featureX
    def __init__(self, discovery_file=None, api_version='v1', max_results=100,
                 num_retries=3, rate_limit=None, **kwargs):
        verify_dependencies(['googleapiclient', 'google_auth'])
        if discovery_file is None:
            if 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ:
                raise ValueError("No Google application credentials found. "
                                 "A JSON service account key must be either "
                                 "passed as the discovery_file argument, or "
                                 "set in the GOOGLE_APPLICATION_CREDENTIALS "
                                 "environment variable.")
            discovery_file = os.environ['GOOGLE_APPLICATION_CREDENTIALS']

        self.discovery_file = discovery_file
        try:
            self.credentials = google_auth.service_account.Credentials\
                .from_service_account_file(discovery_file)
            self.service = googleapiclient.discovery.build(
                self.api_name, api_version, credentials=self.credentials,
                discoveryServiceUrl=DISCOVERY_URL)
        except Exception as e:
            logging.warn(str(e))
            self.credentials = None
            self.service = None
        self.max_results = max_results
        self.num_retries = num_retries
        self.api_version = api_version
        super(GoogleAPITransformer, self).__init__(rate_limit=rate_limit,
                                                   **kwargs)
예제 #5
0
파일: api.py 프로젝트: tyarkoni/featureX
    def __init__(self, consumer_key=None, consumer_secret=None,
                 access_token_key=None, access_token_secret=None,
                 rate_limit=None):
        verify_dependencies(['twitter'])
        if consumer_key is None or consumer_secret is None or \
           access_token_key is None or access_token_secret is None:
            try:
                consumer_key = os.environ['TWITTER_CONSUMER_KEY']
                consumer_secret = os.environ['TWITTER_CONSUMER_SECRET']
                access_token_key = os.environ['TWITTER_ACCESS_TOKEN_KEY']
                access_token_secret = os.environ['TWITTER_ACCESS_TOKEN_SECRET']
            except KeyError:
                raise ValueError("Valid Twitter API credentials "
                                 "must be passed the first time a TweetStim "
                                 "is initialized.")

        self.api = twitter.Api(consumer_key=consumer_key,
                               consumer_secret=consumer_secret,
                               access_token_key=access_token_key,
                               access_token_secret=access_token_secret)
        self.consumer_key = consumer_key
        self.consumer_secret = consumer_secret
        self.access_token_key = access_token_key
        self.access_token_secret = access_token_secret
        super(TweetStimFactory, self).__init__(rate_limit=rate_limit)
예제 #6
0
    def summary(self, stdout=True, plot=False):
        '''
        Displays diagnostics to the user

        Args:
            stdout (bool): print results to the console
            plot (bool): use Seaborn to plot results
        '''
        if stdout:
            print('Collinearity summary:')
            print(pd.concat([self.results['Eigenvalues'],
                             self.results['ConditionIndices'],
                             self.results['VIFs'],
                             self.results['CorrelationMatrix']],
                            axis=1))

            print('Outlier summary:')
            print(self.results['RowMahalanobisDistances'])
            print(self.results['ColumnMahalanobisDistances'])

            print('Validity summary:')
            print(self.results['Variances'])

        if plot:
            verify_dependencies('seaborn')
            for key, result in self.results.items():
                if key == 'CorrelationMatrix':
                    ax = plt.axes()
                    sns.heatmap(result, cmap='Blues', ax=ax)
                    ax.set_title(key)
                    sns.plt.show()
                else:
                    result.plot(kind='bar', title=key)
                    plt.show()
예제 #7
0
파일: video.py 프로젝트: tyarkoni/featureX
    def _filter(self, video):
        if not isinstance(video, VideoStim):
            raise TypeError('Currently, frame sampling is only supported for '
                            'complete VideoStim inputs.')

        if self.every is not None:
            new_idx = range(video.n_frames)[::self.every]
        elif self.hertz is not None:
            interval = video.fps / float(self.hertz)
            new_idx = np.arange(0, video.n_frames, interval).astype(int)
            new_idx = list(new_idx)
        elif self.top_n is not None:
            verify_dependencies(['cv2'])
            diffs = []
            for i, img in enumerate(video.frames):
                if i == 0:
                    last = img
                    continue
                pixel_diffs = cv2.sumElems(cv2.absdiff(last.data, img.data))
                diffs.append(sum(pixel_diffs))
                last = img
            new_idx = sorted(range(len(diffs)),
                             key=lambda i: diffs[i],
                             reverse=True)[:self.top_n]

        return VideoFrameCollectionStim(filename=video.filename,
                                        clip=video.clip,
                                        frame_index=new_idx)
예제 #8
0
파일: image.py 프로젝트: tyarkoni/featureX
    def __init__(self, **face_recognition_kwargs):
        verify_dependencies(['face_recognition'])

        self.face_recognition_kwargs = face_recognition_kwargs
        func = getattr(face_recognition.api, self._feature)
        self.func = partial(func, **face_recognition_kwargs)

        super(FaceRecognitionFeatureExtractor, self).__init__()
예제 #9
0
파일: api.py 프로젝트: tyarkoni/featureX
 def check_valid_keys(self):
     verify_dependencies(['twitter'])
     try:
         self.api.VerifyCredentials()
         return True
     except twitter.error.TwitterError as e:
         logging.warn(str(e))
         return False
예제 #10
0
 def _query_api(self, objects):
     verify_dependencies(['clarifai_client'])
     moc = clarifai_client.ModelOutputConfig(min_value=self.min_value,
                                             max_concepts=self.max_concepts,
                                             select_concepts=self.select_concepts)
     model_output_info = clarifai_client.ModelOutputInfo(output_config=moc)
     tags = self.model.predict(objects, model_output_info=model_output_info)
     return tags['outputs']
예제 #11
0
파일: wit.py 프로젝트: tyarkoni/featureX
    def _convert(self, audio):
        verify_dependencies(['sr'])
        with audio.get_filename() as filename:
            with sr.AudioFile(filename) as source:
                clip = self.recognizer.record(source)

        text = getattr(self.recognizer, self.recognize_method)(clip, self.api_key)

        return ComplexTextStim(text=text)
예제 #12
0
파일: models.py 프로젝트: tyarkoni/featureX
 def __init__(self, weights=None, num_predictions=5):
     verify_dependencies(['tensorflow'])
     super(TensorFlowKerasInceptionV3Extractor, self).__init__()
     if weights is None:
         weights = 'imagenet'
     self.weights = weights
     self.num_predictions = num_predictions
     # Instantiating the model also downloads the weights to a cache.
     self.model = tf.keras.applications.inception_v3.InceptionV3(
         weights=self.weights)
예제 #13
0
def test_google_language_api_entity_sentiment_extractor():
    verify_dependencies(['googleapiclient'])
    ext = GoogleLanguageAPIEntitySentimentExtractor()
    stim = TextStim(join(TEXT_DIR, 'sample_text_with_entities.txt'))
    result = ext.transform(stim).to_df(timing=False, object_id='auto')
    # Produces same result as entity extractor with sentiment columns
    assert result.shape == (10, 11)
    assert result['text'][8] == 'phones'
    assert result['type'][8] == 'CONSUMER_GOOD'
    assert 'sentiment_score' in result.columns
    assert result['sentiment_score'][8] > 0.6 # 'love their ... phones'
예제 #14
0
def test_google_language_api_category_extractor():
    verify_dependencies(['googleapiclient'])
    ext = GoogleLanguageAPITextCategoryExtractor()
    stim = TextStim(join(TEXT_DIR, 'sample_text_with_entities.txt'))
    result = ext.transform(stim).to_df(timing=False, object_id='auto')
    assert result.shape == (1, 4)
    assert 'category_/Computers & Electronics' in result.columns
    assert result['category_/Computers & Electronics'][0] > 0.3
    assert 'category_/News' in result.columns
    assert result['category_/News'][0] > 0.3
    assert result['language'][0] == 'en'
예제 #15
0
파일: wit.py 프로젝트: tyarkoni/featureX
 def __init__(self, api_key=None, rate_limit=None):
     verify_dependencies(['sr'])
     if api_key is None:
         try:
             api_key = os.environ[self.env_keys[0]]
         except KeyError:
             raise ValueError("A valid API key must be passed when a"
                              " SpeechRecognitionAPIConverter is initialized.")
     self.recognizer = sr.Recognizer()
     self.api_key = api_key
     super(SpeechRecognitionAPIConverter, self).__init__(rate_limit=rate_limit)
예제 #16
0
파일: image.py 프로젝트: tyarkoni/featureX
    def _extract(self, stim):
        verify_dependencies(['cv2'])
        # Taken from
        # http://stackoverflow.com/questions/7765810/is-there-a-way-to-detect-if-an-image-is-blurry?lq=1
        data = stim.data
        gray_image = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)

        sharpness = np.max(
            cv2.convertScaleAbs(cv2.Laplacian(gray_image, 3))) / 255.0
        return ExtractorResult(np.array([[sharpness]]), stim, self,
                               features=['sharpness'])
예제 #17
0
파일: revai.py 프로젝트: tyarkoni/featureX
 def __init__(self, access_token=None, timeout=1000, request_rate=5):
     verify_dependencies(['rev_ai_client'])
     if access_token is None:
         try:
             access_token = os.environ['REVAI_ACCESS_TOKEN']
         except KeyError:
             raise ValueError("A valid API key must be passed when a "
                              "RevAISpeechAPIConverter is initialized.")
     self.access_token = access_token
     self.timeout = timeout
     self.request_rate = request_rate
     self.client = rev_ai_client.RevAiAPIClient(access_token)
     super(RevAISpeechAPIConverter, self).__init__()
예제 #18
0
def test_google_language_api_entity_extractor():
    verify_dependencies(['googleapiclient'])
    ext = GoogleLanguageAPIEntityExtractor()
    stim = TextStim(join(TEXT_DIR, 'sample_text_with_entities.txt'))
    result = ext.transform(stim).to_df(timing=False, object_id='auto')
    assert result.shape == (10, 9)
    assert result['text'][0] == 'Google'
    assert result['type'][0] == 'ORGANIZATION'
    assert result['salience'][0] > 0.0 and result['salience'][0] < 0.5
    assert result['begin_char_index'][4] == 165.0
    assert result['end_char_index'][4] == 172.0
    assert result['text'][4] == 'Android'
    assert result['type'][4] == 'CONSUMER_GOOD'
예제 #19
0
def test_google_language_api_sentiment_extractor():
    verify_dependencies(['googleapiclient'])
    ext = GoogleLanguageAPISentimentExtractor()
    stim = TextStim(join(TEXT_DIR, 'scandal.txt'))
    result = ext.transform(stim).to_df(timing=False, object_id='auto')
    assert result.shape == (12, 7)
    assert 'sentiment_magnitude' in result.columns
    assert 'text' in result.columns
    doc_sentiment = result['sentiment_score'][11]
    assert doc_sentiment < 0.3 and doc_sentiment > -0.3
    assert result['begin_char_index'][7] == 565.0
    assert result['end_char_index'][7] == 672.0
    assert result['sentiment_magnitude'][7] > 0.6
    assert result['sentiment_score'][7] > 0.6
예제 #20
0
파일: ibm.py 프로젝트: tyarkoni/featureX
 def __init__(self, username=None, password=None, resolution='words',
              rate_limit=None, model='en-US'):
     verify_dependencies(['sr'])
     if username is None or password is None:
         try:
             username = os.environ['IBM_USERNAME']
             password = os.environ['IBM_PASSWORD']
         except KeyError:
             raise ValueError("A valid API key must be passed when a "
                              "SpeechRecognitionConverter is initialized.")
     self.recognizer = sr.Recognizer()
     self.username = username
     self.password = password
     self.resolution = resolution
     self.model = model
     super(IBMSpeechAPIConverter, self).__init__(rate_limit=rate_limit)
예제 #21
0
    def _extract(self, stims):
        verify_dependencies(['clarifai_client'])

        # ExitStack lets us use filename context managers simultaneously
        with ExitStack() as stack:
            imgs = []
            for s in stims:
                if s.url:
                    imgs.append(clarifai_client.Image(url=s.url))
                else:
                    f = stack.enter_context(s.get_filename())
                    imgs.append(clarifai_client.Image(filename=f))
            outputs = self._query_api(imgs)

        extractions = []
        for i, resp in enumerate(outputs):
            extractions.append(ExtractorResult(resp, stims[i], self))
        return extractions
예제 #22
0
파일: text.py 프로젝트: tyarkoni/featureX
    def save(self, path):
        if path.endswith('srt'):
            verify_dependencies(['pysrt'])
            from pysrt import SubRipFile, SubRipItem
            from datetime import time

            out = SubRipFile()
            for elem in self._elements:
                start = time(*self._to_tup(elem.onset))
                end = time(*self._to_tup(elem.onset + elem.duration))
                out.append(SubRipItem(0, start, end, elem.text))
            out.save(path)
        else:
            with open(path, 'w') as f:
                f.write('onset\ttext\tduration\n')
                for elem in self._elements:
                    f.write('{}\t{}\t{}\n'.format(elem.onset,
                                                  elem.text,
                                                  elem.duration))
예제 #23
0
파일: revai.py 프로젝트: tyarkoni/featureX
    def _convert(self, audio):
        verify_dependencies(['rev_ai'])
        msg = "Beginning audio transcription with a timeout of %fs. Even for "\
              "small audios, full transcription may take awhile." % self.timeout
        logging.warning(msg)

        if audio.url:
            job = self.client.submit_job_url(audio.url)
        else:
            with audio.get_filename() as filename:
                job = self.client.submit_job_local_file(filename)

        operation_start = time.time()
        response = self.client.get_job_details(job.id)
        while (response.status == rev_ai.JobStatus.IN_PROGRESS) and \
              (time.time() - operation_start) < self.timeout:
            response = self.client.get_job_details(job.id)
            time.sleep(self.request_rate)

        if (time.time() - operation_start) >= self.timeout:
            msg = "Conversion reached the timeout limit of %fs." % self.timeout
            logging.warning(msg)

        if response.status == rev_ai.JobStatus.FAILED:
            raise Exception('API failed: %s' % response.failure_detail)

        result = self.client.get_transcript_object(job.id)

        elements = []
        order = 0
        for m in result.monologues:
            for e in m.elements:
                if e.type_ == 'text':
                    start = e.timestamp
                    end = e.end_timestamp
                    elements.append(TextStim(text=e.value,
                                             onset=start,
                                             duration=end-start,
                                             order=order))
                    order += 1

        return ComplexTextStim(elements=elements, onset=audio.onset)
예제 #24
0
def test_google_language_api_extractor():
    verify_dependencies(['googleapiclient'])
    ext = GoogleLanguageAPIExtractor(features=['classifyText',
                                               'extractEntities'])
    stim = TextStim(text='hello world')

    with pytest.raises(googleapiclient.errors.HttpError):
        # Should fail because too few tokens
        ext.transform(stim)

    stim = TextStim(join(TEXT_DIR, 'scandal.txt'))
    result = ext.transform(stim).to_df(timing=False, object_id='auto')
    assert result.shape == (43, 10)
    assert 'category_/Books & Literature' in result.columns
    assert result['category_/Books & Literature'][0] > 0.5
    irene = result[result['text'] == 'Irene Adler']
    assert (irene['type'] == 'PERSON').all()
    assert not irene['metadata_wikipedia_url'].isna().any()
    # Document row shouldn't have entity features, and vice versa
    assert np.isnan(result.iloc[0]['text'])
    assert np.isnan(result.iloc[1]['category_/Books & Literature']).all()
예제 #25
0
파일: text.py 프로젝트: tyarkoni/featureX
    def _from_srt(self, filename):
        verify_dependencies(['pysrt'])

        data = pysrt.open(filename)
        list_ = [[] for _ in data]
        for i, row in enumerate(data):
            start = tuple(row.start)
            start_time = self._to_sec(start)

            end_ = tuple(row.end)
            duration = self._to_sec(end_) - start_time

            line = re.sub('\s+', ' ', row.text)
            list_[i] = [line, start_time, duration]

        # Convert to pandas DataFrame
        df = pd.DataFrame(columns=["text", "onset", "duration"], data=list_)

        for i, r in df.iterrows():
            elem = TextStim(filename, text=r['text'], onset=r['onset'],
                            duration=r['duration'], order=i)
            self._elements.append(elem)
예제 #26
0
파일: ibm.py 프로젝트: tyarkoni/featureX
    def _convert(self, audio):
        verify_dependencies(['sr'])

        with audio.get_filename() as filename:
            with sr.AudioFile(filename) as source:
                clip = self.recognizer.record(source)

        _json = self._query_api(clip)
        if 'results' in _json:
            results = _json['results']
        else:
            raise Exception(
                'received invalid results from API: {0}'.format(str(_json)))
        elements = []
        order = 0
        for result in results:
            if result['final'] is True:
                timestamps = result['alternatives'][0]['timestamps']
                if self.resolution is 'words':
                    for entry in timestamps:
                        text = entry[0]
                        start = entry[1]
                        end = entry[2]
                        elements.append(TextStim(text=text,
                                                 onset=start,
                                                 duration=end-start,
                                                 order=order))
                        order += 1
                elif self.resolution is 'phrases':
                    text = result['alternatives'][0]['transcript']
                    start = timestamps[0][1]
                    end = timestamps[-1][2]
                    elements.append(TextStim(text=text,
                                             onset=start,
                                             duration=end-start,
                                             order=order))
                    order += 1
        return ComplexTextStim(elements=elements, onset=audio.onset)
예제 #27
0
def test_google_language_api_syntax_extractor():
    verify_dependencies(['googleapiclient'])
    ext = GoogleLanguageAPISyntaxExtractor()
    stim = TextStim(join(TEXT_DIR, 'sample_text_with_entities.txt'))
    result = ext.transform(stim).to_df(timing=False, object_id='auto')
    assert result.shape == (32, 20)
    his = result[result['text'] == 'his']
    assert (his['person'] == 'THIRD').all()
    assert (his['gender'] == 'MASCULINE').all()
    assert (his['case'] == 'GENITIVE').all()
    their = result[result['text'] == 'their']
    assert (their['person'] == 'THIRD').all()
    assert (their['number'] == 'PLURAL').all()
    love = result[result['text'] == 'love']
    assert (love['tag'] == 'VERB').all()
    assert (love['mood'] == 'INDICATIVE').all()
    headquartered = result[result['text'] == 'headquartered']
    assert (headquartered['tense'] == 'PAST').all()
    assert (headquartered['lemma'] == 'headquarter').all()
    google = result[result['text'] == 'Google']
    assert (google['proper'] == 'PROPER').all()
    assert (google['tag'] == 'NOUN').all()
    assert (google['dependency_label'] == 'NSUBJ').all()
    assert (google['dependency_headTokenIndex'] == 7).all()
예제 #28
0
 def _extract(self, stim):
     verify_dependencies(['clarifai_client'])
     with stim.get_filename() as filename:
         vids = [clarifai_client.Video(filename=filename)]
         outputs = self._query_api(vids)
     return ExtractorResult(outputs, stim, self)
예제 #29
0
파일: image.py 프로젝트: undarmaa/pliers
 def _convert(self, stim):
     verify_dependencies(['pytesseract'])
     text = pytesseract.image_to_string(Image.fromarray(stim.data))
     return TextStim(text=text, onset=stim.onset, duration=stim.duration)
예제 #30
0
 def __init__(self, **kwargs):
     verify_dependencies(['indicoio'])
     self.allowed_models = indicoio.TEXT_APIS.keys()
     super(IndicoAPITextExtractor, self).__init__(**kwargs)
예제 #31
0
 def __init__(self, embedding_file, binary=False, prefix='embedding_dim'):
     verify_dependencies(['keyedvectors'])
     self.wvModel = keyedvectors.KeyedVectors.load_word2vec_format(
         embedding_file, binary=binary)
     self.prefix = prefix
     super(WordEmbeddingExtractor, self).__init__()
예제 #32
0
 def _extract(self, stim):
     verify_dependencies(['clarifai_client'])
     with stim.get_filename() as filename:
         vids = [clarifai_client.Video(filename=filename)]
         outputs = self._query_api(vids)
     return ExtractorResult(outputs, stim, self)
예제 #33
0
 def __init__(self, api_key=None, models=None, rate_limit=None):
     verify_dependencies(['indicoio'])
     self.allowed_models = indicoio.IMAGE_APIS.keys()
     super(IndicoAPIImageExtractor, self).__init__(api_key=api_key,
                                                   models=models,
                                                   rate_limit=rate_limit)
예제 #34
0
파일: graph.py 프로젝트: tyarkoni/featureX
    def draw(self, filename, color=True):
        ''' Render a plot of the graph via pygraphviz.

        Args:
            filename (str): Path to save the generated image to.
            color (bool): If True, will color graph nodes based on their type,
                otherwise will draw a black-and-white graph.
        '''
        verify_dependencies(['pgv'])
        if not hasattr(self, '_results'):
            raise RuntimeError("Graph cannot be drawn before it is executed. "
                               "Try calling run() first.")

        g = pgv.AGraph(directed=True)
        g.node_attr['colorscheme'] = 'set312'

        for elem in self._results:
            if not hasattr(elem, 'history'):
                continue
            log = elem.history

            while log:
                # Configure nodes
                source_from = log.parent[6] if log.parent else ''
                s_node = hash((source_from, log[2]))
                s_color = stim_list.index(log[2])
                s_color = s_color % 12 + 1

                t_node = hash((log[6], log[7]))
                t_style = 'filled,' if color else ''
                t_style += 'dotted' if log.implicit else ''
                if log[6].endswith('Extractor'):
                    t_color = '#0082c8'
                elif log[6].endswith('Filter'):
                    t_color = '#e6194b'
                else:
                    t_color = '#3cb44b'

                r_node = hash((log[6], log[5]))
                r_color = stim_list.index(log[5])
                r_color = r_color % 12 + 1

                # Add nodes
                if color:
                    g.add_node(s_node, label=log[2], shape='ellipse',
                               style='filled', fillcolor=s_color)
                    g.add_node(t_node, label=log[6], shape='box',
                               style=t_style, fillcolor=t_color)
                    g.add_node(r_node, label=log[5], shape='ellipse',
                               style='filled', fillcolor=r_color)
                else:
                    g.add_node(s_node, label=log[2], shape='ellipse')
                    g.add_node(t_node, label=log[6], shape='box',
                               style=t_style)
                    g.add_node(r_node, label=log[5], shape='ellipse')

                # Add edges
                g.add_edge(s_node, t_node, style=t_style)
                g.add_edge(t_node, r_node, style=t_style)
                log = log.parent

        g.draw(filename, prog='dot')
예제 #35
0
    def __init__(self,method="averageWordEmbedding",\
                             embedding="glove",\
                             dimensionality=300 ,\
                             corpus="6B",\
                             content_only=True,\
                             binary=False,\
                             stopWords=None,\
                             unk_vector=None,\
                             layer=None):

        verify_dependencies(['keyedvectors'])
        verify_dependencies(['doc2vecVectors'])
        '''Check the instance type of all the parameters'''

        if not isinstance(method, str):
            raise ValueError('Method should be string (default is ' + \
            'AverageWordEmbedding) or select' + \
            ' from the list provided in README.')
        '''checking instance type of the parameters is done'''

        self.method = method
        if self.method.lower() == "averagewordembedding":

            self.semvector_object = AverageEmbeddingExtractor(embedding=embedding,\
                             dimensionality=dimensionality,\
                             corpus=corpus,\
                             content_only=content_only,\
                             binary = binary,\
                             stopWords=stopWords)

            self.semvector_object._loadModel()

        elif method.lower() == 'skipthought':

            self.semvector_object = SkipThoughtExtractor()
            self.semvector_object._loadModel()

        elif method.lower() == 'sif':

            self.semvector_object = SmoothInverseFrequencyExtractor(embedding=embedding,\
                             dimensionality=dimensionality,\
                             corpus=corpus,\
                             content_only=content_only,\
                             stopWords=stopWords)
            self.semvector_object._loadModel()

        elif method.lower() == 'doc2vec':

            self.semvector_object = Doc2vecExtractor()
            self.semvector_object._loadModel()

        elif method.lower() == 'elmo':
            self.semvector_object = ElmoExtractor(layer=layer)

        elif method.lower() == 'dan':
            self.semvector_object = DANExtractor()

        elif method.lower() == 'bert':
            self.semvector_object = BertExtractor()

        else:
            raise ValueError('Method: ' + '\"' + method + '\"' ' is not supported. Default is ' + \
            'AverageWordEmbedding or select' + \
            ' from the list provided in README.')

        super(DirectTextExtractorInterface, self).__init__()