Skip to content
Snippets Groups Projects
Select Git revision
  • 9a54a40006c994d4deb78d2ac33b6ad376c5e8b9
  • master default protected
2 results

fine_tuning_wrapper.py

Blame
  • fine_tuning_wrapper.py 1005 B
    from transformers import pipeline
    import os
    import sys
    import torch
    import time
    
    sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
    
    from variables.pubmed import NCDS
    
    MODEL_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./model"))
    
    def create_classifier_fine_tunning():
        print(f" CUDA available: {torch.cuda.is_available()}")
        print(f"CUDA version: {torch.version.cuda}")
        print(f"GPUs number: {torch.cuda.device_count()}")
        device = 0 if torch.cuda.is_available() else -1
        return pipeline("text-classification", model=MODEL_DIR, tokenizer=MODEL_DIR, device=device, return_all_scores=True, truncation=True)
    
    def classify_fine_tunning(classifier, sequence, labels, threshold):
    
        start = time.time()
        results = classifier(sequence)
        end = time.time()
    
        print(results)
    
        predictions = {}
    
        for i, score in enumerate(results[0]):
            predictions[labels[i]] = score["score"] > threshold
    
        return predictions, end - start