-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcounter_test.go
153 lines (130 loc) · 3.2 KB
/
counter_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package xio_test
import (
"bytes"
"io"
"strings"
"sync"
"testing"
"github.com/ozanh/xio"
)
type counterTestCase struct {
name string
reader io.Reader
wantData []byte
wantCount uint64
wantErr error
concurrency int
}
func TestReadCounter(t *testing.T) {
var iface any = new(xio.ReadDataCounter[io.Reader])
if v, _ := iface.(xio.ReadCounter); v == nil {
t.Fatal("ReadDataCounter does not implement ReadCounter")
}
iface = new(xio.ReadDataAtomicCounter[io.Reader])
if v, _ := iface.(xio.ReadCounter); v == nil {
t.Fatal("ReadDataAtomicCounter does not implement ReadCounter")
}
}
func TestReadDataCounter(t *testing.T) {
tests := []counterTestCase{
{
name: "normal string read",
reader: strings.NewReader("hello"),
wantData: []byte("hello"),
wantCount: 5,
},
{
name: "empty string read",
reader: strings.NewReader(""),
wantData: []byte(""),
wantCount: 0,
},
{
name: "multiple reads",
reader: bytes.NewReader([]byte("hello world")),
wantData: []byte("hello world"),
wantCount: 11,
},
{
name: "error case",
reader: &xio.ErrOrEofReader{Err: xio.ErrCmpReadError},
wantErr: xio.ErrCmpReadError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
counter := xio.NewReadDataCounter(tt.reader)
data, err := io.ReadAll(counter)
if tt.wantData != nil && !bytes.Equal(data, tt.wantData) {
t.Errorf("expected data '%s', got '%s'", tt.wantData, data)
}
if tt.wantErr != nil && err != tt.wantErr {
t.Errorf("expected error %v, got %v", tt.wantErr, err)
}
if counter.Count() != tt.wantCount {
t.Errorf("expected count %d, got %d", tt.wantCount, counter.Count())
}
})
}
}
func TestReadDataAtomicCounter(t *testing.T) {
tests := []counterTestCase{
{
name: "normal string read",
reader: strings.NewReader("hello"),
wantCount: 5,
},
{
name: "empty string read",
reader: strings.NewReader(""),
wantCount: 0,
},
{
name: "concurrent reads",
reader: strings.NewReader("hello world"),
wantCount: 11,
concurrency: 10,
},
{
name: "error case",
reader: &xio.ErrOrEofReader{Err: xio.ErrCmpReadError},
wantErr: xio.ErrCmpReadError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
counter := xio.NewReadDataAtomicCounter(&syncReader{r: tt.reader})
if tt.concurrency > 0 {
var wg sync.WaitGroup
for i := 0; i < tt.concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = counter.Read(make([]byte, 2))
}()
}
wg.Wait()
} else {
data, err := io.ReadAll(counter)
if tt.wantData != nil && !bytes.Equal(data, tt.wantData) {
t.Errorf("expected data '%s', got '%s'", tt.wantData, data)
}
if tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
t.Errorf("expected error %v, got %v", tt.wantErr, err)
}
}
if counter.Count() != tt.wantCount {
t.Errorf("expected count %d, got %d", tt.wantCount, counter.Count())
}
})
}
}
type syncReader struct {
mu sync.Mutex
r io.Reader
}
func (r *syncReader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.r.Read(p)
}