Understand Zero-Shot Learning with Python Text Classification Example
Let's learn the useful machine learning method with Python example
The development of machine learning models is bigger than before, with all the Large Language Model (LLM) announcements every now and then. Also, not counting the big data model for other tasks such as computer vision, image generation, audio learning, and more. We are in the era of large models.
Although, researchers and developers can’t rely on classic machine learning training to keep up with the latest trend and have a stable output. Often, it takes time and a lot of money to acquire labelled data. That is why many learning methods were developed; one of them is Zero-Shot Learning.
So, what is Zero-Shot learning, and how do we use them in machine learning development? Let’s explore it further now.
Zero-Shot Learning
Zero-Shot Learning could have many definitions, and it’s very task-dependent. If we take a generalization, we can define Zero-Shot Learning as a model way of learning from data to classify classes that it hasn’t seen previously. Basically, the model tries to extrapolate from the training data.
In the usual way of training, we must train our model with a certain dataset to do the task. In a big data scene, we often use transfer learning with a pre-trained model to our dataset to do the task, but Zero-Shot Learning tries to perform the tasks without any specific training.
Intuitively, Zero-Shot Learning learns from the seen data and tries to predict the unseen data using auxiliary information. According to a paper by Romera-Paredes and Torr (2017), Zero-Shot Learning inherently is a two-stage process:
Training stage, knowledge about the attributes is captured, and
Inference stage, where this knowledge is used to categorise instances among a new set of classes
I suggest you read the article by Ekin Tiu (2021) to understand further regarding Zero-Shot Learning.
For now, let’s continue to our Python tutorial to perform Zero-Shot Learning. In this example, we would use Task-aware representation of sentences (TARS) to classify text without much training data using the FlairNLP framework.
Zero-Shot Learning Tutorial with FlairNLP
FlairNLP is one of the NLP framework models that provide various models for NLP tasks, one of which is the TARS model. It allows the NLP model to do various NLP tasks without much training. Let’s try to use them for our example. We would adopt the Zero-Shot learning tutorial coming from FlairNLP.
First, we need to install the FlairNLP package.
pip install flair
After the installation, let’s use the TARS Classifier model to perform text classification without training data (Zero-Shot).
Let’s prepare the model and the text example first.
from flair.models import TARSClassifier
from flair.data import Sentence
#load the model
tars = TARSClassifier.load('tars-base')
#prepare text to classify
sentence = Sentence("I don't know what are you talking about!")
Then we would define the label to the classification we want the model to predict. For example, I want the model to classify the text into “Angry” and “Happy”.
# define the classes that you want to predict using descriptive names
classes = ["angry", "happy"]
Predict for these classes
tars.predict_zero_shot(sentence, classes, multi_label=False)
# Print sentence with predicted labels
print(sentence)
The result is shown in the image below. Overall, the result is close as the sentence text I gave is close to the Angry tone.
You could always change the label and even train the TARS model by defining more definitive labels. For example, here is the tutorial to train the TARS model for question classification taken from the tutorial.
from flair.data import Corpus
from flair.datasets import TREC_6
from flair.models import TARSClassifier
from flair.trainers import ModelTrainer
# 1. define label names in natural language since some datasets come with cryptic set of labels
label_name_map = {'ENTY': 'question about entity',
'DESC': 'question about description',
'ABBR': 'question about abbreviation',
'HUM': 'question about person',
'NUM': 'question about number',
'LOC': 'question about location'
}
# 2. get the corpus
corpus: Corpus = TREC_6(label_name_map=label_name_map)
# 3. what label do you want to predict?
label_type = 'question_class'
# 4. make a label dictionary
label_dict = corpus.make_label_dictionary(label_type=label_type)
# 5. start from our existing TARS base model for English
tars = TARSClassifier.load("tars-base")
# 5a: alternatively, comment out previous line and comment in next line to train a new TARS model from scratch instead
# tars = TARSClassifier(embeddings="bert-base-uncased")
# 6. switch to a new task (TARS can do multiple tasks so you must define one)
tars.add_and_switch_to_new_task(task_name="question classification",
label_dictionary=label_dict,
label_type=label_type,
)
# 7. initialize the text classifier trainer
trainer = ModelTrainer(tars, corpus)
# 8. start the training
trainer.train(base_path='resources/taggers/trec', # path to store the model artifacts
learning_rate=0.02, # use very small learning rate
mini_batch_size=16,
mini_batch_chunk_size=4, # optionally set this if transformer is too much for your machine
max_epochs=1, # terminate after 10 epochs
)
With the base model available, Zero-Shot Learning certainly makes our life easier. I hope you find the explanation and example useful!
If you like this newsletter post, please share, subscribe, and comment to help others know.