package com.appian.documentunderstanding.prediction.table;

import com.appian.documentunderstanding.common.DocumentExtractionFeatureToggles;
import com.appian.documentunderstanding.populate.AnnotationType;
import com.appian.documentunderstanding.populate.KeyData;
import com.appian.documentunderstanding.populate.TableData;
import com.appian.documentunderstanding.prediction.DocumentUnderstandingAbstractEsPredictionService;
import com.appian.documentunderstanding.prediction.PredictionType;
import com.appian.documentunderstanding.prediction.SearchRequestExecutor;
import com.appian.documentunderstanding.prediction.datatypes.CustomComplexField;
import com.appian.documentunderstanding.prediction.datatypes.CustomDatatype;
import com.appian.documentunderstanding.prediction.datatypes.CustomField;
import com.appian.documentunderstanding.prediction.datatypes.CustomFieldType;
import com.appian.documentunderstanding.prediction.metrics.DocExtractPredictionMetricsCollector;
import com.appian.documentunderstanding.prediction.table.TablePredictionEsBridge;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.util.AbstractMap;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.xml.namespace.QName;
import org.apache.commons.collections.CollectionUtils;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.lucene.search.function.FieldValueFactorFunction;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.DisMaxQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.metrics.TopHitsAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortOrder;

/* loaded from: input_file:com/appian/documentunderstanding/prediction/table/DocumentUnderstandingTableEsPredictionService.class */
public class DocumentUnderstandingTableEsPredictionService extends DocumentUnderstandingAbstractEsPredictionService {
    static final int MAX_HEADER_PREDICTION_ROWS = 3;
    private static final String[] FETCH_SOURCE_INCLUDES = {TablePredictionEsBridge.Field.tableSignature.name(), TablePredictionEsBridge.Field.parentFieldName.name(), TablePredictionEsBridge.Field.mappingsJsonBlob.name(), TablePredictionEsBridge.Field.counter.name()};
    private final SearchRequestExecutor searchRequestExecutor;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/appian/documentunderstanding/prediction/table/DocumentUnderstandingTableEsPredictionService$TableHeaderIntersectionData.class */
    public static class TableHeaderIntersectionData {
        private final String tableId;
        private int headerIndex;
        private int maxIntersections;

        public TableHeaderIntersectionData(String str, int i, int i2) {
            this.tableId = str;
            this.headerIndex = i;
            this.maxIntersections = i2;
        }

        public String getTableId() {
            return this.tableId;
        }

        public int getHeaderIndex() {
            return this.headerIndex;
        }

        public void setHeaderIndex(int i) {
            this.headerIndex = i;
        }

        public int getMaxIntersections() {
            return this.maxIntersections;
        }

        public void setMaxIntersections(int i) {
            this.maxIntersections = i;
        }
    }

    public DocumentUnderstandingTableEsPredictionService(SearchRequestExecutor searchRequestExecutor, DocExtractPredictionMetricsCollector docExtractPredictionMetricsCollector, DocumentExtractionFeatureToggles documentExtractionFeatureToggles) {
        super(PredictionType.TABLE, docExtractPredictionMetricsCollector, documentExtractionFeatureToggles);
        this.searchRequestExecutor = searchRequestExecutor;
    }

    @Override // com.appian.documentunderstanding.prediction.PredictionService
    public Map<String, Collection<KeyData>> getPredictions(CustomDatatype customDatatype, Set<String> set, Collection<KeyData> collection, boolean z) {
        Stream<KeyData> stream = collection.stream();
        Class<TableData> cls = TableData.class;
        TableData.class.getClass();
        Stream<KeyData> filter = stream.filter((v1) -> {
            return r1.isInstance(v1);
        });
        Class<TableData> cls2 = TableData.class;
        TableData.class.getClass();
        List<TableData> list = (List) filter.map((v1) -> {
            return r1.cast(v1);
        }).collect(Collectors.toList());
        if (list.isEmpty()) {
            return Collections.emptyMap();
        }
        String unversionedQName = customDatatype.getUnversionedQName();
        Map<String, QName> mapRecordFieldsToQName = mapRecordFieldsToQName(customDatatype, set);
        Map<String, String> map = (Map) mapRecordFieldsToQName.entrySet().stream().map(entry -> {
            return new AbstractMap.SimpleEntry(entry.getKey(), ((QName) entry.getValue()).getLocalPart());
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }));
        if (map.isEmpty()) {
            return Collections.emptyMap();
        }
        SearchResponse searchIndex = searchIndex(unversionedQName, (Set) list.stream().flatMap(tableData -> {
            return tableData.getRows().stream().limit(3L);
        }).collect(Collectors.toSet()), map);
        ImmutableMap.Builder builder = ImmutableMap.builder();
        HashMap hashMap = new HashMap();
        for (TableData tableData2 : list) {
            hashMap.put(String.format("%s-%d-%d", AnnotationType.TABLE, tableData2.getPage(), tableData2.getIndexInPage()), tableData2);
        }
        getAllTopHits(searchIndex).forEach(searchHit -> {
            handleTopHit(searchHit, customDatatype, mapRecordFieldsToQName, hashMap, builder);
        });
        return builder.build();
    }

    /* JADX WARN: Type inference failed for: r2v0, types: [com.appian.documentunderstanding.prediction.table.DocumentUnderstandingTableEsPredictionService$1] */
    private void handleTopHit(SearchHit searchHit, CustomDatatype customDatatype, Map<String, QName> map, Map<String, TableData> map2, ImmutableMap.Builder<String, Collection<KeyData>> builder) {
        TableData tableData;
        Gson gson = new Gson();
        Map sourceAsMap = searchHit.getSourceAsMap();
        String str = (String) sourceAsMap.get(TablePredictionEsBridge.Field.parentFieldName.name());
        if (map.containsKey(str)) {
            Map map3 = (Map) gson.fromJson((String) sourceAsMap.get(TablePredictionEsBridge.Field.mappingsJsonBlob.name()), new TypeToken<Map<String, String>>() { // from class: com.appian.documentunderstanding.prediction.table.DocumentUnderstandingTableEsPredictionService.1
            }.getType());
            Optional<CustomField> findFirst = customDatatype.getFields(CustomFieldType.COMPLEX).stream().filter(customField -> {
                return customField.getName().equals(str);
            }).findFirst();
            Class<CustomComplexField> cls = CustomComplexField.class;
            CustomComplexField.class.getClass();
            Optional<CustomField> filter = findFirst.filter((v1) -> {
                return r1.isInstance(v1);
            });
            Class<CustomComplexField> cls2 = CustomComplexField.class;
            CustomComplexField.class.getClass();
            Set set = (Set) filter.map((v1) -> {
                return r1.cast(v1);
            }).map((v0) -> {
                return v0.getCustomDatatype();
            }).map(obj -> {
                return ((CustomDatatype) obj).getFieldNames(new CustomFieldType[0]);
            }).orElse(Collections.emptySet());
            Map<String, String> map4 = (Map) map3.entrySet().stream().filter(entry -> {
                return set.contains(entry.getValue());
            }).collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, (v0) -> {
                return v0.getValue();
            }));
            if (map4.isEmpty()) {
                return;
            }
            List list = (List) sourceAsMap.get(TablePredictionEsBridge.Field.tableSignature.name());
            TableHeaderIntersectionData tableHeaderIntersectionData = (TableHeaderIntersectionData) map2.keySet().stream().map(str2 -> {
                return new TableHeaderIntersectionData(str2, 0, 0);
            }).reduce(null, (tableHeaderIntersectionData2, tableHeaderIntersectionData3) -> {
                String tableId = tableHeaderIntersectionData3.getTableId();
                TableHeaderIntersectionData findMaxTableIntersection = findMaxTableIntersection(list, tableId, (TableData) map2.get(tableId));
                tableHeaderIntersectionData3.setHeaderIndex(findMaxTableIntersection.getHeaderIndex());
                tableHeaderIntersectionData3.setMaxIntersections(findMaxTableIntersection.getMaxIntersections());
                return (tableHeaderIntersectionData2 == null || tableHeaderIntersectionData3.getMaxIntersections() > tableHeaderIntersectionData2.getMaxIntersections()) ? tableHeaderIntersectionData3 : tableHeaderIntersectionData2;
            });
            if (tableHeaderIntersectionData == null || tableHeaderIntersectionData.getMaxIntersections() <= 0 || (tableData = map2.get(tableHeaderIntersectionData.getTableId())) == null) {
                return;
            }
            builder.put(str, ImmutableList.of(tableData.createTableDataWithHeaderIndexAndColumnMappings(tableHeaderIntersectionData.getHeaderIndex(), map4)));
            map.remove(str);
        }
    }

    private QueryBuilder wrapWithScoreFunction(QueryBuilder queryBuilder) {
        return QueryBuilders.functionScoreQuery(queryBuilder, ScoreFunctionBuilders.fieldValueFactorFunction(TablePredictionEsBridge.Field.counter.name()).modifier(FieldValueFactorFunction.Modifier.LOG1P));
    }

    SearchResponse searchIndex(String str, Set<List<String>> set, Map<String, String> map) {
        return this.searchRequestExecutor.execute(createSearchRequest(str, set, map));
    }

    SearchRequest createSearchRequest(String str, Set<List<String>> set, Map<String, String> map) {
        QueryBuilder queryBuilder;
        TopHitsAggregationBuilder sort;
        QueryBuilder filter = filter(str, map);
        boolean isCustomScoreEnabled = getFeatureToggles().isCustomScoreEnabled();
        TopHitsAggregationBuilder trackScores = AggregationBuilders.topHits("frequently_choosen").from(0).fetchSource(FETCH_SOURCE_INCLUDES, new String[0]).trackScores(true);
        if (isCustomScoreEnabled) {
            appendTableSignatureConstantScoreQuery(set, filter);
            queryBuilder = wrapWithScoreFunction(filter);
            sort = trackScores.sort(SortBuilders.scoreSort()).sort(SortBuilders.fieldSort(TablePredictionEsBridge.Field.counter.name()).order(SortOrder.DESC));
        } else {
            appendTableSignatureQuery(set, filter);
            queryBuilder = filter;
            sort = trackScores.sort(SortBuilders.fieldSort(TablePredictionEsBridge.Field.counter.name()).order(SortOrder.DESC)).sort(SortBuilders.scoreSort());
        }
        return this.searchRequestExecutor.buildSearchRequest(new SearchSourceBuilder().query(queryBuilder).aggregation(AggregationBuilders.terms("group_by_field_names").field(TablePredictionEsBridge.Field.parentFieldName.name()).size(map.size()).subAggregation(sort)));
    }

    private void appendTableSignatureConstantScoreQuery(Set<List<String>> set, BoolQueryBuilder boolQueryBuilder) {
        boolQueryBuilder.minimumShouldMatch(1);
        DisMaxQueryBuilder disMaxQuery = QueryBuilders.disMaxQuery();
        for (List<String> list : set) {
            BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
            Iterator<String> it = list.iterator();
            while (it.hasNext()) {
                boolQuery.should(QueryBuilders.constantScoreQuery(QueryBuilders.matchQuery(TablePredictionEsBridge.Field.tableSignature.name(), it.next())));
            }
            disMaxQuery.add(boolQuery);
        }
        boolQueryBuilder.should(disMaxQuery);
    }

    private BoolQueryBuilder filter(String str, Map<String, String> map) {
        BoolQueryBuilder queryName = QueryBuilders.boolQuery().queryName("CDT Fields (name and type) bool query");
        Stream<R> map2 = map.entrySet().stream().map(entry -> {
            return QueryBuilders.boolQuery().queryName("Filter for CDT field name '" + ((String) entry.getKey()) + "'").filter(QueryBuilders.termQuery(TablePredictionEsBridge.Field.parentFieldName.name(), (String) entry.getKey())).filter(QueryBuilders.termQuery(TablePredictionEsBridge.Field.childCdtQName.name(), (String) entry.getValue()));
        });
        queryName.getClass();
        map2.forEach((v1) -> {
            r1.should(v1);
        });
        return QueryBuilders.boolQuery().filter(QueryBuilders.termQuery(TablePredictionEsBridge.Field.parentCdtQName.name(), str).queryName("Parent CDT Filter")).filter(queryName);
    }

    private void appendTableSignatureQuery(Set<List<String>> set, BoolQueryBuilder boolQueryBuilder) {
        boolQueryBuilder.minimumShouldMatch(1);
        Iterator<List<String>> it = set.iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = it.next().iterator();
            while (it2.hasNext()) {
                boolQueryBuilder.should(QueryBuilders.matchQuery(TablePredictionEsBridge.Field.tableSignature.name(), it2.next()));
            }
        }
    }

    private TableHeaderIntersectionData findMaxTableIntersection(List<String> list, String str, TableData tableData) {
        Integer num = 0;
        Integer num2 = 0;
        List list2 = (List) tableData.getExtractedValue().subList(0, Math.min(tableData.getExtractedValue().size(), 3)).stream().mapToInt(list3 -> {
            return CollectionUtils.intersection(list3, list).size();
        }).boxed().collect(Collectors.toList());
        for (int i = 0; i < list2.size(); i++) {
            if (((Integer) list2.get(i)).intValue() > num2.intValue()) {
                num = Integer.valueOf(i);
                num2 = (Integer) list2.get(i);
            }
        }
        return new TableHeaderIntersectionData(str, num.intValue(), num2.intValue());
    }

    private Map<String, QName> mapRecordFieldsToQName(CustomDatatype customDatatype, Set<String> set) {
        Stream<CustomField> filter = customDatatype.getFields(CustomFieldType.COMPLEX).stream().filter(customField -> {
            return set == null || set.contains(customField.getName());
        });
        Class<CustomComplexField> cls = CustomComplexField.class;
        CustomComplexField.class.getClass();
        Stream<CustomField> filter2 = filter.filter((v1) -> {
            return r1.isInstance(v1);
        });
        Class<CustomComplexField> cls2 = CustomComplexField.class;
        CustomComplexField.class.getClass();
        return (Map) filter2.map((v1) -> {
            return r1.cast(v1);
        }).collect(Collectors.toMap((v0) -> {
            return v0.getName();
        }, customComplexField -> {
            return QName.valueOf(customComplexField.getCustomDatatype().getQName());
        }));
    }
}
