View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *   http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing,
13   * software distributed under the License is distributed on an
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   * KIND, either express or implied.  See the License for the
16   * specific language governing permissions and limitations
17   * under the License.
18   */
19  package org.apache.syncope.core.persistence.jpa.dao;
20  
21  import static org.assertj.core.api.Assertions.assertThat;
22  import static org.junit.jupiter.api.Assertions.assertEquals;
23  import static org.junit.jupiter.api.Assertions.assertFalse;
24  import static org.mockito.ArgumentMatchers.any;
25  import static org.mockito.ArgumentMatchers.anyString;
26  import static org.mockito.ArgumentMatchers.eq;
27  import static org.mockito.Mockito.doAnswer;
28  import static org.mockito.Mockito.mock;
29  import static org.mockito.Mockito.when;
30  
31  import java.io.IOException;
32  import java.util.List;
33  import java.util.Optional;
34  import java.util.Set;
35  import org.apache.commons.lang3.tuple.Triple;
36  import org.apache.syncope.common.lib.SyncopeConstants;
37  import org.apache.syncope.common.lib.types.AnyTypeKind;
38  import org.apache.syncope.core.persistence.api.attrvalue.validation.PlainAttrValidationManager;
39  import org.apache.syncope.core.persistence.api.dao.DynRealmDAO;
40  import org.apache.syncope.core.persistence.api.dao.GroupDAO;
41  import org.apache.syncope.core.persistence.api.dao.RealmDAO;
42  import org.apache.syncope.core.persistence.api.dao.search.AnyCond;
43  import org.apache.syncope.core.persistence.api.dao.search.AttrCond;
44  import org.apache.syncope.core.persistence.api.dao.search.SearchCond;
45  import org.apache.syncope.core.persistence.api.entity.AnyUtils;
46  import org.apache.syncope.core.persistence.api.entity.AnyUtilsFactory;
47  import org.apache.syncope.core.persistence.api.entity.DynRealm;
48  import org.apache.syncope.core.persistence.api.entity.EntityFactory;
49  import org.apache.syncope.core.persistence.api.entity.PlainAttrValue;
50  import org.apache.syncope.core.persistence.api.entity.PlainSchema;
51  import org.apache.syncope.core.persistence.api.entity.Realm;
52  import org.apache.syncope.core.persistence.jpa.entity.JPAPlainSchema;
53  import org.apache.syncope.core.persistence.jpa.entity.user.JPAUPlainAttrValue;
54  import org.apache.syncope.core.persistence.jpa.entity.user.JPAUser;
55  import org.apache.syncope.core.provisioning.api.utils.RealmUtils;
56  import org.apache.syncope.core.spring.security.AuthContextUtils;
57  import org.apache.syncope.ext.opensearch.client.OpenSearchUtils;
58  import org.junit.jupiter.api.BeforeEach;
59  import org.junit.jupiter.api.Test;
60  import org.junit.jupiter.api.extension.ExtendWith;
61  import org.mockito.Mock;
62  import org.mockito.MockedStatic;
63  import org.mockito.Mockito;
64  import org.mockito.junit.jupiter.MockitoExtension;
65  import org.opensearch.client.opensearch._types.FieldValue;
66  import org.opensearch.client.opensearch._types.SearchType;
67  import org.opensearch.client.opensearch._types.query_dsl.BoolQuery;
68  import org.opensearch.client.opensearch._types.query_dsl.DisMaxQuery;
69  import org.opensearch.client.opensearch._types.query_dsl.Query;
70  import org.opensearch.client.opensearch._types.query_dsl.QueryBuilders;
71  import org.opensearch.client.opensearch.core.SearchRequest;
72  import org.springframework.util.ReflectionUtils;
73  
74  @ExtendWith(MockitoExtension.class)
75  public class OpenSearchAnySearchDAOTest {
76  
77      @Mock
78      private RealmDAO realmDAO;
79  
80      @Mock
81      private DynRealmDAO dynRealmDAO;
82  
83      @Mock
84      private GroupDAO groupDAO;
85  
86      @Mock
87      private EntityFactory entityFactory;
88  
89      @Mock
90      private AnyUtilsFactory anyUtilsFactory;
91  
92      @Mock
93      private PlainAttrValidationManager validator;
94  
95      private OpenSearchAnySearchDAO searchDAO;
96  
97      @BeforeEach
98      protected void setupSearchDAO() {
99          searchDAO = new OpenSearchAnySearchDAO(
100                 realmDAO,
101                 dynRealmDAO,
102                 null,
103                 groupDAO,
104                 null,
105                 null,
106                 entityFactory,
107                 anyUtilsFactory,
108                 validator,
109                 null,
110                 10000);
111     }
112 
113     @Test
114     public void getAdminRealmsFilter4realm() throws IOException {
115         // 1. mock
116         Realm root = mock(Realm.class);
117         when(root.getFullPath()).thenReturn(SyncopeConstants.ROOT_REALM);
118 
119         when(realmDAO.findByFullPath(SyncopeConstants.ROOT_REALM)).thenReturn(root);
120         when(realmDAO.findDescendants(eq(SyncopeConstants.ROOT_REALM), anyString())).thenReturn(List.of("rootKey"));
121 
122         // 2. test
123         Set<String> adminRealms = Set.of(SyncopeConstants.ROOT_REALM);
124         Triple<Optional<Query>, Set<String>, Set<String>> filter =
125                 searchDAO.getAdminRealmsFilter(root, true, adminRealms, AnyTypeKind.USER);
126 
127         assertThat(new Query.Builder().disMax(QueryBuilders.disMax().queries(
128                 new Query.Builder().term(QueryBuilders.term().field("realm").value(FieldValue.of("rootKey")).build()).
129                         build()).build()).build()).
130                 usingRecursiveComparison().isEqualTo(filter.getLeft().get());
131         assertEquals(Set.of(), filter.getMiddle());
132         assertEquals(Set.of(), filter.getRight());
133     }
134 
135     @Test
136     public void getAdminRealmsFilter4dynRealm() {
137         // 1. mock
138         DynRealm dyn = mock(DynRealm.class);
139         when(dyn.getKey()).thenReturn("dyn");
140 
141         when(dynRealmDAO.find("dyn")).thenReturn(dyn);
142 
143         // 2. test
144         Set<String> adminRealms = Set.of("dyn");
145         Triple<Optional<Query>, Set<String>, Set<String>> filter =
146                 searchDAO.getAdminRealmsFilter(realmDAO.getRoot(), true, adminRealms, AnyTypeKind.USER);
147         assertFalse(filter.getLeft().isPresent());
148         assertEquals(Set.of("dyn"), filter.getMiddle());
149         assertEquals(Set.of(), filter.getRight());
150     }
151 
152     @Test
153     public void getAdminRealmsFilter4groupOwner() {
154         Set<String> adminRealms = Set.of(RealmUtils.getGroupOwnerRealm("/any", "groupKey"));
155         Triple<Optional<Query>, Set<String>, Set<String>> filter =
156                 searchDAO.getAdminRealmsFilter(realmDAO.getRoot(), true, adminRealms, AnyTypeKind.USER);
157         assertFalse(filter.getLeft().isPresent());
158         assertEquals(Set.of(), filter.getMiddle());
159         assertEquals(Set.of("groupKey"), filter.getRight());
160     }
161 
162     @Test
163     public void searchRequest4groupOwner() throws IOException {
164         // 1. mock
165         AnyUtils anyUtils = mock(AnyUtils.class);
166         when(anyUtils.getField("key")).thenReturn(ReflectionUtils.findField(JPAUser.class, "id"));
167         when(anyUtils.newPlainAttrValue()).thenReturn(new JPAUPlainAttrValue());
168 
169         when(anyUtilsFactory.getInstance(AnyTypeKind.USER)).thenReturn(anyUtils);
170 
171         when(entityFactory.newEntity(PlainSchema.class)).thenReturn(new JPAPlainSchema());
172 
173         when(groupDAO.findKey("groupKey")).thenReturn("groupKey");
174 
175         try (MockedStatic<OpenSearchUtils> utils = Mockito.mockStatic(OpenSearchUtils.class)) {
176             utils.when(() -> OpenSearchUtils.getAnyIndex(
177                     SyncopeConstants.MASTER_DOMAIN, AnyTypeKind.USER)).thenReturn("master_user");
178 
179             // 2. test
180             Set<String> adminRealms = Set.of(RealmUtils.getGroupOwnerRealm("/any", "groupKey"));
181 
182             AnyCond anyCond = new AnyCond(AttrCond.Type.ISNOTNULL);
183             anyCond.setSchema("key");
184 
185             SearchRequest request = new SearchRequest.Builder().
186                     index(OpenSearchUtils.getAnyIndex(AuthContextUtils.getDomain(), AnyTypeKind.USER)).
187                     searchType(SearchType.QueryThenFetch).
188                     query(searchDAO.getQuery(realmDAO.findByFullPath("/any"), true,
189                             adminRealms, SearchCond.getLeaf(anyCond), AnyTypeKind.USER)).
190                     from(1).
191                     size(10).
192                     build();
193 
194             assertThat(
195                     new Query.Builder().bool(QueryBuilders.bool().
196                             must(new Query.Builder().exists(QueryBuilders.exists().field("id").build()).build()).
197                             must(new Query.Builder().term(QueryBuilders.term().field("memberships").
198                                     value(FieldValue.of("groupKey")).build()).build()).build()).build()).
199                     usingRecursiveComparison().
200                     isEqualTo(request.query());
201         }
202     }
203 
204     @Test
205     public void issueSYNCOPE1725() throws IOException {
206         // 1. mock
207         AnyUtils anyUtils = mock(AnyUtils.class);
208         when(anyUtils.getField("key")).thenReturn(ReflectionUtils.findField(JPAUser.class, "id"));
209         JPAUPlainAttrValue value = new JPAUPlainAttrValue();
210         when(anyUtils.newPlainAttrValue()).thenReturn(value);
211 
212         when(anyUtilsFactory.getInstance(AnyTypeKind.USER)).thenReturn(anyUtils);
213 
214         when(entityFactory.newEntity(PlainSchema.class)).thenReturn(new JPAPlainSchema());
215 
216         doAnswer(ic -> {
217             value.setStringValue(ic.getArgument(1));
218             return null;
219         }).when(validator).validate(any(PlainSchema.class), anyString(), any(PlainAttrValue.class));
220 
221         AnyCond cond1 = new AnyCond(AttrCond.Type.EQ);
222         cond1.setSchema("key");
223         cond1.setExpression("1");
224 
225         AnyCond cond2 = new AnyCond(AttrCond.Type.EQ);
226         cond2.setSchema("key");
227         cond2.setExpression("2");
228 
229         AnyCond cond3 = new AnyCond(AttrCond.Type.EQ);
230         cond3.setSchema("key");
231         cond3.setExpression("3");
232 
233         AnyCond cond4 = new AnyCond(AttrCond.Type.EQ);
234         cond4.setSchema("key");
235         cond4.setExpression("4");
236 
237         AnyCond cond5 = new AnyCond(AttrCond.Type.EQ);
238         cond5.setSchema("key");
239         cond5.setExpression("5");
240 
241         AnyCond cond6 = new AnyCond(AttrCond.Type.EQ);
242         cond6.setSchema("key");
243         cond6.setExpression("6");
244 
245         try (MockedStatic<OpenSearchUtils> utils = Mockito.mockStatic(OpenSearchUtils.class)) {
246             utils.when(() -> OpenSearchUtils.getAnyIndex(
247                     SyncopeConstants.MASTER_DOMAIN, AnyTypeKind.USER)).thenReturn("master_user");
248 
249             Query query = searchDAO.getQuery(
250                     SearchCond.getAnd(
251                             List.of(SearchCond.getLeaf(cond1),
252                                     SearchCond.getLeaf(cond2),
253                                     SearchCond.getLeaf(cond3),
254                                     SearchCond.getLeaf(cond4),
255                                     SearchCond.getLeaf(cond5),
256                                     SearchCond.getLeaf(cond6))),
257                     AnyTypeKind.USER);
258             assertEquals(Query.Kind.Bool, query._kind());
259             assertEquals(6, ((BoolQuery) query._get()).must().size());
260             assertThat(
261                     new Query.Builder().bool(QueryBuilders.bool().
262                             must(new Query.Builder().term(
263                                     QueryBuilders.term().field("id").value(FieldValue.of("1")).build()).build()).
264                             must(new Query.Builder().term(
265                                     QueryBuilders.term().field("id").value(FieldValue.of("2")).build()).build()).
266                             must(new Query.Builder().term(
267                                     QueryBuilders.term().field("id").value(FieldValue.of("3")).build()).build()).
268                             must(new Query.Builder().term(
269                                     QueryBuilders.term().field("id").value(FieldValue.of("4")).build()).build()).
270                             must(new Query.Builder().term(
271                                     QueryBuilders.term().field("id").value(FieldValue.of("5")).build()).build()).
272                             must(new Query.Builder().term(
273                                     QueryBuilders.term().field("id").value(FieldValue.of("6")).build()).build()).
274                             build()).build()).
275                     usingRecursiveComparison().isEqualTo(query);
276 
277             query = searchDAO.getQuery(
278                     SearchCond.getOr(
279                             List.of(SearchCond.getLeaf(cond1),
280                                     SearchCond.getLeaf(cond2),
281                                     SearchCond.getLeaf(cond3),
282                                     SearchCond.getLeaf(cond4),
283                                     SearchCond.getLeaf(cond5),
284                                     SearchCond.getLeaf(cond6))),
285                     AnyTypeKind.USER);
286             assertEquals(Query.Kind.DisMax, query._kind());
287             assertEquals(6, ((DisMaxQuery) query._get()).queries().size());
288             assertThat(
289                     new Query.Builder().disMax(QueryBuilders.disMax().
290                             queries(new Query.Builder().term(
291                                     QueryBuilders.term().field("id").value(FieldValue.of("1")).build()).build()).
292                             queries(new Query.Builder().term(
293                                     QueryBuilders.term().field("id").value(FieldValue.of("2")).build()).build()).
294                             queries(new Query.Builder().term(
295                                     QueryBuilders.term().field("id").value(FieldValue.of("3")).build()).build()).
296                             queries(new Query.Builder().term(
297                                     QueryBuilders.term().field("id").value(FieldValue.of("4")).build()).build()).
298                             queries(new Query.Builder().term(
299                                     QueryBuilders.term().field("id").value(FieldValue.of("5")).build()).build()).
300                             queries(new Query.Builder().term(
301                                     QueryBuilders.term().field("id").value(FieldValue.of("6")).build()).build()).
302                             build()).build()).
303                     usingRecursiveComparison().isEqualTo(query);
304 
305             query = searchDAO.getQuery(
306                     SearchCond.getAnd(List.of(
307                             SearchCond.getOr(List.of(
308                                     SearchCond.getLeaf(cond1),
309                                     SearchCond.getLeaf(cond2),
310                                     SearchCond.getLeaf(cond3))),
311                             SearchCond.getOr(List.of(
312                                     SearchCond.getLeaf(cond4),
313                                     SearchCond.getLeaf(cond5),
314                                     SearchCond.getLeaf(cond6))))),
315                     AnyTypeKind.USER);
316             assertEquals(Query.Kind.Bool, query._kind());
317             assertEquals(2, ((BoolQuery) query._get()).must().size());
318             Query left = ((BoolQuery) query._get()).must().get(0);
319             assertEquals(Query.Kind.DisMax, left._kind());
320             assertEquals(3, ((DisMaxQuery) left._get()).queries().size());
321             Query right = ((BoolQuery) query._get()).must().get(1);
322             assertEquals(Query.Kind.DisMax, right._kind());
323             assertEquals(3, ((DisMaxQuery) right._get()).queries().size());
324             assertThat(
325                     new Query.Builder().bool(QueryBuilders.bool().
326                             must(new Query.Builder().disMax(QueryBuilders.disMax().
327                                     queries(new Query.Builder().term(
328                                             QueryBuilders.term().field("id").
329                                                     value(FieldValue.of("1")).build()).build()).
330                                     queries(new Query.Builder().term(
331                                             QueryBuilders.term().field("id").
332                                                     value(FieldValue.of("2")).build()).build()).
333                                     queries(new Query.Builder().term(
334                                             QueryBuilders.term().field("id").
335                                                     value(FieldValue.of("3")).build()).build()).build()).
336                                     build()).
337                             must(new Query.Builder().disMax(QueryBuilders.disMax().
338                                     queries(new Query.Builder().term(
339                                             QueryBuilders.term().field("id").
340                                                     value(FieldValue.of("4")).build()).build()).
341                                     queries(new Query.Builder().term(
342                                             QueryBuilders.term().field("id").
343                                                     value(FieldValue.of("5")).build()).build()).
344                                     queries(new Query.Builder().term(
345                                             QueryBuilders.term().field("id").
346                                                     value(FieldValue.of("6")).build()).build()).build()).
347                                     build()).
348                             build()).build()).
349                     usingRecursiveComparison().isEqualTo(query);
350 
351             query = searchDAO.getQuery(
352                     SearchCond.getOr(List.of(
353                             SearchCond.getAnd(List.of(
354                                     SearchCond.getLeaf(cond1),
355                                     SearchCond.getLeaf(cond2),
356                                     SearchCond.getLeaf(cond3))),
357                             SearchCond.getAnd(List.of(
358                                     SearchCond.getLeaf(cond4),
359                                     SearchCond.getLeaf(cond5),
360                                     SearchCond.getLeaf(cond6))))),
361                     AnyTypeKind.USER);
362             assertEquals(Query.Kind.DisMax, query._kind());
363             assertEquals(2, ((DisMaxQuery) query._get()).queries().size());
364             left = ((DisMaxQuery) query._get()).queries().get(0);
365             assertEquals(Query.Kind.Bool, left._kind());
366             assertEquals(3, ((BoolQuery) left._get()).must().size());
367             right = ((DisMaxQuery) query._get()).queries().get(1);
368             assertEquals(Query.Kind.Bool, right._kind());
369             assertEquals(3, ((BoolQuery) right._get()).must().size());
370             assertThat(
371                     new Query.Builder().disMax(QueryBuilders.disMax().
372                             queries(new Query.Builder().bool(QueryBuilders.bool().
373                                     must(new Query.Builder().term(
374                                             QueryBuilders.term().field("id").
375                                                     value(FieldValue.of("1")).build()).build()).
376                                     must(new Query.Builder().term(
377                                             QueryBuilders.term().field("id").
378                                                     value(FieldValue.of("2")).build()).build()).
379                                     must(new Query.Builder().term(
380                                             QueryBuilders.term().field("id").
381                                                     value(FieldValue.of("3")).build()).build()).build()).
382                                     build()).
383                             queries(new Query.Builder().bool(QueryBuilders.bool().
384                                     must(new Query.Builder().term(
385                                             QueryBuilders.term().field("id").
386                                                     value(FieldValue.of("4")).build()).build()).
387                                     must(new Query.Builder().term(
388                                             QueryBuilders.term().field("id").
389                                                     value(FieldValue.of("5")).build()).build()).
390                                     must(new Query.Builder().term(
391                                             QueryBuilders.term().field("id").
392                                                     value(FieldValue.of("6")).build()).build()).build()).
393                                     build()).
394                             build()).build()).
395                     usingRecursiveComparison().isEqualTo(query);
396         }
397     }
398 }