[Enhancement] support map type in UDAF (#60840)
Signed-off-by: yan zhang <dirtysalt1987@gmail.com>
This commit is contained in:
parent
96b2c68575
commit
f5f8e9bc2c
|
|
@ -452,6 +452,7 @@ public:
|
|||
helper.get_result_from_boxed_array(ctx, type, output.get(), res, batch_size);
|
||||
} else {
|
||||
helper.get_result_from_boxed_array(ctx, type, to, res, batch_size);
|
||||
(void)ColumnHelper::update_nested_has_null(to);
|
||||
down_cast<NullableColumn*>(to)->update_has_null();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>com.starrocks.udf</groupId>
|
||||
<artifactId>java-udf</artifactId>
|
||||
<version>1.0-SNAPSHOT</version>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<java.version>17</java.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-math3 -->
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-math3</artifactId>
|
||||
<version>3.6.1</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<version>3.3</version>
|
||||
<configuration>
|
||||
<source>${java.version}</source>
|
||||
<target>${java.version}</target>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-assembly-plugin</artifactId>
|
||||
<configuration>
|
||||
<descriptorRefs>
|
||||
<descriptorRef>jar-with-dependencies</descriptorRef>
|
||||
</descriptorRefs>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>make-assembly</id>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>single</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
package com.starrocks.udf;
|
||||
|
||||
import org.apache.commons.math3.special.Erf;
|
||||
|
||||
public class NormalCdf {
|
||||
public final Double evaluate(Double mean, Double standardDeviation, Double value) {
|
||||
if (mean == null || standardDeviation == null || value == null) {
|
||||
return null;
|
||||
}
|
||||
return 0.5 * (1 + Erf.erf((value - mean) / (standardDeviation * Math.sqrt(2))));
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
package com.starrocks.udf;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.util.Map;
|
||||
import java.util.function.BinaryOperator;
|
||||
|
||||
public class SumMap<K, T> {
|
||||
BinaryOperator<T> sumFunction;
|
||||
|
||||
public SumMap(BinaryOperator<T> sumFunction) {
|
||||
this.sumFunction = sumFunction;
|
||||
}
|
||||
|
||||
public static class State<K, T> {
|
||||
Map<K, T> values = new java.util.HashMap<>();
|
||||
|
||||
public byte[] serialize() throws IOException {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
ObjectOutputStream oos = new ObjectOutputStream(baos);
|
||||
for (Map.Entry<K, T> entry : values.entrySet()) {
|
||||
oos.writeObject(entry.getKey());
|
||||
oos.writeObject(entry.getValue());
|
||||
}
|
||||
oos.flush();
|
||||
oos.close();
|
||||
return baos.toByteArray();
|
||||
}
|
||||
|
||||
public int serializeLength() throws IOException {
|
||||
byte[] serializedData = serialize();
|
||||
return serializedData.length;
|
||||
}
|
||||
|
||||
public void serialize(java.nio.ByteBuffer buff) throws IOException {
|
||||
byte[] serializedData = serialize();
|
||||
buff.put(serializedData);
|
||||
}
|
||||
|
||||
public void deserialize(java.nio.ByteBuffer buff) throws IOException, ClassNotFoundException {
|
||||
int length = buff.remaining();
|
||||
byte[] data = new byte[length];
|
||||
buff.get(data);
|
||||
try (java.io.ObjectInputStream ois = new java.io.ObjectInputStream(new java.io.ByteArrayInputStream(data))) {
|
||||
while (true) {
|
||||
try {
|
||||
K key = (K) ois.readObject();
|
||||
T value = (T) ois.readObject();
|
||||
values.put(key, value);
|
||||
} catch (java.io.EOFException eof) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void destroy(State state) {
|
||||
}
|
||||
|
||||
public final void update(State state, Map<K, T> val) {
|
||||
if (val != null) {
|
||||
for (Map.Entry<K, T> entry : val.entrySet()) {
|
||||
K key = entry.getKey();
|
||||
T value = entry.getValue();
|
||||
state.values.merge(key, value, sumFunction);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void serialize(State state, java.nio.ByteBuffer buff) throws IOException {
|
||||
state.serialize(buff);
|
||||
}
|
||||
|
||||
public void merge(State state, java.nio.ByteBuffer buffer) throws IOException, ClassNotFoundException {
|
||||
State<K, T> oldState = new State();
|
||||
oldState.deserialize(buffer);
|
||||
for (Map.Entry<K, T> entry : oldState.values.entrySet()) {
|
||||
K key = entry.getKey();
|
||||
T value = entry.getValue();
|
||||
state.values.merge(key, value, sumFunction);
|
||||
}
|
||||
}
|
||||
|
||||
public Map<Object, T> finalize(State state) {
|
||||
return state.values;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
package com.starrocks.udf;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
|
||||
public class SumMapInt64 extends SumMap<Object, Long> {
|
||||
public SumMapInt64() {
|
||||
super(Long::sum);
|
||||
}
|
||||
|
||||
public static class State extends SumMap.State<Object, Long> {
|
||||
public int serializeLength() throws IOException {
|
||||
return super.serializeLength();
|
||||
}
|
||||
}
|
||||
|
||||
public State create() {
|
||||
return new State();
|
||||
}
|
||||
|
||||
public void destroy(State state) {
|
||||
super.destroy(state);
|
||||
}
|
||||
|
||||
// only support scalar type!!!
|
||||
public final void update(State state, Map<Object, Long> val) {
|
||||
super.update(state, val);
|
||||
}
|
||||
|
||||
public void serialize(State state, java.nio.ByteBuffer buff) throws IOException {
|
||||
super.serialize(state, buff);
|
||||
}
|
||||
|
||||
public void merge(State state, java.nio.ByteBuffer buffer) throws IOException, ClassNotFoundException {
|
||||
super.merge(state, buffer);
|
||||
}
|
||||
|
||||
public Map<Object, Long> finalize(State state) {
|
||||
return super.finalize(state);
|
||||
}
|
||||
}
|
||||
|
|
@ -179,7 +179,7 @@ public class CreateFunctionAnalyzer {
|
|||
} else {
|
||||
createdFunction = analyzeStarrocksJarUdtf(stmt, checksum, handleClass);
|
||||
}
|
||||
|
||||
|
||||
stmt.setFunction(createdFunction);
|
||||
}
|
||||
|
||||
|
|
@ -198,7 +198,7 @@ public class CreateFunctionAnalyzer {
|
|||
}
|
||||
|
||||
private Function analyzeStarrocksJarUdf(CreateFunctionStmt stmt, String checksum,
|
||||
JavaUDFInternalClass handleClass) {
|
||||
JavaUDFInternalClass handleClass) {
|
||||
checkStarrocksJarUdfClass(stmt, handleClass);
|
||||
|
||||
FunctionName functionName = stmt.getFunctionName();
|
||||
|
|
@ -304,8 +304,8 @@ public class CreateFunctionAnalyzer {
|
|||
}
|
||||
|
||||
private Function analyzeStarrocksJarUdaf(CreateFunctionStmt stmt, String checksum,
|
||||
JavaUDFInternalClass mainClass,
|
||||
JavaUDFInternalClass udafStateClass) {
|
||||
JavaUDFInternalClass mainClass,
|
||||
JavaUDFInternalClass udafStateClass) {
|
||||
FunctionName functionName = stmt.getFunctionName();
|
||||
FunctionArgsDef argsDef = stmt.getArgsDef();
|
||||
TypeDef returnType = stmt.getReturnType();
|
||||
|
|
@ -329,7 +329,7 @@ public class CreateFunctionAnalyzer {
|
|||
}
|
||||
|
||||
private Function analyzeStarrocksJarUdtf(CreateFunctionStmt stmt, String checksum,
|
||||
JavaUDFInternalClass mainClass) {
|
||||
JavaUDFInternalClass mainClass) {
|
||||
FunctionName functionName = stmt.getFunctionName();
|
||||
FunctionArgsDef argsDef = stmt.getArgsDef();
|
||||
TypeDef returnType = stmt.getReturnType();
|
||||
|
|
@ -488,17 +488,23 @@ public class CreateFunctionAnalyzer {
|
|||
}
|
||||
|
||||
private void checkUdfType(Method method, Type expType, Class<?> ptype, String pname) {
|
||||
if (!(expType instanceof ScalarType)) {
|
||||
Class<?> cls = null;
|
||||
if (expType instanceof ScalarType) {
|
||||
ScalarType scalarType = (ScalarType) expType;
|
||||
cls = PRIMITIVE_TYPE_TO_JAVA_CLASS_TYPE.get(scalarType.getPrimitiveType());
|
||||
} else if (expType instanceof MapType) {
|
||||
cls = Map.class;
|
||||
} else if (expType instanceof ArrayType) {
|
||||
cls = List.class;
|
||||
} else {
|
||||
ErrorReport.reportSemanticException(ErrorCode.ERR_COMMON_ERROR,
|
||||
String.format("UDF class '%s' method '%s' does not support non-scalar type '%s'",
|
||||
String.format("UDF class '%s' method '%s' does not support type '%s'",
|
||||
clazz.getCanonicalName(), method.getName(), expType));
|
||||
}
|
||||
ScalarType scalarType = (ScalarType) expType;
|
||||
Class<?> cls = PRIMITIVE_TYPE_TO_JAVA_CLASS_TYPE.get(scalarType.getPrimitiveType());
|
||||
if (cls == null) {
|
||||
ErrorReport.reportSemanticException(ErrorCode.ERR_COMMON_ERROR,
|
||||
String.format("UDF class '%s' method '%s' does not support type '%s'",
|
||||
clazz.getCanonicalName(), method.getName(), scalarType));
|
||||
clazz.getCanonicalName(), method.getName(), expType));
|
||||
}
|
||||
if (!cls.equals(ptype)) {
|
||||
ErrorReport.reportSemanticException(ErrorCode.ERR_COMMON_ERROR,
|
||||
|
|
|
|||
|
|
@ -57,6 +57,30 @@ public class CreateFunctionStmtAnalyzerTest {
|
|||
createFunctionSql, 32).get(0);
|
||||
}
|
||||
|
||||
private CreateFunctionStmt createMapStmt(String symbol, String type) {
|
||||
String createFunctionSql = String.format("CREATE %s FUNCTION ABC.MY_UDAF_MAP(map<string,string>) \n"
|
||||
+ "RETURNS map<string,string> \n"
|
||||
+ "properties (\n"
|
||||
+ " \"symbol\" = \"%s\",\n"
|
||||
+ " \"type\" = \"StarrocksJar\",\n"
|
||||
+ " \"file\" = \"http://localhost:8080/\"\n"
|
||||
+ ");", type, symbol);
|
||||
return (CreateFunctionStmt) com.starrocks.sql.parser.SqlParser.parse(
|
||||
createFunctionSql, 32).get(0);
|
||||
}
|
||||
|
||||
private CreateFunctionStmt createListStmt(String symbol, String type) {
|
||||
String createFunctionSql = String.format("CREATE %s FUNCTION ABC.MY_UDAF_LIST(array<string>) \n"
|
||||
+ "RETURNS array<string> \n"
|
||||
+ "properties (\n"
|
||||
+ " \"symbol\" = \"%s\",\n"
|
||||
+ " \"type\" = \"StarrocksJar\",\n"
|
||||
+ " \"file\" = \"http://localhost:8080/\"\n"
|
||||
+ ");", type, symbol);
|
||||
return (CreateFunctionStmt) com.starrocks.sql.parser.SqlParser.parse(
|
||||
createFunctionSql, 32).get(0);
|
||||
}
|
||||
|
||||
private CreateFunctionStmt createPyStmt(String symbol, String type, String target) {
|
||||
Config.enable_udf = true;
|
||||
String createFunctionSql = String.format("CREATE FUNCTION ABC.MY_UDF_JSON_GET(string, string) \n"
|
||||
|
|
@ -168,6 +192,7 @@ public class CreateFunctionStmtAnalyzerTest {
|
|||
+ ");", args, ret);
|
||||
return sql;
|
||||
}
|
||||
|
||||
void mockClazz(Class<?> clazz) {
|
||||
new MockUp<CreateFunctionAnalyzer>() {
|
||||
@Mock
|
||||
|
|
@ -249,6 +274,66 @@ public class CreateFunctionStmtAnalyzerTest {
|
|||
}
|
||||
}
|
||||
|
||||
public static class EmptyAggMapEval {
|
||||
public static class State {
|
||||
public int serializeLength() {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
public State create() {
|
||||
return new State();
|
||||
}
|
||||
|
||||
public void destroy(State state) {
|
||||
}
|
||||
|
||||
public final void update(State state, Map<String, String> val) {
|
||||
}
|
||||
|
||||
public void serialize(State state, java.nio.ByteBuffer buff) {
|
||||
|
||||
}
|
||||
|
||||
public void merge(State state, java.nio.ByteBuffer buffer) {
|
||||
|
||||
}
|
||||
|
||||
public Map<String, String> finalize(State state) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static class EmptyAggListEval {
|
||||
public static class State {
|
||||
public int serializeLength() {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
public State create() {
|
||||
return new State();
|
||||
}
|
||||
|
||||
public void destroy(State state) {
|
||||
}
|
||||
|
||||
public final void update(State state, List<String> val) {
|
||||
}
|
||||
|
||||
public void serialize(State state, java.nio.ByteBuffer buff) {
|
||||
|
||||
}
|
||||
|
||||
public void merge(State state, java.nio.ByteBuffer buffer) {
|
||||
|
||||
}
|
||||
|
||||
public List<String> finalize(State state) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJUDAF() {
|
||||
try {
|
||||
|
|
@ -277,6 +362,62 @@ public class CreateFunctionStmtAnalyzerTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJUDAFMap() {
|
||||
try {
|
||||
Config.enable_udf = true;
|
||||
new MockUp<CreateFunctionAnalyzer>() {
|
||||
@Mock
|
||||
public String computeMd5(CreateFunctionStmt stmt) {
|
||||
return "0xff";
|
||||
}
|
||||
};
|
||||
new MockUp<UDFInternalClassLoader>() {
|
||||
@Mock
|
||||
public final Class<?> loadClass(String name, boolean resolve)
|
||||
throws ClassNotFoundException {
|
||||
if (name.contains("$")) {
|
||||
return EmptyAggMapEval.State.class;
|
||||
}
|
||||
return EmptyAggMapEval.class;
|
||||
}
|
||||
};
|
||||
CreateFunctionStmt stmt = createMapStmt("symbol", "AGGREGATE");
|
||||
new CreateFunctionAnalyzer().analyze(stmt, connectContext);
|
||||
Assertions.assertEquals("0xff", stmt.getFunction().getChecksum());
|
||||
} finally {
|
||||
Config.enable_udf = false;
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJUDAFList() {
|
||||
try {
|
||||
Config.enable_udf = true;
|
||||
new MockUp<CreateFunctionAnalyzer>() {
|
||||
@Mock
|
||||
public String computeMd5(CreateFunctionStmt stmt) {
|
||||
return "0xff";
|
||||
}
|
||||
};
|
||||
new MockUp<UDFInternalClassLoader>() {
|
||||
@Mock
|
||||
public final Class<?> loadClass(String name, boolean resolve)
|
||||
throws ClassNotFoundException {
|
||||
if (name.contains("$")) {
|
||||
return EmptyAggListEval.State.class;
|
||||
}
|
||||
return EmptyAggListEval.class;
|
||||
}
|
||||
};
|
||||
CreateFunctionStmt stmt = createListStmt("symbol", "AGGREGATE");
|
||||
new CreateFunctionAnalyzer().analyze(stmt, connectContext);
|
||||
Assertions.assertEquals("0xff", stmt.getFunction().getChecksum());
|
||||
} finally {
|
||||
Config.enable_udf = false;
|
||||
}
|
||||
}
|
||||
|
||||
public static class JUDTF {
|
||||
public String[] process(String s, String s2) {
|
||||
return null;
|
||||
|
|
|
|||
|
|
@ -281,4 +281,31 @@ PROPERTIES
|
|||
select echo_map2(map("","")), echo_map2(map(null, null));
|
||||
-- result:
|
||||
{null:""} {null:null}
|
||||
-- !result
|
||||
CREATE aggregate FUNCTION sum_map(map<string,bigint>)
|
||||
RETURNS map<string,bigint>
|
||||
properties (
|
||||
"symbol" = "com.starrocks.udf.SumMapInt64",
|
||||
"type" = "StarrocksJar",
|
||||
"file" = "${udf_url}/starrocks-jdbc/java-udf.jar?v=2"
|
||||
);
|
||||
-- result:
|
||||
-- !result
|
||||
create table map_table (id int, data map<string, bigint>);
|
||||
-- result:
|
||||
-- !result
|
||||
insert into map_table values (1, map{"a": 10, "b": 20}), (1, map{"a": 20, "c": 20}), (2, map{"d": 20, "e": 30}), (2, map{null: 40, "d": 30});
|
||||
-- result:
|
||||
-- !result
|
||||
select id, data from map_table order by id;
|
||||
-- result:
|
||||
1 {"a":10,"b":20}
|
||||
1 {"a":20,"c":20}
|
||||
2 {"d":20,"e":30}
|
||||
2 {null:40,"d":30}
|
||||
-- !result
|
||||
select id, sum_map(data) from map_table group by id order by id;
|
||||
-- result:
|
||||
1 {"a":30,"b":20,"c":20}
|
||||
2 {null:40,"d":50,"e":30}
|
||||
-- !result
|
||||
|
|
@ -179,4 +179,18 @@ PROPERTIES
|
|||
"type" = "StarrocksJar",
|
||||
"file" = "${udf_url}/starrocks-jdbc/ArrayMap.jar?v=2"
|
||||
);
|
||||
select echo_map2(map("","")), echo_map2(map(null, null));
|
||||
select echo_map2(map("","")), echo_map2(map(null, null));
|
||||
|
||||
CREATE aggregate FUNCTION sum_map(map<string,bigint>)
|
||||
RETURNS map<string,bigint>
|
||||
properties (
|
||||
"symbol" = "com.starrocks.udf.SumMapInt64",
|
||||
"type" = "StarrocksJar",
|
||||
"file" = "${udf_url}/starrocks-jdbc/java-udf.jar?v=2"
|
||||
);
|
||||
|
||||
create table map_table (id int, data map<string, bigint>);
|
||||
insert into map_table values (1, map{"a": 10, "b": 20}), (1, map{"a": 20, "c": 20}), (2, map{"d": 20, "e": 30}), (2, map{null: 40, "d": 30});
|
||||
select id, data from map_table order by id;
|
||||
select id, sum_map(data) from map_table group by id order by id;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue