BERT Example (Spam or Ham Classification)#

[1]:
!pip install -q tensorflow-text
     |████████████████████████████████| 4.9 MB 6.7 MB/s

[4]:
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
import csv
import matplotlib.pyplot as plt
import datetime

%matplotlib inline
%load_ext tensorboard
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard
[5]:
def check_gpu():
    device_name = tf.test.gpu_device_name()
    if device_name != '/device:GPU:0':
        return 'GPU device not found'
    return 'Found GPU at: {}'.format(device_name)
check_gpu()
[5]:
'Found GPU at: /device:GPU:0'
[7]:
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

Loading URLs#

[8]:
encoder_url = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4"
preprocessing_url = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"

Loading data#

[9]:
df = pd.read_csv("./SMSSpamCollection.csv", sep='\t', \
                 quoting=csv.QUOTE_NONE, names=["label", "message"])
df.head()
[9]:
label message
0 ham Go until jurong point, crazy.. Available only ...
1 ham Ok lar... Joking wif u oni...
2 spam Free entry in 2 a wkly comp to win FA Cup fina...
3 ham U dun say so early hor... U c already then say...
4 ham Nah I don't think he goes to usf, he lives aro...

Checking Class Count#

[10]:
df['label'].value_counts()
[10]:
ham     4827
spam     747
Name: label, dtype: int64

Treating Imbalanced Class#

Looks like there is an imbalance. There are two ways to handle it, either upsampling and down sampling.

[11]:
spam_df = df[df.label == 'spam']
ham_df = df[df.label == 'ham']
[12]:
ham_df = ham_df.sample(spam_df.shape[0])
[13]:
new_dataset = pd.concat([spam_df, ham_df]).sample(frac=1)

new_dataset.head()
[13]:
label message
1134 ham As I entered my cabin my PA said, '' Happy B'd...
5173 ham Oh k. . I will come tomorrow
4077 spam 87077: Kick off a new season with 2wks FREE go...
3189 spam This is the 2nd time we have tried 2 contact u...
4834 spam New Mobiles from 2004, MUST GO! Txt: NOKIA to ...
[14]:
new_dataset.label.value_counts()
[14]:
ham     747
spam    747
Name: label, dtype: int64

Target Class 1 hot encoding#

[15]:
new_dataset['spam'] = new_dataset.label.apply(lambda x: 1 if x=='spam' else 0)
new_dataset['ham'] = new_dataset.label.apply(lambda x: 1 if x=='ham' else 0)
new_dataset.head()
[15]:
label message spam ham
1134 ham As I entered my cabin my PA said, '' Happy B'd... 0 1
5173 ham Oh k. . I will come tomorrow 0 1
4077 spam 87077: Kick off a new season with 2wks FREE go... 1 0
3189 spam This is the 2nd time we have tried 2 contact u... 1 0
4834 spam New Mobiles from 2004, MUST GO! Txt: NOKIA to ... 1 0

Loading BERT Preprocessor#

[16]:
preprocessor = hub.KerasLayer(preprocessing_url)
bert_encoder = hub.KerasLayer(encoder_url)
[17]:
def get_embeddings(sentences):
    preprocessed_values = preprocessor(sentences)
    return bert_encoder(preprocessed_values)['pooled_output']
[18]:
embeddings = get_embeddings([
    'I am an artist',
    'He was writing',
    'I paint everyday'
])
[19]:
from sklearn.metrics.pairwise import cosine_similarity
[20]:
cosine_similarity([embeddings[0]],[embeddings[1]])
[20]:
array([[0.7917898]], dtype=float32)
[21]:
cosine_similarity([embeddings[1]],[embeddings[2]])
[21]:
array([[0.770146]], dtype=float32)
[22]:
cosine_similarity([embeddings[0]],[embeddings[2]])
[22]:
array([[0.99068]], dtype=float32)

Building Model#

Using tensorflow functional apis -

The Keras functional API is a way to create models that are more flexible than the tf.keras.Sequential API. The functional API can handle models with non-linear topology, shared layers, and even multiple inputs or outputs.

[23]:
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='Input-Text')
preprocessed_text = preprocessor(text_input)
encoded_text = bert_encoder(preprocessed_text)


dropout_output = tf.keras.layers.Dropout(0.2, name='Dropout-Layer')(encoded_text['pooled_output'])
relu_output = tf.keras.layers.Dense(units=64, activation='relu', name='Relu-Layer')(dropout_output)
output = tf.keras.layers.Dense(units=2, activation='softmax', name='Output-Layer')(relu_output)

model = tf.keras.Model(inputs=[text_input], outputs=[output])
[24]:
model.summary()
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to
==================================================================================================
 Input-Text (InputLayer)        [(None,)]            0           []

 keras_layer (KerasLayer)       {'input_word_ids':   0           ['Input-Text[0][0]']
                                (None, 128),
                                 'input_type_ids':
                                (None, 128),
                                 'input_mask': (Non
                                e, 128)}

 keras_layer_1 (KerasLayer)     {'encoder_outputs':  109482241   ['keras_layer[0][0]',
                                 [(None, 128, 768),               'keras_layer[0][1]',
                                 (None, 128, 768),                'keras_layer[0][2]']
                                 (None, 128, 768),
                                 (None, 128, 768),
                                 (None, 128, 768),
                                 (None, 128, 768),
                                 (None, 128, 768),
                                 (None, 128, 768),
                                 (None, 128, 768),
                                 (None, 128, 768),
                                 (None, 128, 768),
                                 (None, 128, 768)],
                                 'default': (None,
                                768),
                                 'pooled_output': (
                                None, 768),
                                 'sequence_output':
                                 (None, 128, 768)}

 Dropout-Layer (Dropout)        (None, 768)          0           ['keras_layer_1[0][13]']

 Relu-Layer (Dense)             (None, 64)           49216       ['Dropout-Layer[0][0]']

 Output-Layer (Dense)           (None, 2)            130         ['Relu-Layer[0][0]']

==================================================================================================
Total params: 109,531,587
Trainable params: 49,346
Non-trainable params: 109,482,241
__________________________________________________________________________________________________
[25]:
tf.keras.utils.plot_model(model,show_layer_activations=True)
[25]:
../_images/Concepts_bert_example_29_0.png
[26]:
METRICS = [
    tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
    tf.keras.metrics.Precision(name='precision'),
    tf.keras.metrics.Recall(name='recall')
]

model.compile(
    optimizer='adam',
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=METRICS
)
[27]:
%tensorboard --logdir logs/fit
[28]:
history = model.fit(new_dataset['message'], new_dataset[['spam','ham']], epochs=30, validation_split=0.1,\
                    shuffle=True, batch_size=32, callbacks=[tensorboard_callback])
Epoch 1/30
42/42 [==============================] - 47s 898ms/step - loss: 0.5015 - accuracy: 0.7500 - precision: 0.7500 - recall: 0.7500 - val_loss: 0.2910 - val_accuracy: 0.9133 - val_precision: 0.9133 - val_recall: 0.9133
Epoch 2/30
42/42 [==============================] - 37s 902ms/step - loss: 0.3127 - accuracy: 0.8705 - precision: 0.8705 - recall: 0.8705 - val_loss: 0.2676 - val_accuracy: 0.9067 - val_precision: 0.9067 - val_recall: 0.9067
Epoch 3/30
42/42 [==============================] - 34s 823ms/step - loss: 0.2545 - accuracy: 0.9010 - precision: 0.9010 - recall: 0.9010 - val_loss: 0.1959 - val_accuracy: 0.9333 - val_precision: 0.9333 - val_recall: 0.9333
Epoch 4/30
42/42 [==============================] - 38s 911ms/step - loss: 0.2192 - accuracy: 0.9204 - precision: 0.9204 - recall: 0.9204 - val_loss: 0.1943 - val_accuracy: 0.9400 - val_precision: 0.9400 - val_recall: 0.9400
Epoch 5/30
42/42 [==============================] - 37s 890ms/step - loss: 0.2235 - accuracy: 0.9167 - precision: 0.9167 - recall: 0.9167 - val_loss: 0.1607 - val_accuracy: 0.9400 - val_precision: 0.9400 - val_recall: 0.9400
Epoch 6/30
42/42 [==============================] - 35s 851ms/step - loss: 0.1903 - accuracy: 0.9278 - precision: 0.9278 - recall: 0.9278 - val_loss: 0.1431 - val_accuracy: 0.9467 - val_precision: 0.9467 - val_recall: 0.9467
Epoch 7/30
42/42 [==============================] - 38s 906ms/step - loss: 0.1804 - accuracy: 0.9338 - precision: 0.9338 - recall: 0.9338 - val_loss: 0.1351 - val_accuracy: 0.9467 - val_precision: 0.9467 - val_recall: 0.9467
Epoch 8/30
42/42 [==============================] - 38s 914ms/step - loss: 0.1627 - accuracy: 0.9427 - precision: 0.9427 - recall: 0.9427 - val_loss: 0.1531 - val_accuracy: 0.9533 - val_precision: 0.9533 - val_recall: 0.9533
Epoch 9/30
42/42 [==============================] - 37s 894ms/step - loss: 0.1891 - accuracy: 0.9308 - precision: 0.9308 - recall: 0.9308 - val_loss: 0.1272 - val_accuracy: 0.9533 - val_precision: 0.9533 - val_recall: 0.9533
Epoch 10/30
42/42 [==============================] - 39s 930ms/step - loss: 0.1659 - accuracy: 0.9382 - precision: 0.9382 - recall: 0.9382 - val_loss: 0.1206 - val_accuracy: 0.9667 - val_precision: 0.9667 - val_recall: 0.9667
Epoch 11/30
42/42 [==============================] - 37s 896ms/step - loss: 0.1520 - accuracy: 0.9494 - precision: 0.9494 - recall: 0.9494 - val_loss: 0.1295 - val_accuracy: 0.9600 - val_precision: 0.9600 - val_recall: 0.9600
Epoch 12/30
42/42 [==============================] - 40s 964ms/step - loss: 0.1570 - accuracy: 0.9479 - precision: 0.9479 - recall: 0.9479 - val_loss: 0.1130 - val_accuracy: 0.9600 - val_precision: 0.9600 - val_recall: 0.9600
Epoch 13/30
42/42 [==============================] - 40s 963ms/step - loss: 0.1440 - accuracy: 0.9516 - precision: 0.9516 - recall: 0.9516 - val_loss: 0.1106 - val_accuracy: 0.9800 - val_precision: 0.9800 - val_recall: 0.9800
Epoch 14/30
42/42 [==============================] - 40s 962ms/step - loss: 0.1413 - accuracy: 0.9524 - precision: 0.9524 - recall: 0.9524 - val_loss: 0.1045 - val_accuracy: 0.9733 - val_precision: 0.9733 - val_recall: 0.9733
Epoch 15/30
42/42 [==============================] - 38s 928ms/step - loss: 0.1528 - accuracy: 0.9494 - precision: 0.9494 - recall: 0.9494 - val_loss: 0.1398 - val_accuracy: 0.9667 - val_precision: 0.9667 - val_recall: 0.9667
Epoch 16/30
42/42 [==============================] - 39s 954ms/step - loss: 0.1319 - accuracy: 0.9494 - precision: 0.9494 - recall: 0.9494 - val_loss: 0.1044 - val_accuracy: 0.9800 - val_precision: 0.9800 - val_recall: 0.9800
Epoch 17/30
42/42 [==============================] - 40s 977ms/step - loss: 0.1409 - accuracy: 0.9509 - precision: 0.9509 - recall: 0.9509 - val_loss: 0.1031 - val_accuracy: 0.9733 - val_precision: 0.9733 - val_recall: 0.9733
Epoch 18/30
42/42 [==============================] - 40s 958ms/step - loss: 0.1509 - accuracy: 0.9487 - precision: 0.9487 - recall: 0.9487 - val_loss: 0.1307 - val_accuracy: 0.9600 - val_precision: 0.9600 - val_recall: 0.9600
Epoch 19/30
42/42 [==============================] - 37s 894ms/step - loss: 0.1427 - accuracy: 0.9487 - precision: 0.9487 - recall: 0.9487 - val_loss: 0.0957 - val_accuracy: 0.9800 - val_precision: 0.9800 - val_recall: 0.9800
Epoch 20/30
42/42 [==============================] - 39s 945ms/step - loss: 0.1364 - accuracy: 0.9479 - precision: 0.9479 - recall: 0.9479 - val_loss: 0.1113 - val_accuracy: 0.9667 - val_precision: 0.9667 - val_recall: 0.9667
Epoch 21/30
42/42 [==============================] - 38s 930ms/step - loss: 0.1402 - accuracy: 0.9516 - precision: 0.9516 - recall: 0.9516 - val_loss: 0.1447 - val_accuracy: 0.9667 - val_precision: 0.9667 - val_recall: 0.9667
Epoch 22/30
42/42 [==============================] - 40s 955ms/step - loss: 0.1229 - accuracy: 0.9591 - precision: 0.9591 - recall: 0.9591 - val_loss: 0.1055 - val_accuracy: 0.9667 - val_precision: 0.9667 - val_recall: 0.9667
Epoch 23/30
42/42 [==============================] - 40s 969ms/step - loss: 0.1334 - accuracy: 0.9516 - precision: 0.9516 - recall: 0.9516 - val_loss: 0.0972 - val_accuracy: 0.9733 - val_precision: 0.9733 - val_recall: 0.9733
Epoch 24/30
42/42 [==============================] - 41s 990ms/step - loss: 0.1255 - accuracy: 0.9568 - precision: 0.9568 - recall: 0.9568 - val_loss: 0.0941 - val_accuracy: 0.9800 - val_precision: 0.9800 - val_recall: 0.9800
Epoch 25/30
42/42 [==============================] - 40s 959ms/step - loss: 0.1286 - accuracy: 0.9531 - precision: 0.9531 - recall: 0.9531 - val_loss: 0.0980 - val_accuracy: 0.9800 - val_precision: 0.9800 - val_recall: 0.9800
Epoch 26/30
42/42 [==============================] - 38s 927ms/step - loss: 0.1511 - accuracy: 0.9435 - precision: 0.9435 - recall: 0.9435 - val_loss: 0.2420 - val_accuracy: 0.9133 - val_precision: 0.9133 - val_recall: 0.9133
Epoch 27/30
42/42 [==============================] - 37s 904ms/step - loss: 0.1655 - accuracy: 0.9382 - precision: 0.9382 - recall: 0.9382 - val_loss: 0.1260 - val_accuracy: 0.9667 - val_precision: 0.9667 - val_recall: 0.9667
Epoch 28/30
42/42 [==============================] - 38s 916ms/step - loss: 0.1237 - accuracy: 0.9561 - precision: 0.9561 - recall: 0.9561 - val_loss: 0.0966 - val_accuracy: 0.9733 - val_precision: 0.9733 - val_recall: 0.9733
Epoch 29/30
42/42 [==============================] - 38s 908ms/step - loss: 0.1300 - accuracy: 0.9583 - precision: 0.9583 - recall: 0.9583 - val_loss: 0.0945 - val_accuracy: 0.9800 - val_precision: 0.9800 - val_recall: 0.9800
Epoch 30/30
42/42 [==============================] - 39s 936ms/step - loss: 0.1284 - accuracy: 0.9516 - precision: 0.9516 - recall: 0.9516 - val_loss: 0.1394 - val_accuracy: 0.9733 - val_precision: 0.9733 - val_recall: 0.9733
[29]:
history_df = pd.DataFrame(history.history)
history_df['epochs'] = history.epoch
[30]:
history_df.head(3)
[30]:
loss accuracy precision recall val_loss val_accuracy val_precision val_recall epochs
0 0.501539 0.750000 0.750000 0.750000 0.290969 0.913333 0.913333 0.913333 0
1 0.312725 0.870536 0.870536 0.870536 0.267604 0.906667 0.906667 0.906667 1
2 0.254531 0.901042 0.901042 0.901042 0.195910 0.933333 0.933333 0.933333 2
[31]:

fig, ax = plt.subplots(2,3, figsize=(15,10)) history_df.plot(x='epochs', y=['loss', 'val_loss'], ax=ax[0][0]) history_df.plot(x='epochs', y=['accuracy', 'val_accuracy'], ax=ax[0][1]) history_df.plot(x='epochs', y=['precision','val_precision'], ax=ax[0][2]) history_df.plot(x='epochs', y=['recall','val_recall'], ax=ax[1][0]) history_df.plot(x='epochs', y=['precision','recall'], ax=ax[1][1]) history_df.plot(x='epochs', y=['val_precision','val_recall'], ax=ax[1][2]) plt.show()
../_images/Concepts_bert_example_35_0.png
[34]:
reviews = [
    'Enter a chance to win $5000, hurry up, offer valid until march 31, 2021',
    'You are awarded a SiPix Digital Camera! call 09061221061 from landline. Delivery within 28days. T Cs Box177. M221BP. 2yr warranty. 150ppm. 16 . p p£3.99',
    'it to 80488. Your 500 free text messages are valid until 31 December 2005.',
    'Hey Sam, Are you coming for a cricket game tomorrow',
    "Why don't you wait 'til at least wednesday to see if you get your ."
]
preds = model.predict(reviews)
[36]:
reviews_df = pd.DataFrame(preds, columns=['spam','ham'])
reviews_df['message'] = reviews
reviews_df
[36]:
spam ham message
0 0.724841 0.275158 Enter a chance to win $5000, hurry up, offer v...
1 0.985587 0.014413 You are awarded a SiPix Digital Camera! call 0...
2 0.512872 0.487128 it to 80488. Your 500 free text messages are v...
3 0.003944 0.996056 Hey Sam, Are you coming for a cricket game tom...
4 0.000492 0.999508 Why don't you wait 'til at least wednesday to ...
[45]:
predictions = np.argmax(model.predict(new_dataset.message), axis=1)
[47]:
original = np.argmax(new_dataset[['spam','ham']].values, axis=1)

1 is ham and 0 is spam

[53]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
[58]:
cm = confusion_matrix(original, predictions)
sns.heatmap(cm, annot=True, fmt='d')
plt.xlabel('Predicted')
plt.ylabel('Truth')
[58]:
Text(33.0, 0.5, 'Truth')
../_images/Concepts_bert_example_42_1.png