1 /**
2 Matrix module.
3 */
4 module karasux.linear_algebra.matrix;
5 
6 import std.algorithm : swap;
7 import std.math : cos, sin;
8 import std.traits : isNumeric;
9 
10 import karasux.linear_algebra.vector : Vector;
11 
12 @safe:
13 
14 /**
15 Matrix structure.
16 
17 Params:
18     ROWS = matrix rows.
19     COLS = matrix columns.
20     E = element type.
21 */
22 struct Matrix(size_t ROWS, size_t COLS, E = float)
23 {
24     static assert(ROWS > 0);
25     static assert(COLS > 0);
26     static assert(isNumeric!E);
27 
28     /**
29     Initialize by row major elements.
30 
31     Params:
32         elements = matrix row major elements.
33     Returns:
34         initialized matrix.
35     */
36     static typeof(this) fromRows(scope const(E)[COLS][ROWS] elements)
37     {
38         auto m = typeof(this)();
39         foreach (j; 0 .. COLS)
40         {
41             foreach (i; 0 .. ROWS)
42             {
43                 m.elements_[j][i] = elements[i][j];
44             }
45         }
46         return m;
47     }
48 
49     static if(COLS == ROWS)
50     {
51         /**
52         Initialize unit matrix.
53 
54         Returns:
55             unit matrix;
56         */
57         static typeof(this) unit()
58         {
59             auto m = typeof(this)();
60             foreach (j; 0 .. COLS)
61             {
62                 foreach (i; 0 .. ROWS)
63                 {
64                     m.elements_[j][i] = cast(E)((i == j) ? 1 : 0);
65                 }
66             }
67             return m;
68         }
69 
70         /**
71         Create scale matrix.
72 
73         Params:
74             factors = scale factors.
75         Returns:
76             scale matrix.
77         */
78         static typeof(this) scale(Factors...)(Factors factors)
79         {
80             static assert(factors.length == COLS - 1);
81             auto m = typeof(this)();
82             m.fill(cast(E) 0);
83             foreach (i, f; factors)
84             {
85                 m[i, i] = cast(E) f;
86             }
87             m[ROWS - 1, COLS - 1] = cast(E) 1.0;
88             return m;
89         }
90 
91         /**
92         Create translate matrix.
93 
94         Params:
95             factors = translate factors.
96         Returns:
97             translate matrix.
98         */
99         static typeof(this) translate(Factors...)(Factors factors)
100         {
101             static assert(factors.length == COLS - 1);
102 
103             auto m = typeof(this).unit();
104             foreach (i, f; factors)
105             {
106                 m[i, COLS - 1] = cast(E) f;
107             }
108             return m;
109         }
110 
111         static if (ROWS == 4 && COLS == 4)
112         {
113             /**
114             Create rotation X matrix.
115 
116             Params:
117                 theta = rotate theta.
118             Returns:
119                 rotation matrix.
120             */
121             static typeof(this) rotateX(E theta)
122             {
123                 auto m = typeof(this).unit();
124                 m[1, 1] = cos(theta);
125                 m[1, 2] = -sin(theta);
126                 m[2, 1] = sin(theta);
127                 m[2, 2] = cos(theta);
128                 return m;
129             }
130 
131             /**
132             Create rotation Y matrix.
133 
134             Params:
135                 theta = rotate theta.
136             Returns:
137                 rotation matrix.
138             */
139             static typeof(this) rotateY(E theta)
140             {
141                 auto m = typeof(this).unit();
142                 m[0, 0] = cos(theta);
143                 m[0, 2] = sin(theta);
144                 m[2, 0] = -sin(theta);
145                 m[2, 2] = cos(theta);
146                 return m;
147             }
148 
149             /**
150             Create rotation Z matrix.
151 
152             Params:
153                 theta = rotate theta.
154             Returns:
155                 rotation matrix.
156             */
157             static typeof(this) rotateZ(E theta)
158             {
159                 auto m = typeof(this).unit();
160                 m[0, 0] = cos(theta);
161                 m[0, 1] = -sin(theta);
162                 m[1, 0] = sin(theta);
163                 m[1, 1] = cos(theta);
164                 return m;
165             }
166         }
167     }
168 
169     @property const scope
170     {
171         size_t rows() { return ROWS; }
172         size_t columns() { return COLS; }
173     }
174 
175     /**
176     Get an element.
177 
178     Params:
179         i = row index.
180         j = column index.
181     Returns:
182         element value.
183     */
184     ref const(E) opIndex(size_t i, size_t j) const return scope
185     in (i < ROWS)
186     in (j < COLS)
187     {
188         return elements_[j][i];
189     }
190 
191     /**
192     Set an element.
193 
194     Params:
195         value = element value.
196         i = row index.
197         j = column index.
198     Returns:
199         assigned element value.
200     */
201     ref const(E) opIndexAssign()(auto ref const(E) value, size_t i, size_t j) return scope
202     in (i < ROWS)
203     in (j < COLS)
204     {
205         return elements_[j][i] = value;
206     }
207 
208     /**
209     operation and assign an element.
210 
211     Params:
212         op = operator.
213         value = element value.
214         i = row index.
215         j = column index.
216     Returns:
217         assigned element value.
218     */
219     ref const(E) opIndexOpAssign(string op)(auto ref const(E) value, size_t i, size_t j) return scope
220     in (i < ROWS)
221     in (j < COLS)
222     {
223         return mixin("elements_[j][i] " ~ op ~ "= value");
224     }
225 
226     /**
227     Operation and assign other vector.
228 
229     Params:
230         value = other vetor value.
231     Returns:
232         this vector.
233     */
234     ref typeof(this) opOpAssign(string op)(auto ref const(typeof(this)) value) return scope
235     {
236         foreach (j, ref column; elements_)
237         {
238             foreach (i, ref v; column)
239             {
240                 mixin("v " ~ op ~ "= value[i, j];");
241             }
242         }
243         return this;
244     }
245 
246     /**
247     Matrix multiplication.
248 
249     Params:
250         lhs = left hand side matrix.
251         rhs = right hand side matrix.
252     Returns:
253         calculated this matrix.
254     */
255     ref typeof(this) mul(size_t N, E1, E2)(
256             auto ref const(Matrix!(ROWS, N, E1)) lhs,
257             auto ref const(Matrix!(N, COLS, E2)) rhs) return scope
258     {
259         foreach (j, ref column; elements_)
260         {
261             foreach (i, ref v; column)
262             {
263                 v = cast(E) 0;
264                 foreach (k; 0 .. N)
265                 {
266                     v += lhs[i, k] * rhs[k, j];
267                 }
268             }
269         }
270         return this;
271     }
272 
273     /**
274     Fill elements.
275 
276     Params:
277         value = filler value.
278     */
279     ref typeof(this) fill()(auto ref const(E) value) return scope
280     {
281         foreach (ref column; elements_)
282         {
283             column[] = value;
284         }
285         return this;
286     }
287 
288     /**
289     Matrix pointer.
290 
291     Returns:
292         matrix pointer.
293     */
294     @property const(E)* ptr() const return scope
295     out (r; r != null)
296     {
297         return &elements_[0][0];
298     }
299 
300     /**
301     Swap 2 rows.
302 
303     Params:
304         row1 = swap target row 1.
305         row2 = swap target row 2.
306     */
307     void swapRows(size_t row1, size_t row2) scope
308         in (row1 < ROWS)
309         in (row2 < ROWS)
310     {
311         if (row1 == row2)
312         {
313             return;
314         }
315 
316         foreach (column; 0 .. COLS)
317         {
318             swap(elements_[column][row1], elements_[column][row2]);
319         }
320     }
321 
322 private:
323     E[ROWS][COLS] elements_;
324 }
325 
326 ///
327 @nogc nothrow pure unittest
328 {
329     import std.math : isClose;
330 
331     immutable m = Matrix!(2, 3).fromRows([
332         [1, 2, 3],
333         [4, 5, 6],
334     ]);
335     assert(m.rows == 2);
336     assert(m.columns == 3);
337 
338     assert(m[0, 0].isClose(1));
339     assert(m[0, 1].isClose(2));
340     assert(m[0, 2].isClose(3));
341     assert(m[1, 0].isClose(4));
342     assert(m[1, 1].isClose(5));
343     assert(m[1, 2].isClose(6));
344 }
345 
346 ///
347 @nogc nothrow pure unittest
348 {
349     import std.math : isClose;
350 
351     immutable m = Matrix!(3, 3).unit;
352     assert(m.rows == 3);
353     assert(m.columns == 3);
354 
355     assert(m[0, 0].isClose(1));
356     assert(m[0, 1].isClose(0));
357     assert(m[0, 2].isClose(0));
358     assert(m[1, 0].isClose(0));
359     assert(m[1, 1].isClose(1));
360     assert(m[1, 2].isClose(0));
361     assert(m[2, 0].isClose(0));
362     assert(m[2, 1].isClose(0));
363     assert(m[2, 2].isClose(1));
364 }
365 
366 ///
367 @nogc nothrow pure unittest
368 {
369     import std.math : isClose;
370 
371     auto m = Matrix!(2, 2).fromRows([
372         [1, 2],
373         [3, 4]
374     ]);
375     m[0, 0] = 3.0f;
376     m[0, 1] = 4.0f;
377     m[1, 0] = 5.0f;
378     m[1, 1] = 6.0f;
379 
380     assert(m[0, 0].isClose(3));
381     assert(m[0, 1].isClose(4));
382     assert(m[1, 0].isClose(5));
383     assert(m[1, 1].isClose(6));
384 }
385 
386 ///
387 @nogc nothrow pure unittest
388 {
389     import std.math : isClose;
390 
391     auto m = Matrix!(2, 2).fromRows([
392         [1, 2],
393         [3, 4]
394     ]);
395     m[0, 0] += 1.0f;
396     m[0, 1] += 1.0f;
397     m[1, 0] += 1.0f;
398     m[1, 1] += 1.0f;
399 
400     assert(m[0, 0].isClose(2));
401     assert(m[0, 1].isClose(3));
402     assert(m[1, 0].isClose(4));
403     assert(m[1, 1].isClose(5));
404 }
405 
406 ///
407 @nogc nothrow pure unittest
408 {
409     import std.math : isClose;
410 
411     auto m = Matrix!(2, 2).fromRows([
412         [1, 2],
413         [3, 4]
414     ]);
415     immutable t = Matrix!(2, 2).fromRows([
416         [3, 4],
417         [5, 6]
418     ]);
419 
420     m += t;
421 
422     assert(m[0, 0].isClose(4));
423     assert(m[0, 1].isClose(6));
424     assert(m[1, 0].isClose(8));
425     assert(m[1, 1].isClose(10));
426 }
427 
428 ///
429 @nogc nothrow pure unittest
430 {
431     import std.math : isClose;
432 
433     auto result = Matrix!(2, 2)();
434     immutable lhs = Matrix!(2, 3).fromRows([
435         [3, 4, 5],
436         [6, 7, 8],
437     ]);
438     immutable rhs = Matrix!(3, 2).fromRows([
439         [3, 4],
440         [6, 7],
441         [8, 9],
442     ]);
443 
444     result.mul(lhs, rhs);
445 
446     assert(result[0, 0].isClose(3 * 3 + 4 * 6 + 5 * 8));
447     assert(result[0, 1].isClose(3 * 4 + 4 * 7 + 5 * 9));
448     assert(result[1, 0].isClose(6 * 3 + 7 * 6 + 8 * 8));
449     assert(result[1, 1].isClose(6 * 4 + 7 * 7 + 8 * 9));
450 }
451 
452 ///
453 @nogc nothrow pure unittest
454 {
455     import std.math : isClose;
456 
457     auto m = Matrix!(2, 2)();
458     m.fill(1.0);
459 
460     assert(m[0, 0].isClose(1.0));
461     assert(m[0, 1].isClose(1.0));
462     assert(m[1, 0].isClose(1.0));
463     assert(m[1, 1].isClose(1.0));
464 }
465 
466 ///
467 @nogc nothrow pure unittest
468 {
469     import std.math : isClose;
470 
471     immutable m = Matrix!(2, 2)([[1, 2], [3, 4]]);
472     assert(isClose(*(m.ptr), 1.0));
473 }
474 
475 ///
476 @nogc nothrow pure unittest
477 {
478     import std.math : isClose;
479 
480     immutable m = Matrix!(4, 4).scale(2.0, 3.0, 4.0);
481     assert(m[0, 0].isClose(2.0));
482     assert(m[1, 1].isClose(3.0));
483     assert(m[2, 2].isClose(4.0));
484     assert(m[3, 3].isClose(1.0));
485 
486     foreach (i; 0 .. 4)
487     {
488         foreach (j; 0 .. 4)
489         {
490             if (i != j)
491             {
492                 assert(m[i, j].isClose(0.0));
493             }
494         }
495     }
496 }
497 
498 ///
499 @nogc nothrow pure unittest
500 {
501     import karasux.linear_algebra.vector : isClose;
502 
503     immutable m = Matrix!(4, 4).translate(2.0, 3.0, 4.0);
504     immutable v = Vector!4([1.0, 2.0, 3.0, 1.0]);
505     auto result = Vector!4();
506     result.mul(m, v);
507 
508     assert(result.isClose(Vector!4([3, 5, 7, 1])));
509 }
510 
511 ///
512 @nogc nothrow pure unittest
513 {
514     import karasux.linear_algebra.vector : isClose;
515 
516     immutable m = Matrix!(4, 4).rotateX(0.5);
517     immutable x = Vector!4([1.0, 0.0, 0.0, 1.0]);
518     immutable y = Vector!4([0.0, 1.0, 0.0, 1.0]);
519     immutable z = Vector!4([0.0, 0.0, 1.0, 1.0]);
520 
521     auto result = Vector!4();
522     result.mul(m, x);
523     assert(result.isClose(x));
524     result.mul(m, y);
525     assert(result.isClose(Vector!4([0.0, cos(0.5), sin(0.5), 1.0])));
526     result.mul(m, z);
527     assert(result.isClose(Vector!4([0.0, -sin(0.5), cos(0.5), 1.0])));
528 }
529 
530 ///
531 @nogc nothrow pure unittest
532 {
533     import karasux.linear_algebra.vector : isClose;
534 
535     immutable m = Matrix!(4, 4).rotateY(0.5);
536     immutable x = Vector!4([1.0, 0.0, 0.0, 1.0]);
537     immutable y = Vector!4([0.0, 1.0, 0.0, 1.0]);
538     immutable z = Vector!4([0.0, 0.0, 1.0, 1.0]);
539 
540     auto result = Vector!4();
541     result.mul(m, x);
542     assert(result.isClose(Vector!4([cos(0.5), 0.0, -sin(0.5), 1.0])));
543     result.mul(m, y);
544     assert(result.isClose(y));
545     result.mul(m, z);
546     assert(result.isClose(Vector!4([sin(0.5), 0.0, cos(0.5), 1.0])));
547 }
548 
549 ///
550 @nogc nothrow pure unittest
551 {
552     import karasux.linear_algebra.vector : isClose;
553 
554     immutable m = Matrix!(4, 4).rotateZ(0.5);
555     immutable x = Vector!4([1.0, 0.0, 0.0, 1.0]);
556     immutable y = Vector!4([0.0, 1.0, 0.0, 1.0]);
557     immutable z = Vector!4([0.0, 0.0, 1.0, 1.0]);
558 
559     auto result = Vector!4();
560     result.mul(m, x);
561     assert(result.isClose(Vector!4([cos(0.5), sin(0.5), 0.0, 1.0])));
562     result.mul(m, y);
563     assert(result.isClose(Vector!4([-sin(0.5), cos(0.5), 0.0, 1.0])));
564     result.mul(m, z);
565     assert(result.isClose(z));
566 }
567 
568 /**
569 isClose for matrix.
570 
571 Params:
572     R = rows.
573     C = columns.
574     a = matrix.
575     b = other matrix.
576 Returns:
577     true if both matrix are close.
578 */
579 bool isClose(size_t R, size_t C, E)(
580     auto scope ref const(Matrix!(R, C, E)) a,
581     auto scope ref const(Matrix!(R, C, E)) b) @nogc nothrow pure 
582 {
583     import std.math : mathIsClose = isClose;
584 
585     foreach (c; 0 .. C)
586     {
587         foreach (r; 0 .. R)
588         {
589             if (!mathIsClose(a[r, c], b[r, c]))
590             {
591                 return false;
592             }
593         }
594     }
595 
596     return true;
597 }
598 
599 ///
600 @nogc nothrow pure @safe unittest
601 {
602     immutable a = Matrix!(4, 4).fromRows([
603         [1, 2, 3, 4],
604         [5, 6, 7, 8],
605         [9, 10, 11, 12],
606         [13, 14, 15, 16],
607     ]);
608 
609     assert(a.isClose(a));
610 
611     auto b = Matrix!(4, 4)();
612     b = a;
613     assert(a.isClose(b));
614     assert(b.isClose(a));
615 
616     b[0, 0] = 100.0;
617     assert(!a.isClose(b));
618     assert(!b.isClose(a));
619 }
620 
621 ///
622 @nogc nothrow pure @safe unittest
623 {
624     auto a = Matrix!(4, 4).fromRows([
625         [1, 2, 3, 4],
626         [5, 6, 7, 8],
627         [9, 10, 11, 12],
628         [13, 14, 15, 16],
629     ]);
630 
631     immutable expected = Matrix!(4, 4).fromRows([
632         [9, 10, 11, 12],
633         [5, 6, 7, 8],
634         [1, 2, 3, 4],
635         [13, 14, 15, 16],
636     ]);
637 
638     a.swapRows(0, 2);
639     assert(a.isClose(expected));
640 
641     a.swapRows(3, 3);
642     assert(a.isClose(expected));
643 }
644