5 Different Ways to Flatten a List of Lists in Python

Subscribe to my newsletter and never miss my upcoming articles

Listen to this article

In this post, we’ll see how we can flatten a list in 5 different ways. Each method has pros and cons, and varied performance. By the end of this tutorial, I hope you’ll be able to identify the most appropriate solution for your problem. The Python version used in this tutorial was 3.8 and for the tests I used pytest.

Flattening a List With Lists or Generators Comprehension

Let’s imagine that we have a simple list of lists like this [[1, 3], [2, 5], [1]] and we want to flatten it. In other words, we want to make it like this [1, 3, 2, 5, 1]. The first way of doing that is through list/generator comprehensions. For those already familiar with them, this may sound very straightforward. What we need to do is iterate through each sub-list, then iterate over each one of them producing a single element each time.

The following function accepts a list as an argument and returns a generator. The reason for that is to avoid building a whole list in memory. With generators we can produce the elements on demand.

To make sure everything works as expected, we can assert the behavior with the test_flatten unit test.

from typing import List, Any, Iterable


def flatten_gen_comp(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using generators comprehensions."""
    return (item
            for sublist in lst
            for item in sublist)


def test_flatten():
    l = [[1, 3], [2, 5], [1]]

    assert list(flatten_gen_comp(l)) == [1, 3, 2, 5, 1]

When we run the test we can see it passing...

============================= test session starts ==============================

flatten.py::test_flatten PASSED                                          [100%]

============================== 1 passed in 0.01s ===============================

Process finished with exit code 0

Flattening a List With sum

The second strategy is a bit unconventional and, truth to be told, very "magical". It turns out, we can use the function sum to flatten a list. Doing that is quite simple and concise, we pass the list as an argument along an empty list. The following function illustrates that.

If you’re curious about this approach, I discuss it in more detail in the 5 Hidden Python Features You Probably Never Heard Of post.

from typing import List, Any, Iterable


def flatten_sum(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using sum."""
    return sum(lst, [])


def test_flatten():
    l = [[1, 3], [2, 5], [1]]

    assert list(flatten_sum(l)) == [1, 3, 2, 5, 1]

And the tests pass too...

flatten.py::test_flatten PASSED

Even though it seems clever, it's not a good idea to use it in production IMHO. As we'll see later, this is the worst way to flatten a list in terms of performance.

Flattening using itertools.chain

The third alternative is to use the chain function from the itertools module. In a nutshell, chain creates a single iterator from a sequence of other iterables. This function is a perfect match for our use case.

Equivalent implementation of chain taken from the official docs looks like this.

import itertools


def chain(*iterables):
    # chain('ABC', 'DEF') --> A B C D E F
    for it in iterables:
        for element in it:
            yield element

We can then flatten the list like so:

def flatten_chain(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using chain."""
    return itertools.chain(*lst)


def test_flatten():
    l = [[1, 3], [2, 5], [1]]

    assert list(flatten_chain(l)) == [1, 3, 2, 5, 1]

And the test pass...

[OMITTED]
....
flatten.py::test_flatten PASSED   
...
[OMITTED]

Flattening Irregular Lists

What happens if we have a list like this [[1, 3], [2, 5], 1] or this [1, [2, 3], [[4]], [], [[[[[[[[[5]]]]]]]]]]? Unfortunately, that ends up not so well if we try to apply any of those preceding approaches. In this section, we’ll see two unique solutions for that, one recursive and the other iterative.

The Recursive Approach

Solving the flatting problem recursively means iterating over each item and deciding if the element is already flattened or not. If so, we return it, otherwise we can call flatten_recursive on it. But better than words is code, so let’s see some code.

from typing import List, Any, Iterable


def flatten_recursive(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using recursion."""
    for item in lst:
        if isinstance(item, list):
            yield from flatten_recursive(item)
        else:
            yield item

def test_flatten_recursive():
    l = [[1, 3], [2, 5], 1]

    assert list(flatten_recursive(l)) == [1, 3, 2, 5, 1]

Bare in mind that the if isinstance(item, list) means it only works with lists.

The Iterative Approach

The iterative approach is no doubt the most complex of all of them. In this approach, we’ll use a deque to flatten the irregular list. Quoting the official docs, “Deques are a generalization of stacks and queues (the name is pronounced “deck” and is short for “double-ended queue”). Deques support thread-safe, memory efficient appends and pops from either side of the deque with approximately the same O(1) performance in either direction.”

In other words, we can use deque to simulate the stacking operation of the recursive solution.

To go that, we’ll start just like we did in the recursion case, we’ll iterate through each element and if the element is not a list, we’ll append to the left of the deque. The appendleft method append the element to the leftmost position of the deque, for example:

>>> from collections import deque

>>> l = deque()

>>> l.appendleft(2)

>>> l
deque([2])

>>> l.appendleft(7)

>>> l
deque([7, 2])

If the element is a list, though, then we need to reverse it first to pass it to “extendleft” method, like this my_deque.extendleft(reversed(item)). Again, similar to a list, extendleft adds each item to the leftmost position of the deque in series. As a result, deque will add the elements in a reverse order. That’s exactly why we need to reverse the sub-list before extending left. To make things clearer, let’s see an example.

>>> l = deque()

>>> l.extendleft([1, 2, 3])

>>> l
deque([3, 2, 1])

The final step is to iterate over the deque removing the leftmost element and yielding it if it’s not a list. If the element is a list, then we need to extend it left. The full implement goes like this:

from typing import List, Any, Iterable


def flatten_deque(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using a deque."""
    q = deque()
    for item in lst:
        if isinstance(item, list):
            q.extendleft(reversed(item))
        else:
            q.appendleft(item)
        while q:
            elem = q.popleft()
            if isinstance(elem, list):
                q.extendleft(reversed(elem))
            else:
                yield elem

def test_flatten_super_irregular():
    l = [1, [2, 3], [4], [], [[[[[[[[[5]]]]]]]]]]

    assert list(flatten_deque(l)) == [1, 2, 3, 4, 5]

When we run the test, it happily pass...

============================= test session starts ==============================
...

flatten.py::test_flatten_super_irregular PASSED                          [100%]

============================== 1 passed in 0.01s ===============================

Comparing Performance

As we’ve seen in the previous section, if our list is irregular we have little choice. But assuming that this is not a frequent use case, how these approaches compare in terms of performance?

In this last part, we’ll run a benchmark and compare which solutions perform best.

To do that, we can use the timeit module, we can invoke it in IPython by doing %timeit [code_to_measure]. The following list is the timings for each one of them. As we can see, flatten_chain is the fastest implementation of all. It flatted our list lst in 217 µs avg. The slowest implementation is flatten_sum, taking around 30 ms to flatten the same list.

PS: A special thanks to @hynekcer who pointed out a bug in this benchmark. Since most of the functions return a generator, we need to consume all elements in order to get a better assessment. We can either iterate over the generator or create a list out of it.

Flatten Generator Comprehension

In [1]: lst = [[1, 2, 3], [4, 5, 6], [7], [8, 9]] * 1_000

In [2]: %timeit list(flatten_gen_comp(lst))
433 µs ± 5.58 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Flatten Sum

In [3]: In [19]: %timeit list(flatten_sum(lst))
30 ms ± 343 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Flatten Chain

In [4]: %timeit list(flatten_chain(lst))
217 µs ± 31.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Flatten Recursive

In [5]: %timeit list(flatten_recursive(lst))
1.86 ms ± 20.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Flatten Deque

In [6]: %timeit list(flatten_deque(lst))
2.04 ms ± 78.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Conclusion

We can flatten a list in several different ways with Python. Each approach has its pros and cons. In this post we looked at 5 different ways to flatten a list, including irregular lists.

If you liked this post, consider sharing it with your friends or buying me a coffee:)

Other posts you may like:

See you next time!

Hynek Černoch's photo

All functions measured by a line %timeit some_function(l) are measured incorrectly. A time in about hundreds of nanoseconds to get a list of 9 * 1000 numbers is a nonsense. You will see that the result doesn't depend on a number of items.

It is because any code after a yield command is postponed until you enumerate the output values. If you put a slow sleep command after a yield line then you will see no difference because it run never. Even a generator expression command %timeit (x for x in range(1000000000)) is similarly fast as the same for range(1).

A correct measurement is by %timeit list(some_function(l)) or %timeit for _ in some_function(l): pass

The fastest code for flattening is flatten_chain(l) because the expression itertools.chain(*lst) uses only a library function call and no Python loops.

Miguel Brito's photo

You're absolutely correct. I missed that. So much so that for the tests I always convert it to list.

I'll edit the article with the correct timings.

Thanks for pointing out!

Edit: Fixed it!

HEE JAE CHOI's photo

Great article!!

By the way, is it typo?

def flatten_recursive(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using recursion."""
    for item in lst:
        if isinstance(item, list):
            yield from flatten(item)  # This line! I think it will be `yield from flatten_recursive(item)`
        else:
            yield item
Miguel Brito's photo

Thanks HEE JAE CHOI, and yes, this is a typo!

Thanks for flagging it, I've just fixed it. I renamed it before publishing to make it easier to understand and forgot to double check it in IPython.

This is the correct version:

def flatten_recursive(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using recursion."""
    for item in lst:
        if isinstance(item, list):
            yield from flatten_recursive(item)
        else:
            yield item
Ruud van der Ham's photo

Here's a different, non recursive, solution for deep flattening:

def flatten(bad):
    good  = []
    while bad:
        e = bad.pop()
        if isinstance(e, list):
            bad.extend(e)
        else:
            good.append(e)
    return good[::-1]

Very likely quite fast as well.

Show +1 replies
Ruud van der Ham's photo

Hynek Černoch Good point. It is true that the original list will be destroyed. But if that doen't matter a deep copy is not required.

Ruud van der Ham's photo

Hynek Černoch Maybe it would be better to use this quite different function as an inline flatten, very much like list.sort() or list.remove(). The return value is then always None and the list will be replaced by the flattened version. That can be implemented like:

def flatten_inplace(lst):
    build  = []
    while lst:
        e = lst.pop()
        if isinstance(e, list):
            lst.extend(e)
        else:
            build.append(e)
    lst.extend(build[::-1])