2020-01-30 | Tobias Sterbak


Data validation for NLP machine learning applications

An important part of machine learning applications, is making sure that there is no data degeneration while a model is in production. Sometimes downstream data processing changes and machine learning models are very prone to silent failure due to this. So data validation is a crucial step of every production machine learning pipeline. The case is relatively easy in the case of well-specified tabular data. But in the case of NLP it’s much harder to write down assumptions about the data and enforce them. I’ll show you some approaches to validate text data in machine learning use-cases.

Load the dataset

We will use the Twitter Disaster Dataset from kaggle. Let’s load it and have a quick look.

import numpy as np
import pandas as pd
from pprint import pprint
train = pd.read_csv("data/train.csv")
test = pd.read_csv("data/test.csv")
train.head()

idkeywordlocationtexttarget
01NaNNaNOur Deeds are the Reason of this #earthquake M...1
14NaNNaNForest fire near La Ronge Sask. Canada1
25NaNNaNAll residents asked to 'shelter in place' are ...1
36NaNNaN13,000 people receive #wildfires evacuation or...1
47NaNNaNJust got sent this photo from Ruby #Alaska as ...1

Write the schema

We have to properly write down the assumptions to the data we expect. These are called validation constraints. The [marshmallow] (https://github.com/marshmallow-code/marshmallow) library is really helpful for validating these schemes.

from marshmallow import Schema, fields, validate, ValidationError
from typing import List

First we create a tokenizer function to split the tweets into tokens. Here we use a very simple approach and just split at whitespace. In practice you would use something more sophisticated.

def tokenize(text: str) -> List[str]:
    """Split the given text into tokens by splitting at whitespace."""
    return text.split(" ")
vocabulary = list()
for doc in train.text:
    for token in tokenize(doc):
        if token not in vocabulary:
            vocabulary.append(token)
vocabulary = set(vocabulary)

First, we write a custom validation constrain function to check if there are enough words known to our vocabulary. Note, that the vocabulary is a property of the upstream machine leaning model.

def validate_in_vocabulary(text: str) -> bool:
    """
    Only allow for at most 10% of unknown words in the text.
    
    Raise ValidationError and print the unknown words.
    """
    tokenized_text = tokenize(text)
    unknown_words = [t for t in tokenized_text if t not in vocabulary]
    known_word_score = (len(tokenized_text) - len(unknown_words)) / len(tokenized_text)
    
    if known_word_score < 0.1 and len(unknown_words) > 0:
        raise ValidationError("To many unknown words: {}".format(unknown_words))

We use our custom validation function for the vocabulary alongside with build-in validators to check the length of the tweets. As shortest tweet we consider at most 7 characters. For the longest tweets, we allow up to 280 characters, the maximum of characters on twitter.

class TweetSchema(Schema):
    id = fields.Integer(required=True)
    text = fields.String(validate=[
        validate_in_vocabulary,
        validate.Length(min=7),
        validate.Length(max=280)
    ])

Validate testdata

Now we can use the schema defined above to validate new data inputs.

test_data = test[["id", "text"]].to_dict(orient="records")
try:
    print("## Received {} rows.\n".format(len(test_data)))
    result = TweetSchema(many=True).load(test_data)
except ValidationError as err:
    print("Error log:")
    pprint(err.messages)
    result = [d for i, d in enumerate(err.valid_data)
              if i not in err.messages.keys()]
finally:
    print("\n## Received {} valid rows, removed {} rows"
          .format(len(result), len(test_data) - len(result)))
## Received 3263 rows.

Error log:
{14: {'text': ["To many unknown words: ['Awesome!']"]},
 367: {'text': ["To many unknown words: ['Buying', 'MoP', "
                "'http://t.co/tl7o6Zsqzy']"]},
 394: {'text': ["To many unknown words: ['Name:', 'Chizu\\nGender:', "
                "'Male\\nAge:', '10\\nHair:', 'Red\\nEyes:', 'Pink\\nDere', "
                "'Type:', 'Pasokon\\nBlood', 'Type:', "
                "'O\\nhttp://t.co/cOyPF9ACTd']"]},
 748: {'text': ["To many unknown words: ['Vamos', 'Newells']"]},
 780: {'text': ['Shorter than minimum length 7.']},
 1433: {'text': ["To many unknown words: ['@yourgirlhaileyy', "
                 "'leaveevacuateexitbe', 'banished.']"]},
 1551: {'text': ["To many unknown words: ['@martinsymiguel', "
                 "'@FilipeCoelho92', 'FATALITY']"]},
 1554: {'text': ["To many unknown words: ['@Blawnndee', 'FATALITY!!!']"]},
 1561: {'text': ["To many unknown words: ['@GodHunt_sltv', 'FATALITY']"]},
 1824: {'text': ['Shorter than minimum length 7.']},
 1827: {'text': ['To many unknown words: [\'Graham\', "Phillips\'", '
                 "'Fundraiser', 'Canceled', '@JustGiving', '\\n\\n#fundraise', "
                 "'#Ukraine', '#donbas', '\\n\\nhttp://t.co/HIbEf3CXOX', "
                 "'http://t.co/9crFKQzD52']"]},
 2091: {'text': ["To many unknown words: ['RETWEET', "
                 "'#FOLLOWnGAIN\\n\\n@TheMGWVboss', '\\n\\n??#FOLLOW??', "
                 "'?~(', '??~\\x89Û¢)~?', '@ynovak', '@IAmMrCash', "
                 "'@Frankies_Style', '@MamanyaDana', '@Mayhem_4U']"]},
 2379: {'text': ["To many unknown words: ['rainstorm??']"]},
 2567: {'text': ["To many unknown words: ['@ShojoShit', '/SCREAMS']"]},
 2569: {'text': ["To many unknown words: ['*aggressively', 'screams*', "
                 "'https://t.co/8bHaejsUUt']"]},
 2571: {'text': ["To many unknown words: ['@melodores', '@Hozier', "
                 "'*SCREAMS*']"]},
 2614: {'text': ["To many unknown words: ['Sinking.', '{part', "
                 "'2}\\n??????\\n#buildingmuseum', '#TheBEACHDC', '#VSCOcam', "
                 "'#acreativedc', '#dc', '#dctography', '#vscodc\\x89Û_', "
                 "'https://t.co/SsD9ign6HO']"]},
 2853: {'text': ['To many unknown words: '
                 "['Truth...\\nhttps://t.co/GLzggDjQeH\\n#News\\n#BBC\\n#CNN\\n#Islam\\n#Truth\\n#god\\n#ISIS\\n#terrorism\\n#Quran\\n#Lies', "
                 "'http://t.co/MYtCbJ6nmh']"]},
 2860: {'text': ['To many unknown words: '
                 "['Truth...\\nhttps://t.co/Kix1j4ZyGx\\n#News\\n#BBC\\n#CNN\\n#Islam\\n#Truth\\n#god\\n#ISIS\\n#terrorism\\n#Quran\\n#Lies', "
                 "'http://t.co/pi6Qn7y7ql']"]},
 2863: {'text': ['To many unknown words: '
                 "['Truth...\\nhttps://t.co/n1K5nlib9X\\n#News\\n#BBC\\n#CNN\\n#Islam\\n#Truth\\n#god\\n#ISIS\\n#terrorism\\n#Quran\\n#Lies', "
                 "'http://t.co/CGz84MUOCZ']"]}}

## Received 3243 valid rows, removed 20 rows

This way we get a nice output, that we could format further and use for logging our data pipeline or machine learning API. Depending on your application, you have do decide how to handle violations of the validation constraints. For example, you could decide to not make a prediction in this case or trigger an alert.


Buy Me A Coffee



PrivacyImprintRSS

© depends-on-the-definition 2017-2022