2017-04-10 4 views
0

私は、特定の列のデータセットの文字列を置き換えようとしています。 1または0の場合は1、それ以外の場合は「Y」、それ以外の場合は0です。rddの代わりにpysparkのSQL関数

ラムダでデータフレームからrddへの変換を使用して、対象とする列を特定できましたが、処理に時間がかかります。

各列ごとにrddへの切り替えが実行され、次にdistinctが実行されますが、これはしばらく時間がかかります!

異なる結果セットに 'Y'が存在する場合、その列は変換を必要とするものとして識別されます。

誰もが私はどのようにpespark SQL関数を排他的に各列を切り替えることなく同じ結果を得るために使用することができますか?次のように

コードは、サンプルデータに、次のとおりです。

import pyspark.sql.types as typ 
    import pyspark.sql.functions as func 

    col_names = [ 
     ('ALIVE', typ.StringType()), 
     ('AGE', typ.IntegerType()), 
     ('CAGE', typ.IntegerType()), 
     ('CNT1', typ.IntegerType()), 
     ('CNT2', typ.IntegerType()), 
     ('CNT3', typ.IntegerType()), 
     ('HE', typ.IntegerType()), 
     ('WE', typ.IntegerType()), 
     ('WG', typ.IntegerType()), 
     ('DBP', typ.StringType()), 
     ('DBG', typ.StringType()), 
     ('HT1', typ.StringType()), 
     ('HT2', typ.StringType()), 
     ('PREV', typ.StringType()) 
     ] 

    schema = typ.StructType([typ.StructField(c[0], c[1], False) for c in col_names]) 
    df = spark.createDataFrame([('Y',22,56,4,3,65,180,198,18,'N','Y','N','N','N'), 
           ('N',38,79,3,4,63,155,167,12,'N','N','N','Y','N'), 
           ('Y',39,81,6,6,60,128,152,24,'N','N','N','N','Y')] 
           ,schema=schema) 

    cols = [(col.name, col.dataType) for col in df.schema] 

    transform_cols = [] 

    for s in cols: 
     if s[1] == typ.StringType(): 
     distinct_result = df.select(s[0]).distinct().rdd.map(lambda row: row[0]).collect() 
     if 'Y' in distinct_result: 
      transform_cols.append(s[0]) 

    print(transform_cols) 

出力は次のようになります。

['ALIVE', 'DBG', 'HT2', 'PREV'] 

答えて

1

私はタスクを実行するためにudfを使用することができました。まず、(ここでは私が最初の行流し読みするためにfunc.firstを使用)YまたはNでカラムを選ぶ:

cols_sel = df.select([func.first(col).alias(col) for col in df.columns]).collect()[0].asDict() 
cols = [col_name for (col_name, v) in cols_sel.items() if v in ['Y', 'N']] 
# return ['HT2', 'ALIVE', 'DBP', 'HT1', 'PREV', 'DBG'] 

次に、あなたは10YNをマッピングするためにudf関数を作成することができます。

def map_input(val): 
    map_dict = dict(zip(['Y', 'N'], [1, 0])) 
    return map_dict.get(val) 
udf_map_input = func.udf(map_input, returnType=typ.IntegerType()) 

for col in cols: 
    df = df.withColumn(col, udf_map_input(col)) 
df.show() 

最後に、列を合計することができます。私はその後、辞書に出力を変換し、0より大きい値を持つ列をチェック

out = df.select([func.sum(col).alias(col) for col in cols]).collect() 
out = out[0] 
print([col_name for (col_name, val) in out.asDict().items() if val > 0]) 

出力

['DBG', 'HT2', 'ALIVE', 'PREV'] 
+1

おかげで(すなわちYが含まれている)、それは必ずしも、より効率的ではないが、それは有用でした私がpysparkを初めて使うときには別の解決策を見てください。 – alortimor

+0

ようこそ!それが少し助けてくれることを願っています – titipata

関連する問題