5 Different Ways to Flatten a List of Lists in Python
In this post, we’ll see how we can flatten a list in 5 different ways. Each method has pros and cons, and varied performance. By the end of this tutorial, I hope you’ll be able to identify the most appropriate solution for your problem. The Python version used in this tutorial was 3.8 and for the tests I used
Flattening a List With Lists or Generators Comprehension
Let’s imagine that we have a simple list of lists like this
[[1, 3], [2, 5], ] and we want to flatten it. In other words, we want to make it like this
[1, 3, 2, 5, 1]. The first way of doing that is through list/generator comprehensions. For those already familiar with them, this may sound very straightforward. What we need to do is iterate through each sub-list, then iterate over each one of them producing a single element each time.
The following function accepts a list as an argument and returns a generator. The reason for that is to avoid building a whole list in memory. With generators we can produce the elements on demand.
To make sure everything works as expected, we can assert the behavior with the
test_flatten unit test.
from typing import List, Any, Iterable def flatten_gen_comp(lst: List[Any]) -> Iterable[Any]: """Flatten a list using generators comprehensions.""" return (item for sublist in lst for item in sublist) def test_flatten(): l = [[1, 3], [2, 5], ] assert list(flatten_gen_comp(l)) == [1, 3, 2, 5, 1]
When we run the test we can see it passing...
============================= test session starts ============================== flatten.py::test_flatten PASSED [100%] ============================== 1 passed in 0.01s =============================== Process finished with exit code 0
Flattening a List With
The second strategy is a bit unconventional and, truth to be told, very "magical". It turns out, we can use the function
sum to flatten a list. Doing that is quite simple and concise, we pass the list as an argument along an empty list. The following function illustrates that.
If you’re curious about this approach, I discuss it in more detail in the 5 Hidden Python Features You Probably Never Heard Of post.
from typing import List, Any, Iterable def flatten_sum(lst: List[Any]) -> Iterable[Any]: """Flatten a list using sum.""" return sum(lst, ) def test_flatten(): l = [[1, 3], [2, 5], ] assert list(flatten_sum(l)) == [1, 3, 2, 5, 1]
And the tests pass too...
Even though it seems clever, it's not a good idea to use it in production IMHO. As we'll see later, this is the worst way to flatten a list in terms of performance.
The third alternative is to use the
chain function from the
itertools module. In a nutshell,
chain creates a single iterator from a sequence of other iterables. This function is a perfect match for our use case.
Equivalent implementation of
chain taken from the official docs looks like this.
import itertools def chain(*iterables): # chain('ABC', 'DEF') --> A B C D E F for it in iterables: for element in it: yield element
We can then flatten the list like so:
def flatten_chain(lst: List[Any]) -> Iterable[Any]: """Flatten a list using chain.""" return itertools.chain(*lst) def test_flatten(): l = [[1, 3], [2, 5], ] assert list(flatten_chain(l)) == [1, 3, 2, 5, 1]
And the test pass...
[OMITTED] .... flatten.py::test_flatten PASSED ... [OMITTED]
Flattening Irregular Lists
What happens if we have a list like this
[[1, 3], [2, 5], 1] or this
[1, [2, 3], [], , [[[[[[[[]]]]]]]]]? Unfortunately, that ends up not so well if we try to apply any of those preceding approaches. In this section, we’ll see two unique solutions for that, one recursive and the other iterative.
The Recursive Approach
Solving the flatting problem recursively means iterating over each item and deciding if the element is already flattened or not. If so, we return it, otherwise we can call
flatten_recursive on it. But better than words is code, so let’s see some code.
from typing import List, Any, Iterable def flatten_recursive(lst: List[Any]) -> Iterable[Any]: """Flatten a list using recursion.""" for item in lst: if isinstance(item, list): yield from flatten_recursive(item) else: yield item def test_flatten_recursive(): l = [[1, 3], [2, 5], 1] assert list(flatten_recursive(l)) == [1, 3, 2, 5, 1]
Bare in mind that the if
isinstance(item, list) means it only works with lists.
The Iterative Approach
The iterative approach is no doubt the most complex of all of them. In this approach, we’ll use a
deque to flatten the irregular list. Quoting the official docs, “Deques are a generalization of stacks and queues (the name is pronounced “deck” and is short for “double-ended queue”). Deques support thread-safe, memory efficient appends and pops from either side of the deque with approximately the same O(1) performance in either direction.”
In other words, we can use
deque to simulate the stacking operation of the recursive solution.
To go that, we’ll start just like we did in the recursion case, we’ll iterate through each element and if the element is not a list, we’ll append to the left of the
appendleft method append the element to the leftmost position of the
deque, for example:
from collections import deque l = deque() l.appendleft(2) l deque() l.appendleft(7) l deque([7, 2])
If the element is a list, though, then we need to reverse it first to pass it to “extendleft” method, like this
my_deque.extendleft(reversed(item)). Again, similar to a
extendleft adds each item to the leftmost position of the
deque in series. As a result,
deque will add the elements in a reverse order. That’s exactly why we need to reverse the sub-list before extending left. To make things clearer, let’s see an example.
1, 2, 3]) l deque([3, 2, 1])l = deque() l.extendleft([
The final step is to iterate over the
deque removing the leftmost element and yielding it if it’s not a list. If the element is a list, then we need to extend it left. The full implement goes like this:
from typing import List, Any, Iterable def flatten_deque(lst: List[Any]) -> Iterable[Any]: """Flatten a list using a deque.""" q = deque() for item in lst: if isinstance(item, list): q.extendleft(reversed(item)) else: q.appendleft(item) while q: elem = q.popleft() if isinstance(elem, list): q.extendleft(reversed(elem)) else: yield elem def test_flatten_super_irregular(): l = [1, [2, 3], , , [[[[[[[[]]]]]]]]] assert list(flatten_deque(l)) == [1, 2, 3, 4, 5]
When we run the test, it happily pass...
============================= test session starts ============================== ... flatten.py::test_flatten_super_irregular PASSED [100%] ============================== 1 passed in 0.01s ===============================
As we’ve seen in the previous section, if our list is irregular we have little choice. But assuming that this is not a frequent use case, how these approaches compare in terms of performance?
In this last part, we’ll run a benchmark and compare which solutions perform best.
To do that, we can use the
timeit module, we can invoke it in
IPython by doing
%timeit [code_to_measure]. The following list is the timings for each one of them. As we can see,
flatten_chain is the fastest implementation of all. It flatted our list
217 µs avg. The slowest implementation is
flatten_sum, taking around
30 ms to flatten the same list.
PS: A special thanks to @hynekcer who pointed out a bug in this benchmark. Since most of the functions return a generator, we need to consume all elements in order to get a better assessment. We can either iterate over the generator or create a list out of it.
Flatten Generator Comprehension
In : lst = [[1, 2, 3], [4, 5, 6], , [8, 9]] * 1_000 In : %timeit list(flatten_gen_comp(lst)) 433 µs ± 5.58 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In : In : %timeit list(flatten_sum(lst)) 30 ms ± 343 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In : %timeit list(flatten_chain(lst)) 217 µs ± 31.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In : %timeit list(flatten_recursive(lst)) 1.86 ms ± 20.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In : %timeit list(flatten_deque(lst)) 2.04 ms ± 78.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
We can flatten a list in several different ways with Python. Each approach has its pros and cons. In this post we looked at 5 different ways to flatten a list, including irregular lists.
If you liked this post, consider sharing it with your friends or buying me a coffee:)
Other posts you may like:
- Design Patterns That Make Sense in Python: Simple Factory
- How to Pass Multiple Arguments to a map Function in Python
- 73 Examples to Help You Master Python's f-strings
- 3 Ways to Test API Client Applications in Python
- Everything You Need to Know About Python's Namedtuples
- The Best Way to Compare Two Dictionaries in Python
- 5 Hidden Python Features You Probably Never Heard Of
See you next time!
All functions measured by a line
%timeit some_function(l) are measured incorrectly. A time in about hundreds of nanoseconds to get a list of 9 * 1000 numbers is a nonsense. You will see that the result doesn't depend on a number of items.
It is because any code after a
yield command is postponed until you enumerate the output values. If you put a slow sleep command after a
yield line then you will see no difference because it run never. Even a generator expression command
%timeit (x for x in range(1000000000)) is similarly fast as the same for
A correct measurement is by
%timeit list(some_function(l)) or
%timeit for _ in some_function(l): pass
The fastest code for flattening is
flatten_chain(l) because the expression
itertools.chain(*lst) uses only a library function call and no Python loops.
By the way, is it typo?
def flatten_recursive(lst: List[Any]) -> Iterable[Any]: """Flatten a list using recursion.""" for item in lst: if isinstance(item, list): yield from flatten(item) # This line! I think it will be `yield from flatten_recursive(item)` else: yield item
Thanks HEE JAE CHOI, and yes, this is a typo!
Thanks for flagging it, I've just fixed it. I renamed it before publishing to make it easier to understand and forgot to double check it in IPython.
This is the correct version:
def flatten_recursive(lst: List[Any]) -> Iterable[Any]: """Flatten a list using recursion.""" for item in lst: if isinstance(item, list): yield from flatten_recursive(item) else: yield item
Here's a different, non recursive, solution for deep flattening:
def flatten(bad): good =  while bad: e = bad.pop() if isinstance(e, list): bad.extend(e) else: good.append(e) return good[::-1]
Very likely quite fast as well.
Hynek Černoch Maybe it would be better to use this quite different function as an inline flatten, very much like list.sort() or list.remove(). The return value is then always None and the list will be replaced by the flattened version. That can be implemented like:
def flatten_inplace(lst): build =  while lst: e = lst.pop() if isinstance(e, list): lst.extend(e) else: build.append(e) lst.extend(build[::-1])