/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * license agreements; and to You under the Apache License, version 2.0:
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * This file is part of the Apache Pekko project, which was derived from Akka.
 */

/*
 * Copyright (C) 2021 Lightbend Inc. <https://www.lightbend.com>
 */

package org.apache.pekko.persistence.r2dbc.migration

import scala.concurrent.Await
import scala.concurrent.duration._

import org.apache.pekko
import pekko.Done
import pekko.actor.testkit.typed.scaladsl.LogCapturing
import pekko.actor.testkit.typed.scaladsl.ScalaTestWithActorTestKit
import pekko.actor.typed.ActorSystem
import pekko.persistence.r2dbc.TestActors.Persister
import pekko.persistence.r2dbc.TestConfig
import pekko.persistence.r2dbc.TestData
import pekko.persistence.r2dbc.TestDbLifecycle
import pekko.persistence.typed.PersistenceId
import com.typesafe.config.Config
import com.typesafe.config.ConfigFactory
import org.scalatest.wordspec.AnyWordSpecLike

object MigrationToolSpec {
  val config: Config = ConfigFactory
    .parseString("""
    pekko-persistence-jdbc {
      shared-databases {
        default {
          profile = "slick.jdbc.PostgresProfile$"
          db {
            host = "127.0.0.1"
            url = "jdbc:postgresql://127.0.0.1:5432/postgres?reWriteBatchedInserts=true"
            user = postgres
            password = postgres
            driver = "org.postgresql.Driver"
            numThreads = 20
            maxConnections = 20
            minConnections = 5
          }
        }
      }
    }

    jdbc-journal {
      use-shared-db = "default"
      tables.event_journal.tableName = "jdbc_event_journal"
    }
    jdbc-snapshot-store {
      use-shared-db = "default"
      tables.snapshot.tableName = "jdbc_snapshot"
    }
    jdbc-read-journal {
      use-shared-db = "default"
      tables.event_journal.tableName = "jdbc_event_journal"
    }
    """)
    .withFallback(TestConfig.config)
}

class MigrationToolSpec
    extends ScalaTestWithActorTestKit(MigrationToolSpec.config)
    with AnyWordSpecLike
    with TestDbLifecycle
    with TestData
    with LogCapturing {

  override def typedSystem: ActorSystem[_] = system

  private val migrationConfig = system.settings.config.getConfig("pekko.persistence.r2dbc.migration")
  private val sourceJournalPluginId = "jdbc-journal"
  private val sourceSnapshotPluginId = migrationConfig.getString("source.snapshot-plugin-id")

  private val targetPluginId = migrationConfig.getString("target.persistence-plugin-id")

  private val migration = new MigrationTool(system)

  private val testEnabled: Boolean = {
    // don't run this for Yugabyte since it is using pekko-persistence-jdbc
    system.settings.config.getString("pekko.persistence.r2dbc.dialect") == "postgres"
  }

  override protected def beforeAll(): Unit = {
    super.beforeAll()

    if (testEnabled) {
      Await.result(
        r2dbcExecutor.executeDdl("beforeAll create jdbc tables") { connection =>
          connection.createStatement("""CREATE TABLE IF NOT EXISTS jdbc_event_journal(
                                       |  ordering BIGSERIAL,
                                       |  persistence_id VARCHAR(255) NOT NULL,
                                       |  sequence_number BIGINT NOT NULL,
                                       |  deleted BOOLEAN DEFAULT FALSE NOT NULL,
                                       |
                                       |  writer VARCHAR(255) NOT NULL,
                                       |  write_timestamp BIGINT,
                                       |  adapter_manifest VARCHAR(255),
                                       |
                                       |  event_ser_id INTEGER NOT NULL,
                                       |  event_ser_manifest VARCHAR(255) NOT NULL,
                                       |  event_payload BYTEA NOT NULL,
                                       |
                                       |  meta_ser_id INTEGER,
                                       |  meta_ser_manifest VARCHAR(255),
                                       |  meta_payload BYTEA,
                                       |
                                       |  PRIMARY KEY(persistence_id, sequence_number)
                                       |)""".stripMargin)
        },
        10.seconds)

      Await.result(
        r2dbcExecutor.executeDdl("beforeAll create jdbc tables") { connection =>
          connection.createStatement("""CREATE TABLE IF NOT EXISTS jdbc_snapshot (
                                       |  persistence_id VARCHAR(255) NOT NULL,
                                       |  sequence_number BIGINT NOT NULL,
                                       |  created BIGINT NOT NULL,
                                       |
                                       |  snapshot_ser_id INTEGER NOT NULL,
                                       |  snapshot_ser_manifest VARCHAR(255) NOT NULL,
                                       |  snapshot_payload BYTEA NOT NULL,
                                       |
                                       |  meta_ser_id INTEGER,
                                       |  meta_ser_manifest VARCHAR(255),
                                       |  meta_payload BYTEA,
                                       |
                                       |  PRIMARY KEY(persistence_id, sequence_number)
                                       |)""".stripMargin)
        },
        10.seconds)

      Await.result(
        r2dbcExecutor.updateOne("beforeAll delete jdbc")(_.createStatement("delete from jdbc_event_journal")),
        10.seconds)
      Await.result(
        r2dbcExecutor.updateOne("beforeAll delete jdbc")(_.createStatement("delete from jdbc_snapshot")),
        10.seconds)

      Await.result(migration.migrationDao.createProgressTable(), 10.seconds)
      Await.result(
        r2dbcExecutor.updateOne("beforeAll migration_progress")(_.createStatement("delete from migration_progress")),
        10.seconds)
    }
  }

  private def persistEvents(pid: PersistenceId, events: Seq[String]): Unit = {
    val probe = testKit.createTestProbe[Done]()
    val persister = testKit.spawn(Persister(pid, sourceJournalPluginId, sourceSnapshotPluginId, tags = Set.empty))
    events.foreach { event =>
      persister ! Persister.Persist(event)
    }
    persister ! Persister.Stop(probe.ref)
    probe.expectMessage(Done)
  }

  private def assertEvents(pid: PersistenceId, expectedEvents: Seq[String]): Unit =
    assertState(pid, expectedEvents.mkString("|"))

  private def assertState(pid: PersistenceId, expectedState: String): Unit = {
    val probe = testKit.createTestProbe[Any]()
    val targetPersister =
      testKit.spawn(Persister(pid, targetPluginId + ".journal", targetPluginId + ".snapshot", tags = Set.empty))
    targetPersister ! Persister.GetState(probe.ref)
    probe.expectMessage(expectedState)
    targetPersister ! Persister.Stop(probe.ref)
    probe.expectMessage(Done)
  }

  "MigrationTool" should {
    if (!testEnabled) {
      info(s"MigrationToolSpec not enabled for ${system.settings.config.getString("pekko.persistence.r2dbc.dialect")}")
      pending
    }

    "migrate events of one persistenceId" in {
      val pid = PersistenceId.ofUniqueId(nextPid())

      val events = List("e-1", "e-2", "e-3")
      persistEvents(pid, events)

      migration.migrateEvents(pid.id).futureValue shouldBe 3L

      assertEvents(pid, events)
    }

    "migrate events of a persistenceId several times" in {
      val pid = PersistenceId.ofUniqueId(nextPid())

      val events = List("e-1", "e-2", "e-3")
      persistEvents(pid, events)

      migration.migrateEvents(pid.id).futureValue shouldBe 3L
      assertEvents(pid, events)
      // running again should be idempotent and not fail
      migration.migrateEvents(pid.id).futureValue shouldBe 0L
      assertEvents(pid, events)

      // and running again should find new events
      val moreEvents = List("e-4", "e-5")
      persistEvents(pid, moreEvents)
      migration.migrateEvents(pid.id).futureValue shouldBe 2L

      assertEvents(pid, events ++ moreEvents)
    }

    "migrate snapshot of one persistenceId" in {
      val pid = PersistenceId.ofUniqueId(nextPid())

      persistEvents(pid, List("e-1", "e-2-snap", "e-3"))

      migration.migrateSnapshot(pid.id).futureValue shouldBe 1L

      assertState(pid, "e-1|e-2-snap")
    }

    "migrate snapshot of a persistenceId several times" in {
      val pid = PersistenceId.ofUniqueId(nextPid())

      persistEvents(pid, List("e-1", "e-2-snap", "e-3"))

      migration.migrateSnapshot(pid.id).futureValue shouldBe 1L
      assertState(pid, "e-1|e-2-snap")
      // running again should be idempotent and not fail
      migration.migrateSnapshot(pid.id).futureValue shouldBe 0L
      assertState(pid, "e-1|e-2-snap")

      // and running again should find new snapshot
      persistEvents(pid, List("e-4-snap", "e-5"))
      migration.migrateSnapshot(pid.id).futureValue shouldBe 1L

      assertState(pid, "e-1|e-2-snap|e-3|e-4-snap")
    }

    "update event migration_progress" in {
      val pid = PersistenceId.ofUniqueId(nextPid())
      migration.migrationDao.currentProgress(pid.id).futureValue.map(_.eventSeqNr) shouldBe None

      persistEvents(pid, List("e-1", "e-2", "e-3"))
      migration.migrateEvents(pid.id).futureValue shouldBe 3L
      migration.migrationDao.currentProgress(pid.id).futureValue.map(_.eventSeqNr) shouldBe Some(3L)

      // store and migration some more
      persistEvents(pid, List("e-4", "e-5"))
      migration.migrateEvents(pid.id).futureValue shouldBe 2L
      migration.migrationDao.currentProgress(pid.id).futureValue.map(_.eventSeqNr) shouldBe Some(5L)
    }

    "update snapshot migration_progress" in {
      val pid = PersistenceId.ofUniqueId(nextPid())
      migration.migrationDao.currentProgress(pid.id).futureValue.map(_.snapshotSeqNr) shouldBe None

      persistEvents(pid, List("e-1", "e-2-snap", "e-3"))
      migration.migrateSnapshot(pid.id).futureValue shouldBe 1L
      migration.migrationDao.currentProgress(pid.id).futureValue.map(_.snapshotSeqNr) shouldBe Some(2L)

      // store and migration some more
      persistEvents(pid, List("e-4", "e-5-snap", "e-6"))
      migration.migrateSnapshot(pid.id).futureValue shouldBe 1L
      migration.migrationDao.currentProgress(pid.id).futureValue.map(_.snapshotSeqNr) shouldBe Some(5L)
    }

    "migrate all persistenceIds" in {
      val numberOfPids = 10
      val pids = (1 to numberOfPids).map(_ => PersistenceId.ofUniqueId(nextPid()))
      val events = List("e-1", "e-2", "e-3", "e-4-snap", "e-5", "e-6-snap", "e-7", "e-8", "e-9")

      pids.foreach { pid =>
        persistEvents(pid, events)
      }

      migration.migrateAll().futureValue

      pids.foreach { pid =>
        assertEvents(pid, events)
      }
    }

  }
}
