img_phy_sim.data
PhysGen Dataset Loader and Export Utilities
This module provides a PyTorch-compatible interface for loading, processing, visualizing, and exporting samples from the PhysGen dataset, a large-scale physics-based sound propagation dataset. It is designed for machine learning workflows that operate on image-like representations of physical simulations, such as sound maps derived from urban environments.
The module wraps the official PhysGen Hugging Face dataset and exposes flexible options for selecting dataset variations, input/output modalities, and training splits. In addition, it includes utilities for resizing tensors, retrieving individual samples, constructing PyTorch DataLoaders, and exporting datasets to disk as image files.
The core idea is to make PhysGen easy to integrate into deep learning pipelines while supporting common preprocessing steps such as resizing, modality switching, and derived target generation (e.g. complex-only outputs).
Main features:
- PyTorch
Datasetimplementation for the PhysGen dataset - Support for multiple PhysGen variations (baseline, reflection, diffraction, combined)
- Configurable input types (OSM or base simulation)
- Configurable output types (standard sound maps or complex-only targets)
- Automatic image-to-tensor conversion and normalization
- Utilities for resizing tensors to model-friendly dimensions
- Easy access to single samples as NumPy arrays for debugging and visualization
- Dataset export functionality for saving inputs and targets as PNG images
- Command-line interface for batch dataset extraction
Typical workflow:
- Initialize the dataset via
PhysGenDataset(...). - Wrap it in a PyTorch
DataLoaderusingget_dataloader(). - Retrieve individual samples for inspection using
get_image(). - Train or evaluate a model using the provided tensors.
- Optionally export the dataset to disk using
save_dataset().
Dependencies:
- torch
- torchvision
- numpy
- cv2 (OpenCV)
- PIL
- datasets (Hugging Face)
- prime_printer (for progress visualization)
PhysGen references:
- Dataset: https://huggingface.co/datasets/mspitzna/physicsgen
- Paper: https://arxiv.org/abs/2503.05333
- GitHub: https://github.com/physicsgen/physicsgen
Example:
from img_phy_sim.data import PhysGenDataset, get_dataloader
dataset = PhysGenDataset(
variation="sound_reflection",
mode="train",
input_type="osm",
output_type="complex_only"
)
loader = get_dataloader(
mode="train",
variation="sound_reflection",
input_type="osm",
output_type="complex_only"
)
input_img, target_img, idx = dataset[0]
Author:
Tobia Ippolito
Classes:
- PhysGenDataset(Dataset)
PyTorch dataset wrapper for PhysGen with flexible input/output configuration.
Functions:
- resize_tensor_to_divisible_by_14(tensor)
Resize tensors so height and width are divisible by 14. - get_dataloader(...)
Create a PyTorch DataLoader for the PhysGen dataset. - get_image(...)
Retrieve a single dataset sample (optionally as NumPy arrays). - save_dataset(...)
Export PhysGen inputs and targets as PNG images to disk.
Also see:
1""" 2**PhysGen Dataset Loader and Export Utilities** 3 4This module provides a PyTorch-compatible interface for loading, processing, 5visualizing, and exporting samples from the PhysGen dataset, a large-scale 6physics-based sound propagation dataset. It is designed for machine learning 7workflows that operate on image-like representations of physical simulations, 8such as sound maps derived from urban environments. 9 10The module wraps the official PhysGen Hugging Face dataset and exposes flexible 11options for selecting dataset variations, input/output modalities, and training 12splits. In addition, it includes utilities for resizing tensors, retrieving 13individual samples, constructing PyTorch DataLoaders, and exporting datasets 14to disk as image files. 15 16The core idea is to make PhysGen easy to integrate into deep learning pipelines 17while supporting common preprocessing steps such as resizing, modality switching, 18and derived target generation (e.g. complex-only outputs). 19 20Main features: 21- PyTorch `Dataset` implementation for the PhysGen dataset 22- Support for multiple PhysGen variations (baseline, reflection, diffraction, combined) 23- Configurable input types (OSM or base simulation) 24- Configurable output types (standard sound maps or complex-only targets) 25- Automatic image-to-tensor conversion and normalization 26- Utilities for resizing tensors to model-friendly dimensions 27- Easy access to single samples as NumPy arrays for debugging and visualization 28- Dataset export functionality for saving inputs and targets as PNG images 29- Command-line interface for batch dataset extraction 30 31Typical workflow: 321. Initialize the dataset via `PhysGenDataset(...)`. 332. Wrap it in a PyTorch `DataLoader` using `get_dataloader()`. 343. Retrieve individual samples for inspection using `get_image()`. 354. Train or evaluate a model using the provided tensors. 365. Optionally export the dataset to disk using `save_dataset()`. 37 38Dependencies: 39- torch 40- torchvision 41- numpy 42- cv2 (OpenCV) 43- PIL 44- datasets (Hugging Face) 45- prime_printer (for progress visualization) 46 47PhysGen references: 48- Dataset: https://huggingface.co/datasets/mspitzna/physicsgen 49- Paper: https://arxiv.org/abs/2503.05333 50- GitHub: https://github.com/physicsgen/physicsgen 51 52Example: 53```python 54from img_phy_sim.data import PhysGenDataset, get_dataloader 55 56dataset = PhysGenDataset( 57 variation="sound_reflection", 58 mode="train", 59 input_type="osm", 60 output_type="complex_only" 61) 62 63loader = get_dataloader( 64 mode="train", 65 variation="sound_reflection", 66 input_type="osm", 67 output_type="complex_only" 68) 69 70input_img, target_img, idx = dataset[0] 71``` 72 73Author:<br> 74Tobia Ippolito 75 76Classes: 77- PhysGenDataset(Dataset)<br> 78 PyTorch dataset wrapper for PhysGen with flexible input/output configuration. 79 80Functions: 81- resize_tensor_to_divisible_by_14(tensor)<br> 82 Resize tensors so height and width are divisible by 14. 83- get_dataloader(...)<br> 84 Create a PyTorch DataLoader for the PhysGen dataset. 85- get_image(...)<br> 86 Retrieve a single dataset sample (optionally as NumPy arrays). 87- save_dataset(...)<br> 88 Export PhysGen inputs and targets as PNG images to disk. 89 90Also see: 91- https://huggingface.co/datasets/mspitzna/physicsgen 92- https://arxiv.org/abs/2503.05333 93- https://github.com/physicsgen/physicsgen 94""" 95 96 97 98# --------------- 99# >>> Imports <<< 100# --------------- 101import os 102import shutil 103 104import numpy as np 105import cv2 106 107DATA_DEPENDENCIES_AVAILABLE = True 108try: 109 from datasets import load_dataset 110except Exception: 111 DATA_DEPENDENCIES_AVAILABLE = False 112 113try: 114 import torch 115 import torch.nn.functional as F 116 from torch.utils.data import DataLoader, Dataset 117 # import torchvision.transforms as transforms 118 from torchvision import transforms 119except Exception: 120 DATA_DEPENDENCIES_AVAILABLE = False 121 122try: 123 import prime_printer as prime 124 PRIME_AVAILABLE = True 125except Exception: 126 PRIME_AVAILABLE = False 127 128 129if DATA_DEPENDENCIES_AVAILABLE: 130 # -------------- 131 # >>> Helper <<< 132 # -------------- 133 def resize_tensor_to_divisible_by_14(tensor: torch.Tensor) -> torch.Tensor: 134 """ 135 Resize a tensor to the next smaller (H, W) divisible by 14. 136 137 Args: 138 tensor (torch.Tensor): Input tensor of shape (C, H, W) or (B, C, H, W) 139 140 Returns: 141 torch.Tensor: Resized tensor 142 """ 143 if tensor.dim() == 3: 144 c, h, w = tensor.shape 145 new_h = h - (h % 14) 146 new_w = w - (w % 14) 147 return F.interpolate(tensor.unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False).squeeze(0) 148 149 elif tensor.dim() == 4: 150 b, c, h, w = tensor.shape 151 new_h = h - (h % 14) 152 new_w = w - (w % 14) 153 return F.interpolate(tensor, size=(new_h, new_w), mode='bilinear', align_corners=False) 154 155 else: 156 raise ValueError("Tensor must be 3D (C, H, W) or 4D (B, C, H, W)") 157 158 159 160 # ------------------ 161 # >>> Main Class <<< 162 # ------------------ 163 class PhysGenDataset(Dataset): 164 165 def __init__(self, variation="sound_baseline", mode="train", input_type="osm", output_type="standard"): 166 """ 167 Loads PhysGen Dataset. 168 169 Parameters: 170 - variation : str <br> 171 Chooses the used dataset variant: sound_baseline, sound_reflection, sound_diffraction, sound_combined. 172 - mode : str <br> 173 Can be "train", "test", "eval". 174 - input_type : str <br> 175 Defines the used Input -> "osm", "base_simulation" 176 - output_type : str <br> 177 Defines the Output -> "standard", "complex_only" 178 """ 179 self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 180 # get data 181 self.dataset = load_dataset("mspitzna/physicsgen", name=variation, trust_remote_code=True) 182 # print("Keys:", self.dataset.keys()) 183 self.dataset = self.dataset[mode] 184 185 self.input_type = input_type 186 self.output_type = output_type 187 if self.input_type == "base_simulation" or self.output_type == "complex_only": 188 self.basesimulation_dataset = load_dataset("mspitzna/physicsgen", name="sound_baseline", trust_remote_code=True) 189 self.basesimulation_dataset = self.basesimulation_dataset[mode] 190 191 self.transform = transforms.Compose([ 192 transforms.ToTensor(), # Converts [0,255] PIL image to [0,1] FloatTensor 193 ]) 194 print(f"PhysGen ({variation}) Dataset for {mode} got created") 195 196 def __len__(self): 197 return len(self.dataset) 198 199 def __getitem__(self, idx): 200 sample = self.dataset[idx] 201 # print(sample) 202 # print(sample.keys()) 203 if self.input_type == "base_simulation": 204 input_img = self.basesimulation_dataset[idx]["soundmap"] 205 else: 206 input_img = sample["osm"] # PIL Image 207 target_img = sample["soundmap"] # PIL Image 208 209 input_img = self.transform(input_img) 210 target_img = self.transform(target_img) 211 212 # Fix real image size 512x512 > 256x256 213 input_img = F.interpolate(input_img.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False) 214 input_img = input_img.squeeze(0) 215 # target_img = target_img.unsqueeze(0) 216 217 # # change size 218 # input_img = resize_tensor_to_divisible_by_14(input_img) 219 # target_img = resize_tensor_to_divisible_by_14(target_img) 220 221 # add fake rgb 222 # if input_img.shape[0] == 1: # shape (B, 1, H, W) 223 # input_img = input_img.repeat(3, 1, 1) # make it (B, 3, H, W) 224 225 if self.output_type == "complex_only": 226 base_simulation_img = self.transform(self.basesimulation_dataset[idx]["soundmap"]) 227 # base_simulation_img = resize_tensor_to_divisible_by_14(self.transform(self.basesimulation_dataset[idx]["soundmap"])) 228 # target_img = torch.abs(target_img[0] - base_simulation_img[0]) 229 target_img = target_img[0] - base_simulation_img[0] 230 target_img = target_img.unsqueeze(0) 231 target_img *= -1 232 233 return input_img, target_img, idx 234 235 236 237 # ---------------------- 238 # >>> Loading Helper <<< 239 # ---------------------- 240 def get_dataloader(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True): 241 dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type) 242 return DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1) 243 244 245 246 def get_image(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True, 247 return_output=False, as_numpy_array=True): 248 dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type) 249 loader = DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1) 250 cur_data = next(iter(loader)) 251 input_ = cur_data[0] 252 output_ = cur_data[1] 253 254 if as_numpy_array: 255 input_ = input_.detach().cpu().numpy() 256 output_ = output_.detach().cpu().numpy() 257 258 # remove batch channel 259 input_ = np.squeeze(input_, axis=0) 260 output_ = np.squeeze(output_, axis=0) 261 262 if len(input_.shape) == 3: 263 input_ = np.squeeze(input_, axis=0) 264 output_ = np.squeeze(output_, axis=0) 265 266 # opencv format 267 # if np.issubdtype(img.dtype, np.floating): 268 # img = (img - img.min()) / (img.max() - img.min() + 1e-8) 269 # img = (img * 255).astype(np.uint8) 270 input_ = np.transpose(input_, (1, 0)) 271 output_ = np.transpose(output_, (1, 0)) 272 273 274 result = input_ 275 if return_output: 276 result = [input_, output_] 277 278 return result 279 280 281 282 def save_dataset(output_real_path, output_osm_path, 283 variation, input_type, output_type, 284 data_mode, 285 info_print=False, progress_print=True): 286 # Clearing 287 if os.path.exists(output_osm_path) and os.path.isdir(output_osm_path): 288 shutil.rmtree(output_osm_path) 289 os.makedirs(output_osm_path) 290 print(f"Cleared {output_osm_path}.") 291 else: 292 os.makedirs(output_osm_path) 293 print(f"Created {output_osm_path}.") 294 295 if os.path.exists(output_real_path) and os.path.isdir(output_real_path): 296 shutil.rmtree(output_real_path) 297 os.makedirs(output_real_path) 298 print(f"Cleared {output_real_path}.") 299 else: 300 os.makedirs(output_real_path) 301 print(f"Created {output_real_path}.") 302 303 # Load Dataset 304 dataset = PhysGenDataset(mode=data_mode, variation=variation, input_type=input_type, output_type=output_type) 305 data_len = len(dataset) 306 dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) 307 308 # Save Dataset 309 for i, data in enumerate(dataloader): 310 if progress_print: 311 # print(f'Progress {i+1}/{data_len}') 312 if PRIME_AVAILABLE: 313 prime.get_progress_bar(total=data_len, progress=i+1, 314 should_clear=True, left_bar_char='|', right_bar_char='|', 315 progress_char='#', empty_char=' ', 316 front_message='Physgen Data Loading', back_message='', size=15) 317 318 input_img, target_img, idx = data 319 idx = idx[0].item() if isinstance(idx, torch.Tensor) else idx 320 321 # forward_img = inference_forward(input_img, model, DEVICE) 322 323 if info_print: 324 # print(f"Prediction shape [forward]: {forward_img.shape}") 325 print(f"Prediction shape [osm]: {input_img.shape}") 326 print(f"Prediction shape [target]: {target_img.shape}") 327 328 print(f"OSM Info:\n -> shape: {input_img.shape}\n -> min: {input_img.min()}, max: {input_img.max()}") 329 330 # Transform to Numpy 331 # pred_img = forward_img.squeeze(2) 332 # if not (0 <= pred_img.min() <= 255 and 0 <= pred_img.max() <=255): 333 # raise ValueError(f"Prediction has values out of 0-256 range => min:{pred_img.min()}, max:{pred_img.max()}") 334 # if pred_img.max() <= 1.0: 335 # pred_img *= 255 336 # pred_img = pred_img.astype(np.uint8) 337 338 real_img = target_img.squeeze(0).cpu().squeeze(0).detach().numpy() 339 if not (0 <= real_img.min() <= 255 and 0 <= real_img.max() <=255): 340 raise ValueError(f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 341 if info_print: 342 print( f"\nReal target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 343 if real_img.max() <= 1.0: 344 real_img *= 255 345 if info_print: 346 print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 347 real_img = real_img.astype(np.uint8) 348 if info_print: 349 print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 350 351 if len(input_img.shape) == 4: 352 osm_img = input_img[0, 0].cpu().detach().numpy() 353 else: 354 osm_img = input_img[0].cpu().detach().numpy() 355 if not (0 <= osm_img.min() <= 255 and 0 <= osm_img.max() <=255): 356 raise ValueError(f"Real target has values out of 0-256 range => min:{osm_img.min()}, max:{osm_img.max()}") 357 if osm_img.max() <= 1.0: 358 osm_img *= 255 359 osm_img = osm_img.astype(np.uint8) 360 361 if info_print: 362 print(f"OSM Info:\n -> shape: {osm_img.shape}\n -> min: {osm_img.min()}, max: {osm_img.max()}") 363 364 # Save Results 365 file_name = f"physgen_{idx}.png" 366 367 # save pred image 368 # save_img = os.path.join(output_pred_path, file_name) 369 # cv2.imwrite(save_img, pred_img) 370 # print(f" -> saved pred at {save_img}") 371 372 # save real image 373 save_img = os.path.join(output_real_path, "target_"+file_name) 374 cv2.imwrite(save_img, real_img) 375 if info_print: 376 print(f" -> saved real at {save_img}") 377 378 # save osm image 379 save_img = os.path.join(output_osm_path, "input_"+file_name) 380 cv2.imwrite(save_img, osm_img) 381 if info_print: 382 print(f" -> saved osm at {save_img}") 383 print(f"\nSuccessfull saved {data_len} datapoints into {os.path.abspath(output_real_path)} & {os.path.abspath(output_osm_path)}") 384 385 386 387 # ----------------------- 388 # >>> Make it runable <<< 389 # ----------------------- 390 if __name__ == "__main__": 391 import argparse 392 393 parser = argparse.ArgumentParser(description="Save OSM and real PhysGen dataset images.") 394 395 parser.add_argument("--output_real_path", type=str, required=True, help="Path to save real target images") 396 parser.add_argument("--output_osm_path", type=str, required=True, help="Path to save OSM input images") 397 parser.add_argument("--variation", type=str, required=True, help="PhysGen variation (e.g. box_texture, box_position, etc.)") 398 parser.add_argument("--input_type", type=str, required=True, help="Input type (e.g. osm_depth)") 399 parser.add_argument("--output_type", type=str, required=True, help="Output type (e.g. real_depth)") 400 parser.add_argument("--data_mode", type=str, required=True, help="Data Mode: train, test, val") 401 parser.add_argument("--info_print", action="store_true", help="Print additional info") 402 parser.add_argument("--no_progress", action="store_true", help="Disable progress printing") 403 404 args = parser.parse_args() 405 406 save_dataset( 407 output_real_path=args.output_real_path, 408 output_osm_path=args.output_osm_path, 409 variation=args.variation, 410 input_type=args.input_type, 411 output_type=args.output_type, 412 data_mode=args.data_mode, 413 info_print=args.info_print, 414 progress_print=not args.no_progress 415 ) 416 417
134 def resize_tensor_to_divisible_by_14(tensor: torch.Tensor) -> torch.Tensor: 135 """ 136 Resize a tensor to the next smaller (H, W) divisible by 14. 137 138 Args: 139 tensor (torch.Tensor): Input tensor of shape (C, H, W) or (B, C, H, W) 140 141 Returns: 142 torch.Tensor: Resized tensor 143 """ 144 if tensor.dim() == 3: 145 c, h, w = tensor.shape 146 new_h = h - (h % 14) 147 new_w = w - (w % 14) 148 return F.interpolate(tensor.unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False).squeeze(0) 149 150 elif tensor.dim() == 4: 151 b, c, h, w = tensor.shape 152 new_h = h - (h % 14) 153 new_w = w - (w % 14) 154 return F.interpolate(tensor, size=(new_h, new_w), mode='bilinear', align_corners=False) 155 156 else: 157 raise ValueError("Tensor must be 3D (C, H, W) or 4D (B, C, H, W)")
Resize a tensor to the next smaller (H, W) divisible by 14.
Args: tensor (torch.Tensor): Input tensor of shape (C, H, W) or (B, C, H, W)
Returns: torch.Tensor: Resized tensor
164 class PhysGenDataset(Dataset): 165 166 def __init__(self, variation="sound_baseline", mode="train", input_type="osm", output_type="standard"): 167 """ 168 Loads PhysGen Dataset. 169 170 Parameters: 171 - variation : str <br> 172 Chooses the used dataset variant: sound_baseline, sound_reflection, sound_diffraction, sound_combined. 173 - mode : str <br> 174 Can be "train", "test", "eval". 175 - input_type : str <br> 176 Defines the used Input -> "osm", "base_simulation" 177 - output_type : str <br> 178 Defines the Output -> "standard", "complex_only" 179 """ 180 self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 181 # get data 182 self.dataset = load_dataset("mspitzna/physicsgen", name=variation, trust_remote_code=True) 183 # print("Keys:", self.dataset.keys()) 184 self.dataset = self.dataset[mode] 185 186 self.input_type = input_type 187 self.output_type = output_type 188 if self.input_type == "base_simulation" or self.output_type == "complex_only": 189 self.basesimulation_dataset = load_dataset("mspitzna/physicsgen", name="sound_baseline", trust_remote_code=True) 190 self.basesimulation_dataset = self.basesimulation_dataset[mode] 191 192 self.transform = transforms.Compose([ 193 transforms.ToTensor(), # Converts [0,255] PIL image to [0,1] FloatTensor 194 ]) 195 print(f"PhysGen ({variation}) Dataset for {mode} got created") 196 197 def __len__(self): 198 return len(self.dataset) 199 200 def __getitem__(self, idx): 201 sample = self.dataset[idx] 202 # print(sample) 203 # print(sample.keys()) 204 if self.input_type == "base_simulation": 205 input_img = self.basesimulation_dataset[idx]["soundmap"] 206 else: 207 input_img = sample["osm"] # PIL Image 208 target_img = sample["soundmap"] # PIL Image 209 210 input_img = self.transform(input_img) 211 target_img = self.transform(target_img) 212 213 # Fix real image size 512x512 > 256x256 214 input_img = F.interpolate(input_img.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False) 215 input_img = input_img.squeeze(0) 216 # target_img = target_img.unsqueeze(0) 217 218 # # change size 219 # input_img = resize_tensor_to_divisible_by_14(input_img) 220 # target_img = resize_tensor_to_divisible_by_14(target_img) 221 222 # add fake rgb 223 # if input_img.shape[0] == 1: # shape (B, 1, H, W) 224 # input_img = input_img.repeat(3, 1, 1) # make it (B, 3, H, W) 225 226 if self.output_type == "complex_only": 227 base_simulation_img = self.transform(self.basesimulation_dataset[idx]["soundmap"]) 228 # base_simulation_img = resize_tensor_to_divisible_by_14(self.transform(self.basesimulation_dataset[idx]["soundmap"])) 229 # target_img = torch.abs(target_img[0] - base_simulation_img[0]) 230 target_img = target_img[0] - base_simulation_img[0] 231 target_img = target_img.unsqueeze(0) 232 target_img *= -1 233 234 return input_img, target_img, idx
An abstract class representing a Dataset.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite __getitem__(), supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
__len__(), which is expected to return the size of the dataset by many
~torch.utils.data.Sampler implementations and the default options
of ~torch.utils.data.DataLoader. Subclasses could also
optionally implement __getitems__(), for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.
166 def __init__(self, variation="sound_baseline", mode="train", input_type="osm", output_type="standard"): 167 """ 168 Loads PhysGen Dataset. 169 170 Parameters: 171 - variation : str <br> 172 Chooses the used dataset variant: sound_baseline, sound_reflection, sound_diffraction, sound_combined. 173 - mode : str <br> 174 Can be "train", "test", "eval". 175 - input_type : str <br> 176 Defines the used Input -> "osm", "base_simulation" 177 - output_type : str <br> 178 Defines the Output -> "standard", "complex_only" 179 """ 180 self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 181 # get data 182 self.dataset = load_dataset("mspitzna/physicsgen", name=variation, trust_remote_code=True) 183 # print("Keys:", self.dataset.keys()) 184 self.dataset = self.dataset[mode] 185 186 self.input_type = input_type 187 self.output_type = output_type 188 if self.input_type == "base_simulation" or self.output_type == "complex_only": 189 self.basesimulation_dataset = load_dataset("mspitzna/physicsgen", name="sound_baseline", trust_remote_code=True) 190 self.basesimulation_dataset = self.basesimulation_dataset[mode] 191 192 self.transform = transforms.Compose([ 193 transforms.ToTensor(), # Converts [0,255] PIL image to [0,1] FloatTensor 194 ]) 195 print(f"PhysGen ({variation}) Dataset for {mode} got created")
Loads PhysGen Dataset.
Parameters:
- variation : str
Chooses the used dataset variant: sound_baseline, sound_reflection, sound_diffraction, sound_combined. - mode : str
Can be "train", "test", "eval". - input_type : str
Defines the used Input -> "osm", "base_simulation" - output_type : str
Defines the Output -> "standard", "complex_only"
241 def get_dataloader(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True): 242 dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type) 243 return DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1)
247 def get_image(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True, 248 return_output=False, as_numpy_array=True): 249 dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type) 250 loader = DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1) 251 cur_data = next(iter(loader)) 252 input_ = cur_data[0] 253 output_ = cur_data[1] 254 255 if as_numpy_array: 256 input_ = input_.detach().cpu().numpy() 257 output_ = output_.detach().cpu().numpy() 258 259 # remove batch channel 260 input_ = np.squeeze(input_, axis=0) 261 output_ = np.squeeze(output_, axis=0) 262 263 if len(input_.shape) == 3: 264 input_ = np.squeeze(input_, axis=0) 265 output_ = np.squeeze(output_, axis=0) 266 267 # opencv format 268 # if np.issubdtype(img.dtype, np.floating): 269 # img = (img - img.min()) / (img.max() - img.min() + 1e-8) 270 # img = (img * 255).astype(np.uint8) 271 input_ = np.transpose(input_, (1, 0)) 272 output_ = np.transpose(output_, (1, 0)) 273 274 275 result = input_ 276 if return_output: 277 result = [input_, output_] 278 279 return result
283 def save_dataset(output_real_path, output_osm_path, 284 variation, input_type, output_type, 285 data_mode, 286 info_print=False, progress_print=True): 287 # Clearing 288 if os.path.exists(output_osm_path) and os.path.isdir(output_osm_path): 289 shutil.rmtree(output_osm_path) 290 os.makedirs(output_osm_path) 291 print(f"Cleared {output_osm_path}.") 292 else: 293 os.makedirs(output_osm_path) 294 print(f"Created {output_osm_path}.") 295 296 if os.path.exists(output_real_path) and os.path.isdir(output_real_path): 297 shutil.rmtree(output_real_path) 298 os.makedirs(output_real_path) 299 print(f"Cleared {output_real_path}.") 300 else: 301 os.makedirs(output_real_path) 302 print(f"Created {output_real_path}.") 303 304 # Load Dataset 305 dataset = PhysGenDataset(mode=data_mode, variation=variation, input_type=input_type, output_type=output_type) 306 data_len = len(dataset) 307 dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) 308 309 # Save Dataset 310 for i, data in enumerate(dataloader): 311 if progress_print: 312 # print(f'Progress {i+1}/{data_len}') 313 if PRIME_AVAILABLE: 314 prime.get_progress_bar(total=data_len, progress=i+1, 315 should_clear=True, left_bar_char='|', right_bar_char='|', 316 progress_char='#', empty_char=' ', 317 front_message='Physgen Data Loading', back_message='', size=15) 318 319 input_img, target_img, idx = data 320 idx = idx[0].item() if isinstance(idx, torch.Tensor) else idx 321 322 # forward_img = inference_forward(input_img, model, DEVICE) 323 324 if info_print: 325 # print(f"Prediction shape [forward]: {forward_img.shape}") 326 print(f"Prediction shape [osm]: {input_img.shape}") 327 print(f"Prediction shape [target]: {target_img.shape}") 328 329 print(f"OSM Info:\n -> shape: {input_img.shape}\n -> min: {input_img.min()}, max: {input_img.max()}") 330 331 # Transform to Numpy 332 # pred_img = forward_img.squeeze(2) 333 # if not (0 <= pred_img.min() <= 255 and 0 <= pred_img.max() <=255): 334 # raise ValueError(f"Prediction has values out of 0-256 range => min:{pred_img.min()}, max:{pred_img.max()}") 335 # if pred_img.max() <= 1.0: 336 # pred_img *= 255 337 # pred_img = pred_img.astype(np.uint8) 338 339 real_img = target_img.squeeze(0).cpu().squeeze(0).detach().numpy() 340 if not (0 <= real_img.min() <= 255 and 0 <= real_img.max() <=255): 341 raise ValueError(f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 342 if info_print: 343 print( f"\nReal target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 344 if real_img.max() <= 1.0: 345 real_img *= 255 346 if info_print: 347 print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 348 real_img = real_img.astype(np.uint8) 349 if info_print: 350 print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}") 351 352 if len(input_img.shape) == 4: 353 osm_img = input_img[0, 0].cpu().detach().numpy() 354 else: 355 osm_img = input_img[0].cpu().detach().numpy() 356 if not (0 <= osm_img.min() <= 255 and 0 <= osm_img.max() <=255): 357 raise ValueError(f"Real target has values out of 0-256 range => min:{osm_img.min()}, max:{osm_img.max()}") 358 if osm_img.max() <= 1.0: 359 osm_img *= 255 360 osm_img = osm_img.astype(np.uint8) 361 362 if info_print: 363 print(f"OSM Info:\n -> shape: {osm_img.shape}\n -> min: {osm_img.min()}, max: {osm_img.max()}") 364 365 # Save Results 366 file_name = f"physgen_{idx}.png" 367 368 # save pred image 369 # save_img = os.path.join(output_pred_path, file_name) 370 # cv2.imwrite(save_img, pred_img) 371 # print(f" -> saved pred at {save_img}") 372 373 # save real image 374 save_img = os.path.join(output_real_path, "target_"+file_name) 375 cv2.imwrite(save_img, real_img) 376 if info_print: 377 print(f" -> saved real at {save_img}") 378 379 # save osm image 380 save_img = os.path.join(output_osm_path, "input_"+file_name) 381 cv2.imwrite(save_img, osm_img) 382 if info_print: 383 print(f" -> saved osm at {save_img}") 384 print(f"\nSuccessfull saved {data_len} datapoints into {os.path.abspath(output_real_path)} & {os.path.abspath(output_osm_path)}")