Created
August 14, 2017 23:42
-
-
Save thearchduke/e8226c0cd3c74cc6c31b60fa2d3085ea to your computer and use it in GitHub Desktop.
Basic K-nearest neighbor classifier in standard-library python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Example python implementation of K-nearest neighbor machine learning algorithm | |
''' | |
from collections import Counter | |
import random | |
class KNNClassifier(object): | |
''' | |
K-nearest neighbor classifier for raw_data formatted like so: | |
[(x1, y1, z1, class), (x2, y2, z2, class),...] | |
Numeric data only for vector values | |
''' | |
def __init__(self, raw_data, distance='manhattan', k=3, training_split=0.7): | |
if distance != 'manhattan': | |
raise NotImplementedError() | |
if k > len(raw_data): | |
raise ValueError("KNN cannot have k-value greater than dataset size") | |
my_raw_data = raw_data[:] | |
random.shuffle(my_raw_data) | |
self.training_data = my_raw_data[:int(training_split*len(my_raw_data))] | |
self.testing_data = my_raw_data[int(training_split*len(my_raw_data)):] | |
self.k = k | |
if distance == 'manhattan': | |
self.distance_func = self.manhattan_distance | |
def manhattan_distance(self, p1, p2): | |
distance = 0 | |
for i, val in enumerate(p1): | |
if i == len(p1)-1: | |
break | |
distance += abs(p1[i] - p2[i]) | |
return float(distance) | |
def get_neighbors(self, test_point): | |
neighbors = [] | |
for training_point in self.training_data: | |
dist = self.distance_func(test_point, training_point) | |
neighbors.append((dist, training_point[-1])) | |
neighbors.sort(key=lambda d: d[0]) | |
return [n[1] for n in neighbors[:self.k+1]] | |
def get_classification(self, test_point): | |
neighbors = self.get_neighbors(test_point) | |
c = Counter(neighbors) | |
classification = c.most_common(1)[0][0] | |
return classification | |
@property | |
def accuracy(self): | |
correct_predictions = 0.0 | |
for point in self.testing_data: | |
prediction = self.get_classification(point) | |
if prediction == point[-1]: | |
correct_predictions += 1 | |
return correct_predictions/len(self.testing_data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment