Skip to content

Integration with Flux.jl

Using Wandb.jl in existing Flux workflows is pretty easy. Let's go through the mp_mnist demo in Flux model-zoo and update it to use Wandb. Firstly, use this evironment and add Wandb.jl to it.

julia
using Flux, Statistics
using Flux.Data: DataLoader
using Flux: onehotbatch, onecold, @epochs
using Flux.Losses: logitcrossentropy
using CUDA
using MLDatasets
using Wandb
using Dates
using Logging

lg = WandbLogger(project = "Wandb.jl", name = "fluxjl-integration-$(now())",
                 config = Dict("learning_rate" => 3e-4, "batchsize" => 256,
                               "epochs" => 100, "dataset" => "MNIST", "use_cuda" => true))

global_logger(lg)

##################################################################################
# Wandb # Instead of passing arguments around we will use the global configuration
# Wandb # file from Wandb
##################################################################################
function getdata(device)
    ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"

    # Loading Dataset	
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    xtest, ytest = MLDatasets.MNIST.testdata(Float32)

    # Reshape Data in order to flatten each image into a linear array
    xtrain = Flux.flatten(xtrain)
    xtest = Flux.flatten(xtest)

    # One-hot-encode the labels
    ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

    # Create DataLoaders (mini-batch iterators)
    train_loader = DataLoader(
        (xtrain, ytrain),
        batchsize = get_config(lg, "batchsize"),
        shuffle = true,
    )
    test_loader = DataLoader((xtest, ytest), batchsize = get_config(lg, "batchsize"))

    return train_loader, test_loader
end

build_model(; imgsize = (28, 28, 1), nclasses = 10) =
    Chain(Dense(prod(imgsize), 32, relu), Dense(32, nclasses))

function loss_and_accuracy(data_loader, model, device)
    acc = 0
    ls = 0.0f0
    num = 0
    for (x, y) in data_loader
        x, y = device(x), device(y)
= model(x)
        ls += logitcrossentropy(model(x), y, agg = sum)
        acc += sum(onecold(cpu(model(x))) .== onecold(cpu(y)))
        num += size(x, 2)
    end
    return ls / num, acc / num
end

#################################################################
# Wandb # If any paramters need to be updated pass them as a Dict
#################################################################
function train(update_params::Dict = Dict())
    #################################
    # Wandb # Update config if needed
    #################################
    update_config!(lg, update_params)

    if CUDA.functional() && get_config(lg, "use_cuda")
        @info "Training on CUDA GPU"
        CUDA.allowscalar(false)
        device = gpu
    else
        @info "Training on CPU"
        device = cpu
    end

    # Create test and train dataloaders
    train_loader, test_loader = getdata(device)

    # Construct model
    model = build_model() |> device
    ps = Flux.params(model) # model's trainable parameters

    ## Optimizer
    opt = ADAM(get_config(lg, "learning_rate"))

    ## Training
    for epoch = 1:get_config(lg, "epochs")
        for (x, y) in train_loader
            x, y = device(x), device(y) # transfer data to device
            gs = gradient(() -> logitcrossentropy(model(x), y), ps) # compute gradient
            Flux.Optimise.update!(opt, ps, gs) # update parameters
        end

        # Report on train and test
        train_loss, train_acc = loss_and_accuracy(train_loader, model, device)
        test_loss, test_acc = loss_and_accuracy(test_loader, model, device)

        ###################################
        # Wandb # Log the loss and accuracy
        ###################################
        Wandb.log(
            lg,
            Dict(
                "Training/Loss" => train_loss,
                "Training/Accuracy" => train_acc,
                "Testing/Loss" => test_loss,
                "Testing/Accuracy" => test_acc,
            ),
        )

        println("Epoch=$epoch")
        println("  train_loss = $train_loss, train_accuracy = $train_acc")
        println("  test_loss = $test_loss, test_accuracy = $test_acc")
    end
end

### Run training 
train()

################################
# Wandb # Finish the Current Run
################################
close(lg)