Infer.NET development : Compiler transforms
Gate transform
The purpose of the gate transform is to remove conditional statements in MSL, by replacing them with equivalent code involving duplicated variables. (In fact, some if statements will remain but these are of a special form, and only used to indicate to the MessagePassingTransform that an evidence message is needed.)
There are 3 kinds of conditional statement allowed in MSL:
if
statement whose condition is a boolean variable or the complement of a boolean variable.if
statement whose condition is of the form(x==value)
wherevalue
is a literal boolean or integer andx
is a variable or array element. There must be a separateif
for each value. This is generated by Variable.Case and abbreviated as acase
statement.if
statement whose condition is of the form(x==i)
wherei
is a loop counter ranging over all values ofx
,x
is an integer variable or array element, andx
is initialized outside the loop. This is generated by Variable.Switch and abbreviated as aswitch
statement.
For each of these, the transform does the following:
- The condition variable is split into multiple boolean variables, one for each case. These are stored in an array called c_cases which is initialized by
Gate.Cases
. The c_cases array is given theDoNotSendEvidence
attribute. - Each statement inside the body is wrapped by a new
if
on one of the case variables above. If the statement to be wrapped is afor
loop, the wrapping happens to its body (as if the loop were unrolled). After the transform, these wrappedif
s are all that remains from the original conditional. In particular, there will be no morecase
orswitch
statements. - Enter variable = A random variable that is used inside the conditional but initialized outside the conditional. An enter variable is split by the transform into multiple clones, one for each case. The clones are initialized by
Gate.Enter
,Gate.EnterPartial
, orGate.EnterOne
as appropriate.Gate.Enter
creates a clone for every case.Gate.EnterPartial
creates a clone for a subset of cases, for example, if the variable is used in theelse
clause only.Gate.EnterOne
creates one clone only, for situations when the compiler can easily determine that the variable must be used in exactly one case. - Exit variable = A random variable that is initialized inside the conditional and subsequently used outside the conditional. An exit variable is first defined in terms of separate clones, one inside each case. Then the clones are merged by
Gate.Exit
.
The reason for the variable cloning process is to ensure that no case of the conditional refers to a variable outside of that case. Otherwise it would let information travel freely outside of the case. We don’t want information to travel outside that case unless the case is switched on, and we don’t know that until we perform inference. Also, we don’t want information flowing between mutually exclusive cases of the same conditional. So we use Enter/Exit factors to control the flow of information.
The implementation of GateTransform
:
-
GateAnalysisTransform attaches a GateBlock attribute to all condition statements. A GateBlock attribute describes all external variables used or defined inside the block. For arrays, we need to group uses of an array into disjoint sets of indices, so that each use is covered by one of the sets. Each use is described by an indexing pattern and a set of bindings (assignments to condition variables).
-
ConvertCondition parses the condition expression and checks if a ConditionInformation already exists for the condition variable. If not, a ConditionInformation object is created and code to define the cases array is inserted into the output (see ConditionInformation.Build). ConvertCondition calls ConvertConditionCase to convert the Then and Else bodies. If the condition is deterministic, then there is no need to collect evidence so instead of wrapping each statement, the entire converted block is wrapped with the condition.
-
ConvertConditionCase wraps all statements in the block with conditionals for the purpose of collecting evidence.
-
ConvertVariableDecl attaches the current ConditionInformation as an attribute to the declaration. Any variable with this attribute is a local variable of the condition.
-
ConvertVariableRefExpr and ConvertArrayIndexer call ReplaceWithClone.
-
ReplaceWithClone checks if the expression contains a reference to a stochastic non-local variable. If so, it calls GetClone.
-
ConditionInformation.GetClone fetches a ClonedVarInfo from the CondInfo.cloneMap, or creates one if needed. When creating a clone, the new declaration is placed in a sufficient set of containers. That is, if the expression being cloned refers to a local variable of a container, then the clone must be declared inside that container. CloneVarInfo.GetCloneForCase returns the new expression.
-
ConditionInformation.CreateCloneArray makes code to define the clone array and adds it outside the condition statement. This code is wrapped with a condition constructed from the bindings where this variable is used.
-
CloneVarInfo.GetCloneForCase makes code to extract an element of the clone array and adds it inside the condition statement. This new variable must send evidence to the cases array, so the new code is placed inside the condition statement.
Examples follow.
If statement
Input | Output |
---|---|
if(c) { bool b = Factor.Bernoulli(0.1); Constrain.True(b); } |
bool[] c_cases = new bool[2]; c_cases = Gate.Cases(c); if(c_cases[0]) { bool b = Factor.Bernoulli(0.1); Constrain.True(b); } |
if(c) { Constrain.EqualRandom(x, constDist); } |
bool[] c_cases = new bool[2]; c_cases = Gate.Cases(c); bool[] x_cond_c = new bool[1]; x_cond_c = Gate.EnterPartial(c_cases, x, 0); if(c_cases[0]) { Constrain.EqualRandom(x_cond_c[0], constDist); } |
if(c) { double sum = Factor.Sum(array); } |
bool[] c_cases = new bool[2]; c_cases = Gate.Cases(c); bool[][] array_cond_c = new bool[1][]; for(int _gateind = 0; _gateind < 1; _gateind++) { array_cond_c[_gateind] = new bool[3]; } array_cond_c = Gate.EnterPartial(c_cases, array, 0); if(c_cases[0]) { double sum = Factor.Sum(array_cond_c[0]); } |
for(int i = 0; i < 3; i++) { if(c[i]) { bool b = Factor.Bernoulli(0.1); Constrain.True(b); } } |
for(int i = 0; i < 3; i++) { bool[] c_i_cases = new bool[2]; c_i_cases = Gate.Cases(c[i]); if(c_i_cases[0]) { bool b = Factor.Bernoulli(0.1); Constrain.True(b); } } |
for(int i = 0; i < 3; i++) { if(c[i]) { Constrain.EqualRandom(array[i], constDist); } } |
for(int i = 0; i < 3; i++) { bool[] c_i_cases = new bool[2]; c_i_cases = Gate.Cases(c[i]); bool[] array_i_cond_c = new bool[1]; array_i_cond_c = Gate.EnterPartial(c_i_cases, array[i], 0); if(c_i_cases[0]) { Constrain.EqualRandom(array_i_cond_c[0], constDist); } } |
if(c) { for(int i = 0; i < 3; i++) { Constrain.EqualRandom(array[i], constDist); } } |
bool[] c_cases = new bool[2]; c_cases = Gate.Cases(c); for(int i = 0; i < 3; i++) { bool[] array_i_cond_c = new bool[1]; array_i_cond_c = Gate.EnterPartial(c_cases, array[i], 0); if(c_cases[0]) { Constrain.EqualRandom(array_i_cond_c[0], constDist); } } |
if(c) { for(int k = 0; k < 2; k++) { for(int i = 0; i < 3; i++) { Constrain.EqualRandom(array[i], constDist); } } } |
bool[] c_cases = new bool[2]; c_cases = Gate.Cases(c); for(int i = 0; i < 3; i++) { bool[] array_i_cond_c = new bool[1]; array_i_cond_c = Gate.EnterPartial(c_cases, array[i], 0); } for(int k = 0; k < 2; k++) { for(int i = 0; i < 3; i++) { // re-enter non-nested container if(c_cases[0]) { Constrain.EqualRandom(array_i_cond_c[0], constDist); } } } |
for(int j = 0; j < 2; j++) { if(c[j]) { for(int i = 0; i < 3; i++) { Constrain.EqualRandom(array[i][j], constDist); } } } |
for(int j = 0; j < 2; j++) { bool[] c_j_cases = new bool[2]; c_j_cases = Gate.Cases(c[j]); for(int i = 0; i < 3; i++) { bool[] array_i_j_cond_c_j = new bool[1]; array_i_j_cond_c_j = Gate.EnterPartial(c_j_cases, array[i][j], 0); if(c_j_cases[0]) { Constrain.EqualRandom(array_i_j_cond_c_j[0], constDist); } } } |
Exit variable example
bool x; if(c) { x = Factor.Bernoulli(p); Constrain.EqualRandom(x, constDist); } else { x = Factor.Bernoulli(q); } |
bool x; bool[] c_cases = new bool[2]; c_cases = Gate.Cases(c); bool[] x_cond_c = new bool[2]; if(c_cases[0]) { x_cond_c[0] = Factor.Bernoulli(p); Constrain.EqualRandom(x_cond_c[0], constDist); } if(c_cases[1]) { x_cond_c[1] = Factor.Bernoulli(q); } x = Gate.Exit(c_cases, x_cond_c); |
Alternative method:
if(c) { for(int i = 0; i < 3; i++) { Constrain.EqualRandom(array[i], constDist); } } |
bool[] c_cases = new bool[2]; c_cases = Gate.Cases(c); bool[][] array_cond_c = new bool[1][]; for(int _gateind = 0; _gateind < 1; _gateind++) { array_cond_c[_gateind] = new bool[3]; } array_cond_c = Gate.EnterPartial(c_cases, array, 0); for(int i = 0; i < 3; i++) { if(c_cases[0]) { Constrain.EqualRandom(array_cond_c[0][i], constDist); } } |
Case statement
Input | Output |
---|---|
if(i==0) { bool b = Factor.Bernoulli(0.1); Constrain.True(b); } if(i==1) { Constrain.EqualRandom(x, constDist); } if(i==2) { Constrain.EqualRandom(array[i], constDist); } |
bool[] i_cases = new bool[3]; i_cases = Gate.Cases(i); if(i_cases[0]) { bool b = Factor.Bernoulli(0.1); Constrain.True(b); } bool[] x_cond_i = new bool[1]; x_cond_i = Gate.EnterPartial(i_cases, x, 1); if(i_cases[1]) { Constrain.EqualRandom(x_cond_i[0], constDist); } bool array_2_cond_i; array_i_cond_i = Gate.EnterOne(i_cases, array[2], 2); if(i_cases[2]) { Constrain.EqualRandom(array_2_cond_i, constDist); } |
Switch statement
Input | Output |
---|---|
for(int j = 0; j < 3; j++) { if(i==j) { bool b = Factor.Bernoulli(0.1); Constrain.True(b); } } |
bool[] i_cases = new bool[3]; i_cases = Gate.Cases(i); for(int j = 0; j < 3; j++) { if(i_cases[j]) { bool b = Factor.Bernoulli(0.1); Constrain.True(b); } } |
bool b = Factor.Bernoulli(0.1); for(int j = 0; j < 3; j++) { if(i==j) { Constrain.True(b); } } |
bool[] i_cases = new bool[3]; i_cases = Gate.Cases(i); bool[] b_cond_i = new bool[3]; b_cond_i = Gate.Enter(i_cases, b); for(int j = 0; j < 3; j++) { if(i_cases[j]) { // Make a copy to ensure only one use of b_cond_i[j] bool b_cond_i_j = Factor.Copy(b_cond_i[j]); Constrain.True(b_cond_i_j); } } |
for(int j = 0; j < 3; j++) { if(i==j) { Constrain.EqualRandom(x[i], constDist); } } |
bool[] i_cases = new bool[3]; i_cases = Gate.Cases(i); for(int j = 0; j < 3; j++) { if(i_cases[j]) { bool x_i_cond_i = Gate.EnterOne(i_cases, x[j], j); Constrain.EqualRandom(x_i_cond_i, constDist); } } |
bool x; for(int j = 0; j < 3; j++) { if(i==j) { x = Factor.Bernoulli(const[i]); Constrain.EqualRandom(x, constDist[i]); } } |
bool x; bool[] i_cases = new bool[3]; i_cases = Gate.Cases(i); bool[] x_cond_i = new bool[3]; for(int j = 0; j < 3; j++) { if(i_cases[j]) { bool x_cond_i_j = Factor.Bernoulli(const[j]); x_cond_i[j] = Factor.Copy(x_cond_i_j); Constrain.EqualRandom(x_cond_i_j, constDist[j]); } } x = Gate.Exit(i_cases, x_cond_i); |