Skip to content

Instantly share code, notes, and snippets.

@msjgriffiths
Created October 14, 2021 15:14
Show Gist options
  • Save msjgriffiths/f87d9b71462af75ff6a86ebca1114c08 to your computer and use it in GitHub Desktop.
Save msjgriffiths/f87d9b71462af75ff6a86ebca1114c08 to your computer and use it in GitHub Desktop.
Example sentence embedding
using Flux
# Input data, of (features, examples, time_step)
# We ignore masking
x = randn(Float32, 32, 100, 12)
hidden_states = Chain(
# A 4-layer stacked LSTM
LSTM(32, 64),
LSTM(64, 64),
LSTM(64, 64),
LSTM(64, 32),
)
attention = Chain(
# We want to compute attention. See arxiv@1703.03130.
x -> permutedims(x, (1, 3, 2)),
Parallel(
⊠, # NNlib.batched_mul
identity,
# A = softmax(Wₛ₂tanh(Wₛ₁Hᵀ)); note we take care of transpose with permutedims
Chain(
softmax ∘ Dense(12, 1, false) ∘ Dense(32, 12, false, tanh),
x -> permutedims(x, (2, 1, 3))
)
),
x -> permutedims(x, (3, 1, 2)),
x -> reshape(x, (100, 32))
)
sentence_embedding = Chain(hidden_states, attention)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment