LSTM/GRU for Text Classification

Text is naturally sequential. A classical pipeline for sequence classification is:

  1. Convert tokens to integer ids: \(w_1, \dots, w_T\)
  2. Embed tokens into vectors: \(e_t = E[w_t]\), with \(E \in \mathbb{R}^{V \times d}\)
  3. Apply a recurrent layer (LSTM or GRU) to obtain a sequence representation
  4. Predict label via logistic output: \(\hat{y} = \sigma(u^\top h + b)\)

We optimize binary cross-entropy: \(-\frac{1}{n}\sum_i [y_i\log \hat{y}_i + (1-y_i)\log(1-\hat{y}_i)]\).

In this lesson we:

  • load the IMDB dataset (pre-tokenized sequences),
  • pad sequences to fixed length,
  • fit an embedding + LSTM or GRU classifier,
  • regularize via dropout,
  • validate via early stopping, and
  • report performance on a held-out test set.

Step 1: Load the dataset

We limit the vocabulary to the \(V\) most frequent words.

max_words <- 10000
max_len <- 200

imdb <- dataset_imdb(num_words = max_words)

x_train <- imdb$train$x
y_train <- imdb$train$y

x_test <- imdb$test$x
y_test <- imdb$test$y

length(x_train); length(y_train)
## [1] 25000
## [1] 25000
length(x_test);  length(y_test)
## [1] 25000
## [1] 25000
table(y_train)
## y_train
##     0     1 
## 12500 12500

Step 2: Pad sequences

We need fixed-length tensors for batching. We pad/truncate to length \(T\).

x_train_pad <- pad_sequences(x_train, maxlen = max_len)
x_test_pad  <- pad_sequences(x_test,  maxlen = max_len)

dim(x_train_pad)
## [1] 25000   200
dim(x_test_pad)
## [1] 25000   200

Step 3: Train/validation split

set.seed(123)
n <- nrow(x_train_pad)
idx <- sample.int(n)

n_train <- floor(0.85 * n)
tr_idx <- idx[1:n_train]
va_idx <- idx[(n_train + 1):n]

X_tr <- x_train_pad[tr_idx, , drop = FALSE]
y_tr <- y_train[tr_idx]

X_va <- x_train_pad[va_idx, , drop = FALSE]
y_va <- y_train[va_idx]

Step 4: Model design (Embedding + BiRNN)

We use:

  • Embedding dimension \(d=128\)
  • Bidirectional GRU/LSTM to capture both left-to-right and right-to-left dependencies
  • Dropout for regularization
  • Dense head with sigmoid output
embed_dim <- 128

input <- layer_input(shape = c(max_len), dtype = "int32", name = "tokens")

x <- input |>
  layer_embedding(input_dim = max_words, output_dim = embed_dim, name = "embedding") |>
  layer_bidirectional(layer_gru(units = 64, dropout = 0.2, recurrent_dropout = 0.0)) |>
  layer_dense(units = 64, activation = "relu") |>
  layer_dropout(rate = 0.3) |>
  layer_dense(units = 1, activation = "sigmoid", name = "p_yes")

model <- keras_model(inputs = input, outputs = x)

model |>
  compile(
    optimizer = optimizer_adam(learning_rate = 1e-3),
    loss = "binary_crossentropy",
    metrics = list("accuracy", metric_auc(name = "auc"))
  )

model
## Model: "functional_9"
## ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
## ┃ Layer (type)                                    ┃ Output Shape                         ┃              Param # ┃
## ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
## │ tokens (InputLayer)                             │ (None, 200)                          │                    0 │
## ├─────────────────────────────────────────────────┼──────────────────────────────────────┼──────────────────────┤
## │ embedding (Embedding)                           │ (None, 200, 128)                     │            1,280,000 │
## ├─────────────────────────────────────────────────┼──────────────────────────────────────┼──────────────────────┤
## │ bidirectional (Bidirectional)                   │ (None, 128)                          │               74,496 │
## ├─────────────────────────────────────────────────┼──────────────────────────────────────┼──────────────────────┤
## │ dense_15 (Dense)                                │ (None, 64)                           │                8,256 │
## ├─────────────────────────────────────────────────┼──────────────────────────────────────┼──────────────────────┤
## │ dropout_4 (Dropout)                             │ (None, 64)                           │                    0 │
## ├─────────────────────────────────────────────────┼──────────────────────────────────────┼──────────────────────┤
## │ p_yes (Dense)                                   │ (None, 1)                            │                   65 │
## └─────────────────────────────────────────────────┴──────────────────────────────────────┴──────────────────────┘
##  Total params: 1,362,817 (5.20 MB)
##  Trainable params: 1,362,817 (5.20 MB)
##  Non-trainable params: 0 (0.00 B)

Step 5: Training with early stopping

cb_es <- callback_early_stopping(
  monitor = "val_auc",
  mode = "max",
  patience = 3,
  restore_best_weights = TRUE
)

set.seed(123)
history <- model |>
  fit(
    x = X_tr, y = y_tr,
    validation_data = list(X_va, y_va),
    epochs = 20,
    batch_size = 256,
    callbacks = list(cb_es),
    verbose = 2
  )

plot(history) +
  theme_minimal() +
  ggtitle("Training history — GRU sentiment model")

Step 6: Evaluation on test set

metrics <- model |>
  evaluate(x_test_pad, y_test, verbose = 0)

setNames(as.numeric(metrics), names(metrics)) |>
  round(3)
## accuracy      auc     loss 
##    0.848    0.938    0.381

Step 7: Confusion matrix and thresholding

A default threshold is 0.5, but business decisions often require optimizing for precision or recall.

p_test <- model |>
  predict(x_test_pad)

pred_05 <- ifelse(p_test >= 0.5, 1, 0)

cm <- table(
  truth = factor(y_test, levels = c(0, 1)),
  pred  = factor(pred_05, levels = c(0, 1))
)

cm
##      pred
## truth     0     1
##     0 11601   899
##     1  2894  9606

Compute accuracy, precision, recall, F1:

precision <- cm["0","0"]  # placeholder to avoid confusion with indexing
tp <- cm["1","1"]; tn <- cm["0","0"]; fp <- cm["0","1"]; fn <- cm["1","0"]

acc <- (tp + tn) / sum(cm)
prec <- tp / (tp + fp)
rec  <- tp / (tp + fn)
f1   <- 2 * (prec * rec) / (prec + rec)

c(accuracy = acc, precision = prec, recall = rec, f1 = f1) |>
  round(3)
##  accuracy precision    recall        f1 
##     0.848     0.914     0.768     0.835

Summary

  • Embeddings learn dense semantic representations of tokens.
  • GRU/LSTM layers model order and context beyond bag-of-words.
  • Bidirectionality improves representation for classification tasks.
  • Use early stopping and dropout to control overfitting, and consider threshold tuning when costs are asymmetric.
 

A work by Gianluca Sottile

gianluca.sottile@unipa.it