一、csv文件中包含图片地址与分类
1、先读取csv文件
1 2
| train_data = pd.read_csv("../data/classify-leaves/train.csv") test_data = pd.read_csv("../data/classify-leaves/test.csv")
|
2、将分类取出,得到classes
1
| classes = sorted(list(set(train_data['label'])))
|
3、把classes转成对应的数字
1
| classes_to_num = dict(zip(classes,range(n_classes)))
|
4、获取train_data
(1)定义一个class
初始化,传入csv文件路径,image公共路径,处理后的长宽,transform
1 2 3 4 5 6 7 8 9 10 11 12 13
| def __init__(self,csv_path,img_path,mode,height = 224, weight = 224,valid_ratio=0.2,transform = None): super(LeaveDataset, self).__init__()
self.csv_path = csv_path self.resize_height = height self.resize_weight = weight self.transform = transform self.img_path = img_path self.mode = mode
|
(2)获取data_info
使用pd.read_csv()读取数据,去除表头
1
| self.data_info = pd.read_csv(csv_path,header=None)
|
(3)划分数据集
分析获取到的csv数据
是否有验证集,若没有,则将训练集划分为训练集与验证集,通过分析结构,得到图片路径与label,
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
| self.data_len = len(self.data_info.index)-1 self.train_len = int(self.data_len*(1-valid_ratio))
if mode == 'train': self.img_arr = np.asarray(self.data_info.iloc[1:self.train_len, 0])
self.label_arr = np.asarray(self.data_info.iloc[1:self.train_len, 1]) elif mode == 'valid': self.img_arr = np.asarray(self.data_info.iloc[self.train_len:, 0]) self.label_arr = np.asarray(self.data_info.iloc[self.train_len:, 1]) elif mode == 'test': self.img_arr = np.asarray(self.data_info.iloc[1:, 0])
self.real_len = len(self.img_arr)
print('Finished reading the {} set of Leaves Dataset ({} samples found)' .format(self.mode, self.real_len))
|
(4)通过图片路径,获取图片,返回图片与label
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| def __getitem__(self, item): single_image_name = self.img_arr[item]
img = Image.open(self.img_path+single_image_name)
img = self.transform(img) if self.mode == 'test': return img else: label = self.label_arr[item] num_label = classes_to_num[label]
return (img,num_label)
def __len__(self): return self.real_len
|
5、获取test_data
test_data与train_data基本相同,没有label
6、调用traindata、validdata、testData
1 2 3 4
| train_dataset = LeaveDataset(csv_path="../data/classify-leaves/train.csv",img_path=Img_PATH,mode='train',transform = train_transform) valid_dataset = LeaveDataset(csv_path="../data/classify-leaves/train.csv",img_path=Img_PATH,mode='valid',transform = val_test_transform) test_dataset = LeaveDataset(csv_path="../data/classify-leaves/test.csv",img_path=Img_PATH,mode='test',transform = val_test_transform)
|
7、使用DataLoader分批次加载数据
1 2 3 4 5
| train_loader = DataLoader(train_dataset,batch_size,shuffle=True,num_workers=5) valid_loader = DataLoader(train_dataset,batch_size,shuffle=True,num_workers=5) test_loader = DataLoader(test_dataset,batch_size,shuffle=True,num_workers=5)
|
8、训练
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| for epoch in range(num_epoch): net.train() loss_sum = 0 loss_correct = 0 for i,data in enumerate(train_loader): inputs,labels = data inputs = inputs.to(device) labels = labels.to(device) outputs = net(inputs) loss = loss_func(outputs,labels)
optimizer.zero_grad() loss.backward() optimizer.step()
_,pred = torch.max(outputs.data,dim =1) correct = pred.eq(labels.data).cpu().sum()
loss_sum += loss.item() loss_correct += correct.item()
step += 1 print("train epoch", epoch + 1, "train loss is: ", loss_sum * 1.0 / len(train_loader), "train correct is: ", loss_correct * 100.0 / len(train_loader) / batch_size)
|
二、若为压缩包,解压为按照分类,将图片存入分类下的文件夹中
1、定义类别名称
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| label_name = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck" ]
|
2、批量glob训练数据与测试数据压缩包
1 2 3
| train_list = glob.glob("C:\document\python\pythonDemo\pytorchTest\data\cifar-10-python\cifar-10-batches-py\data_batch_*") test_list = glob.glob("C:\document\python\pythonDemo\pytorchTest\data\cifar-10-python\cifar-10-batches-py\\test_batch*")
|
3、定义解压后训练与测试图片路径
1 2
| save_path = "/pytorchTest/data/cifar-10-python/cifar-10-batches-py/test"
|
4、遍历训练数据压缩包
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| for l in test_list: l_dict = unpickle(l) for im_idx,im_data in enumerate(l_dict[b'data']):
im_label = l_dict[b'labels'][im_idx] im_name = l_dict[b'filenames'][im_idx]
im_label_name = label_name[im_label] im_data = np.reshape(im_data,[3,32,32]) im_data = np.transpose(im_data,(1,2,0))
if not os.path.exists("{}/{}".format(save_path,im_label_name)): os.mkdir("{}/{}".format(save_path,im_label_name))
cv2.imwrite("{}/{}/{}".format(save_path,im_label_name,im_name.decode("utf-8")),im_data)
|
三、已经将图片按照图片分类存入分类下的文件夹中
1、定义分类集合,并转化为对应数字
1 2 3 4 5 6 7 8 9 10
| abel_name = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
label_dict = {}
for idx, name in enumerate(label_name): label_dict[name] = idx
|
2、定义加载图片函数,Image.open()
并规定以什么样的方式打开
1 2 3
| def default_loader(path): return Image.open(path).convert("RGB")
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomRotation(90), transforms.ColorJitter(brightness=0.2, contrast=0.2, hue=0.2), transforms.RandomGrayscale(0.2), transforms.RandomCrop(28), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
test_transform = transforms.Compose([ transforms.CenterCrop((32, 32)), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
|
4、继承torch.utils.data中的Datasetm,并定义自己的Dataset Class
(1)首先定义__init__()函数
传入图片地址集合,数据增强方案transform
**根据文件夹名字为分类,可以采用split获取该图片的label,**并使用label_dict转为数字。使用append加入imgs列表中
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| def __init__(self, im_list, transform=None, loader = default_loader): super(MyDataset, self).__init__() imgs = []
for im_item in im_list: im_label_name = im_item.split("\\")[-2] imgs.append([im_item, label_dict[im_label_name]])
self.imgs = imgs self.transform = transform self.loader = loader
|
(2)定义getitem()函数
根据保存起来的imgs,里面有图片地址,以及分类。调用default_loader打开图片,然后进行数据增强,返回增强后的图片与分类
1 2 3 4 5 6 7 8
| def __getitem__(self, index): im_path, im_label = self.imgs[index] im_data = self.loader(im_path) if self.transform is not None: im_data = self.transform(im_data)
return im_data, im_label
|
(3)len()函数
1 2 3
| def __len__(self): return len(self.imgs)
|
5、调用上面的定义
(1)使用glob获取所有图片地址
1 2 3
| im_train_list = glob.glob("C:\\document\\python\\pythonDemo\\pytorchTest\\data\\cifar-10-python\\cifar-10-batches-py\\train\\*\\*.png") im_test_list = glob.glob("C:\\document\\python\\pythonDemo\\pytorchTest\\data\\cifar-10-python\\cifar-10-batches-py\\test\\*\\*.png")
|
(2)调用Mydataset
1 2 3
| train_dataset = MyDataset(im_train_list,transform=train_transform) test_dataset = MyDataset(im_test_list,transform =test_transform)
|
6、dataloader
1 2 3
| train_loader = DataLoader(dataset=train_dataset,batch_size=128,shuffle=True,num_workers=4) test_loader = DataLoader(dataset=test_dataset,batch_size=128,shuffle=False,num_workers=4)
|
四、划分数据集
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
| """ 将数据集划分为训练集,验证集,测试集 """
import os import random import shutil
def makedir(new_dir): if not os.path.exists(new_dir): os.makedirs(new_dir) random.seed(1)
dataset_dir = "C:\document\python\pythonDemo\pytorchTest\data\Rice_Image_Dataset\Rice_Image_Dataset"
split_dir = "C:\document\python\pythonDemo\pytorchTest\data\Rice_Image_Dataset" train_dir = os.path.join(split_dir, "train") valid_dir = os.path.join(split_dir, "val") test_dir = os.path.join(split_dir, "test")
train_pct = 0.9 valid_pct = 0.1 test_pct = 0.1
for root, dirs, files in os.walk(dataset_dir): for sub_dir in dirs: imgs = os.listdir(os.path.join(root, sub_dir)) imgs = list(filter(lambda x: x.endswith('.jpg'), imgs)) random.shuffle(imgs) img_count = len(imgs) train_point = int(img_count * train_pct) valid_point = int(img_count * (train_pct + valid_pct))
for i in range(img_count): if i < train_point: out_dir = os.path.join(train_dir, sub_dir) elif i < valid_point: out_dir = os.path.join(valid_dir, sub_dir) else: out_dir = os.path.join(test_dir, sub_dir) makedir(out_dir) target_path = os.path.join(out_dir, imgs[i]) src_path = os.path.join(dataset_dir, sub_dir, imgs[i]) shutil.copy(src_path, target_path)
print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point, img_count-valid_point))
|
五、相关知识
1、glob
返回所有匹配的文件路径列表。它只有一个参数pathname,定义了文件路径匹配规则,这里可以是绝对路径,也可以是相对路径。下面是使用glob.glob的例子:
1 2 3 4 5 6 7 8
| import glob
print (glob.glob(r"/home/qiaoyunhao/*/*.png"),"\n")
print (glob.glob(r'../*.py'))
|
2、asarray
转换输入为数组 array
输入参数
a:类数组。输入数据,可以是转换为数组的任意形式。比如列表、元组列表、元组、元组元组、列表元组和 ndarray;
dtype:数据类型,可选。默认情况下,该参数与数据数据类型相同。
order:{‘C’,’F’},可选。选择是行优先(C-style)或列优先(Fortran-style)存储。默认为行优先。
返回值
out:ndarray。‘a’ 的数组形式。如果输入已经是匹配 dtype 和 order 参数的 ndarray 形式,则不执行复制,如果输入是 ndarray 的一个子类,则返回一个基类 ndarray。
3、enumerate()
enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。