/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.huggingface.tokenizers;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.jni.CharSpan;
import ai.djl.huggingface.tokenizers.jni.LibUtils;
import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary;
import ai.djl.modality.nlp.preprocess.Tokenizer;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.util.Ec2Utils;
import ai.djl.util.NativeResource;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class HuggingFaceTokenizer
extends NativeResource<Long>
implements Tokenizer {
    private static final Logger logger = LoggerFactory.getLogger(HuggingFaceTokenizer.class);
    private boolean addSpecialTokens;
    private TruncationStrategy truncation;
    private PaddingStrategy padding;
    private int maxLength;
    private int stride;
    private int padToMultipleOf;
    private int modelMaxLength;

    private HuggingFaceTokenizer(long handle, Map<String, String> options) {
        super((Object)handle);
        String val = TokenizersLibrary.LIB.getTruncationStrategy(handle);
        this.truncation = TruncationStrategy.fromValue(val);
        val = TokenizersLibrary.LIB.getPaddingStrategy(handle);
        this.padding = PaddingStrategy.fromValue(val);
        this.maxLength = TokenizersLibrary.LIB.getMaxLength(handle);
        this.stride = TokenizersLibrary.LIB.getStride(handle);
        this.padToMultipleOf = TokenizersLibrary.LIB.getPadToMultipleOf(handle);
        if (options != null) {
            val = options.getOrDefault("addSpecialTokens", "true");
            this.addSpecialTokens = Boolean.parseBoolean(val);
            this.modelMaxLength = ArgumentsUtil.intValue(options, (String)"modelMaxLength", (int)512);
            if (options.containsKey("truncation")) {
                this.truncation = TruncationStrategy.fromValue(options.get("truncation"));
            }
            if (options.containsKey("padding")) {
                this.padding = PaddingStrategy.fromValue(options.get("padding"));
            }
            this.maxLength = ArgumentsUtil.intValue(options, (String)"maxLength", (int)this.maxLength);
            this.stride = ArgumentsUtil.intValue(options, (String)"stride", (int)this.stride);
            this.padToMultipleOf = ArgumentsUtil.intValue(options, (String)"padToMultipleOf", (int)this.padToMultipleOf);
        } else {
            this.addSpecialTokens = true;
            this.modelMaxLength = 512;
        }
        this.updateTruncationAndPadding();
    }

    public static HuggingFaceTokenizer newInstance(String name) {
        return HuggingFaceTokenizer.newInstance(name, null);
    }

    public static HuggingFaceTokenizer newInstance(String identifier, Map<String, String> options) {
        Ec2Utils.callHome((String)"Huggingface");
        LibUtils.checkStatus();
        long handle = TokenizersLibrary.LIB.createTokenizer(identifier);
        return new HuggingFaceTokenizer(handle, options);
    }

    public static HuggingFaceTokenizer newInstance(Path modelPath) throws IOException {
        return HuggingFaceTokenizer.newInstance(modelPath, null);
    }

    public static HuggingFaceTokenizer newInstance(Path modelPath, Map<String, String> options) throws IOException {
        if (Files.isDirectory(modelPath, new LinkOption[0])) {
            modelPath = modelPath.resolve("tokenizer.json");
        }
        try (InputStream is = Files.newInputStream(modelPath, new OpenOption[0]);){
            HuggingFaceTokenizer huggingFaceTokenizer = HuggingFaceTokenizer.newInstance(is, options);
            return huggingFaceTokenizer;
        }
    }

    public static HuggingFaceTokenizer newInstance(InputStream is, Map<String, String> options) throws IOException {
        Ec2Utils.callHome((String)"Huggingface");
        LibUtils.checkStatus();
        String json = Utils.toString((InputStream)is);
        long handle = TokenizersLibrary.LIB.createTokenizerFromString(json);
        return new HuggingFaceTokenizer(handle, options);
    }

    public List<String> tokenize(String sentence) {
        Encoding encoding = this.encode(sentence);
        return Arrays.asList(encoding.getTokens());
    }

    public String buildSentence(List<String> tokens) {
        return String.join((CharSequence)" ", tokens).replace(" ##", "").trim();
    }

    public void close() {
        Long pointer = this.handle.getAndSet(null);
        if (pointer != null) {
            TokenizersLibrary.LIB.deleteTokenizer(pointer);
        }
    }

    public Encoding encode(String text, boolean addSpecialTokens) {
        long encoding = TokenizersLibrary.LIB.encode((Long)this.getHandle(), text, addSpecialTokens);
        return this.toEncoding(encoding);
    }

    public Encoding encode(String text) {
        return this.encode(text, this.addSpecialTokens);
    }

    public Encoding encode(String text, String textPair, boolean addSpecialTokens) {
        long encoding = TokenizersLibrary.LIB.encodeDual((Long)this.getHandle(), text, textPair, addSpecialTokens);
        return this.toEncoding(encoding);
    }

    public Encoding encode(String text, String textPair) {
        return this.encode(text, textPair, this.addSpecialTokens);
    }

    public Encoding encode(List<String> inputs, boolean addSpecialTokens) {
        String[] array = inputs.toArray(Utils.EMPTY_ARRAY);
        return this.encode(array, addSpecialTokens);
    }

    public Encoding encode(List<String> inputs) {
        return this.encode(inputs, this.addSpecialTokens);
    }

    public Encoding encode(String[] inputs, boolean addSpecialTokens) {
        long encoding = TokenizersLibrary.LIB.encodeList((Long)this.getHandle(), inputs, addSpecialTokens);
        return this.toEncoding(encoding);
    }

    public Encoding encode(String[] inputs) {
        return this.encode(inputs, this.addSpecialTokens);
    }

    public Encoding[] batchEncode(List<String> inputs, boolean addSpecialTokens) {
        String[] array = inputs.toArray(Utils.EMPTY_ARRAY);
        return this.batchEncode(array, addSpecialTokens);
    }

    public Encoding[] batchEncode(List<String> inputs) {
        return this.batchEncode(inputs, this.addSpecialTokens);
    }

    public Encoding[] batchEncode(String[] inputs, boolean addSpecialTokens) {
        long[] encodings = TokenizersLibrary.LIB.batchEncode((Long)this.getHandle(), inputs, addSpecialTokens);
        Encoding[] ret = new Encoding[encodings.length];
        for (int i = 0; i < encodings.length; ++i) {
            ret[i] = this.toEncoding(encodings[i]);
        }
        return ret;
    }

    public Encoding[] batchEncode(String[] inputs) {
        return this.batchEncode(inputs, this.addSpecialTokens);
    }

    public Encoding[] batchEncode(PairList<String, String> inputs, boolean addSpecialTokens) {
        String[] text = (String[])inputs.keyArray((Object[])Utils.EMPTY_ARRAY);
        String[] textPair = (String[])inputs.valueArray((Object[])Utils.EMPTY_ARRAY);
        long[] encodings = TokenizersLibrary.LIB.batchEncodePair((Long)this.getHandle(), text, textPair, addSpecialTokens);
        Encoding[] ret = new Encoding[encodings.length];
        for (int i = 0; i < encodings.length; ++i) {
            ret[i] = this.toEncoding(encodings[i]);
        }
        return ret;
    }

    public Encoding[] batchEncode(PairList<String, String> inputs) {
        return this.batchEncode(inputs, this.addSpecialTokens);
    }

    public String decode(long[] ids, boolean skipSpecialTokens) {
        return TokenizersLibrary.LIB.decode((Long)this.getHandle(), ids, skipSpecialTokens);
    }

    public String decode(long[] ids) {
        return this.decode(ids, !this.addSpecialTokens);
    }

    public String[] batchDecode(long[][] batchIds, boolean skipSpecialTokens) {
        return TokenizersLibrary.LIB.batchDecode((Long)this.getHandle(), batchIds, skipSpecialTokens);
    }

    public String[] batchDecode(long[][] batchIds) {
        return this.batchDecode(batchIds, !this.addSpecialTokens);
    }

    public void enableBatch() {
        boolean changed = false;
        if (this.padding == PaddingStrategy.DO_NOT_PAD) {
            changed = true;
            this.padding = PaddingStrategy.LONGEST;
        }
        if (this.truncation == TruncationStrategy.DO_NOT_TRUNCATE) {
            changed = true;
            this.truncation = TruncationStrategy.LONGEST_FIRST;
        }
        if (changed) {
            this.updateTruncationAndPadding();
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> arguments) {
        Builder builder = HuggingFaceTokenizer.builder();
        builder.configure(arguments);
        return builder;
    }

    private void updateTruncationAndPadding() {
        boolean isTruncate;
        boolean bl = isTruncate = this.truncation != TruncationStrategy.DO_NOT_TRUNCATE;
        if (this.padding == PaddingStrategy.MAX_LENGTH || isTruncate) {
            int remainder;
            if (this.maxLength == -1) {
                logger.warn("maxLength is not explicitly specified, use modelMaxLength: " + this.modelMaxLength);
                this.maxLength = this.modelMaxLength;
            } else if (this.maxLength > this.modelMaxLength) {
                logger.warn("maxLength is greater then modelMaxLength, change to: " + this.modelMaxLength);
                this.maxLength = this.modelMaxLength;
            }
            if (this.padding == PaddingStrategy.MAX_LENGTH && isTruncate && this.padToMultipleOf != 0 && (remainder = this.maxLength % this.padToMultipleOf) != 0) {
                int newMaxLength = this.maxLength + this.padToMultipleOf - this.maxLength % this.padToMultipleOf;
                if (newMaxLength > this.modelMaxLength) {
                    newMaxLength -= this.padToMultipleOf;
                }
                logger.warn("maxLength (" + this.maxLength + ") is not a multiple of padToMultipleOf (" + this.padToMultipleOf + "), change to: " + newMaxLength);
                this.maxLength = newMaxLength;
            }
        }
        if (isTruncate) {
            TokenizersLibrary.LIB.setTruncation((Long)this.getHandle(), this.maxLength, this.truncation.name(), this.stride);
        } else {
            TokenizersLibrary.LIB.disableTruncation((Long)this.getHandle());
        }
        if (this.padding == PaddingStrategy.DO_NOT_PAD) {
            TokenizersLibrary.LIB.disablePadding((Long)this.getHandle());
        } else {
            TokenizersLibrary.LIB.setPadding((Long)this.getHandle(), this.maxLength, this.padding.name(), this.padToMultipleOf);
        }
    }

    private Encoding toEncoding(long encoding) {
        long[] ids = TokenizersLibrary.LIB.getTokenIds(encoding);
        long[] typeIds = TokenizersLibrary.LIB.getTypeIds(encoding);
        String[] tokens = TokenizersLibrary.LIB.getTokens(encoding);
        long[] wordIds = TokenizersLibrary.LIB.getWordIds(encoding);
        long[] attentionMask = TokenizersLibrary.LIB.getAttentionMask(encoding);
        long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding);
        CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding);
        long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);
        Encoding[] overflowing = new Encoding[overflowingHandles.length];
        for (int i = 0; i < overflowingHandles.length; ++i) {
            overflowing[i] = this.toEncoding(overflowingHandles[i]);
        }
        TokenizersLibrary.LIB.deleteEncoding(encoding);
        return new Encoding(ids, typeIds, tokens, wordIds, attentionMask, specialTokenMask, charSpans, overflowing);
    }

    protected void finalize() throws Throwable {
        this.close();
        super.finalize();
    }

    public static final class Builder {
        private Path tokenizerPath;
        private NDManager manager;
        private Map<String, String> options = new ConcurrentHashMap<String, String>();

        Builder() {
            this.options.put("addSpecialTokens", "true");
        }

        public Builder optManager(NDManager manager) {
            this.manager = manager;
            return this;
        }

        public Builder optTokenizerName(String tokenizerName) {
            this.options.put("tokenizer", tokenizerName);
            return this;
        }

        public Builder optTokenizerPath(Path tokenizerPath) {
            this.tokenizerPath = tokenizerPath;
            return this;
        }

        public Builder optAddSpecialTokens(boolean addSpecialTokens) {
            this.options.put("addSpecialTokens", String.valueOf(addSpecialTokens));
            return this;
        }

        public Builder optTruncation(boolean enabled) {
            this.options.put("truncation", String.valueOf(enabled));
            return this;
        }

        public Builder optTruncateFirstOnly() {
            this.options.put("truncation", TruncationStrategy.ONLY_FIRST.name());
            return this;
        }

        public Builder optTruncateSecondOnly() {
            this.options.put("truncation", TruncationStrategy.ONLY_SECOND.name());
            return this;
        }

        public Builder optPadding(boolean enabled) {
            this.options.put("padding", String.valueOf(enabled));
            return this;
        }

        public Builder optPadToMaxLength() {
            this.options.put("padding", PaddingStrategy.MAX_LENGTH.name());
            return this;
        }

        public Builder optMaxLength(int maxLength) {
            this.options.put("maxLength", String.valueOf(maxLength));
            return this;
        }

        public Builder optPadToMultipleOf(int padToMultipleOf) {
            this.options.put("padToMultipleOf", String.valueOf(padToMultipleOf));
            return this;
        }

        public Builder optStride(int stride) {
            this.options.put("stride", String.valueOf(stride));
            return this;
        }

        public void configure(Map<String, ?> arguments) {
            for (Map.Entry<String, ?> entry : arguments.entrySet()) {
                this.options.put(entry.getKey(), entry.getValue().toString());
            }
        }

        private HuggingFaceTokenizer managed(HuggingFaceTokenizer tokenizer) {
            if (this.manager != null) {
                this.manager.attachInternal(tokenizer.getUid(), (AutoCloseable)((Object)tokenizer));
            }
            return tokenizer;
        }

        public HuggingFaceTokenizer build() throws IOException {
            String tokenizerName = this.options.get("tokenizer");
            if (tokenizerName != null) {
                return this.managed(HuggingFaceTokenizer.newInstance(tokenizerName, this.options));
            }
            if (this.tokenizerPath == null) {
                throw new IllegalArgumentException("Missing tokenizer path.");
            }
            return this.managed(HuggingFaceTokenizer.newInstance(this.tokenizerPath, this.options));
        }
    }

    private static enum PaddingStrategy {
        LONGEST,
        MAX_LENGTH,
        DO_NOT_PAD;


        static PaddingStrategy fromValue(String value) {
            if ("true".equals(value)) {
                return LONGEST;
            }
            if ("false".equals(value)) {
                return DO_NOT_PAD;
            }
            for (PaddingStrategy strategy : PaddingStrategy.values()) {
                if (!strategy.name().equalsIgnoreCase(value)) continue;
                return strategy;
            }
            throw new IllegalArgumentException("Invalid PaddingStrategy: " + value);
        }
    }

    private static enum TruncationStrategy {
        LONGEST_FIRST,
        ONLY_FIRST,
        ONLY_SECOND,
        DO_NOT_TRUNCATE;


        static TruncationStrategy fromValue(String value) {
            if ("true".equals(value)) {
                return LONGEST_FIRST;
            }
            if ("false".equals(value)) {
                return DO_NOT_TRUNCATE;
            }
            for (TruncationStrategy strategy : TruncationStrategy.values()) {
                if (!strategy.name().equalsIgnoreCase(value)) continue;
                return strategy;
            }
            throw new IllegalArgumentException("Invalid TruncationStrategy: " + value);
        }
    }
}

