Skip to content

Instantly share code, notes, and snippets.

@lubomir
Created January 10, 2014 13:41
Show Gist options
  • Save lubomir/5d84c6da5b53c7d95cde to your computer and use it in GitHub Desktop.
Save lubomir/5d84c6da5b53c7d95cde to your computer and use it in GitHub Desktop.
Part 3 of final assignment for Complex Valued Neural Networks with Multi Valued Neurons course
module Iris where
import Data.Vector (Vector)
import Control.Arrow
import qualified Data.Vector as V
import Data.Complex
learning :: Vector (Vector (Complex Double), Int)
learning = V.fromList $ map (first $ V.fromList . map cis)
[ ([0.31, 0.873, 0.095, 0.058], 0)
, ([0.233, 0.582, 0.095, 0.058], 0)
, ([0.155, 0.698, 0.071, 0.058], 0)
, ([0.116, 0.64, 0.118, 0.058], 0)
, ([0.271, 0.931, 0.095, 0.058], 0)
, ([0.427, 1.105, 0.166, 0.175], 0)
, ([0.116, 0.814, 0.095, 0.116], 0)
, ([0.271, 0.814, 0.118, 0.058], 0)
, ([0.039, 0.524, 0.095, 0.058], 0)
, ([0.233, 0.64, 0.118, 0], 0)
, ([0.427, 0.989, 0.118, 0.058], 0)
, ([0.194, 0.814, 0.142, 0.058], 0)
, ([0.194, 0.582, 0.095, 0], 0)
, ([0, 0.582, 0.024, 0], 0)
, ([0.582, 1.164, 0.047, 0.058], 0)
, ([0.543, 1.396, 0.118, 0.175], 0)
, ([0.427, 1.105, 0.071, 0.175], 0)
, ([0.31, 0.873, 0.095, 0.116], 0)
, ([0.543, 1.047, 0.166, 0.116], 0)
, ([0.31, 1.047, 0.118, 0.116], 0)
, ([0.427, 0.814, 0.166, 0.058], 0)
, ([0.31, 0.989, 0.118, 0.175], 0)
, ([0.116, 0.931, 0, 0.058], 0)
, ([0.31, 0.756, 0.166, 0.233], 0)
, ([0.194, 0.814, 0.213, 0.058], 0)
, ([0.271, 0.582, 0.142, 0.058], 0)
, ([0.271, 0.814, 0.142, 0.175], 0)
, ([0.349, 0.873, 0.118, 0.058], 0)
, ([0.349, 0.814, 0.095, 0.058], 0)
, ([0.155, 0.698, 0.142, 0.058], 0)
, ([0.621, 0.698, 0.899, 0.989], 1)
, ([0.698, 0.465, 0.71, 0.698], 1)
, ([0.776, 0.291, 0.923, 0.814], 1)
, ([0.698, 0.465, 0.876, 0.64], 1)
, ([0.814, 0.524, 0.781, 0.698], 1)
, ([0.892, 0.582, 0.805, 0.756], 1)
, ([0.97, 0.465, 0.899, 0.756], 1)
, ([0.931, 0.582, 0.947, 0.931], 1)
, ([0.659, 0.524, 0.828, 0.814], 1)
, ([0.543, 0.349, 0.592, 0.524], 1)
, ([0.465, 0.233, 0.663, 0.582], 1)
, ([0.465, 0.233, 0.639, 0.524], 1)
, ([0.582, 0.407, 0.686, 0.64], 1)
, ([0.659, 0.407, 0.97, 0.873], 1)
, ([0.427, 0.582, 0.828, 0.814], 1)
, ([0.659, 0.814, 0.828, 0.873], 1)
, ([0.931, 0.64, 0.876, 0.814], 1)
, ([0.776, 0.175, 0.805, 0.698], 1)
, ([0.504, 0.582, 0.734, 0.698], 1)
, ([0.465, 0.291, 0.71, 0.698], 1)
, ([0.465, 0.349, 0.805, 0.64], 1)
, ([0.698, 0.582, 0.852, 0.756], 1)
, ([0.582, 0.349, 0.71, 0.64], 1)
, ([0.271, 0.175, 0.544, 0.524], 1)
, ([0.504, 0.407, 0.757, 0.698], 1)
, ([0.543, 0.582, 0.757, 0.64], 1)
, ([0.543, 0.524, 0.757, 0.698], 1)
, ([0.737, 0.524, 0.781, 0.698], 1)
, ([0.31, 0.291, 0.473, 0.582], 1)
, ([0.543, 0.465, 0.734, 0.698], 1)
, ([0.776, 0.756, 1.183, 1.396], 2)
, ([0.582, 0.407, 0.97, 1.047], 2)
, ([1.086, 0.582, 1.16, 1.164], 2)
, ([0.776, 0.524, 1.089, 0.989], 2)
, ([0.853, 0.582, 1.136, 1.222], 2)
, ([1.28, 0.582, 1.325, 1.164], 2)
, ([0.233, 0.291, 0.828, 0.931], 2)
, ([1.164, 0.524, 1.254, 0.989], 2)
, ([0.931, 0.291, 1.136, 0.989], 2)
, ([1.125, 0.931, 1.207, 1.396], 2)
, ([0.853, 0.698, 0.97, 1.105], 2)
, ([0.814, 0.407, 1.018, 1.047], 2)
, ([0.97, 0.582, 1.065, 1.164], 2)
, ([0.543, 0.291, 0.947, 1.105], 2)
, ([0.582, 0.465, 0.97, 1.338], 2)
, ([0.814, 0.698, 1.018, 1.28], 2)
, ([0.853, 0.582, 1.065, 0.989], 2)
, ([1.319, 1.047, 1.349, 1.222], 2)
, ([1.319, 0.349, 1.396, 1.28], 2)
, ([0.659, 0.116, 0.947, 0.814], 2)
, ([1.008, 0.698, 1.112, 1.28], 2)
, ([0.504, 0.465, 0.923, 1.105], 2)
, ([1.319, 0.465, 1.349, 1.105], 2)
, ([0.776, 0.407, 0.923, 0.989], 2)
, ([0.931, 0.756, 1.112, 1.164], 2)
, ([1.125, 0.698, 1.183, 0.989], 2)
, ([0.737, 0.465, 0.899, 0.989], 2)
, ([0.698, 0.582, 0.923, 0.989], 2)
, ([0.814, 0.465, 1.089, 1.164], 2)
, ([1.125, 0.582, 1.136, 0.873], 2)
]
testing :: Vector (Vector (Complex Double), Int)
testing = V.fromList $ map (first $ V.fromList . map cis)
[ ([0.194, 0.64, 0.142, 0.058], 0)
, ([0.427, 0.814, 0.118, 0.175], 0)
, ([0.349, 1.222, 0.118, 0], 0)
, ([0.465, 1.28, 0.095, 0.058], 0)
, ([0.233, 0.64, 0.118, 0.058], 0)
, ([0.271, 0.698, 0.047, 0.058], 0)
, ([0.465, 0.873, 0.071, 0.058], 0)
, ([0.233, 0.931, 0.095, 0], 0)
, ([0.039, 0.582, 0.071, 0.058], 0)
, ([0.31, 0.814, 0.118, 0.058], 0)
, ([0.271, 0.873, 0.071, 0.116], 0)
, ([0.078, 0.175, 0.071, 0.116], 0)
, ([0.039, 0.698, 0.071, 0.058], 0)
, ([0.271, 0.873, 0.142, 0.291], 0)
, ([0.31, 1.047, 0.213, 0.175], 0)
, ([0.194, 0.582, 0.095, 0.116], 0)
, ([0.31, 1.047, 0.142, 0.058], 0)
, ([0.116, 0.698, 0.095, 0.058], 0)
, ([0.388, 0.989, 0.118, 0.058], 0)
, ([0.271, 0.756, 0.095, 0.058], 0)
, ([1.047, 0.698, 0.876, 0.756], 1)
, ([0.814, 0.698, 0.828, 0.814], 1)
, ([1.008, 0.64, 0.923, 0.814], 1)
, ([0.465, 0.175, 0.71, 0.698], 1)
, ([0.853, 0.465, 0.852, 0.814], 1)
, ([0.543, 0.465, 0.828, 0.698], 1)
, ([0.776, 0.756, 0.876, 0.873], 1)
, ([0.233, 0.233, 0.544, 0.524], 1)
, ([0.892, 0.524, 0.852, 0.698], 1)
, ([0.349, 0.407, 0.686, 0.756], 1)
, ([0.271, 0, 0.592, 0.524], 1)
, ([0.621, 0.582, 0.757, 0.814], 1)
, ([0.659, 0.116, 0.71, 0.524], 1)
, ([0.698, 0.524, 0.876, 0.756], 1)
, ([0.504, 0.524, 0.615, 0.698], 1)
, ([0.931, 0.64, 0.805, 0.756], 1)
, ([0.504, 0.582, 0.828, 0.814], 1)
, ([0.582, 0.407, 0.734, 0.524], 1)
, ([0.737, 0.116, 0.828, 0.814], 1)
, ([0.504, 0.291, 0.686, 0.582], 1)
, ([1.202, 0.465, 1.207, 1.047], 2)
, ([1.396, 1.047, 1.278, 1.105], 2)
, ([0.814, 0.465, 1.089, 1.222], 2)
, ([0.776, 0.465, 0.97, 0.814], 2)
, ([0.698, 0.349, 1.089, 0.756], 2)
, ([1.319, 0.582, 1.207, 1.28], 2)
, ([0.776, 0.814, 1.089, 1.338], 2)
, ([0.814, 0.64, 1.065, 0.989], 2)
, ([0.659, 0.582, 0.899, 0.989], 2)
, ([1.008, 0.64, 1.041, 1.164], 2)
, ([0.931, 0.64, 1.089, 1.338], 2)
, ([1.008, 0.64, 0.97, 1.28], 2)
, ([0.582, 0.407, 0.97, 1.047], 2)
, ([0.97, 0.698, 1.16, 1.28], 2)
, ([0.931, 0.756, 1.112, 1.396], 2)
, ([0.931, 0.582, 0.994, 1.28], 2)
, ([0.776, 0.291, 0.947, 1.047], 2)
, ([0.853, 0.582, 0.994, 1.105], 2)
, ([0.737, 0.814, 1.041, 1.28], 2)
, ([0.621, 0.582, 0.97, 0.989], 2)
]
import Control.Arrow
import Control.Monad (zipWithM_)
import Data.Complex
import Data.List (minimumBy)
import Data.Maybe (isNothing)
import Data.Vector (Vector)
import qualified Data.Vector as V
import System.Random
import Text.Printf
import Iris
-- |Number of sectors to use for classification.
type NumSectors = Int
-- |Index of a sector. First sector has number 0.
type Sector = Int
-- |Tuple of input values and a expected output.
type LearningSample w r = (Vector w, r)
data Options = Options { numSectors :: Int
, periodicity :: Int
, learningRate :: Complex Double
}
data MVNP w = MVNP (Vector w)
deriving (Show)
instance (Random a, RealFloat a) => Random (Complex a) where
randomR rng g = let (real, g') = randomR (realPart *** realPart $ rng) g
(imag, g'') = randomR (imagPart *** imagPart $ rng) g'
in (real :+ imag, g'')
random g = let (real,g') = random g
(imag,g'') = random g'
in (real :+ imag, g'')
-- |Given a number of sectors, periodicity coefficient and a point in complex
-- plane, find sector to which the point belongs.
--
periodicActivation :: (RealFloat a) => Options -> Complex a -> Sector
periodicActivation Options{numSectors=n,periodicity=l} s =
basicActivation (n*l) s `mod` n
-- |Given number of sectors and a point in complex plane, find the sector to
-- which this point belongs.
--
basicActivation :: (RealFloat a)
=> NumSectors -- ^Number of sectors
-> Complex a -- ^Argument
-> Sector -- ^Index of result sector
basicActivation n s = floor $ posPhase s / sectBound n 1
-- |Runs neuron on a given input vector and return the resulting sector.
--
runNeuron :: (RealFloat w)
=> Options
-> MVNP (Complex w)
-> Vector (Complex w)
-> Sector
runNeuron opts n i = periodicActivation opts $ getWeightedSum n i
-- |Compute phase a of a complex number. The results is always positive,
-- ranging from 0 to π.
--
posPhase :: RealFloat a => Complex a -> a
posPhase x = let phi = phase x
in if phi < 0 then phi + 2 * pi
else phi
-- |For a number of sectors, periodicity coefficient and and a point in complex
-- plane, find phase of sector bound closest to the point that is marked the
-- same as a given sector.
--
closestSector :: (RealFloat a, Ord a) => Options -> Sector -> Complex a -> a
closestSector Options{numSectors=s, periodicity=l} r x =
sectBound (s*l) $ minimumBy cmp [r,r+s..s*l-1]
where
distance m = let low = clamp pi $ sectBound (s*l) m
high = clamp pi $ sectBound (s*l) ((m+1) `mod` (s*l))
in min (abs $ low - phase x) (abs $ high - phase x)
cmp m n = compare (distance m) (distance n)
-- |Make sure number phase is within given range.
--
clamp :: (Num a, Ord a) => a -> a -> a
clamp m n = if n > m then clamp m (n-m) else n
-- |Update weigths of a neuron on given input. The error is passed as an
-- argument and is not checked for correctness.
--
updateWeights :: (RealFloat a)
=> MVNP (Complex a)
-> Vector (Complex a) -- ^Input sample
-> Complex a -- ^Error on this sample
-> MVNP (Complex a) -- ^Updated neuron
updateWeights (MVNP ws) i e = MVNP $ V.zipWith update ws (V.cons 1 i)
where
rate = 1
n = fromIntegral $ V.length i + 1
update w x = w + rate * e * conjugate x / n
--- |Find a lower bound of a sector. Returns the phase of the boundary.
--
sectBound :: (Floating a, Ord a) => NumSectors -> Sector -> a
sectBound n x = fromIntegral x * 2 * pi / fromIntegral n
-- |For a give input vector, find the weighted sum that the neuron computes.
-- The number of inputs should be 1 less than the number of weigths, as the
-- input bias is added automatically.
--
getWeightedSum :: (Num w)
=> MVNP w -- ^Actual neuron
-> Vector w -- ^Input without the leading 1
-> w -- ^Result
getWeightedSum (MVNP ws) inp = V.sum $ V.zipWith (*) ws (1 `V.cons` inp)
printWeights :: Vector (Complex Double) -> IO ()
printWeights vs = zipWithM_ p (V.toList vs) [0..]
where
p :: Complex Double -> Int -> IO ()
p w i = printf "w_%d = %f + %f i\n" i (realPart w) (imagPart w)
-- |Test whether neuron returns correct answer for given sample. If it does
-- not, return the error, otherwise return `Nothing`.
--
testSample :: (RealFloat a, Show a)
=> Options
-> MVNP (Complex a)
-> LearningSample (Complex a) Sector
-> Maybe (Complex a) -- ^Maybe error
testSample opts n (i,r) =
let actual = runNeuron opts n i
z = getWeightedSum n i
in case compare actual r of
EQ -> Nothing
_ -> Just $ closest - cis (sectBound total (basicActivation total z))
where
closest = cis $ closestSector opts r z
total = numSectors opts * periodicity opts
runErrorCorrection :: (RealFloat a, Show a)
=> Options
-> MVNP (Complex a)
-> Vector (LearningSample (Complex a) Sector)
-> (MVNP (Complex a), Int)
runErrorCorrection opts n' d' = go True 1 n' d'
where
go r i n d
| V.null d = if r then (n,i) else go True (i+1) n d'
| otherwise = do
let (inp,expected) = V.head d
case testSample opts n (inp, expected) of
Nothing -> go (r && True) i n (V.tail d)
Just err -> go False i (updateWeights n inp err) (V.tail d)
main :: IO ()
main = do
let opts = Options{numSectors=3, periodicity=3, learningRate=1}
initial <- V.replicateM 5 $ randomRIO ((-0.5):+(-0.5),0.5:+0.5)
putStrLn "initial weights: " >> printWeights initial
let (result@(MVNP ws), iters) = runErrorCorrection opts
(MVNP initial)
learning
putStrLn "final weights: " >> printWeights ws
let correct = V.length $ V.filter isNothing $ V.map (testSample opts result) testing
let total = V.length testing
let accuracy = fromIntegral correct / fromIntegral total * 100 :: Double
putStrLn $ "Done after "++show iters++" iterations"
putStrLn $ "Correct "++show correct++" ("++show accuracy++" %)"
{- Run with
$ for i in $(seq 1 50); do ./part3 | tail -n1; done | awk '{s+=$2;n+=1}END{print s/n}'
57.8
-}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment