Skip to content
Snippets Groups Projects
Commit 87753e5f authored by Julia Berger's avatar Julia Berger
Browse files

added support for additional frame infos

parent 24e1f920
Branches
No related tags found
No related merge requests found
......@@ -10,10 +10,12 @@ class NeRFSyntheticDataset(Dataset):
def __init__(self,
root_dir,
split='train',
additional_info=False,
img_downscale=1):
self.root_dir = root_dir
self.split = split
self.img_downscale = img_downscale
self.additional_info = additional_info
self.read_json()
def read_json(self):
......@@ -27,6 +29,7 @@ class NeRFSyntheticDataset(Dataset):
# Extract image and corresponding transformation matrix
imgs = []
poses = []
self.infos = []
for frame in meta['frames']:
file_path = frame['file_path']
if file_path[-4:] != '.png':
......@@ -51,6 +54,13 @@ class NeRFSyntheticDataset(Dataset):
imgs.append(np.array(img))
poses.append(np.array(frame['transform_matrix']))
# Get additional information per frame, except file path and transformation matrix.
info = {}
for key in frame.keys():
if key != 'file_path' and key != 'transform_matrix':
info[key] = frame[key]
self.infos.append(info)
self.rgbs = (np.array(imgs) / 255.).astype(np.float32)
self.poses = np.array(poses).astype(np.float32)
......@@ -71,9 +81,20 @@ class NeRFSyntheticDataset(Dataset):
self.focal_y = .5 * self.h / np.tan(.5 * camera_angle_y)
self.focal_y /= self.img_downscale
def __len__(self):
return self.rgbs.shape[0]
def __getitem__(self, idx):
if self.additional_info:
return {
'rgb': self.rgbs[idx],
'pose': self.poses[idx],
'focal_x': self.focal_x,
'focal_y': self.focal_y,
'info': self.infos[idx]
}
else:
return {
'rgb': self.rgbs[idx],
'pose': self.poses[idx],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment