In this post, I’m going to show you how you can use a neural network from keras with the LIME algorithm implemented in the eli5 TextExplainer class. For this we will write a scikit-learn compatible wrapper for a keras bidirectional LSTM model. The wrapper will also handle the tokenization and the storage of the vocabulary.
Get data set
from sklearn.datasets import fetch_20newsgroups
categories = ['alt.atheism', 'soc.religion.christian',
'comp.graphics', 'sci.med']
twenty_train = fetch_20newsgroups(
subset='train',
categories=categories,
shuffle=True,
random_state=42,
remove=('headers', 'footers'),
)
twenty_test = fetch_20newsgroups(
subset='test',
categories=categories,
shuffle=True,
random_state=42,
remove=('headers', 'footers'),
)
Setup with keras
from sklearn.base import BaseEstimator, TransformerMixin
from keras.models import Model, Input
from keras.layers import Dense, LSTM, Dropout, Embedding, SpatialDropout1D, Bidirectional, concatenate
from keras.layers import GlobalAveragePooling1D, GlobalMaxPooling1D
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from sklearn.metrics import accuracy_score
from eli5.lime import TextExplainer
import regex as re
import numpy as np
class KerasTextClassifier(BaseEstimator, TransformerMixin):
'''Wrapper class for keras text classification models that takes raw text as input.'''
def __init__(self, max_words=30000, input_length=100, emb_dim=20, n_classes=4, epochs=5, batch_size=32):
self.max_words = max_words
self.input_length = input_length
self.emb_dim = emb_dim
self.n_classes = n_classes
self.epochs = epochs
self.bs = batch_size
self.model = self._get_model()
self.tokenizer = Tokenizer(num_words=self.max_words+1,
lower=True, split=' ', oov_token="UNK")
def _get_model(self):
input_text = Input((self.input_length,))
text_embedding = Embedding(input_dim=self.max_words + 2, output_dim=self.emb_dim,
input_length=self.input_length, mask_zero=False)(input_text)
text_embedding = SpatialDropout1D(0.5)(text_embedding)
bilstm = Bidirectional(LSTM(units=32, return_sequences=True, recurrent_dropout=0.5))(text_embedding)
x = concatenate([GlobalAveragePooling1D()(bilstm), GlobalMaxPooling1D()(bilstm)])
x = Dropout(0.7)(x)
x = Dense(512, activation="relu")(x)
x = Dropout(0.6)(x)
x = Dense(512, activation="relu")(x)
x = Dropout(0.5)(x)
out = Dense(units=self.n_classes, activation="softmax")(x)
model = Model(input_text, out)
model.compile(optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
return model
def _get_sequences(self, texts):
seqs = self.tokenizer.texts_to_sequences(texts)
return pad_sequences(seqs, maxlen=self.input_length, value=0)
def _preprocess(self, texts):
return [re.sub(r"\d", "DIGIT", x) for x in texts]
def fit(self, X, y):
'''
Fit the vocabulary and the model.
:params:
X: list of texts.
y: labels.
'''
self.tokenizer.fit_on_texts(self._preprocess(X))
self.tokenizer.word_index = {e: i for e,i in self.tokenizer.word_index.items() if i <= self.max_words}
self.tokenizer.word_index[self.tokenizer.oov_token] = self.max_words + 1
seqs = self._get_sequences(self._preprocess(X))
self.model.fit(seqs, y, batch_size=self.bs, epochs=self.epochs, validation_split=0.1)
def predict_proba(self, X, y=None):
seqs = self._get_sequences(self._preprocess(X))
return self.model.predict(seqs)
def predict(self, X, y=None):
return np.argmax(self.predict_proba(X), axis=1)
def score(self, X, y):
y_pred = self.predict(X)
return accuracy_score(y, y_pred)
Using TensorFlow backend.
text_model = KerasTextClassifier(epochs=20, max_words=20000, input_length=200)
text_model.fit(twenty_train.data, twenty_train.target)
Train on 2031 samples, validate on 226 samples
Epoch 1/20
2031/2031 [==============================] - 21s 10ms/step - loss: 1.3766 - acc: 0.2767 - val_loss: 1.3677 - val_acc: 0.3496
Epoch 2/20
2031/2031 [==============================] - 19s 9ms/step - loss: 1.3355 - acc: 0.3673 - val_loss: 1.2756 - val_acc: 0.4071
Epoch 3/20
2031/2031 [==============================] - 16s 8ms/step - loss: 1.1739 - acc: 0.4407 - val_loss: 0.9695 - val_acc: 0.5398
Epoch 4/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.8664 - acc: 0.6041 - val_loss: 0.7929 - val_acc: 0.5973
Epoch 5/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.6665 - acc: 0.6736 - val_loss: 0.6011 - val_acc: 0.6991
Epoch 6/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.5710 - acc: 0.7282 - val_loss: 0.5969 - val_acc: 0.7212
Epoch 7/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.5109 - acc: 0.7691 - val_loss: 0.5813 - val_acc: 0.7434
Epoch 8/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.4502 - acc: 0.8129 - val_loss: 0.5595 - val_acc: 0.7566
Epoch 9/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.3907 - acc: 0.8237 - val_loss: 0.5311 - val_acc: 0.7788
Epoch 10/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.3382 - acc: 0.8641 - val_loss: 0.6223 - val_acc: 0.7788
Epoch 11/20
2031/2031 [==============================] - 17s 8ms/step - loss: 0.3191 - acc: 0.8764 - val_loss: 0.5141 - val_acc: 0.8186
Epoch 12/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.2929 - acc: 0.8902 - val_loss: 0.5169 - val_acc: 0.8363
Epoch 13/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.2764 - acc: 0.9050 - val_loss: 0.4989 - val_acc: 0.8451
Epoch 14/20
2031/2031 [==============================] - 17s 8ms/step - loss: 0.2509 - acc: 0.9010 - val_loss: 0.5256 - val_acc: 0.8274
Epoch 15/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.2375 - acc: 0.9060 - val_loss: 0.5371 - val_acc: 0.8319
Epoch 16/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.2090 - acc: 0.9222 - val_loss: 0.6175 - val_acc: 0.8186
Epoch 17/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.1660 - acc: 0.9385 - val_loss: 0.6829 - val_acc: 0.8053
Epoch 18/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.1819 - acc: 0.9325 - val_loss: 0.7315 - val_acc: 0.8142
Epoch 19/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.1876 - acc: 0.9261 - val_loss: 0.6830 - val_acc: 0.8319
Epoch 20/20
2031/2031 [==============================] - 16s 8ms/step - loss: 0.1493 - acc: 0.9478 - val_loss: 0.7457 - val_acc: 0.8363
text_model.score(twenty_test.data, twenty_test.target)
0.7316910785619174
doc = twenty_test.data[2]
te = TextExplainer(random_state=42)
te.fit(doc, text_model.predict_proba)
te.show_prediction(target_names=twenty_train.target_names)
y=alt.atheism (probability 0.098, score -2.154) top features
Contribution? | Feature |
---|---|
-0.016 | <BIAS> |
-2.138 | Highlighted in text (sum) |
in <1993apr29.112642.1@vms.ocom.okstate.edu> chorley@vms.ocom.okstate.edu writes: >
as a child i can remember picking up a centipede and getting a rather painful sting, but it quickly subsided. much less painful compared to a bee sting. centipedes have a poison claw (one of the front feet) to stun their prey, but in my single experience it did not have a lot of "bite" to it.
y=comp.graphics (probability 0.000, score -21.081) top features
Contribution? | Feature |
---|---|
-0.090 | <BIAS> |
-20.991 | Highlighted in text (sum) |
in <1993apr29.112642.1@vms.ocom.okstate.edu> chorley@vms.ocom.okstate.edu writes: >
as a child i can remember picking up a centipede and getting a rather painful sting, but it quickly subsided. much less painful compared to a bee sting. centipedes have a poison claw (one of the front feet) to stun their prey, but in my single experience it did not have a lot of "bite" to it.
y=sci.med (probability 0.902, score 3.084) top features
Contribution? | Feature |
---|---|
+3.504 | Highlighted in text (sum) |
-0.420 | <BIAS> |
in <1993apr29.112642.1@vms.ocom.okstate.edu> chorley@vms.ocom.okstate.edu writes: >
as a child i can remember picking up a centipede and getting a rather painful sting, but it quickly subsided. much less painful compared to a bee sting. centipedes have a poison claw (one of the front feet) to stun their prey, but in my single experience it did not have a lot of "bite" to it.
y=soc.religion.christian (probability 0.000, score -14.546) top features
Contribution? | Feature |
---|---|
-0.038 | <BIAS> |
-14.508 | Highlighted in text (sum) |
in <1993apr29.112642.1@vms.ocom.okstate.edu> chorley@vms.ocom.okstate.edu writes: >
as a child i can remember picking up a centipede and getting a rather painful sting, but it quickly subsided. much less painful compared to a bee sting. centipedes have a poison claw (one of the front feet) to stun their prey, but in my single experience it did not have a lot of "bite" to it.