#!/usr/bin/python3 -i
#
# Copyright (c) 2015-2023 The Khronos Group Inc.
# Copyright (c) 2015-2023 Valve Corporation
# Copyright (c) 2015-2023 LunarG, Inc.
# Copyright (c) 2015-2023 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
from generators.generator_utils import (incIndent, decIndent, addIndent)
from generators.vulkan_object import (Member)
from generators.base_generator import BaseGenerator

class LayerChassisDispatchOutputGenerator(BaseGenerator):
    def __init__(self):
        BaseGenerator.__init__(self)

        # Commands which are not autogenerated but still intercepted
        self.no_autogen_list = [
            'vkCreateInstance',
            'vkDestroyInstance',
            'vkCreateDevice',
            'vkDestroyDevice',
            'vkCreateSwapchainKHR',
            'vkCreateSharedSwapchainsKHR',
            'vkGetSwapchainImagesKHR',
            'vkDestroySwapchainKHR',
            'vkQueuePresentKHR',
            'vkCreateGraphicsPipelines',
            'vkCreateComputePipelines',
            'vkCreateRayTracingPipelinesNV',
            'vkCreateRayTracingPipelinesKHR',
            'vkResetDescriptorPool',
            'vkDestroyDescriptorPool',
            'vkAllocateDescriptorSets',
            'vkFreeDescriptorSets',
            'vkCreateDescriptorUpdateTemplate',
            'vkCreateDescriptorUpdateTemplateKHR',
            'vkDestroyDescriptorUpdateTemplate',
            'vkDestroyDescriptorUpdateTemplateKHR',
            'vkUpdateDescriptorSetWithTemplate',
            'vkUpdateDescriptorSetWithTemplateKHR',
            'vkCmdPushDescriptorSetWithTemplateKHR',
            'vkDebugMarkerSetObjectTagEXT',
            'vkDebugMarkerSetObjectNameEXT',
            'vkCreateRenderPass',
            'vkCreateRenderPass2KHR',
            'vkCreateRenderPass2',
            'vkDestroyRenderPass',
            'vkSetDebugUtilsObjectNameEXT',
            'vkSetDebugUtilsObjectTagEXT',
            'vkGetPhysicalDeviceDisplayPropertiesKHR',
            'vkGetPhysicalDeviceDisplayProperties2KHR',
            'vkGetPhysicalDeviceDisplayPlanePropertiesKHR',
            'vkGetPhysicalDeviceDisplayPlaneProperties2KHR',
            'vkGetDisplayPlaneSupportedDisplaysKHR',
            'vkGetDisplayModePropertiesKHR',
            'vkGetDisplayModeProperties2KHR',
            'vkEnumerateInstanceExtensionProperties',
            'vkEnumerateInstanceLayerProperties',
            'vkEnumerateDeviceExtensionProperties',
            'vkEnumerateDeviceLayerProperties',
            'vkEnumerateInstanceVersion',
            'vkGetPhysicalDeviceToolPropertiesEXT',
            'vkSetPrivateDataEXT',
            'vkGetPrivateDataEXT',
            'vkDeferredOperationJoinKHR',
            'vkGetDeferredOperationResultKHR',
            'vkSetPrivateData',
            'vkGetPrivateData',
            'vkBuildAccelerationStructuresKHR',
            'vkGetDescriptorEXT',
            'vkReleasePerformanceConfigurationINTEL',
            'vkExportMetalObjectsEXT',
            # These are for special-casing the pInheritanceInfo issue (must be ignored for primary CBs)
            'vkAllocateCommandBuffers',
            'vkFreeCommandBuffers',
            'vkDestroyCommandPool',
            'vkBeginCommandBuffer',
            'vkGetAccelerationStructureBuildSizesKHR'
            ]

        # List of all extension structs strings containing handles
        self.ndo_extension_structs = []

    def isNonDispatchable(self, name: str) -> bool:
        return name in self.vk.handles and not self.vk.handles[name].dispatchable

    def containsNonDispatchableObject(self, structName: str) -> bool:
        struct = self.vk.structs[structName]
        for member in struct.members:
            if self.isNonDispatchable(member.type):
                return True
            # recurse for member structs, guard against infinite recursion
            elif member.type in self.vk.structs and member.type != struct.name:
                if self.containsNonDispatchableObject(member.type):
                    return True
        return False

    # Now that the data is all collected and complete, generate and output the wrapping/unwrapping routines
    def generate(self):
        self.write(f'''// *** THIS FILE IS GENERATED - DO NOT EDIT ***
// See {os.path.basename(__file__)} for modifications

/***************************************************************************
*
* Copyright (c) 2015-2023 The Khronos Group Inc.
* Copyright (c) 2015-2023 Valve Corporation
* Copyright (c) 2015-2023 LunarG, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
****************************************************************************/\n''')
        self.write('// NOLINTBEGIN') # Wrap for clang-tidy to ignore

        if self.filename == 'layer_chassis_dispatch.h':
            self.generateHeader()
        elif self.filename == 'layer_chassis_dispatch.cpp':
            self.generateSource()
        else:
            self.write(f'\nFile name {self.filename} has no code to generate\n')

        self.write('// NOLINTEND') # Wrap for clang-tidy to ignore

    def generateHeader(self):
        out = []
        out.append('''
#pragma once

extern bool wrap_handles;

class ValidationObject;
void WrapPnextChainHandles(ValidationObject *layer_data, const void *pNext);

''')
        for command in self.vk.commands.values():
            prototype = command.cPrototype
            prototype = prototype.replace("VKAPI_ATTR ", "")
            prototype = prototype.replace("VKAPI_CALL vk", "Dispatch")
            out.extend([f'#ifdef {command.protect}\n'] if command.protect else [])
            out.append(f'{prototype}\n')
            out.extend([f'#endif // {command.protect}\n'] if command.protect else [])

        self.write("".join(out))

    def generateSource(self):
        # Construct list of extension structs containing handles
        # Generate the list of APIs that might need to handle wrapped extension structs
        for struct in [x for x in self.vk.structs.values() if x.sType and x.extendedBy]:
            for extendedStruct in struct.extendedBy:
                if self.containsNonDispatchableObject(extendedStruct) and extendedStruct not in self.ndo_extension_structs:
                    self.ndo_extension_structs.append(extendedStruct)

        out = []
        out.append('''
#include "utils/cast_utils.h"
#include "chassis.h"
#include "layer_chassis_dispatch.h"
#include "vk_safe_struct.h"
#include "state_tracker/pipeline_state.h"

#define DISPATCH_MAX_STACK_ALLOCATIONS 32

''')

        out.append('''
// Unique Objects pNext extension handling function
void WrapPnextChainHandles(ValidationObject *layer_data, const void *pNext) {
    void *cur_pnext = const_cast<void *>(pNext);
    while (cur_pnext != nullptr) {
        VkBaseOutStructure *header = reinterpret_cast<VkBaseOutStructure *>(cur_pnext);

        switch (header->sType) {
''')

        for struct in [self.vk.structs[x] for x in self.ndo_extension_structs]:
            indent = '                '
            (api_decls, api_pre, api_post) = self.uniquifyMembers(struct.members, indent, 'safe_struct->', 0, False, False, False)
            # Only process extension structs containing handles
            if not api_pre:
                continue
            out.extend([f'#ifdef {struct.protect}\n'] if struct.protect else [])
            out.append(f'            case {struct.sType}: {{\n')
            out.append(f'                    safe_{struct.name} *safe_struct = reinterpret_cast<safe_{struct.name} *>(cur_pnext);\n')
            out.append(api_pre)
            out.append('                } break;\n')
            out.extend([f'#endif // {struct.protect}\n'] if struct.protect else [])
            out.append('\n')

        out.append('''            default:
                break;
        }

        // Process the next structure in the chain
        cur_pnext = header->pNext;
    }
}
''')

        for command in [x for x in self.vk.commands.values() if x.name not in self.no_autogen_list]:
            out.extend([f'#ifdef {command.protect}\n'] if command.protect else [])

            # Generate NDO wrapping/unwrapping code for all parameters
            isCreate = any(x in command.name for x in ['Create', 'Allocate', 'GetRandROutputDisplayEXT', 'GetDrmDisplayEXT', 'RegisterDeviceEvent', 'RegisterDisplayEvent', 'AcquirePerformanceConfigurationINTEL'])
            isDestroy = any(x in command.name for x in ['Destroy', 'Free'])
            indent = '    '

            # Handle ndo create/allocate operations
            create_ndo_code = ''
            if isCreate:
                lastParam = command.params[-1]
                handle_type = lastParam.type
                if self.isNonDispatchable(handle_type):
                    # Check for special case where multiple handles are returned
                    ndo_array = lastParam.length is not None
                    create_ndo_code += f'{indent}if (VK_SUCCESS == result) {{\n'
                    indent = incIndent(indent)
                    ndo_dest = f'*{lastParam.name}'
                    if ndo_array:
                        create_ndo_code += f'{indent}for (uint32_t index0 = 0; index0 < {lastParam.length}; index0++) {{\n'
                        indent = incIndent(indent)
                        ndo_dest = f'{lastParam.name}[index0]'
                    create_ndo_code += f'{indent}{ndo_dest} = layer_data->WrapNew({ndo_dest});\n'
                    if ndo_array:
                        indent = decIndent(indent)
                        create_ndo_code += f'{indent}}}\n'
                    indent = decIndent(indent)
                    create_ndo_code += f'{indent}}}\n'

            # Handle ndo destroy/free operations
            destroy_ndo_code = ''
            if isDestroy:
                param = command.params[-2] # Last param is always VkAllocationCallbacks
                if self.isNonDispatchable(param.type):
                    # Remove a single handle from the map
                    destroy_ndo_code += addIndent(indent,
f'''uint64_t {param.name}_id = CastToUint64({param.name});
auto iter = unique_id_mapping.pop({param.name}_id);
if (iter != unique_id_mapping.end()) {{
    {param.name} = ({param.type})iter->second;
}} else {{
    {param.name} = ({param.type})0;
}}''')
            (api_decls, api_pre, api_post) = self.uniquifyMembers(command.params, indent, '', 0, isCreate, isDestroy, True)
            api_post += create_ndo_code
            if isDestroy:
                api_pre += destroy_ndo_code
            elif api_pre:
                api_pre = f'    {{\n{api_pre}{indent}}}\n'

            # If API doesn't contain NDO's, we still need to make a down-chain call
            down_chain_call_only = False
            if not api_decls and not api_pre and not api_post:
                down_chain_call_only = True

            prototype = command.cPrototype[:-1]
            prototype = prototype.replace("VKAPI_ATTR ", "")
            prototype = prototype.replace("VKAPI_CALL vk", "Dispatch")
            out.append(f'\n{prototype} {{\n')\

            # Pull out the text for each of the parameters, separate them by commas in a list
            paramstext = ', '.join([param.name for param in command.params])
            wrapped_paramstext = paramstext
            # If any of these paramters has been replaced by a local var, fix up the list
            for param in command.params:
                struct = self.vk.structs[param.type] if param.type in self.vk.structs else None
                isLocal = (self.isNonDispatchable(param.type) and param.length and param.const) or (struct and self.containsNonDispatchableObject(struct.name))
                isExtended = struct and struct.extendedBy and any(x in self.ndo_extension_structs for x in struct.extendedBy)
                if isLocal or isExtended:
                    if param.pointer:
                        if param.const:
                          wrapped_paramstext = wrapped_paramstext.replace(param.name, f'(const {param.type}*)local_{param.name}')
                        else:
                          wrapped_paramstext = wrapped_paramstext.replace(param.name, f'({param.type}*)local_{param.name}')
                    else:
                        wrapped_paramstext = wrapped_paramstext.replace(param.name, f'(const {param.type})local_{param.name}')

            # First, add check and down-chain call. Use correct dispatch table
            dispatch_table = 'instance_dispatch_table' if command.instance else 'device_dispatch_table'

            # first parameter is always dispatchable
            out.append(f'    auto layer_data = GetLayerDataPtr(get_dispatch_key({command.params[0].name}), layer_data_map);\n')
            # Put all this together for the final down-chain call
            if not down_chain_call_only:
                out.append(f'    if (!wrap_handles) return layer_data->{dispatch_table}.{command.name[2:]}({paramstext});\n')

            # Handle return values, if any
            assignResult = f'{command.returnType} result = ' if command.returnType != 'void' else ''

            # Pre-pend declarations and pre-api-call codegen
            if api_decls:
                out.append("\n".join(str(api_decls).rstrip().split("\n")))
            if api_pre:
                out.append("\n".join(str(api_pre).rstrip().split("\n")))
            out.append('\n')
            # Generate the wrapped dispatch call
            out.append(f'    {assignResult}layer_data->{dispatch_table}.{command.name[2:]}({wrapped_paramstext});\n')

            out.append("\n".join(str(api_post).rstrip().split("\n")))
            out.append('\n')
            # Handle the return result variable, if any
            if assignResult != '':
                out.append('    return result;\n')
            out.append('}\n')
            out.extend([f'#endif // {command.protect}\n'] if command.protect else [])

        self.write("".join(out))

    #
    # Clean up local declarations
    def cleanUpLocalDeclarations(self, indent, prefix, name, len, deferred_name):
        cleanup = ''
        if len is not None or deferred_name is not None:
            delete_var = f'local_{prefix}{name}'
            if len is None:
                delete_code = f'delete {delete_var}'
            else:
                delete_code = f'delete[] {delete_var}'
            cleanup = f'{indent}if ({delete_var}) {{\n'
            if deferred_name is not None:
                cleanup += f'{indent}    // Fix check for deferred ray tracing pipeline creation\n'
                cleanup += f'{indent}    // https://github.com/KhronosGroup/Vulkan-ValidationLayers/issues/5817\n'
                cleanup += f'{indent}    const bool is_operation_deferred = ({deferred_name} != VK_NULL_HANDLE) && (result == VK_OPERATION_DEFERRED_KHR);\n'
                cleanup += f'{indent}    if (is_operation_deferred) {{\n'
                cleanup += f'{indent}        std::vector<std::function<void()>> cleanup{{[{delete_var}](){{ {delete_code}; }}}};\n'
                cleanup += f'{indent}        layer_data->deferred_operation_post_completion.insert({deferred_name}, cleanup);\n'
                cleanup += f'{indent}    }} else {{\n'
                cleanup += f'{indent}        {delete_code};\n'
                cleanup += f'{indent}    }}\n'
            else:
                cleanup += f'{indent}    {delete_code};\n'
            cleanup += f'{indent}}}\n'
        return cleanup

    #
    # topLevel indicates if elements are passed directly into the function else they're below a ptr/struct
    # isCreate means that this is API creates or allocates NDOs
    # isDestroy indicates that this API destroys or frees NDOs
    def uniquifyMembers(self, members: list[Member], indent: str, prefix: str, arrayIndex: int, isCreate: bool, isDestroy: bool, topLevel: bool):
        decls = ''
        pre_code = ''
        post_code = ''
        index = f'index{str(arrayIndex)}'
        arrayIndex += 1
        # Process any NDOs in this structure and recurse for any sub-structs in this struct
        for member in members:
            # Handle NDOs
            if self.isNonDispatchable(member.type):
                count_name = member.length
                if (count_name is not None) and not topLevel:
                    count_name = f'{prefix}{member.length}'

                if (not topLevel) or (not isCreate) or (not member.pointer):
                    if count_name is not None:
                        if topLevel:
                            decls += f'{indent}{member.type} var_local_{prefix}{member.name}[DISPATCH_MAX_STACK_ALLOCATIONS];\n'
                            decls += f'{indent}{member.type} *local_{prefix}{member.name} = nullptr;\n'
                        pre_code += f'{indent}    if ({prefix}{member.name}) {{\n'
                        indent = incIndent(indent)
                        if topLevel:
                            pre_code += f'{indent}    local_{prefix}{member.name} = {count_name} > DISPATCH_MAX_STACK_ALLOCATIONS ? new {member.type}[{count_name}] : var_local_{prefix}{member.name};\n'
                            pre_code += f'{indent}    for (uint32_t {index} = 0; {index} < {count_name}; ++{index}) {{\n'
                            indent = incIndent(indent)
                            pre_code += f'{indent}    local_{prefix}{member.name}[{index}] = layer_data->Unwrap({member.name}[{index}]);\n'
                        else:
                            pre_code += f'{indent}    for (uint32_t {index} = 0; {index} < {count_name}; ++{index}) {{\n'
                            indent = incIndent(indent)
                            pre_code += f'{indent}    {prefix}{member.name}[{index}] = layer_data->Unwrap({prefix}{member.name}[{index}]);\n'
                        indent = decIndent(indent)
                        pre_code += f'{indent}    }}\n'
                        indent = decIndent(indent)
                        pre_code += f'{indent}    }}\n'
                        if topLevel:
                            post_code += f'{indent}if (local_{prefix}{member.name} != var_local_{prefix}{member.name})\n'
                            indent = incIndent(indent)
                            post_code += f'{indent}delete[] local_{member.name};\n'
                            indent = decIndent(indent)
                    else:
                        if topLevel:
                            if not isDestroy:
                                pre_code += f'{indent}    {member.name} = layer_data->Unwrap({member.name});\n'
                        else:
                            # Make temp copy of this var with the 'local' removed. It may be better to not pass in 'local_'
                            # as part of the string and explicitly print it
                            fix = str(prefix).strip('local_')
                            pre_code += f'{indent}    if ({fix}{member.name}) {{\n'
                            indent = incIndent(indent)
                            pre_code += f'{indent}    {prefix}{member.name} = layer_data->Unwrap({fix}{member.name});\n'
                            indent = decIndent(indent)
                            pre_code += f'{indent}    }}\n'
            # Handle Structs that contain NDOs at some level
            elif member.type in self.vk.structs:
                struct = self.vk.structs[member.type]
                process_pnext = struct.extendedBy and any(x in self.ndo_extension_structs for x in struct.extendedBy)
                # Structs at first level will have an NDO, OR, we need a safe_struct for the pnext chain
                if self.containsNonDispatchableObject(member.type) or process_pnext:
                    safe_type = 'safe_' + member.type if any(x.pointer for x in struct.members) else member.type
                    # Struct Array
                    if member.length is not None:
                        # Check if this function can be deferred.
                        deferred_name = next((x.name for x in members if x.type == 'VkDeferredOperationKHR'), None)
                        # Update struct prefix
                        if topLevel:
                            new_prefix = f'local_{member.name}'
                            # Declare safe_VarType for struct
                            decls += f'{indent}{safe_type} *{new_prefix} = nullptr;\n'
                        else:
                            new_prefix = f'{prefix}{member.name}'
                        pre_code += f'{indent}    if ({prefix}{member.name}) {{\n'
                        indent = incIndent(indent)
                        if topLevel:
                            pre_code += f'{indent}    {new_prefix} = new {safe_type}[{member.length}];\n'
                        pre_code += f'{indent}    for (uint32_t {index} = 0; {index} < {prefix}{member.length}; ++{index}) {{\n'
                        indent = incIndent(indent)
                        if topLevel:
                            if 'safe_' in safe_type:
                                # Handle special initialize function for VkAccelerationStructureBuildGeometryInfoKHR
                                if member.type == "VkAccelerationStructureBuildGeometryInfoKHR":
                                    pre_code += f'{indent}    {new_prefix}[{index}].initialize(&{member.name}[{index}], false, nullptr);\n'
                                else:
                                    pre_code += f'{indent}    {new_prefix}[{index}].initialize(&{member.name}[{index}]);\n'
                            else:
                                pre_code += f'{indent}    {new_prefix}[{index}] = {member.name}[{index}];\n'
                            if process_pnext:
                                pre_code += f'{indent}    WrapPnextChainHandles(layer_data, {new_prefix}[{index}].pNext);\n'
                        local_prefix = f'{new_prefix}[{index}].'
                        # Process sub-structs in this struct
                        (tmp_decl, tmp_pre, tmp_post) = self.uniquifyMembers(struct.members, indent, local_prefix, arrayIndex, isCreate, isDestroy, False)
                        decls += tmp_decl
                        pre_code += tmp_pre
                        post_code += tmp_post
                        indent = decIndent(indent)
                        pre_code += f'{indent}    }}\n'
                        indent = decIndent(indent)
                        pre_code += f'{indent}    }}\n'
                        if topLevel:
                            post_code += self.cleanUpLocalDeclarations(indent, prefix, member.name, member.length, deferred_name)
                    # Single Struct
                    elif member.pointer:
                        # Check if this function can be deferred.
                        deferred_name = next((x.name for x in members if x.type == 'VkDeferredOperationKHR'), None)
                        # Update struct prefix
                        if topLevel:
                            new_prefix = f'local_{member.name}->'
                            if deferred_name is None:
                                decls += f'{indent}{safe_type} var_local_{prefix}{member.name};\n'
                            decls +=  f'{indent}{safe_type} *local_{prefix}{member.name} = nullptr;\n'
                        else:
                            new_prefix = f'{prefix}{member.name}->'
                        # Declare safe_VarType for struct
                        pre_code += f'{indent}    if ({prefix}{member.name}) {{\n'
                        indent = incIndent(indent)
                        if topLevel:
                            if deferred_name is None:
                                pre_code += f'{indent}    local_{prefix}{member.name} = &var_local_{prefix}{member.name};\n'
                            else:
                                pre_code += f'{indent}    local_{member.name} = new {safe_type};\n'
                            if 'safe_' in safe_type:
                                # Handle special initialize function for VkAccelerationStructureBuildGeometryInfoKHR
                                if member.type == "VkAccelerationStructureBuildGeometryInfoKHR":
                                    pre_code += f'{indent}    local_{prefix}{member.name}->initialize({member.name}, false, nullptr);\n'
                                else:
                                    pre_code += f'{indent}    local_{prefix}{member.name}->initialize({member.name});\n'
                            else:
                                pre_code += f'{indent}    *local_{prefix}{member.name} = *{member.name};\n'
                        # Process sub-structs in this struct
                        (tmp_decl, tmp_pre, tmp_post) = self.uniquifyMembers(struct.members, indent, new_prefix, arrayIndex, isCreate, isDestroy, False)
                        decls += tmp_decl
                        pre_code += tmp_pre
                        post_code += tmp_post
                        if process_pnext:
                            pre_code += f'{indent}    WrapPnextChainHandles(layer_data, {new_prefix}pNext);\n'
                        indent = decIndent(indent)
                        pre_code += f'{indent}    }}\n'
                        if topLevel:
                            post_code += self.cleanUpLocalDeclarations(indent, prefix, member.name, member.length, deferred_name)
                    else:
                        # Update struct prefix
                        if topLevel:
                            sys.exit(1)
                        else:
                            new_prefix = f'{prefix}{member.name}.'
                        # Process sub-structs in this struct
                        (tmp_decl, tmp_pre, tmp_post) = self.uniquifyMembers(struct.members, indent, new_prefix, arrayIndex, isCreate, isDestroy, False)
                        decls += tmp_decl
                        pre_code += tmp_pre
                        post_code += tmp_post
                        if process_pnext:
                            pre_code += f'{indent}    WrapPnextChainHandles(layer_data, {prefix}{member.name}.pNext);\n'
        return decls, pre_code, post_code
