Select Git revision
fine_tuning_wrapper.py
fine_tuning_wrapper.py 1005 B
from transformers import pipeline
import os
import sys
import torch
import time
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
from variables.pubmed import NCDS
MODEL_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./model"))
def create_classifier_fine_tunning():
print(f" CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"GPUs number: {torch.cuda.device_count()}")
device = 0 if torch.cuda.is_available() else -1
return pipeline("text-classification", model=MODEL_DIR, tokenizer=MODEL_DIR, device=device, return_all_scores=True, truncation=True)
def classify_fine_tunning(classifier, sequence, labels, threshold):
start = time.time()
results = classifier(sequence)
end = time.time()
print(results)
predictions = {}
for i, score in enumerate(results[0]):
predictions[labels[i]] = score["score"] > threshold
return predictions, end - start