F[unctional] Core, IO[mperative] Shell

May 4, 2024

Let's design a hit counter using a functional domain model based on the Tagless-Final style and The Clean Architecture.

Layers

The domain is the lowest layer. Each additional layer may access any of the layers beneath it.

  1. Domain
  2. Use Cases
  3. Drivers
  4. Runners

Domain

In our domain, there exists a count entity and operations to increment and get it.

package domain.entities:

  case class Count(value: Int)
package domain.operations:

  import domain.entities.Count

  trait Counter[F[_]]:

    def increment: F[Unit]
    def get: F[Count]

Use cases

In our application, we have one use case: increment the count and then get its new value.

This layer imports from the domain layer.

package usecases:

  import cats.Monad
  import cats.syntax.flatMap.toFlatMapOps
  import cats.syntax.functor.toFunctorOps
  import domain.entities.Count

  trait Counter[F[_]]:

    def counter: domain.operations.Counter[F]

    def incrementAndGet(implicit M: Monad[F]): F[Count] =
      for
        _ <- counter.increment
        c <- counter.get
      yield c

Drivers

In this layer, we have a simple in-memory counter, which is useful for testing, and an RDBMS counter, which is useful for production. We don't yet introduce I/O.

This layer imports from the use cases layer and the domain layer.

In-memory counter

package drivers.mem:

  import cats.Functor
  import cats.Functor.ops.toAllFunctorOps
  import cats.effect.kernel.Ref
  import cats.effect.kernel.Ref.Make
  import domain.entities.Count

  def counter[F[_]: Make: Functor]: F[usecases.Counter[F]] =
    for
      ref <- Ref.of[F, Int](0)
    yield 
      new usecases.Counter[F]:

        override def counter =
          new domain.operations.Counter[F]:

            override def increment: F[Unit] =
              ref.update(x => x + 1)

            override def get: F[Count] =
              ref.get.map(x => Count(x))

RDBMS counter

package drivers.db:

  import cats.Applicative
  import cats.Functor.ops.toAllFunctorOps
  import domain.entities.Count
  import java.sql.Connection

  trait Transactor[F[_]]:
    def transact[A](k: Connection => F[A]): F[A]

  def init[F[_]: Transactor: Applicative]: F[Unit] =
    summon[Transactor[F]].transact { c =>
      summon[Applicative[F]].pure {
        val s1 = c.createStatement()
        s1.executeUpdate(
          """|CREATE TABLE DATA
             |  ( `KEY` VARCHAR(256) NOT NULL
             |  , `VALUE` VARCHAR(256) NOT NULL
             |  , PRIMARY KEY (`KEY`)
             |  )
             |""".stripMargin
        )
        s1.close()

        val s2 = c.createStatement()
        s2.executeUpdate(
          """|INSERT INTO DATA
             |  ( `KEY`
             |  , `VALUE`
             |  )
             |  VALUES
             |  ( 'COUNT'
             |  , '0'
             |  )
             |""".stripMargin
        )
        s2.close()
      }
    }

  def counter[F[_]: Transactor: Applicative]: usecases.Counter[F] =
    new usecases.Counter[F]:

      override def counter: domain.operations.Counter[F] =
        new domain.operations.Counter[F]:

          override def increment: F[Unit] =
            summon[Transactor[F]].transact { c =>
              for
                oldCount <- get
              yield {
                val stmt =
                  c.prepareStatement(
                    """|UPDATE DATA
                       |  SET `VALUE` = ?
                       |  WHERE `KEY` = 'COUNT'
                       |""".stripMargin
                  )
                val newValue = oldCount.value + 1
                stmt.setInt(1, newValue)
                stmt.executeUpdate()
                stmt.close()
              }
            }

          override def get: F[Count] =
            summon[Transactor[F]].transact { c =>
              summon[Applicative[F]].pure {
                val rs =
                  c.createStatement.executeQuery(
                    """|SELECT `VALUE`
                       |  FROM DATA
                       |  WHERE `KEY` = 'COUNT'
                       |""".stripMargin
                  )
                rs.next()
                val count = rs.getInt("VALUE")
                Count(count)
              }
            }

Runners

We wrap everything with a main method that uses I/O to exercise and validate our implementations.

This layer imports from the drivers layer, the use cases layer, and the domain layer.

package main:

  import cats.effect.IO
  import cats.effect.IOApp
  import cats.effect.std.Console
  import domain.entities.Count
  import drivers.db.Transactor
  import java.sql.Connection
  import java.sql.DriverManager
  import scala.util.chaining.scalaUtilChainingOps

  object Main extends IOApp.Simple:

    given tx: Transactor[IO] with

      private lazy val c: Connection =
        Class.forName("org.h2.Driver")
        val connection: Connection =
          DriverManager.getConnection("jdbc:h2:mem:", "sa", "")
        connection.setAutoCommit(false)
        connection

      override def transact[A](k: Connection => IO[A]): IO[A] =
        k(c)
          .tap(_ => c.commit())
          .recoverWith {
            case e =>
              c.rollback()
              IO.raiseError(e)
          }

    def assertEquals[A](expected: A, got: A): IO[Unit] =
      if (expected == got) {
        Console[IO].println(s"✅ got ${got}")
      } else {
        IO.raiseError(new Exception(s"❌ expected ${expected}, but got ${got}"))
      }

    def test(svc: usecases.Counter[IO]): IO[Unit] =
      for
        _ <- svc.incrementAndGet.flatTap(assertEquals(Count(1), _))
        _ <- svc.incrementAndGet.flatTap(assertEquals(Count(2), _))
        _ <- svc.incrementAndGet.flatTap(assertEquals(Count(3), _))
      yield ()

    override val run: IO[Unit] =
      for
        _ <- Console[IO].println("mem service:")
        _ <- drivers.mem.counter[IO].flatMap(test)
        _ <- Console[IO].println("db service:")
        _ <- drivers.db.init
        _ <- test(drivers.db.counter[IO])
      yield ()

Try it out

This file is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/scala/clean.md |
  codedown scala |
  scala-cli -q --scala 3 _.scala \
    --dep org.typelevel::cats-effect:3.5.4 \
    --dep com.h2database:h2:2.2.224
mem service:
✅ got Count(1)
✅ got Count(2)
✅ got Count(3)
db service:
✅ got Count(1)
✅ got Count(2)
✅ got Count(3)

References