Skip to content
Snippets Groups Projects
Commit f9da9f5b authored by ivan.pavlovic's avatar ivan.pavlovic
Browse files

Test sur les LLM et les zero-shot avec le nouveau dataset

parent f1f4e18e
No related branches found
No related tags found
No related merge requests found
---------------------------------
MODEL: microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract
TRESHOLD: 0.7
---------------------------------
---------------------------------
PMID: 39737510
Predictions: ['Noncommunicable Diseases']
MeshTerm: ['Humans', 'Adult', 'Middle Aged', 'India', 'Male', 'Female', 'Noncommunicable Diseases', 'Risk Factors', 'Aged', 'Poverty Areas', 'Adolescent', 'Prevalence', 'Smoking', 'Alcohol Drinking', 'Obesity', 'Cross-Sectional Studies', 'Young Adult', 'Urban Population']
import sys
import os
# Ajouter le répertoire parent au chemin de recherche
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from api.parser.jsonParser import parseJsonFile
import time
from model.cohereCommand import cohere_classify
LABELS = [
"Noncommunicable Diseases",
"Diabetes",
"Cancer",
"Chronic respiratory disease",
"Cardiovascular diseases",
"Mental Health",
"Diabetes type 1",
"Diabetes type 2"
]
def predict(article):
title = article["Title"]
articleTitle = article["ArticleTitle"]
abstract = article["Abstract"]
if not isinstance(articleTitle, str):
articleTitle = ""
if not isinstance(abstract, str):
abstract = ""
return cohere_classify(title + articleTitle + abstract, LABELS)
def get_wanted_predictions(article):
wanted = {}
for label in LABELS:
wanted[label] = label in article["Predictions"]
return wanted
def confusion_matrix(wanted, prediction):
matrix = [[0, 0], [0, 0]]
for key in wanted.keys():
if wanted[key]:
if prediction[key]:
matrix[0][0] += 1
else:
matrix[1][0] += 1
else:
if prediction[key]:
matrix[0][1] += 1
else:
matrix[1][1] += 1
return matrix
def test_model():
with open(f"./results/cohere/v1.txt", "w") as file:
print("---------------------------------", file=file)
print(f"MODEL: Cohere", file=file)
print("---------------------------------", file=file)
nb_entries = 0
result_matrix = [[0, 0], [0, 0]]
start = time.time()
for label in LABELS:
try:
filename = label.replace(" ", "_").replace(",", "").lower()
print(filename)
articles = parseJsonFile(f"./dataset/{filename}.json")
# nb_entries += len(articles)
except Exception as e:
print(f"Error: {e}")
for article in articles:
nb_entries += 1
print("---------------------------------", file=file)
wanted = get_wanted_predictions(article)
print(f"PMID: {article["PMID"]}")
print(f"PMID: {article["PMID"]}", file=file)
pred = article["Predictions"]
print(f"Predictions: {pred}", file=file)
print(f"MeshTerm: {article["MeshTerms"]}", file=file)
# try:
results = predict(article)
# except :
# end = time.time()
# print("Time: ", end-start)
# print("Nb entries: ", nb_entries)
predictions = {}
selected_labels = results["labels"]
for label in LABELS:
predictions[label] = label in selected_labels
print(f"Wanted: {wanted}", file=file)
print(f"Predicted: {predictions}", file=file)
print(f"Selected labels: {selected_labels}", file=file)
matrix = confusion_matrix(wanted, predictions)
print(f"Confusion matrix: {matrix}", file=file)
for i in range(2):
for j in range(2):
result_matrix[i][j] += matrix[i][j]
print("---------------------------------", file=file)
time.sleep(6)
end = time.time()
print(f"Time to classify all articles: {end-start} seconds", file=file)
print(f"Result confusion matrix: {result_matrix}", file=file)
tpr = result_matrix[0][0] / (result_matrix[0][0] + result_matrix[1][0])
tnr = result_matrix[1][1] / (result_matrix[1][1] + result_matrix[0][1])
precision = result_matrix[0][0] / (result_matrix[0][0] + result_matrix[0][1])
print(f"True Positive Rate (TPR): {tpr}", file=file)
print(f"True Negative Rate (TNR): {tnr}", file=file)
print(f"Precision: {precision}", file=file)
print()
test_model()
\ No newline at end of file
......@@ -7,7 +7,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from api.parser.jsonParser import parseJsonFile
import time
from model.gemini import gemini_classify
from model.gemini import gemini_classify, gemini_start_chat
LABELS = [
"Noncommunicable Diseases",
......@@ -20,7 +20,7 @@ LABELS = [
"Diabetes type 2"
]
def predict(article):
def predict(chat, article):
title = article["Title"]
articleTitle = article["ArticleTitle"]
abstract = article["Abstract"]
......@@ -31,7 +31,7 @@ def predict(article):
if not isinstance(abstract, str):
abstract = ""
return gemini_classify(title + articleTitle + abstract, LABELS)
return gemini_classify(chat, title + articleTitle + abstract, LABELS)
def get_wanted_predictions(article):
wanted = {}
......@@ -57,7 +57,7 @@ def confusion_matrix(wanted, prediction):
return matrix
def test_model():
with open(f"./results/gemini/v1.txt", "w") as file:
with open(f"./results/gemini/v2.txt", "w") as file:
print("---------------------------------", file=file)
print(f"MODEL: Gemini", file=file)
......@@ -67,27 +67,36 @@ def test_model():
result_matrix = [[0, 0], [0, 0]]
start = time.time()
chat = gemini_start_chat()
for label in LABELS:
try:
filename = label.replace(" ", "_").replace(",", "").lower()
print(filename)
articles = parseJsonFile(f"./data/{filename}.json")
nb_entries += len(articles)
articles = parseJsonFile(f"./dataset/{filename}.json")
# nb_entries += len(articles)
except Exception as e:
print(f"Error: {e}")
for article in articles:
nb_entries += 1
print("---------------------------------", file=file)
wanted = get_wanted_predictions(article)
print(f"PMID: {article["PMID"]}")
print(f"PMID: {article["PMID"]}", file=file)
pred = article["Predictions"]
print(f"Predictions: {pred}", file=file)
print(f"MeshTerm: {article["MeshTerms"]}", file=file)
results = predict(article)
try:
results = predict(chat, article)
except:
end = time.time()
print("Time: ", end-start)
print("Nb entries: ", nb_entries)
predictions = {}
......@@ -109,7 +118,7 @@ def test_model():
print("---------------------------------", file=file)
time.sleep(1)
# time.sleep(1)
end = time.time()
......
......@@ -8,9 +8,9 @@ from api.parser.jsonParser import parseJsonFile
import time
import statistics
from model.HuggingFace.zero_shot_classification import create_classifier, classify, MODELS, LABELS
from model.HuggingFace.zero_shot_classification import create_classifier, classify, MODELS
TRESHOLD = 0.6
TRESHOLD = 0.7
KEYWORDS = [
"Diabetes",
......@@ -18,6 +18,17 @@ KEYWORDS = [
"Diabetes type 2"
]
LABELS = [
"Noncommunicable Diseases",
"Diabetes",
"Cancer",
"Chronic respiratory disease",
"Cardiovascular diseases",
"Mental Health",
"Diabetes type 1",
"Diabetes type 2"
]
def predict(classifier, article, file):
pmid = article["PMID"]
title = article["Title"]
......@@ -80,7 +91,7 @@ def confusion_matrix(wanted, prediction):
def test_model(model):
classifier = create_classifier(model)
filename = model.replace(" ", "_").replace("-", "_").replace(".", "_").replace("/", "-")
with open(f"./results/v6_2/{filename}.txt", "w") as file:
with open(f"./results/zero_shot/v1/{filename}.txt", "w") as file:
print("---------------------------------", file=file)
print(f"MODEL: {model}", file=file)
......@@ -91,13 +102,12 @@ def test_model(model):
result_matrix = [[0, 0], [0, 0]]
start = time.time()
for key in KEYWORDS:
for label in LABELS:
try:
# filename = key.replace(" ", "_").replace(",", "").lower()
filename = key.replace(" ", "_").lower()
filename = label.replace(" ", "_").replace(",", "").lower()
print(filename)
articles = parseJsonFile(f"./data/{filename}.json")
articles = parseJsonFile(f"./dataset/{filename}.json")
print(articles)
nb_entries += len(articles)
except Exception as e:
......@@ -113,10 +123,6 @@ def test_model(model):
print(f"Predictions: {pred}", file=file)
print(f"MeshTerm: {article["MeshTerms"]}", file=file)
results_multi = predict(classifier, article, file)
print(f"Labels: {results_multi["labels"]}", file=file)
print(f"Scores: {results_multi["scores"]}", file=file)
predictions = {}
selected_labels = []
......@@ -127,11 +133,9 @@ def test_model(model):
print(f"Labels: {results["labels"]}", file=file)
print(f"Scores: {results["scores"]}", file=file)
id = results_multi["labels"].index(label)
predictions[label] = (results["scores"][0] + results_multi["scores"][id]) / 2 > TRESHOLD
predictions[label] = results["scores"][0] > TRESHOLD
if (results["scores"][0] + results_multi["scores"][id]) / 2 > TRESHOLD:
if results["scores"][0] > TRESHOLD:
selected_labels.append(label)
print(f"Wanted: {wanted}", file=file)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment