When I started learning programming, one of the most confusing concepts I encountered was recursion. Now, when I finally become the pro programmer who I wished I could be when I was a newbie, I feel that these confusion and struggles are unnecessary. If I had a teacher like the present me, I would have learned it easily! Recently, as I was learning Haskell, I got enamored by how elegant coding could be in Haskell, especially when writing recursion. Inspired by recursion in Haskell, I decided to rewrite some of the common functions from JS and Python in Haskell style. In this post, I’ll show you how to write recursion functions step by step. In the end, I’ll show you more recursion functions I wrote!
Many tutorials explain recursion by introducing Fibonacci numbers. I think that’s counter-productive. We don’t need another complicated concept to explain an already complicated concept. I’ll start with some common reasoning.
Let’s take a look at this piece of code:
Open your terminal or any REPL editors, and run this code. You’ll get an error! Don’t be angry with me yet. This is to be expected. Let’s reason about this code for a bit. This
foo function is defined to call itself, since there’s no command to tell it to stop, it will keep calling itself infinitely, until the runtime quits with stack overflow, which, will be translated to human language as “maximum recursion depth exceeded” in Python runtime.
Let’s modify the code to make it stop after meeting a condition:
Pythondef foo(n):if n <= 1:returnfoo(n-1)foo(10)
This code does nothing, but this time it will not throw an error!
There are three major changes that should draw your attention in the second version of
foofunction takes an argument.
Believe it or not, that’s all recursion is about! We’ll consolidate our understanding by writing some recursion functions. All the functions I’m going to present you are inspired by Haskell!
Let’s first write our own version of
max in Python. I suggest you use the built-in
max method in your code, this is just for practice.
Pythondef max2(list):if len(list) == 1:return listhead, tail = list, list[1:]return head if head > max2(tail) else max2(tail)print max2([3,98,345])# 345
max2 function takes a list. If the length of the list is 1, the function will stop running and yield the first element of the list. Notice that when a recursion stops, it must yield a value. (If the purpose of your function is to perform side effects rather than pure mathematical calculation, then you can choose not to yield a value.) Otherwise, we’ll take the first element from the list and name it head, and the rest of the list will be named tail.
We compare head with the largest element from the tail list, which we don’t know yet, so we put the tail to the
max2 function again. We don’t care how the rest call to
max2 anymore, because we know it will be stopped and yields a value eventually. If the head is bigger than the max number from the tail list, then the head is the max number of the original list. Otherwise, the result of
max2(tail) is the max number.
If this is still confusing to you, I suggest you pass a short list to the function, write the execution process down and observe each step.
reverse takes a list as an argument, and returns a new list with the order of all elements being reversed.
Python# python# Python has a built-in reverse,# so I name this function reverse2def reverse2(list):if len(list) == 1:return listhead, tail = list, list[1:]return reverse2(tail) + [x]print reverse2([1,2,3,4,5,6])# [6,5,4,3,2,1]
map takes a function and a list, it applies the function to each element of the list and returns the result.
Python# python# Python has a built-in map,# so I name this function map2def map2(f, list):if len(list) == 0:return head, tail = list, list[1:]return [f(head)] + map2(f, tail)print map2(lambda x : x + 1, [2,2,2,2])# [3,3,3,3]
zipWith takes a function and two lists, it iterates over these two lists in parallel and zips each element from these two lists together in every iteration, with the provided function being applied to each zip.
Python# Pythondef zipWith(f, listA, listB):if len(listA) == 0 or len(listB) == 0:return headA, tailA = listA, listA[1:]headB, tailB = listB, listB[1:]return [f(headA, headB)] + zipWith(f, tailA, tailB)print zipWith(lambda x, y : x + y, [2,2,2,2], [3,3,3,3,3])# [5,5,5,5]# The result list will only be as long as the shorter source list
replicate takes a number and an arbitrary element, and returns a list with the length of the number, with each element being that arbitrary element.
Python# Pythondef replicate(n,x):if n <= 0:return return [x] + replicate(n-1,x)print replicate(4, 'hello')# ['hello', 'hello', 'hello', 'hello']
filter takes a predicate function and a list, and returns a new list that filters out all the elements that fail to pass the predicate function.
Python# Python# Python has a built-in filter,# so I name this function filter2def filter2(f, list):if len(list) == 0:return head, tail = list, list[1:]return [head] + filter2(f, tail) if f(head) else filter2(f, tail)print filter2(lambda x : x > 4, [3,2,4,5,6])# [5,6]
Here comes the most complicated one so far. Quick sort is a very common sorting algorithm in computer science. This
quickSort function takes a list of numbers, and returns a list of numbers sorted with ascending order. The function first takes out the first number of the list and marks it as a pivot, then it iterates over the rest of the list, comparing each element with the pivot number. If the element is smaller, it will be appended to the new list called smaller, otherwise, it will be appended to another new list called bigger. Then the function enters recursion, it applies
quickSort to both the smaller list and the bigger list, and concatenates the results with pivot, with pivot in the middle. The smaller list and bigger list will be divided in every recursion, until the length of them reaches one, upon which the recursion stops and returns the remaining list.
Python# Pythondef quickSort(xs):if len(xs) <= 1:return xspivot, rest = xs, xs[1:]smaller, bigger = , for x in rest:smaller.append(x) if x < pivot else bigger.append(x)return quickSort(smaller) + [pivot] + quickSort(bigger)print quickSort([44,14,65,34])# [14, 34, 44, 65]