// ************************************************************
//	Clean Linear Algebra Subroutines - CLAS
//	Version 0.6 - February 23, 1998 - Thorsten Zoerner
// 	Catholic University of Nijmegen - zoerner@cs.kun.nl
// ************************************************************

implementation module Clas3

import Clas2, StdMisc

class MatrixMatrixProduct a
where
	(***) infix 7 :: a a -> a

instance MatrixMatrixProduct {# {# Real}}
where
 	(***) a b 
 		| (size a.[0]) <> (size b) 
 			= abort "\n(***): matrix dimensions do not match!" 
 		= { a ** bb \\ bb <-: (transpose b)}
	
solve :: *Matrix .Vector -> .Vector
solve a b
	# lu = lu a
	# y = forwardSubst lu b
	= backwardSubst lu y

solvePartPiv :: *Matrix .Vector -> .Vector
solvePartPiv a b
	# (lu, p) = luPartPiv a
	# y = forwardSubst lu (perVec p b)
	= backwardSubst lu y

lu :: *Matrix -> *Matrix
lu aa = lu_k 0 a
where
	(n, a) = usize aa
	lu_k :: Int *Matrix -> *Matrix
	lu_k k a 
		| (k==dec n) = a
		= lu_k (inc k) (lu_i (inc k) a)
	where
		lu_i :: Int *Matrix -> *Matrix
		lu_i i a
			| (i==n) = a
		lu_i i a =: {[k, k] = akk, [i, k] = aik}			
			| (akk==0.0) = abort "matrix is singular!"
			= lu_i (inc i) (lu_j (inc k) { a & [i,k] = aik / akk})
		where
			lu_j :: Int *Matrix -> *Matrix
			lu_j j a
				| (j==n) = a
			lu_j j a =: { [i, j] = aij, [i, k] = aik, [k, j] = akj}
				= lu_j (inc j) { a & [i,j] = aij - aik * akj}

luPartPiv :: *Matrix -> (*Matrix, *{# Int})
luPartPiv aa = lu_k 0 a { i \\ i <- [0 .. dec n]} 
where
	(n, a) = usize aa
	lu_k :: Int *Matrix *{# Int} -> (*Matrix, *{# Int})
	lu_k k a p
		| (k==dec n) = (a, p)
		#! mu = k + amax { a.[m,k] \\ m <- [k .. dec n]} 
		| (mu==k) = lu_k (inc k) (lu_i (inc k) a) p
		= lu_k (inc k) (lu_i (inc k) (uniMat (swap a k mu))) (swap p k mu)
	where
		lu_i :: Int *Matrix -> *Matrix
		lu_i i a
			| (i==n) = a
		lu_i i a =: {[k, k] = akk, [i, k] = aik}			
			= lu_i (inc i) (lu_j (inc k) { a & [i,k] = aik / akk})
		where
			lu_j :: Int *Matrix -> *Matrix
			lu_j j a
				| (j==n) = a
			lu_j j a =: { [i, j] = aij, [i, k] = aik, [k, j] = akj}
				= lu_j (inc j) { a & [i,j] = aij - aik * akj}

invert :: *Matrix -> .Matrix
invert aa = inv_j 0 a
where
	(n, a) = usize aa
	inv_j :: Int *Matrix -> *Matrix
	inv_j j a 
		| (j==n) = a
		= inv_i 0 a
	where
		inv_i :: Int *Matrix -> *Matrix
		inv_i i a =: { [j, j] = ajj}
			| (i==n) = inv_k 0 { a & [j, j] = 1.0 / ajj}
			| (i==j) = inv_i (inc i) a
			#! (aij, a) = a![i, j]
			= inv_i (inc i) { a & [i, j] = aij / ajj}
		where
			inv_k :: Int *Matrix -> *Matrix
			inv_k k a
				| (k==n) = inv_j (inc j) a
				| (k==j) = inv_k (inc k) a
				= inv_l 0 a
			where
				inv_l :: Int *Matrix -> *Matrix
				inv_l l a =: { [j, j] = ajj, [j, k] = ajk}
					| (l==n) = inv_k (inc k) { a & [j, k] = ~(ajk * ajj)}
					| (l==j) = inv_l (inc l) a
					#! (alk, a) = a![l, k]
					#! (alj, a) = a![l, j]
					= inv_l (inc l) { a & [l, k] = alk - alj * ajk}
