-- |A module defining a recursive multi-dimensional tensor data type and -- arithmetic functions for it. module Tensor ( -- * Data types Tensor(Scalar, Vector), TensorAlgebra, -- * Basic functions tenVerify, isScalar, order, tenDims, -- * Folds, maps etc. tenFold, tenMap, tenZipWith, -- * Creation and export of tensors list2tensor, tensor2list, dim2tensor, dim3tensor, -- * Special tensors kronecker, levicivita, -- * Arithmetic scalarmul, tenAdd, tenBinOp, fullcontract, externalprod, contract, selfcontract, -- * Index rearrangement transpose, swapind, revind, reindex, slice ) where import Text.ParserCombinators.ReadP -- -************************************************************************* -- Data types -- -************************************************************************* -- |The Tensor data type is defined recursively as either a Scalar or an array -- of sub-tensors called Vector. Equal size of all sub-tensors in a Vector is -- assumed by most functions below. Use the tenVerify function to verify -- correctness if you build tensors using constructors. data Tensor n = Scalar n | Vector [Tensor n] -- |Function type for folds. The first function processes Scalars, the second -- combines the results from sub-tensors. type TensorAlgebra n r = (n -> r, [r] -> r) -- -************************************************************************* -- Class instances -- -************************************************************************* instance (Eq n) => Eq (Tensor n) where (==) a b = tenDims a == tenDims b && tenZipWith (uncurry (==), and) a b (/=) a b = tenDims a /= tenDims b || tenZipWith (uncurry (/=), or) a b instance Functor Tensor where fmap = tenMap instance (Show n) => Show (Tensor n) where show (Scalar n) = show n show (Vector l) = "[" ++ (sl l) ++ "]" where sl [] = "" sl (x:xs) | length xs == 0 = show x | otherwise = show x ++ ',' : sl xs instance (Read n) => Read (Tensor n) where readsPrec _ = readP_to_S $ readS_to_P readsScalar +++ readS_to_P readsVector -- |Parse a scalar and wrap it into a Scalar constructor readsScalar :: Read n => String -> [(Tensor n, String)] -- = ReadS (Tensor n) readsScalar = map (\(n,s) -> (Scalar n,s)) . reads -- |Parse a vector or higher-order tensor. The reads function used will -- recurse via the list read function to the Tensor read function again. -- Because of this, the tensor consistency check is quite wasteful. It would -- suffice to call tenVerify once on the final result, but AFAIK there is no -- hook in the Read class that allows that. readsVector :: Read n => String -> [(Tensor n, String)] -- = ReadS (Tensor n) readsVector = filter (tenVerify . fst) . map (\(l,s) -> (Vector l,s)) . reads -- -************************************************************************* -- Basic tensor functions -- -************************************************************************* -- |Verify correctness of Tensor data. To be used if you build tensors -- yourself from constructors. tenVerify :: Tensor n -> Bool tenVerify = decide . tenFold minmaxsizes where decide = all (uncurry (==)) -- Min/max size for every tensor index. minmaxsizes :: TensorAlgebra n [(Int, Int)] minmaxsizes = (const [], merge) merge [] = [(-1, 0)] -- catch empty Vector merge (x:xs) = (dup $ length (x:xs)) : foldl (mmZipWith mm) x xs dup x = (x, x) mm (a,b) (c,d) = (min a c, max b d) -- Special version of zipWith which causes minmaxsize to fail for -- inconsistent tensor order. mmZipWith :: (a -> b -> (Int, Int)) -> [a] -> [b] -> [(Int, Int)] mmZipWith _ [] [] = [] mmZipWith f (x:xs) [] = (-1, 0) : mmZipWith f xs [] mmZipWith f [] (y:ys) = (-1, 0) : mmZipWith f [] ys mmZipWith f (x:xs) (y:ys) = f x y : mmZipWith f xs ys -- |Concise check for Scalar isScalar :: Tensor n -> Bool isScalar (Scalar _) = True isScalar (Vector _) = False -- |Tensor order order :: Tensor n -> Int order (Scalar _) = 0 order (Vector (h:r)) = 1 + (order h) order (Vector []) = error "order: hit empty Vector" -- |Tensor dimensions, starting with outer Vector = first index tenDims :: Tensor n -> [Int] tenDims (Scalar _) = [] tenDims (Vector l) = length l : tenDims (head l) -- -************************************************************************* -- Folds, maps etc. -- -************************************************************************* -- |Fold for the Tensor data type. tenFold :: TensorAlgebra n r -> Tensor n -> r tenFold (fs, _) (Scalar s) = fs s tenFold f@(_, fv) (Vector v) = fv $ map (tenFold f) v -- |Apply a function to all elements of the tensor tenMap :: (a -> b) -> Tensor a -> Tensor b tenMap f = tenFold (Scalar . f, Vector) -- |zipWith for tensors with the same order and dimensions tenZipWith :: TensorAlgebra (a, b) c -> Tensor a -> Tensor b -> c tenZipWith (fs, _) (Scalar a) (Scalar b) = fs (a, b) tenZipWith _ (Vector _) (Scalar _) = error "tenZipWith: order of first argument exceeds that of second" tenZipWith _ (Scalar _) (Vector _) = error "tenZipWith: order of second argument exceeds that of first" tenZipWith f@(_, fv) (Vector a) (Vector b) | null a || null b = error "tenZipWith: hit empty vector" | length a /= length b = error "tenZipWith: mismatching dimensions" | otherwise = fv $ zipWith (tenZipWith f) a b -- The following functions are not exported. They are used internally for -- rearranging indices and contractions. -- |map for Vectors vecMap :: (Tensor a -> Tensor b) -> Tensor a -> Tensor b vecMap f (Vector l) = Vector $ map f l -- |map for n-th order Vectors of Vectors. Used for operations on the n-th -- index, possibly after rearrangement. nVecMap :: Int -> (Tensor a -> Tensor b) -> Tensor a -> Tensor b nVecMap n = (!! n) . iterate vecMap -- |zipWith for Vectors vecZipWith :: (Tensor a -> Tensor b -> Tensor c) -> Tensor a -> Tensor b -> Tensor c vecZipWith f (Vector l1) (Vector l2) = Vector $ zipWith f l1 l2 -- |(zipWith id) for Vectors, for completeness vecZid :: [Tensor a -> Tensor b] -> Tensor a -> Tensor b vecZid fv (Vector l) = Vector $ zipWith id fv l -- |zipWith for n-th order Vectors of Vectors. Used for contraction of two -- tensors with the n-th index, possibly after rearrangement. nVecZipWith :: Int -> (Tensor a -> Tensor b -> Tensor c) -> Tensor a -> Tensor b -> Tensor c nVecZipWith n = (!! n) . iterate vecZipWith -- |Pull a function returning a pair out by one Vector level. vecUnzipWith :: (Tensor a -> (Tensor b, Tensor c)) -> Tensor a -> (Tensor b, Tensor c) vecUnzipWith f (Vector l) = (Vector $ fst fl, Vector $ snd fl) where fl = unzip $ map f l -- |Pull a function returning a pair out by n Vector levels. nVecUnzipWith :: Int -> (Tensor a -> (Tensor b, Tensor c)) -> Tensor a -> (Tensor b, Tensor c) nVecUnzipWith n = (!! n) . iterate vecUnzipWith -- -************************************************************************* -- Creation and export of tensors -- -************************************************************************* -- |Build Tensor from list of dimensions and list of elements. -- The last tensor index = innermost Vector element changes fastest. list2tensor :: [Int] -> [n] -> Tensor n list2tensor s d | any (<= 0) s = error "list2vector: zero or negative dimension" | head sc /= length d = error "list2vector: size mismatch" | otherwise = l2t (tail sc) d where sc = scanr (*) 1 s l2t :: [Int] -> [n] -> Tensor n l2t [] (x:xs) = Scalar x l2t (c:cs) (x:xs) = Vector $ map (l2t cs) $ split c (x:xs) split :: Int -> [n] -> [[n]] split _ [] = [] split n l = take n l : split n (drop n l) -- |Turn tensor into pair of list of dimensions and list of elements tensor2list :: Tensor n -> ([Int], [n]) tensor2list = tenFold t2l where t2l :: TensorAlgebra n ([Int], [n]) t2l = (scalar2l, vectorl2l) scalar2l x = ([], [x]) vectorl2l [] = error "tensor2list: hit empty Vector" vectorl2l l = (length l : fst (head l), concat $ map snd l) -- |Build 2D tensor from nested lists. Haskell's strong typing forbids -- generalising this. In order to guard against higher-order nested lists -- being passed, this function is restricted to Num's. dim2tensor :: Num n => [[n]] -> Tensor n dim2tensor l | not $ tenVerify result = error "dim2tensor: inconsistent tensor dimension or empty sub-tensor" | otherwise = result where result = Vector $ map (Vector . map Scalar) l -- |Build 3D tensor from nested lists. Haskell's strong typing forbids -- generalising this. In order to guard against higher-order nested lists -- being passed, this function is restricted to Num's. dim3tensor :: Num n => [[[n]]] -> Tensor n dim3tensor l | not $ tenVerify result = error "dim3tensor: inconsistent tensor dimension or empty sub-tensor" | otherwise = result where result = Vector $ map (Vector . map (Vector . map Scalar)) l -- -************************************************************************* -- Special tensors -- -************************************************************************* -- |Unit matrix or Kronecker symbol in given dimension kronecker :: Num n => Int -> Tensor n kronecker d | d <= 0 = error "kronecker: dimension must be positive" | otherwise = Vector $ kr 0 where kr e | e == d = [] | otherwise = Vector (map (eq2one e) [0..d-1]) : kr (e+1) eq2one x y | x == y = Scalar 1 | otherwise = Scalar 0 -- |Levi-Civita or fully antisymmetric tensor in d dimensions (and of order d) levicivita :: Num n => Int -> Tensor n levicivita d | d < 2 = error "levicivita: order must be >=2" | otherwise = lcbuilder d [] 0 where -- Build a slice of a Levi-Civita tensor. The index list (second argument) -- is reversed. The third argument is the number of inversions of the -- index list. lcbuilder :: Num n => Int -> [Int] -> Int -> Tensor n lcbuilder d i inv | length i == d = Scalar $ (-1) ^ inv | otherwise = Vector $ buildloop d i inv [1..d] buildloop :: Num n => Int -> [Int] -> Int -> [Int] -> [Tensor n] buildloop _ _ _ [] = [] buildloop d i inv (x:xs) = buildsub d i inv x : buildloop d i inv xs buildsub :: Num n => Int -> [Int] -> Int -> Int -> Tensor n buildsub d i inv x | elem x i = lczeros d (d - (length i) - 1) | otherwise = lcbuilder d (x:i) (inv + addinv i x) addinv :: [Int] -> Int -> Int addinv i x = length $ filter (x<) i lczeros :: Num n => Int -> Int -> Tensor n lczeros _ 0 = Scalar 0 lczeros d o = Vector (replicate d $ lczeros d $ o-1) -- -************************************************************************* -- Arithmetic -- -************************************************************************* -- |Multiplication with a (non-wrapped) scalar scalarmul :: Num n => n -> Tensor n -> Tensor n scalarmul n = tenFold (Scalar . (n*), Vector) -- |Add two equally-sized tensors element by element tenAdd :: Num n => Tensor n -> Tensor n -> Tensor n tenAdd = tenZipWith (Scalar . uncurry (+), Vector) -- |Arbitrary element-wise binary operator tenBinOp :: (a -> b -> c) -> Tensor a -> Tensor b -> Tensor c tenBinOp f = tenZipWith (Scalar . uncurry f, Vector) -- |Full contraction with respect to all indices of two tensors with equal -- dimensions. The result is a scalar not wrapped in a Scalar constructor. fullcontract :: Num n => Tensor n -> Tensor n -> n fullcontract = tenZipWith (uncurry (*), sum) -- |External product (no contractions) externalprod :: Num n => Tensor n -> Tensor n -> Tensor n externalprod t1 t2 = tenFold (flip scalarmul t2, Vector) t1 -- |General contraction of two tensors. The first argument contains pairs of -- one-based indices to be contracted. The contraction is performed by -- shifting indices to be contracted to the end, then using fullcontract on -- corresponding inner tensors of the two tensor arguments. contract :: Num n => [(Int,Int)] -> Tensor n -> Tensor n -> Tensor n contract inds x y = nVecMap (ox-li) (\xs -> nVecMap (oy-li) (Scalar . fullcontract xs) y') x' where inds' = unzip $ sortind $ checkind inds x' = pushall (fst inds') x y' = pushall (snd inds') y ox = order x oy = order y li = length inds checkind :: [(Int,Int)] -> [(Int,Int)] checkind [] = [] checkind ((i0,i1):it) | i0 < 1 || i0 > ox = error $ "contract: index " ++ (show i0) ++ " out of range for first operand" | i1 < 1 || i1 > ox = error $ "contract: index " ++ (show i1) ++ " out of range for second operand" | otherwise = (i0, i1) : checkind it sortind :: [(Int,Int)] -> [(Int,Int)] sortind [] = [] sortind (ph:pt) = inspair ph $ sortind pt inspair :: (Int,Int) -> [(Int,Int)] -> [(Int,Int)] inspair p [] = [p] inspair p l@(lh:lt) | cmpindpair p lh = p : l | otherwise = lh : inspair p lt cmpindpair :: (Int,Int) -> (Int,Int) -> Bool cmpindpair (i0,i1) (i2,i3) | i0 == i2 = error $ "contract: duplicate index " ++ (show i0) ++ " for first operand" | i1 == i3 = error $ "contract: duplicate index " ++ (show i1) ++ " for second operand" | otherwise = i0+i1 > i2+i3 -- |Contract two indices of one tensor. The indices are one-based. selfcontract :: Num n => Int -> Int -> Tensor n -> Tensor n selfcontract a b t | i0 < 0 = error $ "selfcontract: index number " ++ (show (i0+1)) ++ " out of range" | i1 >= order t = error $ "selfcontract: index number " ++ (show (i1+1)) ++ " out of range" | otherwise = nVecMap i0 (secon (i1-i0)) t where i0 = (min a b) - 1 i1 = (max a b) - 1 secon 1 v | vlength (vhead v) /= vlength v = error "selfcontract: dimension mismatch" | otherwise = sctrace v secon b v = vecMap (secon (b-1)) $ transpose v sctrace v | vlength v == 1 = vhead $ vhead v | otherwise = tenAdd (vhead $ vhead v) (sctrace $ vecMap vtail $ vtail v) -- -************************************************************************* -- Index rearrangement -- -************************************************************************* -- This is the most "interesting" section, as Haskell's hierarchical structures -- do not lend themselves well to this kind of thing. -- |Exchange the first two indices transpose :: Tensor n -> Tensor n transpose v | vnull $ vhead v = Vector [] | otherwise = vcons (fst headtail) (transpose $ snd headtail) where headtail = vecUnzipWith vheadtail v -- |Swap two indices. As for all exported functions, these are one-based. swapind :: Int -> Int -> Tensor n -> Tensor n swapind a b t | a == b = t | i0 < 0 = error $ "swapind: index "++(show $ i0+1)++" out of range" | i1 >= order t = error $ "swapind: index "++(show $ i1+1)++" out of range" | otherwise = nVecMap i0 (swind (i1-i0)) t where i0 = (min a b) - 1 i1 = (max a b) - 1 swind 1 v = transpose v swind b v = pullindex (b-1) $ pushindex b v -- |Reverse order of indices revind :: Tensor n -> Tensor n revind s@(Scalar _) = s revind v@(Vector vl) = rimerge $ map revind vl where rimerge :: [Tensor n] -> Tensor n rimerge l | isScalar $ head l = Vector l | vnull $ head l = Vector [] | otherwise = vcons (rimerge $ map vhead l) (rimerge $ map vtail l) -- |Rearrange indices. The first argument contains a list of index numbers -- that are moved to the front / outside in the given order. The remaining -- indices, if any, are left in their existing order at the back. reindex :: [Int] -> Tensor n -> Tensor n reindex il t = ristep il ord t where ord = order t ristep :: [Int] -> Int -> Tensor n -> Tensor n ristep [] _ t = t ristep il _ (Scalar _) = error "reindex: index list exceeds tensor order" ristep il@(ih:it) o t | ih < 1 || ih > o = error $ "reindex: index "++(show ih)++" out of range" | otherwise = vecMap (ristep (rmind il) (o-1)) $ pullindex (ih-1) t -- Remove the first index number from the list, and decrement those that -- exceed it so the list can be used on the sub-tensors. rmind :: [Int] -> [Int] rmind (ih:it) = map (conddec ih) it conddec i x | x < i = x | x == i = error $ "reindex: duplicate index " ++ (show i) | otherwise = x-1 -- |Slice of a tensor for specific values of some indices. The first argument -- is a list of indices and their values, both one-based. slice :: [(Int,Int)] -> Tensor n -> Tensor n slice [] t = t slice inds t | head il < 1 || il !! (length il - 1) > order t = error "slice: index number out of range" | otherwise = slrec il gl vl t where indsvals = unzip $ inssort (\(x,y) (z,w) -> x < z) inds il = fst indsvals gl = inds2gaps $ fst indsvals vl = snd indsvals -- Recurse over indices to evaluate slrec :: [Int] -> [Int] -> [Int] -> Tensor n -> Tensor n slrec [] [] [] t = t slrec (i:is) (g:gs) (v:vs) t = nVecMap g (slrec is gs vs . sleval i v) t -- Evaluate one index by returning the corresponding list element of the -- Vector that is the last argument sleval :: Int -> Int -> Tensor n -> Tensor n sleval i v (Vector l) | v < 1 || v > length l = error $ "slice: value " ++ (show v) ++ " for index " ++ (show i) ++ " out of range" | otherwise = l !! (v-1) -- The following are used only internally and not exported -- |Shift the topmost index downward by n places. pushindex :: Int -> Tensor n -> Tensor n pushindex n v@(Vector (vh:vt)) | null vt = nVecMap n (\x -> Vector [x]) vh | otherwise = nVecZipWith n vcons vh (pushindex n $ Vector vt) -- |Shift nth index upward to the top. n is zero-based. pullindex :: Int -> Tensor n -> Tensor n pullindex n v@(Vector _) | vnull $ (!! n) $ iterate vhead v = Vector [] | otherwise = vcons (fst headtail) (pullindex n $ snd headtail) where headtail = nVecUnzipWith n vheadtail v -- |Push some indices to the back in the given order pushall :: [Int] -> Tensor n -> Tensor n pushall [] t = t pushall inds t = pastep gl pl t where ord = order t gl = inds2gaps $ fst indspushs pl = snd indspushs -- Pair of ascending list of indices and list of push depths, generated -- from index list indspushs :: ([Int],[Int]) indspushs = unzip $ map fixpush $ zip [0..] $ reorder2orig inds -- Permute index numbers into original order after zipping with 0-based -- destination order reorder2orig = inssort (\(x,y) (z,w) -> y < w) . zip [0..] -- Generate pair of original index and push depth from its relative -- position before and after the pushall fixpush :: (Int,(Int,Int)) -> (Int,Int) fixpush (newi, (oldi, ind)) = (ind, ord-ind+(min oldi newi)) -- Recursively descend into the nested tensor by the amount given by the -- gaps between the index numbers, and push them back by the amount -- determined by fixpush pastep :: [Int] -> [Int] -> Tensor n -> Tensor n pastep [] [] t = t pastep (g:gs) (p:ps) t = nVecMap g (pastep gs ps . pushindex p) t -- |Insertion sort, used by pushall and slice to order the indices to work on. -- The first argument is a comparison function which must return true iff its -- first argument is to be sorted before its second. inssort :: (t -> t -> Bool) -> [t] -> [t] inssort _ [] = [] inssort cmp (xh:xt) = isinsert cmp xh $ inssort cmp xt where isinsert :: (t -> t -> Bool) -> t -> [t] -> [t] isinsert cmp x [] = [x] isinsert cmp x xl@(xh:xt) | cmp x xh = x:xl | otherwise = xh : isinsert cmp x xt -- |Compute the number of indices to skip between indices to do something to. -- Used by slice and pushall. The result has the same number of entries as the -- argument; the first is the first entry of the argument minus one, giving the -- number of initial indices to ignore. inds2gaps :: [Int] -> [Int] inds2gaps [] = [] inds2gaps l@(h:t) = (h-1) : zipWith (\x y -> x-y-1) t l -- Small auxiliary functions - standard list functions for Vectors -- It would be nice to generalise the following auxiliary functions, but bad -- for debugging because errors would always occur in vecDo. vecDo :: ([Tensor a] -> b) -> Tensor a -> b vecDo f (Vector l) = f l vcons :: Tensor n -> Tensor n -> Tensor n vcons a (Vector b) = Vector (a:b) -- vcons a = vecDo (Vector . (a:)) vnull :: Tensor n -> Bool vnull (Vector v) = null v -- vnull = vecDo null vlength :: Tensor n -> Int vlength (Vector v) = length v -- vlength = vecDo length vhead :: Tensor n -> Tensor n vhead (Vector (vh:vt)) = vh -- vhead = vecDo head vtail :: Tensor n -> Tensor n vtail (Vector (vh:vt)) = Vector vt -- vtail = vecDo (Vector . tail) vheadtail :: Tensor n -> (Tensor n, Tensor n) vheadtail (Vector (vh:vt)) = (vh, Vector vt)