Skip to content

Commit 8993daa

Browse files
committed
CSHARP-4453: Support Bucket and BucketAuto stages in LINQ3.
1 parent ec46c34 commit 8993daa

20 files changed

+741
-514
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using MongoDB.Bson;
17+
using MongoDB.Bson.IO;
18+
using MongoDB.Bson.Serialization;
19+
using MongoDB.Bson.Serialization.Serializers;
20+
using MongoDB.Driver.Core.Misc;
21+
22+
namespace MongoDB.Driver
23+
{
24+
/// <summary>
25+
/// Static factory class for AggregateBucketAutoResultIdSerializer.
26+
/// </summary>
27+
public static class AggregateBucketAutoResultIdSerializer
28+
{
29+
/// <summary>
30+
/// Creates an instance of AggregateBucketAutoResultIdSerializer.
31+
/// </summary>
32+
/// <typeparam name="TValue">The value type.</typeparam>
33+
/// <param name="valueSerializer">The value serializer.</param>
34+
/// <returns>A AggregateBucketAutoResultIdSerializer.</returns>
35+
public static IBsonSerializer<AggregateBucketAutoResultId<TValue>> Create<TValue>(IBsonSerializer<TValue> valueSerializer)
36+
{
37+
return new AggregateBucketAutoResultIdSerializer<TValue>(valueSerializer);
38+
}
39+
}
40+
41+
/// <summary>
42+
/// A serializer for AggregateBucketAutoResultId.
43+
/// </summary>
44+
/// <typeparam name="TValue">The type of the values.</typeparam>
45+
public class AggregateBucketAutoResultIdSerializer<TValue> : ClassSerializerBase<AggregateBucketAutoResultId<TValue>>, IBsonDocumentSerializer
46+
{
47+
private readonly IBsonSerializer<TValue> _valueSerializer;
48+
49+
/// <summary>
50+
/// Initializes a new instance of the <see cref="AggregateBucketAutoResultIdSerializer{TValue}"/> class.
51+
/// </summary>
52+
/// <param name="valueSerializer">The value serializer.</param>
53+
public AggregateBucketAutoResultIdSerializer(IBsonSerializer<TValue> valueSerializer)
54+
{
55+
_valueSerializer = Ensure.IsNotNull(valueSerializer, nameof(valueSerializer));
56+
}
57+
58+
/// <inheritdoc/>
59+
protected override AggregateBucketAutoResultId<TValue> DeserializeValue(BsonDeserializationContext context, BsonDeserializationArgs args)
60+
{
61+
var reader = context.Reader;
62+
reader.ReadStartDocument();
63+
TValue min = default;
64+
TValue max = default;
65+
while (reader.ReadBsonType() != 0)
66+
{
67+
var name = reader.ReadName();
68+
switch (name)
69+
{
70+
case "min": min = _valueSerializer.Deserialize(context); break;
71+
case "max": max = _valueSerializer.Deserialize(context); break;
72+
default: throw new BsonSerializationException($"Invalid element name for AggregateBucketAutoResultId: {name}.");
73+
}
74+
}
75+
reader.ReadEndDocument();
76+
return new AggregateBucketAutoResultId<TValue>(min, max);
77+
}
78+
79+
/// <inheritdoc/>
80+
protected override void SerializeValue(BsonSerializationContext context, BsonSerializationArgs args, AggregateBucketAutoResultId<TValue> value)
81+
{
82+
var writer = context.Writer;
83+
writer.WriteStartDocument();
84+
writer.WriteName("min");
85+
_valueSerializer.Serialize(context, value.Min);
86+
writer.WriteName("max");
87+
_valueSerializer.Serialize(context, value.Max);
88+
writer.WriteEndDocument();
89+
}
90+
91+
/// <inheritdoc/>
92+
public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo)
93+
{
94+
serializationInfo = memberName switch
95+
{
96+
"Min" => new BsonSerializationInfo("min", _valueSerializer, _valueSerializer.ValueType),
97+
"Max" => new BsonSerializationInfo("max", _valueSerializer, _valueSerializer.ValueType),
98+
_ => null
99+
};
100+
return serializationInfo != null;
101+
}
102+
}
103+
}

src/MongoDB.Driver/GroupForLinq3Result.cs

-57
This file was deleted.

src/MongoDB.Driver/IAggregateFluentExtensions.cs

+36-11
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public static IAggregateFluent<AggregateBucketAutoResult<TValue>> BucketAuto<TRe
9595
}
9696

9797
/// <summary>
98-
/// Appends a $bucketAuto stage to the pipeline.
98+
/// Appends a $bucketAuto stage to the pipeline (this overload can only be used with LINQ3).
9999
/// </summary>
100100
/// <typeparam name="TResult">The type of the result.</typeparam>
101101
/// <typeparam name="TValue">The type of the value.</typeparam>
@@ -110,13 +110,46 @@ public static IAggregateFluent<TNewResult> BucketAuto<TResult, TValue, TNewResul
110110
this IAggregateFluent<TResult> aggregate,
111111
Expression<Func<TResult, TValue>> groupBy,
112112
int buckets,
113-
Expression<Func<IGrouping<TValue, TResult>, TNewResult>> output,
113+
Expression<Func<IGrouping<AggregateBucketAutoResultId<TValue>, TResult>, TNewResult>> output,
114114
AggregateBucketAutoOptions options = null)
115115
{
116116
Ensure.IsNotNull(aggregate, nameof(aggregate));
117+
if (aggregate.Database.Client.Settings.LinqProvider != LinqProvider.V3)
118+
{
119+
throw new InvalidOperationException("This overload of BucketAuto can only be used with LINQ3.");
120+
}
121+
117122
return aggregate.AppendStage(PipelineStageDefinitionBuilder.BucketAuto(groupBy, buckets, output, options));
118123
}
119124

125+
/// <summary>
126+
/// Appends a $bucketAuto stage to the pipeline (this method can only be used with LINQ2).
127+
/// </summary>
128+
/// <typeparam name="TResult">The type of the result.</typeparam>
129+
/// <typeparam name="TValue">The type of the value.</typeparam>
130+
/// <typeparam name="TNewResult">The type of the new result.</typeparam>
131+
/// <param name="aggregate">The aggregate.</param>
132+
/// <param name="groupBy">The expression providing the value to group by.</param>
133+
/// <param name="buckets">The number of buckets.</param>
134+
/// <param name="output">The output projection.</param>
135+
/// <param name="options">The options (optional).</param>
136+
/// <returns>The fluent aggregate interface.</returns>
137+
public static IAggregateFluent<TNewResult> BucketAutoForLinq2<TResult, TValue, TNewResult>(
138+
this IAggregateFluent<TResult> aggregate,
139+
Expression<Func<TResult, TValue>> groupBy,
140+
int buckets,
141+
Expression<Func<IGrouping<TValue, TResult>, TNewResult>> output, // the IGrouping for BucketAuto has been wrong all along, only fixing it for LINQ3
142+
AggregateBucketAutoOptions options = null)
143+
{
144+
Ensure.IsNotNull(aggregate, nameof(aggregate));
145+
if (aggregate.Database.Client.Settings.LinqProvider != LinqProvider.V2)
146+
{
147+
throw new InvalidOperationException("The BucketAutoForLinq2 method can only be used with LINQ2.");
148+
}
149+
150+
return aggregate.AppendStage(PipelineStageDefinitionBuilder.BucketAutoForLinq2(groupBy, buckets, output, options));
151+
}
152+
120153
/// <summary>
121154
/// Appends a $densify stage to the pipeline.
122155
/// </summary>
@@ -396,15 +429,7 @@ public static IAggregateFluent<BsonDocument> Group<TResult>(this IAggregateFluen
396429
public static IAggregateFluent<TNewResult> Group<TResult, TKey, TNewResult>(this IAggregateFluent<TResult> aggregate, Expression<Func<TResult, TKey>> id, Expression<Func<IGrouping<TKey, TResult>, TNewResult>> group)
397430
{
398431
Ensure.IsNotNull(aggregate, nameof(aggregate));
399-
if (aggregate.Database.Client.Settings.LinqProvider == LinqProvider.V2)
400-
{
401-
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Group(id, group));
402-
}
403-
else
404-
{
405-
var (groupStage, projectStage) = PipelineStageDefinitionBuilder.GroupForLinq3(id, group);
406-
return aggregate.AppendStage(groupStage).AppendStage(projectStage);
407-
}
432+
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Group(id, group));
408433
}
409434

410435
/// <summary>

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupPipelineOptimizer.cs src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs

+72-18
Original file line numberDiff line numberDiff line change
@@ -24,41 +24,50 @@
2424

2525
namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Optimizers
2626
{
27-
internal class AstGroupPipelineOptimizer
27+
internal class AstGroupingPipelineOptimizer
2828
{
2929
#region static
3030
public static AstPipeline Optimize(AstPipeline pipeline)
3131
{
32-
var optimizer = new AstGroupPipelineOptimizer();
32+
var optimizer = new AstGroupingPipelineOptimizer();
3333
for (var i = 0; i < pipeline.Stages.Count; i++)
3434
{
3535
var stage = pipeline.Stages[i];
36-
if (stage is AstGroupStage groupStage)
36+
if (IsGroupingStage(stage))
3737
{
38-
pipeline = optimizer.OptimizeGroupStage(pipeline, i, groupStage);
38+
pipeline = optimizer.OptimizeGroupingStage(pipeline, i, stage);
3939
}
4040
}
4141

4242
return pipeline;
43+
44+
static bool IsGroupingStage(AstStage stage)
45+
{
46+
return stage.NodeType switch
47+
{
48+
AstNodeType.GroupStage or AstNodeType.BucketStage or AstNodeType.BucketAutoStage => true,
49+
_ => false
50+
};
51+
}
4352
}
4453
#endregion
4554

4655
private readonly AccumulatorSet _accumulators = new AccumulatorSet();
4756
private AstExpression _element; // normally either "$$ROOT" or "$_v"
4857

49-
private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStage groupStage)
58+
private AstPipeline OptimizeGroupingStage(AstPipeline pipeline, int i, AstStage groupingStage)
5059
{
5160
try
5261
{
53-
if (IsOptimizableGroupStage(groupStage, out _element))
62+
if (IsOptimizableGroupingStage(groupingStage, out _element))
5463
{
5564
var followingStages = GetFollowingStagesToOptimize(pipeline, i + 1);
5665
if (followingStages == null)
5766
{
5867
return pipeline;
5968
}
6069

61-
var mappings = OptimizeGroupAndFollowingStages(groupStage, followingStages);
70+
var mappings = OptimizeGroupingAndFollowingStages(groupingStage, followingStages);
6271
if (mappings.Length > 0)
6372
{
6473
return (AstPipeline)AstNodeReplacer.Replace(pipeline, mappings);
@@ -72,23 +81,57 @@ private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStag
7281

7382
return pipeline;
7483

75-
static bool IsOptimizableGroupStage(AstGroupStage groupStage, out AstExpression element)
84+
static bool IsOptimizableGroupingStage(AstStage groupingStage, out AstExpression element)
7685
{
77-
// { $group : { _id : ?, _elements : { $push : element } } }
78-
if (groupStage.Fields.Count == 1)
86+
if (groupingStage is AstGroupStage groupStage)
87+
{
88+
// { $group : { _id : ?, _elements : { $push : element } } }
89+
if (groupStage.Fields.Count == 1)
90+
{
91+
var field = groupStage.Fields[0];
92+
return IsElementsPush(field, out element);
93+
}
94+
}
95+
96+
if (groupingStage is AstBucketStage bucketStage)
97+
{
98+
// { $bucket : { groupBy : ?, boundaries : ?, default : ?, output : { _elements : { $push : element } } } }
99+
if (bucketStage.Output.Count == 1)
100+
{
101+
var output = bucketStage.Output[0];
102+
return IsElementsPush(output, out element);
103+
}
104+
}
105+
106+
if (groupingStage is AstBucketAutoStage bucketAutoStage)
79107
{
80-
var field = groupStage.Fields[0];
81-
if (field.Path == "_elements" &&
108+
// { $bucketAuto : { groupBy : ?, buckets : ?, granularity : ?, output : { _elements : { $push : element } } } }
109+
if (bucketAutoStage.Output.Count == 1)
110+
{
111+
var output = bucketAutoStage.Output[0];
112+
return IsElementsPush(output, out element);
113+
}
114+
}
115+
116+
element = null;
117+
return false;
118+
119+
static bool IsElementsPush(AstAccumulatorField field, out AstExpression element)
120+
{
121+
if (
122+
field.Path == "_elements" &&
82123
field.Value is AstUnaryAccumulatorExpression unaryAccumulatorExpression &&
83124
unaryAccumulatorExpression.Operator == AstUnaryAccumulatorOperator.Push)
84125
{
85126
element = unaryAccumulatorExpression.Arg;
86127
return true;
87128
}
129+
else
130+
{
131+
element = null;
132+
return false;
133+
}
88134
}
89-
90-
element = null;
91-
return false;
92135
}
93136

94137
static List<AstStage> GetFollowingStagesToOptimize(AstPipeline pipeline, int from)
@@ -135,7 +178,7 @@ static bool IsLastStageThatCanBeOptimized(AstStage stage)
135178
}
136179
}
137180

138-
private (AstNode, AstNode)[] OptimizeGroupAndFollowingStages(AstGroupStage groupStage, List<AstStage> followingStages)
181+
private (AstNode, AstNode)[] OptimizeGroupingAndFollowingStages(AstStage groupingStage, List<AstStage> followingStages)
139182
{
140183
var mappings = new List<(AstNode, AstNode)>();
141184

@@ -148,10 +191,21 @@ static bool IsLastStageThatCanBeOptimized(AstStage stage)
148191
}
149192
}
150193

151-
var newGroupStage = AstStage.Group(groupStage.Id, _accumulators);
152-
mappings.Add((groupStage, newGroupStage));
194+
var newGroupingStage = CreateNewGroupingStage(groupingStage, _accumulators);
195+
mappings.Add((groupingStage, newGroupingStage));
153196

154197
return mappings.ToArray();
198+
199+
static AstStage CreateNewGroupingStage(AstStage groupingStage, AccumulatorSet accumulators)
200+
{
201+
return groupingStage switch
202+
{
203+
AstGroupStage groupStage => AstStage.Group(groupStage.Id, accumulators),
204+
AstBucketStage bucketStage => AstStage.Bucket(bucketStage.GroupBy, bucketStage.Boundaries, bucketStage.Default, accumulators),
205+
AstBucketAutoStage bucketAutoStage => AstStage.BucketAuto(bucketAutoStage.GroupBy, bucketAutoStage.Buckets, bucketAutoStage.Granularity, accumulators),
206+
_ => throw new Exception($"Unexpected {nameof(groupingStage)} node type: {groupingStage.NodeType}.")
207+
};
208+
}
155209
}
156210

157211
private AstStage OptimizeFollowingStage(AstStage stage)

0 commit comments

Comments
 (0)