PyTorch 自定義資料集 (Custom Dataset)

rowan.ts
Oct 7, 2018

--

Photo by Franki Chamaki on Unsplash

PyTorch 將資料用torch.utils.data.Dataset類別包裝起來,定義每一次訓練迭代的資料長相,例如:一張影像和一個標籤、一張影像和多個標籤、一張影像和多個矩形方框的座標與長寬……等,將所有資料打包起來,送進torch.utils.data.DataLoader類別,定義如何取樣資料,以及使用多少資源來得到一個批次 (batch) 的資料。

常用的資料集官方已整理好,可以逕行呼叫使用,如下筆者以 MS COCO 資料集為例;亦會說明如何依照讀者的需求,建立自定義資料集。

官方支持的常見資料集

官方提供以下常用公開資料集的 torchvision.datasets 類別,按應用條列如下,筆者皆建立資料集主頁面的連結,有更詳細的資料集說明,以及下載連結 (截至 PyTorch 1.6.0 資料集種類)。

其中,通用格式的ImageFolder class是辨識任務經常使用的影像資料集格式,會依照資料夾儲存的影像,建立每張影像歸屬之類別,如下所示。

root/dog/001.png
root/dog/002.png
root/dog/003.png
root/cat/cat.png
root/cat/cat1.png
root/cat/cat_.png

DatasetFolder class則不限定資料集為影像,同樣會依照資料夾儲存的影像,建立每張影像歸屬之類別,如下所示。

root/dog/001.txt
root/dog/002.txt
root/dog/003.txt
root/cat/cat.txt
root/cat/cat1.txt
root/cat/cat_.txt

官方資料集使用範例

PyTorch 提供兩種 MS COCO 資料集,分別為生成影像 caption 的dset.CocoCaptions,以及物件偵測用的dset.CocoDetection。首先,先進行 pycocotools套件安裝。

pip install "git+https://github.com/philferriere/cocoapi.git#egg=pycocotools&subdirectory=PythonAPI"

官方資料集使用方法大同小異,如下例,傳入影像位址root與標記 json 檔位址annFile (下載請參照地址),完成初始化動作,即可索引在資料集長度內的每一項物件,因此筆者依照訓練和驗證兩個用途的需求,分別建立資料集。

dset.CocoCaptions 使用範例

讓我們再看一例,如下是建立物件偵測的資料集,每一次索引可以拿到一張影像的張量與影像內的物件標記資料。

dset.CocoDetection 使用範例

值得一提的是,PyTorch 提供許多轉換功能。如上兩例,初始化當中的trns.ToTensor(),即是將影像轉換為張量形式,方便後續訓練過程的操作。筆者常用的轉換功能由下面的例子來說明:

首先一律將影像調整大小至256*256,之後隨機擷取224*224大小的影像,轉換為張量形式,最後是數值的歸一化。多個轉換功能使用trns.Compose串聯在一起,而trns.Compose具有順序性,因此轉換功能的順序值得注意。

先調整至略大的影像,再隨機擷取至模型輸入之大小,是常用的影像增量技巧之一。

影像大小的配置,與歸一化的數值,皆是依照 PyTorch 提供的預訓練模型進行配置,採224*224大小的影像進行訓練,而歸一化的均值mean與標準差std皆以RGB的順序來運算。

transform 使用範例

Tricks & Tips 👻

  1. PIL讀取的影像為RGB順序,cv2讀取的影像為BGR順序。
  2. PyTorch 的transform 接口多是對應到PILnumpy,多採用此兩個套件的功能可減少物件轉換的麻煩。

自定義資料集 (Custom Dataset)

繼承自torch.utils.data.Dataset,一個自定義資料集的框架如下,主要實現__getitem__()__len__()這兩個方法。

PyTorch 資料集類別框架

如下,筆者以狗狗資料集為例,下載地址

主要常以資料位址、子資料集的標籤和轉換條件…..等,作為繼承Dataset類別的自定義資料集的初始條件,再分別定義訓練與驗證的轉換條件傳入訓練集與驗證集。藉由train_transfrom進行資料增量,提高資料的多樣性;相反地,val_transfrom則應維持每一次的驗證條件為公正的,因此僅包含調整大小、轉換至張量,以及歸一化。

最後,torch.utils.data.DataLoader類別定義如何取樣dataset資料集,是否shuffle用以打亂資料集的順序,使用多少num_workers的線程 (thread) 資源來得到一個batch_size大小的資料。因此,每一次的for當中,可以得到大小batch_size*3*224*224的影像資料,大小batch_size的標記。

自定義資料集範例

更多的自定義資料集,可以參考官方提供的 FaceLandmarksDataset,以及 github 範例

感謝您的閱讀,如果文章有益請在底下長按拍手
有任何問題歡迎在底下留言或是來信交流wanju.ts@gmail.com

--

--

No responses yet