Skip to content

Commit 26ed2f4

Browse files
committed
Fixed #74 (Join on nullable and not nullable type)
1 parent 1e5c225 commit 26ed2f4

File tree

7 files changed

+212
-55
lines changed

7 files changed

+212
-55
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//using System.Linq.Expressions;
2+
3+
//namespace System.Linq.Dynamic.Core
4+
//{
5+
// /// <summary>
6+
// /// DynamicExpressionArgument
7+
// /// </summary>
8+
// public class DynamicExpressionArgument
9+
// {
10+
// /// <summary>
11+
// /// If set to <c>true</c> then also create a constructor for all the parameters. Note that this doesn't work for Linq-to-Database entities.
12+
// /// </summary>
13+
// public bool CreateParameterCtor { get; set; }
14+
15+
// /// <summary>
16+
// /// Parameters
17+
// /// </summary>
18+
// public ParameterExpression[] Parameters { get; set; }
19+
20+
// /// <summary>
21+
// /// ResultType
22+
// /// </summary>
23+
// public Type ResultType { get; set; }
24+
25+
// /// <summary>
26+
// /// Expression
27+
// /// </summary>
28+
// public string Expression { get; set; }
29+
30+
// /// <summary>
31+
// /// Values
32+
// /// </summary>
33+
// public object[] Values { get; set; }
34+
// }
35+
//}

src/System.Linq.Dynamic.Core/DynamicQueryableExtensions.cs

+83-54
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using System.Collections.Generic;
22
using System.Collections;
3+
using System.Globalization;
4+
using System.Linq.Dynamic.Core.Exceptions;
35
#if !(WINDOWS_APP45x || SILVERLIGHT)
46
using System.Diagnostics;
57
#endif
@@ -46,7 +48,7 @@ private static Expression OptimizeExpression(Expression expression)
4648
return expression;
4749
}
4850

49-
#region Any
51+
#region Any
5052
private static readonly MethodInfo _any = GetMethod(nameof(Queryable.Any));
5153

5254
/// <summary>
@@ -94,9 +96,9 @@ public static bool Any([NotNull] this IQueryable source, [NotNull] string predic
9496

9597
return Execute<bool>(_anyPredicate, source, lambda);
9698
}
97-
#endregion Any
99+
#endregion Any
98100

99-
#region AsEnumerable
101+
#region AsEnumerable
100102
#if NET35
101103
/// <summary>
102104
/// Returns the input typed as <see cref="IEnumerable{T}"/> of <see cref="object"/>./>
@@ -118,9 +120,9 @@ public static IEnumerable<dynamic> AsEnumerable([NotNull] this IQueryable source
118120
yield return obj;
119121
}
120122
}
121-
#endregion AsEnumerable
123+
#endregion AsEnumerable
122124

123-
#region Count
125+
#region Count
124126
private static readonly MethodInfo _count = GetMethod(nameof(Queryable.Count));
125127

126128
/// <summary>
@@ -168,9 +170,9 @@ public static int Count([NotNull] this IQueryable source, [NotNull] string predi
168170

169171
return Execute<int>(_countPredicate, source, lambda);
170172
}
171-
#endregion Count
173+
#endregion Count
172174

173-
#region Distinct
175+
#region Distinct
174176
private static readonly MethodInfo _distinct = GetMethod(nameof(Queryable.Distinct));
175177

176178
/// <summary>
@@ -192,9 +194,9 @@ public static IQueryable Distinct([NotNull] this IQueryable source)
192194
var optimized = OptimizeExpression(Expression.Call(typeof(Queryable), "Distinct", new Type[] { source.ElementType }, source.Expression));
193195
return source.Provider.CreateQuery(optimized);
194196
}
195-
#endregion Distinct
197+
#endregion Distinct
196198

197-
#region First
199+
#region First
198200
private static readonly MethodInfo _first = GetMethod(nameof(Queryable.First));
199201

200202
/// <summary>
@@ -236,9 +238,9 @@ public static dynamic First([NotNull] this IQueryable source, [NotNull] string p
236238

237239
return Execute(_firstPredicate, source, lambda);
238240
}
239-
#endregion First
241+
#endregion First
240242

241-
#region FirstOrDefault
243+
#region FirstOrDefault
242244
/// <summary>
243245
/// Returns the first element of a sequence, or a default value if the sequence contains no elements.
244246
/// </summary>
@@ -278,9 +280,9 @@ public static dynamic FirstOrDefault([NotNull] this IQueryable source, [NotNull]
278280
return Execute(_firstOrDefaultPredicate, source, lambda);
279281
}
280282
private static readonly MethodInfo _firstOrDefaultPredicate = GetMethod(nameof(Queryable.FirstOrDefault), 1);
281-
#endregion FirstOrDefault
283+
#endregion FirstOrDefault
282284

283-
#region GroupBy
285+
#region GroupBy
284286
/// <summary>
285287
/// Groups the elements of a sequence according to a specified key string function
286288
/// and creates a result value from each group and its key.
@@ -366,9 +368,9 @@ public static IQueryable GroupBy([NotNull] this IQueryable source, [NotNull] str
366368

367369
return source.Provider.CreateQuery(optimized);
368370
}
369-
#endregion GroupBy
371+
#endregion GroupBy
370372

371-
#region GroupByMany
373+
#region GroupByMany
372374
/// <summary>
373375
/// Groups the elements of a sequence according to multiple specified key string functions
374376
/// and creates a result value from each group (and subgroups) and its key.
@@ -427,9 +429,9 @@ static IEnumerable<GroupResult> GroupByManyInternal<TElement>(IEnumerable<TEleme
427429

428430
return result;
429431
}
430-
#endregion GroupByMany
432+
#endregion GroupByMany
431433

432-
#region Join
434+
#region Join
433435
/// <summary>
434436
/// Correlates the elements of two sequences based on matching keys. The default equality comparer is used to compare keys.
435437
/// </summary>
@@ -450,20 +452,47 @@ public static IQueryable Join([NotNull] this IQueryable outer, [NotNull] IEnumer
450452
Check.NotEmpty(innerKeySelector, nameof(innerKeySelector));
451453
Check.NotEmpty(resultSelector, nameof(resultSelector));
452454

455+
Type outerType = outer.ElementType;
456+
Type innerType = inner.AsQueryable().ElementType;
457+
453458
bool createParameterCtor = outer.IsLinqToObjects();
454-
LambdaExpression outerSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, outer.ElementType, null, outerKeySelector, args);
455-
LambdaExpression innerSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, inner.AsQueryable().ElementType, null, innerKeySelector, args);
459+
LambdaExpression outerSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, outerType, null, outerKeySelector, args);
460+
LambdaExpression innerSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, innerType, null, innerKeySelector, args);
461+
462+
Type outerSelectorReturnType = outerSelectorLambda.Body.Type;
463+
Type innerSelectorReturnType = innerSelectorLambda.Body.Type;
464+
465+
// If types are not the same, try to convert to Nullable and generate new LambdaExpression
466+
if (outerSelectorReturnType != innerSelectorReturnType)
467+
{
468+
if (ExpressionParser.IsNullableType(outerSelectorReturnType) && !ExpressionParser.IsNullableType(innerSelectorReturnType))
469+
{
470+
innerSelectorReturnType = ExpressionParser.ToNullableType(innerSelectorReturnType);
471+
innerSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, innerType, innerSelectorReturnType, innerKeySelector, args);
472+
}
473+
else if (!ExpressionParser.IsNullableType(outerSelectorReturnType) && ExpressionParser.IsNullableType(innerSelectorReturnType))
474+
{
475+
outerSelectorReturnType = ExpressionParser.ToNullableType(outerSelectorReturnType);
476+
outerSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, outerType, outerSelectorReturnType, outerKeySelector, args);
477+
}
478+
479+
// If types are still not the same, throw an Exception
480+
if (outerSelectorReturnType != innerSelectorReturnType)
481+
{
482+
throw new ParseException(string.Format(CultureInfo.CurrentCulture, Res.IncompatibleTypes, outerType, innerType), -1);
483+
}
484+
}
456485

457-
ParameterExpression[] parameters = new[]
486+
ParameterExpression[] parameters =
458487
{
459-
Expression.Parameter(outer.ElementType, "outer"), Expression.Parameter(inner.AsQueryable().ElementType, "inner")
488+
Expression.Parameter(outerType, "outer"), Expression.Parameter(innerType, "inner")
460489
};
461490

462491
LambdaExpression resultSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, parameters, null, resultSelector, args);
463492

464493
var optimized = OptimizeExpression(Expression.Call(
465494
typeof(Queryable), "Join",
466-
new[] { outer.ElementType, inner.AsQueryable().ElementType, outerSelectorLambda.Body.Type, resultSelectorLambda.Body.Type },
495+
new[] { outerType, innerType, outerSelectorLambda.Body.Type, resultSelectorLambda.Body.Type },
467496
outer.Expression, // outer: The first sequence to join.
468497
inner.AsQueryable().Expression, // inner: The sequence to join to the first sequence.
469498
Expression.Quote(outerSelectorLambda), // outerKeySelector: A function to extract the join key from each element of the first sequence.
@@ -490,9 +519,9 @@ public static IQueryable<TElement> Join<TElement>([NotNull] this IQueryable<TEle
490519
{
491520
return (IQueryable<TElement>)Join((IQueryable)outer, (IEnumerable)inner, outerKeySelector, innerKeySelector, resultSelector, args);
492521
}
493-
#endregion Join
522+
#endregion Join
494523

495-
#region Last
524+
#region Last
496525
private static readonly MethodInfo _last = GetMethod(nameof(Queryable.Last));
497526
/// <summary>
498527
/// Returns the last element of a sequence.
@@ -509,9 +538,9 @@ public static dynamic Last([NotNull] this IQueryable source)
509538

510539
return Execute(_last, source);
511540
}
512-
#endregion Last
541+
#endregion Last
513542

514-
#region LastOrDefault
543+
#region LastOrDefault
515544
private static readonly MethodInfo _lastDefault = GetMethod(nameof(Queryable.LastOrDefault));
516545
/// <summary>
517546
/// Returns the last element of a sequence, or a default value if the sequence contains no elements.
@@ -528,9 +557,9 @@ public static dynamic LastOrDefault([NotNull] this IQueryable source)
528557

529558
return Execute(_lastDefault, source);
530559
}
531-
#endregion LastOrDefault
560+
#endregion LastOrDefault
532561

533-
#region OrderBy
562+
#region OrderBy
534563
/// <summary>
535564
/// Sorts the elements of a sequence in ascending or descending order according to a key.
536565
/// </summary>
@@ -589,9 +618,9 @@ public static IOrderedQueryable OrderBy([NotNull] this IQueryable source, [NotNu
589618
var optimized = OptimizeExpression(queryExpr);
590619
return (IOrderedQueryable)source.Provider.CreateQuery(optimized);
591620
}
592-
#endregion OrderBy
621+
#endregion OrderBy
593622

594-
#region Page/PageResult
623+
#region Page/PageResult
595624
/// <summary>
596625
/// Returns the elements as paged.
597626
/// </summary>
@@ -677,9 +706,9 @@ public static PagedResult<TSource> PageResult<TSource>([NotNull] this IQueryable
677706

678707
return result;
679708
}
680-
#endregion Page/PageResult
709+
#endregion Page/PageResult
681710

682-
#region Reverse
711+
#region Reverse
683712
/// <summary>
684713
/// Inverts the order of the elements in a sequence.
685714
/// </summary>
@@ -691,9 +720,9 @@ public static IQueryable Reverse([NotNull] this IQueryable source)
691720

692721
return Queryable.Reverse((IQueryable<object>)source);
693722
}
694-
#endregion Reverse
723+
#endregion Reverse
695724

696-
#region Select
725+
#region Select
697726
/// <summary>
698727
/// Projects each element of a sequence into a new form.
699728
/// </summary>
@@ -786,9 +815,9 @@ public static IQueryable Select([NotNull] this IQueryable source, [NotNull] Type
786815

787816
return source.Provider.CreateQuery(optimized);
788817
}
789-
#endregion Select
818+
#endregion Select
790819

791-
#region SelectMany
820+
#region SelectMany
792821
/// <summary>
793822
/// Projects each element of a sequence to an <see cref="IQueryable"/> and combines the resulting sequences into one sequence.
794823
/// </summary>
@@ -991,9 +1020,9 @@ public static IQueryable SelectMany([NotNull] this IQueryable source, [NotNull]
9911020

9921021
return source.Provider.CreateQuery(optimized);
9931022
}
994-
#endregion SelectMany
1023+
#endregion SelectMany
9951024

996-
#region Single/SingleOrDefault
1025+
#region Single/SingleOrDefault
9971026
/// <summary>
9981027
/// Returns the only element of a sequence, and throws an exception if there
9991028
/// is not exactly one element in the sequence.
@@ -1030,9 +1059,9 @@ public static dynamic SingleOrDefault([NotNull] this IQueryable source)
10301059
var optimized = OptimizeExpression(Expression.Call(typeof(Queryable), "SingleOrDefault", new[] { source.ElementType }, source.Expression));
10311060
return source.Provider.Execute(optimized);
10321061
}
1033-
#endregion Single/SingleOrDefault
1062+
#endregion Single/SingleOrDefault
10341063

1035-
#region Skip
1064+
#region Skip
10361065
private static readonly MethodInfo _skip = GetMethod(nameof(Queryable.Skip), 1);
10371066

10381067
/// <summary>
@@ -1052,9 +1081,9 @@ public static IQueryable Skip([NotNull] this IQueryable source, int count)
10521081

10531082
return CreateQuery(_skip, source, Expression.Constant(count));
10541083
}
1055-
#endregion Skip
1084+
#endregion Skip
10561085

1057-
#region SkipWhile
1086+
#region SkipWhile
10581087
private static readonly MethodInfo _skipWhilePredicate = GetMethod(nameof(Queryable.SkipWhile), 1, _predicateParameterHas2);
10591088

10601089
/// <summary>
@@ -1081,9 +1110,9 @@ public static IQueryable SkipWhile([NotNull] this IQueryable source, [NotNull] s
10811110

10821111
return CreateQuery(_skipWhilePredicate, source, lambda);
10831112
}
1084-
#endregion SkipWhile
1113+
#endregion SkipWhile
10851114

1086-
#region Sum
1115+
#region Sum
10871116
/// <summary>
10881117
/// Computes the sum of a sequence of numeric values.
10891118
/// </summary>
@@ -1096,9 +1125,9 @@ public static object Sum([NotNull] this IQueryable source)
10961125
var optimized = OptimizeExpression(Expression.Call(typeof(Queryable), "Sum", null, source.Expression));
10971126
return source.Provider.Execute(optimized);
10981127
}
1099-
#endregion Sum
1128+
#endregion Sum
11001129

1101-
#region Take
1130+
#region Take
11021131
private static readonly MethodInfo _take = GetMethod(nameof(Queryable.Take), 1);
11031132
/// <summary>
11041133
/// Returns a specified number of contiguous elements from the start of a sequence.
@@ -1113,9 +1142,9 @@ public static IQueryable Take([NotNull] this IQueryable source, int count)
11131142

11141143
return CreateQuery(_take, source, Expression.Constant(count));
11151144
}
1116-
#endregion Take
1145+
#endregion Take
11171146

1118-
#region TakeWhile
1147+
#region TakeWhile
11191148
private static readonly MethodInfo _takeWhilePredicate = GetMethod(nameof(Queryable.TakeWhile), 1, _predicateParameterHas2);
11201149

11211150
/// <summary>
@@ -1144,7 +1173,7 @@ public static IQueryable TakeWhile([NotNull] this IQueryable source, [NotNull] s
11441173
}
11451174
#endregion TakeWhile
11461175

1147-
#region ThenBy
1176+
#region ThenBy
11481177
/// <summary>
11491178
/// Performs a subsequent ordering of the elements in a sequence in ascending order according to a key.
11501179
/// </summary>
@@ -1205,9 +1234,9 @@ public static IOrderedQueryable ThenBy([NotNull] this IOrderedQueryable source,
12051234
var optimized = OptimizeExpression(queryExpr);
12061235
return (IOrderedQueryable)source.Provider.CreateQuery(optimized);
12071236
}
1208-
#endregion OrderBy
1237+
#endregion OrderBy
12091238

1210-
#region Where
1239+
#region Where
12111240
/// <summary>
12121241
/// Filters a sequence of values based on a predicate.
12131242
/// </summary>
@@ -1260,9 +1289,9 @@ public static IQueryable Where([NotNull] this IQueryable source, [NotNull] strin
12601289
var optimized = OptimizeExpression(Expression.Call(typeof(Queryable), "Where", new[] { source.ElementType }, source.Expression, Expression.Quote(lambda)));
12611290
return source.Provider.CreateQuery(optimized);
12621291
}
1263-
#endregion
1292+
#endregion
12641293

1265-
#region Private Helpers
1294+
#region Private Helpers
12661295
// Code below is based on https://github.com/aspnet/EntityFramework/blob/9186d0b78a3176587eeb0f557c331f635760fe92/src/Microsoft.EntityFrameworkCore/EntityFrameworkQueryableExtensions.cs
12671296

12681297
private static IQueryable CreateQuery(MethodInfo operatorMethodInfo, IQueryable source)
@@ -1341,6 +1370,6 @@ private static MethodInfo GetMethod<TResult>(string name, int parameterCount = 0
13411370

13421371
private static MethodInfo GetMethod(string name, int parameterCount = 0, Func<MethodInfo, bool> predicate = null) =>
13431372
typeof(Queryable).GetTypeInfo().GetDeclaredMethods(name).Single(mi => (mi.GetParameters().Length == parameterCount + 1) && ((predicate == null) || predicate(mi)));
1344-
#endregion Private Helpers
1373+
#endregion Private Helpers
13451374
}
13461375
}

0 commit comments

Comments
 (0)