from flask import current_app
from enum import Enum
import json
import os
from typing import Optional
from LMSAPI import scheduler


class CronTaskStatus(Enum):
    active = "active"
    processing = "processing"
    completed = "completed"
    failed = "failed"

    @classmethod
    def get_values(cls) -> list:
        return [t.value for t in CronTaskStatus]
    

    @classmethod
    def is_valid(cls, status: int) -> bool:
        return status in cls.get_values()
    
    @classmethod
    def get_name(cls, status: int) -> str:
        if not cls.is_valid(status):
            raise Exception("status is invalid")
        return cls(status).name

class CronTaskType(Enum):
    image_converter = 1

    @classmethod
    def get_values(cls) -> list:
        return [t.value for t in CronTaskType]
    

    @classmethod
    def is_valid(cls, task_type: int) -> bool:
        return task_type in cls.get_values()
    
    @classmethod
    def get_name(cls, task_type: int) -> str:
        if not cls.is_valid(task_type):
            raise Exception("task_type is invalid")
        return cls(task_type).name
    

class CronTasks:

    # table:
    # cron_tasks

    # params:
    # id - SERIAL PRIMARY KEY
    # status - VARCHAR NOT NULL DEFAULT 'active'
    # task_type - integer NOT NULL
    # task_params - JSON
    # error_msg - VARCHAR
    # received_time - TIMESTAMP WITH TIME ZONE DEFAULT NOW()
    # done_time - TIMESTAMP WITH TIME ZONE

    @staticmethod
    def check_cron_table(lname: str) -> bool:
        with scheduler.app.app_context():
            conn = scheduler.app.ms.db(lname).connect()
            sql = """
            SELECT EXISTS (
                SELECT FROM information_schema.tables
                WHERE table_schema = 'public'
                    AND table_name = 'cron_tasks'
            );
            """
            query = conn.execute(sql)
            keys = query.keys()
            res = [dict(zip(tuple(keys) ,i)) for i in query]
            conn.close()
            return res[0].get("exists")

    @staticmethod
    def validate_task_params(task_type: CronTaskType, task_params: dict) -> dict:
        if task_type == CronTaskType.image_converter:
            file_path = task_params.get("file_path")
            if not file_path:
                raise Exception("file_path is required")
            density = task_params.get("density")
            output_format = task_params.get("output_format")
            output_file_dir = task_params.get("output_file_dir")
            return {
                "file_path": file_path,
                "density": density,
                "output_format": output_format,
                "output_file_dir": output_file_dir
            }
        else:
            raise Exception("task_type is invalid")

    @staticmethod
    def create_cron_task(lname: str, task_type: CronTaskType, task_params: dict) -> None:
        with scheduler.app.app_context():
            conn = scheduler.app.ms.db(lname).connect()
            sql = """
            INSERT INTO cron_tasks (task_type, task_params)
            VALUES ({task_type}, {task_params})
            """.format(
                task_type=task_type.value,
                task_params="'{}'".format(json.dumps(task_params)) or "NULL"
            )
            conn.execute(sql)
            conn.close()

    @staticmethod
    def get_cron_tasks(lname: str, task_type: CronTaskType) -> list:
        with scheduler.app.app_context():
            conn = scheduler.app.ms.db(lname).connect()
            sql = """
            SELECT
                    id,
                    task_params

            FROM cron_tasks
            WHERE status = 'active'
                AND task_type = {}
            ORDER BY received_time ASC
            """.format(task_type.value)

            try:
                query = conn.execute(sql)
            except Exception as e:
                with scheduler.app.app_context():
                    current_app.logger.error("get cron tasks failed: %s", str(e))
                raise e
            keys = query.keys()
            res = [dict(zip(tuple(keys) ,i)) for i in query]

            sql = """
            UPDATE cron_tasks
            SET status = 'processing'
            WHERE id IN ({});
            """.format(",".join(str(r.get("id")) for r in res) or "0")
            query = conn.execute(sql)
            conn.close()
            return res


    @staticmethod
    def update_task(lname: str, task_id: int, status: str, error_msg: Optional[str] = None) -> None:
        with scheduler.app.app_context():
            conn = scheduler.app.ms.db(lname).connect()
            sql = """
            UPDATE cron_tasks
            SET status = '{status}',
                done_time = NOW(),
                error_msg = {error_msg}
            WHERE id = {task_id};
            """.format(
                status=status,
                task_id=task_id,
                error_msg="'{}'".format(error_msg) if error_msg else "NULL"
            )
            conn.execute(sql)
            conn.close()
    
    @staticmethod
    def update_cron_tasks(lname: str, completed_tasks: list, error_tasks: dict) -> None:
        with scheduler.app.app_context():
            conn = scheduler.app.ms.db(lname).connect()

            sql = """
            UPDATE cron_tasks
            SET status = 'completed',
                done_time = NOW()
            WHERE id in {completed_tasks};

            UPDATE cron_tasks
            SET status = 'failed',
                done_time = NOW(),
                error_msg = t.error_msg
            FROM (
                SELECT
                        unnest({error_tasks_keys}) AS id,
                        unnest({error_tasks_values}) AS error_msg
            ) AS t
            WHERE cron_tasks.id = t.id;
            """.format(
                completed_tasks="({})".format(",".join(str(i) for i in completed_tasks)) if len(completed_tasks) > 0 else "(0)",
                error_tasks_keys="ARRAY[{}]".format(
                        ",".join(str(i) for i in error_tasks.keys())
                    )
                    if len(error_tasks.keys()) > 0
                    else "ARRAY[]::integer[]",
                error_tasks_values="ARRAY[{}]".format(
                        ",".join("'{}'".format(str(i).replace("'", "|")) for i in error_tasks.values())
                    )
                    if len(error_tasks.values()) > 0
                    else "ARRAY[]::text[]"
            )

            conn.execute(sql)
            conn.close()