use crate::structs::Tuple; use crate::num_traits_cast; use num_traits::NumCast; use std::ops::{Index, IndexMut}; #[derive(Debug)] pub struct Matrix { matrix: Vec>, } impl Matrix { pub fn default(width: usize, height: usize) -> Self { Matrix { matrix: vec![vec![0.0f32; width]; height], } } pub fn from_array(array: [[T; W]; H]) -> Matrix where T: NumCast + Copy { let mut matrix: Vec> = Vec::with_capacity(H); for r in array.iter() { let mut row: Vec = Vec::with_capacity(W); for v in r.iter() { row.push(num_traits_cast!(*v)); } matrix.push(row); } Matrix { matrix, } } pub fn from_vec(matrix: Vec>) -> Matrix where T: NumCast + Copy { let mut matrix_f32 : Vec> = Vec::with_capacity(matrix.len()); for r in matrix.iter() { let mut row: Vec = Vec::with_capacity(r.len()); for v in r.iter() { row.push(num_traits_cast!(*v)); } matrix_f32.push(row); } Matrix { matrix: matrix_f32, } } pub fn identity(size: usize) -> Matrix { let mut m = Self::default(size, size); for i in 0..m.matrix.len() { m.matrix[i][i] = 1.0; } m } pub fn transpose(&mut self) { for i in 0..self.matrix.len() { for j in i..self.matrix[0].len() { let v = self.matrix[i][j]; self.matrix[i][j] = self.matrix[j][i]; self.matrix[j][i] = v; } } } pub fn determinant(&self) -> f32 { if self.matrix[0].len() == 2 { self.matrix[0][0] * self.matrix[1][1] - self.matrix[0][1] * self.matrix[1][0] } else { let mut sum = 0.0; for (col, val) in self.matrix[0].iter().enumerate().take(self.matrix[0].len()) { sum += val * self.cofactor(0, col); } sum } } pub fn minor(&self, row: usize, col: usize) -> f32 { let m = self.sub_matrix(row, col); let det = m.determinant(); det } pub fn cofactor(&self, row: usize, col: usize) -> f32 { let minor = self.minor(row, col); if (row + col) & 0x1 == 0 { minor } else { minor * -1.0 } } pub fn sub_matrix(&self, skip_row: usize, skip_col: usize) -> Matrix { let mut m = Vec::>::with_capacity(self.matrix.len() - 1); for (i, row) in self.matrix.iter().enumerate().take(self.matrix.len()) { if i == skip_row { continue; } let mut r = Vec::::with_capacity(row.len() - 1); for (j, col) in row.iter().enumerate().take(row.len()) { if j == skip_col { continue; } r.push(*col); } m.push(r); } Matrix::from_vec(m) } pub fn is_invertable(&self) -> bool { self.determinant() != 0.0 } pub fn inverse(&self) -> Matrix { // seems dangerous if !self.is_invertable() { panic!("We can't invert {:?}", self.matrix); } //let mut matrix: Vec> = Vec::with_capacity(self.matrix.len()); let mut matrix = Matrix::default(self.matrix.len(), self.matrix[0].len()); let det = self.determinant(); for (row_idx, row) in self.matrix.iter().enumerate().take(self.matrix.len()) { for (col_idx, _) in row.iter().enumerate().take(row.len()) { let c = self.cofactor(row_idx, col_idx); let val = c / det; matrix[col_idx][row_idx] = val; } } matrix } } impl Index for Matrix { type Output = Vec; fn index(&self, index: usize) -> &Self::Output { &self.matrix[index] } } impl IndexMut for Matrix { fn index_mut(&mut self, index: usize) -> &mut Self::Output { &mut self.matrix[index] } } impl PartialEq for Matrix { fn eq(&self, _rhs: &Self) -> bool { if self.matrix.len() != _rhs.matrix.len() { return false; } for row_idx in 0..self.matrix.len() { if self.matrix[row_idx].len() != _rhs.matrix[row_idx].len() { return false; } for col_idx in 0..self.matrix[row_idx].len() { if !relative_eq!( self.matrix[row_idx][col_idx], _rhs.matrix[row_idx][col_idx]) { return false; } } } true } } impl Matrix { fn calc_val_for_mul(&self, row: usize, rhs: &Matrix, col: usize) -> f32 { let mut sum = 0.0; for i in 0..self.matrix.len() { sum += self.matrix[row][i] * rhs.matrix[i][col]; } sum } fn calc_val_for_mul_tuple(&self, row: usize, tuple: &Tuple) -> f32 { (self.matrix[row][0] * tuple.x()) + (self.matrix[row][1] * tuple.y()) + (self.matrix[row][2] * tuple.z()) + (self.matrix[row][3] * tuple.w()) } } impl std::ops::Mul<&Matrix> for &Matrix { type Output = Matrix; fn mul(self, _rhs: &Matrix) -> Matrix { let mut result: Vec> = Vec::with_capacity(self.matrix.len()); for row in 0..self.matrix.len() { let width = self.matrix[row].len(); let mut new_col = Vec::with_capacity(width); for col in 0..width { new_col.push( self.calc_val_for_mul(row, &_rhs, col)); } result.push(new_col); } Matrix::from_vec(result) } } impl std::ops::Mul<&Tuple> for &Matrix { type Output = Tuple; fn mul(self, _rhs: &Tuple) -> Tuple { Tuple::new( self.calc_val_for_mul_tuple(0, &_rhs), self.calc_val_for_mul_tuple(1, &_rhs), self.calc_val_for_mul_tuple(2, &_rhs), self.calc_val_for_mul_tuple(3, &_rhs), ) } } impl std::ops::Mul<&Matrix> for &Tuple { type Output = Tuple; fn mul(self, rhs: &Matrix) -> Tuple { Tuple::new( rhs.calc_val_for_mul_tuple(0, &self), rhs.calc_val_for_mul_tuple(1, &self), rhs.calc_val_for_mul_tuple(2, &self), rhs.calc_val_for_mul_tuple(3, &self), ) } } #[cfg(test)] mod tests { use super::*; #[test] fn matrix_4x4() { let m = [ [1.0, 2.0, 3.0, 4.0], [5.5, 6.5, 7.5, 8.5], [9.0, 10.0, 11.0, 12.0], [13.5, 14.5, 15.5, 16.5], ]; let matrix = Matrix::from_array(m); assert_eq!(1.0, matrix[0][0]); assert_eq!(4.0, matrix[0][3]); assert_eq!(5.5, matrix[1][0]); assert_eq!(7.5, matrix[1][2]); assert_eq!(11.0, matrix[2][2]); assert_eq!(13.5, matrix[3][0]); assert_eq!(15.5, matrix[3][2]); } #[test] fn matrix_4x4_array() { let m = [ [1.0, 2.0, 3.0, 4.0], [5.5, 6.5, 7.5, 8.5], [9.0, 10.0, 11.0, 12.0], [13.5, 14.5, 15.5, 16.5], ]; let matrix = Matrix::from_array(m); assert_eq!(1.0, matrix[0][0]); assert_eq!(4.0, matrix[0][3]); assert_eq!(5.5, matrix[1][0]); assert_eq!(7.5, matrix[1][2]); assert_eq!(11.0, matrix[2][2]); assert_eq!(13.5, matrix[3][0]); assert_eq!(15.5, matrix[3][2]); } #[test] fn matrix_2x2() { let m = [ [-3, 5,], [1, 2,], ]; let matrix = Matrix::from_array(m); assert_eq!(-3.0, matrix[0][0]); assert_eq!(5.0, matrix[0][1]); assert_eq!(1.0, matrix[1][0]); assert_eq!(2.0, matrix[1][1]); } #[test] fn matrix_3x3() { let m = [ [-3, 5, 0], [1, -2, -7], [0, 1, 1], ]; let matrix = Matrix::from_array(m); assert_eq!(-3.0, matrix[0][0]); assert_eq!(-2.0, matrix[1][1]); assert_eq!(1.0, matrix[2][2]); } #[test] fn matrix_equality_a() { let a = [ [1, 2, 3, 4], [5, 6, 7, 8], [9, 8, 7, 6], [5, 4, 3, 2], ]; let m_a = Matrix::from_array(a); let b = [ [1, 2, 3, 4], [5, 6, 7, 8], [9, 8, 7, 6], [5, 4, 3, 2], ]; let m_b = Matrix::from_array(b); assert_eq!(m_a, m_b); } #[test] fn matrix_equality_b() { let a = [ [1, 2, 3, 4], [5, 6, 7, 8], [9, 8, 7, 6], [5, 4, 3, 2], ]; let m_a = Matrix::from_array(a); let b = [ [2, 3, 4, 5], [6, 7, 8, 9], [8, 7, 6, 5], [4, 3, 2, 1], ]; let m_b = Matrix::from_array(b); assert_ne!(m_a, m_b); } #[test] fn multiply() { let matrix_a = Matrix::from_array([ [1, 2, 3, 4,], [5, 6, 7, 8,], [9, 8, 7, 6,], [5, 4, 3, 2,], ]); let matrix_b = Matrix::from_array([ [-2, 1, 2, 3,], [3, 2, 1, -1,], [4, 3, 6, 5,], [1, 2, 7, 8,], ]); let expected = Matrix::from_array([ [20, 22, 50, 48], [44, 54, 114, 108], [40, 58, 110, 102,], [16, 26, 46, 42], ]); assert_eq!(&matrix_a * &matrix_b, expected); } #[test] fn multiply_by_tuple() { let matrix = Matrix::from_array([ [1, 2, 3, 4], [2, 4, 4, 2], [8, 6, 4, 1], [0, 0, 0, 1], ]); let tuple = Tuple::new(1, 2, 3, 1); let expected = Tuple::new(18, 24, 33, 1); assert_eq!(&matrix * &tuple, expected); } #[test] fn multiply_by_tuple_reverse() { let matrix = Matrix::from_array([ [1, 2, 3, 4], [2, 4, 4, 2], [8, 6, 4, 1], [0, 0, 0, 1], ]); let tuple = Tuple::new(1, 2, 3, 1); let expected = Tuple::new(18, 24, 33, 1); assert_eq!(&tuple * &matrix, expected); } #[test] fn matrix_by_identity() { let matrix = Matrix::from_array([ [0, 1, 2, 4,], [1, 2, 4, 8,], [2, 4, 8, 16], [4, 8, 16, 32,] ]); let expected = Matrix::from_array([ [0, 1, 2, 4,], [1, 2, 4, 8,], [2, 4, 8, 16], [4, 8, 16, 32,] ]); assert_eq!(&matrix * &Matrix::identity(4), expected); } #[test] fn tuple_by_identity() { let t = Tuple::new(1, 2, 3, 4); let expected = Tuple::new(1, 2, 3, 4); assert_eq!(&Matrix::identity(4) * &t, expected); } #[test] fn transposition() { let mut m = Matrix::from_array([ [0, 9, 3, 0], [9, 8, 0, 8], [1, 8, 5, 3], [0, 0, 5, 8], ]); let expected = Matrix::from_array([ [0, 9, 1, 0], [9, 8, 8, 0], [3, 0, 5, 5], [0, 8, 3, 8], ]); m.transpose(); assert_eq!(m, expected); } #[test] fn transpose_identity() { let mut m = Matrix::identity(4); m.transpose(); assert_eq!(m, Matrix::identity(4)); } #[test] fn determinant_2x2() { let m = Matrix::from_array([ [1, 5], [-3, 2], ]); assert_eq!(17.0, m.determinant()); } #[test] fn submatrix_3x3() { let start = Matrix::from_array([ [1, 5, 0], [-3, 2, 7], [0, 6, -3], ]); let expected = Matrix::from_array([ [-3, 2], [0, 6], ]); assert_eq!(expected, start.sub_matrix(0, 2)); } #[test] fn submatrix_4x4() { let start = Matrix::from_array([ [-6, 1, 1, 6], [-8, 5, 8, 6], [-1, 0, 8, 2], [-7, 1, -1, 1], ]); let expected = Matrix::from_array([ [-6, 1, 6], [-8, 8, 6], [-7, -1, 1], ]); assert_eq!(expected, start.sub_matrix(2, 1)); } #[test] fn minor_3x3() { let m = Matrix::from_array([ [3, 5, 0], [2, -1, -7], [6, -1, 5], ]); let s = m.sub_matrix(1, 0); assert_eq!(25.0, s.determinant()); assert_eq!(25.0, m.minor(1, 0)); } #[test] fn cofactor_3x3() { let m = Matrix::from_array([ [3, 5, 0], [2, -1, -7], [6, -1, 5], ]); assert_eq!(-12.0, m.minor(0, 0)); assert_eq!(-12.0, m.cofactor(0, 0)); assert_eq!(25.0, m.minor(1, 0)); assert_eq!(-25.0, m.cofactor(1, 0)); } #[test] fn determinant_3x3() { let m = Matrix::from_array([ [1, 2, 6], [-5, 8, -4], [2, 6, 4], ]); assert_eq!(56.0, m.cofactor(0, 0)); assert_eq!(12.0, m.cofactor(0, 1)); assert_eq!(-46.0, m.cofactor(0, 2)); assert_eq!(-196.0, m.determinant()); } #[test] fn determinant_4x4() { let m = Matrix::from_array([ [-2, -8, 3, 5], [-3, 1, 7, 3], [1, 2, -9, 6], [-6, 7, 7, -9], ]); assert_eq!(690.0, m.cofactor(0, 0)); assert_eq!(447.0, m.cofactor(0, 1)); assert_eq!(210.0, m.cofactor(0, 2)); assert_eq!(51.0, m.cofactor(0, 3)); assert_eq!(-4071.0, m.determinant()); } #[test] fn can_invert_invertable() { let m = Matrix::from_array([ [6, 4, 4, 4], [5, 5, 7, 6], [4, -9, 3, -7], [9, 1, 7, -6], ]); assert_eq!(-2120.0, m.determinant()); assert_eq!(true, m.is_invertable()); } #[test] fn can_invert_not_invertable() { let m = Matrix::from_array([ [-4, 2, -2, -3], [9, 6, 2, 6], [0, -5, 1, -5], [0, 0, 0, 0], ]); assert_eq!(0.0, m.determinant()); assert_eq!(false, m.is_invertable()); } pub fn assert_matrix_eq(_lhs: &Matrix, _rhs: &Matrix, max_relative: f32) -> bool { if _lhs.matrix.len() != _rhs.matrix.len() { return false; } for row_idx in 0.._lhs.matrix.len() { if _lhs.matrix[row_idx].len() != _rhs.matrix[row_idx].len() { return false; } for col_idx in 0.._lhs.matrix[row_idx].len() { assert_relative_eq!( _lhs.matrix[row_idx][col_idx], _rhs.matrix[row_idx][col_idx], max_relative = max_relative); } } true } #[test] fn inverse() { let m = Matrix::from_array::([ [-5, 2, 6, -8], [1, -5, 1, 8], [7, 7, -6, -7], [1, -3, 7, 4], ]); let b = m.inverse(); assert_eq!(532.0, m.determinant()); assert_eq!(-160.0, m.cofactor(2, 3)); assert_eq!(-160.0/532.0, b[3][2]); assert_eq!(105.0, m.cofactor(3, 2)); assert_eq!(105.0/532.0, b[2][3]); let expected = Matrix::from_array::([ [0.21805, 0.45113, 0.24060, -0.04511], [-0.80827, -1.45677, -0.44361, 0.52068], [-0.07895, -0.22368, -0.05263, 0.19737], [-0.52256, -0.81392, -0.30075, 0.30639], ]); assert_matrix_eq(&expected, &b, 0.0001); } #[test] fn inverse_2() { let m = Matrix::from_array([ [8, -5, 9, 2], [7, 5, 6, 1], [-6, 0, 9, 6], [-3, 0, -9, -4], ]).inverse(); let expected = Matrix::from_array([ [-0.15385, -0.15385, -0.28205, -0.53846], [-0.07692, 0.12308, 0.02564, 0.03077], [0.35897, 0.35897, 0.43590, 0.92308], [-0.69321, -0.69321, -0.76923, -1.92308], ]); assert_matrix_eq(&expected, &m, 0.01); } #[test] fn inverse_3() { let m = Matrix::from_array([ [9, 3, 0, 9], [-5, -2, -6, -3], [-4, 9, 6, 4], [-7, 6, 6, 2], ]).inverse(); let expected = Matrix::from_array([ [-0.04074, -0.07778, 0.14444, -0.22222], [-0.07778, 0.03333, 0.36667, -0.33333], [-0.02901, -0.14630, -0.10926, 0.12963], [0.17778, 0.06667, -0.26667, 0.33333], ]); assert_matrix_eq(&expected, &m, 0.01); } #[test] fn multiply_by_inverse() { let a = Matrix::from_array([ [3, -9, 7, 3], [3, -8, 2, -9], [-4, 4, 4, 1], [-6, 5, -1, 1], ]); let b = Matrix::from_array([ [8, 2, 2, 2], [3, -1, 7, 0], [7, 0, 5, 4], [6, -2, 0, 5], ]); let c = &a * &b; let r = &c * &b.inverse(); let expected = Matrix::from_array([ [3, -9, 7, 3], [3, -8, 2, -9], [-4, 4, 4, 1], [-6, 5, -1, 1], ]); assert_matrix_eq(&r, &expected, 0.00001); } }