Skip to content

Commit 74da34f

Browse files
authored
Fix calling Average without any arguments (#808)
* Fix calling Average without any arguments * . * TryFindAverageMethod
1 parent a4b4d0a commit 74da34f

File tree

5 files changed

+282
-117
lines changed

5 files changed

+282
-117
lines changed

src/System.Linq.Dynamic.Core/DynamicQueryableExtensions.cs

+15-6
Original file line numberDiff line numberDiff line change
@@ -1803,10 +1803,10 @@ public static IQueryable<TResult> Select<TResult>(this IQueryable source, Parsin
18031803
LambdaExpression lambda = DynamicExpressionParser.ParseLambda(config, createParameterCtor, source.ElementType, typeof(TResult), selector, args);
18041804

18051805
var methodCallExpression = Expression.Call(
1806-
typeof(Queryable),
1806+
typeof(Queryable),
18071807
nameof(Queryable.Select),
18081808
new[] { source.ElementType, typeof(TResult) },
1809-
source.Expression,
1809+
source.Expression,
18101810
Expression.Quote(lambda)
18111811
);
18121812

@@ -2776,10 +2776,11 @@ private static TResult Execute<TResult>(MethodInfo operatorMethodInfo, IQueryabl
27762776
}
27772777

27782778
var optimized = OptimizeExpression(Expression.Call(null, operatorMethodInfo, source.Expression));
2779-
var result = source.Provider.Execute(optimized);
2779+
var result = source.Provider.Execute(optimized)!;
27802780

2781-
return (TResult)Convert.ChangeType(result, typeof(TResult));
2781+
return ConvertResultIfNeeded<TResult>(result);
27822782
}
2783+
27832784
private static object Execute(MethodInfo operatorMethodInfo, IQueryable source, LambdaExpression expression) =>
27842785
Execute(operatorMethodInfo, source, Expression.Quote(expression));
27852786

@@ -2803,12 +2804,20 @@ private static TResult Execute<TResult>(MethodInfo operatorMethodInfo, IQueryabl
28032804
: operatorMethodInfo.MakeGenericMethod(source.ElementType);
28042805

28052806
var optimized = OptimizeExpression(Expression.Call(null, operatorMethodInfo, source.Expression, expression));
2806-
var result = source.Provider.Execute(optimized);
2807+
var result = source.Provider.Execute(optimized)!;
28072808

2808-
return (TResult)Convert.ChangeType(result, typeof(TResult));
2809+
return ConvertResultIfNeeded<TResult>(result);
28092810
}
28102811

2812+
private static TResult ConvertResultIfNeeded<TResult>(object result)
2813+
{
2814+
if (result.GetType() == typeof(TResult))
2815+
{
2816+
return (TResult)result;
2817+
}
28112818

2819+
return (TResult?)Convert.ChangeType(result, typeof(TResult))!;
2820+
}
28122821
#endregion Private Helpers
28132822
}
28142823
}

src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs

+9-2
Original file line numberDiff line numberDiff line change
@@ -2072,8 +2072,8 @@ private bool TryParseEnumerable(Expression instance, Type elementType, string me
20722072
_it = outerIt;
20732073
_parent = oldParent;
20742074

2075-
var typeToCheckForTypeOfString = type ?? instance.Type;
2076-
if (typeToCheckForTypeOfString == typeof(string) && _methodFinder.ContainsMethod(typeToCheckForTypeOfString, methodName, false, instance, ref args))
2075+
var theType = type ?? instance.Type;
2076+
if (theType == typeof(string) && _methodFinder.ContainsMethod(theType, methodName, false, instance, ref args))
20772077
{
20782078
// In case the type is a string, and does contain the methodName (like "IndexOf"), then return false to indicate that the methodName is not an Enumerable method.
20792079
expression = null;
@@ -2096,6 +2096,13 @@ private bool TryParseEnumerable(Expression instance, Type elementType, string me
20962096
callType = typeof(Queryable);
20972097
}
20982098

2099+
// #633 - For Average without any arguments, try to find the non-generic Average method on the callType for the supplied parameter type.
2100+
if (methodName == nameof(Enumerable.Average) && args.Length == 0 && _methodFinder.TryFindAverageMethod(callType, theType, out var averageMethod))
2101+
{
2102+
expression = Expression.Call(null, averageMethod, new[] { instance });
2103+
return true;
2104+
}
2105+
20992106
Type[] typeArgs;
21002107
if (new[] { "OfType", "Cast" }.Contains(methodName))
21012108
{

src/System.Linq.Dynamic.Core/Parser/SupportedMethods/MethodFinder.cs

+14
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System.Collections.Generic;
2+
using System.Diagnostics.CodeAnalysis;
23
using System.Linq.Dynamic.Core.Validation;
34
using System.Linq.Expressions;
45
using System.Reflection;
@@ -44,6 +45,19 @@ public MethodFinder(ParsingConfig parsingConfig, IExpressionHelper expressionHel
4445
_expressionHelper = Check.NotNull(expressionHelper);
4546
}
4647

48+
public bool TryFindAverageMethod(Type callType, Type parameterType, [NotNullWhen(true)] out MethodInfo? averageMethod)
49+
{
50+
averageMethod = callType
51+
.GetMethods()
52+
.Where(m => m is { Name: nameof(Enumerable.Average), IsGenericMethodDefinition: false })
53+
.SelectMany(m => m.GetParameters(), (m, p) => new { Method = m, Parameter = p })
54+
.Where(x => x.Parameter.ParameterType == parameterType)
55+
.Select(x => x.Method)
56+
.FirstOrDefault();
57+
58+
return averageMethod != null;
59+
}
60+
4761
public void CheckAggregateMethodAndTryUpdateArgsToMatchMethodArgs(string methodName, ref Expression[] args)
4862
{
4963
if (methodName is nameof(IAggregateSignatures.Average) or nameof(IAggregateSignatures.Sum))

src/System.Linq.Dynamic.Core/Util/QueryableMethodFinder.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ public static MethodInfo GetGenericMethod(string name)
1111
{
1212
return typeof(Queryable).GetTypeInfo().GetDeclaredMethods(name).Single(mi => mi.IsGenericMethod);
1313
}
14-
14+
1515
public static MethodInfo GetMethod(string name, Type argumentType, Type returnType, int parameterCount = 0, Func<MethodInfo, bool>? predicate = null) =>
16-
GetMethod(name, returnType, parameterCount, mi => mi.ToString().Contains(argumentType.ToString()) && ((predicate == null) || predicate(mi)));
16+
GetMethod(name, returnType, parameterCount, mi => mi.ToString().Contains(argumentType.ToString()) && (predicate == null || predicate(mi)));
1717

1818
public static MethodInfo GetMethod(string name, Type returnType, int parameterCount = 0, Func<MethodInfo, bool>? predicate = null)
1919
{

0 commit comments

Comments
 (0)