Skip to content
Snippets Groups Projects
Commit 2ffce2dd authored by ivan.pavlovic's avatar ivan.pavlovic
Browse files

Ajout de 4 model

parent 66dcab41
No related branches found
No related tags found
No related merge requests found
Showing
with 1694 additions and 203 deletions
# https://huggingface.co/facebook/bart-large-mnli
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
candidate_labels = [
"Diabetes",
"Cancer",
"Chronic respiratory disease",
"Cardiovascular diseases",
"Mental Health",
"Diabetes type 1",
"Diabetes type 2"
]
def classify(sequence):
results = classifier(sequence, candidate_labels)
return results
# print(f"Sequence: {sequence_to_classify}")
# print(f"Labels: {results["labels"]}")
# print(f"Scores: {results["scores"]}")
# https://huggingface.co/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract")
candidate_labels = [
"Diabetes",
"Cancer",
"Chronic respiratory disease",
"Cardiovascular diseases",
"Mental Health",
"Diabetes type 1",
"Diabetes type 2"
]
def classify(sequence):
results = classifier(sequence, candidate_labels)
return results
# print(f"Sequence: {sequence_to_classify}")
# print(f"Labels: {results["labels"]}")
# print(f"Scores: {results["scores"]}")
File added
File added
File added
File added
File added
# https://huggingface.co/MoritzLaurer/bge-m3-zeroshot-v2.0
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="MoritzLaurer/bge-m3-zeroshot-v2.0")
candidate_labels = [
"Diabetes",
"Cancer",
"Chronic respiratory disease",
"Cardiovascular diseases",
"Mental Health",
"Diabetes type 1",
"Diabetes type 2"
]
def classify(sequence):
results = classifier(sequence, candidate_labels)
return results
# print(f"Sequence: {sequence_to_classify}")
# print(f"Labels: {results["labels"]}")
# print(f"Scores: {results["scores"]}")
from transformers import pipeline
LABELS = [
"Diabetes",
"Cancer",
"Chronic respiratory disease",
"Cardiovascular diseases",
"Mental Health",
"Diabetes type 1",
"Diabetes type 2"
]
MODELS = [
"facebook/bart-large-mnli", # https://huggingface.co/facebook/bart-large-mnli
"MoritzLaurer/bge-m3-zeroshot-v2.0", # https://huggingface.co/MoritzLaurer/bge-m3-zeroshot-v2.0
"MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
"MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33",
"MoritzLaurer/multilingual-MiniLMv2-L6-mnli-xnli",
"microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract", # https://huggingface.co/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract
]
def create_classifier(model = MODELS[0]):
return pipeline("zero-shot-classification", model=model)
def classify(classifier, sequence, lables = LABELS, debug = False):
results = classifier(sequence, lables)
if debug:
print(f"Sequence: {sequence}")
print(f"Labels: {results["labels"]}")
print(f"Scores: {results["scores"]}")
return results
This diff is collapsed.
...@@ -502,7 +502,7 @@ Labels: ['Cardiovascular diseases', 'Diabetes', 'Diabetes type 2', 'Cancer', 'Di ...@@ -502,7 +502,7 @@ Labels: ['Cardiovascular diseases', 'Diabetes', 'Diabetes type 2', 'Cancer', 'Di
Scores: [0.38244596123695374, 0.23532404005527496, 0.12610754370689392, 0.08088988065719604, 0.07161928713321686, 0.055826857686042786, 0.04778643324971199] Scores: [0.38244596123695374, 0.23532404005527496, 0.12610754370689392, 0.08088988065719604, 0.07161928713321686, 0.055826857686042786, 0.04778643324971199]
Selected labels: [] Selected labels: []
--------------------------------- ---------------------------------
Time to classify all articles: 374.61803555488586 seconds Time to classify all articles: 380.2531027793884 seconds
Good classification: 12/75 ( 16.0% ) Good classification: 12/75 ( 16.0% )
At least one good per entrie: 12/63 ( 19.05% ) At least one good per entrie: 12/63 ( 19.05% )
Mean score: 0.7297625946998596 Mean score: 0.7297625946998596
......
This diff is collapsed.
This diff is collapsed.
...@@ -502,7 +502,7 @@ Labels: ['Cardiovascular diseases', 'Diabetes type 2', 'Diabetes', 'Chronic resp ...@@ -502,7 +502,7 @@ Labels: ['Cardiovascular diseases', 'Diabetes type 2', 'Diabetes', 'Chronic resp
Scores: [0.34633418917655945, 0.2441960871219635, 0.09421524405479431, 0.09026889503002167, 0.08233360201120377, 0.07539966702461243, 0.06725231558084488] Scores: [0.34633418917655945, 0.2441960871219635, 0.09421524405479431, 0.09026889503002167, 0.08233360201120377, 0.07539966702461243, 0.06725231558084488]
Selected labels: [] Selected labels: []
--------------------------------- ---------------------------------
Time to classify all articles: 418.86657190322876 seconds Time to classify all articles: 420.01833271980286 seconds
Good classification: 31/75 ( 41.33% ) Good classification: 31/75 ( 41.33% )
At least one good per entrie: 31/63 ( 49.21% ) At least one good per entrie: 31/63 ( 49.21% )
Mean score: 0.797397155972088 Mean score: 0.797397155972088
......
...@@ -8,9 +8,7 @@ from api.parser.jsonParser import parseJsonFile ...@@ -8,9 +8,7 @@ from api.parser.jsonParser import parseJsonFile
import time import time
import statistics import statistics
# from model.HuggingFace.BartLargeMnli import classify from model.HuggingFace.zero_shot_classification import create_classifier, classify, MODELS
# from model.HuggingFace.BioMedBERT import classify
from model.HuggingFace.bge_m3_zeroshot_v2 import classify
TRESHOLD = 0.6 TRESHOLD = 0.6
...@@ -20,6 +18,9 @@ KEYWORDS = [ ...@@ -20,6 +18,9 @@ KEYWORDS = [
"Diabetes type 2" "Diabetes type 2"
] ]
model = MODELS[6]
classifier = create_classifier(model)
def predict(article, file): def predict(article, file):
pmid = article["PMID"] pmid = article["PMID"]
title = article["Title"] title = article["Title"]
...@@ -34,11 +35,10 @@ def predict(article, file): ...@@ -34,11 +35,10 @@ def predict(article, file):
print(f"Predictions: {pred}", file=file) print(f"Predictions: {pred}", file=file)
print(f"MeshTerm: {meshTerms}", file=file) print(f"MeshTerm: {meshTerms}", file=file)
return classify(title + " \n " + articleTitle + " \n " + abstract) return classify(classifier, title + " \n " + articleTitle + " \n " + abstract)
# with open("./results/BartLargeMnli.txt", "w") as file: filename = model.replace(" ", "_").replace("-", "_").replace(".", "_").replace("/", "-")
# with open("./results/BiomedBERT.txt", "w") as file: with open(f"./results/{filename}.txt", "w") as file:
with open("./results/bge_m3_zeroshot_v2.txt", "w") as file:
predictions = [] predictions = []
selected_scores = [] selected_scores = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment