4. Kaleidoscope: Adding JIT and Optimizer Support
This chapter of the Kaleidoscope tutorial introduces Just-In-Time (JIT) compilation and simple optimizations of the generated code. As, such this is the first variant of the language implementation where you can actually execute the Kaleidoscope code. Thus, this is a bit more fun than the others as you finally get to see the language working for real!
Constant Folding
If you studied the LLVM IR generated from the previous chapters you will see that it isn't particularly well optimized. There is one case, though, where it does do some nice optimization automatically for us.
For example:
def test(x) 1+2+x;
produces the following LLVM IR:
define double @test(double %x) {
entry:
%addtmp = fadd double 3.000000e+00, %x
ret double %addtmp
}
That's not exactly what the parse tree would suggest. The InstructionBuilder automatically performs an optimization technique known as 'Constant Folding'. This optimization is very important, in fact, many compilers implement the folding directly into the generation of the Abstract Syntax Tree (AST). With LLVM, that isn't necessary as it is automatically provided for you (no extra charge!).
Obviously constant folding isn't the only possible optimization and InstructionBuilder only operates on the individual instructions as they are built. So, there are limits on what InstructionBuilder can do.
For example:
def test(x) (1+2+x)*(x+(1+2));
define double @test(double %x) {
entry:
%addtmp = fadd double 3.000000e+00, %x
%addtmp1 = fadd double %x, 3.000000e+00
%multmp = fmul double %addtmp, %addtmp1
ret double %multmp
}
In this case the operand of the additions are identical. Ideally this would generate as
temp = x+3; result = temp*temp;
rather than computing X+3 twice. This isn't something that
InstructionBuilder alone can do. Ultimately this requires two distinct transformations:
- Re-association of expressions to make the additions lexically identical (e.g. recognize that x+3 == 3+x )
- Common Subexpression Elimination to remove the redundant add instruction.
Fortunately, LLVM provides a very broad set of optimization transformations that can handle this and many other scenarios.
LLVM Optimization Passes
LLVM provides many different optimization passes, each handling a specific scenario with different trade-offs. One of the values of LLVM as a general compilation back-end is that it doesn't enforce any particular set of optimizations. By default, there aren't any optimizations (Other than the obvious constant folding built into the InstructionBuilder). All optimizations are entirely in the hands of the front-end application. The compiler implementor controls what passes are applied, and in what order they are run. This ensures that the optimizations are tailored to correctly meet the needs of the language and runtime environment.
For Kaleidoscope, optimizations are limited to a single function as they are generated when the user types them in on the command line. Ultimate, whole program optimization is off the table (You never know when the user will enter the last expression so it is incorrect to eliminate unused functions). In order to support per-function optimization a FunctionPassManager is created to hold the passes used for optimizing a function. The FunctionPassManager supports running the passes to transform a function into the optimized form. Since a pass manager is tied to the module and, for JIT support, each function is generated into its own module a new method in the code generator is used to create the module and initialize the pass manager.
private void InitializeModuleAndPassManager( )
{
Module = Context.CreateBitcodeModule( );
Module.Layout = JIT.TargetMachine.TargetData;
FunctionPassManager = new FunctionPassManager( Module );
if( !DisableOptimizations )
{
FunctionPassManager.AddInstructionCombiningPass( )
.AddReassociatePass( )
.AddGVNPass( )
.AddCFGSimplificationPass( );
}
FunctionPassManager.Initialize( );
}
Creating the pass manager isn't enough to get the optimizations. Something needs to actually provide the pass manager with the function to optimize. The most sensible place to put that is as the last step of generating the function.
FunctionPassManager.Run( function );
This will run the passes defined when the FunctionPassManager was created, resulting in better generated code.
define double @test(double %x) {
entry:
%addtmp = fadd double %x, 3.000000e+00
%multmp = fmul double %addtmp, %addtmp
ret double %multmp
}
The passes eliminate the redundant add instructions to produce a simpler, yet still correct representation of the generated code. LLVM provides a wide variety of optimization passes. Unfortunately not all are well documented, yet. Looking into what Clang uses is helpful as is using the LLVM 'opt.exe' tool to run passes individually or in various combinations and ordering to see how well it optimizes the code based on what your front-end generates. (This can lead to changing the passes and ordering, as well as changes in what the front-end generates so that the optimizer can handle the input better) This is not an exact science with a one size fits all kind of solution. There are many common passes that are likely relevant to all languages. Though the ordering of them may differ depending on the needs of the language and runtime. Getting, the optimizations and ordering for a given language is arguably where the most work lies in creating a production quality language using LLVM.
Adding JIT Compilation
Now that the code generation produces optimized code, it is time to get to the fun part - executing code! The basic idea is to allow the user to type in the Kaleidoscope code as supported thus far and it will execute to produce a result. Unlike the previous chapters, instead of just printing out the LLVM IR representation of a top level expression it is executed and the results are provided back to the user.
Main Driver
The changes needed to the main driver are pretty simple, mostly consisting of removing a couple lines of code that print out the LLVM IR for the module at the end and for each function when defined. The code already supported showing the results if it was a floating point value by checking if the generated value is a ConstantFP. We'll see a bit later on why that is a ConstantFP value.
Code Generator
The code generation needs an update to support using a JIT engine to generate and execute the Kaleidescope code provided by the user.
To use the Optimization transforms the generator needs a new namespace using declaration.
using Ubiquity.NET.Llvm.Transforms;
Generator fields
To begin with, the generator needs some additional members, including the JIT engine.
private readonly DynamicRuntimeState RuntimeState;
private readonly Context Context;
private readonly InstructionBuilder InstructionBuilder;
private readonly IDictionary<string, Value> NamedValues = new Dictionary<string, Value>( );
private FunctionPassManager? FunctionPassManager;
private readonly bool DisableOptimizations;
private BitcodeModule? Module;
private readonly KaleidoscopeJIT JIT = new KaleidoscopeJIT( );
private readonly Dictionary<string, ulong> FunctionModuleMap = new Dictionary<string, ulong>( );
The JIT engine is retained for the generator to use. The same engine is retained for the lifetime of the generator so that functions are added to the same engine and can call functions previously added. The JIT provides a 'handle' for every module added, which is used to reference the module in the JIT, this is normally used to remove the module from the JIT engine when re-defining a function. Thus, a map of the function names and the JIT handle created for them is maintained. Additionally, a collection of defined function prototypes is retained to enable matching a function call to a previously defined function. Since the JIT support uses a module per function approach, lookups on the current module aren't sufficient.
Generator initialization
The initialization of the generator requires updating to support the new members.
public CodeGenerator( DynamicRuntimeState globalState, bool disableOptimization = false, TextWriter? outputWriter = null )
: base( null )
{
JIT.OutputWriter = outputWriter ?? Console.Out;
globalState.ValidateNotNull( nameof( globalState ) );
if( globalState.LanguageLevel > LanguageLevel.SimpleExpressions )
{
throw new ArgumentException( "Language features not supported by this generator", nameof( globalState ) );
}
RuntimeState = globalState;
Context = new Context( );
DisableOptimizations = disableOptimization;
InitializeModuleAndPassManager( );
InstructionBuilder = new InstructionBuilder( Context );
}
The bool indicating if optimizations are enabled or not is stored and an initial module and pass manager is created.
The option to disable optimizations is useful for debugging the code generation itself as optimizations can alter or even eliminate incorrectly generated code. Thus, when modifying the generation itself, it is useful to disable the optimizations.
JIT Engine
The JIT engine itself is a class provided in the Kaleidoscope.Runtime library derived from the Ubiquity.NET.Llvm OrcJIT engine.
// -----------------------------------------------------------------------
// <copyright file="KaleidoscopeJIT.cs" company="Ubiquity.NET Contributors">
// Copyright (c) Ubiquity.NET Contributors. All rights reserved.
// </copyright>
// -----------------------------------------------------------------------
using System;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Runtime.InteropServices;
using Ubiquity.NET.Llvm;
using Ubiquity.NET.Llvm.JIT;
namespace Kaleidoscope.Runtime
{
/// <summary>JIT engine for Kaleidoscope language</summary>
/// <remarks>
/// This engine uses the <see cref="Ubiquity.NET.Llvm.JIT.OrcJit"/> engine to support lazy
/// compilation of LLVM IR modules added to the JIT.
/// </remarks>
public sealed class KaleidoscopeJIT
: OrcJit
{
/// <summary>Initializes a new instance of the <see cref="KaleidoscopeJIT"/> class.</summary>
public KaleidoscopeJIT( )
: base( BuildTargetMachine( ) )
{
AddInteropCallback( "putchard", new CallbackHandler1( PutChard ) );
AddInteropCallback( "printd", new CallbackHandler1( Printd ) );
}
/// <summary>Gets or sets the output writer for output from the program.</summary>
/// <remarks>The default writer is <see cref="Console.Out"/>.</remarks>
public TextWriter OutputWriter { get; set; } = Console.Out;
/// <summary>Delegate for an interop callback taking no parameters</summary>
/// <returns>value for the function</returns>
[UnmanagedFunctionPointer( CallingConvention.Cdecl )]
public delegate double CallbackHandler0( );
/// <summary>Delegate for an interop callback taking one parameters</summary>
/// <param name="arg1">First parameter</param>
/// <returns>value for the function</returns>
[UnmanagedFunctionPointer( CallingConvention.Cdecl )]
public delegate double CallbackHandler1( double arg1 );
/// <summary>Delegate for an interop callback taking two parameters</summary>
/// <param name="arg1">First parameter</param>
/// <param name="arg2">Second parameter</param>
/// <returns>value for the function</returns>
[UnmanagedFunctionPointer( CallingConvention.Cdecl )]
public delegate double CallbackHandler2( double arg1, double arg2 );
/// <summary>Delegate for an interop callback taking three parameters</summary>
/// <param name="arg1">First parameter</param>
/// <param name="arg2">Second parameter</param>
/// <param name="arg3">Third parameter</param>
/// <returns>value for the function</returns>
[UnmanagedFunctionPointer( CallingConvention.Cdecl )]
public delegate double CallbackHandler3( double arg1, double arg2, double arg3 );
/// <summary>Delegate for an interop callback taking four parameters</summary>
/// <param name="arg1">First parameter</param>
/// <param name="arg2">Second parameter</param>
/// <param name="arg3">Third parameter</param>
/// <param name="arg4">Fourth parameter</param>
/// <returns>value for the function</returns>
[UnmanagedFunctionPointer( CallingConvention.Cdecl )]
public delegate double CallbackHandler4( double arg1, double arg2, double arg3, double arg4 );
[SuppressMessage( "Design", "CA1031:Do not catch general exception types", Justification = "Native callback *MUST NOT* surface managed exceptions" )]
private double Printd( double x )
{
// STOP ALL EXCEPTIONS from bubbling out to JIT'ed code
try
{
OutputWriter.WriteLine( x );
return 0.0F;
}
catch
{
return 0.0;
}
}
[SuppressMessage( "Design", "CA1031:Do not catch general exception types", Justification = "Native callback *MUST NOT* surface managed exceptions" )]
private double PutChard( double x )
{
// STOP ALL EXCEPTIONS from bubbling out to JIT'ed code
try
{
OutputWriter.Write( ( char )x );
return 0.0F;
}
catch
{
return 0.0;
}
}
private static TargetMachine BuildTargetMachine( )
{
string hostTriple = Triple.HostTriple.ToString( );
return Target.FromTriple( hostTriple )
.CreateTargetMachine( hostTriple
, /*cpu*/null
, /*features*/null
, CodeGenOpt.Default
, RelocationMode.Default
, CodeModel.JitDefault
);
}
}
}
OrcJit provides support for declaring functions that are external to the JIT that the JIT'd module code can call. For Kaleidoscope, two such functions are defined directly in KaleidoscopeJIT (putchard and printd), which is consistent with the same functions used in the official LLVM C++ tutorial. Thus, allowing sharing of samples between the two. These functions are used to provide rudimentary console output support.
Warning
All such methods implemented in .NET must block any exception from bubbling out of the call as the JIT engine doesn't know anything about them and neither does the Kaleidoscope language. Exceptions thrown in these functions would produce undefined results, at best - crashing the application.
PassManager
Every time a new function definition is processed the generator creates a new module and initializes the function pass manager for the module. This is done is a new method InitializeModuleAndPassManager()
private void InitializeModuleAndPassManager( )
{
Module = Context.CreateBitcodeModule( );
Module.Layout = JIT.TargetMachine.TargetData;
FunctionPassManager = new FunctionPassManager( Module );
if( !DisableOptimizations )
{
FunctionPassManager.AddInstructionCombiningPass( )
.AddReassociatePass( )
.AddGVNPass( )
.AddCFGSimplificationPass( );
}
FunctionPassManager.Initialize( );
}
The module creation is pretty straight forward, of importance is the layout information pulled from the target machine for the JIT and applied to the module.
Once the module is created, the FunctionPassManager is constructed. If optimizations are not disabled, the optimization passes are added to the pass manager. The set of passes used is a very basic set since the Kaleidoscope language isn't particularly complex at this point.
Generator Dispose
Since the JIT engine is disposable, the code generators Dispose() method must now call the Dispose() method on the JIT engine.
public void Dispose( )
{
JIT.Dispose( );
Module?.Dispose( );
Context.Dispose( );
}
Generate Method
To actually execute the code the generated modules are added to the JIT. If the function is an anonymous top level expression, it is eagerly compiled and a delegate is retrieved from the JIT to allow calling the compiled function directly. The delegate is then called to get the result. Once an anonymous function produces a value, it is no longer used so is removed from the JIT and the result value returned. For other functions the module is added to the JIT and the function is returned.
For named function definitions, the module is lazy added to the JIT as it isn't known if/when the functions is called. The JIT engine will compile modules lazy added into native code on first use. (Though if the function is never used, then creating the IR module was wasted. (Chapter 7.1 has a solution for even that extra overhead - truly lazy JIT). Since Kaleidoscope is generally a dynamic language it is possible and reasonable for the user to re-define a function (to fix an error, or provide a completely different implementation all together). Therefore, any named functions are removed from the JIT, if they existed, before adding in the new definition. Otherwise the JIT resolver would still resolve to the previously compiled instance.
public OptionalValue<Value> Generate( IAstNode ast )
{
ast.ValidateNotNull( nameof( ast ) );
// Prototypes, including extern are ignored as AST generation
// adds them to the RuntimeState so that already has the declarations
if( !( ast is FunctionDefinition definition ) )
{
return default;
}
InitializeModuleAndPassManager( );
Debug.Assert( !( Module is null ), "Module initialization failed" );
var function = ( IrFunction )(definition.Accept( this ) ?? throw new CodeGeneratorException(ExpectValidFunc));
if( definition.IsAnonymous )
{
// eagerly compile modules for anonymous functions as calling the function is the guaranteed next step
ulong jitHandle = JIT.AddEagerlyCompiledModule( Module );
var nativeFunc = JIT.GetFunctionDelegate<KaleidoscopeJIT.CallbackHandler0>( definition.Name );
var retVal = Context.CreateConstant( nativeFunc( ) );
JIT.RemoveModule( jitHandle );
return OptionalValue.Create<Value>( retVal );
}
else
{
// Destroy any previously generated module for this function.
// This allows re-definition as the new module will provide the
// implementation. This is needed, otherwise both the MCJIT
// and OrcJit engines will resolve to the original module, despite
// claims to the contrary in the official tutorial text. (Though,
// to be fair it may have been true in the original JIT and might
// still be true for the interpreter)
if( FunctionModuleMap.Remove( definition.Name, out ulong handle ) )
{
JIT.RemoveModule( handle );
}
// Unknown if any future input will call the function so add it for lazy compilation.
// Native code is generated for the module automatically only when required.
ulong jitHandle = JIT.AddLazyCompiledModule( Module );
FunctionModuleMap.Add( definition.Name, jitHandle );
return OptionalValue.Create<Value>( function );
}
}
Keeping all the JIT interaction in the generate method isolates the rest of the generation from any awareness of the JIT. This will help when adding truly lazy JIT compilation in Chapter 7.1 and AOT compilation in Chapter 8
Function call expressions
Since functions are no longer collected into a single module the code to find the target for a function call requires updating to lookup the function from a collection of functions mapped by name.
public override Value? Visit( FunctionCallExpression functionCall )
{
if( Module is null )
{
throw new InvalidOperationException( "Can't visit a function call without an active module" );
}
functionCall.ValidateNotNull( nameof( functionCall ) );
string targetName = functionCall.FunctionPrototype.Name;
IrFunction? function;
if( RuntimeState.FunctionDeclarations.TryGetValue( targetName, out Prototype target ) )
{
function = GetOrDeclareFunction( target );
}
else if( !Module.TryGetFunction( targetName, out function ) )
{
throw new CodeGeneratorException( $"Definition for function {targetName} not found" );
}
var args = ( from expr in functionCall.Arguments
select expr.Accept( this ) ?? throw new CodeGeneratorException(ExpectValidExpr)
).ToArray();
return InstructionBuilder.Call( function, args ).RegisterName( "calltmp" );
}
This will lookup the function prototype by name and call the GetOrDeclareFunction() with the prototype found. If the prototype wasn't found then it falls back to the previous lookup in the current module. This fall back is needed to support recursive functions where the referenced function actually is in the current module.
GetOrDeclareFunction()
Next is to update the GetOrDeclareFunction() to handle mapping the functions prototype and re-definition of functions.
// Retrieves a Function for a prototype from the current module if it exists,
// otherwise declares the function and returns the newly declared function.
private IrFunction GetOrDeclareFunction( Prototype prototype )
{
if( Module is null )
{
throw new InvalidOperationException( "ICE: Can't get or declare a function without an active module" );
}
if( Module.TryGetFunction( prototype.Name, out IrFunction? function ) )
{
return function;
}
var llvmSignature = Context.GetFunctionType( Context.DoubleType, prototype.Parameters.Select( _ => Context.DoubleType ) );
var retVal = Module.CreateFunction( prototype.Name, llvmSignature );
int index = 0;
foreach( var argId in prototype.Parameters )
{
retVal.Parameters[ index ].Name = argId.Name;
++index;
}
return retVal;
}
This distinguishes the special case of an anonymous top level expression as those are never added to the prototype maps. They are only in the JIT engine long enough to execute once and are then removed. Since they are, by definition, anonymous they can never be referenced by anything else.
Function Definitions
Visiting a function definition needs to add a call to the function pass manager to run the optimization passes for the function. This, makes sense to do, immediately after completing the generation of the function.
public override Value? Visit( FunctionDefinition definition )
{
definition.ValidateNotNull( nameof( definition ) );
var function = GetOrDeclareFunction( definition.Signature );
if( !function.IsDeclaration )
{
throw new CodeGeneratorException( $"Function {function.Name} cannot be redefined in the same module" );
}
try
{
var entryBlock = function.AppendBasicBlock( "entry" );
InstructionBuilder.PositionAtEnd( entryBlock );
NamedValues.Clear( );
foreach( var param in definition.Signature.Parameters )
{
NamedValues[ param.Name ] = function.Parameters[ param.Index ];
}
var funcReturn = definition.Body.Accept( this ) ?? throw new CodeGeneratorException( ExpectValidFunc );
InstructionBuilder.Return( funcReturn );
function.Verify( );
FunctionPassManager?.Run( function );
return function;
}
catch( CodeGeneratorException )
{
function.EraseFromParent( );
throw;
}
}
Conclusion
While the amount of words needed to describe the changes to support optimization and JIT execution here isn't exactly small, the actual code changes required really are. The Parser and JIT engine do all the heavy lifting. Ubiquity.NET.Llvm provides a clean interface to the JIT that fits with common patterns and runtime support for .NET. Very cool, indeed!