Search
Close this search box.

Unit testing, Apache Spark, and Java are three things you’ll rarely see together. And yes, all three are possible and work well together.

Why Unit Test With Spark?

I’m not an advocate of TDD (Test-Driven Development), except when I’m writing Big Data code. You can create as small of a dataset as you want, but it will still take longer to turn around on a test or code run than you’d like. You won’t be testing that specific code that you want to test. You’ll have to wait for the compiler or sample data to hit that eventually.

All of this is wasted time waiting for code to run. Turn around time on fixes and working on code should be your focus. The faster you can check that code worked correctly, the faster, more productive, and better your code will be.

When you have a problem piece of data, you should be able to replicate that problem in less one minute. You should already have a similar unit test written that you copy and paste. Then, you put the the problem data in and verify the unit test fails. Finally, you fix the issue and rerun the test to verify it passes.

Anyone who has went through my Big Data classes has heard this speech. They’ve also thanked me at 3 AM when they can quickly diagnose a problem and fix it.

How to Unit Test Spark

Before Strata NYC 2015 in October 2015, there wasn’t a way to unit test your Spark code. That isn’t something that most Data Scientists cared about, but Data Engineers did. At this conference, Holden Karau released her Spark unit testing framework call Spark Testing Base.

This gives you access to the test suite bases that are in Apache Spark, but aren’t publically exposed in Spark. She added some other goodness aimed at making it easier to unit test client code.

Refactoring Your Code

If you’re familiar with Spark and unit testing, you’re probably thinking how you can test your code. The short answer is that it’s going to take some refactoring.

Let’s say we started out with the following code:

    // Load our input data.
    JavaRDD<String> input = sc.textFile(inputFile);

    // Split up into suits and numbers and
    // transform into pairs
    JavaPairRDD<String, Integer> suitsAndValues = input.mapToPair(w -> {
        String[] split = w.split("t");

        int cardValue = Integer.parseInt(split[0]);
        String cardSuit = split[1];
        return new Tuple2<String, Integer>(cardSuit, cardValue);
    });

    JavaPairRDD<String, Integer> counts = suitsAndValues.reduceByKey((x,
                                          y) -> x + y);

    counts.saveAsTextFile(output);

Let’s say we wanted to unit test the creation of the suitsAndValues RDD. As the code stands, we can’t do that. There isn’t a way for us to reach into the lambda function that is in the mapToPair.

We need to refactor this code to be unit testable. How?

    // Load our input data.
    JavaRDD<String> input = sc.textFile(inputFile);

    JavaPairRDD<String, Integer> suitsAndValues = runETL(input);            

    // Output the number of hearts found
    suitsAndValues.saveAsTextFile(output);

In this snippet, we’ve taken out the entire code for the mapToPair and refactored that into its own method. Let’s take a look at the method:

public static JavaPairRDD<String, Integer> runETL(
    JavaRDD<String> input) {
    // Split up into suits and numbers and transform into pairs
    JavaPairRDD<String, Integer> suitsAndValues = input.mapToPair(w -> {
        String[] split = w.split("t");

        int cardValue = Integer.parseInt(split[0]);
        String cardSuit = split[1];

        return new Tuple2<String, Integer>(cardSuit, cardValue);
    });

    return suitsAndValues;
}

This method now takes a RDD as an input, run the actual lambda function on the RDD, and returns a JavaPairRDD. This way, we can put in different JavaRDDs as the input and get the resulting JavaPairRDD.

Adding the Unit Tests

You’ll need to add Holden’s spark-testing-base to your Maven POM. For Spark 1.6, there’s a runtime dependency on Spark Hive.

<dependency>
    <groupId>com.holdenkarau</groupId>
    <artifactId>spark-testing-base_2.10</artifactId>
    <version>${spark.version}_0.3.3</version>
    <scope>test</scope>
</dependency>
<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-hive_2.10</artifactId>
    <version>${spark.version}</version>
    <scope>test</scope>
</dependency>

Unit Testing the Code

To start off, let’s take a look at the base class:

import static org.junit.Assert.assertEquals;

import java.util.Arrays;
import java.util.List;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.junit.Test;

import com.holdenkarau.spark.testing.JavaRDDComparisons;
import com.holdenkarau.spark.testing.SharedJavaSparkContext;

import scala.Option;
import scala.Serializable;
import scala.Tuple2;
import scala.reflect.ClassTag;
import unittests.RefactoredForTests;

public class SparkTest extends SharedJavaSparkContext implements
    Serializable {
    private static final long serialVersionUID = -5681683598336701496L;
}

You can see that we’re importing some of the spark-testing-base classes.

There are some Scala classes too. spark-testing-base is written in Scala, just like Spark. Some of the objects that spark-testing-base passes back are Scala classes. Don’t worry, the code samples will show how to use them.

Here is an actual test:

@Test
public void verifyMapTest() {
    // Create and run the test
    List<String> input = Arrays.asList("1tHeart", "2tDiamonds");
    JavaRDD<String> inputRDD = jsc().parallelize(input);
    JavaPairRDD<String, Integer> result = RefactoredForTests
                                          .runETL(inputRDD);

    // Create the expected output
    List<Tuple2<String, Integer>> expectedInput = Arrays.asList(
                new Tuple2<String, Integer>("Heart", 1),
                new Tuple2<String, Integer>("Diamonds", 2));
    JavaPairRDD<String, Integer> expectedRDD = jsc()
            .parallelizePairs(expectedInput);

    ClassTag<Tuple2<String, Integer>> tag =
        scala.reflect.ClassTag$.MODULE$
        .apply(Tuple2.class);

    // Run the assertions on the result and expected
    JavaRDDComparisons.assertRDDEquals(
        JavaRDD.fromRDD(JavaPairRDD.toRDD(result), tag),
        JavaRDD.fromRDD(JavaPairRDD.toRDD(expectedRDD), tag));
}

The unit test starts off by creating the input RDD. This is the RDD that we’re going to pass into the refactored method for processing.

Next, we create the expected output RDD. This is what we’re expecting the output of the refactored method to be.

Then, we create ClassTag object. This allows Scala to reflect the Object’s type correctly during the JavaPairRDD to JavaRDD conversion.

Using the JavaRDDComparisons object’s assertRDDEquals method, we compare the two RDDs. If there are any differences, there will be an assertion.

Negative Tests

We also need to write negative tests. These are tests that we expect to fail. We need to pass in wrong data and see what happens.

Here is an example of a negative test:

@Test
public void verifyFailureTest() {
    // Create and run the test
    List<String> input = Arrays.asList("1tHeart", "2tDiamonds");
    JavaRDD<String> inputRDD = jsc().parallelize(input);
    JavaPairRDD<String, Integer> result = RefactoredForTests
                                          .runETL(inputRDD);

    // Create the expected output
    List<Tuple2<String, Integer>> expectedInput = Arrays.asList(
                new Tuple2<String, Integer>("Heart", 1),
                new Tuple2<String, Integer>("Diamonds", 12));
    JavaPairRDD<String, Integer> expectedRDD = jsc()
            .parallelizePairs(expectedInput);

    // Create ClassTag to allow class reflection
    ClassTag<Tuple2<String, Integer>> tag =
        scala.reflect.ClassTag$.MODULE$
        .apply(Tuple2.class);

    // Get the list of Tuple2s that don't match
    Option<Tuple2<Option<Tuple2<String, Integer>>, Option<Tuple2<String, Integer>>>>
    compareWithOrder =
        JavaRDDComparisons.compareWithOrder(
            JavaRDD.fromRDD(JavaPairRDD.toRDD(result), tag),
            JavaRDD.fromRDD(JavaPairRDD.toRDD(expectedRDD), tag));

    // Create the objects that don't match
    Option<Tuple2<String, Integer>> rightTuple =
        Option.apply(
            new Tuple2<String, Integer>("Diamonds", 2));
    Option<Tuple2<String, Integer>> wrongTuple =
        Option.apply(
            new Tuple2<String, Integer>("Diamonds", 12));

    Tuple2<Option<Tuple2<String, Integer>>, Option<Tuple2<String, Integer>>>
    together =
        new Tuple2<Option<Tuple2<String, Integer>>,
    Option<Tuple2<String, Integer>>>
    (rightTuple, wrongTuple);

    // Assert the right and wrong values
    Option<Tuple2<Option<Tuple2<String, Integer>>, Option<Tuple2<String, Integer>>>>
    wrongValue = Option
                 .apply(together);
    assertEquals(wrongValue, compareWithOrder);
}

The test starts out very similar to the positive test above. We create the input, expected output, and run the method to get its output.

In this case we’ve purposely put a difference between the expected and output. We’ll need to create the objects to assert correctly.

In this example, we create several Tuple2 objects to encapsulate the differences. One has diamonds as a 2 and the other one a 12.

Finally, we assert that the output of the JavaRDDComparisons.compareWithOrder equals the wrong value Tuple2s that we created.

Testing Heavily Nested/Chained Processing

Some Spark code uses heavily nested or chained calls. The code snippet below gives an example:

 input.map().map().reduceByKey().map().map()

When you see this code, I want you to think about my 3 AM principle. If you got a call at 3 AM would you understand this code? Could you actually debug the issue? I’m going to say no. This is a style and opinion, but I don’t like code that is this heavily chained. I’d separate each into its own RDD. This makes it vastly easier to unit test. With too much chaining, you won’t know which step has the bug in it. With very granular unit tests, you will know which one failed.

Here’s an example of what this would look like: JavaRDD data1 = doSomething(inputRDD); JavaRDD data2 = doSomething2(data1); JavaRDD data3 = doSomethingWithAReduce(data2);

Coming from MRUnit

If you’re coming from an MRUnit background, I want to point out a few differences. One big difference is how Spark is started. With MRUnit, the entire framework wasn’t started, with spark-testing-base, Spark is started. This means that spark-testing-base, is slower than MRUnit during its initialization. Despite being somewhat slower, spark-testing-base is still much faster than executing an entire Spark job and waiting for it to hit a breakpoint.

Speaking of initialization, you don’t need to reset spark-testing-base, like you had to do with MRUnit between subsequent tests.

Next Steps

Using unit tests are just one of the ways qualified Data Engineers separate themselves and become vastly better at data engineering. If you want more information about Spark, Java, and unit testing, my Professional Spark Development and Professional Data Engineering classes will teach you the skills you need.