import cv2
import numpy as np
import os
import pickle
 
 
data_dir = os.path.join("data", "cifar-10-batches-py")
train_o_dir = os.path.join("data", "train")
test_o_dir = os.path.join("data", "test")
 
train = true   # 不解压训练集,仅解压测试集
 
# 解压缩,返回解压后的字典
def unpickle(file):
    with open(file, 'rb') as fo:
        dict_ = pickle.load(fo, encoding='bytes')
    return dict_
 
def my_mkdir(my_dir):
    if not os.path.isdir(my_dir):
        os.makedirs(my_dir)
 
 
# 生成训练集图片,
if __name__ == '__main__':
    if train:
        for j in range(1, 6):
            data_path = os.path.join(data_dir, "data_batch_" + str(j))  # data_batch_12345
            train_data = unpickle(data_path)
            print(data_path + " is loading...")
 
            for i in range(0, 10000):
                img = np.reshape(train_data[b'data'][i], (3, 32, 32))
                img = img.transpose(1, 2, 0)
 
                label_num = str(train_data[b'labels'][i])
                o_dir = os.path.join(train_o_dir, "data_batch_" + str(j) ,label_num)
                my_mkdir(o_dir)
 
                img_name = label_num + '_' + str(i + (j - 1)*10000) + '.png'
                img_path = os.path.join(o_dir, img_name)
                cv2.imwrite(img_path, img)
            print(data_path + " loaded.")
 
    print("test_batch is loading...")
 
    # 生成测试集图片
    test_data_path = os.path.join(data_dir, "test_batch")
    test_data = unpickle(test_data_path)
    for i in range(0, 10000):
        img = np.reshape(test_data[b'data'][i], (3, 32, 32))
        img = img.transpose(1, 2, 0)
 
        label_num = str(test_data[b'labels'][i])
        o_dir = os.path.join(test_o_dir, label_num)
        my_mkdir(o_dir)
 
        img_name = label_num + '_' + str(i) + '.png'
        img_path = os.path.join(o_dir, img_name)
        cv2.imwrite(img_path, img)
 
    print("test_batch loaded.")

import sys
import os
my_mkdir("data/traintxt")
#生成batch的txt   
data_dir = "data/train/"
datat = "data/traintxt"
for j in range(1, 6):
  data_path = os.path.join(data_dir, "data_batch_" + str(j))  # data_batch_12345
  datatraint = os.path.join(datat, "data_batch_" + str(j) + ".txt")
  ft = open(datatraint, 'w')
  print(data_path)
  for root, s_dirs, _ in os.walk(data_path, topdown=true):  # 获取 train文件下各文件夹名称
      print(s_dirs)
      for sub_dir in s_dirs:
          i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径
          img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径
          for i in range(len(img_list)):
              if not img_list[i].endswith('png'):         # 若不是png文件,跳过
                  continue
              label = img_list[i].split('_')[0]
              img_path = os.path.join(i_dir, img_list[i])
              line = img_path + ' ' + label + 'n'
              ft.write(line)
ft.close()

#总生成txt
data_dir = "data/train/"
datat = "data"
datatraint = os.path.join(datat, "train.txt")
ft = open(datatraint, 'w')
for j in range(1, 6):
  data_path = os.path.join(data_dir, "data_batch_" + str(j))  # data_batch_12345
  print(data_path)
  for root, s_dirs, _ in os.walk(data_path, topdown=true):  # 获取 train文件下各文件夹名称
      print(s_dirs)
      for sub_dir in s_dirs:
          i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径
          img_list = os.listdir(i_dir)                 # 获取类别文件夹下所有png图片的路径
          for i in range(len(img_list)):
              if not img_list[i].endswith('png'):         # 若不是png文件,跳过
                  continue
              label = img_list[i].split('_')[0]
              img_path = os.path.join(i_dir, img_list[i])
              line = img_path + ' ' + label + 'n'
              ft.write(line)
ft.close()

#test的txt
data_dir = "data"
datat = "data"

data_path = os.path.join(data_dir, "test")  
datatraint = os.path.join(datat, "test.txt")
ft = open(datatraint, 'w')
  
print(data_path)
for root, s_dirs, _ in os.walk(data_path, topdown=true):  # 获取 test文件下各文件夹名称
    print(s_dirs)
    for sub_dir in s_dirs:
        i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径
        img_list = os.listdir(i_dir)                 # 获取类别文件夹下所有png图片的路径
        for i in range(len(img_list)):
            if not img_list[i].endswith('png'):         # 若不是png文件,跳过
                continue
            label = img_list[i].split('_')[0]
            img_path = os.path.join(i_dir, img_list[i])
            line = img_path + ' ' + label + 'n'
            ft.write(line)
ft.close()

update from other’s github main.py

'''train cifar10 with pytorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import torch.backends.cudnn as cudnn
from torch.utils.data import dataset
from pil import image
import torchvision
import torchvision.transforms as transforms

import os
import argparse
from models import *
from utils import progress_bar

class mydataset(dataset):
    def __init__(self,txt_path,transform = none,target_transform = none):
        fh = open(txt_path,'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0],int(words[1])))
            self.imgs = imgs
            self.transform = transform
            self.target_transform = target_transform
    def __getitem__(self,index):
        fn,label = self.imgs[index]
        img = image.open(fn)
        if self.transform is not none:
            img = self.transform(img)
        return img,label
    def __len__(self):
        return len(self.imgs)



parser = argparse.argumentparser(description='pytorch cifar10 training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true',
                    help='resume from checkpoint')
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# data
print('==> preparing data..')
transform_train = transforms.compose([
    transforms.randomcrop(32, padding=4),
    transforms.randomhorizontalflip(),
    transforms.totensor(),
    transforms.normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.compose([
    transforms.randomcrop(32, padding=4),
    transforms.randomhorizontalflip(),
    transforms.totensor(),
    transforms.normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = mydataset(txt_path = '/work/aiit/warming/cifar-10-batches-py/train.txt',
                            transform=transform_train)
trainloader = torch.utils.data.dataloader(
    trainset, batch_size=128, shuffle=true, num_workers=2)

testset = mydataset(txt_path = '/work/aiit/warming/cifar-10-batches-py/test.txt',
                             transform=transform_test)
testloader = torch.utils.data.dataloader(
    testset, batch_size=100, shuffle=false, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# model
print('==> building model..')
#net = vgg.vgg('vgg19')
#net = resnet18()
# net = preactresnet18()
# net = googlenet()
# net = densenet121()
# net = resnext29_2x64d()
# net = mobilenet()
# net = mobilenetv2()
# net = dpn92()
# net = shufflenetg2()
#net = senet18()
# net = shufflenetv2(1)
# net = efficientnetb0()
net = regnetx_200mf()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.dataparallel(net)
    cudnn.benchmark = true

if args.resume:
    # load checkpoint.
    print('==> resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.pth')
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

criterion = nn.crossentropyloss()
optimizer = optim.sgd(net.parameters(), lr=args.lr,
                      momentum=0.9, weight_decay=5e-4)


# training
def train(epoch):
    print('nepoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'loss: %.3f | acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    torch.save(net, './checkpoint/regnetx_200mf.pth')


def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'loss: %.3f | acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        #torch.save(net, './checkpoint/ckpt1.pth')
        best_acc = acc


for epoch in range(start_epoch, start_epoch+100):
    train(epoch)
    test(epoch)

预测

import torch
import cv2
import torch.nn.functional as f
import sys 
sys.path.append('/work/aiit/warming/pytorch-cifar-master/models')
#import vgg
#import torchvision.models as models
#from vgg2 import vgg #重要,虽然显示灰色(即在次代码中没用到),但若没有引入这个模型代码,加载模型时会找不到模型
from torch.autograd import variable
from torchvision import datasets, transforms
import numpy as np
  
classes = ('plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #net=models.vgg19(pretrained=false)    
    model = (torch.load('/work/aiit/warming/pytorch-cifar-master/checkpoint/regnetx_200mf.pth')) # 加载模型
    model = model.to(device)
    model.eval() # 把模型转为test模式
  
    img = cv2.imread("/work/aiit/warming/cifar-10-batches-py/test/1/1_6.png") # 读取要预测的图片
    img=cv2.resize(img,(32,32))
    trans = transforms.compose(
    [
     transforms.totensor(),
     transforms.normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
  
    img = trans(img)
    img = img.to(device)
    img = img.unsqueeze(0) # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]
    # 扩展后,为[1,1,28,28]
    output = model(img)
    prob = f.softmax(output,dim=1) #prob是10个分类的概率
    print(prob)
    value, predicted = torch.max(output.data, 1)
    #print(predicted.item())
    #print(value)
    pred_class = classes[predicted.item()]
    print(pred_class)
  
    '''prob = f.softmax(output, dim=1)
    prob = variable(prob)
    prob = prob.cpu().numpy() # 用gpu的数据训练的模型保存的参数都是gpu形式的,要显示则先要转回cpu,再转回numpy模式
    print(prob) # prob是10个分类的概率
    pred = np.argmax(prob) # 选出概率最大的一个
    print(pred)
    print(pred.item())
    pred_class = classes[pred]
    print(pred_class)'''

 

分类: 未分类

0 条评论

发表回复

Avatar placeholder

您的邮箱地址不会被公开。 必填项已用 * 标注