Text is naturally sequential. A classical pipeline for sequence classification is:
We optimize binary cross-entropy: \(-\frac{1}{n}\sum_i [y_i\log \hat{y}_i + (1-y_i)\log(1-\hat{y}_i)]\).
In this lesson we:
We limit the vocabulary to the \(V\) most frequent words.
max_words <- 10000
max_len <- 200
imdb <- dataset_imdb(num_words = max_words)
x_train <- imdb$train$x
y_train <- imdb$train$y
x_test <- imdb$test$x
y_test <- imdb$test$y
length(x_train); length(y_train)## [1] 25000
## [1] 25000
## [1] 25000
## [1] 25000
## y_train
## 0 1
## 12500 12500
We need fixed-length tensors for batching. We pad/truncate to length \(T\).
x_train_pad <- pad_sequences(x_train, maxlen = max_len)
x_test_pad <- pad_sequences(x_test, maxlen = max_len)
dim(x_train_pad)## [1] 25000 200
## [1] 25000 200
We use:
embed_dim <- 128
input <- layer_input(shape = c(max_len), dtype = "int32", name = "tokens")
x <- input |>
layer_embedding(input_dim = max_words, output_dim = embed_dim, name = "embedding") |>
layer_bidirectional(layer_gru(units = 64, dropout = 0.2, recurrent_dropout = 0.0)) |>
layer_dense(units = 64, activation = "relu") |>
layer_dropout(rate = 0.3) |>
layer_dense(units = 1, activation = "sigmoid", name = "p_yes")
model <- keras_model(inputs = input, outputs = x)
model |>
compile(
optimizer = optimizer_adam(learning_rate = 1e-3),
loss = "binary_crossentropy",
metrics = list("accuracy", metric_auc(name = "auc"))
)
model## Model: "functional_9"
## ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
## ┃ Layer (type) ┃ Output Shape ┃ Param # ┃
## ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
## │ tokens (InputLayer) │ (None, 200) │ 0 │
## ├─────────────────────────────────────────────────┼──────────────────────────────────────┼──────────────────────┤
## │ embedding (Embedding) │ (None, 200, 128) │ 1,280,000 │
## ├─────────────────────────────────────────────────┼──────────────────────────────────────┼──────────────────────┤
## │ bidirectional (Bidirectional) │ (None, 128) │ 74,496 │
## ├─────────────────────────────────────────────────┼──────────────────────────────────────┼──────────────────────┤
## │ dense_15 (Dense) │ (None, 64) │ 8,256 │
## ├─────────────────────────────────────────────────┼──────────────────────────────────────┼──────────────────────┤
## │ dropout_4 (Dropout) │ (None, 64) │ 0 │
## ├─────────────────────────────────────────────────┼──────────────────────────────────────┼──────────────────────┤
## │ p_yes (Dense) │ (None, 1) │ 65 │
## └─────────────────────────────────────────────────┴──────────────────────────────────────┴──────────────────────┘
## Total params: 1,362,817 (5.20 MB)
## Trainable params: 1,362,817 (5.20 MB)
## Non-trainable params: 0 (0.00 B)
cb_es <- callback_early_stopping(
monitor = "val_auc",
mode = "max",
patience = 3,
restore_best_weights = TRUE
)
set.seed(123)
history <- model |>
fit(
x = X_tr, y = y_tr,
validation_data = list(X_va, y_va),
epochs = 20,
batch_size = 256,
callbacks = list(cb_es),
verbose = 2
)
plot(history) +
theme_minimal() +
ggtitle("Training history — GRU sentiment model")metrics <- model |>
evaluate(x_test_pad, y_test, verbose = 0)
setNames(as.numeric(metrics), names(metrics)) |>
round(3)## accuracy auc loss
## 0.848 0.938 0.381
A default threshold is 0.5, but business decisions often require optimizing for precision or recall.
p_test <- model |>
predict(x_test_pad)
pred_05 <- ifelse(p_test >= 0.5, 1, 0)
cm <- table(
truth = factor(y_test, levels = c(0, 1)),
pred = factor(pred_05, levels = c(0, 1))
)
cm## pred
## truth 0 1
## 0 11601 899
## 1 2894 9606
Compute accuracy, precision, recall, F1:
precision <- cm["0","0"] # placeholder to avoid confusion with indexing
tp <- cm["1","1"]; tn <- cm["0","0"]; fp <- cm["0","1"]; fn <- cm["1","0"]
acc <- (tp + tn) / sum(cm)
prec <- tp / (tp + fp)
rec <- tp / (tp + fn)
f1 <- 2 * (prec * rec) / (prec + rec)
c(accuracy = acc, precision = prec, recall = rec, f1 = f1) |>
round(3)## accuracy precision recall f1
## 0.848 0.914 0.768 0.835
A work by Gianluca Sottile
gianluca.sottile@unipa.it