Bonus: PyTorch dataset class
You need to be a Pro member to access the full content of this.
Write PyTorch dataset class
In PyTorch, the Dataset class is a useful abstraction that allows you to work with a large amount of data. To create a class for your own data, you need to define a class that inherits from torch.utils.data.Dataset.
Let's write the one for amazon-picking-challenge-2016 dataset we created in Export dataset.
Here is the base structure:
...
from torch.utils import data
ROOT_DIR = "./amazon-picking-challenge-2016"
class AmazonPickingChallenge2016Dataset(data.Dataset):
def __init__(self):
# Find item ids in the dataset (e.g., 1466803322395175933)
item_ids = []
for filename in os.listdir(ROOT_DIR):
filepath = os.path.join(ROOT_DIR, filename)
if not os.path.isdir(filepath):
continue
item_id = filename
item_ids.append(item_id)
self.item_ids = sorted(item_ids)
# Find label names in the dataset
label_names = set()
for item_id in item_ids:
mask_paths = glob.glob(os.path.join(ROOT_DIR, item_id, "mask_*.jpg"))
for mask_path in mask_paths:
stem = os.path.splitext(os.path.basename(mask_path))[0]
label_name = "_".join(stem.split("_")[2:])
if label_name == "_container_":
continue
label_names.add(label_name)
self.label_names = sorted(label_names)
def __len__(self):
# Return the size of the dataset
return len(self.item_ids)
def __getitem__(self, index):
item_id = self.item_ids[index]
# Read image, bounding boxes, labels, and masks from item_id
...
return image, bboxes, labels, masks
__init__(self): This is the constructor that initializes the dataset. It specifically initializes two data class attributes:self.item_idsandself.label_names.self.item_ids: This holds the ID of a single data entry in our dataset iteration. Depending on the task and model architecture, it typically corresponds to each image.self.label_names: This contains a unique list of label names (strings). Since we usually represent object labels as integers in the model, this map is necessary for conversion.
__len__(self): This function returns the size of the dataset. The number of iterations in one epoch, when PyTorch iterates over the dataset, is determined by the value this function returns.__getitem__(self, index): This method returns the data item corresponding to the providedindex. Typically, data loading occurs in this method, which is where we implement image/mask loading.
If you're a Pro member, sign in 🔒️ to read the rest of the guide. Or check our plans if you don't have them yet.