PyTorch 將資料用
torch.utils.data.Dataset
類別包裝起來,定義每一次訓練迭代的資料長相,例如:一張影像和一個標籤、一張影像和多個標籤、一張影像和多個矩形方框的座標與長寬……等,將所有資料打包起來,送進torch.utils.data.DataLoader
類別,定義如何取樣資料,以及使用多少資源來得到一個批次 (batch) 的資料。
常用的資料集官方已整理好,可以逕行呼叫使用,如下筆者以 MS COCO 資料集為例;亦會說明如何依照讀者的需求,建立自定義資料集。
官方支持的常見資料集
官方提供以下常用公開資料集的 torchvision.datasets 類別,按應用條列如下,筆者皆建立資料集主頁面的連結,有更詳細的資料集說明,以及下載連結 (截至 PyTorch 1.6.0 資料集種類)。
- 影像辨識:MNIST(手寫數字)、EMNIST(手寫數字)、QMNIST(手寫數字)、USPS(手寫數字)、KMNIST(手寫日文字)、Fashion-MNIST(衣著)、LSUN(物件、場景)、Imagenet(物件、場景)、CIFAR10(物件)、CIFAR100(物件)、STL10(動物、交通)、SVHN(門牌數字)、CelebA(人臉 ID、屬性、特徵點)
- 物件偵測:MS COCO、VOCDetection
- 影像分割:VOCSegmentation、Cityscapes
- 標題生成:MS COCO、SBU、Flickr8k、Flickr30k
- 動作辨識:Kinetics400(影片)、HMDB51(影片)、UCF101(影片)
- 假影像資料生成:FakeData(可以給定影像大小、資料集大小和類別數)
- 其他:PhotoTour(局部影像)、SBDataset(物件邊緣)
- 通用格式:ImageFolder、DatasetFolder
其中,通用格式的ImageFolder class
是辨識任務經常使用的影像資料集格式,會依照資料夾儲存的影像,建立每張影像歸屬之類別,如下所示。
root/dog/001.png
root/dog/002.png
root/dog/003.pngroot/cat/cat.png
root/cat/cat1.png
root/cat/cat_.png
而DatasetFolder class
則不限定資料集為影像,同樣會依照資料夾儲存的影像,建立每張影像歸屬之類別,如下所示。
root/dog/001.txt
root/dog/002.txt
root/dog/003.txtroot/cat/cat.txt
root/cat/cat1.txt
root/cat/cat_.txt
官方資料集使用範例
PyTorch 提供兩種 MS COCO 資料集,分別為生成影像 caption 的dset.CocoCaptions
,以及物件偵測用的dset.CocoDetection
。首先,先進行 pycocotools
套件安裝。
pip install "git+https://github.com/philferriere/cocoapi.git#egg=pycocotools&subdirectory=PythonAPI"
官方資料集使用方法大同小異,如下例,傳入影像位址root
與標記 json 檔位址annFile
(下載請參照地址),完成初始化動作,即可索引在資料集長度內的每一項物件,因此筆者依照訓練和驗證兩個用途的需求,分別建立資料集。
讓我們再看一例,如下是建立物件偵測的資料集,每一次索引可以拿到一張影像的張量與影像內的物件標記資料。
值得一提的是,PyTorch 提供許多轉換功能。如上兩例,初始化當中的trns.ToTensor()
,即是將影像轉換為張量形式,方便後續訓練過程的操作。筆者常用的轉換功能由下面的例子來說明:
首先一律將影像調整大小至256*256
,之後隨機擷取224*224
大小的影像,轉換為張量形式,最後是數值的歸一化。多個轉換功能使用trns.Compose
串聯在一起,而trns.Compose
具有順序性,因此轉換功能的順序值得注意。
先調整至略大的影像,再隨機擷取至模型輸入之大小,是常用的影像增量技巧之一。
影像大小的配置,與歸一化的數值,皆是依照 PyTorch 提供的預訓練模型進行配置,採224*224
大小的影像進行訓練,而歸一化的均值mean
與標準差std
皆以RGB
的順序來運算。
Tricks & Tips 👻
PIL
讀取的影像為RGB
順序,cv2
讀取的影像為BGR
順序。- PyTorch 的
transform
接口多是對應到PIL
和numpy
,多採用此兩個套件的功能可減少物件轉換的麻煩。
自定義資料集 (Custom Dataset)
繼承自torch.utils.data.Dataset
,一個自定義資料集的框架如下,主要實現__getitem__()
和__len__()
這兩個方法。
如下,筆者以狗狗資料集為例,下載地址。
主要常以資料位址、子資料集的標籤和轉換條件…..等,作為繼承Dataset
類別的自定義資料集的初始條件,再分別定義訓練與驗證的轉換條件傳入訓練集與驗證集。藉由train_transfrom
進行資料增量,提高資料的多樣性;相反地,val_transfrom
則應維持每一次的驗證條件為公正的,因此僅包含調整大小、轉換至張量,以及歸一化。
最後,torch.utils.data.DataLoader
類別定義如何取樣dataset
資料集,是否shuffle
用以打亂資料集的順序,使用多少num_workers
的線程 (thread) 資源來得到一個batch_size
大小的資料。因此,每一次的for
當中,可以得到大小batch_size*3*224*224
的影像資料,大小batch_size
的標記。
更多的自定義資料集,可以參考官方提供的 FaceLandmarksDataset,以及 github 範例。
感謝您的閱讀,如果文章有益請在底下長按拍手
有任何問題歡迎在底下留言或是來信交流wanju.ts@gmail.com