""" Contains MNIST dataset """
import os
import logging
import tempfile
import urllib.request
import gzip
import PIL
import tqdm
import numpy as np
from . import ImagesOpenset
from ..decorators import parallel, any_action_failed
logger = logging.getLogger('mnist')
[docs]class MNIST(ImagesOpenset):
""" MNIST dataset
Examples
--------
::
# download MNIST data, split into train/test and create dataset instances
mnist = MNIST()
# iterate over the dataset
for batch in mnist.train.gen_batch(BATCH_SIZE, shuffle=True, n_epochs=2):
# do something with a batch
# download MNIST data and show progress bar
mnist = MNIST(bar=True)
"""
TRAIN_IMAGES_URL = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"
TRAIN_LABELS_URL = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
TEST_IMAGES_URL = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"
TEST_LABELS_URL = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
ALL_URLS = [TRAIN_IMAGES_URL, TRAIN_LABELS_URL, TEST_IMAGES_URL, TEST_LABELS_URL]
num_classes = 10
def __init__(self, *args, bar=False, preloaded=None, train_test=True, **kwargs):
self.bar = tqdm.tqdm(total=8) if bar else None
super().__init__(*args, preloaded=preloaded, train_test=train_test, **kwargs)
if self.bar:
self.bar.close()
@property
def _get_from_urls(self):
""" List of URLs and type of content (0 - images, 1 - labels) """
return [[self.ALL_URLS[i], i % 2] for i in range(len(self.ALL_URLS))]
def _gather_data(self, all_res, *args, **kwargs):
_ = args, kwargs
if any_action_failed(all_res):
raise IOError('Could not download files:', all_res)
images = np.concatenate([all_res[0], all_res[2]])
labels = np.concatenate([all_res[1], all_res[3]])
preloaded = images, labels
train_len, test_len = len(all_res[0]), len(all_res[2])
index, train_index, test_index = self._infer_train_test_index(train_len, test_len)
return preloaded, index, train_index, test_index
[docs] @parallel(init='_get_from_urls', post='_gather_data', target='t')
def download(self, url, content, path=None):
""" Load data from the web site """
logger.info('Downloading %s', url)
if path is None:
path = tempfile.gettempdir()
filename = os.path.basename(url)
localname = os.path.join(path, filename)
if not os.path.isfile(localname):
opener = urllib.request.URLopener()
opener.addheader('User-agent', 'Mozilla/5.0') # https://github.com/pytorch/vision/issues/1938
opener.retrieve(url, localname)
logger.info("Downloaded %s", filename)
if self.bar:
self.bar.update(1)
with open(localname, 'rb') as f:
data = self._extract_images(f) if content == 0 else self._extract_labels(f)
if self.bar:
self.bar.update(1)
return data
#
# _read32, extract_images, extract_labels are taken from tensorflow
#
@staticmethod
def _read32(bytestream):
dtype = np.dtype(np.uint32).newbyteorder('>')
return np.frombuffer(bytestream.read(4), dtype=dtype)[0]
def _extract_images(self, f):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth].
Args:
f: A file object that can be passed into a gzip reader.
Returns:
data: A 4D uint8 numpy array [index, y, x, depth].
Raises:
ValueError: If the bytestream does not start with 2051.
"""
logger.info('Extracting %s', f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = self._read32(bytestream)
if magic != 2051:
raise ValueError(f"Invalid magic number {magic} in MNIST image file: {f.name} (expected 2051")
num_images = self._read32(bytestream)
rows = self._read32(bytestream)
cols = self._read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = np.frombuffer(buf, dtype=np.uint8)
data = data.reshape(num_images, rows, cols)
return self.create_array([PIL.Image.fromarray(image) for image in data])
def _extract_labels(self, f):
"""Extract the labels into a 1D uint8 numpy array [index].
Args:
f: A file object that can be passed into a gzip reader.
Returns:
labels: a 1D uint8 numpy array.
Raises:
ValueError: If the bystream doesn't start with 2049.
"""
logger.info('Extracting %s', f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = self._read32(bytestream)
if magic != 2049:
raise ValueError(f"Invalid magic number {magic} in MNIST label file: {f.name} (expected 2049)")
num_items = self._read32(bytestream)
buf = bytestream.read(num_items)
labels = np.frombuffer(buf, dtype=np.uint8)
return labels