| import torch |
| import numpy as np |
| |
| import os |
| |
| from skimage import io |
| import cv2 |
| xyz_from_rgb = np.array( |
| [[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169], [0.019334, 0.119193, 0.950227]] |
| ) |
| rgb_from_xyz = np.array( |
| [[3.24048134, -0.96925495, 0.05564664], [-1.53715152, 1.87599, -0.20404134], [-0.49853633, 0.04155593, 1.05731107]] |
| ) |
|
|
|
|
| def tensor_lab2rgb(input): |
| """ |
| n * 3* h *w |
| """ |
| input_trans = input.transpose(1, 2).transpose(2, 3) |
| L, a, b = input_trans[:, :, :, 0:1], input_trans[:, :, :, 1:2], input_trans[:, :, :, 2:] |
| y = (L + 16.0) / 116.0 |
| x = (a / 500.0) + y |
| z = y - (b / 200.0) |
|
|
| neg_mask = z.data < 0 |
| z[neg_mask] = 0 |
| xyz = torch.cat((x, y, z), dim=3) |
|
|
| mask = xyz.data > 0.2068966 |
| mask_xyz = xyz.clone() |
| mask_xyz[mask] = torch.pow(xyz[mask], 3.0) |
| mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.0) / 7.787 |
| mask_xyz[:, :, :, 0] = mask_xyz[:, :, :, 0] * 0.95047 |
| mask_xyz[:, :, :, 2] = mask_xyz[:, :, :, 2] * 1.08883 |
|
|
| rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view( |
| input.size(0), input.size(2), input.size(3), 3 |
| ) |
| rgb = rgb_trans.transpose(2, 3).transpose(1, 2) |
|
|
| mask = rgb > 0.0031308 |
| mask_rgb = rgb.clone() |
| mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055 |
| mask_rgb[~mask] = rgb[~mask] * 12.92 |
|
|
| neg_mask = mask_rgb.data < 0 |
| large_mask = mask_rgb.data > 1 |
| mask_rgb[neg_mask] = 0 |
| mask_rgb[large_mask] = 1 |
| return mask_rgb |
|
|
| def get_files(img_dir): |
| imgs, masks, xmls = list_files(img_dir) |
| return imgs, masks, xmls |
|
|
|
|
| def list_files(in_path): |
| img_files = [] |
| mask_files = [] |
| gt_files = [] |
| for (dirpath, dirnames, filenames) in os.walk(in_path): |
| for file in filenames: |
| filename, ext = os.path.splitext(file) |
| ext = str.lower(ext) |
| if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm': |
| img_files.append(os.path.join(dirpath, file)) |
| elif ext == '.bmp': |
| mask_files.append(os.path.join(dirpath, file)) |
| elif ext == '.xml' or ext == '.gt' or ext == '.txt': |
| gt_files.append(os.path.join(dirpath, file)) |
| elif ext == '.zip': |
| continue |
| return img_files, mask_files, gt_files |
|
|
|
|
| def load_image(img_file): |
| img = io.imread(img_file) |
| if img.shape[0] == 2: |
| img = img[0] |
| if len(img.shape) == 2: |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
| if img.shape[2] == 4: |
| img = img[:, :, :3] |
| img = np.array(img) |
|
|
| return img |