Skip to content
Snippets Groups Projects
Commit fa340d40 authored by Ivan Pavlovich's avatar Ivan Pavlovich
Browse files

Testes des modèles sur la meme machine

parent 862cedb4
No related branches found
No related tags found
No related merge requests found
Showing
with 329713 additions and 15 deletions
......@@ -19,7 +19,8 @@ from variables.pubmed import NCDS, NCDS_MESH_TERM
DATASET_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../testModel/dataset"))
TMP_DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../dataSources/PubMed/tmp"))
DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./data"))
SAVE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./save"))
MODEL_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./model"))
LOGS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./logs"))
# Get PMIDs for all the articles from my testing dataset
......@@ -73,15 +74,21 @@ tokenized_dataset = tokenized_dataset.remove_columns(["text"])
tokenized_dataset.set_format("torch")
training_args = TrainingArguments(
output_dir = DATA_DIR,
output_dir = SAVE_DIR,
num_train_epochs = 32,
per_device_train_batch_size = 12,
per_device_train_batch_size = 9,
warmup_steps = 500,
weight_decay = 0.01,
logging_dir = LOGS_DIR,
logging_steps=10,
fp16=True,
remove_unused_columns=False,
save_strategy="epoch",
# evaluation_strategy="epoch",
save_total_limit=3,
# load_best_model_at_end=True,
# metric_for_best_model="f1",
# greater_is_better=True
)
model = BartForSequenceClassification.from_pretrained(
......@@ -94,12 +101,17 @@ model = BartForSequenceClassification.from_pretrained(
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
model = torch.nn.DataParallel(model)
model.to(device)
model = model.module if isinstance(model, torch.nn.DataParallel) else model
trainer = Trainer(
model = model, # The instantiated model to be trained
args = training_args, # Training arguments, defined above
......@@ -107,11 +119,23 @@ trainer = Trainer(
tokenizer = tokenizer # The tokenizer that was used
)
last_checkpoint = None
if os.path.isdir(SAVE_DIR):
from transformers.trainer_utils import get_last_checkpoint
last_checkpoint = get_last_checkpoint(SAVE_DIR)
if last_checkpoint:
print(f"Resuming training from checkpoint: {last_checkpoint}")
trainer.train(resume_from_checkpoint=last_checkpoint)
else:
print("Starting training from scratch")
trainer.train()
trainer.train()
# If the model is wrapped in DataParallel, access the underlying model using `model.module`
model_to_save = model.module if isinstance(model, torch.nn.DataParallel) else model
# Save the fine-tuned model
model_to_save.save_pretrained(DATA_DIR)
tokenizer.save_pretrained(DATA_DIR)
\ No newline at end of file
model_to_save.save_pretrained(MODEL_DIR)
tokenizer.save_pretrained(MODEL_DIR)
\ No newline at end of file
File added
File added
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
No preview for this file type
......@@ -16,14 +16,14 @@ MODELS_CATEGORIES = [
]
MODELS = {
'facebook/bart-large-mnli': {'category': "huggingface", 'predict': classify, 'requiresDependency': True, 'dependency': create_classifier},
'facebook/bart-large-mnli': {'category': "huggingface", 'predict': classify, 'requiresDependency': True, 'dependency': create_classifier}, # Params: 407M, Size: 1.6GB
# 'MoritzLaurer/bge-m3-zeroshot-v2.0': {'category': "huggingface", 'predict': classify, 'requiresDependency': True, 'dependency': create_classifier},
'MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli': {'category': "huggingface", 'predict': classify, 'requiresDependency': True, 'dependency': create_classifier},
'MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33': {'category': "huggingface", 'predict': classify, 'requiresDependency': True, 'dependency': create_classifier},
'MoritzLaurer/multilingual-MiniLMv2-L6-mnli-xnli': {'category': "huggingface", 'predict': classify, 'requiresDependency': True, 'dependency': create_classifier},
'llama3.2': {'category': "ollama", 'predict': ollama.classify, 'requiresDependency': False, 'dependency': None},
'mistral-small': {'category': "ollama", 'predict': ollama.classify, 'requiresDependency': False, 'dependency': None},
'deepseek-v2': {'category': "ollama", 'predict': ollama.classify, 'requiresDependency': False, 'dependency': None},
'gemini-hosted': {'category': "hosted", 'predict': gemini_classify, 'requiresDependency': True, 'dependency': gemini_start_chat},
'cohere-hosted': {'category': "hosted", 'predict': cohere_classify, 'requiresDependency': True, 'dependency': cohere_create_client}
'MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli': {'category': "huggingface", 'predict': classify, 'requiresDependency': True, 'dependency': create_classifier}, # Params: 184M, Size: 369MB
'MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33': {'category': "huggingface", 'predict': classify, 'requiresDependency': True, 'dependency': create_classifier}, # Params: 184M, Size: 369MB
'MoritzLaurer/multilingual-MiniLMv2-L6-mnli-xnli': {'category': "huggingface", 'predict': classify, 'requiresDependency': True, 'dependency': create_classifier}, # Params: 107M, Size: 428MB
'llama3.2': {'category': "ollama", 'predict': ollama.classify, 'requiresDependency': False, 'dependency': None}, # Params: 3B, Size: 2GB,
'mistral-small': {'category': "ollama", 'predict': ollama.classify, 'requiresDependency': False, 'dependency': None}, # Params: 24B, Size: 14GB
'deepseek-v2': {'category': "ollama", 'predict': ollama.classify, 'requiresDependency': False, 'dependency': None}, # Params: 16B, Size: 8.9GB
'gemini-hosted': {'category': "hosted", 'predict': gemini_classify, 'requiresDependency': True, 'dependency': gemini_start_chat}, # Params: , Input token limit: 1048576, Output token limit: 8192
'cohere-hosted': {'category': "hosted", 'predict': cohere_classify, 'requiresDependency': True, 'dependency': cohere_create_client} # Params: , Context: 128k, Output token limit: 4k
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment