コード例 #1
0
    def __init__(self, msl_alt, url, username, password):
        self.msl_alt = msl_alt
        self.url = url
        self.username = username
        self.password = password

        self.client = Client(url, username, password)
コード例 #2
0
ファイル: upload_targets.py プロジェクト: jesseaster/interop
def main(url, username, password, target_filepath, imagery_dir):
    """Program to load targets and upload via interoperability.

    Args:
        url: The interoperability URL.
        username: The interoperability username.
        password: The interoperability password.
        target_filepath: Filepath to the tab-delimited target file.
        imagery_dir: Base to form paths to imagery files.
    """
    # Load target details.
    targets = load_target_file(target_filepath)

    # Form full imagery filepaths.
    targets = [(t, os.path.join(imagery_dir, i)) for t, i in targets]
    # Validate filepath for each image.
    for _, image_filepath in targets:
        if not os.path.exists(image_filepath):
            raise ValueError('Could not find imagery file: %s' %
                             image_filepath)
    # Validate type of each image.
    for _, image_filepath in targets:
        image_type = imghdr.what(image_filepath)
        if image_type not in ['jpeg', 'png']:
            raise ValueError('Invalid imagery type: %s' % image_type)

    # Create client and upload targets.
    client = Client(url, username, password)
    for target, image_filepath in targets:
        image_data = None
        with open(image_filepath, 'rb') as f:
            image_data = f.read()
        target = client.post_target(target)
        client.put_target_image(target.id, image_data)
コード例 #3
0
ファイル: upload_targets.py プロジェクト: APTRG/interop
def main(url, username, password, target_filepath, imagery_dir):
    """Program to load targets and upload via interoperability.

    Args:
        url: The interoperability URL.
        username: The interoperability username.
        password: The interoperability password.
        target_filepath: Filepath to the tab-delimited target file.
        imagery_dir: Base to form paths to imagery files.
    """
    # Load target details.
    targets = load_target_file(target_filepath)

    # Form full imagery filepaths.
    targets = [(t, os.path.join(imagery_dir, i)) for t, i in targets]
    # Validate filepath for each image.
    for _, image_filepath in targets:
        if not os.path.exists(image_filepath):
            raise ValueError('Could not find imagery file: %s' %
                             image_filepath)
    # Validate type of each image.
    for _, image_filepath in targets:
        image_type = imghdr.what(image_filepath)
        if image_type not in ['jpeg', 'png']:
            raise ValueError('Invalid imagery type: %s' % image_type)

    # Create client and upload targets.
    client = Client(url, username, password)
    for target, image_filepath in targets:
        image_data = None
        with open(image_filepath, 'rb') as f:
            image_data = f.read()
        target = client.post_target(target)
        client.put_target_image(target.id, image_data)
コード例 #4
0
    def post_telemetry(self, location, heading):
        """
        Post the drone's telemetry information.

        :param location:    The location to post.
        :type location:     Location
        :param heading:     The UAV's heading.
        :type heading:      Float
        """
        missing_data = True

        while missing_data:
            try:
                telem_upload_data = Telemetry(
                    location.get_lat(), location.get_lon(),
                    location.get_alt() + GCSSettings.MSL_ALT, heading)
                missing_data = False
            except ConnectionError:
                sleep(GCSSettings.INTEROP_DISCONNECT_RETRY_RATE)

                try:
                    self.client = Client(GCSSettings.INTEROP_URL,
                                         GCSSettings.INTEROP_USERNAME,
                                         GCSSettings.INTEROP_PASSWORD)
                except:
                    print("Failed to connect to Interop., retrying...")

        self.client.post_telemetry(telem_upload_data)
コード例 #5
0
class ApiBridge(object):
    def __init__(self):
        # client = interop.Client(url='http://172.17.0.1:8000', username='******',password='******')
        # missions = client.get_missions()
        # print(missions)
        # stationary_obstacles, moving_obstacles = client.get_obstacles()
        # print(stationary_obstacles, moving_obstacles)
        server = os.getenv('TEST_INTEROP_SERVER', 'http://localhost:8000')
        username = os.getenv('TEST_INTEROP_USER', 'testuser')
        password = os.getenv('TEST_INTEROP_USER_PASS', 'testpass')
        admin_username = os.getenv('TEST_INTEROP_ADMIN', 'testadmin')
        admin_password = os.getenv('TEST_INTEROP_ADMIN_PASS', 'testpass')
        """Create a logged in Client."""
        # Create an admin client to clear cache.
        self.admin_client = Client(server, admin_username, admin_password)
        self.admin_client.get('/api/clear_cache')
        # Test rest with non-admin clients.
        self.client = Client(server, username, password)
        self.async_client = AsyncClient(server, username, password)

    def getObsta(self):
        """Test getting missions."""
        async_future = self.async_client.get_obstacles()
        async_stationary, async_moving = async_future.result()
        return async_stationary, async_moving

    def getMis(self):
        """Test getting missions."""
        # missions = self.client.get_missions()
        async_missions = self.async_client.get_missions().result()
        return async_missions
コード例 #6
0
    def get_active_mission(self):
        """
        Get the active mission and return it. If no missions are active,
        then this function will return None.

        :return:        Active Mission.
        :return type:   Mission / None
        """
        missing_data = True

        while missing_data:
            try:
                missions = self.client.get_missions()
                missing_data = False
            except ConnectionError:
                sleep(GCSSettings.INTEROP_DISCONNECT_RETRY_RATE)  # todo: 0.5

                try:
                    self.client = Client(GCSSettings.INTEROP_URL,
                                         GCSSettings.INTEROP_USERNAME,
                                         GCSSettings.INTEROP_PASSWORD)
                except:
                    print("Failed to connect to Interop., retrying...")

        for mission in missions:
            if mission.active:
                return mission

        return None
コード例 #7
0
class InteropClientConverter:
    def __init__(self, msl_alt, url, username, password):
        self.msl_alt = msl_alt
        self.url = url
        self.username = username
        self.password = password

        self.client = Client(url, username, password)

    def post_telemetry(self, location, heading):
        """
        Post the drone's telemetry information

        :param location: The location to post
        :type location: Location
        :param heading: The UAV's heading
        :type heading: Float
        """
        telem_upload_data = Telemetry(location.get_lat(), location.get_lon(),
                                      location.get_alt() + self.msl_alt,
                                      heading)

        self.client.post_telemetry(telem_upload_data)

    def get_obstacles(self):
        """
        Return the obstacles.

        Returned in the format: [StationaryObstacle], [MovingObstacle]
        """
        stationary_obstacles, moving_obstacles = self.client.get_obstacles()

        return stationary_obstacles, moving_obstacles

    def get_active_mission(self):
        """
        Get the active mission and return it. If no missions are active, return
        None

        Returned in the format: Mission
        """
        missions = self.client.get_missions()

        for mission in missions:
            if mission.active:
                return mission

        return None

    def get_client(self):
        """
        Return's the client
        """
        return self.client
コード例 #8
0
def myfunc():
    server = os.getenv('TEST_INTEROP_SERVER', 'http://localhost:8000')
    username = os.getenv('TEST_INTEROP_USER', 'testuser')
    password = os.getenv('TEST_INTEROP_USER_PASS', 'testpass')
    admin_username = os.getenv('TEST_INTEROP_ADMIN', 'testadmin')
    admin_password = os.getenv('TEST_INTEROP_ADMIN_PASS', 'testpass')
    admin_client = Client(server, admin_username, admin_password)
    admin_client.get('/api/clear_cache')
    async_client = AsyncClient(server, username, password)
    async_missions = async_client.get_missions().result()
    return async_missions
コード例 #9
0
    def post_manual_standard_target(self, target, image_file_path):
        """
        POST a standard ODLC object to the interoperability server.

        :param target: The ODLC target object.
        :type target: JSON, with the form:
        {
            "latitude" : float,
            "longitude" : float,
            "orientation" : string,
            "shape" : string,
            "background_color" : string,
            "alphanumeric" : string,
            "alphanumeric_color" : string,
        }

        :param image_file_path: The ODLC target image file name.
        :type image_file_path:  String

        :return: ID of the posted target.
        """
        odlc_target = Odlc(
            type="standard",
            autonomous=False,
            latitude=target["latitude"],
            longitude=target["longitude"],
            orientation=target["orientation"],
            shape=target["shape"],
            background_color=target["background_color"],
            alphanumeric=target["alphanumeric"],
            alphanumeric_color=target["alphanumeric_color"],
            description='Flint Hill School -- ODLC Standard Target Submission')

        missing_data = True

        while missing_data:
            try:
                returned_odlc = self.client.post_odlc(odlc_target)
                missing_data = False
            except ConnectionError:
                sleep(GCSSettings.INTEROP_DISCONNECT_RETRY_RATE)  # todo: 0.5

                try:
                    self.client = Client(GCSSettings.INTEROP_URL,
                                         GCSSettings.INTEROP_USERNAME,
                                         GCSSettings.INTEROP_PASSWORD)
                except:
                    print("Failed to connect to Interop, retrying...")

        with open(image_file_path) as img_file:
            self.client.post_odlc_image(returned_odlc.id, img_file.read())
コード例 #10
0
 def __init__(self):
     # client = interop.Client(url='http://172.17.0.1:8000', username='******',password='******')
     # missions = client.get_missions()
     # print(missions)
     # stationary_obstacles, moving_obstacles = client.get_obstacles()
     # print(stationary_obstacles, moving_obstacles)
     server = os.getenv('TEST_INTEROP_SERVER', 'http://localhost:8000')
     username = os.getenv('TEST_INTEROP_USER', 'testuser')
     password = os.getenv('TEST_INTEROP_USER_PASS', 'testpass')
     admin_username = os.getenv('TEST_INTEROP_ADMIN', 'testadmin')
     admin_password = os.getenv('TEST_INTEROP_ADMIN_PASS', 'testpass')
     """Create a logged in Client."""
     # Create an admin client to clear cache.
     self.admin_client = Client(server, admin_username, admin_password)
     self.admin_client.get('/api/clear_cache')
     # Test rest with non-admin clients.
     self.client = Client(server, username, password)
     self.async_client = AsyncClient(server, username, password)
コード例 #11
0
    def get_obstacles(self):
        """
        Return the obstacles.

        :return: [StationaryObstacle], [MovingObstacle]
        """
        missing_data = True

        while missing_data:
            try:
                stationary_obstacles, moving_obstacles = self.client.get_obstacles(
                )
                missing_data = False
            except ConnectionError:
                sleep(GCSSettings.INTEROP_DISCONNECT_RETRY_RATE)  # todo: 0.5

                try:
                    self.client = Client(GCSSettings.INTEROP_URL,
                                         GCSSettings.INTEROP_USERNAME,
                                         GCSSettings.INTEROP_PASSWORD)
                except:
                    print("Failed to connect to Interop., retrying...")

        return stationary_obstacles, moving_obstacles
コード例 #12
0
def main():

    parser = argparse.ArgumentParser(
        description='ABES-AART AUVSI Ground Station')
    parser.add_argument('--url',
                        required=True,
                        help='URL for interoperability.')
    parser.add_argument('--username',
                        required=True,
                        help='Username for interoperability.')
    parser.add_argument('--password',
                        required=True,
                        help='Password for interoperability.')
    parser.add_argument(
        '--spoofTelem',
        help=
        'Select this for random testing , press a key for sending telemetry once to server'
    )
    parser.add_argument(
        '--mavlink',
        help='Required argument for accessing and forwarding telemetry data')
    args = parser.parse_args()
    print(args)
    client = Client(args.url, args.username,
                    args.password)  #TODO : CREATE ARGS BASED

    activeMission = getActiveMission(client)
    print activeMission

    if args.spoofTelem:
        spoofTelem(client)
    else:
        proxy_mavlink(args.mavlink, client)

    while True:
        if raw_input():
            break
        else:
            #testWaypoints(client,activeMission)
            testTelemtry(client)
コード例 #13
0
ファイル: get_missions.py プロジェクト: matcheydj/interop
def main(url, username, password):
    """Program to get and print the mission details."""
    client = Client(url, username, password)
    missions = client.get_missions()
    print(missions)
コード例 #14
0
def main():
    # Setup logging
    logging.basicConfig(
        level=logging.INFO,
        stream=sys.stdout,
        format='%(asctime)s: %(name)s: %(levelname)s: %(message)s')

    # Parse command line args.
    parser = argparse.ArgumentParser(description='AUVSI SUAS Interop CLI.')
    parser.add_argument('--url',
                        required=True,
                        help='URL for interoperability.')
    parser.add_argument('--username',
                        required=True,
                        help='Username for interoperability.')
    parser.add_argument('--password', help='Password for interoperability.')

    subparsers = parser.add_subparsers(help='Sub-command help.')

    subparser = subparsers.add_parser('missions', help='Get missions.')
    subparser.set_defaults(func=missions)

    subparser = subparsers.add_parser(
        'odlcs',
        help='Upload odlcs.',
        description='''Download or upload odlcs to/from the interoperability
server.

Without extra arguments, this prints all odlcs that have been uploaded to the
server.

With --odlc_dir, this uploads new odlcs to the server.

This tool searches for odlc JSON and images files within --odlc_dir
conforming to the 2017 Object File Format and uploads the odlc
characteristics and thumbnails to the interoperability server.

There is no deduplication logic. Odlcs will be uploaded multiple times, as
unique odlcs, if the tool is run multiple times.''',
        formatter_class=argparse.RawDescriptionHelpFormatter)
    subparser.set_defaults(func=odlcs)
    subparser.add_argument(
        '--odlc_dir',
        help='Enables odlc upload. Directory containing odlc data.')
    subparser.add_argument(
        '--team_id',
        help='''The username of the team on whose behalf to submit odlcs.
Must be admin user to specify.''')
    subparser.add_argument(
        '--actionable_override',
        help='''Manually sets all the odlcs in the odlc dir to be
actionable. Must be admin user to specify.''')

    subparser = subparsers.add_parser('probe', help='Send dummy requests.')
    subparser.set_defaults(func=probe)
    subparser.add_argument('--interop_time',
                           type=float,
                           default=1.0,
                           help='Time between sent requests (sec).')

    subparser = subparsers.add_parser(
        'mavlink',
        help='''Receive MAVLink GLOBAL_POSITION_INT packets and
forward as telemetry to interop server.''')
    subparser.set_defaults(func=mavlink)
    subparser.add_argument(
        '--device',
        type=str,
        help='pymavlink device name to read from. E.g. tcp:localhost:8080.')

    # Parse args, get password if not provided.
    args = parser.parse_args()
    if args.password:
        password = args.password
    else:
        password = getpass.getpass('Interoperability Password: ')

    # Create client and dispatch subcommand.
    client = Client(args.url, args.username, password)
    args.func(args, client)
コード例 #15
0
 def __init__(self):
     self.client = Client(GCSSettings.INTEROP_URL,
                          GCSSettings.INTEROP_USERNAME,
                          GCSSettings.INTEROP_PASSWORD)
コード例 #16
0
class InteropClientConverter(object):
    def __init__(self):
        self.client = Client(GCSSettings.INTEROP_URL,
                             GCSSettings.INTEROP_USERNAME,
                             GCSSettings.INTEROP_PASSWORD)

    def post_telemetry(self, location, heading):
        """
        Post the drone's telemetry information

        :param location: The location to post
        :type location: Location
        :param heading: The UAV's heading
        :type heading: Float
        """
        telem_upload_data = Telemetry(location.get_lat(), location.get_lon(),
                                      location.get_alt() + GCSSettings.MSL_ALT,
                                      heading)

        self.client.post_telemetry(telem_upload_data)

    def get_obstacles(self):
        """
        Return the obstacles.

        Returned in the format: [StationaryObstacle], [MovingObstacle]
        """
        stationary_obstacles, moving_obstacles = self.client.get_obstacles()

        return stationary_obstacles, moving_obstacles

    def get_active_mission(self):
        """
        Get the active mission and return it. If no missions are active, return
        None

        :return: Active Mission
        :return type: Mission / None
        """
        missions = self.client.get_missions()

        for mission in missions:
            if mission.active:
                return mission

        return None

    def post_manual_standard_target(self, target, image_file_path):
        """
        POST a standard ODLC object to the interoperability server.

        :param target: The ODLC target object
        :type target: JSON, with the form:
        {
            "latitude" : float,
            "longitude" : float,
            "orientation" : string,
            "shape" : string,
            "background_color" : string,
            "alphanumeric" : string,
            "alphanumeric_color" : string,
        }

        :param image_file_path: The ODLC target image file name
        :type image_file_path: String

        :return: ID of the posted target
        """
        odlc_target = Odlc(
            type="standard",
            autonomous=False,
            latitude=target["latitude"],
            longitude=target["longitude"],
            orientation=target["orientation"],
            shape=target["shape"],
            background_color=target["background_color"],
            alphanumeric=target["alphanumeric"],
            alphanumeric_color=target["alphanumeric_color"],
            description='Flint Hill School -- ODLC Standard Target Submission')

        returned_odlc = self.client.post_odlc(odlc_target)

        with open(image_file_path) as img_file:
            self.client.post_odlc_image(returned_odlc.id, img_file.read())

    def post_manual_emergent_target(self, target, image_file_path):
        """
        POST an emergent ODLC object to the interoperability server.

        :param target: The ODLC target object
        :type target: JSON, with the form:
        {
            "latitude" : float,
            "longitude" : float,
            "emergent_description" : String
        }

        :param image_file_path: The ODLC target image file name
        :type image_file_path: String

        :return: ID of the posted target
        """
        odlc_target = Odlc(
            type="emergent",
            autonomous=False,
            latitude=target["latitude"],
            longitude=target["longitude"],
            description='Flint Hill School -- ODLC Emergent Target Submission: '
            + str(target["emergent_description"]))

        returned_odlc = self.client.post_odlc(odlc_target)

        with open(image_file_path) as img_file:
            self.client.post_odlc_image(returned_odlc.id, img_file.read())

    def post_autonomous_target(self, target_info):
        """
        POST a standard ODLC object to the interoperability server.

        :param target: The ODLC target object
        :type target: JSON, with the form:
        {
            "latitude" : float,
            "longitude" : float,
            "orientation" : string,
            "shape" : string,
            "background_color" : string,
            "alphanumeric" : string,
            "alphanumeric_color" : string,
        }

        :param image_file_path: The ODLC target image file name
        :type image_file_path: String

        :return: ID of the posted target
        """

        image_file_path = target_info["target"][0]["current_crop_path"]
        odlc_target = Odlc(
            type="standard",
            autonomous=True,
            latitude=target_info["target"][0]["latitude"],
            longitude=target_info["target"][0]["longitude"],
            orientation=target_info["target"][0]["target_orientation"],
            shape=target_info["target"][0]["target_shape_type"],
            background_color=target_info["target"][0]["target_shape_color"],
            alphanumeric=target_info["target"][0]["target_letter"],
            alphanumeric_color=target_info["target"][0]["target_letter_color"],
            description="Flint Hill School -- ODLC Standard Target Submission")

        returned_odlc = self.client.post_odlc(odlc_target)

        with open(image_file_path) as img_file:
            self.client.post_odlc_image(returned_odlc.id, img_file.read())
コード例 #17
0
    def post_manual_emergent_target(self, target, image_file_path):
        """
        POST an emergent ODLC object to the interoperability server.

        :param target: The ODLC target object.
        :type target: JSON, with the form of
            {
                "latitude" : float,
                "longitude" : float,
                "emergent_description" : String
            }

        :param image_file_path: The ODLC target image file name.
        :type image_file_path:  String

        :return: ID of the posted target.
        """
        odlc_target = Odlc(
            type="emergent",
            autonomous=False,
            latitude=target["latitude"],
            longitude=target["longitude"],
            description='Flint Hill School -- ODLC Emergent Target Submission: '
            + str(target["emergent_description"]))

        # TODO: remove after testing
        """
        returned_odlc = self.client.post_odlc(odlc_target)

        with open(image_file_path) as img_file:
            self.client.post_odlc_image(returned_odlc.id, img_file.read())
        """
        missing_data = True

        while missing_data:
            try:
                returned_odlc = self.client.post_odlc(odlc_target)
                missing_data = False
            except ConnectionError:
                sleep(GCSSettings.INTEROP_DISCONNECT_RETRY_RATE)

                try:
                    self.client = Client(GCSSettings.INTEROP_URL,
                                         GCSSettings.INTEROP_USERNAME,
                                         GCSSettings.INTEROP_PASSWORD)
                except:
                    print("Failed to connect to Interop., retrying...")

        missing_data = True

        with open(image_file_path) as img_file:
            while missing_data:
                try:
                    self.client.post_odlc_image(returned_odlc.id,
                                                img_file.read())
                    missing_data = False
                except:
                    sleep(GCSSettings.INTEROP_DISCONNECT_RETRY_RATE)

                    try:
                        self.client = Client(GCSSettings.INTEROP_URL,
                                             GCSSettings.INTEROP_USERNAME,
                                             GCSSettings.INTEROP_PASSWORD)
                    except:
                        print("Failed to connect to Interop., retrying...")
コード例 #18
0
class TargetClassifier():
    def __init__(self, userid, password, checkpoint_dir, server_url=None):
        # Store Look Up Tables
        self.shapes = {
            0: 'n/a',
            1: 'circle',
            2: 'cross',
            3: 'heptagon',
            4: 'hexagon',
            5: 'octagon',
            6: 'pentagon',
            7: 'quarter_circle',
            8: 'rectangle',
            9: 'semicircle',
            10: 'square',
            11: 'star',
            12: 'trapezoid',
            13: 'triangle'
        }
        self.alphanums = {
            0: 'n/a',
            1: 'A',
            2: 'B',
            3: 'C',
            4: 'D',
            5: 'E',
            6: 'F',
            7: 'G',
            8: 'H',
            9: 'I',
            10: 'J',
            11: 'K',
            12: 'L',
            13: 'M',
            14: 'N',
            15: 'O',
            16: 'P',
            17: 'Q',
            18: 'R',
            19: 'S',
            20: 'T',
            21: 'U',
            22: 'V',
            23: 'W',
            24: 'X',
            25: 'Y',
            26: 'Z',
            27: '0',
            28: '1',
            29: '2',
            30: '3',
            31: '4',
            32: '5',
            33: '6',
            34: '7',
            35: '8',
            36: '9'
        }
        self.colors = {
            0: 'n/a',
            1: 'white',
            2: 'black',
            3: 'gray',
            4: 'red',
            5: 'blue',
            6: 'green',
            7: 'yellow',
            8: 'purple',
            9: 'brown',
            10: 'orange'
        }

        # Store userid
        self.userid = userid

        # IMPORTANT! Put updated mean standard values here
        self.mean = np.array([83.745, 100.718, 115.504])  # R, G, B
        self.stddev = np.array([53.764, 52.350, 59.265])  # R, G, B

        # Counters/trackers for interop
        self.target_id = 2  # Start at target #2

        # Interoperability client
        if server_url is not None:
            self.interop = Client(server_url, userid, password)
        else:
            self.interop = None
            print('Warning: No interop server connection')

        # Logging mode
        if LOGGING:
            self.logging_counter = 0

        # Build TensorFlow graphs
        assert os.path.isdir(checkpoint_dir)
        # Shape graph
        self.shape_graph = tf.Graph()
        with self.shape_graph.as_default():
            self.inputs_shape = tf.placeholder(
                tf.float32,
                shape=[None, IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS])
            self.logits_shape = wideresnet.inference(
                self.inputs_shape, 14,
                scope='shapes')  # 13 shapes + background
            variable_averages = tf.train.ExponentialMovingAverage(
                wideresnet.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)

            self.shape_sess = tf.Session()  # graph=self.shape_graph
            #shape_saver = tf.train.Saver()
            shape_ckpt = tf.train.get_checkpoint_state(
                os.path.join(checkpoint_dir, 'shape'))
            if shape_ckpt and shape_ckpt.model_checkpoint_path:
                print('Reading shape model parameters from %s' %
                      shape_ckpt.model_checkpoint_path)
                #shape_saver.restore(self.shape_sess, self.shape_ckpt.model_checkpoint_path)
                saver.restore(self.shape_sess,
                              shape_ckpt.model_checkpoint_path)
            else:
                print(
                    'Error restoring parameters for shape. Ensure checkpoint is stored in ${checkpoint_dir}/shape/'
                )
                # sys.exit(1)

        # Shape color graph
        self.shape_color_graph = tf.Graph()
        with self.shape_color_graph.as_default():
            self.inputs_shape_color = tf.placeholder(
                tf.float32,
                shape=[None, IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS])
            self.logits_shape_color = shallow_mlp.inference(
                self.inputs_shape_color, 11,
                scope='shape_color')  # 10 shape_colors + background
            variable_averages = tf.train.ExponentialMovingAverage(
                shallow_mlp.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)

            self.shape_color_sess = tf.Session(
            )  # graph=self.shape_color_graph
            #shape_color_saver = tf.train.Saver()
            shape_color_ckpt = tf.train.get_checkpoint_state(
                os.path.join(checkpoint_dir, 'shape_color'))
            if shape_color_ckpt and shape_color_ckpt.model_checkpoint_path:
                print('Reading shape_color model parameters from %s' %
                      shape_color_ckpt.model_checkpoint_path)
                #shape_color_saver.restore(self.shape_color_sess, self.shape_color_ckpt.model_checkpoint_path)
                saver.restore(self.shape_color_sess,
                              shape_color_ckpt.model_checkpoint_path)
            else:
                print(
                    'Error restoring parameters for shape_color. Ensure checkpoint is stored in ${checkpoint_dir}/shape_color/'
                )
                # sys.exit(1)

        # Alphanum graph
        self.alphanum_graph = tf.Graph()
        with self.alphanum_graph.as_default():
            self.inputs_alphanum = tf.placeholder(
                tf.float32,
                shape=[None, IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS])
            self.logits_alphanum = wideresnet.inference(
                self.inputs_alphanum, 37,
                scope='alphanums')  # 37 alphanums + background
            variable_averages = tf.train.ExponentialMovingAverage(
                wideresnet.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)

            self.alphanum_sess = tf.Session()
            #alphanum_saver = tf.train.Saver()
            alphanum_ckpt = tf.train.get_checkpoint_state(
                os.path.join(checkpoint_dir, 'alphanum'))
            if alphanum_ckpt and alphanum_ckpt.model_checkpoint_path:
                print('Reading alphanum model parameters from %s' %
                      alphanum_ckpt.model_checkpoint_path)
                #alphanum_saver.restore(self.alphanum_sess, self.alphanum_ckpt.model_checkpoint_path)
                saver.restore(self.alphanum_sess,
                              alphanum_ckpt.model_checkpoint_path)
            else:
                print(
                    'Error restoring parameters for alphanum. Ensure checkpoint is stored in ${checkpoint_dir}/alphanum/'
                )
                # sys.exit(1)

        # Alphanum color graph
        self.alphanum_color_graph = tf.Graph()
        with self.alphanum_color_graph.as_default():
            self.inputs_alphanum_color = tf.placeholder(
                tf.float32,
                shape=[None, IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS])
            self.logits_alphanum_color = mlp.inference(
                self.inputs_alphanum_color, 11,
                scope='letter_color')  # 10 alphanum_colors + background
            variable_averages = tf.train.ExponentialMovingAverage(
                mlp.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)

            self.alphanum_color_sess = tf.Session()
            #alphanum_color_saver = tf.train.Saver()
            alphanum_color_ckpt = tf.train.get_checkpoint_state(
                os.path.join(checkpoint_dir, 'alphanum_color'))
            if alphanum_color_ckpt and alphanum_color_ckpt.model_checkpoint_path:
                print('Reading alphanum_color model parameters from %s' %
                      alphanum_color_ckpt.model_checkpoint_path)
                #alphanum_color_saver.restore(self.alphanum_color_sess, self.alphanum_color_ckpt.model_checkpoint_path)
                saver.restore(self.alphanum_color_sess,
                              alphanum_color_ckpt.model_checkpoint_path)
            else:
                print(
                    'Error restoring parameters for alphanum_color. Ensure checkpoint is stored in ${checkpoint_dir}/alphanum_color/'
                )
                # sys.exit(1)

    def preprocess_image(self, image):
        ''' Preprocess image for classification
			Args:
				image: np.array containing raw input image
			Returns:
				image: np.array of size [1, width, height, depth]
		'''
        im = image.copy()

        # Change from BGR (OpenCV) to RGB
        b = im[:, :, 0].copy()
        im[:, :, 0] = im[:, :, 2]  # Put red channel in [:,:,0]
        im[:, :, 2] = b  # Put blue channel in [:,:,2]

        # Resize image as necessary
        if (np.greater(im.shape[:2], [IMAGE_SIZE, IMAGE_SIZE]).any()):
            # Scale down
            im = cv2.resize(im,
                            dsize=(IMAGE_SIZE, IMAGE_SIZE),
                            interpolation=cv2.INTER_AREA)
        elif (np.less(im.shape[:2], [IMAGE_SIZE, IMAGE_SIZE]).any()):
            # Scale up
            im = cv2.resize(im,
                            dsize=(IMAGE_SIZE, IMAGE_SIZE),
                            interpolation=cv2.INTER_CUBIC)

        # MeanStd normalization
        im = np.subtract(im, self.mean)
        im = np.divide(im, self.stddev)
        # Pad dimensions from 3-D to 4-D if necessary
        if len(im.shape) == 3:
            im = np.expand_dims(im, axis=0)
        return im

    def preprocess_image_hsv(self, image):
        ''' Preprocess image for classification
			Args:
				image: np.array containing raw input image
			Returns:
				image: np.array of size [1, width, height, depth]
		'''
        im = image.copy()

        # Resize image as necessary
        if (np.greater(im.shape[:2], [IMAGE_SIZE, IMAGE_SIZE]).any()):
            # Scale down
            im = cv2.resize(im,
                            dsize=(IMAGE_SIZE, IMAGE_SIZE),
                            interpolation=cv2.INTER_AREA)
        elif (np.less(im.shape[:2], [IMAGE_SIZE, IMAGE_SIZE]).any()):
            # Scale up
            im = cv2.resize(im,
                            dsize=(IMAGE_SIZE, IMAGE_SIZE),
                            interpolation=cv2.INTER_CUBIC)

        # Change from BGR to HSV
        im = cv2.cvtColor(im, cv2.COLOR_BGR2HSV)
        im = im.astype(np.float32)

        # Scale to [-1,1]
        im[:, :, 0] = np.subtract(im[:, :, 0], 89.5)  # Hue
        im[:, :, 0] = np.divide(im[:, :, 0], 89.5)  # Hue
        im[:, :, 1] = np.subtract(im[:, :, 1], 127.5)  # Saturation
        im[:, :, 1] = np.divide(im[:, :, 1], 127.5)  # Saturation
        im[:, :, 2] = np.subtract(im[:, :, 2], 127.5)  # Value
        im[:, :, 2] = np.divide(im[:, :, 2], 127.5)  # Value

        if len(im.shape) == 3:
            im = np.expand_dims(im, axis=0)

        return im

    def classify_shape(self, image):
        ''' Extract the shape of the target
				Args: The preprocessed input image, of shape 
			Returns:
				str: The classified shape, in human readable text
		'''
        try:
            predictions = self.shape_sess.run(
                [self.logits_shape], feed_dict={self.inputs_shape: image})
            class_out = np.argmax(predictions)
            confidence = np.max(predictions)
            if confidence >= 0.50:
                return self.shapes[class_out]
            else:
                print('Shape %s rejected at confidence %f' %
                      (self.shapes[class_out], confidence))
                return None
        # If checkpoint not loaded, ignore error and return None
        except tf.errors.FailedPreconditionError:
            return None

    def classify_shape_color(self, image):
        ''' Extract the shape color of the target
				Args: The input image
				Returns:
					str: The classified color, in human readable text
		'''
        try:
            predictions = self.shape_color_sess.run(
                [self.logits_shape_color],
                feed_dict={self.inputs_shape_color: image})
            class_out = np.argmax(predictions)
            confidence = np.max(predictions)
            if confidence >= 0.50:
                return self.colors[class_out]
            else:
                print('Shape color %s rejected at confidence %f' %
                      (self.colors[class_out], confidence))
                return None
        # If checkpoint not loaded, ignore error and return None
        except tf.errors.FailedPreconditionError:
            return None

    def classify_letter(self, image):
        ''' Extract the letter color of the target
				Args: The input image
				Returns: 
					str: The classified letter, in human readable text
					str: Amount rotated clockwise, in degrees (int)
		'''
        try:
            rot = 0
            class_out_dict = {}
            image = image.copy().squeeze()
            (h, w) = image.shape[:2]
            center = (w / 2, h / 2)
            while (rot < 360):
                # Rotate image clockwise by rot degrees
                M = cv2.getRotationMatrix2D(center, rot, 1.0)
                image_rot = cv2.warpAffine(image, M, (w, h))
                image_rot = np.expand_dims(image_rot, axis=0)
                predictions = self.alphanum_sess.run(
                    [self.logits_alphanum],
                    feed_dict={self.inputs_alphanum: image_rot})
                class_out_dict[np.max(predictions)] = (
                    np.argmax(predictions), rot
                )  # TODO: Handle duplicate confidences
                rot += 22.5  # 45 degree stride. If computation budget allows, consider increasing to 22.5 deg
            confidence = max(
                class_out_dict)  # Maximum confidence from classifications
            #class_out = np.argmax(predictions)
            class_out, rot_out = class_out_dict[confidence]
            if confidence >= 0.50:
                return self.alphanums[class_out], rot_out
            else:
                print('Letter %s rejected at confidence %f' %
                      (self.alphanums[class_out], confidence))
                return None, None
        # If checkpoint not loaded, ignore error and return None
        except tf.errors.FailedPreconditionError:
            return None, None

    def classify_letter_color(self, image):
        ''' Extract the letter color of the target
				Args: The input image
				Returns:
					str: The classified color, in human readable text
		'''
        try:
            predictions = self.alphanum_color_sess.run(
                [self.logits_alphanum_color],
                feed_dict={self.inputs_alphanum_color: image})
            class_out = np.argmax(predictions)
            confidence = np.max(predictions)
            if confidence >= 0.50:
                return self.colors[class_out]
            else:
                print('Letter color %s rejected at confidence %f' %
                      (self.colors[class_out], confidence))
                return None
        # If checkpoint not loaded, ignore error and return None
        except tf.errors.FailedPreconditionError:
            return None

    def check_valid(self, packet):
        ''' Check whether the prepared output packet is valid
				Args:
					dict: dictionary (JSON) of proposed output packet
				Returns:
					bool: True if packet is valid, False if not
		'''
        # FIXME: Delete this part
        labels = [
            "shape", "alphanumeric", "backgorund_color", "alphanumeric_color"
        ]
        for key, value in packet.iteritems():
            # Background class, flagged "n/a" in our translation key
            #if (value == "n/a") and key != "description":
            #	return False
            if (value != "n/a" and value != None) and key in labels:
                print(value)
                return True
            # Background and alphanumeric color should never be the same
            #if packet['background_color'] == packet['alphanumeric_color']:
            #	return False
            # TODO: Check for valid lat/lon

        #return True
        return False

    def check_duplicate(self, target):
        ''' Utility function to check if target has already been submitted
			Args:
				target: Target to check
			Returns:
				retval: bool, True if duplicate exists
		'''
        if not self.interop:
            return None
        targetLocation = (target.latitude, target.longitude)
        targets = self.interop.get_targets()
        for t in targets:
            tLocation = (t.latitude, t.longitude)
            if self.calc_distance(targetLocation, tLocation) < 0.00015:
                return True
        return False

    def _calc_distance(self, a, b):
        ''' Utility function to calculate the distance between two arrays
			Args:
				a: an array of numbers
				b: an array of numberes
			Returns:
				distance: absolute Euclidian distance between a and b
		'''
        a = np.array(a)
        b = np.array(b)
        assert (a.shape == b.shape)
        return np.sqrt(np.sum(np.power(np.subtract(a, b), 2)))

    def classify_and_maybe_transmit(self, image, location, orientation):
        ''' Main worker function for image classification. Transmits depending on validity
			Args:
				image: np.array of size [width, height, depth]
				location: tuple of GPS coordinates as (lat, lon)
				orientation: degree value in range [-180, 180],
							 where 0 represents due north and 90 represents due east
		'''
        if image is None:
            return False

        image_orig = image

        image = self.preprocess_image(image)
        imageHSV = self.preprocess_image_hsv(image_orig)

        # Run respective image classifiers
        shape = self.classify_shape(image)
        background_color = self.classify_shape_color(imageHSV)
        alphanumeric, rot = self.classify_letter(image)
        alphanumeric_color = self.classify_letter_color(imageHSV)
        latitude, longitude = location

        # Debugging only
        if DEBUG and orientation is None:
            orientation = 0
        if DEBUG and (latitude, longitude) == (None, None):
            latitude, longitude = (0, 0)

        # Extract orientation
        if orientation is not None and rot is not None:
            orientation += rot
            orientation = degToOrientation(orientation)
        else:
            orientation = None

        if DEBUG or VERBOSE:
            print 'Shape =', shape
            print 'Shape Color =', background_color
            print 'Alphanumeric =', alphanumeric
            print 'Alphanum Color =', alphanumeric_color
            print 'Lat, Lon =', latitude, ',', longitude
            print 'Orientation = ', orientation

        packet = {
            "user": self.userid,
            "type": "standard",
            "latitude": latitude,
            "longitude": longitude,
            "orientation": orientation,
            "shape": shape,
            "background_color": background_color,
            "alphanumeric": alphanumeric,
            "alphanumeric_color": alphanumeric_color,
            "description": None,
            "autonomous": True
        }

        if LOGGING:
            if not os.path.exists('processed'):
                os.mkdir('processed')
            savepath = 'processed/img_' + str(self.logging_counter)
            with open(savepath + '.json', 'w') as outfile:
                json.dump(packet, outfile)
            cv2.imwrite(savepath + '.jpg', image_orig)
            self.logging_counter += 1

        # Check for false positives or otherwise invalid targets
        packet_valid = self.check_valid(packet)
        if packet_valid:
            packet["id"] = self.target_id
            json_packet = json.dumps(packet)
            if self.interop is not None:
                if not os.path.exists('transmit'):
                    os.mkdir('transmit')
                savepath = 'transmit/img_' + str(self.target_id)
                with open(savepath + '.json', 'w') as outfile:
                    json.dump(packet, outfile)
                cv2.imwrite(savepath + '.jpg', image_orig)
                # Transmit data to interop server
                target = Target(id=self.target_id,
                                user=self.userid,
                                type='standard',
                                latitude=latitude,
                                longitude=longitude,
                                orientation=orientation,
                                shape=shape,
                                background_color=background_color,
                                alphanumeric=alphanumeric,
                                alphanumeric_color=alphanumeric_color,
                                description=None,
                                autonomous=True)
                if not self.check_duplicate(target):
                    try:
                        print('Transmitting target %d info' % self.target_id)
                        self.interop.post_target(target)
                    except Exception as e:
                        print(e)
                    # Transmit image to interop server
                    with open('transmit/img_%d.jpg' % self.target_id) as f:
                        im_data = f.read()
                        try:
                            print('Transmitting target %d image' %
                                  self.target_id)
                            self.interop.post_target_image(
                                self.target_id, im_data)
                        except Exception as e:
                            print(e)
                else:
                    print(
                        'INFO: Duplicate target detected at (%f, %f) lat/lon' %
                        (latitude, longitude))

                # TODO (optional): build database of detected targets, correct mistakes
            self.target_id += 1
        else:
            print('INFO: An invalid target was discarded')
        return packet_valid, packet
コード例 #19
0
ファイル: interop_cli.py プロジェクト: sedflix/interop
def main():
    # Setup logging
    logging.basicConfig(
        level=logging.DEBUG,
        stream=sys.stdout,
        format='%(asctime)s: %(name)s: %(levelname)s: %(message)s')

    # Parse command line args.
    parser = argparse.ArgumentParser(description='AUVSI SUAS Interop CLI.')
    parser.add_argument('--url',
                        required=True,
                        help='URL for interoperability.')
    parser.add_argument('--username',
                        required=True,
                        help='Username for interoperability.')
    parser.add_argument('--password', help='Password for interoperability.')
    subparsers = parser.add_subparsers(help='Sub-command help.')

    subparser = subparsers.add_parser('missions', help='Get missions.')
    subparser.set_defaults(func=missions)

    subparser = subparsers.add_parser(
        'targets',
        help='Upload targets.',
        description='''Download or upload targets to/from the interoperability
server.

Without extra arguments, this prints all targets that have been uploaded to the
server.

With --target_dir, this uploads new targets to the server.

This tool searches for target JSON and images files within --target_dir
conforming to the 2017 Object File Format and uploads the target
characteristics and thumbnails to the interoperability server.

Alternatively, if --legacy_filepath is specified, that file is parsed as the
legacy 2016 tab-delimited target file format. Image paths referenced in the
file are relative to --target_dir.

There is no deduplication logic. Targets will be uploaded multiple times, as
unique targets, if the tool is run multiple times.''',
        formatter_class=argparse.RawDescriptionHelpFormatter)
    subparser.set_defaults(func=targets)
    subparser.add_argument(
        '--legacy_filepath',
        help='Target file in the legacy 2016 tab-delimited format.')
    subparser.add_argument(
        '--target_dir',
        help='Enables target upload. Directory containing target data.')

    subparser = subparsers.add_parser('probe', help='Send dummy requests.')
    subparser.set_defaults(func=probe)
    subparser.add_argument('--interop_time',
                           type=float,
                           default=1.0,
                           help='Time between sent requests (sec).')

    # Parse args, get password if not provided.
    args = parser.parse_args()
    if args.password:
        password = args.password
    else:
        password = getpass.getpass('Interoperability Password: ')

    # Create client and dispatch subcommand.
    client = Client(args.url, args.username, password)
    args.func(args, client)
コード例 #20
0
from interop import Client

c = Client('htts://10.10.130.10:80', 'losangeles', '9886843481')

f = open('test.png', 'rb')
c.put_target_image(1, f)

コード例 #21
0
    def __init__(self, userid, password, checkpoint_dir, server_url=None):
        # Store Look Up Tables
        self.shapes = {
            0: 'n/a',
            1: 'circle',
            2: 'cross',
            3: 'heptagon',
            4: 'hexagon',
            5: 'octagon',
            6: 'pentagon',
            7: 'quarter_circle',
            8: 'rectangle',
            9: 'semicircle',
            10: 'square',
            11: 'star',
            12: 'trapezoid',
            13: 'triangle'
        }
        self.alphanums = {
            0: 'n/a',
            1: 'A',
            2: 'B',
            3: 'C',
            4: 'D',
            5: 'E',
            6: 'F',
            7: 'G',
            8: 'H',
            9: 'I',
            10: 'J',
            11: 'K',
            12: 'L',
            13: 'M',
            14: 'N',
            15: 'O',
            16: 'P',
            17: 'Q',
            18: 'R',
            19: 'S',
            20: 'T',
            21: 'U',
            22: 'V',
            23: 'W',
            24: 'X',
            25: 'Y',
            26: 'Z',
            27: '0',
            28: '1',
            29: '2',
            30: '3',
            31: '4',
            32: '5',
            33: '6',
            34: '7',
            35: '8',
            36: '9'
        }
        self.colors = {
            0: 'n/a',
            1: 'white',
            2: 'black',
            3: 'gray',
            4: 'red',
            5: 'blue',
            6: 'green',
            7: 'yellow',
            8: 'purple',
            9: 'brown',
            10: 'orange'
        }

        # Store userid
        self.userid = userid

        # IMPORTANT! Put updated mean standard values here
        self.mean = np.array([83.745, 100.718, 115.504])  # R, G, B
        self.stddev = np.array([53.764, 52.350, 59.265])  # R, G, B

        # Counters/trackers for interop
        self.target_id = 2  # Start at target #2

        # Interoperability client
        if server_url is not None:
            self.interop = Client(server_url, userid, password)
        else:
            self.interop = None
            print('Warning: No interop server connection')

        # Logging mode
        if LOGGING:
            self.logging_counter = 0

        # Build TensorFlow graphs
        assert os.path.isdir(checkpoint_dir)
        # Shape graph
        self.shape_graph = tf.Graph()
        with self.shape_graph.as_default():
            self.inputs_shape = tf.placeholder(
                tf.float32,
                shape=[None, IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS])
            self.logits_shape = wideresnet.inference(
                self.inputs_shape, 14,
                scope='shapes')  # 13 shapes + background
            variable_averages = tf.train.ExponentialMovingAverage(
                wideresnet.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)

            self.shape_sess = tf.Session()  # graph=self.shape_graph
            #shape_saver = tf.train.Saver()
            shape_ckpt = tf.train.get_checkpoint_state(
                os.path.join(checkpoint_dir, 'shape'))
            if shape_ckpt and shape_ckpt.model_checkpoint_path:
                print('Reading shape model parameters from %s' %
                      shape_ckpt.model_checkpoint_path)
                #shape_saver.restore(self.shape_sess, self.shape_ckpt.model_checkpoint_path)
                saver.restore(self.shape_sess,
                              shape_ckpt.model_checkpoint_path)
            else:
                print(
                    'Error restoring parameters for shape. Ensure checkpoint is stored in ${checkpoint_dir}/shape/'
                )
                # sys.exit(1)

        # Shape color graph
        self.shape_color_graph = tf.Graph()
        with self.shape_color_graph.as_default():
            self.inputs_shape_color = tf.placeholder(
                tf.float32,
                shape=[None, IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS])
            self.logits_shape_color = shallow_mlp.inference(
                self.inputs_shape_color, 11,
                scope='shape_color')  # 10 shape_colors + background
            variable_averages = tf.train.ExponentialMovingAverage(
                shallow_mlp.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)

            self.shape_color_sess = tf.Session(
            )  # graph=self.shape_color_graph
            #shape_color_saver = tf.train.Saver()
            shape_color_ckpt = tf.train.get_checkpoint_state(
                os.path.join(checkpoint_dir, 'shape_color'))
            if shape_color_ckpt and shape_color_ckpt.model_checkpoint_path:
                print('Reading shape_color model parameters from %s' %
                      shape_color_ckpt.model_checkpoint_path)
                #shape_color_saver.restore(self.shape_color_sess, self.shape_color_ckpt.model_checkpoint_path)
                saver.restore(self.shape_color_sess,
                              shape_color_ckpt.model_checkpoint_path)
            else:
                print(
                    'Error restoring parameters for shape_color. Ensure checkpoint is stored in ${checkpoint_dir}/shape_color/'
                )
                # sys.exit(1)

        # Alphanum graph
        self.alphanum_graph = tf.Graph()
        with self.alphanum_graph.as_default():
            self.inputs_alphanum = tf.placeholder(
                tf.float32,
                shape=[None, IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS])
            self.logits_alphanum = wideresnet.inference(
                self.inputs_alphanum, 37,
                scope='alphanums')  # 37 alphanums + background
            variable_averages = tf.train.ExponentialMovingAverage(
                wideresnet.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)

            self.alphanum_sess = tf.Session()
            #alphanum_saver = tf.train.Saver()
            alphanum_ckpt = tf.train.get_checkpoint_state(
                os.path.join(checkpoint_dir, 'alphanum'))
            if alphanum_ckpt and alphanum_ckpt.model_checkpoint_path:
                print('Reading alphanum model parameters from %s' %
                      alphanum_ckpt.model_checkpoint_path)
                #alphanum_saver.restore(self.alphanum_sess, self.alphanum_ckpt.model_checkpoint_path)
                saver.restore(self.alphanum_sess,
                              alphanum_ckpt.model_checkpoint_path)
            else:
                print(
                    'Error restoring parameters for alphanum. Ensure checkpoint is stored in ${checkpoint_dir}/alphanum/'
                )
                # sys.exit(1)

        # Alphanum color graph
        self.alphanum_color_graph = tf.Graph()
        with self.alphanum_color_graph.as_default():
            self.inputs_alphanum_color = tf.placeholder(
                tf.float32,
                shape=[None, IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS])
            self.logits_alphanum_color = mlp.inference(
                self.inputs_alphanum_color, 11,
                scope='letter_color')  # 10 alphanum_colors + background
            variable_averages = tf.train.ExponentialMovingAverage(
                mlp.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)

            self.alphanum_color_sess = tf.Session()
            #alphanum_color_saver = tf.train.Saver()
            alphanum_color_ckpt = tf.train.get_checkpoint_state(
                os.path.join(checkpoint_dir, 'alphanum_color'))
            if alphanum_color_ckpt and alphanum_color_ckpt.model_checkpoint_path:
                print('Reading alphanum_color model parameters from %s' %
                      alphanum_color_ckpt.model_checkpoint_path)
                #alphanum_color_saver.restore(self.alphanum_color_sess, self.alphanum_color_ckpt.model_checkpoint_path)
                saver.restore(self.alphanum_color_sess,
                              alphanum_color_ckpt.model_checkpoint_path)
            else:
                print(
                    'Error restoring parameters for alphanum_color. Ensure checkpoint is stored in ${checkpoint_dir}/alphanum_color/'
                )
コード例 #22
0
ファイル: APItest.py プロジェクト: wuyou33/QuadrotorModelling
# client = interop.Client(url='http://172.17.0.1:8000', username='******',password='******')
# missions = client.get_missions()
# print(missions)

# stationary_obstacles, moving_obstacles = client.get_obstacles()

# print(stationary_obstacles, moving_obstacles)

server = os.getenv('TEST_INTEROP_SERVER', 'http://localhost:8000')
username = os.getenv('TEST_INTEROP_USER', 'testuser')
password = os.getenv('TEST_INTEROP_USER_PASS', 'testpass')
admin_username = os.getenv('TEST_INTEROP_ADMIN', 'testadmin')
admin_password = os.getenv('TEST_INTEROP_ADMIN_PASS', 'testpass')
"""Create a logged in Client."""
# Create an admin client to clear cache.
admin_client = Client(server, admin_username, admin_password)
admin_client.get('/api/clear_cache')

# Test rest with non-admin clients.
client = Client(server, username, password)
async_client = AsyncClient(server, username, password)
"""Test getting missions."""
missions = client.get_missions()
async_missions = async_client.get_missions().result()

# # Check one mission returned.
# self.assertEqual(1, len(missions))
# self.assertEqual(1, len(async_missions))
# # Check a few fields.
# self.assertTrue(missions[0].active)
# self.assertTrue(async_missions[0].active)