Computer scienceData scienceNLPMain NLP tasks

Text Classification

5 minutes read

This topic will focus on the popular tasks in NLP — text classification. It can help you to classify texts (messages, documents, words) into several classes. For instance, emotional recognition is a text classification task because we classify texts in such classes as happy, sad, and so on.

Input and output data

Input for a text classification model is a sequence of words. The sequence can be preprocessed in many ways:

  • One-Hot encoding — a vector indicating the presence of a token in the text. The vector is the size of the vocabulary of the whole dataset.
  • Count-encoding — a vector indicating the frequency of a token in the text.
  • TF-IDF vectorization — a vector with the weighted frequencies of a token in the text and the corpus.
  • Word2Vec or FastText vectors — embeddings of words based on the context around them.
  • Pretrained vectors from a language model — embeddings trained used language modeling techniques.
  • Trained vectors from scratch — custom embeddings trained on a specific corpus.

The output of a text classification in a multi-class model is a label; in this type of classification, each sample is assigned only one label. If one text can be marked with more than one label, the output will be a vector called multi-label classification.

How to choose the model

To classify texts, we can choose many different models. However, some of them are more suitable in specific settings.

  1. A small amount of data. Suppose you don't have a lot of data. In that case, you should first experiment with more basic encoding methods such as one-hot encoding, count-encoding, and TF-IDF vectorization and then use machine learning models such as SVM (Support Vector Machines) or Decision Trees.
  2. A lot of data. If you have a lot of texts to classify, you should pay more attention to such vectorization techniques as Word2Vec or FastText, pre-trained embeddings on language models. If the topics in your data are unique, you can even train your vectors from scratch. After vectorization, use neural networks to classify texts; such models as RNN, CNN, and Transformer based models (BERT, RoBERTa) are the most appropriate.

Note that the methods can be combined to achieve higher scores. For instance, you can concatenate both the embeddings from BERT and TF-IDF vectors or use RNN and CNN models together.

Performance measures

To understand the performance of the trained model for text classification, you should use the same metrics for any other classification pipelines. The most common metrics are as follows:

  1. Accuracy: overall correctness of the model predictions.
  2. Precision: accuracy of positive predictions.
  3. Recall: the ability to find all positive instances.
  4. F1 Score: harmonic mean of precision and recall.
  5. AUC-ROC: area under the ROC curve, suitable for imbalanced datasets.
  6. Confusion Matrix: detailed breakdown of predictions.

Models

In this paragraph, we will cover the most common models for text classification. Most of them can be found in sklearn library, while others can be implemented with tensorflow or pytorch libraries.

SVM

The Support Vector Machines algorithm creates a line or a hyperplane that separates the data into classes. It performs exceptionally well with high-dimensional data but may not be as effective with limited samples.

Key Components:

  • Margin: The distance between either side of the line (or hyperplane) to the nearest data points is called the margin. It can be soft or hard, depending on the type of SVM used.
  • Kernel: a kernel function to transform text data into a higher-dimensional space, capturing complex word relationships.

GBDT

Gradient-Boosted Decision Trees is an algorithm that optimizes the model's predictive performance through an iterative learning process. It combines weak learners (usually decision trees) sequentially, with each new learner focusing on the residuals left by the previous step, which leads to a strong learner.

RNN + CNN

RNN is a recurrent neural network architecture that captures long dependencies between words in sentences. It is well-suited for text classification tasks. Combining RNN with CNN can further enhance the model's performance.

Approach:

  • RNN: utilizing RNN with word embeddings as input helps the model understand the sequential nature of the text and extract contextual information effectively.
  • CNN: applying CNN before RNN helps capture local context information, which complements the RNN's ability to grasp long-range dependencies.

You can use only RNN or CNN, but their combination is beneficial for understanding long and short dependencies.

BERT

BERT and its analogs are popular solutions suitable for various text classification tasks. It uses transformer-based architecture and bidirectional context to capture complex relationships in the text.

Features:

  • Pretrained Model: BERT can be used with a pre-trained classification model consisting of BERT followed by a linear layer at the end. This approach is simple and effective for many tasks.
  • Custom Implementation: BERT can be implemented from scratch and fine-tuned on the specific text classification problem. Additional numerical features extracted from the data can also be incorporated to improve the model's performance.

Main issues

In text classification, there are many problems that you can face. We will cover the most common of them.

  1. Dirty data preprocessing. Sometimes the data is dirty; there are many additional characters, such as emojis or hashtags. Classical machine learning algorithms perform better on clean texts, so you should perform tokenization, stop word removal, and stemming to standardize text. You can also use additional text types as URLs or emojis as features, but you can also remove them with regular expressions.
  2. Ambiguity and polysemy. Some words in your texts can be ambiguous. For example, play can be a verb or a noun. To work with such words, contextual embeddings like BERT, GPT, or ELMO capture word senses better.
  3. Out-of-vocabulary words. It is widespread that your model will commonly get previously unseen words during inference. To get better results for such words, use such tokenization methods as Byte-Pair Encoding (BPE) or WordPiece to handle unseen words.
  4. Lack of context. Sometimes models do not grasp the main context of the text to classify it. To better understand the context, use such models as recurrent neural networks (RNNs) or transformers or consider using language models that generate context-aware word embeddings.
  5. Lack of Annotated Data. In real-world scenarios, we don't have a lot of labeled texts. To navigate in such situations, use semi-supervised learning or transfer learning to get usage of unlabeled data. In addition, you can use data augmentation techniques like synonym replacement, back-translation, or paraphrasing to generate new data.
  6. Model Complexity. State-of-the-art models such as BERT can be heavy and take more training resources. To use these models, try to look for their smaller versions or quantize them to reduce model size and complexity. Moreover, you can add custom layers to train your models while leaving the main transformer model frozen so that it is not being trained during the process.
  7. Model Overfitting. Small models can overfit quite frequently, and big deep-learning models overfit fast on a small amount of data too. To overcome this problem, you can regularize the model using dropout, L1/L2 regularization, or early stopping techniques. Additionally, you can cross-validation to better analyze the model learning process.
  8. Handling multiple classes. To with multiple classes simultaneously, first, analyze the loss function you use. For multi-class problems, use appropriate loss functions like categorical cross-entropy; for multi-label problems, consider using binary cross-entropy with sigmoid activation. Secondly, analyze the presence of each class in your data and use the weighting technique to pay attention to each class equally or generate additional data for rare classes.

Conclusion

Categorizing texts based on their content is a crucial task in NLP, known as text classification. Commonly used models such as SVM, GBDT, RNN, CNN, and BERT are employed. However, there are challenges in managing ambiguity and incorporating context during data preprocessing. Choosing appropriate techniques and models to improve text classification system performance is essential.

2 learners liked this piece of theory. 0 didn't like it. What about you?
Report a typo