import os
from picamera2 import Picamera2
from libcamera import Transform
import cv2
from libcamera import controls
import numpy as np
import time
import pprint


os.environ["LIBCAMERA_LOG_LEVELS"] = "3"
w, h = 640, 480

_TRANSFORMS = {
    0: Transform(vflip=1), #hflip=1: for Lab item , vflip=1: for metal   
    1: Transform(),  
}

# ~ EXPOSURE = 15000
# ~ AN_GAIN = 3


def linear_contrast_stretch(image):
    
    if image.ndim != 2:
        raise ValueError("this function expect a single channel (2D) image")
    
    min_val, max_val = image.min(), image.max()
    
    if max_val > min_val:
        stretched = (image.astype(np.float32) - min_val) * (255. / (min_val - max_val))
        return np.uint8(np.clip(stretched, 0, 255)) 
    return image         
    


def make_camera(index):
    picam = Picamera2(index) 
    
    base_config = {"FrameDurationLimits": (16667, 16667),
                   "FrameRate": 60,
                   "Saturation": 0, 
                   # ~ "ExposureTime": EXPOSURE,
                   # ~ "AnalogueGain":AN_GAIN
                   }
                  
    config = picam.create_preview_configuration(
        main={
            "size": (w, h), 
            "format": "YUV420",
        },
        transform=_TRANSFORMS[index],
        controls=base_config)
    
    picam.configure(config)
    time.sleep(2.0) 
    actual_config = picam.camera_configuration()
    print(f"camera actual controls:")
    pprint.pprint(actual_config)
    
    return picam

def grab_sync_pair(picams: List[Any], frame_duration: float):
    req0 = picams[0].capture_request()
    req1 = picams[1].capture_request()
    while True:
        ts0 = req0.get_metadata()['SensorTimestamp'] / 1000  # use microseconds
        ts1 = req1.get_metadata()['SensorTimestamp'] / 1000
        if ts0 + frame_duration / 2 < ts1:  # req0 too early, next frame should match better
            req0.release()
            req0 = picams[0].capture_request()
        elif ts1 + frame_duration / 2 < ts0:  # req1 too early
            req1.release()
            req1 = picams[1].capture_request()
        else:
            return req0, req1    


def check_sync(duration=float("inf"), threshold=30.0, window=20):
        
    duration = float(duration)
    threshold = float(threshold)
    window = int(window)

    set_trigger_mode(enable=True)

    picams = [make_camera(i) for i in range(2)]

    frame_duration = 1000000 / 60.
    
    for cam in picams:
        cam.start()
 
    deltas = deque(maxlen=window)
    start_time = time.time()
    status = True    
    
    try:
        c = 0
        while True:
            # break condition for finite-duration runs
            if duration != float("inf") and time.time() - start_time > duration:
                break
            
            req0, req1 = grab_sync_pair(picams, frame_duration)
            
            ts0 = req0.get_metadata()['SensorTimestamp'] / 1000  # use microseconds
            ts1 = req1.get_metadata()['SensorTimestamp'] / 1000
            
            diff = abs(ts0 - ts1)
            deltas.append(diff)

            c += 1
            if c % 20 == 0:
                print("timestamp delta between two cameras: ", ts0, "and", ts1, "difference", diff)
            
            req0.release()
            req1.release()
            
            if duration != float("inf") and len(deltas) == window:
                avg_diff = sum(deltas) / len(deltas)
                if avg_diff > threshold:
                    status = False
       
    except KeyboardInterrupt:
        # Ctrl + C to properly stop cameras
        print("Stopping cameras (Ctrl+C pressed)...")
        status = False if duration == float("inf") else status
    finally:
        [cam.stop() for cam in picams if cam is not None]
        [cam.close() for cam in picams if cam is not None]
        print("Cameras stopped.")    
        
    if duration == float("inf"):
        return None
    return status    
    
def main(plot=True):
    print(check_sync())
    
    picams = [make_camera(i) for i in range(2)]
    
    for cam in picams:
        cam.start()   
         
    time.sleep(1.0)

    c = 0
    
    try:
        while True:
            req0 = picams[0].capture_request()
            req1 = picams[1].capture_request()

            frame_0 = req0.make_array(name="main")     
            frame_1 = req1.make_array(name="main")  
            metadata_0 = req0.get_metadata() 
            metadata_1 = req1.get_metadata() 
                       
            # YUV420 case
            y_h_0 = frame_0.shape[0] * 2 // 3
            frame_0 = frame_0[:y_h_0, :]
            y_h_1 = frame_1.shape[0] * 2 // 3
            frame_1 = frame_1[:y_h_0, :]
            frame_0 = linear_contrast_stretch(frame_0)
            frame_1 = linear_contrast_stretch(frame_1)

            req0.release()
            req1.release() 
            
            c += 1

            if plot: 
                cv2.imshow('cam_0cam_1', np.hstack((frame_0, frame_1)))
            
            if cv2.waitKey(1) == ord('q'):
                break


    except KeyboardInterrupt:
        pass
    finally:
        [cam.stop() for cam in picams if cam is not None]
        [cam.close() for cam in picams if cam is not None]
        print("Cameras stopped.")    


if __name__ == "__main__":
    cv2.startWindowThread()
    
    main()
    
    cv2.destroyAllWindows()
