Warm tip: This article is reproduced from stackoverflow.com, please click
apache-spark scala apache-spark-sql

scala spark UDF filter array of struct

发布于 2020-03-28 23:13:48

I have a dataframe with schema

root
 |-- x: Long (nullable = false)
 |-- y: Long (nullable = false)
 |-- features: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- name: string (nullable = true)
 |    |    |-- score: double (nullable = true)

For example, I have data

+--------------------+--------------------+------------------------------------------+
|                x   |              y     |       features                           |
+--------------------+--------------------+------------------------------------------+
|10                  |          9         |[["f1", 5.9], ["ft2", 6.0], ["ft3", 10.9]]|
|11                  |          0         |[["f4", 0.9], ["ft1", 4.0], ["ft2", 0.9] ]|
|20                  |          9         |[["f5", 5.9], ["ft2", 6.4], ["ft3", 1.9] ]|
|18                  |          8         |[["f1", 5.9], ["ft4", 8.1], ["ft2", 18.9]]|
+--------------------+--------------------+------------------------------------------+

I would like to filter the features with a particular prefix, say "ft", so eventually I want the result:

+--------------------+--------------------+-----------------------------+
|                x   |              y     |       features              |
+--------------------+--------------------+-----------------------------+
|10                  |          9         |[["ft2", 6.0], ["ft3", 10.9]]|
|11                  |          0         |[["ft1", 4.0], ["ft2", 0.9] ]|
|20                  |          9         |[["ft2", 6.4], ["ft3", 1.9] ]|
|18                  |          8         |[["ft4", 8.1], ["ft2", 18.9]]|
+--------------------+--------------------+-----------------------------+

I'm not using Spark 2.4+ so I cannot use the solution provided here: Spark (Scala) filter array of structs without explode

I tried to use UDF, but still does not work. Here are my attempts. I define a UDF:

def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
    features.filter{
        x.getString(0).startsWith("ft")
    }
)

But if I apply this UDF

df.withColumn("filtered", filterFeature($"features"))

I get the error Schema for type org.apache.spark.sql.Row is not supported. I found that I can't return Row from UDF. Then I tried

def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
    features.filter{
        x.getString(0).startsWith("ft")
    }, (StringType, DoubleType)
)

I then got an error:

 error: type mismatch;
 found   : (org.apache.spark.sql.types.StringType.type, org.apache.spark.sql.types.DoubleType.type)
 required: org.apache.spark.sql.types.DataType
              }, (StringType, DoubleType)
                 ^

I also tried a case class as suggested by some answers:

case class FilteredFeature(featureName: String, featureScore: Double)
def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
    features.filter{
        x.getString(0).startsWith("ft")
    }, FilteredFeature
)

But I got:

 error: type mismatch;
 found   : FilteredFeature.type
 required: org.apache.spark.sql.types.DataType
              }, FilteredFeature
                 ^

I tried:

case class FilteredFeature(featureName: String, featureScore: Double)
def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
    features.filter{
        x.getString(0).startsWith("ft")
    }, Seq[FilteredFeature]
)

I got:

<console>:192: error: missing argument list for method apply in class GenericCompanion
Unapplied methods are only converted to functions when a function type is expected.
You can make this conversion explicit by writing `apply _` or `apply(_)` instead of `apply`.
              }, Seq[FilteredFeature]
                    ^

I tried:

case class FilteredFeature(featureName: String, featureScore: Double)
def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
    features.filter{
        x.getString(0).startsWith("ft")
    }, Seq[FilteredFeature](_)
)

I got:

<console>:201: error: type mismatch;
 found   : Seq[FilteredFeature]
 required: FilteredFeature
              }, Seq[FilteredFeature](_)
                          ^

What should I do in this case?

Questioner
al3xtouch
Viewed
336
Raphael Roth 2020-01-31 23:19

You have two Options :

a) provide a schema to the UDF, this let's you return Seq[Row]

b) convert Seq[Row] to a Seq of Tuple2 or a case class, then you don't need to provide a schema (but struct field names are lost if you use Tuples!)

I would prefer option a) for your case (works well for structs with many fields):

val schema = df.schema("features").dataType

val filterFeature = udf((features:Seq[Row]) => features.filter(_.getAs[String]("name").startsWith("ft")),schema)