In [1]:
# import
import yaml
import numpy as np
from mlreco.iotools.factories import loader_factory
# configure
cfg = """
iotool:
  batch_size: 1
  shuffle: False
  num_workers: 1
  collate_fn: CollateSparse
  dataset:
    name: LArCVDataset
    data_dirs:
      - /Users/kvtsang/Work/ml/tmp
    data_key: larcv_mc_20190614_022436_403576.root  
    limit_num_files: 10
    schema:
      input_reco:
        - parse_sparse3d_scn
        - sparse3d_reco
      ghost_label:
        - parse_sparse3d_scn
        - sparse3d_ghost
"""
cfg=yaml.load(cfg,Loader=yaml.Loader)

loader,data_keys = loader_factory(cfg)
it = iter(loader)
Welcome to JupyROOT 6.16/00
Loading file: /Users/kvtsang/Work/ml/tmp/larcv_mc_20190614_022436_403576.root
Loading tree sparse3d_reco
Loading tree sparse3d_ghost
In [2]:
import plotly.plotly as py
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=False)

def get_layout_3d(ranges,titles):
    layout = go.Layout(
        showlegend=False,
        width=768,
        height=768,
        #xaxis=titles[0], yaxis=titles[1], zaxis=titles[2],
        margin=dict(l=0,r=0,b=0,t=0),        
        scene = dict(
            xaxis = dict(nticks=10, range = ranges[0], showticklabels=True, 
                         title='x' if titles is None else titles[0],
                         backgroundcolor="lightgray", gridcolor="rgb(255, 255, 255)",
                         showbackground=True,
                        ),
            yaxis = dict(nticks=10, range = ranges[1], showticklabels=True, 
                         title='y' if titles is None else titles[1],
                         backgroundcolor="lightgray", gridcolor="rgb(255, 255, 255)",
                         showbackground=True
                        ),
            zaxis = dict(nticks=10, range = ranges[2], showticklabels=True,
                         title='z' if titles is None else titles[2],
                         backgroundcolor="lightgray", gridcolor="rgb(255, 255, 255)",
                         showbackground=True,
                        ),
            aspectmode='cube',
            camera = dict(
                up=dict(x=0, y=0, z=1),
                center=dict(x=0, y=0, z=0),
                eye=dict(x=1.2, y=1.2, z=0.075)
            ),
        ),  
    )
    return layout

def trace(x,y,z,color,colorscale=None,markersize=2, name=None, hovertext=None):
    trace_3d = go.Scatter3d(x=x, y=y, z=z,
                            mode='markers',
                            name = name,
                            marker = dict(
                                size = markersize,
                                color = color,
                                colorscale=colorscale,
                                opacity=0.7
                            ),
                            hoverinfo = ['x','y','z'] if hovertext is None else ['x','y','z','text'],
                            hovertext = hovertext,
                            #hoverlabel = dict( font = dict( color = 'black') )
                           )
    
    return trace_3d

def ranges(data_array):
    xs = np.concatenate([d[:,0] for d in data_array])
    ys = np.concatenate([d[:,1] for d in data_array])
    zs = np.concatenate([d[:,2] for d in data_array])
    return ((xs.min(),xs.max()),(ys.min(),ys.max()),(zs.min(),zs.max()))

def edep_color(data):
    color=np.log10(data)*100
    color[np.where(color>color.max()*0.8)]=color.max()*0.4
    return color

def plot(traces,ranges,titles=None):        
    fig = go.Figure(data=traces,layout=get_layout_3d(ranges,titles))
    iplot(fig)
In [20]:
data = next(it)
In [34]:
ghost = data[1]
x, y, z = ghost[:, :3].T
label = ghost[:, -1]

xyz_range = ((0., 1500.), (0., 1500.), (0., 1500.))

seg  = trace(x,y,z,color=label, colorscale='YlOrRd', markersize=2)
plot([seg], xyz_range)