[Enhancement] Pass stream load label directly to TransactionStmtExecutor (backport #63334) (#63463)

Signed-off-by: meegoo <meegoo.sr@gmail.com>
Co-authored-by: meegoo <meegoo.sr@gmail.com>
This commit is contained in:
mergify[bot] 2025-09-23 10:41:06 -07:00 committed by GitHub
parent ce2f48c0f8
commit b491a3f7c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 85 additions and 34 deletions

View File

@ -159,8 +159,18 @@ public class StreamLoadMultiStmtTask extends AbstractStreamLoadTask {
}
public void beginTxn(TransactionResult resp) {
// Ensure a non-empty label is generated in TransactionStmtExecutor.beginStmt
// by providing a valid executionId. Use the pre-generated loadId as executionId.
if (context.getExecutionId() == null) {
context.setExecutionId(loadId);
}
// Also propagate compute resource so the txn carries the same resource context.
if (context.getCurrentComputeResource() == null) {
context.setCurrentComputeResource(computeResource);
}
TransactionStmtExecutor.beginStmt(context, new BeginStmt(NodePosition.ZERO),
TransactionState.LoadJobSourceType.MULTI_STATEMENT_STREAMING);
TransactionState.LoadJobSourceType.MULTI_STATEMENT_STREAMING, label);
this.txnId = context.getTxnId();
LOG.info("start transaction id {}", txnId);
}

View File

@ -81,28 +81,44 @@ public class TransactionStmtExecutor {
beginStmt(context, stmt, TransactionState.LoadJobSourceType.INSERT_STREAMING);
}
public static void beginStmt(ConnectContext context, BeginStmt stmt, TransactionState.LoadJobSourceType sourceType) {
public static void beginStmt(ConnectContext context, BeginStmt stmt,
TransactionState.LoadJobSourceType sourceType) {
beginStmt(context, stmt, sourceType, null);
}
// Overload allowing explicit label override for creating the transaction state.
// If labelOverride is null or empty, it falls back to the default label built from executionId.
public static void beginStmt(ConnectContext context, BeginStmt stmt,
TransactionState.LoadJobSourceType sourceType,
String labelOverride) {
GlobalTransactionMgr globalTransactionMgr = GlobalStateMgr.getCurrentState().getGlobalTransactionMgr();
if (context.getTxnId() != 0) {
//Repeated begin does not create a new transaction
// Repeated begin does not create a new transaction
ExplicitTxnState explicitTxnState = globalTransactionMgr.getExplicitTxnState(context.getTxnId());
String label = explicitTxnState.getTransactionState().getLabel();
long transactionId = explicitTxnState.getTransactionState().getTransactionId();
context.getState().setOk(0, 0, buildMessage(label, TransactionStatus.PREPARE, transactionId, -1));
context.getState().setOk(0, 0,
buildMessage(label, TransactionStatus.PREPARE, transactionId, -1));
return;
}
long transactionId = GlobalStateMgr.getCurrentState().getGlobalTransactionMgr()
.getTransactionIDGenerator().getNextTransactionId();
String label = DebugUtil.printId(context.getExecutionId());
TransactionState transactionState = new TransactionState(transactionId, label, null,
String label = (labelOverride != null && !labelOverride.isEmpty())
? labelOverride
: DebugUtil.printId(context.getExecutionId());
TransactionState transactionState = new TransactionState(
transactionId,
label,
null,
sourceType,
new TransactionState.TxnCoordinator(TransactionState.TxnSourceType.FE, FrontendOptions.getLocalHostAddress()),
new TransactionState.TxnCoordinator(TransactionState.TxnSourceType.FE,
FrontendOptions.getLocalHostAddress()),
context.getExecTimeout() * 1000L);
transactionState.setPrepareTime(System.currentTimeMillis());
transactionState.setComputeResource(context.getCurrentComputeResource());
boolean combinedTxnLog = LakeTableHelper.supportCombinedTxnLog(TransactionState.LoadJobSourceType.INSERT_STREAMING);
boolean combinedTxnLog = LakeTableHelper.supportCombinedTxnLog(sourceType);
transactionState.setUseCombinedTxnLog(combinedTxnLog);
ExplicitTxnState explicitTxnState = new ExplicitTxnState();
@ -110,7 +126,8 @@ public class TransactionStmtExecutor {
globalTransactionMgr.addTransactionState(transactionId, explicitTxnState);
context.setTxnId(transactionId);
context.getState().setOk(0, 0, buildMessage(label, TransactionStatus.PREPARE, transactionId, -1));
context.getState().setOk(0, 0,
buildMessage(label, TransactionStatus.PREPARE, transactionId, -1));
}
public static void loadData(Database database,

View File

@ -19,10 +19,14 @@ import com.starrocks.catalog.OlapTable;
import com.starrocks.common.Config;
import com.starrocks.common.StarRocksException;
import com.starrocks.common.jmockit.Deencapsulation;
import com.starrocks.common.util.DebugUtil;
import com.starrocks.http.rest.TransactionResult;
import com.starrocks.server.WarehouseManager;
import com.starrocks.thrift.TNetworkAddress;
import com.starrocks.thrift.TUniqueId;
import com.starrocks.transaction.TransactionState;
import com.starrocks.transaction.TransactionStmtExecutor;
import mockit.Mock;
import mockit.MockUp;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -35,29 +39,10 @@ public class StreamLoadMultiStmtTaskTest {
@BeforeEach
public void setUp() {
db = new Database(20000, "test_db");
multiTask = new StreamLoadMultiStmtTask(1L, db, "label_multi", "userA", "127.0.0.1", 10000L,
System.currentTimeMillis(), WarehouseManager.DEFAULT_RESOURCE);
}
@Test
public void testTryLoadExistingSubTask() throws Exception {
// prepare a sub task manually and put into taskMaps
StreamLoadTask sub = new StreamLoadTask(10L, new Database(), new OlapTable(), "label_multi", "userA",
"127.0.0.1", 5000, 1, 0, System.currentTimeMillis(), WarehouseManager.DEFAULT_RESOURCE);
Deencapsulation.setField(sub, "tableName", "t1");
Deencapsulation.invoke(sub, "setState", StreamLoadTask.State.LOADING);
java.util.Map<Integer, TNetworkAddress> addrMap = com.google.common.collect.Maps.newHashMap();
addrMap.put(0, new TNetworkAddress("beHost", 8040));
Deencapsulation.setField(sub, "channelIdToBEHTTPAddress", addrMap);
@SuppressWarnings("unchecked") java.util.Map<String, StreamLoadTask> map =
(java.util.Map<String, StreamLoadTask>) Deencapsulation.getField(multiTask, "taskMaps");
map.put("t1", sub);
TransactionResult resp = new TransactionResult();
TNetworkAddress addr = multiTask.tryLoad(0, "t1", resp);
Assertions.assertTrue(resp.stateOK());
Assertions.assertNotNull(addr);
Assertions.assertEquals("t1", multiTask.getTableName());
// Initialize basic fixtures
db = new Database(1L, "test_db");
multiTask = new StreamLoadMultiStmtTask(1L, db, "label_multi", "u", "127.0.0.1",
1000L, System.currentTimeMillis(), WarehouseManager.DEFAULT_RESOURCE);
}
@Test
@ -90,6 +75,45 @@ public class StreamLoadMultiStmtTaskTest {
Assertions.assertTrue(multiTask.endTimeMs() > 0);
}
@Test
public void testBeginTxnSetsExecutionIdAndResource() throws Exception {
// Capture loadId from task for later comparison
TUniqueId expectedLoadId = (TUniqueId) Deencapsulation.getField(multiTask, "loadId");
// Mock static beginStmt to assert context is pre-populated and set a fake txnId
new MockUp<TransactionStmtExecutor>() {
@Mock
public void beginStmt(com.starrocks.qe.ConnectContext ctx,
com.starrocks.sql.ast.txn.BeginStmt stmt,
TransactionState.LoadJobSourceType sourceType,
String labelOverride) {
// executionId must be set and convertible to non-empty label
Assertions.assertNotNull(ctx.getExecutionId());
String label = DebugUtil.printId(ctx.getExecutionId());
Assertions.assertNotNull(label);
Assertions.assertFalse(label.isEmpty());
// executionId should equal task.loadId
Assertions.assertEquals(expectedLoadId.getHi(), ctx.getExecutionId().getHi());
Assertions.assertEquals(expectedLoadId.getLo(), ctx.getExecutionId().getLo());
// compute resource should be propagated
Assertions.assertEquals(WarehouseManager.DEFAULT_RESOURCE, ctx.getCurrentComputeResource());
// labelOverride should equal the multi-statement task's label
Assertions.assertEquals("label_multi", labelOverride);
// simulate txn id assignment inside begin
ctx.setTxnId(987654321L);
}
};
TransactionResult resp = new TransactionResult();
multiTask.beginTxn(resp);
Assertions.assertTrue(resp.stateOK());
Assertions.assertEquals(987654321L, multiTask.getTxnId());
}
@Test
public void testCheckNeedRemoveAndDurable() throws Exception {
Assertions.assertFalse(multiTask.checkNeedRemove(System.currentTimeMillis(), false));
@ -133,6 +157,6 @@ public class StreamLoadMultiStmtTaskTest {
multiTask.afterVisible(txnState, true);
multiTask.replayOnVisible(txnState);
List<List<String>> show = multiTask.getShowInfo();
Assertions.assertTrue(show.isEmpty() || show.size() >= 0);
Assertions.assertEquals(2, show.size());
}
}