跨语言编程#
此页面展示了如何使用 Ray 的跨语言编程功能。
设置驱动程序#
您需要在驱动程序中设置代码搜索路径。
import ray
ray.init(job_config=ray.job_config.JobConfig(code_search_path=["/path/to/code"]))
java -classpath <classpath> \
-Dray.address=<address> \
-Dray.job.code-search-path=/path/to/code/ \
<classname> <args>
如果您将 Python 和 Java 代码放在不同的目录中,您可能需要包含多个目录来加载 worker 的这两种代码。
import ray
ray.init(job_config=ray.job_config.JobConfig(code_search_path="/path/to/jars:/path/to/pys"))
java -classpath <classpath> \
-Dray.address=<address> \
-Dray.job.code-search-path=/path/to/jars:/path/to/pys \
<classname> <args>
Python 调用 Java#
假设您有一个 Java 静态方法和一个 Java 类,如下所示
package io.ray.demo;
public class Math {
public static int add(int a, int b) {
return a + b;
}
}
package io.ray.demo;
// A regular Java class.
public class Counter {
private int value = 0;
public int increment() {
this.value += 1;
return this.value;
}
}
然后,在 Python 中,您可以调用前面提到的 Java 远程函数,或从前面提到的 Java 类创建 actor。
import ray
with ray.init(job_config=ray.job_config.JobConfig(code_search_path=["/path/to/code"])):
# Define a Java class.
counter_class = ray.cross_language.java_actor_class(
"io.ray.demo.Counter")
# Create a Java actor and call actor method.
counter = counter_class.remote()
obj_ref1 = counter.increment.remote()
assert ray.get(obj_ref1) == 1
obj_ref2 = counter.increment.remote()
assert ray.get(obj_ref2) == 2
# Define a Java function.
add_function = ray.cross_language.java_function(
"io.ray.demo.Math", "add")
# Call the Java remote function.
obj_ref3 = add_function.remote(1, 2)
assert ray.get(obj_ref3) == 3
Java 调用 Python#
假设您有一个 Python 模块,如下所示
# /path/to/the_dir/ray_demo.py
import ray
@ray.remote
class Counter(object):
def __init__(self):
self.value = 0
def increment(self):
self.value += 1
return self.value
@ray.remote
def add(a, b):
return a + b
注意
您应该使用
@ray.remote
装饰函数或类。
然后,在 Java 中,您可以调用前面提到的 Python 远程函数,或从前面提到的 Python 类创建 actor。
package io.ray.demo;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.Ray;
import io.ray.api.function.PyActorClass;
import io.ray.api.function.PyActorMethod;
import io.ray.api.function.PyFunction;
import org.testng.Assert;
public class JavaCallPythonDemo {
public static void main(String[] args) {
// Set the code-search-path to the directory of your `ray_demo.py` file.
System.setProperty("ray.job.code-search-path", "/path/to/the_dir/");
Ray.init();
// Define a Python class.
PyActorClass actorClass = PyActorClass.of(
"ray_demo", "Counter");
// Create a Python actor and call actor method.
PyActorHandle actor = Ray.actor(actorClass).remote();
ObjectRef objRef1 = actor.task(
PyActorMethod.of("increment", int.class)).remote();
Assert.assertEquals(objRef1.get(), 1);
ObjectRef objRef2 = actor.task(
PyActorMethod.of("increment", int.class)).remote();
Assert.assertEquals(objRef2.get(), 2);
// Call the Python remote function.
ObjectRef objRef3 = Ray.task(PyFunction.of(
"ray_demo", "add", int.class), 1, 2).remote();
Assert.assertEquals(objRef3.get(), 3);
Ray.shutdown();
}
}
跨语言数据序列化#
如果 Ray 调用的参数和返回值的类型是以下类型,Ray 会自动序列化和反序列化它们
- 基本数据类型
MessagePack
Python
Java
nil
None
null
bool
bool
Boolean
int
int
Short / Integer / Long / BigInteger
float
float
Float / Double
str
str
String
bin
bytes
byte[]
- 基本容器类型
MessagePack
Python
Java
array
list
Array
- Ray 内置类型
ActorHandle
注意
请注意 Python 和 Java 之间的 float / double 精度差异。如果 Java 使用 float 类型接收输入参数,则双精度 Python 数据在 Java 中会降至 float 精度。
BigInteger 支持的最大值为 2^64-1。参见:msgpack/msgpack。如果值大于 2^64-1,则将该值发送到 Python 会引发异常。
以下示例展示了如何将这些类型作为参数传递以及如何返回这些类型。
您可以编写一个返回输入数据的 Python 函数
# ray_serialization.py
import ray
@ray.remote
def py_return_input(v):
return v
然后,您可以将对象从 Java 传递到 Python,再从 Python 传回 Java
package io.ray.demo;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.function.PyFunction;
import java.math.BigInteger;
import org.testng.Assert;
public class SerializationDemo {
public static void main(String[] args) {
Ray.init();
Object[] inputs = new Object[]{
true, // Boolean
Byte.MAX_VALUE, // Byte
Short.MAX_VALUE, // Short
Integer.MAX_VALUE, // Integer
Long.MAX_VALUE, // Long
BigInteger.valueOf(Long.MAX_VALUE), // BigInteger
"Hello World!", // String
1.234f, // Float
1.234, // Double
"example binary".getBytes()}; // byte[]
for (Object o : inputs) {
ObjectRef res = Ray.task(
PyFunction.of("ray_serialization", "py_return_input", o.getClass()),
o).remote();
Assert.assertEquals(res.get(), o);
}
Ray.shutdown();
}
}
跨语言异常栈#
假设您有一个 Java 包,如下所示
package io.ray.demo;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.function.PyFunction;
public class MyRayClass {
public static int raiseExceptionFromPython() {
PyFunction<Integer> raiseException = PyFunction.of(
"ray_exception", "raise_exception", Integer.class);
ObjectRef<Integer> refObj = Ray.task(raiseException).remote();
return refObj.get();
}
}
以及一个 Python 模块,如下所示
# ray_exception.py
import ray
@ray.remote
def raise_exception():
1 / 0
然后,运行以下代码
# ray_exception_demo.py
import ray
with ray.init(job_config=ray.job_config.JobConfig(code_search_path=["/path/to/ray_exception"])):
obj_ref = ray.cross_language.java_function(
"io.ray.demo.MyRayClass",
"raiseExceptionFromPython").remote()
ray.get(obj_ref) # <-- raise exception from here.
异常栈将是
Traceback (most recent call last):
File "ray_exception_demo.py", line 9, in <module>
ray.get(obj_ref) # <-- raise exception from here.
File "ray/python/ray/_private/client_mode_hook.py", line 105, in wrapper
return func(*args, **kwargs)
File "ray/python/ray/_private/worker.py", line 2247, in get
raise value
ray.exceptions.CrossLanguageError: An exception raised from JAVA:
io.ray.api.exception.RayTaskException: (pid=61894, ip=172.17.0.2) Error executing task c8ef45ccd0112571ffffffffffffffffffffffff01000000
at io.ray.runtime.task.TaskExecutor.execute(TaskExecutor.java:186)
at io.ray.runtime.RayNativeRuntime.nativeRunTaskExecutor(Native Method)
at io.ray.runtime.RayNativeRuntime.run(RayNativeRuntime.java:231)
at io.ray.runtime.runner.worker.DefaultWorker.main(DefaultWorker.java:15)
Caused by: io.ray.api.exception.CrossLanguageException: An exception raised from PYTHON:
ray.exceptions.RayTaskError: ray::raise_exception() (pid=62041, ip=172.17.0.2)
File "ray_exception.py", line 7, in raise_exception
1 / 0
ZeroDivisionError: division by zero