diff --git a/apps/search_indexes/filters.py b/apps/search_indexes/filters.py index ce77bc06..3101fdc2 100644 --- a/apps/search_indexes/filters.py +++ b/apps/search_indexes/filters.py @@ -4,6 +4,7 @@ from django_elasticsearch_dsl_drf.filter_backends import SearchFilterBackend, \ FacetedSearchFilterBackend, GeoSpatialFilteringFilterBackend from search_indexes.utils import OBJECT_FIELD_PROPERTIES from six import iteritems +from functools import reduce class CustomGeoSpatialFilteringFilterBackend(GeoSpatialFilteringFilterBackend): @@ -56,10 +57,30 @@ class CustomFacetedSearchFilterBackend(FacetedSearchFilterBackend): :param view: :return: """ - def makefilter(cur_facet): - def myfilter(x): + def make_filter(cur_facet): + def _filter(x): return cur_facet['facet']._params['field'] != next(iter(x._params)) - return myfilter + return _filter + + def make_tags_filter(cur_facet, tags_to_remove_ids): + def _filter(x): + if hasattr(x, '_params') and (x._params.get('must') or x._params.get('should')): + ret = [] + for t in ['must', 'should']: + terms = x._params.get(t) + if terms: + for term in terms: + if cur_facet['facet']._params['field'] != next(iter(term._params)): + return True # different fields. preserve filter + else: + ret.append(next(iter(term._params.values())) not in tags_to_remove_ids) + return all(ret) + if cur_facet['facet']._params['field'] != next(iter(x._params)): + return True # different fields. preserve filter + else: + return next(iter(x._params.values())) not in tags_to_remove_ids + return _filter + __facets = self.construct_facets(request, view) setattr(view.paginator, 'facets_computed', {}) for __field, __facet in iteritems(__facets): @@ -71,29 +92,84 @@ class CustomFacetedSearchFilterBackend(FacetedSearchFilterBackend): 'global' ).bucket(__field, agg) else: - qs = queryset.__copy__() - qs.query = queryset.query._clone() - filterer = makefilter(__facet) - for param_type in ['must', 'must_not', 'should']: - if qs.query._proxied._params.get(param_type): - qs.query._proxied._params[param_type] = list( - filter( - filterer, qs.query._proxied._params[param_type] + if __field != 'tag' or not request.query_params.getlist('tags_id__in'): + qs = queryset.__copy__() + qs.query = queryset.query._clone() + filterer = make_filter(__facet) + for param_type in ['must', 'must_not', 'should']: + if qs.query._proxied._params.get(param_type): + qs.query._proxied._params[param_type] = list( + filter( + filterer, qs.query._proxied._params[param_type] + ) ) - ) - sh = qs.query._proxied._params.get('should') - if (not sh or not len(sh)) \ - and qs.query._proxied._params.get('minimum_should_match'): - qs.query._proxied._params.pop('minimum_should_match') - facet_name = '_filter_' + __field - qs.aggs.bucket( - facet_name, - 'filter', - filter=agg_filter - ).bucket(__field, agg) - view.paginator.facets_computed.update({facet_name: qs.execute().aggregations[facet_name]}) + sh = qs.query._proxied._params.get('should') + if (not sh or not len(sh)) \ + and qs.query._proxied._params.get('minimum_should_match'): + qs.query._proxied._params.pop('minimum_should_match') + facet_name = '_filter_' + __field + qs.aggs.bucket( + facet_name, + 'filter', + filter=agg_filter + ).bucket(__field, agg) + view.paginator.facets_computed.update({facet_name: qs.execute().aggregations[facet_name]}) + else: + tag_facets = [] + facet_name = '_filter_' + __field + for category_tags_ids in request.query_params.getlist('tags_id__in'): + tags_to_remove = category_tags_ids.split('__') + qs = queryset.__copy__() + qs.query = queryset.query._clone() + filterer = make_tags_filter(__facet, tags_to_remove) + for param_type in ['must', 'should']: + if qs.query._proxied._params.get(param_type): + if qs.query._proxied._params.get(param_type): + qs.query._proxied._params[param_type] = list( + filter( + filterer, qs.query._proxied._params[param_type] + ) + ) + sh = qs.query._proxied._params.get('should') + if (not sh or not len(sh)) \ + and qs.query._proxied._params.get('minimum_should_match'): + qs.query._proxied._params.pop('minimum_should_match') + qs.aggs.bucket( + facet_name, + 'filter', + filter=agg_filter + ).bucket(__field, agg) + tag_facets.append(qs.execute().aggregations[facet_name]) + view.paginator.facets_computed.update({facet_name: self.merge_buckets(tag_facets)}) return queryset + @staticmethod + def merge_buckets(buckets: list): + """Reduces all buckets preserving class""" + result_bucket = buckets[0] + for bucket in buckets[1:]: + for tag in bucket.tag.buckets._l_: + if tag not in result_bucket.tag.buckets._l_: + result_bucket.tag.buckets._l_.append(tag) + def reducer(prev, cur): + try: + index = list(map(lambda x: x['key'], prev)).index(cur['key']) + if cur['doc_count'] < prev[index]['doc_count']: + prev[index]['doc_count'] = cur['doc_count'] + except ValueError: + prev.append(cur) + return prev + + result_bucket.tag.buckets._l_ = list(reduce( + reducer, result_bucket.tag.buckets._l_, [] + )) + result_bucket.doc_count = reduce( + lambda prev, cur: prev + cur['doc_count'], + result_bucket.tag.buckets._l_, + 0 + ) + return result_bucket + class CustomSearchFilterBackend(SearchFilterBackend): """Custom SearchFilterBackend."""