API Reference
Configuration and Data Structures
Args
GraphNetSim.Args — Type
ArgsConfiguration structure for training and evaluating Graph Neural Network simulators.
Fields
Network Architecture
mps::Integer=15: Number of message passing steps (higher = more expressive but slower)layer_size::Integer=128: Latent dimension for MLP hidden layershidden_layers::Integer=2: Number of hidden layers in each MLP module
Training Configuration
epochs::Integer=1: Number of passes over the entire datasetsteps::Integer=10e6: Total number of training stepscheckpoint::Integer=10000: Interval (in steps) for saving checkpointsnorm_steps::Integer=1000: Steps to accumulate normalization statistics before weight updatesbatchsize::Integer=1: Batch size (currently limited to 1 - full trajectory per batch)
Normalization
max_norm_steps::Integer=10.0f6: Maximum steps for online normalizer accumulation
Data Augmentation
types_updated::Vector{Integer}=[1]: Node types whose features are predictedtypes_noisy::Vector{Integer}=[0]: Node types to which noise is added during trainingnoise_stddevs::Vector{Float32}=[0.0f0]: Standard deviations for Gaussian noise (per type or broadcast)
Training Strategy
training_strategy::TrainingStrategy=DerivativeTraining(): Method for computing loss
Hardware and Optimization
use_cuda::Bool=true: Enable CUDA GPU acceleration (if available)gpu_device::Union{Nothing,CuDevice}: CUDA device to use (auto-selected if CUDA functional)optimizer_learning_rate_start::Float32=1.0f-4: Initial learning rateoptimizer_learning_rate_stop::Union{Nothing,Float32}=nothing: Final learning rate (for decay schedule)
Validation
show_progress_bars::Bool=true: Show training progress barsuse_valid::Bool=true: Load validation checkpoint (best loss) instead of final checkpointsolver_valid::OrdinaryDiffEqAlgorithm=Tsit5(): ODE solver for validation rolloutssolver_valid_dt::Union{Nothing,Float32}=nothing: Fixed timestep for validation solverreset_valid::Bool=false: Reset validation after loading checkpointsave_step::Bool=false: Save loss at every step (can create large log files)
Dataset
GraphNetSim.Dataset — Type
DatasetA mutable structure containing trajectory data and associated metadata for training, validation, or testing.
Fields
meta::Dict{String,Any}: Dictionary containing all metadata for the dataset, including feature names, trajectory information, and device settings.datafile::String: Path to the data file (usually .h5 or .jld2 format).lock::ReentrantLock: Lock for thread-safe access to the datafile during concurrent operations.
Dataset Constructors
GraphNetSim.Dataset — Method
Dataset(datafile::String, metafile::String, args)Construct a Dataset from separate data and metadata files.
Validates that both files exist and have correct formats (.h5 or .jld2 for data, .json for metadata), then loads trajectories and merges metadata with provided arguments.
Arguments
datafile::String: Path to the data file (.h5 or .jld2 containing trajectory data).metafile::String: Path to the metadata file (.json containing dataset configuration).args: A structure or object whose fields will be merged into the metadata dictionary.
Returns
Dataset: A new Dataset object initialized with the provided data and metadata.
Throws
ArgumentError: If datafile or metafile do not exist or have invalid formats.
GraphNetSim.Dataset — Method
Dataset(split::Symbol, path::String, args)Construct a Dataset by specifying a split type and directory path.
Locates and loads the appropriate data file based on the split type (:train, :valid, or :test), expecting a "meta.json" file in the given directory.
Arguments
split::Symbol: The dataset split, one of:train,:valid, or:test.path::String: Directory path containing the metadata file "meta.json" and the corresponding data file.args: A structure or object whose fields will be merged into the metadata dictionary.
Returns
Dataset: A new Dataset object initialized with data from the specified split.
Throws
ArgumentError: If split is invalid, if meta.json is not found, or if the corresponding data file cannot be found.
GraphNetSim.get_file — Function
get_file(split::Symbol, path::String)Locate the data file corresponding to a dataset split in the given directory.
Attempts to find a .jld2 file first, then a .h5 file with the split name (e.g., "train.jld2" or "train.h5").
Arguments
split::Symbol: The dataset split name (converted to string).path::String: Directory path to search for the data file.
Returns
String: Full path to the located data file.
Throws
ArgumentError: If no data file with the specified split name is found in the directory.
Training and Evaluation
Main Training Function
GraphNetSim.train_network — Function
train_network(opt, ds_path::String, cp_path::String; kws...)Train a Graph Neural Network simulator on trajectory data.
Initializes a graph network, loads dataset, computes normalization statistics, and performs supervised training using the specified training strategy. Validates periodically on validation set and saves checkpoints for the best model.
Arguments
opt: Optimizer configuration (e.g.,Optimisers.Adam(1f-4)).ds_path::String: Path to dataset directory (must contain train, valid, test splits).cp_path::String: Path where checkpoints and logs are saved.
Keyword Arguments
mps::Int=15: Number of message passing steps in the network.layer_size::Int=128: Latent dimension of hidden MLP layers.hidden_layers::Int=2: Number of hidden layers in each MLP.batchsize::Int=1: Batch size for training (default uses full trajectories).epochs::Int=1: Number of training epochs.steps::Int=10e6: Total number of training steps.checkpoint::Int=10000: Create checkpoint every N steps.norm_steps::Int=1000: Steps for accumulating normalization statistics without updates.max_norm_steps::Float32=10.0f6: Maximum steps for online normalizers.types_updated::Vector{Int}=[1]: Node types whose features are predicted.types_noisy::Vector{Int}=[0]: Node types to which noise is added.noise_stddevs::Vector{Float32}=[0.0f0]: Standard deviations for Gaussian noise.training_strategy::TrainingStrategy=DerivativeTraining(): Training method to use.use_cuda::Bool=true: Use CUDA GPU if available.solver_valid::OrdinaryDiffEqAlgorithm=Tsit5(): ODE solver for validation rollouts.solver_valid_dt::Union{Nothing,Float32}=nothing: Fixed timestep for validation (if set).optimizer_learning_rate_start::Float32=1.0f-4: Initial learning rate.optimizer_learning_rate_stop::Union{Nothing,Float32}=nothing: Final learning rate for decay schedule.show_progress_bars::Bool=true: Show training progress bars.use_valid::Bool=true: Use validation checkpoint for early stopping.
Returns
Float32: Minimum validation loss achieved during training.
Training Strategies
DerivativeTraining: Train on current step derivatives (collocation)BatchingStrategy: Custom batching of trajectory segments
Example #TODO add example with different training strategies
train_network(
Optimisers.Adam(1f-4),
"./data",
"./checkpoints";
epochs=10,
steps=50000,
mps=15,
layer_size=128,
training_strategy=DerivativeTraining()
)Internal Training Functions
GraphNetSim.train_gns! — Function
train_gns!(gns::GraphNetwork, opt_state, ds_train::Dataset, ds_valid::Dataset, df_train, df_valid, df_step, device::Function, cp_path::String, args::Args)Execute the main training loop for a Graph Neural Network simulator.
Performs supervised learning with periodic validation, checkpoint saving, and learning rate scheduling. Supports various training strategies (derivative, batching) and handles feature noise injection, gradient accumulation over multiple time steps, and online normalizer updates.
Arguments
gns::GraphNetwork: Graph network model containing parameters and normalizers.opt_state: Optimizer state from Optimisers.jl.ds_train::Dataset: Training dataset with trajectories and metadata.ds_valid::Dataset: Validation dataset for monitoring training progress.df_train: DataFrame storing training loss at checkpoints.df_valid: DataFrame storing best validation losses.df_step: DataFrame storing loss at each step (if save_step enabled).device::Function: Device placement function (cpudevice or gpudevice).cp_path::String: Path for saving checkpoints and logs.args::Args: Configuration including optimizer, strategy, and training parameters.
Algorithm
For each training step:
- Prepare input data and compute node features
- For each time step in trajectory:
- Build computation graph with neighborhood connections
- Forward pass through network
- Compute loss depending on training strategy
- Accumulate gradients
- Update parameters after accumulation window
- Optionally decay learning rate
- Every N steps: validate on full trajectories and save checkpoint if improved
Returns
Float32: Minimum validation loss achieved.
Notes
- First
norm_stepssteps only accumulate normalization statistics - Validation using long trajectory rollouts occurs at checkpoint intervals
- Checkpoints saved to
cp_path/validwhen validation loss improves - Final checkpoint always saved at
cp_path
Normalization Setup
GraphNetSim.calc_norms — Function
calc_norms(dataset::Dataset, device::Function, args::Args)Initialize and compute feature normalizers from dataset statistics.
Computes normalization statistics for edge features, node features, and output features based on metadata specifications. Supports offline min/max and mean/std normalization, boolean encoding, one-hot encoded features, and online accumulation strategies.
Arguments
dataset::Dataset: Dataset object containing feature metadata and specifications.device::Function: Device placement function (cpudevice or gpudevice).args::Args: Configuration struct with norm_steps for online normalizer accumulation.
Returns
Tuple: (quantities, enorms, nnorms, o_norms)quantities::Int: Total number of input feature dimensionse_norms: Dictionary or single normalizer for edge featuresn_norms::Dict: Dictionary mapping node feature names to normalizerso_norms::Dict: Dictionary mapping output feature names to normalizers
Normalizer Types
NormaliserOfflineMinMax: Fixed min/max normalization with learnable target rangeNormaliserOfflineMeanStd: Fixed mean/standard deviation normalizationNormaliserOnline: Accumulates statistics online during training
Notes
- One-hot encoded Int32 features are expanded to multiple dimensions
- Boolean features are mapped to [0.0, 1.0] range
- Distance features add dimensions for domain boundary constraints
Main Evaluation Function
GraphNetSim.eval_network — Function
eval_network(ds_path::String, cp_path::String, out_path::String, solver=nothing; start, stop, dt=nothing, saves, mse_steps, kws...)Evaluate a trained Graph Neural Network simulator on test trajectories.
Loads a trained network from checkpoint, performs long-term trajectory rollouts, computes error metrics against ground truth, and saves results to disk.
Arguments
ds_path::String: Path to dataset directory containing test split.cp_path::String: Path to checkpoint directory (contains model parameters).out_path::String: Path where evaluation results are saved.solver: ODE solver for long-term predictions (e.g.,Tsit5()).
Keyword Arguments
start::Real: Start time for evaluation.stop::Real: End time for evaluation.saves::AbstractVector: Time points where solution is saved.mse_steps::AbstractVector: Time points where error metrics are computed.dt::Union{Nothing,Real}=nothing: Fixed timestep for solver (if applicable).mps::Int=15: Number of message passing steps (must match training config).layer_size::Int=128: Hidden layer size (must match training config).hidden_layers::Int=2: Number of hidden layers (must match training config).types_updated::Vector{Int}=[1]: Updated node types (must match training config).use_cuda::Bool=true: Use CUDA GPU if available.use_valid::Bool=true: Load from best validation checkpoint instead of final checkpoint.
Output
Saves results to out_path/{solver_name}/trajectories.h5:
- Ground truth positions, velocities, accelerations
- Predicted positions, velocities, accelerations
- Prediction errors for each trajectory
Example
eval_network(
"./data",
"./checkpoints",
"./results";
solver=Tsit5(),
start=0.0f0,
stop=1.0f0,
dt=0.01f0,
saves=0.0:0.01:1.0,
mse_steps=0.0:0.1:1.0
)GraphNetSim.eval_network! — Function
eval_network!(solver, gns::GraphNetwork, ds_test::Dataset, device::Function, out_path::String, start::Real, stop::Real, dt, saves, mse_steps, args::Args)Perform evaluation loops and trajectory rollouts for all test samples.
Executes long-term predictions for each test trajectory, computes performance metrics relative to ground truth, and saves trajectories and errors to HDF5 format.
Arguments
solver: ODE solver for trajectory rollouts (ornothingfor collocation).gns::GraphNetwork: Trained graph network model.ds_test::Dataset: Test dataset with trajectories.device::Function: Device placement function.out_path::String: Output directory for results.start::Real: Evaluation start time.stop::Real: Evaluation end time.dt: Fixed timestep (ornothingfor adaptive).saves: Time points to save solution.mse_steps: Time points to compute errors.args::Args: Configuration parameters.
Algorithm
For each test trajectory:
- Extract initial conditions from data
- Create computation graph with graph network
- Roll out trajectory using ODE solver for specified duration
- Extract position, velocity, and acceleration from solution
- Compute mean squared error against ground truth
- Report cumulative error at specified time points
Returns
Tuple: (traj_ops, errors)traj_ops::Dict: Dictionary of trajectories with ground truth and predictionserrors::Dict: Squared errors for each trajectory
Output Files
Creates {out_path}/{solver_name}/trajectories.h5 containing:
- Ground truth and predicted trajectories
- Error fields for each time step
- Properly indexed for easy post-processing
Training Strategies
Abstract Base Type
GraphNetSim.prepare_training — Function
prepare_training(strategy)Function that is executed once before training. Can be overwritten by training strategies if necessary.
Arguments
strategy: Used training strategy.
Returns
- Tuple containing the results of the function.
GraphNetSim.get_delta — Function
get_delta(strategy, trajectory_length)Returns the delta between samples in the training data.
Arguments
strategy: Used training strategy.- Trajectory length (used for Derivative strategies).
Returns
- Delta between samples in the training data.
get_delta(::SolverStrategy, ::Integer)Returns the delta (step size) for solver-based training strategies.
For most solver-based strategies, returns 1 (advancing by single timestep). Can be overridden by specific strategies (e.g., BatchingStrategy).
Arguments
strategy::SolverStrategy: Solver-based training strategy.trajectory_length::Integer: Length of the trajectory (unused in base implementation).
Returns
- Integer delta between training samples.
get_delta(strategy::BatchingStrategy, ::Integer)Returns the number of steps per batch.
Arguments
strategy::BatchingStrategy: BatchingStrategy instance.- Unused trajectory length parameter.
Returns
Integer: Number of steps per batch (strategy.steps).
get_delta(strategy::DerivativeStrategy, trajectory_length)Returns the effective trajectory length for derivative training.
If strategy windowsize > 0 and smaller than trajectorylength, returns windowsize. Otherwise returns the full trajectorylength.
Arguments
strategy::DerivativeStrategy: Derivative-based strategy.trajectory_length::Integer: Length of the trajectory.
GraphNetSim.init_train_step — Function
init_train_step(strategy, t)Function that is executed before each training sample.
Arguments
strategy: Used training strategy.t: Tuple containing the variables necessary for initializing training.ta: Tuple with additional variables that is returned from prepare_training.
Returns
- Tuple containing variables needed for train_step.
init_train_step(strategy::SolverStrategy, t::Tuple)Initializes a training step for solver-based strategies.
Extracts initial conditions, packs state into ComponentArray format, and prepares ground truth data for ODE problem setup.
Arguments
strategy::SolverStrategy: Solver-based training strategy.t::Tuple: Input tuple containing (gns, data, position, velocity, meta, outputfields, targetfields, node_type, mask, device, ...).
Returns
Tuple: Initialized data for training step.
init_train_step(strategy::BatchingStrategy, t::Tuple)Initializes a training step for the BatchingStrategy.
Selects the next batch via nextBatch(), extracts initial conditions for that time window, packs state into ComponentArray, and prepares ground truth data.
Arguments
strategy::BatchingStrategy: BatchingStrategy instance.t::Tuple: Input tuple (gns, data, meta, outputfields, targetfields, nodetype, mask, valmask, device, , batches, showprogress_bars).
Returns
Tuple: Initialized batch data for training step.
init_train_step(strategy::DerivativeStrategy, t::Tuple)Initializes a training step for derivative-based strategies.
Extracts target derivatives at a single datapoint and normalizes using network feature normalizers.
Arguments
strategy::DerivativeStrategy: Derivative-based strategy.t::Tuple: Input tuple with network, data, and sampling information.
GraphNetSim.train_step — Function
train_step(strategy, t)Performs a single training step and return the resulting gradients and loss.
Arguments
strategy: Solver strategy that is used for training.t: Tuple that is returned frominit_train_step.
Returns
- Gradients for optimization step.
- Loss for optimization step.
train_step(strategy::SolverStrategy, t::Tuple)Performs one training step for solver-based strategies.
Constructs an ODE problem from the GNS model, solves it using the strategy's solver, and computes gradients via sensitivity analysis (adjoint method).
Arguments
strategy::SolverStrategy: Solver-based training strategy.t::Tuple: Initialized data frominit_train_step().
Returns
Tuple: (gradients, loss) - Gradients for optimization and scalar training loss.
Algorithm
- Create ODE right-hand side function using
ode_func_train(). - Setup ODE problem with initial conditions and parameters.
- Compute loss via
train_loss(). - Backpropagate through ODE solver using sensitivity algorithm.
- Return gradients and loss.
train_step(strategy::BatchingStrategy, t::Tuple)Performs one training step for the BatchingStrategy.
Constructs an ODE problem for the selected batch, solves it using the strategy's solver, and computes gradients via sensitivity analysis. Updates batch loss.
Arguments
strategy::BatchingStrategy: BatchingStrategy instance.t::Tuple: Data tuple from inittrainstep().
train_step(strategy::DerivativeStrategy, t::Tuple)Performs one training step for derivative-based strategies.
Evaluates network on graph at a single timepoint and computes loss against target derivatives. Computes gradients via backpropagation.
Arguments
strategy::DerivativeStrategy: Derivative-based strategy.t::Tuple: Data tuple from inittrainstep().
GraphNetSim.validation_step — Function
validation_step(strategy, t)Performs validation of a single trajectory. Should be overwritten by training strategies to determine simulation and data interval before calling the inner function _validation_step.
Arguments
strategy: Type of training strategy (used for dispatch).t: Tuple containing the variables necessary for validation.
Returns
- See
_validation_step.
validation_step(strategy::SolverStrategy, t::Tuple)Validation step for solver-based strategies.
Computes validation loss by rolling out the GNS model over the full validation trajectory and comparing predicted outputs with ground truth.
Arguments
strategy::SolverStrategy: Solver-based training strategy.t::Tuple: Validation data tuple containing (gns, data, meta, ...).
Returns
Float32: Validation loss (mean squared error).
validation_step(strategy::DerivativeStrategy, t::Tuple)Validation step for derivative-based strategies.
Computes validation loss by rolling out GNS model over trajectory window and comparing derivatives with ground truth.
Arguments
strategy::DerivativeStrategy: Derivative-based strategy.t::Tuple: Validation data tuple.
GraphNetSim._validation_step — Function
_validation_step(t, sim_interval, data_interval)Inner function for validation of a single trajectory.
Arguments
t: Tuple containing the variables necessary for validation.sim_interval: Interval that determines the simulated time for the validation.data_interval: Interval that determines the indices of the timesteps in ground truth and prediction data.
Returns
- Loss calculated on the difference between ground truth and prediction (via mse).
- Ground truth data with
data_intervalas timesteps. - Prediction data with
data_intervalas timesteps.
GraphNetSim.batchTrajectory — Function
batchTrajectory(strategy::BatchingStrategy, data::Dict)Partitions a trajectory into time intervals (batches) for sequential training.
Divides the full trajectory duration into equal-sized time intervals, creating one Batch object per interval. Used for memory-efficient training on long sequences.
Arguments
strategy::BatchingStrategy: Batching strategy with interval specifications.data::Dict: Data dictionary containing"dt"(timestep) and"trajectory_length".
Returns
Vector{Batch}: Array of Batch objects partitioning the trajectory.
Concrete Strategies
GraphNetSim.SingleShooting — Type
SingleShooting(tstart, dt, tstop, solver; sense = InterpolatingAdjoint(autojacvec = ZygoteVJP()), solargs...)The default solver based training that is normally used for NeuralODEs. Simulates the system from tstart to tstop and calculates the loss based on the difference between the prediction and the ground truth at the timesteps tstart:dt:tstop.
Arguments
tstart: Start time of the simulation.dt: Interval at which the simulation is saved.tstop: Stop time of the simulation.solver: Solver that is used for simulating the system.
Keyword Arguments
sense = InterpolatingAdjoint(autojacvec = ZygoteVJP()): The sensitivity algorithm that is used for caluclating the sensitivities.solargs: Keyword arguments that are passed on to the solver.
GraphNetSim.MultipleShooting — Type
MultipleShooting(tstart, dt, tstop, solver, interval_size, continuity_term = 100; sense = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true), solargs...)Similar to SingleShooting, but splits the trajectory into intervals that are solved independently and then combines them for loss calculation. Useful if the network tends to get stuck in a local minimum if SingleShooting is used.
Arguments
tstart: Start time of the simulation.dt: Interval at which the simulation is saved.tstop: Stop time of the simulation.solver: Solver that is used for simulating the system.interval_size: Size of the intervals (i.e. number of datapoints in one interval).continuity_term = 100: Factor by which the error between points of concurrent intervals is multiplied.
Keyword Arguments
sense = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true):solargs: Keyword arguments that are passed on to the solver.
GraphNetSim.DerivativeTraining — Type
struct DerivativeTraining <: DerivativeStrategyDerivative-based training strategy using finite-difference ground truth.
Compares network output with finite-difference derivatives from data. Faster than solver-based training, useful for initial model training. Supports temporal windowing and optional random shuffling.
GraphNetSim.BatchingStrategy — Type
struct BatchingStrategy <: SolverStrategySolver-based training strategy that batches long trajectories into segments.
Divides long trajectories into time intervals and solves/trains on each segment independently. Useful for memory-efficient training on long sequences.
Normalization and Data Statistics
Computing Normalization Statistics
GraphNetSim.data_minmax — Function
data_minmax(path)Calculates the minimum and maximum values for each numeric feature across all dataset partitions.
Iterates through training, validation, and test datasets to compute global min/max bounds for all numeric features (Int32 and Float32 types).
Arguments
path: Path to the dataset files.
Returns
Dict: Dictionary mapping feature names to [min, max] value pairs computed from all datasets.
GraphNetSim.data_meanstd — Function
data_meanstd(path)Calculates the mean and standard deviation for each feature in the given part of the dataset.
Arguments
path: Path to the dataset files.
Returns
- Mean and standard deviation in training, validation and test set
GraphNetSim.der_minmax — Function
der_minmax(path)Calculates the minimum and maximum across training, validation, and test sets for each numeric feature.
Combines results from both training/validation and test data to compute overall min/max bounds.
Arguments
path: Path to the dataset files.
Returns
Dict: Dictionary mapping feature names to [min, max] value pairs across all datasets.
Graph Construction and ODE Solving
Building Computation Graphs
GraphNetSim.build_graph — Function
build_graph(gns::GraphNetCore.GraphNetwork, data::Dict{String,Any}, datapoint::Integer, meta, node_type, device)Construct a FeatureGraph from trajectory data at a specific time step.
Extracts position and velocity data from the trajectory dictionary at the given time point, then delegates to the second method to construct the graph with edge connectivity and normalized features.
Arguments
gns::GraphNetCore.GraphNetwork: Graph network model containing normalizers for features.data::Dict{String,Any}: Dictionary containing trajectory data (position, velocity, etc.).datapoint::Integer: Time step index to extract from the trajectory.meta::Dict{String,Any}: Metadata dictionary with connectivity and feature settings.node_type: One-hot encoded node type features.device::Function: Device placement function (cpu or gpu).
Returns
GraphNetCore.FeatureGraph: Constructed graph with normalized node and edge features.
build_graph(gns::GraphNetCore.GraphNetwork, position, velocity, meta, node_type, mask, device)Construct a FeatureGraph from position and velocity data with edge connectivity.
Computes edges based on spatial proximity using GPU-accelerated neighborhood search, calculates relative displacements and normalized distances. Node features are constructed from position, velocity, node type, and distance bounds to domain boundaries. All features are normalized using the normalizers stored in the model.
Arguments
gns::GraphNetCore.GraphNetwork: Graph network model containing normalizers.position::AbstractArray: Particle positions with shape (dims, n_particles).velocity::AbstractArray: Particle velocities with shape (dims, n_particles).meta::Dict{String,Any}: Metadata with defaultconnectivityradius, bounds, dims, input_features, and device settings.node_type::AbstractArray: One-hot encoded node type features.mask::AbstractVector: Indices of particles to include in the graph (fluid particles).device::Function: Device placement function.
Returns
GraphNetCore.FeatureGraph: Graph with normalized node features, normalized edge features, sender and receiver indices.
ODE Integration
GraphNetSim.rollout — Function
rollout(solver, gns::GraphNetwork, initial_state, output_fields, meta, target_fields, node_type, mask, val_mask, start, stop, dt, saves, device; pr=nothing)Solves the ODE problem for a Graph Neural Network simulator using the given solver and computes solution at specified timesteps.
Solves the ODEProblem of the GNS model over the specified time interval. The function handles both fixed and adaptive timestep solvers, with optional progress reporting.
Arguments
solver: ODE solver algorithm (e.g.,Tsit5(),RK4()) from OrdinaryDiffEq.jl.gns::GraphNetwork: Graph neural network model to evaluate for dynamics.initial_state::Dict: Dictionary with"position"and"velocity"arrays for initial conditions.output_fields::Vector{String}: Names of output features predicted by the network.meta::Dict: Dataset metadata containing feature dimensions and specifications.target_fields::Vector{String}: Names of target (output) features.node_type::Vector: One-hot encoded node type indicators.mask::Vector: Boolean mask for valid nodes in graph.val_mask::Vector: Validation/evaluation mask for output features.start::Float32: Start time of ODE integration.stop::Float32: Stop time of ODE integration.dt::Union{Nothing,Float32}: Fixed timestep (ifnothing, uses adaptive timestepping).saves::Vector: Timesteps where solution should be saved.device::Function: Device placement function (cpudevice or gpudevice).
Keyword Arguments
pr::Union{Nothing,ProgressBar}=nothing: Progress bar for tracking ODE solve.
Returns
sol: Solution object containing state trajectories at specifiedsavestimesteps.
Notes
- Uses
ode_step_eval()as the right-hand side function for the ODE. - State is packed as
ComponentArraywithx(position) anddx(velocity) fields. - Output features are denormalized using stored normalizers before return.
Data Utilities and Conversion
Dataset Loading
GraphNetSim.keystraj — Function
keystraj(datafile::String)Extract trajectory keys from a data file.
Opens either a .jld2 or .h5 file and returns all top-level keys, representing individual trajectories stored in the file.
Arguments
datafile::String: Path to the data file (.h5 or .jld2).
Returns
Array{String,1}: Array of trajectory keys from the file.
MLCore.getobs! — Function
MLUtils.getobs!(buffer::Dict{String,Any}, ds::Dataset, idx::Int)Load trajectory data into a pre-allocated buffer using the MLUtils interface.
Retrieves a single trajectory by index, populates all metadata and features into the provided buffer dictionary, and applies trajectory preparation (device transfer, masking, validation masks). Modifies the buffer in-place.
Arguments
buffer::Dict{String,Any}: Pre-allocated dictionary to store trajectory data (modified in-place).ds::Dataset: The dataset object.idx::Int: Index of the trajectory to retrieve (1-indexed).
Returns
Dict{String,Any}: The modified buffer containing the trajectory data.
Format Conversion
GraphNetSim.csv_to_hdf5 — Function
csv_to_hdf5(source::String, output::String;
dt::Float64=0.01,
n_trajectories::Int=1,
dims::Vector{Int}=[1, 2],
groupby_col::Symbol=:Idp,
interpolation_scheme::String="pchip",
pos_col_prefix::String="Points",
vel_col_prefix::String="Vel",
type_col::Symbol=:Type,
extra_fields::Vector{Symbol}=Symbol[])Convert particle trajectory data from CSV to HDF5 format with computed accelerations.
Arguments
source::String: Path to input CSV fileoutput::String: Path to output HDF5 file
Keyword Arguments
dt::Float64: Time step (default: 0.01)n_trajectories::Int: Number of trajectories to process (default: 1)dims::Vector{Int}: Spatial dimensions to extract, e.g. [1, 2] or [1, 3].(default: [1, 2])groupby_col::Symbol: Column name for grouping particles (default: :Idp)interpolation_scheme::String: Acceleration calculation method (default: "pchip")- "central_diff": Central difference scheme
- "forward_diff": Forward difference scheme
- "backward_diff": Backward difference scheme
- "from_pos": Acceleration from position (2nd order central difference)
- "pchip": PCHIP interpolation of velocity derivatives
- "linear": LinearInterpolation
- "quadratic": QuadraticInterpolation
- "cubic_spline": CubicSpline
- "quadratic_spline": QuadraticSpline
- "cubic_hermite": CubicHermiteSpline
- "lagrange": LagrangeInterpolation
- "akima": AkimaInterpolation
pos_col_prefix::String: Prefix for position columns (default: "Points")vel_col_prefix::String: Prefix for velocity columns (default: "Vel")type_col::Symbol: Column name for particle type (default: :Type)extra_fields::Vector{Symbol}: Additional CSV columns to copy to HDF5 (e.g. [:Mass, :Temperature])
Example
# 3D simulation with PCHIP interpolation
csv_to_hdf5("data/dam_break.csv", "data/dam_break_first.h5";
dt=0.01, dims=[1, 2, 3], interpolation_scheme="pchip")
# 2D (skip y-dimension), with extra fields
csv_to_hdf5("data/input.csv", "output.h5";
dims=[1, 3], interpolation_scheme="cubic_spline",
extra_fields=[:Mass, :Pressure])Visualization
VTK Export
GraphNetSim.visualize — Function
visualize(inPath, outFolder, Position, subgroupTrajectory, Parameters, Trajectorys, NumberOfTimesteps)Reads HDF5 file into dictionary and writes VTK HDF5 file format with dictToVTKHDF().
Input file must have a group for each trajectory and subgroups for all trajectories used. Datasets are linked to each subgroup of trajectory. If only certain trajectories should be written, specify them with a Vector{Int} (minimum is 1). Validation contains the names of the subgroups which will be written and each trajectory must have the same number of timesteps.
Timesteps can be automatically detected if dataset "timesteps" is linked to each trajectory group. This dataset must contain the number for the largest timestep and timesteps 1:1:max will be read. When using timesteps, all datasets to be read need "[Int]" appended.
Arguments
inPath::String: Complete path ending with .h5 file.outFolder::String: Complete path for output folder.Position::String: Name of HDF5 dataset for position data.subgroupTrajectory::String: Name of subgroup within trajectory groups.Parameters::Vector{String}: Names of other HDF5 datasets to read (optional).Trajectorys::Union{Vector{Int},Nothing}: Trajectory indices to write (optional, auto-detected if nothing).NumberOfTimesteps::Union{Vector{Int},Nothing}: Timesteps to write (optional, auto-detected if nothing).
Returns
Dict: Dictionary mapping (trajectory, dataset_name, timestep) tuples to numerical arrays.