Skip to content
Snippets Groups Projects
Commit c038693b authored by Ivan Pavlovich's avatar Ivan Pavlovich
Browse files

DONE: finetunnig script for facebook/bart-large-mnli (not tested)

parent ae69dccc
Branches
No related tags found
No related merge requests found
No preview for this file type
{
"ALL": {
"month": {
"max": 5309,
"min": 2959,
"avg": 4457.583333333333
}
},
"\"Noncommunicable Diseases\"[Mesh:noexp]": {
"month": {
"max": 59,
"min": 20,
"avg": 42.361111111111114
}
},
"\"Diabetes Mellitus\"[Mesh:noexp]": {
"month": {
"max": 670,
"min": 219,
"avg": 501.4166666666667
}
},
"\"Neoplasms\"[Mesh:noexp]": {
"month": {
"max": 2177,
"min": 1330,
"avg": 1829.888888888889
}
},
"\"Respiratory Tract Diseases\"[Mesh:noexp]": {
"month": {
"max": 40,
"min": 8,
"avg": 24.36111111111111
}
},
"\"Cardiovascular Diseases\"[Mesh:noexp]": {
"month": {
"max": 819,
"min": 503,
"avg": 675.4166666666666
}
},
"\"Mental Health\"[Mesh:noexp]": {
"month": {
"max": 686,
"min": 306,
"avg": 518.4722222222222
}
},
"\"Diabetes Mellitus, Type 1\"[Mesh:noexp]": {
"month": {
"max": 260,
"min": 141,
"avg": 219.36111111111111
}
},
"\"Diabetes Mellitus, Type 2\"[Mesh:noexp]": {
"month": {
"max": 1061,
"min": 575,
"avg": 879.1666666666666
}
}
}
\ No newline at end of file
import json
import sys
import os
import statistics
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
from requests import get
import time
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
from variables.pubmed import PUBMED_API_KEY, NCDS_MESH_TERM
from dataSources.PubMed.util import get_mesh_noexp_term, url_encode
INTERVALS = [
"day",
"week",
"month"
]
def get_count_for_year(year, term, interval = "month"):
start_date = datetime(2021, 12, 31)
counts = []
while(start_date < datetime(2024, 12, 31)):
if interval == "day":
end_date = start_date + timedelta(days=1)
elif interval == "week":
end_date = start_date + timedelta(weeks=1)
elif interval == "month":
end_date = (start_date + timedelta(days=32)).replace(day=1)
start_date += timedelta(days=1)
print(f"Interval: {start_date.strftime("%Y/%m/%d")} - {end_date.strftime("%Y/%m/%d")}")
url = f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&api_key={PUBMED_API_KEY}&term={term}&retmode=json&mindate={start_date.strftime("%Y/%m/%d")}&maxdate={end_date.strftime("%Y/%m/%d")}'
response = get(url)
search_res = response.json()
counts.append(int(search_res["esearchresult"]["count"]))
start_date = end_date
time.sleep(0.1)
max_count = max(counts)
min_count = min(counts)
avg_count = statistics.mean(counts)
return {"max": max_count, "min": min_count, "avg": avg_count}
data = {}
data["ALL"] = {}
ncds_mesh_terms = get_mesh_noexp_term(NCDS_MESH_TERM)
for ncd in ncds_mesh_terms:
data[ncd] = {}
print("NCD: ", ncd)
for interval in INTERVALS:
print("INTERVAL: ", interval)
counts = get_count_for_year(2024, url_encode(ncd), interval)
print(counts)
data[ncd][interval] = counts
print("ALL NCDS")
for interval in INTERVALS:
print("INTERVAL: ", interval)
counts = get_count_for_year(2024, url_encode(" OR ".join(ncds_mesh_terms)), interval)
print(counts)
data["ALL"][interval] = counts
DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./data"))
with open(f"{DATA_DIR}/data_num_test.json", "w") as json_file:
json.dump(data, json_file, indent=4)
...@@ -3,65 +3,121 @@ ...@@ -3,65 +3,121 @@
import pandas as pd import pandas as pd
import torch import torch
from datasets import Dataset, load_metric from datasets import Dataset
import random import random
from transformers import BartTokenizerFast from transformers import BartTokenizerFast
from transformers import BartForSequenceClassification, Trainer, TrainingArguments, EvalPrediction from transformers import BartForSequenceClassification, Trainer, TrainingArguments, EvalPrediction
import numpy as np import numpy as np
from transformers import pipeline from transformers import pipeline
import os
import sys
# Split to train and test portions sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
df_train = df.head(train_portion)
df_test = df.tail(test_portion) from testModel.utils import get_dataset_filename, get_article_data
# Convert to Dataset objects from parsers.jsonParser import parseJsonFile
train_ds = Dataset.from_pandas(df_train, split="train") from variables.pubmed import NCDS, NCDS_MESH_TERM
test_ds = Dataset.from_pandas(df_test, split="test")
DATASET_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../testModel/dataset"))
TMP_DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../dataSources/PubMed/tmp"))
DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./data"))
# Get PMIDs for all the articles from my testing dataset
dataset_pmids = []
for ncd in NCDS:
try:
filename = get_dataset_filename(ncd)
articles = parseJsonFile(f"{DATASET_DIR}/{filename}.json")
except Exception as e:
print(f"Error: {e}")
dataset_pmids += [article["PMID"] for article in articles]
print(len(dataset_pmids))
# Remove article of the testing dataset from the training data
try:
data = parseJsonFile(f"{TMP_DATA_DIR}/save_3_years.json")
except Exception as e:
print(f"Error: {e}")
print(len(data))
articles = [article for article in data if article["PMID"] not in dataset_pmids]
print(len(articles))
dataset = []
for article in articles:
title, abstract = get_article_data(article)
text = title + abstract
articles_mesh_terms = [mesh.lower() for mesh in article["MeshTerms"]]
for ncd, mesh in NCDS_MESH_TERM.items():
if mesh.lower() in articles_mesh_terms:
dataset.append({
"text": text,
"label": ncd
})
print(len(dataset))
label2id = {label: i for i, label in enumerate(NCDS)}
train_dataset = Dataset.from_dict({
"text": [item["text"] for item in dataset],
"label": [label2id[item["label"]] for item in dataset]
})
tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-large-mnli') tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-large-mnli')
def create_input_sequence(sample): def preprocess_function(examples):
text = sample["text"] return tokenizer(examples["text"], truncation=True, padding="max_length")
label = sample["class"][0]
contradiction_label = random.choice([x for x in label_to_int if x != label]) tokenized_dataset = train_dataset.map(preprocess_function, batched=True)
encoded_sequence = tokenizer(text * 2, [template.format(label), template.format(contradiction_label)], truncation = True, padding = 'max_length')
encoded_sequence["labels"] = [2, 0] tokenized_dataset = tokenized_dataset.remove_columns(["text"])
encoded_sequence["input_sentence"] = tokenizer.batch_decode(encoded_sequence.input_ids) tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
return encoded_sequence tokenized_dataset.set_format("torch")
print(len(tokenized_dataset))
train_dataset = train_ds.map(create_input_sequence, batched = True, batch_size = 1, remove_columns = ["class", "text"])
test_dataset = test_ds.map(create_input_sequence, batched = True, batch_size = 1, remove_columns = ["class", "text"])
def compute_metrics(p: EvalPrediction):
metric_acc = load_metric("accuracy")
metric_f1 = load_metric("f1")
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
preds = np.argmax(preds, axis = 1)
result = {}
result["accuracy"] = metric_acc.compute(predictions = preds, references = p.label_ids)["accuracy"]
result["f1"] = metric_f1.compute(predictions = preds, references = p.label_ids, average = 'macro')["f1"]
return result
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir = model_directory, # Output directory output_dir = DATA_DIR, # Output directory
num_train_epochs = 32, # Total number of training epochs num_train_epochs = 32, # Total number of training epochs
per_device_train_batch_size = 16, # Batch size per device during training per_device_train_batch_size = 16, # Batch size per device during training
per_device_eval_batch_size = 64, # Batch size for evaluation
warmup_steps = 500, # Number of warmup steps for learning rate scheduler warmup_steps = 500, # Number of warmup steps for learning rate scheduler
weight_decay = 0.01, # Strength of weight decay weight_decay = 0.01, # Strength of weight decay
) )
model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli", num_labels = len(label_to_int), ignore_mismatched_sizes = True) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")
model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli", num_labels = len(NCDS), ignore_mismatched_sizes=True).to(device)
trainer = Trainer( trainer = Trainer(
model = model, # The instantiated model to be trained model = model, # The instantiated model to be trained
args = training_args, # Training arguments, defined above args = training_args, # Training arguments, defined above
compute_metrics = compute_metrics, # A function to compute the metrics train_dataset = tokenized_dataset, # Training dataset
train_dataset = train_dataset, # Training dataset
eval_dataset = test_dataset, # Evaluation dataset
tokenizer = tokenizer # The tokenizer that was used tokenizer = tokenizer # The tokenizer that was used
) )
classifier = pipeline("zero-shot-classification", model = model, tokenizer = tokenizer, device = 0) trainer.train()
title = "Consumption of Sodium and Its Ratio to Potassium in Relation to All-Cause, Cause-Specific, and Premature Noncommunicable Disease Mortality in Middle-Aged Japanese Adults: A Prospective Cohort Study."
abstract = "Reducing premature noncommunicable disease (NCD) mortality is a global challenge. Sodium is thought to increase risk of NCDs via an effect of salt per se or high-salt foods on hypertension-induced cardiovascular disease (CVD) and gastrointestinal cancer. Further, relative risk of CVD is reportedly more closely associated with sodium-to-potassium ratio than that with sodium alone. However, few studies have investigated the effect of consumption of sodium or its ratio to consumption of potassium on risk of premature NCD death.We examined associations between intake of sodium and sodium-to-potassium ratio and risk of all-cause and cause-specific death, including premature NCD, in a Japanese prospective cohort study.During 1995-1998, a validated food frequency questionnaire was administered in 11 areas to 83,048 men and women aged 45-74 y. During 1,587,901 person-years of follow-up until the end of 2018, 17,727 all-cause deaths and 3555 premature NCD deaths were identified.Higher sodium intake was significantly associated with increased risk of all-cause and premature NCD mortality, but not all NCD mortality, among men: multivariate hazards ratios for the highest compared with lowest quintiles (HR) were 1.11 (95% CI: 1.03, 1.20; P-trend < 0.01) for all-cause and 1.25 (95% CI: 1.06, 1.47; P-trend < 0.01) for premature NCD mortality. When intakes were expressed as ratio to potassium intake, these associations (HR of all-cause: 1.19, 95% CI: 1.11-1.27; P-trend < 0.01; HR of premature NCD: 1.27, 95% CI: 1.10, 1.46; P-trend < 0.01), including associations with cancers (HR: 1.18, 95% CI: 1.07, 1.31; P-trend = 0.02), were strengthened in men.This prospective cohort study showed that both sodium intake and sodium-to-potassium ratio are associated with increased risk of all-cause and early NCD mortality in middle-aged men."
sequence = title + abstract
inputs = tokenizer(sequence, return_tensors="pt").to(device)
outputs = model(**inputs)
print(outputs)
predicted_label_id = outputs.logits.argmax(-1).item()
classifier(sequences, label_to_int, multi_label=False) print(f"Predicted category: {predicted_label_id}")
\ No newline at end of file \ No newline at end of file
#!/bin/bash
#SBATCH --time=01:00:00
#SBATCH --partition=shared-gpu
#SBATCH --gpus=1
module load CUDA/12.3.0 GCC/12.2.0 Python/3.10.8 Qt5/5.15.7
source ../../.venv/bin/activate
python3 ./facebook-bart-large-mnli.py
\ No newline at end of file
...@@ -37,7 +37,8 @@ commands = [ ...@@ -37,7 +37,8 @@ commands = [
f"rclone copy {PROJECT_PWD}/models {selected_cluster}:{DEST_PATH}/models -P", f"rclone copy {PROJECT_PWD}/models {selected_cluster}:{DEST_PATH}/models -P",
f"rclone copy {PROJECT_PWD}/testModel {selected_cluster}:{DEST_PATH}/testModel -P", f"rclone copy {PROJECT_PWD}/testModel {selected_cluster}:{DEST_PATH}/testModel -P",
f"rclone copy {PROJECT_PWD}/parsers {selected_cluster}:{DEST_PATH}/parsers -P", f"rclone copy {PROJECT_PWD}/parsers {selected_cluster}:{DEST_PATH}/parsers -P",
f"rclone copy {PROJECT_PWD}/variables {selected_cluster}:{DEST_PATH}/variables -P" f"rclone copy {PROJECT_PWD}/variables {selected_cluster}:{DEST_PATH}/variables -P",
f"rclone copy {PROJECT_PWD}/dataSources {selected_cluster}:{DEST_PATH}/dataSources -P"
] ]
for cmd in commands: for cmd in commands:
......
...@@ -6,15 +6,24 @@ import sys ...@@ -6,15 +6,24 @@ import sys
PROJECT_PWD = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) PROJECT_PWD = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
DEST_PATH = f"/home/guest/Documents/NCD-Project" DEST_PATH = f"/home/guest/Documents/NCD-Project"
# commands = [
# f"rclone copy {PROJECT_PWD}/models anthoine:{DEST_PATH}/models -P",
# f"rclone copy {PROJECT_PWD}/testModel anthoine:{DEST_PATH}/testModel -P",
# f"rclone copy {PROJECT_PWD}/parsers anthoine:{DEST_PATH}/parsers -P",
# f"rclone copy {PROJECT_PWD}/variables anthoine:{DEST_PATH}/variables -P",
# f"rclone copy {PROJECT_PWD}/dataSources anthoine:{DEST_PATH}/dataSources -P"
# ]
commands = [ commands = [
f"rclone copy {PROJECT_PWD}/models anthoine:{DEST_PATH}/models -P", f"rclone copy {PROJECT_PWD}/models anthoine-laptop:{DEST_PATH}/models -P",
f"rclone copy {PROJECT_PWD}/testModel anthoine:{DEST_PATH}/testModel -P", f"rclone copy {PROJECT_PWD}/testModel anthoine-laptop:{DEST_PATH}/testModel -P",
f"rclone copy {PROJECT_PWD}/parsers anthoine:{DEST_PATH}/parsers -P", f"rclone copy {PROJECT_PWD}/parsers anthoine-laptop:{DEST_PATH}/parsers -P",
f"rclone copy {PROJECT_PWD}/variables anthoine:{DEST_PATH}/variables -P" f"rclone copy {PROJECT_PWD}/variables anthoine-laptop:{DEST_PATH}/variables -P",
f"rclone copy {PROJECT_PWD}/dataSources anthoine-laptop:{DEST_PATH}/dataSources -P"
] ]
for cmd in commands: for cmd in commands:
print(f"-> Running: {cmd}") print(f"-> Running: {cmd}")
os.system(cmd) os.system(cmd)
print(f"Files successfully copied to anthoine:{DEST_PATH}") print(f"Files successfully copied to anthoine-laptop:{DEST_PATH}")
\ No newline at end of file \ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment