def _merge_shapes_for_tf(shape1, shape2): """ Merge 2 shapes, return merged shape, set unknown for dims with different values. Raise exception for mismatch. """ if shape1 is None: return shape2 if shape2 is None: return shape1 utils.make_sure(utils.is_list_or_tuple(shape1), "invalid type for shape1") utils.make_sure(utils.is_list_or_tuple(shape2), "invalid type for shape2") utils.make_sure( len(shape1) == len(shape2), "shapes rank mismatch: shape1=%s, shape2=%s", shape1, shape2) merged = [] for d1, d2 in zip(shape1, shape2): d = d1 if d1 is None: d = d2 elif d2 is not None: # None means unknown in tensorflow d = None merged.append(d) return merged
def check_opset_constraints(self, opset, extra_opset=None): """ Return (condition, reason) tuple, condition is True if constraints are met. """ if not self.opset_constraints: return True, None opsets = {"onnx": opset} if extra_opset: for e in extra_opset: opsets[e.domain] = e.version for constraint in self.opset_constraints: domain = constraint.domain opset_version = opsets.get(domain) if not opset_version: return False, "conversion requires opset {}".format(domain) if constraint.min_version and opset_version < constraint.min_version: reason = "conversion requires opset {} >= {}".format(domain, constraint.min_version) return False, reason if constraint.max_version and opset_version > constraint.max_version: reason = "conversion requires opset {} <= {}".format(domain, constraint.max_version) return False, reason if constraint.excluded_version: if utils.is_list_or_tuple(constraint.excluded_version): skip = opset_version in constraint.excluded_version else: skip = opset_version == constraint.excluded_version if skip: reason = "conversion requires opset {} != {}".format(domain, constraint.excluded_version) return False, reason return True, None