Skip to content
Snippets Groups Projects
Commit 90b0177b authored by Ivan Pavlovich's avatar Ivan Pavlovich
Browse files

Starting to try to finetune a zero-shot for my labels

parent eed59417
No related branches found
No related tags found
No related merge requests found
No preview for this file type
......@@ -106,8 +106,6 @@ def getPubmedData(term, date_min, date_max, nb_items = -1, debug = False, store
for part in entrie["MedlineCitation"]["Article"]["Abstract"]["AbstractText"]:
if "#text" in part:
data["Abstract"] += part["#text"]
elif not isinstance(entrie["MedlineCitation"]["Article"]["Abstract"]["AbstractText"], str):
data["Abstract"] = entrie["MedlineCitation"]["Article"]["Abstract"]["AbstractText"]["#text"]
else:
data["Abstract"] = entrie["MedlineCitation"]["Article"]["Abstract"]["AbstractText"]
......
# https://medium.com/@lidores98/finetuning-huggingface-facebook-bart-model-2c758472e340
# Copied code: (need to understand and modify it for my usage)
import pandas as pd
import torch
from datasets import Dataset, load_metric
import random
from transformers import BartTokenizerFast
from transformers import BartForSequenceClassification, Trainer, TrainingArguments, EvalPrediction
import numpy as np
from transformers import pipeline
# Split to train and test portions
df_train = df.head(train_portion)
df_test = df.tail(test_portion)
# Convert to Dataset objects
train_ds = Dataset.from_pandas(df_train, split="train")
test_ds = Dataset.from_pandas(df_test, split="test")
tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-large-mnli')
def create_input_sequence(sample):
text = sample["text"]
label = sample["class"][0]
contradiction_label = random.choice([x for x in label_to_int if x != label])
encoded_sequence = tokenizer(text * 2, [template.format(label), template.format(contradiction_label)], truncation = True, padding = 'max_length')
encoded_sequence["labels"] = [2, 0]
encoded_sequence["input_sentence"] = tokenizer.batch_decode(encoded_sequence.input_ids)
return encoded_sequence
train_dataset = train_ds.map(create_input_sequence, batched = True, batch_size = 1, remove_columns = ["class", "text"])
test_dataset = test_ds.map(create_input_sequence, batched = True, batch_size = 1, remove_columns = ["class", "text"])
def compute_metrics(p: EvalPrediction):
metric_acc = load_metric("accuracy")
metric_f1 = load_metric("f1")
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
preds = np.argmax(preds, axis = 1)
result = {}
result["accuracy"] = metric_acc.compute(predictions = preds, references = p.label_ids)["accuracy"]
result["f1"] = metric_f1.compute(predictions = preds, references = p.label_ids, average = 'macro')["f1"]
return result
training_args = TrainingArguments(
output_dir = model_directory, # Output directory
num_train_epochs = 32, # Total number of training epochs
per_device_train_batch_size = 16, # Batch size per device during training
per_device_eval_batch_size = 64, # Batch size for evaluation
warmup_steps = 500, # Number of warmup steps for learning rate scheduler
weight_decay = 0.01, # Strength of weight decay
)
model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli", num_labels = len(label_to_int), ignore_mismatched_sizes = True)
trainer = Trainer(
model = model, # The instantiated model to be trained
args = training_args, # Training arguments, defined above
compute_metrics = compute_metrics, # A function to compute the metrics
train_dataset = train_dataset, # Training dataset
eval_dataset = test_dataset, # Evaluation dataset
tokenizer = tokenizer # The tokenizer that was used
)
classifier = pipeline("zero-shot-classification", model = model, tokenizer = tokenizer, device = 0)
classifier(sequences, label_to_int, multi_label=False)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment