diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/DecomposeGroupingSetsRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/DecomposeGroupingSetsRule.java new file mode 100644 index 0000000000000..1467c2b70ef5f --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/DecomposeGroupingSetsRule.java @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.planner.calcite.FlinkRelBuilder; +import org.apache.flink.table.planner.calcite.FlinkRelFactories; +import org.apache.flink.table.planner.plan.utils.AggregateUtil; +import org.apache.flink.table.planner.plan.utils.ExpandUtil; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.util.Preconditions; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.util.ImmutableBitSet; +import org.immutables.value.Value; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +/** + * This rule rewrites an aggregation query with grouping sets into an regular aggregation query with + * expand. + * + *

This rule duplicates the input data by two or more times (# number of groupSets + an optional + * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and + * exchange operators. + * + *

This rule will be used for the plan with grouping sets or the plan with distinct aggregations + * after {@link FlinkAggregateExpandDistinctAggregatesRule} applied. + * + *

`FlinkAggregateExpandDistinctAggregatesRule` rewrites an aggregate query with distinct + * aggregations into an expanded double aggregation. The first aggregate has grouping sets in which + * the regular aggregation expressions and every distinct clause are aggregated in a separate group. + * The results are then combined in a second aggregate. + * + *

Examples:
+ *
+ * MyTable: a: INT, b: BIGINT, c: VARCHAR(32), d: VARCHAR(32)
+ *
+ * Original records:
+ * | a | b | c  | d  |
+ * |:-:|:-:|:--:|:--:|
+ * | 1 | 1 | c1 | d1 |
+ * | 1 | 2 | c1 | d2 |
+ * | 2 | 1 | c1 | d1 |
+ *
+ * Example1 (expand for DISTINCT aggregates):
+ *
+ * SQL: SELECT a, SUM(DISTINCT b) as t1, COUNT(DISTINCT c) as t2, COUNT(d) as t3 FROM MyTable GROUP
+ * BY a
+ *
+ * Logical plan:
+ * {@code
+ * LogicalAggregate(group=[{0}], t1=[SUM(DISTINCT $1)], t2=[COUNT(DISTINCT $2)], t3=[COUNT($3)])
+ *  LogicalTableScan(table=[[builtin, default, MyTable]])
+ * }
+ *
+ * Logical plan after `FlinkAggregateExpandDistinctAggregatesRule` applied:
+ * {@code
+ * LogicalProject(a=[$0], t1=[$1], t2=[$2], t3=[CAST($3):BIGINT NOT NULL])
+ *  LogicalProject(a=[$0], t1=[$1], t2=[$2], $f3=[CASE(IS NOT NULL($3), $3, 0)])
+ *   LogicalAggregate(group=[{0}], t1=[SUM($1) FILTER $4], t2=[COUNT($2) FILTER $5],
+ *     t3=[MIN($3) FILTER $6])
+ *    LogicalProject(a=[$0], b=[$1], c=[$2], t3=[$3], $g_1=[=($4, 1)], $g_2=[=($4, 2)],
+ *      $g_3=[=($4, 3)])
+ *     LogicalAggregate(group=[{0, 1, 2}], groups=[[{0, 1}, {0, 2}, {0}]], t3=[COUNT($3)],
+ *       $g=[GROUPING($0, $1, $2)])
+ *      LogicalTableScan(table=[[builtin, default, MyTable]])
+ * }
+ *
+ * Logical plan after this rule applied:
+ * {@code
+ * LogicalCalc(expr#0..3=[{inputs}], expr#4=[IS NOT NULL($t3)], ...)
+ *  LogicalAggregate(group=[{0}], t1=[SUM($1) FILTER $4], t2=[COUNT($2) FILTER $5],
+ *    t3=[MIN($3) FILTER $6])
+ *   LogicalCalc(expr#0..4=[{inputs}], ... expr#10=[CASE($t6, $t5, $t8, $t7, $t9)],
+ *      expr#11=[1], expr#12=[=($t10, $t11)], ... $g_1=[$t12], ...)
+ *    LogicalAggregate(group=[{0, 1, 2, 4}], groups=[[]], t3=[COUNT($3)])
+ *     LogicalExpand(projects=[{a=[$0], b=[$1], c=[null], d=[$3], $e=[1]},
+ *       {a=[$0], b=[null], c=[$2], d=[$3], $e=[2]}, {a=[$0], b=[null], c=[null], d=[$3], $e=[3]}])
+ *      LogicalTableSourceScan(table=[[builtin, default, MyTable]], fields=[a, b, c, d])
+ * }
+ *
+ * '$e = 1' is equivalent to 'group by a, b' '$e = 2' is equivalent to 'group by a, c' '$e = 3' is
+ * equivalent to 'group by a'
+ *
+ * Expanded records: \+-----+-----+-----+-----+-----+ \| a | b | c | d | $e |
+ * \+-----+-----+-----+-----+-----+ ---+--- \| 1 | 1 | null| d1 | 1 | |
+ * \+-----+-----+-----+-----+-----+ | \| 1 | null| c1 | d1 | 2 | records expanded by record1
+ * \+-----+-----+-----+-----+-----+ | \| 1 | null| null| d1 | 3 | | \+-----+-----+-----+-----+-----+
+ * ---+--- \| 1 | 2 | null| d2 | 1 | | \+-----+-----+-----+-----+-----+ | \| 1 | null| c1 | d2 | 2 |
+ * records expanded by record2 \+-----+-----+-----+-----+-----+ | \| 1 | null| null| d2 | 3 | |
+ * \+-----+-----+-----+-----+-----+ ---+--- \| 2 | 1 | null| d1 | 1 | |
+ * \+-----+-----+-----+-----+-----+ | \| 2 | null| c1 | d1 | 2 | records expanded by record3
+ * \+-----+-----+-----+-----+-----+ | \| 2 | null| null| d1 | 3 | | \+-----+-----+-----+-----+-----+
+ * ---+---
+ *
+ * Example2 (Some fields are both in DISTINCT aggregates and non-DISTINCT aggregates):
+ *
+ * SQL: SELECT MAX(a) as t1, COUNT(DISTINCT a) as t2, count(DISTINCT d) as t3 FROM MyTable
+ *
+ * Field `a` is both in DISTINCT aggregate and `MAX` aggregate, so, `a` should be outputted as two
+ * individual fields, one is for `MAX` aggregate, another is for DISTINCT aggregate.
+ *
+ * Expanded records: \+-----+-----+-----+-----+ \| a | d | $e | a_0 | \+-----+-----+-----+-----+
+ * ---+--- \| 1 | null| 1 | 1 | | \+-----+-----+-----+-----+ | \| null| d1 | 2 | 1 | records
+ * expanded by record1 \+-----+-----+-----+-----+ | \| null| null| 3 | 1 | |
+ * \+-----+-----+-----+-----+ ---+--- \| 1 | null| 1 | 1 | | \+-----+-----+-----+-----+ | \| null|
+ * d2 | 2 | 1 | records expanded by record2 \+-----+-----+-----+-----+ | \| null| null| 3 | 1 | |
+ * \+-----+-----+-----+-----+ ---+--- \| 2 | null| 1 | 2 | | \+-----+-----+-----+-----+ | \| null|
+ * d1 | 2 | 2 | records expanded by record3 \+-----+-----+-----+-----+ | \| null| null| 3 | 2 | |
+ * \+-----+-----+-----+-----+ ---+---
+ *
+ * Example3 (expand for CUBE/ROLLUP/GROUPING SETS):
+ *
+ * SQL: SELECT a, c, SUM(b) as b FROM MyTable GROUP BY GROUPING SETS (a, c)
+ *
+ * Logical plan:
+ * {@code
+ * LogicalAggregate(group=[{0, 1}], groups=[[{0}, {1}]], b=[SUM($2)])
+ *  LogicalProject(a=[$0], c=[$2], b=[$1])
+ *   LogicalTableScan(table=[[builtin, default, MyTable]])
+ * }
+ *
+ * Logical plan after this rule applied:
+ * {@code
+ * LogicalCalc(expr#0..3=[{inputs}], proj#0..1=[{exprs}], b=[$t3])
+ *  LogicalAggregate(group=[{0, 2, 3}], groups=[[]], b=[SUM($1)])
+ *   LogicalExpand(projects=[{a=[$0], b=[$1], c=[null], $e=[1]},
+ *     {a=[null], b=[$1], c=[$2], $e=[2]}])
+ *    LogicalNativeTableScan(table=[[builtin, default, MyTable]])
+ * }
+ *
+ * '$e = 1' is equivalent to 'group by a' '$e = 2' is equivalent to 'group by c'
+ *
+ * Expanded records: \+-----+-----+-----+-----+ \| a | b | c | $e | \+-----+-----+-----+-----+
+ * ---+--- \| 1 | 1 | null| 1 | | \+-----+-----+-----+-----+ records expanded by record1 \| null| 1
+ * \| c1 | 2 | | \+-----+-----+-----+-----+ ---+--- \| 1 | 2 | null| 1 | |
+ * \+-----+-----+-----+-----+ records expanded by record2 \| null| 2 | c1 | 2 | |
+ * \+-----+-----+-----+-----+ ---+--- \| 2 | 1 | null| 1 | | \+-----+-----+-----+-----+ records
+ * expanded by record3 \| null| 1 | c1 | 2 | | \+-----+-----+-----+-----+ ---+---
+ * 
+ */ +@Value.Enclosing +public class DecomposeGroupingSetsRule + extends RelRule { + public static final DecomposeGroupingSetsRule INSTANCE = + DecomposeGroupingSetsRule.DecomposeGroupingSetsRuleConfig.DEFAULT.toRule(); + + protected DecomposeGroupingSetsRule(DecomposeGroupingSetsRuleConfig config) { + super(config); + } + + @Override + public boolean matches(RelOptRuleCall call) { + LogicalAggregate agg = call.rel(0); + List groupIdExprs = + JavaScalaConversionUtil.toJava( + AggregateUtil.getGroupIdExprIndexes( + JavaScalaConversionUtil.toScala(agg.getAggCallList()))); + return agg.getGroupSets().size() > 1 || !groupIdExprs.isEmpty(); + } + + public void onMatch(RelOptRuleCall call) { + LogicalAggregate agg = call.rel(0); + // Long data type is used to store groupValue in FlinkAggregateExpandDistinctAggregatesRule, + // and the result of grouping function is a positive value, + // so the max groupCount must be less than 64. + if (agg.getGroupCount() >= 64) { + throw new TableException("group count must be less than 64."); + } + + RelNode aggInput = agg.getInput(); + List groupIdExprs = + JavaScalaConversionUtil.toJava( + AggregateUtil.getGroupIdExprIndexes( + JavaScalaConversionUtil.toScala(agg.getAggCallList()))); + List> aggCallsWithIndexes = + IntStream.range(0, agg.getAggCallList().size()) + .mapToObj(i -> Tuple2.of(agg.getAggCallList().get(i), i)) + .collect(Collectors.toList()); + + RelOptCluster cluster = agg.getCluster(); + RexBuilder rexBuilder = cluster.getRexBuilder(); + boolean needExpand = agg.getGroupSets().size() > 1; + + FlinkRelBuilder relBuilder = (FlinkRelBuilder) call.builder(); + relBuilder.push(aggInput); + + ImmutableBitSet newGroupSet; + Map duplicateFieldMap; + if (needExpand) { + Tuple2, Integer> expandResult = + JavaScalaConversionUtil.toJava( + ExpandUtil.buildExpandNode( + relBuilder, + JavaScalaConversionUtil.toScala(agg.getAggCallList()), + agg.getGroupSet(), + agg.getGroupSets())); + + // new groupSet contains original groupSet and expand_id('$e') field + newGroupSet = agg.getGroupSet().union(ImmutableBitSet.of(expandResult.f1)); + duplicateFieldMap = JavaScalaConversionUtil.toJava(expandResult.f0); + } else { + // no need add expand node, only need care about group functions + newGroupSet = agg.getGroupSet(); + duplicateFieldMap = new HashMap<>(); + } + + int newGroupCount = newGroupSet.cardinality(); + List newAggCalls = + aggCallsWithIndexes.stream() + .filter(p -> !groupIdExprs.contains(p.f1)) + .map( + p -> { + AggregateCall aggCall = p.f0; + List newArgList = + aggCall.getArgList().stream() + .map(a -> duplicateFieldMap.getOrDefault(a, a)) + .collect(Collectors.toList()); + int newFilterArg = + duplicateFieldMap.getOrDefault( + aggCall.filterArg, aggCall.filterArg); + return aggCall.adaptTo( + relBuilder.peek(), + newArgList, + newFilterArg, + agg.getGroupCount(), + newGroupCount); + }) + .collect(Collectors.toList()); + + // create simple aggregate + relBuilder.aggregate(relBuilder.groupKey(newGroupSet, List.of(newGroupSet)), newAggCalls); + RelNode newAgg = relBuilder.peek(); + + // create a project to mapping original aggregate's output + // get names of original grouping fields + List groupingFieldsName = + IntStream.range(0, agg.getGroupCount()) + .mapToObj(x -> agg.getRowType().getFieldNames().get(x)) + .collect(Collectors.toList()); + + // create field access for all original grouping fields + List groupingFields = + IntStream.range(0, agg.getGroupSet().cardinality()) + .mapToObj(idx -> rexBuilder.makeInputRef(newAgg, idx)) + .collect(Collectors.toList()); + + List> groupSetsWithIndexes = + IntStream.range(0, agg.getGroupSets().size()) + .mapToObj(i -> Tuple2.of(agg.getGroupSets().get(i), i)) + .collect(Collectors.toList()); + // output aggregate calls including `normal` agg call and grouping agg call + int aggCnt = 0; + List aggFields = new ArrayList<>(); + for (Tuple2 aggCallWithIndex : aggCallsWithIndexes) { + AggregateCall aggCall = aggCallWithIndex.f0; + int idx = aggCallWithIndex.f1; + if (groupIdExprs.contains(idx)) { + if (needExpand) { + // reference to expand_id('$e') field in new aggregate + int expandIdIdxInNewAgg = newGroupCount - 1; + RexInputRef expandIdField = + rexBuilder.makeInputRef(newAgg, expandIdIdxInNewAgg); + // create case when for group expression + List whenThenElse = + groupSetsWithIndexes.stream() + .flatMap( + tuple -> { + int i = tuple.f1; + RexNode groupExpr = + lowerGroupExpr( + rexBuilder, + aggCall, + groupSetsWithIndexes, + i); + if (i < agg.getGroupSets().size() - 1) { + // WHEN/THEN + long expandIdVal = + ExpandUtil.genExpandId( + agg.getGroupSet(), tuple.f0); + RelDataType expandIdType = + newAgg.getRowType() + .getFieldList() + .get(expandIdIdxInNewAgg) + .getType(); + RexNode expandIdLit = + rexBuilder.makeLiteral( + expandIdVal, + expandIdType, + false); + return Stream.of( + // when $e = $e_value + rexBuilder.makeCall( + SqlStdOperatorTable.EQUALS, + expandIdField, + expandIdLit), + // then return group expression literal + // value + groupExpr); + } else { + // ELSE + return Stream.of( + // else return group expression literal + // value + groupExpr); + } + }) + .collect(Collectors.toList()); + aggFields.add(rexBuilder.makeCall(SqlStdOperatorTable.CASE, whenThenElse)); + } else { + // create literal for group expression + aggFields.add(lowerGroupExpr(rexBuilder, aggCall, groupSetsWithIndexes, 0)); + } + } else { + // create access to aggregation result + RexInputRef aggResult = rexBuilder.makeInputRef(newAgg, newGroupCount + aggCnt); + aggCnt += 1; + aggFields.add(aggResult); + } + } + + // add a projection to establish the result schema and set the values of the group + // expressions. + RelNode project = + relBuilder + .project( + Stream.concat(groupingFields.stream(), aggFields.stream()) + .collect(Collectors.toList()), + Stream.concat( + groupingFieldsName.stream(), + agg.getAggCallList().stream() + .map(AggregateCall::getName)) + .collect(Collectors.toList())) + .convert(agg.getRowType(), true) + .build(); + + call.transformTo(project); + } + + /** Returns a literal for a given group expression. */ + private RexNode lowerGroupExpr( + RexBuilder builder, + AggregateCall call, + List> groupSetsWithIndexes, + int indexInGroupSets) { + + ImmutableBitSet groupSet = groupSetsWithIndexes.get(indexInGroupSets).f0; + Set groups = groupSet.asSet(); + + switch (call.getAggregation().getKind()) { + case GROUP_ID: + // https://issues.apache.org/jira/browse/CALCITE-1824 + // GROUP_ID is not in the SQL standard. It is implemented only by Oracle. + // GROUP_ID is useful only if you have duplicate grouping sets, + // If grouping sets are distinct, GROUP_ID() will always return zero; + // Else return the index in the duplicate grouping sets. + // e.g. SELECT deptno, GROUP_ID() AS g FROM Emp GROUP BY GROUPING SETS (deptno, (), + // ()) + // As you can see, the grouping set () occurs twice. + // So there is one row in the result for each occurrence: + // the first occurrence has g = 0; the second has g = 1. + + List duplicateGroupSetsIndices = + groupSetsWithIndexes.stream() + .filter(p -> p.f0.compareTo(groupSet) == 0) + .map(tuple2 -> tuple2.f1) + .collect(Collectors.toList()); + Preconditions.checkArgument( + !duplicateGroupSetsIndices.isEmpty(), "requirement failed"); + long id = duplicateGroupSetsIndices.indexOf(indexInGroupSets); + return builder.makeLiteral(id, call.getType(), false); + case GROUPING: + case GROUPING_ID: + // GROUPING function is defined in the SQL standard, + // but the definition of GROUPING is different from in Oracle and in SQL standard: + // https://docs.oracle.com/cd/B28359_01/server.111/b28286/functions064.htm#SQLRF00647 + // + // GROUPING_ID function is not defined in the SQL standard, and has the same + // functionality with GROUPING function in Calcite. + // our implementation is consistent with Oracle about GROUPING_ID function. + // + // NOTES: + // In Calcite, the java-document of SqlGroupingFunction is not consistent with + // agg.iq. + long res = 0L; + for (Integer arg : call.getArgList()) { + res = (res << 1L) + (groups.contains(arg) ? 0L : 1L); + } + return builder.makeLiteral(res, call.getType(), false); + default: + return builder.makeNullLiteral(call.getType()); + } + } + + /** Rule configuration. */ + @Value.Immutable(singleton = false) + public interface DecomposeGroupingSetsRuleConfig extends RelRule.Config { + DecomposeGroupingSetsRule.DecomposeGroupingSetsRuleConfig DEFAULT = + ImmutableDecomposeGroupingSetsRule.DecomposeGroupingSetsRuleConfig.builder() + .operandSupplier(b0 -> b0.operand(LogicalAggregate.class).anyInputs()) + .relBuilderFactory(FlinkRelFactories.FLINK_REL_BUILDER()) + .description("DecomposeGroupingSetsRule") + .build(); + + @Override + default DecomposeGroupingSetsRule toRule() { + return new DecomposeGroupingSetsRule(this); + } + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/DecomposeGroupingSetsRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/DecomposeGroupingSetsRule.scala deleted file mode 100644 index 140c1d7c92c40..0000000000000 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/DecomposeGroupingSetsRule.scala +++ /dev/null @@ -1,338 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.planner.plan.rules.logical - -import org.apache.flink.table.api.TableException -import org.apache.flink.table.planner.calcite.{FlinkRelBuilder, FlinkRelFactories} -import org.apache.flink.table.planner.plan.utils.{AggregateUtil, ExpandUtil} - -import com.google.common.collect.ImmutableList -import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} -import org.apache.calcite.plan.RelOptRule._ -import org.apache.calcite.rel.core.AggregateCall -import org.apache.calcite.rel.logical.LogicalAggregate -import org.apache.calcite.rex.{RexBuilder, RexNode} -import org.apache.calcite.sql.SqlKind -import org.apache.calcite.sql.fun.SqlStdOperatorTable -import org.apache.calcite.util.ImmutableBitSet - -import scala.collection.JavaConversions._ - -/** - * This rule rewrites an aggregation query with grouping sets into an regular aggregation query with - * expand. - * - * This rule duplicates the input data by two or more times (# number of groupSets + an optional - * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and - * exchange operators. - * - * This rule will be used for the plan with grouping sets or the plan with distinct aggregations - * after [[FlinkAggregateExpandDistinctAggregatesRule]] applied. - * - * `FlinkAggregateExpandDistinctAggregatesRule` rewrites an aggregate query with distinct - * aggregations into an expanded double aggregation. The first aggregate has grouping sets in which - * the regular aggregation expressions and every distinct clause are aggregated in a separate group. - * The results are then combined in a second aggregate. - * - * Examples: - * - * MyTable: a: INT, b: BIGINT, c: VARCHAR(32), d: VARCHAR(32) - * - * Original records: - * | a | b | c | d | - * |:-:|:-:|:--:|:--:| - * | 1 | 1 | c1 | d1 | - * | 1 | 2 | c1 | d2 | - * | 2 | 1 | c1 | d1 | - * - * Example1 (expand for DISTINCT aggregates): - * - * SQL: SELECT a, SUM(DISTINCT b) as t1, COUNT(DISTINCT c) as t2, COUNT(d) as t3 FROM MyTable GROUP - * BY a - * - * Logical plan: - * {{{ - * LogicalAggregate(group=[{0}], t1=[SUM(DISTINCT $1)], t2=[COUNT(DISTINCT $2)], t3=[COUNT($3)]) - * LogicalTableScan(table=[[builtin, default, MyTable]]) - * }}} - * - * Logical plan after `FlinkAggregateExpandDistinctAggregatesRule` applied: - * {{{ - * LogicalProject(a=[$0], t1=[$1], t2=[$2], t3=[CAST($3):BIGINT NOT NULL]) - * LogicalProject(a=[$0], t1=[$1], t2=[$2], $f3=[CASE(IS NOT NULL($3), $3, 0)]) - * LogicalAggregate(group=[{0}], t1=[SUM($1) FILTER $4], t2=[COUNT($2) FILTER $5], - * t3=[MIN($3) FILTER $6]) - * LogicalProject(a=[$0], b=[$1], c=[$2], t3=[$3], $g_1=[=($4, 1)], $g_2=[=($4, 2)], - * $g_3=[=($4, 3)]) - * LogicalAggregate(group=[{0, 1, 2}], groups=[[{0, 1}, {0, 2}, {0}]], t3=[COUNT($3)], - * $g=[GROUPING($0, $1, $2)]) - * LogicalTableScan(table=[[builtin, default, MyTable]]) - * }}} - * - * Logical plan after this rule applied: - * {{{ - * LogicalCalc(expr#0..3=[{inputs}], expr#4=[IS NOT NULL($t3)], ...) - * LogicalAggregate(group=[{0}], t1=[SUM($1) FILTER $4], t2=[COUNT($2) FILTER $5], - * t3=[MIN($3) FILTER $6]) - * LogicalCalc(expr#0..4=[{inputs}], ... expr#10=[CASE($t6, $t5, $t8, $t7, $t9)], - * expr#11=[1], expr#12=[=($t10, $t11)], ... $g_1=[$t12], ...) - * LogicalAggregate(group=[{0, 1, 2, 4}], groups=[[]], t3=[COUNT($3)]) - * LogicalExpand(projects=[{a=[$0], b=[$1], c=[null], d=[$3], $e=[1]}, - * {a=[$0], b=[null], c=[$2], d=[$3], $e=[2]}, {a=[$0], b=[null], c=[null], d=[$3], $e=[3]}]) - * LogicalTableSourceScan(table=[[builtin, default, MyTable]], fields=[a, b, c, d]) - * }}} - * - * '$e = 1' is equivalent to 'group by a, b' '$e = 2' is equivalent to 'group by a, c' '$e = 3' is - * equivalent to 'group by a' - * - * Expanded records: \+-----+-----+-----+-----+-----+ \| a | b | c | d | $e | - * \+-----+-----+-----+-----+-----+ ---+--- \| 1 | 1 | null| d1 | 1 | | - * \+-----+-----+-----+-----+-----+ | \| 1 | null| c1 | d1 | 2 | records expanded by record1 - * \+-----+-----+-----+-----+-----+ | \| 1 | null| null| d1 | 3 | | \+-----+-----+-----+-----+-----+ - * ---+--- \| 1 | 2 | null| d2 | 1 | | \+-----+-----+-----+-----+-----+ | \| 1 | null| c1 | d2 | 2 | - * records expanded by record2 \+-----+-----+-----+-----+-----+ | \| 1 | null| null| d2 | 3 | | - * \+-----+-----+-----+-----+-----+ ---+--- \| 2 | 1 | null| d1 | 1 | | - * \+-----+-----+-----+-----+-----+ | \| 2 | null| c1 | d1 | 2 | records expanded by record3 - * \+-----+-----+-----+-----+-----+ | \| 2 | null| null| d1 | 3 | | \+-----+-----+-----+-----+-----+ - * ---+--- - * - * Example2 (Some fields are both in DISTINCT aggregates and non-DISTINCT aggregates): - * - * SQL: SELECT MAX(a) as t1, COUNT(DISTINCT a) as t2, count(DISTINCT d) as t3 FROM MyTable - * - * Field `a` is both in DISTINCT aggregate and `MAX` aggregate, so, `a` should be outputted as two - * individual fields, one is for `MAX` aggregate, another is for DISTINCT aggregate. - * - * Expanded records: \+-----+-----+-----+-----+ \| a | d | $e | a_0 | \+-----+-----+-----+-----+ - * ---+--- \| 1 | null| 1 | 1 | | \+-----+-----+-----+-----+ | \| null| d1 | 2 | 1 | records - * expanded by record1 \+-----+-----+-----+-----+ | \| null| null| 3 | 1 | | - * \+-----+-----+-----+-----+ ---+--- \| 1 | null| 1 | 1 | | \+-----+-----+-----+-----+ | \| null| - * d2 | 2 | 1 | records expanded by record2 \+-----+-----+-----+-----+ | \| null| null| 3 | 1 | | - * \+-----+-----+-----+-----+ ---+--- \| 2 | null| 1 | 2 | | \+-----+-----+-----+-----+ | \| null| - * d1 | 2 | 2 | records expanded by record3 \+-----+-----+-----+-----+ | \| null| null| 3 | 2 | | - * \+-----+-----+-----+-----+ ---+--- - * - * Example3 (expand for CUBE/ROLLUP/GROUPING SETS): - * - * SQL: SELECT a, c, SUM(b) as b FROM MyTable GROUP BY GROUPING SETS (a, c) - * - * Logical plan: - * {{{ - * LogicalAggregate(group=[{0, 1}], groups=[[{0}, {1}]], b=[SUM($2)]) - * LogicalProject(a=[$0], c=[$2], b=[$1]) - * LogicalTableScan(table=[[builtin, default, MyTable]]) - * }}} - * - * Logical plan after this rule applied: - * {{{ - * LogicalCalc(expr#0..3=[{inputs}], proj#0..1=[{exprs}], b=[$t3]) - * LogicalAggregate(group=[{0, 2, 3}], groups=[[]], b=[SUM($1)]) - * LogicalExpand(projects=[{a=[$0], b=[$1], c=[null], $e=[1]}, - * {a=[null], b=[$1], c=[$2], $e=[2]}]) - * LogicalNativeTableScan(table=[[builtin, default, MyTable]]) - * }}} - * - * '$e = 1' is equivalent to 'group by a' '$e = 2' is equivalent to 'group by c' - * - * Expanded records: \+-----+-----+-----+-----+ \| a | b | c | $e | \+-----+-----+-----+-----+ - * ---+--- \| 1 | 1 | null| 1 | | \+-----+-----+-----+-----+ records expanded by record1 \| null| 1 - * \| c1 | 2 | | \+-----+-----+-----+-----+ ---+--- \| 1 | 2 | null| 1 | | - * \+-----+-----+-----+-----+ records expanded by record2 \| null| 2 | c1 | 2 | | - * \+-----+-----+-----+-----+ ---+--- \| 2 | 1 | null| 1 | | \+-----+-----+-----+-----+ records - * expanded by record3 \| null| 1 | c1 | 2 | | \+-----+-----+-----+-----+ ---+--- - */ -class DecomposeGroupingSetsRule - extends RelOptRule( - operand(classOf[LogicalAggregate], any), - FlinkRelFactories.FLINK_REL_BUILDER, - "DecomposeGroupingSetsRule") { - - override def matches(call: RelOptRuleCall): Boolean = { - val agg: LogicalAggregate = call.rel(0) - val groupIdExprs = AggregateUtil.getGroupIdExprIndexes(agg.getAggCallList) - agg.getGroupSets.size() > 1 || groupIdExprs.nonEmpty - } - - override def onMatch(call: RelOptRuleCall): Unit = { - val agg: LogicalAggregate = call.rel(0) - // Long data type is used to store groupValue in FlinkAggregateExpandDistinctAggregatesRule, - // and the result of grouping function is a positive value, - // so the max groupCount must be less than 64. - if (agg.getGroupCount >= 64) { - throw new TableException("group count must be less than 64.") - } - - val aggInput = agg.getInput - val groupIdExprs = AggregateUtil.getGroupIdExprIndexes(agg.getAggCallList) - val aggCallsWithIndexes = agg.getAggCallList.zipWithIndex - - val cluster = agg.getCluster - val rexBuilder = cluster.getRexBuilder - val needExpand = agg.getGroupSets.size() > 1 - - val relBuilder = call.builder().asInstanceOf[FlinkRelBuilder] - relBuilder.push(aggInput) - - val (newGroupSet, duplicateFieldMap) = if (needExpand) { - val (duplicateFieldMap, expandIdIdxInExpand) = ExpandUtil.buildExpandNode( - relBuilder, - agg.getAggCallList, - agg.getGroupSet, - agg.getGroupSets) - - // new groupSet contains original groupSet and expand_id('$e') field - val newGroupSet = agg.getGroupSet.union(ImmutableBitSet.of(expandIdIdxInExpand)) - - (newGroupSet, duplicateFieldMap) - } else { - // no need add expand node, only need care about group functions - (agg.getGroupSet, Map.empty[Integer, Integer]) - } - - val newGroupCount = newGroupSet.cardinality() - val newAggCalls = aggCallsWithIndexes.collect { - case (aggCall, idx) if !groupIdExprs.contains(idx) => - val newArgList = aggCall.getArgList.map(a => duplicateFieldMap.getOrElse(a, a)).toList - val newFilterArg = duplicateFieldMap.getOrDefault(aggCall.filterArg, aggCall.filterArg) - aggCall.adaptTo( - relBuilder.peek(), - newArgList, - newFilterArg, - agg.getGroupCount, - newGroupCount) - } - - // create simple aggregate - relBuilder.aggregate( - relBuilder.groupKey(newGroupSet, ImmutableList.of[ImmutableBitSet](newGroupSet)), - newAggCalls) - val newAgg = relBuilder.peek() - - // create a project to mapping original aggregate's output - // get names of original grouping fields - val groupingFieldsName = Seq - .range(0, agg.getGroupCount) - .map(x => agg.getRowType.getFieldNames.get(x)) - - // create field access for all original grouping fields - val groupingFields = agg.getGroupSet.toList.zipWithIndex - .map { case (_, idx) => rexBuilder.makeInputRef(newAgg, idx) } - .toArray[RexNode] - - val groupSetsWithIndexes = agg.getGroupSets.zipWithIndex - // output aggregate calls including `normal` agg call and grouping agg call - var aggCnt = 0 - val aggFields = aggCallsWithIndexes.map { - case (aggCall, idx) if groupIdExprs.contains(idx) => - if (needExpand) { - // reference to expand_id('$e') field in new aggregate - val expandIdIdxInNewAgg = newGroupCount - 1 - val expandIdField = rexBuilder.makeInputRef(newAgg, expandIdIdxInNewAgg) - // create case when for group expression - val whenThenElse = groupSetsWithIndexes.flatMap { - case (subGroupSet, i) => - val groupExpr = lowerGroupExpr(rexBuilder, aggCall, groupSetsWithIndexes, i) - if (i < agg.getGroupSets.size() - 1) { - // WHEN/THEN - val expandIdVal = ExpandUtil.genExpandId(agg.getGroupSet, subGroupSet) - val expandIdType = newAgg.getRowType.getFieldList.get(expandIdIdxInNewAgg).getType - val expandIdLit = rexBuilder.makeLiteral(expandIdVal, expandIdType, false) - Seq( - // when $e = $e_value - rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, expandIdField, expandIdLit), - // then return group expression literal value - groupExpr - ) - } else { - // ELSE - Seq( - // else return group expression literal value - groupExpr - ) - } - } - rexBuilder.makeCall(SqlStdOperatorTable.CASE, whenThenElse) - } else { - // create literal for group expression - lowerGroupExpr(rexBuilder, aggCall, groupSetsWithIndexes, 0) - } - case _ => - // create access to aggregation result - val aggResult = rexBuilder.makeInputRef(newAgg, newGroupCount + aggCnt) - aggCnt += 1 - aggResult - } - - // add a projection to establish the result schema and set the values of the group expressions. - relBuilder.project( - groupingFields.toSeq ++ aggFields, - groupingFieldsName ++ agg.getAggCallList.map(_.name)) - relBuilder.convert(agg.getRowType, true) - - call.transformTo(relBuilder.build()) - } - - /** Returns a literal for a given group expression. */ - private def lowerGroupExpr( - builder: RexBuilder, - call: AggregateCall, - groupSetsWithIndexes: Seq[(ImmutableBitSet, Int)], - indexInGroupSets: Int): RexNode = { - - val groupSet = groupSetsWithIndexes(indexInGroupSets)._1 - val groups = groupSet.asSet() - call.getAggregation.getKind match { - case SqlKind.GROUP_ID => - // https://issues.apache.org/jira/browse/CALCITE-1824 - // GROUP_ID is not in the SQL standard. It is implemented only by Oracle. - // GROUP_ID is useful only if you have duplicate grouping sets, - // If grouping sets are distinct, GROUP_ID() will always return zero; - // Else return the index in the duplicate grouping sets. - // e.g. SELECT deptno, GROUP_ID() AS g FROM Emp GROUP BY GROUPING SETS (deptno, (), ()) - // As you can see, the grouping set () occurs twice. - // So there is one row in the result for each occurrence: - // the first occurrence has g = 0; the second has g = 1. - val duplicateGroupSetsIndices = groupSetsWithIndexes - .filter { case (gs, _) => gs.compareTo(groupSet) == 0 } - .map(_._2) - .toArray[Int] - require(duplicateGroupSetsIndices.nonEmpty) - val id: Long = duplicateGroupSetsIndices.indexOf(indexInGroupSets) - builder.makeLiteral(id, call.getType, false) - case SqlKind.GROUPING | SqlKind.GROUPING_ID => - // GROUPING function is defined in the SQL standard, - // but the definition of GROUPING is different from in Oracle and in SQL standard: - // https://docs.oracle.com/cd/B28359_01/server.111/b28286/functions064.htm#SQLRF00647 - // - // GROUPING_ID function is not defined in the SQL standard, and has the same - // functionality with GROUPING function in Calcite. - // our implementation is consistent with Oracle about GROUPING_ID function. - // - // NOTES: - // In Calcite, the java-document of SqlGroupingFunction is not consistent with agg.iq. - val res: Long = call.getArgList.foldLeft(0L)( - (res, arg) => (res << 1L) + (if (groups.contains(arg)) 0L else 1L)) - builder.makeLiteral(res, call.getType, false) - case _ => builder.makeNullLiteral(call.getType) - } - } -} - -object DecomposeGroupingSetsRule { - val INSTANCE: RelOptRule = new DecomposeGroupingSetsRule -} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/utils/JavaScalaConversionUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/utils/JavaScalaConversionUtil.scala index 84ea3d15d8d60..8e89af81f5ee4 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/utils/JavaScalaConversionUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/utils/JavaScalaConversionUtil.scala @@ -19,7 +19,7 @@ package org.apache.flink.table.planner.utils import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} -import java.util.{List => JList, Optional, Set => JSet} +import java.util.{List => JList, Map => JMap, Optional, Set => JSet} import java.util.function.{BiConsumer, Consumer, Function} import scala.collection.JavaConverters._ @@ -59,4 +59,10 @@ object JavaScalaConversionUtil { def toJava(set: Set[Int]): JSet[Integer] = set.map(_.asInstanceOf[Integer]).asJava + + def toJava[K, V](map: scala.collection.Map[K, V]): JMap[K, V] = + map.asJava + + def toScala[K, V](map: JMap[K, V]): scala.collection.Map[K, V] = + map.asScala }