1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
| import dataclasses import functools
from graphcast import autoregressive from graphcast import casting from graphcast import data_utils from graphcast import graphcast from graphcast import normalization from graphcast import rollout from graphcast import xarray_jax from graphcast import xarray_tree import haiku as hk import jax import numpy as np import xarray
def parse_file_parts(file_name): return dict(part.split("-", 1) for part in file_name.split("_"))
def construct_wrapped_graphcast( model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig): """Constructs and wraps the GraphCast Predictor.""" predictor = graphcast.GraphCast(model_config, task_config)
predictor = casting.Bfloat16Cast(predictor)
predictor = normalization.InputsAndResiduals( predictor, diffs_stddev_by_level=diffs_stddev_by_level, mean_by_level=mean_by_level, stddev_by_level=stddev_by_level)
predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True) return predictor
@hk.transform_with_state def run_forward(model_config, task_config, inputs, targets_template, forcings): predictor = construct_wrapped_graphcast(model_config, task_config) return predictor(inputs, targets_template=targets_template, forcings=forcings)
@hk.transform_with_state def loss_fn(model_config, task_config, inputs, targets, forcings): predictor = construct_wrapped_graphcast(model_config, task_config) loss, diagnostics = predictor.loss(inputs, targets, forcings) return xarray_tree.map_structure( lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True), (loss, diagnostics))
def grads_fn(params, state, model_config, task_config, inputs, targets, forcings): def _aux(params, state, i, t, f): (loss, diagnostics), next_state = loss_fn.apply( params, state, jax.random.PRNGKey(0), model_config, task_config, i, t, f) return loss, (diagnostics, next_state) (loss, (diagnostics, next_state)), grads = jax.value_and_grad( _aux, has_aux=True)(params, state, inputs, targets, forcings) return loss, diagnostics, next_state, grads
def with_configs(fn): return functools.partial( fn, model_config=model_config, task_config=task_config)
def with_params(fn): return functools.partial(fn, params=params, state=state)
def drop_state(fn): return lambda **kw: fn(**kw)[0]
if __name__ == "__main__": file = "xxxx.nc" params = None state = {}
model_config = graphcast.ModelConfig( resolution=0, mesh_size=random_mesh_size.value, latent_size=random_latent_size.value, gnn_msg_steps=random_gnn_msg_steps.value, hidden_layers=1, radius_query_fraction_edge_length=0.6) task_config = graphcast.TaskConfig( input_variables=graphcast.TASK.input_variables, target_variables=graphcast.TASK.target_variables, forcing_variables=graphcast.TASK.forcing_variables, pressure_levels=graphcast.PRESSURE_LEVELS[random_levels.value], input_duration=graphcast.TASK.input_duration, ) example_batch = xarray.load_dataset(file).compute() train_steps = 1 eval_steps = 1
train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings( example_batch, target_lead_times=slice("6h", f"{train_steps*6}h"), **dataclasses.asdict(task_config)) eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings( example_batch, target_lead_times=slice("6h", f"{eval_steps*6}h"), **dataclasses.asdict(task_config))
print("All Examples: ", example_batch.dims.mapping) print("Train Inputs: ", train_inputs.dims.mapping) print("Train Targets: ", train_targets.dims.mapping) print("Train Forcings:", train_forcings.dims.mapping) print("Eval Inputs: ", eval_inputs.dims.mapping) print("Eval Targets: ", eval_targets.dims.mapping) print("Eval Forcings: ", eval_forcings.dims.mapping)
with open ("stats/diffs_stddev_by_level.nc", "rb") as f: diffs_stddev_by_level = xarray.load_dataset(f).compute() with open ("stats/mean_by_level.nc", "rb") as f: mean_by_level = xarray.load_dataset(f).compute() with open ("stats/stddev_by_level.nc", "rb") as f: stddev_by_level = xarray.load_dataset(f).compute() init_jitted = jax.jit(with_configs(run_forward.init))
if params is None: params, state = init_jitted( rng=jax.random.PRNGKey(0), inputs=train_inputs, targets_template=train_targets, forcings=train_forcings)
loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply)))) grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn))) run_forward_jitted = drop_state(with_params(jax.jit(with_configs( run_forward.apply))))
assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), ( "Model resolution doesn't match the data resolution. You likely want to " "re-filter the dataset list, and download the correct data.")
print("Inputs: ", eval_inputs.dims.mapping) print("Targets: ", eval_targets.dims.mapping) print("Forcings:", eval_forcings.dims.mapping)
predictions = rollout.chunked_prediction( run_forward_jitted, rng=jax.random.PRNGKey(0), inputs=eval_inputs, targets_template=eval_targets * np.nan, forcings=eval_forcings)
|