Running Vision Encoder Decoder Models in Swift (or any language)
The model I am going to be using for this blog post is OleehyO/TexTeller
which is made on top of Microsoft's TrOCR
model (which is bloody good for handwritten text recognition). I am working on an alternative to MathPix's Snipping Tool for macOS and wanted to be able to run this model without requiring to deal with Python.
The title of this post mentions any language as the general strategy of the encoder, decoder, and tokenizer remains the same.
Transformers can See??!
Transformers first started "seeing" with the initial VisionTransformer (ViT) architecture. An image is split into patches which are then flattened and embedded with positional encodings. Usually this is for an image classification task where the output is a single class label (or multiple labels for multi-label classification) representing the prediction of the image category.
The TrOCR paper introduced the idea of using pre-trained image and text transformers for text recognition task in OCR. The basic idea is that an encoder model is used to encode an image which is then fed to the decoder as an input which auto-regressively generates tokens, which the tokenizer then translates back to text.
This Python pseudocode represents how this entire process works.
model = myAmazingVisionEncoderDecoderModel()
tokenizer = myAmazingTokenizer()
last_hidden_state = model.encoder(pixel_values).last_hidden_state
decoder_ids = [tokenizer.bos_token_id]
max_length = 50
for _ in range(max_length):
logits = model.decoder(input_ids=decoder_ids, encoder_hidden_state=last_hidden_state)
next_token = argmax(logits)
if next_token == tokenizer.eos_token_id:
break
decoder_ids.append(next_token)
print(tokenizer.decode(decoder_ids[1:]))
Here, bos
stands for the beginning of speech, and eos
stands for the end of speech.
Padding and Attention Mask
In the above code we do not care about the size of input_ids
, but in some cases we have to provide the input of certain size. Say we had to provide an input of size [1, 100]
we would make use of the padding token. If we only have the decoder tokens tokenizer.bos_token_id, 280, 95
generated so far, we would pad the rest of the input with tokenizer.pad_token_id
(say 1
). Then, TrOCR generates an attention mask where it compares the input to mask out the padding token so the model can ignore it.
Exporting
There are three ways that come to my mind to run this model on-device.
- Ship all necessary Python packages (why would you ever do this if you are not using Python directly)
- Convert to an ONNX model
- Convert to a CoreML model
Converting the model to ONNX/CoreML format requires tracing the model. Since TrOCR
and TexTeller
are implemented using PyTorch, we can do this via torch.jit.trace
or torch.jit.script
. I like using torch.jit.trace
because it is a bit more mature.
Hugging Face 🤗
This is the easiest way to export a model from Huggingface to an ONNX model.
$ optimum-cli export onnx --model "OleehyO/TexTeller" exported
That's it. The amazing people behind Hugging Face have done a lot of work supporting a lot of models. This command generates a bunch of files in the exported
directory
$ ls -la exported
total 5853056
drwxr-xr-x@ 14 navanchauhan staff 448 Oct 19 19:39 .
drwxr-xr-x@ 19 navanchauhan staff 608 Oct 19 19:42 ..
-rw-r--r--@ 1 navanchauhan staff 56003 Oct 13 17:33 added_tokens.json
-rw-r--r--@ 1 navanchauhan staff 4504 Oct 13 17:33 config.json
-rw-r--r--@ 1 navanchauhan staff 908716081 Oct 13 17:33 decoder_model.onnx
-rw-r--r--@ 1 navanchauhan staff 909186959 Oct 13 17:33 decoder_model_merged.onnx
-rw-r--r--@ 1 navanchauhan staff 833037201 Oct 13 17:33 decoder_with_past_model.onnx
-rw-r--r--@ 1 navanchauhan staff 343553824 Oct 13 17:33 encoder_model.onnx
-rw-r--r--@ 1 navanchauhan staff 154 Oct 13 17:33 generation_config.json
-rw-r--r--@ 1 navanchauhan staff 70943 Oct 13 17:33 merges.txt
-rw-r--r--@ 1 navanchauhan staff 958 Oct 13 17:33 special_tokens_map.json
-rw-r--r--@ 1 navanchauhan staff 1370259 Oct 13 17:33 tokenizer.json
-rw-r--r--@ 1 navanchauhan staff 592739 Oct 13 17:33 tokenizer_config.json
-rw-r--r--@ 1 navanchauhan staff 146663 Oct 13 17:33 vocab.json
If you just care about inferencing, jump to the final section. If you want to see how you can trace the model, continue reading. I may update this section with pure Python code to run the encoder and decoder using onnxruntime
PyTorch Tracing
I extracted all the relevant configuration and utility functions from the TexTeller GitHub repository. I also loaded up a simple example image.
from PIL import Image
import requests
url = 'https://miro.medium.com/v2/resize:fit:1400/1*OReJHtogeA62SmSwzNzgvw.png'
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
# Formula image(grayscale) mean and variance
IMAGE_MEAN = 0.9545467
IMAGE_STD = 0.15394445
# Vocabulary size for TexTeller
VOCAB_SIZE = 15000
# Fixed size for input image for TexTeller
FIXED_IMG_SIZE = 448
# Image channel for TexTeller
IMG_CHANNELS = 1 # grayscale image
# Max size of token for embedding
MAX_TOKEN_SIZE = 1024
# Scaling ratio for random resizing when training
MAX_RESIZE_RATIO = 1.15
MIN_RESIZE_RATIO = 0.75
# Minimum height and width for input image for TexTeller
MIN_HEIGHT = 12
MIN_WIDTH = 30
num_beams = 1
from torchvision.transforms import v2
import torch
import cv2
import numpy as np
from typing import List, Union
general_transform_pipeline = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
v2.Grayscale(),
v2.Resize(
size=FIXED_IMG_SIZE - 1,
interpolation=v2.InterpolationMode.BICUBIC,
max_size=FIXED_IMG_SIZE,
antialias=True
),
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
])
import random
from collections import Counter
import re
def trim_white_border(image: np.ndarray):
if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8")
corners = [tuple(image[0, 0]), tuple(image[0, -1]),
tuple(image[-1, 0]), tuple(image[-1, -1])]
bg_color = Counter(corners).most_common(1)[0][0]
bg_color_np = np.array(bg_color, dtype=np.uint8)
h, w = image.shape[:2]
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
diff = cv2.absdiff(image, bg)
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
threshold = 15
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
x, y, w, h = cv2.boundingRect(diff)
trimmed_image = image[y:y+h, x:x+w]
return trimmed_image
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
randi = [random.randint(0, max_size) for _ in range(4)]
pad_height_size = randi[1] + randi[3]
pad_width_size = randi[0] + randi[2]
if (pad_height_size + image.shape[0] < 30):
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
randi[1] += compensate_height
randi[3] += compensate_height
if (pad_width_size + image.shape[1] < 30):
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
randi[0] += compensate_width
randi[2] += compensate_width
return v2.functional.pad(
torch.from_numpy(image).permute(2, 0, 1),
padding=randi,
padding_mode='constant',
fill=(255, 255, 255)
)
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
images = [
v2.functional.pad(
img,
padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
)
for img in images
]
return images
import re
def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
result = ""
i = 0
n = len(input_str)
while i < n:
if input_str[i:i+len(old_inst)] == old_inst:
# check if the old_inst is followed by old_surr_l
start = i + len(old_inst)
else:
result += input_str[i]
i += 1
continue
if start < n and input_str[start] == old_surr_l:
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
count = 1
j = start + 1
escaped = False
while j < n and count > 0:
if input_str[j] == '\\' and not escaped:
escaped = True
j += 1
continue
if input_str[j] == old_surr_r and not escaped:
count -= 1
if count == 0:
break
elif input_str[j] == old_surr_l and not escaped:
count += 1
escaped = False
j += 1
if count == 0:
assert j < n
assert input_str[start] == old_surr_l
assert input_str[j] == old_surr_r
inner_content = input_str[start + 1:j]
# Replace the content with new pattern
result += new_inst + new_surr_l + inner_content + new_surr_r
i = j + 1
continue
else:
assert count >= 1
assert j == n
print("Warning: unbalanced surrogate pair in input string")
result += new_inst + new_surr_l
i = start + 1
continue
else:
result += input_str[i:start]
i = start
if old_inst != new_inst and (old_inst + old_surr_l) in result:
return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
else:
return result
def find_substring_positions(string, substring):
positions = [match.start() for match in re.finditer(re.escape(substring), string)]
return positions
def rm_dollar_surr(content):
pattern = re.compile(r'\\[a-zA-Z]+\$.*?\$|\$.*?\$')
matches = pattern.findall(content)
for match in matches:
if not re.match(r'\\[a-zA-Z]+', match):
new_match = match.strip('$')
content = content.replace(match, ' ' + new_match + ' ')
return content
def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
pos = find_substring_positions(input_str, old_inst + old_surr_l)
res = list(input_str)
for p in pos[::-1]:
res[p:] = list(change(''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r))
res = ''.join(res)
return res
def to_katex(formula: str) -> str:
res = formula
# remove mbox surrounding
res = change_all(res, r'\mbox ', r' ', r'{', r'}', r'', r'')
res = change_all(res, r'\mbox', r' ', r'{', r'}', r'', r'')
# remove hbox surrounding
res = re.sub(r'\\hbox to ?-? ?\d+\.\d+(pt)?\{', r'\\hbox{', res)
res = change_all(res, r'\hbox', r' ', r'{', r'}', r'', r' ')
# remove raise surrounding
res = re.sub(r'\\raise ?-? ?\d+\.\d+(pt)?', r' ', res)
# remove makebox
res = re.sub(r'\\makebox ?\[\d+\.\d+(pt)?\]\{', r'\\makebox{', res)
res = change_all(res, r'\makebox', r' ', r'{', r'}', r'', r' ')
# remove vbox surrounding, scalebox surrounding
res = re.sub(r'\\raisebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\raisebox{', res)
res = re.sub(r'\\scalebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\scalebox{', res)
res = change_all(res, r'\scalebox', r' ', r'{', r'}', r'', r' ')
res = change_all(res, r'\raisebox', r' ', r'{', r'}', r'', r' ')
res = change_all(res, r'\vbox', r' ', r'{', r'}', r'', r' ')
origin_instructions = [
r'\Huge',
r'\huge',
r'\LARGE',
r'\Large',
r'\large',
r'\normalsize',
r'\small',
r'\footnotesize',
r'\tiny'
]
for (old_ins, new_ins) in zip(origin_instructions, origin_instructions):
res = change_all(res, old_ins, new_ins, r'$', r'$', '{', '}')
res = change_all(res, r'\boldmath ', r'\bm', r'{', r'}', r'{', r'}')
res = change_all(res, r'\boldmath', r'\bm', r'{', r'}', r'{', r'}')
res = change_all(res, r'\boldmath ', r'\bm', r'$', r'$', r'{', r'}')
res = change_all(res, r'\boldmath', r'\bm', r'$', r'$', r'{', r'}')
res = change_all(res, r'\scriptsize', r'\scriptsize', r'$', r'$', r'{', r'}')
res = change_all(res, r'\emph', r'\textit', r'{', r'}', r'{', r'}')
res = change_all(res, r'\emph ', r'\textit', r'{', r'}', r'{', r'}')
origin_instructions = [
r'\left',
r'\middle',
r'\right',
r'\big',
r'\Big',
r'\bigg',
r'\Bigg',
r'\bigl',
r'\Bigl',
r'\biggl',
r'\Biggl',
r'\bigm',
r'\Bigm',
r'\biggm',
r'\Biggm',
r'\bigr',
r'\Bigr',
r'\biggr',
r'\Biggr'
]
for origin_ins in origin_instructions:
res = change_all(res, origin_ins, origin_ins, r'{', r'}', r'', r'')
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
if res.endswith(r'\newline'):
res = res[:-8]
# remove multiple spaces
res = re.sub(r'(\\,){1,}', ' ', res)
res = re.sub(r'(\\!){1,}', ' ', res)
res = re.sub(r'(\\;){1,}', ' ', res)
res = re.sub(r'(\\:){1,}', ' ', res)
res = re.sub(r'\\vspace\{.*?}', '', res)
# merge consecutive text
def merge_texts(match):
texts = match.group(0)
merged_content = ''.join(re.findall(r'\\text\{([^}]*)\}', texts))
return f'\\text{{{merged_content}}}'
res = re.sub(r'(\\text\{[^}]*\}\s*){2,}', merge_texts, res)
res = res.replace(r'\bf ', '')
res = rm_dollar_surr(res)
# remove extra spaces (keeping only one)
res = re.sub(r' +', ' ', res)
return res.strip()
def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
images = [np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images]
images = [trim_white_border(image) for image in images]
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
images = padding(images, FIXED_IMG_SIZE)
return images
imgs = inference_transform([image])
from transformers import VisionEncoderDecoderModel
mymodel = VisionEncoderDecoderModel.from_pretrained("OleehyO/TexTeller").eval()
from transformers import RobertaTokenizerFast
tokenizer = RobertaTokenizerFast.from_pretrained("OleehyO/TexTeller")
Encoder Model
In an ideal world we would just be able to run torch.jit.trace
directly on the model with the processed image:
encoder_model = mymodel.encoder
traced_model = torch.jit.trace(encoder_model, torch.stack(imgs))
/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:4713: FutureWarning: <!--CODE_BLOCK_101--> is going to be deprecated in transformers 4.39.0. Please use <!--CODE_BLOCK_102--> instead
warnings.warn(
/usr/local/lib/python3.10/dist-packages/transformers/models/vit/modeling_vit.py:163: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if num_channels != self.num_channels:
/usr/local/lib/python3.10/dist-packages/transformers/models/vit/modeling_vit.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if height != self.image_size[0] or width != self.image_size[1]:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-2-1f8652b4fe66> in <cell line: 2>()
1 encoder_model = mymodel.encoder
----> 2 traced_model = torch.jit.trace(encoder_model, torch.stack(imgs))
2 frames
/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
1273 else:
1274 example_inputs = make_tuple(example_inputs)
-> 1275 module._c._create_method_from_trace(
1276 method_name,
1277 func,
RuntimeError: Encountering a dict at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for <!--CODE_BLOCK_103-->, use a <!--CODE_BLOCK_104--> instead. for <!--CODE_BLOCK_105-->, use a <!--CODE_BLOCK_106--> instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.
But, we run into a RuntimeError
that says the trace function does not like a dictionary output since there is no guarantee that the same keys will be returned every time. We can pass strict=False
but there is a better solution.
from collections import namedtuple
encoder_model = mymodel.encoder
EncoderOutput = namedtuple("EncoderOutput", encoder_model.forward(torch.stack(imgs)).keys())
class EncoderWrapper(torch.nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, pixel_values):
output = self.encoder(pixel_values)
return EncoderOutput(**output)
wrapped_encoder_model = EncoderWrapper(encoder_model)
traced_model = torch.jit.trace(wrapped_encoder_model, torch.stack(imgs))
This can then be exported to a CoreML model directly.
import coremltools as ct
coreml_encoder_model = ct.convert(
traced_model,
inputs=[ct.TensorType(name="pixel_values", shape=torch.stack(imgs).shape)]
)
coreml_encoder_model.save("encoder.mlpackage")
In Python, this can be used to generate the last hidden state by running:
encoder_hidden_states = coreml_encoder_model.predict({"pixel_values": imgs})['hidden_states']
Decoder Model
This is where things get tricky. When running the model directly we do not have to keep track of the shape for the decoder ids, but torch.jit.trace
requires the input shapes to be static so it can do its magic tracing the model. This is where the padding trick comes to play. The TrOCR model implementation states that the attention mask is automatically calculated if it is not passed to the model, which means we can ignore it for now.
We can't also simply have an if len(input_id) < max_length
because the trace()
function does not work with Python boolean logic.
decoder = mymodel.decoder.eval()
max_decoder_length = 100
input_ids = torch.randint(3, mymodel.config.decoder.vocab_size, (1, 80))
input_ids[0][0] = tokenizer.bos_token_id
encoder_hidden_states = torch.randn(1, 785, 768) # Example encoder_hidden_states which matches the shape of the encoder's output
def pad_input_ids(input_ids, max_length, pad_token_id):
input_ids = torch.nn.functional.pad(input_ids, (0, max_length - input_ids.size(1)), 'constant', pad_token_id)
return input_ids
class DecoderWrapper(torch.nn.Module):
def __init__(self, traced_decoder):
super().__init__()
self.traced_decoder = traced_decoder
def forward(self, input_ids=None, encoder_hidden_states=None):
correct_inputs = input_ids[input_ids != 1]
correct_inputs_reshaped = correct_inputs.unsqueeze(0)
return self.traced_decoder(
input_ids=correct_inputs_reshaped,
encoder_hidden_states=encoder_hidden_states,
use_cache=False,
)['logits']
wrapped_decoder = DecoderWrapper(decoder)
input_ids = pad_input_ids(input_ids, max_decoder_length, tokenizer.pad_token_id)
traced_decoder = torch.jit.trace(wrapped_decoder, (input_ids, encoder_hidden_states))
I did realise afterwards that I could have simplified the pad_input_ids
function since we are not tracing it. Oh well!
The use_cache
flag controls whether the model outputs past key values which can the be passed to the next run which does speed up things a bit but is a bit beyond the scope of this post.
coreml_decoder_model = ct.convert(
traced_decoder.eval(),
minimum_deployment_target=ct.target.iOS14, # Fixes issue with the CoreML Tools version I was using
inputs=[
ct.TensorType(
name="input_ids",
shape=input_ids.shape,
dtype=int
),
ct.TensorType(
name="last_hidden_state", shape=encoder_hidden_states.shape
)
],
outputs=[ct.TensorType(name='logits')]
)
To use it for prediction:
start_token_id = tokenizer.cls_token_id
decoder_input_ids = torch.tensor([[start_token_id]], dtype=torch.int32)
max_length = 100
decoded_token_ids = []
encoder_output = coreml_encoder_model.predict({"pixel_values": imgs})['hidden_states']
for _ in range(max_length):
logits = coreml_decoder_model.predict({
"input_ids": pad_input_ids(decoder_input_ids, max_decoder_length, tokenizer.pad_token_id).unsqueeze(0),
"last_hidden_state": encoder_output
})['logits']
next_token_id = np.argmax(logits, axis=-1)[0,-1]
if next_token_id == tokenizer.eos_token_id:
break
decoder_input_ids = torch.cat([decoder_input_ids, torch.tensor([[next_token_id]], dtype=torch.int32)], dim=1)
decoded_token_ids.append(next_token_id)
output_text = tokenizer.decode(decoded_token_ids, skip_special_tokens=True)
print(f"Generated Text: {output_text}")
What about the tokenizer?
The tokenizer class RobertaTokenizerFast
for the model is a specialized fast tokenization implementation that uses the Byte-Pair Encoding (BPE), but a more "fast" implementation. For our use case, we can create a simple implementation in Python using the vocabulary and tokenizer config file for the model. (Swift implementation in the next section)
import json
import re
class MyTokenizer:
def __init__(self, vocab_file, tokenizer_file):
with open(vocab_file, 'r', encoding='utf-8') as f:
self.vocab = json.load(f)
with open(tokenizer_file, 'r', encoding='utf-8') as f:
self.tokenizer_config = json.load(f)
self.id_to_token = {v: k for k, v in self.vocab.items()}
self.special_tokens = self.tokenizer_config.get('added_tokens', [])
self.cls_token_id = self.vocab.get('<s>')
self.sep_token_id = self.vocab.get('</s>')
self.pad_token_id = self.vocab.get('<pad>')
self.unk_token_id = self.vocab.get('<unk>')
def encode(self, text):
tokens = self._tokenize(text)
token_ids = [self.vocab.get(token, self.unk_token_id) for token in tokens]
return token_ids
def decode(self, token_ids, skip_special_tokens = True):
tokens = [self.id_to_token.get(token_id, self.id_to_token[self.unk_token_id]) for token_id in token_ids]
if skip_special_tokens:
tokens = [token for token in tokens if token not in self.special_tokens and token != '</s>']
# Replace 'Ġ' with a space to handle RoBERTa's special space tokenization
decoded_string = self._convert_tokens_to_string(tokens)
return decoded_string
def _convert_tokens_to_string(self, tokens):
text = ''.join(tokens).replace('Ġ', ' ')
text = re.sub(r'\s([?.!,\'"](?:\s|$))', r'\1', text)
return text.strip()
def _tokenize(self, text) :
text = self._clean_text(text)
words = re.findall(r'\w+|\S', text)
tokens = []
for word in words:
tokens.extend(self._bpe_encode(word))
return tokens
def _bpe_encode(self, word):
if word in self.vocab:
return [word]
chars = list(word)
for i in range(len(chars) - 1):
pair = chars[i] + chars[i + 1]
if pair in self.vocab:
chars[i] = pair
del chars[i + 1]
return chars
def _clean_text(self, text):
text = text.strip()
return text
Now, we can replace the last call that we use to generate text with
output_text = MyTokenizer("./exported/vocab.json", "./exported/tokenizer.json").decode(decoded_token_ids, skip_special_tokens=True)
print(f"Generated Text: {output_text}")
Let's bring it all together
These code snippets were used in an Xcode macOS app, but can be easily converted to be used in other projects. decoder_model.onnx
, encoder_model.onnx
, vocab.json
, and tokenizer.json
were copied from the exported
directory after exporting using optimum-cli
. The CoreML models can be exported and import similarly.
Image Processing
Do note that this, and the next section are very specific to the input processing required for the TexTeller model.
// ImageUtils.swift
import Foundation
import CoreImage
import AppKit
let IMAGE_MEAN: CGFloat = 0.9545467
let IMAGE_STD: CGFloat = 0.15394445
let FIXED_IMG_SIZE: CGFloat = 448
let IMG_CHANNELS: Int = 1
let MIN_HEIGHT: CGFloat = 12
let MIN_WIDTH: CGFloat = 30
func loadImage(from urlString: String) -> NSImage? {
guard let url = URL(string: urlString), let imageData = try? Data(contentsOf: url) else {
return nil
}
return NSImage(data: imageData)
}
func nsImageToCIImage(_ image: NSImage) -> CIImage? {
guard let data = image.tiffRepresentation,
let bitmapImage = NSBitmapImageRep(data: data),
let cgImage = bitmapImage.cgImage else {
return nil
}
return CIImage(cgImage: cgImage)
}
func trimWhiteBorder(image: CIImage) -> CIImage? {
let context = CIContext()
guard let cgImage = context.createCGImage(image, from: image.extent) else {
return nil
}
let width = cgImage.width
let height = cgImage.height
let colorSpace = CGColorSpaceCreateDeviceRGB()
let bytesPerPixel = 4
let bytesPerRow = bytesPerPixel * width
let bitmapInfo = CGImageAlphaInfo.premultipliedLast.rawValue
var pixelData = [UInt8](repeating: 0, count: height * bytesPerRow)
guard let contextRef = CGContext(
data: &pixelData,
width: width,
height: height,
bitsPerComponent: 8,
bytesPerRow: bytesPerRow,
space: colorSpace,
bitmapInfo: bitmapInfo
) else {
return nil
}
contextRef.draw(cgImage, in: CGRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height)))
let whitePixel: [UInt8] = [255, 255, 255, 255]
var minX = width
var minY = height
var maxX: Int = 0
var maxY: Int = 0
for y in 0..<height {
for x in 0..<width {
let pixelIndex = (y * bytesPerRow) + (x * bytesPerPixel)
let pixel = Array(pixelData[pixelIndex..<(pixelIndex + 4)])
if pixel != whitePixel {
if x < minX { minX = x }
if x > maxX { maxX = x }
if y < minY { minY = y }
if y > maxY { maxY = y }
}
}
}
if minX == width || minY == height || maxX == 0 || maxY == 0 {
return image
}
let croppedRect = CGRect(x: CGFloat(minX), y: CGFloat(minY), width: CGFloat(maxX - minX), height: CGFloat(maxY - minY))
return image.cropped(to: croppedRect)
}
func addWhiteBorder(to image: CIImage, maxSize: CGFloat) -> CIImage {
let randomPadding = (0..<4).map { _ in CGFloat(arc4random_uniform(UInt32(maxSize))) }
var xPadding = randomPadding[0] + randomPadding[2]
var yPadding = randomPadding[1] + randomPadding[3]
if xPadding + image.extent.width < MIN_WIDTH {
let compensateWidth = (MIN_WIDTH - (xPadding + image.extent.width)) * 0.5 + 1
xPadding += compensateWidth
}
if yPadding + image.extent.height < MIN_HEIGHT {
let compensateHeight = (MIN_HEIGHT - (yPadding + image.extent.height)) * 0.5 + 1
yPadding += compensateHeight
}
let padFilter = CIFilter(name: "CICrop")!
let paddedRect = CGRect(x: image.extent.origin.x - randomPadding[0],
y: image.extent.origin.y - randomPadding[1],
width: image.extent.width + xPadding,
height: image.extent.height + yPadding)
padFilter.setValue(image, forKey: kCIInputImageKey)
padFilter.setValue(CIVector(cgRect: paddedRect), forKey: "inputRectangle")
return padFilter.outputImage ?? image
}
func padding(images: [CIImage], requiredSize: CGFloat) -> [CIImage] {
return images.map { image in
let widthPadding = requiredSize - image.extent.width
let heightPadding = requiredSize - image.extent.height
return addWhiteBorder(to: image, maxSize: max(widthPadding, heightPadding))
}
}
func inferenceTransform(images: [NSImage]) -> [CIImage] {
let ciImages = images.compactMap { nsImageToCIImage($0) }
let trimmedImages = ciImages.compactMap { trimWhiteBorder(image: $0) }
let paddedImages = padding(images: trimmedImages, requiredSize: FIXED_IMG_SIZE)
return paddedImages
}
func ciImageToFloatArray(_ image: CIImage, size: CGSize) -> [Float] {
let context = CIContext()
guard let cgImage = context.createCGImage(image, from: image.extent) else {
return []
}
let width = Int(size.width)
let height = Int(size.height)
var pixelData = [UInt8](repeating: 0, count: width * height)
let colorSpace = CGColorSpaceCreateDeviceGray()
guard let contextRef = CGContext(
data: &pixelData,
width: width,
height: height,
bitsPerComponent: 8,
bytesPerRow: width,
space: colorSpace,
bitmapInfo: CGImageAlphaInfo.none.rawValue
) else {
return []
}
contextRef.draw(cgImage, in: CGRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height)))
return pixelData.map { Float($0) / 255.0 }
}
KaTeX Utils
Just some basic regex stuff ported to Swift
// KatexUtils.swift
import Foundation
func change(_ inputStr: String, oldInst: String, newInst: String, oldSurrL: Character, oldSurrR: Character, newSurrL: String, newSurrR: String) -> String {
var result = ""
var i = 0
let n = inputStr.count
let inputArray = Array(inputStr)
while i < n {
if i + oldInst.count <= n && inputStr[inputStr.index(inputStr.startIndex, offsetBy: i)..<inputStr.index(inputStr.startIndex, offsetBy: i + oldInst.count)] == oldInst {
let start = i + oldInst.count
if start < n && inputArray[start] == oldSurrL {
var count = 1
var j = start + 1
var escaped = false
while j < n && count > 0 {
if inputArray[j] == "\\" && !escaped {
escaped = true
j += 1
continue
}
if inputArray[j] == oldSurrR && !escaped {
count -= 1
if count == 0 {
break
}
} else if inputArray[j] == oldSurrL && !escaped {
count += 1
}
escaped = false
j += 1
}
if count == 0 {
let innerContent = String(inputArray[(start + 1)..<j])
result += newInst + newSurrL + innerContent + newSurrR
i = j + 1
continue
} else {
result += newInst + newSurrL
i = start + 1
continue
}
}
}
result.append(inputArray[i])
i += 1
}
if oldInst != newInst && result.contains(oldInst + String(oldSurrL)) {
return change(result, oldInst: oldInst, newInst: newInst, oldSurrL: oldSurrL, oldSurrR: oldSurrR, newSurrL: newSurrL, newSurrR: newSurrR)
}
return result
}
func findSubstringPositions(_ string: String, substring: String) -> [Int] {
var positions: [Int] = []
var searchRange = string.startIndex..<string.endIndex
while let range = string.range(of: substring, options: [], range: searchRange) {
let position = string.distance(from: string.startIndex, to: range.lowerBound)
positions.append(position)
searchRange = range.upperBound..<string.endIndex
}
return positions
}
func rmDollarSurr(content: String) -> String {
let pattern = try! NSRegularExpression(pattern: "\\\\[a-zA-Z]+\\$.*?\\$|\\$.*?\\$", options: [])
var newContent = content
let matches = pattern.matches(in: content, options: [], range: NSRange(content.startIndex..<content.endIndex, in: content))
for match in matches.reversed() {
let matchedString = (content as NSString).substring(with: match.range)
if !matchedString.starts(with: "\\") {
let strippedMatch = matchedString.replacingOccurrences(of: "$", with: "")
newContent = newContent.replacingOccurrences(of: matchedString, with: " \(strippedMatch) ")
}
}
return newContent
}
func changeAll(inputStr: String, oldInst: String, newInst: String, oldSurrL: Character, oldSurrR: Character, newSurrL: String, newSurrR: String) -> String {
let positions = findSubstringPositions(inputStr, substring: oldInst + String(oldSurrL))
var result = inputStr
for pos in positions.reversed() {
let startIndex = result.index(result.startIndex, offsetBy: pos)
let substring = String(result[startIndex..<result.endIndex])
let changedSubstring = change(substring, oldInst: oldInst, newInst: newInst, oldSurrL: oldSurrL, oldSurrR: oldSurrR, newSurrL: newSurrL, newSurrR: newSurrR)
result.replaceSubrange(startIndex..<result.endIndex, with: changedSubstring)
}
return result
}
func toKatex(formula: String) -> String {
var res = formula
res = changeAll(inputStr: res, oldInst: "\\mbox ", newInst: " ", oldSurrL: "{", oldSurrR: "}", newSurrL: "", newSurrR: "")
res = changeAll(inputStr: res, oldInst: "\\mbox", newInst: " ", oldSurrL: "{", oldSurrR: "}", newSurrL: "", newSurrR: "")
res = res.replacingOccurrences(of: "\\[", with: "")
res = res.replacingOccurrences(of: "\\]", with: "")
res = res.replacingOccurrences(of: "\\\\[?.!,\'\"](?:\\s|$)", with: "", options: .regularExpression)
res = rmDollarSurr(content: res)
res = res.replacingOccurrences(of: " +", with: " ", options: .regularExpression)
return res.trimmingCharacters(in: .whitespacesAndNewlines)
}
Tokenizer
// RobertaTokenizerFast.swift
// I don't think this is very fast -\_
import Foundation
class RobertaTokenizerFast {
var vocab: [String: Int] = [:]
var idToToken: [Int: String] = [:]
var specialTokens: [String] = []
var unkTokenId: Int?
init(vocabFile: String, tokenizerFile: String) {
if let vocabURL = Bundle.main.url(forResource: vocabFile, withExtension: "json"),
let vocabData = try? Data(contentsOf: vocabURL),
let vocabDict = try? JSONSerialization.jsonObject(with: vocabData, options: []) as? [String: Int] {
self.vocab = vocabDict
}
if let tokenizerURL = Bundle.main.url(forResource: tokenizerFile, withExtension: "json"),
let tokenizerData = try? Data(contentsOf: tokenizerURL),
let tokenizerConfig = try? JSONSerialization.jsonObject(with: tokenizerData, options: []) as? [String: Any] {
self.specialTokens = tokenizerConfig["added_tokens"] as? [String] ?? []
}
self.idToToken = vocab.reduce(into: [Int: String]()) { $0[$1.value] = $1.key }
self.unkTokenId = vocab["<unk>"]
}
func encode(text: String) -> [Int] {
let tokens = tokenize(text)
return tokens.map { vocab[$0] ?? unkTokenId! }
}
func decode(tokenIds: [Int], skipSpecialTokens: Bool = true) -> String {
let tokens = tokenIds.compactMap { idToToken[$0] }
let filteredTokens = skipSpecialTokens ? tokens.filter { !specialTokens.contains($0) && $0 != "</s>" } : tokens
return convertTokensToString(filteredTokens)
}
private func tokenize(_ text: String) -> [String] {
let cleanedText = cleanText(text)
let words = cleanedText.split(separator: " ").map { String($0) }
var tokens: [String] = []
for word in words {
tokens.append(contentsOf: bpeEncode(word))
}
return tokens
}
private func bpeEncode(_ word: String) -> [String] {
if vocab.keys.contains(word) {
return [word]
}
let chars = Array(word)
var tokens: [String] = []
var i = 0
while i < chars.count {
if i < chars.count - 1 {
let pair = String(chars[i]) + String(chars[i + 1])
if vocab.keys.contains(pair) {
tokens.append(pair)
i += 2
continue
}
}
tokens.append(String(chars[i]))
i += 1
}
return tokens
}
private func cleanText(_ text: String) -> String {
return text.trimmingCharacters(in: .whitespacesAndNewlines)
}
private func convertTokensToString(_ tokens: [String]) -> String {
let text = tokens.joined().replacingOccurrences(of: "Ġ", with: " ")
return text.replacingOccurrences(of: "\\s([?.!,\'\"](?:\\s|$))", with: "$1", options: .regularExpression, range: nil).trimmingCharacters(in: .whitespaces)
}
}
On it with ONNX
import OnnxRuntimeBindings
public enum ModelError: Error {
case encoderModelNotFound
case decoderModelNotFound
case imageError
}
public struct TexTellerModel {
public let encoderSession: ORTSession
public let decoderSession: ORTSession
private let tokenizer: RobertaTokenizerFast
public init() throws {
guard let encoderModelPath = Bundle.main.path(forResource: "encoder_model", ofType: "onnx") else {
print("Encoder model not found...")
throw ModelError.encoderModelNotFound
}
guard let decoderModelPath = Bundle.main.path(forResource: "decoder_model", ofType: "onnx") else {
print("Decoder model not found...")
throw ModelError.decoderModelNotFound
}
let env = try ORTEnv(loggingLevel: .warning)
let coreMLOptions = ORTCoreMLExecutionProviderOptions()
coreMLOptions.enableOnSubgraphs = true
coreMLOptions.createMLProgram = false
let options = try ORTSessionOptions()
// Uncomment below to use CoreML
//try options.appendCoreMLExecutionProvider(with: coreMLOptions)
encoderSession = try ORTSession(env: env, modelPath: encoderModelPath, sessionOptions: options)
decoderSession = try ORTSession(env: env, modelPath: decoderModelPath, sessionOptions: options)
self.tokenizer = RobertaTokenizerFast(vocabFile: "vocab", tokenizerFile: "tokenizer")
}
public func texIt(_ image: NSImage, rawString: Bool = false) throws -> String {
let transformedImage = inferenceTransform(images: [image])
if let firstTransformedImage = transformedImage.first {
let pixelValues = ciImageToFloatArray(firstTransformedImage, size: CGSize(width: FIXED_IMG_SIZE, height: FIXED_IMG_SIZE))
let inputTensor = try ORTValue(
tensorData: NSMutableData(
data: Data(bytes: pixelValues, count: pixelValues.count * MemoryLayout<Float>.stride)
),
elementType: .float,
shape: [
1, 1, NSNumber(value: FIXED_IMG_SIZE), NSNumber(value: FIXED_IMG_SIZE)
]
)
let encoderInput: [String: ORTValue] = [
"pixel_values": inputTensor
]
let encoderOutputNames = try self.encoderSession.outputNames()
let encoderOutputs: [String: ORTValue] = try self.encoderSession.run(
withInputs: encoderInput,
outputNames: Set(encoderOutputNames),
runOptions: nil
)
var decodedTokenIds: [Int] = []
let startTokenId = 0
let endTokenId = 2
let maxDecoderLength: Int = 100
var decoderInputIds: [Int] = [startTokenId]
let vocabSize = 15000
let decoderOutputNames = try self.decoderSession.outputNames()
for step in 0..<maxDecoderLength {
let decoderInputIdsTensor = try ORTValue(
tensorData: NSMutableData(data: Data(bytes: decoderInputIds, count: decoderInputIds.count * MemoryLayout<Int64>.stride)),
elementType: .int64,
shape: [1, NSNumber(value: decoderInputIds.count)]
)
let decoderInputs: [String: ORTValue] = [
"input_ids": decoderInputIdsTensor,
"encoder_hidden_states": encoderOutputs["last_hidden_state"]!
]
let decoderOutputs: [String: ORTValue] = try self.decoderSession.run(withInputs: decoderInputs, outputNames: Set(decoderOutputNames), runOptions: nil)
let logitsTensor = decoderOutputs["logits"]!
let logitsData = try logitsTensor.tensorData() as Data
let logits = logitsData.withUnsafeBytes {
Array(UnsafeBufferPointer<Float>(
start: $0.baseAddress!.assumingMemoryBound(to: Float.self),
count: logitsData.count / MemoryLayout<Float>.stride
))
}
let sequenceLength = decoderInputIds.count
let startIndex = (sequenceLength - 1) * vocabSize
let endIndex = startIndex + vocabSize
let lastTokenLogits = Array(logits[startIndex..<endIndex])
let nextTokenId = lastTokenLogits.enumerated().max(by: { $0.element < $1.element})?.offset ?? 9 // TODO: Should I track if this fails
if nextTokenId == endTokenId {
break
}
decodedTokenIds.append(nextTokenId)
decoderInputIds.append(nextTokenId)
}
if rawString {
return tokenizer.decode(tokenIds: decodedTokenIds)
}
return toKatex(formula: tokenizer.decode(tokenIds: decodedTokenIds))
}
throw ModelError.imageError
}
}
CoreML's Version
The above class can be modified to use CoreML instead.
import Foundation
import CoreML
import AppKit
func argmax(_ multiArray: MLMultiArray) -> Int? {
guard multiArray.dataType == .float32 else {
print("MLMultiArray is not of type Float32.")
return nil
}
var maxIndex: Int? = nil
var maxValue: Float = -Float.infinity
for i in 0..<multiArray.count {
let value = multiArray[i].floatValue
if value > maxValue {
maxValue = value
maxIndex = i
}
}
return maxIndex
}
public struct TexTellerCoreMLModel {
private let encoderModel: encoder
private let decoderModel: decoder
private let tokenizer: RobertaTokenizerFast
public init() throws {
self.encoderModel = try encoder(configuration: .init())
self.decoderModel = try decoder(configuration: .init())
self.tokenizer = RobertaTokenizerFast(vocabFile: "vocab", tokenizerFile: "tokenizer")
}
public func texIt(_ image: NSImage, rawString: Bool = false) throws -> String {
let transformedImage = inferenceTransform(images: [image])
if let firstTransformedImage = transformedImage.first {
let pixelValues = ciImageToFloatArray(firstTransformedImage, size: CGSize(width: FIXED_IMG_SIZE, height: FIXED_IMG_SIZE))
guard let multiArray = try? MLMultiArray(shape: [1,1,NSNumber(value: FIXED_IMG_SIZE), NSNumber(value: FIXED_IMG_SIZE)], dataType: .float32) else {
throw ModelError.imageError
}
for i in 0..<pixelValues.count {
multiArray[i] = NSNumber(value: pixelValues[i])
}
let prediction = try self.encoderModel.prediction(pixel_values: multiArray)
var decodedTokenIds: [Int] = []
let startTokenId = 0
let endTokenId = 2
let maxDecoderLength: Int = 100
var decoderInputIds: [Int] = [startTokenId]
let vocabSize = 15000
guard var tokenIdsArray = try? MLMultiArray(shape: [1,100], dataType: .float32) else {
throw ModelError.imageError
}
for i in 0..<100 {
tokenIdsArray[i] = 1
}
tokenIdsArray[0] = 0
var count = 1
func argmax(_ multiArray: MLMultiArray, vocabSize: Int) -> Int? {
var maxIndex: Int = 0
var maxValue: Float = -Float.infinity
for i in 0..<vocabSize {
let value = Float(truncating: multiArray[i])
if value > maxValue {
maxValue = value
maxIndex = i
}
}
return maxIndex
}
for i in 0..<32 {
print("my input is \(tokenIdsArray)")
let owo = try self.decoderModel.prediction(input_ids: tokenIdsArray, last_hidden_state: prediction.hidden_states)
print(owo.logits.shape)
print("got something")
// lastTokenLogits.enumerated().max(by: { $0.element < $1.element})?.offset ?? 9
if let nextToken = argmax(owo.logits, vocabSize: vocabSize) {
print("Next token index is \(nextToken)")
if nextToken == endTokenId {
print("Found eos token")
break
}
tokenIdsArray[i+1] = NSNumber(integerLiteral: nextToken)
decodedTokenIds.append(nextToken)
} else {
print("Failed to find the argmax.")
}
}
if rawString {
return tokenizer.decode(tokenIds: decodedTokenIds)
}
return toKatex(formula: tokenizer.decode(tokenIds: decodedTokenIds))
}
throw ModelError.imageError
}
}
Run 'em
To use the ONNX version:
do {
let mymodel = try await TexTellerModel()
if let myimage = loadImage("https://miro.medium.com/v2/resize:fit:1400/1*OReJHtogeA62SmSwzNzgvw.png") {
do {
let latex = try mymodel.texIt(myimage)
} catch {
print("Uh oh")
}
} else {
print("Failed to load the image")
}
} catch {
print("Error :( \(error)")
}