TaskContext¶
TaskContext is an abstraction of task contexts.
Contract¶
addTaskCompletionListener¶
addTaskCompletionListener[U](
  f: (TaskContext) => U): TaskContext
addTaskCompletionListener(
  listener: TaskCompletionListener): TaskContext
Registers a TaskCompletionListener
val rdd = sc.range(0, 5, numSlices = 1)
import org.apache.spark.TaskContext
val printTaskInfo = (tc: TaskContext) => {
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |-------------------""".stripMargin
  println(msg)
}
rdd.foreachPartition { _ =>
  val tc = TaskContext.get
  tc.addTaskCompletionListener(printTaskInfo)
}
addTaskFailureListener¶
addTaskFailureListener(
  f: (TaskContext, Throwable) => Unit): TaskContext
addTaskFailureListener(
  listener: TaskFailureListener): TaskContext
Registers a TaskFailureListener
val rdd = sc.range(0, 2, numSlices = 2)
import org.apache.spark.TaskContext
val printTaskErrorInfo = (tc: TaskContext, error: Throwable) => {
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |error:         ${error.toString}
                |-------------------""".stripMargin
  println(msg)
}
val throwExceptionForOddNumber = (n: Long) => {
  if (n % 2 == 1) {
    throw new Exception(s"No way it will pass for odd number: $n")
  }
}
// FIXME It won't work.
rdd.map(throwExceptionForOddNumber).foreachPartition { _ =>
  val tc = TaskContext.get
  tc.addTaskFailureListener(printTaskErrorInfo)
}
// Listener registration matters.
rdd.mapPartitions { (it: Iterator[Long]) =>
  val tc = TaskContext.get
  tc.addTaskFailureListener(printTaskErrorInfo)
  it
}.map(throwExceptionForOddNumber).count
fetchFailed¶
fetchFailed: Option[FetchFailedException]
Used when:
- TaskRunneris requested to run
getKillReason¶
getKillReason(): Option[String]
getLocalProperty¶
getLocalProperty(
  key: String): String
Looks up a local property by key
getMetricsSources¶
getMetricsSources(
  sourceName: String): Seq[Source]
Looks up Sources by name
isCompleted¶
isCompleted(): Boolean
isInterrupted¶
isInterrupted(): Boolean
killTaskIfInterrupted¶
killTaskIfInterrupted(): Unit
Registering Accumulator¶
registerAccumulator(
  a: AccumulatorV2[_, _]): Unit
Registers a AccumulatorV2
Used when:
- AccumulatorV2is requested to deserialize itself
resources¶
resources(): Map[String, ResourceInformation]
Resources allocated to the task
taskMetrics¶
taskMetrics(): TaskMetrics
others¶
Important
There are other methods, but don't seem very interesting.
Implementations¶
Serializable¶
TaskContext is a Serializable (Java).
Accessing TaskContext¶
get(): TaskContext
get returns the thread-local TaskContext instance.
import org.apache.spark.TaskContext
val tc = TaskContext.get
val rdd = sc.range(0, 3, numSlices = 3)
assert(rdd.partitions.size == 3)
rdd.foreach { n =>
  import org.apache.spark.TaskContext
  val tc = TaskContext.get
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |-------------------""".stripMargin
  println(msg)
}