diff --git a/features/src/structs.rs b/features/src/structs.rs index aefbde5..849d8ac 100644 --- a/features/src/structs.rs +++ b/features/src/structs.rs @@ -1,6 +1,9 @@ use std::fmt; use std::ops; +use num_traits::NumCast; +use num_traits::cast; + #[derive(Debug, Copy, Clone)] pub struct Tuple { x: f32, @@ -9,13 +12,24 @@ pub struct Tuple { w: f32, } +macro_rules! num_traits_cast { + ($tt:ident) => { + cast($tt).unwrap() + }; +} + impl Tuple { - pub fn new(x: T, y: T, z: T, w: T) -> Tuple { + pub fn new(x: X, y: Y, z: Z, w: W) -> Tuple + where X: NumCast, + Y: NumCast, + Z: NumCast, + W: NumCast, + { Tuple { - x: num_traits::cast(x).unwrap(), - y: num_traits::cast(y).unwrap(), - z: num_traits::cast(z).unwrap(), - w: num_traits::cast(w).unwrap(), + x: num_traits_cast!(x), + y: num_traits_cast!(y), + z: num_traits_cast!(z), + w: num_traits_cast!(w), } } @@ -28,20 +42,28 @@ impl Tuple { } } - pub fn point(x: f32, y: f32, z: f32) -> Tuple { + pub fn point(x: T, y: U, z: V) -> Tuple + where T: NumCast, + U: NumCast, + V: NumCast, + + { Tuple { - x, - y, - z, + x: num_traits_cast!(x), + y: num_traits_cast!(y), + z: num_traits_cast!(z), w: 1.0, } } - pub fn vector(x: f32, y: f32, z: f32) -> Tuple { + pub fn vector(x: X, y: Y, z: Z) -> Tuple + where X: NumCast, + Y: NumCast, + Z: NumCast { Tuple { - x, - y, - z, + x: num_traits_cast!(x), + y: num_traits_cast!(y), + z: num_traits_cast!(z), w: 0.0, } } @@ -427,9 +449,36 @@ mod tests { #[test] fn works_with_i32() { - let a = Tuple::new(1, 2, 3); + let a = Tuple::new(1, 2, 3, 0); assert_eq!(1.0, a.x()); assert_eq!(2.0, a.y()); assert_eq!(3.0, a.z()); } + + #[test] + fn works_with_mixed_types() { + let a = Tuple::new(1.1, 2.2, 3, 0); + assert_eq!(1.1, a.x()); + assert_eq!(2.2, a.y()); + assert_eq!(3.0, a.z()); + } + + #[test] + fn point_with_mixed_types() { + let a = Tuple::point(1.0, 2.2, 3); + assert_eq!(1.0, a.x()); + assert_eq!(2.2, a.y()); + assert_eq!(3.0, a.z()); + assert_eq!(1.0, a.w()); + } + + #[test] + fn vector_with_mixed_types() { + let a = Tuple::vector(1.0, 2.2, 3); + assert_eq!(1.0, a.x()); + assert_eq!(2.2, a.y()); + assert_eq!(3.0, a.z()); + assert_eq!(0.0, a.w()); + + } }