I've met quite a few data practitioners who scorn sampling. Ideally, if one can process the whole dataset, the model can only improve. In practice, the tradeoff is much more complex. First, one can build more complex models on a sampled set, particularly if the time complexity of the model building is non-linear—and in most situations, if it is at least N* log(N). A faster model building cycle allows you to iterate over models and converge on the best approach faster. In many situations, time to action is beating the potential improvements in the prediction accuracy due to a model built on complete dataset.
Sampling may be combined with appropriate filtering—in many practical situation, focusing on a subproblem at a time leads to better understanding of the whole problem domain. In many cases, this partitioning is at the foundation of the algorithm, like in decision trees, which are considered later. Often the nature of the problem requires you to focus on the subset of original data. For example, a cyber security analysis is often focused around a specific set of IPs rather than the whole network, as it allows to iterate over hypothesis faster. Including the set of all IPs in the network may complicate things initially if not throw the modeling off the right track.
When dealing with rare events, such as clickthroughs in ADTECH, sampling the positive and negative cases with different probabilities, which is also sometimes called oversampling, often leads to better predictions in short amount of time.
Fundamentally, sampling is equivalent to just throwing a coin—or calling a random number generator—for each data row. Thus it is very much like a stream filter operation, where the filtering is on an augmented column of random numbers. Let's consider the following example:
import scala.util.Random import util.Properties val threshold = 0.05 val lines = scala.io.Source.fromFile("chapter01/data/iris/in.txt").getLines val newLines = lines.filter(_ => Random.nextDouble() <= threshold ) val w = new java.io.FileWriter(new java.io.File("out.txt")) newLines.foreach { s => w.write(s + Properties.lineSeparator) } w.close
This is all good, but it has the following disadvantages:
To fix the first point, we'll need to pass a more complex object to the function, as we need to maintain the state during the original list traversal, which makes the original algorithm less functional and parallelizable (this will be discussed later):
import scala.reflect.ClassTag import scala.util.Random import util.Properties def reservoirSample[T: ClassTag](input: Iterator[T],k: Int): Array[T] = { val reservoir = new Array[T](k) // Put the first k elements in the reservoir. var i = 0 while (i < k && input.hasNext) { val item = input.next() reservoir(i) = item i += 1 } if (i < k) { // If input size < k, trim the array size reservoir.take(i) } else { // If input size > k, continue the sampling process. while (input.hasNext) { val item = input.next val replacementIndex = Random.nextInt(i) if (replacementIndex < k) { reservoir(replacementIndex) = item } i += 1 } reservoir } } val numLines=15 val w = new java.io.FileWriter(new java.io.File("out.txt")) val lines = io.Source.fromFile("chapter01/data/iris/in.txt").getLines reservoirSample(lines, numLines).foreach { s => w.write(s + scala.util.Properties.lineSeparator) } w.close
This will output numLines
lines. Similarly to reservoir sampling, stratified sampling is guaranteed to provide the same ratios of input/output rows for all strata defined by levels of another attribute. We can achieve this by splitting the original dataset into N subsets corresponding to the levels, performing the reservoir sampling, and merging the results afterwards. However, MLlib library, which will be covered in Chapter 3, Working with Spark and MLlib, already has stratified sampling implementation:
val origLinesRdd = sc.textFile("file://...") val keyedRdd = origLines.keyBy(r => r.split(",")(0)) val fractions = keyedRdd.countByKey.keys.map(r => (r, 0.1)).toMap val sampledWithKey = keyedRdd.sampleByKeyExact(fractions) val sampled = sampledWithKey.map(_._2).collect
The other bullet point is more subtle; sometimes we want a consistent subset of values across multiple datasets, either for reproducibility or to join with another sampled dataset. In general, if we sample two datasets, the results will contain random subsets of IDs which might have very little or no intersection. The cryptographic hashing functions come to the help here. The result of applying a hash function such as MD5 or SHA1 is a sequence of bits that is statistically uncorrelated, at least in theory. We will use the MurmurHash
function, which is part of the scala.util.hashing
package:
import scala.util.hashing.MurmurHash3._ val markLow = 0 val markHigh = 4096 val seed = 12345 def consistentFilter(s: String): Boolean = { val hash = stringHash(s.split(" ")(0), seed) >>> 16 hash >= markLow && hash < markHigh } val w = new java.io.FileWriter(new java.io.File("out.txt")) val lines = io.Source.fromFile("chapter01/data/iris/in.txt").getLines lines.filter(consistentFilter).foreach { s => w.write(s + Properties.lineSeparator) } w.close
This function is guaranteed to return exactly the same subset of records based on the value of the first field—it is either all records where the first field equals a certain value or none—and will come up with approximately one-sixteenth of the original sample; the range of hash
is 0
to 65,535
.
18.219.123.84