Summary
In this post, I cover an approach to a document AI problem using a task flow implemented in Apache Airflow. The particular problem is around the de-duplication of invoices. This comes up in payment provider space. I use Azure AI Document Intelligence for OCR, Azure OpenAI for vector embeddings, and Redis Enterprise for vector search.
Architecture
Code Snippets
File Sensor DAG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@task.sensor(task_id="check_inbox", mode="reschedule", timeout=10, executor_config=executor_config_volume_mount) | |
def check_inbox() -> PokeReturnValue: | |
""" File sensor for invoices inbox. If files are detected in the inbox, a cascade processing tasks are triggered: | |
OCR, Embed, Dedup. | |
""" | |
storage_var = Variable.get("storage", deserialize_json=True, default_var=None) | |
if (type(storage_var) != 'dict'): # hack for an apparent bug in airflow | |
storage_var = json.loads(storage_var) | |
inbox_path = storage_var['inbox'] | |
inbox_files = list(map(lambda file: os.path.join(inbox_path, file), os.listdir(inbox_path))) | |
logging.info(f'Number of files to be processed: {len(inbox_files)}') | |
if len(inbox_files) > 0: | |
return PokeReturnValue(is_done=True, xcom_value=inbox_files) | |
else: | |
return PokeReturnValue(is_done=False) |
OCR DAG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@task(task_id='parse_invoice', executor_config=executor_config_volume_mount) | |
def parse_invoice(inbox_file: str) -> dict: | |
""" OCR is performed on each of invoices in the inbox. The result of OCR is space delimited string of a | |
configurable number of invoice fields. | |
""" | |
from invoice.lib.ocr import ocr | |
invoice = ocr(inbox_file) | |
invoice['file'] = inbox_file | |
logging.info(f'Invoice: {pprint.pformat(invoice)}') | |
return invoice |
OCR Client (Azure AI Doc Intelligence)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@retry(wait=wait_random_exponential(min=10, max=60), stop=stop_after_attempt(3)) | |
def ocr(filepath: str) -> dict: | |
""" Executes Azure Form Recognized OCR and returns a Python dict that includes a text string | |
of space-separated values from the input invoice. | |
""" | |
formrec_var = Variable.get("formrec", deserialize_json=True, default_var=None) | |
if (type(formrec_var) != 'dict'): # hack for an apparent bug in airflow | |
formrec_var = json.loads(formrec_var) | |
key = formrec_var["key"] | |
endpoint = formrec_var["endpoint"] | |
vector_fields = formrec_var["fields"] | |
client = DocumentAnalysisClient(endpoint=endpoint, credential=AzureKeyCredential(key)) | |
with open(filepath, "rb") as f: | |
poller = client.begin_analyze_document("prebuilt-invoice", document=f, locale="en-US") | |
invoice = (poller.result()).documents[0] | |
return stringify(invoice, vector_fields) |
Embedding DAG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@task(task_id='embed_invoice') | |
def embed_invoice(invoice: dict) -> dict: | |
""" Accepts a invoice dict that includes a text field of the OCR output | |
and adds an OpenAI embedding (array of floats) to that dict | |
""" | |
from invoice.lib.embed import get_embedding | |
vector = get_embedding(invoice['ocr']) | |
invoice['vector'] = vector | |
logging.info(f'Invoice: {invoice["file"]}, Vector len: {invoice["vector"]}') | |
return invoice |
Embedding Client (Azure OpenAI)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@retry(wait=wait_random_exponential(min=3, max=100), stop=stop_after_attempt(10)) | |
def get_embedding(text: str) -> [float]: | |
response = openai.Embedding.create( | |
input=text, | |
engine="EmbeddingModel" | |
) | |
return response['data'][0]['embedding'] |
Vector Search DAG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@task(task_id='dedup_invoice', executor_config=executor_config_volume_mount) | |
def dedup_invoice(invoice: dict) -> None: | |
""" Sends the invoice dict into a Redis VSS lookup to determine disposition - process or call it a duplicate | |
""" | |
from invoice.lib.vss import dedup | |
result = dedup(invoice) | |
logging.info(f'Invoice: {invoice["file"]}, Result: {result}') |
Vector Search Client (Redis Enterprise)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def dedup(invoice: dict) -> str: | |
""" Accepts a Python dict that includes a vector of a given invoice file. That vector is then sent into | |
Redis VSS to determine disposition. If there's another invoice in Redis within a given vector distance of the input invoice, | |
this invoice is disposed as a duplicate moved to the 'dups' directory. Otherwise, it is disposed as a net new invoice | |
and moved to the 'processed' directory. | |
""" | |
re_var = Variable.get("re", deserialize_json=True, default_var=None) | |
if (type(re_var) != 'dict'): # hack for an apparent bug in airflow | |
re_var = json.loads(re_var) | |
storage_var = Variable.get("storage", deserialize_json=True, default_var=None) | |
if (type(storage_var) != 'dict'): # hack for an apparent bug in airflow | |
storage_var = json.loads(storage_var) | |
creds = redis.UsernamePasswordCredentialProvider(re_var['user'], re_var['pwd']) | |
client = redis.Redis(host=re_var['host'], port=re_var['port'], credential_provider=creds) | |
try: | |
client.ft(re_var['vector_index']).info() | |
except: | |
idx_def = IndexDefinition(index_type=IndexType.HASH, prefix=[re_var['vector_prefix']]) | |
schema = [ | |
TextField('customer_name'), | |
VectorField('vector', | |
'HNSW', | |
{ 'TYPE': re_var['vector_type'], 'DIM': re_var['vector_dim'], 'DISTANCE_METRIC': re_var['vector_metric'] } | |
) | |
] | |
client.ft(re_var['vector_index']).create_index(schema, definition=idx_def) | |
vec = np.array(invoice['vector'], dtype=np.float32).tobytes() | |
q = Query(f'@customer_name:({invoice["customer_name"]}) => [KNN 1 @vector $query_vec AS score]')\ | |
.return_fields('score')\ | |
.dialect(2) | |
results = client.ft(re_var['vector_index']).search(q, query_params={'query_vec': vec}) | |
docs = results.docs | |
if len(docs) > 0 and 1 - float(docs[0].score) > re_var['vector_similarity_bound']: | |
print(f'score:{float(docs[0].score)}') | |
shutil.move(invoice['file'], storage_var['dups']) | |
logging.info(f'Duplicate invoice:{os.path.basename(invoice["file"])}, Similarity:{round(1 - float(docs[0].score), 2)}') | |
return 'duplicate' | |
else: | |
if len(docs) > 0: | |
similarity = round(1 - float(docs[0].score), 2) | |
else: | |
similarity = 'N/A' | |
client.hset(f'invoice:{uuid.uuid4()}', | |
mapping={'customer_name': invoice['customer_name'], 'file': os.path.basename(invoice['file']),'vector': vec}) | |
shutil.move(invoice['file'], storage_var['processed']) | |
logging.info(f'Processed invoice:{os.path.basename(invoice["file"])}, Similarity:{similarity}') | |
return 'processed' |
Source
Copyright ©1993-2024 Joey E Whelan, All rights reserved.