How to Flatten a List of Lists in Python

How to Flatten a List of Lists in Python

by Leodanis Pozo Ramos Jun 26, 2023 intermediate data-science

Sometimes, when you’re working with data, you may have the data as a list of nested lists. A common operation is to flatten this data into a one-dimensional list in Python. Flattening a list involves converting a multidimensional list, such as a matrix, into a one-dimensional list.

To better illustrate what it means to flatten a list, say that you have the following matrix of numeric values:

Python
>>> matrix = [
...     [9, 3, 8, 3],
...     [4, 5, 2, 8],
...     [6, 4, 3, 1],
...     [1, 0, 4, 5],
... ]

The matrix variable holds a Python list that contains four nested lists. Each nested list represents a row in the matrix. The rows store four items or numbers each. Now say that you want to turn this matrix into the following list:

Python
[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

How do you manage to flatten your matrix and get a one-dimensional list like the one above? In this tutorial, you’ll learn how to do that in Python.

How to Flatten a List of Lists With a for Loop

How can you flatten a list of lists in Python? In general, to flatten a list of lists, you can run the following steps either explicitly or implicitly:

  1. Create a new empty list to store the flattened data.
  2. Iterate over each nested list or sublist in the original list.
  3. Add every item from the current sublist to the list of flattened data.
  4. Return the resulting list with the flattened data.

You can follow several paths and use multiple tools to run these steps in Python. Arguably, the most natural and readable way to do this is to use a for loop, which allows you to explicitly iterate over the sublists.

Then you need a way to add items to the new flattened list. For that, you have a couple of valid options. First, you’ll turn to the .extend() method from the list class itself, and then you’ll give the augmented concatenation operator (+=) a go.

To continue with the matrix example, here’s how you would translate these steps into Python code using a for loop and the .extend() method:

Python
>>> def flatten_extend(matrix):
...     flat_list = []
...     for row in matrix:
...         flat_list.extend(row)
...     return flat_list
...

Inside flatten_extend(), you first create a new empty list called flat_list. You’ll use this list to store the flattened data when you extract it from matrix. Then you start a loop to iterate over the inner, or nested, lists from matrix. In this example, you use the name row to represent the current nested list.

In every iteration, you use .extend() to add the content of the current sublist to flat_list. This method takes an iterable as an argument and appends its items to the end of the target list.

Now go ahead and run the following code to check that your function does the job:

Python
>>> flatten_extend(matrix)
[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

That’s neat! You’ve flattened your first list of lists. As a result, you have a one-dimensional list containing all the numeric values from matrix.

With .extend(), you’ve come up with a Pythonic and readable way to flatten your lists. You can get the same result using the augmented concatenation operator (+=) on your flat_list object. However, this alternative approach may not be as readable:

Python
>>> def flatten_concatenation(matrix):
...     flat_list = []
...     for row in matrix:
...         flat_list += row
...     return flat_list
...

This function is similar to flatten_extend(). The only difference is that you’ve replaced the call to .extend() with an augmented concatenation. Concatenations like this allow you to append a list of items to the end of an existing list.

Go ahead and call this function with matrix as an argument:

Python
>>> flatten_concatenation(matrix)
[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

This call to flatten_concatenation() returns the same result as your previous flatten_extend(). Both functions are equivalent and interchangeable. Their goal is to flatten a list of lists. So, it’s up to you to decide which one to use in your code. However, readability-wise, flatten_extend() seems to be a better solution.

Now that you know how to tackle the problem of flattening a list of lists, you’re ready to continue exploring other approaches to the same task. This way, you’ll be better prepared to choose the right tool for the job in every concrete situation.

After finishing your exploration, you’ll run a performance test to compare the execution time of the different approaches to flattening a list of lists. This test will provide relevant data for those use cases in which the code’s performance is critical for you.

To kick things off, you’ll start by using a list comprehension, which is a popular and Pythonic list-transformation tool.

Using a Comprehension to Flatten a List of Lists

List comprehensions are a distinctive feature of Python. They’re quite popular in the Python community, so you’ll likely find them in many codebases. List comprehensions allow you to quickly create and transform lists using a syntax that mimics a for loop but only requires a single line of code.

The core syntax of a list comprehension looks something like this:

Python
[expression(item) for item in iterable]

Every list comprehension needs at least three components:

  1. expression() is a Python expression that returns a concrete value, and most of the time, that value depends on item.
  2. item is the current object from iterable.
  3. iterable can be any Python iterable object, such as a list, tuple, set, string, or generator.

The for construct iterates over the items in iterable, while expression(item) provides the corresponding item for the new list that results from running the comprehension. Note that comprehensions can also have nested for clauses and conditional statements. In this tutorial, you’ll use nested for clauses.

You can use a list comprehension when you need to flatten a list of lists. The function below shows how:

Python
>>> def flatten_comprehension(matrix):
...     return [item for row in matrix for item in row]
...

This list comprehension has two nested for clauses. The first one iterates over the rows in matrix, which is your list of lists. The second for clause iterates over the items in each row. In this case, the expression is relatively straightforward because you only need to extract the items from each sublist.

Here’s how this function works in practice:

Python
>>> matrix = [
...     [9, 3, 8, 3],
...     [4, 5, 2, 8],
...     [6, 4, 3, 1],
...     [1, 0, 4, 5],
... ]

>>> flatten_comprehension(matrix)
[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

The call to flatten_comprehension() processes the content of matrix, flattens it, and returns a new one-dimensional list containing the original data.

Comprehensions are quite popular in Python. They allow you to create new lists out of existing iterables. They’re like concise for loops that can help you quickly transform your data and get a new list as a result.

However, comprehensions aren’t the only alternative tool that you can use to flatten a list of lists. The Python standard library hosts a few other tools that can help you with that task. You can even find useful options in the built-in tool kit. Are you ready to learn about them?

Flattening a List Using Standard-Library and Built-in Tools

You’ll also find a few standard-library and built-in tools that you can use to flatten a list of lists in Python. For example, you can use any of the following tools:

In the following sections, you’ll learn how these tools can assist you when you need to flatten a list of lists in your code.

Chaining Iterables With itertools.chain()

As its name suggests, the chain() function chains multiple iterables into a single one. However, instead of giving you a list, chain() returns an iterator that yields items from all the input iterables until they get exhausted.

You can take advantage of chain() along with list() to flatten a list of lists. Here’s how:

Python
>>> from itertools import chain

>>> def flatten_chain(matrix):
...     return list(chain.from_iterable(matrix))
...

>>> matrix = [
...     [9, 3, 8, 3],
...     [4, 5, 2, 8],
...     [6, 4, 3, 1],
...     [1, 0, 4, 5],
... ]

>>> flatten_chain(matrix)
[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

In this example, you first import chain() from the itertools module. Instead of being a regular function, chain is implemented as a class. That’s why chain has .from_iterable() as a class method. This method provides an alternative constructor that you can use to build a chain from an iterable of iterables, so you can feed it with a list of lists.

The final step is to build a list out of the iterator that .from_iterable() returns. To do that, you call list(), which consumes the iterator and stores its data in a new list. Now if you call flatten_chain() with matrix as an argument, you get a list of flattened data, as expected.

Something to highlight in this solution is that it’s a readable approach to the problem. You can read this almost as plain English: chain the rows of matrix into a single iterable and then convert it into a list.

Concatenating Lists With functools.reduce()

The reduce() function from the functools module is another tool that you can use to flatten lists of lists. This function is part of Python’s functional programming tool kit. It takes a two-argument function that must return a single value and applies that function to the items in an iterable.

To do its job, reduce() takes a pair of items and computes a partial result. Then it uses that result and the next item to compute the next partial result. This process creates an implicit accumulator that stores the cumulative value in every step.

You can use different function-like objects with reduce() to flatten a list of lists. In the example below, you use a custom lambda function:

Python
>>> from functools import reduce

>>> def flatten_reduce_lambda(matrix):
...     return list(reduce(lambda x, y: x + y, matrix, []))
...

>>> matrix = [
...     [9, 3, 8, 3],
...     [4, 5, 2, 8],
...     [6, 4, 3, 1],
...     [1, 0, 4, 5],
... ]

>>> flatten_reduce_lambda(matrix)
[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

The first argument to reduce() is the lambda function, which takes two arguments, x and y, and returns their sum. Then you pass matrix as the second argument to reduce(). Finally, you use an empty list as the third argument. This argument holds an initial value to start the computation with.

Note that in this example, reduce() internally runs a list-concatenation process. It concatenates the initial empty list with the rows in matrix. The final result is your desired flattened list.

Apart from choosing a custom lambda function, you can also use a few other functions from the standard library to get the same result as in the example above. For example, you can use the following functions from the operator module:

  • add() sums two numbers together. It’s equivalent to the addition operator (+).
  • concat() concatenates two values together. It’s equivalent to the concatenation operator on lists (+).
  • iconcat() concatenates two values together in place. It’s equivalent to the augmented concatenation operator (+=).

To flatten a list of lists by using one of these functions as an argument to reduce(), you just need to replace the lambda function in the above example with the desired function. Go ahead and give it a try!

Using sum() to Concatenate Lists

The built-in sum() function is another tool that you can use to flatten a list of lists. This use case of Python’s sum() may seem weird to you at first glance. Some developers may say that it’s not readable or obvious, but it works, and you may find this approach in other people’s code:

Python
>>> def flatten_sum(matrix):
...     return sum(matrix, [])
...

>>> matrix = [
...     [9, 3, 8, 3],
...     [4, 5, 2, 8],
...     [6, 4, 3, 1],
...     [1, 0, 4, 5],
... ]

>>> flatten_sum(matrix)
[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

In this example, you use sum() to concatenate the sublists in matrix. Note that for this to work, you must provide an empty list as the second argument to sum(). This argument sets an initial value to start the concatenation.

Even if the built-in sum() function may not be optimized for this type of operation, it provides a quick, one-liner solution that doesn’t require you to import anything or run an explicit loop to get it to work. So, it can save you some thinking and coding time.

Considering Performance While Flattening Your Lists

Up to this point, you’ve learned about several tools and techniques that you can use to flatten a list of lists in Python. An important aspect of any data-processing algorithm is its efficiency in terms of execution time. This is especially true in Python because it’s not fast compared to other programming languages like C++ or Java.

In this section, you’ll write a small script to test how quickly the different approaches can flatten lists of lists. To kick things off, create a new flatten.py file and put in it all the flattening functions that you’ve coded so far. It should look something like this:

Python
# flatten.py

from functools import reduce
from itertools import chain
from operator import add, concat, iconcat

def flatten_extend(matrix):
    flat_list = []
    for row in matrix:
        flat_list.extend(row)
    return flat_list

def flatten_concatenation(matrix):
    flat_list = []
    for row in matrix:
        flat_list += row
    return flat_list

def flatten_comprehension(matrix):
    return [item for row in matrix for item in row]

def flatten_chain(matrix):
    return list(chain.from_iterable(matrix))

def flatten_reduce_lambda(matrix):
    return list(reduce(lambda x, y: x + y, matrix, []))

def flatten_reduce_add(matrix):
    return reduce(add, matrix, [])

def flatten_reduce_concat(matrix):
    return reduce(concat, matrix, [])

def flatten_reduce_iconcat(matrix):
    return reduce(iconcat, matrix, [])

def flatten_sum(matrix):
    return sum(matrix, [])

Note that in this code, you’re trying out all of the functions that you can use with reduce(), including add(), concat(), and iconcat().

You can also download the sample code for this tutorial by clicking on the link below:

Then use your favorite code editor or IDE to create a new file with the following content:

Python
# performance.py

from timeit import timeit

import flatten

SIZE = 1000
TO_MS = 1000
NUM = 10
FUNCTIONS = [
    "flatten_extend",
    "flatten_concatenation",
    "flatten_comprehension",
    "flatten_chain",
    "flatten_reduce_lambda",
    "flatten_reduce_add",
    "flatten_reduce_concat",
    "flatten_reduce_iconcat",
    "flatten_sum",
]

matrix = [list(range(SIZE))] * SIZE

results = {
    func: timeit(f"flatten.{func}(matrix)", globals=globals(), number=NUM)
    for func in FUNCTIONS
}

print(f"Time to flatten a {SIZE}x{SIZE} matrix (in milliseconds):\n")

for func, time in sorted(results.items(), key=lambda result: result[1]):
    print(f"{func + '()':.<30}{time * TO_MS / NUM:.>7.2f} ms")

In this script, you first import the flatten.py module. Next, you create four constants to store some data that you’ll use later. The final constant is a list of all the functions that you want to compare. Then you create a list of lists with sample data. You call it matrix, as usual.

In this example, you use the timeit() function from the timeit module to measure the execution time of each flattening function. A dictionary comprehension helps you build a dictionary that maps function names to their execution times.

Finally, you run a for loop to show the results sorted by execution time. With this code in place, you can run the script from your command line using the command below. Note that this will take some time before you get the actual output:

Shell
$ python performance.py
Time to flatten a 1000x1000 matrix (in milliseconds):

flatten_concatenation()..........1.95 ms
flatten_extend().................2.03 ms
flatten_reduce_iconcat().........2.68 ms
flatten_chain()..................4.60 ms
flatten_comprehension()..........7.79 ms
flatten_sum().................1113.22 ms
flatten_reduce_concat().......1117.15 ms
flatten_reduce_lambda().......1117.52 ms
flatten_reduce_add()..........1118.80 ms

In this output, you can see that there are significant differences between the fastest and slowest functions. If you run the script several times, then you’ll see that the functions flatten_concatenation(), flatten_extend(), and flatten_reduce_iconcat() are always fighting for first place. This behavior makes sense because these functions perform in-place mutation on an existing list object.

Meanwhile, flatten_chain() is in a decent fourth place but still takes twice the time of the first three functions. However, there’s nothing wrong with chain(). The slow-down effect can come from the call to list(), which creates the final list by consuming the iterator that chain() returns. So, if you’re looking for a memory-efficient solution, then chain() is for you.

Similarly, flatten_comprehension() always occupies the fifth place, maybe because the list comprehension requires running two nested loops to get the list flattened.

The rest of the functions based on reduce() and flatten_sum() are in the last places, with much poorer performance. This behavior makes sense if you consider that these functions keep creating new intermediate lists until they get the final result. In contrast, the top three functions create only one list and mutate it in place. So, they’re also more memory efficient.

With these results, you can conclude that for flattening a list of lists in Python, your best bet is using a for loop and the augmented concatenation operator (+=) or the .extend() method. Using flatten_sum() or certain reduce() functions means creating new intermediate lists, which makes these solutions significantly slower than the others.

Flattening Python Lists for Data Science With NumPy

Many data scientists use Python as their programming language of choice. Are you one of them? If that’s the case, then you spend a lot of time preparing your data for further processing and analysis. In this context, having the ability to flatten lists of lists may be a common requirement.

A flat dataset may be useful when you need to train a machine learning model or use a data science algorithm.

As an example of how to flatten your data, say that you have a NumPy array of nested arrays representing a matrix, and you want the get a flattened array containing all the data from the original one. NumPy arrays have a .flatten() method that does all you need:

Python
>>> import numpy as np

>>> matrix = np.array(
...     [
...         [9, 3, 8, 3],
...         [4, 5, 2, 8],
...         [6, 4, 3, 1],
...         [1, 0, 4, 5],
...     ]
... )

>>> matrix
array([[9, 3, 8, 3],
       [4, 5, 2, 8],
       [6, 4, 3, 1],
       [1, 0, 4, 5]])

>>> matrix.flatten()
array([9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5])

If you call .flatten() on matrix, then you get a one-dimensional array containing all the data. You’ve flattened the original multidimensional array into a flat, or one-dimensional, one. Cool!

Conclusion

In this tutorial, you’ve learned how to flatten a list of lists in Python. You’ve used different tools and techniques to accomplish this task. First, you used a for loop along with the .extend() list method. Then you used other tools, such as list comprehensions and functions like functools.reduce(), itertools.chain(), and sum().

Then you ran a performance test to find out which of these tools offer faster solutions for flattening lists of lists. The test results indicate that your best options include using a loop and the .extend() method or the augmented concatenation operator.

By the end of the tutorial, you also explored some data science–related tools that allow you to flatten nested datasets in your Python code.

With all this knowledge, you’re now ready to start flattening nested data in Python, whether you’re doing regular data processing or data science.

🐍 Python Tricks 💌

Get a short & sweet Python Trick delivered to your inbox every couple of days. No spam ever. Unsubscribe any time. Curated by the Real Python team.

Python Tricks Dictionary Merge

About Leodanis Pozo Ramos

Leodanis is an industrial engineer who loves Python and software development. He's a self-taught Python developer with 6+ years of experience. He's an avid technical writer with a growing number of articles published on Real Python and other sites.

» More about Leodanis

Each tutorial at Real Python is created by a team of developers so that it meets our high quality standards. The team members who worked on this tutorial are:

Master Real-World Python Skills With Unlimited Access to Real Python

Locked learning resources

Join us and get access to thousands of tutorials, hands-on video courses, and a community of expert Pythonistas:

Level Up Your Python Skills »

Master Real-World Python Skills
With Unlimited Access to Real Python

Locked learning resources

Join us and get access to thousands of tutorials, hands-on video courses, and a community of expert Pythonistas:

Level Up Your Python Skills »

What Do You Think?

Rate this article:

What’s your #1 takeaway or favorite thing you learned? How are you going to put your newfound skills to use? Leave a comment below and let us know.

Commenting Tips: The most useful comments are those written with the goal of learning from or helping out other students. Get tips for asking good questions and get answers to common questions in our support portal.


Looking for a real-time conversation? Visit the Real Python Community Chat or join the next “Office Hours” Live Q&A Session. Happy Pythoning!

Keep Learning

Related Tutorial Categories: intermediate data-science