PyTorch - загрузка данных

PyTorch включает пакет под названием torchvision, который используется для загрузки и подготовки набора данных. Он включает в себя две основные функции, а именно Dataset и DataLoader, которые помогают в преобразовании и загрузке набора данных.

Набор данных

Набор данных используется для чтения и преобразования точки данных из данного набора данных. Базовый синтаксис для реализации упомянут ниже -

trainset = torchvision.datasets.CIFAR10(root = './data', train = True,
   download = True, transform = transform)

DataLoader используется для перемешивания и пакетной обработки данных. Его можно использовать для загрузки данных параллельно многопроцессорным рабочим.

trainloader = torch.utils.data.DataLoader(trainset, batch_size = 4,
   shuffle = True, num_workers = 2)

Пример: загрузка файла CSV

Мы используем пакет Python Panda для загрузки файла csv. Исходный файл имеет следующий формат: (название изображения, 68 ориентиров - каждый ориентир имеет координаты оси, y).

landmarks_frame = pd.read_csv('faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)