示例#1
0
def run_training(cfg, uuid, override={}):
    try:
        logger.info("Running with configuration:\n" + pprint.pformat(cfg))
        torch.set_num_threads(1)
        set_seed(cfg['training']['seed'])

        # get new output_dir name (use for checkpoints)
        old_log_dir = cfg['saving']['log_dir']
        changed_log_dir = False
        existing_log_paths = []
        if os.path.exists(old_log_dir) and cfg['saving']['autofix_log_dir']:
            LOG_DIR, existing_log_paths = evkit.utils.logging.unused_dir_name(
                old_log_dir)
            os.makedirs(LOG_DIR, exist_ok=False)
            cfg['saving']['log_dir'] = LOG_DIR
            cfg['saving']['results_log_file'] = os.path.join(
                LOG_DIR, 'result_log.pkl')
            cfg['saving']['reward_log_file'] = os.path.join(
                LOG_DIR, 'rewards.pkl')
            cfg['saving']['visdom_log_file'] = os.path.join(
                LOG_DIR, 'visdom_logs.json')
            changed_log_dir = True

        # Load checkpoint, config, agent
        agent = None
        if cfg['training']['resumable']:
            if cfg['saving']['checkpoint']:
                prev_run_path = cfg['saving']['checkpoint']
                ckpt_fpath = os.path.join(prev_run_path, 'checkpoints',
                                          'ckpt-latest.dat')
                if cfg['saving'][
                        'checkpoint_configs']:  # update configs with values from ckpt
                    prev_run_metadata_paths = [
                        os.path.join(prev_run_path, f)
                        for f in os.listdir(prev_run_path)
                        if f.endswith('metadata')
                    ]
                    prev_run_config_path = os.path.join(
                        prev_run_metadata_paths[0], 'config.json')
                    with open(prev_run_config_path) as f:
                        config = json.load(
                            f)  # keys are ['cfg', 'uuid', 'seed']
                    cfg = update_dict_deepcopy(cfg, config['cfg'])
                    cfg = update_dict_deepcopy(cfg, override)
                    uuid = config['uuid']
                    logger.warning(
                        "Reusing config from {}".format(prev_run_config_path))
                if ckpt_fpath is not None and os.path.exists(ckpt_fpath):
                    checkpoint_obj = torch.load(ckpt_fpath)
                    start_epoch = checkpoint_obj['epoch']
                    logger.info("Loaded learner (epoch {}) from {}".format(
                        start_epoch, ckpt_fpath))
                    agent = checkpoint_obj['agent']
                    actor_critic = agent.actor_critic
                else:
                    logger.warning(
                        "No checkpoint found at {}".format(ckpt_fpath))

        # Make environment
        simulator, scenario = cfg['env']['env_name'].split('_')
        if cfg['env']['transform_fn_pre_aggregation'] is None:
            cfg['env']['transform_fn_pre_aggregation'] = "None"
        envs = EnvFactory.vectorized(
            cfg['env']['env_name'],
            cfg['training']['seed'],
            cfg['env']['num_processes'],
            cfg['saving']['log_dir'],
            cfg['env']['add_timestep'],
            env_specific_kwargs=cfg['env']['env_specific_kwargs'],
            num_val_processes=cfg['env']['num_val_processes'],
            preprocessing_fn=eval(
                cfg['env']['transform_fn_pre_aggregation'].replace("---",
                                                                   "'")),
            addl_repeat_count=cfg['env']['additional_repeat_count'],
            sensors=cfg['env']['sensors'],
            vis_interval=cfg['saving']['vis_interval'],
            visdom_server=cfg['saving']['visdom_server'],
            visdom_port=cfg['saving']['visdom_port'],
            visdom_log_file=cfg['saving']['visdom_log_file'],
            visdom_name=uuid)
        if 'transform_fn_post_aggregation' in cfg['env'] and cfg['env'][
                'transform_fn_post_aggregation'] is not None:
            transform, space = eval(
                cfg['env']['transform_fn_post_aggregation'].replace(
                    "---", "'"))(envs.observation_space)
            envs = ProcessObservationWrapper(envs, transform, space)
        action_space = envs.action_space
        observation_space = envs.observation_space
        retained_obs_shape = {
            k: v.shape
            for k, v in observation_space.spaces.items()
            if k in cfg['env']['sensors']
        }
        logger.info(f"Action space: {action_space}")
        logger.info(f"Observation space: {observation_space}")
        logger.info("Retaining: {}".format(
            set(observation_space.spaces.keys()).intersection(
                cfg['env']['sensors'].keys())))

        # Finish setting up the agent
        if agent == None:
            perception_model = eval(cfg['learner']['perception_network'])(
                cfg['learner']['num_stack'],
                **cfg['learner']['perception_network_kwargs'])

            forward = ForwardModel(cfg['learner']['internal_state_size'],
                                   (-1, 3),
                                   cfg['learner']['internal_state_size'])
            inverse = InverseModel(cfg['learner']['internal_state_size'],
                                   cfg['learner']['internal_state_size'], 3)
            base = ForwardInverseACModule(
                perception_model,
                forward,
                inverse,
                cfg['learner']['recurrent_policy'],
                internal_state_size=cfg['learner']['internal_state_size'])
            actor_critic = PolicyWithBase(
                base,
                action_space,
                num_stack=cfg['learner']['num_stack'],
                takeover=None)
            if cfg['learner']['use_replay']:
                agent = evkit.rl.algo.PPOReplayCuriosity(
                    actor_critic,
                    cfg['learner']['clip_param'],
                    cfg['learner']['ppo_epoch'],
                    cfg['learner']['num_mini_batch'],
                    cfg['learner']['value_loss_coef'],
                    cfg['learner']['entropy_coef'],
                    cfg['learner']['on_policy_epoch'],
                    cfg['learner']['off_policy_epoch'],
                    lr=cfg['learner']['lr'],
                    eps=cfg['learner']['eps'],
                    max_grad_norm=cfg['learner']['max_grad_norm'],
                    curiosity_reward_coef=cfg['learner']
                    ['curiosity_reward_coef'],
                    forward_loss_coef=cfg['learner']['forward_loss_coef'],
                    inverse_loss_coef=cfg['learner']['inverse_loss_coef'])
            else:
                agent = evkit.rl.algo.PPOCuriosity(
                    actor_critic,
                    cfg['learner']['clip_param'],
                    cfg['learner']['ppo_epoch'],
                    cfg['learner']['num_mini_batch'],
                    cfg['learner']['value_loss_coef'],
                    cfg['learner']['entropy_coef'],
                    lr=cfg['learner']['lr'],
                    eps=cfg['learner']['eps'],
                    max_grad_norm=cfg['learner']['max_grad_norm'],
                    curiosity_reward_coef=cfg['learner']
                    ['curiosity_reward_coef'],
                    forward_loss_coef=cfg['learner']['forward_loss_coef'],
                    inverse_loss_coef=cfg['learner']['inverse_loss_coef'])
            start_epoch = 0

        # Machinery for storing rollouts
        num_train_processes = cfg['env']['num_processes'] - cfg['env'][
            'num_val_processes']
        num_val_processes = cfg['env']['num_val_processes']
        assert cfg['learner']['test']  or (cfg['env']['num_val_processes'] < cfg['env']['num_processes']),\
                "Can't train without some training processes!"
        current_obs = StackedSensorDictStorage(cfg['env']['num_processes'],
                                               cfg['learner']['num_stack'],
                                               retained_obs_shape)
        current_train_obs = StackedSensorDictStorage(
            num_train_processes, cfg['learner']['num_stack'],
            retained_obs_shape)
        logger.debug(f'Stacked obs shape {current_obs.obs_shape}')

        if cfg['learner']['use_replay'] and not cfg['learner']['test']:
            rollouts = RolloutSensorDictCuriosityReplayBuffer(
                cfg['learner']['num_steps'], num_train_processes,
                current_obs.obs_shape, action_space,
                cfg['learner']['internal_state_size'], actor_critic,
                cfg['learner']['use_gae'], cfg['learner']['gamma'],
                cfg['learner']['tau'], cfg['learner']['replay_buffer_size'])
        else:
            rollouts = RolloutSensorDictStorage(
                cfg['learner']['num_steps'], num_train_processes,
                current_obs.obs_shape, action_space,
                cfg['learner']['internal_state_size'])

        # Set up logging
        if cfg['saving']['logging_type'] == 'visdom':
            mlog = tnt.logger.VisdomMeterLogger(
                title=uuid,
                env=uuid,
                server=cfg['saving']['visdom_server'],
                port=cfg['saving']['visdom_port'],
                log_to_filename=cfg['saving']['visdom_log_file'])
        elif cfg['saving']['logging_type'] == 'tensorboard':
            mlog = tnt.logger.TensorboardMeterLogger(
                env=uuid,
                log_dir=cfg['saving']['log_dir'],
                plotstylecombined=True)
        else:
            raise NotImplementedError(
                "Unknown logger type: ({cfg['saving']['logging_type']})")

        # Add metrics and logging to TB/Visdom
        loggable_metrics = [
            'metrics/rewards', 'diagnostics/dist_perplexity',
            'diagnostics/lengths', 'diagnostics/max_importance_weight',
            'diagnostics/value', 'losses/action_loss', 'losses/dist_entropy',
            'losses/value_loss'
        ]
        core_metrics = ['metrics/rewards', 'diagnostics/lengths']
        debug_metrics = ['debug/input_images']
        if 'habitat' in cfg['env']['env_name'].lower():
            for metric in [
                    'metrics/collisions', 'metrics/spl', 'metrics/success'
            ]:
                loggable_metrics.append(metric)
                core_metrics.append(metric)
        for meter in loggable_metrics:
            mlog.add_meter(meter, tnt.meter.ValueSummaryMeter())
        for debug_meter in debug_metrics:
            mlog.add_meter(debug_meter,
                           tnt.meter.SingletonMeter(),
                           ptype='image')
        mlog.add_meter('config', tnt.meter.SingletonMeter(), ptype='text')
        mlog.update_meter(cfg_to_md(cfg, uuid),
                          meters={'config'},
                          phase='train')

        # File loggers
        flog = tnt.logger.FileLogger(cfg['saving']['results_log_file'],
                                     overwrite=True)
        reward_only_flog = tnt.logger.FileLogger(
            cfg['saving']['reward_log_file'], overwrite=True)

        # replay data to mlog, move metadata file
        if changed_log_dir:
            evkit.utils.logging.replay_logs(existing_log_paths, mlog)
            evkit.utils.logging.move_metadata_file(old_log_dir,
                                                   cfg['saving']['log_dir'],
                                                   uuid)

        ##########
        # LEARN! #
        ##########
        if cfg['training']['cuda']:
            current_train_obs = current_train_obs.cuda()
            current_obs = current_obs.cuda()
            rollouts.cuda()
            actor_critic.cuda()

        # These variables are used to compute average rewards for all processes.
        episode_rewards = torch.zeros([cfg['env']['num_processes'], 1])
        episode_lengths = torch.zeros([cfg['env']['num_processes'], 1])
        episode_tracker = evkit.utils.logging.EpisodeTracker(
            cfg['env']['num_processes'])
        if cfg['learner']['test']:
            all_episodes = []

        # First observation
        obs = envs.reset()
        current_obs.insert(obs)
        mask_done = torch.FloatTensor(
            [[0.0] for _ in range(cfg['env']['num_processes'])]).pin_memory()
        states = torch.zeros(
            cfg['env']['num_processes'],
            cfg['learner']['internal_state_size']).pin_memory()

        # Main loop
        start_time = time.time()
        n_episodes_completed = 0
        num_updates = int(cfg['training']['num_frames']) // (
            cfg['learner']['num_steps'] * cfg['env']['num_processes'])
        logger.info(f"Running until num updates == {num_updates}")
        for j in range(start_epoch, num_updates, 1):
            for step in range(cfg['learner']['num_steps']):
                obs_unpacked = {
                    k: current_obs.peek()[k].peek()
                    for k in current_obs.peek()
                }
                if j == start_epoch and step < 10:
                    log_input_images(obs_unpacked,
                                     mlog,
                                     num_stack=cfg['learner']['num_stack'],
                                     key_names=['rgb_filled', 'map'],
                                     meter_name='debug/input_images',
                                     step_num=step)

                # Sample actions
                with torch.no_grad():
                    value, action, action_log_prob, states = actor_critic.act(
                        obs_unpacked, states.cuda(), mask_done.cuda())
                cpu_actions = list(action.squeeze(1).cpu().numpy())
                obs, reward, done, info = envs.step(cpu_actions)
                reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                         1)).float()
                episode_tracker.append(obs)

                # Handle terminated episodes; logging values and computing the "done" mask
                episode_rewards += reward
                episode_lengths += (1 + cfg['env']['additional_repeat_count'])
                mask_done = torch.FloatTensor([[0.0] if done_ else [1.0]
                                               for done_ in done])
                for i, (r, l, done_) in enumerate(
                        zip(episode_rewards, episode_lengths,
                            done)):  # Logging loop
                    if done_:
                        n_episodes_completed += 1
                        if cfg['learner']['test']:
                            info[i]['reward'] = r.item()
                            info[i]['length'] = l.item()
                            all_episodes.append({
                                'info':
                                info[i],
                                'history':
                                episode_tracker.episodes[i][:-1]
                            })
                        episode_tracker.clear_episode(i)
                        phase = 'train' if i < num_train_processes else 'val'
                        mlog.update_meter(r.item(),
                                          meters={'metrics/rewards'},
                                          phase=phase)
                        mlog.update_meter(l.item(),
                                          meters={'diagnostics/lengths'},
                                          phase=phase)
                        if 'habitat' in cfg['env']['env_name'].lower():
                            mlog.update_meter(info[i]["collisions"],
                                              meters={'metrics/collisions'},
                                              phase=phase)
                            if scenario == 'PointNav':
                                mlog.update_meter(info[i]["spl"],
                                                  meters={'metrics/spl'},
                                                  phase=phase)
                                mlog.update_meter(info[i]["success"],
                                                  meters={'metrics/success'},
                                                  phase=phase)
                episode_rewards *= mask_done
                episode_lengths *= mask_done

                # Insert the new observation into RolloutStorage
                if cfg['training']['cuda']:
                    mask_done = mask_done.cuda()
                for k in obs:
                    if k in current_train_obs.sensor_names:
                        current_train_obs[k].insert(
                            obs[k][:num_train_processes],
                            mask_done[:num_train_processes])
                current_obs.insert(obs, mask_done)
                if not cfg['learner']['test']:
                    rollouts.insert(current_train_obs.peek(),
                                    states[:num_train_processes],
                                    action[:num_train_processes],
                                    action_log_prob[:num_train_processes],
                                    value[:num_train_processes],
                                    reward[:num_train_processes],
                                    mask_done[:num_train_processes])
                mlog.update_meter(value[:num_train_processes].mean().item(),
                                  meters={'diagnostics/value'},
                                  phase='train')

            # Training update
            if not cfg['learner']['test']:
                if not cfg['learner']['use_replay']:
                    # Moderate compute saving optimization (if no replay buffer):
                    #     Estimate future-discounted returns only once
                    with torch.no_grad():
                        next_value = actor_critic.get_value(
                            rollouts.observations.at(-1), rollouts.states[-1],
                            rollouts.masks[-1]).detach()
                    rollouts.compute_returns(next_value,
                                             cfg['learner']['use_gae'],
                                             cfg['learner']['gamma'],
                                             cfg['learner']['tau'])
                value_loss, action_loss, dist_entropy, max_importance_weight, info = agent.update(
                    rollouts)
                rollouts.after_update(
                )  # For the next iter: initial obs <- current observation

                # Update meters with latest training info
                mlog.update_meter(dist_entropy, meters={'losses/dist_entropy'})
                mlog.update_meter(np.exp(dist_entropy),
                                  meters={'diagnostics/dist_perplexity'})
                mlog.update_meter(value_loss, meters={'losses/value_loss'})
                mlog.update_meter(action_loss, meters={'losses/action_loss'})
                mlog.update_meter(max_importance_weight,
                                  meters={'diagnostics/max_importance_weight'})

            # Main logging
            if (j) % cfg['saving']['log_interval'] == 0:
                num_relevant_processes = num_val_processes if cfg['learner'][
                    'test'] else num_train_processes
                n_steps_since_logging = cfg['saving'][
                    'log_interval'] * num_relevant_processes * cfg['learner'][
                        'num_steps']
                total_num_steps = (
                    j +
                    1) * num_relevant_processes * cfg['learner']['num_steps']

                logger.info("Update {}, num timesteps {}, FPS {}".format(
                    j + 1, total_num_steps,
                    int(n_steps_since_logging / (time.time() - start_time))))
                logger.info(f"Completed episodes: {n_episodes_completed}")
                viable_modes = ['val'] if cfg['learner']['test'] else [
                    'train', 'val'
                ]
                for metric in core_metrics:  # Log to stdout
                    for mode in viable_modes:
                        if metric in core_metrics or mode == 'train':
                            mlog.print_meter(mode,
                                             total_num_steps,
                                             meterlist={metric})
                if not cfg['learner']['test']:
                    for mode in viable_modes:  # Log to files
                        results = mlog.peek_meter(phase=mode)
                        reward_only_flog.log(mode, {
                            metric: results[metric]
                            for metric in core_metrics
                        })
                        if mode == 'train':
                            results['step_num'] = j + 1
                            flog.log('all_results', results)

                        mlog.reset_meter(total_num_steps, mode=mode)
                start_time = time.time()

            # Save checkpoint
            if not cfg['learner'][
                    'test'] and j % cfg['saving']['save_interval'] == 0:
                save_dir_absolute = os.path.join(cfg['saving']['log_dir'],
                                                 cfg['saving']['save_dir'])
                save_checkpoint({
                    'agent': agent,
                    'epoch': j
                }, save_dir_absolute, j)
            if 'test_k_episodes' in cfg[
                    'learner'] and n_episodes_completed >= cfg['learner'][
                        'test_k_episodes']:
                torch.save(
                    all_episodes,
                    os.path.join(cfg['saving']['log_dir'], 'validation.pth'))
                break

    # Clean up (either after ending normally or early [e.g. from a KeyboardInterrupt])
    finally:
        try:
            if isinstance(envs, list):
                [env.close() for env in envs]
            else:
                envs.close()
            logger.info("Killed envs.")
        except UnboundLocalError:
            logger.info("No envs to kill!")
示例#2
0
def train(cfg, uuid):
    set_seed(cfg['training']['seed'])

    ############################################################
    # Logger
    ############################################################
    logger.setLevel(logging.INFO)
    logger.info(pprint.pformat(cfg))
    logger.debug(f'Loaded Torch version: {torch.__version__}')
    logger.debug(f'Using device: {device}')
    logger.info(f"Training following tasks: ")
    for i, (s, t) in enumerate(
            zip(cfg['training']['sources'], cfg['training']['targets'])):
        logger.info(f"\tTask {i}: {s} -> {t}")
    logger.debug(f'Starting data loaders')

    ############################################################
    # Model (and possibly resume from checkpoint)
    ############################################################
    logger.debug(f'Setting up model')
    search_and_replace_dict(cfg['learner']['model_kwargs'],
                            cfg['training']['targets'][0]
                            [0])  # switches to the proper pretrained encoder
    model = eval(cfg['learner']['model'])(**cfg['learner']['model_kwargs'])
    logger.info(
        f"Created model. Number of trainable parameters: {count_trainable_parameters(model)}. Number of total parameters: {count_total_parameters(model)}"
    )
    try:
        logger.info(
            f"Number of trainable transfer parameters: {count_trainable_parameters(model.transfers)}. Number of total transfer parameters: {count_total_parameters(model.transfers)}"
        )
        if isinstance(model.encoder, nn.Module):
            logger.info(
                f"Number of trainable encoder parameters: {count_trainable_parameters(model.base)}. Number of total encoder parameters: {count_total_parameters(model.base)}"
            )
        if isinstance(model.side_networks, nn.Module):
            logger.info(
                f"Number of trainable side parameters: {count_trainable_parameters(model.sides)}. Number of total side parameters: {count_total_parameters(model.sides)}"
            )
        if isinstance(model.merge_operators, nn.Module):
            logger.info(
                f"Number of trainable merge (alpha) parameters: {count_trainable_parameters(model.merge_operators)}. Number of total merge (alpha) parameters: {count_total_parameters(model.merge_operators)}"
            )
    except:
        pass

    ckpt_fpath = cfg['training']['resume_from_checkpoint_path']
    loaded_optimizer = None
    start_epoch = 0

    if ckpt_fpath is not None and not cfg['training']['resume_training']:
        warnings.warn(
            'Checkpoint path provided but resume_training is set to False, are you sure??'
        )
    if ckpt_fpath is not None and cfg['training']['resume_training']:
        if not os.path.exists(ckpt_fpath):
            logger.warning(
                f'Trying to resume training, but checkpoint path {ckpt_fpath} does not exist. Starting training from beginning...'
            )
        else:
            model, checkpoint = load_state_dict_from_path(model, ckpt_fpath)
            start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0
            logger.info(
                f"Loaded model (epoch {start_epoch if 'epoch' in checkpoint else 'unknown'}) from {ckpt_fpath}"
            )
            if 'optimizer' in checkpoint:
                loaded_optimizer = checkpoint['optimizer']
            else:
                warnings.warn('No optimizer in checkpoint, are you sure?')
            try:  # we do not use state_dict, do not let it take up precious CUDA memory
                del checkpoint['state_dict']
            except KeyError:
                pass

    model.to(device)
    if torch.cuda.device_count() > 1:
        logger.info(f"Using {torch.cuda.device_count()} GPUs!")
        assert cfg['learner'][
            'model'] != 'ConstantModel', 'ConstantModel (blind) does not operate with multiple devices'
        model = nn.DataParallel(model, range(torch.cuda.device_count()))
        model.to(device)

    ############################################################
    # Data Loading
    ############################################################
    for key in ['sources', 'targets', 'masks']:
        cfg['training']['dataloader_fn_kwargs'][key] = cfg['training'][key]

    dataloaders = eval(cfg['training']['dataloader_fn'])(
        **cfg['training']['dataloader_fn_kwargs'])
    if cfg['training']['resume_training']:
        if 'curr_iter_idx' in checkpoint and checkpoint['curr_iter_idx'] == -1:
            warnings.warn(
                f'curr_iter_idx is -1, Guessing curr_iter_idx to be start_epoch {start_epoch}'
            )
            dataloaders['train'].start_dl = start_epoch
        elif 'curr_iter_idx' in checkpoint:
            logger.info(
                f"Starting dataloader at {checkpoint['curr_iter_idx']}")
            dataloaders['train'].start_dl = checkpoint['curr_iter_idx']
        else:
            warnings.warn(
                f'Guessing curr_iter_idx to be start_epoch {start_epoch}')
            dataloaders['train'].start_dl = start_epoch

    ############################################################
    # Loss Functions
    ############################################################
    loss_fn_lst = cfg['training']['loss_fn']
    loss_kwargs_lst = cfg['training']['loss_kwargs']
    if not isinstance(loss_fn_lst, list):
        loss_fn_lst = [loss_fn_lst]
        loss_kwargs_lst = [loss_kwargs_lst]
    elif isinstance(loss_kwargs_lst, dict):
        loss_kwargs_lst = [loss_kwargs_lst for _ in range(len(loss_fn_lst))]

    loss_fns = []
    assert len(loss_fn_lst) == len(
        loss_kwargs_lst), 'number of loss fn/kwargs not the same'
    for loss_fn, loss_kwargs in zip(loss_fn_lst, loss_kwargs_lst):
        if loss_fn == 'perceptual_l1':
            loss_fn = perceptual_l1_loss(
                cfg['training']['loss_kwargs']['decoder_path'],
                cfg['training']['loss_kwargs']['bake_decodings'])
        elif loss_fn == 'perceptual_l2':
            loss_fn = perceptual_l2_loss(
                cfg['training']['loss_kwargs']['decoder_path'],
                cfg['training']['loss_kwargs']['bake_decodings'])
        elif loss_fn == 'perceptual_cross_entropy':
            loss_fn = perceptual_cross_entropy_loss(
                cfg['training']['loss_kwargs']['decoder_path'],
                cfg['training']['loss_kwargs']['bake_decodings'])
        else:
            loss_fn = functools.partial(eval(loss_fn), **loss_kwargs)
        loss_fns.append(loss_fn)

    if len(loss_fns) == 1 and len(cfg['training']['sources']) > 1:
        loss_fns = [
            loss_fns[0] for _ in range(len(cfg['training']['sources']))
        ]

    if 'regularizer_fn' in cfg['training'] and cfg['training'][
            'regularizer_fn'] is not None:
        assert torch.cuda.device_count(
        ) <= 1, 'Regularization does not support multi GPU, unable to access model attributes from DataParallel wrapper'
        bare_model = model.module if torch.cuda.device_count() > 1 else model
        loss_fns = [
            eval(cfg['training']['regularizer_fn'])(
                loss_fn=loss_fn,
                model=bare_model,
                **cfg['training']['regularizer_kwargs'])
            for loss_fn in loss_fns
        ]

    ############################################################
    # More Logging
    ############################################################
    flog = tnt.logger.FileLogger(cfg['saving']['results_log_file'],
                                 overwrite=True)
    mlog = get_logger(cfg, uuid)
    mlog.add_meter('config', tnt.meter.SingletonMeter(), ptype='text')
    mlog.update_meter(cfg_to_md(cfg, uuid), meters={'config'}, phase='train')
    for task, _ in enumerate(cfg['training']['targets']):
        mlog.add_meter(f'alpha/task_{task}', tnt.meter.ValueSummaryMeter())
        mlog.add_meter(f'output/task_{task}',
                       tnt.meter.ValueSummaryMeter(),
                       ptype='image')
        mlog.add_meter(f'input/task_{task}',
                       tnt.meter.ValueSummaryMeter(),
                       ptype='image')
        mlog.add_meter('weight_histogram/task_{task}',
                       tnt.meter.ValueSummaryMeter(),
                       ptype='histogram')
        for loss in cfg['training']['loss_list']:
            mlog.add_meter(f'losses/{loss}_{task}',
                           tnt.meter.ValueSummaryMeter())

        if cfg['training']['task_is_classification'][task]:
            mlog.add_meter(f'accuracy_top1/task_{task}',
                           tnt.meter.ClassErrorMeter(topk=[1], accuracy=True))
            mlog.add_meter(f'accuracy_top5/task_{task}',
                           tnt.meter.ClassErrorMeter(topk=[5], accuracy=True))
            mlog.add_meter(f'perplexity_pred/task_{task}',
                           tnt.meter.ValueSummaryMeter())
            mlog.add_meter(f'perplexity_label/task_{task}',
                           tnt.meter.ValueSummaryMeter())

    ############################################################
    # Training
    ############################################################
    try:
        if cfg['training']['train']:
            # Optimizer
            if cfg['training'][
                    'resume_training'] and loaded_optimizer is not None:
                optimizer = loaded_optimizer
            else:
                optimizer = eval(
                    cfg['learner']['optimizer_class'])([
                        {
                            'params': [
                                param
                                for name, param in model.named_parameters()
                                if 'merge_operator' in name
                                or 'context' in name or 'alpha' in name
                            ],
                            'weight_decay':
                            0.0
                        },
                        {
                            'params': [
                                param
                                for name, param in model.named_parameters()
                                if 'merge_operator' not in name and 'context'
                                not in name and 'alpha' not in name
                            ]
                        },
                    ],
                                                       lr=cfg['learner']['lr'],
                                                       **cfg['learner']
                                                       ['optimizer_kwargs'])

            # Scheduler
            scheduler = None
            if cfg['learner']['lr_scheduler_method'] is not None:
                scheduler = eval(cfg['learner']['lr_scheduler_method'])(
                    optimizer, **cfg['learner']['lr_scheduler_method_kwargs'])

            model.start_training()  # For PSP variant

            # Mixed precision training
            if cfg['training']['amp']:
                from apex import amp
                model, optimizer = amp.initialize(model,
                                                  optimizer,
                                                  opt_level='O1')

            logger.info("Starting training...")
            context = train_model(cfg,
                                  model,
                                  dataloaders,
                                  loss_fns,
                                  optimizer,
                                  start_epoch=start_epoch,
                                  num_epochs=cfg['training']['num_epochs'],
                                  save_epochs=cfg['saving']['save_interval'],
                                  scheduler=scheduler,
                                  mlog=mlog,
                                  flog=flog)
    finally:
        print(psutil.virtual_memory())
        GPUtil.showUtilization(all=True)

    ####################
    # Final Test
    ####################
    if cfg['training']['test']:
        run_kwargs = {
            'cfg': cfg,
            'mlog': mlog,
            'flog': flog,
            'optimizer': None,
            'loss_fns': loss_fns,
            'model': model,
            'use_thread': cfg['saving']['in_background'],
        }
        context, _ = run_one_epoch(dataloader=dataloaders['val'],
                                   epoch=0,
                                   train=False,
                                   **run_kwargs)

    logger.info('Waiting up to 10 minutes for all files to save...')
    mlog.flush()
    [c.join(600) for c in context]
    logger.info('All saving is finished.')
示例#3
0
def run_training(cfg, uuid, override={}):
    try:
        logger.info("-------------\nStarting with configuration:\n" + pprint.pformat(cfg))
        logger.info("UUID: " + uuid)
        torch.set_num_threads(1)
        set_seed(cfg['training']['seed'])

        # get new output_dir name (use for checkpoints)
        old_log_dir = cfg['saving']['log_dir']
        changed_log_dir = False
        existing_log_paths = []
        if os.path.exists(old_log_dir) and cfg['saving']['autofix_log_dir']:
            LOG_DIR, existing_log_paths = evkit.utils.logging.unused_dir_name(old_log_dir)
            os.makedirs(LOG_DIR, exist_ok=False)
            cfg['saving']['log_dir'] = LOG_DIR
            cfg['saving']['results_log_file'] = os.path.join(LOG_DIR, 'result_log.pkl')
            cfg['saving']['reward_log_file'] = os.path.join(LOG_DIR, 'rewards.pkl')
            cfg['saving']['visdom_log_file'] = os.path.join(LOG_DIR, 'visdom_logs.json')
            changed_log_dir = True

        # Load checkpoint, config, agent
        agent = None

        if cfg['training']['resumable']:
            if cfg['saving']['checkpoint']:
                prev_run_path = cfg['saving']['checkpoint']
                if cfg['saving']['checkpoint_num'] is None:
                    ckpt_fpath = os.path.join(prev_run_path, 'checkpoints', 'ckpt-latest.dat')
                else:
                    ckpt_fpath = os.path.join(prev_run_path, 'checkpoints', f"ckpt-{cfg['saving']['checkpoint_num']}.dat")
                if cfg['saving']['checkpoint_configs']:  # update configs with values from ckpt
                    prev_run_metadata_paths = [os.path.join(prev_run_path, f)
                                               for f in os.listdir(prev_run_path)
                                               if f.endswith('metadata')]
                    prev_run_config_path = os.path.join(prev_run_metadata_paths[0], 'config.json')
                    with open(prev_run_config_path) as f:
                        config = json.load(f)  # keys are ['cfg', 'uuid', 'seed']
                    true_log_dir = cfg['saving']['log_dir']
                    cfg = update_dict_deepcopy(cfg, config['cfg'])
                    uuid = config['uuid']
                    logger.warning("Reusing config from {}".format(prev_run_config_path))
                    # the saving files should always use the new log dir
                    cfg['saving']['log_dir'] = true_log_dir
                    cfg['saving']['results_log_file'] = os.path.join(true_log_dir, 'result_log.pkl')
                    cfg['saving']['reward_log_file'] = os.path.join(true_log_dir, 'rewards.pkl')
                    cfg['saving']['visdom_log_file'] = os.path.join(true_log_dir, 'visdom_logs.json')
                if ckpt_fpath is not None and os.path.exists(ckpt_fpath):
                    checkpoint_obj = torch.load(ckpt_fpath)
                    start_epoch = checkpoint_obj['epoch']
                    logger.info("Loaded learner (epoch {}) from {}".format(start_epoch, ckpt_fpath))
                    if cfg['learner']['algo'] == 'imitation_learning':
                        actor_critic = checkpoint_obj['model']
                        try:
                            actor_critic = actor_critic.module  # remove DataParallel
                        except:
                            pass
                    else:
                        agent = checkpoint_obj['agent']
                        actor_critic = agent.actor_critic
                else:
                    logger.warning("No checkpoint found at {}".format(ckpt_fpath))
        cfg = update_dict_deepcopy(cfg, override)
        logger.info("-------------\n Running with configuration:\n" + pprint.pformat(cfg))

        # Verify configs are consistent - baked version needs to match un-baked version
        try:
            taskonomy_transform = cfg['env']['transform_fn_post_aggregation_kwargs']['names_to_transforms']['taskonomy']
            taskonomy_encoder = cfg['learner']['perception_network_kwargs']['extra_kwargs']['sidetune_kwargs']['base_weights_path']
            assert taskonomy_encoder in taskonomy_transform, f'Taskonomy PostTransform and perception network base need to match. {taskonomy_encoder} != {taskonomy_transform}'
        except KeyError:
            pass

        if cfg['training']['gpu_devices'] is None:
            cfg['training']['gpu_devices'] = list(range(torch.cuda.device_count()))
        assert not (len(cfg['training']['gpu_devices']) > 1 and 'attributes' in cfg['learner']['cache_kwargs']), 'Cannot utilize cache with more than one model GPU'

        # Make environment
        simulator, scenario = cfg['env']['env_name'].split('_')

        transform_pre_aggregation = None
        if cfg['env']['transform_fn_pre_aggregation'] is not None:
            logger.warning('Using depreciated config transform_fn_pre_aggregation')
            transform_pre_aggregation = eval(cfg['env']['transform_fn_pre_aggregation'].replace("---", "'"))
        elif 'transform_fn_pre_aggregation_fn' in cfg['env'] and cfg['env'][
            'transform_fn_pre_aggregation_fn'] is not None:
            pre_aggregation_kwargs = copy.deepcopy(cfg['env']['transform_fn_pre_aggregation_kwargs'])
            transform_pre_aggregation = eval(cfg['env']['transform_fn_pre_aggregation_fn'].replace("---", "'"))(
                **eval_dict_values(pre_aggregation_kwargs))

        if 'debug_mode' in cfg['env']['env_specific_kwargs'] and cfg['env']['env_specific_kwargs']['debug_mode']:
            assert cfg['env']['num_processes'] == 1, 'Using debug mode requires you to only use one process'

        envs = EnvFactory.vectorized(
            cfg['env']['env_name'],
            cfg['training']['seed'],
            cfg['env']['num_processes'],
            cfg['saving']['log_dir'],
            cfg['env']['add_timestep'],
            env_specific_kwargs=cfg['env']['env_specific_kwargs'],
            num_val_processes=cfg['env']['num_val_processes'],
            preprocessing_fn=transform_pre_aggregation,
            addl_repeat_count=cfg['env']['additional_repeat_count'],
            sensors=cfg['env']['sensors'],
            vis_interval=cfg['saving']['vis_interval'],
            visdom_server=cfg['saving']['visdom_server'],
            visdom_port=cfg['saving']['visdom_port'],
            visdom_log_file=cfg['saving']['visdom_log_file'],
            visdom_name=uuid)

        transform_post_aggregation = None
        if 'transform_fn_post_aggregation' in cfg['env'] and cfg['env']['transform_fn_post_aggregation'] is not None:
            logger.warning('Using depreciated config transform_fn_post_aggregation')
            transform_post_aggregation = eval(cfg['env']['transform_fn_post_aggregation'].replace("---", "'"))
        elif 'transform_fn_post_aggregation_fn' in cfg['env'] and cfg['env'][
            'transform_fn_post_aggregation_fn'] is not None:
            post_aggregation_kwargs = copy.deepcopy(cfg['env']['transform_fn_post_aggregation_kwargs'])
            transform_post_aggregation = eval(cfg['env']['transform_fn_post_aggregation_fn'].replace("---", "'"))(
                **eval_dict_values(post_aggregation_kwargs))

        if transform_post_aggregation is not None:
            transform, space = transform_post_aggregation(envs.observation_space)
            envs = ProcessObservationWrapper(envs, transform, space)

        action_space = envs.action_space
        observation_space = envs.observation_space
        retained_obs_shape = {k: v.shape
                              for k, v in observation_space.spaces.items()
                              if k in cfg['env']['sensors']}
        logger.info(f"Action space: {action_space}")
        logger.info(f"Observation space: {observation_space}")
        logger.info(
            "Retaining: {}".format(set(observation_space.spaces.keys()).intersection(cfg['env']['sensors'].keys())))

        # Finish setting up the agent
        if agent == None and cfg['learner']['algo'] == 'ppo':
            perception_model = eval(cfg['learner']['perception_network'])(
                cfg['learner']['num_stack'],
                **cfg['learner']['perception_network_kwargs'])
            base = NaivelyRecurrentACModule(
                perception_unit=perception_model,
                use_gru=cfg['learner']['recurrent_policy'],
                internal_state_size=cfg['learner']['internal_state_size'])
            actor_critic = PolicyWithBase(
                base, action_space,
                num_stacks=cfg['learner']['num_stack'],
                takeover=None,
                loss_kwargs=cfg['learner']['loss_kwargs'],
                gpu_devices=cfg['training']['gpu_devices'],
            )
            if cfg['learner']['use_replay']:
                agent = evkit.rl.algo.PPOReplay(actor_critic,
                                                cfg['learner']['clip_param'],
                                                cfg['learner']['ppo_epoch'],
                                                cfg['learner']['num_mini_batch'],
                                                cfg['learner']['value_loss_coef'],
                                                cfg['learner']['entropy_coef'],
                                                cfg['learner']['on_policy_epoch'],
                                                cfg['learner']['off_policy_epoch'],
                                                cfg['learner']['num_steps'],
                                                cfg['learner']['num_stack'],
                                                lr=cfg['learner']['lr'],
                                                eps=cfg['learner']['eps'],
                                                max_grad_norm=cfg['learner']['max_grad_norm'],
                                                gpu_devices=cfg['training']['gpu_devices'],
                                                loss_kwargs=cfg['learner']['loss_kwargs'],
                                                cache_kwargs=cfg['learner']['cache_kwargs'],
                                                optimizer_class = cfg['learner']['optimizer_class'],
                                                optimizer_kwargs = cfg['learner']['optimizer_kwargs']
                )
            else:
                agent = evkit.rl.algo.PPO(actor_critic,
                                          cfg['learner']['clip_param'],
                                          cfg['learner']['ppo_epoch'],
                                          cfg['learner']['num_mini_batch'],
                                          cfg['learner']['value_loss_coef'],
                                          cfg['learner']['entropy_coef'],
                                          lr=cfg['learner']['lr'],
                                          eps=cfg['learner']['eps'],
                                          max_grad_norm=cfg['learner']['max_grad_norm']
                )
            start_epoch = 0

            # Set up data parallel
            if torch.cuda.device_count() > 1 and (cfg['training']['gpu_devices'] is None or len(cfg['training']['gpu_devices']) > 1):
                actor_critic.data_parallel(cfg['training']['gpu_devices'])

        elif agent == None and cfg['learner']['algo'] == 'slam':
            assert cfg['learner']['slam_class'] is not None, 'Must define SLAM agent class'
            actor_critic = eval(cfg['learner']['slam_class'])(**cfg['learner']['slam_kwargs'])
            start_epoch = 0

        elif cfg['learner']['algo'] == 'expert':
            actor_critic = eval(cfg['learner']['algo_class'])(**cfg['learner']['algo_kwargs'])
            start_epoch = 0

        if cfg['learner']['algo'] == 'expert':
            assert 'debug_mode' in cfg['env']['env_specific_kwargs'] and cfg['env']['env_specific_kwargs']['debug_mode'], 'need to use debug mode with expert algo'

        if cfg['learner']['perception_network_reinit'] and cfg['learner']['algo'] == 'ppo':
            logger.info('Reinit perception network, use with caution')
            # do not reset map_tower and other parts of the TaskonomyFeaturesOnlyNetwork
            old_perception_unit = actor_critic.base.perception_unit
            new_perception_unit = eval(cfg['learner']['perception_network'])(
                cfg['learner']['num_stack'],
                **cfg['learner']['perception_network_kwargs'])
            new_perception_unit.main_perception = old_perception_unit  # main perception does not change
            actor_critic.base.perception_unit = new_perception_unit  # only x['taskonomy'] changes

            # match important configs of old model
            if (actor_critic.gpu_devices == None or len(actor_critic.gpu_devices) == 1) and len(cfg['training']['gpu_devices']) > 1:
                actor_critic.data_parallel(cfg['training']['gpu_devices'])
            actor_critic.gpu_devices = cfg['training']['gpu_devices']
            agent.gpu_devices = cfg['training']['gpu_devices']

        # Machinery for storing rollouts
        num_train_processes = cfg['env']['num_processes'] - cfg['env']['num_val_processes']
        num_val_processes = cfg['env']['num_val_processes']
        assert cfg['learner']['test'] or (cfg['env']['num_val_processes'] < cfg['env']['num_processes']), \
            "Can't train without some training processes!"
        current_obs = StackedSensorDictStorage(cfg['env']['num_processes'], cfg['learner']['num_stack'],
                                               retained_obs_shape)
        if not cfg['learner']['test']:
            current_train_obs = StackedSensorDictStorage(num_train_processes, cfg['learner']['num_stack'],
                                                         retained_obs_shape)
        logger.debug(f'Stacked obs shape {current_obs.obs_shape}')

        if cfg['learner']['use_replay'] and not cfg['learner']['test']:
            rollouts = RolloutSensorDictReplayBuffer(
                cfg['learner']['num_steps'],
                num_train_processes,
                current_obs.obs_shape,
                action_space,
                cfg['learner']['internal_state_size'],
                actor_critic,
                cfg['learner']['use_gae'],
                cfg['learner']['gamma'],
                cfg['learner']['tau'],
                cfg['learner']['replay_buffer_size'],
                batch_multiplier=cfg['learner']['rollout_value_batch_multiplier']
            )
        else:
            rollouts = RolloutSensorDictStorage(
                cfg['learner']['num_steps'],
                num_train_processes,
                current_obs.obs_shape,
                action_space,
                cfg['learner']['internal_state_size'])

        # Set up logging
        if cfg['saving']['logging_type'] == 'visdom':
            mlog = tnt.logger.VisdomMeterLogger(
                title=uuid, env=uuid,
                server=cfg['saving']['visdom_server'],
                port=cfg['saving']['visdom_port'],
                log_to_filename=cfg['saving']['visdom_log_file'])
        elif cfg['saving']['logging_type'] == 'tensorboard':
            mlog = tnt.logger.TensorboardMeterLogger(
                env=uuid,
                log_dir=cfg['saving']['log_dir'],
                plotstylecombined=True)
        else:
            raise NotImplementedError("Unknown logger type: ({cfg['saving']['logging_type']})")

        # Add metrics and logging to TB/Visdom
        loggable_metrics = ['metrics/rewards',
                            'diagnostics/dist_perplexity',
                            'diagnostics/lengths',
                            'diagnostics/max_importance_weight',
                            'diagnostics/value',
                            'losses/action_loss',
                            'losses/dist_entropy',
                            'losses/value_loss',
                            'introspect/alpha']
        if 'intrinsic_loss_types' in cfg['learner']['loss_kwargs']:
            for iloss in cfg['learner']['loss_kwargs']['intrinsic_loss_types']:
                loggable_metrics.append(f"losses/{iloss}")
        core_metrics = ['metrics/rewards', 'diagnostics/lengths']
        debug_metrics = ['debug/input_images']
        if 'habitat' in cfg['env']['env_name'].lower():
            for metric in ['metrics/collisions', 'metrics/spl', 'metrics/success']:
                loggable_metrics.append(metric)
                core_metrics.append(metric)
        for meter in loggable_metrics:
            mlog.add_meter(meter, tnt.meter.ValueSummaryMeter())
        for debug_meter in debug_metrics:
            mlog.add_meter(debug_meter, tnt.meter.SingletonMeter(), ptype='image')
        try:
            for attr in cfg['learner']['perception_network_kwargs']['extra_kwargs']['attrs_to_remember']:
                mlog.add_meter(f'diagnostics/{attr}', tnt.meter.ValueSummaryMeter(), ptype='histogram')
        except KeyError:
            pass

        mlog.add_meter('config', tnt.meter.SingletonMeter(), ptype='text')
        mlog.update_meter(cfg_to_md(cfg, uuid), meters={'config'}, phase='train')

        # File loggers
        flog = tnt.logger.FileLogger(cfg['saving']['results_log_file'], overwrite=True)
        try:
            flog_keys_to_remove = [f'diagnostics/{k}' for k in cfg['learner']['perception_network_kwargs']['extra_kwargs']['attrs_to_remember']]
        except KeyError:
            warnings.warn('Unable to find flog keys to remove')
            flog_keys_to_remove = []
        reward_only_flog = tnt.logger.FileLogger(cfg['saving']['reward_log_file'], overwrite=True)

        # replay data to mlog, move metadata file
        if changed_log_dir:
            evkit.utils.logging.replay_logs(existing_log_paths, mlog)
            evkit.utils.logging.move_metadata_file(old_log_dir, cfg['saving']['log_dir'], uuid)

        ##########
        # LEARN! #
        ##########
        if cfg['training']['cuda']:
            if not cfg['learner']['test']:
                current_train_obs = current_train_obs.cuda(device=cfg['training']['gpu_devices'][0])
            current_obs = current_obs.cuda(device=cfg['training']['gpu_devices'][0])
            # rollouts.cuda(device=cfg['training']['gpu_devices'][0])  # rollout should be on RAM
            try:
                actor_critic.cuda(device=cfg['training']['gpu_devices'][0])
            except UnboundLocalError as e:
                logger.error(f'Cannot put actor critic on cuda. Are you using a checkpoint and is it being found/initialized properly? {e}')
                raise e

        # These variables are used to compute average rewards for all processes.
        episode_rewards = torch.zeros([cfg['env']['num_processes'], 1])
        episode_lengths = torch.zeros([cfg['env']['num_processes'], 1])
        episode_tracker = evkit.utils.logging.EpisodeTracker(cfg['env']['num_processes'])
        if cfg['learner']['test']:
            all_episodes = []
            actor_critic.eval()
            try:
                actor_critic.base.perception_unit.sidetuner.attrs_to_remember = []
            except:
                pass

        # First observation
        obs = envs.reset()
        current_obs.insert(obs)
        mask_done = torch.FloatTensor([[0.0] for _ in range(cfg['env']['num_processes'])]).cuda(device=cfg['training']['gpu_devices'][0], non_blocking=True)
        states = torch.zeros(cfg['env']['num_processes'], cfg['learner']['internal_state_size']).cuda(device=cfg['training']['gpu_devices'][0], non_blocking=True)
        try:
            actor_critic.reset(envs=envs)
        except:
            actor_critic.reset()

        # Main loop
        start_time = time.time()
        n_episodes_completed = 0
        num_updates = int(cfg['training']['num_frames']) // (cfg['learner']['num_steps'] * cfg['env']['num_processes'])
        if cfg['learner']['test']:
            logger.info(f"Running {cfg['learner']['test_k_episodes']}")
        else:
            logger.info(f"Running until num updates == {num_updates}")
        for j in range(start_epoch, num_updates, 1):
            for step in range(cfg['learner']['num_steps']):
                obs_unpacked = {k: current_obs.peek()[k].peek() for k in current_obs.peek()}
                if j == start_epoch and step < 10:
                    log_input_images(obs_unpacked, mlog, num_stack=cfg['learner']['num_stack'],
                                     key_names=['rgb_filled', 'map'], meter_name='debug/input_images', step_num=step)

                # Sample actions
                with torch.no_grad():
                    # value, action, action_log_prob, states = actor_critic.act(
                    #     {k:v.cuda(device=cfg['training']['gpu_devices'][0]) for k, v in obs_unpacked.items()},
                    #     states.cuda(device=cfg['training']['gpu_devices'][0]),
                    #     mask_done.cuda(device=cfg['training']['gpu_devices'][0]))
                    # All should already be on training.gpu_devices[0]
                    value, action, action_log_prob, states = actor_critic.act(
                        obs_unpacked, states, mask_done, cfg['learner']['deterministic'])
                cpu_actions = list(action.squeeze(1).cpu().numpy())
                obs, reward, done, info = envs.step(cpu_actions)
                mask_done_cpu = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
                mask_done = mask_done_cpu.cuda(device=cfg['training']['gpu_devices'][0], non_blocking=True)
                reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float()
                episode_tracker.append(obs, cpu_actions)

                # log diagnostics
                if cfg['learner']['test']:
                    try:
                        mlog.update_meter(actor_critic.perplexity.cpu(), meters={'diagnostics/dist_perplexity'}, phase='val')
                        mlog.update_meter(actor_critic.entropy.cpu(), meters={'losses/dist_entropy'}, phase='val')
                        mlog.update_meter(value.cpu(), meters={'diagnostics/value'}, phase='val')
                    except AttributeError:
                        pass


                # Handle terminated episodes; logging values and computing the "done" mask
                episode_rewards += reward
                episode_lengths += (1 + cfg['env']['additional_repeat_count'])
                for i, (r, l, done_) in enumerate(zip(episode_rewards, episode_lengths, done)):  # Logging loop
                    if done_:
                        n_episodes_completed += 1
                        if cfg['learner']['test']:
                            info[i]['reward'] = r.item()
                            info[i]['length'] = l.item()
                            if 'debug_mode' in cfg['env']['env_specific_kwargs'] and cfg['env']['env_specific_kwargs']['debug_mode']:
                                info[i]['scene_id'] = envs.env.env.env._env.current_episode.scene_id
                                info[i]['episode_id'] = envs.env.env.env._env.current_episode.episode_id
                            all_episodes.append({
                                'info': info[i],
                                'history': episode_tracker.episodes[i][:-1]})
                        episode_tracker.clear_episode(i)
                        phase = 'train' if i < num_train_processes else 'val'
                        mlog.update_meter(r.item(), meters={'metrics/rewards'}, phase=phase)
                        mlog.update_meter(l.item(), meters={'diagnostics/lengths'}, phase=phase)
                        if 'habitat' in cfg['env']['env_name'].lower():
                            mlog.update_meter(info[i]["collisions"], meters={'metrics/collisions'}, phase=phase)
                            if scenario == 'PointNav':
                                mlog.update_meter(info[i]["spl"], meters={'metrics/spl'}, phase=phase)
                                mlog.update_meter(info[i]["success"], meters={'metrics/success'}, phase=phase)

                        # reset env then agent... note this only works for single process
                        if 'debug_mode' in cfg['env']['env_specific_kwargs'] and cfg['env']['env_specific_kwargs']['debug_mode']:
                            obs = envs.reset()
                        try:
                            actor_critic.reset(envs=envs)
                        except:
                            actor_critic.reset()
                episode_rewards *= mask_done_cpu
                episode_lengths *= mask_done_cpu

                # Insert the new observation into RolloutStorage
                current_obs.insert(obs, mask_done)
                if not cfg['learner']['test']:
                    for k in obs:
                        if k in current_train_obs.sensor_names:
                            current_train_obs[k].insert(obs[k][:num_train_processes], mask_done[:num_train_processes])
                    rollouts.insert(current_train_obs.peek(),
                                    states[:num_train_processes],
                                    action[:num_train_processes],
                                    action_log_prob[:num_train_processes],
                                    value[:num_train_processes],
                                    reward[:num_train_processes],
                                    mask_done[:num_train_processes])
                    mlog.update_meter(value[:num_train_processes].mean().item(), meters={'diagnostics/value'},
                                      phase='train')

            # Training update
            if not cfg['learner']['test']:
                if not cfg['learner']['use_replay']:
                    # Moderate compute saving optimization (if no replay buffer):
                    #     Estimate future-discounted returns only once
                    with torch.no_grad():
                        next_value = actor_critic.get_value(rollouts.observations.at(-1),
                                                            rollouts.states[-1],
                                                            rollouts.masks[-1]).detach()
                    rollouts.compute_returns(next_value, cfg['learner']['use_gae'], cfg['learner']['gamma'],
                                             cfg['learner']['tau'])
                value_loss, action_loss, dist_entropy, max_importance_weight, info = agent.update(rollouts)
                rollouts.after_update()  # For the next iter: initial obs <- current observation

                # Update meters with latest training info
                mlog.update_meter(dist_entropy, meters={'losses/dist_entropy'})
                mlog.update_meter(np.exp(dist_entropy), meters={'diagnostics/dist_perplexity'})
                mlog.update_meter(value_loss, meters={'losses/value_loss'})
                mlog.update_meter(action_loss, meters={'losses/action_loss'})
                mlog.update_meter(max_importance_weight, meters={'diagnostics/max_importance_weight'})
                if 'intrinsic_loss_types' in cfg['learner']['loss_kwargs'] and len(cfg['learner']['loss_kwargs']['intrinsic_loss_types']) > 0:
                    for iloss in cfg['learner']['loss_kwargs']['intrinsic_loss_types']:
                        mlog.update_meter(info[iloss], meters={f'losses/{iloss}'})
                try:
                    for attr in cfg['learner']['perception_network_kwargs']['extra_kwargs']['attrs_to_remember']:
                        mlog.update_meter(info[attr].cpu(), meters={f'diagnostics/{attr}'})
                except KeyError:
                    pass

                try:
                    if hasattr(actor_critic, 'module'):
                        alpha = [param for name, param in actor_critic.module.named_parameters() if 'alpha' in name][0]
                    else:
                        alpha = [param for name, param in actor_critic.named_parameters() if 'alpha' in name][0]
                    mlog.update_meter(torch.sigmoid(alpha).detach().item(), meters={f'introspect/alpha'})
                except IndexError:
                    pass

            # Main logging
            if (j) % cfg['saving']['log_interval'] == 0:
                torch.cuda.empty_cache()
                GPUtil.showUtilization()
                count_open()
                num_relevant_processes = num_val_processes if cfg['learner']['test'] else num_train_processes
                n_steps_since_logging = cfg['saving']['log_interval'] * num_relevant_processes * cfg['learner'][
                    'num_steps']
                total_num_steps = (j + 1) * num_relevant_processes * cfg['learner']['num_steps']

                logger.info("Update {}, num timesteps {}, FPS {}".format(
                    j + 1,
                    total_num_steps,
                    int(n_steps_since_logging / (time.time() - start_time))
                ))
                logger.info(f"Completed episodes: {n_episodes_completed}")
                viable_modes = ['val'] if cfg['learner']['test'] else ['train', 'val']
                for metric in core_metrics:  # Log to stdout
                    for mode in viable_modes:
                        if metric in core_metrics or mode == 'train':
                            mlog.print_meter(mode, total_num_steps, meterlist={metric})
                if not cfg['learner']['test']:
                    for mode in viable_modes:  # Log to files
                        results = mlog.peek_meter(phase=mode)
                        reward_only_flog.log(mode, {metric: results[metric] for metric in core_metrics})
                        if mode == 'train':
                            results_to_log = {}
                            results['step_num'] = j + 1
                            results_to_log['step_num'] = results['step_num']
                            for k,v in results.items():
                                if k in flog_keys_to_remove:
                                    warnings.warn(f'Removing {k} from results_log.pkl due to large size')
                                else:
                                    results_to_log[k] = v
                            flog.log('all_results', results_to_log)

                        mlog.reset_meter(total_num_steps, mode=mode)
                start_time = time.time()

            # Save checkpoint
            if not cfg['learner']['test'] and j % cfg['saving']['save_interval'] == 0:
                save_dir_absolute = os.path.join(cfg['saving']['log_dir'], cfg['saving']['save_dir'])
                save_checkpoint(
                    {'agent': agent, 'epoch': j},
                    save_dir_absolute, j)
            if 'test_k_episodes' in cfg['learner'] and n_episodes_completed >= cfg['learner']['test_k_episodes']:
                torch.save(all_episodes, os.path.join(cfg['saving']['log_dir'], 'validation.pth'))
                all_episodes = all_episodes[:cfg['learner']['test_k_episodes']]
                spl_mean = np.mean([episode['info']['spl'] for episode in all_episodes])
                success_mean = np.mean([episode['info']['success'] for episode in all_episodes])
                reward_mean = np.mean([episode['info']['reward'] for episode in all_episodes])
                logger.info('------------ done with testing -------------')
                logger.info(f'SPL: {spl_mean} --- Success: {success_mean} --- Reward: {reward_mean}')
                for metric in mlog.meter['val'].keys():
                    mlog.print_meter('val', -1, meterlist={metric})
                break

    # Clean up (either after ending normally or early [e.g. from a KeyboardInterrupt])
    finally:
        print(psutil.virtual_memory())
        GPUtil.showUtilization(all=True)
        try:
            logger.info("### Done - Killing envs.")
            if isinstance(envs, list):
                [env.close() for env in envs]
            else:
                envs.close()
            logger.info("Killed envs.")
        except UnboundLocalError:
            logger.info("No envs to kill!")
示例#4
0
def train(cfg, uuid):
    set_seed(cfg['training']['seed'])

    ############################################################
    # Logger
    ############################################################
    logger.setLevel(logging.INFO)
    logger.info(pprint.pformat(cfg))
    logger.debug(f'Loaded Torch version: {torch.__version__}')
    logger.debug(f'Using device: {device}')

    assert len(
        cfg['training']
        ['targets']) == 1, "Transferring is only supported for one target task"
    logger.info(
        f"Training ({ cfg['training']['sources']}) -> ({cfg['training']['targets']})"
    )

    ############################################################
    # Verify configs are consistent - baked version needs to match un-baked version
    ############################################################
    taskonomy_sources = [
        src for src in cfg['training']['sources'] if 'taskonomy' in src
    ]
    assert len(
        taskonomy_sources
    ) <= 1, 'We have no way of handling multiple taskonomy features right now'
    if len(taskonomy_sources) == 1:
        # TODO refactor
        # GenericSidetuneNetwork for Vision Transfer tasks
        if 'encoder_weights_path' in cfg['learner']['model_kwargs']:
            assert cfg['learner']['model_kwargs'][
                'encoder_weights_path'] is not None, 'if we have a taskonomy feature as a source, the model should reflect that'

        # PolicyWithBase for Imitation learning
        try:
            encoder_path = cfg['learner']['model_kwargs']['base_kwargs'][
                'perception_unit_kwargs']['extra_kwargs']['sidetune_kwargs'][
                    'encoder_weights_path']
            assert encoder_path is not None, 'if we have a taskonomy feature as a source, the model should reflect that'
        except KeyError:
            pass

    ############################################################
    # Data Loading
    ############################################################
    logger.debug(f'Starting data loaders')
    data_subfolders = cfg['training']['sources'][:]
    if 'bake_decodings' in cfg['training']['loss_kwargs'] and cfg['training'][
            'loss_kwargs']['bake_decodings']:
        # do not get encodings, convert encodings to decodings
        assert all([
            'encoding' in t for t in cfg['training']['targets']
        ]), 'Do not bake_decodings if your target is not an encoding'
        target_decodings = [
            t.replace('encoding', 'decoding')
            for t in cfg['training']['targets']
        ]
        data_subfolders += target_decodings
    elif not cfg['training']['suppress_target_and_use_annotator']:
        data_subfolders += cfg['training']['targets']
    else:  # use annotator
        cfg['training']['annotator'] = load_submodule(
            eval(cfg['training']['annotator_class']),
            cfg['training']['annotator_weights_path'],
            cfg['training']['annotator_kwargs']).eval()
        cfg['training']['annotator'] = cfg['training']['annotator'].to(device)

    if cfg['training']['use_masks']:
        data_subfolders += ['mask_valid']

    if cfg['training'][
            'dataloader_fn'] is None:  # Legacy support for old config type.
        DeprecationWarning(
            "Empty cfg.learner.dataloader_fn is deprecated and will be removed in a future version"
        )
        logger.info(f"Using split: {cfg['training']['split_to_use']}")
        dataloaders = taskonomy_dataset.get_dataloaders(
            cfg['training']['data_dir'],
            data_subfolders,
            batch_size=cfg['training']['batch_size'],
            batch_size_val=cfg['training']['batch_size_val'],
            zip_file_name=False,
            train_folders=eval(cfg['training']['split_to_use'])['train'],
            val_folders=eval(cfg['training']['split_to_use'])['val'],
            test_folders=eval(cfg['training']['split_to_use'])['test'],
            num_workers=cfg['training']['num_workers'],
            load_to_mem=cfg['training']['load_to_mem'],
            pin_memory=cfg['training']['pin_memory'])
    else:
        cfg['training']['dataloader_fn_kwargs']['tasks'] = data_subfolders
        dataloaders = eval(cfg['training']['dataloader_fn'])(
            **cfg['training']['dataloader_fn_kwargs'])

    ############################################################
    # Model (and possibly resume from checkpoint)
    ############################################################
    logger.debug(f'Setting up model')
    model = eval(cfg['learner']['model'])(**cfg['learner']['model_kwargs'])
    logger.info(
        f"Created model. Number of trainable parameters: {count_trainable_parameters(model)}."
    )

    loaded_optimizer = None
    start_epoch = 0
    ckpt_fpath = cfg['training']['resume_from_checkpoint_path']
    if ckpt_fpath is not None:
        if cfg['training']['resume_training'] and not os.path.exists(
                ckpt_fpath):
            logger.warning(
                f'Trying to resume training, but checkpoint path {ckpt_fpath} does not exist. Starting training from beginning...'
            )
        else:
            checkpoint = torch.load(ckpt_fpath)
            start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else 0

            state_dict = {
                k.replace('module.', ''): v
                for k, v in checkpoint['state_dict'].items()
            }
            model.load_state_dict(state_dict)
            logger.info(
                f"Loaded model (epoch {start_epoch if 'epoch' in checkpoint else 'unknown'}) from {ckpt_fpath}"
            )

            loaded_optimizer = checkpoint['optimizer']
            logger.info(
                f"Loaded optimizer (epoch {start_epoch if 'epoch' in checkpoint else 'unknown'}) from {ckpt_fpath}"
            )

    model.to(device)
    if torch.cuda.device_count() > 1:
        logger.info(f"Using {torch.cuda.device_count()} GPUs!")
        assert cfg['learner'][
            'model'] != 'ConstantModel', 'ConstantModel (e.g. blind) does not operate with multiple devices'
        model = torch.nn.DataParallel(model)

    ############################################################
    # Loss Function
    ############################################################
    if cfg['training']['loss_fn'] == 'perceptual_l1':
        loss_fn = perceptual_l1_loss(
            cfg['training']['loss_kwargs']['decoder_path'],
            cfg['training']['loss_kwargs']['bake_decodings'])
    elif cfg['training']['loss_fn'] == 'perceptual_l2':
        loss_fn = perceptual_l2_loss(
            cfg['training']['loss_kwargs']['decoder_path'],
            cfg['training']['loss_kwargs']['bake_decodings'])
    elif cfg['training']['loss_fn'] == 'perceptual_cross_entropy':
        loss_fn = perceptual_cross_entropy_loss(
            cfg['training']['loss_kwargs']['decoder_path'],
            cfg['training']['loss_kwargs']['bake_decodings'])
    else:
        loss_fn = functools.partial(eval(cfg['training']['loss_fn']),
                                    **cfg['training']['loss_kwargs'])

    if 'regularizer_fn' in cfg['training'] and cfg['training'][
            'regularizer_fn'] is not None:
        assert torch.cuda.device_count(
        ) <= 1, 'Regularization does not support multi GPU, unable to access model attributes from DataParallel wrapper'
        bare_model = model.module if torch.cuda.device_count() > 1 else model
        loss_fn = eval(cfg['training']['regularizer_fn'])(
            loss_fn=loss_fn,
            model=bare_model,
            **cfg['training']['regularizer_kwargs'])

    ############################################################
    # Logging
    ############################################################
    flog = tnt.logger.FileLogger(cfg['saving']['results_log_file'],
                                 overwrite=True)
    mlog = get_logger(cfg, uuid)
    mlog.add_meter('config', tnt.meter.SingletonMeter(), ptype='text')
    mlog.update_meter(cfg_to_md(cfg, uuid), meters={'config'}, phase='train')
    mlog.add_meter('input_image', tnt.meter.ValueSummaryMeter(), ptype='image')
    mlog.add_meter('decoded_image',
                   tnt.meter.ValueSummaryMeter(),
                   ptype='image')
    mlog.add_meter(f'introspect/alpha', tnt.meter.ValueSummaryMeter())
    for loss in cfg['training']['loss_list']:
        mlog.add_meter(f'losses/{loss}', tnt.meter.ValueSummaryMeter())

    # Add Classification logs
    tasks = [
        t for t in SINGLE_IMAGE_TASKS
        if len([tt for tt in cfg['training']['targets'] if t in tt]) > 0
    ]
    if 'class_object' in tasks or 'class_scene' in tasks:
        mlog.add_meter('accuracy_top1',
                       tnt.meter.ClassErrorMeter(topk=[1], accuracy=True))
        mlog.add_meter('accuracy_top5',
                       tnt.meter.ClassErrorMeter(topk=[5], accuracy=True))
        mlog.add_meter('perplexity_pred', tnt.meter.ValueSummaryMeter())
        mlog.add_meter('perplexity_label', tnt.meter.ValueSummaryMeter())
        mlog.add_meter('diagnostics/class_histogram',
                       tnt.meter.ValueSummaryMeter(),
                       ptype='histogram')
        mlog.add_meter('diagnostics/confusion_matrix',
                       tnt.meter.ValueSummaryMeter(),
                       ptype='image')

    # Add Imitation Learning logs
    if cfg['training']['targets'][0] == 'action':
        mlog.add_meter('diagnostics/accuracy',
                       tnt.meter.ClassErrorMeter(topk=[1], accuracy=True))
        mlog.add_meter('diagnostics/perplexity', tnt.meter.ValueSummaryMeter())
        mlog.add_meter('diagnostics/class_histogram',
                       tnt.meter.ValueSummaryMeter(),
                       ptype='histogram')
        mlog.add_meter('diagnostics/confusion_matrix',
                       tnt.meter.ValueSummaryMeter(),
                       ptype='image')

    ############################################################
    # Training
    ############################################################
    if cfg['training']['train']:
        if cfg['training']['resume_training'] and loaded_optimizer is None:
            warnings.warn(
                'resume_training is set but the optimizer is not found, reinitializing optimizer'
            )
        if cfg['training']['resume_training'] and loaded_optimizer is not None:
            optimizer = loaded_optimizer
        else:
            optimizer = eval(
                cfg['learner']['optimizer_class'])(
                    [
                        {
                            'params': [
                                param
                                for name, param in model.named_parameters()
                                if 'merge_operator' in name
                                or 'context' in name or 'alpha' in name
                            ],
                            'weight_decay':
                            0.0
                        },
                        {
                            'params': [
                                param
                                for name, param in model.named_parameters()
                                if 'merge_operator' not in name and 'context'
                                not in name and 'alpha' not in name
                            ]
                        },
                    ],
                    lr=cfg['learner']['lr'],
                    **cfg['learner']['optimizer_kwargs'])
        scheduler = None
        if cfg['learner']['lr_scheduler_method'] is not None:
            scheduler = eval(cfg['learner']['lr_scheduler_method'])(
                optimizer, **cfg['learner']['lr_scheduler_method_kwargs'])
        logger.info("Starting training...")
        context = train_model(cfg,
                              model,
                              dataloaders,
                              loss_fn,
                              optimizer,
                              start_epoch=start_epoch,
                              num_epochs=cfg['training']['num_epochs'],
                              save_epochs=cfg['saving']['save_interval'],
                              scheduler=scheduler,
                              mlog=mlog,
                              flog=flog)

    ####################
    # Final Test
    ####################
    if cfg['training']['test']:
        run_kwargs = {
            'cfg': cfg,
            'mlog': mlog,
            'flog': flog,
            'optimizer': None,
            'loss_fn': loss_fn,
            'model': model,
            'use_thread': cfg['saving']['in_background'],
        }
        context, _ = run_one_epoch(dataloader=dataloaders['val'],
                                   epoch=0,
                                   train=False,
                                   **run_kwargs)

    logger.info('Waiting up to 10 minutes for all files to save...')
    [c.join(600) for c in context]
    logger.info('All saving is finished.')
    def __init__(self, ckpt_path, config_data):
        # Load agent
        self.action_space = spaces.Discrete(3)
        if ckpt_path is not None:
            checkpoint_obj = torch.load(ckpt_path)
            start_epoch = checkpoint_obj["epoch"]
            print("Loaded learner (epoch {}) from {}".format(start_epoch, ckpt_path), flush=True)
            agent = checkpoint_obj["agent"]
        else:
            cfg = config_data['cfg']
            perception_model = eval(cfg['learner']['perception_network'])(
                cfg['learner']['num_stack'],
                **cfg['learner']['perception_network_kwargs'])
            base = NaivelyRecurrentACModule(
                perception_unit=perception_model,
                use_gru=cfg['learner']['recurrent_policy'],
                internal_state_size=cfg['learner']['internal_state_size'])
            actor_critic = PolicyWithBase(
                base, self.action_space,
                num_stack=cfg['learner']['num_stack'],
                takeover=None)
            if cfg['learner']['use_replay']:
                agent = PPOReplay(actor_critic,
                                                cfg['learner']['clip_param'],
                                                cfg['learner']['ppo_epoch'],
                                                cfg['learner']['num_mini_batch'],
                                                cfg['learner']['value_loss_coef'],
                                                cfg['learner']['entropy_coef'],
                                                cfg['learner']['on_policy_epoch'],
                                                cfg['learner']['off_policy_epoch'],
                                                lr=cfg['learner']['lr'],
                                                eps=cfg['learner']['eps'],
                                                max_grad_norm=cfg['learner']['max_grad_norm'])
            else:
                agent = PPO(actor_critic,
                                          cfg['learner']['clip_param'],
                                          cfg['learner']['ppo_epoch'],
                                          cfg['learner']['num_mini_batch'],
                                          cfg['learner']['value_loss_coef'],
                                          cfg['learner']['entropy_coef'],
                                          lr=cfg['learner']['lr'],
                                          eps=cfg['learner']['eps'],
                                          max_grad_norm=cfg['learner']['max_grad_norm'])
            weights_path = cfg['eval_kwargs']['weights_only_path']
            ckpt = torch.load(weights_path)
            agent.actor_critic.load_state_dict(ckpt['state_dict'])
            agent.optimizer = ckpt['optimizer']
        self.actor_critic = agent.actor_critic

        self.takeover_policy = None
        if config_data['cfg']['learner']['backout']['use_backout']:
            backout_type = config_data['cfg']['learner']['backout']['backout_type']
            if backout_type == 'hardcoded':
                self.takeover_policy = BackoutPolicy(
                    patience=config_data['cfg']['learner']['backout']['patience'],
                    num_processes=1,
                    unstuck_dist=config_data['cfg']['learner']['backout']['unstuck_dist'],
                    randomize_actions=config_data['cfg']['learner']['backout']['randomize_actions'],
                )
            elif backout_type == 'trained':
                backout_ckpt =config_data['cfg']['learner']['backout']['backout_ckpt_path']
                assert backout_ckpt is not None, 'need a checkpoint to use a trained backout'
                backout_checkpoint_obj = torch.load(backout_ckpt)
                backout_start_epoch = backout_checkpoint_obj["epoch"]
                print("Loaded takeover policy at (epoch {}) from {}".format(backout_start_epoch, backout_ckpt), flush=True)
                backout_policy = checkpoint_obj["agent"].actor_critic

                self.takeover_policy = TrainedBackoutPolicy(
                    patience=config_data['cfg']['learner']['backout']['patience'],
                    num_processes=1,
                    policy=backout_policy,
                    unstuck_dist=config_data['cfg']['learner']['backout']['unstuck_dist'],
                    num_takeover_steps=config_data['cfg']['learner']['backout']['num_takeover_steps'],
                )
            else:
                assert False, f'do not recognize backout type {backout_type}'
        self.actor_critic.takeover = self.takeover_policy

        self.validator = None
        if config_data['cfg']['learner']['validator']['use_validator']:
            validator_type = config_data['cfg']['learner']['validator']['validator_type']
            if validator_type == 'jerk':
                self.validator = JerkAvoidanceValidator()
            else:
                assert False, f'do not recognize validator {validator_type}'
        self.actor_critic.action_validator = self.validator

        # Set up spaces
        self.target_dim = config_data['cfg']['env']['env_specific_kwargs']['target_dim']

        map_dim = None
        self.omap = None
        if config_data['cfg']['env']['use_map']:
            self.map_kwargs = config_data['cfg']['env']['habitat_map_kwargs']
            map_dim = 84
            assert self.map_kwargs['map_building_size'] > 0, 'If we are using map in habitat, please set building size to be positive!'

        obs_space = get_obs_space(image_dim=256, target_dim=self.target_dim, map_dim=map_dim)

        preprocessing_fn_pre_agg = eval(config_data['cfg']['env']['transform_fn_pre_aggregation'])
        self.transform_pre_agg, obs_space = preprocessing_fn_pre_agg(obs_space)

        preprocessing_fn_post_agg = eval(config_data['cfg']['env']['transform_fn_post_aggregation'])
        self.transform_post_agg, obs_space = preprocessing_fn_post_agg(obs_space)

        self.current_obs = StackedSensorDictStorage(1,
                                               config_data['cfg']['learner']['num_stack'],
                                               {k: v.shape for k, v in obs_space.spaces.items()
                                                if k in config_data['cfg']['env']['sensors']})
        print(f'Stacked obs shape {self.current_obs.obs_shape}')

        self.current_obs = self.current_obs.cuda()
        self.actor_critic.cuda()

        self.hidden_size = config_data['cfg']['learner']['internal_state_size']
        self.test_recurrent_hidden_states = None
        self.not_done_masks = None

        self.episode_rgbs = []
        self.episode_pgs = []
        self.episode_entropy = []
        self.episode_num = 0
        self.t = 0
        self.episode_lengths = []
        self.episode_values = []
        self.last_action = None

        # Set up logging
        if config_data['cfg']['saving']['logging_type'] == 'visdom':
            self.mlog = tnt.logger.VisdomMeterLogger(
                title=config_data['uuid'], env=config_data['uuid'], server=config_data['cfg']['saving']['visdom_server'],
                port=config_data['cfg']['saving']['visdom_port'],
                log_to_filename=config_data['cfg']['saving']['visdom_log_file']
            )
            self.use_visdom = True
        elif config_data['cfg']['saving']['logging_type'] == 'tensorboard':
            self.mlog = tnt.logger.TensorboardMeterLogger(
                env=config_data['uuid'],
                log_dir=config_data['cfg']['saving']['log_dir'],
                plotstylecombined=True
            )
            self.use_visdom = False
        else:
            assert False, 'no proper logger!'

        self.log_dir = config_data['cfg']['saving']['log_dir']
        self.save_eval_videos = config_data['cfg']['saving']['save_eval_videos']
        self.mlog.add_meter('config', tnt.meter.SingletonMeter(), ptype='text')
        self.mlog.update_meter(cfg_to_md(config_data['cfg'], config_data['uuid']), meters={'config'}, phase='val')