/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data;

import jcuda.Pointer;
import jcuda.runtime.JCuda;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.instructions.gpu.context.CSRPointer;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysds.runtime.matrix.data.LibMatrixCuDNN;

public class LibMatrixCuDNNInputRowFetcher
extends LibMatrixCUDA
implements AutoCloseable {
    GPUContext gCtx;
    String instName;
    int numColumns;
    boolean isInputInSparseFormat;
    Object inPointer;
    Pointer outPointer;

    public LibMatrixCuDNNInputRowFetcher(GPUContext gCtx, String instName, MatrixObject image) {
        this.gCtx = gCtx;
        this.instName = instName;
        this.numColumns = LibMatrixCUDA.toInt(image.getNumColumns());
        this.isInputInSparseFormat = LibMatrixCUDA.isInSparseFormat(gCtx, image);
        this.inPointer = this.isInputInSparseFormat ? LibMatrixCUDA.getSparsePointer(gCtx, image, instName) : LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, image, instName);
        this.outPointer = gCtx.allocate(instName, (long)this.numColumns * (long)sizeOfDataType, false);
    }

    public Pointer getNthRow(int n) {
        if (this.isInputInSparseFormat) {
            JCuda.cudaDeviceSynchronize();
            JCuda.cudaMemset((Pointer)this.outPointer, (int)0, (long)((long)this.numColumns * (long)sizeOfDataType));
            JCuda.cudaDeviceSynchronize();
            LibMatrixCUDA.sliceSparseDense(this.gCtx, this.instName, (CSRPointer)this.inPointer, this.outPointer, n, n, 0, LibMatrixCUDA.toInt(this.numColumns - 1), this.numColumns);
        } else {
            LibMatrixCUDA.sliceDenseDense(this.gCtx, this.instName, (Pointer)this.inPointer, this.outPointer, n, n, 0, LibMatrixCUDA.toInt(this.numColumns - 1), this.numColumns);
        }
        return this.outPointer;
    }

    @Override
    public void close() {
        try {
            this.gCtx.cudaFreeHelper(null, this.outPointer, true);
        }
        catch (DMLRuntimeException e) {
            throw new RuntimeException(e);
        }
    }
}

