/*
 * Decompiled with CFR 0.152.
 */
package com.hcl.appscan.ifa.nlp;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.hcl.appscan.ifa.nlp.FindingFeatureModel;
import com.hcl.appscan.ifa.nlp.IFindingEvaluator;
import com.ibm.appscan.assessment.model.Finding;
import com.ibm.appscan.common.utils.functional.Lazy;
import com.ibm.appscan.ifa.common.IfaException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.NoSuchElementException;
import org.apache.commons.io.IOUtils;

public abstract class OnnxModel
extends FindingFeatureModel
implements AutoCloseable,
IFindingEvaluator {
    protected Lazy<OrtSession> m_session = Lazy.of(this::sessionLoader);
    protected static Lazy<OrtEnvironment> m_environment = Lazy.of(OrtEnvironment::getEnvironment);
    protected String m_model_resource;

    @Override
    public void close() throws OrtException {
        if (this.m_session.isLoaded()) {
            ((OrtSession)this.m_session.get()).close();
        }
    }

    protected HashMap<String, OnnxTensor> transform(Finding f) throws OrtException, IfaException {
        HashMap<String, OnnxTensor> output = new HashMap<String, OnnxTensor>();
        long[] shape = new long[]{1L, this.m_features.size()};
        String[] strings = new String[this.m_features.size()];
        output.put("input", OnnxTensor.createTensor((OrtEnvironment)((OrtEnvironment)m_environment.get()), (String[])this.getFeatures(f).toArray(strings), (long[])shape));
        return output;
    }

    private OrtSession sessionLoader() {
        InputStream model_stream = this.getClass().getResourceAsStream(this.m_model_resource);
        try {
            OrtSession session = ((OrtEnvironment)m_environment.get()).createSession(IOUtils.toByteArray((InputStream)model_stream));
            return session;
        }
        catch (Exception e) {
            throw new NoSuchElementException("Could not load model resource ->" + e.getMessage());
        }
    }

    @Override
    public Float getPredictedValue(Finding f) throws Exception {
        try {
            OrtSession.Result result = ((OrtSession)this.m_session.get()).run(this.transform(f));
            float[][] predictions = (float[][])result.get(1).getValue();
            return Float.valueOf(predictions[0][1]);
        }
        catch (OrtException e) {
            throw new IfaException("IFA Inference failed -> " + e.getMessage());
        }
    }
}

