allow rectangular matrixes

This commit is contained in:
Jon Janzen
2021-03-31 15:38:00 -06:00
parent 82360600fe
commit 9c20d13833

View File

@@ -6,24 +6,24 @@ use structs::Tuple;
use std::ops::Index; use std::ops::Index;
#[derive(Debug)] #[derive(Debug)]
pub struct Matrix<const COUNT: usize> { pub struct Matrix<const H: usize, const W: usize> {
matrix: [[f32; COUNT]; COUNT], matrix: [[f32; W]; H],
} }
impl<const COUNT: usize> Matrix<COUNT> { impl<const H: usize, const W: usize> Matrix<H, W> {
pub fn new() -> Matrix<COUNT> { pub fn new() -> Matrix<H, W> {
Matrix { Matrix {
matrix: [[0f32; COUNT]; COUNT], matrix: [[0f32; W]; H],
} }
} }
pub fn from_array(matrix: [[f32; COUNT]; COUNT]) -> Matrix<COUNT> { pub fn from_array(matrix: [[f32; W]; H]) -> Matrix<H, W> {
Matrix { Matrix {
matrix: matrix, matrix: matrix,
} }
} }
pub fn identity() -> Matrix<COUNT> { pub fn identity() -> Matrix<H, W> {
let mut m = Matrix::new(); let mut m = Matrix::new();
for i in 0..m.matrix.len() { for i in 0..m.matrix.len() {
m.matrix[i][i] = 1.0; m.matrix[i][i] = 1.0;
@@ -46,14 +46,14 @@ impl<const COUNT: usize> Matrix<COUNT> {
} }
} }
impl<const COUNT: usize> Index<usize> for Matrix<COUNT> { impl<const H: usize, const W: usize> Index<usize> for Matrix<H, W> {
type Output = [f32; COUNT]; type Output = [f32; W];
fn index(&self, index: usize) -> &Self::Output { fn index(&self, index: usize) -> &Self::Output {
&self.matrix[index] &self.matrix[index]
} }
} }
impl<const COUNT: usize> PartialEq for Matrix<COUNT> { impl<const H: usize, const W: usize> PartialEq for Matrix<H, W> {
fn eq(&self, _rhs: &Self) -> bool { fn eq(&self, _rhs: &Self) -> bool {
if self.matrix.len() != _rhs.matrix.len() { if self.matrix.len() != _rhs.matrix.len() {
return false; return false;
@@ -73,10 +73,10 @@ impl<const COUNT: usize> PartialEq for Matrix<COUNT> {
} }
} }
impl<const COUNT: usize> Matrix<COUNT> { impl<const H: usize, const W: usize> Matrix<H, W> {
fn calc_val_for_mul(&self, row: usize, rhs: &Matrix<COUNT>, col: usize) -> f32 { fn calc_val_for_mul(&self, row: usize, rhs: &Matrix<H, W>, col: usize) -> f32 {
let mut sum = 0.0; let mut sum = 0.0;
for i in 0..self.matrix.len() { for i in 0..W {
sum += self.matrix[row][i] * rhs.matrix[i][col]; sum += self.matrix[row][i] * rhs.matrix[i][col];
} }
sum sum
@@ -90,13 +90,13 @@ impl<const COUNT: usize> Matrix<COUNT> {
} }
} }
impl<const COUNT: usize> std::ops::Mul<Matrix<COUNT>> for Matrix<COUNT> { impl<const H: usize, const W: usize> std::ops::Mul<Matrix<H, W>> for Matrix<H, W> {
type Output = Matrix<COUNT>; type Output = Matrix<H, W>;
fn mul(self, _rhs: Matrix<COUNT>) -> Matrix<COUNT> { fn mul(self, _rhs: Matrix<H, W>) -> Matrix<H, W> {
let mut result = [[0f32; COUNT]; COUNT]; let mut result = [[0f32; W]; H];
for row in 0..COUNT { for row in 0..H {
for col in 0..COUNT { for col in 0..W {
result[row][col] = self.calc_val_for_mul(row, &_rhs, col); result[row][col] = self.calc_val_for_mul(row, &_rhs, col);
} }
} }
@@ -105,7 +105,7 @@ impl<const COUNT: usize> std::ops::Mul<Matrix<COUNT>> for Matrix<COUNT> {
} }
impl<const COUNT: usize> std::ops::Mul<Tuple> for Matrix<COUNT> { impl<const H: usize, const W: usize> std::ops::Mul<Tuple> for Matrix<H, W> {
type Output = Tuple; type Output = Tuple;
fn mul(self, _rhs: Tuple) -> Tuple { fn mul(self, _rhs: Tuple) -> Tuple {
@@ -296,7 +296,7 @@ mod tests {
let t = Tuple::new(1.0, 2.0, 3.0, 4.0); let t = Tuple::new(1.0, 2.0, 3.0, 4.0);
let expected = Tuple::new(1.0, 2.0, 3.0, 4.0); let expected = Tuple::new(1.0, 2.0, 3.0, 4.0);
assert_eq!(Matrix::<4>::identity() * t, expected); assert_eq!(Matrix::<4, 4>::identity() * t, expected);
} }
#[test] #[test]
@@ -322,7 +322,7 @@ mod tests {
fn transpose_identity() { fn transpose_identity() {
let mut m = Matrix::identity(); let mut m = Matrix::identity();
m.transpose(); m.transpose();
assert_eq!(m, Matrix::<4>::identity()); assert_eq!(m, Matrix::<4, 4>::identity());
} }
#[test] #[test]