package app.megachat.shared.base.util

import app.megachat.shared.base.data.ResultWithError
import kotlin.time.Duration
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.combine
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flatMapLatest
import kotlinx.coroutines.flow.flatMapMerge
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.mapNotNull
import kotlinx.coroutines.flow.merge
import kotlinx.coroutines.flow.stateIn

fun tick(
  after: Duration = Duration.ZERO,
  every: Duration,
  times: Int = Int.MAX_VALUE,
) =
  flow {
    delay(after)
    repeat(times) {
      emit(Unit)
      delay(every)
    }
  }

fun <T> MutableSharedFlow<T>.fastEmit(value: T) {
  check(tryEmit(value)) {
    "Must create MutableSharedFlow with extraBufferCapacity = 1 and onBufferOverflow != SUSPEND"
  }
}

suspend fun <T : Any, R : Any> Flow<T?>.firstNotNull(transform: suspend (T) -> R?): R =
  filterNotNull().mapNotNull(transform).first()

fun <T> StateFlow<T>.refresh(on: Flow<*>): StateFlow<T> =
  DerivedStateFlow(
    getValue = { value },
    flow = merge(this, on.map { value }),
  )

fun <T, R> StateFlow<T>.mapState(transform: (T) -> R): StateFlow<R> =
  DerivedStateFlow(
    getValue = { transform(value) },
    flow = map { transform(it) }
  )

@OptIn(ExperimentalCoroutinesApi::class)
fun <T, R> StateFlow<T>.flatMapLatestState(transform: (T) -> StateFlow<R>): StateFlow<R> =
  DerivedStateFlow(
    getValue = { transform(value).value },
    flow = flatMapLatest { transform(it) }
  )

@OptIn(ExperimentalCoroutinesApi::class)
fun <T, R> StateFlow<T>.flatMapMergeState(transform: (T) -> StateFlow<R>): StateFlow<R> =
  DerivedStateFlow(
    getValue = { transform(value).value },
    flow = flatMapMerge { transform(it) }
  )

inline fun <reified T : Any, E : Any, T2 : Any, E2 : Any> combineStates(
  flows: List<StateFlow<ResultWithError<T, E>>>,
  crossinline onSuccess: (List<T>) -> T2,
  crossinline onFailure: (List<E>) -> E2,
): StateFlow<ResultWithError<T2, E2>> =
  combineStates(flows) { results ->
    ResultWithError(
      result = results.mapNotNull { it.result }.takeIf { it.isNotEmpty() }?.let(onSuccess),
      error = results.mapNotNull { it.error }.takeIf { it.isNotEmpty() }?.let(onFailure),
    )
  }

inline fun <reified T, R> combineStates(
  flows: List<StateFlow<T>>,
  crossinline transform: (Array<T>) -> R,
): StateFlow<R> =
  DerivedStateFlow(
    getValue = { transform(flows.map { it.value }.toTypedArray()) },
    flow = combine(flows, transform)
  )

inline fun <reified T1, reified T2, reified R> combineStates(
  flow1: StateFlow<T1>,
  flow2: StateFlow<T2>,
  noinline transform: (T1, T2) -> R,
) : StateFlow<R> =
  DerivedStateFlow(
    getValue = { transform(flow1.value, flow2.value) },
    flow = combine(flow1, flow2, transform)
  )

inline fun <reified T1, reified T2, reified T3, reified R> combineStates(
  flow1: StateFlow<T1>,
  flow2: StateFlow<T2>,
  flow3: StateFlow<T3>,
  noinline transform: (T1, T2, T3) -> R,
) : StateFlow<R> =
  DerivedStateFlow(
    getValue = { transform(flow1.value, flow2.value, flow3.value) },
    flow = combine(flow1, flow2, flow3, transform)
  )

inline fun <reified T1, reified T2, reified T3, reified T4, reified R> combineStates(
  flow1: StateFlow<T1>,
  flow2: StateFlow<T2>,
  flow3: StateFlow<T3>,
  flow4: StateFlow<T4>,
  noinline transform: (T1, T2, T3, T4) -> R,
) : StateFlow<R> =
  DerivedStateFlow(
    getValue = { transform(flow1.value, flow2.value, flow3.value, flow4.value) },
    flow = combine(flow1, flow2, flow3, flow4, transform)
  )

inline fun <reified T1, reified T2, reified T3, reified T4, reified T5, reified R> combineStates(
  flow1: StateFlow<T1>,
  flow2: StateFlow<T2>,
  flow3: StateFlow<T3>,
  flow4: StateFlow<T4>,
  flow5: StateFlow<T5>,
  noinline transform: (T1, T2, T3, T4, T5) -> R,
) : StateFlow<R> =
  DerivedStateFlow(
    getValue = { transform(flow1.value, flow2.value, flow3.value, flow4.value, flow5.value) },
    flow = combine(flow1, flow2, flow3, flow4, flow5, transform)
  )

/** See https://github.com/Kotlin/kotlinx.coroutines/issues/2631#issuecomment-870565860 */
class DerivedStateFlow<T>(
  private val getValue: () -> T,
  private val flow: Flow<T>,
) : StateFlow<T> {

  override val replayCache: List<T> get() = listOf(value)

  override val value: T get() = getValue()

  override suspend fun collect(collector: FlowCollector<T>): Nothing =
    coroutineScope { flow.distinctUntilChanged().stateIn(this).collect(collector) }
}
