解答

  1. use std::convert::TryFrom;
  2. use thiserror::Error;
  3. #[derive(Debug, Error)]
  4. enum Error {
  5. #[error("Invalid varint")]
  6. InvalidVarint,
  7. #[error("Invalid wire-type")]
  8. InvalidWireType,
  9. #[error("Unexpected EOF")]
  10. UnexpectedEOF,
  11. #[error("Invalid length")]
  12. InvalidSize(#[from] std::num::TryFromIntError),
  13. #[error("Unexpected wire-type)")]
  14. UnexpectedWireType,
  15. #[error("Invalid string (not UTF-8)")]
  16. InvalidString,
  17. }
  18. /// A wire type as seen on the wire.
  19. enum WireType {
  20. /// Varint WireType 表明该值为单个 VARINT。
  21. Varint,
  22. //I64, -- not needed for this exercise
  23. /// The Len WireType indicates that the value is a length represented as a
  24. /// VARINT followed by exactly that number of bytes.
  25. Len,
  26. /// The I32 WireType indicates that the value is precisely 4 bytes in
  27. /// little-endian order containing a 32-bit signed integer.
  28. I32,
  29. }
  30. #[derive(Debug)]
  31. /// A field's value, typed based on the wire type.
  32. enum FieldValue<'a> {
  33. Varint(u64),
  34. //I64(i64), -- not needed for this exercise
  35. Len(&'a [u8]),
  36. I32(i32),
  37. }
  38. #[derive(Debug)]
  39. /// A field, containing the field number and its value.
  40. struct Field<'a> {
  41. field_num: u64,
  42. value: FieldValue<'a>,
  43. }
  44. trait ProtoMessage<'a>: Default + 'a {
  45. fn add_field(&mut self, field: Field<'a>) -> Result<(), Error>;
  46. }
  47. impl TryFrom<u64> for WireType {
  48. type Error = Error;
  49. fn try_from(value: u64) -> Result<WireType, Error> {
  50. Ok(match value {
  51. 0 => WireType::Varint,
  52. //1 => WireType::I64, -- not needed for this exercise
  53. 2 => WireType::Len,
  54. 5 => WireType::I32,
  55. _ => return Err(Error::InvalidWireType),
  56. })
  57. }
  58. }
  59. impl<'a> FieldValue<'a> {
  60. fn as_string(&self) -> Result<&'a str, Error> {
  61. let FieldValue::Len(data) = self else {
  62. return Err(Error::UnexpectedWireType);
  63. };
  64. std::str::from_utf8(data).map_err(|_| Error::InvalidString)
  65. }
  66. fn as_bytes(&self) -> Result<&'a [u8], Error> {
  67. let FieldValue::Len(data) = self else {
  68. return Err(Error::UnexpectedWireType);
  69. };
  70. Ok(data)
  71. }
  72. fn as_u64(&self) -> Result<u64, Error> {
  73. let FieldValue::Varint(value) = self else {
  74. return Err(Error::UnexpectedWireType);
  75. };
  76. Ok(*value)
  77. }
  78. }
  79. /// Parse a VARINT, returning the parsed value and the remaining bytes.
  80. fn parse_varint(data: &[u8]) -> Result<(u64, &[u8]), Error> {
  81. for i in 0..7 {
  82. let Some(b) = data.get(i) else {
  83. return Err(Error::InvalidVarint);
  84. };
  85. if b & 0x80 == 0 {
  86. // This is the last byte of the VARINT, so convert it to
  87. // a u64 and return it.
  88. let mut value = 0u64;
  89. for b in data[..=i].iter().rev() {
  90. value = (value << 7) | (b & 0x7f) as u64;
  91. }
  92. return Ok((value, &data[i + 1..]));
  93. }
  94. }
  95. // More than 7 bytes is invalid.
  96. Err(Error::InvalidVarint)
  97. }
  98. /// Convert a tag into a field number and a WireType.
  99. fn unpack_tag(tag: u64) -> Result<(u64, WireType), Error> {
  100. let field_num = tag >> 3;
  101. let wire_type = WireType::try_from(tag & 0x7)?;
  102. Ok((field_num, wire_type))
  103. }
  104. /// Parse a field, returning the remaining bytes
  105. fn parse_field(data: &[u8]) -> Result<(Field, &[u8]), Error> {
  106. let (tag, remainder) = parse_varint(data)?;
  107. let (field_num, wire_type) = unpack_tag(tag)?;
  108. let (fieldvalue, remainder) = match wire_type {
  109. WireType::Varint => {
  110. let (value, remainder) = parse_varint(remainder)?;
  111. (FieldValue::Varint(value), remainder)
  112. }
  113. WireType::Len => {
  114. let (len, remainder) = parse_varint(remainder)?;
  115. let len: usize = len.try_into()?;
  116. if remainder.len() < len {
  117. return Err(Error::UnexpectedEOF);
  118. }
  119. let (value, remainder) = remainder.split_at(len);
  120. (FieldValue::Len(value), remainder)
  121. }
  122. WireType::I32 => {
  123. if remainder.len() < 4 {
  124. return Err(Error::UnexpectedEOF);
  125. }
  126. let (value, remainder) = remainder.split_at(4);
  127. // Unwrap error because `value` is definitely 4 bytes long.
  128. let value = i32::from_le_bytes(value.try_into().unwrap());
  129. (FieldValue::I32(value), remainder)
  130. }
  131. };
  132. Ok((Field { field_num, value: fieldvalue }, remainder))
  133. }
  134. /// Parse a message in the given data, calling `T::add_field` for each field in
  135. /// the message.
  136. ///
  137. /// The entire input is consumed.
  138. fn parse_message<'a, T: ProtoMessage<'a>>(mut data: &'a [u8]) -> Result<T, Error> {
  139. let mut result = T::default();
  140. while !data.is_empty() {
  141. let parsed = parse_field(data)?;
  142. result.add_field(parsed.0)?;
  143. data = parsed.1;
  144. }
  145. Ok(result)
  146. }
  147. #[derive(Debug, Default)]
  148. struct PhoneNumber<'a> {
  149. number: &'a str,
  150. type_: &'a str,
  151. }
  152. #[derive(Debug, Default)]
  153. struct Person<'a> {
  154. name: &'a str,
  155. id: u64,
  156. phone: Vec<PhoneNumber<'a>>,
  157. }
  158. impl<'a> ProtoMessage<'a> for Person<'a> {
  159. fn add_field(&mut self, field: Field<'a>) -> Result<(), Error> {
  160. match field.field_num {
  161. 1 => self.name = field.value.as_string()?,
  162. 2 => self.id = field.value.as_u64()?,
  163. 3 => self.phone.push(parse_message(field.value.as_bytes()?)?),
  164. _ => {} // skip everything else
  165. }
  166. Ok(())
  167. }
  168. }
  169. impl<'a> ProtoMessage<'a> for PhoneNumber<'a> {
  170. fn add_field(&mut self, field: Field<'a>) -> Result<(), Error> {
  171. match field.field_num {
  172. 1 => self.number = field.value.as_string()?,
  173. 2 => self.type_ = field.value.as_string()?,
  174. _ => {} // skip everything else
  175. }
  176. Ok(())
  177. }
  178. }
  179. fn main() {
  180. let person: Person = parse_message(&[
  181. 0x0a, 0x07, 0x6d, 0x61, 0x78, 0x77, 0x65, 0x6c, 0x6c, 0x10, 0x2a, 0x1a,
  182. 0x16, 0x0a, 0x0e, 0x2b, 0x31, 0x32, 0x30, 0x32, 0x2d, 0x35, 0x35, 0x35,
  183. 0x2d, 0x31, 0x32, 0x31, 0x32, 0x12, 0x04, 0x68, 0x6f, 0x6d, 0x65, 0x1a,
  184. 0x18, 0x0a, 0x0e, 0x2b, 0x31, 0x38, 0x30, 0x30, 0x2d, 0x38, 0x36, 0x37,
  185. 0x2d, 0x35, 0x33, 0x30, 0x38, 0x12, 0x06, 0x6d, 0x6f, 0x62, 0x69, 0x6c,
  186. 0x65,
  187. ])
  188. .unwrap();
  189. println!("{:#?}", person);
  190. }
  191. #[cfg(test)]
  192. mod test {
  193. use super::*;
  194. #[test]
  195. fn as_string() {
  196. assert!(FieldValue::Varint(10).as_string().is_err());
  197. assert!(FieldValue::I32(10).as_string().is_err());
  198. assert_eq!(FieldValue::Len(b"hello").as_string().unwrap(), "hello");
  199. }
  200. #[test]
  201. fn as_bytes() {
  202. assert!(FieldValue::Varint(10).as_bytes().is_err());
  203. assert!(FieldValue::I32(10).as_bytes().is_err());
  204. assert_eq!(FieldValue::Len(b"hello").as_bytes().unwrap(), b"hello");
  205. }
  206. #[test]
  207. fn as_u64() {
  208. assert_eq!(FieldValue::Varint(10).as_u64().unwrap(), 10u64);
  209. assert!(FieldValue::I32(10).as_u64().is_err());
  210. assert!(FieldValue::Len(b"hello").as_u64().is_err());
  211. }
  212. }