Commit 8cac1989 authored by Alexander Dielen's avatar Alexander Dielen
Browse files

initial commit

# Shape Correspondences using Spiral Sequences
This repository contains the code to reproduce the results for the
sequence-based networks in the paper
Lim, I., Dielen, A., Campen, M., & Kobbelt, L. (2018). A Simple Approach to
Intrinsic Correspondence Learning on Unstructured 3D Meshes. arXiv preprint
## Results
## Dependencies
* [pytorch][pytorch] 0.4.1
* [openmesh][openmesh]
* [h5py][h5py]
* [matplotlib][matplotlib]
## Dataset
To run the code you need a copy of the meshes in the [FAUST][faust] dataset and
the precomputed SHOT descriptors. Both are contained in this [archive][dropbox]
originally posted by Jonathan Masci [here][tutorial]. Simply extract the
`EG16_tutorial` folder to `data/` and you should be set.
Alternatively you can run the preprocessing code of Masci et al. yourself. The
code for computing the SHOT descriptor can be found [here][matlab1]. We computed
geodesic distances using the code from [here][matlab2]. Copy the mesh files to
`data/meshes/`, the shot descriptors to `data/shot/` and the geodesic distances
to `data/dists/`.
Precomputed distances are optional and only required for evaluation purposes
(i.e. the graph above).
## Running the code
python --mode lstm
python --mode linear
Model checkpoints and predictions on the validation and test sets are saved
to `out/`.
Note: On the first run, the spiral sequences for all meshes and all possible
rotations are precomputed. This may take a while. You can control the number of
parallel processes using the `--processes N` option (default: 4).
import os
import random
import functools
import multiprocessing
import pickle
import h5py
import numpy as np
from import Dataset
from utils import extract_and_save
split = {
'training': (0, 70),
'validation': (70, 80),
'test': (80, 100)
data_dir = 'data/'
spiral_dir = data_dir + 'spirals/'
if os.path.exists(data_dir + 'meshes/'):
mesh_dir = data_dir + 'meshes/'
mesh_dir = data_dir + 'EG16_tutorial/dataset/FAUST_registrations/' \
if os.path.exists(data_dir + 'shot/'):
desc_dir = data_dir + 'shot/'
desc_dir = data_dir + 'EG16_tutorial/dataset/FAUST_registrations/' \
class FaustDataset(Dataset):
def __init__(self, mode, seq_length):
self.descriptors = []
self.indices = []
self.filenames = []
for i in range(*split[mode]):
# load descriptor
filename = desc_dir + 'tr_reg_{:03}.mat'.format(i)
f = h5py.File(filename, 'r')
desc = np.array(f['desc'], dtype=np.float32).T
desc = np.ascontiguousarray(desc)
# load spiral
filename = spiral_dir + '{}/tr_reg_{:03}.pkl'.format(seq_length, i)
self.indices.append(pickle.load(open(filename, 'rb')))
# remember filename
def __len__(self):
return len(self.descriptors)
def __getitem__(self, idx):
indices = []
for i in range(len(self.indices[idx])):
pick = random.randint(0, len(self.indices[idx][i]) - 1)
indices = np.array(indices, dtype=np.int64)
return self.descriptors[idx], indices, self.filenames[idx]
def precompute_spirals(seq_length, processes):
output_dir = spiral_dir + '{}/'.format(seq_length)
if not os.path.exists(output_dir):
func = functools.partial(extract_and_save, seq_length=seq_length,
input_dir=mesh_dir, output_dir=output_dir)
pool = multiprocessing.Pool(processes), list(range(100)))
def geodesic_dists_available():
return os.path.exists(data_dir + 'dists/')
def geodesic_dists(idx):
filename = data_dir + 'dists/tr_reg_{:03}.mat'.format(idx)
f = h5py.File(filename, 'r')
dist = np.array(f['dist'])
return dist
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import sys
import os
import time
import random
import argparse
import faust_data
import utils
class SpiralLayer(torch.nn.Module):
def __init__(self, linear, in_features, out_features, seq_length):
super(SpiralLayer, self).__init__()
if linear:
self.layer = nn.Linear(in_features * seq_length, out_features)
self.layer = nn.LSTM(in_features, out_features, batch_first=True)
self.linear = linear
def forward(self, x, indices):
bs, seq_length = indices.size()
x = torch.index_select(x, 0, indices.view(-1))
if self.linear:
x = x.view(bs, -1)
x = self.layer(x)
x = x.view(bs, seq_length, -1)
x = self.layer(x)[1][0]
x = x.squeeze()
return x
class SpiralNet(torch.nn.Module):
def __init__(self, linear, in_features=544, n_classes=6890, seq_length=30):
super(SpiralNet, self).__init__()
outputs = [100, 150, 200] if linear else [150, 200, 250]
self.fc1 = nn.Linear(in_features, 16)
self.spiral1 = SpiralLayer(linear, 16, outputs[0], seq_length)
self.spiral2 = SpiralLayer(linear, outputs[0], outputs[1], seq_length)
self.spiral3 = SpiralLayer(linear, outputs[1], outputs[2], seq_length)
self.fc2 = nn.Linear(outputs[2], 256)
self.fc3 = nn.Linear(256, n_classes)
def forward(self, x, indices):
x = F.dropout(x, p=0.3,
x = F.relu(self.fc1(x))
x = F.relu(self.spiral1(x, indices))
x = F.relu(self.spiral2(x, indices))
x = F.relu(self.spiral3(x, indices))
x = F.dropout(x, p=0.3,
x = self.fc2(x)
x = F.dropout(x, p=0.3,
x = self.fc3(x)
return x
def train(net, data, targets, device, criterion, optimizer, epoch, n_epochs):
losses, accuracies = [], []
for i in np.random.permutation(len(data)):
x, indices, _ = data[i]
x = torch.from_numpy(x).to(device)
indices = torch.from_numpy(indices).to(device)
outputs = net(x, indices)
predictions = outputs.argmax(1)
correct = predictions.eq(targets).sum().item()
loss = criterion(outputs, targets)
accuracies.append(correct / len(predictions))
print('Epoch {} of {}'.format(epoch, n_epochs))
print('Training Set')
print(' Loss: {0:.15f}'.format(np.mean(losses)))
print(' Accuracy: {0:.15f}'.format(np.mean(accuracies)))
def validate(net, data, targets, device, criterion):
losses, accuracies = [], []
with torch.no_grad():
for i in range(len(data)):
x, indices, _ = data[i]
x = torch.from_numpy(x).to(device)
indices = torch.from_numpy(indices).to(device)
outputs = net(x, indices)
predictions = outputs.argmax(1)
correct = predictions.eq(targets).sum().item()
loss = criterion(outputs, targets)
accuracies.append(correct / len(predictions))
print('Validation Set')
print(' Loss: {0:.15f}'.format(np.mean(losses)))
print(' Accuracy: {0:.15f}\n'.format(np.mean(accuracies)))
return np.mean(accuracies)
def save_predictions(net, data, device, predictions_dir):
with torch.no_grad():
for i in range(len(data)):
x, indices, filename = data[i]
x = torch.from_numpy(x).to(device)
indices = torch.from_numpy(indices).to(device)
outputs = net(x, indices)
predictions = outputs.argmax(1)
filename = predictions_dir + filename + '.npy', predictions.cpu().numpy())
def save_checkpoint(net, optimizer, epoch, checkpoint_dir):{
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
}, checkpoint_dir + 'checkpoint.tar')
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, default='lstm')
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--seq_length', type=int, default=30)
parser.add_argument('--processes', type=int, default=4)
parser.add_argument('--checkpoint', type=str)
return parser.parse_args()
def save_args(filename, args):
args_list = ['{}: {}'.format(arg, getattr(args, arg)) for arg in vars(args)]
args_text = '\n'.join(sorted(args_list))
text_file = open(filename, 'w')
def main():
# make reproducible
run_dir = 'out/' + time.strftime('%Y-%m-%d-%H-%M-%S/', time.localtime())
checkpoint_dir = run_dir + 'checkpoint/'
predictions_dir = run_dir + 'predictions/'
if not os.path.exists(checkpoint_dir):
if not os.path.exists(predictions_dir):
args = parse_args()
save_args(run_dir + 'args.txt', args)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = SpiralNet(args.mode=='linear').to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),
n_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('Number of trainable parameters: {}'.format(n_params))
# restore parameters & epoch from checkpoint
if args.checkpoint is not None:
checkpoint = torch.load(args.checkpoint)
start_epoch = checkpoint['epoch'] + 1
start_epoch = 1
print('Loading data...\n')
faust_data.precompute_spirals(args.seq_length, args.processes)
training_dataset = faust_data.FaustDataset('training', args.seq_length)
validation_dataset = faust_data.FaustDataset('validation', args.seq_length)
test_dataset = faust_data.FaustDataset('test', args.seq_length)
targets = torch.arange(6890, dtype=torch.long).to(device)
max_acc = 0
for epoch in range(start_epoch, args.epochs + 1):
train(net, training_dataset, targets, device, criterion, optimizer, epoch, args.epochs)
acc = validate(net, validation_dataset, targets, device, criterion)
if epoch > args.epochs // 4 and acc > max_acc:
max_acc = acc
save_predictions(net, test_dataset, device, predictions_dir)
save_checkpoint(net, optimizer, epoch, checkpoint_dir)
if faust_data.geodesic_dists_available():
errors = []
targets = targets.cpu().numpy()
for i in range(*faust_data.split['test']):
filename = predictions_dir + 'tr_reg_{:03}.npy'.format(i)
predictions = np.load(filename)
dists = faust_data.geodesic_dists(i)
errors.append(dists[targets, predictions])
errors = np.concatenate(errors)
utils.save_benchmark(run_dir + 'results.pdf', errors)
if __name__ == "__main__":
import openmesh as om
import numpy as np
import os
import pickle
import matplotlib
import matplotlib.pyplot as plt
import faust_data
def _next_ring(mesh, last_ring, other):
res = []
def is_new_vertex(idx):
return (idx not in last_ring and
idx not in other and
idx not in res)
for vh1 in last_ring:
vh1 = om.VertexHandle(vh1)
# first pass: all vertices after last_ring
after_last_ring = False
for vh2 in mesh.vv(vh1):
if after_last_ring:
if is_new_vertex(vh2.idx()):
if vh2.idx() in last_ring:
after_last_ring = True
# second pass: all vertices before last_ring
for vh2 in mesh.vv(vh1):
if vh2.idx() in last_ring:
if is_new_vertex(vh2.idx()):
return res
def extract_spirals(filename, seq_length):
mesh = om.read_trimesh(filename)
spirals = []
for vh0 in mesh.vertices():
reference_one_ring = []
for vh1 in mesh.vv(vh0):
rotated_spirals = []
for shift in range(len(reference_one_ring)):
spiral = [vh0.idx()]
one_ring = list(np.roll(reference_one_ring, -shift))
last_ring = one_ring
next_ring = _next_ring(mesh, last_ring, spiral)
while len(spiral) + len(next_ring) < seq_length:
last_ring = next_ring
next_ring = _next_ring(mesh, last_ring, spiral)
return spirals
def extract_and_save(idx, seq_length=None, input_dir=None, output_dir=None):
ply_filename = input_dir + 'tr_reg_{:03}.ply'.format(idx)
pkl_filename = output_dir + 'tr_reg_{:03}.pkl'.format(idx)
if not os.path.isfile(pkl_filename):
print('Computing spirals for tr_reg_{:03}.ply'.format(idx))
spirals = extract_spirals(ply_filename, seq_length)
pickle.dump(spirals, open(pkl_filename, 'wb'))
def save_benchmark(filename, errors):
x = np.arange(0.0, 0.2, 0.001)
y = np.zeros(x.shape)
errors = np.sort(errors)
m = 0
n = len(errors)
for idx, val in enumerate(x):
while errors[m] <= val and m < n:
m += 1
y[idx] = float(m) / n
references = [
('monet_raw.npy', 'tab:green', 'MoNet (raw)'),
('gcnn_symmetric.npy', 'tab:orange', 'GCNN (symmetric)'),
('gcnn_asymmetric.npy', 'tab:blue', 'GCNN (asymmetric)'),
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(x, y, 'tab:red', linewidth=2.2, label='our method')
for fname, color, label in references:
arr = np.load(faust_data.data_dir + 'references/' + fname)
ax.plot(arr[0], arr[1], color, linewidth=2.2, label=label)
ax.legend(loc='lower right')
ax.set(xlabel='geodesic radius', ylabel='% correct correspondences')
plt.axis([0, 0.2, 0, 1])
plt.xticks(np.arange(0.0, 0.25, 0.05))
plt.yticks(np.arange(0.0, 1.20, 0.20))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment