{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE OverloadedStrings, ScopedTypeVariables, PatternGuards #-}

-- | Haskell indenter.

module HIndent
  (-- * Formatting functions.
   reformat
  ,prettyPrint
  ,parseMode
  -- * Testing
  ,test
  ,testFile
  ,testAst
  ,testFileAst
  ,defaultExtensions
  ,getExtensions
  )
  where

import           Control.Monad.State.Strict
import           Control.Monad.Trans.Maybe
import           Data.ByteString (ByteString)
import qualified Data.ByteString as S
import           Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as S
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Internal as S
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Lazy.Char8 as L8
import qualified Data.ByteString.UTF8 as UTF8
import qualified Data.ByteString.Unsafe as S
import           Data.Char
import           Data.Foldable (foldr')
import           Data.Either
import           Data.Function
import           Data.Functor.Identity
import           Data.List
import           Data.Maybe
import           Data.Monoid
import           Data.Text (Text)
import qualified Data.Text as T
import           Data.Traversable hiding (mapM)
import           HIndent.CodeBlock
import           HIndent.Pretty
import           HIndent.Types
import qualified Language.Haskell.Exts as Exts
import           Language.Haskell.Exts hiding (Style, prettyPrint, Pretty, style, parse)
import           Prelude

-- | Format the given source.
reformat :: Config -> Maybe [Extension] -> Maybe FilePath -> ByteString -> Either String Builder
reformat :: Config
-> Maybe [Extension]
-> Maybe String
-> ByteString
-> Either String Builder
reformat Config
config Maybe [Extension]
mexts Maybe String
mfilepath =
    (ByteString -> Either String Builder)
-> ByteString -> Either String Builder
forall {m :: * -> *}.
Monad m =>
(ByteString -> m Builder) -> ByteString -> m Builder
preserveTrailingNewline
        (([Builder] -> Builder)
-> Either String [Builder] -> Either String Builder
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Builder] -> Builder
forall a. Monoid a => [a] -> a
mconcat ([Builder] -> Builder)
-> ([Builder] -> [Builder]) -> [Builder] -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> [Builder] -> [Builder]
forall a. a -> [a] -> [a]
intersperse Builder
"\n") (Either String [Builder] -> Either String Builder)
-> (ByteString -> Either String [Builder])
-> ByteString
-> Either String Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CodeBlock -> Either String Builder)
-> [CodeBlock] -> Either String [Builder]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM CodeBlock -> Either String Builder
processBlock ([CodeBlock] -> Either String [Builder])
-> (ByteString -> [CodeBlock])
-> ByteString
-> Either String [Builder]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [CodeBlock]
cppSplitBlocks)
  where
    processBlock :: CodeBlock -> Either String Builder
    processBlock :: CodeBlock -> Either String Builder
processBlock (Shebang ByteString
text) = Builder -> Either String Builder
forall a b. b -> Either a b
Right (Builder -> Either String Builder)
-> Builder -> Either String Builder
forall a b. (a -> b) -> a -> b
$ ByteString -> Builder
S.byteString ByteString
text
    processBlock (CPPDirectives ByteString
text) = Builder -> Either String Builder
forall a b. b -> Either a b
Right (Builder -> Either String Builder)
-> Builder -> Either String Builder
forall a b. (a -> b) -> a -> b
$ ByteString -> Builder
S.byteString ByteString
text
    processBlock (HaskellSource Int
line ByteString
text) =
        let ls :: [ByteString]
ls = ByteString -> [ByteString]
S8.lines ByteString
text
            prefix :: ByteString
prefix = [ByteString] -> ByteString
findPrefix [ByteString]
ls
            code :: ByteString
code = [ByteString] -> ByteString
unlines' ((ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (ByteString -> ByteString -> ByteString
stripPrefix ByteString
prefix) [ByteString]
ls)
            exts :: Maybe (Maybe Language, [Extension])
exts = String -> Maybe (Maybe Language, [Extension])
readExtensions (ByteString -> String
UTF8.toString ByteString
code)
            mode'' :: ParseMode
mode'' = case Maybe (Maybe Language, [Extension])
exts of
                       Maybe (Maybe Language, [Extension])
Nothing -> ParseMode
mode'
                       Just (Maybe Language
Nothing, [Extension]
exts') ->
                         ParseMode
mode' { extensions :: [Extension]
extensions =
                                   [Extension]
exts'
                                   [Extension] -> [Extension] -> [Extension]
forall a. [a] -> [a] -> [a]
++ Config -> [Extension]
configExtensions Config
config
                                   [Extension] -> [Extension] -> [Extension]
forall a. [a] -> [a] -> [a]
++ ParseMode -> [Extension]
extensions ParseMode
mode' }
                       Just (Just Language
lang, [Extension]
exts') ->
                         ParseMode
mode' { baseLanguage :: Language
baseLanguage = Language
lang
                               , extensions :: [Extension]
extensions =
                                   [Extension]
exts'
                                   [Extension] -> [Extension] -> [Extension]
forall a. [a] -> [a] -> [a]
++ Config -> [Extension]
configExtensions Config
config
                                   [Extension] -> [Extension] -> [Extension]
forall a. [a] -> [a] -> [a]
++ ParseMode -> [Extension]
extensions ParseMode
mode' }
        in case ParseMode -> String -> ParseResult (Module SrcSpanInfo, [Comment])
parseModuleWithComments ParseMode
mode'' (ByteString -> String
UTF8.toString ByteString
code) of
               ParseOk (Module SrcSpanInfo
m, [Comment]
comments) ->
                   (Builder -> Builder)
-> Either String Builder -> Either String Builder
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
                       (ByteString -> Builder
S.lazyByteString (ByteString -> Builder)
-> (Builder -> ByteString) -> Builder -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString -> ByteString
addPrefix ByteString
prefix (ByteString -> ByteString)
-> (Builder -> ByteString) -> Builder -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
S.toLazyByteString)
                       (Config -> Module SrcSpanInfo -> [Comment] -> Either String Builder
forall a.
Config -> Module SrcSpanInfo -> [Comment] -> Either a Builder
prettyPrint Config
config Module SrcSpanInfo
m [Comment]
comments)
               ParseFailed SrcLoc
loc String
e ->
                 String -> Either String Builder
forall a b. a -> Either a b
Left (SrcLoc -> String
forall a. Pretty a => a -> String
Exts.prettyPrint (SrcLoc
loc {srcLine :: Int
srcLine = SrcLoc -> Int
srcLine SrcLoc
loc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
line}) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
e)
    unlines' :: [ByteString] -> ByteString
unlines' = [ByteString] -> ByteString
S.concat ([ByteString] -> ByteString)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
intersperse ByteString
"\n"
    unlines'' :: [ByteString] -> ByteString
unlines'' = [ByteString] -> ByteString
L.concat ([ByteString] -> ByteString)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
intersperse ByteString
"\n"
    addPrefix :: ByteString -> L8.ByteString -> L8.ByteString
    addPrefix :: ByteString -> ByteString -> ByteString
addPrefix ByteString
prefix = [ByteString] -> ByteString
unlines'' ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (ByteString -> ByteString
L8.fromStrict ByteString
prefix ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>) ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
L8.lines
    stripPrefix :: ByteString -> ByteString -> ByteString
    stripPrefix :: ByteString -> ByteString -> ByteString
stripPrefix ByteString
prefix ByteString
line =
        if ByteString -> Bool
S.null ((Char -> Bool) -> ByteString -> ByteString
S8.dropWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\n') ByteString
line)
            then ByteString
line
            else ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe (String -> ByteString
forall a. HasCallStack => String -> a
error String
"Missing expected prefix") (Maybe ByteString -> ByteString)
-> (ByteString -> Maybe ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString -> Maybe ByteString
s8_stripPrefix ByteString
prefix (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
                 ByteString
line
    findPrefix :: [ByteString] -> ByteString
    findPrefix :: [ByteString] -> ByteString
findPrefix = Bool -> ByteString -> ByteString
takePrefix Bool
False (ByteString -> ByteString)
-> ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
findSmallestPrefix ([ByteString] -> ByteString)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> [ByteString]
dropNewlines
    dropNewlines :: [ByteString] -> [ByteString]
    dropNewlines :: [ByteString] -> [ByteString]
dropNewlines = (ByteString -> Bool) -> [ByteString] -> [ByteString]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (ByteString -> Bool) -> ByteString -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Bool
S.null (ByteString -> Bool)
-> (ByteString -> ByteString) -> ByteString -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Bool) -> ByteString -> ByteString
S8.dropWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\n'))
    takePrefix :: Bool -> ByteString -> ByteString
    takePrefix :: Bool -> ByteString -> ByteString
takePrefix Bool
bracketUsed ByteString
txt =
        case ByteString -> Maybe (Char, ByteString)
S8.uncons ByteString
txt of
            Maybe (Char, ByteString)
Nothing -> ByteString
""
            Just (Char
'>', ByteString
txt') ->
                if Bool -> Bool
not Bool
bracketUsed
                    then Char -> ByteString -> ByteString
S8.cons Char
'>' (Bool -> ByteString -> ByteString
takePrefix Bool
True ByteString
txt')
                    else ByteString
""
            Just (Char
c, ByteString
txt') ->
                if Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
' ' Bool -> Bool -> Bool
|| Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\t'
                    then Char -> ByteString -> ByteString
S8.cons Char
c (Bool -> ByteString -> ByteString
takePrefix Bool
bracketUsed ByteString
txt')
                    else ByteString
""
    findSmallestPrefix :: [ByteString] -> ByteString
    findSmallestPrefix :: [ByteString] -> ByteString
findSmallestPrefix [] = ByteString
""
    findSmallestPrefix (ByteString
"":[ByteString]
_) = ByteString
""
    findSmallestPrefix (ByteString
p:[ByteString]
ps) =
        let first :: Char
first = ByteString -> Char
S8.head ByteString
p
            startsWithChar :: Char -> ByteString -> Bool
startsWithChar Char
c ByteString
x = ByteString -> Int
S8.length ByteString
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& ByteString -> Char
S8.head ByteString
x Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
c
        in if (ByteString -> Bool) -> [ByteString] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Char -> ByteString -> Bool
startsWithChar Char
first) [ByteString]
ps
               then Char -> ByteString -> ByteString
S8.cons
                        Char
first
                        ([ByteString] -> ByteString
findSmallestPrefix (ByteString -> ByteString
S.tail ByteString
p ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> ByteString
S.tail [ByteString]
ps))
               else ByteString
""
    mode' :: ParseMode
mode' =
        let m :: ParseMode
m = case Maybe [Extension]
mexts of
                  Just [Extension]
exts ->
                    ParseMode
parseMode
                    { extensions :: [Extension]
extensions = [Extension]
exts
                    }
                  Maybe [Extension]
Nothing -> ParseMode
parseMode
        in ParseMode
m { parseFilename :: String
parseFilename = String -> Maybe String -> String
forall a. a -> Maybe a -> a
fromMaybe String
"<interactive>" Maybe String
mfilepath }
    preserveTrailingNewline :: (ByteString -> m Builder) -> ByteString -> m Builder
preserveTrailingNewline ByteString -> m Builder
f ByteString
x =
        if ByteString -> Bool
S8.null ByteString
x Bool -> Bool -> Bool
|| (Char -> Bool) -> ByteString -> Bool
S8.all Char -> Bool
isSpace ByteString
x
            then Builder -> m Builder
forall (m :: * -> *) a. Monad m => a -> m a
return Builder
forall a. Monoid a => a
mempty
            else if ByteString -> Bool
hasTrailingLine ByteString
x Bool -> Bool -> Bool
|| Config -> Bool
configTrailingNewline Config
config
                     then (Builder -> Builder) -> m Builder -> m Builder
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
                              (\Builder
x' ->
                                    if ByteString -> Bool
hasTrailingLine
                                           (ByteString -> ByteString
L.toStrict (Builder -> ByteString
S.toLazyByteString Builder
x'))
                                        then Builder
x'
                                        else Builder
x' Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
"\n")
                              (ByteString -> m Builder
f ByteString
x)
                     else ByteString -> m Builder
f ByteString
x

-- | Does the strict bytestring have a trailing newline?
hasTrailingLine :: ByteString -> Bool
hasTrailingLine :: ByteString -> Bool
hasTrailingLine ByteString
xs =
    if ByteString -> Bool
S8.null ByteString
xs
        then Bool
False
        else ByteString -> Char
S8.last ByteString
xs Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\n'

-- | Print the module.
prettyPrint :: Config
            -> Module SrcSpanInfo
            -> [Comment]
            -> Either a Builder
prettyPrint :: forall a.
Config -> Module SrcSpanInfo -> [Comment] -> Either a Builder
prettyPrint Config
config Module SrcSpanInfo
m [Comment]
comments =
  let ast :: Module NodeInfo
ast =
        State [Comment] (Module NodeInfo) -> [Comment] -> Module NodeInfo
forall s a. State s a -> s -> a
evalState
          (Module SrcSpanInfo -> State [Comment] (Module NodeInfo)
collectAllComments
             (Module SrcSpanInfo
-> Maybe (Module SrcSpanInfo) -> Module SrcSpanInfo
forall a. a -> Maybe a -> a
fromMaybe Module SrcSpanInfo
m ([Fixity] -> Module SrcSpanInfo -> Maybe (Module SrcSpanInfo)
forall (ast :: * -> *) (m :: * -> *).
(AppFixity ast, MonadFail m) =>
[Fixity] -> ast SrcSpanInfo -> m (ast SrcSpanInfo)
applyFixities [Fixity]
baseFixities Module SrcSpanInfo
m)))
          [Comment]
comments
  in Builder -> Either a Builder
forall a b. b -> Either a b
Right (Config -> Printer () -> Builder
runPrinterStyle Config
config (Module NodeInfo -> Printer ()
forall (ast :: * -> *).
(Pretty ast, Show (ast NodeInfo)) =>
ast NodeInfo -> Printer ()
pretty Module NodeInfo
ast))

-- | Pretty print the given printable thing.
runPrinterStyle :: Config -> Printer () -> Builder
runPrinterStyle :: Config -> Printer () -> Builder
runPrinterStyle Config
config Printer ()
m =
  Builder -> (PrintState -> Builder) -> Maybe PrintState -> Builder
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
    (String -> Builder
forall a. HasCallStack => String -> a
error String
"Printer failed with mzero call.")
    PrintState -> Builder
psOutput
    (Identity (Maybe PrintState) -> Maybe PrintState
forall a. Identity a -> a
runIdentity
       (MaybeT Identity PrintState -> Identity (Maybe PrintState)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT
          (StateT PrintState (MaybeT Identity) ()
-> PrintState -> MaybeT Identity PrintState
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT
             (Printer () -> StateT PrintState (MaybeT Identity) ()
forall a. Printer a -> StateT PrintState (MaybeT Identity) a
runPrinter Printer ()
m)
             (PrintState :: Int64
-> Builder
-> Bool
-> Int64
-> Int64
-> Config
-> Bool
-> Bool
-> Bool
-> PrintState
PrintState
              { psIndentLevel :: Int64
psIndentLevel = Int64
0
              , psOutput :: Builder
psOutput = Builder
forall a. Monoid a => a
mempty
              , psNewline :: Bool
psNewline = Bool
False
              , psColumn :: Int64
psColumn = Int64
0
              , psLine :: Int64
psLine = Int64
1
              , psConfig :: Config
psConfig = Config
config
              , psInsideCase :: Bool
psInsideCase = Bool
False
              , psFitOnOneLine :: Bool
psFitOnOneLine = Bool
False
              , psEolComment :: Bool
psEolComment = Bool
False
              }))))

-- | Parse mode, includes all extensions, doesn't assume any fixities.
parseMode :: ParseMode
parseMode :: ParseMode
parseMode =
  ParseMode
defaultParseMode {extensions :: [Extension]
extensions = [Extension]
allExtensions
                   ,fixities :: Maybe [Fixity]
fixities = Maybe [Fixity]
forall a. Maybe a
Nothing}
  where allExtensions :: [Extension]
allExtensions =
          (Extension -> Bool) -> [Extension] -> [Extension]
forall a. (a -> Bool) -> [a] -> [a]
filter Extension -> Bool
isDisabledExtension [Extension]
knownExtensions
        isDisabledExtension :: Extension -> Bool
isDisabledExtension (DisableExtension KnownExtension
_) = Bool
False
        isDisabledExtension Extension
_ = Bool
True

-- | Test the given file.
testFile :: FilePath -> IO ()
testFile :: String -> IO ()
testFile String
fp  = String -> IO ByteString
S.readFile String
fp IO ByteString -> (ByteString -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> IO ()
test

-- | Test the given file.
testFileAst :: FilePath -> IO ()
testFileAst :: String -> IO ()
testFileAst String
fp  = String -> IO ByteString
S.readFile String
fp IO ByteString -> (ByteString -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Either String (Module NodeInfo) -> IO ()
forall a. Show a => a -> IO ()
print (Either String (Module NodeInfo) -> IO ())
-> (ByteString -> Either String (Module NodeInfo))
-> ByteString
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either String (Module NodeInfo)
testAst

-- | Test with the given style, prints to stdout.
test :: ByteString -> IO ()
test :: ByteString -> IO ()
test =
  (String -> IO ())
-> (Builder -> IO ()) -> Either String Builder -> IO ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> IO ()
forall a. HasCallStack => String -> a
error (ByteString -> IO ()
L8.putStrLn (ByteString -> IO ())
-> (Builder -> ByteString) -> Builder -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
S.toLazyByteString) (Either String Builder -> IO ())
-> (ByteString -> Either String Builder) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  Config
-> Maybe [Extension]
-> Maybe String
-> ByteString
-> Either String Builder
reformat Config
defaultConfig Maybe [Extension]
forall a. Maybe a
Nothing Maybe String
forall a. Maybe a
Nothing

-- | Parse the source and annotate it with comments, yielding the resulting AST.
testAst :: ByteString -> Either String (Module NodeInfo)
testAst :: ByteString -> Either String (Module NodeInfo)
testAst ByteString
x =
  case ParseMode -> String -> ParseResult (Module SrcSpanInfo, [Comment])
parseModuleWithComments ParseMode
parseMode (ByteString -> String
UTF8.toString ByteString
x) of
    ParseOk (Module SrcSpanInfo
m,[Comment]
comments) ->
      Module NodeInfo -> Either String (Module NodeInfo)
forall a b. b -> Either a b
Right
        (let ast :: Module NodeInfo
ast =
               State [Comment] (Module NodeInfo) -> [Comment] -> Module NodeInfo
forall s a. State s a -> s -> a
evalState
                 (Module SrcSpanInfo -> State [Comment] (Module NodeInfo)
collectAllComments
                    (Module SrcSpanInfo
-> Maybe (Module SrcSpanInfo) -> Module SrcSpanInfo
forall a. a -> Maybe a -> a
fromMaybe Module SrcSpanInfo
m ([Fixity] -> Module SrcSpanInfo -> Maybe (Module SrcSpanInfo)
forall (ast :: * -> *) (m :: * -> *).
(AppFixity ast, MonadFail m) =>
[Fixity] -> ast SrcSpanInfo -> m (ast SrcSpanInfo)
applyFixities [Fixity]
baseFixities Module SrcSpanInfo
m)))
                 [Comment]
comments
         in Module NodeInfo
ast)
    ParseFailed SrcLoc
_ String
e -> String -> Either String (Module NodeInfo)
forall a b. a -> Either a b
Left String
e

-- | Default extensions.
defaultExtensions :: [Extension]
defaultExtensions :: [Extension]
defaultExtensions =
  [ Extension
e
  | e :: Extension
e@EnableExtension {} <- [Extension]
knownExtensions ] [Extension] -> [Extension] -> [Extension]
forall a. Eq a => [a] -> [a] -> [a]
\\
  (KnownExtension -> Extension) -> [KnownExtension] -> [Extension]
forall a b. (a -> b) -> [a] -> [b]
map KnownExtension -> Extension
EnableExtension [KnownExtension]
badExtensions

-- | Extensions which steal too much syntax.
badExtensions :: [KnownExtension]
badExtensions :: [KnownExtension]
badExtensions =
    [KnownExtension
Arrows -- steals proc
    ,KnownExtension
TransformListComp -- steals the group keyword
    ,KnownExtension
XmlSyntax, KnownExtension
RegularPatterns -- steals a-b
    ,KnownExtension
UnboxedTuples -- breaks (#) lens operator
    -- ,QuasiQuotes -- breaks [x| ...], making whitespace free list comps break
    ,KnownExtension
PatternSynonyms -- steals the pattern keyword
    ,KnownExtension
RecursiveDo -- steals the rec keyword
    ,KnownExtension
DoRec -- same
    ,KnownExtension
TypeApplications -- since GHC 8 and haskell-src-exts-1.19
    ]


s8_stripPrefix :: ByteString -> ByteString -> Maybe ByteString
s8_stripPrefix :: ByteString -> ByteString -> Maybe ByteString
s8_stripPrefix bs1 :: ByteString
bs1@(S.PS ForeignPtr Word8
_ Int
_ Int
l1) ByteString
bs2
   | ByteString
bs1 ByteString -> ByteString -> Bool
`S.isPrefixOf` ByteString
bs2 = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (Int -> ByteString -> ByteString
S.unsafeDrop Int
l1 ByteString
bs2)
   | Bool
otherwise = Maybe ByteString
forall a. Maybe a
Nothing

--------------------------------------------------------------------------------
-- Extensions stuff stolen from hlint

-- | Consume an extensions list from arguments.
getExtensions :: [Text] -> [Extension]
getExtensions :: [Text] -> [Extension]
getExtensions = ([Extension] -> String -> [Extension])
-> [Extension] -> [String] -> [Extension]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl [Extension] -> String -> [Extension]
f [Extension]
defaultExtensions ([String] -> [Extension])
-> ([Text] -> [String]) -> [Text] -> [Extension]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> String) -> [Text] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Text -> String
T.unpack
  where f :: [Extension] -> String -> [Extension]
f [Extension]
_ String
"Haskell98" = []
        f [Extension]
a (Char
'N':Char
'o':String
x)
          | Just Extension
x' <- String -> Maybe Extension
forall (m :: * -> *).
(Monad m, MonadFail m) =>
String -> m Extension
readExtension String
x =
            Extension -> [Extension] -> [Extension]
forall a. Eq a => a -> [a] -> [a]
delete Extension
x' [Extension]
a
        f [Extension]
a String
x
          | Just Extension
x' <- String -> Maybe Extension
forall (m :: * -> *).
(Monad m, MonadFail m) =>
String -> m Extension
readExtension String
x =
            Extension
x' Extension -> [Extension] -> [Extension]
forall a. a -> [a] -> [a]
:
            Extension -> [Extension] -> [Extension]
forall a. Eq a => a -> [a] -> [a]
delete Extension
x' [Extension]
a
        f [Extension]
_ String
x = String -> [Extension]
forall a. HasCallStack => String -> a
error (String -> [Extension]) -> String -> [Extension]
forall a b. (a -> b) -> a -> b
$ String
"Unknown extension: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
x

--------------------------------------------------------------------------------
-- Comments

-- | Traverse the structure backwards.
traverseInOrder
  :: (Monad m, Traversable t, Functor m)
  => (b -> b -> Ordering) -> (b -> m b) -> t b -> m (t b)
traverseInOrder :: forall (m :: * -> *) (t :: * -> *) b.
(Monad m, Traversable t, Functor m) =>
(b -> b -> Ordering) -> (b -> m b) -> t b -> m (t b)
traverseInOrder b -> b -> Ordering
cmp b -> m b
f t b
ast = do
  [(Integer, b)]
indexed <-
    ([b] -> [(Integer, b)]) -> m [b] -> m [(Integer, b)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Integer] -> [b] -> [(Integer, b)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 :: Integer ..] ([b] -> [(Integer, b)]) -> ([b] -> [b]) -> [b] -> [(Integer, b)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [b] -> [b]
forall a. [a] -> [a]
reverse) (StateT [b] m (t ()) -> [b] -> m [b]
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT ((b -> StateT [b] m ()) -> t b -> StateT [b] m (t ())
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (([b] -> [b]) -> StateT [b] m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (([b] -> [b]) -> StateT [b] m ())
-> (b -> [b] -> [b]) -> b -> StateT [b] m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:)) t b
ast) [])
  let sorted :: [(Integer, b)]
sorted = ((Integer, b) -> (Integer, b) -> Ordering)
-> [(Integer, b)] -> [(Integer, b)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (\(Integer
_,b
x) (Integer
_,b
y) -> b -> b -> Ordering
cmp b
x b
y) [(Integer, b)]
indexed
  [(Integer, b)]
results <-
    ((Integer, b) -> m (Integer, b))
-> [(Integer, b)] -> m [(Integer, b)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
      (\(Integer
i,b
m) -> do
         b
v <- b -> m b
f b
m
         (Integer, b) -> m (Integer, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer
i, b
v))
      [(Integer, b)]
sorted
  StateT [Integer] m (t b) -> [Integer] -> m (t b)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT
    ((b -> StateT [Integer] m b) -> t b -> StateT [Integer] m (t b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
       (StateT [Integer] m b -> b -> StateT [Integer] m b
forall a b. a -> b -> a
const
          (do Integer
i <- ([Integer] -> Integer) -> StateT [Integer] m Integer
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets [Integer] -> Integer
forall a. [a] -> a
head
              ([Integer] -> [Integer]) -> StateT [Integer] m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify [Integer] -> [Integer]
forall a. [a] -> [a]
tail
              case Integer -> [(Integer, b)] -> Maybe b
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Integer
i [(Integer, b)]
results of
                Maybe b
Nothing -> String -> StateT [Integer] m b
forall a. HasCallStack => String -> a
error String
"traverseInOrder"
                Just b
x -> b -> StateT [Integer] m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
x))
       t b
ast)
    [Integer
0 ..]

-- | Collect all comments in the module by traversing the tree. Read
-- this from bottom to top.
collectAllComments :: Module SrcSpanInfo -> State [Comment] (Module NodeInfo)
collectAllComments :: Module SrcSpanInfo -> State [Comment] (Module NodeInfo)
collectAllComments =
  (Module NodeInfo -> State [Comment] (Module NodeInfo))
-> Module NodeInfo -> State [Comment] (Module NodeInfo)
forall {m :: * -> *} {t :: * -> *} {a} {t}.
(MonadState (t a) m, Foldable t) =>
(t -> m t) -> t -> m t
shortCircuit
    ((NodeInfo -> StateT [Comment] Identity NodeInfo)
-> Module NodeInfo -> State [Comment] (Module NodeInfo)
traverseBackwards
     -- Finally, collect backwards comments which come after each node.
       ((SrcSpan -> SomeComment -> NodeComment)
-> (SrcSpan -> SrcSpan -> Bool)
-> NodeInfo
-> StateT [Comment] Identity NodeInfo
collectCommentsBy
          SrcSpan -> SomeComment -> NodeComment
CommentAfterLine
          (\SrcSpan
nodeSpan SrcSpan
commentSpan ->
              (Int, Int) -> Int
forall a b. (a, b) -> a
fst (SrcSpan -> (Int, Int)
srcSpanStart SrcSpan
commentSpan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= (Int, Int) -> Int
forall a b. (a, b) -> a
fst (SrcSpan -> (Int, Int)
srcSpanEnd SrcSpan
nodeSpan)))) (Module NodeInfo -> State [Comment] (Module NodeInfo))
-> (Module SrcSpanInfo -> State [Comment] (Module NodeInfo))
-> Module SrcSpanInfo
-> State [Comment] (Module NodeInfo)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=<
  (Module NodeInfo -> State [Comment] (Module NodeInfo))
-> Module NodeInfo -> State [Comment] (Module NodeInfo)
forall {m :: * -> *} {t :: * -> *} {a} {t}.
(MonadState (t a) m, Foldable t) =>
(t -> m t) -> t -> m t
shortCircuit Module NodeInfo -> State [Comment] (Module NodeInfo)
addCommentsToTopLevelWhereClauses (Module NodeInfo -> State [Comment] (Module NodeInfo))
-> (Module SrcSpanInfo -> State [Comment] (Module NodeInfo))
-> Module SrcSpanInfo
-> State [Comment] (Module NodeInfo)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=<
  (Module NodeInfo -> State [Comment] (Module NodeInfo))
-> Module NodeInfo -> State [Comment] (Module NodeInfo)
forall {m :: * -> *} {t :: * -> *} {a} {t}.
(MonadState (t a) m, Foldable t) =>
(t -> m t) -> t -> m t
shortCircuit
    ((NodeInfo -> StateT [Comment] Identity NodeInfo)
-> Module NodeInfo -> State [Comment] (Module NodeInfo)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
     -- Collect forwards comments which start at the end line of a
     -- node: Does the start line of the comment match the end-line
     -- of the node?
       ((SrcSpan -> SomeComment -> NodeComment)
-> (SrcSpan -> SrcSpan -> Bool)
-> NodeInfo
-> StateT [Comment] Identity NodeInfo
collectCommentsBy
          SrcSpan -> SomeComment -> NodeComment
CommentSameLine
          (\SrcSpan
nodeSpan SrcSpan
commentSpan ->
              (Int, Int) -> Int
forall a b. (a, b) -> a
fst (SrcSpan -> (Int, Int)
srcSpanStart SrcSpan
commentSpan) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== (Int, Int) -> Int
forall a b. (a, b) -> a
fst (SrcSpan -> (Int, Int)
srcSpanEnd SrcSpan
nodeSpan)))) (Module NodeInfo -> State [Comment] (Module NodeInfo))
-> (Module SrcSpanInfo -> State [Comment] (Module NodeInfo))
-> Module SrcSpanInfo
-> State [Comment] (Module NodeInfo)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=<
  (Module NodeInfo -> State [Comment] (Module NodeInfo))
-> Module NodeInfo -> State [Comment] (Module NodeInfo)
forall {m :: * -> *} {t :: * -> *} {a} {t}.
(MonadState (t a) m, Foldable t) =>
(t -> m t) -> t -> m t
shortCircuit
    ((NodeInfo -> StateT [Comment] Identity NodeInfo)
-> Module NodeInfo -> State [Comment] (Module NodeInfo)
traverseBackwards
     -- Collect backwards comments which are on the same line as a
     -- node: Does the start line & end line of the comment match
     -- that of the node?
       ((SrcSpan -> SomeComment -> NodeComment)
-> (SrcSpan -> SrcSpan -> Bool)
-> NodeInfo
-> StateT [Comment] Identity NodeInfo
collectCommentsBy
          SrcSpan -> SomeComment -> NodeComment
CommentSameLine
          (\SrcSpan
nodeSpan SrcSpan
commentSpan ->
              (Int, Int) -> Int
forall a b. (a, b) -> a
fst (SrcSpan -> (Int, Int)
srcSpanStart SrcSpan
commentSpan) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== (Int, Int) -> Int
forall a b. (a, b) -> a
fst (SrcSpan -> (Int, Int)
srcSpanStart SrcSpan
nodeSpan) Bool -> Bool -> Bool
&&
              (Int, Int) -> Int
forall a b. (a, b) -> a
fst (SrcSpan -> (Int, Int)
srcSpanStart SrcSpan
commentSpan) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== (Int, Int) -> Int
forall a b. (a, b) -> a
fst (SrcSpan -> (Int, Int)
srcSpanEnd SrcSpan
nodeSpan)))) (Module NodeInfo -> State [Comment] (Module NodeInfo))
-> (Module SrcSpanInfo -> State [Comment] (Module NodeInfo))
-> Module SrcSpanInfo
-> State [Comment] (Module NodeInfo)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=<
  (Module NodeInfo -> State [Comment] (Module NodeInfo))
-> Module NodeInfo -> State [Comment] (Module NodeInfo)
forall {m :: * -> *} {t :: * -> *} {a} {t}.
(MonadState (t a) m, Foldable t) =>
(t -> m t) -> t -> m t
shortCircuit
    ((NodeInfo -> StateT [Comment] Identity NodeInfo)
-> Module NodeInfo -> State [Comment] (Module NodeInfo)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
     -- First, collect forwards comments for declarations which both
     -- start on column 1 and occur before the declaration.
       ((SrcSpan -> SomeComment -> NodeComment)
-> (SrcSpan -> SrcSpan -> Bool)
-> NodeInfo
-> StateT [Comment] Identity NodeInfo
collectCommentsBy
          SrcSpan -> SomeComment -> NodeComment
CommentBeforeLine
          (\SrcSpan
nodeSpan SrcSpan
commentSpan ->
              ((Int, Int) -> Int
forall a b. (a, b) -> b
snd (SrcSpan -> (Int, Int)
srcSpanStart SrcSpan
nodeSpan) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
&&
               (Int, Int) -> Int
forall a b. (a, b) -> b
snd (SrcSpan -> (Int, Int)
srcSpanStart SrcSpan
commentSpan) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) Bool -> Bool -> Bool
&&
              (Int, Int) -> Int
forall a b. (a, b) -> a
fst (SrcSpan -> (Int, Int)
srcSpanStart SrcSpan
commentSpan) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< (Int, Int) -> Int
forall a b. (a, b) -> a
fst (SrcSpan -> (Int, Int)
srcSpanStart SrcSpan
nodeSpan)))) (Module NodeInfo -> State [Comment] (Module NodeInfo))
-> (Module SrcSpanInfo -> Module NodeInfo)
-> Module SrcSpanInfo
-> State [Comment] (Module NodeInfo)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  (SrcSpanInfo -> NodeInfo) -> Module SrcSpanInfo -> Module NodeInfo
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SrcSpanInfo -> NodeInfo
nodify
  where
    nodify :: SrcSpanInfo -> NodeInfo
nodify SrcSpanInfo
s = SrcSpanInfo -> [NodeComment] -> NodeInfo
NodeInfo SrcSpanInfo
s [NodeComment]
forall a. Monoid a => a
mempty
    -- Sort the comments by their end position.
    traverseBackwards :: (NodeInfo -> StateT [Comment] Identity NodeInfo)
-> Module NodeInfo -> State [Comment] (Module NodeInfo)
traverseBackwards =
      (NodeInfo -> NodeInfo -> Ordering)
-> (NodeInfo -> StateT [Comment] Identity NodeInfo)
-> Module NodeInfo
-> State [Comment] (Module NodeInfo)
forall (m :: * -> *) (t :: * -> *) b.
(Monad m, Traversable t, Functor m) =>
(b -> b -> Ordering) -> (b -> m b) -> t b -> m (t b)
traverseInOrder
        (\NodeInfo
x NodeInfo
y -> ((Int, Int) -> (Int, Int) -> Ordering)
-> (NodeInfo -> (Int, Int)) -> NodeInfo -> NodeInfo -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
on (((Int, Int) -> (Int, Int) -> Ordering)
-> (Int, Int) -> (Int, Int) -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int, Int) -> (Int, Int) -> Ordering
forall a. Ord a => a -> a -> Ordering
compare) (SrcSpan -> (Int, Int)
srcSpanEnd (SrcSpan -> (Int, Int))
-> (NodeInfo -> SrcSpan) -> NodeInfo -> (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SrcSpanInfo -> SrcSpan
srcInfoSpan (SrcSpanInfo -> SrcSpan)
-> (NodeInfo -> SrcSpanInfo) -> NodeInfo -> SrcSpan
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeInfo -> SrcSpanInfo
nodeInfoSpan) NodeInfo
x NodeInfo
y)
    -- Stop traversing if all comments have been consumed.
    shortCircuit :: (t -> m t) -> t -> m t
shortCircuit t -> m t
m t
v = do
      t a
comments <- m (t a)
forall s (m :: * -> *). MonadState s m => m s
get
      if t a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null t a
comments
        then t -> m t
forall (m :: * -> *) a. Monad m => a -> m a
return t
v
        else t -> m t
m t
v

-- | Collect comments by satisfying the given predicate, to collect a
-- comment means to remove it from the pool of available comments in
-- the State. This allows for a multiple pass approach.
collectCommentsBy
  :: (SrcSpan -> SomeComment -> NodeComment)
  -> (SrcSpan -> SrcSpan -> Bool)
  -> NodeInfo
  -> State [Comment] NodeInfo
collectCommentsBy :: (SrcSpan -> SomeComment -> NodeComment)
-> (SrcSpan -> SrcSpan -> Bool)
-> NodeInfo
-> StateT [Comment] Identity NodeInfo
collectCommentsBy SrcSpan -> SomeComment -> NodeComment
cons SrcSpan -> SrcSpan -> Bool
predicate nodeInfo :: NodeInfo
nodeInfo@(NodeInfo (SrcSpanInfo SrcSpan
nodeSpan [SrcSpan]
_) [NodeComment]
_) = do
  [Comment]
comments <- StateT [Comment] Identity [Comment]
forall s (m :: * -> *). MonadState s m => m s
get
  let ([Comment]
others, [Comment]
mine) =
        [Either Comment Comment] -> ([Comment], [Comment])
forall a b. [Either a b] -> ([a], [b])
partitionEithers
          ((Comment -> Either Comment Comment)
-> [Comment] -> [Either Comment Comment]
forall a b. (a -> b) -> [a] -> [b]
map
             (\comment :: Comment
comment@(Comment Bool
_ SrcSpan
commentSpan String
_) ->
                 if SrcSpan -> SrcSpan -> Bool
predicate SrcSpan
nodeSpan SrcSpan
commentSpan
                   then Comment -> Either Comment Comment
forall a b. b -> Either a b
Right Comment
comment
                   else Comment -> Either Comment Comment
forall a b. a -> Either a b
Left Comment
comment)
             [Comment]
comments)
  [Comment] -> StateT [Comment] Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [Comment]
others
  NodeInfo -> StateT [Comment] Identity NodeInfo
forall (m :: * -> *) a. Monad m => a -> m a
return (NodeInfo -> StateT [Comment] Identity NodeInfo)
-> NodeInfo -> StateT [Comment] Identity NodeInfo
forall a b. (a -> b) -> a -> b
$ (SrcSpan -> SomeComment -> NodeComment)
-> [Comment] -> NodeInfo -> NodeInfo
addCommentsToNode SrcSpan -> SomeComment -> NodeComment
cons [Comment]
mine NodeInfo
nodeInfo

-- | Reintroduce comments which were immediately above declarations in where clauses.
-- Affects where clauses of top level declarations only.
addCommentsToTopLevelWhereClauses ::
     Module NodeInfo -> State [Comment] (Module NodeInfo)
addCommentsToTopLevelWhereClauses :: Module NodeInfo -> State [Comment] (Module NodeInfo)
addCommentsToTopLevelWhereClauses (Module NodeInfo
x Maybe (ModuleHead NodeInfo)
x' [ModulePragma NodeInfo]
x'' [ImportDecl NodeInfo]
x''' [Decl NodeInfo]
topLevelDecls) =
  NodeInfo
-> Maybe (ModuleHead NodeInfo)
-> [ModulePragma NodeInfo]
-> [ImportDecl NodeInfo]
-> [Decl NodeInfo]
-> Module NodeInfo
forall l.
l
-> Maybe (ModuleHead l)
-> [ModulePragma l]
-> [ImportDecl l]
-> [Decl l]
-> Module l
Module NodeInfo
x Maybe (ModuleHead NodeInfo)
x' [ModulePragma NodeInfo]
x'' [ImportDecl NodeInfo]
x''' ([Decl NodeInfo] -> Module NodeInfo)
-> StateT [Comment] Identity [Decl NodeInfo]
-> State [Comment] (Module NodeInfo)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
  (Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo))
-> [Decl NodeInfo] -> StateT [Comment] Identity [Decl NodeInfo]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo)
addCommentsToWhereClauses [Decl NodeInfo]
topLevelDecls
  where
    addCommentsToWhereClauses ::
         Decl NodeInfo -> State [Comment] (Decl NodeInfo)
    addCommentsToWhereClauses :: Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo)
addCommentsToWhereClauses (PatBind NodeInfo
x Pat NodeInfo
x' Rhs NodeInfo
x'' (Just (BDecls NodeInfo
x''' [Decl NodeInfo]
whereDecls))) = do
      [Decl NodeInfo]
newWhereDecls <- (Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo))
-> [Decl NodeInfo] -> StateT [Comment] Identity [Decl NodeInfo]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo)
addCommentsToPatBind [Decl NodeInfo]
whereDecls
      Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo))
-> Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo)
forall a b. (a -> b) -> a -> b
$ NodeInfo
-> Pat NodeInfo
-> Rhs NodeInfo
-> Maybe (Binds NodeInfo)
-> Decl NodeInfo
forall l. l -> Pat l -> Rhs l -> Maybe (Binds l) -> Decl l
PatBind NodeInfo
x Pat NodeInfo
x' Rhs NodeInfo
x'' (Binds NodeInfo -> Maybe (Binds NodeInfo)
forall a. a -> Maybe a
Just (NodeInfo -> [Decl NodeInfo] -> Binds NodeInfo
forall l. l -> [Decl l] -> Binds l
BDecls NodeInfo
x''' [Decl NodeInfo]
newWhereDecls))
    addCommentsToWhereClauses Decl NodeInfo
other = Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return Decl NodeInfo
other
    addCommentsToPatBind :: Decl NodeInfo -> State [Comment] (Decl NodeInfo)
    addCommentsToPatBind :: Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo)
addCommentsToPatBind (PatBind NodeInfo
bindInfo (PVar NodeInfo
x (Ident NodeInfo
declNodeInfo String
declString)) Rhs NodeInfo
x' Maybe (Binds NodeInfo)
x'') = do
      NodeInfo
bindInfoWithComments <- NodeInfo -> StateT [Comment] Identity NodeInfo
addCommentsBeforeNode NodeInfo
bindInfo
      Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo))
-> Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo)
forall a b. (a -> b) -> a -> b
$
        NodeInfo
-> Pat NodeInfo
-> Rhs NodeInfo
-> Maybe (Binds NodeInfo)
-> Decl NodeInfo
forall l. l -> Pat l -> Rhs l -> Maybe (Binds l) -> Decl l
PatBind
          NodeInfo
bindInfoWithComments
          (NodeInfo -> Name NodeInfo -> Pat NodeInfo
forall l. l -> Name l -> Pat l
PVar NodeInfo
x (NodeInfo -> String -> Name NodeInfo
forall l. l -> String -> Name l
Ident NodeInfo
declNodeInfo String
declString))
          Rhs NodeInfo
x'
          Maybe (Binds NodeInfo)
x''
    addCommentsToPatBind Decl NodeInfo
other = Decl NodeInfo -> StateT [Comment] Identity (Decl NodeInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return Decl NodeInfo
other
    addCommentsBeforeNode :: NodeInfo -> State [Comment] NodeInfo
    addCommentsBeforeNode :: NodeInfo -> StateT [Comment] Identity NodeInfo
addCommentsBeforeNode NodeInfo
nodeInfo = do
      [Comment]
comments <- StateT [Comment] Identity [Comment]
forall s (m :: * -> *). MonadState s m => m s
get
      let ([Comment]
notAbove, [Comment]
above) = [Comment] -> NodeInfo -> ([Comment], [Comment])
partitionAboveNotAbove [Comment]
comments NodeInfo
nodeInfo
      [Comment] -> StateT [Comment] Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [Comment]
notAbove
      NodeInfo -> StateT [Comment] Identity NodeInfo
forall (m :: * -> *) a. Monad m => a -> m a
return (NodeInfo -> StateT [Comment] Identity NodeInfo)
-> NodeInfo -> StateT [Comment] Identity NodeInfo
forall a b. (a -> b) -> a -> b
$ (SrcSpan -> SomeComment -> NodeComment)
-> [Comment] -> NodeInfo -> NodeInfo
addCommentsToNode SrcSpan -> SomeComment -> NodeComment
CommentBeforeLine [Comment]
above NodeInfo
nodeInfo
    partitionAboveNotAbove :: [Comment] -> NodeInfo -> ([Comment], [Comment])
    partitionAboveNotAbove :: [Comment] -> NodeInfo -> ([Comment], [Comment])
partitionAboveNotAbove [Comment]
cs (NodeInfo (SrcSpanInfo SrcSpan
nodeSpan [SrcSpan]
_) [NodeComment]
_) =
      (([Comment], [Comment]), SrcSpan) -> ([Comment], [Comment])
forall a b. (a, b) -> a
fst ((([Comment], [Comment]), SrcSpan) -> ([Comment], [Comment]))
-> (([Comment], [Comment]), SrcSpan) -> ([Comment], [Comment])
forall a b. (a -> b) -> a -> b
$
      (Comment
 -> (([Comment], [Comment]), SrcSpan)
 -> (([Comment], [Comment]), SrcSpan))
-> (([Comment], [Comment]), SrcSpan)
-> [Comment]
-> (([Comment], [Comment]), SrcSpan)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr'
        (\comment :: Comment
comment@(Comment Bool
_ SrcSpan
commentSpan String
_) (([Comment]
ls, [Comment]
rs), SrcSpan
lastSpan) ->
           if Comment
comment Comment -> SrcSpan -> Bool
`isAbove` SrcSpan
lastSpan
             then (([Comment]
ls, Comment
comment Comment -> [Comment] -> [Comment]
forall a. a -> [a] -> [a]
: [Comment]
rs), SrcSpan
commentSpan)
             else ((Comment
comment Comment -> [Comment] -> [Comment]
forall a. a -> [a] -> [a]
: [Comment]
ls, [Comment]
rs), SrcSpan
lastSpan))
        (([], []), SrcSpan
nodeSpan)
        [Comment]
cs
    isAbove :: Comment -> SrcSpan -> Bool
    isAbove :: Comment -> SrcSpan -> Bool
isAbove (Comment Bool
_ SrcSpan
commentSpan String
_) SrcSpan
span =
      let (Int
_, Int
commentColStart) = SrcSpan -> (Int, Int)
srcSpanStart SrcSpan
commentSpan
          (Int
commentLnEnd, Int
_) = SrcSpan -> (Int, Int)
srcSpanEnd SrcSpan
commentSpan
          (Int
lnStart, Int
colStart) = SrcSpan -> (Int, Int)
srcSpanStart SrcSpan
span
       in Int
commentColStart Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
colStart Bool -> Bool -> Bool
&& Int
commentLnEnd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
lnStart
addCommentsToTopLevelWhereClauses Module NodeInfo
other = Module NodeInfo -> State [Comment] (Module NodeInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return Module NodeInfo
other

addCommentsToNode :: (SrcSpan -> SomeComment -> NodeComment)
                  -> [Comment]
                  -> NodeInfo
                  -> NodeInfo
addCommentsToNode :: (SrcSpan -> SomeComment -> NodeComment)
-> [Comment] -> NodeInfo -> NodeInfo
addCommentsToNode SrcSpan -> SomeComment -> NodeComment
mkNodeComment [Comment]
newComments nodeInfo :: NodeInfo
nodeInfo@(NodeInfo (SrcSpanInfo SrcSpan
_ [SrcSpan]
_) [NodeComment]
existingComments) =
  NodeInfo
nodeInfo
    {nodeInfoComments :: [NodeComment]
nodeInfoComments = [NodeComment]
existingComments [NodeComment] -> [NodeComment] -> [NodeComment]
forall a. Semigroup a => a -> a -> a
<> (Comment -> NodeComment) -> [Comment] -> [NodeComment]
forall a b. (a -> b) -> [a] -> [b]
map Comment -> NodeComment
mkBeforeNodeComment [Comment]
newComments}
  where
    mkBeforeNodeComment :: Comment -> NodeComment
    mkBeforeNodeComment :: Comment -> NodeComment
mkBeforeNodeComment (Comment Bool
multiLine SrcSpan
commentSpan String
commentString) =
      SrcSpan -> SomeComment -> NodeComment
mkNodeComment
        SrcSpan
commentSpan
        ((if Bool
multiLine
            then String -> SomeComment
MultiLine
            else String -> SomeComment
EndOfLine)
           String
commentString)