Skip to content

Instantly share code, notes, and snippets.

@BloodAxe
Created August 19, 2020 08:53
Show Gist options
  • Save BloodAxe/70265dcab73c9a078a0928223c244c39 to your computer and use it in GitHub Desktop.
Save BloodAxe/70265dcab73c9a078a0928223c244c39 to your computer and use it in GitHub Desktop.
# https://github.com/BloodAxe/Kaggle-2020-Alaska2/blob/master/alaska2/dataset.py#L373
class TrainingValidationDataset(Dataset):
def __init__(
self,
images: Union[List, np.ndarray],
targets: Optional[Union[List, np.ndarray]],
quality: Union[List, np.ndarray],
bits: Optional[Union[List, np.ndarray]],
transform: Union[A.Compose, A.BasicTransform],
features: List[str],
):
"""
:param obliterate - Augmentation that destroys embedding.
"""
if targets is not None:
if len(images) != len(targets):
raise ValueError(f"Size of images and targets does not match: {len(images)} {len(targets)}")
self.images = images
self.targets = targets
self.transform = transform
self.features = features
self.quality = quality
self.bits = bits
def __len__(self):
return len(self.images)
def __repr__(self):
return f"TrainingValidationDataset(len={len(self)}, targets_hist={np.bincount(self.targets)}, qf={np.bincount(self.quality)}, features={self.features})"
def __getitem__(self, index):
image_fname = self.images[index]
try:
image = cv2.imread(image_fname)
if image is None:
raise FileNotFoundError(image_fname)
except Exception as e:
print("Cannot read image ", image_fname, "at index", index)
print(e)
qf = self.quality[index]
data = {}
data["image"] = image
data.update(compute_features(image, image_fname, self.features))
data = self.transform(**data)
sample = {INPUT_IMAGE_ID_KEY: os.path.basename(self.images[index]), INPUT_IMAGE_QF_KEY: int(qf)}
if self.bits is not None:
# OK
sample[INPUT_TRUE_PAYLOAD_BITS] = torch.tensor(self.bits[index], dtype=torch.float32)
if self.targets is not None:
target = int(self.targets[index])
sample[INPUT_TRUE_MODIFICATION_TYPE] = target
sample[INPUT_TRUE_MODIFICATION_FLAG] = torch.tensor([target > 0]).float()
for key, value in data.items():
if key in self.features:
sample[key] = tensor_from_rgb_image(value)
return sample
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment