Scala 3 — Book

Dependent Function Types

Language
This doc page is specific to Scala 3, and may cover new concepts not available in Scala 2. Unless otherwise stated, all the code examples in this page assume you are using Scala 3.

A dependent function type describes function types, where the result type may depend on the function’s parameter values. The concept of dependent types, and of dependent function types is more advanced and you would typically only come across it when designing your own libraries or using advanced libraries.

Dependent Method Types

Let’s consider the following example of a heterogenous database that can store values of different types. The key contains the information about what’s the type of the corresponding value:

trait Key { type Value }

trait DB {
  def get(k: Key): Option[k.Value] // a dependent method
}

Given a key, the method get lets us access the map and potentially returns the stored value of type k.Value. We can read this path-dependent type as: “depending on the concrete type of the argument k, we return a matching value”.

For example, we could have the following keys:

object Name extends Key { type Value = String }
object Age extends Key { type Value = Int }

The following calls to method get would now type check:

val db: DB = ...
val res1: Option[String] = db.get(Name)
val res2: Option[Int] = db.get(Age)

Calling the method db.get(Name) returns a value of type Option[String], while calling db.get(Age) returns a value of type Option[Int]. The return type depends on the concrete type of the argument passed to get—hence the name dependent type.

Dependent Function Types

As seen above, Scala 2 already had support for dependent method types. However, creating values of type DB is quite cumbersome:

// a user of a DB
def user(db: DB): Unit =
  db.get(Name) ... db.get(Age)

// creating an instance of the DB and passing it to `user`
user(new DB {
  def get(k: Key): Option[k.Value] = ... // implementation of DB
})

We manually need to create an anonymous inner class of DB, implementing the get method. For code that relies on creating many different instances of DB this is very tedious.

The trait DB only has a single abstract method get. Wouldn’t it be nice, if we could use lambda syntax instead?

user { k =>
  ... // implementation of DB
}

In fact, this is now possible in Scala 3! We can define DB as a dependent function type:

type DB = (k: Key) => Option[k.Value]
//        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
//      A dependent function type

Given this definition of DB the above call to user type checks, as is.

You can read more about the internals of dependent function types in the reference documentation.

Case Study: Numerical Expressions

Let us assume we want to define a module that abstracts over the internal represention of numbers. This can be useful, for instance, to implement libraries for automatic derivation.

We start by defining our module for numbers:

trait Nums:
  // the type of numbers is left abstract
  type Num

  // some operations on numbers
  def lit(d: Double): Num
  def add(l: Num, r: Num): Num
  def mul(l: Num, r: Num): Num

We omit the concrete implementation of Nums, but as an exercise you could implement Nums by assigning type Num = Double and implement methods accordingly.

A program that uses our number abstraction now has the following type:

type Prog = (n: Nums) => n.Num => n.Num

val ex: Prog = nums => x => nums.add(nums.lit(0.8), x)

The type of a function that computes the derivative of programs like ex is:

def derivative(input: Prog): Double

Given the facility of dependent function types, calling this function with different programs is very convenient:

derivative { nums => x => x }
derivative { nums => x => nums.add(nums.lit(0.8), x) }
// ...

To recall, the same program in the encoding above would be:

derivative(new Prog {
  def apply(nums: Nums)(x: nums.Num): nums.Num = x
})
derivative(new Prog {
  def apply(nums: Nums)(x: nums.Num): nums.Num = nums.add(nums.lit(0.8), x)
})
// ...

Combination with Context Functions

The combination of extension methods, context functions, and dependent functions provides a powerful tool for library designers. For instance, we can refine our library from above as follows:

trait NumsDSL extends Nums:
  extension (x: Num)
    def +(y: Num) = add(x, y)
    def *(y: Num) = mul(x, y)

def const(d: Double)(using n: Nums): n.Num = n.lit(d)

type Prog = (n: NumsDSL) ?=> n.Num => n.Num
//                       ^^^
//     prog is now a context function that implicitly
//     assumes a NumsDSL in the calling context

def derivative(input: Prog): Double = ...

// notice how we do not need to mention Nums in the examples below?
derivative { x => const(1.0) + x }
derivative { x => x * x + const(2.0) }
// ...

Contributors to this page: