API Reference

Configuration and Data Structures

Args

GraphNetSim.ArgsType
Args

Configuration structure for training and evaluating Graph Neural Network simulators.

Fields

Network Architecture

  • mps::Integer=15: Number of message passing steps (higher = more expressive but slower)
  • layer_size::Integer=128: Latent dimension for MLP hidden layers
  • hidden_layers::Integer=2: Number of hidden layers in each MLP module

Training Configuration

  • epochs::Integer=1: Number of passes over the entire dataset
  • steps::Integer=10e6: Total number of training steps
  • checkpoint::Integer=10000: Interval (in steps) for saving checkpoints
  • norm_steps::Integer=1000: Steps to accumulate normalization statistics before weight updates
  • batchsize::Integer=1: Batch size (currently limited to 1 - full trajectory per batch)

Normalization

  • max_norm_steps::Integer=10.0f6: Maximum steps for online normalizer accumulation

Data Augmentation

  • types_updated::Vector{Integer}=[1]: Node types whose features are predicted
  • types_noisy::Vector{Integer}=[0]: Node types to which noise is added during training
  • noise_stddevs::Vector{Float32}=[0.0f0]: Standard deviations for Gaussian noise (per type or broadcast)

Training Strategy

  • training_strategy::TrainingStrategy=DerivativeTraining(): Method for computing loss

Hardware and Optimization

  • use_cuda::Bool=true: Enable CUDA GPU acceleration (if available)
  • gpu_device::Union{Nothing,CuDevice}: CUDA device to use (auto-selected if CUDA functional)
  • optimizer_learning_rate_start::Float32=1.0f-4: Initial learning rate
  • optimizer_learning_rate_stop::Union{Nothing,Float32}=nothing: Final learning rate (for decay schedule)

Validation

  • show_progress_bars::Bool=true: Show training progress bars
  • use_valid::Bool=true: Load validation checkpoint (best loss) instead of final checkpoint
  • solver_valid::OrdinaryDiffEqAlgorithm=Tsit5(): ODE solver for validation rollouts
  • solver_valid_dt::Union{Nothing,Float32}=nothing: Fixed timestep for validation solver
  • reset_valid::Bool=false: Reset validation after loading checkpoint
  • save_step::Bool=false: Save loss at every step (can create large log files)
source

Dataset

GraphNetSim.DatasetType
Dataset

A mutable structure containing trajectory data and associated metadata for training, validation, or testing.

Fields

  • meta::Dict{String,Any}: Dictionary containing all metadata for the dataset, including feature names, trajectory information, and device settings.
  • datafile::String: Path to the data file (usually .h5 or .jld2 format).
  • lock::ReentrantLock: Lock for thread-safe access to the datafile during concurrent operations.
source

Dataset Constructors

GraphNetSim.DatasetMethod
Dataset(datafile::String, metafile::String, args)

Construct a Dataset from separate data and metadata files.

Validates that both files exist and have correct formats (.h5 or .jld2 for data, .json for metadata), then loads trajectories and merges metadata with provided arguments.

Arguments

  • datafile::String: Path to the data file (.h5 or .jld2 containing trajectory data).
  • metafile::String: Path to the metadata file (.json containing dataset configuration).
  • args: A structure or object whose fields will be merged into the metadata dictionary.

Returns

  • Dataset: A new Dataset object initialized with the provided data and metadata.

Throws

  • ArgumentError: If datafile or metafile do not exist or have invalid formats.
source
GraphNetSim.DatasetMethod
Dataset(split::Symbol, path::String, args)

Construct a Dataset by specifying a split type and directory path.

Locates and loads the appropriate data file based on the split type (:train, :valid, or :test), expecting a "meta.json" file in the given directory.

Arguments

  • split::Symbol: The dataset split, one of :train, :valid, or :test.
  • path::String: Directory path containing the metadata file "meta.json" and the corresponding data file.
  • args: A structure or object whose fields will be merged into the metadata dictionary.

Returns

  • Dataset: A new Dataset object initialized with data from the specified split.

Throws

  • ArgumentError: If split is invalid, if meta.json is not found, or if the corresponding data file cannot be found.
source
GraphNetSim.get_fileFunction
get_file(split::Symbol, path::String)

Locate the data file corresponding to a dataset split in the given directory.

Attempts to find a .jld2 file first, then a .h5 file with the split name (e.g., "train.jld2" or "train.h5").

Arguments

  • split::Symbol: The dataset split name (converted to string).
  • path::String: Directory path to search for the data file.

Returns

  • String: Full path to the located data file.

Throws

  • ArgumentError: If no data file with the specified split name is found in the directory.
source

Training and Evaluation

Main Training Function

GraphNetSim.train_networkFunction
train_network(opt, ds_path::String, cp_path::String; kws...)

Train a Graph Neural Network simulator on trajectory data.

Initializes a graph network, loads dataset, computes normalization statistics, and performs supervised training using the specified training strategy. Validates periodically on validation set and saves checkpoints for the best model.

Arguments

  • opt: Optimizer configuration (e.g., Optimisers.Adam(1f-4)).
  • ds_path::String: Path to dataset directory (must contain train, valid, test splits).
  • cp_path::String: Path where checkpoints and logs are saved.

Keyword Arguments

  • mps::Int=15: Number of message passing steps in the network.
  • layer_size::Int=128: Latent dimension of hidden MLP layers.
  • hidden_layers::Int=2: Number of hidden layers in each MLP.
  • batchsize::Int=1: Batch size for training (default uses full trajectories).
  • epochs::Int=1: Number of training epochs.
  • steps::Int=10e6: Total number of training steps.
  • checkpoint::Int=10000: Create checkpoint every N steps.
  • norm_steps::Int=1000: Steps for accumulating normalization statistics without updates.
  • max_norm_steps::Float32=10.0f6: Maximum steps for online normalizers.
  • types_updated::Vector{Int}=[1]: Node types whose features are predicted.
  • types_noisy::Vector{Int}=[0]: Node types to which noise is added.
  • noise_stddevs::Vector{Float32}=[0.0f0]: Standard deviations for Gaussian noise.
  • training_strategy::TrainingStrategy=DerivativeTraining(): Training method to use.
  • use_cuda::Bool=true: Use CUDA GPU if available.
  • solver_valid::OrdinaryDiffEqAlgorithm=Tsit5(): ODE solver for validation rollouts.
  • solver_valid_dt::Union{Nothing,Float32}=nothing: Fixed timestep for validation (if set).
  • optimizer_learning_rate_start::Float32=1.0f-4: Initial learning rate.
  • optimizer_learning_rate_stop::Union{Nothing,Float32}=nothing: Final learning rate for decay schedule.
  • show_progress_bars::Bool=true: Show training progress bars.
  • use_valid::Bool=true: Use validation checkpoint for early stopping.

Returns

  • Float32: Minimum validation loss achieved during training.

Training Strategies

  • DerivativeTraining: Train on current step derivatives (collocation)
  • BatchingStrategy: Custom batching of trajectory segments

Example #TODO add example with different training strategies

train_network(
    Optimisers.Adam(1f-4),
    "./data",
    "./checkpoints";
    epochs=10,
    steps=50000,
    mps=15,
    layer_size=128,
    training_strategy=DerivativeTraining()
)
source

Internal Training Functions

GraphNetSim.train_gns!Function
train_gns!(gns::GraphNetwork, opt_state, ds_train::Dataset, ds_valid::Dataset, df_train, df_valid, df_step, device::Function, cp_path::String, args::Args)

Execute the main training loop for a Graph Neural Network simulator.

Performs supervised learning with periodic validation, checkpoint saving, and learning rate scheduling. Supports various training strategies (derivative, batching) and handles feature noise injection, gradient accumulation over multiple time steps, and online normalizer updates.

Arguments

  • gns::GraphNetwork: Graph network model containing parameters and normalizers.
  • opt_state: Optimizer state from Optimisers.jl.
  • ds_train::Dataset: Training dataset with trajectories and metadata.
  • ds_valid::Dataset: Validation dataset for monitoring training progress.
  • df_train: DataFrame storing training loss at checkpoints.
  • df_valid: DataFrame storing best validation losses.
  • df_step: DataFrame storing loss at each step (if save_step enabled).
  • device::Function: Device placement function (cpudevice or gpudevice).
  • cp_path::String: Path for saving checkpoints and logs.
  • args::Args: Configuration including optimizer, strategy, and training parameters.

Algorithm

For each training step:

  1. Prepare input data and compute node features
  2. For each time step in trajectory:
    • Build computation graph with neighborhood connections
    • Forward pass through network
    • Compute loss depending on training strategy
    • Accumulate gradients
  3. Update parameters after accumulation window
  4. Optionally decay learning rate
  5. Every N steps: validate on full trajectories and save checkpoint if improved

Returns

  • Float32: Minimum validation loss achieved.

Notes

  • First norm_steps steps only accumulate normalization statistics
  • Validation using long trajectory rollouts occurs at checkpoint intervals
  • Checkpoints saved to cp_path/valid when validation loss improves
  • Final checkpoint always saved at cp_path
source

Normalization Setup

GraphNetSim.calc_normsFunction
calc_norms(dataset::Dataset, device::Function, args::Args)

Initialize and compute feature normalizers from dataset statistics.

Computes normalization statistics for edge features, node features, and output features based on metadata specifications. Supports offline min/max and mean/std normalization, boolean encoding, one-hot encoded features, and online accumulation strategies.

Arguments

  • dataset::Dataset: Dataset object containing feature metadata and specifications.
  • device::Function: Device placement function (cpudevice or gpudevice).
  • args::Args: Configuration struct with norm_steps for online normalizer accumulation.

Returns

  • Tuple: (quantities, enorms, nnorms, o_norms)
    • quantities::Int: Total number of input feature dimensions
    • e_norms: Dictionary or single normalizer for edge features
    • n_norms::Dict: Dictionary mapping node feature names to normalizers
    • o_norms::Dict: Dictionary mapping output feature names to normalizers

Normalizer Types

  • NormaliserOfflineMinMax: Fixed min/max normalization with learnable target range
  • NormaliserOfflineMeanStd: Fixed mean/standard deviation normalization
  • NormaliserOnline: Accumulates statistics online during training

Notes

  • One-hot encoded Int32 features are expanded to multiple dimensions
  • Boolean features are mapped to [0.0, 1.0] range
  • Distance features add dimensions for domain boundary constraints
source

Main Evaluation Function

GraphNetSim.eval_networkFunction
eval_network(ds_path::String, cp_path::String, out_path::String, solver=nothing; start, stop, dt=nothing, saves, mse_steps, kws...)

Evaluate a trained Graph Neural Network simulator on test trajectories.

Loads a trained network from checkpoint, performs long-term trajectory rollouts, computes error metrics against ground truth, and saves results to disk.

Arguments

  • ds_path::String: Path to dataset directory containing test split.
  • cp_path::String: Path to checkpoint directory (contains model parameters).
  • out_path::String: Path where evaluation results are saved.
  • solver: ODE solver for long-term predictions (e.g., Tsit5()).

Keyword Arguments

  • start::Real: Start time for evaluation.
  • stop::Real: End time for evaluation.
  • saves::AbstractVector: Time points where solution is saved.
  • mse_steps::AbstractVector: Time points where error metrics are computed.
  • dt::Union{Nothing,Real}=nothing: Fixed timestep for solver (if applicable).
  • mps::Int=15: Number of message passing steps (must match training config).
  • layer_size::Int=128: Hidden layer size (must match training config).
  • hidden_layers::Int=2: Number of hidden layers (must match training config).
  • types_updated::Vector{Int}=[1]: Updated node types (must match training config).
  • use_cuda::Bool=true: Use CUDA GPU if available.
  • use_valid::Bool=true: Load from best validation checkpoint instead of final checkpoint.

Output

Saves results to out_path/{solver_name}/trajectories.h5:

  • Ground truth positions, velocities, accelerations
  • Predicted positions, velocities, accelerations
  • Prediction errors for each trajectory

Example

eval_network(
    "./data",
    "./checkpoints",
    "./results";
    solver=Tsit5(),
    start=0.0f0,
    stop=1.0f0,
    dt=0.01f0,
    saves=0.0:0.01:1.0,
    mse_steps=0.0:0.1:1.0
)
source
GraphNetSim.eval_network!Function
eval_network!(solver, gns::GraphNetwork, ds_test::Dataset, device::Function, out_path::String, start::Real, stop::Real, dt, saves, mse_steps, args::Args)

Perform evaluation loops and trajectory rollouts for all test samples.

Executes long-term predictions for each test trajectory, computes performance metrics relative to ground truth, and saves trajectories and errors to HDF5 format.

Arguments

  • solver: ODE solver for trajectory rollouts (or nothing for collocation).
  • gns::GraphNetwork: Trained graph network model.
  • ds_test::Dataset: Test dataset with trajectories.
  • device::Function: Device placement function.
  • out_path::String: Output directory for results.
  • start::Real: Evaluation start time.
  • stop::Real: Evaluation end time.
  • dt: Fixed timestep (or nothing for adaptive).
  • saves: Time points to save solution.
  • mse_steps: Time points to compute errors.
  • args::Args: Configuration parameters.

Algorithm

For each test trajectory:

  1. Extract initial conditions from data
  2. Create computation graph with graph network
  3. Roll out trajectory using ODE solver for specified duration
  4. Extract position, velocity, and acceleration from solution
  5. Compute mean squared error against ground truth
  6. Report cumulative error at specified time points

Returns

  • Tuple: (traj_ops, errors)
    • traj_ops::Dict: Dictionary of trajectories with ground truth and predictions
    • errors::Dict: Squared errors for each trajectory

Output Files

Creates {out_path}/{solver_name}/trajectories.h5 containing:

  • Ground truth and predicted trajectories
  • Error fields for each time step
  • Properly indexed for easy post-processing
source

Training Strategies

Abstract Base Type

GraphNetSim.prepare_trainingFunction
prepare_training(strategy)

Function that is executed once before training. Can be overwritten by training strategies if necessary.

Arguments

  • strategy: Used training strategy.

Returns

  • Tuple containing the results of the function.
source
GraphNetSim.get_deltaFunction
get_delta(strategy, trajectory_length)

Returns the delta between samples in the training data.

Arguments

  • strategy: Used training strategy.
  • Trajectory length (used for Derivative strategies).

Returns

  • Delta between samples in the training data.
source
get_delta(::SolverStrategy, ::Integer)

Returns the delta (step size) for solver-based training strategies.

For most solver-based strategies, returns 1 (advancing by single timestep). Can be overridden by specific strategies (e.g., BatchingStrategy).

Arguments

  • strategy::SolverStrategy: Solver-based training strategy.
  • trajectory_length::Integer: Length of the trajectory (unused in base implementation).

Returns

  • Integer delta between training samples.
source
get_delta(strategy::BatchingStrategy, ::Integer)

Returns the number of steps per batch.

Arguments

  • strategy::BatchingStrategy: BatchingStrategy instance.
  • Unused trajectory length parameter.

Returns

  • Integer: Number of steps per batch (strategy.steps).
source
get_delta(strategy::DerivativeStrategy, trajectory_length)

Returns the effective trajectory length for derivative training.

If strategy windowsize > 0 and smaller than trajectorylength, returns windowsize. Otherwise returns the full trajectorylength.

Arguments

  • strategy::DerivativeStrategy: Derivative-based strategy.
  • trajectory_length::Integer: Length of the trajectory.
source
GraphNetSim.init_train_stepFunction
init_train_step(strategy, t)

Function that is executed before each training sample.

Arguments

  • strategy: Used training strategy.
  • t: Tuple containing the variables necessary for initializing training.
  • ta: Tuple with additional variables that is returned from prepare_training.

Returns

source
init_train_step(strategy::SolverStrategy, t::Tuple)

Initializes a training step for solver-based strategies.

Extracts initial conditions, packs state into ComponentArray format, and prepares ground truth data for ODE problem setup.

Arguments

  • strategy::SolverStrategy: Solver-based training strategy.
  • t::Tuple: Input tuple containing (gns, data, position, velocity, meta, outputfields, targetfields, node_type, mask, device, ...).

Returns

  • Tuple: Initialized data for training step.
source
init_train_step(strategy::BatchingStrategy, t::Tuple)

Initializes a training step for the BatchingStrategy.

Selects the next batch via nextBatch(), extracts initial conditions for that time window, packs state into ComponentArray, and prepares ground truth data.

Arguments

  • strategy::BatchingStrategy: BatchingStrategy instance.
  • t::Tuple: Input tuple (gns, data, meta, outputfields, targetfields, nodetype, mask, valmask, device, , batches, showprogress_bars).

Returns

  • Tuple: Initialized batch data for training step.
source
init_train_step(strategy::DerivativeStrategy, t::Tuple)

Initializes a training step for derivative-based strategies.

Extracts target derivatives at a single datapoint and normalizes using network feature normalizers.

Arguments

  • strategy::DerivativeStrategy: Derivative-based strategy.
  • t::Tuple: Input tuple with network, data, and sampling information.
source
GraphNetSim.train_stepFunction
train_step(strategy, t)

Performs a single training step and return the resulting gradients and loss.

Arguments

  • strategy: Solver strategy that is used for training.
  • t: Tuple that is returned from init_train_step.

Returns

  • Gradients for optimization step.
  • Loss for optimization step.
source
train_step(strategy::SolverStrategy, t::Tuple)

Performs one training step for solver-based strategies.

Constructs an ODE problem from the GNS model, solves it using the strategy's solver, and computes gradients via sensitivity analysis (adjoint method).

Arguments

  • strategy::SolverStrategy: Solver-based training strategy.
  • t::Tuple: Initialized data from init_train_step().

Returns

  • Tuple: (gradients, loss) - Gradients for optimization and scalar training loss.

Algorithm

  1. Create ODE right-hand side function using ode_func_train().
  2. Setup ODE problem with initial conditions and parameters.
  3. Compute loss via train_loss().
  4. Backpropagate through ODE solver using sensitivity algorithm.
  5. Return gradients and loss.
source
train_step(strategy::BatchingStrategy, t::Tuple)

Performs one training step for the BatchingStrategy.

Constructs an ODE problem for the selected batch, solves it using the strategy's solver, and computes gradients via sensitivity analysis. Updates batch loss.

Arguments

  • strategy::BatchingStrategy: BatchingStrategy instance.
  • t::Tuple: Data tuple from inittrainstep().
source
train_step(strategy::DerivativeStrategy, t::Tuple)

Performs one training step for derivative-based strategies.

Evaluates network on graph at a single timepoint and computes loss against target derivatives. Computes gradients via backpropagation.

Arguments

  • strategy::DerivativeStrategy: Derivative-based strategy.
  • t::Tuple: Data tuple from inittrainstep().
source
GraphNetSim.validation_stepFunction
validation_step(strategy, t)

Performs validation of a single trajectory. Should be overwritten by training strategies to determine simulation and data interval before calling the inner function _validation_step.

Arguments

  • strategy: Type of training strategy (used for dispatch).
  • t: Tuple containing the variables necessary for validation.

Returns

source
validation_step(strategy::SolverStrategy, t::Tuple)

Validation step for solver-based strategies.

Computes validation loss by rolling out the GNS model over the full validation trajectory and comparing predicted outputs with ground truth.

Arguments

  • strategy::SolverStrategy: Solver-based training strategy.
  • t::Tuple: Validation data tuple containing (gns, data, meta, ...).

Returns

  • Float32: Validation loss (mean squared error).
source
validation_step(strategy::DerivativeStrategy, t::Tuple)

Validation step for derivative-based strategies.

Computes validation loss by rolling out GNS model over trajectory window and comparing derivatives with ground truth.

Arguments

  • strategy::DerivativeStrategy: Derivative-based strategy.
  • t::Tuple: Validation data tuple.
source
GraphNetSim._validation_stepFunction
_validation_step(t, sim_interval, data_interval)

Inner function for validation of a single trajectory.

Arguments

  • t: Tuple containing the variables necessary for validation.
  • sim_interval: Interval that determines the simulated time for the validation.
  • data_interval: Interval that determines the indices of the timesteps in ground truth and prediction data.

Returns

  • Loss calculated on the difference between ground truth and prediction (via mse).
  • Ground truth data with data_interval as timesteps.
  • Prediction data with data_interval as timesteps.
source
GraphNetSim.batchTrajectoryFunction
batchTrajectory(strategy::BatchingStrategy, data::Dict)

Partitions a trajectory into time intervals (batches) for sequential training.

Divides the full trajectory duration into equal-sized time intervals, creating one Batch object per interval. Used for memory-efficient training on long sequences.

Arguments

  • strategy::BatchingStrategy: Batching strategy with interval specifications.
  • data::Dict: Data dictionary containing "dt" (timestep) and "trajectory_length".

Returns

  • Vector{Batch}: Array of Batch objects partitioning the trajectory.
source

Concrete Strategies

GraphNetSim.SingleShootingType
SingleShooting(tstart, dt, tstop, solver; sense = InterpolatingAdjoint(autojacvec = ZygoteVJP()), solargs...)

The default solver based training that is normally used for NeuralODEs. Simulates the system from tstart to tstop and calculates the loss based on the difference between the prediction and the ground truth at the timesteps tstart:dt:tstop.

Arguments

  • tstart: Start time of the simulation.
  • dt: Interval at which the simulation is saved.
  • tstop: Stop time of the simulation.
  • solver: Solver that is used for simulating the system.

Keyword Arguments

  • sense = InterpolatingAdjoint(autojacvec = ZygoteVJP()): The sensitivity algorithm that is used for caluclating the sensitivities.
  • solargs: Keyword arguments that are passed on to the solver.
source
GraphNetSim.MultipleShootingType
MultipleShooting(tstart, dt, tstop, solver, interval_size, continuity_term = 100; sense = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true), solargs...)

Similar to SingleShooting, but splits the trajectory into intervals that are solved independently and then combines them for loss calculation. Useful if the network tends to get stuck in a local minimum if SingleShooting is used.

Arguments

  • tstart: Start time of the simulation.
  • dt: Interval at which the simulation is saved.
  • tstop: Stop time of the simulation.
  • solver: Solver that is used for simulating the system.
  • interval_size: Size of the intervals (i.e. number of datapoints in one interval).
  • continuity_term = 100: Factor by which the error between points of concurrent intervals is multiplied.

Keyword Arguments

  • sense = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true):
  • solargs: Keyword arguments that are passed on to the solver.
source
GraphNetSim.DerivativeTrainingType
struct DerivativeTraining <: DerivativeStrategy

Derivative-based training strategy using finite-difference ground truth.

Compares network output with finite-difference derivatives from data. Faster than solver-based training, useful for initial model training. Supports temporal windowing and optional random shuffling.

source
GraphNetSim.BatchingStrategyType
struct BatchingStrategy <: SolverStrategy

Solver-based training strategy that batches long trajectories into segments.

Divides long trajectories into time intervals and solves/trains on each segment independently. Useful for memory-efficient training on long sequences.

source

Normalization and Data Statistics

Computing Normalization Statistics

GraphNetSim.data_minmaxFunction
data_minmax(path)

Calculates the minimum and maximum values for each numeric feature across all dataset partitions.

Iterates through training, validation, and test datasets to compute global min/max bounds for all numeric features (Int32 and Float32 types).

Arguments

  • path: Path to the dataset files.

Returns

  • Dict: Dictionary mapping feature names to [min, max] value pairs computed from all datasets.
source
GraphNetSim.data_meanstdFunction
data_meanstd(path)

Calculates the mean and standard deviation for each feature in the given part of the dataset.

Arguments

  • path: Path to the dataset files.

Returns

  • Mean and standard deviation in training, validation and test set
source
GraphNetSim.der_minmaxFunction
der_minmax(path)

Calculates the minimum and maximum across training, validation, and test sets for each numeric feature.

Combines results from both training/validation and test data to compute overall min/max bounds.

Arguments

  • path: Path to the dataset files.

Returns

  • Dict: Dictionary mapping feature names to [min, max] value pairs across all datasets.
source

Graph Construction and ODE Solving

Building Computation Graphs

GraphNetSim.build_graphFunction
build_graph(gns::GraphNetCore.GraphNetwork, data::Dict{String,Any}, datapoint::Integer, meta, node_type, device)

Construct a FeatureGraph from trajectory data at a specific time step.

Extracts position and velocity data from the trajectory dictionary at the given time point, then delegates to the second method to construct the graph with edge connectivity and normalized features.

Arguments

  • gns::GraphNetCore.GraphNetwork: Graph network model containing normalizers for features.
  • data::Dict{String,Any}: Dictionary containing trajectory data (position, velocity, etc.).
  • datapoint::Integer: Time step index to extract from the trajectory.
  • meta::Dict{String,Any}: Metadata dictionary with connectivity and feature settings.
  • node_type: One-hot encoded node type features.
  • device::Function: Device placement function (cpu or gpu).

Returns

  • GraphNetCore.FeatureGraph: Constructed graph with normalized node and edge features.
source
build_graph(gns::GraphNetCore.GraphNetwork, position, velocity, meta, node_type, mask, device)

Construct a FeatureGraph from position and velocity data with edge connectivity.

Computes edges based on spatial proximity using GPU-accelerated neighborhood search, calculates relative displacements and normalized distances. Node features are constructed from position, velocity, node type, and distance bounds to domain boundaries. All features are normalized using the normalizers stored in the model.

Arguments

  • gns::GraphNetCore.GraphNetwork: Graph network model containing normalizers.
  • position::AbstractArray: Particle positions with shape (dims, n_particles).
  • velocity::AbstractArray: Particle velocities with shape (dims, n_particles).
  • meta::Dict{String,Any}: Metadata with defaultconnectivityradius, bounds, dims, input_features, and device settings.
  • node_type::AbstractArray: One-hot encoded node type features.
  • mask::AbstractVector: Indices of particles to include in the graph (fluid particles).
  • device::Function: Device placement function.

Returns

  • GraphNetCore.FeatureGraph: Graph with normalized node features, normalized edge features, sender and receiver indices.
source

ODE Integration

GraphNetSim.rolloutFunction
rollout(solver, gns::GraphNetwork, initial_state, output_fields, meta, target_fields, node_type, mask, val_mask, start, stop, dt, saves, device; pr=nothing)

Solves the ODE problem for a Graph Neural Network simulator using the given solver and computes solution at specified timesteps.

Solves the ODEProblem of the GNS model over the specified time interval. The function handles both fixed and adaptive timestep solvers, with optional progress reporting.

Arguments

  • solver: ODE solver algorithm (e.g., Tsit5(), RK4()) from OrdinaryDiffEq.jl.
  • gns::GraphNetwork: Graph neural network model to evaluate for dynamics.
  • initial_state::Dict: Dictionary with "position" and "velocity" arrays for initial conditions.
  • output_fields::Vector{String}: Names of output features predicted by the network.
  • meta::Dict: Dataset metadata containing feature dimensions and specifications.
  • target_fields::Vector{String}: Names of target (output) features.
  • node_type::Vector: One-hot encoded node type indicators.
  • mask::Vector: Boolean mask for valid nodes in graph.
  • val_mask::Vector: Validation/evaluation mask for output features.
  • start::Float32: Start time of ODE integration.
  • stop::Float32: Stop time of ODE integration.
  • dt::Union{Nothing,Float32}: Fixed timestep (if nothing, uses adaptive timestepping).
  • saves::Vector: Timesteps where solution should be saved.
  • device::Function: Device placement function (cpudevice or gpudevice).

Keyword Arguments

  • pr::Union{Nothing,ProgressBar}=nothing: Progress bar for tracking ODE solve.

Returns

  • sol: Solution object containing state trajectories at specified saves timesteps.

Notes

  • Uses ode_step_eval() as the right-hand side function for the ODE.
  • State is packed as ComponentArray with x (position) and dx (velocity) fields.
  • Output features are denormalized using stored normalizers before return.
source

Data Utilities and Conversion

Dataset Loading

GraphNetSim.keystrajFunction
keystraj(datafile::String)

Extract trajectory keys from a data file.

Opens either a .jld2 or .h5 file and returns all top-level keys, representing individual trajectories stored in the file.

Arguments

  • datafile::String: Path to the data file (.h5 or .jld2).

Returns

  • Array{String,1}: Array of trajectory keys from the file.
source
MLCore.getobs!Function
MLUtils.getobs!(buffer::Dict{String,Any}, ds::Dataset, idx::Int)

Load trajectory data into a pre-allocated buffer using the MLUtils interface.

Retrieves a single trajectory by index, populates all metadata and features into the provided buffer dictionary, and applies trajectory preparation (device transfer, masking, validation masks). Modifies the buffer in-place.

Arguments

  • buffer::Dict{String,Any}: Pre-allocated dictionary to store trajectory data (modified in-place).
  • ds::Dataset: The dataset object.
  • idx::Int: Index of the trajectory to retrieve (1-indexed).

Returns

  • Dict{String,Any}: The modified buffer containing the trajectory data.
source

Format Conversion

GraphNetSim.csv_to_hdf5Function
csv_to_hdf5(source::String, output::String; 
             dt::Float64=0.01, 
             n_trajectories::Int=1,
             dims::Vector{Int}=[1, 2],
             groupby_col::Symbol=:Idp,
             interpolation_scheme::String="pchip",
             pos_col_prefix::String="Points",
             vel_col_prefix::String="Vel",
             type_col::Symbol=:Type,
             extra_fields::Vector{Symbol}=Symbol[])

Convert particle trajectory data from CSV to HDF5 format with computed accelerations.

Arguments

  • source::String: Path to input CSV file
  • output::String: Path to output HDF5 file

Keyword Arguments

  • dt::Float64: Time step (default: 0.01)
  • n_trajectories::Int: Number of trajectories to process (default: 1)
  • dims::Vector{Int}: Spatial dimensions to extract, e.g. [1, 2] or [1, 3].(default: [1, 2])
  • groupby_col::Symbol: Column name for grouping particles (default: :Idp)
  • interpolation_scheme::String: Acceleration calculation method (default: "pchip")
    • "central_diff": Central difference scheme
    • "forward_diff": Forward difference scheme
    • "backward_diff": Backward difference scheme
    • "from_pos": Acceleration from position (2nd order central difference)
    • "pchip": PCHIP interpolation of velocity derivatives
    • "linear": LinearInterpolation
    • "quadratic": QuadraticInterpolation
    • "cubic_spline": CubicSpline
    • "quadratic_spline": QuadraticSpline
    • "cubic_hermite": CubicHermiteSpline
    • "lagrange": LagrangeInterpolation
    • "akima": AkimaInterpolation
  • pos_col_prefix::String: Prefix for position columns (default: "Points")
  • vel_col_prefix::String: Prefix for velocity columns (default: "Vel")
  • type_col::Symbol: Column name for particle type (default: :Type)
  • extra_fields::Vector{Symbol}: Additional CSV columns to copy to HDF5 (e.g. [:Mass, :Temperature])

Example

# 3D simulation with PCHIP interpolation
csv_to_hdf5("data/dam_break.csv", "data/dam_break_first.h5"; 
            dt=0.01, dims=[1, 2, 3], interpolation_scheme="pchip")

# 2D (skip y-dimension), with extra fields
csv_to_hdf5("data/input.csv", "output.h5"; 
            dims=[1, 3], interpolation_scheme="cubic_spline",
            extra_fields=[:Mass, :Pressure])
source

Visualization

VTK Export

GraphNetSim.visualizeFunction
visualize(inPath, outFolder, Position, subgroupTrajectory, Parameters, Trajectorys, NumberOfTimesteps)

Reads HDF5 file into dictionary and writes VTK HDF5 file format with dictToVTKHDF().

Input file must have a group for each trajectory and subgroups for all trajectories used. Datasets are linked to each subgroup of trajectory. If only certain trajectories should be written, specify them with a Vector{Int} (minimum is 1). Validation contains the names of the subgroups which will be written and each trajectory must have the same number of timesteps.

Timesteps can be automatically detected if dataset "timesteps" is linked to each trajectory group. This dataset must contain the number for the largest timestep and timesteps 1:1:max will be read. When using timesteps, all datasets to be read need "[Int]" appended.

Arguments

  • inPath::String: Complete path ending with .h5 file.
  • outFolder::String: Complete path for output folder.
  • Position::String: Name of HDF5 dataset for position data.
  • subgroupTrajectory::String: Name of subgroup within trajectory groups.
  • Parameters::Vector{String}: Names of other HDF5 datasets to read (optional).
  • Trajectorys::Union{Vector{Int},Nothing}: Trajectory indices to write (optional, auto-detected if nothing).
  • NumberOfTimesteps::Union{Vector{Int},Nothing}: Timesteps to write (optional, auto-detected if nothing).

Returns

  • Dict: Dictionary mapping (trajectory, dataset_name, timestep) tuples to numerical arrays.
source