Introduction to the tf.data API

Before continuing, we will take a look at TensorFlow's way of handling data input to any kind of model we might train. The TensorFlow tf.data API provides us with all the tools we might need to easily build complex input pipelines. One pipeline you might commonly build would involve loading raw training data, performing some preprocessing on it, shuffling, and then putting it into a batch ready for training. The tf.data API allows us to do all these steps in an easy way using simple and reusable pieces of code.

The tf.data API has two main components to it that you need to understand. The First is the tf.data.Dataset; this is what represents your raw data. More specifically, it holds a sequence of elements, where each element contains one or more tensor objects. For the task of image classification, one element would be a single training example, and it would consist of two tensors—one for the image and one for its corresponding label.

The second component is a tf.data.Iterator. These allow you to extract elements from your dataset and act as the connection between your dataset and your model code. There are several different types of iterators within TensorFlow that all serve different purposes and involve varying levels of difficulty to use.

Creating a dataset can be done in two ways. The first way is by creating a data source. An easy example is using tf.data.Dataset.from_tensor_slices(), which will create a dataset from slices of one or more Tensor objects. The other way of producing datasets is to use a dataset transformation on an existing dataset. Doing so will return a new dataset that incorporates the applied transformation. It's important to understand that all input pipelines must begin with a data source. Once you have a Dataset object, it is common to apply multiple transformations, all chained together, to it.

For now, some examples of simple transformations are Dataset.batch(), which will return batches of a set size from your Dataset object, and Dataset.repeat(), which will keep repeating the Dataset content when it reaches the end and is an easy way to be able to iterate over a dataset many times (count parameter).

Now that we have a dataset set up, we can use tf.data.Iterators to iterate over and extract elements from it. Again, there are several different kinds of iterators available to use, but the simplest one we will use is the one shot iterator. This iterator supports going through a dataset just once, but it's super simple to set up. We create it by calling the make_one_shot_iterator() method on our dataset and assigning the result to a variable. We can then call get_next() on our created iterator and assign it to another variable.

Now, whenever this op is run in a session, we will iterate once through the dataset, and a new batch will be extracted to use:

   def train(self, save_dir='./save', batch_size=500): 

       # Use keras to load the complete cifar dataset on memory (Not scalable) 

       (x_train, y_train), (x_test, y_test) = cifar10.load_data() 

 
       # Convert class vectors to binary class matrices. 

       y_train = tf.keras.utils.to_categorical(y_train, 10) 

       y_test = tf.keras.utils.to_categorical(y_test, 10) 

 

       # Using Tensorflow data Api to handle batches 

       dataset_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)) 

       dataset_train = dataset_train.shuffle(buffer_size=10000) 

       dataset_train = dataset_train.repeat() 

       dataset_train = dataset_train.batch(batch_size) 

       dataset_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)) 

       dataset_test = dataset_test.repeat() 

       dataset_test = dataset_test.batch(batch_size) 

 

       # Create an iterator 

       iter_train = dataset_train.make_one_shot_iterator() 

       iter_train_op = iter_train.get_next() 

       iter_test = dataset_test.make_one_shot_iterator()

       iter_test_op = iter_test.get_next() 

 

       # Build model graph 

       self.build_graph() 
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.226.185.87