def prepare(self,
                ckpt_dir: str,
                optimizer: str = 'lars',
                learning_rate: float = 1.0,
                weight_decay: float = 0.0,
                temperature: float = 0.07,
                distributed: bool = False,
                local_rank: int = 0,
                mixed_precision: bool = True,
                **kwargs):  # pylint: disable=unused-argument
        """Prepare training."""

        # Distributed training (optional)
        if distributed:
            self.backbone = DistributedDataParallel(
                nn.SyncBatchNorm.convert_sync_batchnorm(self.backbone).to(local_rank),
                device_ids=[local_rank]
            )
            self.projector = DistributedDataParallel(
                nn.SyncBatchNorm.convert_sync_batchnorm(self.projector).to(local_rank),
                device_ids=[local_rank]
            )
        else:
            self.backbone.to(local_rank)
            self.projector.to(local_rank)

        # Mixed precision training (optional)
        self.scaler = torch.cuda.amp.GradScaler() if mixed_precision else None

        # Optimization
        self.optimizer = get_optimizer(
            params=[
                {'params': self.backbone.parameters()},
                {'params': self.projector.parameters()},
            ],
            name=optimizer,
            lr=learning_rate,
            weight_decay=weight_decay
        )

        # Loss function
        self.loss_function = SimCLRLoss(
            temperature=temperature,
            distributed=distributed,
            local_rank=local_rank
        )

        # TensorBoard
        self.writer = SummaryWriter(ckpt_dir) if local_rank == 0 else None

        self.ckpt_dir = ckpt_dir                # pylint: disable=attribute-defined-outside-init
        self.distributed = distributed          # pylint: disable=attribute-defined-outside-init
        self.local_rank = local_rank            # pylint: disable=attribute-defined-outside-init
        self.mixed_precision = mixed_precision  # pylint: disable=attribute-defined-outside-init
        self.prepared = True                    # pylint: disable=attribute-defined-outside-init
Beispiel #2
0
def main(args):
    """Main function."""

    # 1. Configurations
    torch.backends.cudnn.benchmark = True
    ENCODER_CONFIGS, DECODER_CONFIGS, Config, Encoder, Decoder = \
        AVAILABLE_MODELS[args.backbone_type]

    config = Config(args)
    config.save()

    logfile = os.path.join(config.checkpoint_dir, 'main.log')
    logger = get_logger(stream=False, logfile=logfile)

    # 2. Data
    input_transform = get_transform(config.data,
                                    size=config.input_size,
                                    mode=config.augmentation,
                                    noise=config.noise)
    target_transform = get_transform(config.data,
                                     size=config.input_size,
                                     mode='test')
    if config.data == 'wm811k':
        train_set = torch.utils.data.ConcatDataset([
            WM811KForDenoising('./data/wm811k/unlabeled/train/',
                               input_transform, target_transform),
            WM811KForDenoising('./data/wm811k/labeled/train/', input_transform,
                               target_transform),
        ])
        valid_set = torch.utils.data.ConcatDataset([
            WM811KForDenoising('./data/wm811k/unlabeled/valid/',
                               input_transform, target_transform),
            WM811KForDenoising('./data/wm811k/labeled/valid/', input_transform,
                               target_transform),
        ])
        test_set = torch.utils.data.ConcatDataset([
            WM811KForDenoising('./data/wm811k/unlabeled/test/',
                               input_transform, target_transform),
            WM811KForDenoising('./data/wm811k/labeled/test/', input_transform,
                               target_transform),
        ])
    else:
        raise ValueError(
            f"Denoising only supports 'wm811k' data. Received '{config.data}'."
        )

    # 3. Model
    encoder = Encoder(RESNET_ENCODER_CONFIGS[config.backbone_config],
                      in_channels=IN_CHANNELS[config.data])
    decoder = Decoder(RESNET_DECODER_CONFIGS[config.backbone_config],
                      input_shape=encoder.output_shape,
                      output_shape=(OUT_CHANNELS[config.data],
                                    config.input_size, config.input_size))

    # 4. Optimization
    params = [{
        'params': encoder.parameters()
    }, {
        'params': decoder.parameters()
    }]
    optimizer = get_optimizer(params=params,
                              name=config.optimizer,
                              lr=config.learning_rate,
                              weight_decay=config.weight_decay,
                              momentum=config.momentum)
    scheduler = get_scheduler(optimizer=optimizer,
                              name=config.scheduler,
                              epochs=config.epochs,
                              milestone=config.milestone,
                              warmup_steps=config.warmup_steps)

    # 5. Experiment (Denoising)
    experiment_kwargs = {
        'encoder': encoder,
        'decoder': decoder,
        'optimizer': optimizer,
        'scheduler': scheduler,
        'loss_function': nn.CrossEntropyLoss(reduction='mean'),
        'metrics': None,
        'checkpoint_dir': config.checkpoint_dir,
        'write_summary': config.write_summary,
    }
    experiment = Denoising(**experiment_kwargs)

    # 6. Run (train, evaluate, and test model)
    run_kwargs = {
        'train_set': train_set,
        'valid_set': valid_set,
        'test_set': test_set,
        'epochs': config.epochs,
        'batch_size': config.batch_size,
        'num_workers': config.num_workers,
        'device': config.device,
        'logger': logger,
        'save_every': config.save_every,
    }

    logger.info(f"Data: {config.data}")
    logger.info(f"Augmentation: {config.augmentation}")
    logger.info(
        f"Train : Valid : Test = {len(train_set):,} : {len(valid_set):,} : {len(test_set):,}"
    )
    logger.info(
        f"Trainable parameters ({encoder.__class__.__name__}): {encoder.num_parameters:,}"
    )
    logger.info(
        f"Trainable parameters ({decoder.__class__.__name__}): {decoder.num_parameters:,}"
    )
    logger.info(f"Saving model checkpoints to: {experiment.checkpoint_dir}")
    logger.info(
        f"Epochs: {run_kwargs['epochs']}, Batch size: {run_kwargs['batch_size']}"
    )
    logger.info(
        f"Workers: {run_kwargs['num_workers']}, Device: {run_kwargs['device']}"
    )

    steps_per_epoch = len(train_set) // config.batch_size + 1
    logger.info(f"Training steps per epoch: {steps_per_epoch:,}")
    logger.info(
        f"Total number of training iterations: {steps_per_epoch * config.epochs:,}"
    )

    experiment.run(**run_kwargs)
    logger.handlers.clear()
Beispiel #3
0
def main():
    print("---------------------")
    print("Actions")
    print("STOP", HabitatSimActions.STOP)
    print("FORWARD", HabitatSimActions.MOVE_FORWARD)
    print("LEFT", HabitatSimActions.TURN_LEFT)
    print("RIGHT", HabitatSimActions.TURN_RIGHT)

    log_dir = "{}/models/{}/".format(args.dump_location, args.exp_name)
    dump_dir = "{}/dump/{}/".format(args.dump_location, args.exp_name)
    tb_dir = log_dir + "tensorboard"
    if not os.path.exists(tb_dir): os.makedirs(tb_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists("{}/images/".format(dump_dir)):
        os.makedirs("{}/images/".format(dump_dir))
    logging.basicConfig(
        filename=log_dir + 'train.log',
        level=logging.INFO)
    print("Dumping at {}".format(log_dir))
    print("Arguments starting with ", args)
    logging.info(args)
    device = args.device = torch.device("cuda:0" if args.cuda else "cpu")
    # Logging and loss variables
    num_scenes = args.num_processes
    num_episodes = int(args.num_episodes)

    # setting up rewards and losses
    # policy_loss = 0
    best_cost = float('inf')
    costs = deque(maxlen=1000)
    exp_costs = deque(maxlen=1000)
    pose_costs = deque(maxlen=1000)
    l_masks = torch.zeros(num_scenes).float().to(device)
    # best_local_loss = np.inf
    # if args.eval:
    #     traj_lengths = args.max_episode_length // args.num_local_steps
    # l_action_losses = deque(maxlen=1000)
    print("Setup rewards")

    print("starting envrionments ...")
    # Starting environments
    torch.set_num_threads(1)
    envs = make_vec_envs(args)
    obs, infos = envs.reset()
    print("environments reset")

    # show_gpu_usage()
    # Initialize map variables
    ### Full map consists of 4 channels containing the following:
    ### 1. Obstacle Map
    ### 2. Exploread Area
    ### 3. Current Agent Location
    ### 4. Past Agent Locations
    print("creating maps and poses ")
    torch.set_grad_enabled(False)
    # Calculating full and local map sizes
    map_size = args.map_size_cm // args.map_resolution
    full_w, full_h = map_size, map_size
    local_w, local_h = int(full_w / args.global_downscaling), \
                       int(full_h / args.global_downscaling)
    # Initializing full and local map
    full_map = torch.zeros(num_scenes, 4, full_w, full_h).float().to(device)
    local_map = torch.zeros(num_scenes, 4, local_w, local_h).float().to(device)
    # Initial full and local pose
    full_pose = torch.zeros(num_scenes, 3).float().to(device)
    local_pose = torch.zeros(num_scenes, 3).float().to(device)
    # Origin of local map
    origins = np.zeros((num_scenes, 3))
    # Local Map Boundaries
    lmb = np.zeros((num_scenes, 4)).astype(int)
    ### Planner pose inputs has 7 dimensions
    ### 1-3 store continuous global agent location
    ### 4-7 store local map boundaries
    planner_pose_inputs = np.zeros((num_scenes, 7))

    # show_gpu_usage()
    start_full_pose = np.zeros(3)
    start_full_pose[:2] = args.map_size_cm / 100.0 / 2.0

    def init_map_and_pose():
        full_map.fill_(0.)
        full_pose.fill_(0.)
        full_pose[:, :2] = args.map_size_cm / 100.0 / 2.0

        full_pose_np = full_pose.cpu().numpy()
        planner_pose_inputs[:, :3] = full_pose_np
        for e in range(num_scenes):
            r, c = full_pose_np[e, 1], full_pose_np[e, 0]
            loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                            int(c * 100.0 / args.map_resolution)]

            full_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.0

            lmb[e] = get_local_map_boundaries((loc_r, loc_c),
                                              (local_w, local_h),
                                              (full_w, full_h))

            planner_pose_inputs[e, 3:] = lmb[e]
            origins[e] = [lmb[e][2] * args.map_resolution / 100.0,
                          lmb[e][0] * args.map_resolution / 100.0, 0.]
        for e in range(num_scenes):
            local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]]
            local_pose[e] = full_pose[e] - \
                            torch.from_numpy(origins[e]).to(device).float()

    init_map_and_pose()
    print("maps and poses intialized")

    print("defining architecture")
    # slam
    nslam_module = Neural_SLAM_Module(args).to(device)
    slam_optimizer = get_optimizer(nslam_module.parameters(), args.slam_optimizer)
    slam_memory = FIFOMemory(args.slam_memory_size)

    # # Local policy
    # print("policy observation space", envs.observation_space.spaces['rgb'])
    # print("policy action space ", envs.action_space)
    # l_observation_space = gym.spaces.Box(0, 255,
    #                                      (3,
    #                                       args.frame_width,
    #                                       args.frame_width), dtype='uint8')
    # # todo change this to use envs.observation_space.spaces['rgb'].shape later
    # l_policy = Local_IL_Policy(l_observation_space.shape, envs.action_space.n,
    #                            recurrent=args.use_recurrent_local,
    #                            hidden_size=args.local_hidden_size,
    #                            deterministic=args.use_deterministic_local).to(device)
    # local_optimizer = get_optimizer(l_policy.parameters(), args.local_optimizer)
    # show_gpu_usage()

    print("loading model weights")
    # Loading model
    if args.load_slam != "0":
        print("Loading slam {}".format(args.load_slam))
        state_dict = torch.load(args.load_slam,
                                map_location=lambda storage, loc: storage)
        nslam_module.load_state_dict(state_dict)
    if not args.train_slam:
        nslam_module.eval()

    #     if args.load_local != "0":
    #         print("Loading local {}".format(args.load_local))
    #         state_dict = torch.load(args.load_local,
    #                                 map_location=lambda storage, loc: storage)
    #         l_policy.load_state_dict(state_dict)
    #     if not args.train_local:
    #         l_policy.eval()

    print("predicting first pose and initializing maps")
    # if not (args.use_gt_pose and args.use_gt_map):
    # delta_pose is the expected change in pose when action is applied at
    # the current pose in the absence of noise.
    # initially no action is applied so this is zero.
    delta_poses = torch.from_numpy(np.zeros(local_pose.shape)).float().to(device)
    # initial estimate for local pose and local map from first observation,
    # initialized (zero) pose and map
    _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
        nslam_module(obs, obs, delta_poses, local_map[:, 0, :, :],
                     local_map[:, 1, :, :], local_pose)
    # if args.use_gt_pose:
    #     # todo update local_pose here
    #     full_pose = envs.get_gt_pose()
    #     for e in range(num_scenes):
    #         local_pose[e] = full_pose[e] - \
    #                         torch.from_numpy(origins[e]).to(device).float()
    # if args.use_gt_map:
    #     full_map = envs.get_gt_map()
    #     for e in range(num_scenes):
    #         local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]]
    print("slam module returned pose and maps")

    # NOT NEEDED : 4/29
    local_pose_np = local_pose.cpu().numpy()
    # update local map for each scene - input for planner
    for e in range(num_scenes):
        r, c = local_pose_np[e, 1], local_pose_np[e, 0]
        loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                        int(c * 100.0 / args.map_resolution)]
        local_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.

    #     # todo get goal from env here
    global_goals = envs.get_goal_coords().int()

    # Compute planner inputs
    planner_inputs = [{} for e in range(num_scenes)]
    for e, p_input in enumerate(planner_inputs):
        p_input['goal'] = global_goals[e].detach().cpu().numpy()
        p_input['map_pred'] = local_map[e, 0, :, :].detach().cpu().numpy()
        p_input['exp_pred'] = local_map[e, 1, :, :].detach().cpu().numpy()
        p_input['pose_pred'] = planner_pose_inputs[e]

    # Output stores local goals as well as the the ground-truth action
    planner_out = envs.get_short_term_goal(planner_inputs)
    # planner output contains:
    # Distance to short term goal - positive discretized number
    # angle to short term goal -  angle -180 to 180 but in buckets of 5 degrees so multiply by 5 to ge true angle
    # GT action - action to be taken according to planner (int)

    # going to step through the episodes, so cache previous information
    last_obs = obs.detach()
    local_rec_states = torch.zeros(num_scenes, args.local_hidden_size).to(device)
    start = time.time()
    total_num_steps = -1
    torch.set_grad_enabled(False)

    print("starting episodes")
    with TensorboardWriter(
            tb_dir, flush_secs=60
    ) as writer:
        for itr_counter, ep_num in enumerate(range(num_episodes)):
            print("------------------------------------------------------")
            print("Episode", ep_num)

            # if itr_counter >= 20:
            #     print("DONE WE FIXED IT")
            #     die()
            # for step in range(args.max_episode_length):
            step_bar = tqdm(range(args.max_episode_length))
            for step in step_bar:
                # print("------------------------------------------------------")
                # print("episode ", ep_num, "step ", step)
                total_num_steps += 1
                l_step = step % args.num_local_steps

                # Local Policy
                # ------------------------------------------------------------------
                # cache previous information
                del last_obs
                last_obs = obs.detach()
                #             if not args.use_optimal_policy and not args.use_shortest_path_gt:
                #                 local_masks = l_masks
                #                 local_goals = planner_out[:, :-1].to(device).long()

                #                 if args.train_local:
                #                     torch.set_grad_enabled(True)

                #                 # local policy "step"
                #                 action, action_prob, local_rec_states = l_policy(
                #                     obs,
                #                     local_rec_states,
                #                     local_masks,
                #                     extras=local_goals,
                #                 )

                #                 if args.train_local:
                #                     action_target = planner_out[:, -1].long().to(device)
                #                     # doubt: this is probably wrong? one is action probability and the other is action
                #                     policy_loss += nn.CrossEntropyLoss()(action_prob, action_target)
                #                     torch.set_grad_enabled(False)
                #                 l_action = action.cpu()
                #             else:
                #                 if args.use_optimal_policy:
                #                     l_action = planner_out[:, -1]
                #                 else:
                #                     l_action = envs.get_optimal_gt_action()

                l_action = envs.get_optimal_action(start_full_pose, full_pose).cpu()
                # if step > 10:
                #     l_action = torch.tensor([HabitatSimActions.STOP])

                # ------------------------------------------------------------------
                # ------------------------------------------------------------------
                # Env step
                # print("stepping with action ", l_action)
                # try:
                obs, rew, done, infos = envs.step(l_action)

                # ------------------------------------------------------------------
                # Reinitialize variables when episode ends
                # doubt what if episode ends before max_episode_length?
                # maybe add (or done ) here?
                if l_action == HabitatSimActions.STOP or step == args.max_episode_length - 1:
                    print("l_action", l_action)
                    init_map_and_pose()
                    del last_obs
                    last_obs = obs.detach()
                    print("Reinitialize since at end of episode ")
                    obs, infos = envs.reset()

                # except:
                #     print("can't do that")
                #     print(l_action)
                #     init_map_and_pose()
                #     del last_obs
                #     last_obs = obs.detach()
                #     print("Reinitialize since at end of episode ")
                #     break
                # step_bar.set_description("rew, done, info-sensor_pose, pose_err (stepping) {}, {}, {}, {}".format(rew, done, infos[0]['sensor_pose'], infos[0]['pose_err']))
                if total_num_steps % args.log_interval == 0 and False:
                    print("rew, done, info-sensor_pose, pose_err after stepping ", rew, done, infos[0]['sensor_pose'],
                          infos[0]['pose_err'])
                # l_masks = torch.FloatTensor([0 if x else 1
                #                              for x in done]).to(device)

                # ------------------------------------------------------------------
                # # ------------------------------------------------------------------
                # # Reinitialize variables when episode ends
                # # doubt what if episode ends before max_episode_length?
                # # maybe add (or done ) here?
                # if step == args.max_episode_length - 1 or l_action == HabitatSimActions.STOP:  # Last episode step
                #     init_map_and_pose()
                #     del last_obs
                #     last_obs = obs.detach()
                #     print("Reinitialize since at end of episode ")
                #     break

                # ------------------------------------------------------------------
                # ------------------------------------------------------------------
                # Neural SLAM Module
                delta_poses_np = np.zeros(local_pose_np.shape)
                if args.train_slam:
                    # Add frames to memory
                    for env_idx in range(num_scenes):
                        env_obs = obs[env_idx].to("cpu")
                        env_poses = torch.from_numpy(np.asarray(
                            delta_poses_np[env_idx]
                        )).float().to("cpu")
                        env_gt_fp_projs = torch.from_numpy(np.asarray(
                            infos[env_idx]['fp_proj']
                        )).unsqueeze(0).float().to("cpu")
                        env_gt_fp_explored = torch.from_numpy(np.asarray(
                            infos[env_idx]['fp_explored']
                        )).unsqueeze(0).float().to("cpu")
                        # TODO change pose err here
                        env_gt_pose_err = torch.from_numpy(np.asarray(
                            infos[env_idx]['pose_err']
                        )).float().to("cpu")
                        slam_memory.push(
                            (last_obs[env_idx].cpu(), env_obs, env_poses),
                            (env_gt_fp_projs, env_gt_fp_explored, env_gt_pose_err))
                        delta_poses_np[env_idx] = get_delta_pose(local_pose_np[env_idx], l_action[env_idx])
                delta_poses = torch.from_numpy(delta_poses_np).float().to(device)
                # print("delta pose from SLAM ", delta_poses)
                _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
                    nslam_module(last_obs, obs, delta_poses, local_map[:, 0, :, :],
                                 local_map[:, 1, :, :], local_pose, build_maps=True)
                # print("updated local pose from SLAM ", local_pose)
                # if args.use_gt_pose:
                #     # todo update local_pose here
                #     full_pose = envs.get_gt_pose()
                #     for e in range(num_scenes):
                #         local_pose[e] = full_pose[e] - \
                #                         torch.from_numpy(origins[e]).to(device).float()
                #     print("updated local pose from gt ", local_pose)
                # if args.use_gt_map:
                #     full_map = envs.get_gt_map()
                #     for e in range(num_scenes):
                #         local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]]
                #     print("updated local map from gt")
                local_pose_np = local_pose.cpu().numpy()
                planner_pose_inputs[:, :3] = local_pose_np + origins
                local_map[:, 2, :, :].fill_(0.)  # Resetting current location channel
                for e in range(num_scenes):
                    r, c = local_pose_np[e, 1], local_pose_np[e, 0]
                    loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                                    int(c * 100.0 / args.map_resolution)]
                    local_map[e, 2:, loc_r - 2:loc_r + 3, loc_c - 2:loc_c + 3] = 1.
                if l_step == args.num_local_steps - 1:
                    # For every global step, update the full and local maps
                    for e in range(num_scenes):
                        full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]] = \
                            local_map[e]
                        full_pose[e] = local_pose[e] + \
                                       torch.from_numpy(origins[e]).to(device).float()

                        full_pose_np = full_pose[e].cpu().numpy()
                        r, c = full_pose_np[1], full_pose_np[0]
                        loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                                        int(c * 100.0 / args.map_resolution)]

                        lmb[e] = get_local_map_boundaries((loc_r, loc_c),
                                                          (local_w, local_h),
                                                          (full_w, full_h))

                        planner_pose_inputs[e, 3:] = lmb[e]
                        origins[e] = [lmb[e][2] * args.map_resolution / 100.0,
                                      lmb[e][0] * args.map_resolution / 100.0, 0.]

                        local_map[e] = full_map[e, :,
                                       lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]]
                        local_pose[e] = full_pose[e] - \
                                        torch.from_numpy(origins[e]).to(device).float()

                local_pose_np = local_pose.cpu().numpy()
                planner_pose_inputs[:, :3] = local_pose_np + origins
                local_map[:, 2, :, :].fill_(0.)  # Resetting current location channel
                for e in range(num_scenes):
                    r, c = local_pose_np[e, 1], local_pose_np[e, 0]
                    loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                                    int(c * 100.0 / args.map_resolution)]
                    local_map[e, 2:, loc_r - 2:loc_r + 3, loc_c - 2:loc_c + 3] = 1.

                planner_inputs = [{} for e in range(num_scenes)]
                for e, p_input in enumerate(planner_inputs):
                    p_input['map_pred'] = local_map[e, 0, :, :].cpu().numpy()
                    p_input['exp_pred'] = local_map[e, 1, :, :].cpu().numpy()
                    p_input['pose_pred'] = planner_pose_inputs[e]
                    p_input['goal'] = global_goals[e].cpu().numpy()
                planner_out = envs.get_short_term_goal(planner_inputs)

                ### TRAINING
                torch.set_grad_enabled(True)
                # ------------------------------------------------------------------
                # Train Neural SLAM Module
                if args.train_slam and len(slam_memory) > args.slam_batch_size:
                    for _ in range(args.slam_iterations):
                        inputs, outputs = slam_memory.sample(args.slam_batch_size)
                        b_obs_last, b_obs, b_poses = inputs
                        gt_fp_projs, gt_fp_explored, gt_pose_err = outputs

                        b_obs = b_obs.to(device)
                        b_obs_last = b_obs_last.to(device)
                        b_poses = b_poses.to(device)

                        gt_fp_projs = gt_fp_projs.to(device)
                        gt_fp_explored = gt_fp_explored.to(device)
                        gt_pose_err = gt_pose_err.to(device)

                        b_proj_pred, b_fp_exp_pred, _, _, b_pose_err_pred, _ = \
                            nslam_module(b_obs_last, b_obs, b_poses,
                                         None, None, None,
                                         build_maps=False)
                        loss = 0
                        if args.proj_loss_coeff > 0:
                            proj_loss = F.binary_cross_entropy(b_proj_pred,
                                                               gt_fp_projs)
                            costs.append(proj_loss.item())
                            loss += args.proj_loss_coeff * proj_loss

                        if args.exp_loss_coeff > 0:
                            exp_loss = F.binary_cross_entropy(b_fp_exp_pred,
                                                              gt_fp_explored)
                            exp_costs.append(exp_loss.item())
                            loss += args.exp_loss_coeff * exp_loss

                        if args.pose_loss_coeff > 0:
                            pose_loss = torch.nn.MSELoss()(b_pose_err_pred,
                                                           gt_pose_err)
                            pose_costs.append(args.pose_loss_coeff *
                                              pose_loss.item())
                            loss += args.pose_loss_coeff * pose_loss

                        if args.train_slam:
                            slam_optimizer.zero_grad()
                            loss.backward()
                            slam_optimizer.step()

                        del b_obs_last, b_obs, b_poses
                        del gt_fp_projs, gt_fp_explored, gt_pose_err
                        del b_proj_pred, b_fp_exp_pred, b_pose_err_pred

                # ------------------------------------------------------------------

                # ------------------------------------------------------------------
                # Train Local Policy
                # if (l_step + 1) % args.local_policy_update_freq == 0 \
                #         and args.train_local:
                #     local_optimizer.zero_grad()
                #     policy_loss.backward()
                #     local_optimizer.step()
                #     l_action_losses.append(policy_loss.item())
                #     policy_loss = 0
                #     local_rec_states = local_rec_states.detach_()
                # ------------------------------------------------------------------

                # Finish Training
                torch.set_grad_enabled(False)
                # ------------------------------------------------------------------

                # ------------------------------------------------------------------
                # Logging
                writer.add_scalar("SLAM_Loss_Proj", np.mean(costs), total_num_steps)
                writer.add_scalar("SLAM_Loss_Exp", np.mean(exp_costs), total_num_steps)
                writer.add_scalar("SLAM_Loss_Pose", np.mean(pose_costs), total_num_steps)

                gettime = lambda: str(datetime.now()).split('.')[0]
                if total_num_steps % args.log_interval == 0:
                    end = time.time()
                    time_elapsed = time.gmtime(end - start)
                    log = " ".join([
                        "Time: {0:0=2d}d".format(time_elapsed.tm_mday - 1),
                        "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)),
                        gettime(),
                        "num timesteps {},".format(total_num_steps *
                                                   num_scenes),
                        "FPS {},".format(int(total_num_steps * num_scenes \
                                             / (end - start)))
                    ])

                    log += "\n\tLosses:"

                    # if args.train_local and len(l_action_losses) > 0:
                    #     log += " ".join([
                    #         " Local Loss:",
                    #         "{:.3f},".format(
                    #             np.mean(l_action_losses))
                    #     ])

                    if args.train_slam and len(costs) > 0:
                        log += " ".join([
                            " SLAM Loss proj/exp/pose:"
                            "{:.4f}/{:.4f}/{:.4f}".format(
                                np.mean(costs),
                                np.mean(exp_costs),
                                np.mean(pose_costs))
                        ])

                    print(log)
                    logging.info(log)
                # ------------------------------------------------------------------

                # ------------------------------------------------------------------
                # Save best models
                if (total_num_steps * num_scenes) % args.save_interval < \
                        num_scenes:

                    # Save Neural SLAM Model
                    if len(costs) >= 1000 and np.mean(costs) < best_cost \
                            and not args.eval:
                        print("Saved best model")
                        best_cost = np.mean(costs)
                        torch.save(nslam_module.state_dict(),
                                   os.path.join(log_dir, "model_best.slam"))

                    # Save Local Policy Model
                    # if len(l_action_losses) >= 100 and \
                    #         (np.mean(l_action_losses) <= best_local_loss) \
                    #         and not args.eval:
                    #     torch.save(l_policy.state_dict(),
                    #                os.path.join(log_dir, "model_best.local"))
                    #
                    #     best_local_loss = np.mean(l_action_losses)

                # Save periodic models
                if (total_num_steps * num_scenes) % args.save_periodic < \
                        num_scenes:
                    step = total_num_steps * num_scenes
                    if args.train_slam:
                        torch.save(nslam_module.state_dict(),
                                   os.path.join(dump_dir,
                                                "periodic_{}.slam".format(step)))
                    # if args.train_local:
                    #     torch.save(l_policy.state_dict(),
                    #                os.path.join(dump_dir,
                    #                             "periodic_{}.local".format(step)))
                # ------------------------------------------------------------------

                if l_action == HabitatSimActions.STOP:  # Last episode step
                    break

    # Print and save model performance numbers during evaluation
    if args.eval:
        logfile = open("{}/explored_area.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_area_log[e].shape[0]):
                logfile.write(str(explored_area_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        logfile = open("{}/explored_ratio.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_ratio_log[e].shape[0]):
                logfile.write(str(explored_ratio_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        log = "Final Exp Area: \n"
        for i in range(explored_area_log.shape[2]):
            log += "{:.5f}, ".format(
                np.mean(explored_area_log[:, :, i]))

        log += "\nFinal Exp Ratio: \n"
        for i in range(explored_ratio_log.shape[2]):
            log += "{:.5f}, ".format(
                np.mean(explored_ratio_log[:, :, i]))

        print(log)
        logging.info(log)
Beispiel #4
0
def main():
    # Setup Logging
    log_dir = "{}/models/{}/".format(args.dump_location, args.exp_name)
    dump_dir = "{}/dump/{}/".format(args.dump_location, args.exp_name)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists("{}/images/".format(dump_dir)):
        os.makedirs("{}/images/".format(dump_dir))

    logging.basicConfig(filename=log_dir + 'train.log', level=logging.INFO)
    print("Dumping at {}".format(log_dir))
    print(args)
    logging.info(args)

    # Logging and loss variables
    num_scenes = args.num_processes
    num_episodes = int(args.num_episodes)
    device = args.device = torch.device("cuda:0" if args.cuda else "cpu")
    policy_loss = 0

    best_cost = 100000
    costs = deque(maxlen=1000)
    exp_costs = deque(maxlen=1000)
    pose_costs = deque(maxlen=1000)

    g_masks = torch.ones(num_scenes).float().to(device)
    l_masks = torch.zeros(num_scenes).float().to(device)

    best_local_loss = np.inf
    best_g_reward = -np.inf

    if args.eval:
        traj_lengths = args.max_episode_length // args.num_local_steps
        explored_area_log = np.zeros((num_scenes, num_episodes, traj_lengths))
        explored_ratio_log = np.zeros((num_scenes, num_episodes, traj_lengths))

    g_episode_rewards = deque(maxlen=1000)

    l_action_losses = deque(maxlen=1000)

    g_value_losses = deque(maxlen=1000)
    g_action_losses = deque(maxlen=1000)
    g_dist_entropies = deque(maxlen=1000)

    per_step_g_rewards = deque(maxlen=1000)

    g_process_rewards = np.zeros((num_scenes))

    # Starting environments
    torch.set_num_threads(1)
    envs = make_vec_envs(args)
    obs, infos = envs.reset()

    # Initialize map variables
    ### Full map consists of 4 channels containing the following:
    ### 1. Obstacle Map
    ### 2. Exploread Area
    ### 3. Current Agent Location
    ### 4. Past Agent Locations

    torch.set_grad_enabled(False)

    # Calculating full and local map sizes
    map_size = args.map_size_cm // args.map_resolution
    full_w, full_h = map_size, map_size
    local_w, local_h = int(full_w / args.global_downscaling), \
                       int(full_h / args.global_downscaling)

    # Initializing full and local map
    full_map = torch.zeros(num_scenes, 4, full_w, full_h).float().to(device)
    local_map = torch.zeros(num_scenes, 4, local_w, local_h).float().to(device)

    # Initial full and local pose
    full_pose = torch.zeros(num_scenes, 3).float().to(device)
    local_pose = torch.zeros(num_scenes, 3).float().to(device)

    # Origin of local map
    origins = np.zeros((num_scenes, 3))

    # Local Map Boundaries
    lmb = np.zeros((num_scenes, 4)).astype(int)

    ### Planner pose inputs has 7 dimensions
    ### 1-3 store continuous global agent location
    ### 4-7 store local map boundaries
    planner_pose_inputs = np.zeros((num_scenes, 7))

    def init_map_and_pose():
        full_map.fill_(0.)
        full_pose.fill_(0.)
        full_pose[:, :2] = args.map_size_cm / 100.0 / 2.0

        locs = full_pose.cpu().numpy()
        planner_pose_inputs[:, :3] = locs
        for e in range(num_scenes):
            r, c = locs[e, 1], locs[e, 0]
            loc_r, loc_c = [
                int(r * 100.0 / args.map_resolution),
                int(c * 100.0 / args.map_resolution)
            ]

            full_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.0

            lmb[e] = get_local_map_boundaries(
                (loc_r, loc_c), (local_w, local_h), (full_w, full_h))

            planner_pose_inputs[e, 3:] = lmb[e]
            origins[e] = [
                lmb[e][2] * args.map_resolution / 100.0,
                lmb[e][0] * args.map_resolution / 100.0, 0.
            ]

        for e in range(num_scenes):
            local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1],
                                    lmb[e, 2]:lmb[e, 3]]
            local_pose[e] = full_pose[e] - \
                            torch.from_numpy(origins[e]).to(device).float()

    init_map_and_pose()

    # Global policy observation space
    g_observation_space = gym.spaces.Box(0,
                                         1, (8, local_w, local_h),
                                         dtype='uint8')

    # Global policy action space
    g_action_space = gym.spaces.Box(low=0.0,
                                    high=1.0,
                                    shape=(2, ),
                                    dtype=np.float32)

    # Local policy observation space
    l_observation_space = gym.spaces.Box(
        0, 255, (3, args.frame_width, args.frame_width), dtype='uint8')

    # Local and Global policy recurrent layer sizes
    l_hidden_size = args.local_hidden_size
    g_hidden_size = args.global_hidden_size

    # slam
    nslam_module = Neural_SLAM_Module(args).to(device)
    slam_optimizer = get_optimizer(nslam_module.parameters(),
                                   args.slam_optimizer)

    # Global policy
    g_policy = RL_Policy(g_observation_space.shape,
                         g_action_space,
                         base_kwargs={
                             'recurrent': args.use_recurrent_global,
                             'hidden_size': g_hidden_size,
                             'downscaling': args.global_downscaling
                         }).to(device)
    g_agent = algo.PPO(g_policy,
                       args.clip_param,
                       args.ppo_epoch,
                       args.num_mini_batch,
                       args.value_loss_coef,
                       args.entropy_coef,
                       lr=args.global_lr,
                       eps=args.eps,
                       max_grad_norm=args.max_grad_norm)

    # Local policy
    l_policy = Local_IL_Policy(
        l_observation_space.shape,
        envs.action_space.n,
        recurrent=args.use_recurrent_local,
        hidden_size=l_hidden_size,
        deterministic=args.use_deterministic_local).to(device)
    local_optimizer = get_optimizer(l_policy.parameters(),
                                    args.local_optimizer)

    # Storage
    g_rollouts = GlobalRolloutStorage(args.num_global_steps, num_scenes,
                                      g_observation_space.shape,
                                      g_action_space, g_policy.rec_state_size,
                                      1).to(device)

    slam_memory = FIFOMemory(args.slam_memory_size)

    # Loading model
    if args.load_slam != "0":
        print("Loading slam {}".format(args.load_slam))
        state_dict = torch.load(args.load_slam,
                                map_location=lambda storage, loc: storage)
        nslam_module.load_state_dict(state_dict)

    if not args.train_slam:
        nslam_module.eval()

    if args.load_global != "0":
        print("Loading global {}".format(args.load_global))
        state_dict = torch.load(args.load_global,
                                map_location=lambda storage, loc: storage)
        g_policy.load_state_dict(state_dict)

    if not args.train_global:
        g_policy.eval()

    if args.load_local != "0":
        print("Loading local {}".format(args.load_local))
        state_dict = torch.load(args.load_local,
                                map_location=lambda storage, loc: storage)
        l_policy.load_state_dict(state_dict)

    if not args.train_local:
        l_policy.eval()

    # Predict map from frame 1:
    poses = torch.from_numpy(
        np.asarray([
            infos[env_idx]['sensor_pose'] for env_idx in range(num_scenes)
        ])).float().to(device)

    _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
        nslam_module(obs, obs, poses, local_map[:, 0, :, :],
                     local_map[:, 1, :, :], local_pose)

    # Compute Global policy input
    locs = local_pose.cpu().numpy()
    global_input = torch.zeros(num_scenes, 8, local_w, local_h)
    global_orientation = torch.zeros(num_scenes, 1).long()

    for e in range(num_scenes):
        r, c = locs[e, 1], locs[e, 0]
        loc_r, loc_c = [
            int(r * 100.0 / args.map_resolution),
            int(c * 100.0 / args.map_resolution)
        ]

        local_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.
        global_orientation[e] = int((locs[e, 2] + 180.0) / 5.)

    global_input[:, 0:4, :, :] = local_map.detach()
    global_input[:, 4:, :, :] = nn.MaxPool2d(args.global_downscaling)(full_map)

    g_rollouts.obs[0].copy_(global_input)
    g_rollouts.extras[0].copy_(global_orientation)

    # Run Global Policy (global_goals = Long-Term Goal)
    g_value, g_action, g_action_log_prob, g_rec_states = \
        g_policy.act(
            g_rollouts.obs[0],
            g_rollouts.rec_states[0],
            g_rollouts.masks[0],
            extras=g_rollouts.extras[0],
            deterministic=False
        )

    cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
    global_goals = [[int(action[0] * local_w),
                     int(action[1] * local_h)] for action in cpu_actions]

    # Compute planner inputs
    planner_inputs = [{} for e in range(num_scenes)]
    for e, p_input in enumerate(planner_inputs):
        p_input['goal'] = global_goals[e]
        p_input['map_pred'] = global_input[e, 0, :, :].detach().cpu().numpy()
        p_input['exp_pred'] = global_input[e, 1, :, :].detach().cpu().numpy()
        p_input['pose_pred'] = planner_pose_inputs[e]

    # Output stores local goals as well as the the ground-truth action
    output = envs.get_short_term_goal(planner_inputs)

    last_obs = obs.detach()
    local_rec_states = torch.zeros(num_scenes, l_hidden_size).to(device)
    start = time.time()

    total_num_steps = -1
    g_reward = 0

    torch.set_grad_enabled(False)

    for ep_num in range(num_episodes):
        for step in range(args.max_episode_length):
            total_num_steps += 1

            g_step = (step // args.num_local_steps) % args.num_global_steps
            eval_g_step = step // args.num_local_steps + 1
            l_step = step % args.num_local_steps

            # ------------------------------------------------------------------
            # Local Policy
            del last_obs
            last_obs = obs.detach()
            local_masks = l_masks
            local_goals = output[:, :-1].to(device).long()

            if args.train_local:
                torch.set_grad_enabled(True)

            action, action_prob, local_rec_states = l_policy(
                obs,
                local_rec_states,
                local_masks,
                extras=local_goals,
            )

            if args.train_local:
                action_target = output[:, -1].long().to(device)
                policy_loss += nn.CrossEntropyLoss()(action_prob,
                                                     action_target)
                torch.set_grad_enabled(False)
            l_action = action.cpu()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Env step
            obs, rew, done, infos = envs.step(l_action)

            l_masks = torch.FloatTensor([0 if x else 1
                                         for x in done]).to(device)
            g_masks *= l_masks
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Reinitialize variables when episode ends
            if step == args.max_episode_length - 1:  # Last episode step
                init_map_and_pose()
                del last_obs
                last_obs = obs.detach()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Neural SLAM Module
            if args.train_slam:
                # Add frames to memory
                for env_idx in range(num_scenes):
                    env_obs = obs[env_idx].to("cpu")
                    env_poses = torch.from_numpy(
                        np.asarray(
                            infos[env_idx]['sensor_pose'])).float().to("cpu")
                    env_gt_fp_projs = torch.from_numpy(
                        np.asarray(infos[env_idx]['fp_proj'])).unsqueeze(
                            0).float().to("cpu")
                    env_gt_fp_explored = torch.from_numpy(
                        np.asarray(infos[env_idx]['fp_explored'])).unsqueeze(
                            0).float().to("cpu")
                    env_gt_pose_err = torch.from_numpy(
                        np.asarray(
                            infos[env_idx]['pose_err'])).float().to("cpu")
                    slam_memory.push(
                        (last_obs[env_idx].cpu(), env_obs, env_poses),
                        (env_gt_fp_projs, env_gt_fp_explored, env_gt_pose_err))

            poses = torch.from_numpy(
                np.asarray([
                    infos[env_idx]['sensor_pose']
                    for env_idx in range(num_scenes)
                ])).float().to(device)

            _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
                nslam_module(last_obs, obs, poses, local_map[:, 0, :, :],
                             local_map[:, 1, :, :], local_pose, build_maps=True)

            locs = local_pose.cpu().numpy()
            planner_pose_inputs[:, :3] = locs + origins
            local_map[:,
                      2, :, :].fill_(0.)  # Resetting current location channel
            for e in range(num_scenes):
                r, c = locs[e, 1], locs[e, 0]
                loc_r, loc_c = [
                    int(r * 100.0 / args.map_resolution),
                    int(c * 100.0 / args.map_resolution)
                ]

                local_map[e, 2:, loc_r - 2:loc_r + 3, loc_c - 2:loc_c + 3] = 1.
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Global Policy
            if l_step == args.num_local_steps - 1:
                # For every global step, update the full and local maps
                for e in range(num_scenes):
                    full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]] = \
                        local_map[e]
                    full_pose[e] = local_pose[e] + \
                                   torch.from_numpy(origins[e]).to(device).float()

                    locs = full_pose[e].cpu().numpy()
                    r, c = locs[1], locs[0]
                    loc_r, loc_c = [
                        int(r * 100.0 / args.map_resolution),
                        int(c * 100.0 / args.map_resolution)
                    ]

                    lmb[e] = get_local_map_boundaries(
                        (loc_r, loc_c), (local_w, local_h), (full_w, full_h))

                    planner_pose_inputs[e, 3:] = lmb[e]
                    origins[e] = [
                        lmb[e][2] * args.map_resolution / 100.0,
                        lmb[e][0] * args.map_resolution / 100.0, 0.
                    ]

                    local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1],
                                            lmb[e, 2]:lmb[e, 3]]
                    local_pose[e] = full_pose[e] - \
                                    torch.from_numpy(origins[e]).to(device).float()

                locs = local_pose.cpu().numpy()
                for e in range(num_scenes):
                    global_orientation[e] = int((locs[e, 2] + 180.0) / 5.)
                global_input[:, 0:4, :, :] = local_map
                global_input[:, 4:, :, :] = \
                    nn.MaxPool2d(args.global_downscaling)(full_map)

                if False:
                    for i in range(4):
                        ax[i].clear()
                        ax[i].set_yticks([])
                        ax[i].set_xticks([])
                        ax[i].set_yticklabels([])
                        ax[i].set_xticklabels([])
                        ax[i].imshow(global_input.cpu().numpy()[0, 4 + i])
                    plt.gcf().canvas.flush_events()
                    # plt.pause(0.1)
                    fig.canvas.start_event_loop(0.001)
                    plt.gcf().canvas.flush_events()

                # Get exploration reward and metrics
                g_reward = torch.from_numpy(
                    np.asarray([
                        infos[env_idx]['exp_reward']
                        for env_idx in range(num_scenes)
                    ])).float().to(device)

                if args.eval:
                    g_reward = g_reward * 50.0  # Convert reward to area in m2

                g_process_rewards += g_reward.cpu().numpy()
                g_total_rewards = g_process_rewards * \
                                  (1 - g_masks.cpu().numpy())
                g_process_rewards *= g_masks.cpu().numpy()
                per_step_g_rewards.append(np.mean(g_reward.cpu().numpy()))

                if np.sum(g_total_rewards) != 0:
                    for tr in g_total_rewards:
                        g_episode_rewards.append(tr) if tr != 0 else None

                if args.eval:
                    exp_ratio = torch.from_numpy(
                        np.asarray([
                            infos[env_idx]['exp_ratio']
                            for env_idx in range(num_scenes)
                        ])).float()

                    for e in range(num_scenes):
                        explored_area_log[e, ep_num, eval_g_step - 1] = \
                            explored_area_log[e, ep_num, eval_g_step - 2] + \
                            g_reward[e].cpu().numpy()
                        explored_ratio_log[e, ep_num, eval_g_step - 1] = \
                            explored_ratio_log[e, ep_num, eval_g_step - 2] + \
                            exp_ratio[e].cpu().numpy()

                # Add samples to global policy storage
                g_rollouts.insert(global_input, g_rec_states, g_action,
                                  g_action_log_prob, g_value, g_reward,
                                  g_masks, global_orientation)

                # Sample long-term goal from global policy
                g_value, g_action, g_action_log_prob, g_rec_states = \
                    g_policy.act(
                        g_rollouts.obs[g_step + 1],
                        g_rollouts.rec_states[g_step + 1],
                        g_rollouts.masks[g_step + 1],
                        extras=g_rollouts.extras[g_step + 1],
                        deterministic=False
                    )
                cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
                global_goals = [[
                    int(action[0] * local_w),
                    int(action[1] * local_h)
                ] for action in cpu_actions]

                g_reward = 0
                g_masks = torch.ones(num_scenes).float().to(device)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Get short term goal
            planner_inputs = [{} for e in range(num_scenes)]
            for e, p_input in enumerate(planner_inputs):
                p_input['map_pred'] = local_map[e, 0, :, :].cpu().numpy()
                p_input['exp_pred'] = local_map[e, 1, :, :].cpu().numpy()
                p_input['pose_pred'] = planner_pose_inputs[e]
                p_input['goal'] = global_goals[e]

            output = envs.get_short_term_goal(planner_inputs)
            # ------------------------------------------------------------------

            ### TRAINING
            torch.set_grad_enabled(True)
            # ------------------------------------------------------------------
            # Train Neural SLAM Module
            if args.train_slam and len(slam_memory) > args.slam_batch_size:
                for _ in range(args.slam_iterations):
                    inputs, outputs = slam_memory.sample(args.slam_batch_size)
                    b_obs_last, b_obs, b_poses = inputs
                    gt_fp_projs, gt_fp_explored, gt_pose_err = outputs

                    b_obs = b_obs.to(device)
                    b_obs_last = b_obs_last.to(device)
                    b_poses = b_poses.to(device)

                    gt_fp_projs = gt_fp_projs.to(device)
                    gt_fp_explored = gt_fp_explored.to(device)
                    gt_pose_err = gt_pose_err.to(device)

                    b_proj_pred, b_fp_exp_pred, _, _, b_pose_err_pred, _ = \
                        nslam_module(b_obs_last, b_obs, b_poses,
                                     None, None, None,
                                     build_maps=False)
                    loss = 0
                    if args.proj_loss_coeff > 0:
                        proj_loss = F.binary_cross_entropy(
                            b_proj_pred, gt_fp_projs)
                        costs.append(proj_loss.item())
                        loss += args.proj_loss_coeff * proj_loss

                    if args.exp_loss_coeff > 0:
                        exp_loss = F.binary_cross_entropy(
                            b_fp_exp_pred, gt_fp_explored)
                        exp_costs.append(exp_loss.item())
                        loss += args.exp_loss_coeff * exp_loss

                    if args.pose_loss_coeff > 0:
                        pose_loss = torch.nn.MSELoss()(b_pose_err_pred,
                                                       gt_pose_err)
                        pose_costs.append(args.pose_loss_coeff *
                                          pose_loss.item())
                        loss += args.pose_loss_coeff * pose_loss

                    if args.train_slam:
                        slam_optimizer.zero_grad()
                        loss.backward()
                        slam_optimizer.step()

                    del b_obs_last, b_obs, b_poses
                    del gt_fp_projs, gt_fp_explored, gt_pose_err
                    del b_proj_pred, b_fp_exp_pred, b_pose_err_pred

            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Train Local Policy
            if (l_step + 1) % args.local_policy_update_freq == 0 \
                    and args.train_local:
                local_optimizer.zero_grad()
                policy_loss.backward()
                local_optimizer.step()
                l_action_losses.append(policy_loss.item())
                policy_loss = 0
                local_rec_states = local_rec_states.detach_()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Train Global Policy
            if g_step % args.num_global_steps == args.num_global_steps - 1 \
                    and l_step == args.num_local_steps - 1:
                if args.train_global:
                    g_next_value = g_policy.get_value(
                        g_rollouts.obs[-1],
                        g_rollouts.rec_states[-1],
                        g_rollouts.masks[-1],
                        extras=g_rollouts.extras[-1]).detach()

                    g_rollouts.compute_returns(g_next_value, args.use_gae,
                                               args.gamma, args.tau)
                    g_value_loss, g_action_loss, g_dist_entropy = \
                        g_agent.update(g_rollouts)
                    g_value_losses.append(g_value_loss)
                    g_action_losses.append(g_action_loss)
                    g_dist_entropies.append(g_dist_entropy)
                g_rollouts.after_update()
            # ------------------------------------------------------------------

            # Finish Training
            torch.set_grad_enabled(False)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Logging
            if total_num_steps % args.log_interval == 0:
                end = time.time()
                time_elapsed = time.gmtime(end - start)
                log = " ".join([
                    "Time: {0:0=2d}d".format(time_elapsed.tm_mday - 1),
                    "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)),
                    "num timesteps {},".format(total_num_steps *
                                               num_scenes),
                    "FPS {},".format(int(total_num_steps * num_scenes \
                                         / (end - start)))
                ])

                log += "\n\tRewards:"

                if len(g_episode_rewards) > 0:
                    log += " ".join([
                        " Global step mean/med rew:",
                        "{:.4f}/{:.4f},".format(np.mean(per_step_g_rewards),
                                                np.median(per_step_g_rewards)),
                        " Global eps mean/med/min/max eps rew:",
                        "{:.3f}/{:.3f}/{:.3f}/{:.3f},".format(
                            np.mean(g_episode_rewards),
                            np.median(g_episode_rewards),
                            np.min(g_episode_rewards),
                            np.max(g_episode_rewards))
                    ])

                log += "\n\tLosses:"

                if args.train_local and len(l_action_losses) > 0:
                    log += " ".join([
                        " Local Loss:",
                        "{:.3f},".format(np.mean(l_action_losses))
                    ])

                if args.train_global and len(g_value_losses) > 0:
                    log += " ".join([
                        " Global Loss value/action/dist:",
                        "{:.3f}/{:.3f}/{:.3f},".format(
                            np.mean(g_value_losses), np.mean(g_action_losses),
                            np.mean(g_dist_entropies))
                    ])

                if args.train_slam and len(costs) > 0:
                    log += " ".join([
                        " SLAM Loss proj/exp/pose:"
                        "{:.4f}/{:.4f}/{:.4f}".format(np.mean(costs),
                                                      np.mean(exp_costs),
                                                      np.mean(pose_costs))
                    ])

                print(log)
                logging.info(log)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Save best models
            if (total_num_steps * num_scenes) % args.save_interval < \
                    num_scenes:

                # Save Neural SLAM Model
                if len(costs) >= 1000 and np.mean(costs) < best_cost \
                        and not args.eval:
                    best_cost = np.mean(costs)
                    torch.save(nslam_module.state_dict(),
                               os.path.join(log_dir, "model_best.slam"))

                # Save Local Policy Model
                if len(l_action_losses) >= 100 and \
                        (np.mean(l_action_losses) <= best_local_loss) \
                        and not args.eval:
                    torch.save(l_policy.state_dict(),
                               os.path.join(log_dir, "model_best.local"))

                    best_local_loss = np.mean(l_action_losses)

                # Save Global Policy Model
                if len(g_episode_rewards) >= 100 and \
                        (np.mean(g_episode_rewards) >= best_g_reward) \
                        and not args.eval:
                    torch.save(g_policy.state_dict(),
                               os.path.join(log_dir, "model_best.global"))
                    best_g_reward = np.mean(g_episode_rewards)

            # Save periodic models
            if (total_num_steps * num_scenes) % args.save_periodic < \
                    num_scenes:
                step = total_num_steps * num_scenes
                if args.train_slam:
                    torch.save(
                        nslam_module.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.slam".format(step)))
                if args.train_local:
                    torch.save(
                        l_policy.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.local".format(step)))
                if args.train_global:
                    torch.save(
                        g_policy.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.global".format(step)))
            # ------------------------------------------------------------------

    # Print and save model performance numbers during evaluation
    if args.eval:
        logfile = open("{}/explored_area.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_area_log[e].shape[0]):
                logfile.write(str(explored_area_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        logfile = open("{}/explored_ratio.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_ratio_log[e].shape[0]):
                logfile.write(str(explored_ratio_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        log = "Final Exp Area: \n"
        for i in range(explored_area_log.shape[2]):
            log += "{:.5f}, ".format(np.mean(explored_area_log[:, :, i]))

        log += "\nFinal Exp Ratio: \n"
        for i in range(explored_ratio_log.shape[2]):
            log += "{:.5f}, ".format(np.mean(explored_ratio_log[:, :, i]))

        print(log)
        logging.info(log)
    def prepare(self,
                ckpt_dir: str,
                optimizer: str,
                learning_rate: float = 0.01,
                weight_decay: float = 1e-4,
                cosine_warmup: int = 10,
                cosine_cycles: int = 1,
                cosine_min_lr: float = 5e-3,
                epochs: int = 1000,
                batch_size: int = 256,
                num_workers: int = 0,
                key_momentum: float = 0.999,
                distributed: bool = False,
                local_rank: int = 0,
                mixed_precision: bool = True,
                resume: str = None):
        """Prepare MoCo pre-training."""

        # Set attributes
        self.ckpt_dir = ckpt_dir                # pylint: disable=attribute-defined-outside-init
        self.epochs = epochs                    # pylint: disable=attribute-defined-outside-init
        self.batch_size = batch_size            # pylint: disable=attribute-defined-outside-init
        self.num_workers = num_workers          # pylint: disable=attribute-defined-outside-init
        self.key_momentum = key_momentum        # pylint: disable=attribute-defined-outside-init
        self.distributed = distributed          # pylint: disable=attribute-defined-outside-init
        self.local_rank = local_rank            # pylint: disable=attribute-defined-outside-init
        self.mixed_precision = mixed_precision  # pylint: disable=attribute-defined-outside-init
        self.resume = resume                    # pylint: disable=attribute-defined-outside-init

        """
        Initialize optimizer & LR scheduler.
            1. If training from scratch, optimizer states will be automatically
                created on the device of its parameters. No worries.
            2. If training from a model checkpoint, however, optimizer states must be
                configured manually using the current `local_rank`. A common approach is:
                    a) Load all model checkpoints on 'cpu'; `torch.load(ckpt, map_location='cpu')`.
                    b) Manually move all optimizer states to the appropriate device.
        """  # pylint: disable=pointless-string-statement
        self.optimizer = get_optimizer(
            params=self.net_q.parameters(),
            name=optimizer,
            lr=learning_rate,
            weight_decay=weight_decay
        )
        # Learning rate scheduling; if cosine_warmup < 0: scheduler = None.
        self.scheduler = get_cosine_scheduler(
            self.optimizer,
            epochs=self.epochs,
            warmup_steps=cosine_warmup,
            cycles=cosine_cycles,
            min_lr=cosine_min_lr,
            )

        # Resuming from previous checkpoint (optional)
        if resume is not None:
            if not os.path.exists(resume):
                raise FileNotFoundError
            self.load_model_from_checkpoint(resume)

        # Distributed training (optional, disabled by default.)
        if distributed:
            self.net_q = DistributedDataParallel(
                module=self.net_q.to(local_rank),
                device_ids=[local_rank]
            )
        else:
            self.net_q.to(local_rank)
            
        # No DDP wrapping for key encoder, as it does not have gradients
        self.net_k.to(local_rank)
        
        # Mixed precision training (optional, enabled by default.)
        self.scaler = torch.cuda.amp.GradScaler() if mixed_precision else None

        # TensorBoard
        self.writer = SummaryWriter(ckpt_dir) if local_rank == 0 else None

        # Ready to train!
        self.prepared = True
Beispiel #6
0
def main_worker(local_rank: int, config: object):

    torch.cuda.set_device(local_rank)
    if config.distributed:
        raise NotImplementedError

    config.batch_size = config.batch_size // config.num_gpus_per_node
    config.num_workers = max(1, config.num_workers // config.num_gpus_per_node)

    in_channels = int(config.decouple_input) + 1

    # Model
    BACKBONE_CONFIGS, Backbone = AVAILABLE_MODELS[config.backbone_type]
    Projector = PROJECTOR_TYPES[config.projector_type]
    encoder = Backbone(BACKBONE_CONFIGS[config.backbone_config],
                       in_channels=in_channels)
    head = Projector(encoder.out_channels, config.projector_size)

    # Optimization
    params = [{'params': encoder.parameters()}, {'params': head.parameters()}]
    optimizer = get_optimizer(params=params,
                              name=config.optimizer,
                              lr=config.learning_rate,
                              weight_decay=config.weight_decay,
                              momentum=config.momentum)
    scheduler = get_scheduler(optimizer=optimizer,
                              name=config.scheduler,
                              epochs=config.epochs,
                              warmup_steps=config.warmup_steps)

    # Data
    data_kwargs = {
        'transform':
        WM811KTransform(size=config.input_size, mode='test'),
        'positive_transform':
        WM811KTransform(size=config.input_size, mode=config.augmentation),
        'decouple_input':
        config.decouple_input,
    }
    train_set = torch.utils.data.ConcatDataset([
        WM811KForWaPIRL('./data/wm811k/unlabeled/train/', **data_kwargs),
        WM811KForWaPIRL('./data/wm811k/labeled/train/', **data_kwargs),
    ])
    valid_set = torch.utils.data.ConcatDataset([
        WM811KForWaPIRL('./data/wm811k/unlabeled/valid/', **data_kwargs),
        WM811KForWaPIRL('./data/wm811k/labeled/valid/', **data_kwargs),
    ])
    test_set = torch.utils.data.ConcatDataset([
        WM811KForWaPIRL('./data/wm811k/unlabeled/test/', **data_kwargs),
        WM811KForWaPIRL('./data/wm811k/labeled/test/', **data_kwargs),
    ])

    # Experiment (WaPIRL)
    experiment_kwargs = {
        'backbone':
        encoder,
        'projector':
        head,
        'memory':
        MemoryBank(size=(len(train_set), config.projector_size),
                   device=local_rank),
        'optimizer':
        optimizer,
        'scheduler':
        scheduler,
        'loss_function':
        WaPIRLLoss(temperature=config.temperature),
        'loss_weight':
        config.loss_weight,
        'num_negatives':
        config.num_negatives,
        'distributed':
        config.distributed,
        'local_rank':
        local_rank,
        'metrics': {
            'top@1': TopKAccuracy(num_classes=1 + config.num_negatives, k=1),
            'top@5': TopKAccuracy(num_classes=1 + config.num_negatives, k=5)
        },
        'checkpoint_dir':
        config.checkpoint_dir,
        'write_summary':
        config.write_summary,
    }
    experiment = WaPIRL(**experiment_kwargs)

    if local_rank == 0:
        logfile = os.path.join(config.checkpoint_dir, 'main.log')
        logger = get_logger(stream=False, logfile=logfile)
        logger.info(f"Data: {config.data}")
        logger.info(f"Augmentation: {config.augmentation}")
        logger.info(f"Observations: {len(train_set):,}")
        logger.info(
            f"Trainable parameters ({encoder.__class__.__name__}): {encoder.num_parameters:,}"
        )
        logger.info(
            f"Trainable parameters ({head.__class__.__name__}): {head.num_parameters:,}"
        )
        logger.info(
            f"Projection head: {config.projector_type} ({config.projector_size})"
        )
        logger.info(f"Checkpoint directory: {config.checkpoint_dir}")
    else:
        logger = None

    # Train (WaPIRL)
    run_kwargs = {
        'train_set': train_set,
        'valid_set': valid_set,
        'epochs': config.epochs,
        'batch_size': config.batch_size,
        'num_workers': config.num_workers,
        'logger': logger,
        'save_every': config.save_every,
    }
    experiment.run(**run_kwargs)

    if logger is not None:
        logger.handlers.clear()
    def prepare(self,
                ckpt_dir: str,
                optimizer: str,
                learning_rate: float,
                weight_decay: float,
                cosine_warmup: int = 0,
                cosine_cycles: int = 1,
                cosine_min_lr: float = 5e-3,
                epochs: int = 2000,
                batch_size: int = 256,
                num_workers: int = 4,
                key_momentum: float = 0.999,
                pseudo_momentum: float = 0.5,
                threshold: float = 0.5,
                ramp_up: int = 50,
                distributed: bool = False,
                local_rank: int = 0,
                mixed_precision: bool = True,
                resume: str = None):
        """
        Initialize settings needed for model training.
        """

        # Set attributes
        self.ckpt_dir = ckpt_dir  # pylint: disable=attribute-defined-outside-init
        self.epochs = epochs  # pylint: disable=attribute-defined-outside-init
        self.batch_size = batch_size  # pylint: disable=attribute-defined-outside-init
        self.num_workers = num_workers  # pylint: disable=attribute-defined-outside-init
        self.key_momentum = key_momentum  # pylint: disable=attribute-defined-outside-init
        self.pseudo_momentum = pseudo_momentum  # pylint: disable=attribute-defined-outside-init
        self.threshold = threshold  # pylint: disable=attribute-defined-outside-init
        self.ramp_up = ramp_up  # pylint: disable=attribute-defined-outside-init
        self.distributed = distributed  # pylint: disable=attribute-defined-outside-init
        self.local_rank = local_rank  # pylint: disable=attribute-defined-outside-init
        self.mixed_precision = mixed_precision  # pylint: disable=attribute-defined-outside-init
        self.resume = resume  # pylint: disable=attribute-defined-outside-init

        # Intialize optimizer
        self.optimizer = get_optimizer(params=self.net_q.parameters(),
                                       name=optimizer,
                                       lr=learning_rate,
                                       weight_decay=weight_decay)
        # LR scheduling (if cosine_warmup < 0: scheduler = None)
        self.scheduler = get_cosine_scheduler(
            self.optimizer,
            epochs=epochs,
            warmup_steps=cosine_warmup,
            cycles=cosine_cycles,
            min_lr=cosine_min_lr,
        )

        # Resume from previous checkpoint (if 'resume' is not None)
        if resume is not None:
            if not os.path.exists(resume):
                raise FileNotFoundError
            self.load_model_from_checkpoint(resume)

        # Distributed training
        if distributed:
            self.net_q = DistributedDataParallel(
                module=self.net_q.to(local_rank), device_ids=[local_rank])
        else:
            self.net_q.to(local_rank)

        # NO DDP wrapping for {pseudo, key} encoders; no gradients
        self.net_ps.to(local_rank)
        self.net_k.to(local_rank)

        # Mixed precision training
        self.scaler = torch.cuda.amp.GradScaler() if mixed_precision else None

        # TensorBoard
        self.writer = SummaryWriter(ckpt_dir) if local_rank == 0 else None

        self.prepared = True
Beispiel #8
0
def main(args):
    """Main function."""

    # 1. Configurations
    torch.backends.cudnn.benchmark = True
    BACKBONE_CONFIGS, Config, Backbone = AVAILABLE_MODELS[args.backbone_type]
    Projector = PROJECTOR_TYPES[args.projector_type]

    config = Config(args)
    config.save()

    logfile = os.path.join(config.checkpoint_dir, 'main.log')
    logger = get_logger(stream=False, logfile=logfile)

    # 2. Data
    if config.data == 'wm811k':
        data_transforms = {
            'transform':
            get_transform(data=config.data,
                          size=config.input_size,
                          mode='test'),
            'positive_transform':
            get_transform(
                data=config.data,
                size=config.input_size,
                mode=config.augmentation,
            ),
        }
        train_set = torch.utils.data.ConcatDataset([
            WM811KForPIRL('./data/wm811k/unlabeled/train/', **data_transforms),
            WM811KForPIRL('./data/wm811k/labeled/train/', **data_transforms),
        ])
        valid_set = torch.utils.data.ConcatDataset([
            WM811KForPIRL('./data/wm811k/unlabeled/valid/', **data_transforms),
            WM811KForPIRL('./data/wm811k/labeled/valid/', **data_transforms),
        ])
        test_set = torch.utils.data.ConcatDataset([
            WM811KForPIRL('./data/wm811k/unlabeled/test/', **data_transforms),
            WM811KForPIRL('./data/wm811k/labeled/test/', **data_transforms),
        ])
    else:
        raise ValueError(
            f"PIRL only supports 'wm811k' data. Received '{config.data}'.")

    # 3. Model
    backbone = Backbone(BACKBONE_CONFIGS[config.backbone_config],
                        in_channels=IN_CHANNELS[config.data])
    projector = Projector(backbone.out_channels, config.projector_size)

    # 4. Optimization
    params = [{
        'params': backbone.parameters()
    }, {
        'params': projector.parameters()
    }]
    optimizer = get_optimizer(params=params,
                              name=config.optimizer,
                              lr=config.learning_rate,
                              weight_decay=config.weight_decay,
                              momentum=config.momentum)
    scheduler = get_scheduler(optimizer=optimizer,
                              name=config.scheduler,
                              epochs=config.epochs,
                              milestone=config.milestone,
                              warmup_steps=config.warmup_steps)

    # 5. Experiment (PIRL)
    experiment_kwargs = {
        'backbone':
        backbone,
        'projector':
        projector,
        'memory':
        MemoryBank(size=(len(train_set), config.projector_size),
                   device=config.device),
        'optimizer':
        optimizer,
        'scheduler':
        scheduler,
        'loss_function':
        PIRLLoss(temperature=config.temperature),
        'loss_weight':
        config.loss_weight,
        'num_negatives':
        config.num_negatives,
        'metrics': {
            'top@1': TopKAccuracy(num_classes=1 + config.num_negatives, k=1),
            'top@5': TopKAccuracy(num_classes=1 + config.num_negatives, k=5)
        },
        'checkpoint_dir':
        config.checkpoint_dir,
        'write_summary':
        config.write_summary,
    }
    experiment = PIRL(**experiment_kwargs)

    # 6. Run (train, evaluate, and test model)
    run_kwargs = {
        'train_set': train_set,
        'valid_set': valid_set,
        'test_set': test_set,
        'epochs': config.epochs,
        'batch_size': config.batch_size,
        'num_workers': config.num_workers,
        'device': config.device,
        'logger': logger,
        'save_every': config.save_every,
    }

    logger.info(f"Data: {config.data}")
    logger.info(f"Augmentation: {config.augmentation}")
    logger.info(
        f"Train : Valid : Test = {len(train_set):,} : {len(valid_set):,} : {len(test_set):,}"
    )
    logger.info(
        f"Trainable parameters ({backbone.__class__.__name__}): {backbone.num_parameters:,}"
    )
    logger.info(
        f"Trainable parameters ({projector.__class__.__name__}): {projector.num_parameters:,}"
    )
    logger.info(f"Projector type: {config.projector_type}")
    logger.info(f"Projector dimension: {config.projector_size}")
    logger.info(f"Saving model checkpoints to: {experiment.checkpoint_dir}")
    logger.info(
        f"Epochs: {run_kwargs['epochs']}, Batch size: {run_kwargs['batch_size']}"
    )
    logger.info(
        f"Workers: {run_kwargs['num_workers']}, Device: {run_kwargs['device']}"
    )

    steps_per_epoch = len(train_set) // config.batch_size + 1
    logger.info(f"Training steps per epoch: {steps_per_epoch:,}")
    logger.info(
        f"Total number of training iterations: {steps_per_epoch * config.epochs:,}"
    )

    if config.resume_from_checkpoint is not None:
        logger.info(
            f"Resuming from a checkpoint: {config.resume_from_checkpoint}")
        model_ckpt = os.path.join(config.resume_from_checkpoint,
                                  'best_model.pt')
        memory_ckpt = os.path.join(config.resume_from_checkpoint,
                                   'best_memory.pt')
        experiment.load_model_from_checkpoint(
            model_ckpt)  # load model & optimizer
        experiment.memory.load(memory_ckpt)  # load memory bank

        # Assign optimizer variables to appropriate device
        for state in experiment.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(config.device)

    experiment.run(**run_kwargs)
    logger.handlers.clear()
Beispiel #9
0
def main_worker(local_rank: int, config: object):

    torch.cuda.set_device(local_rank)
    if config.distributed:
        raise NotImplementedError

    if local_rank == 0:
        logfile = os.path.join(config.checkpoint_dir, 'main.log')
        logger = get_logger(stream=False, logfile=logfile)
    else:
        logger = None

    in_channels = int(config.decouple_input) + 1
    num_classes = 9

    # 2. Dataset
    train_transform = WM811KTransform(size=config.input_size,
                                      mode=config.augmentation)
    test_transform = WM811KTransform(size=config.input_size, mode='test')
    train_set = WM811K('./data/wm811k/labeled/train/',
                       transform=train_transform,
                       proportion=config.label_proportion,
                       decouple_input=config.decouple_input)
    valid_set = WM811K('./data/wm811k/labeled/valid/',
                       transform=test_transform,
                       decouple_input=config.decouple_input)
    test_set = WM811K('./data/wm811k/labeled/test/',
                      transform=test_transform,
                      decouple_input=config.decouple_input)

    # 3. Model
    BACKBONE_CONFIGS, Backbone = AVAILABLE_MODELS[config.backbone_type]
    backbone = Backbone(BACKBONE_CONFIGS[config.backbone_config],
                        in_channels=in_channels)
    classifier = LinearClassifier(in_channels=backbone.out_channels,
                                  num_classes=num_classes)

    # 3-1. Load pre-trained weights (if provided)
    if config.pretrained_model_file is not None:
        try:
            backbone.load_weights_from_checkpoint(
                path=config.pretrained_model_file, key='backbone')
        except KeyError:
            backbone.load_weights_from_checkpoint(
                path=config.pretrained_model_file, key='encoder')
        finally:
            if logger is not None:
                logger.info(
                    f"Loaded pre-trained model from: {config.pretrained_model_file}"
                )
    else:
        if logger is not None:
            logger.info("No pre-trained model provided.")

    # 3-2. Finetune or freeze weights of backbone
    if config.freeze:
        backbone.freeze_weights()
        if logger is not None:
            logger.info("Freezing backbone weights.")

    # 4. Optimization
    params = [{
        'params': backbone.parameters()
    }, {
        'params': classifier.parameters()
    }]
    optimizer = get_optimizer(params=params,
                              name=config.optimizer,
                              lr=config.learning_rate,
                              weight_decay=config.weight_decay,
                              momentum=config.momentum)
    scheduler = get_scheduler(optimizer=optimizer,
                              name=config.scheduler,
                              epochs=config.epochs,
                              warmup_steps=config.warmup_steps)

    # 5. Experiment (classification)
    experiment_kwargs = {
        'backbone':
        backbone,
        'classifier':
        classifier,
        'optimizer':
        optimizer,
        'scheduler':
        scheduler,
        'loss_function':
        LabelSmoothingLoss(num_classes, smoothing=config.label_smoothing),
        'distributed':
        config.distributed,
        'local_rank':
        local_rank,
        'checkpoint_dir':
        config.checkpoint_dir,
        'write_summary':
        config.write_summary,
        'metrics': {
            'accuracy': MultiAccuracy(num_classes=num_classes),
            'f1': MultiF1Score(num_classes=num_classes, average='macro'),
        },
    }
    experiment = Classification(**experiment_kwargs)

    # 6. Run (classification)
    run_kwargs = {
        'train_set': train_set,
        'valid_set': valid_set,
        'test_set': test_set,
        'epochs': config.epochs,
        'batch_size': config.batch_size,
        'num_workers': config.num_workers,
        'logger': logger,
    }
    experiment.run(**run_kwargs)
    logger.handlers.clear()
Beispiel #10
0
def test():

    ##########################################################
    # # Realsense test
    # pipeline = rs.pipeline()
    # config = rs.config()
    # config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30)
    # config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30)
    # pipeline.start(config)

    # frames = pipeline.wait_for_frames()
    # color_frame = frames.get_color_frame()
    # img = np.asanyarray(color_frame.get_data())
    # img = cv2.resize(img, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
    # cv2.namedWindow('RealSense', cv2.WINDOW_AUTOSIZE)
    # cv2.imshow('RealSense', img)
    # cv2.waitKey(1)
    ##########################################################

    device = args.device = torch.device("cuda:0" if args.cuda else "cpu")

    # Setup Logging
    log_dir = "{}/models/{}/".format(args.dump_location, args.exp_name)
    dump_dir = "{}/dump/{}/".format(args.dump_location, args.exp_name)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists("{}/images/".format(dump_dir)):
        os.makedirs("{}/images/".format(dump_dir))

    logging.basicConfig(filename=log_dir + 'train.log', level=logging.INFO)
    print("Dumping at {}".format(log_dir))
    logging.info(args)

    # Logging and loss variables
    num_scenes = args.num_processes
    num_episodes = int(args.num_episodes)
    device = args.device = torch.device("cuda:0" if args.cuda else "cpu")
    policy_loss = 0

    best_cost = 100000
    costs = deque(maxlen=1000)
    exp_costs = deque(maxlen=1000)
    pose_costs = deque(maxlen=1000)

    g_masks = torch.ones(num_scenes).float().to(device)
    l_masks = torch.zeros(num_scenes).float().to(device)

    best_local_loss = np.inf
    best_g_reward = -np.inf

    if args.eval:
        traj_lengths = args.max_episode_length // args.num_local_steps
        explored_area_log = np.zeros((num_scenes, num_episodes, traj_lengths))
        explored_ratio_log = np.zeros((num_scenes, num_episodes, traj_lengths))

    g_episode_rewards = deque(maxlen=1000)

    l_action_losses = deque(maxlen=1000)

    g_value_losses = deque(maxlen=1000)
    g_action_losses = deque(maxlen=1000)
    g_dist_entropies = deque(maxlen=1000)

    per_step_g_rewards = deque(maxlen=1000)

    g_process_rewards = np.zeros((num_scenes))

    # Starting environments
    torch.set_num_threads(1)
    envs = make_vec_envs(args)
    obs, infos = envs.reset()

    # Initialize map variables
    ### Full map consists of 4 channels containing the following:
    ### 1. Obstacle Map
    ### 2. Exploread Area
    ### 3. Current Agent Location
    ### 4. Past Agent Locations

    torch.set_grad_enabled(False)

    # Calculating full and local map sizes
    map_size = args.map_size_cm // args.map_resolution
    full_w, full_h = map_size, map_size

    local_w, local_h = int(full_w / args.global_downscaling), \
                       int(full_h / args.global_downscaling)

    # Initializing full and local map
    full_map = torch.zeros(num_scenes, 4, full_w, full_h).float().to(device)
    local_map = torch.zeros(num_scenes, 4, local_w, local_h).float().to(device)

    # Initial full and local pose
    full_pose = torch.zeros(num_scenes, 3).float().to(device)
    local_pose = torch.zeros(num_scenes, 3).float().to(device)

    # Origin of local map
    origins = np.zeros((num_scenes, 3))

    # Local Map Boundaries
    lmb = np.zeros((num_scenes, 4)).astype(int)

    ### Planner pose inputs The global agent location
    ### 4-7 store local map boundaries
    planner_pose_inputs = np.zeros((num_scenes, 7))

    # Initialize full_map and full_pose
    def init_map_and_pose():
        full_map.fill_(0.)
        full_pose.fill_(0.)
        full_pose[:, :2] = args.map_size_cm / 100.0 / 2.0

        locs = full_pose.cpu().numpy()
        planner_pose_inputs[:, :3] = locs
        for e in range(num_scenes):
            r, c = locs[e, 1], locs[e, 0]
            loc_r, loc_c = [
                int(r * 100.0 / args.map_resolution),
                int(c * 100.0 / args.map_resolution)
            ]

            full_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.0

            lmb[e] = get_local_map_boundaries(
                (loc_r, loc_c), (local_w, local_h), (full_w, full_h))

            planner_pose_inputs[e, 3:] = lmb[e]
            origins[e] = [
                lmb[e][2] * args.map_resolution / 100.0,
                lmb[e][0] * args.map_resolution / 100.0, 0.
            ]

        for e in range(num_scenes):
            local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1],
                                    lmb[e, 2]:lmb[e, 3]]
            local_pose[e] = full_pose[e] - \
                            torch.from_numpy(origins[e]).to(device).float()

    init_map_and_pose()
    # Global policy observation space
    g_observation_space = gym.spaces.Box(0,
                                         1, (8, local_w, local_h),
                                         dtype='uint8')

    # Global policy action space
    g_action_space = gym.spaces.Box(low=0.0,
                                    high=1.0,
                                    shape=(2, ),
                                    dtype=np.float32)

    # Local policy observation space
    l_observation_space = gym.spaces.Box(
        0, 255, (3, args.frame_width, args.frame_width), dtype='uint8')

    # Local and Global policy recurrent layer sizes
    l_hidden_size = args.local_hidden_size
    g_hidden_size = args.global_hidden_size

    # slam
    nslam_module = Neural_SLAM_Module(args).to(device)
    slam_optimizer = get_optimizer(nslam_module.parameters(),
                                   args.slam_optimizer)

    # Global policy
    # obse_space.shape= [8, 500, 500]
    # act_space= Box shape (2,)
    # g_hidden_size = 256
    g_policy = RL_Policy(g_observation_space.shape,
                         g_action_space,
                         base_kwargs={
                             'recurrent': args.use_recurrent_global,
                             'hidden_size': g_hidden_size,
                             'downscaling': args.global_downscaling
                         }).to(device)
    g_agent = algo.PPO(g_policy,
                       args.clip_param,
                       args.ppo_epoch,
                       args.num_mini_batch,
                       args.value_loss_coef,
                       args.entropy_coef,
                       lr=args.global_lr,
                       eps=args.eps,
                       max_grad_norm=args.max_grad_norm)

    # Local policy
    l_policy = Local_IL_Policy(
        l_observation_space.shape,
        envs.action_space.n,
        recurrent=args.use_recurrent_local,
        hidden_size=l_hidden_size,
        deterministic=args.use_deterministic_local).to(device)
    local_optimizer = get_optimizer(l_policy.parameters(),
                                    args.local_optimizer)

    # Storage
    g_rollouts = GlobalRolloutStorage(args.num_global_steps, num_scenes,
                                      g_observation_space.shape,
                                      g_action_space, g_policy.rec_state_size,
                                      1).to(device)

    slam_memory = FIFOMemory(args.slam_memory_size)
    '''

    '''

    # Loading model
    if args.load_slam != "0":
        print("Loading slam {}".format(args.load_slam))
        state_dict = torch.load(args.load_slam,
                                map_location=lambda storage, loc: storage)
        nslam_module.load_state_dict(state_dict)

    if not args.train_slam:
        nslam_module.eval()

    if args.load_global != "0":
        print("Loading global {}".format(args.load_global))
        state_dict = torch.load(args.load_global,
                                map_location=lambda storage, loc: storage)
        g_policy.load_state_dict(state_dict)

    if not args.train_global:
        g_policy.eval()

    if args.load_local != "0":
        print("Loading local {}".format(args.load_local))
        state_dict = torch.load(args.load_local,
                                map_location=lambda storage, loc: storage)
        l_policy.load_state_dict(state_dict)

    if not args.train_local:
        l_policy.eval()

    # /////////////////////////////////////////////////////////////// TESTING
    from matplotlib import image
    if args.testing:
        test_images = {}
        for i in range(5):
            for j in range(12):
                img_pth = 'imgs/robots_rs/test_{}_{}.jpg'.format(i + 1, j)
                img = image.imread(img_pth)
                test_images[(i + 1, j)] = np.array(img)

        poses_array = []
        for i in range(8):
            poses_array.append(np.array([[0.3, 0.0, 0.0], [0.3, 0.0, 0.0]]))
        for i in range(4):
            poses_array.append(
                np.array([[0.0, 0.0, -0.24587], [0.0, 0.0, -0.27587]]))

        # index from 1 to 5
        test_1_idx = 3
        test_2_idx = 1
        # image1_1 = image.imread('imgs/robots_rs/img_128_6.jpg')
        # image1_2 = image.imread('imgs/robots_rs/img_128_7.jpg')
        # image1_3 = image.imread('imgs/robots_rs/img_128_8.jpg')
        # image2_1 = image.imread('imgs/robots_rs/img_128_30.jpg')
        # image2_2 = image.imread('imgs/robots_rs/img_128_31.jpg')
        # image2_3 = image.imread('imgs/robots_rs/img_128_32.jpg')
        # # image_data = np.asarray(image)
        # # plt.imshow(image)
        # # plt.show()
        # image_data_1_1 = np.array(image1_1)
        # image_data_1_2 = np.array(image1_2)
        # image_data_1_3 = np.array(image1_3)
        # image_data_2_1 = np.array(image2_1)
        # image_data_2_2 = np.array(image2_2)
        # image_data_2_3 = np.array(image2_3)
        # image_data_1_all = np.array([image_data_1_1, image_data_2_1])
        # image_data_2_all = np.array([image_data_1_2, image_data_2_2])
        # image_data_3_all = np.array([image_data_1_3, image_data_2_3])
        image_data_all = np.array(
            [test_images[(test_1_idx, 0)], test_images[(test_2_idx, 0)]])
        obs = torch.from_numpy(image_data_all).float().to(device)
        obs = obs.permute((0, 3, 1, 2)).contiguous()

        # print(f"New obs: {obs}")
        print(f"New obs size: {obs.size()}")
    # /////////////////////////////////////////////////////////////// TESTING

    # Predict map from frame 1:
    poses = torch.from_numpy(
        np.asarray([
            infos[env_idx]['sensor_pose'] for env_idx in range(num_scenes)
        ])).float().to(device)

    _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
        nslam_module(obs, obs, poses, local_map[:, 0, :, :],
                     local_map[:, 1, :, :], local_pose)

    # print(f"\n\n local_map shape: {local_map.shape}")
    # print(f"\n obs shape: {obs.shape}")
    # print(f"\n poses shape: {poses.shape}")

    # Compute Global policy input
    locs = local_pose.cpu().numpy()

    global_input = torch.zeros(num_scenes, 8, local_w, local_h)
    global_orientation = torch.zeros(num_scenes, 1).long()

    for e in range(num_scenes):
        r, c = locs[e, 1], locs[e, 0]
        loc_r, loc_c = [
            int(r * 100.0 / args.map_resolution),
            int(c * 100.0 / args.map_resolution)
        ]

        local_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.
        global_orientation[e] = int((locs[e, 2] + 180.0) / 5.)

    global_input[:, 0:4, :, :] = local_map.detach()
    global_input[:, 4:, :, :] = nn.MaxPool2d(args.global_downscaling)(full_map)

    g_rollouts.obs[0].copy_(global_input)
    g_rollouts.extras[0].copy_(global_orientation)

    # Run Global Policy (global_goals = Long-Term Goal)
    g_value, g_action, g_action_log_prob, g_rec_states = \
        g_policy.act(
            g_rollouts.obs[0],
            g_rollouts.rec_states[0],
            g_rollouts.masks[0],
            extras=g_rollouts.extras[0],
            deterministic=False
        )

    cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
    global_goals = [[int(action[0] * local_w),
                     int(action[1] * local_h)] for action in cpu_actions]

    # Compute planner inputs
    planner_inputs = [{} for e in range(num_scenes)]
    for e, p_input in enumerate(planner_inputs):
        p_input['goal'] = global_goals[e]
        p_input['map_pred'] = global_input[e, 0, :, :].detach().cpu().numpy()
        p_input['exp_pred'] = global_input[e, 1, :, :].detach().cpu().numpy()
        p_input['pose_pred'] = planner_pose_inputs[e]

    # Output stores local goals as well as the the ground-truth action
    output = envs.get_short_term_goal(planner_inputs)

    last_obs = obs.detach()
    local_rec_states = torch.zeros(num_scenes, l_hidden_size).to(device)
    start = time.time()

    total_num_steps = -1
    g_reward = 0

    torch.set_grad_enabled(False)

    # fig, axis = plt.subplots(1,3)
    fig, axis = plt.subplots(2, 3)
    # a = [[1, 0, 1], [1, 0, 1], [1, 0, 1]]
    # plt.imshow(a)

    for ep_num in range(num_episodes):
        for step in range(args.max_episode_length):

            total_num_steps += 1

            g_step = (step // args.num_local_steps) % args.num_global_steps
            eval_g_step = step // args.num_local_steps + 1
            l_step = step % args.num_local_steps

            # ------------------------------------------------------------------
            # Local Policy
            del last_obs
            last_obs = obs.detach()
            local_masks = l_masks
            local_goals = output[:, :-1].to(device).long()

            if args.train_local:
                torch.set_grad_enabled(True)

            action, action_prob, local_rec_states = l_policy(
                obs,
                local_rec_states,
                local_masks,
                extras=local_goals,
            )

            if args.train_local:
                action_target = output[:, -1].long().to(device)
                policy_loss += nn.CrossEntropyLoss()(action_prob,
                                                     action_target)
                torch.set_grad_enabled(False)
            l_action = action.cpu()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            print(f"l_action: {l_action}")
            print(f"l_action size: {l_action.size()}")
            # Env step
            obs, rew, done, infos = envs.step(l_action)

            # ////////////////////////////////////////////////////////////////// TESTING
            # obs_all = _process_obs_for_display(obs)
            # _ims = [transform_rgb_bgr(obs_all[0]), transform_rgb_bgr(obs_all[1])]

            # ax1.imshow(_ims[0])
            # ax2.imshow(_ims[1])
            # plt.savefig(f"imgs/img_0_{step}.png")
            # # plt.clf()

            # ////////////////////////////////////////////////////////////////// TESTING

            l_masks = torch.FloatTensor([0 if x else 1
                                         for x in done]).to(device)
            g_masks *= l_masks
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Reinitialize variables when episode ends
            if step == args.max_episode_length - 1:  # Last episode step
                print("Final step")
                init_map_and_pose()
                del last_obs
                last_obs = obs.detach()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Neural SLAM Module
            if args.train_slam:
                # Add frames to memory
                for env_idx in range(num_scenes):
                    env_obs = obs[env_idx].to("cpu")
                    env_poses = torch.from_numpy(
                        np.asarray(
                            infos[env_idx]['sensor_pose'])).float().to("cpu")
                    env_gt_fp_projs = torch.from_numpy(
                        np.asarray(infos[env_idx]['fp_proj'])).unsqueeze(
                            0).float().to("cpu")
                    env_gt_fp_explored = torch.from_numpy(
                        np.asarray(infos[env_idx]['fp_explored'])).unsqueeze(
                            0).float().to("cpu")
                    env_gt_pose_err = torch.from_numpy(
                        np.asarray(
                            infos[env_idx]['pose_err'])).float().to("cpu")
                    slam_memory.push(
                        (last_obs[env_idx].cpu(), env_obs, env_poses),
                        (env_gt_fp_projs, env_gt_fp_explored, env_gt_pose_err))

            poses = torch.from_numpy(
                np.asarray([
                    infos[env_idx]['sensor_pose']
                    for env_idx in range(num_scenes)
                ])).float().to(device)

            # ///////////////////////////////////////////////////////////////// TESTING
            if args.testing:
                # obs = torch.from_numpy(obs_).float().to(self.device)
                # obs_cpu = obs.detach().cpu().numpy()
                # last_obs_cpu = last_obs.detach().cpu().numpy()
                # print(f"obs shape: {obs_cpu.shape}")
                # print(f"last_obs shape: {last_obs_cpu.shape}")

                original_obs = obs
                original_last_obs = last_obs
                original_poses = poses

                print(f"step: {step}")
                last_obs = torch.from_numpy(image_data_all).float().to(device)
                last_obs = last_obs.permute((0, 3, 1, 2)).contiguous()
                image_data_all = np.array([
                    test_images[(test_1_idx, step + 1)],
                    test_images[(test_2_idx, step + 1)]
                ])
                obs = torch.from_numpy(image_data_all).float().to(device)
                obs = obs.permute((0, 3, 1, 2)).contiguous()
                _poses = poses_array[step]
                poses = torch.from_numpy(_poses).float().to(device)
                # if step == 0:
                #     print(f"step: {step}")
                #     last_obs = torch.from_numpy(image_data_1_all).float().to(device)
                #     last_obs = last_obs.permute((0, 3, 1, 2)).contiguous()
                #     obs = torch.from_numpy(image_data_2_all).float().to(device)
                #     obs = obs.permute((0, 3, 1, 2)).contiguous()
                #     _poses = np.array([[0.2, 0.0, 0.0], [0.2, 0.0, 0.0]])
                #     poses = torch.from_numpy(_poses).float().to(device)
                # elif step == 1:
                #     print(f"step: {step}")
                #     last_obs = torch.from_numpy(image_data_2_all).float().to(device)
                #     last_obs = last_obs.permute((0, 3, 1, 2)).contiguous()
                #     obs = torch.from_numpy(image_data_3_all).float().to(device)
                #     obs = obs.permute((0, 3, 1, 2)).contiguous()
                #     _poses = np.array([[0.4, 0.0, 0.0], [0.2, 0.0, 0.17587]])
                #     poses = torch.from_numpy(_poses).float().to(device)
                # _poses = np.asarray([infos[env_idx]['sensor_pose'] for env_idx in range(num_scenes)])
                # print(f"New poses: {_poses}")
                # last_obs = torch.from_numpy(image_data_1_1).float().to(device)
                # obs = torch.from_numpy(image_data_1_2).float().to(device)

                # print(f"Original obs: {original_obs}")
                # print(f"Original obs shape: {original_obs.size()}")
                # print(f"Obs: {obs}")
                # print(f"Obs shape: {obs.size()}")
                # print(f"Original last_obs: {original_last_obs}")
                # print(f"Original last_obs shape: {original_last_obs.size()}")
                # print(f"last_obs: {last_obs}")
                # print(f"Last_obs shape: {last_obs.size()}")
                # print(f"Original poses: {original_poses}")
                # print(f"Original poses shape: {original_poses.size()}")
                print(f"Local poses : {local_pose}")
            # ///////////////////////////////////////////////////////////////// TESTING


            _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
                nslam_module(last_obs, obs, poses, local_map[:, 0, :, :],
                             local_map[:, 1, :, :], local_pose, build_maps=True)

            locs = local_pose.cpu().numpy()
            planner_pose_inputs[:, :3] = locs + origins
            local_map[:,
                      2, :, :].fill_(0.)  # Resetting current location channel
            for e in range(num_scenes):
                r, c = locs[e, 1], locs[e, 0]
                loc_r, loc_c = [
                    int(r * 100.0 / args.map_resolution),
                    int(c * 100.0 / args.map_resolution)
                ]

                local_map[e, 2:, loc_r - 2:loc_r + 3, loc_c - 2:loc_c + 3] = 1.

            # //////////////////////////////////////////////////////////////////
            if args.testing:
                local_map_draw = local_map

                if step % 1 == 0:
                    obs_all = _process_obs_for_display(obs)
                    _ims = [
                        transform_rgb_bgr(obs_all[0]),
                        transform_rgb_bgr(obs_all[1])
                    ]

                    imgs_1 = local_map_draw[0, :, :, :].cpu().numpy()
                    imgs_2 = local_map_draw[1, :, :, :].cpu().numpy()

                    # axis[1].imshow(imgs_1[0], cmap='gray')
                    # axis[2].imshow(imgs_1[1], cmap='gray')
                    # axis[0].imshow(_ims[0])
                    axis[0][1].imshow(imgs_1[0], cmap='gray')
                    axis[0][2].imshow(imgs_1[1], cmap='gray')
                    axis[0][0].imshow(_ims[0])
                    axis[1][1].imshow(imgs_2[0], cmap='gray')
                    axis[1][2].imshow(imgs_2[1], cmap='gray')
                    axis[1][0].imshow(_ims[1])
                    plt.savefig(f"imgs/test_{step}.png")

                obs = original_obs
                last_obs = original_last_obs
                poses = original_poses
            # //////////////////////////////////////////////////////////////////

            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Global Policy
            if l_step == args.num_local_steps - 1:
                # For every global step, update the full and local maps
                for e in range(num_scenes):
                    full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]] = \
                        local_map[e]
                    full_pose[e] = local_pose[e] + \
                                   torch.from_numpy(origins[e]).to(device).float()

                    locs = full_pose[e].cpu().numpy()
                    r, c = locs[1], locs[0]
                    loc_r, loc_c = [
                        int(r * 100.0 / args.map_resolution),
                        int(c * 100.0 / args.map_resolution)
                    ]

                    lmb[e] = get_local_map_boundaries(
                        (loc_r, loc_c), (local_w, local_h), (full_w, full_h))

                    planner_pose_inputs[e, 3:] = lmb[e]
                    origins[e] = [
                        lmb[e][2] * args.map_resolution / 100.0,
                        lmb[e][0] * args.map_resolution / 100.0, 0.
                    ]

                    local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1],
                                            lmb[e, 2]:lmb[e, 3]]
                    local_pose[e] = full_pose[e] - \
                                    torch.from_numpy(origins[e]).to(device).float()

                locs = local_pose.cpu().numpy()
                for e in range(num_scenes):
                    global_orientation[e] = int((locs[e, 2] + 180.0) / 5.)
                global_input[:, 0:4, :, :] = local_map
                global_input[:, 4:, :, :] = \
                    nn.MaxPool2d(args.global_downscaling)(full_map)

                if False:
                    for i in range(4):
                        ax[i].clear()
                        ax[i].set_yticks([])
                        ax[i].set_xticks([])
                        ax[i].set_yticklabels([])
                        ax[i].set_xticklabels([])
                        ax[i].imshow(global_input.cpu().numpy()[0, 4 + i])
                    plt.gcf().canvas.flush_events()
                    # plt.pause(0.1)
                    fig.canvas.start_event_loop(0.001)
                    plt.gcf().canvas.flush_events()

                # Get exploration reward and metrics
                g_reward = torch.from_numpy(
                    np.asarray([
                        infos[env_idx]['exp_reward']
                        for env_idx in range(num_scenes)
                    ])).float().to(device)

                if args.eval:
                    g_reward = g_reward * 50.0  # Convert reward to area in m2

                g_process_rewards += g_reward.cpu().numpy()
                g_total_rewards = g_process_rewards * \
                                  (1 - g_masks.cpu().numpy())
                g_process_rewards *= g_masks.cpu().numpy()
                per_step_g_rewards.append(np.mean(g_reward.cpu().numpy()))

                if np.sum(g_total_rewards) != 0:
                    for tr in g_total_rewards:
                        g_episode_rewards.append(tr) if tr != 0 else None

                if args.eval:
                    exp_ratio = torch.from_numpy(
                        np.asarray([
                            infos[env_idx]['exp_ratio']
                            for env_idx in range(num_scenes)
                        ])).float()

                    for e in range(num_scenes):
                        explored_area_log[e, ep_num, eval_g_step - 1] = \
                            explored_area_log[e, ep_num, eval_g_step - 2] + \
                            g_reward[e].cpu().numpy()
                        explored_ratio_log[e, ep_num, eval_g_step - 1] = \
                            explored_ratio_log[e, ep_num, eval_g_step - 2] + \
                            exp_ratio[e].cpu().numpy()

                # Add samples to global policy storage
                g_rollouts.insert(global_input, g_rec_states, g_action,
                                  g_action_log_prob, g_value, g_reward,
                                  g_masks, global_orientation)

                # Sample long-term goal from global policy
                g_value, g_action, g_action_log_prob, g_rec_states = \
                    g_policy.act(
                        g_rollouts.obs[g_step + 1],
                        g_rollouts.rec_states[g_step + 1],
                        g_rollouts.masks[g_step + 1],
                        extras=g_rollouts.extras[g_step + 1],
                        deterministic=False
                    )
                cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
                global_goals = [[
                    int(action[0] * local_w),
                    int(action[1] * local_h)
                ] for action in cpu_actions]

                g_reward = 0
                g_masks = torch.ones(num_scenes).float().to(device)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Get short term goal
            planner_inputs = [{} for e in range(num_scenes)]
            for e, p_input in enumerate(planner_inputs):
                p_input['map_pred'] = local_map[e, 0, :, :].cpu().numpy()
                p_input['exp_pred'] = local_map[e, 1, :, :].cpu().numpy()
                p_input['pose_pred'] = planner_pose_inputs[e]
                p_input['goal'] = global_goals[e]

            output = envs.get_short_term_goal(planner_inputs)

            # print(f"\n output (short term goal): {output}\n")

            # ------------------------------------------------------------------

            ### TRAINING
            torch.set_grad_enabled(True)
            # ------------------------------------------------------------------
            # Train Neural SLAM Module
            if args.train_slam and len(slam_memory) > args.slam_batch_size:
                for _ in range(args.slam_iterations):
                    inputs, outputs = slam_memory.sample(args.slam_batch_size)
                    b_obs_last, b_obs, b_poses = inputs
                    gt_fp_projs, gt_fp_explored, gt_pose_err = outputs

                    b_obs = b_obs.to(device)
                    b_obs_last = b_obs_last.to(device)
                    b_poses = b_poses.to(device)

                    gt_fp_projs = gt_fp_projs.to(device)
                    gt_fp_explored = gt_fp_explored.to(device)
                    gt_pose_err = gt_pose_err.to(device)

                    b_proj_pred, b_fp_exp_pred, _, _, b_pose_err_pred, _ = \
                        nslam_module(b_obs_last, b_obs, b_poses,
                                     None, None, None,
                                     build_maps=False)
                    loss = 0
                    if args.proj_loss_coeff > 0:
                        proj_loss = F.binary_cross_entropy(
                            b_proj_pred, gt_fp_projs)
                        costs.append(proj_loss.item())
                        loss += args.proj_loss_coeff * proj_loss

                    if args.exp_loss_coeff > 0:
                        exp_loss = F.binary_cross_entropy(
                            b_fp_exp_pred, gt_fp_explored)
                        exp_costs.append(exp_loss.item())
                        loss += args.exp_loss_coeff * exp_loss

                    if args.pose_loss_coeff > 0:
                        pose_loss = torch.nn.MSELoss()(b_pose_err_pred,
                                                       gt_pose_err)
                        pose_costs.append(args.pose_loss_coeff *
                                          pose_loss.item())
                        loss += args.pose_loss_coeff * pose_loss

                    if args.train_slam:
                        slam_optimizer.zero_grad()
                        loss.backward()
                        slam_optimizer.step()

                    del b_obs_last, b_obs, b_poses
                    del gt_fp_projs, gt_fp_explored, gt_pose_err
                    del b_proj_pred, b_fp_exp_pred, b_pose_err_pred

            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Train Local Policy
            if (l_step + 1) % args.local_policy_update_freq == 0 \
                    and args.train_local:
                local_optimizer.zero_grad()
                policy_loss.backward()
                local_optimizer.step()
                l_action_losses.append(policy_loss.item())
                policy_loss = 0
                local_rec_states = local_rec_states.detach_()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Train Global Policy
            if g_step % args.num_global_steps == args.num_global_steps - 1 \
                    and l_step == args.num_local_steps - 1:
                if args.train_global:
                    g_next_value = g_policy.get_value(
                        g_rollouts.obs[-1],
                        g_rollouts.rec_states[-1],
                        g_rollouts.masks[-1],
                        extras=g_rollouts.extras[-1]).detach()

                    g_rollouts.compute_returns(g_next_value, args.use_gae,
                                               args.gamma, args.tau)
                    g_value_loss, g_action_loss, g_dist_entropy = \
                        g_agent.update(g_rollouts)
                    g_value_losses.append(g_value_loss)
                    g_action_losses.append(g_action_loss)
                    g_dist_entropies.append(g_dist_entropy)
                g_rollouts.after_update()
            # ------------------------------------------------------------------

            # Finish Training
            torch.set_grad_enabled(False)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Logging
            if total_num_steps % args.log_interval == 0:
                end = time.time()
                time_elapsed = time.gmtime(end - start)
                log = " ".join([
                    "Time: {0:0=2d}d".format(time_elapsed.tm_mday - 1),
                    "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)),
                    "num timesteps {},".format(total_num_steps *
                                               num_scenes),
                    "FPS {},".format(int(total_num_steps * num_scenes \
                                         / (end - start)))
                ])

                log += "\n\tRewards:"

                if len(g_episode_rewards) > 0:
                    log += " ".join([
                        " Global step mean/med rew:",
                        "{:.4f}/{:.4f},".format(np.mean(per_step_g_rewards),
                                                np.median(per_step_g_rewards)),
                        " Global eps mean/med/min/max eps rew:",
                        "{:.3f}/{:.3f}/{:.3f}/{:.3f},".format(
                            np.mean(g_episode_rewards),
                            np.median(g_episode_rewards),
                            np.min(g_episode_rewards),
                            np.max(g_episode_rewards))
                    ])

                log += "\n\tLosses:"

                if args.train_local and len(l_action_losses) > 0:
                    log += " ".join([
                        " Local Loss:",
                        "{:.3f},".format(np.mean(l_action_losses))
                    ])

                if args.train_global and len(g_value_losses) > 0:
                    log += " ".join([
                        " Global Loss value/action/dist:",
                        "{:.3f}/{:.3f}/{:.3f},".format(
                            np.mean(g_value_losses), np.mean(g_action_losses),
                            np.mean(g_dist_entropies))
                    ])

                if args.train_slam and len(costs) > 0:
                    log += " ".join([
                        " SLAM Loss proj/exp/pose:"
                        "{:.4f}/{:.4f}/{:.4f}".format(np.mean(costs),
                                                      np.mean(exp_costs),
                                                      np.mean(pose_costs))
                    ])

                print(log)
                logging.info(log)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Save best models
            if (total_num_steps * num_scenes) % args.save_interval < \
                    num_scenes:

                # Save Neural SLAM Model
                if len(costs) >= 1000 and np.mean(costs) < best_cost \
                        and not args.eval:
                    best_cost = np.mean(costs)
                    torch.save(nslam_module.state_dict(),
                               os.path.join(log_dir, "model_best.slam"))

                # Save Local Policy Model
                if len(l_action_losses) >= 100 and \
                        (np.mean(l_action_losses) <= best_local_loss) \
                        and not args.eval:
                    torch.save(l_policy.state_dict(),
                               os.path.join(log_dir, "model_best.local"))

                    best_local_loss = np.mean(l_action_losses)

                # Save Global Policy Model
                if len(g_episode_rewards) >= 100 and \
                        (np.mean(g_episode_rewards) >= best_g_reward) \
                        and not args.eval:
                    torch.save(g_policy.state_dict(),
                               os.path.join(log_dir, "model_best.global"))
                    best_g_reward = np.mean(g_episode_rewards)

            # Save periodic models
            if (total_num_steps * num_scenes) % args.save_periodic < \
                    num_scenes:
                step = total_num_steps * num_scenes
                if args.train_slam:
                    torch.save(
                        nslam_module.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.slam".format(step)))
                if args.train_local:
                    torch.save(
                        l_policy.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.local".format(step)))
                if args.train_global:
                    torch.save(
                        g_policy.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.global".format(step)))
            # ------------------------------------------------------------------
    print("Finishing Epsiods")

    # Print and save model performance numbers during evaluation
    if args.eval:
        logfile = open("{}/explored_area.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_area_log[e].shape[0]):
                logfile.write(str(explored_area_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        logfile = open("{}/explored_ratio.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_ratio_log[e].shape[0]):
                logfile.write(str(explored_ratio_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        log = "Final Exp Area: \n"
        for i in range(explored_area_log.shape[2]):
            log += "{:.5f}, ".format(np.mean(explored_area_log[:, :, i]))

        log += "\nFinal Exp Ratio: \n"
        for i in range(explored_ratio_log.shape[2]):
            log += "{:.5f}, ".format(np.mean(explored_ratio_log[:, :, i]))

        print(log)
        logging.info(log)

    imgs_1 = local_map[0, :, :, :].cpu().numpy()
    imgs_2 = local_map[1, :, :, :].cpu().numpy()

    obs_all = _process_obs_for_display(obs)

    # fig, axis = plt.subplots(1, 3)
    # axis[0].imshow(obs_all[0])
    # axis[1].imshow(imgs_1[0], cmap='gray')
    # axis[2].imshow(imgs_1[1], cmap='gray')
    return

    cv2.imshow("Camer", transform_rgb_bgr(obs_all[0]))
    cv2.imshow("Proj", imgs_1[0])
    cv2.imshow("Map", imgs_1[1])

    cv2.imshow("Camer2", transform_rgb_bgr(obs_all[1]))
    cv2.imshow("Proj2", imgs_2[0])
    cv2.imshow("Map2", imgs_2[1])

    action = 1
    while action != 4:
        k = cv2.waitKey(0)
        if k == 119:
            action = 1
            action_2 = 1
        elif k == 100:
            action = 3
            action_2 = 1
        elif k == 97:
            action = 2
            action_2 = 2
        elif k == 102:
            action = 4
            break
        else:
            action = 1

        last_obs = obs.detach()

        obs, rew, done, infos = envs.step(
            torch.from_numpy(np.array([action, action_2])))

        obs_all = _process_obs_for_display(obs)
        cv2.imshow("Camer", transform_rgb_bgr(obs_all[0]))
        cv2.imshow("Camer2", transform_rgb_bgr(obs_all[1]))

        poses = torch.from_numpy(
            np.asarray([
                infos[env_idx]['sensor_pose'] for env_idx in range(num_scenes)
            ])).float().to(device)

        _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
            nslam_module(last_obs, obs, poses, local_map[:, 0, :, :],
                            local_map[:, 1, :, :], local_pose, build_maps=True)

        imgs_1 = local_map[0, :, :, :].cpu().numpy()
        imgs_2 = local_map[1, :, :, :].cpu().numpy()
        cv2.imshow("Proj", imgs_1[0])
        cv2.imshow("Map", imgs_1[1])
        cv2.imshow("Proj2", imgs_2[0])
        cv2.imshow("Map2", imgs_2[1])

    # plt.show()

    print("\n\nDone\n\n")
Beispiel #11
0
def main(args):
    """Main function."""

    # 0. CONFIGURATIONS
    torch.backends.cudnn.benchmark = True
    BACKBONE_CONFIGS, Config, Backbone = AVAILABLE_MODELS[args.backbone_type]
    Projector = PROJECTOR_TYPES[args.projector_type]

    config = Config(args)
    config.save()

    logfile = os.path.join(config.checkpoint_dir, 'main.log')
    logger = get_logger(stream=False, logfile=logfile)

    # 1. DATA
    input_transform = get_transform(config.data,
                                    size=config.input_size,
                                    mode='train')
    if config.data == 'wm811k':
        in_channels = 2
        train_set = torch.utils.data.ConcatDataset([
            WM811KForSimCLR('./data/wm811k/unlabeled/train/',
                            transform=input_transform),
            WM811KForSimCLR('./data/wm811k/labeled/train/',
                            transform=input_transform),
        ])
        valid_set = torch.utils.data.ConcatDataset([
            WM811KForSimCLR('./data/wm811k/unlabeled/valid/',
                            transform=input_transform),
            WM811KForSimCLR('./data/wm811k/labeled/valid/',
                            transform=input_transform),
        ])
        test_set = torch.utils.data.ConcatDataset([
            WM811KForSimCLR('./data/wm811k/unlabeled/test/',
                            transform=input_transform),
            WM811KForSimCLR('./data/wm811k/labeled/test/',
                            transform=input_transform),
        ])
    elif config.data == 'cifar10':
        in_channels = 3
        train_set = CIFAR10ForSimCLR('./data/cifar10/',
                                     train=True,
                                     transform=input_transform)
        valid_set = CIFAR10ForSimCLR('./data/cifar10/',
                                     train=False,
                                     transform=input_transform)
        test_set = valid_set
    elif config.data == 'stl10':
        raise NotImplementedError
    elif config.data == 'imagenet':
        raise NotImplementedError
    else:
        raise ValueError
    logger.info(f"Data type: {config.data}")
    logger.info(
        f"Train : Valid : Test = {len(train_set):,} : {len(valid_set):,} : {len(test_set):,}"
    )
    steps_per_epoch = len(train_set) // config.batch_size + 1
    logger.info(f"Training steps per epoch: {steps_per_epoch:,}")
    logger.info(
        f"Total number of training iterations: {steps_per_epoch * config.epochs:,}"
    )

    # 2. MODEL
    backbone = Backbone(BACKBONE_CONFIGS[config.backbone_config], in_channels)
    projector = Projector(backbone.out_channels, config.projector_size)
    logger.info(
        f"Trainable parameters ({backbone.__class__.__name__}): {backbone.num_parameters:,}"
    )
    logger.info(
        f"Trainable parameters ({projector.__class__.__name__}): {projector.num_parameters:,}"
    )
    logger.info(f"Embedding dimension: {config.projector_size}")

    # 3. OPTIMIZATION (TODO: add LARS optimizer)
    params = [{
        'params': backbone.parameters()
    }, {
        'params': projector.parameters()
    }]
    optimizer = get_optimizer(params=params,
                              name=config.optimizer,
                              lr=config.learning_rate,
                              weight_decay=config.weight_decay,
                              momentum=config.momentum)
    scheduler = get_scheduler(optimizer=optimizer,
                              name=config.scheduler,
                              epochs=config.epochs,
                              milestone=config.milestone,
                              warmup_steps=config.warmup_steps)

    # 4. EXPERIMENT (SimCLR)
    experiment_kwargs = {
        'backbone': backbone,
        'projector': projector,
        'optimizer': optimizer,
        'scheduler': scheduler,
        'loss_function': SimCLRLoss(temperature=config.temperature),
        'metrics': None,
        'checkpoint_dir': config.checkpoint_dir,
        'write_summary': config.write_summary,
    }
    experiment = SimCLR(**experiment_kwargs)
    logger.info(f"Saving model checkpoints to: {experiment.checkpoint_dir}")

    # 5. RUN (SimCLR)
    run_kwargs = {
        'train_set': train_set,
        'valid_set': valid_set,
        'test_set': test_set,
        'epochs': config.epochs,
        'batch_size': config.batch_size,
        'num_workers': config.num_workers,
        'device': config.device,
        'logger': logger,
        'save_every': config.save_every,
    }
    logger.info(
        f"Epochs: {run_kwargs['epochs']}, Batch size: {run_kwargs['batch_size']}"
    )
    logger.info(
        f"Workers: {run_kwargs['num_workers']}, Device: {run_kwargs['device']}"
    )

    if config.resume_from_checkpoint is not None:
        logger.info(
            f"Resuming from checkpoint: {config.resume_from_checkpoint}")
        model_ckpt = os.path.join(config.resume_from_checkpoint,
                                  'best_model.pt')
        experiment.load_model_from_checkpoint(model_ckpt)

        # Assign optimizer variables to appropriate device
        for state in experiment.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(config.device)

    experiment.run(**run_kwargs)
    logger.handlers.clear()
Beispiel #12
0
def main(args):
    """Main function."""

    # 0. Main configurations
    BACKBONE_CONFIGS, Config, Backbone = AVAILABLE_MODELS[args.backbone_type]
    Classifier = CLASSIFIER_TYPES['linear']
    config = Config(args)
    config.save()

    np.random.seed(config.seed)
    torch.manual_seed(config.seed)  # For reproducibility
    torch.backends.cudnn.benchmark = True

    logfile = os.path.join(config.checkpoint_dir, 'main.log')
    logger = get_logger(stream=False, logfile=logfile)

    in_channels = IN_CHANNELS[config.data]
    num_classes = NUM_CLASSES[config.data]

    # 1. Dataset
    if config.data == 'wm811k':
        train_transform = get_transform(config.data,
                                        size=config.input_size,
                                        mode=config.augmentation)
        test_transform = get_transform(config.data,
                                       size=config.input_size,
                                       mode='test')
        train_set = WM811K('./data/wm811k/labeled/train/',
                           transform=train_transform,
                           proportion=config.label_proportion,
                           seed=config.seed)
        valid_set = WM811K('./data/wm811k/labeled/valid/',
                           transform=test_transform)
        test_set = WM811K('./data/wm811k/labeled/test/',
                          transform=test_transform)
    else:
        raise NotImplementedError

    steps_per_epoch = len(train_set) // config.batch_size + 1
    logger.info(f"Data type: {config.data}")
    logger.info(
        f"Train : Valid : Test = {len(train_set):,} : {len(valid_set):,} : {len(test_set):,}"
    )
    logger.info(f"Training steps per epoch: {steps_per_epoch:,}")
    logger.info(
        f"Total number of training iterations: {steps_per_epoch * config.epochs:,}"
    )

    # 2. Model
    backbone = Backbone(BACKBONE_CONFIGS[config.backbone_config], in_channels)
    classifier = Classifier(
        in_channels=backbone.out_channels,
        num_classes=num_classes,
        dropout=config.dropout,
    )

    # 3. Optimization (TODO: add LARS)
    params = [
        {
            'params': backbone.parameters(),
            'lr': config.learning_rate
        },
        {
            'params': classifier.parameters(),
            'lr': config.learning_rate
        },
    ]
    optimizer = get_optimizer(params=params,
                              name=config.optimizer,
                              lr=config.learning_rate,
                              weight_decay=config.weight_decay)
    scheduler = get_scheduler(optimizer=optimizer,
                              name=config.scheduler,
                              epochs=config.epochs,
                              milestone=config.milestone,
                              warmup_steps=config.warmup_steps)

    # 4. Experiment (Mixup)
    experiment_kwargs = {
        'backbone': backbone,
        'classifier': classifier,
        'optimizer': optimizer,
        'scheduler': scheduler,
        'loss_function': nn.CrossEntropyLoss(),
        'checkpoint_dir': config.checkpoint_dir,
        'write_summary': config.write_summary,
        'metrics': {
            'accuracy': MultiAccuracy(num_classes=num_classes),
            'f1': MultiF1Score(num_classes=num_classes, average='macro'),
        },
    }
    experiment = Mixup(**experiment_kwargs)
    logger.info(f"Saving model checkpoints to: {experiment.checkpoint_dir}")

    # 9. RUN (Mixup)
    run_kwargs = {
        'train_set': train_set,
        'valid_set': valid_set,
        'test_set': test_set,
        'epochs': config.epochs,
        'batch_size': config.batch_size,
        'num_workers': config.num_workers,
        'device': config.device,
        'logger': logger,
        'eval_metric': config.eval_metric,
        'balance': config.balance,
        'disable_mixup': config.disable_mixup,
    }
    logger.info(
        f"Epochs: {run_kwargs['epochs']}, Batch size: {run_kwargs['batch_size']}"
    )
    logger.info(
        f"Workers: {run_kwargs['num_workers']}, Device: {run_kwargs['device']}"
    )
    logger.info(f"Mixup enabled: {not config.disable_mixup}")

    experiment.run(**run_kwargs)
    logger.handlers.clear()
    def prepare(self,
                ckpt_dir: str,
                optimizer: str = 'lars',
                learning_rate: float = 1.0,
                weight_decay: float = 0.0,
                cosine_warmup: int = 0,
                epochs: int = 100,
                batch_size: int = 256,
                num_workers: int = 0,
                distributed: bool = False,
                local_rank: int = 0,
                mixed_precision: bool = True,
                **kwargs):  # pylint: disable=unused-argument
        """Add function docstring."""

        # Set attributes
        self.ckpt_dir = ckpt_dir  # pylint: disable=attribute-defined-outside-init
        self.epochs = epochs  # pylint: disable=attribute-defined-outside-init
        self.batch_size = batch_size  # pylint: disable=attribute-defined-outside-init
        self.num_workers = num_workers  # pylint: disable=attribute-defined-outside-init
        self.distributed = distributed  # pylint: disable=attribute-defined-outside-init
        self.local_rank = local_rank  # pylint: disable=attribute-defined-outside-init
        self.mixed_precision = mixed_precision  # pylint: disable=attribute-defined-outside-init

        # Distributed training (optional)
        if distributed:
            self.backbone = DistributedDataParallel(
                nn.SyncBatchNorm.convert_sync_batchnorm(
                    self.backbone).to(local_rank),
                device_ids=[local_rank])
            self.classifier = DistributedDataParallel(
                nn.SyncBatchNorm.convert_sync_batchnorm(
                    self.classifier).to(local_rank),
                device_ids=[local_rank])
        else:
            self.backbone.to(local_rank)
            self.classifier.to(local_rank)

        # Mixed precision training (optional)
        self.scaler = torch.cuda.amp.GradScaler() if mixed_precision else None

        # Optimization (TODO: freeze)
        self.optimizer = get_optimizer(params=[
            {
                'params': self.backbone.parameters()
            },
            {
                'params': self.classifier.parameters()
            },
        ],
                                       name=optimizer,
                                       lr=learning_rate,
                                       weight_decay=weight_decay)
        self.scheduler = get_cosine_scheduler(self.optimizer,
                                              epochs=epochs,
                                              warmup_steps=cosine_warmup)

        # Loss function
        self.loss_function = nn.CrossEntropyLoss()

        # TensorBoard
        self.writer = SummaryWriter(ckpt_dir) if local_rank == 0 else None

        # Ready to train!
        self.prepared = True
Beispiel #14
0
    def prepare(self,
                ckpt_dir: str,
                optimizer: str = 'lars',
                learning_rate: float = 0.2,
                weight_decay: float = 1.5 * 1e-6,
                cosine_warmup: int = 10,
                cosine_cycles: int = 1,
                cosine_min_lr: float = 0.,
                epochs: int = 1000,
                batch_size: int = 256,
                num_workers: int = 0,
                distributed: bool = False,
                local_rank: int = 0,
                mixed_precision: bool = True,
                resume: str = None):
        """Prepare BYOL pre-training."""

        # Set attributes
        self.ckpt_dir = ckpt_dir
        self.epochs = epochs
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.distributed = distributed
        self.local_rank = local_rank
        self.mixed_precision = mixed_precision
        self.resume = resume

        self.optimizer = get_optimizer(
            params=[
                {
                    'params': self.online_net.parameters()
                },
                {
                    'params': self.online_predictor.parameters()
                },
            ],
            name=optimizer,
            lr=learning_rate,
            weight_decay=weight_decay  # TODO: remove params from batch norm
        )

        self.scheduler = get_cosine_scheduler(
            self.optimizer,
            epochs=self.epochs,
            warmup_steps=cosine_warmup,
            cycles=cosine_cycles,
            min_lr=cosine_min_lr,
        )

        # Resuming from previous checkpoint (optional)
        if resume is not None:
            if not os.path.exists(resume):
                raise FileNotFoundError
            self.load_model_from_checkpoint(resume)

        # Distributed training (optional, disabled by default.)
        if distributed:
            self.online_net = DistributedDataParallel(
                module=self.online_net.to(local_rank), device_ids=[local_rank])
            self.online_predictor = DistributedDataParallel(
                module=self.online_predictor.to(local_rank),
                device_ids=[local_rank])
        else:
            self.online_net.to(local_rank)
            self.online_predictor.to(local_rank)

        # No DDP wrapping for target network; no gradient updates
        self.target_net.to(local_rank)

        # Mixed precision training (optional, enabled by default)
        self.scaler = torch.cuda.amp.GradScaler() if mixed_precision else None

        # TensorBoard
        self.writer = SummaryWriter(ckpt_dir) if local_rank == 0 else None

        # Ready to train
        self.prepared = True