diff --git a/src/lib.rs b/src/lib.rs index 1908218..b036d8e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -96,8 +96,45 @@ pub fn decode_tag(bytes: &[u8]) -> Result, Error> { pub trait Schema<'a> { type Element: 'a; - fn should_unwrap(element_id: u64) -> bool; - fn decode<'b: 'a>(element_id: u64, bytes: &'b[u8]) -> Result; + fn should_unwrap(&self, element_id: u64) -> bool; + fn decode<'b: 'a>(&self, element_id: u64, bytes: &'b[u8]) -> Result; + + fn decode_element<'b: 'a>(&self, bytes: &'b[u8]) -> Result, Error> { + match decode_tag(bytes) { + Ok(None) => Ok(None), + Err(err) => Err(err), + Ok(Some((element_id, payload_size_tag, tag_size))) => { + let should_unwrap = self.should_unwrap(element_id); + + let payload_size = match (should_unwrap, payload_size_tag) { + (true, _) => 0, + (false, Varint::Unknown) => return Err(Error::UnknownElementLength), + (false, Varint::Value(size)) => size as usize + }; + + let element_size = tag_size + payload_size; + if element_size > bytes.len() { + // need to read more still + return Ok(None); + } + + match self.decode(element_id, &bytes[tag_size..element_size]) { + Ok(element) => Ok(Some((element, element_size))), + Err(error) => Err(error) + } + } + } + } + + fn iter_for<'b: 'a>(self, bytes: &'b[u8]) -> EbmlIterator<'a, Self> + where Self: Sized + { + EbmlIterator { + schema: self, + slice: bytes, + position: 0 + } + } } pub struct Webm; @@ -110,64 +147,27 @@ pub enum WebmElement<'a> { impl<'a> Schema<'a> for Webm { type Element = WebmElement<'a>; - fn should_unwrap(element_id: u64) -> bool { + fn should_unwrap(&self, element_id: u64) -> bool { false } - fn decode<'b: 'a>(element_id: u64, bytes: &'b[u8]) -> Result, Error> { + fn decode<'b: 'a>(&self, element_id: u64, bytes: &'b[u8]) -> Result, Error> { // dummy Ok(WebmElement::Unknown(element_id, bytes)) } } -pub fn decode_element<'a, 'b: 'a, T: Schema<'a>>(bytes: &'b[u8]) -> Result, Error> { - match decode_tag(bytes) { - Ok(None) => Ok(None), - Err(err) => Err(err), - Ok(Some((element_id, payload_size_tag, tag_size))) => { - let should_unwrap = T::should_unwrap(element_id); - - let payload_size = match (should_unwrap, payload_size_tag) { - (true, _) => 0, - (false, Varint::Unknown) => return Err(Error::UnknownElementLength), - (false, Varint::Value(size)) => size as usize - }; - - let element_size = tag_size + payload_size; - if element_size > bytes.len() { - // need to read more still - return Ok(None); - } - - match T::decode(element_id, &bytes[tag_size..element_size]) { - Ok(element) => Ok(Some((element, element_size))), - Err(error) => Err(error) - } - } - } -} - pub struct EbmlIterator<'b, T: Schema<'b>> { - schema: std::marker::PhantomData, + schema: T, slice: &'b[u8], position: usize, } -impl<'b, T: Schema<'b>> EbmlIterator<'b, T> { - pub fn new(bytes: &'b[u8]) -> Self { - EbmlIterator { - schema: std::marker::PhantomData, - slice: bytes, - position: 0 - } - } -} - impl<'b, T: Schema<'b>> Iterator for EbmlIterator<'b, T> { type Item = T::Element; fn next(&mut self) -> Option { - match decode_element::(&self.slice[self.position..]) { + match self.schema.decode_element(&self.slice[self.position..]) { Err(_) => None, Ok(None) => None, Ok(Some((element, element_size))) => { @@ -249,7 +249,7 @@ mod tests { #[test] fn decode_sanity_test() { - let decoded = decode_element::(TEST_FILE); + let decoded = Webm.decode_element(TEST_FILE); if let Ok(Some((WebmElement::Unknown(tag, slice), element_size))) = decoded { assert_eq!(tag, 0x0A45DFA3); // EBML tag, sans the length indicator bit assert_eq!(slice.len(), 31); // known header payload length @@ -273,7 +273,7 @@ mod tests { #[test] fn decode_webm_test1() { - let mut iter = EbmlIterator::::new(TEST_FILE); + let mut iter = Webm.iter_for(TEST_FILE); // EBML Header assert_webm_blob(iter.next(), 0x0A45DFA3, 31);