Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added transform layer example #5

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions mcl/array_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

from mcl.machine_types import i32, memref
from mcl.ndarray import Array, to_scalar_array
from mcl.vm import machine_op

def array_exp(array: Array):
new_memref = machine_op("memref_exp", memref, array.data)
return Array(dtype=array.dtype, data=new_memref)

def array_sqrt(array: Array):
new_memref = machine_op("memref_sqrt", memref, array.data)
return Array(dtype=array.dtype, data=new_memref)

def array_sum(array: Array, axis: int = -1, keepdims: bool = False):
new_memref = machine_op("memref_sum", memref, array.data, i32(axis), i32(keepdims))
return Array(dtype=array.dtype, data=new_memref)

def array_max(array: Array, axis: int = -1, keepdims: bool = False):
new_memref = machine_op("memref_max", memref, array.data, i32(axis), i32(keepdims))
return Array(dtype=array.dtype, data=new_memref)

def array_maximum(array: Array, other):
if isinstance(other, Array):
other = other.data
else:
other = to_scalar_array(other, array.dtype).data

new_memref = machine_op("memref_maximum", memref, array.data, other)
return Array(dtype=array.dtype, data=new_memref)

def array_matmul(matrix_1: Array, matrix_2: Array):
new_memref = machine_op("memref_matmul", memref, matrix_1.data, matrix_2.data)
return Array(dtype=matrix_1.dtype, data=new_memref)
97 changes: 96 additions & 1 deletion mcl/machine_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typing as _tp

from mcl.vm import machine_op, machine_type, struct_type
from mcl.vm import machine_op, machine_type

T = _tp.TypeVar("T")

Expand All @@ -17,12 +17,42 @@ def __add__(self, other) -> i32:
else:
return NotImplemented

def __sub__(self, other) -> i32:
if type(other) is i32:
return machine_op("int_sub", i32, self, other)
else:
return NotImplemented

def __mul__(self, other) -> i32:
if type(other) is i32:
return machine_op("int_mul", i32, self, other)
else:
return NotImplemented

def __floordiv__(self, other) -> i32:
if type(other) is i32:
return machine_op("int_floordiv", i32, self, other)
else:
return NotImplemented

def __eq__(self, other) -> bool:
if type(other) is i32:
return machine_op("int_eq", bool, self, other)
else:
return NotImplemented

def __lt__(self, other) -> bool:
if type(other) is i32:
return machine_op("int_lt", bool, self, other)
else:
return NotImplemented

def __mod__(self, other) -> i32:
if type(other) is i32:
return machine_op("int_mod", i32, self, other)
else:
return NotImplemented


@machine_type(builtin=True, final=True)
class i64:
Expand Down Expand Up @@ -84,9 +114,70 @@ def __lt__(self, other) -> bool:
return machine_op("int_lt", bool, self, other)
else:
return NotImplemented

def __mod__(self, other) -> intp:
if type(other) is intp:
return machine_op("int_mod", intp, self, other)
else:
return NotImplemented

def __index__(self) -> int:
return machine_op("cast", int, self)

def __hash__(self):
return hash(int(self))

@machine_type(builtin=True, final=True)
class f32:
__machine_repr__ = "f32"

def __add__(self, other) -> f32:
if type(other) is f32:
return machine_op("float_add", f32, self, other)
else:
return NotImplemented

def __sub__(self, other) -> f32:
if type(other) is f32:
return machine_op("float_sub", f32, self, other)
else:
return NotImplemented

def __mul__(self, other) -> f32:
if type(other) is f32:
return machine_op("float_mul", f32, self, other)
else:
return NotImplemented

def __floordiv__(self, other) -> f32:
if type(other) is f32:
return machine_op("float_floordiv", f32, self, other)
else:
return NotImplemented

def __truediv__(self, other) -> f32:
if type(other) is f32:
return machine_op("float_truediv", f32, self, other)
else:
return NotImplemented

def __eq__(self, other) -> bool:
if type(other) is f32:
return machine_op("float_eq", bool, self, other)
else:
return NotImplemented

def __lt__(self, other) -> bool:
if type(other) is f32:
return machine_op("float_lt", bool, self, other)
else:
return NotImplemented

def __mod__(self, other) -> f32:
if type(other) is f32:
return machine_op("float_mod", f32, self, other)
else:
return NotImplemented


@machine_type(builtin=True, final=True)
Expand All @@ -97,6 +188,10 @@ class memref[T]:
def alloc(cls, shape: tuple[intp, ...], type: _tp.Type[T]) -> memref[T]:
return machine_op("memref_alloc", memref, shape, type)

@classmethod
def alloc_random(cls, shape: tuple[intp, ...], type: _tp.Type[T]) -> memref[T]:
return machine_op("memref_alloc_random", memref, shape, type)

@property
def shape(self) -> tuple[intp, ...]:
return machine_op("memref_shape", tuple, self)
Expand Down
123 changes: 117 additions & 6 deletions mcl/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import typing as _tp

from mcl.builtins import tuple_cast
from mcl.machine_types import i32, intp, memref
from mcl.vm import struct_type
from mcl.machine_types import i32, intp, memref, f32
from mcl.vm import struct_type, machine_op
from mcl.dialects import LoopNestAPI
from mcl.vm import _get_machine_value


@struct_type()
Expand All @@ -29,6 +30,9 @@ class Number(Generic):
class Integer(Number):
pass

@struct_type()
class Float(Number):
pass

type _IntLike = int | intp
type _Indices = tuple[_IntLike, ...] | _IntLike
Expand All @@ -49,6 +53,21 @@ def __eq__(self, other) -> bool:
return self.value == other.value
return NotImplemented

@struct_type()
class Float32(Float):
value: f32

@classmethod
def from_memory(cls, data: memref, index: tuple[intp, ...]) -> Float32:
return cls(value=data.load(index, f32))

def __eq__(self, other) -> bool:
if type(other) is f32:
return self.value == other
elif isinstance(other, Float32):
return self.value == other.value
return NotImplemented


@struct_type(final=True)
class Array[T]:
Expand Down Expand Up @@ -123,6 +142,30 @@ def __getitem__(self, idx: _Indices) -> Array[T] | Generic:
idx = tuple_cast(intp, idx)
# TODO: There's no assertion that checks if idx is within bounds.
return self.dtype.type.from_memory(self.data, idx)

def __add__(self, other: Array[T]) -> Array[T]:
if isinstance(other, Array):
other_memref = other.data
else:
other_memref = to_scalar_array(other, self.dtype).data
machine_op("memref_add", memref, self.data, other_memref)
return self

def __truediv__(self, other: Array[T]) -> Array[T]:
if isinstance(other, Array):
other_memref = other.data
else:
other_memref = to_scalar_array(other, self.dtype).data
machine_op("memref_truediv", memref, self.data, other_memref)
return self

def __sub__(self, other: Array[T]) -> Array[T]:
if isinstance(other, Array):
other_memref = other.data
else:
other_memref = to_scalar_array(other, self.dtype).data
machine_op("memref_sub", memref, self.data, other_memref)
return self

def broadcast_to(self, shape: tuple[intp, ...]) -> None:
# This function can also serve as a assertion
Expand All @@ -144,6 +187,68 @@ def broadcast_to(self, shape: tuple[intp, ...]) -> None:

self.data = self.data.view(new_shape, new_strides, self.data.offset)

def reshape(self, shape: tuple[intp, ...], copy: bool=True) -> None:
# Check if this is a valid reshape
num_elems_orig = intp(1)
for i in self.shape:
num_elems_orig *= i
num_elems_new = intp(1)
has_neg_one = False
neg_one_idx = -1
for idx, i in enumerate(shape):
if i == intp(-1):
if has_neg_one:
raise ValueError("Only one dimension can be -1")
has_neg_one = True
neg_one_idx = idx
else:
num_elems_new *= i

if has_neg_one:
assert num_elems_orig % num_elems_new == intp(0), f"Cannot reshape array of shape {self.shape} to shape {shape}"

shape = shape[:neg_one_idx] + (num_elems_orig // num_elems_new,) + shape[neg_one_idx + 1:]
else:
assert num_elems_orig == num_elems_new, f"Cannot reshape array of shape {self.shape} to shape {shape}"

new_strides = [intp(0)] * len(shape)

if copy:
self = self.copy()

# TODO: Check if this logic is true
# If we are at this point, we can assume that the array is contiguous
# If it wasn't earlier the copy would have made it contiguous

# TODO: This should be intp(size_of_element) instead of 4
new_strides[-1] = intp(4)

for i in range(len(shape) - 2, -1, -1):
new_strides[i] = new_strides[i + 1] * shape[i + 1]

self.data = self.data.view(shape, new_strides, self.data.offset)
return self

def transpose(self, axis: tuple[intp, ...] = None) -> None:
# Check is axis is valid
if axis is None:
axis = tuple([intp(i) for i in range(self.ndim - 1, -1, -1)])

assert intp(len(axis)) == self.ndim
assert set(axis) == set([intp(i) for i in range(self.ndim)])

new_shape = [intp(0)] * self.ndim
new_strides = [intp(0)] * self.ndim

# These instances of get_machine_value needs to be removed
for i, j in enumerate(axis):
new_shape[i] = self.shape[_get_machine_value(j)]
new_strides[i] = self.strides[_get_machine_value(j)]

self.data = self.data.view(tuple(new_shape), tuple(new_strides), self.data.offset)

return self

@classmethod
def is_advanced(cls, idx: _Indices) -> bool:
return any((isinstance(i, Array) and i.ndim > intp(0)) for i in idx)
Expand Down Expand Up @@ -267,7 +372,13 @@ def copy(self) -> Array[T]:
return Array(dtype=self.dtype, data=self.data.copy())

def print(self) -> None:
res = []
for idx in LoopNestAPI.from_tuple(self.shape):
res.append(self[idx].value)
print(res)
machine_op("memref_print", None, self.data)

@classmethod
def random(cls, shape: tuple[intp, ...]) -> None:
return Array(dtype=DType(Float32), data=memref.alloc_random(shape, f32))

def to_scalar_array(data, dtype):
temp_memref = machine_op("memref_alloc", memref, (intp(1),), f32)
machine_op("memref_store", None, temp_memref, (i32(0),), data)
return Array(dtype=dtype, data=temp_memref)
Loading