Welcome to systemml’s documentation!

Contents:

systemml package

Subpackages

systemml.mllearn package

Submodules
systemml.mllearn.estimators module
class systemml.mllearn.estimators.LinearRegression(sqlCtx, fit_intercept=True, max_iter=100, tol=1e-06, C=1.0, solver='newton-cg', transferUsingDF=False)[source]

Bases: systemml.mllearn.estimators.BaseSystemMLRegressor

Performs linear regression to model the relationship between one numerical response variable and one or more explanatory (feature) variables.

>>> import numpy as np
>>> from sklearn import datasets
>>> from systemml.mllearn import LinearRegression
>>> from pyspark.sql import SQLContext
>>> # Load the diabetes dataset
>>> diabetes = datasets.load_diabetes()
>>> # Use only one feature
>>> diabetes_X = diabetes.data[:, np.newaxis, 2]
>>> # Split the data into training/testing sets
>>> diabetes_X_train = diabetes_X[:-20]
>>> diabetes_X_test = diabetes_X[-20:]
>>> # Split the targets into training/testing sets
>>> diabetes_y_train = diabetes.target[:-20]
>>> diabetes_y_test = diabetes.target[-20:]
>>> # Create linear regression object
>>> regr = LinearRegression(sqlCtx, solver='newton-cg')
>>> # Train the model using the training sets
>>> regr.fit(diabetes_X_train, diabetes_y_train)
>>> # The mean square error
>>> print("Residual sum of squares: %.2f" % np.mean((regr.predict(diabetes_X_test) - diabetes_y_test) ** 2))
class systemml.mllearn.estimators.LogisticRegression(sqlCtx, penalty='l2', fit_intercept=True, max_iter=100, max_inner_iter=0, tol=1e-06, C=1.0, solver='newton-cg', transferUsingDF=False)[source]

Bases: systemml.mllearn.estimators.BaseSystemMLClassifier

Performs both binomial and multinomial logistic regression.

Scikit-learn way

>>> from sklearn import datasets, neighbors
>>> from systemml.mllearn import LogisticRegression
>>> from pyspark.sql import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> digits = datasets.load_digits()
>>> X_digits = digits.data
>>> y_digits = digits.target + 1
>>> n_samples = len(X_digits)
>>> X_train = X_digits[:.9 * n_samples]
>>> y_train = y_digits[:.9 * n_samples]
>>> X_test = X_digits[.9 * n_samples:]
>>> y_test = y_digits[.9 * n_samples:]
>>> logistic = LogisticRegression(sqlCtx)
>>> print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_test, y_test))

MLPipeline way

>>> from pyspark.ml import Pipeline
>>> from systemml.mllearn import LogisticRegression
>>> from pyspark.ml.feature import HashingTF, Tokenizer
>>> from pyspark.sql import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> training = sqlCtx.createDataFrame([
>>>     (0L, "a b c d e spark", 1.0),
>>>     (1L, "b d", 2.0),
>>>     (2L, "spark f g h", 1.0),
>>>     (3L, "hadoop mapreduce", 2.0),
>>>     (4L, "b spark who", 1.0),
>>>     (5L, "g d a y", 2.0),
>>>     (6L, "spark fly", 1.0),
>>>     (7L, "was mapreduce", 2.0),
>>>     (8L, "e spark program", 1.0),
>>>     (9L, "a e c l", 2.0),
>>>     (10L, "spark compile", 1.0),
>>>     (11L, "hadoop software", 2.0)
>>> ], ["id", "text", "label"])
>>> tokenizer = Tokenizer(inputCol="text", outputCol="words")
>>> hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
>>> lr = LogisticRegression(sqlCtx)
>>> pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
>>> model = pipeline.fit(training)
>>> test = sqlCtx.createDataFrame([
>>>     (12L, "spark i j k"),
>>>     (13L, "l m n"),
>>>     (14L, "mapreduce spark"),
>>>     (15L, "apache hadoop")], ["id", "text"])
>>> prediction = model.transform(test)
>>> prediction.show()
class systemml.mllearn.estimators.SVM(sqlCtx, fit_intercept=True, max_iter=100, tol=1e-06, C=1.0, is_multi_class=False, transferUsingDF=False)[source]

Bases: systemml.mllearn.estimators.BaseSystemMLClassifier

Performs both binary-class and multiclass SVM (Support Vector Machines).

>>> from sklearn import datasets, neighbors
>>> from systemml.mllearn import SVM
>>> from pyspark.sql import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> digits = datasets.load_digits()
>>> X_digits = digits.data
>>> y_digits = digits.target 
>>> n_samples = len(X_digits)
>>> X_train = X_digits[:.9 * n_samples]
>>> y_train = y_digits[:.9 * n_samples]
>>> X_test = X_digits[.9 * n_samples:]
>>> y_test = y_digits[.9 * n_samples:]
>>> svm = SVM(sqlCtx, is_multi_class=True)
>>> print('LogisticRegression score: %f' % svm.fit(X_train, y_train).score(X_test, y_test))
class systemml.mllearn.estimators.NaiveBayes(sqlCtx, laplace=1.0, transferUsingDF=False)[source]

Bases: systemml.mllearn.estimators.BaseSystemMLClassifier

Performs Naive Bayes.

>>> from sklearn.datasets import fetch_20newsgroups
>>> from sklearn.feature_extraction.text import TfidfVectorizer
>>> from systemml.mllearn import NaiveBayes
>>> from sklearn import metrics
>>> from pyspark.sql import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']
>>> newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
>>> newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
>>> vectorizer = TfidfVectorizer()
>>> # Both vectors and vectors_test are SciPy CSR matrix
>>> vectors = vectorizer.fit_transform(newsgroups_train.data)
>>> vectors_test = vectorizer.transform(newsgroups_test.data)
>>> nb = NaiveBayes(sqlCtx)
>>> nb.fit(vectors, newsgroups_train.target)
>>> pred = nb.predict(vectors_test)
>>> metrics.f1_score(newsgroups_test.target, pred, average='weighted')
Module contents
SystemML Algorithms
Classification Algorithms
LogisticRegression Performs binomial and multinomial logistic regression
SVM Performs both binary-class and multi-class SVM
NaiveBayes Multinomial naive bayes classifier
Regression Algorithms
LinearRegression Performs linear regression
class systemml.mllearn.LinearRegression(sqlCtx, fit_intercept=True, max_iter=100, tol=1e-06, C=1.0, solver='newton-cg', transferUsingDF=False)[source]

Bases: systemml.mllearn.estimators.BaseSystemMLRegressor

Performs linear regression to model the relationship between one numerical response variable and one or more explanatory (feature) variables.

>>> import numpy as np
>>> from sklearn import datasets
>>> from systemml.mllearn import LinearRegression
>>> from pyspark.sql import SQLContext
>>> # Load the diabetes dataset
>>> diabetes = datasets.load_diabetes()
>>> # Use only one feature
>>> diabetes_X = diabetes.data[:, np.newaxis, 2]
>>> # Split the data into training/testing sets
>>> diabetes_X_train = diabetes_X[:-20]
>>> diabetes_X_test = diabetes_X[-20:]
>>> # Split the targets into training/testing sets
>>> diabetes_y_train = diabetes.target[:-20]
>>> diabetes_y_test = diabetes.target[-20:]
>>> # Create linear regression object
>>> regr = LinearRegression(sqlCtx, solver='newton-cg')
>>> # Train the model using the training sets
>>> regr.fit(diabetes_X_train, diabetes_y_train)
>>> # The mean square error
>>> print("Residual sum of squares: %.2f" % np.mean((regr.predict(diabetes_X_test) - diabetes_y_test) ** 2))
class systemml.mllearn.LogisticRegression(sqlCtx, penalty='l2', fit_intercept=True, max_iter=100, max_inner_iter=0, tol=1e-06, C=1.0, solver='newton-cg', transferUsingDF=False)[source]

Bases: systemml.mllearn.estimators.BaseSystemMLClassifier

Performs both binomial and multinomial logistic regression.

Scikit-learn way

>>> from sklearn import datasets, neighbors
>>> from systemml.mllearn import LogisticRegression
>>> from pyspark.sql import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> digits = datasets.load_digits()
>>> X_digits = digits.data
>>> y_digits = digits.target + 1
>>> n_samples = len(X_digits)
>>> X_train = X_digits[:.9 * n_samples]
>>> y_train = y_digits[:.9 * n_samples]
>>> X_test = X_digits[.9 * n_samples:]
>>> y_test = y_digits[.9 * n_samples:]
>>> logistic = LogisticRegression(sqlCtx)
>>> print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_test, y_test))

MLPipeline way

>>> from pyspark.ml import Pipeline
>>> from systemml.mllearn import LogisticRegression
>>> from pyspark.ml.feature import HashingTF, Tokenizer
>>> from pyspark.sql import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> training = sqlCtx.createDataFrame([
>>>     (0L, "a b c d e spark", 1.0),
>>>     (1L, "b d", 2.0),
>>>     (2L, "spark f g h", 1.0),
>>>     (3L, "hadoop mapreduce", 2.0),
>>>     (4L, "b spark who", 1.0),
>>>     (5L, "g d a y", 2.0),
>>>     (6L, "spark fly", 1.0),
>>>     (7L, "was mapreduce", 2.0),
>>>     (8L, "e spark program", 1.0),
>>>     (9L, "a e c l", 2.0),
>>>     (10L, "spark compile", 1.0),
>>>     (11L, "hadoop software", 2.0)
>>> ], ["id", "text", "label"])
>>> tokenizer = Tokenizer(inputCol="text", outputCol="words")
>>> hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
>>> lr = LogisticRegression(sqlCtx)
>>> pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
>>> model = pipeline.fit(training)
>>> test = sqlCtx.createDataFrame([
>>>     (12L, "spark i j k"),
>>>     (13L, "l m n"),
>>>     (14L, "mapreduce spark"),
>>>     (15L, "apache hadoop")], ["id", "text"])
>>> prediction = model.transform(test)
>>> prediction.show()
class systemml.mllearn.SVM(sqlCtx, fit_intercept=True, max_iter=100, tol=1e-06, C=1.0, is_multi_class=False, transferUsingDF=False)[source]

Bases: systemml.mllearn.estimators.BaseSystemMLClassifier

Performs both binary-class and multiclass SVM (Support Vector Machines).

>>> from sklearn import datasets, neighbors
>>> from systemml.mllearn import SVM
>>> from pyspark.sql import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> digits = datasets.load_digits()
>>> X_digits = digits.data
>>> y_digits = digits.target 
>>> n_samples = len(X_digits)
>>> X_train = X_digits[:.9 * n_samples]
>>> y_train = y_digits[:.9 * n_samples]
>>> X_test = X_digits[.9 * n_samples:]
>>> y_test = y_digits[.9 * n_samples:]
>>> svm = SVM(sqlCtx, is_multi_class=True)
>>> print('LogisticRegression score: %f' % svm.fit(X_train, y_train).score(X_test, y_test))
class systemml.mllearn.NaiveBayes(sqlCtx, laplace=1.0, transferUsingDF=False)[source]

Bases: systemml.mllearn.estimators.BaseSystemMLClassifier

Performs Naive Bayes.

>>> from sklearn.datasets import fetch_20newsgroups
>>> from sklearn.feature_extraction.text import TfidfVectorizer
>>> from systemml.mllearn import NaiveBayes
>>> from sklearn import metrics
>>> from pyspark.sql import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']
>>> newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
>>> newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
>>> vectorizer = TfidfVectorizer()
>>> # Both vectors and vectors_test are SciPy CSR matrix
>>> vectors = vectorizer.fit_transform(newsgroups_train.data)
>>> vectors_test = vectorizer.transform(newsgroups_test.data)
>>> nb = NaiveBayes(sqlCtx)
>>> nb.fit(vectors, newsgroups_train.target)
>>> pred = nb.predict(vectors_test)
>>> metrics.f1_score(newsgroups_test.target, pred, average='weighted')

systemml.random package

Submodules
systemml.random.sampling module
systemml.random.sampling.normal(loc=0.0, scale=1.0, size=(1, 1), sparsity=1.0)[source]

Draw random samples from a normal (Gaussian) distribution.

loc: Mean (“centre”) of the distribution. scale: Standard deviation (spread or “width”) of the distribution. size: Output shape (only tuple of length 2, i.e. (m, n), supported). sparsity: Sparsity (between 0.0 and 1.0).

>>> import systemml as sml
>>> import numpy as np
>>> sml.setSparkContext(sc)
>>> from systemml import random
>>> m1 = sml.random.normal(loc=3, scale=2, size=(3,3))
>>> m1.toNumPy()
array([[ 3.48857226,  6.17261819,  2.51167259],
       [ 3.60506708, -1.90266305,  3.97601633],
       [ 3.62245706,  5.9430881 ,  2.53070413]])
systemml.random.sampling.uniform(low=0.0, high=1.0, size=(1, 1), sparsity=1.0)[source]

Draw samples from a uniform distribution.

low: Lower boundary of the output interval. high: Upper boundary of the output interval. size: Output shape (only tuple of length 2, i.e. (m, n), supported). sparsity: Sparsity (between 0.0 and 1.0).

>>> import systemml as sml
>>> import numpy as np
>>> sml.setSparkContext(sc)
>>> from systemml import random
>>> m1 = sml.random.uniform(size=(3,3))
>>> m1.toNumPy()
array([[ 0.54511396,  0.11937437,  0.72975775],
       [ 0.14135946,  0.01944448,  0.52544478],
       [ 0.67582422,  0.87068849,  0.02766852]])
systemml.random.sampling.poisson(lam=1.0, size=(1, 1), sparsity=1.0)[source]

Draw samples from a Poisson distribution.

lam: Expectation of interval, should be > 0. size: Output shape (only tuple of length 2, i.e. (m, n), supported). sparsity: Sparsity (between 0.0 and 1.0).

>>> import systemml as sml
>>> import numpy as np
>>> sml.setSparkContext(sc)
>>> from systemml import random
>>> m1 = sml.random.poisson(lam=1, size=(3,3))
>>> m1.toNumPy()
array([[ 1.,  0.,  2.],
       [ 1.,  0.,  0.],
       [ 0.,  0.,  0.]])
Module contents
Random Number Generation
Univariate distributions
normal Normal / Gaussian distribution.
poisson Poisson distribution.
uniform Uniform distribution.
systemml.random.normal(loc=0.0, scale=1.0, size=(1, 1), sparsity=1.0)[source]

Draw random samples from a normal (Gaussian) distribution.

loc: Mean (“centre”) of the distribution. scale: Standard deviation (spread or “width”) of the distribution. size: Output shape (only tuple of length 2, i.e. (m, n), supported). sparsity: Sparsity (between 0.0 and 1.0).

>>> import systemml as sml
>>> import numpy as np
>>> sml.setSparkContext(sc)
>>> from systemml import random
>>> m1 = sml.random.normal(loc=3, scale=2, size=(3,3))
>>> m1.toNumPy()
array([[ 3.48857226,  6.17261819,  2.51167259],
       [ 3.60506708, -1.90266305,  3.97601633],
       [ 3.62245706,  5.9430881 ,  2.53070413]])
systemml.random.uniform(low=0.0, high=1.0, size=(1, 1), sparsity=1.0)[source]

Draw samples from a uniform distribution.

low: Lower boundary of the output interval. high: Upper boundary of the output interval. size: Output shape (only tuple of length 2, i.e. (m, n), supported). sparsity: Sparsity (between 0.0 and 1.0).

>>> import systemml as sml
>>> import numpy as np
>>> sml.setSparkContext(sc)
>>> from systemml import random
>>> m1 = sml.random.uniform(size=(3,3))
>>> m1.toNumPy()
array([[ 0.54511396,  0.11937437,  0.72975775],
       [ 0.14135946,  0.01944448,  0.52544478],
       [ 0.67582422,  0.87068849,  0.02766852]])
systemml.random.poisson(lam=1.0, size=(1, 1), sparsity=1.0)[source]

Draw samples from a Poisson distribution.

lam: Expectation of interval, should be > 0. size: Output shape (only tuple of length 2, i.e. (m, n), supported). sparsity: Sparsity (between 0.0 and 1.0).

>>> import systemml as sml
>>> import numpy as np
>>> sml.setSparkContext(sc)
>>> from systemml import random
>>> m1 = sml.random.poisson(lam=1, size=(3,3))
>>> m1.toNumPy()
array([[ 1.,  0.,  2.],
       [ 1.,  0.,  0.],
       [ 0.,  0.,  0.]])

Submodules

systemml.converters module

systemml.converters.getNumCols(numPyArr)[source]
systemml.converters.convertToMatrixBlock(sc, src)[source]
systemml.converters.convertToNumPyArr(sc, mb)[source]
systemml.converters.convertToPandasDF(X)[source]
systemml.converters.convertToLabeledDF(sqlCtx, X, y=None)[source]

systemml.defmatrix module

systemml.defmatrix.setSparkContext(sc)[source]

Before using the matrix, the user needs to invoke this function.

sc: SparkContext
SparkContext
class systemml.defmatrix.matrix(data, op=None)[source]

Bases: object

matrix class is a python wrapper that implements basic matrix operators, matrix functions as well as converters to common Python types (for example: Numpy arrays, PySpark DataFrame and Pandas DataFrame).

The operators supported are:

  1. Arithmetic operators: +, -, , /, //, %, * as well as dot (i.e. matrix multiplication)
  2. Indexing in the matrix
  3. Relational/Boolean operators: <, <=, >, >=, ==, !=, &, |

In addition, following functions are supported for matrix:

  1. transpose
  2. Aggregation functions: sum, mean, var, sd, max, min, argmin, argmax, cumsum
  3. Global statistical built-In functions: exp, log, abs, sqrt, round, floor, ceil, sin, cos, tan, asin, acos, atan, sign, solve

Note: an evaluated matrix contains a data field computed by eval method as DataFrame or NumPy array.

>>> import SystemML as sml
>>> import numpy as np
>>> sml.setSparkContext(sc)

Welcome to Apache SystemML!

>>> m1 = sml.matrix(np.ones((3,3)) + 2)
>>> m2 = sml.matrix(np.ones((3,3)) + 3)
>>> m2 = m1 * (m2 + m1)
>>> m4 = 1.0 - m2
>>> m4
# This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.
mVar1 = load(" ", format="csv")
mVar2 = load(" ", format="csv")
mVar3 = mVar2 + mVar1
mVar4 = mVar1 * mVar3
mVar5 = 1.0 - mVar4
save(mVar5, " ")
>>> m2.eval()
>>> m2
# This matrix (mVar4) is backed by NumPy array. To fetch the NumPy array, invoke toNumPy() method.
>>> m4
# This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.
mVar4 = load(" ", format="csv")
mVar5 = 1.0 - mVar4
save(mVar5, " ")
>>> m4.sum(axis=1).toNumPy()
array([[-60.],
       [-60.],
       [-60.]])

Design Decisions:

  1. Until eval() method is invoked, we create an AST (not exposed to the user) that consist of unevaluated operations and data required by those operations. As an anology, a spark user can treat eval() method similar to calling RDD.persist() followed by RDD.count().
  2. The AST consist of two kinds of nodes: either of type matrix or of type DMLOp. Both these classes expose _visit method, that helps in traversing the AST in DFS manner.
  3. A matrix object can either be evaluated or not. If evaluated, the attribute ‘data’ is set to one of the supported types (for example: NumPy array or DataFrame). In this case, the attribute ‘op’ is set to None. If not evaluated, the attribute ‘op’ which refers to one of the intermediate node of AST and if of type DMLOp. In this case, the attribute ‘data’ is set to None.
  1. DMLOp has an attribute ‘inputs’ which contains list of matrix objects or DMLOp.

  2. To simplify the traversal, every matrix object is considered immutable and an matrix operations creates a new matrix object. As an example: m1 = sml.matrix(np.ones((3,3))) creates a matrix object backed by ‘data=(np.ones((3,3))’. m1 = m1 * 2 will create a new matrix object which is now backed by ‘op=DMLOp( ... )’ whose input is earlier created matrix object.

  3. Left indexing (implemented in __setitem__ method) is a special case, where Python expects the existing object to be mutated. To ensure the above property, we make deep copy of existing object and point any references to the left-indexed matrix to the newly created object. Then the left-indexed matrix is set to be backed by DMLOp consisting of following pydml: left-indexed-matrix = new-deep-copied-matrix left-indexed-matrix[index] = value

  4. Please use m.printAST() and/or type m for debugging. Here is a sample session:

    >>> npm = np.ones((3,3))
    >>> m1 = sml.matrix(npm + 3)
    >>> m2 = sml.matrix(npm + 5)
    >>> m3 = m1 + m2
    >>> m3
    mVar2 = load(" ", format="csv")
    mVar1 = load(" ", format="csv")
    mVar3 = mVar1 + mVar2
    save(mVar3, " ")
    >>> m3.printAST()
    - [mVar3] (op).
      - [mVar1] (data).
      - [mVar2] (data).    
    
argmax(axis=None)[source]

Returns the indices of the maximum values along an axis.

axis : int, optional (only axis=1, i.e. rowIndexMax is supported in this version)

argmin(axis=None)[source]

Returns the indices of the minimum values along an axis.

axis : int, optional (only axis=1, i.e. rowIndexMax is supported in this version)

cumsum(axis=None)[source]

Returns the indices of the maximum values along an axis.

axis : int, optional (only axis=0, i.e. cumsum along the rows is supported in this version)

dml = []
dot(other)[source]

Numpy way of performing matrix multiplication

eval(outputDF=False)[source]

This is a convenience function that calls the global eval method

max(axis=None)[source]

Compute the maximum value along the specified axis

axis : int, optional

mean(axis=None)[source]

Compute the arithmetic mean along the specified axis

axis : int, optional

min(axis=None)[source]

Compute the minimum value along the specified axis

axis : int, optional

ml = None
printAST(numSpaces=0)[source]

Please use m.printAST() and/or type m for debugging. Here is a sample session:

>>> npm = np.ones((3,3))
>>> m1 = sml.matrix(npm + 3)
>>> m2 = sml.matrix(npm + 5)
>>> m3 = m1 + m2
>>> m3
mVar2 = load(" ", format="csv")
mVar1 = load(" ", format="csv")
mVar3 = mVar1 + mVar2
save(mVar3, " ")
>>> m3.printAST()
- [mVar3] (op).
  - [mVar1] (data).
  - [mVar2] (data).
script = None
sd(axis=None)[source]

Compute the standard deviation along the specified axis

axis : int, optional

sum(axis=None)[source]

Compute the sum along the specified axis

axis : int, optional

systemmlVarID = 0
toDF()[source]

This is a convenience function that calls the global eval method and then converts the matrix object into DataFrame.

toNumPy()[source]

This is a convenience function that calls the global eval method and then converts the matrix object into NumPy array.

toPandas()[source]

This is a convenience function that calls the global eval method and then converts the matrix object into Pandas DataFrame.

trace()[source]

Return the sum of the cells of the main diagonal square matrix

transpose()[source]

Transposes the matrix.

var(axis=None)[source]

Compute the variance along the specified axis

axis : int, optional

visited = []
systemml.defmatrix.eval(outputs, outputDF=False, execute=True)[source]

Executes the unevaluated DML script and computes the matrices specified by outputs.

outputs: list of matrices or a matrix object outputDF: back the data of matrix as PySpark DataFrame

systemml.defmatrix.solve(A, b)[source]

Computes the least squares solution for system of linear equations A %*% x = b

>>> import numpy as np
>>> from sklearn import datasets
>>> import SystemML as sml
>>> from pyspark.sql import SQLContext
>>> diabetes = datasets.load_diabetes()
>>> diabetes_X = diabetes.data[:, np.newaxis, 2]
>>> X_train = diabetes_X[:-20]
>>> X_test = diabetes_X[-20:]
>>> y_train = diabetes.target[:-20]
>>> y_test = diabetes.target[-20:]
>>> sml.setSparkContext(sc)
>>> X = sml.matrix(X_train)
>>> y = sml.matrix(y_train)
>>> A = X.transpose().dot(X)
>>> b = X.transpose().dot(y)
>>> beta = sml.solve(A, b).toNumPy()
>>> y_predicted = X_test.dot(beta)
>>> print('Residual sum of squares: %.2f' % np.mean((y_predicted - y_test) ** 2))
Residual sum of squares: 25282.12
class systemml.defmatrix.DMLOp(inputs, dml=None)[source]

Bases: object

Represents an intermediate node of Abstract syntax tree created to generate the PyDML script

printAST(numSpaces)[source]
systemml.defmatrix.exp(X)[source]
systemml.defmatrix.log(X, y=None)[source]
systemml.defmatrix.abs(X)[source]
systemml.defmatrix.sqrt(X)[source]
systemml.defmatrix.round(X)[source]
systemml.defmatrix.floor(X)[source]
systemml.defmatrix.ceil(X)[source]
systemml.defmatrix.sin(X)[source]
systemml.defmatrix.cos(X)[source]
systemml.defmatrix.tan(X)[source]
systemml.defmatrix.asin(X)[source]
systemml.defmatrix.acos(X)[source]
systemml.defmatrix.atan(X)[source]
systemml.defmatrix.sign(X)[source]

systemml.mlcontext module

class systemml.mlcontext.MLResults(results, sc)[source]

Bases: object

Wrapper around a Java ML Results object.

results: JavaObject
A Java MLResults object as returned by calling ml.execute().
sc: SparkContext
SparkContext
get(*outputs)[source]
outputs: string, list of strings
Output variables as defined inside the DML script.
class systemml.mlcontext.MLContext(sc)[source]

Bases: object

Wrapper around the new SystemML MLContext.

sc: SparkContext
SparkContext
execute(script)[source]

Execute a DML / PyDML script.

script: Script instance
Script instance defined with the appropriate input and output variables.
ml_results: MLResults
MLResults instance.
setExplain(explain)[source]

Explanation about the program. Mainly intended for developers.

explain: boolean

setExplainLevel(explainLevel)[source]

Set explain level.

explainLevel: string
Can be one of “hops”, “runtime”, “recompile_hops”, “recompile_runtime” or in the above in upper case.
setStatistics(statistics)[source]

Whether or not to output statistics (such as execution time, elapsed time) about script executions.

statistics: boolean

class systemml.mlcontext.Script(scriptString, scriptType='dml')[source]

Bases: object

Instance of a DML/PyDML Script.

scriptString: string
Can be either a file path to a DML script or a DML script itself.
scriptType: string
Script language, either “dml” for DML (R-like) or “pydml” for PyDML (Python-like).
input(*args, **kwargs)[source]
args: name, value tuple
where name is a string, and currently supported value formats are double, string, dataframe, rdd, and list of such object.
kwargs: dict of name, value pairs
To know what formats are supported for name and value, look above.
output(*names)[source]
names: string, list of strings
Output variables as defined inside the DML script.
systemml.mlcontext.dml(scriptString)[source]

Create a dml script object based on a string.

scriptString: string
Can be a path to a dml script or a dml script itself.
script: Script instance
Instance of a script object.
systemml.mlcontext.pydml(scriptString)[source]

Create a pydml script object based on a string.

scriptString: string
Can be a path to a pydml script or a pydml script itself.
script: Script instance
Instance of a script object.
systemml.mlcontext.getNumCols(numPyArr)[source]
systemml.mlcontext.convertToMatrixBlock(sc, src)[source]
systemml.mlcontext.convertToNumPyArr(sc, mb)[source]
systemml.mlcontext.convertToPandasDF(X)[source]
systemml.mlcontext.convertToLabeledDF(sqlCtx, X, y=None)[source]

Module contents

class systemml.MLResults(results, sc)[source]

Bases: object

Wrapper around a Java ML Results object.

results: JavaObject
A Java MLResults object as returned by calling ml.execute().
sc: SparkContext
SparkContext
get(*outputs)[source]
outputs: string, list of strings
Output variables as defined inside the DML script.
class systemml.MLContext(sc)[source]

Bases: object

Wrapper around the new SystemML MLContext.

sc: SparkContext
SparkContext
execute(script)[source]

Execute a DML / PyDML script.

script: Script instance
Script instance defined with the appropriate input and output variables.
ml_results: MLResults
MLResults instance.
setExplain(explain)[source]

Explanation about the program. Mainly intended for developers.

explain: boolean

setExplainLevel(explainLevel)[source]

Set explain level.

explainLevel: string
Can be one of “hops”, “runtime”, “recompile_hops”, “recompile_runtime” or in the above in upper case.
setStatistics(statistics)[source]

Whether or not to output statistics (such as execution time, elapsed time) about script executions.

statistics: boolean

class systemml.Script(scriptString, scriptType='dml')[source]

Bases: object

Instance of a DML/PyDML Script.

scriptString: string
Can be either a file path to a DML script or a DML script itself.
scriptType: string
Script language, either “dml” for DML (R-like) or “pydml” for PyDML (Python-like).
input(*args, **kwargs)[source]
args: name, value tuple
where name is a string, and currently supported value formats are double, string, dataframe, rdd, and list of such object.
kwargs: dict of name, value pairs
To know what formats are supported for name and value, look above.
output(*names)[source]
names: string, list of strings
Output variables as defined inside the DML script.
systemml.dml(scriptString)[source]

Create a dml script object based on a string.

scriptString: string
Can be a path to a dml script or a dml script itself.
script: Script instance
Instance of a script object.
systemml.pydml(scriptString)[source]

Create a pydml script object based on a string.

scriptString: string
Can be a path to a pydml script or a pydml script itself.
script: Script instance
Instance of a script object.
systemml.setSparkContext(sc)[source]

Before using the matrix, the user needs to invoke this function.

sc: SparkContext
SparkContext
class systemml.matrix(data, op=None)[source]

Bases: object

matrix class is a python wrapper that implements basic matrix operators, matrix functions as well as converters to common Python types (for example: Numpy arrays, PySpark DataFrame and Pandas DataFrame).

The operators supported are:

  1. Arithmetic operators: +, -, , /, //, %, * as well as dot (i.e. matrix multiplication)
  2. Indexing in the matrix
  3. Relational/Boolean operators: <, <=, >, >=, ==, !=, &, |

In addition, following functions are supported for matrix:

  1. transpose
  2. Aggregation functions: sum, mean, var, sd, max, min, argmin, argmax, cumsum
  3. Global statistical built-In functions: exp, log, abs, sqrt, round, floor, ceil, sin, cos, tan, asin, acos, atan, sign, solve

Note: an evaluated matrix contains a data field computed by eval method as DataFrame or NumPy array.

>>> import SystemML as sml
>>> import numpy as np
>>> sml.setSparkContext(sc)

Welcome to Apache SystemML!

>>> m1 = sml.matrix(np.ones((3,3)) + 2)
>>> m2 = sml.matrix(np.ones((3,3)) + 3)
>>> m2 = m1 * (m2 + m1)
>>> m4 = 1.0 - m2
>>> m4
# This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.
mVar1 = load(" ", format="csv")
mVar2 = load(" ", format="csv")
mVar3 = mVar2 + mVar1
mVar4 = mVar1 * mVar3
mVar5 = 1.0 - mVar4
save(mVar5, " ")
>>> m2.eval()
>>> m2
# This matrix (mVar4) is backed by NumPy array. To fetch the NumPy array, invoke toNumPy() method.
>>> m4
# This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.
mVar4 = load(" ", format="csv")
mVar5 = 1.0 - mVar4
save(mVar5, " ")
>>> m4.sum(axis=1).toNumPy()
array([[-60.],
       [-60.],
       [-60.]])

Design Decisions:

  1. Until eval() method is invoked, we create an AST (not exposed to the user) that consist of unevaluated operations and data required by those operations. As an anology, a spark user can treat eval() method similar to calling RDD.persist() followed by RDD.count().
  2. The AST consist of two kinds of nodes: either of type matrix or of type DMLOp. Both these classes expose _visit method, that helps in traversing the AST in DFS manner.
  3. A matrix object can either be evaluated or not. If evaluated, the attribute ‘data’ is set to one of the supported types (for example: NumPy array or DataFrame). In this case, the attribute ‘op’ is set to None. If not evaluated, the attribute ‘op’ which refers to one of the intermediate node of AST and if of type DMLOp. In this case, the attribute ‘data’ is set to None.
  1. DMLOp has an attribute ‘inputs’ which contains list of matrix objects or DMLOp.

  2. To simplify the traversal, every matrix object is considered immutable and an matrix operations creates a new matrix object. As an example: m1 = sml.matrix(np.ones((3,3))) creates a matrix object backed by ‘data=(np.ones((3,3))’. m1 = m1 * 2 will create a new matrix object which is now backed by ‘op=DMLOp( ... )’ whose input is earlier created matrix object.

  3. Left indexing (implemented in __setitem__ method) is a special case, where Python expects the existing object to be mutated. To ensure the above property, we make deep copy of existing object and point any references to the left-indexed matrix to the newly created object. Then the left-indexed matrix is set to be backed by DMLOp consisting of following pydml: left-indexed-matrix = new-deep-copied-matrix left-indexed-matrix[index] = value

  4. Please use m.printAST() and/or type m for debugging. Here is a sample session:

    >>> npm = np.ones((3,3))
    >>> m1 = sml.matrix(npm + 3)
    >>> m2 = sml.matrix(npm + 5)
    >>> m3 = m1 + m2
    >>> m3
    mVar2 = load(" ", format="csv")
    mVar1 = load(" ", format="csv")
    mVar3 = mVar1 + mVar2
    save(mVar3, " ")
    >>> m3.printAST()
    - [mVar3] (op).
      - [mVar1] (data).
      - [mVar2] (data).    
    
argmax(axis=None)[source]

Returns the indices of the maximum values along an axis.

axis : int, optional (only axis=1, i.e. rowIndexMax is supported in this version)

argmin(axis=None)[source]

Returns the indices of the minimum values along an axis.

axis : int, optional (only axis=1, i.e. rowIndexMax is supported in this version)

cumsum(axis=None)[source]

Returns the indices of the maximum values along an axis.

axis : int, optional (only axis=0, i.e. cumsum along the rows is supported in this version)

dml = []
dot(other)[source]

Numpy way of performing matrix multiplication

eval(outputDF=False)[source]

This is a convenience function that calls the global eval method

max(axis=None)[source]

Compute the maximum value along the specified axis

axis : int, optional

mean(axis=None)[source]

Compute the arithmetic mean along the specified axis

axis : int, optional

min(axis=None)[source]

Compute the minimum value along the specified axis

axis : int, optional

ml = None
printAST(numSpaces=0)[source]

Please use m.printAST() and/or type m for debugging. Here is a sample session:

>>> npm = np.ones((3,3))
>>> m1 = sml.matrix(npm + 3)
>>> m2 = sml.matrix(npm + 5)
>>> m3 = m1 + m2
>>> m3
mVar2 = load(" ", format="csv")
mVar1 = load(" ", format="csv")
mVar3 = mVar1 + mVar2
save(mVar3, " ")
>>> m3.printAST()
- [mVar3] (op).
  - [mVar1] (data).
  - [mVar2] (data).
script = None
sd(axis=None)[source]

Compute the standard deviation along the specified axis

axis : int, optional

sum(axis=None)[source]

Compute the sum along the specified axis

axis : int, optional

systemmlVarID = 0
toDF()[source]

This is a convenience function that calls the global eval method and then converts the matrix object into DataFrame.

toNumPy()[source]

This is a convenience function that calls the global eval method and then converts the matrix object into NumPy array.

toPandas()[source]

This is a convenience function that calls the global eval method and then converts the matrix object into Pandas DataFrame.

trace()[source]

Return the sum of the cells of the main diagonal square matrix

transpose()[source]

Transposes the matrix.

var(axis=None)[source]

Compute the variance along the specified axis

axis : int, optional

visited = []
systemml.eval(outputs, outputDF=False, execute=True)[source]

Executes the unevaluated DML script and computes the matrices specified by outputs.

outputs: list of matrices or a matrix object outputDF: back the data of matrix as PySpark DataFrame

systemml.solve(A, b)[source]

Computes the least squares solution for system of linear equations A %*% x = b

>>> import numpy as np
>>> from sklearn import datasets
>>> import SystemML as sml
>>> from pyspark.sql import SQLContext
>>> diabetes = datasets.load_diabetes()
>>> diabetes_X = diabetes.data[:, np.newaxis, 2]
>>> X_train = diabetes_X[:-20]
>>> X_test = diabetes_X[-20:]
>>> y_train = diabetes.target[:-20]
>>> y_test = diabetes.target[-20:]
>>> sml.setSparkContext(sc)
>>> X = sml.matrix(X_train)
>>> y = sml.matrix(y_train)
>>> A = X.transpose().dot(X)
>>> b = X.transpose().dot(y)
>>> beta = sml.solve(A, b).toNumPy()
>>> y_predicted = X_test.dot(beta)
>>> print('Residual sum of squares: %.2f' % np.mean((y_predicted - y_test) ** 2))
Residual sum of squares: 25282.12
class systemml.DMLOp(inputs, dml=None)[source]

Bases: object

Represents an intermediate node of Abstract syntax tree created to generate the PyDML script

printAST(numSpaces)[source]
systemml.exp(X)[source]
systemml.log(X, y=None)[source]
systemml.abs(X)[source]
systemml.sqrt(X)[source]
systemml.round(X)[source]
systemml.floor(X)[source]
systemml.ceil(X)[source]
systemml.sin(X)[source]
systemml.cos(X)[source]
systemml.tan(X)[source]
systemml.asin(X)[source]
systemml.acos(X)[source]
systemml.atan(X)[source]
systemml.sign(X)[source]
systemml.getNumCols(numPyArr)[source]
systemml.convertToMatrixBlock(sc, src)[source]
systemml.convertToNumPyArr(sc, mb)[source]
systemml.convertToPandasDF(X)[source]
systemml.convertToLabeledDF(sqlCtx, X, y=None)[source]

Indices and tables