Skip to content
Snippets Groups Projects
Commit 88696137 authored by Ivan Pavlovich's avatar Ivan Pavlovich
Browse files

Remade the test script to add easily new models

parent e8e74d8f
No related branches found
No related tags found
No related merge requests found
Showing
with 42661 additions and 8 deletions
...@@ -36,12 +36,19 @@ def create_classifier(model = MODELS[0]): ...@@ -36,12 +36,19 @@ def create_classifier(model = MODELS[0]):
return pipeline("zero-shot-classification", model=model, device=device) return pipeline("zero-shot-classification", model=model, device=device)
def classify(classifier, sequence, lables = LABELS, debug = False): def classify(classifier, sequence, lables, treshold, debug = False):
results = classifier(sequence, lables) predictions = {}
if debug: if debug:
print(f"Sequence: {sequence}") print(f"Sequence: {sequence}")
for label in lables:
results = classifier(sequence, [label])
if debug:
print(f"Labels: {results['labels']}") print(f"Labels: {results['labels']}")
print(f"Scores: {results['scores']}") print(f"Scores: {results['scores']}")
return results predictions[label] = results["scores"][0] > treshold
return predictions
...@@ -22,10 +22,19 @@ def add_confusion_matrices(confusion_matrix, tmp_confusion_matrix): ...@@ -22,10 +22,19 @@ def add_confusion_matrices(confusion_matrix, tmp_confusion_matrix):
return confusion_matrix return confusion_matrix
def get_tpr(confusion_matrix): def get_tpr(confusion_matrix):
return confusion_matrix[0][0] / (confusion_matrix[0][0] + confusion_matrix[1][0]) denominator = (confusion_matrix[0][0] + confusion_matrix[1][0])
if denominator == 0:
return 0
return confusion_matrix[0][0] / denominator
def get_tnr(confusion_matrix): def get_tnr(confusion_matrix):
return confusion_matrix[1][1] / (confusion_matrix[1][1] + confusion_matrix[0][1]) denominator = (confusion_matrix[1][1] + confusion_matrix[0][1])
if denominator == 0:
return 0
return confusion_matrix[1][1] / denominator
def get_precision(confusion_matrix): def get_precision(confusion_matrix):
return confusion_matrix[0][0] / (confusion_matrix[0][0] + confusion_matrix[0][1]) denominator = (confusion_matrix[0][0] + confusion_matrix[0][1])
\ No newline at end of file if denominator == 0:
return 0
return confusion_matrix[0][0] / denominator
\ No newline at end of file
File moved
File moved
File moved
File moved
File moved
File moved
Source diff could not be displayed: it is too large. Options to address this: view the blob.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment