DeepEM: A Deep Neural Network for DEM Inversion

by Paul Wright, Mark Cheung, Rajat Thomas, Richard Galvez, Alexandre Szenicer, Meng Jin, Andres Munoz-Jaramillo, and David Fouhey

The intensity observed through optically-thin SDO/AIA filters (94 Å, 131 Å, 171 Å, 193 Å, 211 Å, 335 Å) can be related to the temperature distribution of the solar corona (the differential emission measure; DEM) as

\begin{equation} g_{i} = \int_{T} K_{i}(T) \xi(T) dT \, . \end{equation}

In this equation, $g_{i}$ is the DN s$^{-1}$ px$^{-1}$ value in the $i$th SDO/AIA channel. This intensity corresponds to the $K_{i}(T)$ temperature response function, and the DEM, $\xi(T)$, is in units of cm$^{-5}$ K$^{-1}$. The matrix formulation of this integral equation can be represented in the form, $\vec{g} = {\bf K}\vec{\xi}$, however, this problem is an ill-posed inverse problem, and any attempt to directly recover $\vec{\xi}$ leads to significant noise amplication.

There are numerous methods to tackle mathematical problems of this kind, and there are an increasing number of methods in the literature for recovering the differential emission measure including a method based on the concept of sparsity (Cheung et al 2015). In the following notebook, we will demonstrate how a simple 1x1 2D convolutional neural network allows for significant improvement in computational speed for DEM inversion with similar fidelity to Basis Pursuit (Sparse Inversion). Additionally this method, DeepEM, provides solutions with values of emission measure >0 in every temperature bin.

DeepEM: A Deep Learning Approach for DEM Inversion

Paul J. Wright, Mark Cheung, Rajat Thomas, Richard Galvez, Alexandre Szenicer, Meng Jin, Andres Munoz-Jaramillo, and David Fouhey


In this chapter we will introduce a Deep Learning approach for DEM Inversion. For this notebook, DeepEM is a trained on one set of SDO/AIA observations (six optically thin channels; $6 \times N \times N$) and DEM solutions (in 18 temperature bins from log$_{10}$T = 5.5 - 7.2, $18 \times N \times N$; Cheung et al 2015) at a resolution of $512 \times 512$ ($N = 512$) using a $1 \times 1$ 2D Convolutional Neural Network with a single hidden layer.

The DeepEM method presented here takes every DEM solution with no regards to the quality or existence of the solution. As will be demonstrated, when this method is trained with a single set images and DEM solutions, the DeepEM solutions have a similar fidelity to Sparse Inversion (with a significantly increased computation speed), and additionally, the DeepEM solutions find positive solutions at every pixel, and reduced noise in the DEM solutions.

In [ ]:
import os
import json
import time
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from src import aia_deep_em
from scipy.io import readsav
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torch.autograd import Variable
from torch.utils.data import DataLoader

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
In [2]:
def em_scale(y):
    return np.sqrt(y/1e25)

def em_unscale(y):
    return 1e25*(y*y)

def img_scale(x):
    x2 = x
    bad = np.where(x2 <= 0.0)
    x2[bad] = 0.0
    return np.sqrt(x2)

def img_unscale(x):
    return x*x 

Step 1: Obtain Data and Sparse Inversion Solutions for Training

In [3]:
aia_files = ['AIA_DEM_2011-01-27','AIA_DEM_2011-02-22','AIA_DEM_2011-03-20']
em_cube_files = aia_files

for k, (afile, emfile) in enumerate(zip(aia_files, em_cube_files)):
    afile_name = os.path.join('./SomeData2/', afile + '.aia.npy')
    emfile_name = os.path.join('./SomeData2/', emfile + '.emcube.npy')
    if k == 0:
        X = np.load(afile_name)
        y = np.load(emfile_name)
 
        X = np.zeros((len(aia_files), X.shape[0], X.shape[1], X.shape[2]))
        y = np.zeros((len(em_cube_files), y.shape[0], y.shape[1], y.shape[2]))
        
        nlgT = y.shape[0]
        lgtaxis = np.arange(y.shape[1])*0.1 + 5.5
        
    X[k] = np.load(afile_name)
    y[k] = np.load(emfile_name) 

Step 2: Define the Model

We first define the model as a 1x1 2D Convolutional Neural Network (CNN) with a kernel size of 1x1. The model accepts a data cube of $6 \times N \times N$ (SDO/AIA data), and returns a data cube of $18 \times N \times N$ (DEM). which when trained, will transform the input (each pixel of the 6 SDO/AIA channels; $6 \times 1 \times 1$) to the output (DEM at each pixel; $18 \times 1 \times 1$).

In [4]:
model = nn.Sequential(
    nn.Conv2d(6, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 18, kernel_size=1)).cuda() #Loading model on to gpu

Step 3: Train the Model

For training our CNN we select one SDO/AIA data cube ($6\times512\times512$) and the corresponding Sparse Inversion DEM output ($18\times512\times512$). In the case presented here, we train the CNN on an image of the Sun obtained on 27-01-2011, validate on an image of the Sun obtained one synodic rotation later (+26 days; 22-02-2011), and finally test on an image another 26 days later (20-03-2011).

In [5]:
X = img_scale(X) # Why was the data with negatives??????
y = em_scale(y)

X_train = X[0:1] 
y_train = y[0:1] 

X_val = X[1:2] 
y_val = y[1:2] 

X_test = X[2:3] 
y_test = y[2:3]

Plotting SDO/AIA Observations ${\it vs.}$ Basis Pursuit DEM bins

For the test data set, the SDO/AIA images for 171 Å, 211 Å, and 94 Å, and the corresponding DEM bins near the peak sensitivity in these relative isothermal channel (logT = 6.3, 5.9) are shown in Figure 1. Figure 1 shows a set of SDO/AIA images (171 Å, 211 Å, and 94 Å [Left to Right]) with the corresponding DEM maps for temperature bins there are near the peak sensitivity of the SDO/AIA channel. Furthermore, it is clear from the DEM maps that a number of pixels that are $zero$. These pixels are primarily located off-disk, but there are a number of pixels on-disk that show this behaviour.

In [6]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(y_test[0,8,:,:],vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(y_test[0,4,:,:],vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(y_test[0,15,:,:],vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)

Figure 1: Left to Right: SDO/AIA images in 171 Å, 211 Å, and 94 Å (top), with the corresponding DEM bins (chosen at the peak sensitivity of each of the SDO/AIA channels) shown below. In the DEM bins (bottom) it is clear that there are some pixels that have solutions of DEM = $zero$, as explicitly seen as dark regions/clusters of pixels on and off disk.


To implement training and testing of our model, we first define a DEMdata class, that allows us to...., and define functions for training and validation/test: train_model, and valtest_model.

N.B. It is not necessary to train the model, and if required, the trained model can be loaded to the cpu as follows:

model = nn.Sequential(
    nn.Conv2d(6, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 18, kernel_size=1))

dem_model_file = 'DeepEM_CNN_HelioML.pth'
model.load_state_dict(torch.load(dem_model_file, map_location='cpu'))

Once you have loaded the the model, skip to Step 4: Testing the Model.

In [7]:
class DEMdata(nn.Module):
    def __init__(self, xtrain, ytrain, xtest, ytest, xval, yval, split='train'):
        
        if split == 'train':
            self.x = xtrain
            self.y = ytrain
        if split == 'val':
            self.x = xval
            self.y = yval
        if split == 'test':
            self.x = xtest
            self.y = ytest
            
    def __getitem__(self, index):
        return torch.from_numpy(self.x[index]).type(torch.FloatTensor), torch.from_numpy(self.y[index]).type(torch.FloatTensor)

    def __len__(self):
        return self.x.shape[0]
In [8]:
def train_model(dem_loader, criterion, optimizer, epochs=500):
    model.train()
    train_loss_all_batches = []
    train_loss_epoch = []
    train_val = []
    for k in range(epochs):
        count_ = 0
        avg_loss = 0
        # =================== progress indicator ==============
        if k % ((epochs + 1) // 4) == 0:
            print('[{0}]: {1:.1f}% complete: '.format(k, k / epochs * 100))
        # =====================================================
        for img, dem in dem_loader:
            count_ += 1
            optimizer.zero_grad()
            # =================== forward =====================
            img = img.cuda()
            dem = dem.cuda()

            output = model(img) 
            loss = criterion(output, dem)

            loss.backward()
            optimizer.step()
            
            train_loss_all_batches.append(loss.item())
            avg_loss += loss.item()
        # =================== Validation ===================
        dem_data_val = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='val')
        dem_loader_val = DataLoader(dem_data_val, batch_size=1)
        #val_loss, dummy, dem_pred_val, dem_in_test_val, lossarr, loss2arr, loss3arr = valtest_model(dem_loader_val, criterion)
        val_loss, dummy, dem_pred_val, dem_in_test_val = valtest_model(dem_loader_val, criterion)
        
        train_loss_epoch.append(avg_loss/count_)
        train_val.append(val_loss)
        
        if k>0:
            #print('Epoch: {0}, Train Loss: {1}, Validation loss: {2}').format(k,avg_loss/count_,train_val[k-1])
            print('Epoch: ', k, 'trn_loss: ', avg_loss/count_, 'val_loss: ', train_val[k-1])
        else:
            print('Epoch: ', k, 'trn_loss: ', avg_loss/count_)
            
    torch.save(model.state_dict(), 'DeepEM_CNN_HelioML.pth')
    return train_loss_epoch, train_val

def valtest_model(dem_loader, criterion):

    #model.to('cpu')
    model.eval()
    
    val_loss = 0
    count = 0
    test_loss = []
    lossarr, loss2arr, loss3arr = [], [], []
    for img, dem in dem_loader:
        count += 1
        # =================== forward =====================
        img = img.cuda()
        dem = dem.cuda()
        
        output = model(img)
        loss = criterion(output, dem)
        test_loss.append(loss.item())
        val_loss += loss.item()
        
    #model.to('cuda')
    return val_loss/count, test_loss, output, dem#, lossarr, loss2arr, loss3arr

We choose the Adam optimiser with a learning rate of 1e-4, and weight_decay set to 1e-9. We use Mean Squared Error (MSE) between the Sparse Inversion DEM map and the DeepEM map as our loss function.

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-9); # 
criterion = nn.MSELoss().cuda()

Using the defined functions, dem_data will return the training data, and this will be loaded by the DataLoader with batch_size=1 (one 512 x 512 image per batch). For each epoch, train_loss and valdn_loss will be returned by train_model

In [10]:
dem_data = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='train')
dem_loader = DataLoader(dem_data, batch_size=1)

t0=time.time() #Timing how long it takes to predict the DEMs
train_loss, valdn_loss = train_model(dem_loader, criterion, optimizer, epochs=500)
ttime = "Training time = {0} seconds".format(time.time()-t0)
print(ttime)
[0]: 0.0% complete: 
Epoch:  0 trn_loss:  2.720517635345459
Epoch:  1 trn_loss:  2.5714499950408936 val_loss:  2.992682933807373
Epoch:  2 trn_loss:  2.4300031661987305 val_loss:  2.836111307144165
Epoch:  3 trn_loss:  2.2960352897644043 val_loss:  2.6875860691070557
Epoch:  4 trn_loss:  2.169210433959961 val_loss:  2.546846628189087
Epoch:  5 trn_loss:  2.0493409633636475 val_loss:  2.4136226177215576
Epoch:  6 trn_loss:  1.9361165761947632 val_loss:  2.2875678539276123
Epoch:  7 trn_loss:  1.8289133310317993 val_loss:  2.1680002212524414
Epoch:  8 trn_loss:  1.727626919746399 val_loss:  2.054877996444702
Epoch:  9 trn_loss:  1.6316800117492676 val_loss:  1.9475942850112915
Epoch:  10 trn_loss:  1.540708065032959 val_loss:  1.845640778541565
Epoch:  11 trn_loss:  1.4545608758926392 val_loss:  1.7487554550170898
Epoch:  12 trn_loss:  1.3730440139770508 val_loss:  1.6566749811172485
Epoch:  13 trn_loss:  1.2958085536956787 val_loss:  1.5690476894378662
Epoch:  14 trn_loss:  1.2226998805999756 val_loss:  1.485715627670288
Epoch:  15 trn_loss:  1.153700828552246 val_loss:  1.4066847562789917
Epoch:  16 trn_loss:  1.0887045860290527 val_loss:  1.3317606449127197
Epoch:  17 trn_loss:  1.0275851488113403 val_loss:  1.2607773542404175
Epoch:  18 trn_loss:  0.9703051447868347 val_loss:  1.1936999559402466
Epoch:  19 trn_loss:  0.9168525338172913 val_loss:  1.1306010484695435
Epoch:  20 trn_loss:  0.8672220706939697 val_loss:  1.071594476699829
Epoch:  21 trn_loss:  0.8214365839958191 val_loss:  1.0167442560195923
Epoch:  22 trn_loss:  0.779478132724762 val_loss:  0.966064453125
Epoch:  23 trn_loss:  0.7412143349647522 val_loss:  0.919457733631134
Epoch:  24 trn_loss:  0.7066997289657593 val_loss:  0.8770449161529541
Epoch:  25 trn_loss:  0.6760365962982178 val_loss:  0.8389584422111511
Epoch:  26 trn_loss:  0.649154543876648 val_loss:  0.805176854133606
Epoch:  27 trn_loss:  0.625911295413971 val_loss:  0.7756196856498718
Epoch:  28 trn_loss:  0.606098473072052 val_loss:  0.7501355409622192
Epoch:  29 trn_loss:  0.58946692943573 val_loss:  0.7284752726554871
Epoch:  30 trn_loss:  0.5757277011871338 val_loss:  0.7103342413902283
Epoch:  31 trn_loss:  0.5646038055419922 val_loss:  0.6953976154327393
Epoch:  32 trn_loss:  0.5557752847671509 val_loss:  0.6833671927452087
Epoch:  33 trn_loss:  0.5489102602005005 val_loss:  0.6739091873168945
Epoch:  34 trn_loss:  0.5437614917755127 val_loss:  0.6667168140411377
Epoch:  35 trn_loss:  0.5400499105453491 val_loss:  0.6614289879798889
Epoch:  36 trn_loss:  0.5374208092689514 val_loss:  0.6575927734375
Epoch:  37 trn_loss:  0.5354573130607605 val_loss:  0.6546897888183594
Epoch:  38 trn_loss:  0.5337497591972351 val_loss:  0.6522194743156433
Epoch:  39 trn_loss:  0.5319835543632507 val_loss:  0.6497853398323059
Epoch:  40 trn_loss:  0.5299403667449951 val_loss:  0.6471131443977356
Epoch:  41 trn_loss:  0.5274948477745056 val_loss:  0.644046425819397
Epoch:  42 trn_loss:  0.5245950222015381 val_loss:  0.6405224204063416
Epoch:  43 trn_loss:  0.521243691444397 val_loss:  0.6365557909011841
Epoch:  44 trn_loss:  0.5174883008003235 val_loss:  0.6322070956230164
Epoch:  45 trn_loss:  0.513403058052063 val_loss:  0.627569317817688
Epoch:  46 trn_loss:  0.5090671181678772 val_loss:  0.6227394938468933
Epoch:  47 trn_loss:  0.5045645833015442 val_loss:  0.6178194284439087
Epoch:  48 trn_loss:  0.4999774396419525 val_loss:  0.6129067540168762
Epoch:  49 trn_loss:  0.4953801929950714 val_loss:  0.6080872416496277
Epoch:  50 trn_loss:  0.4908350706100464 val_loss:  0.603421151638031
Epoch:  51 trn_loss:  0.48641449213027954 val_loss:  0.5989682674407959
Epoch:  52 trn_loss:  0.4821968078613281 val_loss:  0.5947808027267456
Epoch:  53 trn_loss:  0.4782201945781708 val_loss:  0.5908913612365723
Epoch:  54 trn_loss:  0.47451189160346985 val_loss:  0.587300717830658
Epoch:  55 trn_loss:  0.4710758924484253 val_loss:  0.5839908123016357
Epoch:  56 trn_loss:  0.46790266036987305 val_loss:  0.5809233784675598
Epoch:  57 trn_loss:  0.4649679958820343 val_loss:  0.5780518054962158
Epoch:  58 trn_loss:  0.46223893761634827 val_loss:  0.5753233432769775
Epoch:  59 trn_loss:  0.45968690514564514 val_loss:  0.5726969838142395
Epoch:  60 trn_loss:  0.4572847783565521 val_loss:  0.5701340436935425
Epoch:  61 trn_loss:  0.4549981355667114 val_loss:  0.5675960779190063
Epoch:  62 trn_loss:  0.4527876079082489 val_loss:  0.5650503635406494
Epoch:  63 trn_loss:  0.45062288641929626 val_loss:  0.5624792575836182
Epoch:  64 trn_loss:  0.44848552346229553 val_loss:  0.559876561164856
Epoch:  65 trn_loss:  0.44636502861976624 val_loss:  0.557245135307312
Epoch:  66 trn_loss:  0.44425684213638306 val_loss:  0.5545910000801086
Epoch:  67 trn_loss:  0.4421635568141937 val_loss:  0.5519225597381592
Epoch:  68 trn_loss:  0.44008588790893555 val_loss:  0.5492464900016785
Epoch:  69 trn_loss:  0.4380251467227936 val_loss:  0.5465660691261292
Epoch:  70 trn_loss:  0.43598219752311707 val_loss:  0.5438817739486694
Epoch:  71 trn_loss:  0.43395718932151794 val_loss:  0.5411955714225769
Epoch:  72 trn_loss:  0.43195196986198425 val_loss:  0.538510262966156
Epoch:  73 trn_loss:  0.4299684464931488 val_loss:  0.535831868648529
Epoch:  74 trn_loss:  0.42800891399383545 val_loss:  0.5331716537475586
Epoch:  75 trn_loss:  0.4260765612125397 val_loss:  0.5305460691452026
Epoch:  76 trn_loss:  0.42417314648628235 val_loss:  0.5279734134674072
Epoch:  77 trn_loss:  0.42230165004730225 val_loss:  0.5254625082015991
Epoch:  78 trn_loss:  0.4204569458961487 val_loss:  0.5230153203010559
Epoch:  79 trn_loss:  0.418636292219162 val_loss:  0.520649790763855
Epoch:  80 trn_loss:  0.4168963134288788 val_loss:  0.5184133648872375
Epoch:  81 trn_loss:  0.41519343852996826 val_loss:  0.5162205100059509
Epoch:  82 trn_loss:  0.4135139584541321 val_loss:  0.5140416622161865
Epoch:  83 trn_loss:  0.41185620427131653 val_loss:  0.5118736624717712
Epoch:  84 trn_loss:  0.4102168679237366 val_loss:  0.509716272354126
Epoch:  85 trn_loss:  0.40859371423721313 val_loss:  0.5075733065605164
Epoch:  86 trn_loss:  0.40698516368865967 val_loss:  0.5054500699043274
Epoch:  87 trn_loss:  0.40538710355758667 val_loss:  0.5033512115478516
Epoch:  88 trn_loss:  0.40379223227500916 val_loss:  0.5012767314910889
Epoch:  89 trn_loss:  0.4021846055984497 val_loss:  0.49922382831573486
Epoch:  90 trn_loss:  0.40055906772613525 val_loss:  0.4971690773963928
Epoch:  91 trn_loss:  0.3989052176475525 val_loss:  0.49507924914360046
Epoch:  92 trn_loss:  0.39720144867897034 val_loss:  0.49295514822006226
Epoch:  93 trn_loss:  0.3954240381717682 val_loss:  0.4907684326171875
Epoch:  94 trn_loss:  0.39359250664711 val_loss:  0.4885094463825226
Epoch:  95 trn_loss:  0.39175736904144287 val_loss:  0.4862063229084015
Epoch:  96 trn_loss:  0.3899799883365631 val_loss:  0.483944833278656
Epoch:  97 trn_loss:  0.38832423090934753 val_loss:  0.4818169176578522
Epoch:  98 trn_loss:  0.38678768277168274 val_loss:  0.47982609272003174
Epoch:  99 trn_loss:  0.38529014587402344 val_loss:  0.47788265347480774
Epoch:  100 trn_loss:  0.3837854564189911 val_loss:  0.4759517312049866
Epoch:  101 trn_loss:  0.3822726905345917 val_loss:  0.4740460216999054
Epoch:  102 trn_loss:  0.38076186180114746 val_loss:  0.4721814692020416
Epoch:  103 trn_loss:  0.3792618215084076 val_loss:  0.47036755084991455
Epoch:  104 trn_loss:  0.3777764141559601 val_loss:  0.4686030149459839
Epoch:  105 trn_loss:  0.3763037919998169 val_loss:  0.46687597036361694
Epoch:  106 trn_loss:  0.374838650226593 val_loss:  0.46516773104667664
Epoch:  107 trn_loss:  0.3733749985694885 val_loss:  0.463457852602005
Epoch:  108 trn_loss:  0.3719085156917572 val_loss:  0.46172890067100525
Epoch:  109 trn_loss:  0.37043771147727966 val_loss:  0.4599689841270447
Epoch:  110 trn_loss:  0.36896270513534546 val_loss:  0.4581727981567383
Epoch:  111 trn_loss:  0.3674834966659546 val_loss:  0.45633983612060547
Epoch:  112 trn_loss:  0.36599811911582947 val_loss:  0.45447540283203125
Epoch:  113 trn_loss:  0.3645072877407074 val_loss:  0.45258671045303345
Epoch:  114 trn_loss:  0.3630163073539734 val_loss:  0.45068180561065674
Epoch:  115 trn_loss:  0.36152660846710205 val_loss:  0.44877704977989197
Epoch:  116 trn_loss:  0.36004722118377686 val_loss:  0.4469039738178253
Epoch:  117 trn_loss:  0.3585922420024872 val_loss:  0.44508621096611023
Epoch:  118 trn_loss:  0.35713210701942444 val_loss:  0.44322019815444946
Epoch:  119 trn_loss:  0.3556671440601349 val_loss:  0.441264271736145
Epoch:  120 trn_loss:  0.35419708490371704 val_loss:  0.4392160475254059
Epoch:  121 trn_loss:  0.3527316451072693 val_loss:  0.4371379017829895
Epoch:  122 trn_loss:  0.35127541422843933 val_loss:  0.4350956380367279
Epoch:  123 trn_loss:  0.34982478618621826 val_loss:  0.43313682079315186
Epoch:  124 trn_loss:  0.348373681306839 val_loss:  0.4312806725502014
[125]: 25.0% complete: 
Epoch:  125 trn_loss:  0.3469199240207672 val_loss:  0.4295201301574707
Epoch:  126 trn_loss:  0.3454664647579193 val_loss:  0.42783012986183167
Epoch:  127 trn_loss:  0.3440166711807251 val_loss:  0.4261738657951355
Epoch:  128 trn_loss:  0.34257179498672485 val_loss:  0.42450881004333496
Epoch:  129 trn_loss:  0.3411303162574768 val_loss:  0.42280280590057373
Epoch:  130 trn_loss:  0.33969107270240784 val_loss:  0.42104291915893555
Epoch:  131 trn_loss:  0.33825448155403137 val_loss:  0.4192368686199188
Epoch:  132 trn_loss:  0.33682143688201904 val_loss:  0.4174092710018158
Epoch:  133 trn_loss:  0.33539149165153503 val_loss:  0.41558653116226196
Epoch:  134 trn_loss:  0.33396199345588684 val_loss:  0.41378873586654663
Epoch:  135 trn_loss:  0.33253103494644165 val_loss:  0.4120235741138458
Epoch:  136 trn_loss:  0.33109962940216064 val_loss:  0.4102880656719208
Epoch:  137 trn_loss:  0.32967042922973633 val_loss:  0.40856727957725525
Epoch:  138 trn_loss:  0.32824480533599854 val_loss:  0.4068402051925659
Epoch:  139 trn_loss:  0.3268235921859741 val_loss:  0.405087411403656
Epoch:  140 trn_loss:  0.32540783286094666 val_loss:  0.4033021330833435
Epoch:  141 trn_loss:  0.32399845123291016 val_loss:  0.40149375796318054
Epoch:  142 trn_loss:  0.3225962519645691 val_loss:  0.39967775344848633
Epoch:  143 trn_loss:  0.3212057650089264 val_loss:  0.39787551760673523
Epoch:  144 trn_loss:  0.3198336362838745 val_loss:  0.39610692858695984
Epoch:  145 trn_loss:  0.31847119331359863 val_loss:  0.3943817913532257
Epoch:  146 trn_loss:  0.3171083629131317 val_loss:  0.3926984965801239
Epoch:  147 trn_loss:  0.3157441318035126 val_loss:  0.39104223251342773
Epoch:  148 trn_loss:  0.3143830895423889 val_loss:  0.3893944323062897
Epoch:  149 trn_loss:  0.313029021024704 val_loss:  0.3877376317977905
Epoch:  150 trn_loss:  0.31168273091316223 val_loss:  0.3860611021518707
Epoch:  151 trn_loss:  0.3103446662425995 val_loss:  0.3843643367290497
Epoch:  152 trn_loss:  0.3090153932571411 val_loss:  0.382657527923584
Epoch:  153 trn_loss:  0.3076956272125244 val_loss:  0.3809569180011749
Epoch:  154 trn_loss:  0.3063872456550598 val_loss:  0.3792808949947357
Epoch:  155 trn_loss:  0.30508995056152344 val_loss:  0.3776388168334961
Epoch:  156 trn_loss:  0.303801566362381 val_loss:  0.3760322034358978
Epoch:  157 trn_loss:  0.3025206923484802 val_loss:  0.3744567632675171
Epoch:  158 trn_loss:  0.30124780535697937 val_loss:  0.372900128364563
Epoch:  159 trn_loss:  0.29998430609703064 val_loss:  0.37134647369384766
Epoch:  160 trn_loss:  0.2987309992313385 val_loss:  0.3697834610939026
Epoch:  161 trn_loss:  0.29748842120170593 val_loss:  0.36820539832115173
Epoch:  162 trn_loss:  0.2962568998336792 val_loss:  0.3666168451309204
Epoch:  163 trn_loss:  0.295036643743515 val_loss:  0.36502566933631897
Epoch:  164 trn_loss:  0.29382768273353577 val_loss:  0.3634432256221771
Epoch:  165 trn_loss:  0.2926303446292877 val_loss:  0.3618801534175873
Epoch:  166 trn_loss:  0.2914446294307709 val_loss:  0.36034318804740906
Epoch:  167 trn_loss:  0.2902708947658539 val_loss:  0.3588337004184723
Epoch:  168 trn_loss:  0.28910937905311584 val_loss:  0.3573489487171173
Epoch:  169 trn_loss:  0.28796035051345825 val_loss:  0.35588300228118896
Epoch:  170 trn_loss:  0.28682413697242737 val_loss:  0.35442987084388733
Epoch:  171 trn_loss:  0.2857007086277008 val_loss:  0.3529864251613617
Epoch:  172 trn_loss:  0.28459030389785767 val_loss:  0.35155197978019714
Epoch:  173 trn_loss:  0.28349268436431885 val_loss:  0.3501293361186981
Epoch:  174 trn_loss:  0.2824079990386963 val_loss:  0.34872186183929443
Epoch:  175 trn_loss:  0.28133609890937805 val_loss:  0.34733250737190247
Epoch:  176 trn_loss:  0.28027698397636414 val_loss:  0.34596219658851624
Epoch:  177 trn_loss:  0.2792307734489441 val_loss:  0.3446095883846283
Epoch:  178 trn_loss:  0.27819761633872986 val_loss:  0.3432721793651581
Epoch:  179 trn_loss:  0.27717745304107666 val_loss:  0.34194621443748474
Epoch:  180 trn_loss:  0.2761702835559845 val_loss:  0.34063005447387695
Epoch:  181 trn_loss:  0.27517616748809814 val_loss:  0.33932459354400635
Epoch:  182 trn_loss:  0.27419498562812805 val_loss:  0.3380321264266968
Epoch:  183 trn_loss:  0.27322694659233093 val_loss:  0.33675676584243774
Epoch:  184 trn_loss:  0.2722719609737396 val_loss:  0.3355017602443695
Epoch:  185 trn_loss:  0.2713298797607422 val_loss:  0.3342685103416443
Epoch:  186 trn_loss:  0.2704004943370819 val_loss:  0.33305591344833374
Epoch:  187 trn_loss:  0.2694838047027588 val_loss:  0.33186066150665283
Epoch:  188 trn_loss:  0.2685796916484833 val_loss:  0.3306792080402374
Epoch:  189 trn_loss:  0.26768800616264343 val_loss:  0.32950714230537415
Epoch:  190 trn_loss:  0.2668086588382721 val_loss:  0.3283434212207794
Epoch:  191 trn_loss:  0.2659415900707245 val_loss:  0.327188640832901
Epoch:  192 trn_loss:  0.2650865912437439 val_loss:  0.3260452151298523
Epoch:  193 trn_loss:  0.26424339413642883 val_loss:  0.3249155282974243
Epoch:  194 trn_loss:  0.26341190934181213 val_loss:  0.32380229234695435
Epoch:  195 trn_loss:  0.2625919580459595 val_loss:  0.3227068781852722
Epoch:  196 trn_loss:  0.26178327202796936 val_loss:  0.3216293454170227
Epoch:  197 trn_loss:  0.26098570227622986 val_loss:  0.32056835293769836
Epoch:  198 trn_loss:  0.260198712348938 val_loss:  0.3195227384567261
Epoch:  199 trn_loss:  0.259422242641449 val_loss:  0.31849080324172974
Epoch:  200 trn_loss:  0.25865572690963745 val_loss:  0.3174709677696228
Epoch:  201 trn_loss:  0.25789886713027954 val_loss:  0.3164633512496948
Epoch:  202 trn_loss:  0.2571505904197693 val_loss:  0.3154677152633667
Epoch:  203 trn_loss:  0.25640979409217834 val_loss:  0.31448352336883545
Epoch:  204 trn_loss:  0.25567570328712463 val_loss:  0.3135102689266205
Epoch:  205 trn_loss:  0.25494787096977234 val_loss:  0.31254690885543823
Epoch:  206 trn_loss:  0.25422531366348267 val_loss:  0.3115929067134857
Epoch:  207 trn_loss:  0.2535097599029541 val_loss:  0.31064674258232117
Epoch:  208 trn_loss:  0.252804160118103 val_loss:  0.30970701575279236
Epoch:  209 trn_loss:  0.252104789018631 val_loss:  0.3087732493877411
Epoch:  210 trn_loss:  0.25141236186027527 val_loss:  0.30785098671913147
Epoch:  211 trn_loss:  0.2507249116897583 val_loss:  0.30695295333862305
Epoch:  212 trn_loss:  0.2500661313533783 val_loss:  0.3061336576938629
Epoch:  213 trn_loss:  0.2494550347328186 val_loss:  0.30540451407432556
Epoch:  214 trn_loss:  0.24884366989135742 val_loss:  0.3046551048755646
Epoch:  215 trn_loss:  0.24822528660297394 val_loss:  0.3038537800312042
Epoch:  216 trn_loss:  0.24761027097702026 val_loss:  0.3030061423778534
Epoch:  217 trn_loss:  0.2470024973154068 val_loss:  0.3021312355995178
Epoch:  218 trn_loss:  0.24640202522277832 val_loss:  0.3012598156929016
Epoch:  219 trn_loss:  0.24580825865268707 val_loss:  0.30042243003845215
Epoch:  220 trn_loss:  0.2452208697795868 val_loss:  0.2996369004249573
Epoch:  221 trn_loss:  0.244639053940773 val_loss:  0.29889917373657227
Epoch:  222 trn_loss:  0.24405993521213531 val_loss:  0.2981848418712616
Epoch:  223 trn_loss:  0.24347925186157227 val_loss:  0.29746243357658386
Epoch:  224 trn_loss:  0.2428930103778839 val_loss:  0.29669973254203796
Epoch:  225 trn_loss:  0.2422923445701599 val_loss:  0.29588282108306885
Epoch:  226 trn_loss:  0.24167487025260925 val_loss:  0.29502788186073303
Epoch:  227 trn_loss:  0.2410658448934555 val_loss:  0.2941628694534302
Epoch:  228 trn_loss:  0.2404797077178955 val_loss:  0.2933007776737213
Epoch:  229 trn_loss:  0.23993295431137085 val_loss:  0.29246610403060913
Epoch:  230 trn_loss:  0.2394104301929474 val_loss:  0.2916789650917053
Epoch:  231 trn_loss:  0.23890362679958344 val_loss:  0.29092440009117126
Epoch:  232 trn_loss:  0.2384023666381836 val_loss:  0.29022637009620667
Epoch:  233 trn_loss:  0.23790280520915985 val_loss:  0.2896081209182739
Epoch:  234 trn_loss:  0.23740547895431519 val_loss:  0.2890579104423523
Epoch:  235 trn_loss:  0.23691561818122864 val_loss:  0.28854289650917053
Epoch:  236 trn_loss:  0.23643429577350616 val_loss:  0.28802838921546936
Epoch:  237 trn_loss:  0.23595643043518066 val_loss:  0.2874943017959595
Epoch:  238 trn_loss:  0.23547813296318054 val_loss:  0.2869454622268677
Epoch:  239 trn_loss:  0.23500068485736847 val_loss:  0.28640303015708923
Epoch:  240 trn_loss:  0.23452812433242798 val_loss:  0.2858823537826538
Epoch:  241 trn_loss:  0.23406320810317993 val_loss:  0.28537803888320923
Epoch:  242 trn_loss:  0.23360517621040344 val_loss:  0.2848685085773468
Epoch:  243 trn_loss:  0.23315173387527466 val_loss:  0.28433284163475037
Epoch:  244 trn_loss:  0.23270143568515778 val_loss:  0.2837727665901184
Epoch:  245 trn_loss:  0.23225480318069458 val_loss:  0.283214807510376
Epoch:  246 trn_loss:  0.23181205987930298 val_loss:  0.2826923131942749
Epoch:  247 trn_loss:  0.23137272894382477 val_loss:  0.2822265326976776
Epoch:  248 trn_loss:  0.23093655705451965 val_loss:  0.28180626034736633
Epoch:  249 trn_loss:  0.23050373792648315 val_loss:  0.28139224648475647
[250]: 50.0% complete: 
Epoch:  250 trn_loss:  0.23007355630397797 val_loss:  0.2809486985206604
Epoch:  251 trn_loss:  0.22964556515216827 val_loss:  0.28046396374702454
Epoch:  252 trn_loss:  0.2292202115058899 val_loss:  0.2799537479877472
Epoch:  253 trn_loss:  0.22879669070243835 val_loss:  0.279446542263031
Epoch:  254 trn_loss:  0.22837336361408234 val_loss:  0.2789609432220459
Epoch:  255 trn_loss:  0.22794978320598602 val_loss:  0.2784919738769531
Epoch:  256 trn_loss:  0.22752612829208374 val_loss:  0.2780184745788574
Epoch:  257 trn_loss:  0.22710344195365906 val_loss:  0.27752435207366943
Epoch:  258 trn_loss:  0.22668224573135376 val_loss:  0.27700480818748474
Epoch:  259 trn_loss:  0.22626447677612305 val_loss:  0.27647534012794495
Epoch:  260 trn_loss:  0.2258533388376236 val_loss:  0.2759590148925781
Epoch:  261 trn_loss:  0.22545453906059265 val_loss:  0.2754822373390198
Epoch:  262 trn_loss:  0.22506257891654968 val_loss:  0.2750554084777832
Epoch:  263 trn_loss:  0.22467337548732758 val_loss:  0.2746695280075073
Epoch:  264 trn_loss:  0.22428232431411743 val_loss:  0.2743052840232849
Epoch:  265 trn_loss:  0.22388948500156403 val_loss:  0.2739402949810028
Epoch:  266 trn_loss:  0.2234969139099121 val_loss:  0.27357298135757446
Epoch:  267 trn_loss:  0.22310614585876465 val_loss:  0.2732067406177521
Epoch:  268 trn_loss:  0.2227174937725067 val_loss:  0.27283775806427
Epoch:  269 trn_loss:  0.2223328948020935 val_loss:  0.2724590301513672
Epoch:  270 trn_loss:  0.2219529151916504 val_loss:  0.27206385135650635
Epoch:  271 trn_loss:  0.22157585620880127 val_loss:  0.27165547013282776
Epoch:  272 trn_loss:  0.2211996465921402 val_loss:  0.27124887704849243
Epoch:  273 trn_loss:  0.22082313895225525 val_loss:  0.2708624601364136
Epoch:  274 trn_loss:  0.22044669091701508 val_loss:  0.2705067992210388
Epoch:  275 trn_loss:  0.22007158398628235 val_loss:  0.27017322182655334
Epoch:  276 trn_loss:  0.2196982353925705 val_loss:  0.26984238624572754
Epoch:  277 trn_loss:  0.21932564675807953 val_loss:  0.2694939374923706
Epoch:  278 trn_loss:  0.2189537137746811 val_loss:  0.26912862062454224
Epoch:  279 trn_loss:  0.21858231723308563 val_loss:  0.26875779032707214
Epoch:  280 trn_loss:  0.2182110995054245 val_loss:  0.2683919370174408
Epoch:  281 trn_loss:  0.21783941984176636 val_loss:  0.26803430914878845
Epoch:  282 trn_loss:  0.21746760606765747 val_loss:  0.26767662167549133
Epoch:  283 trn_loss:  0.21709734201431274 val_loss:  0.267311692237854
Epoch:  284 trn_loss:  0.21673113107681274 val_loss:  0.26692721247673035
Epoch:  285 trn_loss:  0.21636822819709778 val_loss:  0.2665267884731293
Epoch:  286 trn_loss:  0.21600845456123352 val_loss:  0.26613423228263855
Epoch:  287 trn_loss:  0.21565183997154236 val_loss:  0.2657707631587982
Epoch:  288 trn_loss:  0.2152981162071228 val_loss:  0.2654394507408142
Epoch:  289 trn_loss:  0.21494588255882263 val_loss:  0.26512253284454346
Epoch:  290 trn_loss:  0.2145947366952896 val_loss:  0.26479652523994446
Epoch:  291 trn_loss:  0.21424508094787598 val_loss:  0.26445162296295166
Epoch:  292 trn_loss:  0.2138969898223877 val_loss:  0.2640887498855591
Epoch:  293 trn_loss:  0.2135496437549591 val_loss:  0.26371198892593384
Epoch:  294 trn_loss:  0.21320249140262604 val_loss:  0.26331961154937744
Epoch:  295 trn_loss:  0.21285559237003326 val_loss:  0.26290369033813477
Epoch:  296 trn_loss:  0.2125086486339569 val_loss:  0.26246121525764465
Epoch:  297 trn_loss:  0.21215955913066864 val_loss:  0.2619968056678772
Epoch:  298 trn_loss:  0.21180710196495056 val_loss:  0.2615213394165039
Epoch:  299 trn_loss:  0.21145445108413696 val_loss:  0.2610474228858948
Epoch:  300 trn_loss:  0.21110428869724274 val_loss:  0.26057708263397217
Epoch:  301 trn_loss:  0.21076083183288574 val_loss:  0.26011136174201965
Epoch:  302 trn_loss:  0.2104257196187973 val_loss:  0.2596893310546875
Epoch:  303 trn_loss:  0.21008871495723724 val_loss:  0.2593444585800171
Epoch:  304 trn_loss:  0.20974841713905334 val_loss:  0.2590535581111908
Epoch:  305 trn_loss:  0.20940981805324554 val_loss:  0.2587902843952179
Epoch:  306 trn_loss:  0.20907369256019592 val_loss:  0.2585374116897583
Epoch:  307 trn_loss:  0.20873907208442688 val_loss:  0.2582770884037018
Epoch:  308 trn_loss:  0.20840512216091156 val_loss:  0.2579915523529053
Epoch:  309 trn_loss:  0.2080717235803604 val_loss:  0.25766631960868835
Epoch:  310 trn_loss:  0.2077386975288391 val_loss:  0.25730040669441223
Epoch:  311 trn_loss:  0.20740638673305511 val_loss:  0.2569100260734558
Epoch:  312 trn_loss:  0.20707452297210693 val_loss:  0.2565248906612396
Epoch:  313 trn_loss:  0.20674149692058563 val_loss:  0.25616714358329773
Epoch:  314 trn_loss:  0.2064056098461151 val_loss:  0.2558384835720062
Epoch:  315 trn_loss:  0.20606675744056702 val_loss:  0.2555282413959503
Epoch:  316 trn_loss:  0.20572897791862488 val_loss:  0.25522732734680176
Epoch:  317 trn_loss:  0.20540033280849457 val_loss:  0.2549355924129486
Epoch:  318 trn_loss:  0.20508302748203278 val_loss:  0.2546514868736267
Epoch:  319 trn_loss:  0.20477096736431122 val_loss:  0.2543649673461914
Epoch:  320 trn_loss:  0.20445936918258667 val_loss:  0.2540694773197174
Epoch:  321 trn_loss:  0.20414535701274872 val_loss:  0.25376734137535095
Epoch:  322 trn_loss:  0.203828364610672 val_loss:  0.25346270203590393
Epoch:  323 trn_loss:  0.2035098671913147 val_loss:  0.25315940380096436
Epoch:  324 trn_loss:  0.20319294929504395 val_loss:  0.2528567314147949
Epoch:  325 trn_loss:  0.2028798758983612 val_loss:  0.2525583505630493
Epoch:  326 trn_loss:  0.20257116854190826 val_loss:  0.25227075815200806
Epoch:  327 trn_loss:  0.2022656500339508 val_loss:  0.25199374556541443
Epoch:  328 trn_loss:  0.20196178555488586 val_loss:  0.25172460079193115
Epoch:  329 trn_loss:  0.20165878534317017 val_loss:  0.2514559030532837
Epoch:  330 trn_loss:  0.2013561874628067 val_loss:  0.2511786222457886
Epoch:  331 trn_loss:  0.2010539025068283 val_loss:  0.25088828802108765
Epoch:  332 trn_loss:  0.20075224339962006 val_loss:  0.2505873441696167
Epoch:  333 trn_loss:  0.20045195519924164 val_loss:  0.25028255581855774
Epoch:  334 trn_loss:  0.20015352964401245 val_loss:  0.2499808669090271
Epoch:  335 trn_loss:  0.19985702633857727 val_loss:  0.2496829777956009
Epoch:  336 trn_loss:  0.19956228137016296 val_loss:  0.249386727809906
Epoch:  337 trn_loss:  0.19926895201206207 val_loss:  0.24909284710884094
Epoch:  338 trn_loss:  0.19897693395614624 val_loss:  0.24880051612854004
Epoch:  339 trn_loss:  0.19868628680706024 val_loss:  0.24850903451442719
Epoch:  340 trn_loss:  0.1983967274427414 val_loss:  0.2482195943593979
Epoch:  341 trn_loss:  0.1981082260608673 val_loss:  0.24793249368667603
Epoch:  342 trn_loss:  0.19782096147537231 val_loss:  0.24764996767044067
Epoch:  343 trn_loss:  0.1975352168083191 val_loss:  0.2473699003458023
Epoch:  344 trn_loss:  0.1972511261701584 val_loss:  0.2470943182706833
Epoch:  345 trn_loss:  0.19696849584579468 val_loss:  0.24682076275348663
Epoch:  346 trn_loss:  0.1966872364282608 val_loss:  0.2465454488992691
Epoch:  347 trn_loss:  0.19640739262104034 val_loss:  0.2462620735168457
Epoch:  348 trn_loss:  0.19612900912761688 val_loss:  0.24597685039043427
Epoch:  349 trn_loss:  0.1958521604537964 val_loss:  0.24569553136825562
Epoch:  350 trn_loss:  0.19557689130306244 val_loss:  0.24541902542114258
Epoch:  351 trn_loss:  0.195303276181221 val_loss:  0.2451448142528534
Epoch:  352 trn_loss:  0.1950312852859497 val_loss:  0.24487270414829254
Epoch:  353 trn_loss:  0.19476088881492615 val_loss:  0.24460318684577942
Epoch:  354 trn_loss:  0.19449208676815033 val_loss:  0.24433715641498566
Epoch:  355 trn_loss:  0.19422470033168793 val_loss:  0.24407032132148743
Epoch:  356 trn_loss:  0.19395869970321655 val_loss:  0.24380047619342804
Epoch:  357 trn_loss:  0.19369405508041382 val_loss:  0.24352973699569702
Epoch:  358 trn_loss:  0.1934306025505066 val_loss:  0.24326191842556
Epoch:  359 trn_loss:  0.19316822290420532 val_loss:  0.2429966926574707
Epoch:  360 trn_loss:  0.19290682673454285 val_loss:  0.24273164570331573
Epoch:  361 trn_loss:  0.19264641404151917 val_loss:  0.24246689677238464
Epoch:  362 trn_loss:  0.19238699972629547 val_loss:  0.24220331013202667
Epoch:  363 trn_loss:  0.1921285092830658 val_loss:  0.24194449186325073
Epoch:  364 trn_loss:  0.1918710172176361 val_loss:  0.24168723821640015
Epoch:  365 trn_loss:  0.1916145533323288 val_loss:  0.24142853915691376
Epoch:  366 trn_loss:  0.19135892391204834 val_loss:  0.24116498231887817
Epoch:  367 trn_loss:  0.19110405445098877 val_loss:  0.24090419709682465
Epoch:  368 trn_loss:  0.19084981083869934 val_loss:  0.24064858257770538
Epoch:  369 trn_loss:  0.19059611856937408 val_loss:  0.2403944581747055
Epoch:  370 trn_loss:  0.19034279882907867 val_loss:  0.24014070630073547
Epoch:  371 trn_loss:  0.19008919596672058 val_loss:  0.23988883197307587
Epoch:  372 trn_loss:  0.1898353099822998 val_loss:  0.23963764309883118
Epoch:  373 trn_loss:  0.18958121538162231 val_loss:  0.23938502371311188
Epoch:  374 trn_loss:  0.18932710587978363 val_loss:  0.23913346230983734
[375]: 75.0% complete: 
Epoch:  375 trn_loss:  0.1890748292207718 val_loss:  0.238894522190094
Epoch:  376 trn_loss:  0.188825786113739 val_loss:  0.2386762797832489
Epoch:  377 trn_loss:  0.18857771158218384 val_loss:  0.23847529292106628
Epoch:  378 trn_loss:  0.18832935392856598 val_loss:  0.23828119039535522
Epoch:  379 trn_loss:  0.18808183073997498 val_loss:  0.23808833956718445
Epoch:  380 trn_loss:  0.18783555924892426 val_loss:  0.2378917783498764
Epoch:  381 trn_loss:  0.18759050965309143 val_loss:  0.23767992854118347
Epoch:  382 trn_loss:  0.18734709918498993 val_loss:  0.23744869232177734
Epoch:  383 trn_loss:  0.18710505962371826 val_loss:  0.23721159994602203
Epoch:  384 trn_loss:  0.18686382472515106 val_loss:  0.23697692155838013
Epoch:  385 trn_loss:  0.186623215675354 val_loss:  0.23673711717128754
Epoch:  386 trn_loss:  0.1863832175731659 val_loss:  0.23647886514663696
Epoch:  387 trn_loss:  0.1861439198255539 val_loss:  0.23620271682739258
Epoch:  388 trn_loss:  0.18590551614761353 val_loss:  0.23592042922973633
Epoch:  389 trn_loss:  0.18566836416721344 val_loss:  0.23563691973686218
Epoch:  390 trn_loss:  0.1854325234889984 val_loss:  0.23535510897636414
Epoch:  391 trn_loss:  0.18519799411296844 val_loss:  0.235079824924469
Epoch:  392 trn_loss:  0.1849648356437683 val_loss:  0.2348124235868454
Epoch:  393 trn_loss:  0.18473288416862488 val_loss:  0.2345460057258606
Epoch:  394 trn_loss:  0.1845020055770874 val_loss:  0.2342797964811325
Epoch:  395 trn_loss:  0.18427212536334991 val_loss:  0.23402251303195953
Epoch:  396 trn_loss:  0.18404310941696167 val_loss:  0.23377874493598938
Epoch:  397 trn_loss:  0.18381494283676147 val_loss:  0.23354452848434448
Epoch:  398 trn_loss:  0.18358762562274933 val_loss:  0.23331600427627563
Epoch:  399 trn_loss:  0.18336105346679688 val_loss:  0.23309046030044556
Epoch:  400 trn_loss:  0.18313530087471008 val_loss:  0.232865571975708
Epoch:  401 trn_loss:  0.18291045725345612 val_loss:  0.23263958096504211
Epoch:  402 trn_loss:  0.18268656730651855 val_loss:  0.2324121743440628
Epoch:  403 trn_loss:  0.1824636459350586 val_loss:  0.23218554258346558
Epoch:  404 trn_loss:  0.18224164843559265 val_loss:  0.2319575995206833
Epoch:  405 trn_loss:  0.18202045559883118 val_loss:  0.23172275722026825
Epoch:  406 trn_loss:  0.1817999631166458 val_loss:  0.23148134350776672
Epoch:  407 trn_loss:  0.18158024549484253 val_loss:  0.23123939335346222
Epoch:  408 trn_loss:  0.18136116862297058 val_loss:  0.2310003638267517
Epoch:  409 trn_loss:  0.18114276230335236 val_loss:  0.2307618409395218
Epoch:  410 trn_loss:  0.18092504143714905 val_loss:  0.2305237501859665
Epoch:  411 trn_loss:  0.18070799112319946 val_loss:  0.23028503358364105
Epoch:  412 trn_loss:  0.18049176037311554 val_loss:  0.23004287481307983
Epoch:  413 trn_loss:  0.18027636408805847 val_loss:  0.22979889810085297
Epoch:  414 trn_loss:  0.18006153404712677 val_loss:  0.22955799102783203
Epoch:  415 trn_loss:  0.1798473596572876 val_loss:  0.22932063043117523
Epoch:  416 trn_loss:  0.17963387072086334 val_loss:  0.22908517718315125
Epoch:  417 trn_loss:  0.1794208437204361 val_loss:  0.22885063290596008
Epoch:  418 trn_loss:  0.17920824885368347 val_loss:  0.22861893475055695
Epoch:  419 trn_loss:  0.17899605631828308 val_loss:  0.22838878631591797
Epoch:  420 trn_loss:  0.17878447473049164 val_loss:  0.2281588912010193
Epoch:  421 trn_loss:  0.17857390642166138 val_loss:  0.22793038189411163
Epoch:  422 trn_loss:  0.17836421728134155 val_loss:  0.22770363092422485
Epoch:  423 trn_loss:  0.17815512418746948 val_loss:  0.22747738659381866
Epoch:  424 trn_loss:  0.17794662714004517 val_loss:  0.2272491157054901
Epoch:  425 trn_loss:  0.17773878574371338 val_loss:  0.22702032327651978
Epoch:  426 trn_loss:  0.1775316596031189 val_loss:  0.2267915904521942
Epoch:  427 trn_loss:  0.1773253083229065 val_loss:  0.22655746340751648
Epoch:  428 trn_loss:  0.17711952328681946 val_loss:  0.22632186114788055
Epoch:  429 trn_loss:  0.1769145131111145 val_loss:  0.226089209318161
Epoch:  430 trn_loss:  0.1767100691795349 val_loss:  0.22585827112197876
Epoch:  431 trn_loss:  0.17650587856769562 val_loss:  0.22562970221042633
Epoch:  432 trn_loss:  0.17630203068256378 val_loss:  0.22540447115898132
Epoch:  433 trn_loss:  0.17609840631484985 val_loss:  0.22517959773540497
Epoch:  434 trn_loss:  0.1758948713541031 val_loss:  0.22495406866073608
Epoch:  435 trn_loss:  0.17569130659103394 val_loss:  0.22472889721393585
Epoch:  436 trn_loss:  0.17548780143260956 val_loss:  0.22449976205825806
Epoch:  437 trn_loss:  0.1752830147743225 val_loss:  0.22426454722881317
Epoch:  438 trn_loss:  0.17507389187812805 val_loss:  0.22402794659137726
Epoch:  439 trn_loss:  0.1748594343662262 val_loss:  0.22379425168037415
Epoch:  440 trn_loss:  0.174647256731987 val_loss:  0.22356335818767548
Epoch:  441 trn_loss:  0.17444376647472382 val_loss:  0.2233429253101349
Epoch:  442 trn_loss:  0.17423875629901886 val_loss:  0.22312171757221222
Epoch:  443 trn_loss:  0.1740340143442154 val_loss:  0.22290436923503876
Epoch:  444 trn_loss:  0.17383411526679993 val_loss:  0.22268684208393097
Epoch:  445 trn_loss:  0.17363642156124115 val_loss:  0.2224578857421875
Epoch:  446 trn_loss:  0.1734364777803421 val_loss:  0.22221679985523224
Epoch:  447 trn_loss:  0.17323671281337738 val_loss:  0.2219763845205307
Epoch:  448 trn_loss:  0.17303945124149323 val_loss:  0.22174912691116333
Epoch:  449 trn_loss:  0.17284171283245087 val_loss:  0.2215387225151062
Epoch:  450 trn_loss:  0.17264242470264435 val_loss:  0.22133596241474152
Epoch:  451 trn_loss:  0.1724439114332199 val_loss:  0.22113367915153503
Epoch:  452 trn_loss:  0.17224597930908203 val_loss:  0.2209206223487854
Epoch:  453 trn_loss:  0.1720469892024994 val_loss:  0.22069334983825684
Epoch:  454 trn_loss:  0.1718476265668869 val_loss:  0.22045543789863586
Epoch:  455 trn_loss:  0.17164887487888336 val_loss:  0.22021572291851044
Epoch:  456 trn_loss:  0.17144978046417236 val_loss:  0.21998310089111328
Epoch:  457 trn_loss:  0.17124979197978973 val_loss:  0.21976037323474884
Epoch:  458 trn_loss:  0.17104962468147278 val_loss:  0.21954482793807983
Epoch:  459 trn_loss:  0.1708494871854782 val_loss:  0.21932977437973022
Epoch:  460 trn_loss:  0.1706482172012329 val_loss:  0.21910595893859863
Epoch:  461 trn_loss:  0.17044557631015778 val_loss:  0.2188740372657776
Epoch:  462 trn_loss:  0.17024178802967072 val_loss:  0.2186402529478073
Epoch:  463 trn_loss:  0.17003701627254486 val_loss:  0.21840617060661316
Epoch:  464 trn_loss:  0.16982880234718323 val_loss:  0.21817241609096527
Epoch:  465 trn_loss:  0.16961956024169922 val_loss:  0.21793916821479797
Epoch:  466 trn_loss:  0.1694081574678421 val_loss:  0.21770605444908142
Epoch:  467 trn_loss:  0.16919618844985962 val_loss:  0.21747441589832306
Epoch:  468 trn_loss:  0.16898375749588013 val_loss:  0.21724523603916168
Epoch:  469 trn_loss:  0.1687786728143692 val_loss:  0.21702449023723602
Epoch:  470 trn_loss:  0.16857337951660156 val_loss:  0.21684497594833374
Epoch:  471 trn_loss:  0.16836611926555634 val_loss:  0.21668770909309387
Epoch:  472 trn_loss:  0.16815941035747528 val_loss:  0.2165074646472931
Epoch:  473 trn_loss:  0.16793929040431976 val_loss:  0.2163211703300476
Epoch:  474 trn_loss:  0.1677149534225464 val_loss:  0.21613281965255737
Epoch:  475 trn_loss:  0.16749779880046844 val_loss:  0.21588772535324097
Epoch:  476 trn_loss:  0.1672801524400711 val_loss:  0.21551799774169922
Epoch:  477 trn_loss:  0.16706915199756622 val_loss:  0.2152245193719864
Epoch:  478 trn_loss:  0.16686460375785828 val_loss:  0.2150420993566513
Epoch:  479 trn_loss:  0.16666196286678314 val_loss:  0.21491365134716034
Epoch:  480 trn_loss:  0.16646282374858856 val_loss:  0.21477101743221283
Epoch:  481 trn_loss:  0.166264608502388 val_loss:  0.21452462673187256
Epoch:  482 trn_loss:  0.16607001423835754 val_loss:  0.21423496305942535
Epoch:  483 trn_loss:  0.16587771475315094 val_loss:  0.21401453018188477
Epoch:  484 trn_loss:  0.1656862050294876 val_loss:  0.21378809213638306
Epoch:  485 trn_loss:  0.1654946357011795 val_loss:  0.21348606050014496
Epoch:  486 trn_loss:  0.16530491411685944 val_loss:  0.213189035654068
Epoch:  487 trn_loss:  0.16511470079421997 val_loss:  0.2129574567079544
Epoch:  488 trn_loss:  0.16492301225662231 val_loss:  0.21279770135879517
Epoch:  489 trn_loss:  0.16473230719566345 val_loss:  0.21266847848892212
Epoch:  490 trn_loss:  0.16454045474529266 val_loss:  0.21247227489948273
Epoch:  491 trn_loss:  0.1643500030040741 val_loss:  0.21223384141921997
Epoch:  492 trn_loss:  0.16416028141975403 val_loss:  0.21204014122486115
Epoch:  493 trn_loss:  0.16397173702716827 val_loss:  0.21185758709907532
Epoch:  494 trn_loss:  0.16378480195999146 val_loss:  0.21161359548568726
Epoch:  495 trn_loss:  0.16359752416610718 val_loss:  0.21132279932498932
Epoch:  496 trn_loss:  0.1634114533662796 val_loss:  0.2110316902399063
Epoch:  497 trn_loss:  0.16322572529315948 val_loss:  0.21078820526599884
Epoch:  498 trn_loss:  0.16304101049900055 val_loss:  0.21059785783290863
Epoch:  499 trn_loss:  0.16285713016986847 val_loss:  0.21039755642414093
Training time = 150.80658340454102 seconds

Plotting: MSE Loss for Training and Validation

In order to understand how well the model has trained we plot the training loss and validation loss as a function of Epoch in Figure 2. Figure 2 shows the MSE loss for training (blue) and validation (orange) as a function of epoch.

In [11]:
plt.plot(np.arange(len(train_loss[:])), train_loss[:], color="blue")
plt.plot(np.arange(len(train_loss[:]))+1, valdn_loss[:], color="orange")
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.show()

Figure 2: Training and Validation MSE loss (blue, orange) as a function of Epoch.


Step 4: Testing the Model

Now that the model has been trained, testing the model is a computationally cheap proceedure. As before, we choose the data using DEMdata, and load with DataLoader. Using valtest_model, the DeepEM map is created ${\texttt{output = model(img)}}$, and the MSE loss calculated as during training.

In [12]:
dem_data_test = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='test')
dem_loader = DataLoader(dem_data_test, batch_size=1)

t0=time.time() #Timing how long it takes to predict the DEMs
dummy, test_loss, dem_pred, dem_in_test = valtest_model(dem_loader, criterion)
performance = "Number of DEM solutions per second = {0}".format((y_test.shape[2]*y_test.shape[3])/(time.time()-t0))

print(performance)
Number of DEM solutions per second = 2983061.2417108673

Plotting: AIA, Basis Pursuit, DeepEM

With the DeepEM map calculated, we can now compare the solutions obtained by Basis Pursuit and DeepEM. Figure 3 is similar to Figure 1 with an additional row corresponding to the solutions for DeepEM. Figure 3 shows ...

In [13]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_in_test[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_in_test[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_in_test[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_pred[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_pred[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_pred[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)

Figure 3: Left to Right: SDO/AIA images in 171 Å, 211 Å, and 94 Å (top), with the corresponding DEM bins from Basis Pursuit (chosen at the peak sensitivity of each of the SDO/AIA channels) shown below (middle). The bottom row shows the DeepEM solutions that correspond to the same bins as the Basis Pursuit solutions. DeepEM provides solutions that are similar to Basis Pursuit, but importantly, provides DEM solutions for every pixel.


Furthermore, as we have the original Basis Pursuit DEM solutions (the ground truth), we can compare the average DEM from Basis Pursuit to the average DEM from DeepEM, as they should be similar. Figure 4 shows the average Basis Pursuit DEM (black curve) and the DeepEM solution (light blue bars/dotted line).

In [14]:
def PlotTotalEM(em_unscaled, em_pred_unscaled, lgtaxis, status):
    mask = np.zeros([status.shape[0],status.shape[1]])
    mask[np.where(status == 0.0)] = 1.0
    nmask = np.sum(mask)
    
    EM_tru_sum = np.zeros([lgtaxis.size])
    EM_inv_sum = np.zeros([lgtaxis.size])
    
    for i in range(lgtaxis.size):
        EM_tru_sum[i] = np.sum(em_unscaled[0,i,:,:]*mask)/nmask
        EM_inv_sum[i] = np.sum(em_pred_unscaled[0,i,:,:]*mask)/nmask
        
    fig = plt.figure   
    plt.plot(lgtaxis,EM_tru_sum, linewidth=3, color="black")
    plt.plot(lgtaxis,EM_inv_sum, linewidth=3, color="lightblue", linestyle='--')
    plt.tick_params(axis='both', which='major')#, labelsize=16)
    plt.tick_params(axis='both', which='minor')#, labelsize=16)
    
    dlogT = lgtaxis[1]-lgtaxis[0]
    plt.bar(lgtaxis-0.5*dlogT, EM_inv_sum, dlogT, linewidth=2, color='lightblue')
    
    plt.xlim(lgtaxis[0]-0.5*dlogT, lgtaxis.max()+0.5*dlogT)
    plt.xticks(np.arange(np.min(lgtaxis), np.max(lgtaxis),2*dlogT))
    plt.ylim(1e24,1e27)
    plt.yscale('log')
    plt.xlabel('log$_{10}$T [K]')
    plt.ylabel('Mean Emission Measure [cm$^{-5}$]')
    plt.title('Basis Pursuit (curve) vs. DeepEM (bars)')
    
    plt.show()
    return EM_inv_sum, EM_tru_sum
In [15]:
em_unscaled = em_unscale(dem_in_test.detach().cpu().numpy())
em_pred_unscaled = em_unscale(dem_pred.detach().cpu().numpy())
status = np.zeros([512,512])
                   
EMinv, EMTru = PlotTotalEM(em_unscaled,em_pred_unscaled,lgtaxis,status)

Figure 4: Average Basis Pursuit DEM (black line) against the Average DeepEM solution (light blue bars/dotted line). It is clear that this simple implementation of DeepEM provides, on average, DEMs that are similar to Basis Pursuit (Cheung et al 2015).


Step 5: Synthesize SDO/AIA Observations

Finally, it is also of interest to reconstruct the SDO/AIA observations from both the Basis Pursuit, and DeepEM solutions.

We are able to pose the problem of reconstructing the SDO/AIA observations from the DEM as a 1x1 2D Convolution. We first define the weights as the response functions of each channel, and set the biases to $zero$. By convolving the unscaled DEM at each pixel with the 6 filters (one for each SDO/AIA response function), we can recover the SDO/AIA observations.

In [16]:
# We first load the AIA response functions:
cl = np.load('chianti_lines_AIA.npy')
In [17]:
# Used Conv2d to convolve?? every pixel (18x1x1) by the 6 response functions
# to return a set of observed fluxes in each channel (6x1x1)
dem2aia = nn.Conv2d(18, 6, kernel_size=1).cuda()

chianti_lines_2 = torch.zeros(6,18,1,1).cuda()
biases = torch.zeros(6).cuda()

# set the weights to each of the SDO/AIA response functions and biases to zero
for i, p in enumerate(dem2aia.parameters()):
    if i == 0:
        p.data = Variable(torch.from_numpy(cl).type(torch.cuda.FloatTensor))
    else:
        p.data = biases 
In [18]:
AIA_out = img_scale(dem2aia(Variable(em_unscale(dem_in_test))).detach().cpu().numpy())
AIA_out_DeepEM = img_scale(dem2aia(Variable(em_unscale(dem_pred))).detach().cpu().numpy())

Plotting SDO/AIA Observations and Synthetic Observations

In [19]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(AIA_out[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'Basis Pursuit Synthesized 211 $\AA$', color="white", size='large')
ax[0].imshow(AIA_out[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, 'Basis Pursuit Synthesized 171 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'Basis Pursuit Synthesized 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(AIA_out_DeepEM[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'DeepEM Synthesized 211 $\AA$', color="white", size='large')
ax[0].imshow(AIA_out_DeepEM[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, 'DeepEM Synthesized 171 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out_DeepEM[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'DeepEM Synthesized 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    

Figure 5: Left to Right: SDO/AIA images in 171 Å, 211 Å, and 94 Å (top) with the corresponding synthesised observations from Basis Pursuit (middle) and DeepEM (bottom). DeepEM provides synthetic observations that are similar to Basis Pursuit, with the addition of solutions where the basis pursuit solution was $zero$.


Discussion

This chapter has provided an example of how a 1x1 2D Convolutional Neural Network can be used to improve computational cost for DEM inversion. Future improvements to DeepEM can come in a few ways.

First, by using both the original, and synthesised data from the DEM, the ability of the DEM to recover the original or supplementary data can be used as a additional term in the loss function.

  • Use SDO/AIA Data to correct the DEMs
  • Use MEGS-A EUV to correct the DEMs
  • Use Hard X-ray observations to correct the DEMs
In [ ]:
 
+-----------------------------------------------------------------------------+ | NVIDIA-SMI 390.59 Driver Version: 390.59 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | |===============================+======================+======================| | 0 GeForce GTX 105... Off | 00000000:01:00.0 On | N/A | | 20% 48C P0 N/A / 75W | 3821MiB / 4039MiB | 100% Default | +-------------------------------+----------------------+----------------------+