diff --git a/iris.py b/iris.py index 19f11820e617391f631656449c1ff13746e63c0e..a95c1ce4e387ffa4aeb4426dea76395a2b7aea7b 100644 --- a/iris.py +++ b/iris.py @@ -1,7 +1,7 @@ import pandas as pd import numpy as np from collections import Counter - +import matplotlib.pyplot as plt def manhattan_distance(x1: np.ndarray, x2: np.ndarray) -> float: distance = 0 @@ -26,6 +26,7 @@ if __name__ == '__main__': centroids = X[np.random.choice(X.shape[0], k, replace=False)] # Nombre maximal d'itérations max_iter = 100 + total_variances = [] # Algorithme K-Means for i in range(max_iter): @@ -45,15 +46,17 @@ if __name__ == '__main__': if len(clusters[j]) > 0: centroids[j] = np.mean(clusters[j], axis=0) + # Calculer la variance totale à cette itération + total_variance = 0.0 + for j in range(k): + for point in clusters[j]: + total_variance += manhattan_distance(point, centroids[j]) ** 2 + + total_variances.append(total_variance) + # Convergence ? if np.all(old_centroids == centroids): break - # Calculer la variance totale - total_variance = 0.0 - for cluster_index, cluster_points in enumerate(clusters): - cluster_center = centroids[cluster_index] - for point in cluster_points: - total_variance += manhattan_distance(point, cluster_center) ** 2 # Calculer le taux de classification par cluster et la classe majoritaire par cluster cluster_classifications = {} @@ -70,4 +73,12 @@ if __name__ == '__main__': # Afficher le taux de classification par cluster et la classe majoritaire par cluster for cluster_index in range(k): print(f'Cluster {cluster_index + 1} - Taux de classification : {cluster_classifications[cluster_index]}') - print(f'Cluster {cluster_index + 1} - Classe majoritaire : {cluster_majority_class[cluster_index]}') \ No newline at end of file + print(f'Cluster {cluster_index + 1} - Classe majoritaire : {cluster_majority_class[cluster_index]}') + + # Afficher les variances totales à chaque itération + plt.plot(range(1, len(total_variances) + 1), total_variances, marker='o') + plt.xlabel('Itération') + plt.ylabel('Variance totale') + plt.title('Variance totale à chaque itération de K-Means') + plt.grid(True) + plt.show()