从 Lambda 表达式中捕获类型参数

引言

Forge 在高版本(Minecraft 1.13+)为事件总线添加了使用 Functional Interface 监听事件的方式。以下是接口声明:

<T extends Event> void addListener(Consumer<T> consumer);  

我们可以注意到,我们不需要向 addListener 方法传入一个额外的 Class 代表事件的类型,也就是说事件类型只能通过我们传入的 Consumer<T> 拿到。换言之,我们需要从 Consumer<T> 中推断出具体的 T 是什么类型,考虑到 Java 的泛型擦除机制,这不得不说是一个困难。

本文将以 Consumer<T> 类型为例,阐述将类型参数从泛型类型的实例(尤其是 Lambda 表达式)中捕获的方式。读者应能相对容易地将其推广到其他类型。

本文将使用 Java 11(实际上相较基于 Java 8 的等价代码,只是单纯地多了一些 var 而已),以证明该解决方案不单纯限于 Java 8 或更低版本。

常规解决方案

有的读者可能会说,这很简单啊:我们只需要通过一点反射的小技巧,就可以拿到具体的类型。

类似的实现有很多,比如 Guava 在 TypeToken 中的实现。这里我们也随手写一个:

public static Class<?> getErased(Type type)  
{
    if (type instanceof ParameterizedType)
    {
        return getErased(((ParameterizedType) type).getRawType());
    }
    if (type instanceof GenericArrayType)
    {
        return Array.newInstance(getErased(((GenericArrayType) type).getGenericComponentType()), 0).getClass();
    }
    if (type instanceof TypeVariable<?>)
    {
        var bounds = ((TypeVariable<?>) type).getBounds();
        return bounds.length > 0 ? getErased(bounds[0]) : Object.class;
    }
    if (type instanceof WildcardType)
    {
        var bounds = ((WildcardType) type).getUpperBounds();
        return bounds.length > 0 ? getErased(bounds[0]) : Object.class;
    }
    if (type instanceof Class<?>)
    {
        return (Class<?>) type;
    }
    return Object.class;
}

public static Class<?> getConsumerParameterType(Consumer<?> consumer) throws ReflectiveOperationException  
{
    for (var type : consumer.getClass().getGenericInterfaces())
    {
        if (type instanceof ParameterizedType && ((ParameterizedType) type).getRawType() == Consumer.class)
        {
            return getErased(((ParameterizedType) type).getActualTypeArguments()[0]);
        }
    }
    if (consumer.getClass().isSynthetic())
    {
        return getConsumerLambdaParameterType(consumer);
    }
    throw new NoSuchMethodException();
}

getConsumerParameterType 方法的实现对于匿名内部类非常完美,换句话说,下面这一实例中的具体参数类型将会很容易地被捕获:

Consumer<String> a = new Consumer<String>  
{
    @Override
    public void accept(String s)
    {
        System.out.println(s);
    }
}

但对于 Lambda 表达式(也包括方法引用)呢?

Consumer<String> b = s -> System.out.println(s);  
Consumer<String> c = System.out::println;  

实际上,我们从 Lambda 表达式中,根本拿不到一个 ParameterizedType,更不要说从里面提取参数类型了。我们需要其他的方法,当然了,也是更为 dirty hack 的方法。

探究 Lambda 表达式常量池

每个 .class 后缀的文件中都有一段二进制存放的是该类或接口的常量池(Constant Pool),其中包含着描述每个字段和方法的符号引用。

我们知道,对于每个 Lambda 表达式,JVM 都会为其生成对应的类型,而它们也包含了常量池。我们需要做的,就是把 Lambda 表达式对应的常量池提取出来,并寻找对我们有用的方法的符号引用。

我们需要排除这些符号引用:

  • 构造方法的符号引用
  • 覆盖 Object 类的方法的符号引用

我们可以通过调用 Class 类的 getConstantPool 方法(非公开)拿到常量池,也就是 ConstantPool 类的实例。但是,从 Java 9 开始,常量池的实现被移动到了 jdk.internal.reflect 包下,换言之,如果想要调用常量池的一些非公开方法,我们需要一些更激进的策略。

我们可以通过替换 Method 类下的 override 字段来绕过 JVM 的限制。为此,我们先编写辅助用的方法:

public static Method getMethod(Class<?> objClass, String methodName) throws NoSuchMethodException  
{
    for (var method : objClass.getDeclaredMethods())
    {
        if (methodName.equals(method.getName()))
        {
            return method;
        }
    }
    throw new NoSuchMethodException();
}

public static Object invoke(Object obj, String methodName, Object... args) throws ReflectiveOperationException  
{
    var overrideField = AccessibleObject.class.getDeclaredField("override");
    overrideField.setAccessible(true);
    var targetMethod = getMethod(obj.getClass(), methodName);
    overrideField.set(targetMethod, true);
    return targetMethod.invoke(obj, args);
}

上面的 invoke 方法将会找到特定名称的方法,并针对性的调用该方法。由于仅为演示可行性用,因此上面的方法并未考虑性能,因此如果读者需要在实际开发中用到,请自行对该实现进行优化。

从常量池中获取参数类型

我们将接下来的工作分为四步:

  • 找到 Lambda 表达式对应的 Class 的常量池
  • 依次遍历常量池中对方法的符号引用
  • 排除无关方法的符号引用
  • 从第一个满足条件的方法中取出参数类型

需要用到 ConstantPool 的两个方法:

  • getSize 方法:用于获取常量池的元素个数
  • getMethodAt 方法:用于获取常量池特定位置对方法的符号引用

注意如果 getMethodAt 寻找的位置对应的不是对方法的符号引用,调用该方法将会报错。我们需要把该报错屏蔽掉,然后尝试寻找下一个常量池中的元素。

下面的实现贯彻了上面提到的四步。需要注意的一点是,遍历常量池是倒序进行的:

public static Class<?> getConsumerLambdaParameterType(Consumer<?> consumer) throws ReflectiveOperationException  
{
    var consumerClass = consumer.getClass();
    var constantPool = invoke(consumerClass, "getConstantPool");
    for (var i = (int) invoke(constantPool, "getSize") - 1; i >= 0; --i)
    {
        try
        {
            var member = (Member) invoke(constantPool, "getMethodAt", i);
            if (member instanceof Method && member.getDeclaringClass() != Object.class)
            {
                return ((Method) member).getParameterTypes()[0];
            }
        }
        catch (Exception ignored)
        {
            // ignored
        }
    }
    throw new NoSuchMethodException();
}

我们可以把该方法的调用补充到之前编写的 getConsumerParameterType 方法下。这里使用的判别标准是该方法所对应的类是否是 Java 编译器自动生成的(isSynthetic 方法返回 true):

public static Class<?> getConsumerParameterType(Consumer<?> consumer) throws ReflectiveOperationException  
{
    for (var type : consumer.getClass().getGenericInterfaces())
    {
        if (type instanceof ParameterizedType && ((ParameterizedType) type).getRawType() == Consumer.class)
        {
            return getErased(((ParameterizedType) type).getActualTypeArguments()[0]);
        }
    }
    // lambda start
    if (consumer.getClass().isSynthetic())
    {
        return getConsumerLambdaParameterType(consumer);
    }
    // lambda end
    throw new NoSuchMethodException();
}

各位可以自己试一试了:

public static void main(String[] args)  
{
    try
    {
        Consumer<Consumer<String>> consumerConsumer = c -> c.accept("");
        System.out.println(getConsumerParameterType(consumerConsumer));

        Consumer<String> stringConsumer = System.out::println;
        System.out.println(getConsumerParameterType(stringConsumer));

        Consumer<Long> longConsumer = l -> System.out.println(l + "L");
        System.out.println(getConsumerParameterType(longConsumer));
    }
    catch (ReflectiveOperationException e)
    {
        e.printStackTrace();
    }
}

鸣谢