Visual Graph Arena Dataset

Python Code Example

Below is a Python code snippet that demonstrates basic data loading for the dataset. You can specify the desired task using the "--task" argument. Choose between isomorphism/easy , isomorphism/hard , path/hamiltonian , path/shortest , cycle/hamiltonian , and cycle/chordless .


from torch.utils.data import Dataset
from PIL import Image
import os
import argparse
from torchvision import transforms


# ======================== Dataset Class ========================
def make_labels(label_file):  # A madule to extract label.txt files.
    labels = [line.strip() for line in open(label_file, 'r')]

    if 'true' in labels:  # handles true and false labels
        return [int(label == 'true') for label in labels]

    labels = [int(label) for label in labels]
    if 1 in labels:  # Then labels are 1, 2, 3, 4. make it 0, 1, 2, 3
        return [i - 1 for i in labels]
    if 6 in labels:  # Then labels are 3, 4, 5, 6. make it 0, 1, 2, 3
        return [i - 3 for i in labels]

    raise ValueError("Invalid labels")


class ImageDataset(Dataset):
    def __init__(self, img_dir, label_file, transformer):
        self.img_dir = img_dir
        self.labels = make_labels(label_file)
        self.transformer = transformer

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, f"image_{idx}.png")
        image = Image.open(img_path).convert("L")  # Load as grayscale
        image = Image.merge("RGB", (image, image, image))  # Convert to RGB by duplicating channels
        label = self.labels[idx]
        if self.transformer:
            image = self.transformer(image)
        return image, label


# ======================== Dataset Arguments ========================
parser = argparse.ArgumentParser(description="Load a dataset.")
parser.add_argument('--dataset_path', type=str, default='./vision_graph_arena', help='The path of the dataset')
parser.add_argument('--task', type=str, default='isomorphism/easy',
                    help='The task you want to train on \n You can choose between `isomorphism/easy`,'
                         ' `isomorphism/hard`, `path/hamiltonian`, `path/shortest`,'
                         ' `cycle/hamiltonian`, and `cycle/chordless`')
args = parser.parse_args()

if args.task == 'path/shortest' or args.task == 'cycle/chordless':
    num_output = 4  # number of output classes
else:
    num_output = 2

transform = transforms.Compose([
    # TODO: Resize and transform the image as you wish
])

# ======================== Dataset Creation ========================
train_dataset = ImageDataset(f'{args.dataset_path}/{args.task}/inputs/train',
                             f'{args.dataset_path}/{args.task}/labels/train/labels.txt', transform)

if 'isomorphism' in args.task:  # only isomorphism tasks have one test dataset
    test_datasets = [ImageDataset(f'{args.dataset_path}/{args.task}/inputs/test',
                                  f'{args.dataset_path}/{args.task}/labels/test/labels.txt', transform)]
else:  # other tasks have two test datasets, only in Kawai layout and the another in Planar
    test_datasets = [ImageDataset(f'{args.dataset_path}/{args.task}/inputs/test/kamada_kawai',
                                  f'{args.dataset_path}/{args.task}/labels/test/labels.txt', transform),
                     ImageDataset(f'{args.dataset_path}/{args.task}/inputs/test/planar',
                                  f'{args.dataset_path}/{args.task}/labels/test/labels.txt', transform)]
        
Python