Jraph API

GraphsTuple

class jraph.GraphsTuple(nodes: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], edges: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], receivers: Optional[jax._src.numpy.ndarray.ndarray], senders: Optional[jax._src.numpy.ndarray.ndarray], globals: Optional[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], n_node: jax._src.numpy.ndarray.ndarray, n_edge: jax._src.numpy.ndarray.ndarray)[source]

An ordered collection of graphs in a sparse format.

The values of nodes, edges and globals can be ArrayTrees - nests of features with jax compatible values. For example, nodes in a graph may have more than one type of attribute.

However, the GraphsTuple typically takes the following form for a batch of n graphs:

  • n_node: The number of nodes per graph. It is a vector of integers with shape [n_graphs], such that graph.n_node[i] is the number of nodes in the i-th graph.

  • n_edge: The number of edges per graph. It is a vector of integers with shape [n_graphs], such that graph.n_edge[i] is the number of edges in the i-th graph.

  • nodes: The nodes features. It is either None (the graph has no node features), or a vector of shape [n_nodes] + node_shape, where n_nodes = sum(graph.n_node) is the total number of nodes in the batch of graphs, and node_shape represents the shape of the features of each node. The relative index of a node from the batched version can be recovered from the graph.n_node property. For instance, the second node of the third graph will have its features in the 1 + graph.n_node[0] + graph.n_node[1]-th slot of graph.nodes. Observe that having a None value for this field does not mean that the graphs have no nodes, only that they do not have node features.

  • edges: The edges features. It is either None (the graph has no edge features), or a vector of shape [n_edges] + edge_shape, where n_edges = sum(graph.n_edge) is the total number of edges in the batch of graphs, and edge_shape represents the shape of the features of each edge.

    The relative index of an edge from the batched version can be recovered from the graph.n_edge property. For instance, the third edge of the third graph will have its features in the 2 + graph.n_edge[0] + graph.n_edge[1]- th slot of graph.edges.

    Having a None value for this field does not necessarily mean that the graph has no edges, only that they do not have edge features.

  • receivers: The indices of the receiver nodes, for each edge. It is either None (if the graph has no edges), or a vector of integers of shape [n_edges], such that graph.receivers[i] is the index of the node receiving from the i-th edge.

    Observe that the index is absolute (in other words, cumulative), i.e. graphs.receivers take value in [0, n_nodes]. For instance, an edge connecting the vertices with relative indices 2 and 3 in the second graph of the batch would have a receivers value of 3 + graph.n_node[0]. If graphs.receivers is None, then graphs.edges and graphs.senders should also be None.

  • senders: The indices of the sender nodes, for each edge. It is either None (if the graph has no edges), or a vector of integers of shape [n_edges], such that graph.senders[i] is the index of the node sending from the i-th edge.

    Observe that the index is absolute, i.e. graphs.senders take value in [0, n_nodes]. For instance, an edge connecting the vertices with relative indices 1 and 3 in the third graph of the batch would have a senders value of 1 + graph.n_node[0] + graph.n_node[1].

    If graphs.senders is None, then graphs.edges and graphs.receivers should also be None.

  • globals: The global features of the graph. It is either None (the graph has no global features), or a vector of shape [n_graphs] + global_shape representing graph level features.

Batching & Padding Utilities

jraph.batch(graphs)[source]

Returns a batched graph given a list of graphs.

This method will concatenate the nodes, edges and globals, n_node and n_edge of a sequence of GraphsTuple along axis 0. For senders and receivers, offsets are computed so that connectivity remains valid for the new node indices.

For example:

key = jax.random.PRNGKey(0)
graph_1 = GraphsTuple(nodes=jax.random.normal(key, (3, 64)),
                  edges=jax.random.normal(key, (5, 64)),
                  senders=jnp.array([0,0,1,1,2]),
                  receivers=[1,2,0,2,1],
                  n_node=jnp.array([3]),
                  n_edge=jnp.array([5]),
                  globals=jax.random.normal(key, (1, 64)))
graph_2 = GraphsTuple(nodes=jax.random.normal(key, (5, 64)),
                  edges=jax.random.normal(key, (10, 64)),
                  senders=jnp.array([0,0,1,1,2,2,3,3,4,4]),
                  receivers=jnp.array([1,2,0,2,1,0,2,1,3,2]),
                  n_node=jnp.array([5]),
                  n_edge=jnp.array([10]),
                  globals=jax.random.normal(key, (1, 64)))
batch = graph.batch([graph_1, graph_2])

batch.nodes.shape
>> (8, 64)
batch.edges.shape
>> (15, 64)
# Offsets computed on senders and receivers
batch.senders
>> DeviceArray([0, 0, 1, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7], dtype=int32)
batch.receivers
>> DeviceArray([1, 2, 0, 2, 1, 4, 5, 3, 5, 4, 3, 5, 4, 6, 5], dtype=int32)
batch.n_node
>> DeviceArray([3, 5], dtype=int32)
batch.n_edge
>> DeviceArray([5, 10], dtype=int32)

If a GraphsTuple does not contain any graphs, it will be dropped from the batch.

This method is not compilable because it is data dependent.

This functionality was implementation as utils_tf.concat in the Tensorflow version of graph_nets.

Parameters

graphs (Sequence[GraphsTuple]) – sequence of ``GraphsTuple``s which will be batched into a single graph.

Return type

GraphsTuple

jraph.unbatch(graph)[source]

Returns a list of graphs given a batched graph.

This function does not support jax.jit, because the shape of the output is data-dependent!

Parameters

graph (GraphsTuple) – the batched graph, which will be unbatched into a list of graphs.

Return type

List[GraphsTuple]

jraph.pad_with_graphs(graph, n_node, n_edge, n_graph=2)[source]

Pads a GraphsTuple to size by adding computation preserving graphs.

The GraphsTuple is padded by first adding a dummy graph which contains the padding nodes and edges, and then empty graphs without nodes or edges.

The empty graphs and the dummy graph do not interfer with the graphnet calculations on the original graph, and so are computation preserving.

The padding graph requires at least one node and one graph.

This function does not support jax.jit, because the shape of the output is data-dependent.

Parameters
  • graph (GraphsTuple) – GraphsTuple padded with dummy graph and empty graphs.

  • n_node (int) – the number of nodes in the padded GraphsTuple.

  • n_edge (int) – the number of edges in the padded GraphsTuple.

  • n_graph (int) – the number of graphs in the padded GraphsTuple. Default is 2, which is the lowest possible value, because we always have at least one graph in the original GraphsTuple and we need one dummy graph for the padding.

Raises
  • ValueError – if the passed n_graph is smaller than 2.

  • RuntimeError – if the given GraphsTuple is too large for the given padding.

Return type

GraphsTuple

Returns

A padded GraphsTuple.

jraph.get_number_of_padding_with_graphs_graphs(padded_graph)[source]

Returns number of padding graphs in padded_graph.

Warning: This method only gives results for graphs that have been padded with pad_with_graphs.

Parameters

padded_graph (GraphsTuple) – a GraphsTuple that has been padded with pad_with_graphs.

Return type

int

Returns

The number of padding graphs.

jraph.get_number_of_padding_with_graphs_nodes(padded_graph)[source]

Returns number of padding nodes in given padded_graph.

Warning: This method only gives results for graphs that have been padded with pad_with_graphs.

Parameters

padded_graph (GraphsTuple) – a GraphsTuple that has been padded with pad_with_graphs.

Return type

int

Returns

The number of padding nodes.

jraph.get_number_of_padding_with_graphs_edges(padded_graph)[source]

Returns number of padding edges in given padded_graph.

Warning: This method only gives results for graphs that have been padded with pad_with_graphs.

Parameters

padded_graph (GraphsTuple) – a GraphsTuple that has been padded with pad_with_graphs.

Return type

int

Returns

The number of padding edges.

jraph.unpad_with_graphs(padded_graph)[source]

Unpads the given graph by removing the dummy graph and empty graphs.

This function assumes that the given graph was padded with the pad_with_graphs function.

This function does not support jax.jit, because the shape of the output is data-dependent!

Parameters

padded_graph (GraphsTuple) – GraphsTuple padded with a dummy graph and empty graphs.

Return type

GraphsTuple

Returns

The unpadded graph.

jraph.get_node_padding_mask(padded_graph)[source]

Returns a mask for the nodes of a padded graph.

Parameters

padded_graph (GraphsTuple) – GraphsTuple padded using pad_with_graphs. This graph must contain at least one array of node features so the total static number of nodes can be inferred statically from the shape, and the method can be jitted.

Return type

Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]

Returns

Boolean array of shape [total_num_nodes] containing True for real nodes, and False for padding nodes.

jraph.get_edge_padding_mask(padded_graph)[source]

Returns a mask for the edges of a padded graph.

Parameters

padded_graph (GraphsTuple) – GraphsTuple padded using pad_with_graphs.

Return type

Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]

Returns

Boolean array of shape [total_num_edges] containing True for real edges, and False for padding edges.

jraph.get_graph_padding_mask(padded_graph)[source]

Returns a mask for the graphs of a padded graph.

Parameters

padded_graph (GraphsTuple) – GraphsTuple padded using pad_with_graphs.

Return type

Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]

Returns

Boolean array of shape [total_num_graphs] containing True for real graphs, and False for padding graphs.

Segment Utilities

jraph.segment_mean(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False)[source]

Returns mean for each segment.

Parameters
  • data (ndarray) – the values which are averaged segment-wise.

  • segment_ids (ndarray) – indices for the segments.

  • num_segments (Optional[int]) – total number of segments.

  • indices_are_sorted (bool) – whether segment_ids is known to be sorted.

  • unique_indices (bool) – whether segment_ids is known to be free of duplicates.

jraph.segment_max(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False)[source]

Alias for jax.ops.segment_max.

Parameters
  • data (ndarray) – an array with the values to be maxed over.

  • segment_ids (ndarray) – an array with integer dtype that indicates the segments of data (along its leading axis) to be maxed over. Values can be repeated and need not be sorted. Values outside of the range [0, num_segments) are dropped and do not contribute to the result.

  • num_segments (Optional[int]) – optional, an int with positive value indicating the number of segments. The default is jnp.maximum(jnp.max(segment_ids) + 1, jnp.max(-segment_ids)) but since num_segments determines the size of the output, a static value must be provided to use segment_max in a jit-compiled function.

  • indices_are_sorted (bool) – whether segment_ids is known to be sorted

  • unique_indices (bool) – whether segment_ids is known to be free of duplicates

Returns

An array with shape (num_segments,) + data.shape[1:] representing the segment maxs.

jraph.segment_softmax(logits, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False)[source]

Computes a segment-wise softmax.

For a given tree of logits that can be divded into segments, computes a softmax over the segments.

logits = jnp.ndarray([1.0, 2.0, 3.0, 1.0, 2.0]) segment_ids = jnp.ndarray([0, 0, 0, 1, 1]) segment_softmax(logits, segments) >> DeviceArray([0.09003057, 0.24472848, 0.66524094, 0.26894142, 0.7310586], >> dtype=float32)

Parameters
  • logits (ndarray) – an array of logits to be segment softmaxed.

  • segment_ids (ndarray) – an array with integer dtype that indicates the segments of data (along its leading axis) to be maxed over. Values can be repeated and need not be sorted. Values outside of the range [0, num_segments) are dropped and do not contribute to the result.

  • num_segments (Optional[int]) – optional, an int with positive value indicating the number of segments. The default is jnp.maximum(jnp.max(segment_ids) + 1, jnp.max(-segment_ids)) but since num_segments determines the size of the output, a static value must be provided to use segment_sum in a jit-compiled function.

  • indices_are_sorted (bool) – whether segment_ids is known to be sorted

  • unique_indices (bool) – whether segment_ids is known to be free of duplicates

Return type

Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]

Returns

The segment softmax-ed logits.

jraph.partition_softmax(logits, partitions, sum_partitions=None)[source]

Compute a softmax within partitions of an array.

For example::

logits = jnp.ndarray([1.0, 2.0, 3.0, 1.0, 2.0]) partitions = jnp.ndarray([3, 2]) partition_softmax(node_logits, n_node) >> DeviceArray( >> [0.09003057, 0.24472848, 0.66524094, 0.26894142, 0.7310586], >> dtype=float32)

Parameters
  • logits (Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]) – the logits for the softmax.

  • partitions (ndarray) – the number of nodes per graph. It is a vector of integers with shape [n_graphs], such that graph.n_node[i] is the number of nodes in the i-th graph.

  • sum_partitions (Optional[int]) – the sum of n_node. If not passed, the result of this method is data dependent and so not jit-able.

Returns

The softmax over partitions.

Misc Utilities

jraph.concatenated_args(update=None, *, axis=- 1)[source]

Decorator that concatenates arguments before being passed to an update_fn.

By default node, edge and global features are passed separately to update functions. However, it is common practice to concatenate these features before passing them to a neural network. This wrapper concatenates the arguments for you.

For example:

# Without the wrapper
def update_node_fn(nodes, receivers, globals):
  return net(jnp.concatenate([nodes, receivers, globals], axis=1))

# With the wrapper
@concatenated_args
def update_node_fn(features):
  return net(features)
Parameters
  • update (Optional[Callable[…, Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]) – an update function that takes jnp.ndarray.

  • axis (int) – the axis upon which to concatenate.

Return type

Union[Callable[…, Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Callable[[Callable[…, Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]

Returns

A wrapped function with the arguments concatenated.

Models

jraph.GraphNetwork(update_edge_fn, update_node_fn, update_global_fn=None, aggregate_edges_for_nodes_fn=<function segment_sum>, aggregate_nodes_for_globals_fn=<function segment_sum>, aggregate_edges_for_globals_fn=<function segment_sum>, attention_logit_fn=None, attention_normalize_fn=<function segment_softmax>, attention_reduce_fn=None)[source]

Returns a method that applies a configured GraphNetwork.

This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261

There is one difference. For the nodes update the class aggregates over the sender edges and receiver edges separately. This is a bit more general than the algorithm described in the paper. The original behaviour can be recovered by using only the receiver edge aggregations for the update.

In addition this implementation supports softmax attention over incoming edge features.

Example usage:

gn = GraphNetwork(update_edge_function,
update_node_function, **kwargs)
# Conduct multiple rounds of message passing with the same parameters:
for _ in range(num_message_passing_steps):
  graph = gn(graph)
Parameters
  • update_edge_fn (Optional[Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]) – function used to update the edges or None to deactivate edge updates.

  • update_node_fn (Optional[Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]) – function used to update the nodes or None to deactivate node updates.

  • update_global_fn (Optional[Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]) – function used to update the globals or None to deactivate globals updates.

  • aggregate_edges_for_nodes_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], ndarray, int], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to aggregate messages to each node.

  • aggregate_nodes_for_globals_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], ndarray, int], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to aggregate the nodes for the globals.

  • aggregate_edges_for_globals_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], ndarray, int], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to aggregate the edges for the globals.

  • attention_logit_fn (Optional[Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]) – function used to calculate the attention weights or None to deactivate attention mechanism.

  • attention_normalize_fn (Optional[Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], ndarray, int], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]) – function used to normalize raw attention logits or None if attention mechanism is not active.

  • attention_reduce_fn (Optional[Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]) – function used to apply weights to the edge features or None if attention mechanism is not active.

Returns

A method that applies the configured GraphNetwork.

jraph.InteractionNetwork(update_edge_fn, update_node_fn, aggregate_edges_for_nodes_fn=<function segment_sum>, include_sent_messages_in_node_update=False)[source]

Returns a method that applies a configured InteractionNetwork.

An interaction network computes interactions on the edges based on the previous edges features, and on the features of the nodes sending into those edges. It then updates the nodes based on the incoming updated edges. See https://arxiv.org/abs/1612.00222 for more details.

This implementation adds an option not in https://arxiv.org/abs/1612.00222, which is to include edge features for which a node is a sender in the arguments to the node update function.

Parameters
  • update_edge_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – a function mapping a single edge update inputs to a single edge feature.

  • update_node_fn (Callable[…, Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – a function mapping a single node update input to a single node feature.

  • aggregate_edges_for_nodes_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], ndarray, int], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to aggregate messages to each node.

  • include_sent_messages_in_node_update (bool) – pass edge features for which a node is a sender to the node update function.

jraph.GraphMapFeatures(embed_edge_fn=None, embed_node_fn=None, embed_global_fn=None)[source]

Returns function which embeds the components of a graph independently.

Parameters
  • embed_edge_fn (Optional[Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]) – function used to embed the edges.

  • embed_node_fn (Optional[Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]) – function used to embed the nodes.

  • embed_global_fn (Optional[Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]) – function used to embed the globals.

jraph.RelationNetwork(update_edge_fn, update_global_fn, aggregate_edges_for_globals_fn=<function segment_sum>)[source]

Returns a method that applies a Relation Network.

See https://arxiv.org/abs/1706.01427 for more details.

This implementation has one more argument, aggregate_edges_for_globals_fn, which changes how edge features are aggregated. The paper uses the default - utils.segment_sum.

Parameters
  • update_edge_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to update the edges.

  • update_global_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to update the globals.

  • aggregate_edges_for_globals_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], ndarray, int], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to aggregate the edges for the globals.

jraph.DeepSets(update_node_fn, update_global_fn, aggregate_nodes_for_globals_fn=<function segment_sum>)[source]

Returns a method that applies a DeepSets layer.

Implementation for the model described in https://arxiv.org/abs/1703.06114 (M. Zaheer, S. Kottur, S. Ravanbakhsh, B. Poczos, R. Salakhutdinov, A. Smola). See also PointNet (https://arxiv.org/abs/1612.00593, C. Qi, H. Su, K. Mo, L. J. Guibas) for a related model.

This module operates on sets, which can be thought of as graphs without edges. The nodes features are first updated based on their value and the globals features, and new globals features are then computed based on the updated nodes features.

Parameters
  • update_node_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to update the nodes.

  • update_global_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to update the globals.

  • aggregate_nodes_for_globals_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], ndarray, int], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to aggregate the nodes for the globals.

jraph.GraphNetGAT(update_edge_fn, update_node_fn, attention_logit_fn, attention_reduce_fn, update_global_fn=None, aggregate_edges_for_nodes_fn=<function segment_sum>, aggregate_nodes_for_globals_fn=<function segment_sum>, aggregate_edges_for_globals_fn=<function segment_sum>)[source]

Returns a method that applies a GraphNet with attention on edge features.

Parameters
  • update_edge_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to update the edges.

  • update_node_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to update the nodes.

  • attention_logit_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to calculate the attention weights.

  • attention_reduce_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to apply attention weights to the edge features.

  • update_global_fn (Optional[Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]) – function used to update the globals or None to deactivate globals updates.

  • aggregate_edges_for_nodes_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], ndarray, int], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to aggregate attention-weighted messages to each node.

  • aggregate_nodes_for_globals_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], ndarray, int], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to aggregate the nodes for the globals.

  • aggregate_edges_for_globals_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], ndarray, int], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to aggregate attention-weighted edges for the globals.

Returns

A function that applies a GraphNet Graph Attention layer.

jraph.GAT(attention_query_fn, attention_logit_fn, node_update_fn=None)[source]

Returns a method that applies a Graph Attention Network layer.

Graph Attention message passing as described in https://arxiv.org/abs/1710.10903. This model expects node features as a jnp.array, may use edge features for computing attention weights, and ignore global features. It does not support nests.

NOTE: this implementation assumes that the input graph has self edges. To recover the behavior of the referenced paper, please add self edges.

Parameters
  • attention_query_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function that generates attention queries from sender node features.

  • attention_logit_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function that converts attention queries into logits for softmax attention.

  • node_update_fn (Optional[Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]) – function that updates the aggregated messages. If None, will apply leaky relu and concatenate (if using multi-head attention).

Returns

A function that applies a Graph Attention layer.

jraph.GraphConvolution(update_node_fn, aggregate_nodes_fn=<function segment_sum>, add_self_edges=False, symmetric_normalization=True)[source]

Returns a method that applies a Graph Convolution layer.

Graph Convolutional layer as in https://arxiv.org/abs/1609.02907,

NOTE: This implementation does not add an activation after aggregation. If you are stacking layers, you may want to add an activation between each layer.

Parameters
  • update_node_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to update the nodes. In the paper a single layer MLP is used.

  • aggregate_nodes_fn (Callable[[Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]], ndarray, int], Union[ndarray, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]) – function used to aggregates the sender nodes.

  • add_self_edges (bool) – whether to add self edges to nodes in the graph as in the paper definition of GCN. Defaults to False.

  • symmetric_normalization (bool) – whether to use symmetric normalization. Defaults to True.

Returns

A method that applies a Graph Convolution layer.