From aeac2c49966e4585cdd5a3128917b5531b8185c3 Mon Sep 17 00:00:00 2001
From: Ivan Pavlovich <ivan.pavlovic@hes-so.ch>
Date: Fri, 14 Mar 2025 03:41:24 +0100
Subject: [PATCH] More precise token calculation

---
 models/LLM/Tokenizer/doc/token_count.json |  96 +++++++++++++---
 models/LLM/Tokenizer/token_count.py       | 130 ++++++++++++++++++----
 2 files changed, 192 insertions(+), 34 deletions(-)

diff --git a/models/LLM/Tokenizer/doc/token_count.json b/models/LLM/Tokenizer/doc/token_count.json
index 28282848c..cbbf029ee 100644
--- a/models/LLM/Tokenizer/doc/token_count.json
+++ b/models/LLM/Tokenizer/doc/token_count.json
@@ -1,20 +1,90 @@
 {
     "bert-base-uncased": {
-        "day": {
-            "min": 0,
-            "max": 336390,
-            "mean": 55947.34222222222
+        "ALL": 62940760,
+        "NO KEYWORDS": {
+            "day": {
+                "min": 0,
+                "max": 336390,
+                "mean": 55947.34222222222
+            },
+            "week": {
+                "min": 0,
+                "max": 610773,
+                "mean": 390936.39751552796
+            },
+            "month": {
+                "min": 149220,
+                "max": 1988608,
+                "mean": 1701101.6216216215
+            }
         },
-        "week": {
-            "min": 0,
-            "max": 610773,
-            "mean": 390936.39751552796
+        "KEYWORDS": {
+            "day": {
+                "min": 0,
+                "max": 14061,
+                "mean": 2494.1857777777777
+            },
+            "week": {
+                "min": 0,
+                "max": 28111,
+                "mean": 17428.316770186335
+            },
+            "month": {
+                "min": 12058,
+                "max": 105204,
+                "mean": 75836.72972972973
+            }
         },
-        "month": {
-            "min": 149220,
-            "max": 1988608,
-            "mean": 1701101.6216216215
+        "SUBHEADINGS": {
+            "day": {
+                "min": 0,
+                "max": 14061,
+                "mean": 2494.1857777777777
+            },
+            "week": {
+                "min": 0,
+                "max": 28111,
+                "mean": 17428.316770186335
+            },
+            "month": {
+                "min": 12058,
+                "max": 105204,
+                "mean": 75836.72972972973
+            }
         },
-        "ALL": 62940760
+        "SITE PROPOSITION": {
+            "day": {
+                "min": 0,
+                "max": 17409,
+                "mean": 3292.2702222222224
+            },
+            "week": {
+                "min": 0,
+                "max": 36705,
+                "mean": 23004.993788819876
+            },
+            "month": {
+                "min": 13250,
+                "max": 124682,
+                "mean": 100102.81081081081
+            }
+        },
+        "PROPOSITION": {
+            "day": {
+                "min": 0,
+                "max": 24471,
+                "mean": 4493.711111111111
+            },
+            "week": {
+                "min": 0,
+                "max": 49793,
+                "mean": 31400.155279503106
+            },
+            "month": {
+                "min": 17661,
+                "max": 172341,
+                "mean": 136633.1081081081
+            }
+        }
     }
 }
\ No newline at end of file
diff --git a/models/LLM/Tokenizer/token_count.py b/models/LLM/Tokenizer/token_count.py
index 98f377612..051857de6 100644
--- a/models/LLM/Tokenizer/token_count.py
+++ b/models/LLM/Tokenizer/token_count.py
@@ -8,6 +8,7 @@ from datetime import datetime, timedelta
 sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))
 
 from testModel.utils import get_article_data
+from variables.pubmed import NCDS_MESH_TERM, KEYWORDS_MESH_TERM, KEYWORDS_MESH_SUBHEADING, KEYWORDS_MESH_SITE_PROPOSITION, KEYWORDS_MESH_PROPOSITION
 
 DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../dataSources/PubMed/data"))
 DOC_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./doc"))
@@ -18,10 +19,40 @@ INTERVALS = [
     "month"
 ]
 
+CATEGORIES = [
+    "NO KEYWORDS",
+    "KEYWORDS",
+    "SUBHEADINGS",
+    "SITE PROPOSITION",
+    "PROPOSITION"
+]
+
 TOKENIZERS = [
     "bert-base-uncased"
 ]
 
+def lower_keywords(mesh_terms):
+    res = []
+
+    for _, mesh_term in mesh_terms.items():
+        if isinstance(mesh_term, list):
+            res.append([part.lower() for part in mesh_term])
+        else:
+            res.append(mesh_term.lower())
+    
+    return res
+
+def mesh_term_present(article_mesh_terms, mesh_term):
+        if isinstance(mesh_term, list):
+            all_in = True
+            for part in mesh_term:
+                if part not in article_mesh_terms:
+                    all_in = False
+            
+            return all_in
+        else:
+            return mesh_term in article_mesh_terms
+
 def get_date_indices(date, start_date):
     delta_days = (date - start_date).days
     day_index = delta_days
@@ -33,6 +64,20 @@ def get_date_indices(date, start_date):
 
     return day_index, week_index, month_index
 
+def add_num_token(article_date, start_date, token_num, counts, tokenizer_name, category):
+    day_index, week_index, month_index = get_date_indices(article_date, start_date)
+
+    counts[tokenizer_name][category]["day"][day_index] += token_num 
+    counts[tokenizer_name][category]["week"][week_index] += token_num
+    counts[tokenizer_name][category]["month"][month_index] += token_num
+
+
+ncds_mesh_terms = [mesh_term.lower() for ncd, mesh_term in NCDS_MESH_TERM.items()]
+keywords_mesh_terms = lower_keywords(KEYWORDS_MESH_TERM)
+keywords_subheading_mesh_terms = lower_keywords(KEYWORDS_MESH_SUBHEADING)
+keywords_site_proposition_mesh_terms = lower_keywords(KEYWORDS_MESH_SITE_PROPOSITION)
+keywords_proposition_mesh_terms = lower_keywords(KEYWORDS_MESH_PROPOSITION)
+
 
 file_path = f"{DATA_DIR}/save_3_years.json"
 with open(file_path, "r", encoding="utf-8") as file:
@@ -43,12 +88,14 @@ print(len(data))
 counts = {}
 
 for tokenizer_name in TOKENIZERS:
-    counts[tokenizer_name] = {
-        "day": {},
-        "week": {},
-        "month": {},
-        "ALL": 0
-    }
+    counts[tokenizer_name] = {}
+    counts[tokenizer_name]["ALL"] = 0
+    for category in CATEGORIES:
+        counts[tokenizer_name][category] = {
+            "day": {},
+            "week": {},
+            "month": {},
+        }
 
 start_date = datetime(2022, 1, 1)
 end_date = datetime(2025, 1, 30)
@@ -58,9 +105,10 @@ while(current_date < end_date):
     day_index, week_index, month_index = get_date_indices(current_date, start_date)
 
     for tokenizer_name in TOKENIZERS:
-        counts[tokenizer_name]["day"][day_index] = 0
-        counts[tokenizer_name]["week"][week_index] = 0
-        counts[tokenizer_name]["month"][month_index] = 0
+        for category in CATEGORIES:
+            counts[tokenizer_name][category]["day"][day_index] = 0
+            counts[tokenizer_name][category]["week"][week_index] = 0
+            counts[tokenizer_name][category]["month"][month_index] = 0
 
     current_date += timedelta(days=1)
 
@@ -74,29 +122,69 @@ for tokenizer_name in TOKENIZERS:
     for article in data:
         print(f"Article N°{i}")
 
+        article_mesh_terms = [mesh_term.lower() for mesh_term in article["MeshTerms"]]
         article_date = datetime(int(article["Date"]["Year"]), int(article["Date"]["Month"]), int(article["Date"]["Day"]))
         title, abstract = get_article_data(article)
 
         tokens = tokenizer(title+abstract, return_tensors="pt")
         num_tokens = len(tokens["input_ids"][0])
 
-        day_index, week_index, month_index = get_date_indices(article_date, start_date)
-
-        counts[tokenizer_name]["day"][day_index] += num_tokens
-        counts[tokenizer_name]["week"][week_index] += num_tokens
-        counts[tokenizer_name]["month"][month_index] += num_tokens
+        add_num_token(article_date, start_date, num_tokens, counts, tokenizer_name, "NO KEYWORDS")
         counts[tokenizer_name]["ALL"] += num_tokens
 
+        added = False
+
+        for ncd in ncds_mesh_terms:
+            if added:
+                break
+
+            if mesh_term_present(article_mesh_terms, ncd):
+
+                for keyword in keywords_mesh_terms:
+                    if added:
+                        break
+                    if mesh_term_present(article_mesh_terms, keyword):
+                        add_num_token(article_date, start_date, num_tokens, counts, tokenizer_name, "KEYWORDS")
+                        add_num_token(article_date, start_date, num_tokens, counts, tokenizer_name, "SUBHEADINGS")
+                        add_num_token(article_date, start_date, num_tokens, counts, tokenizer_name, "SITE PROPOSITION")
+                        add_num_token(article_date, start_date, num_tokens, counts, tokenizer_name, "PROPOSITION")
+                        added = True
+
+                for keyword in keywords_subheading_mesh_terms:
+                    if added:
+                        break
+                    if mesh_term_present(article_mesh_terms, keyword):
+                        add_num_token(article_date, start_date, num_tokens, counts, tokenizer_name, "SUBHEADINGS")
+                        add_num_token(article_date, start_date, num_tokens, counts, tokenizer_name, "SITE PROPOSITION")
+                        add_num_token(article_date, start_date, num_tokens, counts, tokenizer_name, "PROPOSITION")
+                        added = True
+
+                for keyword in keywords_site_proposition_mesh_terms:
+                    if added:
+                        break
+                    if mesh_term_present(article_mesh_terms, keyword):
+                        add_num_token(article_date, start_date, num_tokens, counts, tokenizer_name, "SITE PROPOSITION")
+                        add_num_token(article_date, start_date, num_tokens, counts, tokenizer_name, "PROPOSITION")
+                        added = True
+
+                for keyword in keywords_proposition_mesh_terms:
+                    if added:
+                        break
+                    if mesh_term_present(article_mesh_terms, keyword):
+                        add_num_token(article_date, start_date, num_tokens, counts, tokenizer_name, "PROPOSITION")
+                        added = True
+
         i += 1
 
-    for interval in INTERVALS:
-        counts[tokenizer_name][interval] = [val for _, val in counts[tokenizer_name][interval].items()]
+    for category in CATEGORIES:
+        for interval in INTERVALS:
+            counts[tokenizer_name][category][interval] = [val for _, val in counts[tokenizer_name][category][interval].items()]
 
-        counts[tokenizer_name][interval] = {
-            "min": min(counts[tokenizer_name][interval]),
-            "max": max(counts[tokenizer_name][interval]),
-            "mean": statistics.mean(counts[tokenizer_name][interval])
-        }
+            counts[tokenizer_name][category][interval] = {
+                "min": min(counts[tokenizer_name][category][interval]),
+                "max": max(counts[tokenizer_name][category][interval]),
+                "mean": statistics.mean(counts[tokenizer_name][category][interval])
+            }
 
 with open(f"{DOC_DIR}/token_count.json", "w") as json_file:
     json.dump(counts, json_file, indent=4)
-- 
GitLab