Skip to content
Snippets Groups Projects
Commit b717aba1 authored by Ivan Pavlovich's avatar Ivan Pavlovich
Browse files

Modification et teste du script de fine tunning de facebook/bart-large-mnli....

Modification et teste du script de fine tunning de facebook/bart-large-mnli. Pas concluant, à cause du temps de fine tunnong
parent c038693b
No related branches found
No related tags found
No related merge requests found
...@@ -21,6 +21,7 @@ from variables.pubmed import NCDS, NCDS_MESH_TERM ...@@ -21,6 +21,7 @@ from variables.pubmed import NCDS, NCDS_MESH_TERM
DATASET_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../testModel/dataset")) DATASET_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../testModel/dataset"))
TMP_DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../dataSources/PubMed/tmp")) TMP_DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../dataSources/PubMed/tmp"))
DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./data")) DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./data"))
LOGS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./logs"))
# Get PMIDs for all the articles from my testing dataset # Get PMIDs for all the articles from my testing dataset
dataset_pmids = [] dataset_pmids = []
...@@ -46,6 +47,8 @@ print(len(data)) ...@@ -46,6 +47,8 @@ print(len(data))
articles = [article for article in data if article["PMID"] not in dataset_pmids] articles = [article for article in data if article["PMID"] not in dataset_pmids]
articles = articles[:1000]
print(len(articles)) print(len(articles))
dataset = [] dataset = []
...@@ -89,15 +92,25 @@ print(len(tokenized_dataset)) ...@@ -89,15 +92,25 @@ print(len(tokenized_dataset))
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir = DATA_DIR, # Output directory output_dir = DATA_DIR, # Output directory
num_train_epochs = 32, # Total number of training epochs num_train_epochs = 32, # Total number of training epochs
per_device_train_batch_size = 16, # Batch size per device during training per_device_train_batch_size = 12, # Batch size per device during training
warmup_steps = 500, # Number of warmup steps for learning rate scheduler warmup_steps = 500, # Number of warmup steps for learning rate scheduler
weight_decay = 0.01, # Strength of weight decay weight_decay = 0.01, # Strength of weight decay
logging_dir = LOGS_DIR,
logging_steps=10,
fp16=True,
remove_unused_columns=False,
) )
model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli", num_labels = len(NCDS), ignore_mismatched_sizes=True)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}") print(f"Using device: {device}")
model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli", num_labels = len(NCDS), ignore_mismatched_sizes=True).to(device) if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
model = torch.nn.DataParallel(model)
model.to(device)
trainer = Trainer( trainer = Trainer(
model = model, # The instantiated model to be trained model = model, # The instantiated model to be trained
...@@ -108,6 +121,13 @@ trainer = Trainer( ...@@ -108,6 +121,13 @@ trainer = Trainer(
trainer.train() trainer.train()
# If the model is wrapped in DataParallel, access the underlying model using `model.module`
model_to_save = model.module if isinstance(model, torch.nn.DataParallel) else model
# Save the fine-tuned model
model_to_save.save_pretrained(DATA_DIR)
tokenizer.save_pretrained(DATA_DIR)
title = "Consumption of Sodium and Its Ratio to Potassium in Relation to All-Cause, Cause-Specific, and Premature Noncommunicable Disease Mortality in Middle-Aged Japanese Adults: A Prospective Cohort Study." title = "Consumption of Sodium and Its Ratio to Potassium in Relation to All-Cause, Cause-Specific, and Premature Noncommunicable Disease Mortality in Middle-Aged Japanese Adults: A Prospective Cohort Study."
abstract = "Reducing premature noncommunicable disease (NCD) mortality is a global challenge. Sodium is thought to increase risk of NCDs via an effect of salt per se or high-salt foods on hypertension-induced cardiovascular disease (CVD) and gastrointestinal cancer. Further, relative risk of CVD is reportedly more closely associated with sodium-to-potassium ratio than that with sodium alone. However, few studies have investigated the effect of consumption of sodium or its ratio to consumption of potassium on risk of premature NCD death.We examined associations between intake of sodium and sodium-to-potassium ratio and risk of all-cause and cause-specific death, including premature NCD, in a Japanese prospective cohort study.During 1995-1998, a validated food frequency questionnaire was administered in 11 areas to 83,048 men and women aged 45-74 y. During 1,587,901 person-years of follow-up until the end of 2018, 17,727 all-cause deaths and 3555 premature NCD deaths were identified.Higher sodium intake was significantly associated with increased risk of all-cause and premature NCD mortality, but not all NCD mortality, among men: multivariate hazards ratios for the highest compared with lowest quintiles (HR) were 1.11 (95% CI: 1.03, 1.20; P-trend < 0.01) for all-cause and 1.25 (95% CI: 1.06, 1.47; P-trend < 0.01) for premature NCD mortality. When intakes were expressed as ratio to potassium intake, these associations (HR of all-cause: 1.19, 95% CI: 1.11-1.27; P-trend < 0.01; HR of premature NCD: 1.27, 95% CI: 1.10, 1.46; P-trend < 0.01), including associations with cancers (HR: 1.18, 95% CI: 1.07, 1.31; P-trend = 0.02), were strengthened in men.This prospective cohort study showed that both sodium intake and sodium-to-potassium ratio are associated with increased risk of all-cause and early NCD mortality in middle-aged men." abstract = "Reducing premature noncommunicable disease (NCD) mortality is a global challenge. Sodium is thought to increase risk of NCDs via an effect of salt per se or high-salt foods on hypertension-induced cardiovascular disease (CVD) and gastrointestinal cancer. Further, relative risk of CVD is reportedly more closely associated with sodium-to-potassium ratio than that with sodium alone. However, few studies have investigated the effect of consumption of sodium or its ratio to consumption of potassium on risk of premature NCD death.We examined associations between intake of sodium and sodium-to-potassium ratio and risk of all-cause and cause-specific death, including premature NCD, in a Japanese prospective cohort study.During 1995-1998, a validated food frequency questionnaire was administered in 11 areas to 83,048 men and women aged 45-74 y. During 1,587,901 person-years of follow-up until the end of 2018, 17,727 all-cause deaths and 3555 premature NCD deaths were identified.Higher sodium intake was significantly associated with increased risk of all-cause and premature NCD mortality, but not all NCD mortality, among men: multivariate hazards ratios for the highest compared with lowest quintiles (HR) were 1.11 (95% CI: 1.03, 1.20; P-trend < 0.01) for all-cause and 1.25 (95% CI: 1.06, 1.47; P-trend < 0.01) for premature NCD mortality. When intakes were expressed as ratio to potassium intake, these associations (HR of all-cause: 1.19, 95% CI: 1.11-1.27; P-trend < 0.01; HR of premature NCD: 1.27, 95% CI: 1.10, 1.46; P-trend < 0.01), including associations with cancers (HR: 1.18, 95% CI: 1.07, 1.31; P-trend = 0.02), were strengthened in men.This prospective cohort study showed that both sodium intake and sodium-to-potassium ratio are associated with increased risk of all-cause and early NCD mortality in middle-aged men."
sequence = title + abstract sequence = title + abstract
...@@ -116,8 +136,15 @@ inputs = tokenizer(sequence, return_tensors="pt").to(device) ...@@ -116,8 +136,15 @@ inputs = tokenizer(sequence, return_tensors="pt").to(device)
outputs = model(**inputs) outputs = model(**inputs)
print(outputs) # Apply softmax to get probabilities for all the labels
logits = outputs.logits
probs = F.softmax(logits, dim=-1) # Softmax over the last dimension (the logits for each class)
predicted_label_id = outputs.logits.argmax(-1).item() # Print out the probabilities for each label
probs = probs.squeeze() # Remove the extra batch dimension
for i, prob in enumerate(probs):
print(f"Label {i}: {prob.item():.4f}")
# Get the predicted label id
predicted_label_id = probs.argmax(-1).item()
print(f"Predicted category: {predicted_label_id}") print(f"Predicted category: {predicted_label_id}")
\ No newline at end of file
#!/bin/sh
#SBATCH --cpus-per-task=4
#SBATCH --job-name=finetunning
#SBATCH --time=12:00:00
#SBATCH --partition=shared-gpu
#SBATCH --gres=gpu:1,VramPerGpu:25G
#SBATCH --mem=32G
#SBATCH --ntasks=1 # One task, multiple GPUs
#SBATCH --nodes=1 # Use a single node
module load CUDA/12.3.0 GCC/12.2.0 Python/3.10.8 Qt5/5.15.7
source ../../.venv/bin/activate
python3 facebook-bart-large-mnli.py
\ No newline at end of file
#!/bin/bash
#SBATCH --time=01:00:00
#SBATCH --partition=shared-gpu
#SBATCH --gpus=1
module load CUDA/12.3.0 GCC/12.2.0 Python/3.10.8 Qt5/5.15.7
source ../../.venv/bin/activate
python3 ./facebook-bart-large-mnli.py
\ No newline at end of file
from transformers import BartForSequenceClassification, BartTokenizerFast
import torch.nn.functional as F
import os
import torch
DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./data"))
model = BartForSequenceClassification.from_pretrained(DATA_DIR)
tokenizer = BartTokenizerFast.from_pretrained(DATA_DIR)
# Now, you can reuse the model for inference
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
title = "Consumption of Sodium and Its Ratio to Potassium in Relation to All-Cause, Cause-Specific, and Premature Noncommunicable Disease Mortality in Middle-Aged Japanese Adults: A Prospective Cohort Study."
abstract = "Reducing premature noncommunicable disease (NCD) mortality is a global challenge. Sodium is thought to increase risk of NCDs via an effect of salt per se or high-salt foods on hypertension-induced cardiovascular disease (CVD) and gastrointestinal cancer. Further, relative risk of CVD is reportedly more closely associated with sodium-to-potassium ratio than that with sodium alone. However, few studies have investigated the effect of consumption of sodium or its ratio to consumption of potassium on risk of premature NCD death.We examined associations between intake of sodium and sodium-to-potassium ratio and risk of all-cause and cause-specific death, including premature NCD, in a Japanese prospective cohort study.During 1995-1998, a validated food frequency questionnaire was administered in 11 areas to 83,048 men and women aged 45-74 y. During 1,587,901 person-years of follow-up until the end of 2018, 17,727 all-cause deaths and 3555 premature NCD deaths were identified.Higher sodium intake was significantly associated with increased risk of all-cause and premature NCD mortality, but not all NCD mortality, among men: multivariate hazards ratios for the highest compared with lowest quintiles (HR) were 1.11 (95% CI: 1.03, 1.20; P-trend < 0.01) for all-cause and 1.25 (95% CI: 1.06, 1.47; P-trend < 0.01) for premature NCD mortality. When intakes were expressed as ratio to potassium intake, these associations (HR of all-cause: 1.19, 95% CI: 1.11-1.27; P-trend < 0.01; HR of premature NCD: 1.27, 95% CI: 1.10, 1.46; P-trend < 0.01), including associations with cancers (HR: 1.18, 95% CI: 1.07, 1.31; P-trend = 0.02), were strengthened in men.This prospective cohort study showed that both sodium intake and sodium-to-potassium ratio are associated with increased risk of all-cause and early NCD mortality in middle-aged men."
sequence = title + abstract
inputs = tokenizer(sequence, return_tensors="pt").to(device)
outputs = model(**inputs)
# Apply softmax to get probabilities for all the labels
logits = outputs.logits
probs = F.softmax(logits, dim=-1) # Softmax over the last dimension (the logits for each class)
# Print out the probabilities for each label
probs = probs.squeeze() # Remove the extra batch dimension
for i, prob in enumerate(probs):
print(f"Label {i}: {prob.item():.4f}")
# Get the predicted label id
predicted_label_id = probs.argmax(-1).item()
print(f"Predicted category: {predicted_label_id}")
\ No newline at end of file
...@@ -2,47 +2,45 @@ import argparse ...@@ -2,47 +2,45 @@ import argparse
import os import os
import sys import sys
CLUSTERS = { REMOTES = {
"bamboo": "hpc-cluster-bamboo", "bamboo": { "remote": "hpc-cluster-bamboo", "destination":"/home/users/p/pavlovii/NCD-Project"},
"baobab": "hpc-cluster-baobab", "baobab": { "remote": "hpc-cluster-baobab", "destination":"/home/users/p/pavlovii/NCD-Project"},
"yggdrasil": "hpc-cluster-yggdrasil" "yggdrasil": { "remote": "hpc-cluster-yggdrasil", "destination":"/home/users/p/pavlovii/NCD-Project"},
"phenix": { "remote": "phenix-service", "destination": "/home/HES/ivan.pavlovic/NCD-Project"},
"laptop": { "remote": "anthoine-laptop", "destination": "/home/guest/Documents/NCD-Project"},
"tower": { "remote": "anthoine", "destination": "/home/guest/Documents/NCD-Project"}
} }
parser = argparse.ArgumentParser(description="Sync project files to an HPC cluster.") parser = argparse.ArgumentParser(description="Sync project files to a remote.")
parser.add_argument("--bamboo", action="store_true", help="Use hpc-cluster-bamboo")
parser.add_argument("--baobab", action="store_true", help="Use hpc-cluster-baobab")
parser.add_argument("--yggdrasil", action="store_true", help="Use hpc-cluster-yggdrasil")
if len(sys.argv) == 1: parser.add_argument(
parser.print_help() "--remote",
exit(1) type=str,
default="bamboo",
help="Specify the remote to use (default: bamboo)"
)
args = parser.parse_args() args = parser.parse_args()
selected_cluster = None selected_remote = REMOTES[args.remote]
for cluster_name, cluster_host in CLUSTERS.items():
if getattr(args, cluster_name):
selected_cluster = cluster_host
break
if not selected_cluster: if selected_remote is None:
print("Error: No cluster selected. Use --bamboo, --baobab, or --yggdrasil.") print(f"Error: No cluster selected. Choises: {[remote for remote in REMOTES.keys()]}")
parser.print_help() parser.print_help()
exit(1) exit(1)
PROJECT_PWD = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) PROJECT_PWD = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
DEST_PATH = f"/home/users/p/pavlovii/NCD-Project"
commands = [ commands = [
f"rclone copy {PROJECT_PWD}/models {selected_cluster}:{DEST_PATH}/models -P", f"rclone copy {PROJECT_PWD}/models {selected_remote["remote"]}:{selected_remote["destination"]}/models -P",
f"rclone copy {PROJECT_PWD}/testModel {selected_cluster}:{DEST_PATH}/testModel -P", f"rclone copy {PROJECT_PWD}/testModel {selected_remote["remote"]}:{selected_remote["destination"]}/testModel -P",
f"rclone copy {PROJECT_PWD}/parsers {selected_cluster}:{DEST_PATH}/parsers -P", f"rclone copy {PROJECT_PWD}/parsers {selected_remote["remote"]}:{selected_remote["destination"]}/parsers -P",
f"rclone copy {PROJECT_PWD}/variables {selected_cluster}:{DEST_PATH}/variables -P", f"rclone copy {PROJECT_PWD}/variables {selected_remote["remote"]}:{selected_remote["destination"]}/variables -P",
f"rclone copy {PROJECT_PWD}/dataSources {selected_cluster}:{DEST_PATH}/dataSources -P" f"rclone copy {PROJECT_PWD}/dataSources {selected_remote["remote"]}:{selected_remote["destination"]}/dataSources -P"
] ]
for cmd in commands: for cmd in commands:
print(f"-> Running: {cmd}") print(f"-> Running: {cmd}")
os.system(cmd) os.system(cmd)
print(f"Files successfully copied to {selected_cluster}:{DEST_PATH}") print(f"Files successfully copied to {selected_remote["remote"]}:{selected_remote["destination"]}")
\ No newline at end of file \ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment