Skip to content
Snippets Groups Projects
Commit 13372b77 authored by Ivan Pavlovich's avatar Ivan Pavlovich
Browse files

Generation des resultats sur la meme machine et création/test du script de fine tunning multi-label

parent a119304a
Branches
No related tags found
No related merge requests found
Showing
with 322874 additions and 45 deletions
...@@ -47,7 +47,7 @@ print(len(data)) ...@@ -47,7 +47,7 @@ print(len(data))
articles = [article for article in data if article["PMID"] not in dataset_pmids] articles = [article for article in data if article["PMID"] not in dataset_pmids]
articles = articles[:1000] # articles = articles[:1000]
print(len(articles)) print(len(articles))
...@@ -126,25 +126,4 @@ model_to_save = model.module if isinstance(model, torch.nn.DataParallel) else mo ...@@ -126,25 +126,4 @@ model_to_save = model.module if isinstance(model, torch.nn.DataParallel) else mo
# Save the fine-tuned model # Save the fine-tuned model
model_to_save.save_pretrained(DATA_DIR) model_to_save.save_pretrained(DATA_DIR)
tokenizer.save_pretrained(DATA_DIR) tokenizer.save_pretrained(DATA_DIR)
\ No newline at end of file
title = "Consumption of Sodium and Its Ratio to Potassium in Relation to All-Cause, Cause-Specific, and Premature Noncommunicable Disease Mortality in Middle-Aged Japanese Adults: A Prospective Cohort Study."
abstract = "Reducing premature noncommunicable disease (NCD) mortality is a global challenge. Sodium is thought to increase risk of NCDs via an effect of salt per se or high-salt foods on hypertension-induced cardiovascular disease (CVD) and gastrointestinal cancer. Further, relative risk of CVD is reportedly more closely associated with sodium-to-potassium ratio than that with sodium alone. However, few studies have investigated the effect of consumption of sodium or its ratio to consumption of potassium on risk of premature NCD death.We examined associations between intake of sodium and sodium-to-potassium ratio and risk of all-cause and cause-specific death, including premature NCD, in a Japanese prospective cohort study.During 1995-1998, a validated food frequency questionnaire was administered in 11 areas to 83,048 men and women aged 45-74 y. During 1,587,901 person-years of follow-up until the end of 2018, 17,727 all-cause deaths and 3555 premature NCD deaths were identified.Higher sodium intake was significantly associated with increased risk of all-cause and premature NCD mortality, but not all NCD mortality, among men: multivariate hazards ratios for the highest compared with lowest quintiles (HR) were 1.11 (95% CI: 1.03, 1.20; P-trend < 0.01) for all-cause and 1.25 (95% CI: 1.06, 1.47; P-trend < 0.01) for premature NCD mortality. When intakes were expressed as ratio to potassium intake, these associations (HR of all-cause: 1.19, 95% CI: 1.11-1.27; P-trend < 0.01; HR of premature NCD: 1.27, 95% CI: 1.10, 1.46; P-trend < 0.01), including associations with cancers (HR: 1.18, 95% CI: 1.07, 1.31; P-trend = 0.02), were strengthened in men.This prospective cohort study showed that both sodium intake and sodium-to-potassium ratio are associated with increased risk of all-cause and early NCD mortality in middle-aged men."
sequence = title + abstract
inputs = tokenizer(sequence, return_tensors="pt").to(device)
outputs = model(**inputs)
# Apply softmax to get probabilities for all the labels
logits = outputs.logits
probs = F.softmax(logits, dim=-1) # Softmax over the last dimension (the logits for each class)
# Print out the probabilities for each label
probs = probs.squeeze() # Remove the extra batch dimension
for i, prob in enumerate(probs):
print(f"Label {i}: {prob.item():.4f}")
# Get the predicted label id
predicted_label_id = probs.argmax(-1).item()
print(f"Predicted category: {predicted_label_id}")
\ No newline at end of file
# https://medium.com/@lidores98/finetuning-huggingface-facebook-bart-model-2c758472e340
import pandas as pd
import torch
from datasets import Dataset
import random
from transformers import BartTokenizerFast
from transformers import BartForSequenceClassification, Trainer, TrainingArguments, EvalPrediction
import numpy as np
from transformers import pipeline
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
from testModel.utils import get_dataset_filename, get_article_data
from parsers.jsonParser import parseJsonFile
from variables.pubmed import NCDS, NCDS_MESH_TERM
DATASET_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../testModel/dataset"))
TMP_DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../dataSources/PubMed/tmp"))
DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./data"))
LOGS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./logs"))
# Get PMIDs for all the articles from my testing dataset
dataset_pmids = []
for ncd in NCDS:
try:
filename = get_dataset_filename(ncd)
articles = parseJsonFile(f"{DATASET_DIR}/{filename}.json")
except Exception as e:
print(f"Error: {e}")
dataset_pmids += [article["PMID"] for article in articles]
# Remove article of the testing dataset from the training data
try:
data = parseJsonFile(f"{TMP_DATA_DIR}/save_3_years.json")
except Exception as e:
print(f"Error: {e}")
articles = [article for article in data if article["PMID"] not in dataset_pmids]
# Creating the train dataset
dataset = []
for article in articles:
title, abstract = get_article_data(article)
text = title + abstract
articles_mesh_terms = [mesh.lower() for mesh in article["MeshTerms"]]
labels = [1.0 if NCDS_MESH_TERM[ncd].lower() in articles_mesh_terms else 0.0 for ncd in NCDS]
dataset.append({
"text": text,
"labels": labels
})
train_dataset = Dataset.from_dict({
"text": [item["text"] for item in dataset],
"labels": [item["labels"] for item in dataset]
})
tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-large-mnli')
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True, padding="max_length")
tokenized_dataset = train_dataset.map(preprocess_function, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(["text"])
tokenized_dataset.set_format("torch")
training_args = TrainingArguments(
output_dir = DATA_DIR,
num_train_epochs = 32,
per_device_train_batch_size = 12,
warmup_steps = 500,
weight_decay = 0.01,
logging_dir = LOGS_DIR,
logging_steps=10,
fp16=True,
remove_unused_columns=False,
)
model = BartForSequenceClassification.from_pretrained(
"facebook/bart-large-mnli",
num_labels = len(NCDS),
problem_type="multi_label_classification",
ignore_mismatched_sizes=True
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
model = torch.nn.DataParallel(model)
model.to(device)
trainer = Trainer(
model = model, # The instantiated model to be trained
args = training_args, # Training arguments, defined above
train_dataset = tokenized_dataset, # Training dataset
tokenizer = tokenizer # The tokenizer that was used
)
trainer.train()
# If the model is wrapped in DataParallel, access the underlying model using `model.module`
model_to_save = model.module if isinstance(model, torch.nn.DataParallel) else model
# Save the fine-tuned model
model_to_save.save_pretrained(DATA_DIR)
tokenizer.save_pretrained(DATA_DIR)
\ No newline at end of file
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#SBATCH --job-name=finetunning #SBATCH --job-name=finetunning
#SBATCH --time=12:00:00 #SBATCH --time=12:00:00
#SBATCH --partition=shared-gpu #SBATCH --partition=shared-gpu
#SBATCH --gres=gpu:1,VramPerGpu:25G #SBATCH --gres=gpu:2,VramPerGpu:25G
#SBATCH --mem=32G #SBATCH --mem=32G
#SBATCH --ntasks=1 # One task, multiple GPUs #SBATCH --ntasks=1 # One task, multiple GPUs
#SBATCH --nodes=1 # Use a single node #SBATCH --nodes=1 # Use a single node
......
from transformers import pipeline
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
from variables.pubmed import NCDS
DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./data"))
classifier = pipeline("text-classification", model=DATA_DIR, tokenizer=DATA_DIR, return_all_scores=True)
title = "Consumption of Sodium and Its Ratio to Potassium in Relation to All-Cause, Cause-Specific, and Premature Noncommunicable Disease Mortality in Middle-Aged Japanese Adults: A Prospective Cohort Study."
abstract = "Reducing premature noncommunicable disease (NCD) mortality is a global challenge. Sodium is thought to increase risk of NCDs via an effect of salt per se or high-salt foods on hypertension-induced cardiovascular disease (CVD) and gastrointestinal cancer. Further, relative risk of CVD is reportedly more closely associated with sodium-to-potassium ratio than that with sodium alone. However, few studies have investigated the effect of consumption of sodium or its ratio to consumption of potassium on risk of premature NCD death.We examined associations between intake of sodium and sodium-to-potassium ratio and risk of all-cause and cause-specific death, including premature NCD, in a Japanese prospective cohort study.During 1995-1998, a validated food frequency questionnaire was administered in 11 areas to 83,048 men and women aged 45-74 y. During 1,587,901 person-years of follow-up until the end of 2018, 17,727 all-cause deaths and 3555 premature NCD deaths were identified.Higher sodium intake was significantly associated with increased risk of all-cause and premature NCD mortality, but not all NCD mortality, among men: multivariate hazards ratios for the highest compared with lowest quintiles (HR) were 1.11 (95% CI: 1.03, 1.20; P-trend < 0.01) for all-cause and 1.25 (95% CI: 1.06, 1.47; P-trend < 0.01) for premature NCD mortality. When intakes were expressed as ratio to potassium intake, these associations (HR of all-cause: 1.19, 95% CI: 1.11-1.27; P-trend < 0.01; HR of premature NCD: 1.27, 95% CI: 1.10, 1.46; P-trend < 0.01), including associations with cancers (HR: 1.18, 95% CI: 1.07, 1.31; P-trend = 0.02), were strengthened in men.This prospective cohort study showed that both sodium intake and sodium-to-potassium ratio are associated with increased risk of all-cause and early NCD mortality in middle-aged men."
sequence = title + abstract
predictions = classifier(sequence)
print(predictions)
threshold = 0.7
labels = [NCDS[i] for i, score in enumerate(predictions[0]) if score["score"] > threshold]
print("Predicted Labels:", labels)
\ No newline at end of file
...@@ -2,43 +2,41 @@ import argparse ...@@ -2,43 +2,41 @@ import argparse
import os import os
import sys import sys
CLUSTERS = { REMOTES = {
"bamboo": "hpc-cluster-bamboo", "bamboo": { "remote": "hpc-cluster-bamboo", "destination":"/home/users/p/pavlovii/NCD-Project"},
"baobab": "hpc-cluster-baobab", "baobab": { "remote": "hpc-cluster-baobab", "destination":"/home/users/p/pavlovii/NCD-Project"},
"yggdrasil": "hpc-cluster-yggdrasil" "yggdrasil": { "remote": "hpc-cluster-yggdrasil", "destination":"/home/users/p/pavlovii/NCD-Project"},
"phenix": { "remote": "phenix-service", "destination": "/home/HES/ivan.pavlovic/NCD-Project"},
"laptop": { "remote": "anthoine-laptop", "destination": "/home/guest/Documents/NCD-Project"},
"tower": { "remote": "anthoine", "destination": "/home/guest/Documents/NCD-Project"}
} }
parser = argparse.ArgumentParser(description="Sync project files to an HPC cluster.") parser = argparse.ArgumentParser(description="Sync project files to a remote.")
parser.add_argument("--bamboo", action="store_true", help="Use hpc-cluster-bamboo")
parser.add_argument("--baobab", action="store_true", help="Use hpc-cluster-baobab")
parser.add_argument("--yggdrasil", action="store_true", help="Use hpc-cluster-yggdrasil")
if len(sys.argv) == 1: parser.add_argument(
parser.print_help() "--remote",
exit(1) type=str,
default="bamboo",
help="Specify the remote to use (default: bamboo)"
)
args = parser.parse_args() args = parser.parse_args()
selected_cluster = None selected_remote = REMOTES[args.remote]
for cluster_name, cluster_host in CLUSTERS.items():
if getattr(args, cluster_name):
selected_cluster = cluster_host
break
if not selected_cluster: if selected_remote is None:
print("Error: No cluster selected. Use --bamboo, --baobab, or --yggdrasil.") print(f"Error: No cluster selected. Choises: {[remote for remote in REMOTES.keys()]}")
parser.print_help() parser.print_help()
exit(1) exit(1)
PROJECT_PWD = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) PROJECT_PWD = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
DEST_PATH = f"/home/users/p/pavlovii/NCD-Project"
commands = [ commands = [
f"rclone copy {selected_cluster}:{DEST_PATH}/testModel/results {PROJECT_PWD}/testModel/results -P", f"rclone copy {selected_remote["remote"]}:{selected_remote["destination"]}/testModel/results {PROJECT_PWD}/testModel/results -P",
] ]
for cmd in commands: for cmd in commands:
print(f"-> Running: {cmd}") print(f"-> Running: {cmd}")
os.system(cmd) os.system(cmd)
print(f"Results successfully copied from {selected_cluster}:{DEST_PATH}") print(f"Files successfully copied from {selected_remote["remote"]}:{selected_remote["destination"]}")
\ No newline at end of file \ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment