Skip to content

Instantly share code, notes, and snippets.

@joonas-fi
Last active January 29, 2023 15:59
Show Gist options
  • Save joonas-fi/c48c556b77eab28f9fed374928266c43 to your computer and use it in GitHub Desktop.
Save joonas-fi/c48c556b77eab28f9fed374928266c43 to your computer and use it in GitHub Desktop.
A stab at faster MD4 for Go
package main
import (
"encoding/binary"
"hash"
)
// The size of an MD4 checksum in bytes.
const Size = 16
// The blocksize of MD4 in bytes.
const BlockSize = 64
var (
le = binary.LittleEndian // shorthand
)
type md4Hash struct {
a, b, c, d uint32 // digest
inputBytesHashed int64 // how many "payload" bytes were sent to Write()
// not all Write()s align with BlockSize (64 bytes), so in those cases we've to copy bytes
// here until we receive more writes to complete a full block.
queued [BlockSize]byte // reused to reduce allocations
queuedLength int // queued bytes = queued[:queuedLength]
}
func New() *md4Hash {
return &md4Hash{
a: 0x67452301,
b: 0xefcdab89,
c: 0x98badcfe,
d: 0x10325476,
}
}
var _ hash.Hash = (*md4Hash)(nil)
func (m *md4Hash) Write(p []byte) (int, error) {
m.inputBytesHashed += int64(len(p))
leftTowrite := p // pointer to subset of p which is advanced as we process blocks
for len(leftTowrite) > 0 {
if m.queuedLength > 0 { // have queued writes
if m.queuedLength+len(leftTowrite) >= BlockSize { // manage to now get a full block
n := copy(m.queued[m.queuedLength:], leftTowrite)
m.runBlock(m.queued[:])
leftTowrite = leftTowrite[n:] // advance
m.queuedLength = 0 // reset
} else { // still not enough input to get a full block
// store in queued to wait for the next Write()
n := copy(m.queued[m.queuedLength:], leftTowrite) // by definition will fit
leftTowrite = leftTowrite[n:] // advance
m.queuedLength += n
}
continue
}
if len(leftTowrite) >= BlockSize { // can write full block from input
m.runBlock(leftTowrite)
leftTowrite = leftTowrite[BlockSize:] // advance
} else { // only partial block. need to store it so we can pass full blocks to runBlock()
// queued was empty. leftTowrite necessarily fits in full
n := copy(m.queued[:], leftTowrite)
m.queuedLength += n
break // *leftTowrite* is necessarily empty now
}
}
return len(p), nil
}
func (m *md4Hash) Sum(b []byte) []byte {
// need to copy to honor the contract that Sum() must not change internal state.
final := New()
*final = *m // copy contents
return final.SumUnsafe(b)
}
// faster version of Sum() that is allowed to change internal state
func (m *md4Hash) SumUnsafe(b []byte) []byte {
// capture original input's size because our following end marker, padding etc. Write()
// would change it
inputBytesHashed := m.inputBytesHashed
m.Write([]byte{0x80})
// need 8 bytes to write length of digested stream
if m.queuedLength+8 > BlockSize { // doesn't fit in current queued block? pad and flush it out to start new block
paddingZeroes := make([]byte, BlockSize-m.queuedLength)
m.Write(paddingZeroes)
// now m.queued is quaranteed to have space for 8 bytes
}
for i := m.queuedLength; i < BlockSize; i++ { // zero rest of queued buffer
m.queued[i] = 0x00
}
le.PutUint64(m.queued[BlockSize-8:], uint64(inputBytesHashed*8)) // in bits
m.runBlock(m.queued[:])
sum := [Size]byte{}
le.PutUint32(sum[0*4:], m.a)
le.PutUint32(sum[1*4:], m.b)
le.PutUint32(sum[2*4:], m.c)
le.PutUint32(sum[3*4:], m.d)
return append(b, sum[:]...)
}
func (m *md4Hash) Reset() {
c := New()
*m = *c
}
func (m *md4Hash) Size() int {
return Size
}
func (m *md4Hash) BlockSize() int {
return BlockSize
}
// len(block) >= 64 is guaranteed here
func (m *md4Hash) runBlock(block []byte) {
a, b, c, d := m.runBlock2(block)
m.a += a
m.b += b
m.c += c
m.d += d
}
// returns new (a, b, c, d) which are derived from current digest, and which are supposed to be added to digest
func (m *md4Hash) runBlock2(block []byte) (uint32, uint32, uint32, uint32) {
a, b, c, d := m.a, m.b, m.c, m.d // shorthands
words := [16]uint32{
le.Uint32(block[0*4:]),
le.Uint32(block[1*4:]),
le.Uint32(block[2*4:]),
le.Uint32(block[3*4:]),
le.Uint32(block[4*4:]),
le.Uint32(block[5*4:]),
le.Uint32(block[6*4:]),
le.Uint32(block[7*4:]),
le.Uint32(block[8*4:]),
le.Uint32(block[9*4:]),
le.Uint32(block[10*4:]),
le.Uint32(block[11*4:]),
le.Uint32(block[12*4:]),
le.Uint32(block[13*4:]),
le.Uint32(block[14*4:]),
le.Uint32(block[15*4:]),
}
for _, i := range []int{0, 4, 8, 12} {
a = rotl(a+f(b, c, d)+words[i+0], 3)
d = rotl(d+f(a, b, c)+words[i+1], 7)
c = rotl(c+f(d, a, b)+words[i+2], 11)
b = rotl(b+f(c, d, a)+words[i+3], 19)
}
for _, i := range []int{0, 1, 2, 3} {
a = rotl(a+g(b, c, d)+words[i+0]+0x5a827999, 3)
d = rotl(d+g(a, b, c)+words[i+4]+0x5a827999, 5)
c = rotl(c+g(d, a, b)+words[i+8]+0x5a827999, 9)
b = rotl(b+g(c, d, a)+words[i+12]+0x5a827999, 13)
}
for _, i := range []int{0, 2, 1, 3} {
a = rotl(a+h(b, c, d)+words[i+0]+0x6ed9eba1, 3)
d = rotl(d+h(a, b, c)+words[i+8]+0x6ed9eba1, 9)
c = rotl(c+h(d, a, b)+words[i+4]+0x6ed9eba1, 11)
b = rotl(b+h(c, d, a)+words[i+12]+0x6ed9eba1, 15)
}
return a, b, c, d
}
func rotl(x, n uint32) uint32 {
return (x << n) | (x >> (32 - n))
}
func f(x, y, z uint32) uint32 {
return (x & y) | (^x & z)
}
func g(x, y, z uint32) uint32 {
return (x & y) | (x & z) | (y & z)
}
func h(x, y, z uint32) uint32 {
return x ^ y ^ z // xor
}
package main
import (
"fmt"
"testing"
"github.com/function61/gokit/testing/assert"
gomd4 "golang.org/x/crypto/md4"
)
// var testMaterial = []byte("Hello world")
var testMaterial = gen1MBFile()
func BenchmarkGoMD4(b *testing.B) {
for n := 0; n < b.N; n++ {
h := gomd4.New()
h.Write(testMaterial)
h.Sum(nil)
}
}
func BenchmarkOurMD4(b *testing.B) {
for n := 0; n < b.N; n++ {
h := New()
h.Write(testMaterial)
h.Sum(nil)
}
}
func TestMD4(t *testing.T) {
for _, tc := range []struct {
input string
expectedOutput string
}{
{
"The quick brown fox jumps over the lazy dog",
"1bee69a46ba811185c194762abaeae90",
},
{
"The quick brown fox jumps over the lazy cog",
"b86e130ce7028da59e672d56ad0113df",
},
{
"",
"31d6cfe0d16ae931b73c59d7e0c089c0",
},
{
"a",
"bde52cb31de33e46245e05fbdbd6fb24",
},
{
"abc",
"a448017aaf21d8525fc10ae87aa6729d",
},
{
"message digest",
"d9130a8164549fe818874806e1c7014b",
},
{
"abcdefghijklmnopqrstuvwxyz",
"d79e1c308aa5bbcdeea8ed63df412da9",
},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789",
"043f8582f241db351ce627e153e7f0e4",
},
{
"12345678901234567890123456789012345678901234567890123456789012345678901234567890",
"e33b4ddc9c38f2199c3e7b164fcc0536",
},
} {
t.Run(tc.input, func(t *testing.T) {
h := New()
h.Write([]byte(tc.input))
assert.EqualString(t, hex(h.Sum(nil)), tc.expectedOutput)
})
}
}
func hex(input []byte) string {
return fmt.Sprintf("%x", input)
}
func gen1MBFile() []byte {
buf := make([]byte, 1024*1024)
for i := 0; i < len(buf); i++ {
buf[i] = byte(i)
}
return buf
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment