Select Git revision
student-data.py
thibault.capt authored
student-data.py 2.47 KiB
# Importation des bibliothèques
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
def manhattan_distance(x1: np.ndarray, x2: np.ndarray) -> float:
distance = 0
if x1.shape == x2.shape:
for i in range(x1.size):
distance += np.abs(x1[i] - x2[i])
return distance
if __name__ == '__main__':
dataset = pd.read_csv("Data/student-data-test.csv", header=0)
X = dataset.iloc[:, 1:].values
k = 3
# Initialisation des k-centroides de manière aléatoire
centroids = X[np.random.choice(X.shape[0], k, replace=False)]
max_iter = 100
total_variances = []
# Algorithme K-Means
for i in range(max_iter):
# Créer des clusters vides
clusters = [[] for _ in range(k)]
# Sauvegarder les anciens centroïdes
for point in X:
distances = [manhattan_distance(point, centroid) for centroid in centroids]
cluster_index = np.argmin(distances)
clusters[cluster_index].append(point)
# Sauvegarder les anciens centroïdes
old_centroids = centroids.copy()
# Mettre à jour les centroïdes en calculant la moyenne des points dans chaque cluster
for j in range(k):
if len(clusters[j]) > 0:
centroids[j] = np.mean(clusters[j], axis=0)
# Calculer la variance totale à cette itération
total_variance = 0.0
for j in range(k):
for point in clusters[j]:
total_variance += manhattan_distance(point, centroids[j]) ** 2
total_variances.append(total_variance)
# Afficher les centroides finaux et les clusters
for iteration, centroid in enumerate(centroids):
cluster_points = np.array(clusters[iteration])
plt.scatter(cluster_points[:, 0], cluster_points[:, 1], label=f"Cluster {iteration}")
plt.scatter(centroid[0], centroid[1], marker="+", s=200, c="red", label=f"Centroïde {iteration + 1}")
plt.xlabel("Grade 1")
plt.ylabel("Grade 2")
plt.title(f"Clustering K-Means {i}")
plt.legend()
plt.show()
# Convergence ?
if np.all(old_centroids == centroids):
break
# Afficher les variances totales à chaque itération
plt.plot(range(1, len(total_variances) + 1), total_variances, marker='o')
plt.xlabel('Itération')
plt.ylabel('Variance totale')
plt.title('Variance totale à chaque itération de K-Means')
plt.grid(True)
plt.show()