Hackster is hosting Hackster Holidays, Finale: Livestream & Giveaway Drawing. Watch previous episodes or stream live on Tuesday!Stream Hackster Holidays, Finale on Tuesday!
Alex
Created July 31, 2024

Solving Schrödinger equation directly with neural networks

Use the neural network to numerically solve the Schrödinger equation directly in 1D and 2D using MLPs and convolutional neural networks.

8
Solving Schrödinger equation directly with neural networks

Things used in this project

Story

Read more

Code

quantum1d

Python
#!/usr/bin/env python
# coding: utf-8

# # Computation of the quantum state using deep neural network, 1D case
# 
# Here we computation of the quantum state using direction solution of the steady state Schrdinger equation. We provide the potential function $V(x)$ on a grid and using deep neural network to generate the function value $\psi(x)$ which is the wave function.

# In[21]:


import torch 
import torch.nn as nn 
import numpy as np 
import matplotlib.pyplot as plt 


# Define the computation domain, here we take the domain to be -10~10, with 100 grid points.

# In[23]:


L = 10
N = 100
h = 2*L/N


# Define a simple feedforward neural network.

# In[25]:


class FeedForward(nn.Module):
    """The feedfoward network"""
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, n_layers: int, act = torch.tanh):
        super().__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.layers = torch.nn.ModuleList()
        for _ in range(n_layers):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        self.act = act 

    def forward(self, x):
        x = self.input_layer(x)
        x = self.act(x)
        for layer in self.layers:
            x = self.act(layer(x))
        x = self.output_layer(x)
        return x


# Define the potential functions.

# In[28]:


def V(x, poten_type):
  if poten_type == 'box':
    return 1e-10*x
  elif poten_type == 'harmonic':
    return - x**2
  elif poten_type == 'H':
    return 1/np.abs(x+1e-10)
  else:
    raise NotImplementedError("potential function not defined")
def u_true(x, poten_type):
  if poten_type == 'box':
    return np.cos(np.pi*x/(2*L))
  elif poten_type == 'harmonic':
    return (1/np.pi)**0.25 * np.exp(-0.5*x**2)
  elif poten_type == 'H':
    return None
  else:
    raise NotImplementedError("potential function not defined")


# Define the loss, which is 
# $$\frac{\int |\nabla \psi(x)|^2 dx + \int V(x) \psi^2(x) dx}{\int \psi^2(x)dx}$$

# In[29]:


def quad(x):
  return x[:,0]/2 + torch.sum(x[:,1:-1],axis=-1) + x[:,-1]/2
def dx(x):
  return torch.diff(x, prepend=torch.zeros(x.shape[0],1).to(x.device), append=torch.zeros(x.shape[0],1).to(x.device))#, prepend=torch.tensor([0.]*len(x[:,0])).to(x.device), append=torch.tensor([0.]*len(x[:,0])).to(device))


def loss_fn(unet, input, h, eps = 0):
    u = unet(input).squeeze()
    unorm = quad(u**2)
    ugradnorm = quad((dx(u)/h)**2)
    potential = quad(u**2*input)#V(x, poten_type))

  # Vectorize the previous to compute the average of the loss on all samples.
    return torch.sum(torch.mean( (ugradnorm - potential)/unorm,axis=-1))


# generate the potential values on the grid as input

# In[33]:


# USE V AS INPUT
x_grid = np.linspace(-L,L,N)
V_value_harmonic = V(x_grid,poten_type='harmonic')
V_value_box = V(x_grid,poten_type='box')
V_value_H= V(x_grid,poten_type='H')
V_value = torch.tensor(np.stack([V_value_box, V_value_harmonic, V_value_H],axis=0),dtype=torch.float32)


# In[34]:


unet = FeedForward(input_dim = 100, output_dim = 100, hidden_dim = 64, n_layers = 4)


# In[35]:


loss_fn(unet, V_value, 1)


# In[36]:


input = V_value


# Train the network with Adam

# In[37]:


lr = 1e-3
unet.to('cuda')
optimizer = torch.optim.Adam(unet.parameters(),lr)
num_epochs = 10000
device = 'cuda'
input = input.to(device)
# x = x.to(device)

for i in range(num_epochs):
    loss = loss_fn(unet, input, h)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    if i % 100 == 0:
        print('Loss step {}: '.format(i), loss.item())


# In[38]:


torch.tensor(V_value_harmonic,dtype=torch.float32).to(device).view(1,-1).shape


# In[39]:


x = torch.tensor(np.linspace(-L,L,N))
u = unet(torch.tensor(V_value_box,dtype=torch.float32).to(device).view(1,-1))
import matplotlib.pyplot as plt
plt.plot(x.detach().cpu(), torch.abs(u/torch.sqrt(2*L/N*quad(u**2))).detach().cpu().squeeze(),'.-')
tmp1 = torch.mean(u/torch.sqrt(2*L/N*quad(u**2))).detach().cpu()
tmp2 = torch.mean(u_true(x.cpu(), 'box'))
plt.plot(x.detach().cpu(), torch.abs(u_true(x.cpu(), 'box')/tmp2*tmp1))
plt.legend(['Numerical soluton', 'True solution'])
plt.savefig('box1d.png')


# In[42]:


plt.figure(1)
u = unet(torch.tensor(V_value_harmonic,dtype=torch.float32).to(device).view(1,-1))
plt.plot(x.detach().cpu(), torch.abs(u/torch.sqrt(2*L/N*quad(u**2))).detach().cpu().squeeze(),'.-')
tmp1 = torch.mean(u/torch.sqrt(2*L/N*quad(u**2))).detach().cpu()
tmp2 = torch.mean(u_true(x.cpu(), 'harmonic'))
plt.plot(x.detach().cpu(), torch.abs(u_true(x.cpu(), 'harmonic')/tmp2*tmp1))
plt.legend(['Numerical soluton', 'True solution'])

plt.savefig('harmonic1d.png')


# In[43]:


plt.figure(2)
u = unet(torch.tensor(V_value_H,dtype=torch.float32).to(device).view(1,-1))
plt.plot(x.detach().cpu(), torch.abs(u/torch.sqrt(2*L/N*quad(u**2))).detach().cpu().squeeze(),'.-')
plt.legend(['Numerical soluton', 'True solution'])

plt.savefig('H1d.png')


# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:

quantum2d

Python
#!/usr/bin/env python
# coding: utf-8

# # Computation of the quantum state using deep neural network, 2D case
# 
# We continue the previous work in 1D case to the 2D case. Using a convolutional network.

# In[1]:


import torch 
import torch.nn as nn 
import numpy as np 
import matplotlib.pyplot as plt 


# In[2]:


L = 10
N = 64
h = 2*L/N


# In[3]:


def V(x, poten_type):
  if poten_type == 'box':
    return 1e-10*(torch.ones_like(x[:,0])**2 + torch.ones_like(x[:,1])**2)
  elif poten_type == 'harmonic':
    return - x[:,0]**2 - x[:,1]**2
  elif poten_type == 'H':
    return 1/np.abs(np.sqrt(x[:,0]**2+x[:,1]**2)+1e-10)
  else:
    raise NotImplementedError("potential function not defined")


# In[4]:


x = np.linspace(-L,L,N)
x_grid = torch.tensor(np.stack(np.meshgrid(x,x),-1),dtype=torch.float32)
x_grid = x_grid.view(N*N,2)
V_value_harmonic = V(x_grid,poten_type='harmonic')
V_value_box = V(x_grid,poten_type='box')
V_value_H= V(x_grid,poten_type='H')
V_value = torch.tensor(np.stack([V_value_box, V_value_harmonic],axis=0),dtype=torch.float32).view(2,N,N)


# In[ ]:





# In[5]:


def quad(x):
    return (4*torch.sum(x[:,1:-1,1:-1],axis=[1,2])  + 2*torch.sum(x[:,0,1:-1],axis=-1) + 2*torch.sum(x[:,-1,1:-1],axis=-1)
            + 2*torch.sum(x[:,1:-1,0],axis=-1) + 2*torch.sum(x[:,1:-1,-1],axis=-1) + x[:,0,0] + x[:,0,-1] + x[:,-1,0] + x[:,-1,-1])


# In[6]:


def dx(x):
  return torch.diff(x, axis=-2,prepend=torch.zeros(x.shape[0],1,x.shape[-2]).to(x.device), append=torch.zeros(x.shape[0],1,x.shape[-2]).to(x.device))

def dy(x):
  return torch.diff(x,axis=-1, prepend=torch.zeros(x.shape[0],x.shape[-1],1).to(x.device), append=torch.zeros(x.shape[0],x.shape[-1],1).to(x.device))


# In[ ]:





# In[ ]:





# In[ ]:





# In[7]:


def loss_fn(unet, input, h, eps = 0):
    u = unet(input).squeeze(axis=1)
    unorm = quad(u**2)
    ugradnorm = quad((dx(u)/h)**2) + quad((dy(u)/h)**2)
    potential = quad(u**2*input.squeeze(axis=1))#V(x, poten_type))

  # Vectorize the previous to compute the average of the loss on all samples.
    return torch.sum(torch.mean( (ugradnorm - potential)/unorm,axis=-1))

    # return torch.sum((torch.mean(ugradnorm) - torch.mean(potential))/unorm,axis=-1)


# In[ ]:





# In[8]:


nz=256
nc = 1
ndf = 64
ngf=64
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            # input is ``(nc) x 64 x 64``
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.ReLU(),
            # state size. ``(ndf) x 32 x 32``
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.ReLU(),
            # state size. ``(ndf*2) x 16 x 16``
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.ReLU(),
            # state size. ``(ndf*4) x 8 x 8``
            nn.Conv2d(ndf * 4, ndf , 4, 4, 1, bias=False),
            nn.BatchNorm2d(ndf),
            nn.ReLU(),
            nn.Flatten())
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            # nn.LayerNorm([64,64]),
            # nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.shape[0],x.shape[1],1,1)
        
        return self.decoder(x)


# In[9]:


# USE V AS INPUT
V_value_harmonic = V(x_grid,poten_type='harmonic')
V_value_box = V(x_grid,poten_type='box')
V_value_H= V(x_grid,poten_type='H')
V_value = torch.tensor(np.stack([V_value_box, V_value_harmonic, V_value_H],axis=0),dtype=torch.float32)


# In[10]:


input = V_value.view(3,1,64,64)


# In[11]:


unet = ConvNet()


# In[12]:


u = unet(input).squeeze(axis=1)


# In[13]:


u = unet(input).squeeze(axis=1)
unorm = quad(u**2)
ugradnorm = quad((dx(u)/h)**2) + quad((dy(u)/h)**2)
potential = quad(u**2*input.squeeze(1))#V(x, poten_type))
torch.sum((torch.mean(ugradnorm) - torch.mean(potential))/unorm,axis=-1)


# In[ ]:





# In[14]:


lr = 1e-3
unet.to('cuda')
optimizer = torch.optim.Adam(unet.parameters(),lr)
num_epochs = 10000
device = 'cuda'
input = input.to(device)

for i in range(num_epochs):
    loss = loss_fn(unet, input, h)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    if i % 100 == 0:
        print('Loss step {}: '.format(i), loss.item())


# In[15]:


plt.contourf(unet(V_value_harmonic.view(1,1,64,64).to(device))[0,0].detach().cpu())


# In[16]:


plt.contourf(unet(V_value_H.view(1,1,64,64).to(device))[0,0].detach().cpu())


# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:

Credits

Alex
1 project • 0 followers

Comments