diff --git a/matrix/src/lib.rs b/matrix/src/lib.rs index 6a0aefc..e4024d2 100644 --- a/matrix/src/lib.rs +++ b/matrix/src/lib.rs @@ -6,24 +6,24 @@ use structs::Tuple; use std::ops::Index; #[derive(Debug)] -pub struct Matrix { - matrix: [[f32; COUNT]; COUNT], +pub struct Matrix { + matrix: [[f32; W]; H], } -impl Matrix { - pub fn new() -> Matrix { +impl Matrix { + pub fn new() -> Matrix { Matrix { - matrix: [[0f32; COUNT]; COUNT], + matrix: [[0f32; W]; H], } } - pub fn from_array(matrix: [[f32; COUNT]; COUNT]) -> Matrix { + pub fn from_array(matrix: [[f32; W]; H]) -> Matrix { Matrix { matrix: matrix, } } - pub fn identity() -> Matrix { + pub fn identity() -> Matrix { let mut m = Matrix::new(); for i in 0..m.matrix.len() { m.matrix[i][i] = 1.0; @@ -46,14 +46,14 @@ impl Matrix { } } -impl Index for Matrix { - type Output = [f32; COUNT]; +impl Index for Matrix { + type Output = [f32; W]; fn index(&self, index: usize) -> &Self::Output { &self.matrix[index] } } -impl PartialEq for Matrix { +impl PartialEq for Matrix { fn eq(&self, _rhs: &Self) -> bool { if self.matrix.len() != _rhs.matrix.len() { return false; @@ -73,10 +73,10 @@ impl PartialEq for Matrix { } } -impl Matrix { - fn calc_val_for_mul(&self, row: usize, rhs: &Matrix, col: usize) -> f32 { +impl Matrix { + fn calc_val_for_mul(&self, row: usize, rhs: &Matrix, col: usize) -> f32 { let mut sum = 0.0; - for i in 0..self.matrix.len() { + for i in 0..W { sum += self.matrix[row][i] * rhs.matrix[i][col]; } sum @@ -90,13 +90,13 @@ impl Matrix { } } -impl std::ops::Mul> for Matrix { - type Output = Matrix; +impl std::ops::Mul> for Matrix { + type Output = Matrix; - fn mul(self, _rhs: Matrix) -> Matrix { - let mut result = [[0f32; COUNT]; COUNT]; - for row in 0..COUNT { - for col in 0..COUNT { + fn mul(self, _rhs: Matrix) -> Matrix { + let mut result = [[0f32; W]; H]; + for row in 0..H { + for col in 0..W { result[row][col] = self.calc_val_for_mul(row, &_rhs, col); } } @@ -105,7 +105,7 @@ impl std::ops::Mul> for Matrix { } -impl std::ops::Mul for Matrix { +impl std::ops::Mul for Matrix { type Output = Tuple; fn mul(self, _rhs: Tuple) -> Tuple { @@ -296,7 +296,7 @@ mod tests { let t = Tuple::new(1.0, 2.0, 3.0, 4.0); let expected = Tuple::new(1.0, 2.0, 3.0, 4.0); - assert_eq!(Matrix::<4>::identity() * t, expected); + assert_eq!(Matrix::<4, 4>::identity() * t, expected); } #[test] @@ -322,7 +322,7 @@ mod tests { fn transpose_identity() { let mut m = Matrix::identity(); m.transpose(); - assert_eq!(m, Matrix::<4>::identity()); + assert_eq!(m, Matrix::<4, 4>::identity()); } #[test]