9 min read

Rolling Cube

Table of Contents

Problem Statement

Consider a 3x3 grid containing eight unique cubes and one empty space. Each cube has a distinct design: one face is black, the opposite face has a unique color, and the remaining four faces are gray. Initially, all cubes show their black faces up, with the center square empty. The goal is to roll the cubes (without lifting them) until all colored faces are facing up.

Mathematical Structure

Cube States and Rotations

Each cube can be in one of six orientations:

  • 0: Black face up
  • 1: Colored face up
  • 2: Gray face up, black facing north
  • 3: Gray face up, black facing east
  • 4: Gray face up, black facing south
  • 5: Gray face up, black facing west

State Space

  • 9 possible positions for the empty square
  • 6 possible orientations for each of 8 cubes
  • Total states: 9 × 6^8 = 1,679,616 states
  • Each state has 2-4 possible moves (depending on empty square position)

Implementation

State Representation

@dataclass(frozen=True)
class State:
    orientations: Tuple[Tuple[int | None, ...], ...]  # 3x3 grid of orientations
    empty: Tuple[int, int]                            # Empty cell position

    def next_states(self) -> Set[Tuple['State', str]]:
        moves = {'U': (-1, 0), 'D': (1, 0), 'L': (0, -1), 'R': (0, 1)}
        er, ec = self.empty
        neighbors = set()
        for move, (dr, dc) in moves.items():
            r, c = er - dr, ec - dc
            if 0 <= r < 3 and 0 <= c < 3:
                new_grid = [[x for x in row] for row in self.orientations]
                new_grid[er][ec] = roll(new_grid[r][c], move)
                new_grid[r][c] = None
                neighbors.add((
                    State(tuple(tuple(row) for row in new_grid), (r, c)),
                    move
                ))
        return neighbors

Cube Rolling Mechanics

def roll(orientation: int, direction: str) -> int:
    if orientation == 0:  # black up
        return {'U': 2,   # black faces north
                'D': 4,   # black faces south
                'L': 5,   # black faces west 
                'R': 3    # black faces east
               }[direction]
    elif orientation == 1:  # color up
        return {'U': 4,   # black faces south
                'D': 2,   # black faces north
                'L': 3,   # black faces east
                'R': 5    # black faces west
               }[direction]
    
    black_facing = {2: 'N', 3: 'E', 4: 'S', 5: 'W'}[orientation]
    moves = {'N': 'U', 'S': 'D', 'E': 'R', 'W': 'L'}
    
    if direction == moves[black_facing]:        # Rolling towards black
        return 1                                # Color comes up
    opposites = {'N': 'S', 'S': 'N', 'E': 'W', 'W': 'E'}
    if direction == moves[opposites[black_facing]]:  # Rolling away from black
        return 0                                     # Black comes up
    
    return orientation  # Rolling parallel, orientation stays same

Graph Generation

def generate_all_states() -> tuple[nx.DiGraph, Dict[State, int]]:
    G = nx.DiGraph()
    states_dict = {}
    states_count = 0
    
    # Generate all possible states
    empty_positions = [(i, j) for i in range(3) for j in range(3)]
    for empty_pos in empty_positions:
        for orientations in itertools.product(range(6), repeat=8):
            grid = [[None] * 3 for _ in range(3)]
            idx = 0
            for r in range(3):
                for c in range(3):
                    if (r, c) != empty_pos:
                        grid[r][c] = orientations[idx]
                        idx += 1
            
            state = State(tuple(tuple(row) for row in grid), empty_pos)
            states_dict[state] = states_count
            G.add_node(states_count, state=state)
            states_count += 1
    
    # Add edges for valid moves
    for state, node_id in states_dict.items():
        for next_state, move in state.next_states():
            next_id = states_dict[next_state]
            G.add_edge(node_id, next_id, move=move)
    
    return G, states_dict

Finding the Solution

def find_solution(G: nx.DiGraph, states_dict: Dict[State, int]) -> str:
    initial_state = State(tuple(tuple(0 if (i != 1 or j != 1) else None 
                                    for j in range(3))
                              for i in range(3)), (1, 1))
    
    goal_state = State(tuple(tuple(1 if (i != 1 or j != 1) else None 
                                  for j in range(3))
                            for i in range(3)), (1, 1))
    
    path = nx.shortest_path(G, states_dict[initial_state], states_dict[goal_state])
    return ''.join(G[path[i]][path[i+1]]['move'] for i in range(len(path)-1))

Key Implementation Details

  1. State Design

    • Immutable and hashable states using frozen dataclass
    • Complete state representation with orientations and empty position
    • Efficient next state generation
  2. Rolling Mechanics

    • Precise handling of all possible cube orientations
    • Three cases for gray-up states:
      • Rolling towards black direction
      • Rolling away from black direction
      • Rolling parallel (maintaining orientation)
  3. Graph Generation

    • Two-phase approach:
      • First generate all possible states
      • Then add edges for valid moves
    • Complete state space exploration
    • Efficient edge creation
  4. Solution Finding

    • Uses NetworkX’s shortest_path algorithm
    • Converts path to sequence of moves
    • Optimal solution guaranteed

The solution demonstrates the power of combining discrete mathematics, graph theory, and efficient data structures to solve a complex physical puzzle. The systematic approach to state space generation and graph building ensures we find the optimal solution, while the careful handling of cube rotations maintains the physical accuracy of the puzzle mechanics.

This puzzle serves as an excellent example of how abstract mathematical concepts can be applied to solve concrete physical problems, and how proper state representation and systematic exploration can make even large state spaces manageable.

Complete Python Code

import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Tuple, Set, Dict
import networkx as nx
import itertools

@dataclass(frozen=True)
class State:
    """State of the puzzle: cube orientations and empty position.
    orientations: 0=black up, 1=color up, 2=black north, 3=black east, 4=black south, 5=black west
    empty: (row, col) of empty cell
    """
    orientations: Tuple[Tuple[int | None, ...], ...]
    empty: Tuple[int, int]

    def next_states(self) -> Set[Tuple['State', str]]:
        """Get all valid next states and their moves."""
        moves = {'U': (-1, 0), 'D': (1, 0), 'L': (0, -1), 'R': (0, 1)}
        er, ec = self.empty
        neighbors = set()

        for move, (dr, dc) in moves.items():
            r, c = er - dr, ec - dc
            if 0 <= r < 3 and 0 <= c < 3:
                new_grid = [[x for x in row] for row in self.orientations]
                new_grid[er][ec] = roll(new_grid[r][c], move)
                new_grid[r][c] = None
                neighbors.add((
                    State(tuple(tuple(row) for row in new_grid), (r, c)),
                    move
                ))
        return neighbors

def roll(orientation: int, direction: str) -> int:
    """Return new orientation after rolling."""
    if orientation == 0:  # black up
        return {'U': 2,   # black faces north
                'D': 4,   # black faces south
                'L': 5,   # black faces west 
                'R': 3    # black faces east
               }[direction]
    elif orientation == 1:  # color up
        return {'U': 4,   # black faces south
                'D': 2,   # black faces north
                'L': 3,   # black faces east
                'R': 5    # black faces west
               }[direction]
    
    black_facing = {2: 'N', 3: 'E', 4: 'S', 5: 'W'}[orientation]
    moves = {'N': 'U', 'S': 'D', 'E': 'R', 'W': 'L'}
    
    # If rolling in direction black faces, color comes up
    if direction == moves[black_facing]:
        return 1
    # If rolling opposite to black face direction, black comes up
    opposites = {'N': 'S', 'S': 'N', 'E': 'W', 'W': 'E'}
    if direction == moves[opposites[black_facing]]:
        return 0
        
    # If rolling parallel, orientation stays the same
    return orientation

def generate_all_states() -> tuple[nx.DiGraph, Dict[State, int]]:
    """Generate all possible states and their connections."""
    G = nx.DiGraph()
    states_dict = {}  # map State -> node_id
    states_count = 0
    
    # Generate all possible empty positions
    empty_positions = [(i, j) for i in range(3) for j in range(3)]
    
    # First generate all possible states
    print("Generating states...")
    for empty_pos in empty_positions:
        # For each possible orientation of the 8 cubes
        for orientations in itertools.product(range(6), repeat=8):
            # Create state grid
            grid = [[None] * 3 for _ in range(3)]
            idx = 0
            for r in range(3):
                for c in range(3):
                    if (r, c) != empty_pos:
                        grid[r][c] = orientations[idx]
                        idx += 1
            
            state = State(tuple(tuple(row) for row in grid), empty_pos)
            states_dict[state] = states_count
            G.add_node(states_count, state=state)
            states_count += 1
            
    # Now add all edges between states
    print("\nAdding edges...")
    edge_count = 0
    for state, node_id in states_dict.items():
        for next_state, move in state.next_states():
            next_id = states_dict[next_state]
            G.add_edge(node_id, next_id, move=move)
            edge_count += 1
                
    print(f"\nTotal state ids: {states_count}")
    print(f"Total edges: {edge_count}")
    return G, states_dict

def find_solution(G: nx.DiGraph, states_dict: Dict[State, int]) -> str:
    """Find shortest solution in the pre-built graph."""
    initial_state = State(tuple(tuple(0 if (i != 1 or j != 1) else None 
                                    for j in range(3))
                              for i in range(3)), (1, 1))
    
    goal_state = State(tuple(tuple(1 if (i != 1 or j != 1) else None 
                                  for j in range(3))
                            for i in range(3)), (1, 1))
    
    path = nx.shortest_path(G, states_dict[initial_state], states_dict[goal_state])
    return ''.join(G[path[i]][path[i+1]]['move'] for i in range(len(path)-1))

def visualize_states(solution: str):
    """Visualize all states in the solution path."""
    initial_state = State(tuple(tuple(0 if (i != 1 or j != 1) else None 
                                    for j in range(3))
                              for i in range(3)), (1, 1))
    
    states = [initial_state]
    current_state = initial_state
    for move in solution:
        for next_state, m in current_state.next_states():
            if m == move:
                current_state = next_state
                break
        states.append(current_state)
    
    n_states = len(states)
    cols = min(5, n_states)
    rows = (n_states + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
    if rows == 1:
        axes = [axes] if cols == 1 else axes
    if cols == 1:
        axes = [[ax] for ax in axes]
    
    colors = {
        0: 'black',
        1: 'red',
        2: 'lightgray',
        3: 'lightgray',
        4: 'lightgray',
        5: 'lightgray'
    }
    labels = {
        0: 'B',
        1: 'C',
        2: 'N',
        3: 'E',
        4: 'S',
        5: 'W'
    }
    
    for idx, state in enumerate(states):
        row, col = divmod(idx, cols)
        ax = axes[row][col]
        
        # Draw grid
        for i in range(4):
            ax.axhline(y=i, color='black', linewidth=1)
            ax.axvline(x=i, color='black', linewidth=1)
        
        # Fill squares
        for i in range(3):
            for j in range(3):
                value = state.orientations[i][j]
                if value is not None:
                    ax.fill([j, j+1, j+1, j], 
                           [3-i-1, 3-i-1, 3-i, 3-i], 
                           color=colors[value], 
                           alpha=0.3)
                    text_color = 'white' if value == 0 else 'black'
                    ax.text(j+0.5, 2.5-i, labels[value], 
                           ha='center', va='center', 
                           color=text_color)
        
        # Mark empty cell
        er, ec = state.empty
        ax.plot(ec+0.5, 2.5-er, 'rx', markersize=10)
        
        title = "Start" if idx == 0 else f"After {solution[idx-1]}"
        ax.set_title(title)
        ax.set_aspect('equal')
        ax.axis('off')
    
    for idx in range(n_states, rows * cols):
        row, col = divmod(idx, cols)
        fig.delaxes(axes[row][col])
    
    plt.tight_layout()
    plt.show()

def main():
    print("Generating complete state graph...")
    G, states_dict = generate_all_states()
    
    print("\nFinding shortest solution...")
    solution = find_solution(G, states_dict)
    print(f"Solution: {solution}")
    print(f"Number of moves: {len(solution)}")
    
    print("\nVisualizing solution path...")
    visualize_states(solution)

if __name__ == "__main__":
    main()