Labelme v5.11 is released! (SAM3 is here)

Bonus: PyTorch dataset class

You need to be a Pro member to access the full content of this.

Write PyTorch dataset class

In PyTorch, the Dataset class is a useful abstraction that allows you to work with a large amount of data. To create a class for your own data, you need to define a class that inherits from torch.utils.data.Dataset.


Let's write the one for amazon-picking-challenge-2016 dataset we created in Export dataset.


Here is the base structure:

...
from torch.utils import data


ROOT_DIR = "./amazon-picking-challenge-2016"


class AmazonPickingChallenge2016Dataset(data.Dataset):
    def __init__(self):
        # Find item ids in the dataset (e.g., 1466803322395175933)
        item_ids = []
        for filename in os.listdir(ROOT_DIR):
            filepath = os.path.join(ROOT_DIR, filename)
            if not os.path.isdir(filepath):
                continue
            item_id = filename
            item_ids.append(item_id)
        self.item_ids = sorted(item_ids)

        # Find label names in the dataset
        label_names = set()
        for item_id in item_ids:
            mask_paths = glob.glob(os.path.join(ROOT_DIR, item_id, "mask_*.jpg"))
            for mask_path in mask_paths:
                stem = os.path.splitext(os.path.basename(mask_path))[0]
                label_name = "_".join(stem.split("_")[2:])
                if label_name == "_container_":
                    continue
                label_names.add(label_name)
        self.label_names = sorted(label_names)

    def __len__(self):
        # Return the size of the dataset
        return len(self.item_ids)

    def __getitem__(self, index):
        item_id = self.item_ids[index]
        # Read image, bounding boxes, labels, and masks from item_id
        ...

        return image, bboxes, labels, masks
  • __init__(self): This is the constructor that initializes the dataset. It specifically initializes two data class attributes: self.item_ids and self.label_names.
    • self.item_ids: This holds the ID of a single data entry in our dataset iteration. Depending on the task and model architecture, it typically corresponds to each image.
    • self.label_names: This contains a unique list of label names (strings). Since we usually represent object labels as integers in the model, this map is necessary for conversion.
  • __len__(self): This function returns the size of the dataset. The number of iterations in one epoch, when PyTorch iterates over the dataset, is determined by the value this function returns.
  • __getitem__(self, index): This method returns the data item corresponding to the provided index. Typically, data loading occurs in this method, which is where we implement image/mask loading.


If you're a Pro member, sign in 🔒️ to read the rest of the guide. Or check our plans if you don't have them yet.