img_phy_sim.data

PhysGen Dataset Loader and Export Utilities

This module provides a PyTorch-compatible interface for loading, processing, visualizing, and exporting samples from the PhysGen dataset, a large-scale physics-based sound propagation dataset. It is designed for machine learning workflows that operate on image-like representations of physical simulations, such as sound maps derived from urban environments.

The module wraps the official PhysGen Hugging Face dataset and exposes flexible options for selecting dataset variations, input/output modalities, and training splits. In addition, it includes utilities for resizing tensors, retrieving individual samples, constructing PyTorch DataLoaders, and exporting datasets to disk as image files.

The core idea is to make PhysGen easy to integrate into deep learning pipelines while supporting common preprocessing steps such as resizing, modality switching, and derived target generation (e.g. complex-only outputs).

Main features:

  • PyTorch Dataset implementation for the PhysGen dataset
  • Support for multiple PhysGen variations (baseline, reflection, diffraction, combined)
  • Configurable input types (OSM or base simulation)
  • Configurable output types (standard sound maps or complex-only targets)
  • Automatic image-to-tensor conversion and normalization
  • Utilities for resizing tensors to model-friendly dimensions
  • Easy access to single samples as NumPy arrays for debugging and visualization
  • Dataset export functionality for saving inputs and targets as PNG images
  • Command-line interface for batch dataset extraction

Typical workflow:

  1. Initialize the dataset via PhysGenDataset(...).
  2. Wrap it in a PyTorch DataLoader using get_dataloader().
  3. Retrieve individual samples for inspection using get_image().
  4. Train or evaluate a model using the provided tensors.
  5. Optionally export the dataset to disk using save_dataset().

Dependencies:

  • torch
  • torchvision
  • numpy
  • cv2 (OpenCV)
  • PIL
  • datasets (Hugging Face)
  • prime_printer (for progress visualization)

PhysGen references:

Example:

from img_phy_sim.data import PhysGenDataset, get_dataloader

dataset = PhysGenDataset(
    variation="sound_reflection",
    mode="train",
    input_type="osm",
    output_type="complex_only"
)

loader = get_dataloader(
    mode="train",
    variation="sound_reflection",
    input_type="osm",
    output_type="complex_only"
)

input_img, target_img, idx = dataset[0]

Author:
Tobia Ippolito

Classes:

  • PhysGenDataset(Dataset)
    PyTorch dataset wrapper for PhysGen with flexible input/output configuration.

Functions:

  • resize_tensor_to_divisible_by_14(tensor)
    Resize tensors so height and width are divisible by 14.
  • get_dataloader(...)
    Create a PyTorch DataLoader for the PhysGen dataset.
  • get_image(...)
    Retrieve a single dataset sample (optionally as NumPy arrays).
  • save_dataset(...)
    Export PhysGen inputs and targets as PNG images to disk.

Also see:

  1"""
  2**PhysGen Dataset Loader and Export Utilities**
  3
  4This module provides a PyTorch-compatible interface for loading, processing,
  5visualizing, and exporting samples from the PhysGen dataset, a large-scale
  6physics-based sound propagation dataset. It is designed for machine learning
  7workflows that operate on image-like representations of physical simulations,
  8such as sound maps derived from urban environments.
  9
 10The module wraps the official PhysGen Hugging Face dataset and exposes flexible
 11options for selecting dataset variations, input/output modalities, and training
 12splits. In addition, it includes utilities for resizing tensors, retrieving
 13individual samples, constructing PyTorch DataLoaders, and exporting datasets
 14to disk as image files.
 15
 16The core idea is to make PhysGen easy to integrate into deep learning pipelines
 17while supporting common preprocessing steps such as resizing, modality switching,
 18and derived target generation (e.g. complex-only outputs).
 19
 20Main features:
 21- PyTorch `Dataset` implementation for the PhysGen dataset
 22- Support for multiple PhysGen variations (baseline, reflection, diffraction, combined)
 23- Configurable input types (OSM or base simulation)
 24- Configurable output types (standard sound maps or complex-only targets)
 25- Automatic image-to-tensor conversion and normalization
 26- Utilities for resizing tensors to model-friendly dimensions
 27- Easy access to single samples as NumPy arrays for debugging and visualization
 28- Dataset export functionality for saving inputs and targets as PNG images
 29- Command-line interface for batch dataset extraction
 30
 31Typical workflow:
 321. Initialize the dataset via `PhysGenDataset(...)`.
 332. Wrap it in a PyTorch `DataLoader` using `get_dataloader()`.
 343. Retrieve individual samples for inspection using `get_image()`.
 354. Train or evaluate a model using the provided tensors.
 365. Optionally export the dataset to disk using `save_dataset()`.
 37
 38Dependencies:
 39- torch
 40- torchvision
 41- numpy
 42- cv2 (OpenCV)
 43- PIL
 44- datasets (Hugging Face)
 45- prime_printer (for progress visualization)
 46
 47PhysGen references:
 48- Dataset: https://huggingface.co/datasets/mspitzna/physicsgen
 49- Paper: https://arxiv.org/abs/2503.05333
 50- GitHub: https://github.com/physicsgen/physicsgen
 51
 52Example:
 53```python
 54from img_phy_sim.data import PhysGenDataset, get_dataloader
 55
 56dataset = PhysGenDataset(
 57    variation="sound_reflection",
 58    mode="train",
 59    input_type="osm",
 60    output_type="complex_only"
 61)
 62
 63loader = get_dataloader(
 64    mode="train",
 65    variation="sound_reflection",
 66    input_type="osm",
 67    output_type="complex_only"
 68)
 69
 70input_img, target_img, idx = dataset[0]
 71```
 72
 73Author:<br>
 74Tobia Ippolito
 75
 76Classes:
 77- PhysGenDataset(Dataset)<br>
 78    PyTorch dataset wrapper for PhysGen with flexible input/output configuration.
 79
 80Functions:
 81- resize_tensor_to_divisible_by_14(tensor)<br>
 82    Resize tensors so height and width are divisible by 14.
 83- get_dataloader(...)<br>
 84    Create a PyTorch DataLoader for the PhysGen dataset.
 85- get_image(...)<br>
 86    Retrieve a single dataset sample (optionally as NumPy arrays).
 87- save_dataset(...)<br>
 88    Export PhysGen inputs and targets as PNG images to disk.
 89
 90Also see:
 91- https://huggingface.co/datasets/mspitzna/physicsgen
 92- https://arxiv.org/abs/2503.05333
 93- https://github.com/physicsgen/physicsgen
 94"""
 95
 96
 97
 98# ---------------
 99# >>> Imports <<<
100# ---------------
101import os
102import shutil
103
104import numpy as np
105import cv2
106
107DATA_DEPENDENCIES_AVAILABLE = True
108try:
109    from datasets import load_dataset
110except Exception:
111    DATA_DEPENDENCIES_AVAILABLE = False
112
113try:
114    import torch
115    import torch.nn.functional as F
116    from torch.utils.data import DataLoader, Dataset
117    # import torchvision.transforms as transforms
118    from torchvision import transforms
119except Exception:
120    DATA_DEPENDENCIES_AVAILABLE = False
121
122try:
123    import prime_printer as prime
124    PRIME_AVAILABLE = True
125except Exception:
126    PRIME_AVAILABLE = False
127
128
129if DATA_DEPENDENCIES_AVAILABLE:
130    # --------------
131    # >>> Helper <<<
132    # --------------
133    def resize_tensor_to_divisible_by_14(tensor: torch.Tensor) -> torch.Tensor:
134        """
135        Resize a tensor to the next smaller (H, W) divisible by 14.
136        
137        Args:
138            tensor (torch.Tensor): Input tensor of shape (C, H, W) or (B, C, H, W)
139        
140        Returns:
141            torch.Tensor: Resized tensor
142        """
143        if tensor.dim() == 3:
144            c, h, w = tensor.shape
145            new_h = h - (h % 14)
146            new_w = w - (w % 14)
147            return F.interpolate(tensor.unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False).squeeze(0)
148        
149        elif tensor.dim() == 4:
150            b, c, h, w = tensor.shape
151            new_h = h - (h % 14)
152            new_w = w - (w % 14)
153            return F.interpolate(tensor, size=(new_h, new_w), mode='bilinear', align_corners=False)
154        
155        else:
156            raise ValueError("Tensor must be 3D (C, H, W) or 4D (B, C, H, W)")
157
158
159
160    # ------------------
161    # >>> Main Class <<<
162    # ------------------
163    class PhysGenDataset(Dataset):
164
165        def __init__(self, variation="sound_baseline", mode="train", input_type="osm", output_type="standard"):
166            """
167            Loads PhysGen Dataset.
168
169            Parameters:
170            - variation : str <br>
171                Chooses the used dataset variant: sound_baseline, sound_reflection, sound_diffraction, sound_combined.
172            - mode : str <br>
173                Can be "train", "test", "eval".
174            - input_type : str <br>
175                Defines the used Input -> "osm", "base_simulation"
176            - output_type : str <br>
177                Defines the Output -> "standard", "complex_only"
178            """
179            self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
180            # get data
181            self.dataset = load_dataset("mspitzna/physicsgen", name=variation, trust_remote_code=True)
182            # print("Keys:", self.dataset.keys())
183            self.dataset = self.dataset[mode]
184            
185            self.input_type = input_type
186            self.output_type = output_type
187            if self.input_type == "base_simulation" or self.output_type == "complex_only":
188                self.basesimulation_dataset = load_dataset("mspitzna/physicsgen", name="sound_baseline", trust_remote_code=True)
189                self.basesimulation_dataset = self.basesimulation_dataset[mode]
190
191            self.transform = transforms.Compose([
192                transforms.ToTensor(),  # Converts [0,255] PIL image to [0,1] FloatTensor
193            ])
194            print(f"PhysGen ({variation}) Dataset for {mode} got created")
195
196        def __len__(self):
197            return len(self.dataset)
198
199        def __getitem__(self, idx):
200            sample = self.dataset[idx]
201            # print(sample)
202            # print(sample.keys())
203            if self.input_type == "base_simulation":
204                input_img = self.basesimulation_dataset[idx]["soundmap"]
205            else:
206                input_img = sample["osm"]  # PIL Image
207            target_img = sample["soundmap"]  # PIL Image
208
209            input_img = self.transform(input_img)
210            target_img = self.transform(target_img)
211
212            # Fix real image size 512x512 > 256x256
213            input_img = F.interpolate(input_img.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
214            input_img = input_img.squeeze(0)
215            # target_img = target_img.unsqueeze(0)
216
217            # # change size
218            # input_img = resize_tensor_to_divisible_by_14(input_img)
219            # target_img = resize_tensor_to_divisible_by_14(target_img)
220
221            # add fake rgb
222            # if input_img.shape[0] == 1:  # shape (B, 1, H, W)
223            #     input_img = input_img.repeat(3, 1, 1)  # make it (B, 3, H, W)
224
225            if self.output_type == "complex_only":
226                base_simulation_img = self.transform(self.basesimulation_dataset[idx]["soundmap"])
227                # base_simulation_img = resize_tensor_to_divisible_by_14(self.transform(self.basesimulation_dataset[idx]["soundmap"]))
228                # target_img = torch.abs(target_img[0] - base_simulation_img[0])
229                target_img = target_img[0] - base_simulation_img[0]
230                target_img = target_img.unsqueeze(0)
231                target_img *= -1
232
233            return input_img, target_img, idx
234
235
236
237    # ----------------------
238    # >>> Loading Helper <<<
239    # ----------------------
240    def get_dataloader(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True):
241        dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type)
242        return DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1)
243
244
245
246    def get_image(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True, 
247                return_output=False, as_numpy_array=True):
248        dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type)
249        loader = DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1)
250        cur_data = next(iter(loader))
251        input_ = cur_data[0]
252        output_ = cur_data[1]
253
254        if as_numpy_array:
255            input_ = input_.detach().cpu().numpy()
256            output_ = output_.detach().cpu().numpy()
257
258            # remove batch channel
259            input_ = np.squeeze(input_, axis=0)
260            output_ = np.squeeze(output_, axis=0)
261
262            if len(input_.shape) == 3:
263                input_ = np.squeeze(input_, axis=0)
264                output_ = np.squeeze(output_, axis=0)
265
266            # opencv format
267            # if np.issubdtype(img.dtype, np.floating):
268            #     img = (img - img.min()) / (img.max() - img.min() + 1e-8)
269            #     img = (img * 255).astype(np.uint8)
270            input_ = np.transpose(input_, (1, 0))
271            output_ = np.transpose(output_, (1, 0))
272
273
274        result = input_
275        if return_output:
276            result = [input_, output_]
277
278        return result
279
280
281
282    def save_dataset(output_real_path, output_osm_path, 
283                    variation, input_type, output_type,
284                    data_mode, 
285                    info_print=False, progress_print=True):
286        # Clearing
287        if os.path.exists(output_osm_path) and os.path.isdir(output_osm_path):
288            shutil.rmtree(output_osm_path)
289            os.makedirs(output_osm_path)
290            print(f"Cleared {output_osm_path}.")
291        else:
292            os.makedirs(output_osm_path)
293            print(f"Created {output_osm_path}.")
294
295        if os.path.exists(output_real_path) and os.path.isdir(output_real_path):
296            shutil.rmtree(output_real_path)
297            os.makedirs(output_real_path)
298            print(f"Cleared {output_real_path}.")
299        else:
300            os.makedirs(output_real_path)
301            print(f"Created {output_real_path}.")
302        
303        # Load Dataset
304        dataset = PhysGenDataset(mode=data_mode, variation=variation, input_type=input_type, output_type=output_type)
305        data_len = len(dataset)
306        dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
307
308        # Save Dataset
309        for i, data in enumerate(dataloader):
310            if progress_print:
311                # print(f'Progress {i+1}/{data_len}')
312                if PRIME_AVAILABLE:
313                    prime.get_progress_bar(total=data_len, progress=i+1, 
314                                        should_clear=True, left_bar_char='|', right_bar_char='|', 
315                                        progress_char='#', empty_char=' ', 
316                                        front_message='Physgen Data Loading', back_message='', size=15)
317
318            input_img, target_img, idx = data
319            idx = idx[0].item() if isinstance(idx, torch.Tensor) else idx
320
321            # forward_img = inference_forward(input_img, model, DEVICE)
322
323            if info_print:
324                # print(f"Prediction shape [forward]: {forward_img.shape}")
325                print(f"Prediction shape [osm]: {input_img.shape}")
326                print(f"Prediction shape [target]: {target_img.shape}")
327
328                print(f"OSM Info:\n    -> shape: {input_img.shape}\n    -> min: {input_img.min()}, max: {input_img.max()}")
329
330            # Transform to Numpy
331            # pred_img = forward_img.squeeze(2)
332            # if not (0 <= pred_img.min() <= 255 and 0 <= pred_img.max() <=255):
333            #     raise ValueError(f"Prediction has values out of 0-256 range => min:{pred_img.min()}, max:{pred_img.max()}")
334            # if pred_img.max() <= 1.0:
335            #     pred_img *= 255
336            # pred_img = pred_img.astype(np.uint8)
337
338            real_img = target_img.squeeze(0).cpu().squeeze(0).detach().numpy()
339            if not (0 <= real_img.min() <= 255 and 0 <= real_img.max() <=255):
340                raise ValueError(f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
341            if info_print:
342                print( f"\nReal target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
343            if real_img.max() <= 1.0:
344                real_img *= 255
345            if info_print:
346                print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
347            real_img = real_img.astype(np.uint8)
348            if info_print:
349                print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
350
351            if len(input_img.shape) == 4:
352                osm_img = input_img[0, 0].cpu().detach().numpy()
353            else:
354                osm_img = input_img[0].cpu().detach().numpy()
355            if not (0 <= osm_img.min() <= 255 and 0 <= osm_img.max() <=255):
356                raise ValueError(f"Real target has values out of 0-256 range => min:{osm_img.min()}, max:{osm_img.max()}")
357            if osm_img.max() <= 1.0:
358                osm_img *= 255
359            osm_img = osm_img.astype(np.uint8)
360
361            if info_print:
362                print(f"OSM Info:\n    -> shape: {osm_img.shape}\n    -> min: {osm_img.min()}, max: {osm_img.max()}")
363
364            # Save Results
365            file_name = f"physgen_{idx}.png"
366
367            # save pred image
368            # save_img = os.path.join(output_pred_path, file_name)
369            # cv2.imwrite(save_img, pred_img)
370            # print(f"    -> saved pred at {save_img}")
371
372            # save real image
373            save_img = os.path.join(output_real_path, "target_"+file_name)
374            cv2.imwrite(save_img, real_img)
375            if info_print:
376                print(f"    -> saved real at {save_img}")
377
378            # save osm image
379            save_img = os.path.join(output_osm_path, "input_"+file_name)
380            cv2.imwrite(save_img, osm_img)
381            if info_print:
382                print(f"    -> saved osm at {save_img}")
383        print(f"\nSuccessfull saved {data_len} datapoints into {os.path.abspath(output_real_path)} & {os.path.abspath(output_osm_path)}")
384
385
386
387    # -----------------------
388    # >>> Make it runable <<<
389    # -----------------------
390    if __name__ == "__main__":
391        import argparse
392
393        parser = argparse.ArgumentParser(description="Save OSM and real PhysGen dataset images.")
394
395        parser.add_argument("--output_real_path", type=str, required=True, help="Path to save real target images")
396        parser.add_argument("--output_osm_path", type=str, required=True, help="Path to save OSM input images")
397        parser.add_argument("--variation", type=str, required=True, help="PhysGen variation (e.g. box_texture, box_position, etc.)")
398        parser.add_argument("--input_type", type=str, required=True, help="Input type (e.g. osm_depth)")
399        parser.add_argument("--output_type", type=str, required=True, help="Output type (e.g. real_depth)")
400        parser.add_argument("--data_mode", type=str, required=True, help="Data Mode: train, test, val")
401        parser.add_argument("--info_print", action="store_true", help="Print additional info")
402        parser.add_argument("--no_progress", action="store_true", help="Disable progress printing")
403
404        args = parser.parse_args()
405
406        save_dataset(
407            output_real_path=args.output_real_path,
408            output_osm_path=args.output_osm_path,
409            variation=args.variation,
410            input_type=args.input_type,
411            output_type=args.output_type,
412            data_mode=args.data_mode,
413            info_print=args.info_print,
414            progress_print=not args.no_progress
415        )
416
417    
DATA_DEPENDENCIES_AVAILABLE = True
def resize_tensor_to_divisible_by_14(tensor: torch.Tensor) -> torch.Tensor:
134    def resize_tensor_to_divisible_by_14(tensor: torch.Tensor) -> torch.Tensor:
135        """
136        Resize a tensor to the next smaller (H, W) divisible by 14.
137        
138        Args:
139            tensor (torch.Tensor): Input tensor of shape (C, H, W) or (B, C, H, W)
140        
141        Returns:
142            torch.Tensor: Resized tensor
143        """
144        if tensor.dim() == 3:
145            c, h, w = tensor.shape
146            new_h = h - (h % 14)
147            new_w = w - (w % 14)
148            return F.interpolate(tensor.unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False).squeeze(0)
149        
150        elif tensor.dim() == 4:
151            b, c, h, w = tensor.shape
152            new_h = h - (h % 14)
153            new_w = w - (w % 14)
154            return F.interpolate(tensor, size=(new_h, new_w), mode='bilinear', align_corners=False)
155        
156        else:
157            raise ValueError("Tensor must be 3D (C, H, W) or 4D (B, C, H, W)")

Resize a tensor to the next smaller (H, W) divisible by 14.

Args: tensor (torch.Tensor): Input tensor of shape (C, H, W) or (B, C, H, W)

Returns: torch.Tensor: Resized tensor

class PhysGenDataset(typing.Generic[+_T_co]):
164    class PhysGenDataset(Dataset):
165
166        def __init__(self, variation="sound_baseline", mode="train", input_type="osm", output_type="standard"):
167            """
168            Loads PhysGen Dataset.
169
170            Parameters:
171            - variation : str <br>
172                Chooses the used dataset variant: sound_baseline, sound_reflection, sound_diffraction, sound_combined.
173            - mode : str <br>
174                Can be "train", "test", "eval".
175            - input_type : str <br>
176                Defines the used Input -> "osm", "base_simulation"
177            - output_type : str <br>
178                Defines the Output -> "standard", "complex_only"
179            """
180            self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
181            # get data
182            self.dataset = load_dataset("mspitzna/physicsgen", name=variation, trust_remote_code=True)
183            # print("Keys:", self.dataset.keys())
184            self.dataset = self.dataset[mode]
185            
186            self.input_type = input_type
187            self.output_type = output_type
188            if self.input_type == "base_simulation" or self.output_type == "complex_only":
189                self.basesimulation_dataset = load_dataset("mspitzna/physicsgen", name="sound_baseline", trust_remote_code=True)
190                self.basesimulation_dataset = self.basesimulation_dataset[mode]
191
192            self.transform = transforms.Compose([
193                transforms.ToTensor(),  # Converts [0,255] PIL image to [0,1] FloatTensor
194            ])
195            print(f"PhysGen ({variation}) Dataset for {mode} got created")
196
197        def __len__(self):
198            return len(self.dataset)
199
200        def __getitem__(self, idx):
201            sample = self.dataset[idx]
202            # print(sample)
203            # print(sample.keys())
204            if self.input_type == "base_simulation":
205                input_img = self.basesimulation_dataset[idx]["soundmap"]
206            else:
207                input_img = sample["osm"]  # PIL Image
208            target_img = sample["soundmap"]  # PIL Image
209
210            input_img = self.transform(input_img)
211            target_img = self.transform(target_img)
212
213            # Fix real image size 512x512 > 256x256
214            input_img = F.interpolate(input_img.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
215            input_img = input_img.squeeze(0)
216            # target_img = target_img.unsqueeze(0)
217
218            # # change size
219            # input_img = resize_tensor_to_divisible_by_14(input_img)
220            # target_img = resize_tensor_to_divisible_by_14(target_img)
221
222            # add fake rgb
223            # if input_img.shape[0] == 1:  # shape (B, 1, H, W)
224            #     input_img = input_img.repeat(3, 1, 1)  # make it (B, 3, H, W)
225
226            if self.output_type == "complex_only":
227                base_simulation_img = self.transform(self.basesimulation_dataset[idx]["soundmap"])
228                # base_simulation_img = resize_tensor_to_divisible_by_14(self.transform(self.basesimulation_dataset[idx]["soundmap"]))
229                # target_img = torch.abs(target_img[0] - base_simulation_img[0])
230                target_img = target_img[0] - base_simulation_img[0]
231                target_img = target_img.unsqueeze(0)
232                target_img *= -1
233
234            return input_img, target_img, idx

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many ~torch.utils.data.Sampler implementations and the default options of ~torch.utils.data.DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

PhysGenDataset( variation='sound_baseline', mode='train', input_type='osm', output_type='standard')
166        def __init__(self, variation="sound_baseline", mode="train", input_type="osm", output_type="standard"):
167            """
168            Loads PhysGen Dataset.
169
170            Parameters:
171            - variation : str <br>
172                Chooses the used dataset variant: sound_baseline, sound_reflection, sound_diffraction, sound_combined.
173            - mode : str <br>
174                Can be "train", "test", "eval".
175            - input_type : str <br>
176                Defines the used Input -> "osm", "base_simulation"
177            - output_type : str <br>
178                Defines the Output -> "standard", "complex_only"
179            """
180            self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
181            # get data
182            self.dataset = load_dataset("mspitzna/physicsgen", name=variation, trust_remote_code=True)
183            # print("Keys:", self.dataset.keys())
184            self.dataset = self.dataset[mode]
185            
186            self.input_type = input_type
187            self.output_type = output_type
188            if self.input_type == "base_simulation" or self.output_type == "complex_only":
189                self.basesimulation_dataset = load_dataset("mspitzna/physicsgen", name="sound_baseline", trust_remote_code=True)
190                self.basesimulation_dataset = self.basesimulation_dataset[mode]
191
192            self.transform = transforms.Compose([
193                transforms.ToTensor(),  # Converts [0,255] PIL image to [0,1] FloatTensor
194            ])
195            print(f"PhysGen ({variation}) Dataset for {mode} got created")

Loads PhysGen Dataset.

Parameters:

  • variation : str
    Chooses the used dataset variant: sound_baseline, sound_reflection, sound_diffraction, sound_combined.
  • mode : str
    Can be "train", "test", "eval".
  • input_type : str
    Defines the used Input -> "osm", "base_simulation"
  • output_type : str
    Defines the Output -> "standard", "complex_only"
device
dataset
input_type
output_type
transform
def get_dataloader( mode='train', variation='sound_reflection', input_type='osm', output_type='complex_only', shuffle=True):
241    def get_dataloader(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True):
242        dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type)
243        return DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1)
def get_image( mode='train', variation='sound_reflection', input_type='osm', output_type='complex_only', shuffle=True, return_output=False, as_numpy_array=True):
247    def get_image(mode='train', variation="sound_reflection", input_type="osm", output_type="complex_only", shuffle=True, 
248                return_output=False, as_numpy_array=True):
249        dataset = PhysGenDataset(mode=mode, variation=variation, input_type=input_type, output_type=output_type)
250        loader = DataLoader(dataset, batch_size=1, shuffle=shuffle, num_workers=1)
251        cur_data = next(iter(loader))
252        input_ = cur_data[0]
253        output_ = cur_data[1]
254
255        if as_numpy_array:
256            input_ = input_.detach().cpu().numpy()
257            output_ = output_.detach().cpu().numpy()
258
259            # remove batch channel
260            input_ = np.squeeze(input_, axis=0)
261            output_ = np.squeeze(output_, axis=0)
262
263            if len(input_.shape) == 3:
264                input_ = np.squeeze(input_, axis=0)
265                output_ = np.squeeze(output_, axis=0)
266
267            # opencv format
268            # if np.issubdtype(img.dtype, np.floating):
269            #     img = (img - img.min()) / (img.max() - img.min() + 1e-8)
270            #     img = (img * 255).astype(np.uint8)
271            input_ = np.transpose(input_, (1, 0))
272            output_ = np.transpose(output_, (1, 0))
273
274
275        result = input_
276        if return_output:
277            result = [input_, output_]
278
279        return result
def save_dataset( output_real_path, output_osm_path, variation, input_type, output_type, data_mode, info_print=False, progress_print=True):
283    def save_dataset(output_real_path, output_osm_path, 
284                    variation, input_type, output_type,
285                    data_mode, 
286                    info_print=False, progress_print=True):
287        # Clearing
288        if os.path.exists(output_osm_path) and os.path.isdir(output_osm_path):
289            shutil.rmtree(output_osm_path)
290            os.makedirs(output_osm_path)
291            print(f"Cleared {output_osm_path}.")
292        else:
293            os.makedirs(output_osm_path)
294            print(f"Created {output_osm_path}.")
295
296        if os.path.exists(output_real_path) and os.path.isdir(output_real_path):
297            shutil.rmtree(output_real_path)
298            os.makedirs(output_real_path)
299            print(f"Cleared {output_real_path}.")
300        else:
301            os.makedirs(output_real_path)
302            print(f"Created {output_real_path}.")
303        
304        # Load Dataset
305        dataset = PhysGenDataset(mode=data_mode, variation=variation, input_type=input_type, output_type=output_type)
306        data_len = len(dataset)
307        dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
308
309        # Save Dataset
310        for i, data in enumerate(dataloader):
311            if progress_print:
312                # print(f'Progress {i+1}/{data_len}')
313                if PRIME_AVAILABLE:
314                    prime.get_progress_bar(total=data_len, progress=i+1, 
315                                        should_clear=True, left_bar_char='|', right_bar_char='|', 
316                                        progress_char='#', empty_char=' ', 
317                                        front_message='Physgen Data Loading', back_message='', size=15)
318
319            input_img, target_img, idx = data
320            idx = idx[0].item() if isinstance(idx, torch.Tensor) else idx
321
322            # forward_img = inference_forward(input_img, model, DEVICE)
323
324            if info_print:
325                # print(f"Prediction shape [forward]: {forward_img.shape}")
326                print(f"Prediction shape [osm]: {input_img.shape}")
327                print(f"Prediction shape [target]: {target_img.shape}")
328
329                print(f"OSM Info:\n    -> shape: {input_img.shape}\n    -> min: {input_img.min()}, max: {input_img.max()}")
330
331            # Transform to Numpy
332            # pred_img = forward_img.squeeze(2)
333            # if not (0 <= pred_img.min() <= 255 and 0 <= pred_img.max() <=255):
334            #     raise ValueError(f"Prediction has values out of 0-256 range => min:{pred_img.min()}, max:{pred_img.max()}")
335            # if pred_img.max() <= 1.0:
336            #     pred_img *= 255
337            # pred_img = pred_img.astype(np.uint8)
338
339            real_img = target_img.squeeze(0).cpu().squeeze(0).detach().numpy()
340            if not (0 <= real_img.min() <= 255 and 0 <= real_img.max() <=255):
341                raise ValueError(f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
342            if info_print:
343                print( f"\nReal target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
344            if real_img.max() <= 1.0:
345                real_img *= 255
346            if info_print:
347                print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
348            real_img = real_img.astype(np.uint8)
349            if info_print:
350                print( f"Real target has values out of 0-256 range => min:{real_img.min()}, max:{real_img.max()}")
351
352            if len(input_img.shape) == 4:
353                osm_img = input_img[0, 0].cpu().detach().numpy()
354            else:
355                osm_img = input_img[0].cpu().detach().numpy()
356            if not (0 <= osm_img.min() <= 255 and 0 <= osm_img.max() <=255):
357                raise ValueError(f"Real target has values out of 0-256 range => min:{osm_img.min()}, max:{osm_img.max()}")
358            if osm_img.max() <= 1.0:
359                osm_img *= 255
360            osm_img = osm_img.astype(np.uint8)
361
362            if info_print:
363                print(f"OSM Info:\n    -> shape: {osm_img.shape}\n    -> min: {osm_img.min()}, max: {osm_img.max()}")
364
365            # Save Results
366            file_name = f"physgen_{idx}.png"
367
368            # save pred image
369            # save_img = os.path.join(output_pred_path, file_name)
370            # cv2.imwrite(save_img, pred_img)
371            # print(f"    -> saved pred at {save_img}")
372
373            # save real image
374            save_img = os.path.join(output_real_path, "target_"+file_name)
375            cv2.imwrite(save_img, real_img)
376            if info_print:
377                print(f"    -> saved real at {save_img}")
378
379            # save osm image
380            save_img = os.path.join(output_osm_path, "input_"+file_name)
381            cv2.imwrite(save_img, osm_img)
382            if info_print:
383                print(f"    -> saved osm at {save_img}")
384        print(f"\nSuccessfull saved {data_len} datapoints into {os.path.abspath(output_real_path)} & {os.path.abspath(output_osm_path)}")