| import json |
|
|
| import tensorflow as tf |
| import yaml |
|
|
| from data.preprocess_scripts import * |
| from configs.state_vec import STATE_VEC_IDX_MAPPING, STATE_VEC_LEN |
| from data.utils import capitalize_and_period |
|
|
| |
| DATASET_NAMES_NO_STATE = [ |
| 'nyu_door_opening_surprising_effectiveness', |
| "usc_cloth_sim_converted_externally_to_rlds", |
| 'cmu_franka_exploration_dataset_converted_externally_to_rlds', |
| 'imperialcollege_sawyer_wrist_cam' |
| ] |
|
|
| |
| with open('configs/dataset_img_keys.json', 'r') as file: |
| IMAGE_KEYS = json.load(file) |
| |
| with open('configs/base.yaml', 'r') as file: |
| config = yaml.safe_load(file) |
|
|
|
|
| def assemble_state_vec(arm_concat: tf.Tensor, arm_format: str, |
| base_concat=None, base_format=None) -> tf.Tensor: |
| """ |
| Assemble the state/action vector from the arm and base. |
| """ |
| state_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32) |
| mask_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32) |
|
|
| |
| arm_concat = tf.cast(arm_concat, tf.float32) |
| arm_format = arm_format.split(',') |
| |
| state_vec = tf.tensor_scatter_nd_update( |
| state_vec, |
| [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format], |
| arm_concat |
| ) |
| mask_vec = tf.tensor_scatter_nd_update( |
| mask_vec, |
| [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format], |
| tf.ones(len(arm_format), dtype=tf.float32) |
| ) |
|
|
| |
| if base_concat is not None: |
| base_concat = tf.cast(base_concat, tf.float32) |
| base_format = base_format.split(',') |
| state_vec = tf.tensor_scatter_nd_update( |
| state_vec, |
| [[STATE_VEC_IDX_MAPPING[name]] for name in base_format], |
| base_concat |
| ) |
| mask_vec = tf.tensor_scatter_nd_update( |
| mask_vec, |
| [[STATE_VEC_IDX_MAPPING[name]] for name in base_format], |
| tf.ones(len(base_format), dtype=tf.float32) |
| ) |
| return state_vec, mask_vec |
|
|
|
|
| @tf.autograph.experimental.do_not_convert |
| def _generate_json_state_agilex(episode: dict, dataset_name: str): |
| """ |
| Generate the json dict and state for a given episode. |
| """ |
| |
| IMG_HISTORY_SIZE = config['common']['img_history_size'] |
| if IMG_HISTORY_SIZE < 1: |
| raise ValueError("Config `img_history_size` must be at least 1.") |
| ACTION_CHUNK_SIZE = config['common']['action_chunk_size'] |
| if ACTION_CHUNK_SIZE < 1: |
| raise ValueError("Config `action_chunk_size` must be at least 1.") |
|
|
| |
| episode_metadata = { |
| 'dataset_name': dataset_name, |
| '#steps': 0, |
| 'instruction': None |
| } |
| |
| |
| base_act = None |
| last_base_act = None |
| episode_states = [] |
| episode_acts = [] |
| episode_masks = [] |
| has_base = None |
| for step_id, step in enumerate(iter(episode['steps'])): |
| |
| action = step['action'] |
| if has_base is None: |
| has_base = 'base_concat' in action |
| if has_base: |
| base_act = action['base_concat'] |
| |
| |
| state = step['observation'] |
|
|
| arm_format = state['format'].numpy().decode('utf-8') |
| base_format = None |
| if has_base: |
| act_format = action['format'].numpy().decode('utf-8') |
| base_formate_idx = act_format.find('base') |
| base_format = act_format[base_formate_idx:] |
|
|
| arm_state = state['arm_concat'] |
| base_state = None |
| if has_base: |
| if last_base_act is None: |
| base_state = base_act * 0 |
| else: |
| base_state = last_base_act |
| last_base_act = base_act |
|
|
| |
| state_vec, mask_vec = assemble_state_vec( |
| arm_state, arm_format, base_state, base_format) |
| |
| |
| act_vec, mask_vec = assemble_state_vec( |
| action['arm_concat'], arm_format, base_state, base_format |
| ) |
| |
| episode_states.append(state_vec) |
| episode_masks.append(mask_vec) |
| episode_acts.append(act_vec) |
|
|
| |
| instr = step['observation']['natural_language_instruction'] |
| instr = instr.numpy().decode('utf-8') |
| instr = capitalize_and_period(instr) |
| |
| |
| if episode_metadata['instruction'] is None: |
| episode_metadata['instruction'] = instr |
|
|
| episode_metadata['#steps'] = step_id |
| |
| episode_states = tf.stack(episode_states) |
| episode_masks = tf.stack(episode_masks) |
| episode_acts = tf.stack(episode_acts) |
|
|
| return episode_metadata, episode_states, episode_masks, episode_acts |
|
|
|
|
| @tf.autograph.experimental.do_not_convert |
| def _generate_json_state(episode: dict, dataset_name: str): |
| """ |
| Generate the json dict and state for a given episode. |
| """ |
| |
| IMG_HISTORY_SIZE = config['common']['img_history_size'] |
| if IMG_HISTORY_SIZE < 1: |
| raise ValueError("Config `img_history_size` must be at least 1.") |
| ACTION_CHUNK_SIZE = config['common']['action_chunk_size'] |
| if ACTION_CHUNK_SIZE < 1: |
| raise ValueError("Config `action_chunk_size` must be at least 1.") |
|
|
| |
| episode_metadata = { |
| 'dataset_name': dataset_name, |
| '#steps': 0, |
| 'instruction': None |
| } |
| |
| |
| base_act = None |
| last_base_act = None |
| episode_states = [] |
| episode_masks = [] |
| has_base = None |
| for step_id, step in enumerate(iter(episode['steps'])): |
| |
| action = step['action'] |
| if has_base is None: |
| has_base = 'base_concat' in action |
| if has_base: |
| base_act = action['base_concat'] |
| |
| |
| state = step['observation'] |
|
|
| arm_format = state['format'].numpy().decode('utf-8') |
| base_format = None |
| if has_base: |
| act_format = action['format'].numpy().decode('utf-8') |
| base_formate_idx = act_format.find('base') |
| base_format = act_format[base_formate_idx:] |
|
|
| arm_state = state['arm_concat'] |
| base_state = None |
| if has_base: |
| if last_base_act is None: |
| base_state = base_act * 0 |
| else: |
| base_state = last_base_act |
| last_base_act = base_act |
|
|
| |
| state_vec, mask_vec = assemble_state_vec( |
| arm_state, arm_format, base_state, base_format) |
| |
| episode_states.append(state_vec) |
| episode_masks.append(mask_vec) |
|
|
| |
| instr = step['observation']['natural_language_instruction'] |
| instr = instr.numpy().decode('utf-8') |
| instr = capitalize_and_period(instr) |
| |
| |
| if episode_metadata['instruction'] is None: |
| episode_metadata['instruction'] = instr |
| |
| episode_metadata['#steps'] = step_id |
| episode_states = tf.stack(episode_states) |
| episode_masks = tf.stack(episode_masks) |
|
|
| return episode_metadata, episode_states, episode_masks |
|
|
|
|
| @tf.autograph.experimental.do_not_convert |
| def _generate_json_state_nostate_ds(episode: dict, dataset_name: str): |
| """ |
| Generate the json dict and state for an episode in the dataset without state. |
| If not state, we use the last action as current state. |
| """ |
| |
| IMG_HISTORY_SIZE = config['common']['img_history_size'] |
| if IMG_HISTORY_SIZE < 1: |
| raise ValueError("Config `img_history_size` must be at least 1.") |
| ACTION_CHUNK_SIZE = config['common']['action_chunk_size'] |
| if ACTION_CHUNK_SIZE < 1: |
| raise ValueError("Config `action_chunk_size` must be at least 1.") |
|
|
| |
| episode_metadata = { |
| 'dataset_name': dataset_name, |
| '#steps': 0, |
| 'instruction': None |
| } |
| |
| last_base_act = None |
| last_arm_act = None |
| episode_states = [] |
| episode_masks = [] |
| has_base = None |
| for step_id, step in enumerate(iter(episode['steps'])): |
| |
| action = step['action'] |
| if has_base is None: |
| has_base = 'base_concat' in action |
| if has_base: |
| base_act = action['base_concat'] |
| if last_base_act is None: |
| last_base_act = base_act * 0 |
|
|
| |
| arm_act = action['arm_concat'] |
| if last_arm_act is None: |
| last_arm_act = arm_act * 0 |
|
|
| |
| |
| act_format = action['format'].numpy().decode('utf-8') |
|
|
| |
| if has_base: |
| last_act_concat = tf.concat([last_arm_act, last_base_act], axis=0) |
| else: |
| last_act_concat = last_arm_act |
| state_vec, mask_vec = assemble_state_vec( |
| last_act_concat, act_format) |
|
|
| episode_states.append(state_vec) |
| episode_masks.append(mask_vec) |
|
|
| |
| instr = step['observation']['natural_language_instruction'] |
| instr = instr.numpy().decode('utf-8') |
| instr = capitalize_and_period(instr) |
| |
| |
| if episode_metadata['instruction'] is None: |
| episode_metadata['instruction'] = instr |
|
|
| |
| last_arm_act = arm_act |
| if has_base: |
| last_base_act = base_act |
| |
| episode_metadata['#steps'] = step_id |
| episode_states = tf.stack(episode_states) |
| episode_masks = tf.stack(episode_masks) |
| |
| return episode_metadata, episode_states, episode_masks |
|
|
|
|
| @tf.autograph.experimental.do_not_convert |
| def generate_json_state(episode: dict, dataset_name: str): |
| """ |
| Generate the json dict and state for an episode. |
| """ |
| if isinstance(dataset_name, tf.Tensor): |
| dataset_name = dataset_name.numpy().decode('utf-8') |
|
|
| |
| episode['steps'] = episode['steps'].map( |
| globals()[dataset_name].process_step, |
| ) |
|
|
| if dataset_name == "agilex": |
| return _generate_json_state_agilex(episode, dataset_name) |
| |
| if dataset_name in DATASET_NAMES_NO_STATE: |
| return _generate_json_state_nostate_ds(episode, dataset_name) |
|
|
| return _generate_json_state(episode, dataset_name) |
|
|