fixed initialization for the linear model

parent 8cac1989
......@@ -42,7 +42,7 @@ Precomputed distances are optional and only required for evaluation purposes
python --mode lstm
python --mode linear
python --mode linear --epochs 4000
Model checkpoints and predictions on the validation and test sets are saved
to `out/`.
......@@ -20,6 +20,8 @@ class SpiralLayer(torch.nn.Module):
super(SpiralLayer, self).__init__()
if linear:
self.layer = nn.Linear(in_features * seq_length, out_features)
torch.nn.init.xavier_uniform_(self.layer.weight, gain=1)
torch.nn.init.constant_(self.layer.bias, 0)
self.layer = nn.LSTM(in_features, out_features, batch_first=True)
self.linear = linear
......@@ -48,6 +50,14 @@ class SpiralNet(torch.nn.Module):
self.fc2 = nn.Linear(outputs[2], 256)
self.fc3 = nn.Linear(256, n_classes)
if linear:
torch.nn.init.xavier_uniform_(self.fc1.weight, gain=1)
torch.nn.init.xavier_uniform_(self.fc2.weight, gain=1)
torch.nn.init.xavier_uniform_(self.fc3.weight, gain=1)
torch.nn.init.constant_(self.fc1.bias, 0)
torch.nn.init.constant_(self.fc2.bias, 0)
torch.nn.init.constant_(self.fc3.bias, 0)
def forward(self, x, indices):
x = F.dropout(x, p=0.3,
x = F.relu(self.fc1(x))
