Monday, August 13, 2007

Fun Functions: Flatten

This post is going to be about a function which may not seem like much, it takes a tree and returns the elements of the tree in a list the way an inorder traversal would visit them. While this is not an uncommon task in programming it is so mundane that it hardly deserves a blog post on its own. So why am I writing this? The thing that have captured me is a particular implementation of this function. Let us first set the scene. This blog post is a literate Haskell file and you should be able so save it to a file and run it yourself.
> module Flatten where
> import Test.QuickCheck
> data Tree a = Node (Tree a) a (Tree a) | Leaf deriving Show
So we have a data type for trees, it's the standard one. Nothing surprising here. The scene is set. How would we go about implementing flatten then? There are in fact a whole number of ways to implement it and I'm going to show several of them. The last version will be my favorite one and after having seen all the other versions I hope you will appreciate it as I do. Let us then start with something rather naive. A straightforward recursive function which turns a tree into a list:
> flatten1 :: Tree a -> [a]
> flatten1 Leaf           = []
> flatten1 (Node t1 a t2) = flatten1 t1 ++ [a] ++ flatten1 t2
That's very concise and easy to understand. But this solution has a slight problem. It takes quadratic time. This is due to the fact that ++ take linear time in it's left argument. It is a well known problem with many solutions. The standard solution, due to John Hughes, is to represent lists as functions from lists to lists. This way we get constant time concatenation. Incidentally, some people seem to think that this version is fine. It is used on the type level by alpheccar in his blog. Implementing lists the Hughes way gives us our second variation of flatten.
> flatten2 :: Tree a -> [a]
> flatten2 tree = help2 tree []
> 
> help2 :: Tree a -> [a] -> [a]
> help2 Leaf           = id
> help2 (Node t1 a t2) = help2 t1 . (a :) . help2 t2
Hmmmm. I don't know what you think of this version but I don't like it. For two reasons. It uses higher order programming when it isn't necessary, and that just obfuscates things. Secondly, it needs a helper function. It would be nice if we could do without that. Let's start by tackling the first problem of higher order programming. One way to remove the higher orderness is to give the function help2 an extra argument (called eta expansion, if that means anything to you). This gives us the following version.
> flatten3 :: Tree a -> [a]
> flatten3 tree = help3 tree []
> 
> help3 :: Tree a -> [a] -> [a]
> help3 Leaf           rest = rest
> help3 (Node t1 a t2) rest = help3 t1 (a : help3 t2 rest)
I call the extra argument rest because it contains the rest of the elements that we should put in the list. This solution still uses a helper function which is a bit of a nuisance but unfortunately I don't know an easy way to remove it. Instead I'm going to present another way to think about flattening that will take us down another road. How would you go about implementing an inorder traversal with a while loop instead of recursion? The solution would be to use a stack to remember the parts of the tree that we haven't visited yet. We can use the same idea to write a new version of flatten, using a list as the stack. The top of the stack will be the part of the tree that we are visiting at the moment. We will need a little trick to remember when we have visited a subtree and that is to simply remove it from its parent. The final code looks like this.
> flatten4 :: Tree a -> [a]
> flatten4 tree = help4 [tree]
>
> help4 :: [Tree a] -> [a]
> help4 [] = []
> help4 (Leaf : ts) = help4 ts
> help4 (Node Leaf a t : ts) = a : help4 (t : ts)
> help4 (Node t1 a t2 : ts) = help4 (t1 : Node Leaf a t2 : ts)
This version is pretty clever. It takes apart the tree and spreads it out into the list until there is nothing left of the tree. But one might argue that this version is too clever. It is certainly not easy to understand from the code. Plus it uses a helper function, something that we have been trying to get away from. The final version that I will show builds upon the idea that we can represent the list of trees from our previous version as simply a tree. We are going to use the input tree as the stack and modify it as we go along. It goes like this.
> flatten5 :: Tree a -> [a]
> flatten5 Leaf = []
> flatten5 (Node Leaf a t) = a : flatten5 t
> flatten5 (Node (Node t1 a t2) b t3) = flatten5 (Node t1 a (Node t2 b t3))
What happens in the third equation is that the tree is restructured so that it becomes a linear structure, just like a list. It is then simple to read off the element in the second equation. This is my favorite version. It's short, without helper functions and no higher order stuff. Plus it is efficient. Now, you might complain that it does more allocations that some of the previous versions. Indeed we have to allocate new trees in the third case. But note that that case is tail recursive. That means that it doesn't use any stack. It is possible to implement this version of the function so that it doesn't any more space than any of the other versions above. The control stack is encoded into the tree. One additional thing about this version is that it's termination is not entirely obvious. This function will fool several termination checkers that I know of. In particular, in the third equation the argument tree in recursive call is not smaller than the input tree. This is a criterion that many termination checker use. To understand that the function terminates you must instead see that it is the left subtree of the input that gets smaller together with the fact that trees are finite. To me this is an extra bonus. You might think otherwise. Finally, I have some code to test that all the version I've given above are equal.
> instance Arbitrary a => Arbitrary (Tree a) where
>   arbitrary = sized arbTree
>    where arbTree 0 = return Leaf
>          arbTree n | n>0 = frequency $
>                 [(1, return Leaf)
>                 ,(4, do elem <- arbitrary
>                         t1   <- arbTree (n `div` 2)
>                         t2   <- arbTree (n `div` 2)
>                         return (Node t1 elem t2))]
> prop_flatten2 t = flatten1 t == flatten2 (t ::Tree Int)
> prop_flatten3 t = flatten1 t == flatten3 (t ::Tree Int)
> prop_flatten4 t = flatten1 t == flatten4 (t ::Tree Int)
> prop_flatten5 t = flatten1 t == flatten5 (t ::Tree Int)
> runTests = mapM_ quickCheck [prop_flatten2,prop_flatten3,
>                              prop_flatten4,prop_flatten5]
> 

No comments: