BERT Example (Spam or Ham Classification)#
[1]:
!pip install -q tensorflow-text
|████████████████████████████████| 4.9 MB 6.7 MB/s
[4]:
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
import csv
import matplotlib.pyplot as plt
import datetime
%matplotlib inline
%load_ext tensorboard
The tensorboard extension is already loaded. To reload it, use:
%reload_ext tensorboard
[5]:
def check_gpu():
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
return 'GPU device not found'
return 'Found GPU at: {}'.format(device_name)
check_gpu()
[5]:
'Found GPU at: /device:GPU:0'
[7]:
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
Loading URLs#
[8]:
encoder_url = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4"
preprocessing_url = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
Loading data#
[9]:
df = pd.read_csv("./SMSSpamCollection.csv", sep='\t', \
quoting=csv.QUOTE_NONE, names=["label", "message"])
df.head()
[9]:
label | message | |
---|---|---|
0 | ham | Go until jurong point, crazy.. Available only ... |
1 | ham | Ok lar... Joking wif u oni... |
2 | spam | Free entry in 2 a wkly comp to win FA Cup fina... |
3 | ham | U dun say so early hor... U c already then say... |
4 | ham | Nah I don't think he goes to usf, he lives aro... |
Checking Class Count#
[10]:
df['label'].value_counts()
[10]:
ham 4827
spam 747
Name: label, dtype: int64
Treating Imbalanced Class#
Looks like there is an imbalance. There are two ways to handle it, either upsampling and down sampling.
[11]:
spam_df = df[df.label == 'spam']
ham_df = df[df.label == 'ham']
[12]:
ham_df = ham_df.sample(spam_df.shape[0])
[13]:
new_dataset = pd.concat([spam_df, ham_df]).sample(frac=1)
new_dataset.head()
[13]:
label | message | |
---|---|---|
1134 | ham | As I entered my cabin my PA said, '' Happy B'd... |
5173 | ham | Oh k. . I will come tomorrow |
4077 | spam | 87077: Kick off a new season with 2wks FREE go... |
3189 | spam | This is the 2nd time we have tried 2 contact u... |
4834 | spam | New Mobiles from 2004, MUST GO! Txt: NOKIA to ... |
[14]:
new_dataset.label.value_counts()
[14]:
ham 747
spam 747
Name: label, dtype: int64
Target Class 1 hot encoding#
[15]:
new_dataset['spam'] = new_dataset.label.apply(lambda x: 1 if x=='spam' else 0)
new_dataset['ham'] = new_dataset.label.apply(lambda x: 1 if x=='ham' else 0)
new_dataset.head()
[15]:
label | message | spam | ham | |
---|---|---|---|---|
1134 | ham | As I entered my cabin my PA said, '' Happy B'd... | 0 | 1 |
5173 | ham | Oh k. . I will come tomorrow | 0 | 1 |
4077 | spam | 87077: Kick off a new season with 2wks FREE go... | 1 | 0 |
3189 | spam | This is the 2nd time we have tried 2 contact u... | 1 | 0 |
4834 | spam | New Mobiles from 2004, MUST GO! Txt: NOKIA to ... | 1 | 0 |
Loading BERT Preprocessor#
[16]:
preprocessor = hub.KerasLayer(preprocessing_url)
bert_encoder = hub.KerasLayer(encoder_url)
[17]:
def get_embeddings(sentences):
preprocessed_values = preprocessor(sentences)
return bert_encoder(preprocessed_values)['pooled_output']
[18]:
embeddings = get_embeddings([
'I am an artist',
'He was writing',
'I paint everyday'
])
[19]:
from sklearn.metrics.pairwise import cosine_similarity
[20]:
cosine_similarity([embeddings[0]],[embeddings[1]])
[20]:
array([[0.7917898]], dtype=float32)
[21]:
cosine_similarity([embeddings[1]],[embeddings[2]])
[21]:
array([[0.770146]], dtype=float32)
[22]:
cosine_similarity([embeddings[0]],[embeddings[2]])
[22]:
array([[0.99068]], dtype=float32)
Building Model#
Using tensorflow functional apis -
The Keras functional API is a way to create models that are more flexible than the tf.keras.Sequential API. The functional API can handle models with non-linear topology, shared layers, and even multiple inputs or outputs.
[23]:
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='Input-Text')
preprocessed_text = preprocessor(text_input)
encoded_text = bert_encoder(preprocessed_text)
dropout_output = tf.keras.layers.Dropout(0.2, name='Dropout-Layer')(encoded_text['pooled_output'])
relu_output = tf.keras.layers.Dense(units=64, activation='relu', name='Relu-Layer')(dropout_output)
output = tf.keras.layers.Dense(units=2, activation='softmax', name='Output-Layer')(relu_output)
model = tf.keras.Model(inputs=[text_input], outputs=[output])
[24]:
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
Input-Text (InputLayer) [(None,)] 0 []
keras_layer (KerasLayer) {'input_word_ids': 0 ['Input-Text[0][0]']
(None, 128),
'input_type_ids':
(None, 128),
'input_mask': (Non
e, 128)}
keras_layer_1 (KerasLayer) {'encoder_outputs': 109482241 ['keras_layer[0][0]',
[(None, 128, 768), 'keras_layer[0][1]',
(None, 128, 768), 'keras_layer[0][2]']
(None, 128, 768),
(None, 128, 768),
(None, 128, 768),
(None, 128, 768),
(None, 128, 768),
(None, 128, 768),
(None, 128, 768),
(None, 128, 768),
(None, 128, 768),
(None, 128, 768)],
'default': (None,
768),
'pooled_output': (
None, 768),
'sequence_output':
(None, 128, 768)}
Dropout-Layer (Dropout) (None, 768) 0 ['keras_layer_1[0][13]']
Relu-Layer (Dense) (None, 64) 49216 ['Dropout-Layer[0][0]']
Output-Layer (Dense) (None, 2) 130 ['Relu-Layer[0][0]']
==================================================================================================
Total params: 109,531,587
Trainable params: 49,346
Non-trainable params: 109,482,241
__________________________________________________________________________________________________
[25]:
tf.keras.utils.plot_model(model,show_layer_activations=True)
[25]:
[26]:
METRICS = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall')
]
model.compile(
optimizer='adam',
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=METRICS
)
[27]:
%tensorboard --logdir logs/fit
[28]:
history = model.fit(new_dataset['message'], new_dataset[['spam','ham']], epochs=30, validation_split=0.1,\
shuffle=True, batch_size=32, callbacks=[tensorboard_callback])
Epoch 1/30
42/42 [==============================] - 47s 898ms/step - loss: 0.5015 - accuracy: 0.7500 - precision: 0.7500 - recall: 0.7500 - val_loss: 0.2910 - val_accuracy: 0.9133 - val_precision: 0.9133 - val_recall: 0.9133
Epoch 2/30
42/42 [==============================] - 37s 902ms/step - loss: 0.3127 - accuracy: 0.8705 - precision: 0.8705 - recall: 0.8705 - val_loss: 0.2676 - val_accuracy: 0.9067 - val_precision: 0.9067 - val_recall: 0.9067
Epoch 3/30
42/42 [==============================] - 34s 823ms/step - loss: 0.2545 - accuracy: 0.9010 - precision: 0.9010 - recall: 0.9010 - val_loss: 0.1959 - val_accuracy: 0.9333 - val_precision: 0.9333 - val_recall: 0.9333
Epoch 4/30
42/42 [==============================] - 38s 911ms/step - loss: 0.2192 - accuracy: 0.9204 - precision: 0.9204 - recall: 0.9204 - val_loss: 0.1943 - val_accuracy: 0.9400 - val_precision: 0.9400 - val_recall: 0.9400
Epoch 5/30
42/42 [==============================] - 37s 890ms/step - loss: 0.2235 - accuracy: 0.9167 - precision: 0.9167 - recall: 0.9167 - val_loss: 0.1607 - val_accuracy: 0.9400 - val_precision: 0.9400 - val_recall: 0.9400
Epoch 6/30
42/42 [==============================] - 35s 851ms/step - loss: 0.1903 - accuracy: 0.9278 - precision: 0.9278 - recall: 0.9278 - val_loss: 0.1431 - val_accuracy: 0.9467 - val_precision: 0.9467 - val_recall: 0.9467
Epoch 7/30
42/42 [==============================] - 38s 906ms/step - loss: 0.1804 - accuracy: 0.9338 - precision: 0.9338 - recall: 0.9338 - val_loss: 0.1351 - val_accuracy: 0.9467 - val_precision: 0.9467 - val_recall: 0.9467
Epoch 8/30
42/42 [==============================] - 38s 914ms/step - loss: 0.1627 - accuracy: 0.9427 - precision: 0.9427 - recall: 0.9427 - val_loss: 0.1531 - val_accuracy: 0.9533 - val_precision: 0.9533 - val_recall: 0.9533
Epoch 9/30
42/42 [==============================] - 37s 894ms/step - loss: 0.1891 - accuracy: 0.9308 - precision: 0.9308 - recall: 0.9308 - val_loss: 0.1272 - val_accuracy: 0.9533 - val_precision: 0.9533 - val_recall: 0.9533
Epoch 10/30
42/42 [==============================] - 39s 930ms/step - loss: 0.1659 - accuracy: 0.9382 - precision: 0.9382 - recall: 0.9382 - val_loss: 0.1206 - val_accuracy: 0.9667 - val_precision: 0.9667 - val_recall: 0.9667
Epoch 11/30
42/42 [==============================] - 37s 896ms/step - loss: 0.1520 - accuracy: 0.9494 - precision: 0.9494 - recall: 0.9494 - val_loss: 0.1295 - val_accuracy: 0.9600 - val_precision: 0.9600 - val_recall: 0.9600
Epoch 12/30
42/42 [==============================] - 40s 964ms/step - loss: 0.1570 - accuracy: 0.9479 - precision: 0.9479 - recall: 0.9479 - val_loss: 0.1130 - val_accuracy: 0.9600 - val_precision: 0.9600 - val_recall: 0.9600
Epoch 13/30
42/42 [==============================] - 40s 963ms/step - loss: 0.1440 - accuracy: 0.9516 - precision: 0.9516 - recall: 0.9516 - val_loss: 0.1106 - val_accuracy: 0.9800 - val_precision: 0.9800 - val_recall: 0.9800
Epoch 14/30
42/42 [==============================] - 40s 962ms/step - loss: 0.1413 - accuracy: 0.9524 - precision: 0.9524 - recall: 0.9524 - val_loss: 0.1045 - val_accuracy: 0.9733 - val_precision: 0.9733 - val_recall: 0.9733
Epoch 15/30
42/42 [==============================] - 38s 928ms/step - loss: 0.1528 - accuracy: 0.9494 - precision: 0.9494 - recall: 0.9494 - val_loss: 0.1398 - val_accuracy: 0.9667 - val_precision: 0.9667 - val_recall: 0.9667
Epoch 16/30
42/42 [==============================] - 39s 954ms/step - loss: 0.1319 - accuracy: 0.9494 - precision: 0.9494 - recall: 0.9494 - val_loss: 0.1044 - val_accuracy: 0.9800 - val_precision: 0.9800 - val_recall: 0.9800
Epoch 17/30
42/42 [==============================] - 40s 977ms/step - loss: 0.1409 - accuracy: 0.9509 - precision: 0.9509 - recall: 0.9509 - val_loss: 0.1031 - val_accuracy: 0.9733 - val_precision: 0.9733 - val_recall: 0.9733
Epoch 18/30
42/42 [==============================] - 40s 958ms/step - loss: 0.1509 - accuracy: 0.9487 - precision: 0.9487 - recall: 0.9487 - val_loss: 0.1307 - val_accuracy: 0.9600 - val_precision: 0.9600 - val_recall: 0.9600
Epoch 19/30
42/42 [==============================] - 37s 894ms/step - loss: 0.1427 - accuracy: 0.9487 - precision: 0.9487 - recall: 0.9487 - val_loss: 0.0957 - val_accuracy: 0.9800 - val_precision: 0.9800 - val_recall: 0.9800
Epoch 20/30
42/42 [==============================] - 39s 945ms/step - loss: 0.1364 - accuracy: 0.9479 - precision: 0.9479 - recall: 0.9479 - val_loss: 0.1113 - val_accuracy: 0.9667 - val_precision: 0.9667 - val_recall: 0.9667
Epoch 21/30
42/42 [==============================] - 38s 930ms/step - loss: 0.1402 - accuracy: 0.9516 - precision: 0.9516 - recall: 0.9516 - val_loss: 0.1447 - val_accuracy: 0.9667 - val_precision: 0.9667 - val_recall: 0.9667
Epoch 22/30
42/42 [==============================] - 40s 955ms/step - loss: 0.1229 - accuracy: 0.9591 - precision: 0.9591 - recall: 0.9591 - val_loss: 0.1055 - val_accuracy: 0.9667 - val_precision: 0.9667 - val_recall: 0.9667
Epoch 23/30
42/42 [==============================] - 40s 969ms/step - loss: 0.1334 - accuracy: 0.9516 - precision: 0.9516 - recall: 0.9516 - val_loss: 0.0972 - val_accuracy: 0.9733 - val_precision: 0.9733 - val_recall: 0.9733
Epoch 24/30
42/42 [==============================] - 41s 990ms/step - loss: 0.1255 - accuracy: 0.9568 - precision: 0.9568 - recall: 0.9568 - val_loss: 0.0941 - val_accuracy: 0.9800 - val_precision: 0.9800 - val_recall: 0.9800
Epoch 25/30
42/42 [==============================] - 40s 959ms/step - loss: 0.1286 - accuracy: 0.9531 - precision: 0.9531 - recall: 0.9531 - val_loss: 0.0980 - val_accuracy: 0.9800 - val_precision: 0.9800 - val_recall: 0.9800
Epoch 26/30
42/42 [==============================] - 38s 927ms/step - loss: 0.1511 - accuracy: 0.9435 - precision: 0.9435 - recall: 0.9435 - val_loss: 0.2420 - val_accuracy: 0.9133 - val_precision: 0.9133 - val_recall: 0.9133
Epoch 27/30
42/42 [==============================] - 37s 904ms/step - loss: 0.1655 - accuracy: 0.9382 - precision: 0.9382 - recall: 0.9382 - val_loss: 0.1260 - val_accuracy: 0.9667 - val_precision: 0.9667 - val_recall: 0.9667
Epoch 28/30
42/42 [==============================] - 38s 916ms/step - loss: 0.1237 - accuracy: 0.9561 - precision: 0.9561 - recall: 0.9561 - val_loss: 0.0966 - val_accuracy: 0.9733 - val_precision: 0.9733 - val_recall: 0.9733
Epoch 29/30
42/42 [==============================] - 38s 908ms/step - loss: 0.1300 - accuracy: 0.9583 - precision: 0.9583 - recall: 0.9583 - val_loss: 0.0945 - val_accuracy: 0.9800 - val_precision: 0.9800 - val_recall: 0.9800
Epoch 30/30
42/42 [==============================] - 39s 936ms/step - loss: 0.1284 - accuracy: 0.9516 - precision: 0.9516 - recall: 0.9516 - val_loss: 0.1394 - val_accuracy: 0.9733 - val_precision: 0.9733 - val_recall: 0.9733
[29]:
history_df = pd.DataFrame(history.history)
history_df['epochs'] = history.epoch
[30]:
history_df.head(3)
[30]:
loss | accuracy | precision | recall | val_loss | val_accuracy | val_precision | val_recall | epochs | |
---|---|---|---|---|---|---|---|---|---|
0 | 0.501539 | 0.750000 | 0.750000 | 0.750000 | 0.290969 | 0.913333 | 0.913333 | 0.913333 | 0 |
1 | 0.312725 | 0.870536 | 0.870536 | 0.870536 | 0.267604 | 0.906667 | 0.906667 | 0.906667 | 1 |
2 | 0.254531 | 0.901042 | 0.901042 | 0.901042 | 0.195910 | 0.933333 | 0.933333 | 0.933333 | 2 |
[31]:
fig, ax = plt.subplots(2,3, figsize=(15,10))
history_df.plot(x='epochs', y=['loss', 'val_loss'], ax=ax[0][0])
history_df.plot(x='epochs', y=['accuracy', 'val_accuracy'], ax=ax[0][1])
history_df.plot(x='epochs', y=['precision','val_precision'], ax=ax[0][2])
history_df.plot(x='epochs', y=['recall','val_recall'], ax=ax[1][0])
history_df.plot(x='epochs', y=['precision','recall'], ax=ax[1][1])
history_df.plot(x='epochs', y=['val_precision','val_recall'], ax=ax[1][2])
plt.show()
[34]:
reviews = [
'Enter a chance to win $5000, hurry up, offer valid until march 31, 2021',
'You are awarded a SiPix Digital Camera! call 09061221061 from landline. Delivery within 28days. T Cs Box177. M221BP. 2yr warranty. 150ppm. 16 . p p£3.99',
'it to 80488. Your 500 free text messages are valid until 31 December 2005.',
'Hey Sam, Are you coming for a cricket game tomorrow',
"Why don't you wait 'til at least wednesday to see if you get your ."
]
preds = model.predict(reviews)
[36]:
reviews_df = pd.DataFrame(preds, columns=['spam','ham'])
reviews_df['message'] = reviews
reviews_df
[36]:
spam | ham | message | |
---|---|---|---|
0 | 0.724841 | 0.275158 | Enter a chance to win $5000, hurry up, offer v... |
1 | 0.985587 | 0.014413 | You are awarded a SiPix Digital Camera! call 0... |
2 | 0.512872 | 0.487128 | it to 80488. Your 500 free text messages are v... |
3 | 0.003944 | 0.996056 | Hey Sam, Are you coming for a cricket game tom... |
4 | 0.000492 | 0.999508 | Why don't you wait 'til at least wednesday to ... |
[45]:
predictions = np.argmax(model.predict(new_dataset.message), axis=1)
[47]:
original = np.argmax(new_dataset[['spam','ham']].values, axis=1)
1 is ham and 0 is spam
[53]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
[58]:
cm = confusion_matrix(original, predictions)
sns.heatmap(cm, annot=True, fmt='d')
plt.xlabel('Predicted')
plt.ylabel('Truth')
[58]:
Text(33.0, 0.5, 'Truth')