Select Git revision
test_model.py
test_model.py 3.07 KiB
from transformers import BartForSequenceClassification, BartTokenizerFast
import torch.nn.functional as F
import os
import torch
DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "./data"))
model = BartForSequenceClassification.from_pretrained(DATA_DIR)
tokenizer = BartTokenizerFast.from_pretrained(DATA_DIR)
# Now, you can reuse the model for inference
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
title = "Consumption of Sodium and Its Ratio to Potassium in Relation to All-Cause, Cause-Specific, and Premature Noncommunicable Disease Mortality in Middle-Aged Japanese Adults: A Prospective Cohort Study."
abstract = "Reducing premature noncommunicable disease (NCD) mortality is a global challenge. Sodium is thought to increase risk of NCDs via an effect of salt per se or high-salt foods on hypertension-induced cardiovascular disease (CVD) and gastrointestinal cancer. Further, relative risk of CVD is reportedly more closely associated with sodium-to-potassium ratio than that with sodium alone. However, few studies have investigated the effect of consumption of sodium or its ratio to consumption of potassium on risk of premature NCD death.We examined associations between intake of sodium and sodium-to-potassium ratio and risk of all-cause and cause-specific death, including premature NCD, in a Japanese prospective cohort study.During 1995-1998, a validated food frequency questionnaire was administered in 11 areas to 83,048 men and women aged 45-74 y. During 1,587,901 person-years of follow-up until the end of 2018, 17,727 all-cause deaths and 3555 premature NCD deaths were identified.Higher sodium intake was significantly associated with increased risk of all-cause and premature NCD mortality, but not all NCD mortality, among men: multivariate hazards ratios for the highest compared with lowest quintiles (HR) were 1.11 (95% CI: 1.03, 1.20; P-trend < 0.01) for all-cause and 1.25 (95% CI: 1.06, 1.47; P-trend < 0.01) for premature NCD mortality. When intakes were expressed as ratio to potassium intake, these associations (HR of all-cause: 1.19, 95% CI: 1.11-1.27; P-trend < 0.01; HR of premature NCD: 1.27, 95% CI: 1.10, 1.46; P-trend < 0.01), including associations with cancers (HR: 1.18, 95% CI: 1.07, 1.31; P-trend = 0.02), were strengthened in men.This prospective cohort study showed that both sodium intake and sodium-to-potassium ratio are associated with increased risk of all-cause and early NCD mortality in middle-aged men."
sequence = title + abstract
inputs = tokenizer(sequence, return_tensors="pt").to(device)
outputs = model(**inputs)
# Apply softmax to get probabilities for all the labels
logits = outputs.logits
probs = F.softmax(logits, dim=-1) # Softmax over the last dimension (the logits for each class)
# Print out the probabilities for each label
probs = probs.squeeze() # Remove the extra batch dimension
for i, prob in enumerate(probs):
print(f"Label {i}: {prob.item():.4f}")
# Get the predicted label id
predicted_label_id = probs.argmax(-1).item()
print(f"Predicted category: {predicted_label_id}")