arrsingh.com

Distributed Systems & Artificial Intelligence

Simple Linear Regression in Go

Linear regression is a statistical model that predicts one dependent variable given one or more independent variables. In simpler terms, given a dataset where one output variable (y) is a linear combination of several input variables (x1, x2, x3 ... xn) Linear Regression can be used to estimate (or predict) future values of y for values of x1...xn.

Simple Linear Regression is the simplest linear model where there is only one dependent variable (y) and one independent variable (x). Simple Linear regression is used to estimate the parameters - slope (m) and Y-intercept (c) - of a line that best fits the data. The equation of a line is given by:

$$y = mx + c$$

The good news is that, Simple Linear Regression has an easy closed form solution, so its really easy to implement. Given a dataset of n observations, the slope and the y-intercept for the line of best fit is given by:

\[c = \frac{\sum\limits_{i=1}^{n} y_i - m \sum\limits_{i=1}^{n}x_i}{n}\] \[m = \frac{n \sum\limits_{i=1}^{n}x_i y_i - \sum\limits_{i=1}^{n}x_i \sum\limits_{i=1}^{n}y_i}{n\sum\limits_{i=1}^{n}x_i^2 - \left(\sum\limits_{i=1}^{n}x_i\right)^2}\]

With this background the Go implementation is simple. We'll use the gonum package for the vectors.

package main

import (
	"math"

	"gonum.org/v1/gonum/blas/blas64"
	"gonum.org/v1/gonum/mat"
)

type Line struct {
	YIntercept float64
	Slope      float64
}

func SimpleLinearRegression(x, y *mat.VecDense) *Line {
	if x.Len() != y.Len() {
		panic("vector x must be the same length as y")
	}

	n := x.Len()

	//Σ(x) - sum of the elements of x
	sig_x := mat.Sum(x)
	//Σ(y) - sum of the elements of y
	sig_y := mat.Sum(y)

	//Σ(xy) - dot product of two vectors
	sig_xy := blas64.Dot(x.RawVector(), y.RawVector())

	//Σ(x^2) - Sum of the squares of the elements of x (Dot product of
	// a vector with itself)
	sig_xisq := blas64.Dot(x.RawVector(), x.RawVector())

	//(Σx)^2 - Square of the sum of the elements of x
	sig_xsq := math.Pow(sig_x, 2)

	// Equation of a line = y = mx + c

	// slope of the line (m)
	// m = (n * Σ(xy)) - Σ(x)Σ(y) / (nΣ(x^2) - (Σx)^2)
	m := ((float64(n) * sig_xy) - (sig_x * sig_y)) / ((float64(n) * sig_xisq) - sig_xsq)

	// Y intercept (c)
	// c = (Σ(y) - mΣ(x)) / n
	c := (sig_y - (m * sig_x)) / float64(n)
	return &Line{
		YIntercept: c,
		Slope:      m,
	}
}

 

We can use some dummy data and call the SimpleLinearRegression() function to find the slope and intercept for the line of best fit.

x := mat.NewVecDense(6, []float64{1, 2, 9, 8, 3, 7})
y := mat.NewVecDense(6, []float64{4, 5, 2, 3, 1, 0})

line := SimpleLinearRegression2(x, y)
fmt.Printf("Slope = %2f\n", line.Slope)
fmt.Printf("Intercept = %2f\n", line.YIntercept)
 

Running the code above gives us the values for the slope and the y-intercept.

Slope = -0.275862
Intercept = 3.879310

Verifying the results

We can verify the result against the simple linear regression in the gonum stat package:

//now use the stat package
x := []float64{1, 2, 9, 8, 3, 7}
y := []float64{4, 5, 2, 3, 1, 0}
c, m := stat.LinearRegression(x, y, nil, false)
fmt.Printf("Slope: %f\n", m)
fmt.Printf("Y-Intercept: %f\n", c)
 

This prints out the same values:

Slope: -0.275862
Intercept: 3.879310

Sign up for updates

No Thanks

Great! Check your inbox and click the link to confirm your subscription.