Source code for hedwig.validators.protobuf

import base64
import re
import typing
from copy import deepcopy
from distutils.version import StrictVersion
from functools import lru_cache
from importlib import import_module
from types import ModuleType
from typing import Tuple, Union, TypeVar, Type

import funcy
from google.protobuf import json_format
from google.protobuf.any_pb2 import Any
from google.protobuf.message import DecodeError, Message as ProtoMessage
from google.protobuf.struct_pb2 import Value

from hedwig.conf import settings
from hedwig.exceptions import ValidationError
from hedwig.protobuf import options_pb2
from hedwig.protobuf.container_pb2 import PayloadV1
from hedwig.validators.base import HedwigBaseValidator, MetaAttributes


class SchemaError(Exception):
    pass


ProtoMessageT = TypeVar("ProtoMessageT", bound=ProtoMessage)


def decode_proto_json(msg_class: typing.Any, value: Union[str, bytes]) -> ProtoMessageT:
    assert isinstance(value, str)

    msg: ProtoMessageT = msg_class()
    json_format.Parse(value, msg, ignore_unknown_fields=True)
    return msg


def encode_proto_json(msg: ProtoMessage) -> str:
    return json_format.MessageToJson(msg, preserving_proto_field_name=True, indent=0).replace("\n", "")


[docs]class ProtobufValidator(HedwigBaseValidator): """ A validator that encodes the payload using Protobuf binary format. """ schema_module: ModuleType """ The module that contains protoc compiled python classes - supplied by app """ _version_pattern_re = re.compile(r"^([0-9]+)\.\*$") def __init__(self, schema_module: ModuleType = None) -> None: # schema encoding, eg: hedwig.automatic.com/schema#/schemas/trip.created/1.0 schema_fmt = '{message_type}/{message_version}' schema_re = re.compile(r'([^/]+)/([^/]+)$') super().__init__(schema_fmt, schema_re, StrictVersion('1.0')) if schema_module is None: schema_module = settings.HEDWIG_PROTOBUF_SCHEMA_MODULE if isinstance(schema_module, str): schema_module = import_module(schema_module) self.schema_module = schema_module self._check_schema(schema_module) def _decode_proto(self, msg_class: typing.Any, value: Union[str, bytes]) -> ProtoMessageT: """ Decode an arbitrary protobuf message from value """ assert isinstance(value, bytes) msg: ProtoMessageT = msg_class() msg.ParseFromString(value) return msg def _encode_proto(self, msg: ProtoMessage) -> Union[str, bytes]: """ Encode an arbitrary protobuf message into value """ return msg.SerializeToString() @lru_cache(maxsize=20) def _msg_class(self, message_type: str, major_version: int) -> Type[ProtoMessageT]: msg_class_name = self._msg_class_name(message_type, major_version) if not hasattr(self.schema_module, msg_class_name): raise ValidationError( f"Protobuf message class not found for '{message_type}' v{major_version}. " f"Must be named '{msg_class_name}'" ) msg_class = getattr(self.schema_module, msg_class_name) return msg_class def _verify_known_minor_version(self, message_type: str, full_version: StrictVersion): msg_class = self._msg_class(message_type, full_version.version[0]) options = msg_class.DESCRIPTOR.GetOptions().Extensions[options_pb2.message_options] if options.minor_version < full_version.version[1]: raise ValidationError( f'Unknown minor version: {full_version.version[1]}, last known minor version: ' f'{options.minor_version}' ) def _extract_data( self, message_payload: Union[bytes, str], attributes: dict, use_transport_attributes: bool ) -> Tuple[MetaAttributes, bytes]: assert isinstance(message_payload, (bytes, str)) if not use_transport_attributes: try: msg_payload = self._decode_proto(PayloadV1, message_payload) # type: ignore except (DecodeError, RuntimeError, AssertionError, json_format.ParseError) as e: raise ValidationError(f"Invalid data for message: PayloadV1: {e}") data = msg_payload.data meta_attrs = MetaAttributes( msg_payload.metadata.timestamp.ToMilliseconds(), msg_payload.metadata.publisher, dict(msg_payload.metadata.headers), msg_payload.id, msg_payload.schema, msg_payload.format_version, ) else: data = message_payload meta_attrs = self._decode_meta_attributes(attributes) if meta_attrs.format_version != self._current_format_version: raise ValidationError(f"Invalid format version: {meta_attrs.format_version}") return meta_attrs, data def _extract_data_firehose(self, line: str) -> Tuple[MetaAttributes, bytes]: assert isinstance(line, str) try: msg_payload = decode_proto_json(PayloadV1, line) # type: ignore except (DecodeError, RuntimeError, AssertionError, json_format.ParseError) as e: raise ValidationError(f"Invalid data for message: PayloadV1: {e}") data = msg_payload.data if data.Is(Value.DESCRIPTOR): # unknown data types are encoded as `encode(PayloadV1(data=Value(string_value=base64(binary))))` value_msg = Value() data.Unpack(value_msg) assert value_msg.WhichOneof("kind") == "string_value" data = base64.decodebytes(value_msg.string_value.encode("utf8")) meta_attrs = MetaAttributes( msg_payload.metadata.timestamp.ToMilliseconds(), msg_payload.metadata.publisher, dict(msg_payload.metadata.headers), msg_payload.id, msg_payload.schema, msg_payload.format_version, ) return meta_attrs, data def _decode_data( self, meta_attrs: MetaAttributes, message_type: str, full_version: StrictVersion, data: Union[Any, bytes, str], ) -> ProtoMessage: assert isinstance(data, (Any, bytes, str)) msg_class = self._msg_class(message_type, full_version.version[0]) try: if isinstance(data, Any): assert data.Is(msg_class.DESCRIPTOR) data_msg = msg_class() data.Unpack(data_msg) else: data_msg = self._decode_proto(msg_class, data) except (DecodeError, RuntimeError, AssertionError, json_format.ParseError) as e: raise ValidationError(f"Invalid data for message: {msg_class.__name__}: {e}") return data_msg def _encode_payload( self, meta_attrs: MetaAttributes, data: ProtoMessage, use_transport_attributes: bool ) -> Tuple[Union[bytes, str], dict]: assert isinstance(data, ProtoMessage) if not use_transport_attributes: msg = PayloadV1() msg.format_version = str(self._current_format_version) msg.id = str(meta_attrs.id) msg.metadata.publisher = meta_attrs.publisher msg.metadata.timestamp.FromMilliseconds(meta_attrs.timestamp) for k, v in meta_attrs.headers.items(): msg.metadata.headers[k] = v msg.schema = meta_attrs.schema msg.data.Pack(data) payload = self._encode_proto(msg) msg_attrs = deepcopy(meta_attrs.headers) else: payload = self._encode_proto(data) msg_attrs = self._encode_meta_attributes(meta_attrs) return payload, msg_attrs def _encode_payload_firehose( self, message_type: str, version: StrictVersion, meta_attrs: MetaAttributes, data: ProtoMessage ) -> str: assert isinstance(data, ProtoMessage) msg = PayloadV1() msg.format_version = str(self._current_format_version) msg.id = str(meta_attrs.id) msg.metadata.publisher = meta_attrs.publisher msg.metadata.timestamp.FromMilliseconds(meta_attrs.timestamp) for k, v in meta_attrs.headers.items(): msg.metadata.headers[k] = v msg.schema = meta_attrs.schema try: self._verify_known_minor_version(message_type, version) msg.data.Pack(data) except ValidationError: value_msg = Value() value_msg.string_value = base64.b64encode(data.SerializeToString()).decode() msg.data.Pack(value_msg) payload = encode_proto_json(msg) return payload @classmethod def _msg_class_name(cls, msg_type: str, major_version: int) -> str: normalized_type = msg_type for ch in ('.', '_', '-'): normalized_type = normalized_type.replace(ch, ' ') normalized_type = normalized_type.title().replace(' ', '') return f"{normalized_type}V{major_version}" @classmethod def _check_schema(cls, schema_module: ModuleType) -> None: msg_types_found = {k for k in funcy.chain(settings.HEDWIG_MESSAGE_ROUTING, settings.HEDWIG_CALLBACKS)} errors = [] for message_type, version_pattern in msg_types_found: m = cls._version_pattern_re.match(version_pattern) if not m: errors.append(f"Invalid version '{version_pattern}' for message: '{message_type}'") continue major_version = int(m.group(1)) if major_version == 0: errors.append(f"Invalid version '{major_version}' for message: '{message_type}'. Must not be 0.") msg_class_name = cls._msg_class_name(message_type, major_version) if not hasattr(schema_module, msg_class_name): errors.append( f"Protobuf message class not found for '{message_type}' v{major_version}. " f"Must be named '{msg_class_name}'" ) continue msg_class = getattr(schema_module, msg_class_name) if options_pb2.message_options not in msg_class.DESCRIPTOR.GetOptions().Extensions: errors.append(f"Protobuf message class '{msg_class_name}' does not define option message_options") options = msg_class.DESCRIPTOR.GetOptions().Extensions[options_pb2.message_options] if not options.major_version: # default is 0 which is invalid errors.append( f"Protobuf message class '{msg_class_name}' does not define option message_options.major_version" ) elif options.major_version != major_version: errors.append( f"Protobuf message class '{msg_class_name}' option message_options.major_version isn't valid: " f"{options.major_version}, expected: {major_version}" ) # minor_version default value is 0, which is valid and type is already validated by protoc, # so nothing to do here for minor_version. if errors: raise SchemaError(str(errors))
class ProtobufJSONValidator(ProtobufValidator): """ A validator that encodes the payload using Protobuf JSON format. Documentation: https://googleapis.dev/python/protobuf/latest/google/protobuf/json_format.html """ def _decode_proto(self, msg_class: typing.Any, value: Union[str, bytes]) -> ProtoMessageT: return decode_proto_json(msg_class, value) def _encode_proto(self, msg: ProtoMessage) -> Union[str, bytes]: return encode_proto_json(msg)