Warm tip: This article is reproduced from serverfault.com, please click

apache spark-具有窗口功能的PySpark数据偏度

(apache spark - PySpark data skewness with Window Functions)

发布于 2020-11-23 11:10:54

我有一个巨大的PySpark数据框,并且正在通过我的键定义的分区上执行一系列的Window函数。

密钥的问题是,我的分区因此而倾斜,并导致事件时间轴看起来像这样,

在此处输入图片说明

我知道在进行联接时可以使用加盐技术解决此问题。但是,当我使用Window函数时,如何解决此问题?

我在Window函数中使用了滞后,超前等函数。我不能使用盐键来完成该过程,因为会得到错误的结果。

在这种情况下如何解决偏度?

我正在寻找一种动态的方式来重新划分数据框而不会产生偏斜。

根据@jxc的答案进行更新

我尝试创建示例df,并尝试在其上运行代码,

df = pd.DataFrame()
df['id'] = np.random.randint(1, 1000, size=150000)
df['id'] = df['id'].map(lambda x: 100 if x % 2 == 0 else x)
df['timestamp'] = pd.date_range(start=pd.Timestamp('2020-01-01'), periods=len(df), freq='60s')
sdf = sc.createDataFrame(df)
sdf = sdf.withColumn("amt", F.rand()*100)
w = Window.partitionBy("id").orderBy("timestamp")

sdf = sdf.withColumn("new_col", F.lag("amt").over(w) + F.lead("amt").over(w))
x = sdf.toPandas()

这给了我一个这样的事件时间表,

在此处输入图片说明

我尝试了@jxc的答案中的代码,

sdf = sc.createDataFrame(df)
sdf = sdf.withColumn("amt", F.rand()*100)

N = 24*3600*365*2
sdf_1 = sdf.withColumn('pid', F.ceil(F.unix_timestamp('timestamp')/N))

w1 = Window.partitionBy('id', 'pid').orderBy('timestamp')
w2 = Window.partitionBy('id', 'pid')

sdf_2 = sdf_1.select(
    '*',
    F.count('*').over(w2).alias('cnt'),
    F.row_number().over(w1).alias('rn'),
    (F.lag('amt',1).over(w1) + F.lead('amt',1).over(w1)).alias('new_val')
)

sdf_3 = sdf_2.filter('rn in (1, 2, cnt-1, cnt)') \
    .withColumn('new_val', F.lag('amt',1).over(w) + F.lead('amt',1).over(w)) \
    .filter('rn in (1,cnt)')
    
df_new = sdf_2.filter('rn not in (1,cnt)').union(sdf_3)

x = df_new.toPandas()

我结束了又一个阶段,事件时间表似乎更偏斜了,

在此处输入图片说明

此外,新代码还会增加运行时间

Questioner
Sreeram TP
Viewed
0
jxc 2020-12-10 05:08:09

要处理大分区,你可以尝试根据orderBy列(最有可能是数字列或可以转换为数字的日期/时间戳列)对它进行拆分,以便所有新的子分区都保持正确的行顺序。使用新的分区程序处理行,并使用laglead函数进行计算,仅需对子分区之间边界周围的行进行后处理。(下面还讨论了如何在任务2中合并小分区)

使用你的示例,sdf并假设我们具有以下WinSpec和简单的聚合函数:

w = Window.partitionBy('id').orderBy('timestamp')
df.withColumn('new_amt', F.lag('amt',1).over(w) + F.lead('amt',1).over(w))

任务1:分割大分区:

请尝试以下方法:

  1. 选择一个Ñ到分裂时间戳并设置一个附加partitionBy柱PID(使用ceilintfloor等等):

    # N to cover 35-days' intervals
    N = 24*3600*35
    df1 = sdf.withColumn('pid', F.ceil(F.unix_timestamp('timestamp')/N))
    
  2. 添加PID为partitionBy(见W1),然后calaulte row_number()lag()lead()W1在每个新分区中查找行数(cnt),以帮助识别分区的结尾(rn == cnt)。除每个分区边界上的那些行外,所得的new_val对于大多数行都适用。

    w1 = Window.partitionBy('id', 'pid').orderBy('timestamp')
    w2 = Window.partitionBy('id', 'pid')
    
    df2 = df1.select(
        '*',
        F.count('*').over(w2).alias('cnt'),
        F.row_number().over(w1).alias('rn'),
        (F.lag('amt',1).over(w1) + F.lead('amt',1).over(w1)).alias('new_amt')
    )
    

    下面是df2显示边界行的示例

    在此处输入图片说明

  3. 处理边界:选择边界上的行rn in (1, cnt)以及具有计算中使用的值的行rn in (2, cnt-1)w执行相同的new_val计算,并仅保存边界行的结果。

    df3 = df2.filter('rn in (1, 2, cnt-1, cnt)') \
        .withColumn('new_amt', F.lag('amt',1).over(w) + F.lead('amt',1).over(w)) \
        .filter('rn in (1,cnt)')
    

    下面显示了上面df2生成的df3

    在此处输入图片说明

  4. df3合并df2以更新边界行rn in (1,cnt)

    df_new = df2.filter('rn not in (1,cnt)').union(df3)
    

    下面的屏幕截图显示了边界行周围的最终df_new

    在此处输入图片说明

    # drop columns which are used to implement logic only
    df_new = df_new.drop('cnt', 'rn')
    

一些注意事项:

  1. 定义了以下3个WindowSpec:

    w = Window.partitionBy('id').orderBy('timestamp')          <-- fix boundary rows
    w1 = Window.partitionBy('id', 'pid').orderBy('timestamp')  <-- calculate internal rows
    w2 = Window.partitionBy('id', 'pid')                       <-- find #rows in a partition
    

    注意:严格来说,我们最好使用以下内容w来修复边界行,以避免timestamp围绕边界的问题。

    w = Window.partitionBy('id').orderBy('pid', 'rn')          <-- fix boundary rows
    
  2. 如果你知道哪些分区是歪斜的,则只需将它们分开,然后跳过其他分区即可。如果稀疏分布,现有方法可能会将一个小分区分成2个甚至更多个分区

    df1 = df.withColumn('pid', F.when(F.col('id').isin('a','b'), F.ceil(F.unix_timestamp('timestamp')/N)).otherwise(1))
    

    如果对于每个分区,你可以检索count(行数)和min_ts= min(timestamp),然后进行更动态的尝试pid(以下M是要拆分的阈值行数):

    F.expr(f"IF(count>{M}, ceil((unix_timestamp(timestamp)-unix_timestamp(min_ts))/{N}), 1)")
    

    注意:对于分区内的偏斜,将需要生成更复杂的函数pid

  3. 如果仅使用lag(1)函数,则仅对左边界进行后处理,过滤rn in (1, cnt)并仅更新rn == 1

    df3 = df1.filter('rn in (1, cnt)') \
        .withColumn('new_amt', F.lag('amt',1).over(w)) \
        .filter('rn = 1')
    

    当我们只需要确定正确的边界并更新时,它就类似于引导函数 rn == cnt

  4. 如果仅lag(2)使用,则使用过滤并更新更多行df3

    df3 = df1.filter('rn in (1, 2, cnt-1, cnt)') \
        .withColumn('new_amt', F.lag('amt',2).over(w)) \
        .filter('rn in (1,2)')
    

    可以将相同的方法与两个延伸到混合箱laglead具有不同的偏移。

任务2:合并小分区:

根据分区中的记录数count,我们可以设置一个阈值,M以便if count>Mid拥有其自己的分区,否则我们合并分区,以使#of total records小于M(以下方法的边缘情况为2*M-2)。

M = 20000

# create pandas df with columns `id`, `count` and `f`, sort rows so that rows with count>=M are located on top
d2 = pd.DataFrame([ e.asDict() for e in sdf.groupby('id').count().collect() ]) \
    .assign(f=lambda x: x['count'].lt(M)) \
    .sort_values('f')    

# add pid column to merge smaller partitions but the total row-count in partition should be less than or around M 
# potentially there could be at most `2*M-2` records for the same pid, to make sure strictly count<M, use a for-loop to iterate d1 and set pid:
d2['pid'] = (d2.mask(d2['count'].gt(M),M)['count'].shift(fill_value=0).cumsum()/M).astype(int)

# add pid to sdf. In case join is too heavy, try using Map
sdf_1 = sdf.join(spark.createDataFrame(d2).alias('d2'), ["id"]) \
    .select(sdf["*"], F.col("d2.pid"))

# check pid: # of records and # of distinct ids
sdf_1.groupby('pid').agg(F.count('*').alias('count'), F.countDistinct('id').alias('cnt_ids')).orderBy('pid').show()
+---+-----+-------+                                                             
|pid|count|cnt_ids|
+---+-----+-------+
|  0|74837|      1|
|  1|20036|    133|
|  2|20052|    134|
|  3|20010|    133|
|  4|15065|    100|
+---+-----+-------+

现在,应该单独用pid对新Window进行分区,并将id移到orderBy,如下所示:

w3 = Window.partitionBy('pid').orderBy('id','timestamp')

根据上述w3 WinSpec自定义滞后/超前函数,然后计算new_val

lag_w3  = lambda col,n=1: F.when(F.lag('id',n).over(w3) == F.col('id'), F.lag(col,n).over(w3))
lead_w3 = lambda col,n=1: F.when(F.lead('id',n).over(w3) == F.col('id'), F.lead(col,n).over(w3))

sdf_new = sdf_1.withColumn('new_val', lag_w3('amt',1) + lead_w3('amt',1))