What's new in Apache Spark 3.5.0 - watermark propagation

Versions: Apache Spark 3.5.0

Watermark, or rather multiple watermarks management, has been a thorn in the side of Apache Spark Structured Streaming. It has improved in the previous release (3.4.0) but still had some room for improvement. Well, it did have because the 3.5.0 release brought a serious fix for the multiple watermarks scenario.

Looking for a better data engineering position and skills?

You have been working as a data engineer but feel stuck? You don't have any new challenges and are still writing the same jobs all over again? You have now different options. You can try to look for a new job, now or later, or learn from the others! "Become a Better Data Engineer" initiative is one of these places where you can find online learning resources where the theory meets the practice. They will help you prepare maybe for the next job, or at least, improve your current skillset without looking for something else.

👉 I'm interested in improving my data engineering skillset

See you there, Bartosz

Jungtaek Lim added a watermark propagation where the output watermark of a stateful operator is the input watermark of the next stateful operator. It changes from the previous releases where only a global watermark was available. Why was this global watermark problematic? You can read about that in the Correctness issue in Apache Spark.

Watermark propagation

Before delving into details, let's see some code and high-level activity flow. Below is a job with 2 stateful operations, a join and a grouping:

val clicksWithWatermark = clicksStream.toDF.withWatermark("clickTime", "10 minutes")
val impressionsWithWatermark = impressionsStream.toDF.withWatermark("impressionTime", "33 minutes")

val joined = impressionsWithWatermark.join(
  clicksWithWatermark,
  functions.expr("""clickAdId = impressionAdId AND
      	clickTime >= impressionTime AND
      	clickTime <= impressionTime + interval 2 minutes
      	"""),
  joinType = "leftOuter"
)

val groups = joined.groupBy($"clickAdId", functions.window($"clickTime", "1 hour"))
  .count()

The watermark propagation consists of generating the watermark at the children nodes and moving it up to the parent query plan nodes, as in the schema below:

How does it translate to the implementation? Let's see!

Delving into details

How does this watermark propagation feature work? First, it might not work at all if you disable it by setting the spark.sql.streaming.statefulOperator.allowMultiple to false. In that case, the behavior will remain the same as in Apache Spark 3.4.0.

If not, something new will happen. First, at the query planning stage where the query will read the propagated watermarks for late data and state eviction from the PropagateWatermarkSimulator:

The propagation consists of traversing the execution plan bottom-up and recording the watermarks encountered at each step:

The propagation logic applies a different operation to a node:

By the end of this traversal, the propagator builds a map of query node ids and their respective watermark values.When the planner wants to get the watermark value for late data or state eviction, it calls these 2 methods which reads the watermarks from the just built node-to-watermark map:

  override def getInputWatermarkForLateEvents(batchId: Long, stateOpId: Long): Long = {
	// We use watermark for previous microbatch to determine late events.
	getInputWatermark(batchId - 1, stateOpId)
  }

  override def getInputWatermarkForEviction(batchId: Long, stateOpId: Long): Long =
	getInputWatermark(batchId, stateOpId)

The methods are called only for the stateful operations:

  object WatermarkPropagationRule extends SparkPlanPartialRule {
	private def inputWatermarkForLateEvents(stateInfo: StatefulOperatorStateInfo): Option[Long] = {
  	Some(watermarkPropagator.getInputWatermarkForLateEvents(currentBatchId,
    	stateInfo.operatorId))
	}

	private def inputWatermarkForEviction(stateInfo: StatefulOperatorStateInfo): Option[Long] = {
  	Some(watermarkPropagator.getInputWatermarkForEviction(currentBatchId, stateInfo.operatorId))
	}

	override val rule: PartialFunction[SparkPlan, SparkPlan] = {
  	case s: StateStoreSaveExec if s.stateInfo.isDefined =>
    	s.copy(
      	eventTimeWatermarkForLateEvents = inputWatermarkForLateEvents(s.stateInfo.get),
      	eventTimeWatermarkForEviction = inputWatermarkForEviction(s.stateInfo.get)
    	)

  	case s: SessionWindowStateStoreSaveExec if s.stateInfo.isDefined =>
// ...

  	case s: SessionWindowStateStoreRestoreExec if s.stateInfo.isDefined =>
  	// ...

  	case s: StreamingDeduplicateExec if s.stateInfo.isDefined =>
    	// ...

  	case s: StreamingDeduplicateWithinWatermarkExec if s.stateInfo.isDefined =>
    	// ...

  	case m: FlatMapGroupsWithStateExec if m.stateInfo.isDefined =>
    	// ...

  	case m: FlatMapGroupsInPandasWithStateExec if m.stateInfo.isDefined =>
    	// ...

  	case j: StreamingSymmetricHashJoinExec =>
    	// ...
	}
  }

Example

If the above was hard to understand, let's analyze this example:

val clicksWithWatermark = clicksStream.toDF
  .withWatermark("clickTime", "10 minutes")
val impressionsWithWatermark = impressionsStream.toDF
  .withWatermark("impressionTime", "20 minutes")

val joined = impressionsWithWatermark.join(
  clicksWithWatermark,
  functions.expr("""clickAdId = impressionAdId AND
              clickTime >= impressionTime AND
              clickTime <= impressionTime + interval 2 minutes
              """),
  joinType = "leftOuter"
)

val groups = joined
  .groupBy(functions.window($"clickTime", "1 hour"))
  .count()

// New micro-batch
clicksStream.addData(Seq(
  Click(1, Timestamp.valueOf("2023-06-10 10:10:00")),
  Click(1, Timestamp.valueOf("2023-06-10 10:20:00")),
  Click(2, Timestamp.valueOf("2023-06-10 10:30:00"))
))
impressionsStream.addData(Seq(
  Impression(1, Timestamp.valueOf("2023-06-10 10:00:00"))
))
query.processAllAvailable()

// New micro-batch
clicksStream.addData(Seq(
  Click(4, Timestamp.valueOf("2023-06-10 11:01:00")),
  Click(1, Timestamp.valueOf("2023-06-10 11:03:00")),
  Click(2, Timestamp.valueOf("2023-06-10 11:04:00"))
))
impressionsStream.addData(Seq(
  Impression(5, Timestamp.valueOf("2023-06-10 11:00:00"))
))
query.processAllAvailable()

The query has 2 stateful operations, the join and the window aggregate, plus 2 watermarked nodes at the bottom of the query plan. After running the first datasets, the watermark propagation algorithm assigns the following watermark values to the query nodes:

Micro-batchInput global watermarkOutput global watermarkJOIN watermarks
(input/output)
Aggregation watermarks
(input/output)
Comments
0000 / 00 / 0
10MIN(9:40, 10:20) = 9:400 / 00 / 0as the global input watermark is empty, the propagation doesn't apply
29:40MIN(10:40, 10:54) = 10:409:40 / 9:37:599:37:59 / 9:37:59

Confused? I was. Initially I considered the propagation as local, i.e. within the same micro-batch. But do not forget that the watermark itself is updated only after running the micro-batch, here:

  protected def markMicroBatchEnd(): Unit = {
    watermarkTracker.updateWatermark(lastExecution.executedPlan)
    reportTimeTaken("commitOffsets") {
       if (!commitLog.add(currentBatchId,CommitMetadata(watermarkTracker.currentWatermark)))   
        { throw QueryExecutionErrors.concurrentStreamLogUpdate(currentBatchId) }
    }
    committedOffsets ++= availableOffsets
}

That's why the watermark propagation gets active only from the micro-batch number 2. Besides, it also explains why the watermark nodes don't propagate the event times from the current micro-batch but only from the previous one.

Watermarks and past batches

But the watermark propagation doesn't stop there. It also impacts the watermarks for late data and state eviction, both available from there:

class PropagateWatermarkSimulator extends WatermarkPropagator with Logging {

  private def getInputWatermark(batchId: Long, stateOpId: Long): Long = {
  // ...
    inputWatermarks(batchId).get(stateOpId) match {
    	case Some(wmOpt) =>
      	   Math.max(wmOpt.getOrElse(0L), 0L)
    	// ...
  }

  override def getInputWatermarkForLateEvents(batchId: Long, stateOpId: Long): Long = {
	// We use watermark for previous microbatch to determine late events.
	getInputWatermark(batchId - 1, stateOpId)
  }

  override def getInputWatermarkForEviction(batchId: Long, stateOpId: Long): Long =
	getInputWatermark(batchId, stateOpId)

If we apply this logic to our previous example, we will get:

Micro-batchJOIN watermarks
(late data / state eviction)
Aggregation watermarks
(late data / state eviction)
00 / 00 / 0
10 / 00 / 0
20 / 9:400 / 9:37:59

Watermarks are probably the most evolving feature in the 3 last versions of Apache Spark Structured Streaming. But I don't think you'll hear about it for the last time in my "What's new in Apache Spark" series ;)