/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.training.loss.Loss;

public class BertMaskedLanguageModelLoss
extends Loss {
    private int labelIdx;
    private int maskIdx;
    private int logProbsIdx;

    public BertMaskedLanguageModelLoss(int labelIdx, int maskIdx, int logProbsIdx) {
        super("BertMLLoss");
        this.labelIdx = labelIdx;
        this.maskIdx = maskIdx;
        this.logProbsIdx = logProbsIdx;
    }

    @Override
    public NDArray evaluate(NDList labels, NDList predictions) {
        try (NDManager scope = NDManager.subManagerOf(labels);){
            scope.tempAttachAll(labels, predictions);
            NDArray logProbs = (NDArray)predictions.get(this.logProbsIdx);
            int dictionarySize = (int)logProbs.getShape().get(1);
            NDArray targetIds = ((NDArray)labels.get(this.labelIdx)).flatten();
            NDArray mask = ((NDArray)labels.get(this.maskIdx)).flatten().toType(DataType.FLOAT32, false);
            NDArray targetOneHots = targetIds.oneHot(dictionarySize);
            NDArray perExampleLoss = logProbs.mul(targetOneHots).sum(new int[]{1}).mul(-1);
            NDArray numerator = perExampleLoss.mul(mask).sum();
            NDArray denominator = mask.sum().add(Float.valueOf(1.0E-5f));
            NDArray result = numerator.div(denominator);
            NDArray nDArray = scope.ret(result);
            return nDArray;
        }
    }

    public NDArray accuracy(NDList labels, NDList predictions) {
        try (NDManager scope = NDManager.subManagerOf(labels);){
            scope.tempAttachAll(labels, predictions);
            NDArray mask = ((NDArray)labels.get(this.maskIdx)).flatten();
            NDArray targetIds = ((NDArray)labels.get(this.labelIdx)).flatten();
            NDArray logProbs = (NDArray)predictions.get(this.logProbsIdx);
            NDArray predictedIs = logProbs.argMax(1).toType(DataType.INT32, false);
            NDArray equal = predictedIs.eq(targetIds).mul(mask);
            NDArray equalCount = equal.sum().toType(DataType.FLOAT32, false);
            NDArray count = mask.sum().toType(DataType.FLOAT32, false);
            NDArray result = equalCount.div(count);
            NDArray nDArray = scope.ret(result);
            return nDArray;
        }
    }
}

