A Beginners Guide To:
Machine Learning Classification

Scroll to Learn!

A Brief History of Classification

Can computers learn from data?

This question led to the development of machine learning. In 1959, Arthur Samuel (a pioneer computer scientist) defined machine learning as “the field of study that gives computers the ability to learn without being explicitly programmed”.

While this may sound complex, we hope to make some machine learning concepts approachable and easy to learn through examples and activities in this article.

Understanding Algorithms

Think of it this way – as you can learn these concepts through experimenting with examples, computers can do the same!

To grasp how machine learning works, we need to understand algorithms - the method for creating models from data. In this article, we will show you popular algorithms that perform classification tasks on a simple dataset related to income.

Case Study

Is Income Predictable?

You are starting your senior year of high school, and are trying to decide if you will further your education beyond graduation. You are wondering if you will see high enough returns in the long run that will allow you to live a comfortable life and pay off student loans. To do this, you look at US census data, and gather real examples of peoples:

1. Age
2. Number of years of education
3. What their income is

You classify income as high if it exceeds $50,000 (Data is from the 1994 US census bureau database. $50,000 is worth ~$88,000 in 2020).

Lets take a quick look at the data below:

Age
Years of Education
High Income
38
8
False
59
13
True
26
13
False
51
13
True
34
11
False
50
15
True
45
13
True
57
10
True
Page 1 of 4
Loading...

Exploring the dataset

As you probably realized, looking at raw data isn’t very useful. So now lets look at visualizations to see what more we can learn from this dataset!

First, let’s look at a simple bar chart showing the number of people with high income and without. As you can see here, there is an equal number of people (16) with high and low income in the dataset, totaling 32 records.

Color this chart by:

Next, let’s look at the spread of peoples’ ages. When the high income button is selected to the right, you can see how old people who have a high income are.

As you can see, it seems as if older people are more likely to be high income. This makes sense as more experience usually means a better salary!
Color this chart by:

This graph shows the amount of people by years of education. When the high income button is selected, you can see how many years of education people with a high income have.

As you can see from this graph, those with high income tend to be the ones with more years of education. However, this is not always the case, as you will see some low income individuals with 9, 10, 11, and 13 years of experience.

Finally, let’s look at the age and education variables together in a scatterplot.

After looking at the data, do you think more years of schooling truly result in a higher income for you?

Evaluating Your Model

After exploring the data, we hope that you got a sense of which patterns are more likely to mean that someone is high income or not. By doing this, you were actually training your own model, which is almost exactly what machine learning models do! However, now that your model is trained, lets put it to the test and have you predict the income of 3 mystery people.

Mystery Person #1:
Age: 57
Years of Education: 15

Mystery Person #2:
Age: 22
Years of Education: 7

Mystery Person #3:
Age: 30
Years of Education: 10

Using Machine Learning

We will now complete the same classifying task with three commonly used classification algorithms:

1. K Nearest Neighbors (KNN)

2. Decision Trees

3. Logistic Regression

Training

Before we get into the algorithms, we need to understand why models need training. Training a model simply means teaching the model good values for a set of examples. Each example will help the model to understand the relationship between the variables (in this case, age and education years) and the classification value (income). Once the relationship is understood, the model will be able to predict an unknown classification value for new sets of variables.

Warning: Overfitting

One thing to pay attention to when using machine learning algorithms is overfitting. Overfitting happens when a model is too complex and starts to classify according to a random error in the data over the expected relationships between variables. A model is considered “overfit” when it fits your training data really well, yet performs poorly on new data.

One way to identify an “overfit” is to reserve a portion of your data set and introduce it after you are finished creating your model to see how it performs. If you model performs poorly on the reserved set, it is overfit to the training data! A way to understand how the model performs is by adjusting the hyperparameters, which are higher-level properties of the model. In our case, they would be the value of K, the depth of the decision tree, or the threshold value for the logistical regression

K Nearest Neighbors (KNN)

The K Nearest Neighbors (KNN) algorithm classifies data based on data that is most similar. KNN uses similar data for classification by plotting a test point with training data points and classifies the test point based on the class of the number (K) of points closest to the test point.

In this example the the right, our “test point” is the red point. Click on the buttons to see how different values of k can lead to different classifications.

When we set K = 3, we see that there are two points near it that are blue, and one that is black. Since we would classify based on the majority, we would classify our test point as blue.
However, when we set K = 6, four of the nearest neighbors are black, while only two are blue. Thus, when K = 6, we would label our test point as black since black now has the majority.

Here, we will apply KNN to our dataset, where you can interact and vary the number of neighbors (K) to see how it will affect what the model predicts.


Here, you can play around with the number of neighbors (k) to see what our model would predict for the last mystery person with an age of 30 and 10 years of education. If you remember from above, this person is classified as low income, but if you make K is 2 or less, the KNN algorithm would actually predict high income. This is because the nearest neighbor is classified as high income.

However, having K = 2 presents an interesting problem as the neighbors are classified differently. While there are many approches to break a tie, we are using the “nearest” approach, which uses the class with the nearest neighbor. There is also the “random” approach, which selects the class of a random neighbor.

Decision Trees

The decision tree is an important machine learning algorithm that is commonly used for classification problems, just this case! This algorithm follows a set of if-else conditions to classify data according to the conditions. An example of a decision tree is shown on the right.

On a high level, decision trees work by reducing information entropy as the data moves down the tree and answers a set of questions.

Information entropy can be thought of as a measurement of uncertainty. Analogously, it can be thought of as how surprised you would be to discover the real label. If all you know is that half of the labels are true and the other half are false, your uncertainty when making predictions would be very high. On the other extreme, if all labels are true then you would not be surprised at all upon discovering the real label. Your uncertainty would be very low and thus entropy would be 0.

At each step (also called a “split”), there is a question about the data at each node in the tree.

The format of a split is “which data points have a value of less than V for the attribute A”.

1. Attributes that contain any numerical value within a range are called continuous. Both age and years of education are therefore considered continuous. The important first step in the algorithm is to create discretized values for each attribute. To keep things simple, age is discretized in steps of 5 (20, 25,30..) and the years of education are discretized in steps of 2 (6, 8, 10...). Each discretization forms a candidate for a potential split.

2. You should calculate the entropy of the data at the current node using the formula for Shannon information entropy as shown on the right. This gives an indication of how mixed the dataset is at this level and the goal is to create a split that reduces the entropy of our child nodes by the most.

3. For each discretization and for each attribute:

3a. Data points that have a value less than V for attribute A are passed to the left node while data that have a value greater or equal to V for attribute A are passed to the right node.

3b. Take a weighted average of the entropy of both child nodes where the weightings are equal to the number of data points on the child nodes divided by the number of data points on the parent nodes. This average of child entropies is also called “the conditional entropy”.

3c. Subtract the parent entropy by this weighted average in order to calculate the information gain. Choose the split that results in the highest information gain. To avoid making excessive splits, make sure to stop if the highest information gain is 0.

3d. Iteratively repeat steps 2 and 3 until you have reached the max depth or until all the labels of the data points belong to the same class (remember that this is when entropy equals 0).

What is depth?

The depth of decision trees refers to the number of layers that the tree has. Decision tree depth is a delicate balance, as too much depth could cause overfitting, and too little depth could lead to less accuracy. To ensure that your tree has appropriate depth, there are two commonly used methods:

1. Grow the tree until it overfits, then start cutting back

2. Prevent the tree from growing by stopping it just before perfect classification of the training data Let’s look at how the data performs at different depths below.

Below is a visualization showing the decision tree run on our census dataset! Use the scroller to change the maximum depth of the decision tree.

Depth 1

Depth 2

Depth 3

Depth 4

1 4



Now, let’s gain a different perspective on how the decision tree makes predictions. Each line in the graph above represents a split in the decision tree. This graph shows how the decision tree breaks down the data into smaller and smaller rectangles and makes the prediction based on which rectangle the data point lands in.

One interesting note thing to note is that as the max depth increases, the decision tree tries to account for the anomalous data point (the 32 year old with 9 years of education who is a high earner). If we were to increase max depth past 4, the decision tree would wrap around the square containing the anomalous data point and predict all future data points within that square as “High Income”. What is this an example of?

-20-101020-0.50.51.01.5
f(x)=11+exf(x)=\frac{1}{1+e^{-x}}

Logistic Regression

Contrary to its name, logistic regression is not a regression algorithm, but a classification algorithm. It does so by performing a linear regression based on the attributes, and then uses that regression to calculate the probability of whether or not the data point is in a certain class.

Before understanding logistic regression, one must first understand linear regression. Suppose we have a set of points that are in tuples (x, y). We can visualize this set of points in a scatterplot, such as one right here:

What if we wanted to fit a line that best fits this data? Our data is noisy, so it won’t fit a line perfectly, but one can manually try to do so.

Alternatively, one can use a method called “ordinary least squares”, or OLS for short. Since a line can be written as a function, any line can be used to create estimates of an output given input values. Each point has a corresponding residual, which is computed by taking the difference between the actual value and our estimated value. OLS is the method that calculates the slope and intercept of the line that minimizes the sum of the squares of these residual values:

minm,bxipoints(yi(mxib))2min_{m, b} \sum_{x_i \in points} (y_i - (m\cdot x_i - b))^2

This scary equation pretty much just helps us find the line that best fits our data points. So, for our example, our ordinary least-squares regression line is: Y=0.5x+3Y=0.5x + 3

Now that we know the basics of linear regression, we can move on to logistic regression! While explaining the theory behind logistic regression is beyond the scope of this article, we will explain in the overall concept of how to perform logistic regression.

For a given data point, we can find the OLS estimate using our least-squares regression line equation. This OLS output will be the x-value of the datapoint. The y-value of the datapoint is the label, which, in our original example, is whether or not an individual has a high income (we can set y = 1 if the data point has a high income, and y = 0 otherwise). For our dataset, we get the regression line of approximately

Y=16.4+0.25Age+0.53YoE,Y = -16.4 + 0.25*Age + 0.53*YoE,

where YoE is short for the “Years of Education” value. We can then plot the points as such:

Notice how the points with X<0 are usually labeled with 0 (low income), while points with X>0 have a label of 1 (high income). We can see that there is a point with X ~ -3.593 that is a clear outlier. Looking back at our exploratory data analysis, this is the outlier point point of (9, 32) that is labeled as high income despite the fact that all similar points are low income.

Moving beyond the dataset, the next question is naturally how would we figure out what our model would label points we haven’t seen before?

To do so, we fit a sigmoid function on our linear regression line. The sigmoid function is defined as: f(x)=11+exf(x) =\frac{1}{1+e^{-x}} The sigmoid function is plotted on the right.

The sigmoid function applied to a linear regression model provides the probability that we classify the output as a 1 (having high income) given the linear regression output , or P(y=1x)P(y=1 | x) . The mathematical intuition behind this is also beyond the scope of the article, but these probabilities are then used to classify unknown points. We can set a threshold value, and then given a probability of the point being labeled as 1, use the threshold to decide how to classify the data point.

Changing the threshold could affect what we label points as. Feel free to use the slider below to see how big our linear regression output needs to be to label the data as “high income” given different threshold values. Note that the red line is the threshold probability, and the shaded green area shows the corresponding values of X that would be classified as “1”, or having high income.

0.3 0.7

Value of threshold: 0.50.

Model Selection

Now that we introduced ourselves to the three classification models, how would we decide which model to use?

Note that for the earlier parts of this report, we only ran an algorithm on 32 points of data. However, the dataset these points originally came from actually had 32,561 rows of data! If we run the algorithm on all of the rows (using the best hyperparameters for all the models), we would receive the following accuracies:

alg
train
test
KNN
0.794
0.787
Dt
0.799
0.784
LogReg
0.783
0.778
Loading...


We see that for each algorithm, the training and test accuracies are close to each other, signifying that our models do not exhibit overfitting. In addition, the accuracies are actually pretty close between the different algorithms! This means when looking at accuracy, the performance of these models are similar. As a future step, one can look at other measures of performance, such as precision and recall, to determine what the best model is. Other factors, such as an individual’s understanding or interpretability of the model, may also be considerations when selecting models in the future.

Conclusion

Great Job!

We hope that you now have a good understanding of the basics of 3 different types of machine learning classification algorithms. As you saw, these algorithms learn from the data input given and use it to classify new observations. These algorithms are used to perform analytical tasks that would take humans hundreds of more hours to perform!

Machine learning has applications in many fields like medical, defense, finance, and technology just to name a few. And, classification algorithms are at the heart of many innovations like image recognition, speech recognition, self-driving cars, email spam filtering, and ecommerce product recommendations. We hope you are inspired to further your knowledge about these topics!