//
// Copyright (C) 1999, 2000, Marco Kesseler
//

implementation module jpeg

//**************************************************************************
//
// This code was initially (loosely) based on the code by Jeroen Fokker 
// (Department of Computer Science, Utrecht, August 7, 1995), but is has
// undergone quite some changes since. It has been tested on JPEG images
// inside PDF files (RGB, CMYK and greyscale). I do not know how well this
// code works for images from other sources.
//
// The output is an uncompressed image in 'interleaved' or 'chunky' format,
// returned as a list of characters (which is the way uncompressed images are
// stored in PDF files). The JPEG decompression algorithm supports all allowed
// JPEG precisions, but the output function only supports 8-bit images. Those
// who require other output formats probably need to change the 'segments2...image'
// function, and perhaps the 'imageToCharList' function.
//
// Big differences between this code and that of Jeroen are:
//
// - This code is more efficient:
//   - it uses a Huffman decoder that does not read codes one
//	   bit at at time, but takes a number of bits in one go.
//   - it uses arrays at a few places (albeit at as few places
//     as possible: lists are not that bad).
//	 - it tries to avoid append functions, in favour of
//	   continuations.
//	 - it uses a fast discrete cosine transform (which matters a lot!)
//	 - Intermediate bitmaps with 3 or 4 components use tuples to represent
//	   their pixels instead of lists, which avoids some node building.
//
// - This code does NOT use a monadic style. I do not like to carying things
//	 around implicitly, as I tend to forget what these things are (I know, the
//	 type system doesn't, but I am not a type system).
//
//	- This code complies better to the JPEG specifications. Actually,
//	  the following is mandatory to implement a basic decoder:
//	  - It supports multiple Huffman table specifications per DHT
//		segment, instead of just one.
//	  - It supports restart markers.
//    - It supports greyscale images (i.e. frames with a single component)
//
//	- This code incorporates some error checking, at the cost of being
//	  a bit less compact. At the moment we simply 'abort' if we find an
//	  error. This is perhaps not always appropriate.
//
//	- This code supports an Adobe extension for images that use an
//	  Ycc or YccK 'colourspace'. Use of such a colourspace tends to offer
//	  better compression, and it can be regularly found in images in
//	  PDF files.
//
//************************************************************************

import StdClass, StdBool, StdChar, StdInt, StdFunc
import StdList, StdEnum, StdArray, StdMisc, StdString

import scanbin, huffman, bitstream, basic, dct, matrix
import colourTransform, imagePlanes

//*******************************************************************************
//
//	Main DCTDecode function
//
//*******************************************************************************

DCTDecode :: [!Char] -> [!Char]
DCTDecode bytes = scanSegments (segment bytes) emptyState 

//
// Comment one of the following out to pick the appropriate DCT transformation
//
/*
IDCT64ScaleFactors :== IDCT64SlowScaleFactors
IDCT64 :== IDCT64Slow
:: QuantisationTable :== [[Int]]
*/
/*
IDCT64ScaleFactors :== IDCT64RealScaleFactors
IDCT64 :== IDCT64Real
:: QuantisationTable :== [[Real]]
*/

IDCT64ScaleFactors :== IDCT64IntScaleFactors
IDCT64 :== IDCT64Int
:: QuantisationTable :== [[Int]]

//
// The following definitions can be used to set the apropriate dictionary
// implementation.
//
/*
import dictionary
:: Dict key value :== BTree key value
newDict :== newBTree
*/
import simpledict
:: Dict key value :== SimpleDict key value
newDict :== emptySimpleDict

//*******************************************************************************
//
//	Constants
//
//*******************************************************************************

SOF0	:== '\xC0'	// Start Of Frame (Baseline DCT)

DHT		:== '\xC4'	// Define Huffman Table(s)
SOI		:== '\xD8'	// Start Of Image
EOI		:== '\xD9'	// End Of Image
SOS		:== '\xDA'	// Start Of Scan
DQT		:== '\xDB'	// Define Quantisation Table(s)
DRI		:== '\xDD'	// Define Restart Interval
COM		:== '\xFE'	// Comment

// Application segments

APP0	:== '\xE0'
APPE	:== '\xEE'
APPF	:== '\xEF'

// Restart Markers

RST0	:== '\xD0'
RST7	:== '\xD7'

// Adobe transform values

AdobeNoTransform	:== 0
AdobeYcc2RGB		:== 1
AdobeYccK2CMYK		:== 2

//*******************************************************************************
//
//	Basic Types
//
//*******************************************************************************

:: Segment	:== (Char, [Char])

:: Dim :== (!Int, !Int)
:: DCTable :== HuffTree Int
:: ACTable :== HuffTree (Int, Int)

//*******************************************************************************
//
//	Segment Types
//
//*******************************************************************************

:: Frame =
	{
		samplePrecision		:: Int,
		dimensions			:: Dim,
		maxSamplingFactors	:: Dim,
		nComponents			:: Int,
		adobeTransform		:: Int,
		imageComponents		:: Dict !Int ImageComponent
	}
	
:: ImageComponent =
	{
		samplingFactors	:: Dim,
		qTabId			:: Int
	}

:: Scan =
	{
		scanComponents	:: [ScanComponent],
		startSpectral	:: Int,
		endSpectral		:: Int,
		appoxHigh		:: Int,
		approxLow		:: Int
	}

:: ScanComponent =
	{
		imageComponentId	:: Int,
		dcTableId			:: Int,
		acTableId			:: Int
	}
	
:: ScanState =
	{
		quantisationTables	:: Dict Int QuantisationTable,
		dcTables			:: Dict Int DCTable,
		acTables			:: Dict Int ACTable,
		frame				:: Frame,
		restartInterval		:: Int,
		comments			:: [String]
	}

emptyState :: ScanState
emptyState =
	{
		quantisationTables	= newDict,
		dcTables			= newDict,
		acTables			= newDict,
		frame				= emptyFrame,
		restartInterval		= 0,				// Zero means: no restart markers
		comments			= []
	}
where
	emptyFrame =
		{
			samplePrecision 	= 0,			// Bits per sample
			dimensions			= (0,0),		// Size of entire image
			maxSamplingFactors	= (0,0),		// Maximum of all components
			nComponents			= 0,			// Number of components
			adobeTransform		= 0,			// Zero means: no transformation needed
			imageComponents		= newDict		// Set of component descriptions
		}
	
//***************************************************************************
//
// Segmenting
//
// The JPEG format divides its data into segments by means of markers. All
// markers start with a '\xFF' byte and are followed by a byte that is not
// '\x00' or '\xFF'. Any marker may optionally be preceded by any number
// of '\FF' fill bytes. The following function divides the input into a
// number of segments, consisting of a marker code and the bytes of the segment
//
//***************************************************************************

segment :: [Char] -> [Segment]
segment cs = segs
where
	(_, segs) = segment cs	// bytes that precede first mark are ignored
	
	segment [c : cs]
	| c <> '\xFF'
		= ([c : seg], segs)
		with
			(seg, segs) = segment cs
		= marker cs
		with
			marker [c : cs]	
			| c == '\x00'
				= (['\xFF' : seg], segs)
				with (seg, segs) = segment cs
			| c == '\xFF'
				= (['\xFF' : seg], segs)
				with (seg, segs) = marker cs
				//= marker cs
				= (terminator, [(c, seg) : segs])
				with (seg, segs) = segment cs
			marker _ = abort "marker lacks second byte"
	segment []
		= (terminator, [])
	
	// We end all segments by an infinite sequence of 1-bits. These act as a terminator
	// for the Huffman decoders. Normally we should not encountered these
	// codes, because we only read as many codes as needed, but if we do we can signal
	// an error.
	
	terminator = repeat '\xFF'

//*************************************************************************
//
// Scanning of Application Markers (Currently only Adobe)
//
//*************************************************************************

scanAppE :: ScanState [Char] (ScanState -> [Char]) -> [Char]
scanAppE state bytes continue
	# (bytes, identifier) 					= takeString (size "Adobe") bytes
	| identifier <> "Adobe"					= abort "unknown APPE marker (not an Adobe marker)"
	# (bytes, version)						= takeInt16 bytes
	# (bytes, flags0)						= takeInt16 bytes
	# (bytes, flags1)						= takeInt16 bytes
	# (bytes, transform)					= takeInt8 bytes
	| version < 100							= abort "Unsupported Adobe revision number"
	| (flags0 bitand (bitnot 0x8000)) <> 0	= abort "Unknown Adobe flags0"
	| flags1 <> 0							= abort "Unknown Adobe flags1"
	| transform > 2							= abort ("Unknown Adobe colour transform code: " +++ toString transform)
											= continue {state & frame.adobeTransform = transform}
	       
//*************************************************************************
//
// Scanning
//
//*************************************************************************
	
//
// The quantisation factors are stored in zigzag order, so normally dequantisation
// will be done before de-zigzagging each block. We do not do this, but we
// perform dequantisation inside the inverse DCT transform for reasons of
// efficiency. Therefore we de-zigzag the quantisation factors here as well.

scanQuantisationTables :: ScanState Int [Char] (ScanState -> [Char]) -> [Char]
scanQuantisationTables state remain bytes continue
	| remain > 0
		# (bytes, (precision, i)) = take2Int4 bytes
		| precision == 0
			# (bytes, table)	= takeN 64 takeInt8 bytes
			# table				= IDCT64ScaleFactors (zigzag table)
			# state				= {state & quantisationTables = insertReplace state.quantisationTables (i, table)}
			= scanQuantisationTables state (remain - 65) bytes continue
		| precision == 1
			# (bytes, table)	= takeN 64 takeInt16 bytes
			# table				= IDCT64ScaleFactors (zigzag table)
			# state				= {state & quantisationTables = insertReplace state.quantisationTables (i, table)}
			= scanQuantisationTables state (remain - 129) bytes continue
			= abort "scanQuantisationTable: unknown precision"
		= continue state

scanFrame :: ScanState [Char] (ScanState -> [Char]) -> [Char]
scanFrame state bytes continue
	# (bytes, samplePrecision)			= takeInt8 bytes
	# (bytes, ySize)					= takeInt16 bytes
	# (bytes, xSize)					= takeInt16 bytes
	# (bytes, nComponents)				= takeInt8 bytes
	# ((_,maxH,maxV), imageComponents)	= takeN nComponents scanImageComponent (bytes, 0, 0)
	# frame =
		{
			samplePrecision		= samplePrecision,
			dimensions			= (xSize, ySize),
			maxSamplingFactors	= (maxH, maxV),
			nComponents			= nComponents,
			adobeTransform		= state.frame.adobeTransform,
			imageComponents		= insertReplaceList newDict imageComponents
		}
	= continue {state & frame = frame}
where
	scanImageComponent (bytes, maxH, maxV)
		# (bytes, id)		= takeInt8 bytes
		# (bytes, (h,v))	= take2Int4 bytes
		# (bytes, qTabId)	= takeInt8 bytes
		= ((bytes, max maxH h, max maxV v),
			(id,
				{
					samplingFactors	= (h,v),
					qTabId			= qTabId		
				}
			))

scanHuffmanTable :: ScanState Int [Char] (ScanState -> [Char]) -> [Char]
scanHuffmanTable state remain bytes continue
	| remain > 0
		# (bytes, (tClass, i))	= take2Int4 bytes
		# (bytes, counts)		= takeN 16 takeInt8 bytes
		# nValues				= sum counts
		# remain				= remain - 17 - nValues
		| tClass == 0
			# (bytes, values)	= takeN nValues takeInt8 bytes
			# dcTree			= huffTreeFromCountsAndValues counts values
			# state				= {state & dcTables = insertReplace state.dcTables (i, dcTree)}
			= scanHuffmanTable  state remain bytes continue
		| tClass == 1
			# (bytes, values)	= takeN nValues take2Int4 bytes
			# acTree			= huffTreeFromCountsAndValues counts values
			# state				= {state & acTables = insertReplace state.acTables (i, acTree)}
			= scanHuffmanTable state remain bytes continue
			= abort "unknown class"
		= continue state

scanRestartInterval :: ScanState [Char] (ScanState -> [Char]) -> [Char]
scanRestartInterval state bytes continue
	# (bytes, interval)	= takeInt16 bytes
	= continue {state & restartInterval = interval}

scanScan :: [Char] -> (Scan, LEBitStream)
scanScan bytes
	# (bytes, nComponents)				= takeInt8 bytes
	# (bytes, components)				= takeN nComponents scanScanComponent bytes
	# (bytes, startSpectral)			= takeInt8 bytes
	# (bytes, endSpectral)				= takeInt8 bytes
	# (bytes, (approxHigh, approxLow))	= take2Int4 bytes
	# scan =
		{
			scanComponents	= components,
			startSpectral	= startSpectral,
			endSpectral		= endSpectral,
			appoxHigh		= approxHigh,
			approxLow		= approxLow
		}
	= (scan, toBitStream bytes)
where
	scanScanComponent bytes
		# (bytes, imageComponentId)			= takeInt8 bytes
		# (bytes, (dcTableId, acTableId))	= take2Int4 bytes
		= (bytes,
			{
				imageComponentId	= imageComponentId,
				dcTableId			= dcTableId,
				acTableId			= acTableId
			})

scanComment :: ScanState [Char] (ScanState -> [Char]) -> [Char]
scanComment state comment continue
	= continue {state & comments = [(toString comment) : state.comments]}

scanSegments :: [Segment] ScanState -> [Char]
scanSegments [(mark, segment) : segments] state
	= case mark of
		SOF0	-> scanFrame state segment` (scanSegments segments)
		DHT		-> scanHuffmanTable state (length - 2) segment` (scanSegments segments)
		DQT		-> scanQuantisationTables state (length - 2) segment` (scanSegments segments)
		SOS		-> decode state (scanScan segment`) (takeDataSegments segments)
		DRI		-> scanRestartInterval state segment` (scanSegments segments)
		COM		-> scanComment state segment` (scanSegments segments)
		SOI 	-> scanSegments segments emptyState
		EOI 	-> scanSegments segments emptyState
		
		// Application segments
		
		APPE	-> scanAppE state segment` (scanSegments segments)
		
		// Unsupported segments
		
		_  		-> abort ("unsupported segment: " +++ (toString [c1, c2]))
					with (c1, c2) = charToHex mark
			
	where
		(segment`, length) = takeInt16 segment
		(data, segments`) = takeDataSegments segments	
		
		takeDataSegments :: [Segment] -> ([LEBitStream], [Segment])
		takeDataSegments [(mark, segment) : segments]
		| mark >= RST0 && mark <= RST7
			= ([toBitStream segment : data], segments`)
			with (data, segments`) = takeDataSegments segments
			= ([], [(mark, segment) : segments])
		
		decode state (scan, data1) (dataN, segments)
			= decodeScan scan state [data1 : dataN] (scanSegments segments state)

scanSegments [] state
	= []

//*************************************************************************
//
// MCU Decoder
//
// The JPEG standard does not allow more than 10 blocks/MCU.
// Adobes code sometimes does not adhere to that. We do not care: this
// code does not contain any limits in this respect.
//
// Likewise we support an unlimited number of components instead of just
// the 4 that the JPEG standard prescribes.
//
//*************************************************************************

decodeScan :: Scan ScanState [LEBitStream] [Char] -> [Char]
decodeScan scan state segments continue
| frame.nComponents == 1
	= imageToCharList image xSize ySize continue
	with
		image			= segments2greyImage precision image`		
		(_, image`)		= takeNxN xMCUCount yMCUCount takeListElement mcus
		mcus 			= takeMCUs (xMCUCount * yMCUCount) state.restartInterval mcuDecoder segstate
		mcuDecoder		= componentDecoder True state (hd scan.scanComponents)
		xMCUCount		= xSize ceilDiv 8
		yMCUCount		= ySize ceilDiv 8
		segstate		= map (\s -> (s, 0)) segments
		
	= imageToCharList image xSize ySize continue
	with
		image			= segments2colourImage precision frame image`
		(_, image`)		= takeNxN xMCUCount yMCUCount takeListElement mcus
		mcus 			= takeMCUs (xMCUCount * yMCUCount) state.restartInterval mcuDecoder segstate		
		mcuDecoder s	= takeSeq (map (componentDecoder False state)  scan.scanComponents) s
		xMCUCount		= xSize ceilDiv (8 * maxH)
		yMCUCount		= ySize ceilDiv (8 * maxV)
		segstate		= map (\s -> (s, repeat 0)) segments
	
where
	frame			= state.frame
	precision		= frame.samplePrecision
	(maxH, maxV)	= frame.maxSamplingFactors
	(xSize, ySize)	= frame.dimensions
	
	// The following function delivers a number of MCUs, taking into account
	// the proper restart interval
	
	takeMCUs n restart mcuDecoder segments
	| restart == 0
		= mcus
		with
			(_, mcus) = takeN n mcuDecoder (hd segments)
		= takeNfromSegments n restart mcuDecoder segments

	// The following function delivers an 'minimum coding unit' (MCU) decoder
	// for each component.
	//
	// The h and v values are the horizontal and vertical sampling factors.
	// For colour images (single = False) they indicate how many units of this
	// component in both directions are present in the MCU.  For greyscale images
	// (single = True) each mcu consists of a single unit, regardless of
	// the sampling factors of the component
	//
	// The decoder value is a decoder for one unit. It is built from two functions:
	// the entropy decoder consumes a number of bits from the stream, and builds a
	// list of 64 integers. The unitDecoder takes this list, dequantises it, performs
	// the inverse DCT and upsamples the result.
	//

	componentDecoder single state compDef input
	| single
		= decoder input
		= (input`, matFlatten component)
		with (input`, component) = takeNxN h v decoder input
	where
		decoder				= entropyDecoder dcTable acTable (unitDecoder precision upsampleValues quantisationTable)
		imageComponent		= get frame.imageComponents compDef.imageComponentId
		dcTable				= get state.dcTables compDef.dcTableId
		acTable				= get state.acTables compDef.acTableId
		quantisationTable	= get state.quantisationTables imageComponent.qTabId
		(h,v)				= imageComponent.samplingFactors
		(maxH, maxV)		= frame.maxSamplingFactors
		upsampleValues		= (maxH/h, maxV/v)


//*************************************************************************
//
// Unit Decoder
//
//*************************************************************************

unitDecoder :: Int Dim QuantisationTable [Int] -> [[Int]]
unitDecoder precision u quaTab block
	= upsamp u (IDCT64 (zigzag block) quaTab (SignedRangeLimitFunction precision))	

//*************************************************************************
//
// ac and dc decoders (6 seems to be a good lookahead for the Huffman decoders)
//
//*************************************************************************

entropyDecoder :: (HuffTree Int) (HuffTree (Int, Int)) ([Int] -> [[Int]]) (LEBitStream, Int) -> ((LEBitStream, Int), [[Int]])
entropyDecoder dcHuffTree acHuffTree process (s, dc) = dcDecoder s dc
where
	dcDecoder = huffmanDecoder dcHuffTree 6 (dcDecode acDecoder process, decodeUndefined)
	acDecoder = huffmanDecoder acHuffTree 6 (acDecode acDecoder, decodeUndefined)

	//
	// dcDecode takes 'valueLength' bits from its input stream, sign-extends
	// it and adds it to the current dc value (which is passed in as state).
	// Then it calls the argument acDecoder to deliver the rest
	//

	dcDecode acDecoder process nbits valueLength
		= getValue
	where
		getValue stream dc
			= ((stream``, dc`), process [dc` : acValues])
			where
				(value, stream`)		= getBits (skipBits stream nbits) valueLength
				(stream``, acValues)	= acDecoder stream` 0
				delta					= extend value valueLength
				dc`						= dc + delta

	acDecode acDecoder nbits (0,0)
		= completeWithZeroes
	where
		completeWithZeroes stream nProduced
			= (skipBits stream nbits, copy (63 - nProduced) 0 [])
	acDecode acDecoder nbits (nzeroes, valueLength)
		= getValue
	where
		getValue stream nProduced
			= (stream``, copy nzeroes 0 [ac : acValues])
		where
			(value, stream`)		= getBits (skipBits stream nbits) valueLength
			(stream``, acValues)	= if (nProduced` < 63)
										(acDecoder stream` nProduced`)
										(stream`, [])
			ac						= extend value valueLength
			nProduced`				= nProduced + nzeroes + 1

//*************************************************************************
//
// Get the decoded (planar) parts into a chunky image format
//
//*************************************************************************

segments2greyImage :: Int (Matrix (Matrix Int)) -> Image
segments2greyImage precision image
	= Grey precision (matFlatten image)

segments2colourImage :: Int Frame (Matrix [Matrix Int]) -> Image
segments2colourImage precision frame planarSegments
| frame.nComponents == 3
	| frame.adobeTransform == AdobeYcc2RGB
		= Chunky3xN precision (matMap (Ycc2RGBFunction precision) image)
		with image = matFlatten (matMap matZip3 planarSegments)
		= Chunky3xN precision image
		with image = matFlatten (matMap matZip3 planarSegments)
		
| frame.nComponents == 4
	| frame.adobeTransform == AdobeYccK2CMYK
		= Chunky4xN precision (matMap (YccK2CMYKFunction precision) image)
		with image = matFlatten (matMap matZip4 planarSegments)
		= Chunky4xN precision image
		with image = matFlatten (matMap matZip4 planarSegments)

	= Chunky precision (matFlatten (matMap matZip planarSegments))
	
//*********************************************************************************
//
// Zig-zagging
//
// We use a transposed version of the zigzag table, so that we can avoid a
// transpose during dct decoding.
//
//*********************************************************************************

zigzagTable =: 	transpose
					[[ 0, 1, 5, 6,14,15,27,28]
					,[ 2, 4, 7,13,16,26,29,42]
					,[ 3, 8,12,17,25,30,41,43]
					,[ 9,11,18,24,31,40,44,53]
					,[10,19,23,32,39,45,52,54]
					,[20,22,33,38,46,51,55,60]
					,[21,34,37,47,50,56,59,61]
					,[35,36,48,49,57,58,62,63]
					]

zigzag :: [a] -> [[a]]
zigzag xs = matMap (\i -> as.[i]) zigzagTable
where as = {x \\ x <- xs}

//*********************************************************************************
//
// Some Small Stuff
//
//*********************************************************************************

upsamp :: Dim [[a]] -> [[a]]
upsamp (1,1) mat	= mat
upsamp (x,y) mat	= multi y (map (multi x) mat)
where
	multi n [x : xs]	= copy n x (multi n xs)
	multi n []			= []

copy :: Int a [a] -> [a]	
copy n x rest
| n == 0
	= rest
	= [x : copy (n - 1) x rest]

//
// extend sign-extends a value of 'length' bits
//

extend :: !Int !Int -> Int
extend value length
| length == 0	= 0
| value >= half	= value
				= value + 1 - (half << 1)
where
	half = 1 << (length - 1)	

(ceilDiv) infix 7 :: Int Int -> Int
(ceilDiv) n d = (n+d-1)/d
