Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(query): use quickselect instead of sorting while pagination #8995

Merged
merged 13 commits into from
Sep 14, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed comments
Harshil Goel committed Sep 14, 2023
commit cf8dcc212d81de8b46b707832917ed7d827c8080
133 changes: 133 additions & 0 deletions types/select.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright 2016-2023 Dgraph Labs, Inc. and Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package types

// Below functions are taken from go's sort library zsortinterface.go
func insertionSort(data byValue, a, b int) {
for i := a + 1; i < b; i++ {
for j := i; j > a && data.Less(j, j-1); j-- {
data.Swap(j, j-1)
}
}
}

func order2(data byValue, a, b int) (int, int) {
if data.Less(b, a) {
return b, a
}
return a, b
}

func median(data byValue, a, b, c int) int {
a, b = order2(data, a, b)
b, c = order2(data, b, c)
a, b = order2(data, a, b)
return b
}

func medianAdjacent(data byValue, a int) int {
return median(data, a-1, a, a+1)
}

// [shortestNinther,∞): uses the Tukey ninther method.
func choosePivot(data byValue, a, b int) (pivot int) {
const (
shortestNinther = 50
maxSwaps = 4 * 3
)

l := b - a

var (
i = a + l/4*1
j = a + l/4*2
k = a + l/4*3
)

if l >= 8 {
if l >= shortestNinther {
// Tukey ninther method, the idea came from Rust's implementation.
i = medianAdjacent(data, i)
j = medianAdjacent(data, j)
k = medianAdjacent(data, k)
}
// Find the median among i, j, k and stores it into j.
j = median(data, i, j, k)
}

return j
}

func partition(data byValue, a, b, pivot int) int {
data.Swap(a, pivot)
i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned

for i <= j && data.Less(i, a) {
i++
}
for i <= j && !data.Less(j, a) {
j--
}
if i > j {
data.Swap(j, a)
return j
}
data.Swap(i, j)
i++
j--

for {
for i <= j && data.Less(i, a) {
i++
}
for i <= j && !data.Less(j, a) {
j--
}
if i > j {
break
}
data.Swap(i, j)
i++
j--
}
data.Swap(j, a)
return j
}

func QuickSelect(data byValue, low, high, k int) {
var pivotIndex int

for {
if low >= high {
return
} else if high-low <= 8 {
insertionSort(data, low, high+1)
return
}

pivotIndex = choosePivot(data, low, high)
pivotIndex = partition(data, low, high, pivotIndex)

if k < pivotIndex {
high = pivotIndex - 1
} else if k > pivotIndex {
low = pivotIndex + 1
} else {
return
}
}
}
117 changes: 1 addition & 116 deletions types/sort.go
Original file line number Diff line number Diff line change
@@ -110,121 +110,6 @@ func IsSortable(tid TypeID) bool {
}
}

func insertionSort(data byValue, a, b int) {
for i := a + 1; i < b; i++ {
for j := i; j > a && data.Less(j, j-1); j-- {
data.Swap(j, j-1)
}
}
}

func order2(data byValue, a, b int) (int, int) {
if data.Less(b, a) {
return b, a
}
return a, b
}

func median(data byValue, a, b, c int) int {
a, b = order2(data, a, b)
b, c = order2(data, b, c)
a, b = order2(data, a, b)
return b
}

func medianAdjacent(data byValue, a int) int {
return median(data, a-1, a, a+1)
}

// [shortestNinther,∞): uses the Tukey ninther method.
func choosePivot(data byValue, a, b int) (pivot int) {
const (
shortestNinther = 50
maxSwaps = 4 * 3
)

l := b - a

var (
i = a + l/4*1
j = a + l/4*2
k = a + l/4*3
)

if l >= 8 {
if l >= shortestNinther {
// Tukey ninther method, the idea came from Rust's implementation.
i = medianAdjacent(data, i)
j = medianAdjacent(data, j)
k = medianAdjacent(data, k)
}
// Find the median among i, j, k and stores it into j.
j = median(data, i, j, k)
}

return j
}

func partition(data byValue, a, b, pivot int) int {
data.Swap(a, pivot)
i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned

for i <= j && data.Less(i, a) {
i++
}
for i <= j && !data.Less(j, a) {
j--
}
if i > j {
data.Swap(j, a)
return j
}
data.Swap(i, j)
i++
j--

for {
for i <= j && data.Less(i, a) {
i++
}
for i <= j && !data.Less(j, a) {
j--
}
if i > j {
break
}
data.Swap(i, j)
i++
j--
}
data.Swap(j, a)
return j
}

func randomizedSelectionFinding(data byValue, low, high, k int) {
var pivotIndex int

for {
if low >= high {
return
} else if high-low <= 8 {
insertionSort(data, low, high+1)
return
}

pivotIndex = choosePivot(data, low, high)
pivotIndex = partition(data, low, high, pivotIndex)

if k < pivotIndex {
high = pivotIndex - 1
} else if k > pivotIndex {
low = pivotIndex + 1
} else {
return
}
}
}

// SortWithFacet sorts the given array in-place and considers the given facets to calculate
// the proper ordering.
func SortTopN(v [][]Val, ul *[]uint64, desc []bool, lang string, n int) error {
@@ -264,7 +149,7 @@ func SortTopN(v [][]Val, ul *[]uint64, desc []bool, lang string, n int) error {
if nul > n {
b1 := sortBase{v[:nul], desc, ul, nil, cl}
toBeSorted1 := byValue{b1}
randomizedSelectionFinding(toBeSorted1, 0, nul-1, n)
QuickSelect(toBeSorted1, 0, nul-1, n)
}
toBeSorted.values = toBeSorted.values[:n]
sort.Sort(toBeSorted)
1 change: 1 addition & 0 deletions worker/sort.go
Original file line number Diff line number Diff line change
@@ -455,6 +455,7 @@ func multiSort(ctx context.Context, r *sortresult, ts *pb.SortMessage) error {
}

start, end := x.PageRange(int(ts.Count), int(r.multiSortOffsets[i]), len(ul.Uids))
// TODO(harshil.goel): can be improved for other cases too
if end < len(ul.Uids)/2 {
//nolint:gosec
if err := types.SortTopN(vals, &ul.Uids, desc, "", end); err != nil {