Logsignature example
This notebook is based on the examples from the torchcde
package by
Kidger and Morrill which can be found at
https://github.com/patrick-kidger/torchcde. Further information about
the techniques described in this notebook can be found
Morrill, J., Salvi, C., Kidger, P., Foster, J. and Lyons, T., 2020. Neural rough differential equations for long time series. arXiv preprint arXiv:2009.08295
Morrill, J., Kidger, P., Yang, L. and Lyons, T., 2021. Neural Controlled Differential Equations for Online Prediction Tasks. arXiv preprint arXiv:2106.11028.
Kidger, P., Foster, J., Li, X., Oberhauser, H. and Lyons, T., 2021. Neural sdes as infinite-dimensional gans. arXiv preprint arXiv:2102.03657.
In this notebook we code up a Neural CDE using the log-ode method for a long time series thus becoming a Neural RDE. We will only describe the differences from that example.
Set up the notebook
Install dependencies
This notebook requires PyTorch and the torchcde package. The
dependencies are listed in the requirements.txt
file. They can be
installed using the following command.
import sys
!{sys.executable} -m pip uninstall -y enum34
!{sys.executable} -m pip install -r requirements.txt
Import the necessary packages
import math
import time
import torch
import torchcde
Also set some parameters that can be changed when experimenting with the method.
HIDDEN_LAYER_WIDTH = 64 # This is the width of the hidden layer of the neural network
NUM_EPOCHS = 10 # This is the number of training iterations we will use later
NUM_TIMEPOINTS = 5000 # Number of time points to use in generated data.
# This is large to demonstrate the utility of logsignature features.
We use the CDEFunc
and NeuralCDE
classes, and the get_data
function defined in the time series classificiation notebook.
class CDEFunc(torch.nn.Module):
def __init__(self, input_channels, hidden_channels):
######################
# input_channels is the number of input channels in the data X. (Determined by the data.)
# hidden_channels is the number of channels for z_t. (Determined by you!)
######################
super(CDEFunc, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.linear1 = torch.nn.Linear(hidden_channels, HIDDEN_LAYER_WIDTH)
self.linear2 = torch.nn.Linear(HIDDEN_LAYER_WIDTH, input_channels * hidden_channels)
######################
# For most purposes the t argument can probably be ignored; unless you want your CDE to behave differently at
# different times, which would be unusual. But it's there if you need it!
######################
def forward(self, t, z):
# z has shape (batch, hidden_channels)
z = self.linear1(z)
z = z.relu()
z = self.linear2(z)
######################
# Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
######################
z = z.tanh()
######################
# Ignoring the batch dimension, the shape of the output tensor must be a matrix,
# because we need it to represent a linear map from R^input_channels to R^hidden_channels.
######################
z = z.view(z.size(0), self.hidden_channels, self.input_channels)
return z
class NeuralCDE(torch.nn.Module):
def __init__(self, input_channels, hidden_channels, output_channels, interpolation="cubic"):
super(NeuralCDE, self).__init__()
self.func = CDEFunc(input_channels, hidden_channels)
self.initial = torch.nn.Linear(input_channels, hidden_channels)
self.readout = torch.nn.Linear(hidden_channels, output_channels)
self.interpolation = interpolation
def forward(self, coeffs):
if self.interpolation == 'cubic':
X = torchcde.NaturalCubicSpline(coeffs)
elif self.interpolation == 'linear':
X = torchcde.LinearInterpolation(coeffs)
else:
raise ValueError("Only 'linear' and 'cubic' interpolation methods are implemented.")
######################
# Easy to forget gotcha: Initial hidden state should be a function of the first observation.
######################
X0 = X.evaluate(X.interval[0])
z0 = self.initial(X0)
######################
# Actually solve the CDE.
######################
z_T = torchcde.cdeint(X=X,
z0=z0,
func=self.func,
t=X.interval)
######################
# Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
# and then apply a linear map.
######################
z_T = z_T[:, 1]
pred_y = self.readout(z_T)
return pred_y
def get_data(num_timepoints=100):
t = torch.linspace(0., 4 * math.pi, num_timepoints)
start = torch.rand(HIDDEN_LAYER_WIDTH) * 2 * math.pi
x_pos = torch.cos(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
x_pos[:HIDDEN_LAYER_WIDTH//2] *= -1
y_pos = torch.sin(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
x_pos += 0.01 * torch.randn_like(x_pos)
y_pos += 0.01 * torch.randn_like(y_pos)
######################
# Easy to forget gotcha: time should be included as a channel; Neural CDEs need to be explicitly told the
# rate at which time passes. Here, we have a regularly sampled dataset, so appending time is pretty simple.
######################
X = torch.stack([t.unsqueeze(0).repeat(HIDDEN_LAYER_WIDTH, 1), x_pos, y_pos], dim=2)
y = torch.zeros(HIDDEN_LAYER_WIDTH)
y[:HIDDEN_LAYER_WIDTH//2] = 1
perm = torch.randperm(HIDDEN_LAYER_WIDTH)
X = X[perm]
y = y[perm]
######################
# X is a tensor of observations, of shape (batch=128, sequence=100, channels=3)
# y is a tensor of labels, of shape (batch=128,), either 0 or 1 corresponding to anticlockwise or clockwise respectively.
######################
return X, y
Now we can define a function that will train the model and evaluate the performance on our data set using logsignatures up to a specified depth.
def train_and_evaluate(train_X, train_y, test_X, test_y, depth, num_epochs, window_length):
# Time the training process
start_time = time.time()
# Logsignature computation step
train_logsig = torchcde.logsig_windows(train_X, depth, window_length=window_length)
print("Logsignature shape: {}".format(train_logsig.size()))
model = NeuralCDE(
input_channels=train_logsig.size(-1), hidden_channels=8, output_channels=1, interpolation="linear"
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
train_coeffs = torchcde.linear_interpolation_coeffs(train_logsig)
train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
for epoch in range(num_epochs):
for batch in train_dataloader:
batch_coeffs, batch_y = batch
pred_y = model(batch_coeffs).squeeze(-1)
loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print("Epoch: {} Training loss: {}".format(epoch, loss.item()))
# Remember to compute the logsignatures of the test data too!
test_logsig = torchcde.logsig_windows(test_X, depth, window_length=window_length)
test_coeffs = torchcde.linear_interpolation_coeffs(test_logsig)
pred_y = model(test_coeffs).squeeze(-1)
binary_prediction = (torch.sigmoid(pred_y) > 0.5).to(test_y.dtype)
prediction_matches = (binary_prediction == test_y).to(test_y.dtype)
proportion_correct = prediction_matches.sum() / test_y.size(0)
print("Test Accuracy: {}".format(proportion_correct))
# Total time
elapsed = time.time() - start_time
return proportion_correct, elapsed
Here we load a high frequency version of the spiral data using in
torchcde.example
. Each sample contains NUM_TIMEPOINTS
time
points. This is too long to sensibly expect a Neural CDE to handle,
training time will be very long and it will struggle to remember
information from early on in the sequence.
train_X, train_y = get_data(num_timepoints=NUM_TIMEPOINTS)
test_X, test_y = get_data(num_timepoints=NUM_TIMEPOINTS)
We test the model over logsignature depths 1, 2, and 3, with a window
length of 50. This reduces the effective length of the path to just 100.
The only change is an application of torchcde.logsig_windows
The raw signal has 3 input channels. Taking logsignatures of depths 1, 2, and 3 results in a path of logsignatures of dimension 3, 6, and 14 respectively. We see that higher logsignature depths contain more information about the path over the intervals, at a cost of increased numbers of channels.
depths = [1, 2, 3]
window_length = 50
accuracies = []
training_times = []
for depth in depths:
print(f'Running for logsignature depth: {depth}')
acc, elapsed = train_and_evaluate(
train_X, train_y, test_X, test_y, depth, NUM_EPOCHS, window_length
)
training_times.append(elapsed)
accuracies.append(acc)
Running for logsignature depth: 1
Logsignature shape: torch.Size([64, 101, 3])
Epoch: 0 Training loss: 1.7253673076629639
Epoch: 1 Training loss: 2.6841232776641846
Epoch: 2 Training loss: 1.1095588207244873
Epoch: 3 Training loss: 1.8698482513427734
Epoch: 4 Training loss: 0.8444149494171143
Epoch: 5 Training loss: 1.102584719657898
Epoch: 6 Training loss: 0.9590306282043457
Epoch: 7 Training loss: 1.0678613185882568
Epoch: 8 Training loss: 0.7616084814071655
Epoch: 9 Training loss: 0.6925854086875916
Test Accuracy: 0.796875
Running for logsignature depth: 2
Logsignature shape: torch.Size([64, 101, 6])
Epoch: 0 Training loss: 3.9483087062835693
Epoch: 1 Training loss: 2.967172384262085
Epoch: 2 Training loss: 1.3951165676116943
Epoch: 3 Training loss: 0.6525543332099915
Epoch: 4 Training loss: 0.5654739141464233
Epoch: 5 Training loss: 0.6235690712928772
Epoch: 6 Training loss: 0.643418550491333
Epoch: 7 Training loss: 0.7490644454956055
Epoch: 8 Training loss: 0.6644153594970703
Epoch: 9 Training loss: 0.6092175841331482
Test Accuracy: 0.703125
Running for logsignature depth: 3
Logsignature shape: torch.Size([64, 101, 14])
Epoch: 0 Training loss: 9.29626750946045
Epoch: 1 Training loss: 2.3605875968933105
Epoch: 2 Training loss: 0.9953503608703613
Epoch: 3 Training loss: 1.4490458965301514
Epoch: 4 Training loss: 0.6993889212608337
Epoch: 5 Training loss: 1.3962339162826538
Epoch: 6 Training loss: 0.7141188979148865
Epoch: 7 Training loss: 0.7587863206863403
Epoch: 8 Training loss: 0.8748772144317627
Epoch: 9 Training loss: 0.6787529587745667
Test Accuracy: 0.5
Finally, log the results to the console for a comparison
print("Final results")
for acc, elapsed, depth in zip(accuracies, training_times, depths):
print(
f"Depth: {depth}\n\tAccuracy on test set: {acc*100:.1f}%\n\tTime per epoch: {elapsed/NUM_EPOCHS:.1f}s"
)
Final results
Depth: 1
Accuracy on test set: 79.7%
Time per epoch: 4.7s
Depth: 2
Accuracy on test set: 70.3%
Time per epoch: 6.9s
Depth: 3
Accuracy on test set: 50.0%
Time per epoch: 5.2s