import os
from picamera2 import Picamera2
from libcamera import Transform
import cv2
from libcamera import controls
import numpy as np
import time
import pprint
from typing import List, Any
from collections import deque


os.environ["LIBCAMERA_LOG_LEVELS"] = "3"
w, h = 640, 480

_TRANSFORMS = {
    0: Transform(vflip=1), #hflip=1: for Lab item , vflip=1: for metal   
    1: Transform(),  
}

# ~ EXPOSURE = 15000
# ~ AN_GAIN = 3


def linear_contrast_stretch(image):
    
    if image.ndim != 2:
        raise ValueError("this function expect a single channel (2D) image")
    
    min_val, max_val = image.min(), image.max()
    
    if max_val > min_val:
        stretched = (image.astype(np.float32) - min_val) * (255. / (min_val - max_val))
        return np.uint8(np.clip(stretched, 0, 255)) 
    return image         
    


def make_camera(index):
    picam = Picamera2(index) 
    
    base_config = {"FrameDurationLimits": (16667, 16667),
                   "FrameRate": 60,
                   "Saturation": 0, 
                   # ~ "ExposureTime": EXPOSURE,
                   # ~ "AnalogueGain":AN_GAIN
                   }
                  
    config = picam.create_preview_configuration(
        main={
            "size": (w, h), 
            "format": "YUV420",
        },
        transform=_TRANSFORMS[index],
        controls=base_config)
    
    picam.configure(config)
    time.sleep(2.0) 
    actual_config = picam.camera_configuration()
    print(f"camera actual controls:")
    pprint.pprint(actual_config)
    
    return picam

def grab_sync_pair(picams: List[Any], frame_duration: float):
    req0 = picams[0].capture_request()
    req1 = picams[1].capture_request()
    while True:
        ts0 = req0.get_metadata()['SensorTimestamp'] / 1000  # use microseconds
        ts1 = req1.get_metadata()['SensorTimestamp'] / 1000
        if ts0 + frame_duration / 2 < ts1:  # req0 too early, next frame should match better
            req0.release()
            req0 = picams[0].capture_request()
        elif ts1 + frame_duration / 2 < ts0:  # req1 too early
            req1.release()
            req1 = picams[1].capture_request()
        else:
            return req0, req1    


def check_sync(duration=float("inf"), threshold=30.0, window=20):
        
    duration = float(duration)
    threshold = float(threshold)
    window = int(window)

    #set_trigger_mode(enable=True)

    picams = [make_camera(i) for i in range(2)]

    frame_duration = 1000000 / 60.
    
    for cam in picams:
        cam.start()
 
    deltas = deque(maxlen=window)
    start_time = time.time()
    status = True    
    
    try:
        c = 0
        while True:
            # break condition for finite-duration runs
            if duration != float("inf") and time.time() - start_time > duration:
                break
            
            req0, req1 = grab_sync_pair(picams, frame_duration)
            
            ts0 = req0.get_metadata()['SensorTimestamp'] / 1000  # use microseconds
            ts1 = req1.get_metadata()['SensorTimestamp'] / 1000
            
            diff = abs(ts0 - ts1)
            deltas.append(diff)

            c += 1
            if c % 20 == 0:
                print("timestamp delta between two cameras: ", ts0, "and", ts1, "difference", diff)
            
            req0.release()
            req1.release()
            
            if duration != float("inf") and len(deltas) == window:
                avg_diff = sum(deltas) / len(deltas)
                if avg_diff > threshold:
                    status = False
       
    except KeyboardInterrupt:
        # Ctrl + C to properly stop cameras
        print("Stopping cameras (Ctrl+C pressed)...")
        status = False if duration == float("inf") else status
    finally:
        [cam.stop() for cam in picams if cam is not None]
        [cam.close() for cam in picams if cam is not None]
        print("Cameras stopped.")    
        
    if duration == float("inf"):
        return None
    return status    
    
def main(plot=True):
    print(check_sync())
    
    picams = [make_camera(i) for i in range(2)]
    
    for cam in picams:
        cam.start()   
         
    time.sleep(1.0)

    c = 0
    
    try:
        while True:
            req0 = picams[0].capture_request()
            req1 = picams[1].capture_request()

            frame_0 = req0.make_array(name="main")     
            frame_1 = req1.make_array(name="main")  
            metadata_0 = req0.get_metadata() 
            metadata_1 = req1.get_metadata() 
                       
            # YUV420 case
            y_h_0 = frame_0.shape[0] * 2 // 3
            frame_0 = frame_0[:y_h_0, :]
            y_h_1 = frame_1.shape[0] * 2 // 3
            frame_1 = frame_1[:y_h_0, :]
            frame_0 = linear_contrast_stretch(frame_0)
            frame_1 = linear_contrast_stretch(frame_1)

            req0.release()
            req1.release() 
            
            c += 1

            if plot: 
                cv2.imshow('cam_0cam_1', np.hstack((frame_0, frame_1)))
            
            if cv2.waitKey(1) == ord('q'):
                break


    except KeyboardInterrupt:
        pass
    finally:
        [cam.stop() for cam in picams if cam is not None]
        [cam.close() for cam in picams if cam is not None]
        print("Cameras stopped.")    


if __name__ == "__main__":
    cv2.startWindowThread()
    
    main()
    
    cv2.destroyAllWindows()
