Skip to content

Flow matching with julia

작년 초, flow matching1에 대해서 처음 들어봤다. 한 번 공부해보고 싶은 마음이 들어서 Julia로 간단한 code를 작성해 보았다.

다만, julia로 코드를 작성할 때는 python보다 조금 더 신경을 써 줘야 한다. 에러 메시지가 상대적으로 불친절하며, type이 조금 더 엄격하다. 또한 vcat 함수에 대해 edge case가 있어서 빠른 prototyping에는 아직 이르지 않나 싶다.

# Tutorial: flow matching
using Lux, Reactant, Enzyme, Optimisers, Random
using Plots
import Base: vcat

# setup device
Reactant.set_default_backend("cpu")
const cdev = cpu_device()
const xdev = reactant_device()

vcat(a::Number, b::ConcreteRArray) = vcat(fill!(similar(b, typeof(a), (1,size(b)[2:end]...)), a), b)

# prepare data
rng = Xoshiro(777)

function generate_data(N::Int=256)
    tt = rand(rng, Float32, 1, N) # uniform(0, 1) 
    x_0 = randn(rng, Float32, 2, N) # pure noise
    x_1 = cat(
        1 .+ 0.05 * randn(rng, 2, N ÷ 2),
        -1 .+ 0.05 * randn(rng, 2, N ÷ 2),
        dims=2) .|> Float32  # data
    return tt, x_0, x_1
end
tt, x_0, x_1 = generate_data(256)
scatter(x_1[1, :], x_1[2, :])
savefig("data.png")

# parametrized vector field
v_θ = Chain(
    Dense(3, 64, elu),
    Dense(64, 64, elu),
    Dense(64, 64, elu),
    Dense(64, 2)
)

# conditional flow: "optimal" transport path
x_t = (1 .- tt) .* x_0 .+ tt .* x_1

# conditional vector field: regression target
∂x_t = x_1 - x_0


# compute flow matching loss 
function loss_fn(v_θ, ps, st, (tt, x_t, ∂x_t))
    pred, st = v_θ(cat(tt, x_t, dims=1), ps, st)
    loss = sum(abs2, pred - ∂x_t) / size(x_t)[2]
    return loss, st, nothing
end

# fitting
ps, st = Lux.setup(rng, v_θ) |> xdev

opt = Adam(1f-3)
tstate = Training.TrainState(v_θ, ps, st, opt)
vjp_rule = AutoEnzyme()


function fit(tstate::Training.TrainState, vjp, data; epochs::Int=10000)
    data = data |> xdev
    for epoch in 1:epochs
        _, loss, _, tstate = Training.single_train_step!(vjp, loss_fn, data, tstate)
        if epoch % 50 == 1 || epoch == epochs
            println("Epoch: $epoch \t Loss: $loss")
            if isnan(loss)
                break
            end
        end
    end
    return tstate
end

data = (tt, x_t, ∂x_t)
tstate = fit(tstate, vjp_rule, data, epochs=20000)

function _inference(v_θ, ps, st, x; num_steps::Int=8)
    tt = LinRange(0, 1, num_steps+1) .|> Float32 |> xdev
    dt = tt[2] - tt[1] |> Float32
    x |> xdev
    for t in tt[1:end-1]
        k1 = dt / 2 * v_θ(vcat(t, x), ps, st)[1]
        x = x + dt * v_θ(vcat(t + dt / 2, x + k1), ps, st)[1]
    end
    return x |> cdev
end

sample_input = cat(tt[1], x_0[:, 1], dims=1) |> xdev
net_compiled = @compile v_θ(sample_input, tstate.parameters, tstate.states)

xx = randn(rng, Float32, 2, 256) |> xdev

function inference(net, ps, st, xx; num_steps::Int=8)
    yy = []
    for i in axes(xx, 2)
        x = xx[:, i]
        push!(yy, _inference(net, ps, st, x; num_steps=num_steps))
    end
    return stack(yy, dims=2)
end

xx_1 = inference(net_compiled, tstate.parameters, tstate.states, xx; num_steps=8)

scatter(xx_1[1, :], xx_1[2, :])
savefig("flow_matching.png")

P.S.

약 1년 전, MLSC Lab 구성원들끼리 diffusion model에 대해서 공부하고 발표하는 시간을 가졌었다. 맨 처음 이명수 박사님께서 stochastic differential equation (SDE)이 무엇인지 빠르게 정리해 주셨고, 이어서 설윤창 교수님께서 SDE를 수치적으로 푸는 방법들에 대해 설명해 주셨었다. 그 당시 나는 diffusion model을 사용해 생성된 이미지들만 몇 번 봤을 뿐이고, 어떻게 동작하는지에 대해서는 전혀 몰랐었다. 그래서 큰 그림을 얻고자 유명한 논문을 읽고 발표를 준비했었다.

한편, 세미나를 하면서 normalizing flow라는 generative modeling에 대해서도 들을 기회가 있었다. 그러나 invertability 이슈 때문에 인기가 사그러들었다고 들었다. 2018년, Neural ODE를 사용한 continuous normalizing flow (CNF)2 논문이 invertibility 이슈를 해결했다. 하지만 학습 과정에서 Neural ODE를 여러번 수치적으로 풀어야 해서 메모리 및 계산 비용이 비쌌다. Flow matching은 ODE의 우변을 직접 regression 하는 방법이므로 학습 과정에서 수치적분 과정이 필요가 없어 계산 비용을 획기적으로 줄였고, diffusion model과 대등한 수준으로 CNF의 성능을 끌어올렸다고 한다.


  1. Lipman, Y., Chen, R.T., Ben-Hamu, H., Nickel, M., & Le, M. (2022) Flow matching for generative modeling. arXiv preprint arXiv:2210.02747

  2. Chen, R.T., Rubanova, Y., Bettencourt, J., & Duvenaud, D.K. (2018) Neural ordinary differential equations. Advances in neural information processing systems, 31