コード例 #1
0
ファイル: server_test.py プロジェクト: spotify/luigi
    def test_api_cors_headers_disallow_any_no_matched_allowed_origins(self):
        get_config().set('cors', 'allow_any_origin', 'false')
        get_config().set('cors', 'allowed_origins', '["foo", "bar"]')
        response = self.fetch('/api/graph', headers={'Origin': 'foobar'})
        headers = dict(response.headers)

        self.assertIsNone(headers.get('Access-Control-Allow-Origin'))
コード例 #2
0
ファイル: notifications.py プロジェクト: akaul/luigi
def send_email(subject, message, sender, recipients, image_png=None):
    config = configuration.get_config()

    subject = _prefix(subject)
    logger.debug("Emailing:\n"
                 "-------------\n"
                 "To: %s\n"
                 "From: %s\n"
                 "Subject: %s\n"
                 "Message:\n"
                 "%s\n"
                 "-------------", recipients, sender, subject, message)
    if not recipients or recipients == (None,):
        return
    if (sys.stdout.isatty() or DEBUG) and (not config.getboolean('email', 'force-send', False)):
        logger.info("Not sending email when running from a tty or in debug mode")
        return

    config = configuration.get_config()

    # Clean the recipients lists to allow multiple error-email addresses, comma
    # separated in client.cfg
    recipients_tmp = []
    for r in recipients:
        recipients_tmp.extend(r.split(','))

    # Replace original recipients with the clean list
    recipients = recipients_tmp

    if config.get('email', 'type', None) == "ses":
        send_email_ses(config, sender, subject, message, recipients, image_png)
    else:
        send_email_smtp(config, sender, subject, message, recipients, image_png)
コード例 #3
0
 def parameters(self):
     return {'CONN': configuration.get_config().get('mongodb', 'mongo_conn'),
             'DB': configuration.get_config().get('mongodb', 'mongo_db'),
             'II_COLLECTION': '%s_%s' % (self.mongodb_output_collection_name, 'II'),
             'UI_COLLECTION': '%s_%s' % (self.mongodb_output_collection_name, 'UI'),
             'OUTPUT_PATH': self.output_base_path
            }
コード例 #4
0
ファイル: server_test.py プロジェクト: spotify/luigi
    def test_api_cors_headers_disallow_any(self):
        get_config().set('cors', 'allow_any_origin', 'false')
        get_config().set('cors', 'allowed_origins', '["foo", "bar"]')
        response = self.fetch('/api/graph', headers={'Origin': 'foo'})
        headers = dict(response.headers)

        self.assertEqual('foo', headers['Access-Control-Allow-Origin'])
コード例 #5
0
ファイル: gitrec-luigi.py プロジェクト: mortardata/gitrec
 def parameters(self):
     return {'OUTPUT_PATH': self.output_base_path,
             'USER_CONTRIB_TABLE': 'github_sqrt_contrib_recs_%s' % self.date_string,
             'GRAVATAR_TABLE': 'github_sqrt_user_gravatar_ids_%s' % self.date_string,
             'USER_INTEREST_TABLE': 'github_sqrt_user_interest_recs_%s' % self.date_string,
             'ITEM_ITEM_TABLE': 'github_sqrt_repo_recs_%s' % self.date_string,
             'AWS_ACCESS_KEY_ID': configuration.get_config().get('dynamodb', 'aws_access_key_id'),
             'AWS_SECRET_ACCESS_KEY': configuration.get_config().get('dynamodb', 'aws_secret_access_key')}
コード例 #6
0
ファイル: mongo-luigi.py プロジェクト: CC213/mortar-recsys
 def parameters(self):
     """
     This method defines the parameters that will be passed to Mortar when starting
     this pigscript.
     """
     return {'CONN': configuration.get_config().get('mongodb', 'mongo_conn'),
             'DB': configuration.get_config().get('mongodb', 'mongo_db'),
             'COLLECTION': configuration.get_config().get('mongodb', 'mongo_input_collection'),
             'OUTPUT_PATH': self.output_base_path}
コード例 #7
0
    def __init__(self, *args, **kwargs):
        super(OptionalVerticaMixin, self).__init__(*args, **kwargs)

        if not self.vertica_credentials:
            self.vertica_credentials = get_config().get('vertica-export', 'credentials', None)

        if not self.vertica_schema:
            self.vertica_schema = get_config().get('vertica-export', 'schema', None)

        self.vertica_enabled = self.vertica_credentials and self.vertica_schema
コード例 #8
0
ファイル: dbms-luigi.py プロジェクト: CC213/mortar-recsys
 def parameters(self):
     return {'DATABASE_DRIVER': 'org.postgresql.Driver',
             'DATABASE_TYPE': 'postgresql',
             'DATABASE_HOST': '%s:%s' % (configuration.get_config().get('postgres', 'host'), configuration.get_config().get('postgres', 'port')),
             'DATABASE_NAME': configuration.get_config().get('postgres', 'dbname'),
             'DATABASE_USER': configuration.get_config().get('postgres', 'user'),
             'II_TABLE': '%s%s' % (self.table_name_prefix, 'ii'),
             'UI_TABLE': '%s%s' % (self.table_name_prefix, 'ui'),
             'OUTPUT_PATH': self.output_base_path
            }
コード例 #9
0
ファイル: spark.py プロジェクト: nzlinus/luigi
    def run(self):
        warnings.warn("The use of SparkJob is deprecated. Please use SparkSubmitTask or PySparkTask.", stacklevel=2)
        original_output_path = self.output().path
        path_no_slash = original_output_path[:-2] if original_output_path.endswith("/*") else original_output_path
        path_no_slash = original_output_path[:-1] if original_output_path[-1] == "/" else path_no_slash
        tmp_output = luigi.hdfs.HdfsTarget(path_no_slash + "-luigi-tmp-%09d" % random.randrange(0, 1e10))

        args = ["org.apache.spark.deploy.yarn.Client"]
        args += ["--jar", self.jar()]
        args += ["--class", self.job_class()]

        for a in self.job_args():
            if a == self.output().path:
                # pass temporary output path to job args
                logger.info("Using temp path: %s for path %s", tmp_output.path, original_output_path)
                args += ["--args", tmp_output.path]
            else:
                args += ["--args", str(a)]

        if self.spark_workers is not None:
            args += ["--num-workers", self.spark_workers]

        if self.spark_master_memory is not None:
            args += ["--master-memory", self.spark_master_memory]

        if self.spark_worker_memory is not None:
            args += ["--worker-memory", self.spark_worker_memory]

        queue = self.queue
        if queue is not None:
            args += ["--queue", queue]

        env = os.environ.copy()
        env["SPARK_JAR"] = configuration.get_config().get("spark", "spark-jar")
        env["HADOOP_CONF_DIR"] = configuration.get_config().get("spark", "hadoop-conf-dir")
        env["MASTER"] = "yarn-client"
        spark_class = configuration.get_config().get("spark", "spark-class")

        temp_stderr = tempfile.TemporaryFile()
        logger.info("Running: %s %s", spark_class, " ".join(args))
        proc = subprocess.Popen(
            [spark_class] + args, stdout=subprocess.PIPE, stderr=temp_stderr, env=env, close_fds=True
        )

        return_code, final_state, app_id = self.track_progress(proc)
        if return_code == 0 and final_state != "FAILED":
            tmp_output.move(path_no_slash)
        elif final_state == "FAILED":
            raise SparkJobError("Spark job failed: see yarn logs for %s" % app_id)
        else:
            temp_stderr.seek(0)
            errors = "".join((x.decode("utf8") for x in temp_stderr.readlines()))
            logger.error(errors)
            raise SparkJobError("Spark job failed", err=errors)
コード例 #10
0
ファイル: server_test.py プロジェクト: spotify/luigi
    def test_api_preflight_cors_headers_disabled(self):
        get_config().set('cors', 'enabled', 'false')
        response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foo'})
        headers = dict(response.headers)

        self.assertNotIn('Access-Control-Allow-Headers', headers)
        self.assertNotIn('Access-Control-Allow-Methods', headers)
        self.assertNotIn('Access-Control-Allow-Origin', headers)
        self.assertNotIn('Access-Control-Max-Age', headers)
        self.assertNotIn('Access-Control-Allow-Credentials', headers)
        self.assertNotIn('Access-Control-Expose-Headers', headers)
コード例 #11
0
ファイル: server_test.py プロジェクト: spotify/luigi
    def test_api_preflight_cors_headers_disallow_any_no_matched_allowed_origins(self):
        get_config().set('cors', 'allow_any_origin', 'false')
        get_config().set('cors', 'allowed_origins', '["foo", "bar"]')
        response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foobar'})
        headers = dict(response.headers)

        self.assertNotIn('Access-Control-Allow-Headers', headers)
        self.assertNotIn('Access-Control-Allow-Methods', headers)
        self.assertNotIn('Access-Control-Allow-Origin', headers)
        self.assertNotIn('Access-Control-Max-Age', headers)
        self.assertNotIn('Access-Control-Allow-Credentials', headers)
        self.assertNotIn('Access-Control-Expose-Headers', headers)
コード例 #12
0
ファイル: spark.py プロジェクト: nirmeshk/luigi
    def run(self):
        warnings.warn("The use of SparkJob is deprecated. Please use SparkSubmitTask or PySparkTask.", stacklevel=2)
        original_output_path = self.output().path
        path_no_slash = original_output_path[:-2] if original_output_path.endswith('/*') else original_output_path
        path_no_slash = original_output_path[:-1] if original_output_path[-1] == '/' else path_no_slash
        tmp_output = luigi.contrib.hdfs.HdfsTarget(path_no_slash + '-luigi-tmp-%09d' % random.randrange(0, 1e10))

        args = ['org.apache.spark.deploy.yarn.Client']
        args += ['--jar', self.jar()]
        args += ['--class', self.job_class()]

        for a in self.job_args():
            if a == self.output().path:
                # pass temporary output path to job args
                logger.info('Using temp path: %s for path %s', tmp_output.path, original_output_path)
                args += ['--args', tmp_output.path]
            else:
                args += ['--args', str(a)]

        if self.spark_workers is not None:
            args += ['--num-workers', self.spark_workers]

        if self.spark_master_memory is not None:
            args += ['--master-memory', self.spark_master_memory]

        if self.spark_worker_memory is not None:
            args += ['--worker-memory', self.spark_worker_memory]

        queue = self.queue
        if queue is not None:
            args += ['--queue', queue]

        env = os.environ.copy()
        env['SPARK_JAR'] = configuration.get_config().get('spark', 'spark-jar')
        env['HADOOP_CONF_DIR'] = configuration.get_config().get('spark', 'hadoop-conf-dir')
        env['MASTER'] = 'yarn-client'
        spark_class = configuration.get_config().get('spark', 'spark-class')

        temp_stderr = tempfile.TemporaryFile()
        logger.info('Running: %s %s', spark_class, subprocess.list2cmdline(args))
        proc = subprocess.Popen([spark_class] + args, stdout=subprocess.PIPE,
                                stderr=temp_stderr, env=env, close_fds=True)

        return_code, final_state, app_id = self.track_progress(proc)
        if return_code == 0 and final_state != 'FAILED':
            tmp_output.move(path_no_slash)
        elif final_state == 'FAILED':
            raise SparkJobError('Spark job failed: see yarn logs for %s' % app_id)
        else:
            temp_stderr.seek(0)
            errors = "".join((x.decode('utf8') for x in temp_stderr.readlines()))
            logger.error(errors)
            raise SparkJobError('Spark job failed', err=errors)
コード例 #13
0
 def _set_tables(self):
     headers = {'Accept': 'application/json',
                'Accept-Encoding': 'gzip',
                'Content-Type': 'application/json',
                'User-Agent': 'mortar-luigi'}
     url = self._client_update_endpoint()
     body = {'ii_table': self.table_names()['ii_table'],
             'ui_table': self.table_names()['ui_table']}
     auth = HTTPBasicAuth(configuration.get_config().get('recsys', 'email'),
                          configuration.get_config().get('recsys', 'password'))
     logger.info('Setting new tables to %s at %s' % (body, url))
     response = requests.put(url, data=json.dumps(body), auth=auth, headers=headers)
     response.raise_for_status()
コード例 #14
0
ファイル: server_test.py プロジェクト: spotify/luigi
    def test_api_preflight_cors_headers_disallow_any(self):
        get_config().set('cors', 'allow_any_origin', 'false')
        get_config().set('cors', 'allowed_origins', '["foo", "bar"]')
        response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foo'})
        headers = dict(response.headers)

        self.assertEqual(self._default_cors.allowed_headers,
                         headers['Access-Control-Allow-Headers'])
        self.assertEqual(self._default_cors.allowed_methods,
                         headers['Access-Control-Allow-Methods'])
        self.assertEqual('foo', headers['Access-Control-Allow-Origin'])
        self.assertEqual(str(self._default_cors.max_age), headers['Access-Control-Max-Age'])
        self.assertIsNone(headers.get('Access-Control-Allow-Credentials'))
        self.assertIsNone(headers.get('Access-Control-Expose-Headers'))
コード例 #15
0
ファイル: server_test.py プロジェクト: spotify/luigi
    def test_api_preflight_cors_headers_all_response_headers(self):
        get_config().set('cors', 'allow_credentials', 'true')
        get_config().set('cors', 'exposed_headers', 'foo, bar')
        response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foo'})
        headers = dict(response.headers)

        self.assertEqual(self._default_cors.allowed_headers,
                         headers['Access-Control-Allow-Headers'])
        self.assertEqual(self._default_cors.allowed_methods,
                         headers['Access-Control-Allow-Methods'])
        self.assertEqual('*', headers['Access-Control-Allow-Origin'])
        self.assertEqual(str(self._default_cors.max_age), headers['Access-Control-Max-Age'])
        self.assertEqual('true', headers['Access-Control-Allow-Credentials'])
        self.assertEqual('foo, bar', headers['Access-Control-Expose-Headers'])
コード例 #16
0
ファイル: server_test.py プロジェクト: spotify/luigi
    def setUp(self):
        super(ServerTest, self).setUp()
        get_config().remove_section('cors')
        self._default_cors = luigi.server.cors()

        get_config().set('cors', 'enabled', 'true')
        get_config().set('cors', 'allow_any_origin', 'true')
        get_config().set('cors', 'allow_null_origin', 'true')
コード例 #17
0
ファイル: s3.py プロジェクト: roverdotcom/luigi
 def _get_s3_config(self, key=None):
     defaults = dict(configuration.get_config().defaults())
     try:
         config = dict(configuration.get_config().items('s3'))
     except NoSectionError:
         return {}
     # So what ports etc can be read without us having to specify all dtypes
     for k, v in six.iteritems(config):
         try:
             config[k] = int(v)
         except ValueError:
             pass
     if key:
         return config.get(key)
     section_only = {k: v for k, v in config.items() if k not in defaults or v != defaults[k]}
     return section_only
コード例 #18
0
ファイル: s3.py プロジェクト: arunpn/semafor-parsing
 def _get_s3_config(self, key):
     try:
         return configuration.get_config().get('s3', key)
     except NoSectionError:
         return None
     except NoOptionError:
         return None
コード例 #19
0
 def __init__(self):
     configuration = get_config()
     self.partner = configuration.get('paypal', 'partner', 'PayPal')
     self.vendor = configuration.get('paypal', 'vendor')
     self.password = configuration.get('paypal', 'password')
     self.user = configuration.get('paypal', 'user', None)
     self.url = configuration.get('paypal', 'url')
コード例 #20
0
ファイル: notifications.py プロジェクト: cpapazian/luigi
def send_error_email(subject, message):
    """
    Sends an email to the configured error-email.

    If no error-email is configured, then a message is logged.
    """
    config = configuration.get_config()
    receiver = config.get('core', 'error-email', None)
    sns_topic = config.get('core', 'error-sns-topic', None)
    if receiver:
        sender = config.get('core', 'email-sender', DEFAULT_CLIENT_EMAIL)
        logger.info("Sending warning email to %r", receiver)
        send_email(
            subject=subject,
            message=message,
            sender=sender,
            recipients=(receiver,)
        )
    elif sns_topic:
        logger.info("Sending warning via %r", sns_topic)
        send_sns_notification(
            subject=subject,
            message=message,
            topic=sns_topic
        )
    else:
        logger.info("Skipping error email. Set `error-email` in the `core` "
                    "section of the luigi config file to receive error "
                    "emails.")
コード例 #21
0
ファイル: notifications.py プロジェクト: cpapazian/luigi
def send_sns_notification(subject, message, topic):
    import boto.sns
    config = configuration.get_config()
    con = boto.sns.connect_to_region(config.get('sns', 'region', 'us-east-1'),
                                     aws_access_key_id=config.get('sns', 'AWS_ACCESS_KEY', None),
                                     aws_secret_access_key=config.get('sns', 'AWS_SECRET_KEY', None))
    con.publish(topic, message, subject[:100])
コード例 #22
0
    def run(self):
        self.remove_output_on_overwrite()

        ppr = PaypalReportRequest(
            "SettlementReport",
            processor="PayPal",
            start_date=self.date.isoformat() + " 00:00:00",
            end_date=self.date.isoformat() + " 23:59:59",
            timezone="GMT",
        )
        report_response = ppr.execute()
        report_id = report_response.report_id

        is_running = report_response.is_running
        timeout = get_config().getint("paypal", "timeout", 60 * 60 * 2)
        start_time = time.time()
        while is_running:
            if timeout >= 0 and time.time() >= (start_time + timeout):
                raise PaypalTimeoutError(start_time)
            time.sleep(5)
            results_response = PaypalReportResultsRequest(report_id=report_id).execute()
            is_running = results_response.is_running

        metadata_response = PaypalReportMetadataRequest(report_id=report_id).execute()

        with self.output().open("w") as output_tsv_file:
            for page_num in range(metadata_response.num_pages):
                data_response = PaypalReportDataRequest(report_id=report_id, page_num=(page_num + 1)).execute()
                for row in data_response.rows:
                    self.write_transaction_record(row, output_tsv_file)
コード例 #23
0
ファイル: notifications.py プロジェクト: DrValani/luigi
def send_email(subject, message, sender, recipients, image_png=None):
    """
    Decides whether to send notification. Notification is cancelled if there are
    no recipients or if stdout is onto tty or if in debug mode.

    Dispatches on config value email.type.  Default is 'smtp'.
    """
    config = configuration.get_config()
    notifiers = {'ses': send_email_ses,
                 'sendgrid': send_email_sendgrid,
                 'smtp': send_email_smtp,
                 'sns': send_email_sns}

    subject = _prefix(subject)
    if not recipients or recipients == (None,):
        return
    if _email_disabled():
        return

    # Clean the recipients lists to allow multiple error-email addresses, comma
    # separated in luigi.cfg
    recipients_tmp = []
    for r in recipients:
        recipients_tmp.extend([a.strip() for a in r.split(',') if a.strip()])

    # Replace original recipients with the clean list
    recipients = recipients_tmp

    # Get appropriate sender and call it to send the notification
    email_sender_type = config.get('email', 'type', None)
    email_sender = notifiers.get(email_sender_type, send_email_smtp)
    email_sender(config, sender, subject, message, recipients, image_png)
コード例 #24
0
ファイル: db_task_history.py プロジェクト: TrueCar/luigi
 def __init__(self):
     config = configuration.get_config()
     connection_string = config.get('task_history', 'db_connection')
     self.engine = sqlalchemy.create_engine(connection_string)
     self.session_factory = sqlalchemy.orm.sessionmaker(bind=self.engine, expire_on_commit=False)
     Base.metadata.create_all(self.engine)
     self.tasks = {}  # task_id -> TaskRecord
コード例 #25
0
ファイル: notifications_test.py プロジェクト: TodayTix/luigi
    def test_sends_smtp_email_exceptions(self):
        """
        Call notificaions.send_email_smtp when it cannot connect to smtp server (socket.error)
        starttls.
        """
        smtp_kws = {"host": "my.smtp.local",
                    "port": 999,
                    "local_hostname": "ptms",
                    "timeout": 1200}

        with mock.patch('smtplib.SMTP') as SMTP:
            with mock.patch('luigi.notifications.generate_email') as generate_email:
                SMTP.side_effect = socket.error()
                generate_email.return_value \
                    .as_string.return_value = self.mocked_email_msg

                try:
                    notifications.send_email_smtp(configuration.get_config(),
                                                  *self.notification_args)
                except socket.error:
                    self.fail("send_email_smtp() raised expection unexpectedly")

                SMTP.assert_called_once_with(**smtp_kws)
                self.assertEqual(notifications.generate_email.called, False)
                self.assertEqual(SMTP.sendemail.called, False)
コード例 #26
0
 def __init__(self):
     configuration = get_config()
     self.partner = configuration.get("paypal", "partner", "PayPal")
     self.vendor = configuration.get("paypal", "vendor")
     self.password = configuration.get("paypal", "password")
     self.user = configuration.get("paypal", "user", None)
     self.url = configuration.get("paypal", "url")
コード例 #27
0
ファイル: scheduler.py プロジェクト: minkyoungkook/luigi
    def __init__(self, config=None, resources=None, task_history_impl=None, **kwargs):
        """
        Keyword Arguments:
        :param config: an object of class "scheduler" or None (in which the global instance will be used)
        :param resources: a dict of str->int constraints
        :param task_history_override: ignore config and use this object as the task history
        """
        self._config = config or scheduler(**kwargs)
        self._state = SimpleTaskState(self._config.state_path)

        if task_history_impl:
            self._task_history = task_history_impl
        elif self._config.record_task_history:
            from luigi import db_task_history  # Needs sqlalchemy, thus imported here

            self._task_history = db_task_history.DbTaskHistory()
        else:
            self._task_history = history.NopHistory()
        self._resources = resources or configuration.get_config().getintdict(
            "resources"
        )  # TODO: Can we make this a Parameter?
        self._make_task = functools.partial(
            Task,
            disable_failures=self._config.disable_failures,
            disable_hard_timeout=self._config.disable_hard_timeout,
            disable_window=self._config.disable_window,
        )
コード例 #28
0
ファイル: webhdfs.py プロジェクト: RUNDSP/luigi
 def get_config(self, key):
     config = configuration.get_config()
     try:
         return config.get('hdfs', key)
     except:
         raise RuntimeError("You must specify %s in the [hdfs] section of "
                            "the luigi client.cfg file" % key)
コード例 #29
0
ファイル: paypal.py プロジェクト: edx/edx-analytics-pipeline
    def run(self):
        self.remove_output_on_overwrite()

        ppr = PaypalReportRequest(
            'SettlementReport',
            processor='PayPal',
            start_date=self.date.isoformat() + ' 00:00:00',
            end_date=self.date.isoformat() + ' 23:59:59',
            timezone='GMT'
        )
        report_response = ppr.execute()
        report_id = report_response.report_id

        is_running = report_response.is_running
        timeout = get_config().getint('paypal', 'timeout', 60 * 60 * 2)
        start_time = time.time()
        while is_running:
            if timeout >= 0 and time.time() >= (start_time + timeout):
                raise PaypalTimeoutError(start_time)
            time.sleep(5)
            results_response = PaypalReportResultsRequest(report_id=report_id).execute()
            is_running = results_response.is_running

        metadata_response = PaypalReportMetadataRequest(report_id=report_id).execute()

        if metadata_response.num_rows < 1 and not self.is_empty_transaction_allowed:
            raise Exception('No transactions to process.')

        with self.output().open('w') as output_tsv_file:
            for page_num in range(metadata_response.num_pages):
                data_response = PaypalReportDataRequest(report_id=report_id, page_num=(page_num + 1)).execute()
                for row in data_response.rows:
                    self.write_transaction_record(row, output_tsv_file)
コード例 #30
0
ファイル: spark.py プロジェクト: genba/luigi
 def run(self):
     spark_submit = configuration.get_config().get('spark', 'spark-submit',
                                                   'spark-submit')
     options = [
         '--class', self.job_class(),
         '--master', 'yarn-client',
     ]
     if self.num_executors is not None:
         options += ['--num-executors', self.num_executors]
     if self.driver_memory is not None:
         options += ['--driver-memory', self.driver_memory]
     if self.executor_memory is not None:
         options += ['--executor-memory', self.executor_memory]
     if self.executor_cores is not None:
         options += ['--executor-cores', self.executor_cores]
     dependency_jars = self.dependency_jars()
     if dependency_jars != []:
         options += ['--jars', ','.join(dependency_jars)]
     args = [spark_submit] + options + self.spark_options() + \
         [self.jar()] + list(self.job_args())
     args = map(str, args)
     env = os.environ.copy()
     temp_stderr = tempfile.TemporaryFile()
     logger.info('Running: {0}'.format(repr(args)))
     proc = subprocess.Popen(args, stdout=subprocess.PIPE,
                             stderr=temp_stderr, env=env, close_fds=True)
     return_code, final_state, app_id = self.track_progress(proc)
     if final_state == 'FAILED':
         raise SparkJobError('Spark job failed: see yarn logs for {0}'
                             .format(app_id))
     elif return_code != 0:
         temp_stderr.seek(0)
         errors = temp_stderr.readlines()
         logger.error(errors)
         raise SparkJobError('Spark job failed', err=errors)
コード例 #31
0
 def run(self):
     spark_submit = configuration.get_config().get('spark', 'spark-submit',
                                                   'spark-submit')
     options = [
         '--class',
         self.job_class(),
     ]
     if self.num_executors is not None:
         options += ['--num-executors', self.num_executors]
     if self.driver_memory is not None:
         options += ['--driver-memory', self.driver_memory]
     if self.executor_memory is not None:
         options += ['--executor-memory', self.executor_memory]
     if self.executor_cores is not None:
         options += ['--executor-cores', self.executor_cores]
     if self.deploy_mode is not None:
         options += ['--deploy-mode', self.deploy_mode]
     if self.queue is not None:
         options += ['--queue', self.queue]
     if self.spark_master is not None:
         options += ['--master', self.spark_master]
     dependency_jars = self.dependency_jars()
     if dependency_jars != []:
         options += ['--jars', ','.join(dependency_jars)]
     args = [spark_submit] + options + self.spark_options() + \
         [self.jar()] + list(self.job_args())
     args = map(str, args)
     env = os.environ.copy()
     temp_stderr = tempfile.TemporaryFile()
     logger.info('Running: {0}'.format(repr(args)))
     proc = subprocess.Popen(args,
                             stdout=subprocess.PIPE,
                             stderr=temp_stderr,
                             env=env,
                             close_fds=True)
     return_code, final_state, app_id = self.track_progress(proc)
     if final_state == 'FAILED':
         raise SparkJobError(
             'Spark job failed: see yarn logs for {0}'.format(app_id))
     elif return_code != 0:
         temp_stderr.seek(0)
         errors = "".join(temp_stderr.readlines())
         logger.error(errors)
         raise SparkJobError('Spark job failed', err=errors)
コード例 #32
0
class LDAParams(luigi.ExternalTask):
    """
    Simulate parameters for an LDA model

    This generates the beta and theta parameters that are the basis for the rest of the simulation.

    Arguments:
      D (int): How many samples are there in this experiment?
      V (int): How many terms are there across samples?
      K (int): How many topics are there?
      alpha0 (float): What is the true theta parameter prior used in
        generating data?
      gamma0 (float): What is the true beta parameter prior used in generating
        data?
    """
    D = luigi.Parameter()
    V = luigi.Parameter()
    K = luigi.Parameter()
    alpha0 = luigi.Parameter()
    gamma0 = luigi.Parameter()

    conf = configuration.get_config()

    def run(self):
        gen_id = hash_string("".join(
            [self.D, self.V, self.K, self.alpha0, self.gamma0]))
        run_cmd = [
            "Rscript",
            self.conf.get("expers", "param_script"),
            self.conf.get("expers", "output_dir"), gen_id, self.D, self.V,
            self.K, self.alpha0, self.gamma0
        ]
        run_and_check(run_cmd)

    def output(self):
        gen_id = hash_string("".join(
            [self.D, self.V, self.K, self.alpha0, self.gamma0]))
        output_dir = self.conf.get("expers", "output_dir")
        return [
            luigi.LocalTarget(
                os.path.join(output_dir, "beta-" + gen_id + ".feather")),
            luigi.LocalTarget(
                os.path.join(output_dir, "theta-" + gen_id + ".feather"))
        ]
コード例 #33
0
ファイル: notifications_test.py プロジェクト: sabtra73/luigi
    def test_sns_subject_is_shortened(self):
        """
        Call notificaions.send_email_sns with too long Subject (more than 100 chars)
        and check that it is cut to lenght of 100 chars.
        """

        long_subject = 'Luigi: SanityCheck(regexPattern=aligned-source\\|data-not-older\\|source-chunks-compl,'\
                       'mailFailure=False, mongodb=mongodb://localhost/stats) FAILED'

        with mock.patch('boto3.resource') as res:
            notifications.send_email_sns(configuration.get_config(),
                                         self.sender, long_subject, self.message,
                                         self.recipients, self.image_png)

            SNS = res.return_value
            SNS.Topic.assert_called_once_with(self.recipients[0])
            called_subj = SNS.Topic.return_value.publish.call_args[1]['Subject']
            self.assertTrue(len(called_subj) <= 100,
                            "Subject can be max 100 chars long! Found {}.".format(len(called_subj)))
コード例 #34
0
def send_error_email(subject, message):
    """
    Sends an email to the configured error-email.

    If no error-email is configured, then a message is logged.
    """
    config = configuration.get_config()
    receiver = config.get('core', 'error-email', None)
    if receiver:
        sender = config.get('core', 'email-sender', DEFAULT_CLIENT_EMAIL)
        logger.info("Sending warning email to %r", receiver)
        send_email(subject=subject,
                   message=message,
                   sender=sender,
                   recipients=(receiver, ))
    else:
        logger.info("Skipping error email. Set `error-email` in the `core` "
                    "section of the luigi config file to receive error "
                    "emails.")
コード例 #35
0
ファイル: notifications_test.py プロジェクト: sabtra73/luigi
    def test_sends_ses_email(self):
        """
        Call notificaions.send_email_ses with fixture parameters
        and check that boto is properly called.
        """

        with mock.patch('boto3.client') as boto_client:
            with mock.patch('luigi.notifications.generate_email') as generate_email:
                generate_email.return_value\
                    .as_string.return_value = self.mocked_email_msg

                notifications.send_email_ses(configuration.get_config(),
                                             *self.notification_args)

                SES = boto_client.return_value
                SES.send_raw_email.assert_called_once_with(
                    Source=self.sender,
                    Destinations=self.recipients,
                    RawMessage={'Data': self.mocked_email_msg})
コード例 #36
0
def create_spawner(turbsim_exe, fast_exe, turbsim_base_file, fast_base_file,
                   fast_version, runner_type, turbsim_working_dir,
                   fast_working_dir, outdir, prereq_outdir):
    """

    :param turbsim_exe: Location of TurbSim executable
    :param fast_exe: Location of FAST executable
    :param turbsim_base_file: Baseline TurbSim input file (typically `TurbSim.inp`)
        from which wind file generation tasks are spawned
    :param fast_base_file: FAST input file (typically `.fst`) to which all parameter
        editions are made and from which simulations are spawned
    :param fast_version: Major version of FAST {'v7', 'v8'}
    :param runner_type: default is `process`
    :param turbsim_working_dir: Directory in which TurbSim wind generation tasks are executed
    :param fast_working_dir: Directory in which FAST simulations are executed.
        Note that the discon.dll must be in this directory
    :param outdir: Root output directory for spawning and thus where simulation outputs are located
    :param prereq_outdir: Root output directory for prerequisite tasks (i.e. wind file generation)
    :returns: `FastSimulationSpawner` object
    """
    validate_file(turbsim_exe, 'turbsim_exe')
    validate_file(fast_exe, 'fast_exe')
    validate_file(turbsim_base_file, 'turbsim_base_file')
    validate_file(fast_base_file, 'fast_base_file')
    validate_dir(turbsim_working_dir, 'turbsim_working_dir')
    validate_dir(fast_working_dir, 'fast_working_dir')

    luigi_config = configuration.get_config()

    luigi_config.set(WindGenerationTask.__name__, '_exe_path', turbsim_exe)
    luigi_config.set(WindGenerationTask.__name__, '_runner_type', runner_type)
    luigi_config.set(WindGenerationTask.__name__, '_working_dir',
                     turbsim_working_dir)
    luigi_config.set(FastSimulationTask.__name__, '_exe_path', fast_exe)
    luigi_config.set(FastSimulationTask.__name__, '_runner_type', runner_type)
    luigi_config.set(FastSimulationTask.__name__, '_working_dir',
                     fast_working_dir)

    wind_spawner = TurbsimSpawner(TurbsimInput.from_file(turbsim_base_file))
    fast_input_cls = {'v7': Fast7Input, 'v8': Fast8Input}
    return FastSimulationSpawner(
        fast_input_cls[fast_version].from_file(fast_base_file), wind_spawner,
        path.join(outdir, prereq_outdir))
コード例 #37
0
    def __init__(self, url='http://localhost:8082/', connect_timeout=None):
        assert not url.startswith('http+unix://') or HAS_UNIX_SOCKET, (
            'You need to install requests-unixsocket for Unix socket support.')

        self._url = url.rstrip('/')
        config = configuration.get_config()

        if connect_timeout is None:
            connect_timeout = config.getfloat('core', 'rpc-connect-timeout',
                                              10.0)
        self._connect_timeout = connect_timeout

        self._rpc_retry_attempts = config.getint('core', 'rpc-retry-attempts',
                                                 3)
        self._rpc_retry_wait = config.getint('core', 'rpc-retry-wait', 30)

        if HAS_REQUESTS:
            self._fetcher = RequestsFetcher(requests.Session())
        else:
            self._fetcher = URLLibFetcher()
コード例 #38
0
class Tweets_String(SparkSubmitTask):
    sighting_date = luigi.DateParameter()
    bucket = configuration.get_config().get('etl','bucket')
    def requires(self):
        return ReadContainer()

    @property
    def name(self):
        return 'Tweets_String'

    def app_options(self):
        return [self.input().path, self.output().path]

    @property
    def app(self):
        return 'tweets_string.py'


    def output(self):
        return luigi.file.LocalTarget('/home/dpa_worker/model_data/{}{}{}tweets.json'.format(self.sighting_date.year,self.sighting_date.month,self.sighting_date.day))
コード例 #39
0
ファイル: scheduler.py プロジェクト: sabtra73/luigi
    def __init__(self, config=None, resources=None, task_history_impl=None, **kwargs):
        """
        Keyword Arguments:
        :param config: an object of class "scheduler" or None (in which the global instance will be used)
        :param resources: a dict of str->int constraints
        :param task_history_impl: ignore config and use this object as the task history
        """
        self._config = config or scheduler(**kwargs)
        self._state = SimpleTaskState(self._config.state_path)

        if task_history_impl:
            self._task_history = task_history_impl
        elif self._config.record_task_history:
            from luigi import db_task_history  # Needs sqlalchemy, thus imported here
            self._task_history = db_task_history.DbTaskHistory()
        else:
            self._task_history = history.NopHistory()
        self._resources = resources or configuration.get_config().getintdict('resources')  # TODO: Can we make this a Parameter?
        self._make_task = functools.partial(Task, retry_policy=self._config._get_retry_policy())
        self._worker_requests = {}
コード例 #40
0
def create_manifest_target(manifest_id, targets):
    # If we are running locally, we need our manifest file to be a local file target, however, if we are running on
    # a real Hadoop cluster, it has to be an HDFS file so that the input format can read it. Luigi makes it a little
    # difficult for us to construct a target that can be one or the other of those types of targets at runtime since
    # it relies on inheritance to signify the difference. We hack the inheritance here, by dynamically choosing the
    # base class at runtime based on the URL of the manifest file.

    # Construct the manifest file URL from the manifest_id and the configuration
    base_url = configuration.get_config().get(CONFIG_SECTION, 'path')
    manifest_file_path = url_path_join(base_url, manifest_id + '.manifest')

    # Figure out the type of target that should be used to write/read the file.
    manifest_file_target_class, init_args, init_kwargs = get_target_class_from_url(manifest_file_path)

    # Ensure our constructed target inherits from the appropriate type of file target.
    class ManifestInputTarget(ManifestInputTargetMixin, manifest_file_target_class):
        pass

    # This functionality is inherited from the Mixin which contains all of the substantial logic
    return ManifestInputTarget.from_existing_targets(targets, *init_args, **init_kwargs)
コード例 #41
0
def send_error_email(subject, message, additional_recipients=None):
    """
    Sends an email to the configured error-email.

    If no error-email is configured, then a message is logged.
    """
    config = configuration.get_config()
    recipients = _email_recipients(additional_recipients)
    if recipients:
        sender = config.get('core', 'email-sender', DEFAULT_CLIENT_EMAIL)
        logger.info("Sending warning email to %r", recipients)
        send_email(subject=subject,
                   message=message,
                   sender=sender,
                   recipients=recipients)
    else:
        logger.info(
            "Skipping error email. Set `error-email` in the `core` "
            "section of the luigi config file or override `owner_email`"
            "in the task to receive error emails.")
コード例 #42
0
ファイル: json_input_file.py プロジェクト: BitBloomTech/spawn
def create_spawner(task_exe, working_dir, base_file, runner_type):
    """
    Creates spawner that creates tasks taking a single JSON input file as command line argument

    :param task_exe: Path of executable to run. If None, input files will be written but no tasks run
    :param working_dir: Working directory of task execution
    :param base_file: Baseline JSON file on which to make parameter editions and additions. If None, parameter additions
     will be made onto an empty input
    :return: :class:`SingleInputFileSpawner` object
    """
    if task_exe is not None:
        validate_file(task_exe, 'task_exe')
    if working_dir is not None:
        validate_dir(working_dir, 'working_dir')
    if base_file is not None:
        validate_file(base_file, 'base_file')

    if task_exe is None:
        task_exe = ''
    if working_dir is None:
        working_dir = '.'

    logger = logging.getLogger(__name__)
    logger.info("Creating Single input file spawner with JSON input")
    logger.info("task_exe = %s", task_exe)
    logger.info("working_dir = %s", working_dir)
    logger.info("base_file = %s", base_file)
    logger.info("runner_type = %s", runner_type)

    luigi_config = configuration.get_config()
    luigi_config.set(SimulationTask.__name__, '_exe_path', task_exe)
    luigi_config.set(SimulationTask.__name__, '_working_dir', working_dir)
    luigi_config.set(SimulationTask.__name__, '_runner_type', runner_type)

    if base_file is not None:
        with open(base_file, 'r') as fp:
            params = json.load(fp)
    else:
        params = {}
    sim_input = JsonSimulationInput(params, indent=2)
    return SingleInputFileSpawner(sim_input, 'input.json')
コード例 #43
0
ファイル: compranet.py プロジェクト: rsanchezavalos/compranet
class IngestPipeline(luigi.WrapperTask):
    """
        Este wrapper ejecuta la ingesta de cada pipeline-task

    Input Args (From luigi.cfg):
        pipelines: lista con los pipeline-tasks especificados a correr.

    """

    year_month = luigi.Parameter()
    conf = configuration.get_config()
    pipelines = parse_cfg_list(conf.get("IngestPipeline", "pipelines"))

    #python_pipelines = parse_cfg_list(conf.get("IngestPipeline", "python_pipelines"))

    def requires(self):

        for pipeline_task in self.pipelines:

            yield UpdateDB(pipeline_task=pipeline_task,
                           year_month=self.year_month)
コード例 #44
0
def get_pgdict_from_cfg():
    """
    loads postgres configuration from luigi config file
    """
    try:
        cfg = configuration.get_config()
        pghost = cfg.get('postgres', 'host')
        pgdb = cfg.get('postgres', 'database')
        pguser = cfg.get('postgres', 'user')
        pgpassword = cfg.get('postgres', 'password')

        dbitems = {
            'PGUSER': pguser,
            'PGPASSWORD': pgpassword,
            'PGHOST': pghost,
            'PGDATABASE': pgdb
        }

        return dbitems
    except:
        return None
コード例 #45
0
    def should_include_answer(self, answer):
        """Determine if a problem "part" should be included in the distribution."""
        response_type = answer.get('response_type')

        # For problems which only have old responses, we don't
        # have information about whether to include their answers.
        if response_type is None:
            return False

        # read out valid types from client.cfg file.  The 3rd argument below sets a default in case the
        # config file is somehow misread.  But to change the list, please update the client.cfg
        valid_type_str = get_config().get(
            'answer-distribution', 'valid_response_types',
            'choiceresponse,optionresponse,multiplechoiceresponse,numericalresponse,stringresponse,formularesponse'
        )

        valid_types = set(valid_type_str.split(","))
        if response_type in valid_types:
            return True

        return False
コード例 #46
0
def convert_tasks_to_manifest_if_necessary(input_tasks):  # pylint: disable=invalid-name
    """
    Provide a manifest for the input paths if there are too many of them.

    The configuration section "manifest" can contain a "threshold" option which, when exceeded, causes this function
    to return a URLManifestTask instead of the original input_tasks.
    """
    all_input_tasks = task.flatten(input_tasks)
    targets = task.flatten(task.getpaths(all_input_tasks))
    threshold = configuration.get_config().getint(CONFIG_SECTION, 'threshold',
                                                  -1)
    if threshold > 0 and len(targets) >= threshold:
        log.debug(
            'Using manifest since %d inputs are greater than or equal to the threshold %d',
            len(targets), threshold)
        return [URLManifestTask(urls=[target.path for target in targets])]
    else:
        log.debug(
            'Directly processing files since %d inputs are less than the threshold %d',
            len(targets), threshold)
        return all_input_tasks
コード例 #47
0
    def __init__(self,
                 auth_url=None,
                 client_id=None,
                 client_secret=None,
                 oauth_username=None,
                 oauth_password=None,
                 token_type=None):

        self._expires_at = None
        self._session = requests.Session()
        self._session.hooks = {'response': log_response_hook}

        config = configuration.get_config()
        self.client_id = client_id or config.get('edx-rest-api', 'client_id')
        self.client_secret = client_secret or config.get(
            'edx-rest-api', 'client_secret')
        self.auth_url = auth_url or config.get('edx-rest-api', 'auth_url')
        self.token_type = token_type or 'jwt'
        self.oauth_username = oauth_username or config.get(
            'edx-rest-api', 'oauth_username', None)
        self.oauth_password = oauth_password or config.get(
            'edx-rest-api', 'oauth_password', None)
コード例 #48
0
class CreateSchema(luigi.Task):
    """
    Class to create a schema in the sedesol DB

    This expects a parameter giving the name of the schema to create.
    It has no dependencies.
    """
    schema_name = luigi.Parameter()
    conf = configuration.get_config()
    output_param = (conf.get('etl', 'logging_path') + '%s_schema_created.log')

    def requires(self):
        query_string = "CREATE SCHEMA IF NOT EXISTS %s;" % self.schema_name
        return pg_sed.QueryString(query_string)

    def run(self):
        output_path = self.output_param % self.schema_name
        open(output_path, 'a').close()  # touch the output file

    def output(self):
        output_path = self.output_param % self.schema_name
        return luigi.LocalTarget(output_path)
コード例 #49
0
class CreateIndices(luigi.Task):
    """
    This executes the queries in the clean_table_indexer path

    The queries have to be done in sequence, because we need to first drop any
    queries before creating new ones. We also can't use the QueryString() class,
    because that doesn't have a require statement (and will try making indices
    indices before the clean tables are even created).
    """
    pipeline_task = luigi.Parameter()
    conf = configuration.get_config()
    logging_path = conf.get('etl', 'logging_path')

    def requires(self):
        return RawToCleanSchema(self.pipeline_task)

    def run(self):
        query_filename = self.conf.get(self.pipeline_task,
                                       'clean_table_indexer')
        connection = pg_sed.db_connection()
        connection.autocommit = True
        cursor = connection.cursor()

        all_queries = open(query_filename, 'r').read().split(';')
        for query in all_queries:
            if not query.isspace():
                logger.info("Executing query: " + query)
                cursor.execute(query)
                connection.commit()

        connection.close()
        output_path = "%s%s_clean_indices_created.log" % (self.logging_path,
                                                          self.pipeline_task)
        open(output_path, 'a').close()

    def output(self):
        output_path = "%s%s_clean_indices_created.log" % (self.logging_path,
                                                          self.pipeline_task)
        return luigi.LocalTarget(output_path)
コード例 #50
0
ファイル: notifications.py プロジェクト: leochencipher/luigi
def send_email(subject, message, sender, recipients, image_png=None):
    config = configuration.get_config()

    subject = _prefix(subject)
    logger.debug(
        "Emailing:\n"
        "-------------\n"
        "To: %s\n"
        "From: %s\n"
        "Subject: %s\n"
        "Message:\n"
        "%s\n"
        "-------------", recipients, sender, subject, message)
    if not recipients or recipients == (None, ):
        return
    if (sys.stdout.isatty() or
            DEBUG) and (not config.getboolean('email', 'force-send', False)):
        logger.info(
            "Not sending email when running from a tty or in debug mode")
        return

    # Clean the recipients lists to allow multiple error-email addresses, comma
    # separated in client.cfg
    recipients_tmp = []
    for r in recipients:
        recipients_tmp.extend(r.split(','))

    # Replace original recipients with the clean list
    recipients = recipients_tmp

    email_sender_type = config.get('email', 'type', None)
    if email_sender_type == "ses":
        send_email_ses(config, sender, subject, message, recipients, image_png)
    elif email_sender_type == "sendgrid":
        send_email_sendgrid(config, sender, subject, message, recipients,
                            image_png)
    else:
        send_email_smtp(config, sender, subject, message, recipients,
                        image_png)
コード例 #51
0
    def requires(self):
        config = get_config()
        for merchant_id in self.cybersource_merchant_ids:
            section_name = 'cybersource:' + merchant_id
            interval_start = luigi.DateParameter().parse(
                config.get(section_name, 'interval_start'))
            interval_end = self.import_date

            merchant_close_date = config.get(section_name,
                                             'merchant_close_date', '')
            if merchant_close_date:
                parsed_date = luigi.DateParameter().parse(merchant_close_date)
                interval_end = min(self.import_date, parsed_date)

            cybersource_interval = date_interval.Custom(
                interval_start, interval_end)

            for date in cybersource_interval:
                filename = "cybersource_{}.tsv".format(merchant_id)
                url = url_path_join(self.warehouse_path, 'payments',
                                    'dt=' + date.isoformat(), filename)
                yield ExternalURL(url=url)
コード例 #52
0
    def job_runner(self):

        config = configuration.get_config()
        venv_path = self.venv_path
        if not venv_path.lower().endswith(".zip"):
            if venv_path.startswith('hdfs://'):
                raise ValueError(
                    "Cannot automatically compress a venv located on HDFS"
                )
            venv_path = self._create_venv_archive(venv_path)

        python_excutable = config.get('hadoop', 'python-executable', 'python')
        self.old_python_executable = python_excutable
        symbolic = venv_path.split('/')[-1].split('.')[0]
        venv_archive = "{}#{}".format(venv_path, symbolic)
        python_executable = "{}/bin/{}".format(symbolic, python_excutable)
        config.set('hadoop', 'python-executable', python_executable)

        return TYDefaultHadoopJobRunner(
            archives=[venv_archive],
            libjars=self.libjars,
        )
コード例 #53
0
class UnigramParams(luigi.ExternalTask):
    """
    Simulate parameters for a Unigram model

    This generates the parameters mu[t] for a particular instance of the
    dynamic unigram model.

    Arguments:
      D (int): How many samples are there in this experiment?
      V (int): How many terms are there across samples?
      sigma0 (float): What is the true sigma random walk size parameter used in
      generating the data?
    """
    D = luigi.Parameter()
    V = luigi.Parameter()
    sigma0 = luigi.Parameter()

    conf = configuration.get_config()

    def run(self):
        print("test")
        gen_id = hash_string("".join([self.D, self.V, self.sigma0]))
        print(gen_id)
        run_cmd = [
            "Rscript",
            self.conf.get("expers", "param_script"),
            self.conf.get("expers", "output_dir"), gen_id, self.D, self.V,
            self.sigma0
        ]
        run_and_check(run_cmd)

    def output(self):
        gen_id = hash_string("".join([self.D, self.V, self.sigma0]))
        output_dir = self.conf.get("expers", "output_dir")
        return [
            luigi.LocalTarget(
                os.path.join(output_dir, "mu-" + gen_id + ".feather"))
        ]
コード例 #54
0
class MicrobiomePred(luigi.WrapperTask):
    conf = configuration.get_config()

    def requires(self):
        ps_path = os.path.join(self.conf.get("paths", "project_dir"),
                               self.conf.get("paths", "phyloseq"))

        ensemble = pf.values_from_conf(self.conf, "ensemble")

        tasks = []
        for i in ensemble.keys():
            tasks.append(EnsembleEval(i))

        exper = pf.values_from_conf(self.conf, "experiment")
        for i in exper.keys():
            for k in ["all", "all-cv"] + list(range(1, exper[i]["k_folds"])):
                tasks.append(
                    CVEval(ps_path, exper[i]["preprocessing"],
                           str(exper[i]["validation_prop"]),
                           str(exper[i]["k_folds"]), exper[i]["features"],
                           exper[i]["model"], str(k), exper[i]["metrics"]))

        return tasks
コード例 #55
0
class LoadCVIndices(luigi.WrapperTask):
    """
    Create CV indices for each current subset of the semantic table
    """
    pipeline_task = luigi.Parameter()
    conf = configuration.get_config()

    def requires(self):

        experiments_path = self.conf.get(self.pipeline_task, "experiments")
        subsets = pg_sed.unique_experiment_subsets(
            experiments_path,
            self.conf,
            self.pipeline_task
        )

        tasks = []
        for subset in subsets:
            tasks.append(
                LoadCVIndex(self.pipeline_task, subset_table=subset[0])
            )

        return tasks
コード例 #56
0
ファイル: run_models.py プロジェクト: valencig/sedesol-public
class RunModels(luigi.WrapperTask):
    """
    Run, evaluate, and load many models

    This runs models for different model / parameter / subset configurations,
    as specified in the luigi.cfg file, and saves the resulting models to the
    database.
    """
    pipeline_task = luigi.Parameter()
    conf = configuration.get_config()
    k_folds = conf.get("shared", "k_folds")

    def requires(self):
        """
        Set off a task for all the model / parameter / fold combinations

        This is heavily influenced by https://github.com/rayidghani/magicloops/
        """
        experiments_path = self.conf.get(self.pipeline_task, "experiments")
        with open(experiments_path) as experiments_file:
            experiments_dict = json.load(experiments_file)

        tasks = []
        for exper_id, exper in experiments_dict.items():
            params = {
                param: ast.literal_eval(values)
                for param, values in exper["model"]["params"].items()
            }
            parameter_combns = ParameterGrid(params)

            for theta in parameter_combns:
                for cur_fold in range(int(self.k_folds)):
                    cur_task = RunModel(exper_id, str(theta), str(cur_fold),
                                        self.pipeline_task)
                    tasks.append(cur_task)

        return tasks
コード例 #57
0
class LoadNeo4j(luigi.WrapperTask):
    """
    Clase que crea la infraestructura de Neo4j para poder correr queries y modelos
    """

    conf = configuration.get_config()
    logging_path = conf.get('etl', 'logging_path')

    def requires(self):
        return LoadNeo4jTalentos(), LoadNeo4jProyectos()

    def run(self):

        # Al final:
        # Crear un archivo (vacío) para crear el log de carga de talentos a Neo4j
        output_path = os.path.join(self.logging_path, "LoadedBothNeo4j")
        open(output_path, "a").close()

    def output(self):
        """
        Output es un log file indicando que se cargaron correctamente ambas bases a Neo4j
        """
        output_path = os.path.join(self.logging_path, "LoadedBothNeo4j")
        return luigi.LocalTarget(output_path)
コード例 #58
0
 def run(self):
     spark_submit = configuration.get_config().get('spark', 'spark-submit',
                                                   'spark-submit')
     options = ['--master', 'yarn-client']
     if self.num_executors is not None:
         options += ['--num-executors', self.num_executors]
     if self.driver_memory is not None:
         options += ['--driver-memory', self.driver_memory]
     if self.executor_memory is not None:
         options += ['--executor-memory', self.executor_memory]
     if self.executor_cores is not None:
         options += ['--executor-cores', self.executor_cores]
     py_files = self.py_files()
     if py_files != []:
         options += ['--py-files', ','.join(py_files)]
     args = [spark_submit] + options + self.spark_options() + \
         [self.program()] + list(self.job_args())
     args = map(str, args)
     env = os.environ.copy()
     temp_stderr = tempfile.TemporaryFile()
     logger.info('Running: %s', repr(args))
     proc = subprocess.Popen(args,
                             stdout=subprocess.PIPE,
                             stderr=temp_stderr,
                             env=env,
                             close_fds=True)
     return_code, final_state, app_id = self.track_progress(proc)
     if final_state == 'FAILED':
         raise SparkJobError('Spark job failed: see yarn logs for %s',
                             app_id)
     elif return_code != 0:
         temp_stderr.seek(0)
         errors = "".join(
             (x.decode('utf8') for x in temp_stderr.readlines()))
         logger.error(errors)
         raise SparkJobError('Spark job failed', err=errors)
コード例 #59
0
class UnigramData(luigi.Task):
    """
    Simulate data according to a Unigram model

    Arguments:
      D (int): How many samples are there in this experiment?
      N (int): How many words are there in each sample?
      V (int): How many terms are there across samples?
      sigma0 (float): What is the true sigma random walk size parameter used in
      generating the data?
    """
    D = luigi.Parameter()
    N = luigi.Parameter()
    V = luigi.Parameter()
    sigma0 = luigi.Parameter()

    conf = configuration.get_config()

    def requires(self):
        return UnigramParams(self.D, self.V, self.sigma0)

    def run(self):
        mu_path = self.input()[0].open("r").name
        gen_id = hash_string("".join([self.D, self.N, self.V, self.sigma0]))
        run_cmd = [
            "Rscript",
            self.conf.get("expers", "sim_script"),
            self.conf.get("expers", "output_dir"), gen_id, self.N, mu_path
        ]
        run_and_check(run_cmd)

    def output(self):
        gen_id = hash_string("".join([self.D, self.N, self.V, self.sigma0]))
        output_dir = self.conf.get("expers", "output_dir")
        return luigi.LocalTarget(
            os.path.join(output_dir, "x-" + gen_id + ".feather"))
コード例 #60
0
 def pig_home(self):
     return configuration.get_config().get('pig', 'home', '/usr/share/pig')