starrocks/test/test_sql_cases.py

599 lines
24 KiB
Python

# !/usr/bin/env python
# -*- coding: utf-8 -*-
###########################################################################
# Copyright 2021-present StarRocks, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###########################################################################
"""
test_sql_cases
@Time : 2022/10/28 13:41
@Author : Brook.Ye
"""
import copy
import json
import os
import re
import sys
import threading
import time
import uuid
from typing import List
import pymysql
from nose import tools
from parameterized import parameterized
from cup import log
from lib import sr_sql_lib
from lib import choose_cases
from lib import sql_annotation
from lib.sr_sql_lib import self_print
from lib import *
# model: run case model, True => Record mode
# - t: run sql and save result into r dir
# - r: run sql and compare result with r
record_mode = os.environ.get("record_mode", "false") == "true"
arrow_mode = os.environ.get("arrow_mode", "false") == "true"
case_list = choose_cases.choose_cases(record_mode).case_list
if len(case_list) == 0:
self_print(f"{'-' * 60}\n** INFO: No case! **", ColorEnum.RED)
sys.exit(0)
def doc_func(_, num, param):
"""doc func"""
if param[0][0].file == "":
return "%-5d: %s" % (num, param[0][0].name)
else:
return "%-5d: %s" % (num, param[0][0].file + ":" + param[0][0].name)
def name_func(testcase_func, param_num, param):
"""name func"""
return parameterized.to_safe_name(param.args[0].name)
class TestSQLCases(sr_sql_lib.StarrocksSQLApiLib):
"""
SQL Framework Run Test
"""
_multiprocess_can_split_ = True
def __init__(self, *args, **kwargs):
"""
init
"""
super(TestSQLCases, self).__init__(*args, **kwargs)
self.case_info: choose_cases.ChooseCase.CaseTR
self._check_db_unique()
def setUp(self, *args, **kwargs):
"""set up"""
super().setUp()
self.connect_starrocks()
self.create_starrocks_conn_pool()
self.check_cluster_status()
self._init_global_configs()
self._init_global_session_variables()
def _init_global_configs(self):
"""
Configs that are not ready for production but it can be used for testing.
"""
default_configs = [
"'mv_refresh_fail_on_filter_data' = 'true'",
"'enable_mv_refresh_query_rewrite' = 'true'",
# enlarge task run concurrency to speed up mv's refresh and find more potential bugs
"'task_runs_concurrency' = '16'",
]
for config in default_configs:
sql = "ADMIN SET FRONTEND CONFIG (%s)" % config
self.execute_sql(sql)
def _init_global_session_variables(self):
"""
Session variables that are not ready for production but it can be used for testing.
"""
session_variables = [
"new_planner_optimize_timeout = 10000",
]
for session_variable in session_variables:
sql = "SET %s;" % session_variable
self.execute_sql(sql)
@sql_annotation.ignore_timeout()
def tearDown(self):
"""tear down"""
super().tearDown()
log.info("[TearDown begin]: %s" % self.case_info.name)
# reset the scheduler interval
sql = "ADMIN SET FRONTEND CONFIG 'dynamic_partition_check_interval_seconds' = '600'"
self.execute_sql(sql)
# run custom cleanup (always)
try:
if hasattr(self.case_info, "cleanup"):
for stmt in self.case_info.cleanup:
try:
self.execute_single_statement(stmt, -1, False)
except Exception as e:
log.warning(f"cleanup stmt error: {e}")
except Exception as e:
log.warning(f"cleanup error: {e}")
for each_db in self.db:
self.drop_database(each_db)
for each_resource in self.resource:
self.drop_resource(each_resource)
res = None
if record_mode:
# save case result into db
res = self.save_r_into_db(self.case_info.file, self.case_info.name, self.res_log, self.version)
self.close_starrocks()
# destroy connection pool
while self.connection_pool and len(self.connection_pool._idle_cache) > 0:
log.info(f"Teardown: freeze conn pool, size: {len(self.connection_pool._idle_cache)}")
conn = self.connection_pool._idle_cache.pop()
conn.close()
self.close_trino()
self.close_spark()
self.close_hive()
self.close_starrocks_arrow()
if record_mode:
tools.assert_true(res, "Save %s.%s result error" % (self.case_info.file, self.case_info.name))
self_print(f"{'*' * 20} [FINISH] {self.case_info.name} {'*' * 20}", ColorEnum.GREEN, bold=True)
def _init_data(self, sql_list: List) -> List:
self.db = list()
self.resource = list()
sql_list = self._replace_uuid_variables(sql_list)
for sql in sql_list:
if isinstance(sql, str):
db_name = self._get_db_name(sql)
if len(db_name) > 0:
self.db.append(db_name)
resource_name = self._get_resource_name(sql)
if len(resource_name) > 0:
self.resource.append(resource_name)
elif isinstance(sql, dict) and sql.get("type", "") == LOOP_FLAG:
tools.assert_in("stat", sql, "LOOP STATEMENT FORMAT ERROR!")
for each_sql in sql["stat"]:
db_name = self._get_db_name(each_sql)
if len(db_name) > 0:
self.db.append(db_name)
resource_name = self._get_resource_name(each_sql)
if len(resource_name) > 0:
self.resource.append(resource_name)
elif isinstance(sql, dict) and sql.get("type", "") == CONCURRENCY_FLAG:
tools.assert_in("thread", sql, "CONCURRENCY THREAD FORMAT ERROR!")
for each_thread in sql["thread"]:
for each_cmd in each_thread["cmd"]:
db_name = self._get_db_name(each_cmd)
if len(db_name) > 0:
self.db.append(db_name)
resource_name = self._get_resource_name(each_cmd)
if len(resource_name) > 0:
self.resource.append(resource_name)
elif isinstance(sql, dict) and sql.get("type", "") == CLEANUP_FLAG:
tools.assert_in("cmd", sql, "CLEANUP STATEMENT FORMAT ERROR!")
for each_cmd in sql["cmd"]:
db_name = self._get_db_name(each_cmd)
if len(db_name) > 0:
self.db.append(db_name)
resource_name = self._get_resource_name(each_cmd)
if len(resource_name) > 0:
self.resource.append(resource_name)
else:
tools.ok_(False, "Init data error!")
self._clear_db_and_resource_if_exists()
if len(self.db) == 0:
self._create_and_use_db()
return sql_list
def _clear_db_and_resource_if_exists(self):
for each_db in self.db:
log.info("init drop db: %s" % each_db)
self.drop_database(each_db)
for each_resource in self.resource:
log.info("init drop resource: %s" % each_resource)
self.drop_resource(each_resource)
def _create_and_use_db(self):
db_name = "test_db_%s" % uuid.uuid4().hex
self.db.append(db_name)
self.execute_sql("CREATE DATABASE %s;" % db_name)
self.execute_sql("USE %s;" % db_name)
self_print("[SQL]: CREATE DATABASE %s;" % db_name)
self_print("[SQL]: USE %s;" % db_name)
def _check_db_unique(self):
all_db_dict = dict()
for case in case_list:
sql_list = self._replace_uuid_variables(case.sql)
for sql in sql_list:
if isinstance(sql, str):
db_name = self._get_db_name(sql)
if len(db_name) > 0:
all_db_dict.setdefault(db_name, set()).add(case.name)
elif isinstance(sql, dict) and sql.get("type", "") == LOOP_FLAG:
tools.assert_in("stat", sql, "LOOP STATEMENT FORMAT ERROR!")
for each_sql in sql["stat"]:
db_name = self._get_db_name(each_sql)
if len(db_name) > 0:
all_db_dict.setdefault(db_name, set()).add(case.name)
elif isinstance(sql, dict) and sql.get("type", "") == CONCURRENCY_FLAG:
tools.assert_in("thread", sql, "CONCURRENCY THREAD FORMAT ERROR!")
for each_thread in sql["thread"]:
for each_cmd in each_thread["cmd"]:
db_name = self._get_db_name(each_cmd)
if len(db_name) > 0:
all_db_dict.setdefault(db_name, set()).add(case.name)
elif isinstance(sql, dict) and sql.get("type", "") == CLEANUP_FLAG:
tools.assert_in("cmd", sql, "CLEANUP STATEMENT FORMAT ERROR!")
for each_cmd in sql["cmd"]:
db_name = self._get_db_name(each_cmd)
if len(db_name) > 0:
all_db_dict.setdefault(db_name, set()).add(case.name)
else:
tools.ok_(False, "Check db uniqueness error!")
error_info_dict = {db: list(cases) for db, cases in all_db_dict.items() if len(cases) > 1}
tools.assert_true(
len(error_info_dict) <= 0, "Pre Check Failed, Duplicate DBs: \n%s" % json.dumps(error_info_dict, indent=2)
)
@staticmethod
def _replace_uuid_variables(sql_list: List) -> List:
ret = list()
variable_dict = {}
for sql in sql_list:
if isinstance(sql, str):
uuid_vars = re.findall(r"\${(uuid[0-9]*)}", sql)
for each_uuid in uuid_vars:
if each_uuid not in variable_dict:
variable_dict[each_uuid] = uuid.uuid4().hex
elif isinstance(sql, dict) and sql.get("type", "") == LOOP_FLAG:
tools.assert_in("stat", sql, "LOOP STATEMENT FORMAT ERROR!")
for each_sql in sql["stat"]:
uuid_vars = re.findall(r"\${(uuid[0-9]*)}", each_sql)
for each_uuid in uuid_vars:
if each_uuid not in variable_dict:
variable_dict[each_uuid] = uuid.uuid4().hex
elif isinstance(sql, dict) and sql.get("type", "") == CONCURRENCY_FLAG:
tools.assert_in("thread", sql, "CONCURRENCY THREAD FORMAT ERROR!")
for each_thread in sql["thread"]:
for each_cmd in each_thread["cmd"]:
uuid_vars = re.findall(r"\${(uuid[0-9]*)}", each_cmd)
for each_uuid in uuid_vars:
if each_uuid not in variable_dict:
variable_dict[each_uuid] = uuid.uuid4().hex
elif isinstance(sql, dict) and sql.get("type", "") == CLEANUP_FLAG:
tools.assert_in("cmd", sql, "CLEANUP STATEMENT FORMAT ERROR!")
for each_cmd in sql["cmd"]:
uuid_vars = re.findall(r"\${(uuid[0-9]*)}", each_cmd)
for each_uuid in uuid_vars:
if each_uuid not in variable_dict:
variable_dict[each_uuid] = uuid.uuid4().hex
else:
tools.ok_(False, "Replace uuid error!")
for sql in sql_list:
if isinstance(sql, str):
for each_var in variable_dict:
sql = sql.replace("${%s}" % each_var, variable_dict[each_var])
ret.append(sql)
elif isinstance(sql, dict) and sql.get("type", "") == LOOP_FLAG:
_sql = copy.deepcopy(sql)
tmp_stat = []
for each_sql in _sql["stat"]:
for each_var in variable_dict:
each_sql = each_sql.replace("${%s}" % each_var, variable_dict[each_var])
tmp_stat.append(each_sql)
_sql["stat"] = tmp_stat
ret.append(_sql)
elif isinstance(sql, dict) and sql.get("type", "") == CONCURRENCY_FLAG:
_sql = copy.deepcopy(sql)
tools.assert_in("thread", sql, "CONCURRENCY THREAD FORMAT ERROR!")
for each_thread in _sql["thread"]:
tmp_cmd = []
for each_cmd in each_thread["cmd"]:
for each_var in variable_dict:
each_cmd = each_cmd.replace("${%s}" % each_var, variable_dict[each_var])
tmp_cmd.append(each_cmd)
each_thread["cmd"] = tmp_cmd
ret.append(_sql)
elif isinstance(sql, dict) and sql.get("type", "") == CLEANUP_FLAG:
_sql = copy.deepcopy(sql)
tools.assert_in("cmd", _sql, "CLEANUP STATEMENT FORMAT ERROR!")
tmp_cmd = []
for each_cmd in _sql["cmd"]:
for each_var in variable_dict:
each_cmd = each_cmd.replace("${%s}" % each_var, variable_dict[each_var])
tmp_cmd.append(each_cmd)
_sql["cmd"] = tmp_cmd
ret.append(_sql)
return ret
@staticmethod
def _get_db_name(sql: str) -> str:
db_name = ""
if "CREATE DATABASE" in sql.upper():
db_name = sql.rstrip(";").strip().split(" ")[-1]
return db_name
@staticmethod
def _get_resource_name(sql: str) -> str:
matches = list()
if "CREATE EXTERNAL RESOURCE" in sql.upper():
matches = re.findall(r"CREATE EXTERNAL RESOURCE \"?([a-zA-Z0-9_-]+)\"?", sql, flags=re.IGNORECASE)
return matches[0] if len(matches) > 0 else ""
# -------------------------------------------
# [CASE]
# -------------------------------------------
@parameterized.expand([[case_info] for case_info in case_list], doc_func=doc_func, name_func=name_func)
@sql_annotation.timeout()
def test_sql_basic(self, case_info: choose_cases.ChooseCase.CaseTR):
"""
sql tester
Args:
case_info:
name: case name
file: case info {table: duplicate, file: A.csv ...}
sql: run sql
result: result
"""
# -------------------------------------------
# [CASE EXECUTE]
# -------------------------------------------
# replace all db_name with each run
self.case_info = case_info
self_print("-" * 60, ColorEnum.GREEN, bold=True)
self_print(f"[case name]: {case_info.name}", ColorEnum.GREEN, bold=True)
self_print(f"[case file]: {case_info.file}", ColorEnum.GREEN, bold=True)
self_print("-" * 60, ColorEnum.GREEN, bold=True)
sql_list = self._init_data(case_info.sql)
self_print(f"\t → case db: {self.db}")
self_print(f"\t → case resource: {self.resource}")
log.info(
"""
*********************************************
Start to run: %s
*********************************************"""
% case_info.name
)
# record mode, init info
if record_mode:
self.res_log.append(case_info.info)
for sql_id, sql in enumerate(sql_list):
if arrow_mode and isinstance(sql, str):
sql = sql.strip()
if sql.startswith(
(
"mysql:",
"shell:",
"--",
"function:",
"CHECK:",
"PROPERTY:",
"LOOP",
"END LOOP",
"CONCURRENCY",
"END CONCURRENCY",
)
):
self_print(f"[arrow_mode] Skip non-arrow SQL: {sql}", ColorEnum.YELLOW)
continue
if not sql.startswith("arrow:"):
sql = "arrow: " + sql
uncheck = False
ori_sql = sql
var = None
if isinstance(sql, str):
if record_mode:
self.res_log.append(case_info.ori_sql[sql_id])
# uncheck flag, owns the highest priority
if sql.startswith(sr_sql_lib.UNCHECK_FLAG):
uncheck = True
sql = sql[len(sr_sql_lib.UNCHECK_FLAG) :]
actual_res, actual_res_log, var, order = self.execute_single_statement(sql, sql_id, record_mode)
if not record_mode and not uncheck:
# check mode only work in validating mode
# pretreatment expect res
expect_res = case_info.result[sql_id]
expect_res_for_log = expect_res if len(expect_res) < 1000 else expect_res[:1000] + "..."
log.info(
f"""[{sql_id}.result]:
- [exp]: {expect_res_for_log}
- [act]: {actual_res}"""
)
# -------------------------------------------
# [CHECKER]
# -------------------------------------------
self.check(sql_id, sql, expect_res, actual_res, order, ori_sql)
elif isinstance(sql, dict) and sql["type"] == sr_sql_lib.LOOP_FLAG:
# loop statement
# loop properties
loop_timeout = sql["prop"]["timeout"]
loop_interval = sql["prop"]["interval"]
loop_desc = sql["prop"]["desc"]
self_print(f"\n[LOOP] Start: {loop_desc}...", color=ColorEnum.BLUE, logout=True, bold=True)
if record_mode:
self.res_log.append("".join(sql["ori"]))
loop_check_res = self.execute_loop_statement(sql["stat"], sql_id, loop_timeout, loop_interval)
self_print(f"[LOOP] Finish!\n", color=ColorEnum.BLUE, logout=True, bold=True)
tools.ok_(loop_check_res, f"Loop checker fail: {''.join(sql['ori'])}!")
elif isinstance(sql, dict) and sql.get("type", "") == CLEANUP_FLAG:
# record mode: write cleanup block in-place to res_log to keep position
if record_mode:
if isinstance(sql.get("ori"), list):
self.res_log.append("".join(sql["ori"]))
else:
self.res_log.append("CLEANUP {")
for stmt in sql.get("cmd", []):
self.res_log.append(stmt)
self.res_log.append("} END CLEANUP")
# do nothing during execution here; cleanup executes in tearDown
continue
elif isinstance(sql, dict) and sql["type"] == sr_sql_lib.CONCURRENCY_FLAG:
# concurrency statement
self_print(f"[CONCURRENCY] Start...", color=ColorEnum.CYAN, logout=True)
if record_mode:
self.res_log.append("\n" + CONCURRENCY_FLAG + " {")
t_info_list: List[dict] = sql["thread"]
thread_list = []
# get now db
_outer_db = self.get_now_db()
# thread group
_t_conn_list = []
for _t_info_id, _thread in enumerate(t_info_list):
# _thread prop: name, count, info, ori, cmd, res
_t_count = _thread["count"]
_t_name = _thread["name"]
_t_ori_cmd = _thread["ori"]
_t_cmd = _thread["cmd"]
_t_res = _thread["res"]
# thread exec, set count in (*)
for _t_exec_id in range(_t_count):
this_t_id = f"{_t_name}-{_t_info_id}-{_t_exec_id}"
# init a conn for thread
this_conn = self.connection_pool.connection()
_t_conn_list.append([this_conn, this_t_id])
t = threading.Thread(
name=f"Thread-{this_t_id}",
target=self.execute_thread,
args=(this_t_id, _t_cmd, _t_res, _t_ori_cmd, record_mode, _outer_db, this_conn),
)
thread_list.append(t)
threading.excepthook = self.custom_except_hook
for thread in thread_list:
thread.start()
for thread in thread_list:
thread.join()
# release conn
for _conn, _id in _t_conn_list:
self_print(f"[{_id}] Close connection...", color=ColorEnum.CYAN, logout=True)
_conn.close()
if len(self.thread_res) != 0:
err_t_name = "\n\t- ".join(self.thread_res.keys())
err_msg = f"[CONCURRENCY] FAIL:\n\t- {err_t_name}"
self_print(err_msg, color=ColorEnum.RED, bold=True, logout=True)
# console detail
for _, t_err in self.thread_res.items():
self_print(t_err, bold=True, logout=True)
tools.ok_(False, err_msg)
# record mode
if record_mode:
for _t_info_id, _thread in enumerate(t_info_list):
_t_name = _thread["name"]
_t_count = _thread["count"]
_t_uid = f"{_t_name}-{_t_info_id}"
_t_info_line = _thread["info"]
self.res_log.append(_t_info_line)
# check thread result info
tools.assert_in(_t_uid, self.thread_res_log, f"Thread log of {_t_uid} is not found!")
tools.eq_(
len(self.thread_res_log[_t_uid]),
_t_count,
f"Thread log size: {len(self.thread_res_log[_t_uid])} error, maybe you used the same thread name?",
)
s_thread_log = self.thread_res_log[_t_uid][0]
for exec_res_log in self.thread_res_log[_t_uid]:
if exec_res_log != s_thread_log:
self_print(
"Thread result of exec not equal: \n - %s\n - %s" % (exec_res_log, s_thread_log),
color=ColorEnum.RED,
logout=True,
bold=True,
)
self.res_log.extend(s_thread_log)
self.res_log.append("")
self_print(f"[CONCURRENCY] SUCCESS!", color=ColorEnum.CYAN, logout=True)
if record_mode:
self.res_log.append("} " + END_CONCURRENCY_FLAG + "\n")