From d316d248ab2876653d53ad083aac255674f442c6 Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Thu, 14 Mar 2019 02:43:40 -0300
Subject: [PATCH] vk_shader_decompiler: Implement non-OperationCode visits

---
 .../renderer_vulkan/vk_shader_decompiler.cpp  | 136 +++++++++++++++++-
 1 file changed, 129 insertions(+), 7 deletions(-)

diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
index e4c3e3d9c..5060dbba9 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -494,25 +494,136 @@ private:
             return (this->*decompiler)(*operation);
 
         } else if (const auto gpr = std::get_if<GprNode>(node)) {
-            UNIMPLEMENTED();
+            const u32 index = gpr->GetIndex();
+            if (index == Register::ZeroIndex) {
+                return Constant(t_float, 0.0f);
+            }
+            return Emit(OpLoad(t_float, registers.at(index)));
 
         } else if (const auto immediate = std::get_if<ImmediateNode>(node)) {
-            UNIMPLEMENTED();
+            return BitcastTo<Type::Float>(Constant(t_uint, immediate->GetValue()));
 
         } else if (const auto predicate = std::get_if<PredicateNode>(node)) {
-            UNIMPLEMENTED();
+            const auto value = [&]() -> Id {
+                switch (const auto index = predicate->GetIndex(); index) {
+                case Tegra::Shader::Pred::UnusedIndex:
+                    return v_true;
+                case Tegra::Shader::Pred::NeverExecute:
+                    return v_false;
+                default:
+                    return Emit(OpLoad(t_bool, predicates.at(index)));
+                }
+            }();
+            if (predicate->IsNegated()) {
+                return Emit(OpLogicalNot(t_bool, value));
+            }
+            return value;
 
         } else if (const auto abuf = std::get_if<AbufNode>(node)) {
-            UNIMPLEMENTED();
+            const auto attribute = abuf->GetIndex();
+            const auto element = abuf->GetElement();
+
+            switch (attribute) {
+            case Attribute::Index::Position:
+                if (stage != ShaderStage::Fragment) {
+                    UNIMPLEMENTED();
+                    break;
+                } else {
+                    if (element == 3) {
+                        return Constant(t_float, 1.0f);
+                    }
+                    return Emit(OpLoad(t_float, AccessElement(t_in_float, frag_coord, element)));
+                }
+            case Attribute::Index::TessCoordInstanceIDVertexID:
+                // TODO(Subv): Find out what the values are for the first two elements when inside a
+                // vertex shader, and what's the value of the fourth element when inside a Tess Eval
+                // shader.
+                ASSERT(stage == ShaderStage::Vertex);
+                switch (element) {
+                case 2:
+                    return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, instance_index)));
+                case 3:
+                    return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, vertex_index)));
+                }
+                UNIMPLEMENTED_MSG("Unmanaged TessCoordInstanceIDVertexID element={}", element);
+                return Constant(t_float, 0);
+            case Attribute::Index::FrontFacing:
+                // TODO(Subv): Find out what the values are for the other elements.
+                ASSERT(stage == ShaderStage::Fragment);
+                if (element == 3) {
+                    const Id is_front_facing = Emit(OpLoad(t_bool, front_facing));
+                    const Id true_value =
+                        BitcastTo<Type::Float>(Constant(t_int, static_cast<s32>(-1)));
+                    const Id false_value = BitcastTo<Type::Float>(Constant(t_int, 0));
+                    return Emit(OpSelect(t_float, is_front_facing, true_value, false_value));
+                }
+                UNIMPLEMENTED_MSG("Unmanaged FrontFacing element={}", element);
+                return Constant(t_float, 0.0f);
+            default:
+                if (IsGenericAttribute(attribute)) {
+                    const Id pointer =
+                        AccessElement(t_in_float, input_attributes.at(attribute), element);
+                    return Emit(OpLoad(t_float, pointer));
+                }
+                break;
+            }
+            UNIMPLEMENTED_MSG("Unhandled input attribute: {}", static_cast<u32>(attribute));
 
         } else if (const auto cbuf = std::get_if<CbufNode>(node)) {
-            UNIMPLEMENTED();
+            const Node offset = cbuf->GetOffset();
+            const Id buffer_id = constant_buffers.at(cbuf->GetIndex());
+
+            Id buffer_index{};
+            Id buffer_element{};
+
+            if (const auto immediate = std::get_if<ImmediateNode>(offset)) {
+                // Direct access
+                const u32 offset_imm = immediate->GetValue();
+                ASSERT(offset_imm % 4 == 0);
+                buffer_index = Constant(t_uint, offset_imm / 16);
+                buffer_element = Constant(t_uint, (offset_imm / 4) % 4);
+
+            } else if (std::holds_alternative<OperationNode>(*offset)) {
+                // Indirect access
+                // TODO(Rodrigo): Use a uniform buffer stride of 4 and drop this slow math (which
+                // emits sub-optimal code on GLSL from my testing).
+                const Id offset_id = BitcastTo<Type::Uint>(Visit(offset));
+                const Id unsafe_offset = Emit(OpUDiv(t_uint, offset_id, Constant(t_uint, 4)));
+                const Id final_offset = Emit(
+                    OpUMod(t_uint, unsafe_offset, Constant(t_uint, MAX_CONSTBUFFER_ELEMENTS - 1)));
+                buffer_index = Emit(OpUDiv(t_uint, final_offset, Constant(t_uint, 4)));
+                buffer_element = Emit(OpUMod(t_uint, final_offset, Constant(t_uint, 4)));
+
+            } else {
+                UNREACHABLE_MSG("Unmanaged offset node type");
+            }
+
+            const Id pointer = Emit(OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0),
+                                                  buffer_index, buffer_element));
+            return Emit(OpLoad(t_float, pointer));
 
         } else if (const auto gmem = std::get_if<GmemNode>(node)) {
-            UNIMPLEMENTED();
+            const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor());
+            const Id real = BitcastTo<Type::Uint>(Visit(gmem->GetRealAddress()));
+            const Id base = BitcastTo<Type::Uint>(Visit(gmem->GetBaseAddress()));
+
+            Id offset = Emit(OpISub(t_uint, real, base));
+            offset = Emit(OpUDiv(t_uint, offset, Constant(t_uint, 4u)));
+            return Emit(OpLoad(t_float, Emit(OpAccessChain(t_gmem_float, gmem_buffer,
+                                                           Constant(t_uint, 0u), offset))));
 
         } else if (const auto conditional = std::get_if<ConditionalNode>(node)) {
-            UNIMPLEMENTED();
+            // It's invalid to call conditional on nested nodes, use an operation instead
+            const Id true_label = OpLabel();
+            const Id skip_label = OpLabel();
+            Emit(OpBranchConditional(Visit(conditional->GetCondition()), true_label, skip_label));
+            Emit(true_label);
+
+            VisitBasicBlock(conditional->GetCode());
+
+            Emit(OpBranch(skip_label));
+            Emit(skip_label);
+            return {};
 
         } else if (const auto comment = std::get_if<CommentNode>(node)) {
             Name(Emit(OpUndef(t_void)), comment->GetText());
@@ -719,6 +830,17 @@ private:
         return false;
     }
 
+    template <typename... Args>
+    Id AccessElement(Id pointer_type, Id composite, Args... elements_) {
+        std::vector<Id> members;
+        auto elements = {elements_...};
+        for (const auto element : elements) {
+            members.push_back(Constant(t_uint, element));
+        }
+
+        return Emit(OpAccessChain(pointer_type, composite, members));
+    }
+
     template <Type type>
     Id VisitOperand(Operation operation, std::size_t operand_index) {
         const Id value = Visit(operation[operand_index]);