1 package emissary.kff;
2
3 import emissary.kff.KffFilter.FilterType;
4 import emissary.test.core.junit5.UnitTest;
5
6 import jakarta.annotation.Nullable;
7 import net.spy.memcached.MemcachedClient;
8 import net.spy.memcached.internal.GetFuture;
9 import net.spy.memcached.internal.OperationFuture;
10 import net.spy.memcached.ops.Operation;
11 import net.spy.memcached.ops.OperationStatus;
12 import org.apache.commons.lang3.NotImplementedException;
13 import org.junit.jupiter.api.AfterEach;
14 import org.junit.jupiter.api.BeforeEach;
15 import org.junit.jupiter.api.Test;
16 import org.mockito.ArgumentMatchers;
17 import org.mockito.stubbing.Answer;
18
19 import java.io.IOException;
20 import java.lang.reflect.Field;
21 import java.security.NoSuchAlgorithmException;
22 import java.util.ArrayList;
23 import java.util.List;
24 import java.util.concurrent.CountDownLatch;
25 import java.util.concurrent.ExecutionException;
26 import java.util.concurrent.Executors;
27 import java.util.concurrent.Future;
28 import java.util.concurrent.TimeUnit;
29 import java.util.concurrent.TimeoutException;
30
31 import static org.junit.jupiter.api.Assertions.assertEquals;
32 import static org.junit.jupiter.api.Assertions.assertFalse;
33 import static org.junit.jupiter.api.Assertions.assertThrows;
34 import static org.junit.jupiter.api.Assertions.assertTrue;
35 import static org.mockito.Mockito.mock;
36 import static org.mockito.Mockito.validateMockitoUsage;
37 import static org.mockito.Mockito.when;
38
39 class KffMemcachedTest extends UnitTest {
40
41 private static final String TEST_ID_WITH_SPACES = "TEST ID";
42 private static final String TEST_PAYLOAD = "TEST DATA";
43 private static final String TEST_UNFORMATTED_ID_HASH = "01e44cd59b2c0e8acbb99647d579f74f91bde66e4a243dc212a3c8e8739c9957";
44 private String expectedKey = "";
45 @Nullable
46 private MemcachedClient mockMemcachedClient = null;
47 private boolean isBinaryConnection = false;
48 @Nullable
49 private String cacheResult = null;
50
51 @BeforeEach
52 public void setup() {
53 mockMemcachedClient = createMockMemcachedClient();
54
55 }
56
57 @AfterEach
58 @Override
59 public void tearDown() throws Exception {
60 super.tearDown();
61 isBinaryConnection = false;
62 validateMockitoUsage();
63 }
64
65 @Test
66 void testKffMemcachedCreation() throws Exception {
67 KffMemcached mcdFilter = createTestFilter(Boolean.TRUE, Boolean.TRUE, TEST_ID_WITH_SPACES);
68 mcdFilter.setPreferredAlgorithm("SHA-256");
69 assertEquals("SHA-256", mcdFilter.getPreferredAlgorithm());
70 assertEquals("KFF", mcdFilter.getName());
71 assertEquals(FilterType.DUPLICATE, mcdFilter.getFilterType());
72 }
73
74 @Test
75 void testThrowsWithNonAsciiAndDups() throws Exception {
76 KffMemcached mcdFilter = createTestFilter(Boolean.TRUE, Boolean.TRUE, TEST_ID_WITH_SPACES);
77 ChecksumResults results = createSums(mcdFilter);
78 assertThrows(IllegalArgumentException.class, () -> {
79 mcdFilter.check(TEST_ID_WITH_SPACES, results);
80 });
81 }
82
83
84 @Test
85 void testNoHitNoStoreIdDupe() throws Exception {
86 KffMemcached mcdFilter = createTestFilter(Boolean.FALSE, Boolean.FALSE, TEST_UNFORMATTED_ID_HASH);
87 assertFalse(mcdFilter.check(TEST_ID_WITH_SPACES, createSums(mcdFilter)), "Filter should not hit");
88 }
89
90 @Test
91 void testHitNoStoreIdDupe() throws Exception {
92 KffMemcached mcdFilter = createTestFilter(Boolean.FALSE, Boolean.TRUE, null);
93 assertTrue(mcdFilter.check(TEST_ID_WITH_SPACES, createSums(mcdFilter)), "Filter should hit");
94 }
95
96 @Test
97 void testNoHitWithStoreIdDupe() throws Exception {
98 KffMemcached mcdFilter = createTestFilter(Boolean.TRUE, Boolean.FALSE, TEST_UNFORMATTED_ID_HASH);
99 assertFalse(mcdFilter.check(TEST_ID_WITH_SPACES, createSums(mcdFilter)), "Filter should not hit");
100 }
101
102 @Test
103 void testHitWithStoreIdDupe() throws Exception {
104 isBinaryConnection = true;
105 KffMemcached mcdFilter = createTestFilter(Boolean.TRUE, Boolean.TRUE, TEST_ID_WITH_SPACES);
106 assertTrue(mcdFilter.check(TEST_ID_WITH_SPACES, createSums(mcdFilter)), "Filter should hit");
107 }
108
109 private static ChecksumResults createSums(KffMemcached mcd) throws NoSuchAlgorithmException {
110 List<String> kffalgs = new ArrayList<>();
111 kffalgs.add(mcd.getPreferredAlgorithm());
112 return new ChecksumCalculator(kffalgs).digest(TEST_PAYLOAD.getBytes());
113 }
114
115
116 private KffMemcached createTestFilter(Boolean storeIdDupe, boolean simulateHit, @Nullable String expectedKey)
117 throws IOException, NoSuchFieldException,
118 IllegalAccessException {
119 KffMemcached filter = new KffMemcached(TEST_ID_WITH_SPACES, "KFF", FilterType.DUPLICATE, mockMemcachedClient);
120 setPrivateMembersForTesting(filter, storeIdDupe);
121 if (simulateHit) {
122 cacheResult = "FAKE FIND";
123 } else {
124 cacheResult = null;
125 }
126 this.expectedKey = expectedKey;
127 return filter;
128 }
129
130 private static void checkForValidAscii(String key) {
131
132
133 if (key.length() > 250 || key.contains(" ") || key.contains("\n")) {
134 throw new IllegalArgumentException("Invalid Key for ASCII Memcached");
135 }
136 }
137
138 private static void setPrivateMembersForTesting(KffMemcached cacheFilter, @Nullable Boolean storeIdDupe)
139 throws NoSuchFieldException, IllegalAccessException {
140
141
142 if (storeIdDupe != null) {
143 Field storeIdDupeField = KffMemcached.class.getDeclaredField("storeIdDupe");
144 storeIdDupeField.setAccessible(true);
145 storeIdDupeField.set(cacheFilter, storeIdDupe);
146 }
147
148 }
149
150 private MemcachedClient createMockMemcachedClient() {
151
152 MemcachedClient localMockMemcachedClient = mock(MemcachedClient.class);
153
154 when(localMockMemcachedClient.asyncGet(ArgumentMatchers.anyString())).thenAnswer((Answer<TestGetFuture<Object>>) invocation -> {
155 Object[] args = invocation.getArguments();
156 return new TestGetFuture<>(new CountDownLatch(1), 500, (String) args[0]);
157 });
158
159 when(localMockMemcachedClient.set(ArgumentMatchers.anyString(), ArgumentMatchers.anyInt(), ArgumentMatchers.any()))
160 .thenAnswer((Answer<Future<Boolean>>) invocation -> {
161 Object[] args = invocation.getArguments();
162 String key = (String) args[0];
163
164 if (!key.equals(expectedKey)) {
165 throw new Exception("Key :" + key + " not equal to expected key: " + expectedKey);
166 }
167
168 if (!isBinaryConnection) {
169 checkForValidAscii(key);
170 }
171
172 return new OperationFuture<>(key, new CountDownLatch(1), 500, Executors.newFixedThreadPool(1));
173 });
174
175 return localMockMemcachedClient;
176 }
177
178 private class TestGetFuture<T> extends GetFuture<T> {
179
180 public TestGetFuture(CountDownLatch l, long opTimeout, String key) {
181 super(l, opTimeout, key, Executors.newFixedThreadPool(1));
182 }
183
184 @Override
185 public boolean cancel(boolean ign) {
186 return true;
187 }
188
189 @Override
190 public T get() throws InterruptedException, ExecutionException {
191 return null;
192 }
193
194 @Override
195 @SuppressWarnings("unchecked")
196 public T get(long duration, TimeUnit units) throws InterruptedException, TimeoutException, ExecutionException {
197 return (T) cacheResult;
198 }
199
200 @Override
201 public OperationStatus getStatus() {
202 return new OperationStatus(true, "Done");
203 }
204
205 @Override
206 public void set(Future<T> d, OperationStatus s) {
207 throw new NotImplementedException("don't call set");
208 }
209
210 @Override
211 public void setOperation(Operation to) {
212 throw new NotImplementedException("don't call setOperation");
213 }
214
215 @Override
216 public boolean isCancelled() {
217 return true;
218 }
219
220 @Override
221 public boolean isDone() {
222 return true;
223 }
224
225 }
226
227 }