Published on

Test closure once semantics in Swift

Authors

Today I was working on NanoFlick's TCA port. It's a joy to work on - you can write really descriptive tests and be nice and lazy during iterative testing and debugging - replaying a test requires no tapping, just waiting for the actions to be sent through the reducer.

One of the features of TCA is its dependency injection system. It's a really nice way to handle dependencies where every dependency endpoint can be swapped out with a mock function. However, what if I want to make guarantees in my tests about how many times I call a test function?

The naive way I (have) been doing this is to use a var to track the number of times a function is called.

var timesCalled = 0
store.dependencies.fetchMovies = { _ in
    timesCalled += 1
    return .none
}
store.send(.fetchMovies)
XCTAssertEqual(timesCalled, 1)

This is, like, three lines of non-contiguous code every time I want to make this assertion. It's not a lot, but it's enough to be annoying. In the end, we'll be able to write this:

store.dependencies.fetchMovies = Run.once { _ in
    return .none
}
store.send(.fetchMovies)

and it'll work essentially the same way.

The implementation

The first try at doing this is to make just a straightforward closure. Because we want to be able to handle arbitrary argument lists, we have to use the generic parameter pack feature that just rolled out in Swift 5.9. This is a really cool feature that lets us write functions that take an arbitrary number of arguments of the same type.

We take those arguments and pass them to the closure, and then we increment the counter representing the number of calls to the function. If the counter is greater than 1, we fail the test.

func once<each T, U>(_ closure: @escaping (repeat each T) -> U) -> ((repeat each T) -> U) {
    var count = 0
    return { (x: repeat each T) in
        count += 1
        if count > 1 {
            XCTFail("Called too many times: \(count)", file: file, line: line)
        }
        return closure(repeat each x)
    }
}

This is great, but what if we want to make sure that the function is called? Then, we should track when the function goes out of scope without being called. We can do this by making the captured variable a class, and use the deinitializer to perform the final check.

class Run {
    var count = 0

    func run() {
        count += 1
        if count > 1 {
            XCTFail("Called too many times: \(count)", file: file, line: line)
        }
    }
    
    let file: StaticString
    let line: UInt
    
    init(file: StaticString = #file, line: UInt = #line) {
        self.file = file
        self.line = line
    }

    deinit {
        if count == 0 {
            XCTFail("Value not used", file: file, line: line)
        }
    }

    static func once<each T, U>(file: StaticString = #file, line: UInt = #line, _ closure: @escaping (repeat each T) -> U) -> ((repeat each T) -> U) {
        let once = Run(file: file, line: line)
        return { (x: repeat each T) in
            defer {
                once.run()
            }
            return closure(repeat each x)
        }
    }
}

Great! Now we can say something like

store.dependencies.fetchMovies = Run.once { _ in
    return .none
}

What if we want to add more functionality to Run, like the fancy JavaScript libraries we're now oh-so-close to? It's actually pretty straightforward! Just have an enum to handle which rule is selected to share the common reference counted count variable, and have an ergonomic static function to select which one you want.

class Run {
    enum Modality {
        case once
        case atLeastOnce
    }
    
    var modality: Modality
    var count = 0

    func run() {
        count += 1
        if count > 1 && modality == .once {
            XCTFail("Called too many times: \(count)", file: file, line: line)
        }
    }
    
    let file: StaticString
    let line: UInt
    
    init(_ modality: Modality, file: StaticString = #file, line: UInt = #line) {
        self.modality = modality
        self.file = file
        self.line = line
    }

    deinit {
        if count == 0 {
            XCTFail("Value not used", file: file, line: line)
        }
    }
    
    static func atLeastOnce<each T, U>(file: StaticString = #file, line: UInt = #line, _ closure: @escaping (repeat each T) -> U) -> ((repeat each T) -> U) {
        amount(modality: .atLeastOnce, file: file, line: line, closure)
    }
    
    static func once<each T, U>(file: StaticString = #file, line: UInt = #line, _ closure: @escaping (repeat each T) -> U) -> ((repeat each T) -> U) {
        amount(modality: .once, file: file, line: line, closure)
    }
    
    private static func amount<each T, U>(modality: Modality, file: StaticString = #file, line: UInt = #line, _ closure: @escaping (repeat each T) -> U) -> ((repeat each T) async -> U) {
    let once = Run(modality, file: file, line: line)
    return { (x: repeat each T) in
        defer {
            await once.run()
        }
        return closure(repeat each x)
    }
}
}

Thread safety

There's actually a race condition here. If we call store.send(.fetchMovies) twice in a row, there may not be an error. This is because the count += 1 function may run at the same time, leading to count holding the value '1'. A classic OS 101 race condition!

One way to fix this would be to make everything async by making Run an actor, instead of a class. This would fix the race conditon at the cost of either having to make the returned of the function async, or having a detached task run the increment and praying that it runs before the observer falls out of scope. If the test observer deallocates before the task runs, then we'll get a false positive - XCTFail will run without it being recorded!

// if we made it detached.
// problem: what if this runs too late?
private static func amount<each T, U>(modality: Modality, file: StaticString = #file, line: UInt = #line, _ closure: @escaping (repeat each T) -> U) -> ((repeat each T) -> U) {
    let once = Run(modality, file: file, line: line)
    return { (x: repeat each T) in
        defer {
            Task {
                await once.run()
            }
        }
        return closure(repeat each x)
    }
}
// if we made it async
// now not a perfect type substitution :(
private static func amount<each T, U>(modality: Modality, file: StaticString = #file, line: UInt = #line, _ closure: @escaping (repeat each T) -> U) -> ((repeat each T) async -> U) {
    let once = Run(modality, file: file, line: line)
    return { (x: repeat each T) in
        defer {
            await once.run()
        }
        return closure(repeat each x)
    }
}

Unless you are already returning an async function, this is pretty undesirable, especially because there's no contention in the common case. No reason to have all that async complication!

The other way is to use LockIsolated to meet the Sendable guarantee without actually having to async / await. This is probably the best way. We get the synchronous code we want, and we get the thread safety we need.

private static func amount<each T, U>(modality: Modality, file: StaticString = #file, line: UInt = #line, _ closure: @escaping (repeat each T) -> U) -> ((repeat each T) async -> U) {
    let once = LockIsolated(Run(modality, file: file, line: line))
    return { (x: repeat each T) in
        defer {
            await once.run()
        }
        return closure(repeat each x)
    }
}

Note that this built on Xcode 15b2, but is broken on Xcode 15b3. To fix, just comment out the convenience functions and directly call amount and await the next update.