Esempio n. 1
0
 def toStr(self, protoObj):
     """Used for pretty printing a result from the API."""
     return text_format.MessageToString(protoObj)
Esempio n. 2
0
    def _AssertProtoDictEquals(self,
                               expected_dict,
                               actual_dict,
                               verbose=False,
                               update_goldens=False,
                               additional_missing_object_message='',
                               api_version=2):
        """Diff given dicts of protobufs and report differences a readable way.

    Args:
      expected_dict: a dict of TFAPIObject protos constructed from golden files.
      actual_dict: a ict of TFAPIObject protos constructed by reading from the
        TF package linked to the test.
      verbose: Whether to log the full diffs, or simply report which files were
        different.
      update_goldens: Whether to update goldens when there are diffs found.
      additional_missing_object_message: Message to print when a symbol is
        missing.
      api_version: TensorFlow API version to test.
    """
        diffs = []
        verbose_diffs = []

        expected_keys = set(expected_dict.keys())
        actual_keys = set(actual_dict.keys())
        only_in_expected = expected_keys - actual_keys
        only_in_actual = actual_keys - expected_keys
        all_keys = expected_keys | actual_keys

        # This will be populated below.
        updated_keys = []

        for key in all_keys:
            diff_message = ''
            verbose_diff_message = ''
            # First check if the key is not found in one or the other.
            if key in only_in_expected:
                diff_message = 'Object %s expected but not found (removed). %s' % (
                    key, additional_missing_object_message)
                verbose_diff_message = diff_message
            elif key in only_in_actual:
                diff_message = 'New object %s found (added).' % key
                verbose_diff_message = diff_message
            else:
                # Do not truncate diff
                self.maxDiff = None  # pylint: disable=invalid-name
                # Now we can run an actual proto diff.
                try:
                    self.assertProtoEquals(expected_dict[key],
                                           actual_dict[key])
                except AssertionError as e:
                    updated_keys.append(key)
                    diff_message = 'Change detected in python object: %s.' % key
                    verbose_diff_message = str(e)

            # All difference cases covered above. If any difference found, add to the
            # list.
            if diff_message:
                diffs.append(diff_message)
                verbose_diffs.append(verbose_diff_message)

        # If diffs are found, handle them based on flags.
        if diffs:
            diff_count = len(diffs)
            logging.error(self._test_readme_message)
            logging.error('%d differences found between API and golden.',
                          diff_count)
            messages = verbose_diffs if verbose else diffs
            for i in range(diff_count):
                print('Issue %d\t: %s' % (i + 1, messages[i]), file=sys.stderr)

            if update_goldens:
                # Write files if requested.
                logging.warning(self._update_golden_warning)

                # If the keys are only in expected, some objects are deleted.
                # Remove files.
                for key in only_in_expected:
                    filepath = _KeyToFilePath(key, api_version)
                    file_io.delete_file(filepath)

                # If the files are only in actual (current library), these are new
                # modules. Write them to files. Also record all updates in files.
                for key in only_in_actual | set(updated_keys):
                    filepath = _KeyToFilePath(key, api_version)
                    file_io.write_string_to_file(
                        filepath,
                        text_format.MessageToString(actual_dict[key]))
            else:
                # Fail if we cannot fix the test by updating goldens.
                self.fail('%d differences found between API and golden.' %
                          diff_count)

        else:
            logging.info('No differences found between API and golden.')
Esempio n. 3
0
def main():
    """main
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--net', type=str, required=True, help='net prototxt')
    parser.add_argument('--weight', type=str, required=True, help='net weight')
    args = parser.parse_args()
    print(args)

    net = caffe_pb2.NetParameter()
    text_format.Merge(open(args.net, 'r').read(), net)
    weight = caffe_pb2.NetParameter()
    weight.ParseFromString(open(args.weight, 'rb').read())

    # remove useless layers
    net_layers = [layer.name for layer in net.layer]
    weight_layers = [layer.name for layer in net.layer]
    not_used_layers = [
        layer for layer in weight_layers if layer not in net_layers
    ]

    for layer_name in not_used_layers:
        weight.layer.remove(get_layer(weight, layer_name))

    # search conv-batchnorm-scale pattern
    remove_list = []
    i = 0
    while i < len(net_layers):
        if net.layer[i].type == 'Convolution':
            conv_param = net.layer[i].convolution_param
            conv_name = net.layer[i].name
            conv_layer = get_layer(weight, conv_name)
            i += 1
            if conv_param.group != 1:
                continue
            if i < len(net_layers) and net.layer[i].type == 'BatchNorm':
                batch_norm_param = net.layer[i].batch_norm_param
                batch_norm_name = net.layer[i].name
                batch_norm_layer = get_layer(weight, batch_norm_name)
                i += 1
                if net.layer[i].HasField(
                        'batch_norm_param'
                ) and not batch_norm_param.use_global_stats:
                    continue
                if i < len(net_layers) and net.layer[i].type == 'Scale':
                    scale_param = net.layer[i].scale_param
                    scale_name = net.layer[i].name
                    scale_layer = get_layer(weight, scale_name)
                    i += 1

                    print('fuse (%s, %s, %s)' %
                          (conv_name, batch_norm_name, scale_name))
                    # weight, bias
                    conv_weight = convert_blob_to_array(conv_layer.blobs[0])
                    if conv_param.bias_term:
                        conv_bias = convert_blob_to_array(conv_layer.blobs[1])
                    else:
                        channels = conv_param.num_output
                        conv_bias = np.zeros(channels, dtype=np.float32)
                    # mean, std
                    mean = convert_blob_to_array(batch_norm_layer.blobs[0])
                    std = convert_blob_to_array(batch_norm_layer.blobs[1])
                    scale_factor = convert_blob_to_array(
                        batch_norm_layer.blobs[2])
                    mean /= scale_factor
                    std /= scale_factor
                    # scale, shift
                    scale = convert_blob_to_array(scale_layer.blobs[0])
                    if scale_param.bias_term:
                        shift = convert_blob_to_array(scale_layer.blobs[1])
                    else:
                        shift = np.zeros(scale.shape, dtype=np.float32)
                    # eps
                    if batch_norm_param is not None:
                        eps = batch_norm_param.eps
                    else:
                        eps = 1e-5
                    # fuse
                    conv_weight, conv_bias = fuse(conv_weight, conv_bias, mean,
                                                  std, scale, shift, eps)

                    conv_layer.blobs.pop()
                    if conv_param.bias_term:
                        conv_layer.blobs.pop()
                    else:
                        conv_param.bias_term = True
                    conv_layer.blobs.extend([
                        convert_array_to_blob(conv_weight),
                        convert_array_to_blob(conv_bias)
                    ])

                    # remove batchnorm and scale layer
                    weight.layer.remove(batch_norm_layer)
                    weight.layer.remove(scale_layer)
                    net.layer[i - 3].top[0] = net.layer[i - 1].top[0]
                    remove_list.extend([net.layer[i - 1], net.layer[i - 2]])
        else:
            i += 1

    for layer in remove_list:
        net.layer.remove(layer)
    # save
    out_net = '_nobn'.join(os.path.splitext(args.net))
    out_weight = '_nobn'.join(os.path.splitext(args.weight))
    with open(out_net, 'w') as fout:
        fout.write(text_format.MessageToString(net))
    with open(out_weight, 'wb') as fout:
        fout.write(weight.SerializeToString())
def do_test(request):
  response = conformance_pb2.ConformanceResponse()

  if request.message_type == "conformance.FailureSet":
    failure_set = conformance_pb2.FailureSet()
    failures = []
    # TODO(gerbens): Remove, this is a hack to detect if the old vs new
    # parser is used by the cpp code. Relying on a bug in the old parser.
    hack_proto = test_messages_proto2_pb2.TestAllTypesProto2()
    old_parser = True
    try:
      hack_proto.ParseFromString(b"\322\002\001")
    except message.DecodeError as e:
      old_parser = False
    if old_parser:
      # the string above is one of the failing conformance test strings of the
      # old parser. If we succeed the c++ implementation is using the old
      # parser so we add the list of failing conformance tests.
      failures = [
          "Required.Proto3.ProtobufInput.PrematureEofInDelimitedDataForKnownNonRepeatedValue.MESSAGE",
          "Required.Proto3.ProtobufInput.PrematureEofInDelimitedDataForKnownRepeatedValue.MESSAGE",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.BOOL",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.DOUBLE",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.ENUM",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.FIXED32",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.FIXED64",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.FLOAT",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.INT32",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.INT64",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.SFIXED32",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.SFIXED64",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.SINT32",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.SINT64",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.UINT32",
          "Required.Proto3.ProtobufInput.PrematureEofInPackedField.UINT64",
          "Required.Proto2.ProtobufInput.PrematureEofInDelimitedDataForKnownNonRepeatedValue.MESSAGE",
          "Required.Proto2.ProtobufInput.PrematureEofInDelimitedDataForKnownRepeatedValue.MESSAGE",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.BOOL",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.DOUBLE",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.ENUM",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.FIXED32",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.FIXED64",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.FLOAT",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.INT32",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.INT64",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.SFIXED32",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.SFIXED64",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.SINT32",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.SINT64",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.UINT32",
          "Required.Proto2.ProtobufInput.PrematureEofInPackedField.UINT64",
      ]
    for x in failures:
      failure_set.failure.append(x)
    response.protobuf_payload = failure_set.SerializeToString()
    return response

  isProto3 = (request.message_type == "protobuf_test_messages.proto3.TestAllTypesProto3")
  isJson = (request.WhichOneof('payload') == 'json_payload')
  isProto2 = (request.message_type == "protobuf_test_messages.proto2.TestAllTypesProto2")

  if (not isProto3) and (not isJson) and (not isProto2):
    raise ProtocolError("Protobuf request doesn't have specific payload type")

  test_message = test_messages_proto2_pb2.TestAllTypesProto2() if isProto2 else \
    test_messages_proto3_pb2.TestAllTypesProto3()

  try:
    if request.WhichOneof('payload') == 'protobuf_payload':
      try:
        test_message.ParseFromString(request.protobuf_payload)
      except message.DecodeError as e:
        response.parse_error = str(e)
        return response

    elif request.WhichOneof('payload') == 'json_payload':
      try:
        ignore_unknown_fields = \
            request.test_category == \
                conformance_pb2.JSON_IGNORE_UNKNOWN_PARSING_TEST
        json_format.Parse(request.json_payload, test_message,
                          ignore_unknown_fields)
      except Exception as e:
        response.parse_error = str(e)
        return response

    elif request.WhichOneof('payload') == 'text_payload':
      try:
        text_format.Parse(request.text_payload, test_message)
      except Exception as e:
        response.parse_error = str(e)
        return response

    else:
      raise ProtocolError("Request didn't have payload.")

    if request.requested_output_format == conformance_pb2.UNSPECIFIED:
      raise ProtocolError("Unspecified output format")

    elif request.requested_output_format == conformance_pb2.PROTOBUF:
      response.protobuf_payload = test_message.SerializeToString()

    elif request.requested_output_format == conformance_pb2.JSON:
      try:
        response.json_payload = json_format.MessageToJson(test_message)
      except Exception as e:
        response.serialize_error = str(e)
        return response

    elif request.requested_output_format == conformance_pb2.TEXT_FORMAT:
      response.text_payload = text_format.MessageToString(test_message)

  except Exception as e:
    response.runtime_error = str(e)

  return response
Esempio n. 5
0
    def test_standalone(self):
        """Sample test for run_local_database.py as a standalone process."""

        topology = vttest_pb2.VTTestTopology()
        keyspace = topology.keyspaces.add(name='test_keyspace')
        keyspace.replica_count = 2
        keyspace.rdonly_count = 1
        keyspace.shards.add(name='-80')
        keyspace.shards.add(name='80-')
        topology.keyspaces.add(name='redirect', served_from='test_keyspace')

        # launch a backend database based on the provided topology and schema
        port = environment.reserve_ports(1)
        args = [
            environment.run_local_database,
            '--port',
            str(port),
            '--proto_topo',
            text_format.MessageToString(topology, as_one_line=True),
            '--schema_dir',
            os.path.join(environment.vtroot, 'test', 'vttest_schema'),
            '--web_dir',
            environment.vtroot + '/web/vtctld',
        ]
        sp = subprocess.Popen(args,
                              stdin=subprocess.PIPE,
                              stdout=subprocess.PIPE)
        config = json.loads(sp.stdout.readline())

        # gather the vars for the vtgate process
        url = 'http://localhost:%d/debug/vars' % config['port']
        f = urllib.urlopen(url)
        data = f.read()
        f.close()
        json_vars = json.loads(data)
        self.assertIn('vtcombo', json_vars['cmdline'][0])

        # build the vtcombo address and protocol
        protocol = protocols_flavor().vttest_protocol()
        if protocol == 'grpc':
            vtgate_addr = 'localhost:%d' % config['grpc_port']
        else:
            vtgate_addr = 'localhost:%d' % config['port']
        conn_timeout = 30.0
        utils.pause('Paused test after vtcombo was started.\n'
                    'For manual testing, connect to vtgate at: %s '
                    'using protocol: %s.\n'
                    'Press enter to continue.' % (vtgate_addr, protocol))

        # Remember the current timestamp after we sleep for a bit, so we
        # can use it for UpdateStream later.
        time.sleep(2)
        before_insert = long(time.time())

        # Connect to vtgate.
        conn = vtgate_client.connect(protocol, vtgate_addr, conn_timeout)

        # Insert a row.
        row_id = 123
        keyspace_id = get_keyspace_id(row_id)
        cursor = conn.cursor(tablet_type='master',
                             keyspace='test_keyspace',
                             keyspace_ids=[pack_kid(keyspace_id)],
                             writable=True)
        cursor.begin()
        insert = ('insert into test_table (id, msg, keyspace_id) values (:id, '
                  ':msg, :keyspace_id)')
        bind_variables = {
            'id': row_id,
            'msg': 'test %s' % row_id,
            'keyspace_id': keyspace_id,
        }
        cursor.execute(insert, bind_variables)
        cursor.commit()

        # Read the row back.
        cursor.execute('select * from test_table where id=:id', {'id': row_id})
        result = cursor.fetchall()
        self.assertEqual(result[0][1], 'test 123')

        # try to insert again, see if we get the right integrity error exception
        # (this is meant to test vtcombo properly returns exceptions, and to a
        # lesser extent that the python client converts it properly)
        cursor.begin()
        with self.assertRaises(dbexceptions.IntegrityError):
            cursor.execute(insert, bind_variables)
        cursor.rollback()

        # Insert a bunch of rows with long msg values.
        bind_variables['msg'] = 'x' * 64
        id_start = 1000
        rowcount = 500
        cursor.begin()
        for i in xrange(id_start, id_start + rowcount):
            bind_variables['id'] = i
            bind_variables['keyspace_id'] = get_keyspace_id(i)
            cursor.execute(insert, bind_variables)
        cursor.commit()
        cursor.close()

        # Try to fetch a large number of rows, from a rdonly
        # (more than one streaming result packet).
        stream_cursor = conn.cursor(
            tablet_type='rdonly',
            keyspace='test_keyspace',
            keyspace_ids=[pack_kid(keyspace_id)],
            cursorclass=vtgate_cursor.StreamVTGateCursor)
        stream_cursor.execute('select * from test_table where id >= :id_start',
                              {'id_start': id_start})
        self.assertEqual(rowcount, len(list(stream_cursor.fetchall())))
        stream_cursor.close()

        # try to read a row using the redirected keyspace, to a replica this time
        row_id = 123
        keyspace_id = get_keyspace_id(row_id)
        cursor = conn.cursor(tablet_type='replica',
                             keyspace='redirect',
                             keyspace_ids=[pack_kid(keyspace_id)])
        cursor.execute('select * from test_table where id=:id', {'id': row_id})
        result = cursor.fetchall()
        self.assertEqual(result[0][1], 'test 123')
        cursor.close()

        # Try to get the update stream from the connection. This makes
        # sure that part works as well.
        count = 0
        for (event, _) in conn.update_stream('test_keyspace',
                                             topodata_pb2.MASTER,
                                             timestamp=before_insert,
                                             shard='-80'):
            for statement in event.statements:
                if statement.table_name == 'test_table':
                    count += 1
            if count == rowcount + 1:
                # We're getting the initial value, plus the 500 updates.
                break

        # Insert a sentinel value into the second shard.
        row_id = 0x8100000000000000
        keyspace_id = get_keyspace_id(row_id)
        cursor = conn.cursor(tablet_type='master',
                             keyspace='test_keyspace',
                             keyspace_ids=[pack_kid(keyspace_id)],
                             writable=True)
        cursor.begin()
        bind_variables = {
            'id': row_id,
            'msg': 'test %s' % row_id,
            'keyspace_id': keyspace_id,
        }
        cursor.execute(insert, bind_variables)
        cursor.commit()
        cursor.close()

        # Try to connect to an update stream on the other shard.
        # We may get some random update stream events, but we should not get any
        # event that's related to the first shard. Only events related to
        # the Insert we just did.
        found = False
        for (event, _) in conn.update_stream('test_keyspace',
                                             topodata_pb2.MASTER,
                                             timestamp=before_insert,
                                             shard='80-'):
            for statement in event.statements:
                self.assertEqual(statement.table_name, 'test_table')
                fields, rows = proto3_encoding.convert_stream_event_statement(
                    statement)
                self.assertEqual(fields[0], 'id')
                self.assertEqual(rows[0][0], row_id)
                found = True
            if found:
                break

        # Clean up the connection
        conn.close()

        # Test we can connect to vtcombo for vtctl actions
        protocol = protocols_flavor().vtctl_python_client_protocol()
        if protocol == 'grpc':
            vtgate_addr = 'localhost:%d' % config['grpc_port']
        else:
            vtgate_addr = 'localhost:%d' % config['port']
        out, _ = utils.run(environment.binary_args('vtctlclient') + [
            '-vtctl_client_protocol',
            protocol,
            '-server',
            vtgate_addr,
            '-stderrthreshold',
            '0',
            'ListAllTablets',
            'test',
        ],
                           trap_output=True)
        num_master = 0
        num_replica = 0
        num_rdonly = 0
        num_dash_80 = 0
        num_80_dash = 0
        for line in out.splitlines():
            parts = line.split()
            self.assertEqual(parts[1], 'test_keyspace',
                             'invalid keyspace in line: %s' % line)
            if parts[3] == 'master':
                num_master += 1
            elif parts[3] == 'replica':
                num_replica += 1
            elif parts[3] == 'rdonly':
                num_rdonly += 1
            else:
                self.fail('invalid tablet type in line: %s' % line)
            if parts[2] == '-80':
                num_dash_80 += 1
            elif parts[2] == '80-':
                num_80_dash += 1
            else:
                self.fail('invalid shard name in line: %s' % line)
        self.assertEqual(num_master, 2)
        self.assertEqual(num_replica, 2)
        self.assertEqual(num_rdonly, 2)
        self.assertEqual(num_dash_80, 3)
        self.assertEqual(num_80_dash, 3)

        # and we're done, clean-up process
        sp.stdin.write('\n')
        sp.wait()
Esempio n. 6
0
            def build_example(label, param_dict_real, zip_path_label):
                """Build the model with parameter values set in param_dict_real.

        Args:
          label: Label of the model
          param_dict_real: Parameter dictionary (arguments to the factories
            make_graph and make_test_inputs)
          zip_path_label: Filename in the zip

        Returns:
          (tflite_model_binary, report) where tflite_model_binary is the
          serialized flatbuffer as a string and report is a dictionary with
          keys `tflite_converter_log` (log of conversion), `tf_log` (log of tf
          conversion), `converter` (a string of success status of the
          conversion), `tf` (a string success status of the conversion).
        """

                np.random.seed(RANDOM_SEED)
                report = {
                    "tflite_converter": report_lib.NOTRUN,
                    "tf": report_lib.FAILED
                }

                # Build graph
                report["tf_log"] = ""
                report["tflite_converter_log"] = ""
                tf.reset_default_graph()

                with tf.Graph().as_default():
                    with tf.device("/cpu:0"):
                        try:
                            inputs, outputs = make_graph(param_dict_real)
                            inputs = [x for x in inputs if x is not None]
                        except (tf.errors.UnimplementedError,
                                tf.errors.InvalidArgumentError, ValueError):
                            report["tf_log"] += traceback.format_exc()
                            return None, report

                    sess = tf.Session()
                    try:
                        baseline_inputs, baseline_outputs = (make_test_inputs(
                            param_dict_real, sess, inputs, outputs))
                        baseline_inputs = [
                            x for x in baseline_inputs if x is not None
                        ]
                        # Converts baseline inputs/outputs to maps. The signature input and
                        # output names are set to be the same as the tensor names.
                        input_names = [
                            _normalize_input_name(x.name) for x in inputs
                        ]
                        output_names = [
                            _normalize_output_name(x.name) for x in outputs
                        ]
                        baseline_input_map = dict(
                            zip(input_names, baseline_inputs))
                        baseline_output_map = dict(
                            zip(output_names, baseline_outputs))
                    except (tf.errors.UnimplementedError,
                            tf.errors.InvalidArgumentError, ValueError):
                        report["tf_log"] += traceback.format_exc()
                        return None, report
                    report["tflite_converter"] = report_lib.FAILED
                    report["tf"] = report_lib.SUCCESS

                    # Builds a saved model with the default signature key.
                    input_names, tensor_info_inputs = _get_tensor_info(
                        inputs, "input_", _normalize_input_name)
                    output_tensors, tensor_info_outputs = _get_tensor_info(
                        outputs, "output_", _normalize_output_name)
                    input_tensors = [(name, t.shape, t.dtype)
                                     for name, t in zip(input_names, inputs)]

                    inference_signature = (
                        tf.saved_model.signature_def_utils.build_signature_def(
                            inputs=tensor_info_inputs,
                            outputs=tensor_info_outputs,
                            method_name="op_test"))
                    saved_model_dir = tempfile.mkdtemp("op_test")
                    saved_model_tags = [tf.saved_model.tag_constants.SERVING]
                    signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
                    builder = tf.saved_model.builder.SavedModelBuilder(
                        saved_model_dir)
                    builder.add_meta_graph_and_variables(
                        sess,
                        saved_model_tags,
                        signature_def_map={
                            signature_key: inference_signature,
                        },
                        strip_default_attrs=True)
                    builder.save(as_text=False)
                    # pylint: disable=g-long-ternary
                    graph_def = freeze_graph(
                        sess,
                        tf.global_variables() + inputs +
                        outputs) if use_frozen_graph else sess.graph_def

                if "split_tflite_lstm_inputs" in param_dict_real:
                    extra_convert_options.split_tflite_lstm_inputs = param_dict_real[
                        "split_tflite_lstm_inputs"]
                tflite_model_binary, converter_log = options.tflite_convert_function(
                    options,
                    saved_model_dir,
                    input_tensors,
                    output_tensors,
                    extra_convert_options=extra_convert_options,
                    test_params=param_dict_real)
                report["tflite_converter"] = (report_lib.SUCCESS if
                                              tflite_model_binary is not None
                                              else report_lib.FAILED)
                report["tflite_converter_log"] = converter_log

                if options.save_graphdefs:
                    zipinfo = zipfile.ZipInfo(zip_path_label + ".pbtxt")
                    archive.writestr(zipinfo,
                                     text_format.MessageToString(graph_def),
                                     zipfile.ZIP_DEFLATED)

                if tflite_model_binary:
                    if options.make_edgetpu_tests:
                        # Set proper min max values according to input dtype.
                        baseline_input_map, baseline_output_map = generate_inputs_outputs(
                            tflite_model_binary, min_value=0, max_value=255)
                    zipinfo = zipfile.ZipInfo(zip_path_label + ".bin")
                    archive.writestr(zipinfo, tflite_model_binary,
                                     zipfile.ZIP_DEFLATED)

                    example = {
                        "inputs": baseline_input_map,
                        "outputs": baseline_output_map
                    }

                    example_fp = StringIO()
                    write_examples(example_fp, [example])
                    zipinfo = zipfile.ZipInfo(zip_path_label + ".inputs")
                    archive.writestr(zipinfo, example_fp.getvalue(),
                                     zipfile.ZIP_DEFLATED)

                    example_fp2 = StringIO()
                    write_test_cases(example_fp2, zip_path_label + ".bin",
                                     [example])
                    zipinfo = zipfile.ZipInfo(zip_path_label + "_tests.txt")
                    archive.writestr(zipinfo, example_fp2.getvalue(),
                                     zipfile.ZIP_DEFLATED)

                    zip_manifest_label = zip_path_label + " " + label
                    if zip_path_label == label:
                        zip_manifest_label = zip_path_label

                    zip_manifest.append(zip_manifest_label + "\n")

                return tflite_model_binary, report
Esempio n. 7
0
def _test_user_op_graph(test_case, is_cuda):
    test_case.assertTrue(oneflow.framework.env_util.HasAllMultiClientEnvVars())

    x0 = flow.tensor(np.random.rand(20, 30), dtype=flow.float32)
    weight0 = flow.tensor(np.random.rand(30, 50), dtype=flow.float32)
    x1 = flow.tensor(np.random.rand(50, 70), dtype=flow.float32)

    if is_cuda:
        x0 = x0.to(device=flow.device("cuda"))
        weight0 = weight0.to(device=flow.device("cuda"))
        x1 = x1.to(device=flow.device("cuda"))

    # NOTE(chengcheng): this tiny net is:
    #    x0 * weight0 -> out0
    #    relu(out0) -> y0
    #    y0 * x1 -> out1
    #    relu(out1) -> y1

    session = session_ctx.GetDefaultSession()
    test_case.assertTrue(isinstance(session, MultiClientSession))
    session.TryInit()

    with oneflow._oneflow_internal.lazy_mode.guard(True):

        oneflow._oneflow_internal.JobBuildAndInferCtx_Open(
            "cc_test_user_op_expr_job_with_cuda" + str(is_cuda))
        job_conf = oneflow.core.job.job_conf_pb2.JobConfigProto()
        job_conf.job_name = "cc_test_user_op_expr_job_with_cuda" + str(is_cuda)
        job_conf.predict_conf.SetInParent()
        c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)

        x0_conf = oneflow.core.operator.op_conf_pb2.FeedInputOpConf()
        x0_conf.in_0 = "in_0"
        x0_conf.out_0 = "out_0"
        x0_conf_str = text_format.MessageToString(x0_conf)
        x0_op = oneflow._oneflow_internal.one.FeedInputOpExpr(
            "cc_Input_0", x0_conf_str, ["in_0"], ["out_0"])

        x1_conf = oneflow.core.operator.op_conf_pb2.FeedInputOpConf()
        x1_conf.in_0 = "in_0"
        x1_conf.out_0 = "out_0"
        x1_conf_str = text_format.MessageToString(x1_conf)
        x1_op = oneflow._oneflow_internal.one.FeedInputOpExpr(
            "cc_Input_1", x1_conf_str, ["in_0"], ["out_0"])

        weight0_conf = oneflow.core.operator.op_conf_pb2.FeedVariableOpConf()
        weight0_conf.in_0 = "in_0"
        weight0_conf.out_0 = "out_0"
        weight0_conf_str = text_format.MessageToString(weight0_conf)
        weight0_op = oneflow._oneflow_internal.one.FeedVariableOpExpr(
            "cc_Variable_0", weight0_conf_str, ["in_0"], ["out_0"])
        output_conf = oneflow.core.operator.op_conf_pb2.FetchOutputOpConf()
        output_conf.in_0 = "in_0"
        output_conf.out_0 = "out_0"
        output_conf_str = text_format.MessageToString(output_conf)
        output_op = oneflow._oneflow_internal.one.FetchOutputOpExpr(
            "cc_Output_0", output_conf_str, ["in_0"], ["out_0"])

        x0_lazy_tensor = _C.dispatch_feed_input(x0_op, x0)
        x1_lazy_tensor = _C.dispatch_feed_input(x1_op, x1)
        weight0_lazy_tensor = _C.dispatch_feed_input(weight0_op, weight0)

        test_case.assertEqual(x0_lazy_tensor.shape, (20, 30))
        test_case.assertTrue(x0_lazy_tensor.is_lazy)

        test_case.assertEqual(weight0_lazy_tensor.shape, (30, 50))
        test_case.assertTrue(weight0_lazy_tensor.is_lazy)
        test_case.assertEqual(x1_lazy_tensor.shape, (50, 70))
        test_case.assertTrue(x1_lazy_tensor.is_lazy)

        out0 = flow._C.matmul(x0_lazy_tensor, weight0_lazy_tensor)
        test_case.assertEqual(out0.shape, (20, 50))
        test_case.assertTrue(out0.is_lazy)

        y0 = flow._C.relu(out0)
        test_case.assertEqual(y0.shape, (20, 50))
        test_case.assertTrue(y0.is_lazy)

        out1 = flow._C.matmul(y0, x1_lazy_tensor)
        test_case.assertEqual(out1.shape, (20, 70))
        test_case.assertTrue(out1.is_lazy)

        y1 = flow._C.relu(out1)
        test_case.assertEqual(y1.shape, (20, 70))
        test_case.assertTrue(y1.is_lazy)

        eager_output = _C.dispatch_fetch_output(output_op, y1)
        test_case.assertEqual(eager_output.shape, (20, 70))
        test_case.assertTrue(not eager_output.is_lazy)

        if is_cuda:
            test_case.assertTrue(x0_lazy_tensor.is_cuda)
            test_case.assertTrue(x1_lazy_tensor.is_cuda)
            test_case.assertTrue(weight0_lazy_tensor.is_cuda)
            test_case.assertTrue(out0.is_cuda)
            test_case.assertTrue(y0.is_cuda)
            test_case.assertTrue(out1.is_cuda)
            test_case.assertTrue(y1.is_cuda)

        oneflow._oneflow_internal.JobBuildAndInferCtx_Close()
Esempio n. 8
0
    start=args.start
    end=args.end
    num_attr=end-start


    net_file=args.prototxt
    n=caffe_pb2.NetParameter()
    text_format.Merge(open(net_file).read(), n)
    n.layer[-6].inner_product_param.num_output=num_attr

    new_dir=os.path.dirname(net_file)+"/train_net_dir"
    new_file=  new_dir+"/test_net_{}_{}.prototxt".format(start,end)

    with open(new_file,'w+') as  new_prototxt_file:
        new_prototxt_file.write(unicode(text_format.MessageToString(n)))

    args.prototxt=new_file
    # print args.prototxt
    # print num_attr
    # time.sleep(10)
    net = caffe.Net(new_file, args.caffemodel, caffe.TEST)
    net.name = os.path.splitext(os.path.basename(args.caffemodel))[0]

    if args.db == 'RAP':
        """Load RAP database"""
        from utils.rap_db import RAP
        db = RAP(os.path.join('data', 'dataset', args.db), args.par_set_id)
    else:
        """Load PETA dayanse"""
        from utils.peta_db import PETA
Esempio n. 9
0
def CheckAndCompleteUserOpConf(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    new_op_conf = oneflow_api.CheckAndCompleteUserOpConf(serialized_op_conf)
    return text_format.Parse(new_op_conf, op_conf_util.OperatorConf())
Esempio n. 10
0
def save_experiment(filename, expt):
    file_write_safe(filename, text_format.MessageToString(expt))
Esempio n. 11
0
        def RemoteCall(*args, **kwargs):
            """Dynamically calls a remote API and returns the result value."""
            func_msg = self.GetApi(api_name)
            if not func_msg:
                raise MirrorObjectError("api %s unknown", func_msg)

            logging.debug("remote call %s.%s", self._parent_path, api_name)
            logging.info("remote call %s%s", api_name, args)
            if args:
                for arg_msg, value_msg in zip(func_msg.arg, args):
                    logging.debug("arg msg %s", arg_msg)
                    logging.debug("value %s", value_msg)
                    if value_msg is not None:
                        self.ArgToPb(arg_msg, value_msg)

                logging.debug("final msg %s", func_msg)
            else:
                # TODO: use kwargs
                for arg in func_msg.arg:
                    # TODO: handle other
                    if (arg.type == CompSpecMsg.TYPE_SCALAR
                            and arg.scalar_type == "pointer"):
                        arg.scalar_value.pointer = 0
                logging.debug(func_msg)

            if self._parent_path:
                func_msg.parent_path = self._parent_path

            if self._interface_id is not None:
                func_msg.hidl_interface_id = self._interface_id

            if isinstance(self._if_spec_msg,
                          CompSpecMsg.ComponentSpecificationMessage):
                if self._if_spec_msg.component_class:
                    logging.info("component_class %s",
                                 self._if_spec_msg.component_class)
                    if self._if_spec_msg.component_class == CompSpecMsg.HAL_CONVENTIONAL_SUBMODULE:
                        submodule_name = self._if_spec_msg.original_data_structure_name
                        if submodule_name.endswith("*"):
                            submodule_name = submodule_name[:-1]
                        func_msg.submodule_name = submodule_name
            result = self._client.CallApi(
                text_format.MessageToString(func_msg), self.__caller_uid)
            logging.debug(result)
            if (isinstance(result, tuple) and len(result) == 2
                    and isinstance(result[1], dict)
                    and "coverage" in result[1]):
                self._last_raw_code_coverage_data = result[1]["coverage"]
                result = result[0]

            if (result and isinstance(result,
                                      CompSpecMsg.VariableSpecificationMessage)
                    and result.type == CompSpecMsg.TYPE_HIDL_INTERFACE):
                if result.hidl_interface_id <= -1:
                    return None
                nested_interface_id = result.hidl_interface_id
                nested_interface_name = result.predefined_type.split("::")[-1]
                logging.debug("Nested interface name: %s",
                              nested_interface_name)
                nested_interface = self.GetHidlNestedInterface(
                    nested_interface_name, nested_interface_id)
                return nested_interface
            return result
Esempio n. 12
0
def to_pbtxt_file(output_path, spec):
    """Saves a spec encoded as a struct_pb2.StructuredValue in a pbtxt file."""
    spec_proto = to_proto(spec)
    with tf.io.gfile.GFile(output_path, "wb") as f:
        f.write(text_format.MessageToString(spec_proto))
Esempio n. 13
0
def commit_portal_manifest(etcd, portal_manifest):
    etcd_key = os.path.join(portal_manifest.name, 'manifest')
    etcd.set_data(etcd_key, text_format.MessageToString(portal_manifest))
Esempio n. 14
0
def commit_data_source(etcd, data_source):
    etcd_key = os.path.join(data_source.data_source_meta.name, 'master')
    etcd.set_data(etcd_key, text_format.MessageToString(data_source))
def populate_experiment(run_config,
                        hparams,
                        pipeline_config_path,
                        train_steps=None,
                        eval_steps=None,
                        model_fn_creator=create_model_fn,
                        **kwargs):
  """Populates an `Experiment` object.

  Args:
    run_config: A `RunConfig`.
    hparams: A `HParams`.
    pipeline_config_path: A path to a pipeline config file.
    train_steps: Number of training steps. If None, the number of training steps
      is set from the `TrainConfig` proto.
    eval_steps: Number of evaluation steps per evaluation cycle. If None, the
      number of evaluation steps is set from the `EvalConfig` proto.
    model_fn_creator: A function that creates a `model_fn` for `Estimator`.
      Follows the signature:

      * Args:
        * `detection_model_fn`: Function that returns `DetectionModel` instance.
        * `configs`: Dictionary of pipeline config objects.
        * `hparams`: `HParams` object.
      * Returns:
        `model_fn` for `Estimator`.

    **kwargs: Additional keyword arguments for configuration override.

  Returns:
    An `Experiment` that defines all aspects of training, evaluation, and
    export.
  """
  configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  configs = config_util.merge_external_params_with_configs(
      configs,
      hparams,
      train_steps=train_steps,
      eval_steps=eval_steps,
      **kwargs)
  model_config = configs['model']
  train_config = configs['train_config']
  train_input_config = configs['train_input_config']
  eval_config = configs['eval_config']
  eval_input_config = configs['eval_input_config']

  if train_steps is None:
    train_steps = train_config.num_steps if train_config.num_steps else None

  if eval_steps is None:
    eval_steps = eval_config.num_examples if eval_config.num_examples else None

  detection_model_fn = functools.partial(
      model_builder.build, model_config=model_config)

  # Create the input functions for TRAIN/EVAL.
  train_input_fn = inputs.create_train_input_fn(
      train_config=train_config,
      train_input_config=train_input_config,
      model_config=model_config)
  eval_input_fn = inputs.create_eval_input_fn(
      eval_config=eval_config,
      eval_input_config=eval_input_config,
      model_config=model_config)

  export_strategies = [
      tf.contrib.learn.utils.saved_model_export_utils.make_export_strategy(
          serving_input_fn=inputs.create_predict_input_fn(
              model_config=model_config))
  ]

  estimator = tf.estimator.Estimator(
      model_fn=model_fn_creator(detection_model_fn, configs, hparams),
      config=run_config)

  if run_config.is_chief:
    # Store the final pipeline config for traceability.
    pipeline_config_final = config_util.create_pipeline_proto_from_configs(
        configs)
    pipeline_config_final_path = os.path.join(estimator.model_dir,
                                              'pipeline.config')
    config_text = text_format.MessageToString(pipeline_config_final)
    with tf.gfile.Open(pipeline_config_final_path, 'wb') as f:
      tf.logging.info('Writing as-run pipeline config file to %s',
                      pipeline_config_final_path)
      f.write(config_text)

  return tf.contrib.learn.Experiment(
      estimator=estimator,
      train_input_fn=train_input_fn,
      eval_input_fn=eval_input_fn,
      train_steps=train_steps,
      eval_steps=eval_steps,
      export_strategies=export_strategies,
      eval_delay_secs=120,)
Esempio n. 16
0
def CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    add_and_infer = oneflow_api.CurJobBuildAndInferCtx_AddAndInferMirroredOp
    op_attribute_str = add_and_infer(serialized_op_conf)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
Esempio n. 17
0
        return mapped

    data_in = open(args.i, 'r') if args.i else sys.stdin
    df = pd.read_csv(data_in, sep='\t', names=COLUMNS)
    for row in (df.loc[i] for i in df.index):
        datapoint = row_to_dict(row, tokenize)
        try:
            serialized = ''
            if args.o == 'json':
                serialized = json.dumps(datapoint) + '\n'
            elif args.o == 'proto' or args.o == 'proto_text':
                proto = datapoint_to_proto_as_embeddings(datapoint) \
                        if as_embeddings else \
                           datapoint_to_proto_as_words(datapoint)
                serialized = length_prefix_proto(proto) if args.o == 'proto' \
                             else text_format.MessageToString(proto)
            elif args.o == 'token_list':
                serialized = datapoint_to_tokens(datapoint) + ' '
            else:
                print('Unknown output format "%s"' % args.o, file=sys.stderr)
                sys.exit(2)
            assert serialized

            # Note!
            # sys.stdout.write( ...converted to string with str())
            # All errors ignored here.
            sys.stdout.write(str(serialized, 'utf-8', 'ignore'))
        except IOError as e:
            if e.errno == errno.EPIPE:
                sys.exit(0)
            raise e
Esempio n. 18
0
def CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair(lbi_and_uuid):
    serialized = str(text_format.MessageToString(lbi_and_uuid))
    oneflow_api.CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair(serialized)
Esempio n. 19
0
#!/usr/bin/env python3

import socket
import sys

from google.protobuf import text_format

sys.path.append("../src/plugins")

from events_pb2 import *

HOST, PORT = "localhost", 7777

# Create a socket (SOCK_STREAM means a TCP socket)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
    # Connect to server and send data
    sock.connect((HOST, PORT))

    sock.setblocking(True)

    while 1:
        ser = sock.recv(4096)

        ev = EventMessage()
        ev.ParseFromString(ser)
        print(text_format.MessageToString(ev))

Esempio n. 20
0
def InitEnv(env_proto):
    assert type(env_proto) is env_pb2.EnvProto
    env_proto_str = text_format.MessageToString(env_proto)
    oneflow_api.InitEnv(env_proto_str)
Esempio n. 21
0
def main(_):
    if not FLAGS.save_path:
        print "You must specify --save_path."

    with tf.name_scope("inference/rnn"):
        input_data = tf.placeholder(tf.int32, [1, 1], name="inputs")
        input_data_f = tf.to_float(input_data)
        initial_state_c_0 = tf.zeros([1, 3], name="zeros")
        initial_state_h_0 = tf.zeros([1, 3], name="zeros_1")
        initial_state_c_1 = tf.zeros([1, 3], name="zeros_2")
        initial_state_h_1 = tf.zeros([1, 3], name="zeros_3")
        w = tf.constant([[0.2, 0.5, 0.3]])
        final_state_c_0 = tf.add(
            initial_state_c_0 + 0.1,
            input_data_f,
            name="RNN/MultiRNNCell/Cell0/BasicLSTMCell/add_2")
        final_state_h_0 = tf.add(
            initial_state_h_0 + 0.2,
            input_data_f,
            name="RNN/MultiRNNCell/Cell0/BasicLSTMCell/mul_2")
        final_state_c_1 = tf.add(
            initial_state_c_1 + 0.3,
            input_data_f,
            name="RNN/MultiRNNCell/Cell1/BasicLSTMCell/add_2")
        final_state_h_1 = tf.add(
            initial_state_h_1 + 0.4,
            input_data_f,
            name="RNN/MultiRNNCell/Cell1/BasicLSTMCell/mul_2")
        y = final_state_c_0 + final_state_h_0 + final_state_c_1 + final_state_h_1 - (
            3 * input_data_f)
        predictions = tf.mul(tf.matmul(input_data_f, w), y, name="predictions")

    with tf.Session() as sess:
        tf.initialize_all_variables().run()

        # Save model for use in C++.
        # --------------------------

        # Save meta information about the RNN.
        rnn_proto = rnn_pb2.RNNProto()
        rnn_proto.type = 2
        rnn_proto.input_tensor_name = input_data.name
        rnn_proto.logits_tensor_name = predictions.name
        rnn_proto.predictions_tensor_name = predictions.name

        rnn_proto.h.add(
            initial=initial_state_h_0.name,
            final=final_state_h_0.name,
        )
        rnn_proto.h.add(
            initial=initial_state_h_1.name,
            final=final_state_h_1.name,
        )
        rnn_proto.c.add(
            initial=initial_state_c_0.name,
            final=final_state_c_0.name,
        )
        rnn_proto.c.add(
            initial=initial_state_c_1.name,
            final=final_state_c_1.name,
        )
        with open(os.path.join(FLAGS.save_path, "rnn.pbtxt"), "wb") as f:
            f.write(text_format.MessageToString(rnn_proto))

        # Save the vocabulary.
        vocab = vocab_pb2.VocabProto()
        vocab.min_frequency = 1
        word_to_id = {
            "<unk>": 0,
            "<s>": 1,
            "the": 2,
        }
        for w in word_to_id:
            item = vocab.item.add()
            item.id = word_to_id[w]
            item.word = w
        with open(os.path.join(FLAGS.save_path, "vocab.pbtxt"), "wb") as f:
            f.write(text_format.MessageToString(vocab))

        # Note: graph_util.convert_variables_to_constants() appends ':0' onto the variable names, which
        # is why it isn't included in 'inference/lstm/predictions'.
        graph_def = graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=sess.graph.as_graph_def(),
            output_node_names=[predictions.name.split(':', 1)[0]])

        tf.train.write_graph(graph_def,
                             FLAGS.save_path,
                             "graph.pb",
                             as_text=False)
        tf.train.write_graph(graph_def, FLAGS.save_path, "graph.pbtxt")
Esempio n. 22
0
def InitLazyGlobalSession(config_proto):
    assert type(config_proto) is job_set_pb.ConfigProto
    config_proto_str = text_format.MessageToString(config_proto)
    oneflow_api.InitLazyGlobalSession(config_proto_str)
Esempio n. 23
0
if __name__ == '__main__':
    caffe.set_mode_gpu()
    p = make_parser()
    args = p.parse_args()

    # build and save testable net
    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)
    print "Building BN calc net..."
    testable_msg = make_testable(args.train_model)
    BN_calc_path = os.path.join(
        args.out_dir,
        '__for_calculating_BN_stats_' + os.path.basename(args.train_model))
    with open(BN_calc_path, 'w') as f:
        f.write(text_format.MessageToString(testable_msg))

    # use testable net to calculate BN layer stats
    print "Calculate BN stats..."
    train_ims, train_labs = extract_dataset(testable_msg)
    train_size = len(train_ims)
    minibatch_size = testable_msg.layer[0].dense_image_data_param.batch_size
    num_iterations = train_size // minibatch_size + train_size % minibatch_size
    in_h, in_w = (360, 480)
    test_net, test_msg = make_test_files(BN_calc_path, args.weights,
                                         num_iterations, in_h, in_w)

    # save deploy prototxt
    #print "Saving deployment prototext file..."
    #test_path = os.path.join(args.out_dir, "deploy.prototxt")
    #with open(test_path, 'w') as f:
Esempio n. 24
0
def CurJobBuildAndInferCtx_SetTrainConf(train_config_proto):
    serialized_train_conf = str(text_format.MessageToString(train_config_proto))
    oneflow_api.CurJobBuildAndInferCtx_SetTrainConf(serialized_train_conf)
Esempio n. 25
0
def update_checkpoint_state_internal(save_dir,
                                     model_checkpoint_path,
                                     all_model_checkpoint_paths=None,
                                     latest_filename=None,
                                     save_relative_paths=False,
                                     all_model_checkpoint_timestamps=None,
                                     last_preserved_timestamp=None):
    """Updates the content of the 'checkpoint' file.

  This updates the checkpoint file containing a CheckpointState
  proto.

  Args:
    save_dir: Directory where the model was saved.
    model_checkpoint_path: The checkpoint file.
    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
      the last element must be equal to model_checkpoint_path.  These paths
      are also saved in the CheckpointState proto.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.
    save_relative_paths: If `True`, will write relative paths to the checkpoint
      state file.
    all_model_checkpoint_timestamps: Optional list of timestamps (floats,
      seconds since the Epoch) indicating when the checkpoints in
      `all_model_checkpoint_paths` were created.
    last_preserved_timestamp: A float, indicating the number of seconds since
      the Epoch when the last preserved checkpoint was written, e.g. due to a
      `keep_checkpoint_every_n_hours` parameter (see
      `tf.contrib.checkpoint.CheckpointManager` for an implementation).

  Raises:
    RuntimeError: If any of the model checkpoint paths conflict with the file
      containing CheckpointSate.
  """
    # Writes the "checkpoint" file for the coordinator for later restoration.
    coord_checkpoint_filename = _GetCheckpointFilename(save_dir,
                                                       latest_filename)
    if save_relative_paths:
        if os.path.isabs(model_checkpoint_path):
            rel_model_checkpoint_path = os.path.relpath(
                model_checkpoint_path, save_dir)
        else:
            rel_model_checkpoint_path = model_checkpoint_path
        rel_all_model_checkpoint_paths = []
        for p in all_model_checkpoint_paths:
            if os.path.isabs(p):
                rel_all_model_checkpoint_paths.append(
                    os.path.relpath(p, save_dir))
            else:
                rel_all_model_checkpoint_paths.append(p)
        ckpt = generate_checkpoint_state_proto(
            save_dir,
            rel_model_checkpoint_path,
            all_model_checkpoint_paths=rel_all_model_checkpoint_paths,
            all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
            last_preserved_timestamp=last_preserved_timestamp)
    else:
        ckpt = generate_checkpoint_state_proto(
            save_dir,
            model_checkpoint_path,
            all_model_checkpoint_paths=all_model_checkpoint_paths,
            all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
            last_preserved_timestamp=last_preserved_timestamp)

    if coord_checkpoint_filename == ckpt.model_checkpoint_path:
        raise RuntimeError(
            "Save path '%s' conflicts with path used for "
            "checkpoint state.  Please use a different save path." %
            model_checkpoint_path)

    # Preventing potential read/write race condition by *atomically* writing to a
    # file.
    file_io.atomic_write_string_to_file(coord_checkpoint_filename,
                                        text_format.MessageToString(ckpt))
Esempio n. 26
0
def GetOpParallelSymbolId(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    return oneflow_api.GetOpParallelSymbolId(serialized_op_conf)
Esempio n. 27
0
 def write_to_file(self):
     prototxt = open(APOLLO_ROOT + self.conf_file, 'w')
     prototxt.write(text_format.MessageToString(self.proto_root))
     prototxt.close()
Esempio n. 28
0
      def build_example(label, param_dict_real, zip_path_label):
        """Build the model with parameter values set in param_dict_real.

        Args:
          label: Label of the model
          param_dict_real: Parameter dictionary (arguments to the factories
            make_graph and make_test_inputs)
          zip_path_label: Filename in the zip

        Returns:
          (tflite_model_binary, report) where tflite_model_binary is the
          serialized flatbuffer as a string and report is a dictionary with
          keys `toco_log` (log of toco conversion), `tf_log` (log of tf
          conversion), `toco` (a string of success status of the conversion),
          `tf` (a string success status of the conversion).
        """

        np.random.seed(RANDOM_SEED)
        report = {"converter": report_lib.NOTRUN, "tf": report_lib.FAILED}

        # Build graph
        report["tf_log"] = ""
        report["converter_log"] = ""
        tf.reset_default_graph()

        with tf.Graph().as_default():
          with tf.device("/cpu:0"):
            try:
              inputs, outputs = make_graph(param_dict_real)
            except (tf.errors.UnimplementedError,
                    tf.errors.InvalidArgumentError, ValueError):
              report["tf_log"] += traceback.format_exc()
              return None, report

          sess = tf.Session()
          try:
            baseline_inputs, baseline_outputs = (
                make_test_inputs(param_dict_real, sess, inputs, outputs))
          except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
                  ValueError):
            report["tf_log"] += traceback.format_exc()
            return None, report
          report["converter"] = report_lib.FAILED
          report["tf"] = report_lib.SUCCESS
          # Convert graph to toco
          input_tensors = [(input_tensor.name.split(":")[0], input_tensor.shape,
                            input_tensor.dtype) for input_tensor in inputs]
          output_tensors = [_normalize_output_name(out.name) for out in outputs]
          # pylint: disable=g-long-ternary
          graph_def = freeze_graph(
              sess,
              tf.global_variables() + inputs +
              outputs) if use_frozen_graph else sess.graph_def

        if "split_tflite_lstm_inputs" in param_dict_real:
          extra_toco_options.split_tflite_lstm_inputs = param_dict_real[
              "split_tflite_lstm_inputs"]
        tflite_model_binary, toco_log = options.tflite_convert_function(
            options,
            graph_def,
            input_tensors,
            output_tensors,
            extra_toco_options=extra_toco_options,
            test_params=param_dict_real)
        report["converter"] = (
            report_lib.SUCCESS
            if tflite_model_binary is not None else report_lib.FAILED)
        report["converter_log"] = toco_log

        if options.save_graphdefs:
          zipinfo = zipfile.ZipInfo(zip_path_label + ".pbtxt")
          archive.writestr(zipinfo, text_format.MessageToString(graph_def),
                           zipfile.ZIP_DEFLATED)

        if tflite_model_binary:
          if options.make_edgetpu_tests:
            # Set proper min max values according to input dtype.
            baseline_inputs, baseline_outputs = generate_inputs_outputs(
                tflite_model_binary, min_value=0, max_value=255)
          zipinfo = zipfile.ZipInfo(zip_path_label + ".bin")
          archive.writestr(zipinfo, tflite_model_binary, zipfile.ZIP_DEFLATED)
          example = {"inputs": baseline_inputs, "outputs": baseline_outputs}

          example_fp = StringIO()
          write_examples(example_fp, [example])
          zipinfo = zipfile.ZipInfo(zip_path_label + ".inputs")
          archive.writestr(zipinfo, example_fp.getvalue(), zipfile.ZIP_DEFLATED)

          example_fp2 = StringIO()
          write_test_cases(example_fp2, zip_path_label + ".bin", [example])
          zipinfo = zipfile.ZipInfo(zip_path_label + "_tests.txt")
          archive.writestr(zipinfo, example_fp2.getvalue(),
                           zipfile.ZIP_DEFLATED)

          zip_manifest_label = zip_path_label + " " + label
          if zip_path_label == label:
            zip_manifest_label = zip_path_label

          zip_manifest.append(zip_manifest_label + "\n")

        return tflite_model_binary, report
    def test_api(self):
        logging.getLogger().setLevel(logging.DEBUG)
        kvstore_type = 'etcd'
        db_base_dir = 'dp_test'
        os.environ['ETCD_BASE_DIR'] = db_base_dir
        data_portal_name = 'test_data_source'
        kvstore = DBClient(kvstore_type, True)
        kvstore.delete_prefix(db_base_dir)
        portal_input_base_dir='./portal_upload_dir'
        portal_output_base_dir='./portal_output_dir'
        raw_data_publish_dir = 'raw_data_publish_dir'
        portal_manifest = dp_pb.DataPortalManifest(
                name=data_portal_name,
                data_portal_type=dp_pb.DataPortalType.Streaming,
                output_partition_num=4,
                input_file_wildcard="*.done",
                input_base_dir=portal_input_base_dir,
                output_base_dir=portal_output_base_dir,
                raw_data_publish_dir=raw_data_publish_dir,
                processing_job_id=-1,
                next_job_id=0
            )
        kvstore.set_data(common.portal_kvstore_base_dir(data_portal_name),
                      text_format.MessageToString(portal_manifest))
        if gfile.Exists(portal_input_base_dir):
            gfile.DeleteRecursively(portal_input_base_dir)
        gfile.MakeDirs(portal_input_base_dir)
        all_fnames = ['1001/{}.done'.format(i) for i in range(100)]
        all_fnames.append('{}.xx'.format(100))
        all_fnames.append('1001/_SUCCESS')
        for fname in all_fnames:
            fpath = os.path.join(portal_input_base_dir, fname)
            gfile.MakeDirs(os.path.dirname(fpath))
            with gfile.Open(fpath, "w") as f:
                f.write('xxx')
        portal_master_addr = 'localhost:4061'
        portal_options = dp_pb.DataPotraMasterlOptions(
                use_mock_etcd=True,
                long_running=False,
                check_success_tag=True,
            )
        data_portal_master = DataPortalMasterService(
                int(portal_master_addr.split(':')[1]),
                data_portal_name, kvstore_type,
                portal_options
            )
        data_portal_master.start()

        channel = make_insecure_channel(portal_master_addr, ChannelType.INTERNAL)
        portal_master_cli = dp_grpc.DataPortalMasterServiceStub(channel)
        recv_manifest = portal_master_cli.GetDataPortalManifest(empty_pb2.Empty())
        self.assertEqual(recv_manifest.name, portal_manifest.name)
        self.assertEqual(recv_manifest.data_portal_type, portal_manifest.data_portal_type)
        self.assertEqual(recv_manifest.output_partition_num, portal_manifest.output_partition_num)
        self.assertEqual(recv_manifest.input_file_wildcard, portal_manifest.input_file_wildcard)
        self.assertEqual(recv_manifest.input_base_dir, portal_manifest.input_base_dir)
        self.assertEqual(recv_manifest.output_base_dir, portal_manifest.output_base_dir)
        self.assertEqual(recv_manifest.raw_data_publish_dir, portal_manifest.raw_data_publish_dir)
        self.assertEqual(recv_manifest.next_job_id, 1)
        self.assertEqual(recv_manifest.processing_job_id, 0)
        self._check_portal_job(kvstore, all_fnames, portal_manifest, 0)
        mapped_partition = set()
        task_0 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        task_0_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        self.assertEqual(task_0, task_0_1)
        self.assertTrue(task_0.HasField('map_task'))
        mapped_partition.add(task_0.map_task.partition_id)
        self._check_map_task(task_0.map_task, all_fnames,
                             task_0.map_task.partition_id,
                             portal_manifest)
        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=0, partition_id=task_0.map_task.partition_id,
            part_state=dp_pb.PartState.kIdMap)
        )
        task_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        self.assertTrue(task_1.HasField('map_task'))
        mapped_partition.add(task_1.map_task.partition_id)
        self._check_map_task(task_1.map_task, all_fnames,
                             task_1.map_task.partition_id,
                             portal_manifest)

        task_2 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=1))
        self.assertTrue(task_2.HasField('map_task'))
        mapped_partition.add(task_2.map_task.partition_id)
        self._check_map_task(task_2.map_task, all_fnames,
                             task_2.map_task.partition_id,
                             portal_manifest)

        task_3 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=2))
        self.assertTrue(task_3.HasField('map_task'))
        mapped_partition.add(task_3.map_task.partition_id)
        self._check_map_task(task_3.map_task, all_fnames,
                             task_3.map_task.partition_id,
                             portal_manifest)

        self.assertEqual(len(mapped_partition), portal_manifest.output_partition_num)

        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=0, partition_id=task_1.map_task.partition_id,
            part_state=dp_pb.PartState.kIdMap)
        )

        pending_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=4))
        self.assertTrue(pending_1.HasField('pending'))
        pending_2 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=3))
        self.assertTrue(pending_2.HasField('pending'))

        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=1, partition_id=task_2.map_task.partition_id,
            part_state=dp_pb.PartState.kIdMap)
        )

        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=2, partition_id=task_3.map_task.partition_id,
            part_state=dp_pb.PartState.kIdMap)
        )

        reduce_partition = set()
        task_4 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        task_4_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        self.assertEqual(task_4, task_4_1)
        self.assertTrue(task_4.HasField('reduce_task'))
        reduce_partition.add(task_4.reduce_task.partition_id)
        self._check_reduce_task(task_4.reduce_task,
                                task_4.reduce_task.partition_id,
                                portal_manifest)
        task_5 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=1))
        self.assertTrue(task_5.HasField('reduce_task'))
        reduce_partition.add(task_5.reduce_task.partition_id)
        self._check_reduce_task(task_5.reduce_task,
                                task_5.reduce_task.partition_id,
                                portal_manifest)
        task_6 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=2))
        self.assertTrue(task_6.HasField('reduce_task'))
        reduce_partition.add(task_6.reduce_task.partition_id)
        self._check_reduce_task(task_6.reduce_task,
                                task_6.reduce_task.partition_id,
                                portal_manifest)
        task_7= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=3))
        self.assertTrue(task_7.HasField('reduce_task'))
        reduce_partition.add(task_7.reduce_task.partition_id)
        self.assertEqual(len(reduce_partition), 4)
        self._check_reduce_task(task_7.reduce_task,
                                task_7.reduce_task.partition_id,
                                portal_manifest)

        task_8= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=5))
        self.assertTrue(task_8.HasField('pending'))

        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=0, partition_id=task_4.reduce_task.partition_id,
            part_state=dp_pb.PartState.kEventTimeReduce)
        )
        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=1, partition_id=task_5.reduce_task.partition_id,
            part_state=dp_pb.PartState.kEventTimeReduce)
        )
        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=2, partition_id=task_6.reduce_task.partition_id,
            part_state=dp_pb.PartState.kEventTimeReduce)
        )
        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=3, partition_id=task_7.reduce_task.partition_id,
            part_state=dp_pb.PartState.kEventTimeReduce)
        )

        time.sleep(31)
        task_9= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=5))
        self.assertTrue(task_9.HasField('finished'))

        data_portal_master.stop()
        gfile.DeleteRecursively(portal_input_base_dir)
Esempio n. 30
0
    #get masks for regular convolutions
    blobmask = get_masks(ref_net, percentile)

    #get masks for coordinate|confidence convolutions
    blobmask.update(mboxes)

    #resize networks
    sizes = {k: sum(v) for k, v in blobmask.items()}
    train_par = resize_network(train_par, sizes)
    test_par = resize_network(test_par, sizes, verbose=False)
    dep_par = resize_network(dep_par, sizes, verbose=False)

    #write new parameters
    with open('models/ssd_face_pruned/face_train.prototxt', 'w') as f:
        f.write(txtf.MessageToString(train_par))
    with open('models/ssd_face_pruned/face_test.prototxt', 'w') as f:
        f.write(txtf.MessageToString(test_par))
    with open('models/ssd_face_pruned/face_deploy.prototxt', 'w') as f:
        f.write(txtf.MessageToString(dep_par))

    #load pruned net with empty parameters
    new_net = caffe.Net('models/ssd_face_pruned/face_train.prototxt',
                        'models/empty.caffemodel', caffe.TRAIN)

    #copy masked parameters to pruned net
    set_params(ref_net, new_net, train_par, blobmask)

    #save pruned net parameters
    new_net.save('models/ssd_face_pruned/face_init.caffemodel')