BERT masked LM training
Initial Setup
I will use BERT model from huggingface and a lighweight wrapper over pytorch
called Pytorch Lightning to avoid writing boilerplate.
!pip install transformers
!pip install pytorch-lightning
To run this over TPUs, the following dependencies are also needed.
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
To demonstrate I'll use a text corpus, which can be downloaded as follows:
import urllib.request
txt_url = "https://s3.amazonaws.com/text-datasets/nietzsche.txt"
urllib.request.urlretrieve(txt_url, 'train.txt')
Imports and Configs
import pytorch_lightning as pl
from argparse import Namespace
from transformers import (
BertConfig,
BertForMaskedLM,
BertTokenizer,
AdamW,
DataCollatorForLanguageModeling
)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
args = Namespace()
args.train = "train.txt"
args.max_len = 128
args.model_name = "bert-base-uncased"
args.epochs = 1
args.batch_size = 4
Create Dataloader
The Dataset class reads a text file. Each line in the file forms a single element of the dataset after tokenization with BERT's tokenizer.
tokenizer = BertTokenizer.from_pretrained(args.model_name)
class MaskedLMDataset(Dataset):
def __init__(self, file, tokenizer):
self.tokenizer = tokenizer
self.lines = self.load_lines(file)
self.ids = self.encode_lines(self.lines)
def load_lines(self, file):
with open(file) as f:
lines = [
line
for line in f.read().splitlines()
if (len(line) > 0 and not line.isspace())
]
return lines
def encode_lines(self, lines):
batch_encoding = self.tokenizer(
lines, add_special_tokens=True, truncation=True, max_length=args.max_len
)
return batch_encoding["input_ids"]
def __len__(self):
return len(self.lines)
def __getitem__(self, idx):
return torch.tensor(self.ids[idx], dtype=torch.long)
train_dataset = MaskedLMDataset(args.train, tokenizer)
A collator function in pytorch takes a list of elements given by the dataset class and and creates a batch of input (and targets). Huggingface provides a convenient collator function which takes a list of input ids from my dataset, masks 15% of the tokens, and creates a batch after appropriate padding.
Targets are created by cloning the input ids. Then, if a token is supposed to be masked, the corresponding input id is replaced by that of either the [MASK] token (80% chance), a random token (10% chance), the same token (10% chance). If a token is not supposed to be masked, the corresponding target id is replaced by -100, so that they are ignored during loss calculation.
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
train_loader = DataLoader(
train_dataset,
batch_size=args.bs,
collate_fn=data_collator
)
Define model, training step and optmizer
class Bert(pl.LightningModule):
def __init__(self):
super().__init__()
self.bert = BertForMaskedLM.from_pretrained(args.model_name)
def forward(self, input_ids, labels):
return self.bert(input_ids=input_ids,labels=labels)
def training_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
labels = batch["labels"]
outputs = self(input_ids=input_ids, labels=labels)
loss = outputs[0]
return {"loss": loss}
def configure_optimizers(self):
return AdamW(self.parameters(), lr=1e-5)
model = Bert()
Train
This is where pytorch lightning does an awesome job. Once the model and
data loader are ready, I can train on CPU, single GPU, multiple GPUs, single TPU core and multiple TPU cores with just two lines of code.
- Initialise the Trainer as per the hardware:
CPU
GPU (single or multiple)trainer = pl.Trainer(max_epochs=1)
Single TPU coretrainer = pl.Trainer(max_epochs=1, gpus=8)
Multiple TPU corestrainer = pl.Trainer(max_epochs=1, tpu_cores=[1])
trainer = pl.Trainer(max_epochs=1, tpu_cores=8)
- Run the fit function.
trainer = pl.Trainer(gpus=1)
trainer.fit(model, train_loader)
Saving and Loading
The weights can be saved and loaded for predictions like this.
torch.save(model.state_dict(), 'saved.bin')
class BertPred(nn.Module):
def __init__(self):
super().__init__()
self.bert = BertForMaskedLM.from_pretrained('bert-base-uncased')
def forward(self, input_ids, labels=None):
return self.bert(input_ids=input_ids,labels=labels)
new_model = BERTPred()
new_model.load_state_dict(torch.load('saved.bin'))
new_model.eval()