import os
from picamera2 import Picamera2
from libcamera import Transform
import cv2
from libcamera import controls
import numpy as np
import time
import pprint


os.environ["LIBCAMERA_LOG_LEVELS"] = "3"
w, h = 640, 480

_TRANSFORMS = {
    0: Transform(vflip=1), #hflip=1: for Lab item , vflip=1: for metal   
    1: Transform(),  
}

# ~ EXPOSURE = 15000
# ~ AN_GAIN = 3


def linear_contrast_stretch(image):
    
    if image.ndim != 2:
        raise ValueError("this function expect a single channel (2D) image")
    
    min_val, max_val = image.min(), image.max()
    
    if max_val > min_val:
        stretched = (image.astype(np.float32) - min_val) * (255. / (min_val - max_val))
        return np.uint8(np.clip(stretched, 0, 255)) 
    return image         
    


def make_camera(index):
    picam = Picamera2(index) 
    
    base_config = {"FrameDurationLimits": (16667, 16667),
                   "FrameRate": 60,
                   "Saturation": 0, 
                   # ~ "ExposureTime": EXPOSURE,
                   # ~ "AnalogueGain":AN_GAIN
                   }
                  
    config = picam.create_preview_configuration(
        main={
            "size": (w, h), 
            "format": "YUV420",
        },
        transform=_TRANSFORMS[index],
        controls=base_config)
    
    picam.configure(config)
    time.sleep(2.0) 
    actual_config = picam.camera_configuration()
    print(f"camera actual controls:")
    pprint.pprint(actual_config)
    
    return picam


def main(plot=True):

    picams = [make_camera(i) for i in range(2)]
    
    for cam in picams:
        cam.start()   
         
    time.sleep(1.0)

    c = 0
    try:
        while True:
            req0 = picams[0].capture_request()
            req1 = picams[1].capture_request()

            frame_0 = req0.make_array(name="main")     
            frame_1 = req1.make_array(name="main")  
            metadata_0 = req0.get_metadata() 
            metadata_1 = req1.get_metadata() 
                       
            # YUV420 case
            y_h_0 = frame_0.shape[0] * 2 // 3
            frame_0 = frame_0[:y_h_0, :]
            y_h_1 = frame_1.shape[0] * 2 // 3
            frame_1 = frame_1[:y_h_0, :]
            frame_0 = linear_contrast_stretch(frame_0)
            frame_1 = linear_contrast_stretch(frame_1)

            req0.release()
            req1.release() 
            
            c += 1

            if plot: 
                cv2.imshow('cam_0cam_1', np.hstack((frame_0, frame_1)))
            
            if cv2.waitKey(1) == ord('q'):
                break


    except KeyboardInterrupt:
        pass
    finally:
        [cam.stop() for cam in picams if cam is not None]
        [cam.close() for cam in picams if cam is not None]
        print("Cameras stopped.")    


if __name__ == "__main__":
    cv2.startWindowThread()
    
    main()
    
    cv2.destroyAllWindows()
