File size: 2,373 Bytes
e72f783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# src/cache.py
# LRU cache keyed by image SHA256 hash
# Prevents recomputing WideResNet + CLIP for repeated images
# maxsize=128: holds ~128 inference results in RAM (~100MB max)

import hashlib
from collections import OrderedDict
from PIL import Image
import io


MAX_CACHE_SIZE = 128


class LRUCache:
    """
    Simple LRU cache backed by OrderedDict.
    Key: SHA256 hash of raw image bytes
    Value: dict of precomputed features for that image
    
    Why not functools.lru_cache: we need explicit key control
    (image hash, not the PIL object itself which is unhashable).
    """

    def __init__(self, maxsize=MAX_CACHE_SIZE):
        self.cache = OrderedDict()
        self.maxsize = maxsize
        self.hits = 0
        self.misses = 0

    def get(self, key):
        if key not in self.cache:
            self.misses += 1
            return None
        # Move to end = most recently used
        self.cache.move_to_end(key)
        self.hits += 1
        return self.cache[key]

    def set(self, key, value):
        if key in self.cache:
            self.cache.move_to_end(key)
        self.cache[key] = value
        if len(self.cache) > self.maxsize:
            # Pop least recently used (first item)
            self.cache.popitem(last=False)

    def stats(self):
        total = self.hits + self.misses
        hit_rate = self.hits / total if total > 0 else 0.0
        return {
            "hits": self.hits,
            "misses": self.misses,
            "total": total,
            "hit_rate": round(hit_rate, 4),
            "current_size": len(self.cache),
            "max_size": self.maxsize
        }

    def clear(self):
        self.cache.clear()
        self.hits = 0
        self.misses = 0


def get_image_hash(image_bytes: bytes) -> str:
    """
    SHA256 hash of raw image bytes.
    Used as cache key AND as unique image ID in HF Dataset logs.
    Same image submitted twice = same hash = cache hit.
    """
    return hashlib.sha256(image_bytes).hexdigest()


def pil_to_bytes(pil_img: Image.Image) -> bytes:
    """Convert PIL image to bytes for hashing."""
    buf = io.BytesIO()
    pil_img.save(buf, format="PNG")
    return buf.getvalue()


# Global cache instance — lives for the entire FastAPI server lifetime
# Initialised once in api/startup.py, imported everywhere
inference_cache = LRUCache(maxsize=MAX_CACHE_SIZE)