Source code for geowatch.tasks.depth.demo_transform


import torch
from PIL import Image
import numpy as np

try:
    import accimage
except ImportError:
    accimage = None


def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


[docs] class Scale(object): def __init__(self, size): self.size = size def __call__(self, image): image = self.changeScale(image, self.size) return image
[docs] def changeScale(self, img, size, interpolation=Image.BILINEAR): ow, oh = size return img.resize((ow, oh), interpolation)
[docs] class CenterCrop(object): def __init__(self, size): self.size = size def __call__(self, image): image = self.centerCrop(image, self.size) return image
[docs] def centerCrop(self, image, size): w1, h1 = image.size tw, th = size if w1 == tw and h1 == th: return image x1 = int(round((w1 - tw) / 2.)) y1 = int(round((h1 - th) / 2.)) image = image.crop((x1, y1, tw + x1, th + y1)) return image
[docs] class ToTensor(object): """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. Converts a PIL.Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. """ def __call__(self, image): image = self.to_tensor(image) return image
[docs] def to_tensor(self, pic): if not (_is_pil_image(pic) or _is_numpy_image(pic)): raise TypeError( 'pic should be PIL Image or ndarray. Got {}'.format( type(pic))) if isinstance(pic, np.ndarray): img = torch.from_numpy(pic.transpose((2, 0, 1))) return img.float().div(255) if accimage is not None and isinstance(pic, accimage.Image): nppic = np.zeros( [pic.channels, pic.height, pic.width], dtype=np.float32) pic.copyto(nppic) return torch.from_numpy(nppic) # handle PIL Image if pic.mode == 'I': img = torch.from_numpy(np.array(pic, np.int32, copy=False)) elif pic.mode == 'I;16': img = torch.from_numpy(np.array(pic, np.int16, copy=False)) else: img = torch.ByteTensor( torch.ByteStorage.from_buffer( pic.tobytes())) # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK if pic.mode == 'YCbCr': nchannel = 3 elif pic.mode == 'I;16': nchannel = 1 else: nchannel = len(pic.mode) img = img.view(pic.size[1], pic.size[0], nchannel) # put it from HWC to CHW format # yikes, this transpose takes 80% of the loading time/CPU img = img.transpose(0, 1).transpose(0, 2).contiguous() if isinstance(img, torch.ByteTensor): return img.float().div(255) else: return img
[docs] class ToNumpy(object): """Converts a torch.FloatTensor of shape (C x H x W) to a numpy.ndarray (H x W x C) """ def __call__(self, image): return self.to_numpy(image)
[docs] def to_numpy(self, image): image = image.cpu().data.numpy() return image.transpose((1, 2, 0))
[docs] class Normalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, image): image = self.normalize(image, self.mean, self.std) return image
[docs] def normalize(self, tensor, mean, std): for t, m, s in zip(tensor, mean, std): t.sub_(m).div_(s) return tensor
[docs] class Lighting(object): def __init__(self, alphastd, eigval, eigvec): self.alphastd = alphastd self.eigval = eigval self.eigvec = eigvec def __call__(self, image): if self.alphastd == 0: return image alpha = image.new().resize_(3).normal_(0, self.alphastd) rgb = self.eigvec.type_as(image).clone()\ .mul(alpha.view(1, 3).expand(3, 3))\ .mul(self.eigval.view(1, 3).expand(3, 3))\ .sum(1).squeeze() image = image.add(rgb.view(3, 1, 1).expand_as(image)) return image