Example #1
0
    def execute(cls, request: JobTemplateRequest,
                user_info: UserInfo) -> Result[Job]:
        index = cls.get(request.name)
        if index is None:
            if request.name not in TEMPLATES:
                return Error(
                    code=ErrorCode.INVALID_REQUEST,
                    errors=["no such template: %s" % request.name],
                )
            base_template = TEMPLATES[request.name]
        else:
            base_template = index.template

        template = render(request, base_template)
        if isinstance(template, Error):
            return template

        try:
            for task_config in template.tasks:
                check_config(task_config)
                if task_config.pool is None:
                    return Error(code=ErrorCode.INVALID_REQUEST,
                                 errors=["pool not defined"])

        except TaskConfigError as err:
            return Error(code=ErrorCode.INVALID_REQUEST, errors=[str(err)])

        for notification_config in template.notifications:
            for task_container in request.containers:
                if task_container.type == notification_config.container_type:
                    notification = Notification.create(
                        task_container.name,
                        notification_config.notification.config,
                        True,
                    )
                    if isinstance(notification, Error):
                        return notification

        job = Job(config=template.job)
        job.save()

        tasks: List[Task] = []
        for task_config in template.tasks:
            task_config.job_id = job.job_id
            if task_config.prereq_tasks:
                # pydantic verifies prereq_tasks in u128 form are index refs to
                # previously generated tasks
                task_config.prereq_tasks = [
                    tasks[x.int].task_id for x in task_config.prereq_tasks
                ]

            task = Task.create(config=task_config,
                               job_id=job.job_id,
                               user_info=user_info)
            if isinstance(task, Error):
                return task

            tasks.append(task)

        return job
Example #2
0
def get_os(region: Region, image: str) -> Union[Error, OS]:
    client = get_compute_client()
    parsed = parse_resource_id(image)
    if "resource_group" in parsed:
        try:
            name = client.images.get(
                parsed["resource_group"],
                parsed["name"]).storage_profile.os_disk.os_type.name
        except (ResourceNotFoundError, CloudError) as err:
            return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
    else:
        publisher, offer, sku, version = image.split(":")
        try:
            if version == "latest":
                version = client.virtual_machine_images.list(region,
                                                             publisher,
                                                             offer,
                                                             sku,
                                                             top=1)[0].name
            name = client.virtual_machine_images.get(
                region, publisher, offer, sku,
                version).os_disk_image.operating_system.lower()
        except (ResourceNotFoundError, CloudError) as err:
            return Error(code=ErrorCode.INVALID_IMAGE, errors=[str(err)])
    return OS[name]
Example #3
0
def get(req: func.HttpRequest) -> func.HttpResponse:
    get_registration = parse_uri(AgentRegistrationGet, req)

    if isinstance(get_registration, Error):
        return not_ok(get_registration, context="agent registration")

    agent_node = Node.get_by_machine_id(get_registration.machine_id)

    if agent_node is None:
        return not_ok(
            Error(
                code=ErrorCode.INVALID_REQUEST,
                errors=[
                    "unable to find a registration associated with machine_id '%s'"
                    % get_registration.machine_id
                ],
            ),
            context="agent registration",
            status_code=404,
        )
    else:
        pool = Pool.get_by_name(agent_node.pool_name)
        if isinstance(pool, Error):
            return not_ok(
                Error(
                    code=ErrorCode.INVALID_REQUEST,
                    errors=[
                        "unable to find a pool associated with the provided machine_id"
                    ],
                ),
                context="agent registration",
            )

        return create_registration_response(agent_node.machine_id, pool)
Example #4
0
def post(req: func.HttpRequest) -> func.HttpResponse:
    envelope = parse_request(NodeEventEnvelope, req)
    if isinstance(envelope, Error):
        return not_ok(envelope, context=ERROR_CONTEXT)

    logging.info(
        "node event: machine_id: %s event: %s",
        envelope.machine_id,
        envelope.event,
    )

    if isinstance(envelope.event, NodeEvent):
        event = envelope.event
    elif isinstance(envelope.event, NodeStateUpdate):
        event = NodeEvent(state_update=envelope.event)
    elif isinstance(envelope.event, WorkerEvent):
        event = NodeEvent(worker_event=envelope.event)
    else:
        err = Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid node event"])
        return not_ok(err, context=ERROR_CONTEXT)

    if event.state_update:
        on_state_update(envelope.machine_id, event.state_update)
        return ok(BoolResult(result=True))
    elif event.worker_event:
        on_worker_event(envelope.machine_id, event.worker_event)
        return ok(BoolResult(result=True))
    else:
        err = Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid node event"])
        return not_ok(err, context=ERROR_CONTEXT)
def try_get_token_auth_header(request: func.HttpRequest) -> Union[Error, TokenData]:
    """ Obtains the Access Token from the Authorization Header """
    auth: str = request.headers.get("Authorization", None)
    if not auth:
        return Error(
            code=ErrorCode.INVALID_REQUEST, errors=["Authorization header is expected"]
        )
    parts = auth.split()

    if parts[0].lower() != "bearer":
        return Error(
            code=ErrorCode.INVALID_REQUEST,
            errors=["Authorization header must start with Bearer"],
        )

    elif len(parts) == 1:
        return Error(code=ErrorCode.INVALID_REQUEST, errors=["Token not found"])

    elif len(parts) > 2:
        return Error(
            code=ErrorCode.INVALID_REQUEST,
            errors=["Authorization header must be Bearer token"],
        )

    # This token has already been verified by the azure authentication layer
    token = jwt.decode(parts[1], verify=False)
    return TokenData(application_id=UUID(token["appid"]), object_id=UUID(token["oid"]))
Example #6
0
def parse_jwt_token(request: func.HttpRequest) -> Result[UserInfo]:
    """Obtains the Access Token from the Authorization Header"""
    token_str = get_auth_token(request)
    if token_str is None:
        return Error(
            code=ErrorCode.INVALID_REQUEST,
            errors=["unable to find authorization token"],
        )

    # The JWT token has already been verified by the azure authentication layer,
    # but we need to verify the tenant is as we expect.
    token = jwt.decode(token_str, options={"verify_signature": False})

    if "iss" not in token:
        return Error(code=ErrorCode.INVALID_REQUEST,
                     errors=["missing issuer from token"])

    tenants = get_allowed_tenants()
    if token["iss"] not in tenants:
        logging.error("issuer not from allowed tenant: %s - %s", token["iss"],
                      tenants)
        return Error(code=ErrorCode.INVALID_REQUEST,
                     errors=["unauthorized AAD issuer"])

    application_id = UUID(token["appid"]) if "appid" in token else None
    object_id = UUID(token["oid"]) if "oid" in token else None
    upn = token.get("upn")
    return UserInfo(application_id=application_id,
                    object_id=object_id,
                    upn=upn)
Example #7
0
def parse_jwt_token(request: func.HttpRequest) -> Result[UserInfo]:
    """ Obtains the Access Token from the Authorization Header """

    auth: str = request.headers.get("Authorization", None)
    if not auth:
        return Error(code=ErrorCode.INVALID_REQUEST,
                     errors=["Authorization header is expected"])

    parts = auth.split()

    if len(parts) != 2:
        return Error(code=ErrorCode.INVALID_REQUEST,
                     errors=["Invalid authorization header"])

    if parts[0].lower() != "bearer":
        return Error(
            code=ErrorCode.INVALID_REQUEST,
            errors=["Authorization header must start with Bearer"],
        )

    # This token has already been verified by the azure authentication layer
    token = jwt.decode(parts[1], verify=False)

    application_id = UUID(token["appid"])
    object_id = UUID(token["oid"]) if "oid" in token else None
    upn = token.get("upn")
    return UserInfo(application_id=application_id,
                    object_id=object_id,
                    upn=upn)
Example #8
0
    def build_repro_script(self) -> Optional[Error]:
        if self.auth is None:
            return Error(code=ErrorCode.VM_CREATE_FAILED, errors=["missing auth"])

        task = Task.get_by_task_id(self.task_id)
        if isinstance(task, Error):
            return task

        report = get_report(self.config.container, self.config.path)
        if report is None:
            return Error(code=ErrorCode.VM_CREATE_FAILED, errors=["missing report"])

        files = {}

        if task.os == OS.windows:
            ssh_path = "$env:ProgramData/ssh/administrators_authorized_keys"
            cmds = [
                'Set-Content -Path %s -Value "%s"' % (ssh_path, self.auth.public_key),
                ". C:\\onefuzz\\tools\\win64\\onefuzz.ps1",
                "Set-SetSSHACL",
                'while (1) { cdb -server tcp:port=1337 -c "g" setup\\%s %s }'
                % (
                    task.config.task.target_exe,
                    report.input_blob.name,
                ),
            ]
            cmd = "\r\n".join(cmds)
            files["repro.ps1"] = cmd
        elif task.os == OS.linux:
            gdb_fmt = (
                "ASAN_OPTIONS='abort_on_error=1' gdbserver "
                "%s /onefuzz/setup/%s /onefuzz/downloaded/%s"
            )
            cmd = "while :; do %s; done" % (
                gdb_fmt
                % (
                    "localhost:1337",
                    task.config.task.target_exe,
                    report.input_blob.name,
                )
            )
            files["repro.sh"] = cmd

            cmd = "#!/bin/bash\n%s" % (
                gdb_fmt % ("-", task.config.task.target_exe, report.input_blob.name)
            )
            files["repro-stdout.sh"] = cmd
        else:
            raise NotImplementedError("invalid task os: %s" % task.os)

        for filename in files:
            save_blob(
                Container("repro-scripts"),
                "%s/%s" % (self.vm_id, filename),
                files[filename],
                StorageType.config,
            )

        logging.info("saved repro script")
        return None
Example #9
0
def get(req: func.HttpRequest) -> func.HttpResponse:
    request = parse_uri(FileEntry, req)
    if isinstance(request, Error):
        return not_ok(request, context="download")

    if not container_exists(request.container, StorageType.corpus):
        return not_ok(
            Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
            context=request.container,
        )

    if not blob_exists(request.container, request.filename, StorageType.corpus):
        return not_ok(
            Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid filename"]),
            context=request.filename,
        )

    return redirect(
        get_file_sas_url(
            request.container,
            request.filename,
            StorageType.corpus,
            read=True,
            days=0,
            minutes=5,
        )
    )
Example #10
0
def on_worker_event(machine_id: UUID, event: WorkerEvent) -> func.HttpResponse:
    if event.running:
        task_id = event.running.task_id
    elif event.done:
        task_id = event.done.task_id

    task = get_task_checked(task_id)
    node = get_node_checked(machine_id)
    node_task = NodeTasks(machine_id=machine_id,
                          task_id=task_id,
                          state=NodeTaskState.running)

    if event.running:
        if task.state not in TaskState.shutting_down():
            task.state = TaskState.running
        if node.state not in NodeState.ready_for_reset():
            node.state = NodeState.busy
        node_task.save()
        task.on_start()
    elif event.done:
        # Only record exit status if the task isn't already shutting down.
        #
        # It's ok for the agent to fail because resources vanish out from underneath
        # it during deletion.
        if task.state not in TaskState.shutting_down():
            exit_status = event.done.exit_status

            if not exit_status.success:
                logging.error("task failed: status = %s", exit_status)

                task.error = Error(
                    code=ErrorCode.TASK_FAILED,
                    errors=[
                        "task failed. exit_status = %s" % exit_status,
                        event.done.stdout,
                        event.done.stderr,
                    ],
                )

            task.state = TaskState.stopping
        if node.state not in NodeState.ready_for_reset():
            node.state = NodeState.done
        node_task.delete()
    else:
        err = Error(
            code=ErrorCode.INVALID_REQUEST,
            errors=["invalid worker event type"],
        )
        raise RequestException(err)

    task.save()
    node.save()
    task_event = TaskEvent(task_id=task_id,
                           machine_id=machine_id,
                           event_data=event)
    task_event.save()
    return ok(BoolResult(result=True))
Example #11
0
def post(req: func.HttpRequest) -> func.HttpResponse:
    request = parse_request(ScalesetCreate, req)
    if isinstance(request, Error):
        return not_ok(request, context="ScalesetCreate")

    # Verify the pool exists
    pool = Pool.get_by_name(request.pool_name)
    if isinstance(pool, Error):
        return not_ok(pool, context=repr(request))

    if not pool.managed:
        return not_ok(
            Error(
                code=ErrorCode.UNABLE_TO_CREATE,
                errors=["scalesets can only be added to managed pools"],
            ),
            context="scalesetcreate",
        )

    if request.region is None:
        region = get_base_region()
    else:
        if request.region not in get_regions():
            return not_ok(
                Error(code=ErrorCode.UNABLE_TO_CREATE,
                      errors=["invalid region"]),
                context="scalesetcreate",
            )

        region = request.region

    if request.vm_sku not in list_available_skus(region):
        return not_ok(
            Error(
                code=ErrorCode.UNABLE_TO_CREATE,
                errors=[
                    "The specified vm_sku '%s' is not available in the location '%s'"
                    % (request.vm_sku, region)
                ],
            ),
            context="scalesetcreate",
        )

    scaleset = Scaleset.create(
        pool_name=request.pool_name,
        vm_sku=request.vm_sku,
        image=request.image,
        region=region,
        size=request.size,
        spot_instances=request.spot_instances,
        tags=request.tags,
    )
    scaleset.save()
    # don't return auths during create, only 'get' with include_auth
    scaleset.auth = None
    return ok(scaleset)
Example #12
0
def on_worker_event(machine_id: UUID, event: WorkerEvent) -> None:
    if event.running:
        task_id = event.running.task_id
    elif event.done:
        task_id = event.done.task_id
    else:
        raise NotImplementedError

    task = get_task_checked(task_id)
    node = get_node_checked(machine_id)
    node_task = NodeTasks(
        machine_id=machine_id, task_id=task_id, state=NodeTaskState.running
    )

    if event.running:
        if task.state not in TaskState.shutting_down():
            task.state = TaskState.running
        if node.state not in NodeState.ready_for_reset():
            node.state = NodeState.busy
            node.save()
        node_task.save()

        # Start the clock for the task if it wasn't started already
        # (as happens in 1.0.0 agents)
        task.on_start()
    elif event.done:
        node_task.delete()

        exit_status = event.done.exit_status
        if not exit_status.success:
            logging.error("task failed. status:%s", exit_status)
            task.mark_failed(
                Error(
                    code=ErrorCode.TASK_FAILED,
                    errors=[
                        "task failed. exit_status:%s" % exit_status,
                        event.done.stdout,
                        event.done.stderr,
                    ],
                )
            )
        else:
            task.mark_stopping()

        node.to_reimage(done=True)
    else:
        err = Error(
            code=ErrorCode.INVALID_REQUEST,
            errors=["invalid worker event type"],
        )
        raise RequestException(err)

    task.save()

    task_event = TaskEvent(task_id=task_id, machine_id=machine_id, event_data=event)
    task_event.save()
Example #13
0
def associate_subnet(name: str, vnet: VirtualNetwork,
                     subnet: Subnet) -> Union[None, Error]:

    resource_group = get_base_resource_group()
    nsg = get_nsg(name)
    if not nsg:
        return Error(
            code=ErrorCode.UNABLE_TO_FIND,
            errors=["cannot associate subnet. nsg %s not found" % name],
        )

    if nsg.location != vnet.location:
        return Error(
            code=ErrorCode.UNABLE_TO_UPDATE,
            errors=[
                "subnet and nsg have to be in the same region.",
                "nsg %s %s, subnet: %s %s" %
                (nsg.name, nsg.location, subnet.name, subnet.location),
            ],
        )

    if subnet.network_security_group and subnet.network_security_group.id == nsg.id:
        logging.info("Subnet %s and NSG %s already associated, not updating",
                     subnet.name, name)
        return None

    logging.info("associating subnet %s with nsg: %s %s", subnet.name,
                 resource_group, name)

    subnet.network_security_group = nsg
    network_client = get_network_client()
    try:
        network_client.subnets.begin_create_or_update(resource_group,
                                                      vnet.name, subnet.name,
                                                      subnet)
    except (ResourceNotFoundError, CloudError) as err:
        if is_concurrent_request_error(str(err)):
            logging.debug(
                "associate NSG with subnet had conflicts",
                "with concurrent request, ignoring %s",
                err,
            )
            return None
        return Error(
            code=ErrorCode.UNABLE_TO_UPDATE,
            errors=[
                "Unable to associate nsg %s with subnet %s due to %s" % (
                    name,
                    subnet.name,
                    err,
                )
            ],
        )

    return None
Example #14
0
def post(req: func.HttpRequest) -> func.HttpResponse:
    request = parse_request(PoolCreate, req)
    if isinstance(request, Error):
        return not_ok(request, context="PoolCreate")

    answer = check_require_admins(req)
    if isinstance(answer, Error):
        return not_ok(answer, context="PoolCreate")

    pool = Pool.get_by_name(request.name)
    if isinstance(pool, Pool):
        return not_ok(
            Error(
                code=ErrorCode.INVALID_REQUEST,
                errors=["pool with that name already exists"],
            ),
            context=repr(request),
        )

    logging.info(request)

    if request.autoscale:
        if request.autoscale.region is None:
            request.autoscale.region = get_base_region()
        else:
            if request.autoscale.region not in get_regions():
                return not_ok(
                    Error(code=ErrorCode.UNABLE_TO_CREATE,
                          errors=["invalid region"]),
                    context="poolcreate",
                )

        region = request.autoscale.region

        if request.autoscale.vm_sku not in list_available_skus(region):
            return not_ok(
                Error(
                    code=ErrorCode.UNABLE_TO_CREATE,
                    errors=[
                        "vm_sku '%s' is not available in the location '%s'" %
                        (request.autoscale.vm_sku, region)
                    ],
                ),
                context="poolcreate",
            )

    pool = Pool.create(
        name=request.name,
        os=request.os,
        arch=request.arch,
        managed=request.managed,
        client_id=request.client_id,
        autoscale=request.autoscale,
    )
    return ok(set_config(pool))
Example #15
0
def update_scale_in_protection(name: UUID, vm_id: UUID,
                               protect_from_scale_in: bool) -> Optional[Error]:
    instance_id = get_instance_id(name, vm_id)

    if isinstance(instance_id, Error):
        return instance_id

    compute_client = get_compute_client()
    resource_group = get_base_resource_group()

    try:
        instance_vm = compute_client.virtual_machine_scale_set_vms.get(
            resource_group, name, instance_id)
    except (ResourceNotFoundError, CloudError):
        return Error(
            code=ErrorCode.UNABLE_TO_FIND,
            errors=["unable to find vm instance: %s:%s" % (name, instance_id)],
        )

    new_protection_policy = VirtualMachineScaleSetVMProtectionPolicy(
        protect_from_scale_in=protect_from_scale_in)
    if instance_vm.protection_policy is not None:
        new_protection_policy = instance_vm.protection_policy
        new_protection_policy.protect_from_scale_in = protect_from_scale_in

    instance_vm.protection_policy = new_protection_policy

    try:
        compute_client.virtual_machine_scale_set_vms.begin_update(
            resource_group, name, instance_id, instance_vm)
    except (ResourceNotFoundError, CloudError, HttpResponseError) as err:
        if isinstance(err, HttpResponseError):
            err_str = str(err)
            instance_not_found = (
                " is not an active Virtual Machine Scale Set VM instanceId.")
            if (instance_not_found in err_str
                    and instance_vm.protection_policy.protect_from_scale_in is
                    False and protect_from_scale_in
                    == instance_vm.protection_policy.protect_from_scale_in):
                logging.info(
                    "Tried to remove scale in protection on node %s but the instance no longer exists"  # noqa: E501
                    % instance_id)
                return None
        return Error(
            code=ErrorCode.UNABLE_TO_UPDATE,
            errors=[
                "unable to set protection policy on: %s:%s" %
                (vm_id, instance_id)
            ],
        )

    logging.info("Successfully set scale in protection on node %s to %s" %
                 (vm_id, protect_from_scale_in))
    return None
Example #16
0
    def get_by_task_id(cls, task_id: UUID) -> Union[Error, "Task"]:
        tasks = cls.search(query={"task_id": [task_id]})
        if not tasks:
            return Error(code=ErrorCode.INVALID_REQUEST,
                         errors=["unable to find task"])

        if len(tasks) != 1:
            return Error(code=ErrorCode.INVALID_REQUEST,
                         errors=["error identifying task"])
        task = tasks[0]
        return task
Example #17
0
    def get_by_id(cls, scaleset_id: UUID) -> Union[Error, "Scaleset"]:
        scalesets = cls.search(query={"scaleset_id": [scaleset_id]})
        if not scalesets:
            return Error(code=ErrorCode.INVALID_REQUEST,
                         errors=["unable to find scaleset"])

        if len(scalesets) != 1:
            return Error(code=ErrorCode.INVALID_REQUEST,
                         errors=["error identifying scaleset"])
        scaleset = scalesets[0]
        return scaleset
Example #18
0
    def get_by_name(cls, name: PoolName) -> Union[Error, "Pool"]:
        pools = cls.search(query={"name": [name]})
        if not pools:
            return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find pool"])

        if len(pools) != 1:
            return Error(
                code=ErrorCode.INVALID_REQUEST, errors=["error identifying pool"]
            )
        pool = pools[0]
        return pool
Example #19
0
    def get_by_id(cls, pool_id: UUID) -> Union[Error, "Pool"]:
        pools = cls.search(query={"pool_id": [pool_id]})
        if not pools:
            return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find pool"])

        if len(pools) != 1:
            return Error(
                code=ErrorCode.INVALID_REQUEST, errors=["error identifying pool"]
            )
        pool = pools[0]
        return pool
Example #20
0
def dissociate_subnet(name: str, vnet: VirtualNetwork,
                      subnet: Subnet) -> Union[None, Error]:
    if subnet.network_security_group is None:
        return None
    resource_group = get_base_resource_group()
    nsg = get_nsg(name)
    if not nsg:
        return Error(
            code=ErrorCode.UNABLE_TO_FIND,
            errors=["cannot update nsg rules. nsg %s not found" % name],
        )
    if nsg.id != subnet.network_security_group.id:
        return Error(
            code=ErrorCode.UNABLE_TO_UPDATE,
            errors=[
                "subnet is not associated with this nsg.",
                "nsg %s, subnet: %s, subnet.nsg: %s" % (
                    nsg.id,
                    subnet.name,
                    subnet.network_security_group.id,
                ),
            ],
        )

    logging.info("dissociating subnet %s with nsg: %s %s", subnet.name,
                 resource_group, name)

    subnet.network_security_group = None
    network_client = get_network_client()
    try:
        network_client.subnets.begin_create_or_update(resource_group,
                                                      vnet.name, subnet.name,
                                                      subnet)
    except (ResourceNotFoundError, CloudError) as err:
        if is_concurrent_request_error(str(err)):
            logging.debug(
                "dissociate nsg with subnet had conflicts with ",
                "concurrent request, ignoring %s",
                err,
            )
            return None
        return Error(
            code=ErrorCode.UNABLE_TO_UPDATE,
            errors=[
                "Unable to dissociate nsg %s with subnet %s due to %s" % (
                    name,
                    subnet.name,
                    err,
                )
            ],
        )

    return None
Example #21
0
    def update_or_create(
        cls,
        region: Region,
        scaleset_id: UUID,
        machine_id: UUID,
        dst_port: int,
        duration: int,
    ) -> Union["ProxyForward", Error]:
        private_ip = get_scaleset_instance_ip(scaleset_id, machine_id)
        if not private_ip:
            return Error(
                code=ErrorCode.UNABLE_TO_PORT_FORWARD, errors=["no private ip for node"]
            )

        entries = cls.search_forward(
            scaleset_id=scaleset_id,
            machine_id=machine_id,
            dst_port=dst_port,
            region=region,
        )
        if entries:
            entry = entries[0]
            entry.endtime = datetime.datetime.utcnow() + datetime.timedelta(
                hours=duration
            )
            entry.save()
            return entry

        existing = [int(x.port) for x in entries]
        for port in PORT_RANGES:
            if port in existing:
                continue

            entry = cls(
                region=region,
                port=port,
                scaleset_id=scaleset_id,
                machine_id=machine_id,
                dst_ip=private_ip,
                dst_port=dst_port,
                endtime=datetime.datetime.utcnow() + datetime.timedelta(hours=duration),
            )
            result = entry.save(new=True)
            if isinstance(result, Error):
                logging.info("port is already used: %s", entry)
                continue

            return entry

        return Error(
            code=ErrorCode.UNABLE_TO_PORT_FORWARD, errors=["all forward ports used"]
        )
Example #22
0
def associate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]:
    resource_group = get_base_resource_group()
    nsg = get_nsg(name)
    if not nsg:
        return Error(
            code=ErrorCode.UNABLE_TO_FIND,
            errors=["cannot associate nic. nsg %s not found" % name],
        )

    if nsg.location != nic.location:
        return Error(
            code=ErrorCode.UNABLE_TO_UPDATE,
            errors=[
                "network interface and nsg have to be in the same region.",
                "nsg %s %s, nic: %s %s" %
                (nsg.name, nsg.location, nic.name, nic.location),
            ],
        )

    if nic.network_security_group and nic.network_security_group.id == nsg.id:
        logging.info("NIC %s and NSG %s already associated, not updating",
                     nic.name, name)
        return None

    logging.info("associating nic %s with nsg: %s %s", nic.name,
                 resource_group, name)

    nic.network_security_group = nsg
    network_client = get_network_client()
    try:
        network_client.network_interfaces.begin_create_or_update(
            resource_group, nic.name, nic)
    except (ResourceNotFoundError, CloudError) as err:
        if is_concurrent_request_error(str(err)):
            logging.debug(
                "associate NSG with NIC had conflicts",
                "with concurrent request, ignoring %s",
                err,
            )
            return None
        return Error(
            code=ErrorCode.UNABLE_TO_UPDATE,
            errors=[
                "Unable to associate nsg %s with nic %s due to %s" % (
                    name,
                    nic.name,
                    err,
                )
            ],
        )

    return None
Example #23
0
def set_allowed(name: str,
                sources: NetworkSecurityGroupConfig) -> Union[None, Error]:
    resource_group = get_base_resource_group()
    nsg = get_nsg(name)
    if not nsg:
        return Error(
            code=ErrorCode.UNABLE_TO_FIND,
            errors=["cannot update nsg rules. nsg %s not found" % name],
        )

    logging.info(
        "setting allowed incoming connection sources for nsg: %s %s",
        resource_group,
        name,
    )
    all_sources = sources.allowed_ips + sources.allowed_service_tags
    security_rules = []
    # NSG security rule priority range defined here:
    # https://docs.microsoft.com/en-us/azure/virtual-network/network-security-groups-overview
    min_priority = 100
    # NSG rules per NSG limits:
    # https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/azure-subscription-service-limits?toc=/azure/virtual-network/toc.json#networking-limits
    max_rule_count = 1000
    if len(all_sources) > max_rule_count:
        return Error(
            code=ErrorCode.INVALID_REQUEST,
            errors=[
                "too many rules provided %d. Max allowed: %d" %
                ((len(all_sources)), max_rule_count),
            ],
        )

    priority = min_priority
    for src in all_sources:
        security_rules.append(
            SecurityRule(
                name="Allow" + str(priority),
                protocol="*",
                source_port_range="*",
                destination_port_range="*",
                source_address_prefix=src,
                destination_address_prefix="*",
                access=SecurityRuleAccess.ALLOW,
                priority=priority,  # between 100 and 4096
                direction="Inbound",
            ))
        # Will not exceed `max_rule_count` or max NSG priority (4096)
        # due to earlier check of `len(all_sources)`.
        priority += 1

    nsg.security_rules = security_rules
    return update_nsg(nsg)
Example #24
0
    def get_by_id(cls, webhook_id: UUID) -> Result["Webhook"]:
        webhooks = cls.search(query={"webhook_id": [webhook_id]})
        if not webhooks:
            return Error(code=ErrorCode.INVALID_REQUEST,
                         errors=["unable to find webhook"])

        if len(webhooks) != 1:
            return Error(
                code=ErrorCode.INVALID_REQUEST,
                errors=["error identifying Notification"],
            )
        webhook = webhooks[0]
        return webhook
def check_require_admins_impl(
    config: InstanceConfig, user_info: UserInfo
) -> Optional[Error]:
    if config.require_admin_privileges:
        return None

    if config.admins is None:
        return Error(code=ErrorCode.UNAUTHORIZED, errors=["pool modification disabled"])

    if user_info.object_id in config.admins:
        return None

    return Error(code=ErrorCode.UNAUTHORIZED, errors=["not authorized to manage pools"])
Example #26
0
def dissociate_nic(name: str, nic: NetworkInterface) -> Union[None, Error]:
    if nic.network_security_group is None:
        return None
    resource_group = get_base_resource_group()
    nsg = get_nsg(name)
    if not nsg:
        return Error(
            code=ErrorCode.UNABLE_TO_FIND,
            errors=["cannot update nsg rules. nsg %s not found" % name],
        )
    if nsg.id != nic.network_security_group.id:
        return Error(
            code=ErrorCode.UNABLE_TO_UPDATE,
            errors=[
                "network interface is not associated with this nsg.",
                "nsg %s, nic: %s, nic.nsg: %s" % (
                    nsg.id,
                    nic.name,
                    nic.network_security_group.id,
                ),
            ],
        )

    logging.info("dissociating nic %s with nsg: %s %s", nic.name,
                 resource_group, name)

    nic.network_security_group = None
    network_client = get_network_client()
    try:
        network_client.network_interfaces.begin_create_or_update(
            resource_group, nic.name, nic)
    except (ResourceNotFoundError, CloudError) as err:
        if is_concurrent_request_error(str(err)):
            logging.debug(
                "dissociate nsg with nic had conflicts with ",
                "concurrent request, ignoring %s",
                err,
            )
            return None
        return Error(
            code=ErrorCode.UNABLE_TO_UPDATE,
            errors=[
                "Unable to dissociate nsg %s with nic %s due to %s" % (
                    name,
                    nic.name,
                    err,
                )
            ],
        )

    return None
Example #27
0
def not_ok(error: Error,
           *,
           status_code: int = 400,
           context: Union[str, UUID]) -> HttpResponse:
    if 400 <= status_code and status_code <= 599:
        logging.error("request error - %s: %s" % (str(context), error.json()))

        return HttpResponse(error.json(),
                            status_code=status_code,
                            mimetype="application/json")
    else:
        raise Exception(
            "status code %s is not int the expected range [400; 599]" %
            status_code)
Example #28
0
    def get_by_id(cls, notification_id: UUID) -> Result["Notification"]:
        notifications = cls.search(query={"notification_id": [notification_id]})
        if not notifications:
            return Error(
                code=ErrorCode.INVALID_REQUEST, errors=["unable to find Notification"]
            )

        if len(notifications) != 1:
            return Error(
                code=ErrorCode.INVALID_REQUEST,
                errors=["error identifying Notification"],
            )
        notification = notifications[0]
        return notification
Example #29
0
def post(req: func.HttpRequest) -> func.HttpResponse:
    request = parse_request(TaskConfig, req)
    if isinstance(request, Error):
        return not_ok(request, context="task create")

    user_info = parse_jwt_token(req)
    if isinstance(user_info, Error):
        return not_ok(user_info, context="task create")

    try:
        check_config(request)
    except TaskConfigError as err:
        return not_ok(
            Error(code=ErrorCode.INVALID_REQUEST, errors=[str(err)]),
            context="task create",
        )

    if "dryrun" in req.params:
        return ok(BoolResult(result=True))

    job = Job.get(request.job_id)
    if job is None:
        return not_ok(
            Error(code=ErrorCode.INVALID_REQUEST,
                  errors=["unable to find job"]),
            context=request.job_id,
        )

    if job.state not in [JobState.enabled, JobState.init]:
        return not_ok(
            Error(
                code=ErrorCode.UNABLE_TO_ADD_TASK_TO_JOB,
                errors=["unable to add a job in state: %s" % job.state.name],
            ),
            context=job.job_id,
        )

    if request.prereq_tasks:
        for task_id in request.prereq_tasks:
            prereq = Task.get_by_task_id(task_id)
            if isinstance(prereq, Error):
                return not_ok(prereq, context="task create prerequisite")

    task = Task.create(config=request,
                       job_id=request.job_id,
                       user_info=user_info)
    if isinstance(task, Error):
        return not_ok(task, context="task create invalid pool")
    return ok(task)
Example #30
0
 def get_error() -> Error:
     return Error(
         code=ErrorCode.VM_CREATE_FAILED,
         errors=[
             "The scaleset is expected to have exactly 1 user assigned identity"
         ],
     )