Skip to content

Instantly share code, notes, and snippets.

@sebinsua
Last active March 11, 2024 12:02
Show Gist options
  • Save sebinsua/76fc5eb6fc498636bc637b9f10b7e6bf to your computer and use it in GitHub Desktop.
Save sebinsua/76fc5eb6fc498636bc637b9f10b7e6bf to your computer and use it in GitHub Desktop.
Smooth a stream of LLM tokens into a stream of characters while reducing jitter by stabilising output timing. Explorations of different approaches.
function sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
class AsyncQueue<T> {
queuedItems: (T | Error)[];
queuedProcessors: [(item: T) => void, (error: Error) => void][];
constructor() {
// Note: The FIFO `shift` operations we do on these arrays are `O(n)`.
// The performance is acceptable to us for now. If we ever need
// to optimize this we can swap the arrays for linked lists.
// However, without benchmarking it's hard to know whether we
// would benefit more from keeping the contiguous memory layout
// of an array or from moving to linked lists and getting `shift`
// operations with a time complexity of `O(1)` instead of `O(N)`.
//
// Note: We've implemented this here already:
// https://gist.github.com/sebinsua/76fc5eb6fc498636bc637b9f10b7e6bf
this.queuedItems = [];
this.queuedProcessors = [];
}
enqueue(item: T | Error) {
if (this.queuedProcessors.length > 0) {
const [resolve, reject] = this.queuedProcessors.shift()!;
if (item instanceof Error) {
reject(item);
} else {
resolve(item);
}
} else {
this.queuedItems.push(item);
}
}
async dequeue(): Promise<T> {
if (this.queuedItems.length > 0) {
const item = this.queuedItems.shift()!;
if (item instanceof Error) {
throw item;
}
return item;
} else {
return new Promise((resolve, reject) =>
this.queuedProcessors.push([resolve, reject])
);
}
}
size() {
return this.queuedItems.length;
}
}
interface CalculateDelayOptions {
initialDelay?: number;
zeroDelayQueueSize?: number;
}
function calculateDelay(
queueSize: number,
{ initialDelay = 32, zeroDelayQueueSize = 64 }: CalculateDelayOptions = {}
): number {
return Math.max(
0,
Math.floor(initialDelay - (initialDelay / zeroDelayQueueSize) * queueSize)
);
}
export type TokenizeFn = (
text: string,
inclusive?: boolean,
eof?: boolean
) => (readonly [token: string, index: number])[];
export type TokenizeType = "preserve" | "chars" | "words";
export type SmoothOptions = CalculateDelayOptions & {
tokenize?: TokenizeType | TokenizeFn;
};
function preserve(buffer: string) {
return [[buffer, buffer.length] as const];
}
function chars(buffer: string) {
return buffer.split("").map((token, index) => [token, index + 1] as const);
}
function chunks(buffer: string, regex: RegExp, inclusive = false, eof = false) {
const ws = [];
let lastIndex = 0;
for (let currentIndex = 0; currentIndex < buffer.length; currentIndex++) {
if (regex.test(buffer[currentIndex]!)) {
ws.push([
buffer.slice(lastIndex, currentIndex + (inclusive ? 1 : 0)),
currentIndex + (inclusive ? 1 : 0),
] as const);
lastIndex = currentIndex;
}
}
if (eof) {
ws.push([buffer.slice(lastIndex), buffer.length] as const);
}
return ws;
}
function words(buffer: string, eof = false) {
return chunks(buffer, /\s/, false, eof);
}
function clauses(buffer: string, eof = false) {
return chunks(buffer, /[.,!?;]/, true, eof);
}
const tokenizers = {
chars,
words,
clauses,
preserve,
} as const;
/**
* Smooth a stream of LLM tokens into a stream of characters or semantic chunks
* while reducing jitter by stabilising output timing.
*
* @param streamingData A stream of LLM tokens.
* @param options Options for the smoothing algorithm.
*/
export async function* smooth(
streamingData: AsyncGenerator<string | undefined>,
{ tokenize: _tokenize = chars, ...options }: SmoothOptions = {}
) {
const tokenize =
typeof _tokenize === "function" ? _tokenize : tokenizers[_tokenize];
const queue = new AsyncQueue<string | undefined>();
void (async () => {
let buffer = "";
let lastIndex: number | undefined;
try {
for await (const oldToken of streamingData) {
buffer += oldToken ?? "";
for (const [newToken, index] of tokenize(buffer)) {
queue.enqueue(newToken);
lastIndex = index;
}
if (typeof lastIndex === "number") {
buffer = buffer.slice(lastIndex);
lastIndex = undefined;
}
}
// Flush the buffer.
for (const [newToken] of tokenize(buffer, true)) {
queue.enqueue(newToken);
}
} catch (error) {
queue.enqueue(error as Error);
} finally {
queue.enqueue(undefined);
}
})();
while (true) {
const newToken = await queue.dequeue();
if (newToken === undefined) {
break;
}
yield newToken;
const delay = calculateDelay(queue.size(), options);
if (delay === 0) {
continue;
}
await sleep(delay);
}
}
function sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
class ListNode<T> {
public value: T;
public next: ListNode<T> | null = null;
constructor(value: T) {
this.value = value;
}
}
class LinkedList<T> {
private head: ListNode<T> | null = null;
private tail: ListNode<T> | null = null;
private _length: number = 0;
public isEmpty(): boolean {
return this._length === 0;
}
public size(): number {
return this._length;
}
public get length(): number {
return this._length;
}
public push(value: T): void {
const newNode = new ListNode(value);
if (this.tail) {
this.tail.next = newNode;
} else {
this.head = newNode;
}
this.tail = newNode;
this._length++;
}
public shift(): T | null {
if (!this.head) {
return null;
}
const headValue = this.head.value;
this.head = this.head.next;
if (!this.head) {
this.tail = null;
}
this._length--;
return headValue;
}
}
class AsyncQueue<T> {
queuedItems: LinkedList<T | Error>;
queuedProcessors: LinkedList<[(item: T) => void, (error: Error) => void]>;
constructor() {
// Note: The FIFO `shift` operations we do are `O(n)` on arrays.
// Therefore, we are using linked lists, however, without
// benchmarking it's hard to know whether we would benefit
// more from keeping the contiguous memory layout of an array
// or from continuing to use linked lists in order to get
// `shift` operations with a time complexity of `O(1)` instead
// of `O(N)`.
this.queuedItems = new LinkedList();
this.queuedProcessors = new LinkedList();
}
enqueue(item: T | Error) {
if (this.queuedProcessors.length > 0) {
const [resolve, reject] = this.queuedProcessors.shift()!;
if (item instanceof Error) {
reject(item);
} else {
resolve(item);
}
} else {
this.queuedItems.push(item);
}
}
async dequeue(): Promise<T> {
if (this.queuedItems.length > 0) {
const item = this.queuedItems.shift()!;
if (item instanceof Error) {
throw item;
}
return item;
} else {
return new Promise((resolve, reject) =>
this.queuedProcessors.push([resolve, reject])
);
}
}
size() {
return this.queuedItems.length;
}
}
interface CalculateDelayOptions {
initialDelay?: number;
zeroDelayQueueSize?: number;
}
function calculateDelay(
queueSize: number,
{ initialDelay = 32, zeroDelayQueueSize = 64 }: CalculateDelayOptions = {}
): number {
return Math.max(
0,
Math.floor(initialDelay - (initialDelay / zeroDelayQueueSize) * queueSize)
);
}
export type TokenizeFn = (
text: string,
inclusive?: boolean,
eof?: boolean
) => (readonly [token: string, index: number])[];
export type TokenizeType = "preserve" | "chars" | "words";
export type SmoothOptions = CalculateDelayOptions & {
tokenize?: TokenizeType | TokenizeFn;
};
function preserve(buffer: string) {
return [[buffer, buffer.length] as const];
}
function chars(buffer: string) {
return buffer.split("").map((token, index) => [token, index + 1] as const);
}
function chunks(buffer: string, regex: RegExp, inclusive = false, eof = false) {
const ws = [];
let lastIndex = 0;
for (let currentIndex = 0; currentIndex < buffer.length; currentIndex++) {
if (regex.test(buffer[currentIndex]!)) {
ws.push([
buffer.slice(lastIndex, currentIndex + (inclusive ? 1 : 0)),
currentIndex + (inclusive ? 1 : 0),
] as const);
lastIndex = currentIndex;
}
}
if (eof) {
ws.push([buffer.slice(lastIndex), buffer.length] as const);
}
return ws;
}
function words(buffer: string, eof = false) {
return chunks(buffer, /\s/, false, eof);
}
function clauses(buffer: string, eof = false) {
return chunks(buffer, /[.,!?;]/, true, eof);
}
const tokenizers = {
chars,
words,
clauses,
preserve,
} as const;
/**
* Smooth a stream of LLM tokens into a stream of characters or semantic chunks
* while reducing jitter by stabilising output timing.
*
* @param streamingData A stream of LLM tokens.
* @param options Options for the smoothing algorithm.
*/
export async function* smooth(
streamingData: AsyncGenerator<string | undefined>,
{ tokenize: _tokenize = chars, ...options }: SmoothOptions = {}
) {
const tokenize =
typeof _tokenize === "function" ? _tokenize : tokenizers[_tokenize];
const queue = new AsyncQueue<string | undefined>();
void (async () => {
let buffer = "";
let lastIndex: number | undefined;
try {
for await (const oldToken of streamingData) {
buffer += oldToken ?? "";
for (const [newToken, index] of tokenize(buffer)) {
queue.enqueue(newToken);
lastIndex = index;
}
if (typeof lastIndex === "number") {
buffer = buffer.slice(lastIndex);
lastIndex = undefined;
}
}
// Flush the buffer.
for (const [newToken] of tokenize(buffer, true)) {
queue.enqueue(newToken);
}
} catch (error) {
queue.enqueue(error as Error);
} finally {
queue.enqueue(undefined);
}
})();
while (true) {
const newToken = await queue.dequeue();
if (newToken === undefined) {
break;
}
yield newToken;
const delay = calculateDelay(queue.size(), options);
if (delay === 0) {
continue;
}
await sleep(delay);
}
}
@sebinsua
Copy link
Author

sebinsua commented Nov 13, 2023

I experimented trying to put this on the back-end using ai in a Next.js app and a client-side React library on the UI. In order to do so, I had to transform to and from AsyncIterator and ReadableStream (as the latter doesn't support async iterables in V8/Chrome yet).

This actually doesn't work right now as the stream output does not arrive smoothly and instead a large chunk is outputted at once. I've not yet confirmed whether this is due to (1) the complex JSON response structure that is being outputted, (2) some kind of low-level HTTP/TCP chunking/segmentation while streaming responses, (3) whatever the browser does when receiving responses, (4) the way React has been configured to render text that arrives, or something else. I'll need to strip down the logic so it's smaller to look into each layer to find out what is happening.

In practice, it might always be better for this logic to exist on the UI side, since that way we can smooth out the jitter and pauses that are caused by network conditions, too.

Either way, here are some helpful utilities for integrating with Node.js streams:

/**
 * Implements ReadableStream.from(asyncIterable), which isn't documented in MDN and isn't implemented in node.
 * https://github.com/whatwg/streams/commit/8d7a0bf26eb2cc23e884ddbaac7c1da4b91cf2bc
 *
 * Inlined from: https://github.com/vercel/ai/blob/8429dce6e6a650cb837a4aafb42367a618fa03e4/packages/core/streams/ai-stream.ts#L249C1-L266C2
 */
export function readableFromAsyncIterable<T>(iterable: AsyncIterable<T>) {
  let it = iterable[Symbol.asyncIterator]();
  return new ReadableStream<T>({
    async pull(controller) {
      const { done, value } = await it.next();
      if (done) controller.close();
      else controller.enqueue(value);
    },

    async cancel(reason) {
      await it.return?.(reason);
    },
  });
}

export async function* readableToAsyncIterable<T>(
  readableStream: ReadableStream<T>,
): AsyncIterable<T> {
  const reader = readableStream.getReader();
  try {
    while (true) {
      const { done, value } = await reader.read();
      if (done) break;
      yield value;
    }
  } finally {
    reader.releaseLock();
  }
}


function uint8ArrayToString(uint8Array: Uint8Array) {
  const decoder = new TextDecoder();
  return decoder.decode(uint8Array);
}

function stringToUint8Array(str: string) {
  const encoder = new TextEncoder();
  return encoder.encode(str);
}

async function* tokenAsyncIterable(stream: ReadableStream<Uint8Array>) {
  for await (let chunk of readableToAsyncIterable(stream)) {
    yield uint8ArrayToString(chunk);
  }
}

async function* uint8ArrayAsyncIterable(
  asyncIterable: AsyncIterable<string>,
): AsyncIterable<Uint8Array> {
  for await (let chunk of asyncIterable) {
    yield stringToUint8Array(chunk);
  }
}

Then use like so:

return new StreamingTextResponse(
  readableFromAsyncIterable(
    uint8ArrayAsyncIterable(smooth(tokenAsyncIterable(stream))),
  ),
);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment