View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *   http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing,
13   * software distributed under the License is distributed on an
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   * KIND, either express or implied.  See the License for the
16   * specific language governing permissions and limitations
17   * under the License.
18   */
19  package org.apache.syncope.core.persistence.jpa.dao;
20  
21  import java.util.List;
22  import java.util.stream.Collectors;
23  import org.apache.commons.lang3.StringUtils;
24  import org.apache.syncope.common.lib.SyncopeConstants;
25  import org.apache.syncope.core.persistence.api.dao.MalformedPathException;
26  import org.apache.syncope.core.persistence.api.dao.RoleDAO;
27  import org.apache.syncope.core.persistence.api.entity.Realm;
28  import org.apache.syncope.core.spring.security.AuthContextUtils;
29  import org.apache.syncope.ext.opensearch.client.OpenSearchUtils;
30  import org.opensearch.client.opensearch.OpenSearchClient;
31  import org.opensearch.client.opensearch._types.FieldValue;
32  import org.opensearch.client.opensearch._types.ScriptSortType;
33  import org.opensearch.client.opensearch._types.SearchType;
34  import org.opensearch.client.opensearch._types.SortOptions;
35  import org.opensearch.client.opensearch._types.SortOrder;
36  import org.opensearch.client.opensearch._types.query_dsl.Query;
37  import org.opensearch.client.opensearch._types.query_dsl.QueryBuilders;
38  import org.opensearch.client.opensearch.core.CountRequest;
39  import org.opensearch.client.opensearch.core.SearchRequest;
40  import org.opensearch.client.opensearch.core.search.Hit;
41  import org.springframework.context.ApplicationEventPublisher;
42  import org.springframework.transaction.annotation.Transactional;
43  
44  public class OpenSearchRealmDAO extends JPARealmDAO {
45  
46      protected static final List<SortOptions> ES_SORT_OPTIONS_REALM = List.of(
47              new SortOptions.Builder().
48                      script(s -> s.type(ScriptSortType.Number).
49                      script(t -> t.inline(i -> i.lang("painless").
50                      source("doc['fullPath'].value.chars().filter(ch -> ch == '/').count()"))).
51                      order(SortOrder.Asc)).
52                      build());
53  
54      protected final OpenSearchClient client;
55  
56      protected final int indexMaxResultWindow;
57  
58      public OpenSearchRealmDAO(
59              final RoleDAO roleDAO,
60              final ApplicationEventPublisher publisher,
61              final OpenSearchClient client,
62              final int indexMaxResultWindow) {
63  
64          super(roleDAO, publisher);
65          this.client = client;
66          this.indexMaxResultWindow = indexMaxResultWindow;
67      }
68  
69      @Transactional(readOnly = true)
70      @Override
71      public Realm findByFullPath(final String fullPath) {
72          if (SyncopeConstants.ROOT_REALM.equals(fullPath)) {
73              return getRoot();
74          }
75  
76          if (StringUtils.isBlank(fullPath) || !PATH_PATTERN.matcher(fullPath).matches()) {
77              throw new MalformedPathException(fullPath);
78          }
79  
80          SearchRequest request = new SearchRequest.Builder().
81                  index(OpenSearchUtils.getRealmIndex(AuthContextUtils.getDomain())).
82                  searchType(SearchType.QueryThenFetch).
83                  query(new Query.Builder().term(QueryBuilders.term().
84                          field("fullPath").value(FieldValue.of(fullPath)).build()).build()).
85                  size(1).
86                  build();
87  
88          try {
89              String result = client.search(request, Void.class).hits().hits().stream().findFirst().
90                      map(Hit::id).
91                      orElse(null);
92              return find(result);
93          } catch (Exception e) {
94              LOG.error("While searching ES for one match", e);
95          }
96  
97          return null;
98      }
99  
100     protected List<String> search(final Query query) {
101         SearchRequest request = new SearchRequest.Builder().
102                 index(OpenSearchUtils.getRealmIndex(AuthContextUtils.getDomain())).
103                 searchType(SearchType.QueryThenFetch).
104                 query(query).
105                 sort(ES_SORT_OPTIONS_REALM).
106                 build();
107 
108         try {
109             return client.search(request, Void.class).hits().hits().stream().
110                     map(Hit::id).
111                     collect(Collectors.toList());
112         } catch (Exception e) {
113             LOG.error("While searching in OpenSearch", e);
114             return List.of();
115         }
116     }
117 
118     @Override
119     public List<Realm> findByName(final String name) {
120         List<String> result = search(
121                 new Query.Builder().term(QueryBuilders.term().
122                         field("name").value(FieldValue.of(name)).build()).build());
123         return result.stream().map(this::find).collect(Collectors.toList());
124     }
125 
126     @Override
127     public List<Realm> findChildren(final Realm realm) {
128         List<String> result = search(
129                 new Query.Builder().term(QueryBuilders.term().
130                         field("parent_id").value(FieldValue.of(realm.getKey())).build()).build());
131         return result.stream().map(this::find).collect(Collectors.toList());
132     }
133 
134     protected Query buildDescendantQuery(final String base, final String keyword) {
135         Query prefix = new Query.Builder().disMax(QueryBuilders.disMax().queries(
136                 new Query.Builder().term(QueryBuilders.term().
137                         field("fullPath").value(FieldValue.of(base)).build()).build(),
138                 new Query.Builder().regexp(QueryBuilders.regexp().
139                         field("fullPath").value(SyncopeConstants.ROOT_REALM.equals(base) ? "/.*" : base + "/.*").
140                         build()).build()).build()).build();
141 
142         if (keyword == null) {
143             return prefix;
144         }
145 
146         StringBuilder output = new StringBuilder();
147         for (char c : keyword.toLowerCase().toCharArray()) {
148             if (c == '%') {
149                 output.append(".*");
150             } else if (Character.isLetter(c)) {
151                 output.append('[').
152                         append(c).
153                         append(Character.toUpperCase(c)).
154                         append(']');
155             } else {
156                 output.append(OpenSearchUtils.escapeForLikeRegex(c));
157             }
158         }
159 
160         return new Query.Builder().bool(QueryBuilders.bool().must(
161                 prefix,
162                 new Query.Builder().regexp(QueryBuilders.regexp().
163                         field("name").value(output.toString()).build()).
164                         build()).build()).
165                 build();
166     }
167 
168     @Override
169     public int countDescendants(final String base, final String keyword) {
170         CountRequest request = new CountRequest.Builder().
171                 index(OpenSearchUtils.getRealmIndex(AuthContextUtils.getDomain())).
172                 query(buildDescendantQuery(base, keyword)).
173                 build();
174 
175         try {
176             return (int) client.count(request).count();
177         } catch (Exception e) {
178             LOG.error("While counting in OpenSearch", e);
179             return 0;
180         }
181     }
182 
183     @Override
184     public List<Realm> findDescendants(
185             final String base,
186             final String keyword,
187             final int page,
188             final int itemsPerPage) {
189 
190         SearchRequest request = new SearchRequest.Builder().
191                 index(OpenSearchUtils.getRealmIndex(AuthContextUtils.getDomain())).
192                 searchType(SearchType.QueryThenFetch).
193                 query(buildDescendantQuery(base, keyword)).
194                 from(itemsPerPage * (page <= 0 ? 0 : page - 1)).
195                 size(itemsPerPage < 0 ? indexMaxResultWindow : itemsPerPage).
196                 sort(ES_SORT_OPTIONS_REALM).
197                 build();
198 
199         List<String> result = List.of();
200         try {
201             result = client.search(request, Void.class).hits().hits().stream().
202                     map(Hit::id).
203                     collect(Collectors.toList());
204         } catch (Exception e) {
205             LOG.error("While searching in OpenSearch", e);
206         }
207 
208         return result.stream().map(this::find).collect(Collectors.toList());
209     }
210 
211     @Override
212     public List<String> findDescendants(final String base, final String prefix) {
213         Query prefixQuery = new Query.Builder().disMax(QueryBuilders.disMax().queries(
214                 new Query.Builder().term(QueryBuilders.term().
215                         field("fullPath").value(FieldValue.of(base)).build()).build(),
216                 new Query.Builder().prefix(QueryBuilders.prefix().
217                         field("fullPath").value(SyncopeConstants.ROOT_REALM.equals(prefix) ? "/" : prefix + "/").
218                         build()).build()).build()).build();
219 
220         Query query = new Query.Builder().bool(QueryBuilders.bool().must(
221                 buildDescendantQuery(base, (String) null),
222                 prefixQuery).build()).
223                 build();
224 
225         SearchRequest request = new SearchRequest.Builder().
226                 index(OpenSearchUtils.getRealmIndex(AuthContextUtils.getDomain())).
227                 searchType(SearchType.QueryThenFetch).
228                 query(query).
229                 from(0).
230                 size(indexMaxResultWindow).
231                 sort(ES_SORT_OPTIONS_REALM).
232                 build();
233 
234         List<String> result = List.of();
235         try {
236             result = client.search(request, Void.class).hits().hits().stream().
237                     map(Hit::id).
238                     collect(Collectors.toList());
239         } catch (Exception e) {
240             LOG.error("While searching in OpenSearch", e);
241         }
242         return result;
243     }
244 }