Training Strategies

MeshGraphNets.DerivativeTrainingType
DerivativeTraining(; window_size = 0)

Compares the prediction of the system with the derivative from the data (via finite differences). Useful for initial training of the system since it it faster than training with a solver.

Keyword Arguments

  • window_size = 0: Number of steps from each trajectory (starting at the beginning) that are used for training. If the number is zero then the whole trajectory is used.
source
MeshGraphNets.SolverTrainingType
SolverTraining(tstart, dt, tstop, solver; sense = InterpolatingAdjoint(autojacvec = ZygoteVJP()), solargs...)

The default solver based training that is normally used for NeuralODEs. Simulates the system from tstart to tstop and calculates the loss based on the difference between the prediction and the ground truth at the timesteps tstart:dt:tstop.

Arguments

  • tstart: Start time of the simulation.
  • dt: Interval at which the simulation is saved.
  • tstop: Stop time of the simulation.
  • solver: Solver that is used for simulating the system.

Keyword Arguments

  • sense = InterpolatingAdjoint(autojacvec = ZygoteVJP()): The sensitivity algorithm that is used for caluclating the sensitivities.
  • solargs: Keyword arguments that are passed on to the solver.
source
MeshGraphNets.MultipleShootingType
MultipleShooting(tstart, dt, tstop, solver, interval_size, continuity_term = 100; sense = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true), solargs...)

Similar to SolverTraining, but splits the trajectory into intervals that are solved independently and then combines them for loss calculation. Useful if the network tends to get stuck in a local minimum if SolverTraining is used.

Arguments

  • tstart: Start time of the simulation.
  • dt: Interval at which the simulation is saved.
  • tstop: Stop time of the simulation.
  • solver: Solver that is used for simulating the system.
  • interval_size: Size of the intervals (i.e. number of datapoints in one interval).
  • continuity_term = 100: Factor by which the error between points of concurrent intervals is multiplied.

Keyword Arguments

  • sense = InterpolatingAdjoint(autojacvec = ZygoteVJP(), checkpointing = true):
  • solargs: Keyword arguments that are passed on to the solver.
source