Linear Regression is a simple and powerful model for predicting a numeric response from a set of one or more independent variables. This article will focus mostly on how the method is used in machine learning, so we won't cover common use cases like causal inference or experimental design. And although it may seem like linear regression is overlooked in modern machine learning's ever-increasing world of complex neural network architectures, the algorithm is still widely used across a large number of domains because it is effective, easy to interpret, and easy to extend. The key ideas in linear regression are recycled everywhere, so understanding the algorithm is a must-have for a strong foundation in machine learning.
Let's Be More Specific
Linear regression is a supervised algorithm[ℹ]
that learns to model a dependent variable, , as
a function of some independent variables (aka "features"), , by finding a line (or surface) that best "fits" the data. In general, we
assume to be some number and each
can be basically anything. For example: predicting
the price of a house using the number of rooms in that house (: price, : number of rooms) or predicting
weight from height and age (: weight, : height, : age).
In general, the equation for linear regression is
where:
Fitting a linear regression model is all about finding the set of coefficients that best model as a function of our features. We may never know the true parameters for our model, but we can estimate them (more on this later). Once we've estimated these coefficients, , we predict future values, , as:
So predicting future values (often called inference), is as simple as plugging the values of our features into our equation!
To make linear regression easier to digest, let's go through a quick, high-level introduction of how it works. We'll scroll through the core concepts of the algorithm at a high-level, and then delve into the details thereafter:
Let's fit a model to predict housing price ($) in San Diego, USA
using the size of the house (in square-footage):
We'll start with a very simple model, predicting the price of each house
to be just the average house price in our dataset, ~$290,000, ignoring
the different sizes of each house:
Of course we know this model is bad - the model doesn't fit the data
well at all. But how can do quantify exactly how bad?
To evaluate our model's performance quantitatively, we plot the error
of each observation directly. These errors, or
residuals, measure the distance between
each observation and the predicted value for that observation. We'll
make use of these residuals later when we talk about evaluating
regression models, but we can clearly see that our model has a lot
of error.
The goal of linear regression is reducing this error such that we
find a line/surface that 'best' fits our data. For our simple
regression problem, that involves estimating the y-intercept and slope
of our model, and .
For our specific problem, the best fit line is shown. There's still error,
sure, but the general pattern is captured well. As a result, we can be
reasonably confident that if we plug in new values of square-footage,
our predicted values of price would be reasonably accurate.
Once we've fit our model, predicting future values is super easy! We
just plug in any values into our equation!
For our simple model, that means plugging in a value for
into our model (try adjusting the slider):
Value: 350
Thus, our model predicts a house that is 350 square-feet will cost
$293,683.
Now that we have a high-level idea of how linear regression works, let's
dive a bit deeper. The remainder of this article will cover how to evaluate
regression models, how to find the "best" model, how to interpret different
forms of regression models, and the assumptions underpinning correct usage
of regression models in statistical settings.
Let's dive in!
To train an accurate linear regression model, we need a way to quantify how good
(or bad) our model performs. In machine learning, we call such performance-measuring
functions loss functions. Several popular loss functions exist for
regression problems.[ℹ]
To measure our model's performance, we'll use one of the most popular: mean-squared
error (MSE).
Mean-Squared Error (MSE)
MSE quantifies how close a predicted value is to the true value, so we'll use
it to quantify how close a regression line is to a set of points. MSE works by
squaring the distance between each data point and the regression line (the red
residuals in the graphs above), summing the squared values, and then dividing
by the number of data points:
The name is quite literal: take the mean of the squared errors. The squaring
of errors prevents negative and positive terms from canceling out in the sum,[ℹ]
and gives more weight to points further from the regression line, punishing outliers.
In practice, we'll fit our regression model to a set training data, and evaluate
it's performance using MSE on the test dataset.
R-Squared
Regression models may also be evaluated with the so-called
goodness of fit
measures, which summarize how well a model fits a set of data. The most popular
goodness of fit measure for linear regression is r-squared, a metric that represents
the percentage of the variance in explained by our
features .[ℹ]
More specifically, r-squared measures the percentage of variance explained normalized
against the baseline variance of our model (which is just the variance of the
the trivial model that always predicts the mean):
The highest possible value for r-squared is 1, representing a model that captures
100% of the variance. A negative r-squared means that our model is doing worse
(capturing less variance) than a flat line through mean of our data would. (The name
"r-squared" falsely implies that it would not have a negative value.)
To build intuition for yourself, try changing the weight and
intercept terms below to see how the MSE and r-squared change across different
possible models for a toy dataset (click Shuffle Data to make a new toy dataset):
Intercept (): 5.00
Weight (): 0.00
You will often see R-Squared referenced in statistical contexts as a way to
assess model fit.
Selecting An Evaluation Metric
Many methods exist for evaluating regression models, each with different concerns
around interpretability, theory, and usability. The evaluation metric should
reflect whatever it is you actually care about when making predictions. For example,
when we use MSE, we are implicitly saying that we think the cost of our prediction
error should reflect the quadratic (squared) distance between what we predicted
and what is correct. This may work well if we want to punish outliers or if our
data is minimized by the mean, but this comes at the cost of interpretability:
we output our error in squared units (though this may be fixed with
RMSE). If instead we wanted our error to reflect the linear distance between
what we predicted and what is correct, or we wanted our data minimized by
the median, we could try something like Mean Absolute Error (MAE). Whatever the case, you should be thinking of your evaluation metric as
part of your modeling process, and select the best metric based on the
specific concerns of your use-case.
Let's recap what we've learned so far: Linear regression is all about finding a
line (or surface) that fits our data well. And as we just saw, this involves selecting
the coefficients for our model that minimize our evaluation metric. But how can
we best estimate these coefficients? In practice, they're unknown, and selecting
them by hand quickly becomes infeasible for regression models with many features.
There must be a better way!
Luckily for us, several algorithms exist to do just this. We'll discuss two: an
iterative solution and a closed-form solution.
An Iterative Solution
Gradient descent is an iterative optimization algorithm that estimates some set
of coefficients to yield the minimum of a convex function. Put simply: it will
find suitable coefficients for our regression model that minimize prediction error
(remember, lower MSE equals better model).
A full conversation on gradient descent is outside the course of this article (stay tuned
for our future article on the subject), but if you'd like to learn more, click
the "Show Math" button below. Otherwise, read on!
Gradient descent will iteratively identify the coefficients our model needs to
fit the data. Let's see an example directly. We'll fit data to our equation
, so gradient descent will learn two coefficients, (the intercept) and (the weight). To do
so, interact with the plot below. Try dragging the weights and values to create a
'poorly' fit (large error) solution and run gradient descent to see the error
iteratively improve.
Click the buttons to run 1, 10, or 100 steps of gradient descent, and see the linear regression model update live. The error at each iteration of gradient descent (or manual coefficient update) is shown in the bottom chart. With each weight update, we recalculate the error, so you can see how gradient descent improves our model iteratively.
Intercept (): 0.100
Weight (): 0.100
Although gradient descent is the most popular optimization algorithm in machine learning, it's not perfect! It doesn't work for every loss function, and it may not always find the most optimal set of coefficients for your model. Still, it has many extensions to help solve these issues, and is widely used across machine learning.
A Closed-Form Solution
We'd be remiss not to mention the Normal Equation, a widely taught method for
obtaining estimates for our linear regression coefficients. The Normal Equation
is a closed-form solution that allows us to estimate our coefficients directly
by minimizing the residual sum of squares (RSS) of our data:
The RSS should look familiar - it was a key piece in both the MSE and r-squared
formulas that represents our model's total squared error:
Add circles to the chart below to see how the Normal Equation calculates two
features, the intercept and weight, for the corresponding regression model.
Despite providing a convenient closed-form solution for finding our optimal coefficients, the Normal Equation estimates are often not used in practice, because of the computational complexity required to invert a matrix with too many features. While our two feature example above runs fast (we can run it in the browser!), most machine learning models are more complicated. For this reason, we often just use gradient descent.
Are Our Coefficients Valid?
In research publications and statistical software, coefficients of regression
models are often presented with associated p-values. These p-values come from
traditional null hypothesis statistical tests: t-tests are used to measure whether
a given coefficient is significantly different than zero (the null hypothesis
that a particular coefficient equals zero),
while F tests are used to measure whether
any
of the terms in a regression model are significantly different from zero. Different
opinions exist on the utility of such tests (e.g. chapter 10.7 of
[1] maintains they're not super important). We don't
take a strong stance on this issue, but believe practitioners should always assess
the standard error around any parameter estimates for themselves and present them
in their research.
One of the most powerful aspects of regression models is their
interpretability. However, different forms of regression models require
different interpretations. To make this clear, we'll walk through several
typical constructs of a regression model, and describe how to interpret each
in turn. For all aforementioned models, we interpret the error term as
irreducible noise not captured by our model.
Select a tab to learn
how to interpret the given form of regression model:
A Regression Model With One Binary Feature
Example:
Interpretation: This model summarizes the difference
in average housing prices between houses without swimming pools () and houses with swimming pools ().
The intercept, $172,893, is the average
predicted price for houses that do not have swimming pools (to see this,
set to 0 and simplify the equation).
To find the average price predicted price for houses that do have pools, we plug
in to obtain
$172,893 + $241,582 * 1 = $414,475.
The difference between these two subpopulation means is equal to
the coefficient on . It tells us that houses
with pools cost $241,582 more on average than houses
that do not have pools.[ℹ]
Of course, this is not an exhaustive list of regression models, many other forms exist!
When teaching regression models, it's common to mention the various
assumptions underpinning linear regression. For completion, we'll list some of
those assumptions here. However, in the context of machine learning we care
most about if the predictions made from our model generalize well to unseen
data. We'll use our model if it generalizes well even if it violates
statistical assumptions. Still, no treatment of regression is complete without
mentioning the assumptions.
Validity:
Does the data we're modeling matches to the problem we're actually trying to solve?
Representativeness:
Is the sample data used to train the regression model representative of the population
to which it will be applied?
Additivity and Linearity:
The deterministic component of a regression model is a linear function of the separate
predictors: .
Independence of Errors:
The errors from our model are independent.
Homoscedasticity:
The errors from our model have equal variance.
Normality of Errors:
The errors from our model are normally distributed.
When Assumptions Fail?
What should we do if the assumptions for our regression model aren't met? Don't
fret, it's not the end of the world! First, double-check that the assumptions even
matter in the first place: if the predictions made from our model generalize well
to unseen data, and our task is to create a model that generalizes well, then we're
probably fine. If not, figure out which assumption is being violated, and how to
address it! This will change depending on the assumption being violated, but in
general, one can attempt to extend the model, accompany new data, transform the
existing data, or some combination thereof. If a model transformation is unfit,
perhaps the application (or research question) can be changed or restricted to
better align with the data. In practice, some combination of the above will usually
suffice.
The study of linear regression is a very deep topic: there's a ton of different things to talk about and we'd be foolish to try to cover them all in one single article. Some of those topics left unmentioned are: regularization methods, selection techniques, common regression transformations, bayesian formulations of regression, and additional evaluation techniques. For those interested in learning more, we recommend diving deep into the aforementioned topics, or reading the resources below. Hopefully this article serves as a nice starting point for learning about linear regression.
Thanks for reading. We hope that the article is insightful no matter where you
are along your Machine Learning journey, and that you came away with a better
understanding of linear regression.
To learn more about Machine Learning, check out our
self-paced courses, our
YouTube videos, and the
Dive into Deep Learning
textbook. If you have any comments or ideas related to
MLU-Explain articles, feel free to reach out
directly. The code for
this article is available
here.
This article is a product of the following resources + the awesome people who made (and contributed to) them:
[1] Regression And Other Stories
(Gelman, Hill and Vehtari 2020).
[2] Elements of Statistical Learning
(John Ross Quinlan, 1986).
[3] Mathematical Statistics and Data Analysis
(John A. Rice, 2010).
D3.js
(Mike Bostock &
Philippe Rivière)
KaTeX
(Emily Eisenberg
& Sophie Alpert)
Svelte
(Rich Harris)