Note for Transformer & BERT
BERT
-
BERT is pretrained with two objectives: masked language modeling and next-sentence prediction
- [SEP] [CLS] Intro & Coda
- [CLS] classification input label
The input embeddings are passed through multiple encoder layers to output some final hidden states.
-
To use the pretrained model for text classification, add a sequence classification head on top of the base BERT model. The sequence classification head is a linear layer that accepts the final hidden states and performs a linear transformation to convert them into logits.
The cross-entropy loss is calculated between the logits and target to find the most likely label.
Token classification for BERT
https://huggingface.co/docs/transformers/tasks/token_classification
Use IMDb first:
from datasets import load_dataset
imdb = load_dataset("imdb")
imdb["test"][0] # {"label" : 0, "text" : ....}
tokenizer(examples["text"], truncation=True) , truncation truncates sequences to be no longer than DistilBERT’s maximum input length
To apply prepocesses, use map(), batched can enable you to train multiple module at the same time
Fine Tuning
Import module:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id
)
id2lable & label2id : Map for translating ids to labels: E.G.
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
Define hyperparameters & trainer:
training_args = TrainingArguments(
output_dir="my_awesome_model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_imdb["train"],
eval_dataset=tokenized_imdb["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()