File size: 5,095 Bytes
97aa5af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | import os
import numpy as np
import torch
class ClassificationData:
def __init__(self, data_dict):
self.data_dict = data_dict
self.pcs = self.find_attribute('pcs')
self.labels = self.find_attribute('labels')
self.check_data()
def find_attribute(self, attribute):
try:
attribute_data = self.data_dict[attribute]
except:
print("Given data directory has no key attribute \"{}\"".format(attribute))
return attribute_data
def check_data(self):
assert 1 < len(self.pcs.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.pcs.shape)
assert 0 < len(self.labels.shape) < 3, "Error in dimension of labels! Given data dimension: {}".format(self.labels.shape)
if len(self.pcs.shape)==2: self.pcs = self.pcs.reshape(1, -1, 3)
if len(self.labels.shape) == 1: self.labels = self.labels.reshape(1, -1)
assert self.pcs.shape[0] == self.labels.shape[0], "Inconsistency in the number of point clouds and number of ground truth labels!"
def __len__(self):
return self.pcs.shape[0]
def __getitem__(self, index):
return torch.tensor(self.pcs[index]).float(), torch.from_numpy(self.labels[idx]).type(torch.LongTensor)
class RegistrationData:
def __init__(self, data_dict):
self.data_dict = data_dict
self.template = self.find_attribute('template')
self.source = self.find_attribute('source')
self.transformation = self.find_attribute('transformation')
self.check_data()
# def find_attribute(self, attribute):
# try:
# attribute_data = self.data[attribute]
# except:
# print("Given data directory has no key attribute \"{}\"".format(attribute))
# return attribute_data
def find_attribute(self, attribute):
attribute_data = None
if attribute in self.data_dict:
attribute_data = self.data_dict[attribute]
else:
print("Given data directory has no key attribute \"{}\"".format(attribute))
return attribute_data
def check_data(self):
assert 1 < len(self.template.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.template.shape)
assert 1 < len(self.source.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.source.shape)
assert 1 < len(self.transformation.shape) < 4, "Error in dimension of transformations! Given data dimension: {}".format(self.transformation.shape)
if len(self.template.shape)==2: self.template = self.template.reshape(1, -1, 3)
if len(self.source.shape)==2: self.source = self.source.reshape(1, -1, 3)
if len(self.transformation.shape) == 2: self.transformation = self.transformation.reshape(1, 4, 4)
assert self.template.shape[0] == self.source.shape[0], "Inconsistency in the number of template and source point clouds!"
assert self.source.shape[0] == self.transformation.shape[0], "Inconsistency in the number of transformation and source point clouds!"
def __len__(self):
return self.template.shape[0]
def __getitem__(self, index):
return torch.tensor(self.template[index]).float(), torch.tensor(self.source[index]).float(), torch.tensor(self.transformation[index]).float()
class FlowData:
def __init__(self, data_dict):
self.data_dict = data_dict
self.frame1 = self.find_attribute('frame1')
self.frame2 = self.find_attribute('frame2')
self.flow = self.find_attribute('flow')
self.check_data()
def find_attribute(self, attribute):
try:
attribute_data = self.data[attribute]
except:
print("Given data directory has no key attribute \"{}\"".format(attribute))
return attribute_data
def check_data(self):
assert 1 < len(self.frame1.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame1.shape)
assert 1 < len(self.frame2.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame2.shape)
assert 1 < len(self.flow.shape) < 4, "Error in dimension of flow! Given data dimension: {}".format(self.flow.shape)
if len(self.frame1.shape)==2: self.frame1 = self.frame1.reshape(1, -1, 3)
if len(self.frame2.shape)==2: self.frame2 = self.frame2.reshape(1, -1, 3)
if len(self.flow.shape) == 2: self.flow = self.flow.reshape(1, -1, 3)
assert self.frame1.shape[0] == self.frame2.shape[0], "Inconsistency in the number of frame1 and frame2 point clouds!"
assert self.frame2.shape[0] == self.flow.shape[0], "Inconsistency in the number of flow and frame2 point clouds!"
def __len__(self):
return self.frame1.shape[0]
def __getitem__(self, index):
return torch.tensor(self.frame1[index]).float(), torch.tensor(self.frame2[index]).float(), torch.tensor(self.flow[index]).float()
class UserData:
def __init__(self, application, data_dict):
self.application = application
if self.application == 'classification':
self.data_class = ClassificationData(data_dict)
elif self.application == 'registration':
self.data_class = RegistrationData(data_dict)
elif self.application == 'flow_estimation':
self.data_class = FlowData(data_dict)
def __len__(self):
return len(self.data_class)
def __getitem__(self, index):
return self.data_class[index]
|