Skip to content
Snippets Groups Projects
Commit 158da3fd authored by ivan.pavlovic's avatar ivan.pavlovic
Browse files

Modification du systeme de teste

parent 2ffce2dd
Branches
No related tags found
No related merge requests found
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -8,9 +8,9 @@ from api.parser.jsonParser import parseJsonFile ...@@ -8,9 +8,9 @@ from api.parser.jsonParser import parseJsonFile
import time import time
import statistics import statistics
from model.HuggingFace.zero_shot_classification import create_classifier, classify, MODELS from model.HuggingFace.zero_shot_classification import create_classifier, classify, MODELS, LABELS
TRESHOLD = 0.6 TRESHOLD = 0.7
KEYWORDS = [ KEYWORDS = [
"Diabetes", "Diabetes",
...@@ -18,10 +18,7 @@ KEYWORDS = [ ...@@ -18,10 +18,7 @@ KEYWORDS = [
"Diabetes type 2" "Diabetes type 2"
] ]
model = MODELS[6] def predict(classifier, article, file):
classifier = create_classifier(model)
def predict(article, file):
pmid = article["PMID"] pmid = article["PMID"]
title = article["Title"] title = article["Title"]
articleTitle = article["ArticleTitle"] articleTitle = article["ArticleTitle"]
...@@ -37,8 +34,23 @@ def predict(article, file): ...@@ -37,8 +34,23 @@ def predict(article, file):
return classify(classifier, title + " \n " + articleTitle + " \n " + abstract) return classify(classifier, title + " \n " + articleTitle + " \n " + abstract)
def predict_label(classifier, article, label):
title = article["Title"]
articleTitle = article["ArticleTitle"]
abstract = article["Abstract"]
return classify(classifier, title + " \n " + articleTitle + " \n " + abstract, lables=[label])
def test_model(model):
classifier = create_classifier(model)
filename = model.replace(" ", "_").replace("-", "_").replace(".", "_").replace("/", "-") filename = model.replace(" ", "_").replace("-", "_").replace(".", "_").replace("/", "-")
with open(f"./results/{filename}.txt", "w") as file: with open(f"./results/v3/{filename}.txt", "w") as file:
print("---------------------------------", file=file)
print(f"MODEL: {model}", file=file)
print(f"TRESHOLD: {TRESHOLD}", file=file)
print("---------------------------------", file=file)
predictions = [] predictions = []
selected_scores = [] selected_scores = []
...@@ -57,33 +69,139 @@ with open(f"./results/{filename}.txt", "w") as file: ...@@ -57,33 +69,139 @@ with open(f"./results/{filename}.txt", "w") as file:
for article in articles: for article in articles:
print("---------------------------------", file=file) print("---------------------------------", file=file)
results = predict(article, file) # ---------------------------------
# results = predict(classifier, article, file)
print(f"Labels: {results["labels"]}", file=file)
print(f"Scores: {results["scores"]}", file=file) # print(f"Labels: {results["labels"]}", file=file)
# print(f"Scores: {results["scores"]}", file=file)
# selected_labels = []
# for id, score in enumerate(results["scores"]):
# if score >= TRESHOLD:
# selected_labels.append(results["labels"][id])
# selected_scores.append(score)
# print(f"Selected labels: {selected_labels}", file=file)
# good_prediction = False
# for keyword in article["Predictions"]:
# if keyword in selected_labels:
# good_prediction = True
# predictions.append(True)
# elif (
# (keyword in ["Diabetes type 1", "Diabetes type 2"] and "Diabetes" in selected_labels)
# ):
# good_prediction = True
# predictions.append(True)
# else:
# predictions.append(False)
# # predictions.append( True if keyword not in selected_labels else False)
# if good_prediction:
# one_per_entrie_count += 1
# ---------------------------------
# ---------------------------------
print(f"PMID: {article["PMID"]}", file=file)
pred = article["Predictions"]
print(f"Predictions: {pred}", file=file)
print(f"MeshTerm: {article["MeshTerms"]}", file=file)
good_prediction = False
selected_labels = [] selected_labels = []
for id, score in enumerate(results["scores"]): for label in LABELS:
if score >= TRESHOLD:
selected_labels.append(results["labels"][id])
selected_scores.append(score)
print(f"Selected labels: {selected_labels}", file=file) results = predict_label(classifier, article, label)
good_prediction = False print(f"Labels: {results["labels"]}", file=file)
print(f"Scores: {results["scores"]}", file=file)
for keyword in article["Predictions"]: if results["scores"][0] > TRESHOLD:
if keyword in selected_labels: selected_labels.append(label)
selected_scores.append(results["scores"][0])
if label in pred:
good_prediction = True good_prediction = True
predictions.append(True) predictions.append(True)
pred.remove(label)
else: else:
predictions.append(False) predictions.append(False)
# predictions.append( True if keyword not in selected_labels else False)
for label in pred:
predictions.append(False)
if good_prediction: if good_prediction:
one_per_entrie_count += 1 one_per_entrie_count += 1
print(f"Selected labels: {selected_labels}", file=file)
# ---------------------------------
# ---------------------------------
# results = predict(classifier, article, file)
# print("Multi-label classification:", file=file)
# print(f"Labels: {results["labels"]}", file=file)
# print(f"Scores: {results["scores"]}", file=file)
# results_multi = []
# for id, label in enumerate(results["labels"]):
# results_multi.append({
# "label": label,
# "score": results["scores"][id]
# })
# print("Solo-label classification:", file=file)
# results_solo = []
# for label in LABELS:
# results = predict_label(classifier, article, label)
# print(f"Labels: {results["labels"]}", file=file)
# print(f"Scores: {results["scores"]}", file=file)
# results_solo.append({
# "label": label,
# "score": results["scores"][0]
# })
# print("Combination: ", file=file)
# results = []
# for result_multi in results_multi:
# for result_solo in results_solo:
# if result_multi["label"] == result_solo["label"]:
# results.append({
# "label": result_multi["label"],
# "score": (result_multi["score"] + result_solo["score"]) / 2
# })
# selected_labels = []
# good_prediction = False
# pred = article["Predictions"]
# for result in results:
# print(f"\"{result["label"]}\": {result["score"]}", file=file)
# if result["score"] >= TRESHOLD:
# selected_labels.append(result["label"])
# if label in pred:
# good_prediction = True
# predictions.append(True)
# pred.remove(label)
# else:
# predictions.append(False)
# for label in pred:
# predictions.append(False)
# if good_prediction:
# one_per_entrie_count += 1
# print(f"Selected labels: {selected_labels}", file=file)
# ---------------------------------
print("---------------------------------", file=file) print("---------------------------------", file=file)
end = time.time() end = time.time()
...@@ -103,3 +221,6 @@ with open(f"./results/{filename}.txt", "w") as file: ...@@ -103,3 +221,6 @@ with open(f"./results/{filename}.txt", "w") as file:
print(f"Mean score: {mean_score}", file=file) print(f"Mean score: {mean_score}", file=file)
print(f"Median score: {median_score}", file=file) print(f"Median score: {median_score}", file=file)
print() print()
for model in MODELS:
test_model(model)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment