Spark DataSources Implementation

Spark数据源扩展与实践(40行代码实现一个自定义的DataSource)

简单示例

Spark的DataSource API可以方便地扩展。如果没有使用META-INFO这种ServiceLocator机制,则自定义的数据源名称必须是DefaultSource.
并且必须实现RelationProvider接口。

1
2
3
4
5
6
class DefaultSource extends RelationProvider {
override def createRelation(sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
???
}
}

通常自定义数据源都有不同的配置文件,所以我们也要实现自己的BaseRelation

1
2
3
4
5
6
7
8
9
10
class DefaultSource extends RelationProvider{
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
EmptyRelation()
}
}

case class EmptyRelation() extends BaseRelation {
override def sqlContext: SQLContext = ???
override def schema: StructType = ???
}

主要的起始还是BaseRelation的实现类,但是这里怎么获取schema和SQLContext呢。由于DefaultSource的createRelation方法中已经有SQLContext。所以我们可以改成

1
2
3
4
5
6
7
8
9
10
11
class DefaultSource extends RelationProvider{
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
EmptyRelation()(sqlContext)
}
}

case class EmptyRelation()(@transient val sc: SQLContext) extends BaseRelation {
override def sqlContext: SQLContext = sc

override def schema: StructType = ???
}

那么Schema怎么确定呢?通常它需要从DefaultSource的createRelation方法的parameters确定。
所以通常我们会给自定义的BaseRelation加上一个参数:

1
2
3
4
5
6
7
8
9
10
11
class DefaultSource extends RelationProvider{
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
EmptyRelation(parameters)(sqlContext)
}
}

case class EmptyRelation(parameters: Map[String, String])(@transient val sc: SQLContext) extends BaseRelation {
override def sqlContext: SQLContext = sc

override def schema: StructType = ???
}

这个schema的具体实现必须依赖于如何读取数据源。所以EmptyRelation还需要实现另外一个接口:TableScan

1
2
3
4
5
6
7
8
9
case class EmptyRelation(parameters: Map[String, String])
(@transient val sc: SQLContext)
extends BaseRelation with TableScan{
override def sqlContext: SQLContext = sc

override def schema: StructType = ???

override def buildScan(): RDD[Row] = ???
}

现在有两个方法需要我们自己实现。buildScan表示如何读取数据源,并生成RDD[ROW]
下面以一个简单的示例入门:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
case class EmptyRelation(parameters: Map[String, String])
(@transient val sc: SQLContext)
extends BaseRelation with TableScan{
override def sqlContext: SQLContext = sc

override def schema: StructType = {
StructType(List(
StructField("id", IntegerType),
StructField("name", StringType),
StructField("age", IntegerType)
))
}

override def buildScan(): RDD[Row] = {
val rdd = sqlContext.sparkContext.parallelize(
List(
(1, "A", 20),
(2, "B", 25)
)
)
rdd.map(row => Row.fromSeq(Seq(row._1, row._2, row._3)))
}
}

接下来就可以运行测试例子了:

1
2
3
4
5
6
7
8
9
object TestExample {

def main(args: Array[String]) {
val spark = SparkSession.builder().master("local").getOrCreate()
val df = spark.read.format("com.zqh.spark.connectors.test.empty").load()
df.printSchema()
df.show()
}
}

什么,只有40行代码,就实现了自定义的DataSource!!!

1
2
3
4
5
6
7
8
9
10
root
|-- id: integer (nullable = true)
|-- name: string (nullable = true)
|-- age: integer (nullable = true)

+---+----+---+
| id|name|age|
+---+----+---+
| 1| A| 20|
+---+----+---+

上面示例EmptyRelation中,schema方法和buildScan方法有如下特点:

  • schema定义了三个字段,则buildScan中每一行Row都必须有三个元素
  • RDD的每一行Row是数据,而schema对应了数据的元数据,schema可以任意指定

总结下自定义数据源相关的类:

1
2
3
4
5
6
7
8
9
10
RelationProvider                  BaseRelation    TableScan
/|\ /|\ /|\ spark
| | | ------------------
| | | user
| schema() buildScan()
DefaultSource | |
| | |
| | |
· | |
createRelation() --------------------> EmptyRelation

JDBC DataSource

开启mysql的查询日志,对应的日志文件是/usr/local/var/mysql/zqhmac.log

1
2
3
4
5
6
7
8
9
10
mysql> set GLOBAL general_log = on;
Query OK, 0 rows affected (0.08 sec)

mysql> show VARIABLES like '%general_log%';
+------------------+---------------------------------+
| Variable_name | Value |
+------------------+---------------------------------+
| general_log | ON |
| general_log_file | /usr/local/var/mysql/zqhmac.log |
+------------------+---------------------------------+

spark读取jdbc有多种方式:

1. 全量读取,只有一个分区

1
2
3
4
5
6
7
8
9
val url = "jdbc:mysql://localhost/test"
val table = "test"
val properties = new java.util.Properties
properties.put("user", "root")
properties.put("password", "root")

val df = spark.read.jdbc(url, table, properties)

df.rdd.partitions.size # 1

后台日志:

1
2008 Query SELECT `id`,`name`,`total` FROM test

Spark UI上可以看到只有一个Executor和一个Task:

jdbc spark

如果数据量太大,就会报错OOM:

oom

2. 指定上下界,自动分片

1
2
3
4
5
6
7
8
val columnName = "id"
val lowerBound = 1
val upperBound = 1000
val numPartitions = 5

val df = spark.read.jdbc(url,table,columnName,lowerBound,upperBound,numPartitions,properties)

df.rdd.partitions.size # 指定的分区数量

指定上下界有个限制条件是分区字段必须是整数类型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def jdbc(
url: String,
table: String,
columnName: String,
lowerBound: Long,
upperBound: Long,
numPartitions: Int,
connectionProperties: Properties): DataFrame = {
// columnName, lowerBound, upperBound and numPartitions override settings in extraOptions.
this.extraOptions ++= Map(
JDBCOptions.JDBC_PARTITION_COLUMN -> columnName,
JDBCOptions.JDBC_LOWER_BOUND -> lowerBound.toString,
JDBCOptions.JDBC_UPPER_BOUND -> upperBound.toString,
JDBCOptions.JDBC_NUM_PARTITIONS -> numPartitions.toString)
jdbc(url, table, connectionProperties)
}

spark的做法是根据上下界,分区个数,自动切分。这种场景主要针对数据库的主键是自增字段(当然是整数了)。

因为自增的数字分布很均匀,所以给定上下界和分区的数量,每个分区拉取的数据也是很均匀的。

后台日志:

1
2
3
4
5
2010 Query SELECT `id`,`name`,`total` FROM test WHERE id < 201 or id is null
2011 Query SELECT `id`,`name`,`total` FROM test WHERE id >= 201 AND id < 401
2012 Query SELECT `id`,`name`,`total` FROM test WHERE id >= 401 AND id < 601
2013 Query SELECT `id`,`name`,`total` FROM test WHERE id >= 601 AND id < 801
2014 Query SELECT `id`,`name`,`total` FROM test WHERE id >= 801

3. 手动构造predicates

1
2
3
4
5
6
7
8
9
val predicates = Array(
"id>=0 and id<10",
"id>=10 and id<100",
"id>=100 and id<1000"
)

val df = spark.read.jdbc(url, table, predicates, properties)

df.rdd.partitions.size # 3,predicates数组有几个,对应几个分区

后台日志:

1
2
3
2016 Query SELECT `id`,`name`,`total` FROM test WHERE id>=0 and id<10
2017 Query SELECT `id`,`name`,`total` FROM test WHERE id>=10 and id<100
2018 Query SELECT `id`,`name`,`total` FROM test WHERE id>=100 and id<1000

如果数据分布不均匀,可以采用这种方式,而且这种方式不限于主键、整数类型,可以是任意类型,任意字段。

比如我们的测试mysql表数据如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
mysql> select * from test;
+-----+------+-------+
| id | name | total |
+-----+------+-------+
| 1 | A | 1 |
| 2 | B | 2 |
| 3 | C | 3 |
| 11 | A | 12 |
| 12 | B | 12 |
| 13 | C | 12 |
| 100 | 1 | 0 |
| 101 | 2 | 1 |
| 102 | 2 | 1 |
+-----+------+-------+

现在要根据name列进行手动指定查询方式:

1
2
3
4
5
6
7
8
9
10
val predicates = Array(
"name = 'A'",
"name = 'B'",
"name = 'C'",
"name in('1','2')"
)

val df = spark.read.jdbc(url, table, predicates, properties)

df.show

后台日志:

1
2
3
4
2020 Query SELECT `id`,`name`,`total` FROM test WHERE name = 'A'
2022 Query SELECT `id`,`name`,`total` FROM test WHERE name = 'C'
2023 Query SELECT `id`,`name`,`total` FROM test WHERE name = 'B'
2021 Query SELECT `id`,`name`,`total` FROM test WHERE name in('1','2')

由于是自定义查询条件,所以我们可以使用任何方式,比如limit方法:

1
2
3
4
5
6
7
8
val predicates = Array(
"1=1 order by name limit 3 offset 0",
"1=1 order by name limit 3 offset 3",
"1=1 order by name limit 3 offset 6"
)

val df = spark.read.jdbc(url, table, predicates, properties)
df.count

后台日志:

1
2
3
2025 Query SELECT 1 FROM test WHERE 1=1 order by name limit 3 offset 3
2026 Query SELECT 1 FROM test WHERE 1=1 order by name limit 3 offset 6
2027 Query SELECT 1 FROM test WHERE 1=1 order by name limit 3 offset 0

动态指定排序字段和个数:

1
2
3
4
5
6
7
8
9
10
val orderByColumn = "name"
val limitCount = 3
val predicates = Array(
s"1=1 order by $orderByColumn limit $limitCount offset 0",
s"1=1 order by $orderByColumn limit $limitCount offset ${limitCount}",
s"1=1 order by $orderByColumn limit $limitCount offset ${limitCount*2}"
)

val df = spark.read.jdbc(url, table, predicates, properties)
df.count

后台日志:

1
2
3
2030 Query SELECT 1 FROM test WHERE 1=1 order by name limit 3 offset 3
2029 Query SELECT 1 FROM test WHERE 1=1 order by name limit 3 offset 0
2031 Query SELECT 1 FROM test WHERE 1=1 order by name limit 3 offset 6

当然上面的predicates还是不够智能,正确的做法是先查询总数,然后根据limitCount构造predicates数组。

1
2
3
4
5
6
7
8
9
10
val orderByColumn = "name"
val limitCount = 3
//val totalCount = spark.read.jdbc(url, table, properties).count // 日志:SELECT 1 FROM test
val countDF = spark.read.jdbc(url, s"(select count(*) from $table) tmp", properties) // SELECT * FROM (select count(*) from test) tmp WHERE 1=0
val totalCount = countDF.take(1)(0).getAs[Long](0) // SELECT `count(*)` FROM (select count(*) from test) tmp

val split = totalCount / limitCount
val predicates = for(i <- 0l to split) yield s"1=1 order by $orderByColumn limit $limitCount offset ${limitCount * i}"
val df = spark.read.jdbc(url, table, predicates.toArray, properties)
df.count

后台日志:

1
2
3
4
2050 Query SELECT 1 FROM test WHERE 1=1 order by name limit 3 offset 0
2051 Query SELECT 1 FROM test WHERE 1=1 order by name limit 3 offset 6
2052 Query SELECT 1 FROM test WHERE 1=1 order by name limit 3 offset 3
2053 Query SELECT 1 FROM test WHERE 1=1 order by name limit 3 offset 9

JDBC实现

spark.read.jdbc进入DataFrameReader,真正执行在load()方法中:

1
2
3
4
5
6
7
8
9
def load(paths: String*): DataFrame = {
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap).resolveRelation())
}

JDBC格式对应的Provider就定义在DataSource中:

1
2
3
4
5
6
7
8
9
10
11
12
object DataSource extends Logging {
private val backwardCompatibilityMap: Map[String, String] = {
val jdbc = classOf[JdbcRelationProvider].getCanonicalName
val json = classOf[JsonFileFormat].getCanonicalName
val csv = classOf[CSVFileFormat].getCanonicalName
Map(
"org.apache.spark.sql.jdbc" -> jdbc,
"org.apache.spark.sql.json" -> json,
"com.databricks.spark.csv" -> csv
)
}
}

jdbc数据源的定义类是:JdbcRelationProvider

JDBC扩展

参考: http://blog.csdn.net/cjuexuan/article/details/52333970

category是唯一键,存在则更新num,不存在则插入category,num。

1
2
3
4
5
INSERT INTO ip_category_count
(category,num,createTime)
VALUES(?,?,CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE
num=?,updateTime=CURRENT_TIMESTAMP

对应的Statemen写法, set时从1开始,get时从0开始:

1
2
3
4
ps.setInt(1, row.getInt(0))
ps.setLong(2, row.getLong(1))
ps.setLong(3, row.getLong(1))
ps.executeUpdate()

假设有下面的SQL:

1
2
INSERT INTO test_1 (`id`,`year`,count`) VALUES (?,?,?)
ON DUPLICATE KEY UPDATE `id`=?,`year`=?,`count`=?

对应的写法:

1
2
3
4
5
6
7
ps.setInt(1, row.getInt(0))
ps.setString(2, row.getLong(1))
ps.setLong(3, row.getLong(2))
-------------------------------
ps.setInt(4, row.getInt(0))
ps.setString(5, row.getLong(1))
ps.setLong(6, row.getLong(2))

总结出来的规则:stmt.setInt(pos + 1, row.getInt(pos - offset))

1
2
1. i<midField,  position=i, offset=0        => stmt.setInt(i + 1, row.getInt(i - 0))
2. i>=midField, position=i, offset=midField => stmt.setInt(i + 1, row.getInt(i - midField))

以3个字段为例,当i<midField时:

  • i=0: stmt.setInt(0 + 1, row.getInt(0 - 0)), stmt.setInt(1, row.getInt(0))
  • i=1: stmt.setInt(1 + 1, row.getInt(1 - 0)), stmt.setInt(2, row.getInt(1))
  • i=2: stmt.setInt(2 + 1, row.getInt(2 - 0)), stmt.setInt(3, row.getInt(2))

i>=midField时:

  • i=3: stmt.setInt(3 + 1, row.getInt(3 - 3)), stmt.setInt(3, row.getInt(0))
  • i=4: stmt.setInt(4 + 1, row.getInt(4 - 3)), stmt.setInt(4, row.getInt(1))
  • i=5: stmt.setInt(5 + 1, row.getInt(5 - 3)), stmt.setInt(5, row.getInt(2))

setter方法的第一个参数:index of setter,第二个参数:index of row。
比如对于i小于midField而言,get的位置等于索引减去0;i大于midField而言,get的位置等于索引减去3。

1
2
3
4
5
6
7
8
row[1,2,3]
setter(0) =》 set(0+1, get(0-0)) =》 set(1, get(0))
setter(1) =》 set(1+1, get(1-0)) =》 set(2, get(1))
setter(2) =》 set(2+1, get(2-0)) =》 set(3, get(2))
--------------------------------------------------------
setter(3) =》 set(3+1, get(3-3)) =》 set(4, get(0))
setter(4) =》 set(4+1, get(4-3)) =》 set(5, get(1))
setter(5) =》 set(5+1, get(5-3)) =》 set(6, get(2))

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
val length = rddSchema.fields.length
val numFields = if (isUpdateMode) length * 2 else length // real num Field length

var i = 0
val midField = numFields / 2
while (i < numFields) {
//if duplicate ,'?' size = 2 * row.field.length
if (isUpdateMode) { // 更新模式
i < midField match {
// check midField > i ,if midFiled >i ,rowIndex is setterIndex - (setterIndex/2) + 1
case true// insert部分
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(stmt, row, i, 0)
}
case false// update部分
if (row.isNullAt(i - midField)) {
stmt.setNull(i + 1, nullTypes(i - midField))
} else {
setters(i).apply(stmt, row, i, midField)
}
}
} else { // 直接插入
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(stmt, row, i, 0)
}
}
i = i + 1
}

总结下对应关系:

1
2
3
4
5
6

setter[i]: | 0 | 1 | 2 | 3 | 4 | 5 |
position: | 0 | 1 | 2 | 3 | 4 | 5 |
offset: | 0 | 0 | 0 | 3 | 3 | 3 |
setXXX: | 1 | 2 | 3 | 4 | 5 | 6 | i+1
getXXX: | 0 | 1 | 2 | 0 | 1 | 2 | position-offset

文章目录
  1. 1. 简单示例
  2. 2. JDBC DataSource
    1. 2.1. 1. 全量读取,只有一个分区
    2. 2.2. 2. 指定上下界,自动分片
    3. 2.3. 3. 手动构造predicates
  3. 3. JDBC实现
  4. 4. JDBC扩展