{"id":153,"date":"2020-11-19T04:31:00","date_gmt":"2020-11-19T04:31:00","guid":{"rendered":"https:\/\/imwarming.com\/?p=153"},"modified":"2020-11-19T04:31:00","modified_gmt":"2020-11-19T04:31:00","slug":"cifar-10-dataset","status":"publish","type":"post","link":"https:\/\/imwarming.com\/?p=153","title":{"rendered":"cifar-10-dataset"},"content":{"rendered":"<div class=\"cnblogs_code\">\n<pre><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> cv2\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> numpy as np\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> os\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> pickle\n \n \ndata_dir <\/span>= os.path.join(<span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data<\/span><span style=\"color: #800000;\">\"<\/span>, <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">cifar-10-batches-py<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">)\ntrain_o_dir <\/span>= os.path.join(<span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data<\/span><span style=\"color: #800000;\">\"<\/span>, <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">train<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">)\ntest_o_dir <\/span>= os.path.join(<span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data<\/span><span style=\"color: #800000;\">\"<\/span>, <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">test<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">)\n \ntrain <\/span>= true   <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u4e0d\u89e3\u538b\u8bad\u7ec3\u96c6\uff0c\u4ec5\u89e3\u538b\u6d4b\u8bd5\u96c6<\/span>\n \n<span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u89e3\u538b\u7f29\uff0c\u8fd4\u56de\u89e3\u538b\u540e\u7684\u5b57\u5178<\/span>\n<span style=\"color: #0000ff;\">def<\/span><span style=\"color: #000000;\"> unpickle(file):\n    with open(file, <\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">rb<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">) as fo:\n        dict_ <\/span>= pickle.load(fo, encoding=<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">bytes<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n    <\/span><span style=\"color: #0000ff;\">return<\/span><span style=\"color: #000000;\"> dict_\n \n<\/span><span style=\"color: #0000ff;\">def<\/span><span style=\"color: #000000;\"> my_mkdir(my_dir):\n    <\/span><span style=\"color: #0000ff;\">if<\/span> <span style=\"color: #0000ff;\">not<\/span><span style=\"color: #000000;\"> os.path.isdir(my_dir):\n        os.makedirs(my_dir)\n \n \n<\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u751f\u6210\u8bad\u7ec3\u96c6\u56fe\u7247\uff0c<\/span>\n<span style=\"color: #0000ff;\">if<\/span> <span style=\"color: #800080;\">__name__<\/span> == <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">__main__<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">:\n    <\/span><span style=\"color: #0000ff;\">if<\/span><span style=\"color: #000000;\"> train:\n        <\/span><span style=\"color: #0000ff;\">for<\/span> j <span style=\"color: #0000ff;\">in<\/span> range(1, 6<span style=\"color: #000000;\">):\n            data_path <\/span>= os.path.join(data_dir, <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data_batch_<\/span><span style=\"color: #800000;\">\"<\/span> + str(j))  <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> data_batch_12345<\/span>\n            train_data =<span style=\"color: #000000;\"> unpickle(data_path)\n            <\/span><span style=\"color: #0000ff;\">print<\/span>(data_path + <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\"> is loading...<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">)\n \n            <\/span><span style=\"color: #0000ff;\">for<\/span> i <span style=\"color: #0000ff;\">in<\/span> range(0, 10000<span style=\"color: #000000;\">):\n                img <\/span>= np.reshape(train_data[b<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">data<\/span><span style=\"color: #800000;\">'<\/span>][i], (3, 32, 32<span style=\"color: #000000;\">))\n                img <\/span>= img.transpose(1, 2<span style=\"color: #000000;\">, 0)\n \n                label_num <\/span>= str(train_data[b<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">labels<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">][i])\n                o_dir <\/span>= os.path.join(train_o_dir, <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data_batch_<\/span><span style=\"color: #800000;\">\"<\/span> +<span style=\"color: #000000;\"> str(j) ,label_num)\n                my_mkdir(o_dir)\n \n                img_name <\/span>= label_num + <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">_<\/span><span style=\"color: #800000;\">'<\/span> + str(i + (j - 1)*10000) + <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">.png<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">\n                img_path <\/span>=<span style=\"color: #000000;\"> os.path.join(o_dir, img_name)\n                cv2.imwrite(img_path, img)\n            <\/span><span style=\"color: #0000ff;\">print<\/span>(data_path + <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\"> loaded.<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">)\n \n    <\/span><span style=\"color: #0000ff;\">print<\/span>(<span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">test_batch is loading...<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">)\n \n    <\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u751f\u6210\u6d4b\u8bd5\u96c6\u56fe\u7247<\/span>\n    test_data_path = os.path.join(data_dir, <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">test_batch<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">)\n    test_data <\/span>=<span style=\"color: #000000;\"> unpickle(test_data_path)\n    <\/span><span style=\"color: #0000ff;\">for<\/span> i <span style=\"color: #0000ff;\">in<\/span> range(0, 10000<span style=\"color: #000000;\">):\n        img <\/span>= np.reshape(test_data[b<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">data<\/span><span style=\"color: #800000;\">'<\/span>][i], (3, 32, 32<span style=\"color: #000000;\">))\n        img <\/span>= img.transpose(1, 2<span style=\"color: #000000;\">, 0)\n \n        label_num <\/span>= str(test_data[b<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">labels<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">][i])\n        o_dir <\/span>=<span style=\"color: #000000;\"> os.path.join(test_o_dir, label_num)\n        my_mkdir(o_dir)\n \n        img_name <\/span>= label_num + <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">_<\/span><span style=\"color: #800000;\">'<\/span> + str(i) + <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">.png<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">\n        img_path <\/span>=<span style=\"color: #000000;\"> os.path.join(o_dir, img_name)\n        cv2.imwrite(img_path, img)\n \n    <\/span><span style=\"color: #0000ff;\">print<\/span>(<span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">test_batch loaded.<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">)\n\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> sys\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> os\nmy_mkdir(<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data\/traintxt<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">)\n<\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\">\u751f\u6210batch\u7684txt   <\/span>\ndata_dir = <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data\/train\/<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">\ndatat <\/span>= <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data\/traintxt<\/span><span style=\"color: #800000;\">\"<\/span>\n<span style=\"color: #0000ff;\">for<\/span> j <span style=\"color: #0000ff;\">in<\/span> range(1, 6<span style=\"color: #000000;\">):\n  data_path <\/span>= os.path.join(data_dir, <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data_batch_<\/span><span style=\"color: #800000;\">\"<\/span> + str(j))  <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> data_batch_12345<\/span>\n  datatraint = os.path.join(datat, <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data_batch_<\/span><span style=\"color: #800000;\">\"<\/span> + str(j) + <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">.txt<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">)\n  ft <\/span>= open(datatraint, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">w<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n  <\/span><span style=\"color: #0000ff;\">print<\/span><span style=\"color: #000000;\">(data_path)\n  <\/span><span style=\"color: #0000ff;\">for<\/span> root, s_dirs, _ <span style=\"color: #0000ff;\">in<\/span> os.walk(data_path, topdown=true):  <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u83b7\u53d6 train\u6587\u4ef6\u4e0b\u5404\u6587\u4ef6\u5939\u540d\u79f0<\/span>\n      <span style=\"color: #0000ff;\">print<\/span><span style=\"color: #000000;\">(s_dirs)\n      <\/span><span style=\"color: #0000ff;\">for<\/span> sub_dir <span style=\"color: #0000ff;\">in<\/span><span style=\"color: #000000;\"> s_dirs:\n          i_dir <\/span>= os.path.join(root, sub_dir)             <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u83b7\u53d6\u5404\u7c7b\u7684\u6587\u4ef6\u5939 \u7edd\u5bf9\u8def\u5f84<\/span>\n          img_list = os.listdir(i_dir)                    <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u83b7\u53d6\u7c7b\u522b\u6587\u4ef6\u5939\u4e0b\u6240\u6709png\u56fe\u7247\u7684\u8def\u5f84<\/span>\n          <span style=\"color: #0000ff;\">for<\/span> i <span style=\"color: #0000ff;\">in<\/span><span style=\"color: #000000;\"> range(len(img_list)):\n              <\/span><span style=\"color: #0000ff;\">if<\/span> <span style=\"color: #0000ff;\">not<\/span> img_list[i].endswith(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">png<\/span><span style=\"color: #800000;\">'<\/span>):         <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u82e5\u4e0d\u662fpng\u6587\u4ef6\uff0c\u8df3\u8fc7<\/span>\n                  <span style=\"color: #0000ff;\">continue<\/span><span style=\"color: #000000;\">\n              label <\/span>= img_list[i].split(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">_<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)[0]\n              img_path <\/span>=<span style=\"color: #000000;\"> os.path.join(i_dir, img_list[i])\n              line <\/span>= img_path + <span style=\"color: #800000;\">'<\/span> <span style=\"color: #800000;\">'<\/span> + label + <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">n<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">\n              ft.write(line)\nft.close()\n\n<\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\">\u603b\u751f\u6210txt<\/span>\ndata_dir = <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data\/train\/<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">\ndatat <\/span>= <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">\ndatatraint <\/span>= os.path.join(datat, <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">train.txt<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">)\nft <\/span>= open(datatraint, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">w<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n<\/span><span style=\"color: #0000ff;\">for<\/span> j <span style=\"color: #0000ff;\">in<\/span> range(1, 6<span style=\"color: #000000;\">):\n  data_path <\/span>= os.path.join(data_dir, <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data_batch_<\/span><span style=\"color: #800000;\">\"<\/span> + str(j))  <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> data_batch_12345<\/span>\n  <span style=\"color: #0000ff;\">print<\/span><span style=\"color: #000000;\">(data_path)\n  <\/span><span style=\"color: #0000ff;\">for<\/span> root, s_dirs, _ <span style=\"color: #0000ff;\">in<\/span> os.walk(data_path, topdown=true):  <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u83b7\u53d6 train\u6587\u4ef6\u4e0b\u5404\u6587\u4ef6\u5939\u540d\u79f0<\/span>\n      <span style=\"color: #0000ff;\">print<\/span><span style=\"color: #000000;\">(s_dirs)\n      <\/span><span style=\"color: #0000ff;\">for<\/span> sub_dir <span style=\"color: #0000ff;\">in<\/span><span style=\"color: #000000;\"> s_dirs:\n          i_dir <\/span>= os.path.join(root, sub_dir)             <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u83b7\u53d6\u5404\u7c7b\u7684\u6587\u4ef6\u5939 \u7edd\u5bf9\u8def\u5f84<\/span>\n          img_list = os.listdir(i_dir)                 <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u83b7\u53d6\u7c7b\u522b\u6587\u4ef6\u5939\u4e0b\u6240\u6709png\u56fe\u7247\u7684\u8def\u5f84<\/span>\n          <span style=\"color: #0000ff;\">for<\/span> i <span style=\"color: #0000ff;\">in<\/span><span style=\"color: #000000;\"> range(len(img_list)):\n              <\/span><span style=\"color: #0000ff;\">if<\/span> <span style=\"color: #0000ff;\">not<\/span> img_list[i].endswith(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">png<\/span><span style=\"color: #800000;\">'<\/span>):         <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u82e5\u4e0d\u662fpng\u6587\u4ef6\uff0c\u8df3\u8fc7<\/span>\n                  <span style=\"color: #0000ff;\">continue<\/span><span style=\"color: #000000;\">\n              label <\/span>= img_list[i].split(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">_<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)[0]\n              img_path <\/span>=<span style=\"color: #000000;\"> os.path.join(i_dir, img_list[i])\n              line <\/span>= img_path + <span style=\"color: #800000;\">'<\/span> <span style=\"color: #800000;\">'<\/span> + label + <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">n<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">\n              ft.write(line)\nft.close()\n\n<\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\">test\u7684txt<\/span>\ndata_dir = <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">\ndatat <\/span>= <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">data<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">\n\ndata_path <\/span>= os.path.join(data_dir, <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">test<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">)  \ndatatraint <\/span>= os.path.join(datat, <span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">test.txt<\/span><span style=\"color: #800000;\">\"<\/span><span style=\"color: #000000;\">)\nft <\/span>= open(datatraint, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">w<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n  \n<\/span><span style=\"color: #0000ff;\">print<\/span><span style=\"color: #000000;\">(data_path)\n<\/span><span style=\"color: #0000ff;\">for<\/span> root, s_dirs, _ <span style=\"color: #0000ff;\">in<\/span> os.walk(data_path, topdown=true):  <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u83b7\u53d6 test\u6587\u4ef6\u4e0b\u5404\u6587\u4ef6\u5939\u540d\u79f0<\/span>\n    <span style=\"color: #0000ff;\">print<\/span><span style=\"color: #000000;\">(s_dirs)\n    <\/span><span style=\"color: #0000ff;\">for<\/span> sub_dir <span style=\"color: #0000ff;\">in<\/span><span style=\"color: #000000;\"> s_dirs:\n        i_dir <\/span>= os.path.join(root, sub_dir)             <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u83b7\u53d6\u5404\u7c7b\u7684\u6587\u4ef6\u5939 \u7edd\u5bf9\u8def\u5f84<\/span>\n        img_list = os.listdir(i_dir)                 <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u83b7\u53d6\u7c7b\u522b\u6587\u4ef6\u5939\u4e0b\u6240\u6709png\u56fe\u7247\u7684\u8def\u5f84<\/span>\n        <span style=\"color: #0000ff;\">for<\/span> i <span style=\"color: #0000ff;\">in<\/span><span style=\"color: #000000;\"> range(len(img_list)):\n            <\/span><span style=\"color: #0000ff;\">if<\/span> <span style=\"color: #0000ff;\">not<\/span> img_list[i].endswith(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">png<\/span><span style=\"color: #800000;\">'<\/span>):         <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u82e5\u4e0d\u662fpng\u6587\u4ef6\uff0c\u8df3\u8fc7<\/span>\n                <span style=\"color: #0000ff;\">continue<\/span><span style=\"color: #000000;\">\n            label <\/span>= img_list[i].split(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">_<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)[0]\n            img_path <\/span>=<span style=\"color: #000000;\"> os.path.join(i_dir, img_list[i])\n            line <\/span>= img_path + <span style=\"color: #800000;\">'<\/span> <span style=\"color: #800000;\">'<\/span> + label + <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">n<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">\n            ft.write(line)\nft.close()<\/span><\/pre>\n<\/div>\n<p>update from&nbsp;<a href=\"http:\/\/github.com\/kuangliu\/pytorch-cifar\" rel=\"noreferrer noopener\" target=\"_blank\">other&rsquo;s github main.py<\/a><\/p>\n<div class=\"cnblogs_code\">\n<pre><span style=\"color: #800000;\">'''<\/span><span style=\"color: #800000;\">train cifar10 with pytorch.<\/span><span style=\"color: #800000;\">'''<\/span>\n<span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> torch\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> torch.nn as nn\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> torch.optim as optim\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> torch.nn.functional as f\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> torch.backends.cudnn as cudnn\n<\/span><span style=\"color: #0000ff;\">from<\/span> torch.utils.data <span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> dataset\n<\/span><span style=\"color: #0000ff;\">from<\/span> pil <span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> image\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> torchvision\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> torchvision.transforms as transforms\n\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> os\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> argparse\n<\/span><span style=\"color: #0000ff;\">from<\/span> models <span style=\"color: #0000ff;\">import<\/span> *\n<span style=\"color: #0000ff;\">from<\/span> utils <span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> progress_bar\n\n<\/span><span style=\"color: #0000ff;\">class<\/span><span style=\"color: #000000;\"> mydataset(dataset):\n    <\/span><span style=\"color: #0000ff;\">def<\/span> <span style=\"color: #800080;\">__init__<\/span>(self,txt_path,transform = none,target_transform =<span style=\"color: #000000;\"> none):\n        fh <\/span>= open(txt_path,<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">r<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n        imgs <\/span>=<span style=\"color: #000000;\"> []\n        <\/span><span style=\"color: #0000ff;\">for<\/span> line <span style=\"color: #0000ff;\">in<\/span><span style=\"color: #000000;\"> fh:\n            line <\/span>=<span style=\"color: #000000;\"> line.rstrip()\n            words <\/span>=<span style=\"color: #000000;\"> line.split()\n            imgs.append((words[0],int(words[<\/span>1<span style=\"color: #000000;\">])))\n            self.imgs <\/span>=<span style=\"color: #000000;\"> imgs\n            self.transform <\/span>=<span style=\"color: #000000;\"> transform\n            self.target_transform <\/span>=<span style=\"color: #000000;\"> target_transform\n    <\/span><span style=\"color: #0000ff;\">def<\/span> <span style=\"color: #800080;\">__getitem__<\/span><span style=\"color: #000000;\">(self,index):\n        fn,label <\/span>=<span style=\"color: #000000;\"> self.imgs[index]\n        img <\/span>=<span style=\"color: #000000;\"> image.open(fn)\n        <\/span><span style=\"color: #0000ff;\">if<\/span> self.transform <span style=\"color: #0000ff;\">is<\/span> <span style=\"color: #0000ff;\">not<\/span><span style=\"color: #000000;\"> none:\n            img <\/span>=<span style=\"color: #000000;\"> self.transform(img)\n        <\/span><span style=\"color: #0000ff;\">return<\/span><span style=\"color: #000000;\"> img,label\n    <\/span><span style=\"color: #0000ff;\">def<\/span> <span style=\"color: #800080;\">__len__<\/span><span style=\"color: #000000;\">(self):\n        <\/span><span style=\"color: #0000ff;\">return<\/span><span style=\"color: #000000;\"> len(self.imgs)\n\n\n\nparser <\/span>= argparse.argumentparser(description=<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">pytorch cifar10 training<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\nparser.add_argument(<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">--lr<\/span><span style=\"color: #800000;\">'<\/span>, default=0.1, type=float, help=<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">learning rate<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\nparser.add_argument(<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">--resume<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">-r<\/span><span style=\"color: #800000;\">'<\/span>, action=<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">store_true<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">,\n                    help<\/span>=<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">resume from checkpoint<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\nargs <\/span>=<span style=\"color: #000000;\"> parser.parse_args()\n\ndevice <\/span>= <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">cuda<\/span><span style=\"color: #800000;\">'<\/span> <span style=\"color: #0000ff;\">if<\/span> torch.cuda.is_available() <span style=\"color: #0000ff;\">else<\/span> <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">cpu<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">\nbest_acc <\/span>= 0  <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> best test accuracy<\/span>\nstart_epoch = 0  <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> start from epoch 0 or last checkpoint epoch<\/span>\n\n<span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> data<\/span>\n<span style=\"color: #0000ff;\">print<\/span>(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">==&gt; preparing data..<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\ntransform_train <\/span>=<span style=\"color: #000000;\"> transforms.compose([\n    transforms.randomcrop(<\/span>32, padding=4<span style=\"color: #000000;\">),\n    transforms.randomhorizontalflip(),\n    transforms.totensor(),\n    transforms.normalize((<\/span>0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010<span style=\"color: #000000;\">)),\n])\n\ntransform_test <\/span>=<span style=\"color: #000000;\"> transforms.compose([\n    transforms.randomcrop(<\/span>32, padding=4<span style=\"color: #000000;\">),\n    transforms.randomhorizontalflip(),\n    transforms.totensor(),\n    transforms.normalize((<\/span>0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010<span style=\"color: #000000;\">)),\n])\n\ntrainset <\/span>= mydataset(txt_path = <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">\/work\/aiit\/warming\/cifar-10-batches-py\/train.txt<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">,\n                            transform<\/span>=<span style=\"color: #000000;\">transform_train)\ntrainloader <\/span>=<span style=\"color: #000000;\"> torch.utils.data.dataloader(\n    trainset, batch_size<\/span>=128, shuffle=true, num_workers=2<span style=\"color: #000000;\">)\n\ntestset <\/span>= mydataset(txt_path = <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">\/work\/aiit\/warming\/cifar-10-batches-py\/test.txt<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">,\n                             transform<\/span>=<span style=\"color: #000000;\">transform_test)\ntestloader <\/span>=<span style=\"color: #000000;\"> torch.utils.data.dataloader(\n    testset, batch_size<\/span>=100, shuffle=false, num_workers=2<span style=\"color: #000000;\">)\n\nclasses <\/span>= (<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">plane<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">car<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">bird<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">cat<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">deer<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">,\n           <\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">dog<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">frog<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">horse<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">ship<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">truck<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n\n<\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> model<\/span>\n<span style=\"color: #0000ff;\">print<\/span>(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">==&gt; building model..<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n<\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\">net = vgg.vgg('vgg19')<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\">net = resnet18()<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\"> net = preactresnet18()<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\"> net = googlenet()<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\"> net = densenet121()<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\"> net = resnext29_2x64d()<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\"> net = mobilenet()<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\"> net = mobilenetv2()<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\"> net = dpn92()<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\"> net = shufflenetg2()<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\">net = senet18()<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\"> net = shufflenetv2(1)<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\"> net = efficientnetb0()<\/span>\nnet =<span style=\"color: #000000;\"> regnetx_200mf()\nnet <\/span>=<span style=\"color: #000000;\"> net.to(device)\n<\/span><span style=\"color: #0000ff;\">if<\/span> device == <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">cuda<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">:\n    net <\/span>=<span style=\"color: #000000;\"> torch.nn.dataparallel(net)\n    cudnn.benchmark <\/span>=<span style=\"color: #000000;\"> true\n\n<\/span><span style=\"color: #0000ff;\">if<\/span><span style=\"color: #000000;\"> args.resume:\n    <\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> load checkpoint.<\/span>\n    <span style=\"color: #0000ff;\">print<\/span>(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">==&gt; resuming from checkpoint..<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n    <\/span><span style=\"color: #0000ff;\">assert<\/span> os.path.isdir(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">checkpoint<\/span><span style=\"color: #800000;\">'<\/span>), <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">error: no checkpoint directory found!<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">\n    checkpoint <\/span>= torch.load(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">.\/checkpoint\/ckpt.pth<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n    net.load_state_dict(checkpoint[<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">net<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">])\n    best_acc <\/span>= checkpoint[<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">acc<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">]\n    start_epoch <\/span>= checkpoint[<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">epoch<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">]\n\ncriterion <\/span>=<span style=\"color: #000000;\"> nn.crossentropyloss()\noptimizer <\/span>= optim.sgd(net.parameters(), lr=<span style=\"color: #000000;\">args.lr,\n                      momentum<\/span>=0.9, weight_decay=5e-4<span style=\"color: #000000;\">)\n\n\n<\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> training<\/span>\n<span style=\"color: #0000ff;\">def<\/span><span style=\"color: #000000;\"> train(epoch):\n    <\/span><span style=\"color: #0000ff;\">print<\/span>(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">nepoch: %d<\/span><span style=\"color: #800000;\">'<\/span> %<span style=\"color: #000000;\"> epoch)\n    net.train()\n    train_loss <\/span>=<span style=\"color: #000000;\"> 0\n    correct <\/span>=<span style=\"color: #000000;\"> 0\n    total <\/span>=<span style=\"color: #000000;\"> 0\n    <\/span><span style=\"color: #0000ff;\">for<\/span> batch_idx, (inputs, targets) <span style=\"color: #0000ff;\">in<\/span><span style=\"color: #000000;\"> enumerate(trainloader):\n        inputs, targets <\/span>=<span style=\"color: #000000;\"> inputs.to(device), targets.to(device)\n        optimizer.zero_grad()\n        outputs <\/span>=<span style=\"color: #000000;\"> net(inputs)\n        loss <\/span>=<span style=\"color: #000000;\"> criterion(outputs, targets)\n        loss.backward()\n        optimizer.step()\n\n        train_loss <\/span>+=<span style=\"color: #000000;\"> loss.item()\n        _, predicted <\/span>= outputs.max(1<span style=\"color: #000000;\">)\n        total <\/span>+=<span style=\"color: #000000;\"> targets.size(0)\n        correct <\/span>+=<span style=\"color: #000000;\"> predicted.eq(targets).sum().item()\n\n        progress_bar(batch_idx, len(trainloader), <\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">loss: %.3f | acc: %.3f%% (%d\/%d)<\/span><span style=\"color: #800000;\">'<\/span>\n                     % (train_loss\/(batch_idx+1), 100.*correct\/<span style=\"color: #000000;\">total, correct, total))\n    torch.save(net, <\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">.\/checkpoint\/regnetx_200mf.pth<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n\n\n<\/span><span style=\"color: #0000ff;\">def<\/span><span style=\"color: #000000;\"> test(epoch):\n    <\/span><span style=\"color: #0000ff;\">global<\/span><span style=\"color: #000000;\"> best_acc\n    net.eval()\n    test_loss <\/span>=<span style=\"color: #000000;\"> 0\n    correct <\/span>=<span style=\"color: #000000;\"> 0\n    total <\/span>=<span style=\"color: #000000;\"> 0\n    with torch.no_grad():\n        <\/span><span style=\"color: #0000ff;\">for<\/span> batch_idx, (inputs, targets) <span style=\"color: #0000ff;\">in<\/span><span style=\"color: #000000;\"> enumerate(testloader):\n            inputs, targets <\/span>=<span style=\"color: #000000;\"> inputs.to(device), targets.to(device)\n            outputs <\/span>=<span style=\"color: #000000;\"> net(inputs)\n            loss <\/span>=<span style=\"color: #000000;\"> criterion(outputs, targets)\n\n            test_loss <\/span>+=<span style=\"color: #000000;\"> loss.item()\n            _, predicted <\/span>= outputs.max(1<span style=\"color: #000000;\">)\n            total <\/span>+=<span style=\"color: #000000;\"> targets.size(0)\n            correct <\/span>+=<span style=\"color: #000000;\"> predicted.eq(targets).sum().item()\n\n            progress_bar(batch_idx, len(testloader), <\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">loss: %.3f | acc: %.3f%% (%d\/%d)<\/span><span style=\"color: #800000;\">'<\/span>\n                         % (test_loss\/(batch_idx+1), 100.*correct\/<span style=\"color: #000000;\">total, correct, total))\n\n    <\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> save checkpoint.<\/span>\n    acc = 100.*correct\/<span style=\"color: #000000;\">total\n    <\/span><span style=\"color: #0000ff;\">if<\/span> acc &gt;<span style=\"color: #000000;\"> best_acc:\n        <\/span><span style=\"color: #0000ff;\">print<\/span>(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">saving..<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n        state <\/span>=<span style=\"color: #000000;\"> {\n            <\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">net<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">: net.state_dict(),\n            <\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">acc<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">: acc,\n            <\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">epoch<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">: epoch,\n        }\n        <\/span><span style=\"color: #0000ff;\">if<\/span> <span style=\"color: #0000ff;\">not<\/span> os.path.isdir(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">checkpoint<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">):\n            os.mkdir(<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">checkpoint<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n        <\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\">torch.save(net, '.\/checkpoint\/ckpt1.pth')<\/span>\n        best_acc =<span style=\"color: #000000;\"> acc\n\n\n<\/span><span style=\"color: #0000ff;\">for<\/span> epoch <span style=\"color: #0000ff;\">in<\/span> range(start_epoch, start_epoch+100<span style=\"color: #000000;\">):\n    train(epoch)\n    test(epoch)<\/span><\/pre>\n<\/div>\n<p>\u9884\u6d4b<\/p>\n<div class=\"cnblogs_code\">\n<pre><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> torch\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> cv2\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> torch.nn.functional as f\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> sys \nsys.path.append(<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">\/work\/aiit\/warming\/pytorch-cifar-master\/models<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n<\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\">import vgg<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\">import torchvision.models as models<\/span><span style=\"color: #008000;\">\n#<\/span><span style=\"color: #008000;\">from vgg2 import vgg #\u91cd\u8981\uff0c\u867d\u7136\u663e\u793a\u7070\u8272(\u5373\u5728\u6b21\u4ee3\u7801\u4e2d\u6ca1\u7528\u5230)\uff0c\u4f46\u82e5\u6ca1\u6709\u5f15\u5165\u8fd9\u4e2a\u6a21\u578b\u4ee3\u7801\uff0c\u52a0\u8f7d\u6a21\u578b\u65f6\u4f1a\u627e\u4e0d\u5230\u6a21\u578b<\/span>\n<span style=\"color: #0000ff;\">from<\/span> torch.autograd <span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> variable\n<\/span><span style=\"color: #0000ff;\">from<\/span> torchvision <span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> datasets, transforms\n<\/span><span style=\"color: #0000ff;\">import<\/span><span style=\"color: #000000;\"> numpy as np\n  \nclasses <\/span>= (<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">plane<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">car<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">bird<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">cat<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">deer<\/span><span style=\"color: #800000;\">'<\/span>,<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">dog<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">frog<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">horse<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">ship<\/span><span style=\"color: #800000;\">'<\/span>, <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">truck<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n<\/span><span style=\"color: #0000ff;\">if<\/span> <span style=\"color: #800080;\">__name__<\/span> == <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">__main__<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">:\n    device <\/span>= torch.device(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">cuda<\/span><span style=\"color: #800000;\">'<\/span> <span style=\"color: #0000ff;\">if<\/span> torch.cuda.is_available() <span style=\"color: #0000ff;\">else<\/span> <span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">cpu<\/span><span style=\"color: #800000;\">'<\/span><span style=\"color: #000000;\">)\n    <\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\">net=models.vgg19(pretrained=false)    <\/span>\n    model = (torch.load(<span style=\"color: #800000;\">'<\/span><span style=\"color: #800000;\">\/work\/aiit\/warming\/pytorch-cifar-master\/checkpoint\/regnetx_200mf.pth<\/span><span style=\"color: #800000;\">'<\/span>)) <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u52a0\u8f7d\u6a21\u578b<\/span>\n    model =<span style=\"color: #000000;\"> model.to(device)\n    model.eval() <\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u628a\u6a21\u578b\u8f6c\u4e3atest\u6a21\u5f0f<\/span>\n<span style=\"color: #000000;\">  \n    img <\/span>= cv2.imread(<span style=\"color: #800000;\">\"<\/span><span style=\"color: #800000;\">\/work\/aiit\/warming\/cifar-10-batches-py\/test\/1\/1_6.png<\/span><span style=\"color: #800000;\">\"<\/span>) <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u8bfb\u53d6\u8981\u9884\u6d4b\u7684\u56fe\u7247<\/span>\n    img=cv2.resize(img,(32,32<span style=\"color: #000000;\">))\n    trans <\/span>=<span style=\"color: #000000;\"> transforms.compose(\n    [\n     transforms.totensor(),\n     transforms.normalize(mean<\/span>=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5<span style=\"color: #000000;\">))\n    ])\n  \n    img <\/span>=<span style=\"color: #000000;\"> trans(img)\n    img <\/span>=<span style=\"color: #000000;\"> img.to(device)\n    img <\/span>= img.unsqueeze(0) <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u56fe\u7247\u6269\u5c55\u591a\u4e00\u7ef4,\u56e0\u4e3a\u8f93\u5165\u5230\u4fdd\u5b58\u7684\u6a21\u578b\u4e2d\u662f4\u7ef4\u7684[batch_size,\u901a\u9053,\u957f\uff0c\u5bbd]\uff0c\u800c\u666e\u901a\u56fe\u7247\u53ea\u6709\u4e09\u7ef4\uff0c[\u901a\u9053,\u957f\uff0c\u5bbd]<\/span>\n    <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\"> \u6269\u5c55\u540e\uff0c\u4e3a[1\uff0c1\uff0c28\uff0c28]<\/span>\n    output =<span style=\"color: #000000;\"> model(img)\n    prob <\/span>= f.softmax(output,dim=1) <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\">prob\u662f10\u4e2a\u5206\u7c7b\u7684\u6982\u7387<\/span>\n    <span style=\"color: #0000ff;\">print<\/span><span style=\"color: #000000;\">(prob)\n    value, predicted <\/span>= torch.max(output.data, 1<span style=\"color: #000000;\">)\n    <\/span><span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\">print(predicted.item())<\/span>\n    <span style=\"color: #008000;\">#<\/span><span style=\"color: #008000;\">print(value)<\/span>\n    pred_class =<span style=\"color: #000000;\"> classes[predicted.item()]\n    <\/span><span style=\"color: #0000ff;\">print<\/span><span style=\"color: #000000;\">(pred_class)\n  \n    <\/span><span style=\"color: #800000;\">'''<\/span><span style=\"color: #800000;\">prob = f.softmax(output, dim=1)\n    prob = variable(prob)\n    prob = prob.cpu().numpy() # \u7528gpu\u7684\u6570\u636e\u8bad\u7ec3\u7684\u6a21\u578b\u4fdd\u5b58\u7684\u53c2\u6570\u90fd\u662fgpu\u5f62\u5f0f\u7684\uff0c\u8981\u663e\u793a\u5219\u5148\u8981\u8f6c\u56decpu\uff0c\u518d\u8f6c\u56denumpy\u6a21\u5f0f\n    print(prob) # prob\u662f10\u4e2a\u5206\u7c7b\u7684\u6982\u7387\n    pred = np.argmax(prob) # \u9009\u51fa\u6982\u7387\u6700\u5927\u7684\u4e00\u4e2a\n    print(pred)\n    print(pred.item())\n    pred_class = classes[pred]\n    print(pred_class)<\/span><span style=\"color: #800000;\">'''<\/span><\/pre>\n<\/div>\n<p>&nbsp;<\/p>\n","protected":false},"excerpt":{"rendered":"<p>import cv2 import numpy as np import os import pic [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[1],"tags":[],"class_list":["post-153","post","type-post","status-publish","format-standard","hentry","category-uncategorized"],"yoast_head":"<!-- This site is optimized with the Yoast SEO plugin v18.6 - https:\/\/yoast.com\/wordpress\/plugins\/seo\/ -->\n<title>cifar-10-dataset - imwarming<\/title>\n<meta name=\"robots\" content=\"index, follow, max-snippet:-1, max-image-preview:large, max-video-preview:-1\" \/>\n<link rel=\"canonical\" href=\"https:\/\/imwarming.com\/?p=153\" \/>\n<meta property=\"og:locale\" content=\"zh_CN\" \/>\n<meta property=\"og:type\" content=\"article\" \/>\n<meta property=\"og:title\" content=\"cifar-10-dataset - imwarming\" \/>\n<meta property=\"og:description\" content=\"import cv2 import numpy as np import os import pic [&hellip;]\" \/>\n<meta property=\"og:url\" content=\"https:\/\/imwarming.com\/?p=153\" \/>\n<meta property=\"og:site_name\" content=\"imwarming\" \/>\n<meta property=\"article:published_time\" content=\"2020-11-19T04:31:00+00:00\" \/>\n<meta name=\"twitter:card\" content=\"summary_large_image\" \/>\n<meta name=\"twitter:label1\" content=\"\u4f5c\u8005\" \/>\n\t<meta name=\"twitter:data1\" content=\"warming\" \/>\n\t<meta name=\"twitter:label2\" content=\"\u9884\u8ba1\u9605\u8bfb\u65f6\u95f4\" \/>\n\t<meta name=\"twitter:data2\" content=\"7 \u5206\" \/>\n<script type=\"application\/ld+json\" class=\"yoast-schema-graph\">{\"@context\":\"https:\/\/schema.org\",\"@graph\":[{\"@type\":\"WebSite\",\"@id\":\"https:\/\/imwarming.com\/#website\",\"url\":\"https:\/\/imwarming.com\/\",\"name\":\"imwarming\",\"description\":\"\u6c38\u8fdc\u5e74\u8f7b\uff0c\u6c38\u8fdc\u70ed\u6cea\u76c8\u7736\",\"potentialAction\":[{\"@type\":\"SearchAction\",\"target\":{\"@type\":\"EntryPoint\",\"urlTemplate\":\"https:\/\/imwarming.com\/?s={search_term_string}\"},\"query-input\":\"required name=search_term_string\"}],\"inLanguage\":\"zh-Hans\"},{\"@type\":\"WebPage\",\"@id\":\"https:\/\/imwarming.com\/?p=153#webpage\",\"url\":\"https:\/\/imwarming.com\/?p=153\",\"name\":\"cifar-10-dataset - imwarming\",\"isPartOf\":{\"@id\":\"https:\/\/imwarming.com\/#website\"},\"datePublished\":\"2020-11-19T04:31:00+00:00\",\"dateModified\":\"2020-11-19T04:31:00+00:00\",\"author\":{\"@id\":\"https:\/\/imwarming.com\/#\/schema\/person\/9d76869a558bac6dd0d6d58f420ee8ea\"},\"breadcrumb\":{\"@id\":\"https:\/\/imwarming.com\/?p=153#breadcrumb\"},\"inLanguage\":\"zh-Hans\",\"potentialAction\":[{\"@type\":\"ReadAction\",\"target\":[\"https:\/\/imwarming.com\/?p=153\"]}]},{\"@type\":\"BreadcrumbList\",\"@id\":\"https:\/\/imwarming.com\/?p=153#breadcrumb\",\"itemListElement\":[{\"@type\":\"ListItem\",\"position\":1,\"name\":\"\u9996\u9875\",\"item\":\"https:\/\/imwarming.com\/\"},{\"@type\":\"ListItem\",\"position\":2,\"name\":\"cifar-10-dataset\"}]},{\"@type\":\"Person\",\"@id\":\"https:\/\/imwarming.com\/#\/schema\/person\/9d76869a558bac6dd0d6d58f420ee8ea\",\"name\":\"warming\",\"image\":{\"@type\":\"ImageObject\",\"@id\":\"https:\/\/imwarming.com\/#personlogo\",\"inLanguage\":\"zh-Hans\",\"url\":\"https:\/\/secure.gravatar.com\/avatar\/c4a913eed88f7601b76bbf2b103472621195b6fa2f742af89b5ea185b60e7cff?s=96&d=mm&r=g\",\"contentUrl\":\"https:\/\/secure.gravatar.com\/avatar\/c4a913eed88f7601b76bbf2b103472621195b6fa2f742af89b5ea185b60e7cff?s=96&d=mm&r=g\",\"caption\":\"warming\"},\"sameAs\":[\"https:\/\/imwarming.com\"],\"url\":\"https:\/\/imwarming.com\/?author=1\"}]}<\/script>\n<!-- \/ Yoast SEO plugin. -->","yoast_head_json":{"title":"cifar-10-dataset - imwarming","robots":{"index":"index","follow":"follow","max-snippet":"max-snippet:-1","max-image-preview":"max-image-preview:large","max-video-preview":"max-video-preview:-1"},"canonical":"https:\/\/imwarming.com\/?p=153","og_locale":"zh_CN","og_type":"article","og_title":"cifar-10-dataset - imwarming","og_description":"import cv2 import numpy as np import os import pic [&hellip;]","og_url":"https:\/\/imwarming.com\/?p=153","og_site_name":"imwarming","article_published_time":"2020-11-19T04:31:00+00:00","twitter_card":"summary_large_image","twitter_misc":{"\u4f5c\u8005":"warming","\u9884\u8ba1\u9605\u8bfb\u65f6\u95f4":"7 \u5206"},"schema":{"@context":"https:\/\/schema.org","@graph":[{"@type":"WebSite","@id":"https:\/\/imwarming.com\/#website","url":"https:\/\/imwarming.com\/","name":"imwarming","description":"\u6c38\u8fdc\u5e74\u8f7b\uff0c\u6c38\u8fdc\u70ed\u6cea\u76c8\u7736","potentialAction":[{"@type":"SearchAction","target":{"@type":"EntryPoint","urlTemplate":"https:\/\/imwarming.com\/?s={search_term_string}"},"query-input":"required name=search_term_string"}],"inLanguage":"zh-Hans"},{"@type":"WebPage","@id":"https:\/\/imwarming.com\/?p=153#webpage","url":"https:\/\/imwarming.com\/?p=153","name":"cifar-10-dataset - imwarming","isPartOf":{"@id":"https:\/\/imwarming.com\/#website"},"datePublished":"2020-11-19T04:31:00+00:00","dateModified":"2020-11-19T04:31:00+00:00","author":{"@id":"https:\/\/imwarming.com\/#\/schema\/person\/9d76869a558bac6dd0d6d58f420ee8ea"},"breadcrumb":{"@id":"https:\/\/imwarming.com\/?p=153#breadcrumb"},"inLanguage":"zh-Hans","potentialAction":[{"@type":"ReadAction","target":["https:\/\/imwarming.com\/?p=153"]}]},{"@type":"BreadcrumbList","@id":"https:\/\/imwarming.com\/?p=153#breadcrumb","itemListElement":[{"@type":"ListItem","position":1,"name":"\u9996\u9875","item":"https:\/\/imwarming.com\/"},{"@type":"ListItem","position":2,"name":"cifar-10-dataset"}]},{"@type":"Person","@id":"https:\/\/imwarming.com\/#\/schema\/person\/9d76869a558bac6dd0d6d58f420ee8ea","name":"warming","image":{"@type":"ImageObject","@id":"https:\/\/imwarming.com\/#personlogo","inLanguage":"zh-Hans","url":"https:\/\/secure.gravatar.com\/avatar\/c4a913eed88f7601b76bbf2b103472621195b6fa2f742af89b5ea185b60e7cff?s=96&d=mm&r=g","contentUrl":"https:\/\/secure.gravatar.com\/avatar\/c4a913eed88f7601b76bbf2b103472621195b6fa2f742af89b5ea185b60e7cff?s=96&d=mm&r=g","caption":"warming"},"sameAs":["https:\/\/imwarming.com"],"url":"https:\/\/imwarming.com\/?author=1"}]}},"_links":{"self":[{"href":"https:\/\/imwarming.com\/index.php?rest_route=\/wp\/v2\/posts\/153","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/imwarming.com\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/imwarming.com\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/imwarming.com\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/imwarming.com\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=153"}],"version-history":[{"count":0,"href":"https:\/\/imwarming.com\/index.php?rest_route=\/wp\/v2\/posts\/153\/revisions"}],"wp:attachment":[{"href":"https:\/\/imwarming.com\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=153"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/imwarming.com\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=153"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/imwarming.com\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=153"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}