7 Different Ways to Flatten a List of Lists in Python

7 Different Ways to Flatten a List of Lists in Python

Learn how to "unlist" (unnest) a irregular list; nested lists of tuples, ints or strings; with or without recursion; or using itertools python

Ever wondered how can you flatten, or unnest, a 2D list of lists in Python?

In another words, how to turn 2-D lists into 1D:

  • [[1, 2], [3, 4], [5, 6, 7], [8]] -> [1, 2, 3, 4, 5, 6, 7, 8]?
  • [[1, 2], [4, 5], [[[7]]], [[[[8]]]]] -> [1, 2, 3, 4, 5, 6, 7, 8]?
  • [1, 2, 3, [4, 5, 6], [[7, 8]]] -> [1, 2, 3, 4, 5, 6, 7, 8]?

What about lists with mixed types such as list of strings, or list of tuples?

  • [[1, 2], "three", ["four", "five"]] -> [1, 2, "three", "four", "five"]
  • [[1, 2], (3, 4), (5, 6, 7), [8]] -> [1, 2, 3, 4, 5, 6, 7, 8]

In this post, we’ll see how we can unnest an arbitrarily nested list of lists in 7 different ways. Each method has pros and cons, and varies in performance. By going over each one, you’ll learn how to identify the most appropriate solution for your problem by creating your own flatten() function in Python.

For all examples, we'll use Python 3, and for the tests pytest.

By the end of this guide, you'll have learned:

  • how to flatten / unnest a list of mixed types, including list of strings, list of tuples or ints
  • the best way to flatten lists of lists with list comprehensions
  • how to unfold a list and remove duplicates
  • how to convert a nested list of lists using the built-in function sum from the standard library
  • how to use numpy to flatten nested lists
  • how to use itertools chain to create a flat list
  • the best way to flatten a nested list using recursion or without recursion

Table of Contents

  1. Flattening a list of lists with list comprehensions
  2. How to flatten list of strings, tuples or mixed types
  3. [How to flatten a nested list and remove duplicates](#how to flatten a list and remove duplicates)
  4. Flattening a nested list of lists with the sum function
  5. Flattening using itertools.chain
  6. Flatten a regular list of lists with numpy
  7. Flattening irregular lists

  8. Which method is faster? A performance comparison

  9. Conclusion

Flattening a list of lists with list comprehensions

Let’s imagine that we have a simple list of lists like this [[1, 3], [2, 5], [1]] and we want to flatten it.

In other words, we want to convert the original list into a flat list like this [1, 3, 2, 5, 1]. The first way of doing that is through list/generator comprehensions. We iterate through each sublist, then iterate over each one of them producing a single element each time.

The following function accepts any multidimensional lists as an argument and returns a generator. The reason for that is to avoid building a whole list in memory. We can then use the generators to create a single list.

To make sure everything works as expected, we can assert the behavior with the test_flatten unit test.

from typing import List, Any, Iterable


def flatten_gen_comp(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using generators comprehensions."""
    return (item
            for sublist in lst
            for item in sublist)


def test_flatten():
    lst = [[1, 3], [2, 5], [1]]

    assert list(flatten_gen_comp(lst)) == [1, 3, 2, 5, 1]

This function returns a generator, to get a list back we need to convert the generator to list.

When we run the test we can see it passing...

============================= test session starts ==============================

flatten.py::test_flatten PASSED                                          [100%]

============================== 1 passed in 0.01s ===============================

Process finished with exit code 0

If you prefer you can make the code shorter by using a lambda function.

>>> flatten_lambda = lambda lst: (item for sublist in lst for item in sublist)

>>> lst = [[1, 3], [2, 5], [1]]

>>> list(flatten_lambda(lst))
[1, 3, 2, 5, 1]

How to flatten list of strings, tuples or mixed types

The technique we've seen assumes the items are not iterables. Otherwise, it flattens them as well, which is the case for strings. Let's see what happens if we plug a list of lists and strings.

from typing import List, Any, Iterable


def flatten_gen_comp(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using generators comprehensions."""
    return (item
            for sublist in lst
            for item in sublist)


>>> lst = [['hello', 'world'], ['my', 'dear'], 'friend']
>>> list(flatten(lst))
['hello', 'world', 'my', 'dear', 'f', 'r', 'i', 'e', 'n', 'd']

Oops, that's not what we want! The reason is, since one of the items is iterable, the function will unfold them as well.

One way of preventing that is by checking if the item is a list of not. If not, we don't iterate over it.

>>> from typing import List, Any, Iterable


>>> def flatten(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using generators comprehensions.
        Returns a flattened version of list lst.
    """

    for sublist in lst:
         if isinstance(sublist, list):
             for item in sublist:
                 yield item
         else:
             yield sublist

>>> lst = [['hello', 'world'], ['my', 'dear'], 'friend']

>>> list(flatten(l))
['hello', 'world', 'my', 'dear', 'f', 'r', 'i', 'e', 'n', 'd']

Since we check if sublist is a list of not, this works with list of tuples as well, any list of iterables, for that matter.

>>> lst = [[1, 2], 3, (4, 5)]
>>> list(flatten(lst))
[1, 2, 3, (4, 5)]

Lastly, this flatten function works for multidimensional list of mixed types.

>>> lst = [[1, 2], 3, (4, 5), ["string"], "hello"]

>>> list(flatten(lst))
[1, 2, 3, (4, 5), 'string', 'hello']

How to flatten a list and remove duplicates

To flatten a list of lists and return a list without duplicates, the best way is to convert the final output to a set.

The only downside is that if the list is big, there'll be a performance penalty since we need to create the set using the generator, then convert set to list.

>>> from typing import List, Any, Iterable


>>> def flatten(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using generators comprehensions.
        Returns a flattened version of list lst.
    """

    for sublist in lst:
         if isinstance(sublist, list):
             for item in sublist:
                 yield item
         else:
             yield sublist

>>> lst = [[1, 2], 3, (4, 5), ["string"], "hello", 3, 4, "hello"]

>>> list(set(flatten(lst)))
[1, 2, 3, 'hello', 4, (4, 5), 'string']

Flattening a list of lists with the sum function

The second strategy is a bit unconventional and, truth to be told, very "magical".

Did you know that we can use the built-in function sum to create flattened lists?

All we need to do is to pass the list as an argument along an empty list. The following code snippet illustrates that.

If you’re curious about this approach, I discuss it in more detail in another post.

from typing import List, Any, Iterable


def flatten_sum(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using sum."""
    return sum(lst, [])


def test_flatten():
    lst = [[1, 3], [2, 5], [1]]

    assert list(flatten_sum(lst)) == [1, 3, 2, 5, 1]

And the tests pass too...

flatten.py::test_flatten PASSED

Even though it seems clever, it's not a good idea to use it in production IMHO. As we'll see later, this is the worst way to flatten a list in terms of performance.

Flattening using itertools.chain

The third alternative is to use the chain function from the itertools module. In a nutshell, chain creates a single iterator from a sequence of other iterables. This function is a perfect match for our use case.

Equivalent implementation of chain taken from the official docs looks like this.

import itertools


def chain(*iterables):
    # chain('ABC', 'DEF') --> A B C D E F
    for it in iterables:
        for element in it:
            yield element

We can then flatten the multi-level lists like so:

def flatten_chain(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using chain."""
    return itertools.chain(*lst)


def test_flatten():
    lst = [[1, 3], [2, 5], [1]]

    assert list(flatten_chain(lst)) == [1, 3, 2, 5, 1]

And the test pass...

[OMITTED]
....
flatten.py::test_flatten PASSED   
...
[OMITTED]

Flatten a regular list of lists with numpy

Another option to create flat lists from nested ones is to use numpy. This library is mostly used to represent and perform operations on multidimensional arrays such as 2D and 3D arrays.

What most people don't know is that some of its function also work with multidimensional lists or other list of iterables. For example, we can use the numpy.concatenate function to flatten a regular list of lists.

>>> import numpy as np

>>> lst = [[1, 3], [2, 5], [1], [7, 8]]

>>> list(np.concatenate(lst))
[1, 3, 2, 5, 1, 7, 8]

Flattening irregular lists

So far we've been flattening regular lists, but what happens if we have a list like this [[1, 3], [2, 5], 1] or this [1, [2, 3], [[4]], [], [[[[[[[[[5]]]]]]]]]]?

Unfortunately, that ends up not so well if we try to apply any of those preceding approaches. In this section, we’ll see two unique solutions for that, one recursive and the other iterative.

The recursive approach

Solving the flatting problem recursively means iterating over each list element and deciding if the item is already flattened or not. If so, we return it, otherwise we can call flatten_recursive on it. But better than words is code, so let’s see some code.

from typing import List, Any, Iterable


def flatten_recursive(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using recursion."""
    for item in lst:
        if isinstance(item, list):
            yield from flatten_recursive(item)
        else:
            yield item

def test_flatten_recursive():
    lst = [[1, 3], [2, 5], 1]

    assert list(flatten_recursive(lst)) == [1, 3, 2, 5, 1]

Bare in mind that the if isinstance(item, list) means it only works with lists. On the other hand, it will flatten lists of mixed types with no trouble.

>>> list(flatten_recursive(lst))
[1, 2, 3, (4, 5), 'string', 'hello']

The iterative approach

The iterative approach is no doubt the most complex of all of them.

In this approach, we’ll use a deque to flatten the irregular list. Quoting the official docs:

"Deques are a generalization of stacks and queues (the name is pronounced “deck” and is short for “double-ended queue”). Deques support thread-safe, memory efficient appends and pops from either side of the deque with approximately the same O(1) performance in either direction."

In other words, we can use deque to simulate the stacking operation of the recursive solution.

To do that, we’ll start just like we did in the recursion case, we’ll iterate through each element and if the element is not a list, we’ll append to the left of the deque. The appendleft method append the element to the leftmost position of the deque, for example:

>>> from collections import deque

>>> l = deque()

>>> l.appendleft(2)

>>> l
deque([2])

>>> l.appendleft(7)

>>> l
deque([7, 2])

If the element is a list, though, then we need to reverse it first to pass it to “extendleft” method, like this my_deque.extendleft(reversed(item)). Again, similar to a list, extendleft adds each item to the leftmost position of the deque in series. As a result, deque will add the elements in a reverse order. That’s exactly why we need to reverse the sub-list before extending left. To make things clearer, let’s see an example.

>>> l = deque()

>>> l.extendleft([1, 2, 3])

>>> l
deque([3, 2, 1])

The final step is to iterate over the deque removing the leftmost element and yielding it if it’s not a list. If the element is a list, then we need to extend it left. The full implement goes like this:

from typing import List, Any, Iterable


def flatten_deque(lst: List[Any]) -> Iterable[Any]:
    """Flatten a list using a deque."""
    q = deque()
    for item in lst:
        if isinstance(item, list):
            q.extendleft(reversed(item))
        else:
            q.appendleft(item)
        while q:
            elem = q.popleft()
            if isinstance(elem, list):
                q.extendleft(reversed(elem))
            else:
                yield elem

def test_flatten_super_irregular():
    lst = [1, [2, 3], [4], [], [[[[[[[[[5]]]]]]]]]]

    assert list(flatten_deque(lst)) == [1, 2, 3, 4, 5]

When we run the test, it happily pass...

============================= test session starts ==============================
...

flatten.py::test_flatten_super_irregular PASSED                          [100%]

============================== 1 passed in 0.01s ===============================

This approach can also flatten lists of mixed types with no trouble.

>>> list(flatten_deque(lst))
[1, 2, 3, (4, 5), 'string', 'hello']

Which method is faster? A performance comparison

As we’ve seen in the previous section, if our multi-level list is irregular we have little choice. But assuming that this is not a frequent use case, how these approaches compare in terms of performance?

In this last part, we’ll run a benchmark and compare which solutions perform best.

To do that, we can use the timeit module, we can invoke it in IPython by doing %timeit [code_to_measure]. The following list is the timings for each one of them. As we can see, flatten_chain is the fastest implementation of all. It flatted our list lst in 267 µs avg. The slowest implementation is flatten_sum, taking around 42 ms to flatten the same list.

PS: A special thanks to @hynekcer who pointed out a bug in this benchmark. Since most of the functions return a generator, we need to consume all elements in order to get a better assessment. We can either iterate over the generator or create a list out of it.

Flatten generator comprehension

In [1]: lst = [[1, 2, 3], [4, 5, 6], [7], [8, 9]] * 1_000

In [2]: %timeit list(flatten_gen_comp(lst))
615 µs ± 2.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Flatten sum

In [3]: In [19]: %timeit list(flatten_sum(lst))
42 ms ± 660 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Flatten chain

In [4]: %timeit list(flatten_chain(lst))
267 µs ± 517 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Flatten numpy

In [5]: %timeit list(flatten_numpy(lst))
4.65 ms ± 14.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Flatten recursive

In [6]: %timeit list(flatten_recursive(lst))
3.02 ms ± 174 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Flatten deque

In [7]: %timeit list(flatten_deque(lst))
2.97 ms ± 21.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Conclusion

We can flatten a multi-level list in several different ways with Python. Each approach has its pros and cons. In this post we looked at 5 different ways to create 1D lists from nested 2D lists, including:

  • regular nested lists
  • irregular lists
  • list of mixed types
  • list of strings
  • list of tuples or ints
  • recursive and iterative approach
  • using the iterools module
  • removing duplicates

Other posts you may like:

See you next time!