| from enum import Enum |
| from PIL import Image |
|
|
|
|
| def unravel_index(index, shape): |
| out = [] |
| for dim in reversed(shape): |
| out.append(index % dim) |
| index = index // dim |
| return tuple(reversed(out)) |
|
|
|
|
| class ExplicitEnum(Enum): |
| """ |
| Enum with more explicit error message for missing values or getting all options |
| """ |
|
|
| @classmethod |
| def _missing_(cls, value): |
| raise ValueError( |
| f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" |
| ) |
|
|
| @classmethod |
| def options(cls): |
| return list(cls._value2member_map_.keys()) |
|
|
|
|
| class InferenceMethod(ExplicitEnum): |
| """All the implemented inference methods""" |
|
|
| FIRST = "first" |
| SECOND = "second" |
| LAST = "last" |
|
|
| GRID = "grid" |
| |
|
|
| MAX_CONFIDENCE = "max_confidence" |
| SOFT_VOTING = "soft_voting" |
| HARD_VOTING = "hard_voting" |
|
|
| @property |
| def scope(self): |
| if self in [InferenceMethod.FIRST, InferenceMethod.SECOND, InferenceMethod.LAST]: |
| return "sample" |
| if self in [InferenceMethod.GRID]: |
| return "sample-grid" |
| else: |
| return "iter" |
|
|
| def get_page_scope(self, pages): |
| if self.scope == "iter": |
| return pages |
| if self == InferenceMethod.GRID: |
| try: |
| return equal_image_grid(pages) |
| except Exception as e: |
| return pages[-1] |
| if self == InferenceMethod.FIRST: |
| return pages[0] |
| if self == InferenceMethod.SECOND: |
| if len(pages) > 1: |
| return pages[1] |
| return pages[0] |
| if self == InferenceMethod.LAST: |
| return pages[-1] |
|
|
| def apply_decision_strategy(self, page_logits): |
| """ |
| page logits is of shape [NUM_PAGES x CLASSES] |
| """ |
| if self == InferenceMethod.MAX_CONFIDENCE: |
| index = page_logits.argmax() |
| indices = unravel_index(index, page_logits.shape) |
| print(f"The page which is max confident: {indices[0]}") |
| return indices[-1] |
| if self == InferenceMethod.HARD_VOTING: |
| return page_logits.argmax(-1).max() |
| if self == InferenceMethod.SOFT_VOTING: |
| return page_logits.mean(0).argmax(-1) |
|
|
|
|
| def equal_image_grid(images): |
| def compute_grid(n, max_cols=6): |
| equalDivisor = int(n**0.5) |
| cols = min(equalDivisor, max_cols) |
| rows = equalDivisor |
| if rows * cols >= n: |
| return rows, cols |
| cols += 1 |
| if rows * cols >= n: |
| return rows, cols |
| while rows * cols < n: |
| rows += 1 |
| return rows, cols |
|
|
| |
| rows, cols = compute_grid(len(images)) |
|
|
| |
| images = [im for im in images if (im.height > 0) and (im.width > 0)] |
|
|
| min_width = min(im.width for im in images) |
| images = [im.resize((min_width, int(im.height * min_width / im.width)), resample=Image.BICUBIC) for im in images] |
|
|
| w, h = max([img.size[0] for img in images]), max([img.size[1] for img in images]) |
|
|
| grid = Image.new("RGB", size=(cols * w, rows * h)) |
| grid_w, grid_h = grid.size |
|
|
| for i, img in enumerate(images): |
| grid.paste(img, box=(i % cols * w, i // cols * h)) |
| return grid |
|
|