Solution

  1. /// A wire type as seen on the wire.
  2. enum WireType {
  3.     /// The Varint WireType indicates the value is a single VARINT.
  4.     Varint,
  5.     /// The I64 WireType indicates that the value is precisely 8 bytes in
  6.     /// little-endian order containing a 64-bit signed integer or double type.
  7.     //I64,  -- not needed for this exercise
  8.     /// The Len WireType indicates that the value is a length represented as a
  9.     /// VARINT followed by exactly that number of bytes.
  10.     Len,
  11.     // The I32 WireType indicates that the value is precisely 4 bytes in
  12.     // little-endian order containing a 32-bit signed integer or float type.
  13.     //I32,  -- not needed for this exercise
  14. }
  15. #[derive(Debug)]
  16. /// A field's value, typed based on the wire type.
  17. enum FieldValue<'a> {
  18.     Varint(u64),
  19.     //I64(i64),  -- not needed for this exercise
  20.     Len(&'a [u8]),
  21.     //I32(i32),  -- not needed for this exercise
  22. }
  23. #[derive(Debug)]
  24. /// A field, containing the field number and its value.
  25. struct Field<'a> {
  26.     field_num: u64,
  27.     value: FieldValue<'a>,
  28. }
  29. trait ProtoMessage<'a>: Default {
  30.     fn add_field(&mut self, field: Field<'a>);
  31. }
  32. impl From<u64> for WireType {
  33.     fn from(value: u64) -> Self {
  34.         match value {
  35.             0 => WireType::Varint,
  36.             //1 => WireType::I64,  -- not needed for this exercise
  37.             2 => WireType::Len,
  38.             //5 => WireType::I32,  -- not needed for this exercise
  39.             _ => panic!("Invalid wire type: {value}"),
  40.         }
  41.     }
  42. }
  43. impl<'a> FieldValue<'a> {
  44.     fn as_str(&self) -> &'a str {
  45.         let FieldValue::Len(data) = self else {
  46.             panic!("Expected string to be a `Len` field");
  47.         };
  48.         std::str::from_utf8(data).expect("Invalid string")
  49.     }
  50.     fn as_bytes(&self) -> &'a [u8] {
  51.         let FieldValue::Len(data) = self else {
  52.             panic!("Expected bytes to be a `Len` field");
  53.         };
  54.         data
  55.     }
  56.     fn as_u64(&self) -> u64 {
  57.         let FieldValue::Varint(value) = self else {
  58.             panic!("Expected `u64` to be a `Varint` field");
  59.         };
  60.         *value
  61.     }
  62. }
  63. /// Parse a VARINT, returning the parsed value and the remaining bytes.
  64. fn parse_varint(data: &[u8]) -> (u64, &[u8]) {
  65.     for i in 0..7 {
  66.         let Some(b) = data.get(i) else {
  67.             panic!("Not enough bytes for varint");
  68.         };
  69.         if b & 0x80 == 0 {
  70.             // This is the last byte of the VARINT, so convert it to
  71.             // a u64 and return it.
  72.             let mut value = 0u64;
  73.             for b in data[..=i].iter().rev() {
  74.                 value = (value << 7) | (b & 0x7f) as u64;
  75.             }
  76.             return (value, &data[i + 1..]);
  77.         }
  78.     }
  79.     // More than 7 bytes is invalid.
  80.     panic!("Too many bytes for varint");
  81. }
  82. /// Convert a tag into a field number and a WireType.
  83. fn unpack_tag(tag: u64) -> (u64, WireType) {
  84.     let field_num = tag >> 3;
  85.     let wire_type = WireType::from(tag & 0x7);
  86.     (field_num, wire_type)
  87. }
  88. /// Parse a field, returning the remaining bytes
  89. fn parse_field(data: &[u8]) -> (Field, &[u8]) {
  90.     let (tag, remainder) = parse_varint(data);
  91.     let (field_num, wire_type) = unpack_tag(tag);
  92.     let (fieldvalue, remainder) = match wire_type {
  93.         WireType::Varint => {
  94.             let (value, remainder) = parse_varint(remainder);
  95.             (FieldValue::Varint(value), remainder)
  96.         }
  97.         WireType::Len => {
  98.             let (len, remainder) = parse_varint(remainder);
  99.             let len: usize = len.try_into().expect("len not a valid `usize`");
  100.             if remainder.len() < len {
  101.                 panic!("Unexpected EOF");
  102.             }
  103.             let (value, remainder) = remainder.split_at(len);
  104.             (FieldValue::Len(value), remainder)
  105.         }
  106.     };
  107.     (Field { field_num, value: fieldvalue }, remainder)
  108. }
  109. /// Parse a message in the given data, calling `T::add_field` for each field in
  110. /// the message.
  111. ///
  112. /// The entire input is consumed.
  113. fn parse_message<'a, T: ProtoMessage<'a>>(mut data: &'a [u8]) -> T {
  114.     let mut result = T::default();
  115.     while !data.is_empty() {
  116.         let parsed = parse_field(data);
  117.         result.add_field(parsed.0);
  118.         data = parsed.1;
  119.     }
  120.     result
  121. }
  122. #[derive(PartialEq)]
  123. #[derive(Debug, Default)]
  124. struct PhoneNumber<'a> {
  125.     number: &'a str,
  126.     type_: &'a str,
  127. }
  128. #[derive(PartialEq)]
  129. #[derive(Debug, Default)]
  130. struct Person<'a> {
  131.     name: &'a str,
  132.     id: u64,
  133.     phone: Vec<PhoneNumber<'a>>,
  134. }
  135. impl<'a> ProtoMessage<'a> for Person<'a> {
  136.     fn add_field(&mut self, field: Field<'a>) {
  137.         match field.field_num {
  138.             1 => self.name = field.value.as_str(),
  139.             2 => self.id = field.value.as_u64(),
  140.             3 => self.phone.push(parse_message(field.value.as_bytes())),
  141.             _ => {} // skip everything else
  142.         }
  143.     }
  144. }
  145. impl<'a> ProtoMessage<'a> for PhoneNumber<'a> {
  146.     fn add_field(&mut self, field: Field<'a>) {
  147.         match field.field_num {
  148.             1 => self.number = field.value.as_str(),
  149.             2 => self.type_ = field.value.as_str(),
  150.             _ => {} // skip everything else
  151.         }
  152.     }
  153. }
  154. fn main() {
  155.     let person: Person = parse_message(&[
  156.         0x0a, 0x07, 0x6d, 0x61, 0x78, 0x77, 0x65, 0x6c, 0x6c, 0x10, 0x2a, 0x1a,
  157.         0x16, 0x0a, 0x0e, 0x2b, 0x31, 0x32, 0x30, 0x32, 0x2d, 0x35, 0x35, 0x35,
  158.         0x2d, 0x31, 0x32, 0x31, 0x32, 0x12, 0x04, 0x68, 0x6f, 0x6d, 0x65, 0x1a,
  159.         0x18, 0x0a, 0x0e, 0x2b, 0x31, 0x38, 0x30, 0x30, 0x2d, 0x38, 0x36, 0x37,
  160.         0x2d, 0x35, 0x33, 0x30, 0x38, 0x12, 0x06, 0x6d, 0x6f, 0x62, 0x69, 0x6c,
  161.         0x65,
  162.     ]);
  163.     println!("{:#?}", person);
  164. }
  165. #[cfg(test)]
  166. mod tests {
  167.     use super::*;
  168.     #[test]
  169.     fn test_id() {
  170.         let person_id: Person = parse_message(&[0x10, 0x2a]);
  171.         assert_eq!(person_id, Person { name: "", id: 42, phone: vec![] });
  172.     }
  173.     #[test]
  174.     fn test_name() {
  175.         let person_name: Person = parse_message(&[
  176.             0x0a, 0x0e, 0x62, 0x65, 0x61, 0x75, 0x74, 0x69, 0x66, 0x75, 0x6c, 0x20,
  177.             0x6e, 0x61, 0x6d, 0x65,
  178.         ]);
  179.         assert_eq!(
  180.             person_name,
  181.             Person { name: "beautiful name", id: 0, phone: vec![] }
  182.         );
  183.     }
  184.     #[test]
  185.     fn test_just_person() {
  186.         let person_name_id: Person =
  187.             parse_message(&[0x0a, 0x04, 0x45, 0x76, 0x61, 0x6e, 0x10, 0x16]);
  188.         assert_eq!(person_name_id, Person { name: "Evan", id: 22, phone: vec![] });
  189.     }
  190.     #[test]
  191.     fn test_phone() {
  192.         let phone: Person = parse_message(&[
  193.             0x0a, 0x00, 0x10, 0x00, 0x1a, 0x16, 0x0a, 0x0e, 0x2b, 0x31, 0x32, 0x33,
  194.             0x34, 0x2d, 0x37, 0x37, 0x37, 0x2d, 0x39, 0x30, 0x39, 0x30, 0x12, 0x04,
  195.             0x68, 0x6f, 0x6d, 0x65,
  196.         ]);
  197.         assert_eq!(
  198.             phone,
  199.             Person {
  200.                 name: "",
  201.                 id: 0,
  202.                 phone: vec![PhoneNumber { number: "+1234-777-9090", type_: "home" },],
  203.             }
  204.         );
  205.     }
  206. }