From 9265faccf53706f61c043b21754dbf9736dc042f Mon Sep 17 00:00:00 2001 From: "thibault.capt" <thibault.capt@etu.hesge.ch> Date: Mon, 9 Oct 2023 18:35:41 +0200 Subject: [PATCH] show variance --- main.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 63cc38f..03e4755 100644 --- a/main.py +++ b/main.py @@ -21,7 +21,7 @@ if __name__ == '__main__': centroids = X[np.random.choice(X.shape[0], k, replace=False)] max_iter = 100 - + total_variances = [] # Algorithme K-Means for i in range(max_iter): # Créer des clusters vides @@ -40,6 +40,14 @@ if __name__ == '__main__': if len(clusters[j]) > 0: centroids[j] = np.mean(clusters[j], axis=0) + # Calculer la variance totale à cette itération + total_variance = 0.0 + for j in range(k): + for point in clusters[j]: + total_variance += manhattan_distance(point, centroids[j]) ** 2 + + total_variances.append(total_variance) + # Afficher les centroides finaux et les clusters for iteration, centroid in enumerate(centroids): cluster_points = np.array(clusters[iteration]) @@ -54,4 +62,12 @@ if __name__ == '__main__': # Convergence ? if np.all(old_centroids == centroids): - break \ No newline at end of file + break + + # Afficher les variances totales à chaque itération + plt.plot(range(1, len(total_variances) + 1), total_variances, marker='o') + plt.xlabel('Itération') + plt.ylabel('Variance totale') + plt.title('Variance totale à chaque itération de K-Means') + plt.grid(True) + plt.show() -- GitLab