View Javadoc
1   package emissary.kff;
2   
3   import emissary.kff.KffFilter.FilterType;
4   import emissary.test.core.junit5.UnitTest;
5   
6   import jakarta.annotation.Nullable;
7   import net.spy.memcached.MemcachedClient;
8   import net.spy.memcached.internal.GetFuture;
9   import net.spy.memcached.internal.OperationFuture;
10  import net.spy.memcached.ops.Operation;
11  import net.spy.memcached.ops.OperationStatus;
12  import org.apache.commons.lang3.NotImplementedException;
13  import org.junit.jupiter.api.AfterEach;
14  import org.junit.jupiter.api.BeforeEach;
15  import org.junit.jupiter.api.Test;
16  import org.mockito.ArgumentMatchers;
17  import org.mockito.stubbing.Answer;
18  
19  import java.io.IOException;
20  import java.lang.reflect.Field;
21  import java.security.NoSuchAlgorithmException;
22  import java.util.ArrayList;
23  import java.util.List;
24  import java.util.concurrent.CountDownLatch;
25  import java.util.concurrent.ExecutionException;
26  import java.util.concurrent.Executors;
27  import java.util.concurrent.Future;
28  import java.util.concurrent.TimeUnit;
29  import java.util.concurrent.TimeoutException;
30  
31  import static org.junit.jupiter.api.Assertions.assertEquals;
32  import static org.junit.jupiter.api.Assertions.assertFalse;
33  import static org.junit.jupiter.api.Assertions.assertThrows;
34  import static org.junit.jupiter.api.Assertions.assertTrue;
35  import static org.mockito.Mockito.mock;
36  import static org.mockito.Mockito.validateMockitoUsage;
37  import static org.mockito.Mockito.when;
38  
39  class KffMemcachedTest extends UnitTest {
40  
41      private static final String TEST_ID_WITH_SPACES = "TEST ID";
42      private static final String TEST_PAYLOAD = "TEST DATA";
43      private static final String TEST_UNFORMATTED_ID_HASH = "01e44cd59b2c0e8acbb99647d579f74f91bde66e4a243dc212a3c8e8739c9957";
44      private String expectedKey = "";
45      @Nullable
46      private MemcachedClient mockMemcachedClient = null;
47      private boolean isBinaryConnection = false;
48      @Nullable
49      private String cacheResult = null;
50  
51      @BeforeEach
52      public void setup() {
53          mockMemcachedClient = createMockMemcachedClient();
54  
55      }
56  
57      @AfterEach
58      @Override
59      public void tearDown() throws Exception {
60          super.tearDown();
61          isBinaryConnection = false;
62          validateMockitoUsage();
63      }
64  
65      @Test
66      void testKffMemcachedCreation() throws Exception {
67          KffMemcached mcdFilter = createTestFilter(Boolean.TRUE, Boolean.TRUE, TEST_ID_WITH_SPACES);
68          mcdFilter.setPreferredAlgorithm("SHA-256");
69          assertEquals("SHA-256", mcdFilter.getPreferredAlgorithm());
70          assertEquals("KFF", mcdFilter.getName());
71          assertEquals(FilterType.DUPLICATE, mcdFilter.getFilterType());
72      }
73  
74      @Test
75      void testThrowsWithNonAsciiAndDups() throws Exception {
76          KffMemcached mcdFilter = createTestFilter(Boolean.TRUE, Boolean.TRUE, TEST_ID_WITH_SPACES);
77          ChecksumResults results = createSums(mcdFilter);
78          assertThrows(IllegalArgumentException.class, () -> {
79              mcdFilter.check(TEST_ID_WITH_SPACES, results);
80          });
81      }
82  
83  
84      @Test
85      void testNoHitNoStoreIdDupe() throws Exception {
86          KffMemcached mcdFilter = createTestFilter(Boolean.FALSE, Boolean.FALSE, TEST_UNFORMATTED_ID_HASH);
87          assertFalse(mcdFilter.check(TEST_ID_WITH_SPACES, createSums(mcdFilter)), "Filter should not hit");
88      }
89  
90      @Test
91      void testHitNoStoreIdDupe() throws Exception {
92          KffMemcached mcdFilter = createTestFilter(Boolean.FALSE, Boolean.TRUE, null);
93          assertTrue(mcdFilter.check(TEST_ID_WITH_SPACES, createSums(mcdFilter)), "Filter should hit");
94      }
95  
96      @Test
97      void testNoHitWithStoreIdDupe() throws Exception {
98          KffMemcached mcdFilter = createTestFilter(Boolean.TRUE, Boolean.FALSE, TEST_UNFORMATTED_ID_HASH);
99          assertFalse(mcdFilter.check(TEST_ID_WITH_SPACES, createSums(mcdFilter)), "Filter should not hit");
100     }
101 
102     @Test
103     void testHitWithStoreIdDupe() throws Exception {
104         isBinaryConnection = true;
105         KffMemcached mcdFilter = createTestFilter(Boolean.TRUE, Boolean.TRUE, TEST_ID_WITH_SPACES);
106         assertTrue(mcdFilter.check(TEST_ID_WITH_SPACES, createSums(mcdFilter)), "Filter should hit");
107     }
108 
109     private static ChecksumResults createSums(KffMemcached mcd) throws NoSuchAlgorithmException {
110         List<String> kffalgs = new ArrayList<>();
111         kffalgs.add(mcd.getPreferredAlgorithm());
112         return new ChecksumCalculator(kffalgs).digest(TEST_PAYLOAD.getBytes());
113     }
114 
115 
116     private KffMemcached createTestFilter(Boolean storeIdDupe, boolean simulateHit, @Nullable String expectedKey)
117             throws IOException, NoSuchFieldException,
118             IllegalAccessException {
119         KffMemcached filter = new KffMemcached(TEST_ID_WITH_SPACES, "KFF", FilterType.DUPLICATE, mockMemcachedClient);
120         setPrivateMembersForTesting(filter, storeIdDupe);
121         if (simulateHit) {
122             cacheResult = "FAKE FIND";
123         } else {
124             cacheResult = null;
125         }
126         this.expectedKey = expectedKey;
127         return filter;
128     }
129 
130     private static void checkForValidAscii(String key) {
131         // Arbitrary string up to 250 bytes in length. No space or newlines for
132         // ASCII mode
133         if (key.length() > 250 || key.contains(" ") || key.contains("\n")) {
134             throw new IllegalArgumentException("Invalid Key for ASCII Memcached");
135         }
136     }
137 
138     private static void setPrivateMembersForTesting(KffMemcached cacheFilter, @Nullable Boolean storeIdDupe)
139             throws NoSuchFieldException, IllegalAccessException {
140 
141         // Overriding the protected attribute of the field for testing
142         if (storeIdDupe != null) {
143             Field storeIdDupeField = KffMemcached.class.getDeclaredField("storeIdDupe");
144             storeIdDupeField.setAccessible(true);
145             storeIdDupeField.set(cacheFilter, storeIdDupe);
146         }
147 
148     }
149 
150     private MemcachedClient createMockMemcachedClient() {
151 
152         MemcachedClient localMockMemcachedClient = mock(MemcachedClient.class);
153 
154         when(localMockMemcachedClient.asyncGet(ArgumentMatchers.anyString())).thenAnswer((Answer<TestGetFuture<Object>>) invocation -> {
155             Object[] args = invocation.getArguments();
156             return new TestGetFuture<>(new CountDownLatch(1), 500, (String) args[0]);
157         });
158 
159         when(localMockMemcachedClient.set(ArgumentMatchers.anyString(), ArgumentMatchers.anyInt(), ArgumentMatchers.any()))
160                 .thenAnswer((Answer<Future<Boolean>>) invocation -> {
161                     Object[] args = invocation.getArguments();
162                     String key = (String) args[0];
163 
164                     if (!key.equals(expectedKey)) {
165                         throw new Exception("Key :" + key + " not equal to expected key: " + expectedKey);
166                     }
167 
168                     if (!isBinaryConnection) {
169                         checkForValidAscii(key);
170                     }
171 
172                     return new OperationFuture<>(key, new CountDownLatch(1), 500, Executors.newFixedThreadPool(1));
173                 });
174 
175         return localMockMemcachedClient;
176     }
177 
178     private class TestGetFuture<T> extends GetFuture<T> {
179 
180         public TestGetFuture(CountDownLatch l, long opTimeout, String key) {
181             super(l, opTimeout, key, Executors.newFixedThreadPool(1));
182         }
183 
184         @Override
185         public boolean cancel(boolean ign) {
186             return true;
187         }
188 
189         @Override
190         public T get() throws InterruptedException, ExecutionException {
191             return null;
192         }
193 
194         @Override
195         @SuppressWarnings("unchecked")
196         public T get(long duration, TimeUnit units) throws InterruptedException, TimeoutException, ExecutionException {
197             return (T) cacheResult;
198         }
199 
200         @Override
201         public OperationStatus getStatus() {
202             return new OperationStatus(true, "Done");
203         }
204 
205         @Override
206         public void set(Future<T> d, OperationStatus s) {
207             throw new NotImplementedException("don't call set");
208         }
209 
210         @Override
211         public void setOperation(Operation to) {
212             throw new NotImplementedException("don't call setOperation");
213         }
214 
215         @Override
216         public boolean isCancelled() {
217             return true;
218         }
219 
220         @Override
221         public boolean isDone() {
222             return true;
223         }
224 
225     }
226 
227 }