def proto1_to_proto2(message, body): """Converts Task message protocol 1 arguments to protocol 2. Returns tuple of ``(body, headers, already_decoded_status, utc)`` """ try: args, kwargs = body['args'], body['kwargs'] kwargs.items except KeyError: raise InvalidTaskError('Message does not have args/kwargs') except AttributeError: raise InvalidTaskError( 'Task keyword arguments must be a mapping', ) body.update( argsrepr=saferepr(args), kwargsrepr=saferepr(kwargs), headers=message.headers, ) try: body['group'] = body['taskset'] except KeyError: pass return (args, kwargs), body, True, body.get('utc', True)
def proto1_to_proto2(message, body): """Convert Task message protocol 1 arguments to protocol 2. Returns: Tuple: of ``(body, headers, already_decoded_status, utc)`` """ try: args, kwargs = body.get("args", ()), body.get("kwargs", {}) kwargs.items # pylint: disable=pointless-statement except KeyError: raise InvalidTaskError("Message does not have args/kwargs") except AttributeError: raise InvalidTaskError("Task keyword arguments must be a mapping", ) body.update( argsrepr=saferepr(args), kwargsrepr=saferepr(kwargs), headers=message.headers, ) try: body["group"] = body["taskset"] except KeyError: pass embed = { "callbacks": body.get("callbacks"), "errbacks": body.get("errbacks"), "chord": body.get("chord"), "chain": None, } return (args, kwargs, embed), body, True, body.get("utc", True)
def proto1_to_proto2(message, body): """Convert Task message protocol 1 arguments to protocol 2. Returns: Tuple: of ``(body, headers, already_decoded_status, utc)`` """ try: args, kwargs = body.get('args', ()), body.get('kwargs', {}) kwargs.items # pylint: disable=pointless-statement except KeyError: raise InvalidTaskError('Message does not have args/kwargs') except AttributeError: raise InvalidTaskError( 'Task keyword arguments must be a mapping', ) body.update( argsrepr=saferepr(args), kwargsrepr=saferepr(kwargs), headers=message.headers, ) try: body['group'] = body['taskset'] except KeyError: pass embed = { 'callbacks': body.get('callbacks'), 'errbacks': body.get('errbacks'), 'chord': body.get('chord'), 'chain': None, } return (args, kwargs, embed), body, True, body.get('utc', True)
def dispose_compute_resource(self, task_data, **kwargs): from ..models import Job, ComputeResource if task_data is None: raise InvalidTaskError("task_data is None") compute_resource_id = task_data.get('compute_resource_id', None) # result = task_data.get('result') if not compute_resource_id: job_id = task_data.get('job_id') job = Job.objects.get(id=job_id) compute_resource_id = job.compute_resource.id compute = ComputeResource.objects.get(id=compute_resource_id) self.status = ComputeResource.STATUS_TERMINATING self.save() # TODO: Terminate the compute resource # (different depending on cloud provider, resource type) raise NotImplementedError() ################################################################ self.status = ComputeResource.STATUS_DECOMMISSIONED self.save() return task_data
def __init__(self, body, on_ack=noop, hostname=None, eventer=None, app=None, connection_errors=None, request_dict=None, delivery_info=None, task=None, **opts): self.app = app or app_or_default(app) name = self.name = body['task'] self.id = body['id'] self.args = body.get('args', []) self.kwargs = body.get('kwargs', {}) try: self.kwargs.items except AttributeError: raise InvalidTaskError('Task keyword arguments is not a mapping') if NEEDS_KWDICT: self.kwargs = kwdict(self.kwargs) eta = body.get('eta') expires = body.get('expires') utc = self.utc = body.get('utc', False) self.on_ack = on_ack self.hostname = hostname or socket.gethostname() self.eventer = eventer self.connection_errors = connection_errors or () self.task = task or self.app.tasks[name] self.acknowledged = self._already_revoked = False self.time_start = self.worker_pid = self._terminate_on_ack = None self._tzlocal = None # timezone means the message is timezone-aware, and the only timezone # supported at this point is UTC. if eta is not None: try: self.eta = maybe_iso8601(eta) except (AttributeError, ValueError), exc: raise InvalidTaskError('invalid eta value %r: %s' % ( eta, exc, )) if utc: self.eta = maybe_make_aware(self.eta, self.tzlocal)
def register(self, task): """Register a task in the task registry. The task will be automatically instantiated if not already an instance. Name must be configured prior to registration. """ if task.name is None: raise InvalidTaskError('Task "class {0}" must specify name'.format(task.__class__.__name__)) self[task.name] = inspect.isclass(task) and task() or task
def hybrid_to_proto2(message, body): """Create a fresh protocol 2 message from a hybrid protocol 1/2 message.""" try: args, kwargs = body.get('args', ()), body.get('kwargs', {}) kwargs.items # pylint: disable=pointless-statement except KeyError: raise InvalidTaskError('Message does not have args/kwargs') except AttributeError: raise InvalidTaskError('Task keyword arguments must be a mapping', ) headers = { 'lang': body.get('lang'), 'task': body.get('task'), 'id': body.get('id'), 'root_id': body.get('root_id'), 'parent_id': body.get('parent_id'), 'group': body.get('group'), 'meth': body.get('meth'), 'shadow': body.get('shadow'), 'eta': body.get('eta'), 'expires': body.get('expires'), 'retries': body.get('retries'), 'timelimit': body.get('timelimit'), 'argsrepr': body.get('argsrepr'), 'kwargsrepr': body.get('kwargsrepr'), 'origin': body.get('origin'), } embed = { 'callbacks': body.get('callbacks'), 'errbacks': body.get('errbacks'), 'chord': body.get('chord'), 'chain': None, } octolog.error( "<=Worker Strategy=> HYBRID_TO_PROTO2(): body.get('utc', 'GOT_True') %s", body.get('utc', 'GOT_True')) return (args, kwargs, embed), headers, True, body.get('utc', True)
def inference_task(task_id, run_path, package_path, kwargs): os.mkdir(run_path) meta_info = {"task_id": task_id, "task_kwargs": kwargs} # Pairwise Interaction pi_file = kwargs["pi_file"] write_pairwise_interaction(run_path, pi_file) split_return_code = split_pairwise_interaction(run_path) if split_return_code != "0": raise InvalidTaskError() # Observation obs_file = kwargs["obs_file"] run_type = kwargs["task_type"] write_observation(run_path, obs_file, run_type) # Logical (|| Learnt) Factorgraph lfg_generator = generate_logical_factorgraph(run_path) meta_info["lfg_command"] = next(lfg_generator) lfg_return_code = next(lfg_generator) meta_info["lfg_runtime"] = next(lfg_generator) # return_code = generate_logical_factorgraph(run_path) if lfg_return_code != "0": raise InvalidTaskError() # If a learnt factorgraph is provided, use it for inference instead lfg_file = kwargs["lfg_file"] if lfg_file != "": write_learnt_factorgraph(run_path, lfg_file) fg_name = "learnt.fg" else: fg_name = "logical.fg" # Inference number_states = kwargs["number_states"] inference_generator = inference(run_path, number_states, fg_name) meta_info["pgm_command"] = next(inference_generator) pgm_return_code = next(inference_generator) meta_info["pgm_runtime"] = next(inference_generator) # pgm_return_code = inference(run_path, number_states, fg_name) if pgm_return_code != "0": raise InvalidTaskError() # Package write_info_file(run_path, meta_info) shutil.make_archive(package_path, "zip", root_dir=run_path)
def register(self, task): """Register a task in the task registry. The task will be automatically instantiated if not already an instance. Name must be configured prior to registration. """ if task.name is None: raise InvalidTaskError( 'Task class {0!r} must specify .name attribute'.format( type(task).__name__)) task = inspect.isclass(task) and task() or task add_autoretry_behaviour(task) self[task.name] = task
def learning_task(task_id, run_path, package_path, kwargs): os.mkdir(run_path) meta_info = {"task_id": task_id, "task_kwargs": kwargs} # Pairwise interaction, split .pi file pi_file = kwargs["pi_file"] write_pairwise_interaction(run_path, pi_file) split_return_code = split_pairwise_interaction(run_path) if split_return_code != "0": raise InvalidTaskError() # Observation obs_file = kwargs["obs_file"] run_type = kwargs["task_type"] write_observation(run_path, obs_file, run_type) # Factorgraph lfg_generator = generate_logical_factorgraph(run_path) meta_info["lfg_command"] = next(lfg_generator) lfg_return_code = next(lfg_generator) meta_info["lfg_runtime"] = next(lfg_generator) # return_code = generate_logical_factorgraph(run_path) if lfg_return_code != "0": # self.update_state(state=states.FAILURE, meta="reason for failure") raise InvalidTaskError() # Learning number_states = kwargs["number_states"] change_limit = kwargs["change_limit"] max_iterations = kwargs["max_iterations"] learning_generator = learning(run_path, number_states, change_limit, max_iterations) meta_info["pgm_command"] = next(learning_generator) pgm_return_code = next(learning_generator) meta_info["pgm_runtime"] = next(learning_generator) # pgm_return_code = learning(run_path, number_states, change_limit, max_iterations) if pgm_return_code != "0": raise InvalidTaskError() # Package to down write_info_file(run_path, meta_info) shutil.make_archive(package_path, "zip", root_dir=run_path)
def test_receive_message_InvalidTaskError(self, error): c = self.LoopConsumer() c.blueprint.state = RUN c.steps.pop() m = self.create_task_message( Mock(), self.foo_task.name, args=(1, 2), kwargs='foobarbaz', id=1) c.update_strategies() strat = c.strategies[self.foo_task.name] = Mock(name='strategy') strat.side_effect = InvalidTaskError() callback = self._get_on_message(c) callback(m) error.assert_called() assert 'Received invalid task message' in error.call_args[0][0]
def hybrid_to_proto2(message, body): """Create a fresh protocol 2 message from a hybrid protocol 1/2 message.""" try: args, kwargs = body.get("args", ()), body.get("kwargs", {}) kwargs.items # pylint: disable=pointless-statement except KeyError: raise InvalidTaskError("Message does not have args/kwargs") except AttributeError: raise InvalidTaskError("Task keyword arguments must be a mapping", ) headers = { "lang": body.get("lang"), "task": body.get("task"), "id": body.get("id"), "root_id": body.get("root_id"), "parent_id": body.get("parent_id"), "group": body.get("group"), "meth": body.get("meth"), "shadow": body.get("shadow"), "eta": body.get("eta"), "expires": body.get("expires"), "retries": body.get("retries", 0), "timelimit": body.get("timelimit", (None, None)), "argsrepr": body.get("argsrepr"), "kwargsrepr": body.get("kwargsrepr"), "origin": body.get("origin"), } embed = { "callbacks": body.get("callbacks"), "errbacks": body.get("errbacks"), "chord": body.get("chord"), "chain": None, } return (args, kwargs, embed), headers, True, body.get("utc", True)
def set_job_status(self, task_data=None, **kwargs): from ..models import Job, File if task_data is None: raise InvalidTaskError("task_data is None") job_id = task_data.get('job_id') status = task_data.get('status') with transaction.atomic(): job = Job.objects.get(id=job_id) job.status = status job.save() if (job.done and job.compute_resource and job.compute_resource.disposable and not job.compute_resource.running_jobs()): job.compute_resource.dispose() return task_data
def test_receive_message_InvalidTaskError(self, error): l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = RUN l.event_dispatcher = mock_event_dispatcher() l.steps.pop() l.controller = l.app.WorkController() l.pool = l.controller.pool = Mock() m = create_task_message( Mock(), self.foo_task.name, args=(1, 2), kwargs='foobarbaz', id=1) l.update_strategies() l.event_dispatcher = mock_event_dispatcher() strat = l.strategies[self.foo_task.name] = Mock(name='strategy') strat.side_effect = InvalidTaskError() callback = self._get_on_message(l) callback(m) self.assertTrue(error.called) self.assertIn('Received invalid task message', error.call_args[0][0])
def generate_time_spent_per_user_report(courses, *args, **kwargs): """ Returns the time spent per user data for the given courses. Args: courses: Course ids list. date: Date string for querying in Google BigQuery. Date format: '%Y-%m-%d' e.g. '2019-01-01' Returns: Dict with 'time_spent_per_user_data' containing time spent per user report data. """ date_format = '%Y-%m-%d' report_data = {} try: date_field = datetime.strptime(kwargs.get('date', ''), date_format) except ValueError: raise InvalidTaskError( json.dumps({ 'data': { 'status': FAILURE, 'result': 'date field was not provided or has an invalid format.', }, 'status': status.HTTP_400_BAD_REQUEST, }), ) for course_id in courses: try: course_key = CourseKey.from_string(course_id) except InvalidKeyError: continue time_spent_per_user_report = GenerateTimeSpentPerUserReport( users=get_enrolled_users(course_key), course_key=course_key, query_date=date_field.strftime(date_format), ) report_data[ course_id] = time_spent_per_user_report.generate_report_data() return report_data
def generate_activity_completion_report(courses, *args, **kwargs): """ Returns the activity completion report. """ data = {} serialized_data = ActivityCompletionReportSerializer(data=kwargs) if not serialized_data.is_valid(): # Raises the error containing the JsonResponse parameters # to be used in the view. raise InvalidTaskError( json.dumps({ 'data': { 'status': FAILURE, 'result': serialized_data.errors, }, 'status': status.HTTP_400_BAD_REQUEST, }) ) required_block_ids = serialized_data.data.get('required_activity_ids', []) block_types = serialized_data.data.get('block_types', []) passing_score = serialized_data.data.get('passing_score', []) for course_id in courses: try: course_key = CourseKey.from_string(course_id) except InvalidKeyError: continue course_key = CourseKey.from_string(course_id) completion_report = GenerateCompletionReport( get_enrolled_users(course_key), course_key, required_block_ids, block_types, passing_score, ) data[course_id] = completion_report.generate_report_data() return data
def __init__(self, message, on_ack=noop, hostname=None, eventer=None, app=None, connection_errors=None, request_dict=None, task=None, on_reject=noop, body=None, headers=None, decoded=False, utc=True, maybe_make_aware=maybe_make_aware, maybe_iso8601=maybe_iso8601, **opts): self._message = message self._request_dict = message.headers if headers is None else headers self._body = message.body if body is None else body self._app = app self._utc = utc self._decoded = decoded if decoded: self._content_type = self._content_encoding = None else: self._content_type, self._content_encoding = ( message.content_type, message.content_encoding, ) self.__payload = self._body if self._decoded else message.payload self.id = self._request_dict["id"] self._type = self.name = self._request_dict["task"] if "shadow" in self._request_dict: self.name = self._request_dict["shadow"] or self.name self._root_id = self._request_dict.get("root_id") self._parent_id = self._request_dict.get("parent_id") timelimit = self._request_dict.get("timelimit", None) if timelimit: self.time_limits = timelimit self._argsrepr = self._request_dict.get("argsrepr", "") self._kwargsrepr = self._request_dict.get("kwargsrepr", "") self._on_ack = on_ack self._on_reject = on_reject self._hostname = hostname or gethostname() self._eventer = eventer self._connection_errors = connection_errors or () self._task = task or self._app.tasks[self._type] # timezone means the message is timezone-aware, and the only timezone # supported at this point is UTC. eta = self._request_dict.get("eta") if eta is not None: try: eta = maybe_iso8601(eta) except (AttributeError, ValueError, TypeError) as exc: raise InvalidTaskError("invalid ETA value {0!r}: {1}".format( eta, exc)) self._eta = maybe_make_aware(eta, self.tzlocal) else: self._eta = None expires = self._request_dict.get("expires") if expires is not None: try: expires = maybe_iso8601(expires) except (AttributeError, ValueError, TypeError) as exc: raise InvalidTaskError( "invalid expires value {0!r}: {1}".format(expires, exc)) self._expires = maybe_make_aware(expires, self.tzlocal) else: self._expires = None delivery_info = message.delivery_info or {} properties = message.properties or {} self._delivery_info = { "exchange": delivery_info.get("exchange"), "routing_key": delivery_info.get("routing_key"), "priority": properties.get("priority"), "redelivered": delivery_info.get("redelivered"), } self._request_dict.update({ "reply_to": properties.get("reply_to"), "correlation_id": properties.get("correlation_id"), "hostname": self._hostname, "delivery_info": self._delivery_info, }) # this is a reference pass to avoid memory usage burst self._request_dict["args"], self._request_dict[ "kwargs"], _ = self.__payload self._args = self._request_dict["args"] self._kwargs = self._request_dict["kwargs"]
def __init__(self, body, on_ack=noop, hostname=None, eventer=None, app=None, connection_errors=None, request_dict=None, message=None, task=None, on_reject=noop, **opts): self.app = app name = self.name = body['task'] self.id = body['id'] self.args = body.get('args', []) self.kwargs = body.get('kwargs', {}) try: self.kwargs.items except AttributeError: raise InvalidTaskError('Task keyword arguments is not a mapping') if NEEDS_KWDICT: self.kwargs = kwdict(self.kwargs) eta = body.get('eta') expires = body.get('expires') utc = self.utc = body.get('utc', False) self.on_ack = on_ack self.on_reject = on_reject self.hostname = hostname or socket.gethostname() self.eventer = eventer self.connection_errors = connection_errors or () self.task = task or self.app.tasks[name] self.acknowledged = self._already_revoked = False self.time_start = self.worker_pid = self._terminate_on_ack = None self._apply_result = None self._tzlocal = None # timezone means the message is timezone-aware, and the only timezone # supported at this point is UTC. if eta is not None: try: self.eta = maybe_iso8601(eta) except (AttributeError, ValueError, TypeError) as exc: raise InvalidTaskError('invalid eta value {0!r}: {1}'.format( eta, exc)) if utc: self.eta = maybe_make_aware(self.eta, self.tzlocal) else: self.eta = None if expires is not None: try: self.expires = maybe_iso8601(expires) except (AttributeError, ValueError, TypeError) as exc: raise InvalidTaskError( 'invalid expires value {0!r}: {1}'.format(expires, exc)) if utc: self.expires = maybe_make_aware(self.expires, self.tzlocal) else: self.expires = None if message: delivery_info = message.delivery_info or {} properties = message.properties or {} body.update({ 'headers': message.headers, 'reply_to': properties.get('reply_to'), 'correlation_id': properties.get('correlation_id'), 'delivery_info': { 'exchange': delivery_info.get('exchange'), 'routing_key': delivery_info.get('routing_key'), 'priority': properties.get('priority', delivery_info.get('priority')), 'redelivered': delivery_info.get('redelivered'), } }) else: body.update(DEFAULT_FIELDS) self.request_dict = body
def estimate_job_tarball_size(self, task_data=None, **kwargs): if task_data is None: raise InvalidTaskError("task_data is None") from ..models import Job _init_fabric_env() environment = task_data.get('environment', {}) job_id = task_data.get('job_id') job = Job.objects.get(id=job_id) master_ip = job.compute_resource.host gateway = job.compute_resource.gateway_server queue_type = job.compute_resource.queue_type private_key = job.compute_resource.private_key remote_username = job.compute_resource.extra.get('username', None) job_path = job.abs_path_on_compute message = "No message." task_result = dict() try: with fabsettings(gateway=gateway, host_string=master_ip, user=remote_username, key=private_key): with cd(job_path): with shell_env(**environment): # if queue_type == 'slurm': # result = run(f"scancel {job.remote_id}") # else: # result = run(f"kill {job.remote_id}") # NOTE: If running tar -czf is too slow / too much extra I/O load, # we could use the placeholder heuristic of # f`du -bc --max-depth=0 "{job_path}"` * 0.66 for RNAsik runs, # stored in job metadata. Or add proper sizes to every File.metdata # and derive it from a query. result = run( f'tar -czf - --directory "{job_path}" . | wc --bytes') if result.succeeded: tarball_size = int(result.stdout.strip()) with transaction.atomic(): job = Job.objects.get(id=job_id) job.params['tarball_size'] = tarball_size job.save() task_result['tarball_size'] = tarball_size else: task_result['stdout'] = result.stdout.strip() task_result['stderr'] = result.stderr.strip() except BaseException as e: if hasattr(e, 'message'): message = e.message self.update_state(state=states.FAILURE, meta=message) raise e task_data.update(result=task_result) return task_data
def start_job(self, task_data=None, **kwargs): from ..models import Job if task_data is None: raise InvalidTaskError("task_data is None") job_id = task_data.get('job_id') job = Job.objects.get(id=job_id) result = task_data.get('result') master_ip = job.compute_resource.host gateway = job.compute_resource.gateway_server webhook_notify_url = '' # secret = None environment = task_data.get('environment', {}) job_auth_header = task_data.get('job_auth_header', '') # environment.update(JOB_ID=job_id) _init_fabric_env() private_key = job.compute_resource.private_key remote_username = job.compute_resource.extra.get('username', None) job_script_template_vars = dict(environment) job_script_template_vars['JOB_AUTH_HEADER'] = job_auth_header job_script = BytesIO( render_to_string('job_scripts/run_job.sh', context=job_script_template_vars).encode('utf-8')) kill_script = BytesIO( render_to_string('job_scripts/kill_job.sh', context=job_script_template_vars).encode('utf-8')) curl_headers = BytesIO(b"%s\n" % job_auth_header.encode('utf-8')) config_json = BytesIO(json.dumps(job.params).encode('utf-8')) remote_id = None message = "Failure, without exception." try: with fabsettings( gateway=gateway, host_string=master_ip, user=remote_username, key=private_key, # key_filename=expanduser("~/.ssh/id_rsa"), ): working_dir = job.abs_path_on_compute input_dir = join(working_dir, 'input') output_dir = join(working_dir, 'output') job_script_path = join(input_dir, 'run_job.sh') kill_script_path = join(working_dir, 'kill_job.sh') for d in [working_dir, input_dir, output_dir]: result = run(f'mkdir -p {d} && chmod 700 {d}') result = put(job_script, job_script_path, mode=0o700) result = put(kill_script, kill_script_path, mode=0o700) result = put(curl_headers, join(working_dir, '.private_request_headers'), mode=0o600) result = put(config_json, join(input_dir, 'pipeline_config.json'), mode=0o600) with cd(working_dir): with shell_env(**environment): # NOTE: We can't sbatch the run_job.sh script due to # the local aria2c RPC daemon launched by laxydl # In the future, we may have a DataTransferHost where # the data staging steps run, then we could launch # run_job.sh via sbatch. # if job.compute_resource.queue_type == 'slurm': # result = run(f"sbatch --parsable " # f'--job-name="laxy:{job_id}" ' # f"--output output/run_job.out " # f"{job_script_path} " # f" >>slurm.jids") # remote_id = run(str("head -1 slurm.jids")) # The job script is always run locally on the compute # node (not sbatched), but will itself send jobs # to the queue. result = run(f"nohup bash -l -c '" f"{job_script_path} & " f"echo $! >>job.pids" f"' >output/run_job.out") remote_id = run(str("head -1 job.pids")) succeeded = result.succeeded except BaseException as e: succeeded = False if hasattr(e, 'message'): message = e.message if hasattr(e, '__traceback__'): tb = e.__traceback__ message = '%s - Traceback: %s' % (message, ''.join( traceback.format_list(traceback.extract_tb(tb)))) else: message = repr(e) if not succeeded and job.compute_resource.disposable: job.compute_resource.dispose() job_status = Job.STATUS_RUNNING if succeeded else Job.STATUS_FAILED job = Job.objects.get(id=job_id) job.status = job_status job.remote_id = remote_id job.save() # if webhook_notify_url: # job_status = Job.STATUS_STARTING if succeeded else Job.STATUS_FAILED # resp = request_with_retries( # 'PATCH', callback_url, # json={'status': job_status}, # headers={'Authorization': secret}, # ) if not succeeded: self.update_state(state=states.FAILURE, meta=message) raise Exception(message) # raise Ignore() task_data.update(result=result) return task_data
self.eta = maybe_iso8601(eta) except (AttributeError, ValueError), exc: raise InvalidTaskError('invalid eta value %r: %s' % ( eta, exc, )) if utc: self.eta = maybe_make_aware(self.eta, self.tzlocal) else: self.eta = None if expires is not None: try: self.expires = maybe_iso8601(expires) except (AttributeError, ValueError), exc: raise InvalidTaskError('invalid expires value %r: %s' % ( expires, exc, )) if utc: self.expires = maybe_make_aware(self.expires, self.tzlocal) else: self.expires = None delivery_info = {} if delivery_info is None else delivery_info self.delivery_info = { 'exchange': delivery_info.get('exchange'), 'routing_key': delivery_info.get('routing_key'), 'priority': delivery_info.get('priority'), } # amqplib transport adds the channel here for some reason, so need # to remove it.
def index_remote_files(self, task_data=None, **kwargs): if task_data is None: raise InvalidTaskError("task_data is None") job_id = task_data.get('job_id') job = Job.objects.get(id=job_id) clobber = task_data.get('clobber', False) compute_resource = job.compute_resource if compute_resource is not None: master_ip = compute_resource.host gateway = compute_resource.gateway_server else: logger.info(f"Not indexing files for {job_id}, no compute_resource.") return task_data job.log_event('JOB_INFO', 'Indexing all files (backend task)') environment = task_data.get('environment', {}) # environment.update(JOB_ID=job_id) _init_fabric_env() private_key = job.compute_resource.private_key remote_username = job.compute_resource.extra.get('username', None) compute_id = job.compute_resource.id message = "No message." def create_update_file_objects(remote_path, fileset=None, prefix_path='', location_base=''): """ Returns a list of (unsaved) File objects from a recursive 'find' of a remote directory. If a file of the same path exists in the FileSet, update the file object location (if unset) rather than create a new one. :param fileset: :type fileset: :param prefix_path: :type prefix_path: :param remote_path: Path on the remote server. :type remote_path: str :param location_base: Prefix of location URL (eg sftp://127.0.0.1/XxX/) :type location_base: str :return: A list of File objects :rtype: List[File] """ with cd(remote_path): filepaths = remote_list_files('.') urls = [(f'{location_base}/{fpath}', fpath) for fpath in filepaths] file_objs = [] for location, filepath in urls: fname = Path(filepath).name fpath = Path(prefix_path) / Path(filepath).parent if fileset: f = fileset.get_file_by_path(Path(fpath) / Path(fname)) if not f: f = File(location=location, owner=job.owner, name=fname, path=fpath) elif not f.location: f.location = location f.owner = job.owner file_objs.append(f) return file_objs try: with fabsettings( gateway=gateway, host_string=master_ip, user=remote_username, key=private_key, # key_filename=expanduser("~/.ssh/id_rsa"), ): working_dir = job.abs_path_on_compute input_dir = os.path.join(working_dir, 'input') output_dir = os.path.join(working_dir, 'output') output_files = create_update_file_objects( output_dir, fileset=job.output_files, prefix_path='output', location_base=laxy_sftp_url(job, 'output'), ) job.output_files.path = 'output' job.output_files.owner = job.owner if clobber: job.output_files.remove(job.output_files, delete=True) job.output_files.add(output_files) # TODO: This should really be done at job start, or once input data # has been staged on the compute node. input_files = create_update_file_objects( input_dir, fileset=job.input_files, prefix_path='input', location_base=laxy_sftp_url(job, 'input')) job.input_files.path = 'input' job.input_files.owner = job.owner if clobber: job.input_files.remove(job.input_files, delete=True) job.input_files.add(input_files) succeeded = True except BaseException as e: succeeded = False if hasattr(e, 'message'): message = e.message self.update_state(state=states.FAILURE, meta=message) raise e # job_status = Job.STATUS_RUNNING if succeeded else Job.STATUS_FAILED # job = Job.objects.get(id=job_id) # job.status = job_status # job.save() # if not succeeded: # self.update_state(state=states.FAILURE, meta=message) # raise Exception(message) # # raise Ignore() return task_data
def trace_task(uuid, args, kwargs, request=None): # R - is the possibly prepared return value. # I - is the Info object. # T - runtime # Rstr - textual representation of return value # retval - is the always unmodified return value. # state - is the resulting task state. # This function is very long because we have unrolled all the calls # for performance reasons, and because the function is so long # we want the main variables (I, and R) to stand out visually from the # the rest of the variables, so breaking PEP8 is worth it ;) R = I = T = Rstr = retval = state = None task_request = None time_start = monotonic() try: try: kwargs.items except AttributeError: raise InvalidTaskError( 'Task keyword arguments is not a mapping') push_task(task) task_request = Context(request or {}, args=args, called_directly=False, kwargs=kwargs) push_request(task_request) try: # -*- PRE -*- if prerun_receivers: send_prerun(sender=task, task_id=uuid, task=task, args=args, kwargs=kwargs) loader_task_init(uuid, task) if track_started: store_result( uuid, { 'pid': pid, 'hostname': hostname }, STARTED, request=task_request, ) # -*- TRACE -*- try: R = retval = fun(*args, **kwargs) state = SUCCESS except Reject as exc: I, R = Info(REJECTED, exc), ExceptionInfo(internal=True) state, retval = I.state, I.retval I.handle_reject(task, task_request) except Ignore as exc: I, R = Info(IGNORED, exc), ExceptionInfo(internal=True) state, retval = I.state, I.retval I.handle_ignore(task, task_request) except Retry as exc: I, R, state, retval = on_error( task_request, exc, uuid, RETRY, call_errbacks=False, ) except Exception as exc: I, R, state, retval = on_error(task_request, exc, uuid) except BaseException as exc: raise else: try: # callback tasks must be applied before the result is # stored, so that result.children is populated. # groups are called inline and will store trail # separately, so need to call them separately # so that the trail's not added multiple times :( # (Issue #1936) callbacks = task.request.callbacks if callbacks: if len(task.request.callbacks) > 1: sigs, groups = [], [] for sig in callbacks: sig = signature(sig, app=app) if isinstance(sig, group): groups.append(sig) else: sigs.append(sig) for group_ in groups: group.apply_async((retval, )) if sigs: group(sigs).apply_async((retval, )) else: signature(callbacks[0], app=app).delay(retval) if publish_result: store_result( uuid, retval, SUCCESS, request=task_request, ) except EncodeError as exc: I, R, state, retval = on_error(task_request, exc, uuid) else: if task_on_success: task_on_success(retval, uuid, args, kwargs) if success_receivers: send_success(sender=task, result=retval) if _does_info: T = monotonic() - time_start Rstr = truncate(safe_repr(R), 256) info( LOG_SUCCESS, { 'id': uuid, 'name': name, 'return_value': Rstr, 'runtime': T, }) # -* POST *- if state not in IGNORE_STATES: if task_request.chord: on_chord_part_return(task, state, R) if task_after_return: task_after_return( state, retval, uuid, args, kwargs, None, ) finally: try: if postrun_receivers: send_postrun(sender=task, task_id=uuid, task=task, args=args, kwargs=kwargs, retval=retval, state=state) finally: pop_task() pop_request() if not eager: try: backend_cleanup() loader_cleanup() except (KeyboardInterrupt, SystemExit, MemoryError): raise except Exception as exc: logger.error('Process cleanup failed: %r', exc, exc_info=True) except MemoryError: raise except Exception as exc: if eager: raise R = report_internal_error(task, exc) if task_request is not None: I, _, _, _ = on_error(task_request, exc, uuid) return trace_ok_t(R, I, T, Rstr)
def test_on_task_InvalidTaskError(self): x, on_task, msg, strategy = self.task_context(self.add.s(2, 2)) exc = strategy.side_effect = InvalidTaskError() on_task(msg) x.on_invalid_task.assert_called_with(None, msg, exc)
def __init__(self, message, on_ack=noop, hostname=None, eventer=None, app=None, connection_errors=None, request_dict=None, task=None, on_reject=noop, body=None, headers=None, decoded=False, utc=True, maybe_make_aware=maybe_make_aware, maybe_iso8601=maybe_iso8601, **opts): if headers is None: headers = message.headers if body is None: body = message.body self.app = app self.message = message self.body = body self.utc = utc self._decoded = decoded if decoded: self.content_type = self.content_encoding = None else: self.content_type, self.content_encoding = ( message.content_type, message.content_encoding, ) self.id = headers['id'] type = self.type = self.name = headers['task'] self.root_id = headers.get('root_id') self.parent_id = headers.get('parent_id') if 'shadow' in headers: self.name = headers['shadow'] or self.name timelimit = headers.get('timelimit', None) if timelimit: self.time_limits = timelimit self.argsrepr = headers.get('argsrepr', '') self.kwargsrepr = headers.get('kwargsrepr', '') self.on_ack = on_ack self.on_reject = on_reject self.hostname = hostname or gethostname() self.eventer = eventer self.connection_errors = connection_errors or () self.task = task or self.app.tasks[type] # timezone means the message is timezone-aware, and the only timezone # supported at this point is UTC. eta = headers.get('eta') if eta is not None: try: eta = maybe_iso8601(eta) except (AttributeError, ValueError, TypeError) as exc: raise InvalidTaskError( 'invalid ETA value {0!r}: {1}'.format(eta, exc)) self.eta = maybe_make_aware(eta, self.tzlocal) else: self.eta = None expires = headers.get('expires') if expires is not None: try: expires = maybe_iso8601(expires) except (AttributeError, ValueError, TypeError) as exc: raise InvalidTaskError( 'invalid expires value {0!r}: {1}'.format(expires, exc)) self.expires = maybe_make_aware(expires, self.tzlocal) else: self.expires = None delivery_info = message.delivery_info or {} properties = message.properties or {} headers.update({ 'reply_to': properties.get('reply_to'), 'correlation_id': properties.get('correlation_id'), 'delivery_info': { 'exchange': delivery_info.get('exchange'), 'routing_key': delivery_info.get('routing_key'), 'priority': properties.get('priority'), 'redelivered': delivery_info.get('redelivered'), } }) self.request_dict = headers
def __init__(self, message, on_ack=noop, hostname=None, eventer=None, app=None, connection_errors=None, request_dict=None, task=None, on_reject=noop, body=None, headers=None, decoded=False, utc=True, maybe_make_aware=maybe_make_aware, maybe_iso8601=maybe_iso8601, **opts): self._message = message self._request_dict = (message.headers.copy() if headers is None else headers.copy()) self._body = message.body if body is None else body self._app = app self._utc = utc self._decoded = decoded if decoded: self._content_type = self._content_encoding = None else: self._content_type, self._content_encoding = ( message.content_type, message.content_encoding, ) self.__payload = self._body if self._decoded else message.payload self.id = self._request_dict['id'] self._type = self.name = self._request_dict['task'] if 'shadow' in self._request_dict: self.name = self._request_dict['shadow'] or self.name self._root_id = self._request_dict.get('root_id') self._parent_id = self._request_dict.get('parent_id') timelimit = self._request_dict.get('timelimit', None) if timelimit: self.time_limits = timelimit self._argsrepr = self._request_dict.get('argsrepr', '') self._kwargsrepr = self._request_dict.get('kwargsrepr', '') self._on_ack = on_ack self._on_reject = on_reject self._hostname = hostname or gethostname() self._eventer = eventer self._connection_errors = connection_errors or () self._task = task or self._app.tasks[self._type] self._ignore_result = self._request_dict.get('ignore_result', False) # timezone means the message is timezone-aware, and the only timezone # supported at this point is UTC. eta = self._request_dict.get('eta') if eta is not None: try: eta = maybe_iso8601(eta) except (AttributeError, ValueError, TypeError) as exc: raise InvalidTaskError(f'invalid ETA value {eta!r}: {exc}') self._eta = maybe_make_aware(eta, self.tzlocal) else: self._eta = None expires = self._request_dict.get('expires') if expires is not None: try: expires = maybe_iso8601(expires) except (AttributeError, ValueError, TypeError) as exc: raise InvalidTaskError( f'invalid expires value {expires!r}: {exc}') self._expires = maybe_make_aware(expires, self.tzlocal) else: self._expires = None delivery_info = message.delivery_info or {} properties = message.properties or {} self._delivery_info = { 'exchange': delivery_info.get('exchange'), 'routing_key': delivery_info.get('routing_key'), 'priority': properties.get('priority'), 'redelivered': delivery_info.get('redelivered', False), } self._request_dict.update({ 'properties': properties, 'reply_to': properties.get('reply_to'), 'correlation_id': properties.get('correlation_id'), 'hostname': self._hostname, 'delivery_info': self._delivery_info }) # this is a reference pass to avoid memory usage burst self._request_dict['args'], self._request_dict[ 'kwargs'], _ = self.__payload self._args = self._request_dict['args'] self._kwargs = self._request_dict['kwargs']
def trace_task(uuid, args, kwargs, request=None): # R - is the possibly prepared return value. # I - is the Info object. # T - runtime # Rstr - textual representation of return value # retval - is the always unmodified return value. # state - is the resulting task state. # This function is very long because we've unrolled all the calls # for performance reasons, and because the function is so long # we want the main variables (I, and R) to stand out visually from the # the rest of the variables, so breaking PEP8 is worth it ;) R = I = T = Rstr = retval = state = None task_request = None time_start = monotonic() try: try: kwargs.items except AttributeError: raise InvalidTaskError( 'Task keyword arguments is not a mapping') task_request = Context(request or {}, args=args, called_directly=False, kwargs=kwargs) redelivered = (task_request.delivery_info and task_request.delivery_info.get('redelivered', False)) if deduplicate_successful_tasks and redelivered: if task_request.id in successful_requests: return trace_ok_t(R, I, T, Rstr) r = AsyncResult(task_request.id, app=app) try: state = r.state except BackendGetMetaError: pass else: if state == SUCCESS: info(LOG_IGNORED, { 'id': task_request.id, 'name': get_task_name(task_request, name), 'description': 'Task already completed successfully.' }) return trace_ok_t(R, I, T, Rstr) push_task(task) root_id = task_request.root_id or uuid task_priority = task_request.delivery_info.get('priority') if \ inherit_parent_priority else None push_request(task_request) try: # -*- PRE -*- if prerun_receivers: send_prerun(sender=task, task_id=uuid, task=task, args=args, kwargs=kwargs) loader_task_init(uuid, task) if track_started: task.backend.store_result( uuid, {'pid': pid, 'hostname': hostname}, STARTED, request=task_request, ) # -*- TRACE -*- try: R = retval = fun(*args, **kwargs) state = SUCCESS except Reject as exc: I, R = Info(REJECTED, exc), ExceptionInfo(internal=True) state, retval = I.state, I.retval I.handle_reject(task, task_request) traceback_clear(exc) except Ignore as exc: I, R = Info(IGNORED, exc), ExceptionInfo(internal=True) state, retval = I.state, I.retval I.handle_ignore(task, task_request) traceback_clear(exc) except Retry as exc: I, R, state, retval = on_error( task_request, exc, uuid, RETRY, call_errbacks=False) traceback_clear(exc) except Exception as exc: I, R, state, retval = on_error(task_request, exc, uuid) traceback_clear(exc) except BaseException: raise else: try: # callback tasks must be applied before the result is # stored, so that result.children is populated. # groups are called inline and will store trail # separately, so need to call them separately # so that the trail's not added multiple times :( # (Issue #1936) callbacks = task.request.callbacks if callbacks: if len(task.request.callbacks) > 1: sigs, groups = [], [] for sig in callbacks: sig = signature(sig, app=app) if isinstance(sig, group): groups.append(sig) else: sigs.append(sig) for group_ in groups: group_.apply_async( (retval,), parent_id=uuid, root_id=root_id, priority=task_priority ) if sigs: group(sigs, app=app).apply_async( (retval,), parent_id=uuid, root_id=root_id, priority=task_priority ) else: signature(callbacks[0], app=app).apply_async( (retval,), parent_id=uuid, root_id=root_id, priority=task_priority ) # execute first task in chain chain = task_request.chain if chain: _chsig = signature(chain.pop(), app=app) _chsig.apply_async( (retval,), chain=chain, parent_id=uuid, root_id=root_id, priority=task_priority ) task.backend.mark_as_done( uuid, retval, task_request, publish_result, ) except EncodeError as exc: I, R, state, retval = on_error(task_request, exc, uuid) else: Rstr = saferepr(R, resultrepr_maxsize) T = monotonic() - time_start if task_on_success: task_on_success(retval, uuid, args, kwargs) if success_receivers: send_success(sender=task, result=retval) if _does_info: info(LOG_SUCCESS, { 'id': uuid, 'name': get_task_name(task_request, name), 'return_value': Rstr, 'runtime': T, }) # -* POST *- if state not in IGNORE_STATES: if task_after_return: task_after_return( state, retval, uuid, args, kwargs, None, ) finally: try: if postrun_receivers: send_postrun(sender=task, task_id=uuid, task=task, args=args, kwargs=kwargs, retval=retval, state=state) finally: pop_task() pop_request() if not eager: try: task.backend.process_cleanup() loader_cleanup() except (KeyboardInterrupt, SystemExit, MemoryError): raise except Exception as exc: logger.error('Process cleanup failed: %r', exc, exc_info=True) except MemoryError: raise except Exception as exc: _signal_internal_error(task, uuid, args, kwargs, request, exc) if eager: raise R = report_internal_error(task, exc) if task_request is not None: I, _, _, _ = on_error(task_request, exc, uuid) return trace_ok_t(R, I, T, Rstr)