1 package emissary.test.core.junit5.extensions;
2
3 import org.junit.jupiter.api.TestTemplate;
4 import org.junit.jupiter.api.extension.ExtendWith;
5 import org.junit.jupiter.api.extension.ExtensionContext;
6 import org.junit.jupiter.api.extension.TestExecutionExceptionHandler;
7 import org.junit.jupiter.api.extension.TestTemplateInvocationContext;
8 import org.junit.jupiter.api.extension.TestTemplateInvocationContextProvider;
9 import org.junit.jupiter.api.parallel.Execution;
10 import org.junit.platform.commons.util.AnnotationUtils;
11 import org.junit.platform.commons.util.Preconditions;
12 import org.opentest4j.TestAbortedException;
13
14 import java.lang.annotation.Retention;
15 import java.lang.annotation.RetentionPolicy;
16 import java.lang.annotation.Target;
17 import java.util.Iterator;
18 import java.util.Locale;
19 import java.util.NoSuchElementException;
20 import java.util.Spliterator;
21 import java.util.Spliterators;
22 import java.util.stream.Stream;
23 import java.util.stream.StreamSupport;
24
25 import static java.lang.annotation.ElementType.ANNOTATION_TYPE;
26 import static java.lang.annotation.ElementType.METHOD;
27 import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD;
28
29
30
31
32
33
34 @Target({METHOD, ANNOTATION_TYPE})
35 @Retention(RetentionPolicy.RUNTIME)
36 @Execution(SAME_THREAD)
37 @ExtendWith(TestAttempts.AttemptTestExtension.class)
38 @TestTemplate
39 public @interface TestAttempts {
40
41
42 int value() default 3;
43
44
45
46
47 class AttemptTestExtension implements TestTemplateInvocationContextProvider, TestExecutionExceptionHandler {
48
49 protected static final ExtensionContext.Namespace EXTENSION_CONTEXT_NAMESPACE = ExtensionContext.Namespace.create(AttemptTestExtension.class);
50
51 @Override
52 public boolean supportsTestTemplate(ExtensionContext context) {
53
54 return AnnotationUtils.isAnnotated(context.getRequiredTestMethod(), TestAttempts.class);
55 }
56
57 @Override
58 public Stream<TestTemplateInvocationContext> provideTestTemplateInvocationContexts(ExtensionContext context) {
59 return StreamSupport.stream(splitTestTemplateInvocationContexts(context), false);
60 }
61
62 @Override
63 public void handleTestExecutionException(ExtensionContext context, Throwable throwable) {
64 handleTestAttemptFailure(context.getParent().orElseThrow(() -> new UnsupportedOperationException("No template context found")),
65 throwable);
66 }
67
68 protected static Spliterator<TestTemplateInvocationContext> splitTestTemplateInvocationContexts(ExtensionContext context) {
69 return Spliterators.spliteratorUnknownSize(getTestTemplateInvocationContextProvider(context), Spliterator.ORDERED);
70 }
71
72 protected static AcceptFirstPassingAttempt getTestTemplateInvocationContextProvider(ExtensionContext context) {
73 ExtensionContext.Store store = context.getStore(EXTENSION_CONTEXT_NAMESPACE);
74 String key = context.getRequiredTestMethod().toString();
75 return store.getOrComputeIfAbsent(key, k -> createTestTemplateInvocationContextProvider(context), AcceptFirstPassingAttempt.class);
76 }
77
78 protected static AcceptFirstPassingAttempt createTestTemplateInvocationContextProvider(ExtensionContext context) {
79 TestAttempts retryTest = AnnotationUtils.findAnnotation(context.getRequiredTestMethod(), TestAttempts.class)
80 .orElseThrow(() -> new UnsupportedOperationException("Missing @TestAttempts annotation."));
81 int maxAttempts = retryTest.value();
82 Preconditions.condition(maxAttempts > 0, "Total test attempts need to be greater than 0");
83 return new AcceptFirstPassingAttempt(maxAttempts);
84 }
85
86 protected void handleTestAttemptFailure(ExtensionContext context, Throwable throwable) {
87 AcceptFirstPassingAttempt testAttempt = getTestTemplateInvocationContextProvider(context);
88 testAttempt.failed();
89
90 if (testAttempt.hasNext()) {
91
92 throw new TestAbortedException(
93 String.format(Locale.getDefault(), "Test attempt %d of %d failed, retrying...", testAttempt.exceptions,
94 testAttempt.maxAttempts),
95 throwable);
96 } else {
97
98 throw new AssertionError(
99 String.format(Locale.getDefault(), "Test attempt %d of %d failed", testAttempt.exceptions, testAttempt.maxAttempts),
100 throwable);
101 }
102 }
103
104
105
106
107 private static class AcceptFirstPassingAttempt implements Iterator<TestTemplateInvocationContext> {
108
109 protected final int maxAttempts;
110 protected int attempts;
111 protected int exceptions;
112
113 private AcceptFirstPassingAttempt(int maxAttempts) {
114 this.maxAttempts = maxAttempts;
115 }
116
117
118
119
120
121
122
123 @Override
124 public boolean hasNext() {
125 return isFirstAttempt() || (hasNoPassingAttempts() && hasMoreAttempts());
126 }
127
128
129
130
131
132
133
134 @Override
135 public TestTemplateInvocationContext next() {
136 if (!hasNext()) {
137 throw new NoSuchElementException();
138 }
139
140 ++attempts;
141 return new TestAttemptsInvocationContext(maxAttempts);
142 }
143
144 void failed() {
145 ++exceptions;
146 }
147
148 boolean isFirstAttempt() {
149 return attempts == 0;
150 }
151
152 boolean hasNoPassingAttempts() {
153 return attempts == exceptions;
154 }
155
156 boolean hasMoreAttempts() {
157 return attempts != maxAttempts;
158 }
159 }
160
161
162
163
164 static class TestAttemptsInvocationContext implements TestTemplateInvocationContext {
165
166 final int maxAttempts;
167
168 public TestAttemptsInvocationContext(int maxAttempts) {
169 this.maxAttempts = maxAttempts;
170 }
171
172 @Override
173 public String getDisplayName(int invocationIndex) {
174 return "Attempt " + invocationIndex + " of " + maxAttempts;
175 }
176 }
177 }
178 }
179