Skip to content
Snippets Groups Projects
Commit 8cac1989 authored by Alexander Dielen's avatar Alexander Dielen
Browse files

initial commit

parents
No related branches found
No related tags found
No related merge requests found
data
out
__pycache__
.idea
# Shape Correspondences using Spiral Sequences
This repository contains the code to reproduce the results for the
sequence-based networks in the paper
Lim, I., Dielen, A., Campen, M., & Kobbelt, L. (2018). A Simple Approach to
Intrinsic Correspondence Learning on Unstructured 3D Meshes. arXiv preprint
arXiv:1809.06664.
## Results
![](data/results.png)
## Dependencies
* [pytorch][pytorch] 0.4.1
* [openmesh][openmesh]
* [h5py][h5py]
* [matplotlib][matplotlib]
## Dataset
To run the code you need a copy of the meshes in the [FAUST][faust] dataset and
the precomputed SHOT descriptors. Both are contained in this [archive][dropbox]
originally posted by Jonathan Masci [here][tutorial]. Simply extract the
`EG16_tutorial` folder to `data/` and you should be set.
Alternatively you can run the preprocessing code of Masci et al. yourself. The
code for computing the SHOT descriptor can be found [here][matlab1]. We computed
geodesic distances using the code from [here][matlab2]. Copy the mesh files to
`data/meshes/`, the shot descriptors to `data/shot/` and the geodesic distances
to `data/dists/`.
Precomputed distances are optional and only required for evaluation purposes
(i.e. the graph above).
## Running the code
python train.py --mode lstm
python train.py --mode linear
Model checkpoints and predictions on the validation and test sets are saved
to `out/`.
Note: On the first run, the spiral sequences for all meshes and all possible
rotations are precomputed. This may take a while. You can control the number of
parallel processes using the `--processes N` option (default: 4).
[pytorch]: https://pytorch.org/
[openmesh]: https://pypi.org/project/openmesh/
[h5py]: https://pypi.org/project/h5py/
[matplotlib]: https://pypi.org/project/matplotlib/
[faust]: http://faust.is.tue.mpg.de/
[dropbox]: https://www.dropbox.com/s/aamd98nynkvbcop/EG16_tutorial.tar.bz2?dl=0
[tutorial]: https://github.com/jonathanmasci/EG16_tutorial/blob/master/deep_learning_for_3D_shape_analysis.ipynb
[matlab1]: https://github.com/davideboscaini/shape_utils
[matlab2]: https://github.com/jonathanmasci/ShapeNet_data_preparation_toolbox
File added
File added
File added
File added
File added
data/results.png

33 KiB

import os
import random
import functools
import multiprocessing
import pickle
import h5py
import numpy as np
from torch.utils.data import Dataset
from utils import extract_and_save
split = {
'training': (0, 70),
'validation': (70, 80),
'test': (80, 100)
}
data_dir = 'data/'
spiral_dir = data_dir + 'spirals/'
if os.path.exists(data_dir + 'meshes/'):
mesh_dir = data_dir + 'meshes/'
else:
mesh_dir = data_dir + 'EG16_tutorial/dataset/FAUST_registrations/' \
'meshes/orig/'
if os.path.exists(data_dir + 'shot/'):
desc_dir = data_dir + 'shot/'
else:
desc_dir = data_dir + 'EG16_tutorial/dataset/FAUST_registrations/' \
'data/diam=200/descs/shot/'
class FaustDataset(Dataset):
def __init__(self, mode, seq_length):
self.descriptors = []
self.indices = []
self.filenames = []
for i in range(*split[mode]):
# load descriptor
filename = desc_dir + 'tr_reg_{:03}.mat'.format(i)
f = h5py.File(filename, 'r')
desc = np.array(f['desc'], dtype=np.float32).T
desc = np.ascontiguousarray(desc)
self.descriptors.append(desc)
f.close()
# load spiral
filename = spiral_dir + '{}/tr_reg_{:03}.pkl'.format(seq_length, i)
self.indices.append(pickle.load(open(filename, 'rb')))
# remember filename
self.filenames.append('tr_reg_{:03}'.format(i))
def __len__(self):
return len(self.descriptors)
def __getitem__(self, idx):
indices = []
for i in range(len(self.indices[idx])):
pick = random.randint(0, len(self.indices[idx][i]) - 1)
indices.append(self.indices[idx][i][pick])
indices = np.array(indices, dtype=np.int64)
return self.descriptors[idx], indices, self.filenames[idx]
def precompute_spirals(seq_length, processes):
output_dir = spiral_dir + '{}/'.format(seq_length)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
func = functools.partial(extract_and_save, seq_length=seq_length,
input_dir=mesh_dir, output_dir=output_dir)
pool = multiprocessing.Pool(processes)
pool.map(func, list(range(100)))
pool.close()
pool.join()
def geodesic_dists_available():
return os.path.exists(data_dir + 'dists/')
def geodesic_dists(idx):
filename = data_dir + 'dists/tr_reg_{:03}.mat'.format(idx)
f = h5py.File(filename, 'r')
dist = np.array(f['dist'])
f.close()
return dist
train.py 0 → 100644
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import sys
import os
import time
import random
import argparse
import faust_data
import utils
class SpiralLayer(torch.nn.Module):
def __init__(self, linear, in_features, out_features, seq_length):
super(SpiralLayer, self).__init__()
if linear:
self.layer = nn.Linear(in_features * seq_length, out_features)
else:
self.layer = nn.LSTM(in_features, out_features, batch_first=True)
self.linear = linear
def forward(self, x, indices):
bs, seq_length = indices.size()
x = torch.index_select(x, 0, indices.view(-1))
if self.linear:
x = x.view(bs, -1)
x = self.layer(x)
else:
x = x.view(bs, seq_length, -1)
x = self.layer(x)[1][0]
x = x.squeeze()
return x
class SpiralNet(torch.nn.Module):
def __init__(self, linear, in_features=544, n_classes=6890, seq_length=30):
super(SpiralNet, self).__init__()
outputs = [100, 150, 200] if linear else [150, 200, 250]
self.fc1 = nn.Linear(in_features, 16)
self.spiral1 = SpiralLayer(linear, 16, outputs[0], seq_length)
self.spiral2 = SpiralLayer(linear, outputs[0], outputs[1], seq_length)
self.spiral3 = SpiralLayer(linear, outputs[1], outputs[2], seq_length)
self.fc2 = nn.Linear(outputs[2], 256)
self.fc3 = nn.Linear(256, n_classes)
def forward(self, x, indices):
x = F.dropout(x, p=0.3, training=self.training)
x = F.relu(self.fc1(x))
x = F.relu(self.spiral1(x, indices))
x = F.relu(self.spiral2(x, indices))
x = F.relu(self.spiral3(x, indices))
x = F.dropout(x, p=0.3, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=0.3, training=self.training)
x = self.fc3(x)
return x
def train(net, data, targets, device, criterion, optimizer, epoch, n_epochs):
net.train()
losses, accuracies = [], []
for i in np.random.permutation(len(data)):
x, indices, _ = data[i]
x = torch.from_numpy(x).to(device)
indices = torch.from_numpy(indices).to(device)
optimizer.zero_grad()
outputs = net(x, indices)
predictions = outputs.argmax(1)
correct = predictions.eq(targets).sum().item()
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
losses.append(loss.item())
accuracies.append(correct / len(predictions))
print('========================================')
print('Epoch {} of {}'.format(epoch, n_epochs))
print('========================================')
print('Training Set')
print(' Loss: {0:.15f}'.format(np.mean(losses)))
print(' Accuracy: {0:.15f}'.format(np.mean(accuracies)))
def validate(net, data, targets, device, criterion):
net.eval()
losses, accuracies = [], []
with torch.no_grad():
for i in range(len(data)):
x, indices, _ = data[i]
x = torch.from_numpy(x).to(device)
indices = torch.from_numpy(indices).to(device)
outputs = net(x, indices)
predictions = outputs.argmax(1)
correct = predictions.eq(targets).sum().item()
loss = criterion(outputs, targets)
losses.append(loss.item())
accuracies.append(correct / len(predictions))
print('Validation Set')
print(' Loss: {0:.15f}'.format(np.mean(losses)))
print(' Accuracy: {0:.15f}\n'.format(np.mean(accuracies)))
sys.stdout.flush()
return np.mean(accuracies)
def save_predictions(net, data, device, predictions_dir):
net.eval()
with torch.no_grad():
for i in range(len(data)):
x, indices, filename = data[i]
x = torch.from_numpy(x).to(device)
indices = torch.from_numpy(indices).to(device)
outputs = net(x, indices)
predictions = outputs.argmax(1)
filename = predictions_dir + filename + '.npy'
np.save(filename, predictions.cpu().numpy())
def save_checkpoint(net, optimizer, epoch, checkpoint_dir):
torch.save({
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
}, checkpoint_dir + 'checkpoint.tar')
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, default='lstm')
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--seq_length', type=int, default=30)
parser.add_argument('--processes', type=int, default=4)
parser.add_argument('--checkpoint', type=str)
return parser.parse_args()
def save_args(filename, args):
args_list = ['{}: {}'.format(arg, getattr(args, arg)) for arg in vars(args)]
args_text = '\n'.join(sorted(args_list))
text_file = open(filename, 'w')
text_file.write(args_text)
text_file.close()
def main():
# make reproducible
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
run_dir = 'out/' + time.strftime('%Y-%m-%d-%H-%M-%S/', time.localtime())
checkpoint_dir = run_dir + 'checkpoint/'
predictions_dir = run_dir + 'predictions/'
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
if not os.path.exists(predictions_dir):
os.makedirs(predictions_dir)
args = parse_args()
save_args(run_dir + 'args.txt', args)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = SpiralNet(args.mode=='linear').to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=args.lr)
n_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('Number of trainable parameters: {}'.format(n_params))
# restore parameters & epoch from checkpoint
if args.checkpoint is not None:
checkpoint = torch.load(args.checkpoint)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
else:
start_epoch = 1
print('Loading data...\n')
faust_data.precompute_spirals(args.seq_length, args.processes)
training_dataset = faust_data.FaustDataset('training', args.seq_length)
validation_dataset = faust_data.FaustDataset('validation', args.seq_length)
test_dataset = faust_data.FaustDataset('test', args.seq_length)
targets = torch.arange(6890, dtype=torch.long).to(device)
max_acc = 0
for epoch in range(start_epoch, args.epochs + 1):
train(net, training_dataset, targets, device, criterion, optimizer, epoch, args.epochs)
acc = validate(net, validation_dataset, targets, device, criterion)
if epoch > args.epochs // 4 and acc > max_acc:
max_acc = acc
save_predictions(net, test_dataset, device, predictions_dir)
save_checkpoint(net, optimizer, epoch, checkpoint_dir)
if faust_data.geodesic_dists_available():
errors = []
targets = targets.cpu().numpy()
for i in range(*faust_data.split['test']):
filename = predictions_dir + 'tr_reg_{:03}.npy'.format(i)
predictions = np.load(filename)
dists = faust_data.geodesic_dists(i)
errors.append(dists[targets, predictions])
errors = np.concatenate(errors)
utils.save_benchmark(run_dir + 'results.pdf', errors)
if __name__ == "__main__":
main()
utils.py 0 → 100644
import openmesh as om
import numpy as np
import os
import pickle
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import faust_data
def _next_ring(mesh, last_ring, other):
res = []
def is_new_vertex(idx):
return (idx not in last_ring and
idx not in other and
idx not in res)
for vh1 in last_ring:
vh1 = om.VertexHandle(vh1)
# first pass: all vertices after last_ring
after_last_ring = False
for vh2 in mesh.vv(vh1):
if after_last_ring:
if is_new_vertex(vh2.idx()):
res.append(vh2.idx())
if vh2.idx() in last_ring:
after_last_ring = True
# second pass: all vertices before last_ring
for vh2 in mesh.vv(vh1):
if vh2.idx() in last_ring:
break
if is_new_vertex(vh2.idx()):
res.append(vh2.idx())
return res
def extract_spirals(filename, seq_length):
mesh = om.read_trimesh(filename)
spirals = []
for vh0 in mesh.vertices():
reference_one_ring = []
for vh1 in mesh.vv(vh0):
reference_one_ring.append(vh1.idx())
rotated_spirals = []
for shift in range(len(reference_one_ring)):
spiral = [vh0.idx()]
one_ring = list(np.roll(reference_one_ring, -shift))
last_ring = one_ring
next_ring = _next_ring(mesh, last_ring, spiral)
spiral.extend(last_ring)
while len(spiral) + len(next_ring) < seq_length:
last_ring = next_ring
next_ring = _next_ring(mesh, last_ring, spiral)
spiral.extend(last_ring)
spiral.extend(next_ring)
rotated_spirals.append(spiral[:seq_length])
spirals.append(rotated_spirals)
return spirals
def extract_and_save(idx, seq_length=None, input_dir=None, output_dir=None):
ply_filename = input_dir + 'tr_reg_{:03}.ply'.format(idx)
pkl_filename = output_dir + 'tr_reg_{:03}.pkl'.format(idx)
if not os.path.isfile(pkl_filename):
print('Computing spirals for tr_reg_{:03}.ply'.format(idx))
spirals = extract_spirals(ply_filename, seq_length)
pickle.dump(spirals, open(pkl_filename, 'wb'))
def save_benchmark(filename, errors):
x = np.arange(0.0, 0.2, 0.001)
y = np.zeros(x.shape)
errors = np.sort(errors)
m = 0
n = len(errors)
for idx, val in enumerate(x):
while errors[m] <= val and m < n:
m += 1
y[idx] = float(m) / n
references = [
('monet_raw.npy', 'tab:green', 'MoNet (raw)'),
('gcnn_symmetric.npy', 'tab:orange', 'GCNN (symmetric)'),
('gcnn_asymmetric.npy', 'tab:blue', 'GCNN (asymmetric)'),
]
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(x, y, 'tab:red', linewidth=2.2, label='our method')
for fname, color, label in references:
arr = np.load(faust_data.data_dir + 'references/' + fname)
ax.plot(arr[0], arr[1], color, linewidth=2.2, label=label)
ax.legend(loc='lower right')
ax.set(xlabel='geodesic radius', ylabel='% correct correspondences')
ax.grid()
plt.axis([0, 0.2, 0, 1])
plt.xticks(np.arange(0.0, 0.25, 0.05))
plt.yticks(np.arange(0.0, 1.20, 0.20))
plt.tight_layout()
fig.savefig(filename)
plt.close(fig)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment