diff --git a/models/FineTuning/facebook-bart-large-mnli.py b/models/FineTuning/facebook-bart-large-mnli.py index d68b4bd35884dac985f4a4a4cd85bc2b527b209e..48e7884c9d21abbd5da1f6caba7f7e0e0d4af424 100644 --- a/models/FineTuning/facebook-bart-large-mnli.py +++ b/models/FineTuning/facebook-bart-large-mnli.py @@ -21,6 +21,7 @@ from variables.pubmed import NCDS, NCDS_MESH_TERM DATASET_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../testModel/dataset")) TMP_DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../dataSources/PubMed/tmp")) DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./data")) +LOGS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./logs")) # Get PMIDs for all the articles from my testing dataset dataset_pmids = [] @@ -46,6 +47,8 @@ print(len(data)) articles = [article for article in data if article["PMID"] not in dataset_pmids] +articles = articles[:1000] + print(len(articles)) dataset = [] @@ -89,15 +92,25 @@ print(len(tokenized_dataset)) training_args = TrainingArguments( output_dir = DATA_DIR, # Output directory num_train_epochs = 32, # Total number of training epochs - per_device_train_batch_size = 16, # Batch size per device during training + per_device_train_batch_size = 12, # Batch size per device during training warmup_steps = 500, # Number of warmup steps for learning rate scheduler weight_decay = 0.01, # Strength of weight decay + logging_dir = LOGS_DIR, + logging_steps=10, + fp16=True, + remove_unused_columns=False, ) +model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli", num_labels = len(NCDS), ignore_mismatched_sizes=True) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") print(f"Using device: {device}") -model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli", num_labels = len(NCDS), ignore_mismatched_sizes=True).to(device) +if torch.cuda.device_count() > 1: + print(f"Using {torch.cuda.device_count()} GPUs!") + model = torch.nn.DataParallel(model) + +model.to(device) trainer = Trainer( model = model, # The instantiated model to be trained @@ -108,6 +121,13 @@ trainer = Trainer( trainer.train() +# If the model is wrapped in DataParallel, access the underlying model using `model.module` +model_to_save = model.module if isinstance(model, torch.nn.DataParallel) else model + +# Save the fine-tuned model +model_to_save.save_pretrained(DATA_DIR) +tokenizer.save_pretrained(DATA_DIR) + title = "Consumption of Sodium and Its Ratio to Potassium in Relation to All-Cause, Cause-Specific, and Premature Noncommunicable Disease Mortality in Middle-Aged Japanese Adults: A Prospective Cohort Study." abstract = "Reducing premature noncommunicable disease (NCD) mortality is a global challenge. Sodium is thought to increase risk of NCDs via an effect of salt per se or high-salt foods on hypertension-induced cardiovascular disease (CVD) and gastrointestinal cancer. Further, relative risk of CVD is reportedly more closely associated with sodium-to-potassium ratio than that with sodium alone. However, few studies have investigated the effect of consumption of sodium or its ratio to consumption of potassium on risk of premature NCD death.We examined associations between intake of sodium and sodium-to-potassium ratio and risk of all-cause and cause-specific death, including premature NCD, in a Japanese prospective cohort study.During 1995-1998, a validated food frequency questionnaire was administered in 11 areas to 83,048 men and women aged 45-74 y. During 1,587,901 person-years of follow-up until the end of 2018, 17,727 all-cause deaths and 3555 premature NCD deaths were identified.Higher sodium intake was significantly associated with increased risk of all-cause and premature NCD mortality, but not all NCD mortality, among men: multivariate hazards ratios for the highest compared with lowest quintiles (HR) were 1.11 (95% CI: 1.03, 1.20; P-trend < 0.01) for all-cause and 1.25 (95% CI: 1.06, 1.47; P-trend < 0.01) for premature NCD mortality. When intakes were expressed as ratio to potassium intake, these associations (HR of all-cause: 1.19, 95% CI: 1.11-1.27; P-trend < 0.01; HR of premature NCD: 1.27, 95% CI: 1.10, 1.46; P-trend < 0.01), including associations with cancers (HR: 1.18, 95% CI: 1.07, 1.31; P-trend = 0.02), were strengthened in men.This prospective cohort study showed that both sodium intake and sodium-to-potassium ratio are associated with increased risk of all-cause and early NCD mortality in middle-aged men." sequence = title + abstract @@ -116,8 +136,15 @@ inputs = tokenizer(sequence, return_tensors="pt").to(device) outputs = model(**inputs) -print(outputs) +# Apply softmax to get probabilities for all the labels +logits = outputs.logits +probs = F.softmax(logits, dim=-1) # Softmax over the last dimension (the logits for each class) -predicted_label_id = outputs.logits.argmax(-1).item() +# Print out the probabilities for each label +probs = probs.squeeze() # Remove the extra batch dimension +for i, prob in enumerate(probs): + print(f"Label {i}: {prob.item():.4f}") +# Get the predicted label id +predicted_label_id = probs.argmax(-1).item() print(f"Predicted category: {predicted_label_id}") \ No newline at end of file diff --git a/models/FineTuning/job_launch.sbatch b/models/FineTuning/job_launch.sbatch new file mode 100644 index 0000000000000000000000000000000000000000..468f265ebf0879d4b4709dc948b19b523c7a9dec --- /dev/null +++ b/models/FineTuning/job_launch.sbatch @@ -0,0 +1,16 @@ +#!/bin/sh + +#SBATCH --cpus-per-task=4 +#SBATCH --job-name=finetunning +#SBATCH --time=12:00:00 +#SBATCH --partition=shared-gpu +#SBATCH --gres=gpu:1,VramPerGpu:25G +#SBATCH --mem=32G +#SBATCH --ntasks=1 # One task, multiple GPUs +#SBATCH --nodes=1 # Use a single node + +module load CUDA/12.3.0 GCC/12.2.0 Python/3.10.8 Qt5/5.15.7 + +source ../../.venv/bin/activate + +python3 facebook-bart-large-mnli.py \ No newline at end of file diff --git a/models/FineTuning/job_launch.sh b/models/FineTuning/job_launch.sh deleted file mode 100644 index c44e472cb18bf660b43f74ce0089004b71731770..0000000000000000000000000000000000000000 --- a/models/FineTuning/job_launch.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -#SBATCH --time=01:00:00 -#SBATCH --partition=shared-gpu -#SBATCH --gpus=1 - -module load CUDA/12.3.0 GCC/12.2.0 Python/3.10.8 Qt5/5.15.7 - -source ../../.venv/bin/activate - -python3 ./facebook-bart-large-mnli.py \ No newline at end of file diff --git a/models/FineTuning/test_model.py b/models/FineTuning/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c60cfa9ef823b84edf3ce66d611212e41ff8e974 --- /dev/null +++ b/models/FineTuning/test_model.py @@ -0,0 +1,34 @@ +from transformers import BartForSequenceClassification, BartTokenizerFast +import torch.nn.functional as F +import os +import torch + +DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./data")) + +model = BartForSequenceClassification.from_pretrained(DATA_DIR) +tokenizer = BartTokenizerFast.from_pretrained(DATA_DIR) + +# Now, you can reuse the model for inference +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +model.to(device) + +title = "Consumption of Sodium and Its Ratio to Potassium in Relation to All-Cause, Cause-Specific, and Premature Noncommunicable Disease Mortality in Middle-Aged Japanese Adults: A Prospective Cohort Study." +abstract = "Reducing premature noncommunicable disease (NCD) mortality is a global challenge. Sodium is thought to increase risk of NCDs via an effect of salt per se or high-salt foods on hypertension-induced cardiovascular disease (CVD) and gastrointestinal cancer. Further, relative risk of CVD is reportedly more closely associated with sodium-to-potassium ratio than that with sodium alone. However, few studies have investigated the effect of consumption of sodium or its ratio to consumption of potassium on risk of premature NCD death.We examined associations between intake of sodium and sodium-to-potassium ratio and risk of all-cause and cause-specific death, including premature NCD, in a Japanese prospective cohort study.During 1995-1998, a validated food frequency questionnaire was administered in 11 areas to 83,048 men and women aged 45-74 y. During 1,587,901 person-years of follow-up until the end of 2018, 17,727 all-cause deaths and 3555 premature NCD deaths were identified.Higher sodium intake was significantly associated with increased risk of all-cause and premature NCD mortality, but not all NCD mortality, among men: multivariate hazards ratios for the highest compared with lowest quintiles (HR) were 1.11 (95% CI: 1.03, 1.20; P-trend < 0.01) for all-cause and 1.25 (95% CI: 1.06, 1.47; P-trend < 0.01) for premature NCD mortality. When intakes were expressed as ratio to potassium intake, these associations (HR of all-cause: 1.19, 95% CI: 1.11-1.27; P-trend < 0.01; HR of premature NCD: 1.27, 95% CI: 1.10, 1.46; P-trend < 0.01), including associations with cancers (HR: 1.18, 95% CI: 1.07, 1.31; P-trend = 0.02), were strengthened in men.This prospective cohort study showed that both sodium intake and sodium-to-potassium ratio are associated with increased risk of all-cause and early NCD mortality in middle-aged men." +sequence = title + abstract + +inputs = tokenizer(sequence, return_tensors="pt").to(device) + +outputs = model(**inputs) + +# Apply softmax to get probabilities for all the labels +logits = outputs.logits +probs = F.softmax(logits, dim=-1) # Softmax over the last dimension (the logits for each class) + +# Print out the probabilities for each label +probs = probs.squeeze() # Remove the extra batch dimension +for i, prob in enumerate(probs): + print(f"Label {i}: {prob.item():.4f}") + +# Get the predicted label id +predicted_label_id = probs.argmax(-1).item() +print(f"Predicted category: {predicted_label_id}") \ No newline at end of file diff --git a/scripts/copy_files_hpc.py b/scripts/copy_files_hpc.py index 0dfb323a18b091094279fdddadd08a05dca24c1b..e5f494f67428f92aabb6217d612e00c320654723 100644 --- a/scripts/copy_files_hpc.py +++ b/scripts/copy_files_hpc.py @@ -2,47 +2,45 @@ import argparse import os import sys -CLUSTERS = { - "bamboo": "hpc-cluster-bamboo", - "baobab": "hpc-cluster-baobab", - "yggdrasil": "hpc-cluster-yggdrasil" +REMOTES = { + "bamboo": { "remote": "hpc-cluster-bamboo", "destination":"/home/users/p/pavlovii/NCD-Project"}, + "baobab": { "remote": "hpc-cluster-baobab", "destination":"/home/users/p/pavlovii/NCD-Project"}, + "yggdrasil": { "remote": "hpc-cluster-yggdrasil", "destination":"/home/users/p/pavlovii/NCD-Project"}, + "phenix": { "remote": "phenix-service", "destination": "/home/HES/ivan.pavlovic/NCD-Project"}, + "laptop": { "remote": "anthoine-laptop", "destination": "/home/guest/Documents/NCD-Project"}, + "tower": { "remote": "anthoine", "destination": "/home/guest/Documents/NCD-Project"} } -parser = argparse.ArgumentParser(description="Sync project files to an HPC cluster.") -parser.add_argument("--bamboo", action="store_true", help="Use hpc-cluster-bamboo") -parser.add_argument("--baobab", action="store_true", help="Use hpc-cluster-baobab") -parser.add_argument("--yggdrasil", action="store_true", help="Use hpc-cluster-yggdrasil") +parser = argparse.ArgumentParser(description="Sync project files to a remote.") -if len(sys.argv) == 1: - parser.print_help() - exit(1) +parser.add_argument( + "--remote", + type=str, + default="bamboo", + help="Specify the remote to use (default: bamboo)" +) args = parser.parse_args() -selected_cluster = None -for cluster_name, cluster_host in CLUSTERS.items(): - if getattr(args, cluster_name): - selected_cluster = cluster_host - break +selected_remote = REMOTES[args.remote] -if not selected_cluster: - print("Error: No cluster selected. Use --bamboo, --baobab, or --yggdrasil.") +if selected_remote is None: + print(f"Error: No cluster selected. Choises: {[remote for remote in REMOTES.keys()]}") parser.print_help() exit(1) PROJECT_PWD = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) -DEST_PATH = f"/home/users/p/pavlovii/NCD-Project" commands = [ - f"rclone copy {PROJECT_PWD}/models {selected_cluster}:{DEST_PATH}/models -P", - f"rclone copy {PROJECT_PWD}/testModel {selected_cluster}:{DEST_PATH}/testModel -P", - f"rclone copy {PROJECT_PWD}/parsers {selected_cluster}:{DEST_PATH}/parsers -P", - f"rclone copy {PROJECT_PWD}/variables {selected_cluster}:{DEST_PATH}/variables -P", - f"rclone copy {PROJECT_PWD}/dataSources {selected_cluster}:{DEST_PATH}/dataSources -P" + f"rclone copy {PROJECT_PWD}/models {selected_remote["remote"]}:{selected_remote["destination"]}/models -P", + f"rclone copy {PROJECT_PWD}/testModel {selected_remote["remote"]}:{selected_remote["destination"]}/testModel -P", + f"rclone copy {PROJECT_PWD}/parsers {selected_remote["remote"]}:{selected_remote["destination"]}/parsers -P", + f"rclone copy {PROJECT_PWD}/variables {selected_remote["remote"]}:{selected_remote["destination"]}/variables -P", + f"rclone copy {PROJECT_PWD}/dataSources {selected_remote["remote"]}:{selected_remote["destination"]}/dataSources -P" ] for cmd in commands: print(f"-> Running: {cmd}") os.system(cmd) -print(f"Files successfully copied to {selected_cluster}:{DEST_PATH}") \ No newline at end of file +print(f"Files successfully copied to {selected_remote["remote"]}:{selected_remote["destination"]}") \ No newline at end of file