package org.tensorflow.contrib.android;

import android.content.res.AssetManager;
import android.os.Build;
import android.os.Trace;
import android.text.TextUtils;
import android.util.Log;
import com.sensetime.stmobile.STMobileHumanActionNative;
import defpackage.C2984hka;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.types.UInt8;

/* loaded from: classes2.dex */
public class TensorFlowInferenceInterface {
    private static final String ASSET_FILE_PREFIX = "file:///android_asset/";
    private static final String TAG = "TensorFlowInferenceInterface";
    private List<String> feedNames = new ArrayList();
    private List<Tensor<?>> feedTensors = new ArrayList();
    private List<String> fetchNames = new ArrayList();
    private List<Tensor<?>> fetchTensors = new ArrayList();
    private final Graph g;
    private final String modelName;
    private RunStats runStats;
    private Session.Runner runner;
    private final Session sess;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: classes2.dex */
    public static class TensorId {
        String name;
        int outputIndex;

        private TensorId() {
        }

        public static TensorId parse(String str) {
            TensorId tensorId = new TensorId();
            int lastIndexOf = str.lastIndexOf(58);
            if (lastIndexOf < 0) {
                tensorId.outputIndex = 0;
                tensorId.name = str;
                return tensorId;
            }
            try {
                tensorId.outputIndex = Integer.parseInt(str.substring(lastIndexOf + 1));
                tensorId.name = str.substring(0, lastIndexOf);
            } catch (NumberFormatException unused) {
                tensorId.outputIndex = 0;
                tensorId.name = str;
            }
            return tensorId;
        }
    }

    public TensorFlowInferenceInterface(AssetManager assetManager, String str) {
        String str2;
        InputStream fileInputStream;
        prepareNativeRuntime();
        this.modelName = str;
        this.g = new Graph();
        this.sess = new Session(this.g, (byte[]) null);
        this.runner = this.sess.runner();
        boolean startsWith = str.startsWith(ASSET_FILE_PREFIX);
        if (startsWith) {
            try {
                str2 = str.split(ASSET_FILE_PREFIX)[1];
            } catch (IOException e) {
                if (startsWith) {
                    throw new RuntimeException(C2984hka.i("Failed to load model from '", str, "'"), e);
                }
                try {
                    fileInputStream = new FileInputStream(str);
                } catch (IOException unused) {
                    throw new RuntimeException(C2984hka.i("Failed to load model from '", str, "'"), e);
                }
            }
        } else {
            str2 = str;
        }
        fileInputStream = assetManager.open(str2);
        try {
            if (Build.VERSION.SDK_INT >= 18) {
                Trace.beginSection("initializeTensorFlow");
                Trace.beginSection("readGraphDef");
            }
            byte[] bArr = new byte[fileInputStream.available()];
            int read = fileInputStream.read(bArr);
            if (read != bArr.length) {
                throw new IOException("read error: read only " + read + " of the graph, expected to read " + bArr.length);
            }
            if (Build.VERSION.SDK_INT >= 18) {
                Trace.endSection();
            }
            loadGraph(bArr, this.g);
            fileInputStream.close();
            String str3 = "Successfully loaded model from '" + str + "'";
            if (Build.VERSION.SDK_INT >= 18) {
                Trace.endSection();
            }
        } catch (IOException e2) {
            throw new RuntimeException(C2984hka.i("Failed to load model from '", str, "'"), e2);
        }
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    public TensorFlowInferenceInterface(InputStream inputStream) {
        prepareNativeRuntime();
        this.modelName = "";
        this.g = new Graph();
        this.sess = new Session(this.g, (byte[]) null);
        this.runner = this.sess.runner();
        try {
            if (Build.VERSION.SDK_INT >= 18) {
                Trace.beginSection("initializeTensorFlow");
                Trace.beginSection("readGraphDef");
            }
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(inputStream.available() > 16384 ? inputStream.available() : STMobileHumanActionNative.ST_MOBILE_ENABLE_BODY_CONTOUR);
            byte[] bArr = new byte[STMobileHumanActionNative.ST_MOBILE_ENABLE_BODY_CONTOUR];
            while (true) {
                int read = inputStream.read(bArr, 0, bArr.length);
                if (read == -1) {
                    break;
                } else {
                    byteArrayOutputStream.write(bArr, 0, read);
                }
            }
            byte[] byteArray = byteArrayOutputStream.toByteArray();
            if (Build.VERSION.SDK_INT >= 18) {
                Trace.endSection();
            }
            loadGraph(byteArray, this.g);
            if (Build.VERSION.SDK_INT >= 18) {
                Trace.endSection();
            }
        } catch (IOException e) {
            throw new RuntimeException("Failed to load model from the input stream", e);
        }
    }

    public TensorFlowInferenceInterface(Graph graph) {
        prepareNativeRuntime();
        this.modelName = "";
        this.g = graph;
        this.sess = new Session(graph, (byte[]) null);
        this.runner = this.sess.runner();
    }

    private void addFeed(String str, Tensor<?> tensor) {
        TensorId parse = TensorId.parse(str);
        this.runner.feed(parse.name, parse.outputIndex, tensor);
        this.feedNames.add(str);
        this.feedTensors.add(tensor);
    }

    private void closeFeeds() {
        Iterator<Tensor<?>> it = this.feedTensors.iterator();
        while (it.hasNext()) {
            it.next().close();
        }
        this.feedTensors.clear();
        this.feedNames.clear();
    }

    private void closeFetches() {
        Iterator<Tensor<?>> it = this.fetchTensors.iterator();
        while (it.hasNext()) {
            it.next().close();
        }
        this.fetchTensors.clear();
        this.fetchNames.clear();
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private Tensor<?> getTensor(String str) {
        Iterator<String> it = this.fetchNames.iterator();
        int i = 0;
        while (it.hasNext()) {
            if (it.next().equals(str)) {
                return this.fetchTensors.get(i);
            }
            i++;
        }
        throw new RuntimeException(C2984hka.i("Node '", str, "' was not provided to run(), so it cannot be read"));
    }

    private void loadGraph(byte[] bArr, Graph graph) throws IOException {
        long currentTimeMillis = System.currentTimeMillis();
        if (Build.VERSION.SDK_INT >= 18) {
            Trace.beginSection("importGraphDef");
        }
        try {
            graph.importGraphDef(bArr);
            if (Build.VERSION.SDK_INT >= 18) {
                Trace.endSection();
            }
            long currentTimeMillis2 = System.currentTimeMillis();
            StringBuilder jg = C2984hka.jg("Model load took ");
            jg.append(currentTimeMillis2 - currentTimeMillis);
            jg.append("ms, TensorFlow version: ");
            jg.append(TensorFlow.version());
            jg.toString();
        } catch (IllegalArgumentException e) {
            StringBuilder jg2 = C2984hka.jg("Not a valid TensorFlow Graph serialization: ");
            jg2.append(e.getMessage());
            throw new IOException(jg2.toString());
        }
    }

    private void prepareNativeRuntime() {
        try {
            try {
                new RunStats();
            } catch (UnsatisfiedLinkError unused) {
                System.loadLibrary("tensorflow_inference");
            }
        } catch (UnsatisfiedLinkError unused2) {
            throw new RuntimeException("Native TF methods not found; check that the correct native libraries are present in the APK.");
        }
    }

    public void close() {
        closeFeeds();
        closeFetches();
        this.sess.close();
        this.g.close();
        RunStats runStats = this.runStats;
        if (runStats != null) {
            runStats.close();
        }
        this.runStats = null;
    }

    public void feed(String str, ByteBuffer byteBuffer, long... jArr) {
        addFeed(str, Tensor.create(UInt8.class, jArr, byteBuffer));
    }

    public void feed(String str, DoubleBuffer doubleBuffer, long... jArr) {
        addFeed(str, Tensor.create(jArr, doubleBuffer));
    }

    public void feed(String str, FloatBuffer floatBuffer, long... jArr) {
        addFeed(str, Tensor.create(jArr, floatBuffer));
    }

    public void feed(String str, IntBuffer intBuffer, long... jArr) {
        addFeed(str, Tensor.create(jArr, intBuffer));
    }

    public void feed(String str, LongBuffer longBuffer, long... jArr) {
        addFeed(str, Tensor.create(jArr, longBuffer));
    }

    public void feed(String str, byte[] bArr, long... jArr) {
        addFeed(str, Tensor.create(UInt8.class, jArr, ByteBuffer.wrap(bArr)));
    }

    public void feed(String str, double[] dArr, long... jArr) {
        addFeed(str, Tensor.create(jArr, DoubleBuffer.wrap(dArr)));
    }

    public void feed(String str, float[] fArr, long... jArr) {
        addFeed(str, Tensor.create(jArr, FloatBuffer.wrap(fArr)));
    }

    public void feed(String str, int[] iArr, long... jArr) {
        addFeed(str, Tensor.create(jArr, IntBuffer.wrap(iArr)));
    }

    public void feed(String str, long[] jArr, long... jArr2) {
        addFeed(str, Tensor.create(jArr2, LongBuffer.wrap(jArr)));
    }

    public void feed(String str, boolean[] zArr, long... jArr) {
        byte[] bArr = new byte[zArr.length];
        for (int i = 0; i < zArr.length; i++) {
            bArr[i] = zArr[i] ? (byte) 1 : (byte) 0;
        }
        addFeed(str, Tensor.create(Boolean.class, jArr, ByteBuffer.wrap(bArr)));
    }

    public void feedString(String str, byte[] bArr) {
        addFeed(str, Tensor.create(bArr, String.class));
    }

    public void feedString(String str, byte[][] bArr) {
        addFeed(str, Tensor.create(bArr, String.class));
    }

    public void fetch(String str, ByteBuffer byteBuffer) {
        getTensor(str).writeTo(byteBuffer);
    }

    public void fetch(String str, DoubleBuffer doubleBuffer) {
        getTensor(str).writeTo(doubleBuffer);
    }

    public void fetch(String str, FloatBuffer floatBuffer) {
        getTensor(str).writeTo(floatBuffer);
    }

    public void fetch(String str, IntBuffer intBuffer) {
        getTensor(str).writeTo(intBuffer);
    }

    public void fetch(String str, LongBuffer longBuffer) {
        getTensor(str).writeTo(longBuffer);
    }

    public void fetch(String str, byte[] bArr) {
        fetch(str, ByteBuffer.wrap(bArr));
    }

    public void fetch(String str, double[] dArr) {
        fetch(str, DoubleBuffer.wrap(dArr));
    }

    public void fetch(String str, float[] fArr) {
        fetch(str, FloatBuffer.wrap(fArr));
    }

    public void fetch(String str, int[] iArr) {
        fetch(str, IntBuffer.wrap(iArr));
    }

    public void fetch(String str, long[] jArr) {
        fetch(str, LongBuffer.wrap(jArr));
    }

    protected void finalize() throws Throwable {
        try {
            close();
        } finally {
            super.finalize();
        }
    }

    public String getStatString() {
        RunStats runStats = this.runStats;
        return runStats == null ? "" : runStats.summary();
    }

    public Graph graph() {
        return this.g;
    }

    public Operation graphOperation(String str) {
        Operation operation = this.g.operation(str);
        if (operation != null) {
            return operation;
        }
        throw new RuntimeException(C2984hka.a(C2984hka.j("Node '", str, "' does not exist in model '"), this.modelName, "'"));
    }

    public void run(String[] strArr) {
        run(strArr, false);
    }

    public void run(String[] strArr, boolean z) {
        run(strArr, z, new String[0]);
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    public void run(String[] strArr, boolean z, String[] strArr2) {
        closeFetches();
        for (String str : strArr) {
            this.fetchNames.add(str);
            TensorId parse = TensorId.parse(str);
            this.runner.fetch(parse.name, parse.outputIndex);
        }
        for (String str2 : strArr2) {
            this.runner.addTarget(str2);
        }
        try {
            try {
                if (z) {
                    Session.Run runAndFetchMetadata = this.runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
                    this.fetchTensors = runAndFetchMetadata.outputs;
                    if (this.runStats == null) {
                        this.runStats = new RunStats();
                    }
                    this.runStats.add(runAndFetchMetadata.metadata);
                } else {
                    this.fetchTensors = this.runner.run();
                }
            } catch (RuntimeException e) {
                Log.e(TAG, "Failed to run TensorFlow inference with inputs:[" + TextUtils.join(", ", this.feedNames) + "], outputs:[" + TextUtils.join(", ", this.fetchNames) + "]");
                throw e;
            }
        } finally {
            closeFeeds();
            this.runner = this.sess.runner();
        }
    }
}
