diff --git a/src/cluster.c b/src/cluster.c index bfb714912253863941f169aafc1ee9158897d5a6..fbc62fbdba7889f26509cfc0184be8ff91ce74d6 100644 --- a/src/cluster.c +++ b/src/cluster.c @@ -1,4 +1,5 @@ #include "cluster.h" +#include <stdbool.h> #include "vector.h" @@ -43,26 +44,58 @@ void cluster_add_point_fpt(cluster_fpt_t* cluster, vector_fpt_t* point) { } -void cluster_update_center_int(cluster_int_t* cluster) { - vector_destroy_int(cluster->center); +bool cluster_update_center_int(cluster_int_t* cluster) { + // save old center + vector_int_t* old_center = cluster->center; + // create new center list_points_node_int_t* node = cluster->points->head; cluster->center = vector_create_int(node->point->dim); + // sum all values in center while (node != NULL) { vector_add_inplace_int(cluster->center, *(node->point)); node = node->next; } + // divide by number of points vector_div_inplace_int(cluster->center, (int_t) cluster->points->size); + // check whether center has changed + bool changed = false; + for (size_t p = 0; p < cluster->center->dim; ++p) { + if (cluster->center->data[p] != old_center->data[p]) { + changed = true; + break; + } + } + // destroy old center + vector_destroy_int(old_center); + // return true if center has changed + return changed; } -void cluster_update_center_fpt(cluster_fpt_t* cluster) { - vector_destroy_fpt(cluster->center); +bool cluster_update_center_fpt(cluster_fpt_t* cluster) { + // save old center + vector_fpt_t* old_center = cluster->center; + // create new center list_points_node_fpt_t* node = cluster->points->head; cluster->center = vector_create_fpt(node->point->dim); + // sum all values in center while (node != NULL) { vector_add_inplace_fpt(cluster->center, *(node->point)); node = node->next; } + // divide by number of points vector_div_inplace_fpt(cluster->center, (fpt_t) cluster->points->size); + // check whether center has changed + bool changed = false; + for (size_t p = 0; p < cluster->center->dim; ++p) { + if (cluster->center->data[p] != old_center->data[p]) { + changed = true; + break; + } + } + // destroy old center + vector_destroy_fpt(old_center); + // return true if center has changed + return changed; } diff --git a/src/cluster.h b/src/cluster.h index 0e2993aff11b27d54429e703540872263be5ddfc..0888785b55ebec639c3f378c68bb8a73a807af91 100644 --- a/src/cluster.h +++ b/src/cluster.h @@ -1,6 +1,7 @@ #ifndef PROG_KMEANS_CLUSTER_H #define PROG_KMEANS_CLUSTER_H +#include <stdbool.h> #include <stdlib.h> #include "linkedlist.h" #include "vector.h" @@ -32,9 +33,9 @@ void cluster_add_point_int(cluster_int_t* cluster, vector_int_t* point); void cluster_add_point_fpt(cluster_fpt_t* cluster, vector_fpt_t* point); -void cluster_update_center_int(cluster_int_t* cluster); +bool cluster_update_center_int(cluster_int_t* cluster); -void cluster_update_center_fpt(cluster_fpt_t* cluster); +bool cluster_update_center_fpt(cluster_fpt_t* cluster); void cluster_reset_int(cluster_int_t* cluster); diff --git a/src/kmeans.c b/src/kmeans.c index a8b6df1d0b1af24f185191ba89e35b804a9c9fe9..11795413e2ba0334bff5083581f836f20c14f239 100644 --- a/src/kmeans.c +++ b/src/kmeans.c @@ -90,18 +90,60 @@ cluster_fpt_t** kmeans_init_clusters_fpt(const vector_fpt_t** points, const size } -void kmeans_int(vector_int_t** points, const size_t point_count, cluster_int_t** clusters, const size_t nb_clusters, fpt_t (* distance_function)(const vector_int_t*, const vector_int_t*)) { - //TODO +void kmeans_int(vector_int_t** points, const size_t point_count, cluster_int_t** clusters, const size_t nb_clusters, + fpt_t (* distance_function)(const vector_int_t*, const vector_int_t*)) { bool changed = true; while (changed) { changed = false; + for (size_t i = 0; i < point_count; ++i) { + vector_int_t* point = points[i]; + // find closest cluster and add point to it + cluster_int_t* cmin = clusters[0]; + int_t dmin = distance_function(point, cmin->center); + for (size_t k = 1; k < nb_clusters; ++k) { + cluster_int_t* current_cluster = clusters[k]; + fpt_t dist = distance_function(point, current_cluster->center); + if (dist < dmin) { + cmin = current_cluster; + dmin = dist; + } + } + cluster_add_point_int(cmin, point); + // update all cluster centers + for (size_t k = 0; k < nb_clusters; ++k) { + if (cluster_update_center_int(clusters[k])) { + changed = true; + } + } + } } } -void kmeans_fpt(vector_fpt_t** points, const size_t point_count, cluster_fpt_t** clusters, const size_t nb_clusters, fpt_t (* distance_function)(const vector_fpt_t*, const vector_fpt_t*)) { - //TODO +void kmeans_fpt(vector_fpt_t** points, const size_t point_count, cluster_fpt_t** clusters, const size_t nb_clusters, + fpt_t (* distance_function)(const vector_fpt_t*, const vector_fpt_t*)) { bool changed = true; while (changed) { changed = false; + for (size_t i = 0; i < point_count; ++i) { + vector_fpt_t* point = points[i]; + // find closest cluster and add point to it + cluster_fpt_t* cmin = clusters[0]; + fpt_t dmin = distance_function(point, cmin->center); + for (size_t k = 1; k < nb_clusters; ++k) { + cluster_fpt_t* current_cluster = clusters[k]; + fpt_t dist = distance_function(point, current_cluster->center); + if (dist < dmin) { + cmin = current_cluster; + dmin = dist; + } + } + cluster_add_point_fpt(cmin, point); + // update all cluster centers + for (size_t k = 0; k < nb_clusters; ++k) { + if (cluster_update_center_fpt(clusters[k])) { + changed = true; + } + } + } } }