#[macro_use] extern crate approx; use structs::Tuple; use std::ops::Index; #[derive(Debug)] pub struct Matrix { matrix: Vec>, } impl Matrix { pub fn default(width: usize, height: usize) -> Self { Matrix { matrix: vec![vec![0.0f32; width]; height], } } pub fn from_array(array: [[f32; W]; H]) -> Matrix { let mut matrix: Vec> = Vec::with_capacity(H); for r in array.iter() { let mut row: Vec = Vec::with_capacity(W); for v in r.iter() { row.push(*v); } matrix.push(row); } Matrix { matrix, } } pub fn from_vec(matrix: Vec>) -> Matrix { Matrix { matrix, } } pub fn identity(size: usize) -> Matrix { let mut m = Self::default(size, size); for i in 0..m.matrix.len() { m.matrix[i][i] = 1.0; } m } pub fn transpose(&mut self) { for i in 0..self.matrix.len() { for j in i..self.matrix[0].len() { let v = self.matrix[i][j]; self.matrix[i][j] = self.matrix[j][i]; self.matrix[j][i] = v; } } } pub fn determinant(&self) -> f32 { if self.matrix[0].len() == 2 { self.matrix[0][0] * self.matrix[1][1] - self.matrix[0][1] * self.matrix[1][0] } else { let mut sum = 0.0; for (col, val) in self.matrix[0].iter().enumerate().take(self.matrix[0].len()) { sum += val * self.cofactor(0, col); } sum } } pub fn minor(&self, row: usize, col: usize) -> f32 { let m = self.sub_matrix(row, col); let det = m.determinant(); det } pub fn cofactor(&self, row: usize, col: usize) -> f32 { let minor = self.minor(row, col); if (row + col) & 0x1 == 0 { minor } else { minor * -1.0 } } pub fn sub_matrix(&self, skip_row: usize, skip_col: usize) -> Matrix { let mut m = Vec::>::with_capacity(self.matrix.len() - 1); for (i, row) in self.matrix.iter().enumerate().take(self.matrix.len()) { if i == skip_row { continue; } let mut r = Vec::::with_capacity(row.len() - 1); for (j, col) in row.iter().enumerate().take(row.len()) { if j == skip_col { continue; } r.push(*col); } m.push(r); } Matrix::from_vec(m) } } impl Index for Matrix { type Output = Vec; fn index(&self, index: usize) -> &Self::Output { &self.matrix[index] } } impl PartialEq for Matrix { fn eq(&self, _rhs: &Self) -> bool { if self.matrix.len() != _rhs.matrix.len() { return false; } for row_idx in 0..self.matrix.len() { if self.matrix[row_idx].len() != _rhs.matrix[row_idx].len() { return false; } for col_idx in 0..self.matrix[row_idx].len() { if !relative_eq!(self.matrix[row_idx][col_idx], _rhs.matrix[row_idx][col_idx]) { return false; } } } true } } impl Matrix { fn calc_val_for_mul(&self, row: usize, rhs: &Matrix, col: usize) -> f32 { let mut sum = 0.0; for i in 0..self.matrix.len() { sum += self.matrix[row][i] * rhs.matrix[i][col]; } sum } fn calc_val_for_mul_tuple(&self, row: usize, tuple: &Tuple) -> f32 { (self.matrix[row][0] * tuple.x()) + (self.matrix[row][1] * tuple.y()) + (self.matrix[row][2] * tuple.z()) + (self.matrix[row][3] * tuple.w()) } } impl std::ops::Mul for Matrix { type Output = Matrix; fn mul(self, _rhs: Matrix) -> Matrix { let mut result: Vec> = Vec::with_capacity(self.matrix.len()); for row in 0..self.matrix.len() { let width = self.matrix[row].len(); let mut new_col = Vec::with_capacity(width); for col in 0..width { new_col.push( self.calc_val_for_mul(row, &_rhs, col)); } result.push(new_col); } Matrix::from_vec(result) } } impl std::ops::Mul for Matrix { type Output = Tuple; fn mul(self, _rhs: Tuple) -> Tuple { Tuple::new( self.calc_val_for_mul_tuple(0, &_rhs), self.calc_val_for_mul_tuple(1, &_rhs), self.calc_val_for_mul_tuple(2, &_rhs), self.calc_val_for_mul_tuple(3, &_rhs), ) } } #[cfg(test)] mod tests { use super::*; #[test] fn matrix_4x4() { let m = [ [1.0, 2.0, 3.0, 4.0], [5.5, 6.5, 7.5, 8.5], [9.0, 10.0, 11.0, 12.0], [13.5, 14.5, 15.5, 16.5], ]; let matrix = Matrix::from_array(m); assert_eq!(1.0, matrix[0][0]); assert_eq!(4.0, matrix[0][3]); assert_eq!(5.5, matrix[1][0]); assert_eq!(7.5, matrix[1][2]); assert_eq!(11.0, matrix[2][2]); assert_eq!(13.5, matrix[3][0]); assert_eq!(15.5, matrix[3][2]); } #[test] fn matrix_4x4_array() { let m = [ [1.0, 2.0, 3.0, 4.0], [5.5, 6.5, 7.5, 8.5], [9.0, 10.0, 11.0, 12.0], [13.5, 14.5, 15.5, 16.5], ]; let matrix = Matrix::from_array(m); assert_eq!(1.0, matrix[0][0]); assert_eq!(4.0, matrix[0][3]); assert_eq!(5.5, matrix[1][0]); assert_eq!(7.5, matrix[1][2]); assert_eq!(11.0, matrix[2][2]); assert_eq!(13.5, matrix[3][0]); assert_eq!(15.5, matrix[3][2]); } #[test] fn matrix_2x2() { let m = [ [-3.0, 5.0,], [1.0, 2.0,], ]; let matrix = Matrix::from_array(m); assert_eq!(-3.0, matrix[0][0]); assert_eq!(5.0, matrix[0][1]); assert_eq!(1.0, matrix[1][0]); assert_eq!(2.0, matrix[1][1]); } #[test] fn matrix_3x3() { let m = [ [-3.0, 5.0, 0.0], [1.0, -2.0, -7.0], [0.0, 1.0, 1.0], ]; let matrix = Matrix::from_array(m); assert_eq!(-3.0, matrix[0][0]); assert_eq!(-2.0, matrix[1][1]); assert_eq!(1.0, matrix[2][2]); } #[test] fn matrix_equality_a() { let a = [ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 8.0, 7.0, 6.0], [5.0, 4.0, 3.0, 2.0], ]; let m_a = Matrix::from_array(a); let b = [ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 8.0, 7.0, 6.0], [5.0, 4.0, 3.0, 2.0], ]; let m_b = Matrix::from_array(b); assert_eq!(m_a, m_b); } #[test] fn matrix_equality_b() { let a = [ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 8.0, 7.0, 6.0], [5.0, 4.0, 3.0, 2.0], ]; let m_a = Matrix::from_array(a); let b = [ [2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0], [8.0, 7.0, 6.0, 5.0], [4.0, 3.0, 2.0, 1.0], ]; let m_b = Matrix::from_array(b); assert_ne!(m_a, m_b); } #[test] fn multiply() { let matrix_a = Matrix::from_array([ [1.0, 2.0, 3.0, 4.0,], [5.0, 6.0, 7.0, 8.0,], [9.0, 8.0, 7.0, 6.0,], [5.0, 4.0, 3.0, 2.0,], ]); let matrix_b = Matrix::from_array([ [-2.0, 1.0, 2.0, 3.0,], [3.0, 2.0, 1.0, -1.0,], [4.0, 3.0, 6.0, 5.0,], [1.0, 2.0, 7.0, 8.0,], ]); let expected = Matrix::from_array([ [20.0, 22.0, 50.0, 48.0], [44.0, 54.0, 114.0, 108.0], [40.0, 58.0, 110.0, 102.0,], [16.0, 26.0, 46.0, 42.0], ]); assert_eq!(matrix_a * matrix_b, expected); } #[test] fn multiply_by_tuple() { let matrix = Matrix::from_array([ [1.0, 2.0, 3.0, 4.0], [2.0, 4.0, 4.0, 2.0], [8.0, 6.0, 4.0, 1.0], [0.0, 0.0, 0.0, 1.0], ]); let tuple = Tuple::new(1.0, 2.0, 3.0, 1.0); let expected = Tuple::new(18.0, 24.0, 33.0, 1.0); assert_eq!(matrix * tuple, expected); } #[test] fn matrix_by_identity() { let matrix = Matrix::from_array([ [0.0, 1.0, 2.0, 4.0,], [1.0, 2.0, 4.0, 8.0,], [2.0, 4.0, 8.0, 16.0], [4.0, 8.0, 16.0, 32.0,] ]); let expected = Matrix::from_array([ [0.0, 1.0, 2.0, 4.0,], [1.0, 2.0, 4.0, 8.0,], [2.0, 4.0, 8.0, 16.0], [4.0, 8.0, 16.0, 32.0,] ]); assert_eq!(matrix * Matrix::identity(4), expected); } #[test] fn tuple_by_identity() { let t = Tuple::new(1.0, 2.0, 3.0, 4.0); let expected = Tuple::new(1.0, 2.0, 3.0, 4.0); assert_eq!(Matrix::identity(4) * t, expected); } #[test] fn transposition() { let mut m = Matrix::from_array([ [0.0, 9.0, 3.0, 0.0], [9.0, 8.0, 0.0, 8.0], [1.0, 8.0, 5.0, 3.0], [0.0, 0.0, 5.0, 8.0], ]); let expected = Matrix::from_array([ [0.0, 9.0, 1.0, 0.0], [9.0, 8.0, 8.0, 0.0], [3.0, 0.0, 5.0, 5.0], [0.0, 8.0, 3.0, 8.0], ]); m.transpose(); assert_eq!(m, expected); } #[test] fn transpose_identity() { let mut m = Matrix::identity(4); m.transpose(); assert_eq!(m, Matrix::identity(4)); } #[test] fn determinant_2x2() { let m = Matrix::from_array([ [1.0, 5.0], [-3.0, 2.0], ]); assert_eq!(17.0, m.determinant()); } #[test] fn submatrix_3x3() { let start = Matrix::from_array([ [1.0, 5.0, 0.0], [-3.0, 2.0, 7.0], [0.0, 6.0, -3.0], ]); let expected = Matrix::from_array([ [-3.0, 2.0], [0.0, 6.0], ]); assert_eq!(expected, start.sub_matrix(0, 2)); } #[test] fn submatrix_4x4() { let start = Matrix::from_array([ [-6.0, 1.0, 1.0, 6.0], [-8.0, 5.0, 8.0, 6.0], [-1.0, 0.0, 8.0, 2.0], [-7.0, 1.0, -1.0, 1.0], ]); let expected = Matrix::from_array([ [-6.0, 1.0, 6.0], [-8.0, 8.0, 6.0], [-7.0, -1.0, 1.0], ]); assert_eq!(expected, start.sub_matrix(2, 1)); } #[test] fn minor_3x3() { let m = Matrix::from_array([ [3.0, 5.0, 0.0], [2.0, -1.0, -7.0], [6.0, -1.0, 5.0], ]); let s = m.sub_matrix(1, 0); assert_eq!(25.0, s.determinant()); assert_eq!(25.0, m.minor(1, 0)); } #[test] fn cofactor_3x3() { let m = Matrix::from_array([ [3.0, 5.0, 0.0], [2.0, -1.0, -7.0], [6.0, -1.0, 5.0], ]); assert_eq!(-12.0, m.minor(0, 0)); assert_eq!(-12.0, m.cofactor(0, 0)); assert_eq!(25.0, m.minor(1, 0)); assert_eq!(-25.0, m.cofactor(1, 0)); } #[test] fn determinant_3x3() { let m = Matrix::from_array([ [1.0, 2.0, 6.0], [-5.0, 8.0, -4.0], [2.0, 6.0, 4.0], ]); assert_eq!(56.0, m.cofactor(0, 0)); assert_eq!(12.0, m.cofactor(0, 1)); assert_eq!(-46.0, m.cofactor(0, 2)); assert_eq!(-196.0, m.determinant()); } #[test] fn determinant_4x4() { let m = Matrix::from_array([ [-2.0, -8.0, 3.0, 5.0], [-3.0, 1.0, 7.0, 3.0], [1.0, 2.0, -9.0, 6.0], [-6.0, 7.0, 7.0, -9.0], ]); assert_eq!(690.0, m.cofactor(0, 0)); assert_eq!(447.0, m.cofactor(0, 1)); assert_eq!(210.0, m.cofactor(0, 2)); assert_eq!(51.0, m.cofactor(0, 3)); assert_eq!(-4071.0, m.determinant()); } }