Example 5.4 oc aln model deterministic
Optimal control of the ALN model
This notebook shows how to compute the optimal control (OC) signal for the ALN model for a simple example task.
import matplotlib.pyplot as plt
import numpy as np
import os
while os.getcwd().split(os.sep)[-1] != "neurolib":
os.chdir('..')
# We import the model, stimuli, and the optimal control package
from neurolib.models.aln import ALNModel
from neurolib.utils.stimulus import ZeroInput
from neurolib.control.optimal_control import oc_aln
from neurolib.utils.plot_oc import plot_oc_singlenode, plot_oc_network
# This will reload all imports as soon as the code changes
%load_ext autoreload
%autoreload 2
# This function reads out the final state of a simulation
def getfinalstate(model):
N = model.params.Cmat.shape[0]
V = len(model.state_vars)
T = model.getMaxDelay() + 1
state = np.zeros((N, V, T))
for v in range(V):
if "rates" in model.state_vars[v] or "IA" in model.state_vars[v]:
for n in range(N):
state[n, v, :] = model.state[model.state_vars[v]][n, -T:]
else:
for n in range(N):
state[n, v, :] = model.state[model.state_vars[v]][n]
return state
def setinitstate(model, state):
N = model.params.Cmat.shape[0]
V = len(model.init_vars)
T = model.getMaxDelay() + 1
for n in range(N):
for v in range(V):
if "rates" in model.init_vars[v] or "IA" in model.init_vars[v]:
model.params[model.init_vars[v]] = state[:, v, -T:]
else:
model.params[model.init_vars[v]] = state[:, v, -1]
return
def getstate(model):
state = np.concatenate( ( np.concatenate((model.params["rates_exc_init"][:, np.newaxis, -1],
model.params["rates_inh_init"][:, np.newaxis, -1],
model.params["IA_init"][:, np.newaxis, -1], ), axis=1, )[:, :, np.newaxis],
np.stack((model.rates_exc, model.rates_inh, model.IA), axis=1),),axis=2, )
return state
We stimulate the system with a known control signal, define the resulting activity as target, and compute the optimal control for this target. We define weights such that precision is penalized only (w_p=1, w_2=0). Hence, the optimal control signal should converge to the input signal.
We first study current inputs. We will later proceed to rate inputs.
# We import the model
model = ALNModel()
model.params.duration = 10000
model.params.mue_ext_mean = 2. # up state
model.run()
setinitstate(model, getfinalstate(model))
# Some parameters to define stimulation signals
dt = model.params["dt"]
duration = 10.
amplitude = 1.
period = duration /4.
# We define a "zero-input", and a sine-input
zero_input = ZeroInput().generate_input(duration=duration+dt, dt=dt)
input = np.copy(zero_input)
input[0,1:-1] = amplitude * np.sin(2.*np.pi*np.arange(0,duration-0.1, dt)/period) # other functions or random values can be used as well
# We set the duration of the simulation and the initial values
model.params["duration"] = duration
# We set the stimulus in x and y variables, and run the simulation
model.params["ext_exc_current"] = input
model.params["ext_inh_current"] = zero_input
model.params["ext_exc_rate"] = zero_input
model.params["ext_inh_rate"] = zero_input
model.run()
# Define the result of the stimulation as target
target = getstate(model)
target_input = np.concatenate( (input, zero_input, zero_input, zero_input), axis=0)[np.newaxis,:,:]
# Remove stimuli and re-run the simulation
model.params["ext_exc_current"] = zero_input
model.params["ext_inh_current"] = zero_input
control = np.concatenate( (zero_input, zero_input, zero_input, zero_input), axis=0)[np.newaxis,:,:]
model.run()
# combine initial value and simulation result to one array
state = getstate(model)
plot_oc_singlenode(duration, dt, state, target, control, target_input)
# We load the optimal control class
# print array (optional parameter) defines, for which iterations intermediate results will be printed
# Parameters will be taken from the input model
control_mat = np.zeros((1,len(model.input_vars)))
control_mat[0,0] = 1.
cost_mat = np.zeros((1,len(model.output_vars)))
cost_mat[0,0] = 1.
model_controlled = oc_aln.OcAln(model, target, print_array=np.arange(0,501,25), control_matrix=control_mat, cost_matrix=cost_mat)
model_controlled.weights["w_p"] = 1. # default value 1
model_controlled.weights["w_2"] = 0. # default value 0
# We run 500 iterations of the optimal control gradient descent algorithm
model_controlled.optimize(500)
state = model_controlled.get_xs()
control = model_controlled.control
plot_oc_singlenode(duration, dt, state, target, control, target_input, model_controlled.cost_history)
# Do another 100 iterations if you want to.
# Repeated execution will continue with further 100 iterations.
model_controlled.optimize(100)
state = model_controlled.get_xs()
control = model_controlled.control
plot_oc_singlenode(duration, dt, state, target, control, target_input, model_controlled.cost_history)
Let us now look at a scenario with rate-type control inputs
amplitude = 40.
offset = 60.
period = duration /4.
# We define a "zero-input", and a sine-input
zero_input = ZeroInput().generate_input(duration=duration+dt, dt=dt)
input = np.copy(zero_input)
input[0,1:-1] = offset + amplitude * np.sin(2.*np.pi*np.arange(0,duration-0.1, dt)/period) # other functions or random values can be used as well
# We set the stimulus in x and y variables, and run the simulation
model.params["ext_exc_current"] = zero_input
model.params["ext_inh_current"] = zero_input
model.params["ext_exc_rate"] = input * 1e-3 # rate inputs need to be converted to kHz
model.params["ext_inh_rate"] = zero_input
model.run()
# Define the result of the stimulation as target
target = getstate(model)
target_input = np.concatenate( (zero_input, zero_input, input, zero_input), axis=0)[np.newaxis,:,:]
# Remove stimuli and re-run the simulation
model.params["ext_exc_rate"] = zero_input
control = np.concatenate( (zero_input, zero_input, zero_input, zero_input), axis=0)[np.newaxis,:,:]
model.run()
# combine initial value and simulation result to one array
state = getstate(model)
plot_oc_singlenode(duration, dt, state, target, control, target_input, plot_control_vars=[2,3])
# Control matrix needs to be adjusted for rate inputs
control_mat = np.zeros((1,len(model.input_vars)))
control_mat[0,2] = 1.
model_controlled = oc_aln.OcAln(model, target, print_array=np.arange(0,501,25), control_matrix=control_mat, cost_matrix=cost_mat)
model_controlled.weights["w_p"] = 1. # default value 1
model_controlled.weights["w_2"] = 0. # default value 0
# We run 500 iterations of the optimal control gradient descent algorithm
model_controlled.optimize(500)
state = model_controlled.get_xs()
control = model_controlled.control
plot_oc_singlenode(duration, dt, state, target, control*1e3, target_input, model_controlled.cost_history, plot_control_vars=[2,3])
# Do another 100 iterations if you want to.
# Repeated execution will continue with further 100 iterations.
model_controlled.optimize(100)
state = model_controlled.get_xs()
control = model_controlled.control
plot_oc_singlenode(duration, dt, state, target, control*1e3, target_input, model_controlled.cost_history, plot_control_vars=[2,3])
Network case
Let us know study a simple 2-node network of model oscillators. We first define the coupling matrix and the distance matrix. We can then initialize the model.
cmat = np.array( [[0., 0.5], [1., 0.]] ) # diagonal elements are zero, connection strength is 1 (0.5) from node 0 to node 1 (from node 1 to node 0)
dmat = np.array( [[0., 0.], [0., 0.]] ) # no delay
model = ALNModel(Cmat=cmat, Dmat=dmat)
model.params.duration = 10000
model.params.mue_ext_mean = 2. # up state
model.params.de = 0.0
model.params.di = 0.0
model.run()
setinitstate(model, getfinalstate(model))
# we define the control input matrix to enable or disable certain channels and nodes
control_mat = np.zeros( (model.params.N, len(model.input_vars)) )
control_mat[0,0] = 1. # only allow inputs in x-channel in node 0
amplitude = 1.
model.params["duration"] = duration
zero_input = ZeroInput().generate_input(duration=duration+dt, dt=dt)
input = np.copy(zero_input)
input[0,1:-3] = amplitude * np.sin(2.*np.pi*np.arange(0,duration-0.3, dt)/period) # other functions or random values can be used as well
# We set the stimulus in x and y variables, and run the simulation
input_nw = np.concatenate( (np.vstack( [control_mat[0,0] * input, control_mat[0,1] * input, control_mat[0,2] * input, control_mat[0,3] * input] )[np.newaxis,:,:],
np.vstack( [control_mat[1,0] * input, control_mat[1,1] * input, control_mat[1,2] * input, control_mat[1,3] * input] )[np.newaxis,:,:]), axis=0)
zero_input_nw = np.concatenate( (np.vstack( [zero_input, zero_input, zero_input, zero_input] )[np.newaxis,:,:],
np.vstack( [zero_input, zero_input, zero_input, zero_input] )[np.newaxis,:,:]), axis=0)
model.params["ext_exc_current"] = input_nw[:,0,:]
model.params["ext_inh_current"] = input_nw[:,1,:]
model.params["ext_exc_rate"] = input_nw[:,2,:]
model.params["ext_inh_rate"] = input_nw[:,3,:]
model.run()
# Define the result of the stimulation as target
target = getstate(model)
# Remove stimuli and re-run the simulation
model.params["ext_exc_current"] = zero_input_nw[:,0,:]
model.params["ext_inh_current"] = zero_input_nw[:,1,:]
model.params["ext_exc_rate"] = zero_input_nw[:,2,:]
model.params["ext_inh_rate"] = zero_input_nw[:,3,:]
model.run()
# combine initial value and simulation result to one array
state = getstate(model)
plot_oc_network(model.params.N, duration, dt, state, target, zero_input_nw, input_nw)
# we define the precision matrix to specify, in which nodes and channels we measure deviations from the target
cost_mat = np.zeros( (model.params.N, len(model.output_vars)) )
cost_mat[1,0] = 1. # only measure in y-channel in node 1
# We set the external stimulation to zero. This is the "initial guess" for the OC algorithm
model.params["ext_exc_current"] = zero_input_nw[:,0,:]
model.params["ext_inh_current"] = zero_input_nw[:,1,:]
model.params["ext_exc_rate"] = zero_input_nw[:,2,:]
model.params["ext_inh_rate"] = zero_input_nw[:,3,:]
# We load the optimal control class
# print array (optional parameter) defines, for which iterations intermediate results will be printed
# Parameters will be taken from the input model
model_controlled = oc_aln.OcAln(model, target, print_array=np.arange(0,501,25), control_matrix=control_mat, cost_matrix=cost_mat)
# We run 500 iterations of the optimal control gradient descent algorithm
model_controlled.optimize(500)
state = model_controlled.get_xs()
control = model_controlled.control
plot_oc_network(model.params.N, duration, dt, state, target, control, input_nw, model_controlled.cost_history, model_controlled.step_sizes_history)
# Do another 1000 iterations if you want to.
# Repeated execution will continue with further 100 iterations.
model_controlled.zero_step_encountered = False
model_controlled.optimize(100)
state = model_controlled.get_xs()
control = model_controlled.control
plot_oc_network(model.params.N, duration, dt, state, target, control, input_nw, model_controlled.cost_history, model_controlled.step_sizes_history)
Delayed network of neural populations
We now consider a network topology with delayed signalling between the two nodes.
cmat = np.array( [[0., 0.], [1., 0.]] ) # diagonal elements are zero, connection strength is 1 from node 0 to node 1
dmat = np.array( [[0., 0.], [18, 0.]] ) # distance from 0 to 1, delay is computed by dividing by the signal speed params.signalV
model = ALNModel(Cmat=cmat, Dmat=dmat)
model.params.mue_ext_mean = 2. # up state
model.run()
setinitstate(model, getfinalstate(model))
duration = 6.
model.params.duration = duration
model.run()
# we define the control input matrix to enable or disable certain channels and nodes
control_mat = np.zeros( (model.params.N, len(model.state_vars)) )
control_mat[0,0] = 1. # only allow inputs in E-channel in node 0
zero_input = ZeroInput().generate_input(duration=duration+dt, dt=dt)
input = zero_input.copy()
input[0,10] = 10.
input[0,20] = 10.
input[0,30] = 10. # Three pulses as control input
input_nw = np.concatenate( (np.vstack( [control_mat[0,0] * input, control_mat[0,1] * input, control_mat[0,2] * input, control_mat[0,3] * input] )[np.newaxis,:,:],
np.vstack( [control_mat[1,0] * input, control_mat[1,1] * input, control_mat[1,2] * input, control_mat[1,3] * input] )[np.newaxis,:,:]), axis=0)
zero_input_nw = np.concatenate( (np.vstack( [zero_input, zero_input, zero_input, zero_input] )[np.newaxis,:,:],
np.vstack( [zero_input, zero_input, zero_input, zero_input] )[np.newaxis,:,:]), axis=0)
model.params["ext_exc_current"] = input_nw[:,0,:]
model.params["ext_inh_current"] = input_nw[:,1,:]
model.params["ext_exc_rate"] = input_nw[:,2,:]
model.params["ext_inh_rate"] = input_nw[:,3,:]
model.run()
# Define the result of the stimulation as target
target = getstate(model)
# Remove stimuli and re-run the simulation
model.params["ext_exc_current"] = zero_input_nw[:,0,:]
model.params["ext_inh_current"] = zero_input_nw[:,1,:]
model.params["ext_exc_rate"] = zero_input_nw[:,2,:]
model.params["ext_inh_rate"] = zero_input_nw[:,3,:]
model.run()
# combine initial value and simulation result to one array
state = getstate(model)
plot_oc_network(model.params.N, duration, dt, state, target, zero_input_nw, input_nw)
# We load the optimal control class
# print array (optional parameter) defines, for which iterations intermediate results will be printed
# Parameters will be taken from the input model
model.params["ext_exc_current"] = zero_input_nw[:,0,:]
model.params["ext_inh_current"] = zero_input_nw[:,1,:]
model.params["ext_exc_rate"] = zero_input_nw[:,2,:]
model.params["ext_inh_rate"] = zero_input_nw[:,3,:]
model_controlled = oc_aln.OcAln(model, target, print_array=np.arange(0,501,25), control_matrix=control_mat, cost_matrix=cost_mat)
# We run 500 iterations of the optimal control gradient descent algorithm
model_controlled.optimize(500)
state = model_controlled.get_xs()
control = model_controlled.control
plot_oc_network(model.params.N, duration, dt, state, target, control, input_nw, model_controlled.cost_history, model_controlled.step_sizes_history)
# perofrm another 100 iterations to improve result
# repeat execution to add another 100 iterations
# converence to the input stimulus is relatively slow for the WC nodel
model_controlled.optimize(100)
state = model_controlled.get_xs()
control = model_controlled.control
plot_oc_network(model.params.N, duration, dt, state, target, control, input_nw, model_controlled.cost_history, model_controlled.step_sizes_history)