package youversion.red.moments.util

import kotlin.coroutines.CoroutineContext
import kotlin.math.abs
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.async
import kotlinx.coroutines.cancelChildren
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.launch
import red.platform.Log
import red.platform.LogLevel
import red.platform.threads.AtomicInt
import red.platform.threads.AtomicReference
import red.platform.threads.SuspendedLock
import red.platform.threads.decr
import red.platform.threads.freeze
import red.platform.threads.getValue
import red.platform.threads.incr
import red.platform.threads.set
import red.platform.threads.setValue
import red.platform.threads.sync
import red.tasks.CoroutineDispatchContext
import red.tasks.CoroutineDispatchers.withContext
import red.tasks.CoroutineDispatchers.withMain

enum class PagingState {
    Idle,
    Paging,
    Done
}

interface PageDataSource<T> {

    val pageSize: Int
    val maxPageBuffer: Int

    fun isFetchable(item: T): Boolean

    suspend fun fetchPage(pageIndex: Int): Int?
    suspend fun getPage(pageIndex: Int): Flow<List<T>>
}

interface PageListener<T> {

    fun onSizeChanged(pageIndex: Int, newSize: Int)
    fun onPageChanged(pageIndex: Int, items: List<T>)
    fun onError(pageIndex: Int, exception: Exception)
    fun onStateChange(oldState: PagingState, newState: PagingState)
}

private class CloseableCoroutineScope(context: CoroutineContext) : CoroutineScope {

    override val coroutineContext: CoroutineContext = context

    fun close() {
        // KJB: NOTE: I think this should be cancel, but on iOS it was throwing an exception that can't be caught
        coroutineContext.cancelChildren()
    }
}

class PagedList<T>(
    private val dataSource: PageDataSource<T>,
    private val listener: PageListener<T>,
    coroutineContext: CoroutineContext,
    private val fetchDispatchContext: CoroutineDispatchContext,
    private val pageDispatchContext: CoroutineDispatchContext
) {

    private val _state = AtomicReference(PagingState.Idle)
    private val _size = AtomicInt(-1)
    private val pages = AtomicReference(emptyMap<Int, List<T>>())
    private val currentPageIndex = AtomicInt(-dataSource.maxPageBuffer)
    private val activeFlows = AtomicReference(emptyList<Deferred<Unit>>())

    private val activeFetches = AtomicInt(0)
    private val maxFetchablePageIndex = AtomicInt(-1)
    private val lastFetchableItem = AtomicInt(0)
    private val fetchLock = SuspendedLock()
    private val pageLock = SuspendedLock()
    private val scope = CloseableCoroutineScope(SupervisorJob() + coroutineContext)

    init {
        moveTo(0)
        fetchPage(0)
        freeze()
    }

    val state: PagingState
        get() = _state.value

    val size: Int
        get() {
            val s = _size.getValue().takeIf { it != -1 } ?: 0
            if (state != PagingState.Done) {
                return s + 1
            }
            return s
        }

    fun close() {
        activeFlows.set(emptyList())
        pages.set(emptyMap())
        scope.close()
    }

    private fun moveTo(pageIndex: Int) {
        val diff = abs(pageIndex - currentPageIndex.getValue())
        // we're outside of the bounds of what we need

//        if (Log.level == LogLevel.DEBUG) {
//            Log.d("PagedList", "moveTo? $pageIndex -> ${currentPageIndex.getValue()} -> ${dataSource.maxPageBuffer} -> $diff")
//        }
        val pageBuffer = dataSource.maxPageBuffer / 2
        if (diff >= pageBuffer) {
//            if (Log.level == LogLevel.DEBUG) {
//                Log.d("PagedList", "moveTo($pageIndex)")
//            }
            currentPageIndex.setValue(pageIndex)
            // get all items before current page
            val startPage = pageIndex - pageBuffer
            // get all items after current page
            val endPage = pageIndex + pageBuffer
            // get a list of mutable pages
            val newPages = pages.value.toMutableMap()
            // empty out all old pages
            for (index in 0 until startPage) {
                newPages.remove(index)
            }
            for (index in endPage until newPages.size) {
                newPages.remove(index)
            }
            // store the newly emptied set of pages
            pages.set(newPages)
            // stop monitoring those pages for changes
            activeFlows.value.forEach {
                it.cancel()
            }
            // start monitoring the new pages
            val newActiveFlows = mutableListOf<Deferred<Unit>>()
            for (index in startPage until endPage) {
                if (index >= 0) {
                    newActiveFlows += collectPageAsync(index)
                }
            }
            activeFlows.set(newActiveFlows.freeze())
        }
    }

    private fun setState(state: PagingState) {
        val oldState = _state.value
        _state.set(state)
        listener.onStateChange(oldState, state)
        if (state == PagingState.Done && _size.getValue() == -1) {
            _size.setValue(0)
            listener.onSizeChanged(0, 0)
        }
    }

    private fun fetchPage(pageIndex: Int) {
        if (pageIndex > maxFetchablePageIndex.getValue()) {
            if (Log.level == LogLevel.DEBUG) {
                Log.d("PagedList", "fetchPage($pageIndex)")
            }
            maxFetchablePageIndex.setValue(pageIndex)
            if (activeFetches.incr() == 1) {
                setState(PagingState.Paging)
            }

            scope.launch {
                try {
                    withContext(fetchDispatchContext) {
                        fetchLock.sync {
                            var hasMore = true
                            try {
                                hasMore = dataSource.fetchPage(pageIndex) != null
                            } catch (e: Exception) {
                                Log.e("PagedList", "Error during page fetch", e)
                                e.freeze()
                                withMain {
                                    listener.onError(pageIndex, e)
                                }
                            } finally {
                                if (activeFetches.decr() == 0) {
                                    withMain {
                                        setState(if (hasMore) PagingState.Idle else PagingState.Done)
                                    }
                                }
                            }
                        }
                    }
                } catch (e: Exception) {
                    Log.e("PagedList", "Error during page fetch", e)
                }
            }
        }
    }

    private fun isEqual(a: List<T>?, b: List<T>?): Boolean {
        if (a == null || b == null) {
            return a == b
        }
        if (a.size != b.size) {
            return false
        }
        for (i in a.indices) {
            if (a[i] != b[i]) {
                return false
            }
        }
        return true
    }

    private fun updateSize(page: List<T>, pageIndex: Int) {
        if (page.isNotEmpty()) {
            val newSize = (pageIndex * dataSource.pageSize) + page.size
            if (newSize > _size.getValue()) {
                _size.setValue(newSize)
                listener.onSizeChanged(pageIndex, newSize)
            }
        }
    }

    private fun collectPageAsync(pageIndex: Int): Deferred<Unit> =
        scope.async {
            try {
                withContext(pageDispatchContext) {
                    if (Log.level == LogLevel.DEBUG) {
                        Log.d("PagedList", "collectPageAsync($pageIndex)")
                    }
                    dataSource.getPage(pageIndex).collect { page ->
                        pageLock.sync {
                            val lastPage = pages.value[pageIndex]
                            if (!isEqual(lastPage, page)) {
                                val currentPages = pages.value.toMutableMap()
                                currentPages[pageIndex] = page.freeze()
                                currentPages.freeze()
                                withMain {
                                    pages.set(currentPages)
                                    listener.onPageChanged(pageIndex, page)
                                    updateSize(page, pageIndex)
                                }
                            } else {
                                updateSize(page, pageIndex)
                            }
                        }
                    }
                }
            } catch (e: Exception) {
                Log.e("PagedList", "Error during page collection", e)
            }
        }

    operator fun get(index: Int): T? {
        if (Log.level == LogLevel.DEBUG) {
            Log.d("PagedList", "get($index)")
        }
        val pageIndex = index / dataSource.pageSize
        moveTo(pageIndex)
        val page = pages.value[pageIndex] ?: emptyList()
        val pageItemIndex = index % dataSource.pageSize
        val item = page.getOrNull(pageItemIndex)
        item?.takeIf { index > lastFetchableItem.getValue() && dataSource.isFetchable(it) }?.let {
            lastFetchableItem.setValue(index)
            if (Log.level == LogLevel.DEBUG) {
                Log.d("PagedList", "Page size: ${dataSource.pageSize}")
                Log.d("PagedList", "Last fetchable index: $index -> $pageItemIndex")
            }
            val diff = dataSource.pageSize - pageItemIndex
            if (Log.level == LogLevel.DEBUG) {
                Log.d("PagedList", "Last fetchable diff: $diff")
            }
            if (diff <= 3) {
                val lastPageIndex = maxFetchablePageIndex.getValue()
                val startPageIndex = lastPageIndex.takeIf { it > -1 } ?: 0
                val endPageIndex = lastPageIndex.takeIf { it == -1 } ?: pageIndex
                if (Log.level == LogLevel.DEBUG) {
                    Log.d("PagedList", "Pages: $startPageIndex -> $endPageIndex")
                }
                for (i in startPageIndex..endPageIndex + 1) {
                    fetchPage(i)
                }
            }
        }
        return item
    }
}
