Fix race condition introduced by background platform channels (flutter/engine#29377)

This commit is contained in:
Emmanuel Garcia
2021-11-05 15:35:43 -07:00
committed by GitHub
parent 6ddd9a22c1
commit 69ac1ed4dc
6 changed files with 283 additions and 37 deletions

View File

@@ -286,6 +286,12 @@ public class DartExecutor implements BinaryMessenger {
}
}
@Override
public void enableBufferingIncomingMessages() {}
@Override
public void disableBufferingIncomingMessages() {}
/**
* Configuration options that specify which Dart entrypoint function is executed and where to find
* that entrypoint and other assets required for Dart execution.
@@ -461,5 +467,11 @@ public class DartExecutor implements BinaryMessenger {
@Nullable TaskQueue taskQueue) {
messenger.setMessageHandler(channel, handler, taskQueue);
}
@Override
public void enableBufferingIncomingMessages() {}
@Override
public void disableBufferingIncomingMessages() {}
}
}

View File

@@ -14,9 +14,10 @@ import io.flutter.embedding.engine.FlutterJNI;
import io.flutter.plugin.common.BinaryMessenger;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -33,22 +34,37 @@ class DartMessenger implements BinaryMessenger, PlatformMessageHandler {
@NonNull private final FlutterJNI flutterJNI;
@NonNull private final ConcurrentHashMap<String, HandlerInfo> messageHandlers;
/**
* Maps a channel name to an object that contains the task queue and the handler associated with
* the channel.
*
* <p>Reads and writes to this map must lock {@code handlersLock}.
*/
@NonNull private final Map<String, HandlerInfo> messageHandlers = new HashMap<>();
@NonNull private final Map<Integer, BinaryMessenger.BinaryReply> pendingReplies;
/**
* Maps a channel name to an object that holds information about the incoming Dart message.
*
* <p>Reads and writes to this map must lock {@code handlersLock}.
*/
@NonNull private Map<String, List<BufferedMessageInfo>> bufferedMessages = new HashMap<>();
@NonNull private final Object handlersLock = new Object();
@NonNull private final AtomicBoolean enableBufferingIncomingMessages = new AtomicBoolean(false);
@NonNull private final Map<Integer, BinaryMessenger.BinaryReply> pendingReplies = new HashMap<>();
private int nextReplyId = 1;
@NonNull private final DartMessengerTaskQueue platformTaskQueue = new PlatformTaskQueue();
@NonNull private WeakHashMap<TaskQueue, DartMessengerTaskQueue> createdTaskQueues;
@NonNull
private WeakHashMap<TaskQueue, DartMessengerTaskQueue> createdTaskQueues =
new WeakHashMap<TaskQueue, DartMessengerTaskQueue>();
@NonNull private TaskQueueFactory taskQueueFactory;
DartMessenger(@NonNull FlutterJNI flutterJNI, @NonNull TaskQueueFactory taskQueueFactory) {
this.flutterJNI = flutterJNI;
this.messageHandlers = new ConcurrentHashMap<>();
this.pendingReplies = new HashMap<>();
this.createdTaskQueues = new WeakHashMap<TaskQueue, DartMessengerTaskQueue>();
this.taskQueueFactory = taskQueueFactory;
}
@@ -78,6 +94,10 @@ class DartMessenger implements BinaryMessenger, PlatformMessageHandler {
}
}
/**
* Holds information about a platform handler, such as the task queue that processes messages from
* Dart.
*/
private static class HandlerInfo {
@NonNull public final BinaryMessenger.BinaryMessageHandler handler;
@Nullable public final DartMessengerTaskQueue taskQueue;
@@ -90,7 +110,22 @@ class DartMessenger implements BinaryMessenger, PlatformMessageHandler {
}
}
/** A serial task queue that can run on a concurrent ExecutorService. */
/**
* Holds information that allows to dispatch a Dart message to a platform handler when it becomes
* available.
*/
private static class BufferedMessageInfo {
@NonNull public final ByteBuffer message;
int replyId;
long messageData;
BufferedMessageInfo(@NonNull ByteBuffer message, int replyId, long messageData) {
this.message = message;
this.replyId = replyId;
this.messageData = messageData;
}
}
static class DefaultTaskQueue implements DartMessengerTaskQueue {
@NonNull private final ExecutorService executor;
@NonNull private final ConcurrentLinkedQueue<Runnable> queue;
@@ -154,18 +189,53 @@ class DartMessenger implements BinaryMessenger, PlatformMessageHandler {
@Nullable TaskQueue taskQueue) {
if (handler == null) {
Log.v(TAG, "Removing handler for channel '" + channel + "'");
messageHandlers.remove(channel);
} else {
DartMessengerTaskQueue dartMessengerTaskQueue = null;
if (taskQueue != null) {
dartMessengerTaskQueue = createdTaskQueues.get(taskQueue);
if (dartMessengerTaskQueue == null) {
throw new IllegalArgumentException(
"Unrecognized TaskQueue, use BinaryMessenger to create your TaskQueue (ex makeBackgroundTaskQueue).");
}
synchronized (handlersLock) {
messageHandlers.remove(channel);
}
Log.v(TAG, "Setting handler for channel '" + channel + "'");
return;
}
DartMessengerTaskQueue dartMessengerTaskQueue = null;
if (taskQueue != null) {
dartMessengerTaskQueue = createdTaskQueues.get(taskQueue);
if (dartMessengerTaskQueue == null) {
throw new IllegalArgumentException(
"Unrecognized TaskQueue, use BinaryMessenger to create your TaskQueue (ex makeBackgroundTaskQueue).");
}
}
Log.v(TAG, "Setting handler for channel '" + channel + "'");
List<BufferedMessageInfo> list;
synchronized (handlersLock) {
messageHandlers.put(channel, new HandlerInfo(handler, dartMessengerTaskQueue));
list = bufferedMessages.remove(channel);
if (list == null) {
return;
}
}
for (BufferedMessageInfo info : list) {
dispatchMessageToQueue(
channel, messageHandlers.get(channel), info.message, info.replyId, info.messageData);
}
}
@Override
public void enableBufferingIncomingMessages() {
enableBufferingIncomingMessages.set(true);
}
@Override
public void disableBufferingIncomingMessages() {
Map<String, List<BufferedMessageInfo>> pendingMessages;
synchronized (handlersLock) {
enableBufferingIncomingMessages.set(false);
pendingMessages = bufferedMessages;
bufferedMessages = new HashMap<>();
}
for (Map.Entry<String, List<BufferedMessageInfo>> channel : pendingMessages.entrySet()) {
for (BufferedMessageInfo info : channel.getValue()) {
dispatchMessageToQueue(
channel.getKey(), null, info.message, info.replyId, info.messageData);
}
}
}
@@ -218,16 +288,12 @@ class DartMessenger implements BinaryMessenger, PlatformMessageHandler {
}
}
@Override
public void handleMessageFromDart(
@NonNull final String channel,
private void dispatchMessageToQueue(
@NonNull String channel,
@Nullable HandlerInfo handlerInfo,
@Nullable ByteBuffer message,
final int replyId,
int replyId,
long messageData) {
// Called from the ui thread.
Log.v(TAG, "Received message from Dart over channel '" + channel + "'");
@Nullable final HandlerInfo handlerInfo = messageHandlers.get(channel);
@Nullable
final DartMessengerTaskQueue taskQueue = (handlerInfo != null) ? handlerInfo.taskQueue : null;
Runnable myRunnable =
() -> {
@@ -235,8 +301,8 @@ class DartMessenger implements BinaryMessenger, PlatformMessageHandler {
try {
invokeHandler(handlerInfo, message, replyId);
if (message != null && message.isDirect()) {
// This ensures that if a user retains an instance to the ByteBuffer and it happens to
// be direct they will get a deterministic error.
// This ensures that if a user retains an instance to the ByteBuffer and it
// happens to be direct they will get a deterministic error.
message.limit(0);
}
} finally {
@@ -245,12 +311,43 @@ class DartMessenger implements BinaryMessenger, PlatformMessageHandler {
Trace.endSection();
}
};
@NonNull
final DartMessengerTaskQueue nonnullTaskQueue =
taskQueue == null ? platformTaskQueue : taskQueue;
nonnullTaskQueue.dispatch(myRunnable);
}
@Override
public void handleMessageFromDart(
@NonNull String channel, @Nullable ByteBuffer message, int replyId, long messageData) {
// Called from the ui thread.
Log.v(TAG, "Received message from Dart over channel '" + channel + "'");
HandlerInfo handlerInfo;
boolean messageDeferred;
synchronized (handlersLock) {
handlerInfo = messageHandlers.get(channel);
messageDeferred = (enableBufferingIncomingMessages.get() && handlerInfo == null);
if (messageDeferred) {
// The channel is not defined when the Dart VM sends a message before the channels are
// registered.
//
// This is possible if the Dart VM starts before channel registration, and if the thread
// that registers the channels is busy or slow at registering the channel handlers.
//
// In such cases, the task dispatchers are queued, and processed when the channel is
// defined.
if (!bufferedMessages.containsKey(channel)) {
bufferedMessages.put(channel, new LinkedList<>());
}
List<BufferedMessageInfo> buffer = bufferedMessages.get(channel);
buffer.add(new BufferedMessageInfo(message, replyId, messageData));
}
}
if (!messageDeferred) {
dispatchMessageToQueue(channel, handlerInfo, message, replyId, messageData);
}
}
@Override
public void handlePlatformMessageResponse(int replyId, @Nullable ByteBuffer reply) {
Log.v(TAG, "Received message reply from Dart.");

View File

@@ -116,6 +116,22 @@ public interface BinaryMessenger {
setMessageHandler(channel, handler);
}
/**
* Enables the ability to queue messages received from Dart.
*
* <p>This is useful when there are pending channel handler registrations. For example, Dart may
* be initialized concurrently, and prior to the registration of the channel handlers. This
* implies that Dart may start sending messages while plugins are being registered.
*/
void enableBufferingIncomingMessages();
/**
* Disables the ability to queue messages received from Dart.
*
* <p>This can be used after all pending channel handlers have been registered.
*/
void disableBufferingIncomingMessages();
/** Handler for incoming binary messages from Flutter. */
interface BinaryMessageHandler {
/**

View File

@@ -159,6 +159,12 @@ public class FlutterNativeView implements BinaryMessenger {
dartExecutor.getBinaryMessenger().setMessageHandler(channel, handler, taskQueue);
}
@Override
public void enableBufferingIncomingMessages() {}
@Override
public void disableBufferingIncomingMessages() {}
/*package*/ FlutterJNI getFlutterJNI() {
return mFlutterJNI;
}

View File

@@ -347,6 +347,12 @@ public class FlutterView extends SurfaceView
mFirstFrameListeners.remove(listener);
}
@Override
public void enableBufferingIncomingMessages() {}
@Override
public void disableBufferingIncomingMessages() {}
/**
* Reverts this back to the {@link SurfaceView} defaults, at the back of its window and opaque.
*/

View File

@@ -1,13 +1,18 @@
package io.flutter.embedding.engine.dart;
import static android.os.Looper.getMainLooper;
import static junit.framework.TestCase.assertEquals;
import static junit.framework.TestCase.assertNotNull;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertArrayEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.robolectric.Shadows.shadowOf;
import io.flutter.embedding.engine.FlutterJNI;
import io.flutter.embedding.engine.dart.DartMessenger.DartMessengerTaskQueue;
@@ -21,6 +26,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.robolectric.RobolectricTestRunner;
import org.robolectric.annotation.Config;
@@ -160,16 +166,16 @@ public class DartMessengerTest {
public void cleansUpMessageData() throws InterruptedException {
final FlutterJNI fakeFlutterJni = mock(FlutterJNI.class);
final DartMessenger messenger = new DartMessenger(fakeFlutterJni, () -> synchronousTaskQueue);
BinaryMessenger.TaskQueue taskQueue = messenger.makeBackgroundTaskQueue();
String channel = "foobar";
final BinaryMessenger.TaskQueue taskQueue = messenger.makeBackgroundTaskQueue();
final String channel = "foobar";
BinaryMessenger.BinaryMessageHandler handler =
(ByteBuffer message, BinaryMessenger.BinaryReply reply) -> {
reply.reply(null);
};
messenger.setMessageHandler(channel, handler, taskQueue);
final ByteBuffer message = ByteBuffer.allocateDirect(4 * 2);
int replyId = 1;
long messageData = 1234;
final int replyId = 1;
final long messageData = 1234;
messenger.handleMessageFromDart(channel, message, replyId, messageData);
verify(fakeFlutterJni).cleanupMessageData(eq(messageData));
}
@@ -178,21 +184,124 @@ public class DartMessengerTest {
public void cleansUpMessageDataOnError() throws InterruptedException {
final FlutterJNI fakeFlutterJni = mock(FlutterJNI.class);
final DartMessenger messenger = new DartMessenger(fakeFlutterJni, () -> synchronousTaskQueue);
BinaryMessenger.TaskQueue taskQueue = messenger.makeBackgroundTaskQueue();
String channel = "foobar";
final BinaryMessenger.TaskQueue taskQueue = messenger.makeBackgroundTaskQueue();
final String channel = "foobar";
BinaryMessenger.BinaryMessageHandler handler =
(ByteBuffer message, BinaryMessenger.BinaryReply reply) -> {
throw new RuntimeException("hello");
};
messenger.setMessageHandler(channel, handler, taskQueue);
final ByteBuffer message = ByteBuffer.allocateDirect(4 * 2);
int replyId = 1;
long messageData = 1234;
final int replyId = 1;
final long messageData = 1234;
messenger.handleMessageFromDart(channel, message, replyId, messageData);
verify(fakeFlutterJni).cleanupMessageData(eq(messageData));
}
@Test
public void emptyResponseWhenHandlerIsNotSet() throws InterruptedException {
final FlutterJNI fakeFlutterJni = mock(FlutterJNI.class);
final DartMessenger messenger = new DartMessenger(fakeFlutterJni, () -> synchronousTaskQueue);
final String channel = "foobar";
final ByteBuffer message = ByteBuffer.allocateDirect(4 * 2);
final int replyId = 1;
final long messageData = 1234;
messenger.handleMessageFromDart(channel, message, replyId, messageData);
shadowOf(getMainLooper()).idle();
verify(fakeFlutterJni).invokePlatformMessageEmptyResponseCallback(replyId);
}
@Test
public void buffersResponseWhenHandlerIsNotSet() throws InterruptedException {
final FlutterJNI fakeFlutterJni = mock(FlutterJNI.class);
final DartMessenger messenger = new DartMessenger(fakeFlutterJni, () -> synchronousTaskQueue);
final BinaryMessenger.TaskQueue taskQueue = messenger.makeBackgroundTaskQueue();
final String channel = "foobar";
final ByteBuffer message = ByteBuffer.allocateDirect(4 * 2);
final int replyId = 1;
final long messageData = 1234;
messenger.enableBufferingIncomingMessages();
messenger.handleMessageFromDart(channel, message, replyId, messageData);
shadowOf(getMainLooper()).idle();
verify(fakeFlutterJni, never()).invokePlatformMessageEmptyResponseCallback(eq(replyId));
final BinaryMessenger.BinaryMessageHandler handler =
(ByteBuffer msg, BinaryMessenger.BinaryReply reply) -> {
reply.reply(ByteBuffer.wrap("done".getBytes()));
};
messenger.setMessageHandler(channel, handler, taskQueue);
shadowOf(getMainLooper()).idle();
verify(fakeFlutterJni, never()).invokePlatformMessageEmptyResponseCallback(eq(replyId));
final ArgumentCaptor<ByteBuffer> response = ArgumentCaptor.forClass(ByteBuffer.class);
verify(fakeFlutterJni)
.invokePlatformMessageResponseCallback(anyInt(), response.capture(), anyInt());
assertArrayEquals("done".getBytes(), response.getValue().array());
}
@Test
public void disableBufferingTriggersEmptyResponseForPendingMessages()
throws InterruptedException {
final FlutterJNI fakeFlutterJni = mock(FlutterJNI.class);
final DartMessenger messenger = new DartMessenger(fakeFlutterJni, () -> synchronousTaskQueue);
final String channel = "foobar";
final ByteBuffer message = ByteBuffer.allocateDirect(4 * 2);
final int replyId = 1;
final long messageData = 1234;
messenger.enableBufferingIncomingMessages();
messenger.handleMessageFromDart(channel, message, replyId, messageData);
shadowOf(getMainLooper()).idle();
verify(fakeFlutterJni, never()).invokePlatformMessageEmptyResponseCallback(replyId);
messenger.disableBufferingIncomingMessages();
shadowOf(getMainLooper()).idle();
verify(fakeFlutterJni).invokePlatformMessageEmptyResponseCallback(replyId);
}
@Test
public void emptyResponseWhenHandlerIsUnregistered() throws InterruptedException {
final FlutterJNI fakeFlutterJni = mock(FlutterJNI.class);
final DartMessenger messenger = new DartMessenger(fakeFlutterJni, () -> synchronousTaskQueue);
final BinaryMessenger.TaskQueue taskQueue = messenger.makeBackgroundTaskQueue();
final String channel = "foobar";
final ByteBuffer message = ByteBuffer.allocateDirect(4 * 2);
final int replyId = 1;
final long messageData = 1234;
messenger.enableBufferingIncomingMessages();
messenger.handleMessageFromDart(channel, message, replyId, messageData);
shadowOf(getMainLooper()).idle();
verify(fakeFlutterJni, never()).invokePlatformMessageEmptyResponseCallback(eq(replyId));
final BinaryMessenger.BinaryMessageHandler handler =
(ByteBuffer msg, BinaryMessenger.BinaryReply reply) -> {
reply.reply(ByteBuffer.wrap("done".getBytes()));
};
messenger.setMessageHandler(channel, handler, taskQueue);
shadowOf(getMainLooper()).idle();
verify(fakeFlutterJni, never()).invokePlatformMessageEmptyResponseCallback(eq(replyId));
final ArgumentCaptor<ByteBuffer> response = ArgumentCaptor.forClass(ByteBuffer.class);
verify(fakeFlutterJni)
.invokePlatformMessageResponseCallback(anyInt(), response.capture(), anyInt());
assertArrayEquals("done".getBytes(), response.getValue().array());
messenger.disableBufferingIncomingMessages();
messenger.setMessageHandler(channel, null, null); // Unregister handler.
messenger.handleMessageFromDart(channel, message, replyId, messageData);
shadowOf(getMainLooper()).idle();
verify(fakeFlutterJni).invokePlatformMessageEmptyResponseCallback(replyId);
}
public void testSerialTaskQueue() throws InterruptedException {
final FlutterJNI fakeFlutterJni = mock(FlutterJNI.class);
final DartMessenger messenger = new DartMessenger(fakeFlutterJni);