#######################################################################
#
# Package Name: SeqArray
#
# Description: A list of units of selected variants
#


#######################################################################
# Filter out the unit variants according to MAF, MAC and missing rates
#
seqUnitFilterCond <- function(gdsfile, units, maf=NaN, mac=1L, missing.rate=NaN,
    minsize=1L, parallel=seqGetParallel(), verbose=TRUE)
{
    # check
    stopifnot(inherits(gdsfile, "SeqVarGDSClass"))
    stopifnot(inherits(units, "SeqUnitListClass"))
    stopifnot(is.numeric(maf), length(maf) %in% 1:2)
    stopifnot(is.numeric(mac), length(mac) %in% 1:2)
    stopifnot(is.numeric(missing.rate), length(missing.rate)==1L)
    stopifnot(is.numeric(minsize), length(minsize)==1L, minsize>=0L)
    stopifnot(is.logical(verbose), length(verbose)==1L)

    # save state
    seqSetFilter(gdsfile, variant.sel=unlist(units$index), action="push+set",
        warn=FALSE, verbose=FALSE)
    on.exit({ seqSetFilter(gdsfile, action="pop", verbose=FALSE) })
    if (verbose)
    {
        dm <- .seldim(gdsfile)
        .cat("Dataset: ", .pretty(dm[2L]), " sample", .plural(dm[2L]),
            " x ", .pretty(dm[3L]), " variant", .plural(dm[3L]))
    }

    # calculate # of ref. allele and missing genotype
    if (verbose)
        cat("Calculating MAF, MAC and missing rates ...\n")
    # get MAF/MAC/missing rate
    v <- .Get_MAF_MAC_Missing(gdsfile, parallel, verbose)
    # show maf, mac and missing rate
    if (verbose)
    {
        if (length(maf) == 1L)
            smaf <- paste0(">=", maf)
        else
            smaf <- paste0("[", maf[1L], ",", maf[2L], ")")
        if (length(mac) == 1L)
            smac <- paste0(">=", mac)
        else
            smac <- paste0("[", mac[1L], ",", mac[2L], ")")
        .cat("[Filter] MAF: ", smaf, "; MAC: ", smac, "; missing rate: ",
            ifelse(is.na(missing.rate), "no", paste0("<=", missing.rate)))
    }

    # selection
    sel <- rep(TRUE, length(v$maf))
    # check mac[1] <= ... < mac[2]
    if (!is.na(mac[1L]))
        sel <- sel & (mac[1L] <= v$mac)
    if (!is.na(mac[2L]))
        sel <- sel & (v$mac < mac[2L])
    # check maf[1] <= ... < maf[2]
    if (any(!is.na(maf)))
    {
        if (!is.na(maf[1L]))
            sel <- sel & (maf[1L] <= v$maf)
        if (!is.na(maf[2L]))
            sel <- sel & (v$maf < maf[2L])
    }
    # check ... <= missing.rate
    if (!is.na(missing.rate))
        sel <- sel & (v$miss <= missing.rate)
    if (all(sel))
    {
        if (verbose) cat("No variant excluded!")
        return(units)
    } else {
        if (verbose)
        {
            n <- length(sel) - sum(sel)
            .cat("Excluding ", .pretty(n), " variant", .plural(n), " ...")
        }
    }

    # global sel according to all variants
    x <- rep(FALSE, .dim(gdsfile)[3L])
    x[seqGetData(gdsfile, "$variant_index")] <- sel
    sel <- x

    # for each unit
    idx <- lapply(units$index, function(ii) {
        s <- sel[ii]
        if (all(s, na.rm=TRUE)) return(ii)
        ii[s]
    })
    units$index <- idx
    # check unit size
    x <- lengths(idx) >= minsize
    if (!all(x))
    {
        if (verbose)
        {
            n <- length(idx) - sum(x)
            .cat("Remove ", .pretty(n), " unit", .plural(n))
        }
        units <- seqUnitSubset(units, x)
    }
    units
}


#######################################################################
# Get a list of units of selected variants via sliding windows based on basepairs
#
seqUnitSlidingWindows <- function(gdsfile, win.size=5000L, win.shift=2500L,
    win.start=0L, dup.rm=TRUE, verbose=TRUE)
{
    # check
    stopifnot(inherits(gdsfile, "SeqVarGDSClass"))
    stopifnot(is.numeric(win.size), length(win.size)==1L, win.size>0L)
    stopifnot(is.numeric(win.shift), length(win.shift)==1L, win.shift>0L)
    stopifnot(is.numeric(win.start), is.finite(win.start), length(win.start)==1L)
    stopifnot(is.logical(dup.rm), length(dup.rm)==1L)
    stopifnot(is.logical(verbose), length(verbose)==1L)

    # save state
    seqSetFilter(gdsfile, action="push", verbose=FALSE)
    on.exit({ seqSetFilter(gdsfile, action="pop", verbose=FALSE) })

    # chromosome list
    chrlst <- unique(seqGetData(gdsfile, "chromosome"))
    if (length(chrlst) <= 0L) stop("No selected variant!")
    ans_tab <- ans_idx <- NULL
    for (chr in chrlst)
    {
        if (verbose)
            cat("Chromosome ", chr, ", ", sep="")
        seqSetFilterChrom(gdsfile, include=chr, intersect=TRUE, verbose=FALSE)
        idx <- which(seqGetFilter(gdsfile)$variant.sel)
        pos <- seqGetData(gdsfile, "position")
        if (!is.unsorted(pos) || pos[1L]>pos[length(pos)])
        {
            i <- order(pos)
            pos <- pos[i]; idx <- idx[i]
        }
        # generated by sliding windows
        v <- .Call(SEQ_Unit_SlidingWindows, pos, idx, win.size, win.shift, win.start,
            dup.rm, integer(length(pos)))
        names(v[[2L]]) <- rep(paste0("chr", chr), length(v[[2L]]))
        ans_idx <- c(ans_idx, v[[2L]])
        v <- v[[1L]]
        ans_tab <- rbind(ans_tab, data.frame(
            chr=rep(chr, length(v)), start=v, end=as.integer(v+win.size-1L),
            stringsAsFactors=FALSE))
        if (verbose)
            cat("# of units: ", length(v), "\n", sep="")
        # reset
        seqSetFilter(gdsfile, action="pop", verbose=FALSE)
        seqSetFilter(gdsfile, action="push", verbose=FALSE)
    }
    if (verbose)
        cat("# of units in total: ", length(ans_idx), "\n", sep="")

    # output
    ans <- list(desp=ans_tab, index=ans_idx)
    class(ans) <- "SeqUnitListClass"
    ans
}


#######################################################################
# Create, subset and merge the units
#
seqUnitCreate <- function(idx, desp=NULL)
{
    # check
    stopifnot(is.list(idx))
    stopifnot(is.null(desp) | is.data.frame(desp))
    if (is.data.frame(desp))
        stopifnot(length(idx) == nrow(desp))
    if (is.null(desp))
        desp <- data.frame(id=seq_along(idx))
    for (i in seq_along(idx))
    {
        k <- idx[[i]]
        if (is.numeric(k) && is.vector(k))
        {
            if (anyNA(k) || any(k < 1L, na.rm=TRUE))
            {
                idx[[i]] <- k[!is.na(k) & (k >= 1L)]
            }
        } else
            stop(sprintf("idx[[%d]] should be a numeric index vector.", i))
    }

    # output
    ans <- list(desp=desp, index=idx)
    class(ans) <- "SeqUnitListClass"
    ans
}

seqUnitSubset <- function(units, i)
{
    # check
    stopifnot(inherits(units, "SeqUnitListClass"))
    stopifnot(is.vector(i), is.numeric(i) | is.logical(i))
    n <- length(units$index)
    if (is.logical(i))
    {
        if (n != length(i))
            stop("'i' should be a logical vector of length ", n, ".")
        i[is.na(i)] <- FALSE
    } else {
        x <- (1L <= i) & (i <= n)
        if (anyNA(x))
            stop("'i' should not have NA.")
        if (!all(x))
            stop("'i' should be between 1 and ", n, ".")
    }

    # output
    v <- units$desp[i, ]
    if (!is.data.frame(v))  # if only one column
    {
        v <- data.frame(v)
        names(v) <- names(units$desp)
    }
    rownames(v) <- NULL
    units$desp <- v
    units$index <- units$index[i]
    units
}

seqUnitMerge <- function(ut1, ut2)
{
    # check
    stopifnot(inherits(ut1, "SeqUnitListClass"))
    stopifnot(inherits(ut2, "SeqUnitListClass"))
    # output
    ut1$desp <- rbind(ut1$desp, ut2$desp)
    ut1$index <- c(ut1$index, ut2$index)
    ut1
}


#######################################################################
# Apply a user-defined function to each unit
#
seqUnitApply <- function(gdsfile, units, var.name, FUN,
    as.is=c("none", "list", "unlist"), parallel=FALSE, ..., .bl_size=256L,
    .progress=FALSE, .useraw=FALSE, .padNA=TRUE, .tolist=FALSE, .envir=NULL)
{
    # check
    stopifnot(inherits(gdsfile, "SeqVarGDSClass"))
    stopifnot(inherits(units, "SeqUnitListClass"))
    stopifnot(is.character(var.name), length(var.name)>0L)
    FUN <- match.fun(FUN)
    stopifnot(length(units) > 0L)
    as.is <- match.arg(as.is)
    stopifnot(is.numeric(.bl_size), length(.bl_size)==1L, .bl_size>0L)
    stopifnot(is.logical(.progress), length(.progress)==1L)
    stopifnot(is.logical(.useraw), length(.useraw)==1L)
    stopifnot(is.logical(.padNA), length(.padNA)==1L)
    stopifnot(is.null(.envir) || is.environment(.envir) || is.list(.envir))

    # further check units
    stopifnot(is.data.frame(units$desp))
    stopifnot(is.list(units$index))
    stopifnot(nrow(units$desp) == length(units$index))
    stopifnot(all(sapply(units$index, is.integer)))

    # initialize internally
    .clear_varmap(gdsfile)
    .Call(SEQ_IntAssign, process_index, 1L)
    .Call(SEQ_IntAssign, process_count, 1L)

    # get the number of workers
    njobs <- .NumParallel(parallel)
    parallel <- .McoreParallel(parallel)
    if (njobs == 1L)
    {
        # save state
        seqSetFilter(gdsfile, action="push", verbose=FALSE)
        on.exit({ seqSetFilter(gdsfile, action="pop", verbose=FALSE) })
        # progress information
        nl <- length(units$index)
        progress <- if (.progress) .seqProgress(nl, njobs) else NULL
        # for-loop
        ans <- vector("list", nl)
        for (i in seq_len(nl))
        {
            seqSetFilter(gdsfile, variant.sel=units$index[[i]], verbose=FALSE)
            x <- seqGetData(gdsfile, var.name, .useraw, .padNA, .tolist, .envir)
            v <- FUN(x, ...)
            if (!is.null(v)) ans[[i]] <- v
            .seqProgForward(progress, 1L)
        }
        # finalize
        remove(progress)

    } else {

        # parameters for load balancing
        nl <- length(units$index)
        .bl_size <- as.integer(.bl_size)
        if (.bl_size * njobs > nl)
        {
            .bl_size <- nl %/% njobs
            if (.bl_size <= 0L) .bl_size <- 1L
        }
        totnum <- nl %/% .bl_size
        if (nl %% .bl_size) totnum <- totnum + 1L

        # multiple processes
        if (.IsForking(parallel))
        {
            # forking
            .packageEnv$gdsfile <- gdsfile
            .packageEnv$units <- units$index
            .packageEnv$var.name <- var.name
            .packageEnv$envir <- .envir
            parallel <- parallel::makeForkCluster(njobs)
            on.exit({
                with(.packageEnv, gdsfile <- units <- var.name <- envir <- NULL)
                stopCluster(parallel)
            })
        } else {
            need_cluster <- is.numeric(parallel) || is.logical(parallel)
            if (need_cluster)
            {
                # no forking on windows
                parallel <- makeCluster(njobs)
            }
            # distribute the parameters to each node
            clusterCall(parallel, function(fn, ut, vn, ss, env) {
                f <- SeqArray::seqOpen(fn, allow.duplicate=TRUE)
                .packageEnv$gdsfile <- f
                SeqArray::seqSetFilter(f, sample.sel=ss, verbose=FALSE)
                .packageEnv$units <- ut
                .packageEnv$var.name <- vn
                .packageEnv$envir <- env
                invisible()
            }, fn=gdsfile$filename, ut=units$index, vn=var.name,
                ss=.Call(SEQ_GetSpaceSample, gdsfile), env=.envir)
            # finalize
            on.exit({
                clusterCall(parallel, function() {
                    SeqArray::seqClose(.packageEnv$gdsfile)
                    with(.packageEnv, gdsfile <- units <- var.name <- envir <- NULL)
                })
            })
            if (need_cluster)
                on.exit(stopCluster(parallel), add=TRUE)
        }
        # initialize internally
        clusterApply(parallel, 1:njobs, function(i, njobs) {
            .Call(SEQ_IntAssign, process_index, i)
            .Call(SEQ_IntAssign, process_count, njobs)
        }, njobs=njobs)

        # progress information
        progress <- if (.progress) .seqProgress(length(units$index), njobs) else NULL
        # distributed for-loop
        ans <- .DynamicClusterCall(parallel, totnum,
            .fun = function(i, FUN, .useraw, .bl_size, ...)
        {
            # chuck size
            n <- .bl_size
            k <- (i - 1L) * n
            if (k + n > length(.packageEnv$units))
                n <- length(.packageEnv$units) - k
            # temporary
            f <- .packageEnv$gdsfile
            vn <- .packageEnv$var.name
            env <- .packageEnv$envir
            rv <- vector("list", n)
            # set variant filter for each sub unit
            for (j in seq_len(n))
            {
                seqSetFilter(f, variant.sel=.packageEnv$units[[j+k]], verbose=FALSE)
                x <- seqGetData(f, vn, .useraw, .padNA, .tolist, env)
                v <- FUN(x, ...)
                if (!is.null(v)) rv[[j]] <- v
            }
            # return
            rv
        }, .combinefun="list",
            .updatefun=function(i) .seqProgForward(progress, .bl_size),
            FUN=FUN, .useraw=.useraw, .bl_size=.bl_size, ...)
        ans <- unlist(ans, recursive=FALSE)
        # finalize
        remove(progress)
    }

    # output
    if (as.is == "unlist")
        ans <- unlist(ans, recursive=FALSE)
    else if (as.is == "none")
        ans <- invisible()
    ans
}


print.SeqUnitListClass <- function(x, ...) str(x, list.len=6L)

summary.SeqUnitListClass <- function(object, ...)
{
    .cat("# of units: ", length(object$index))
    .cat("# of variants in total: ", .pretty(length(unique(object$index))))
    v <- lengths(object$index, use.names=FALSE)
    .cat("Avg # of variants per unit: ", mean(v))
    .cat("Median # of variants in a unit: ", median(v))
    .cat("Min # of variants in a unit: ", min(v))
    .cat("Max # of variants in a unit: ", max(v))
    .cat("SD # of variants in a unit: ", sd(v))
    invisible()
}
