Regression is a statistical approach used for predicting real values like age, weight, salary, for example.
In regression, we have a dependent variable which we want to predict using some independent variables.
The goal of regression is to find the relationship between an independent variable and a dependent variable, which then be used to predict the outcome of an event.
Simple linear regression is one type of regression technique. The main objective of this algorithm is to find the straight line which best fits the data.
The best fit line is chosen such that the distance from the line to all the points is minimum.
If we have an independent variable x and a dependent variable y, then the linear relationship between both the variables can be given by the equation
Y = b0 + b1*X
Where b0 is the y-intercept and b1 is the slope.
Now, let’s implement the simple linear regression using Keras
You can download the dataset here.
Let’s take a quick look at the dataset.
We have to predict the salaries based on the years of experience
Let’s first import the necessary modules and the data
1 2 3 4 5 6 7 8 9 10 |
import pandas as pd import matplotlib.pyplot as plt from keras.models import Sequential from keras.layers import Dense, Activation from keras.optimizers import SGD df = pd.read_csv('Salary.csv') X = df.iloc[:,0] y = df.iloc[:,1] |
Now we can create our Keras model for linear regression. We will use a single Dense layer with a linear activation function.
1 2 3 |
model = Sequential() model.add(Dense(1, input_dim=1)) model.add(Activation('linear')) |
Compile the model with SGD optimizer with a learning rate of 0.01 and set the mean squared error as the loss function.
1 2 |
sgd = SGD(0.01) model.compile(loss='mse',optimizer=sgd) |
Now we can use the model.fit() function to train the model using
1 |
history = model.fit(X,y,epochs=500,verbose=0) |
As the model is trained, we can now use the predict method to make predictions
1 |
pred = model.predict(X) |
Let’s plot our predictions
1 2 3 |
plt.scatter(X, y, c='blue') plt.plot(X, pred, color='g') plt.show() |
The green line is the model predictions and the blue points represent our data.
Complete code can be found in this Github Repo.