Exemplo n.º 1
0
def main():
    args = parser.parse_args()
    args.tensorboard = not args.no_tensorboard
    args.load_model = not args.clear_weights
    args.save_checkpoints = not args.no_save_checkpoints

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn(
            "You have chosen to seed training. "
            "This will turn on the CUDNN deterministic setting, "
            "which can slow down your training considerably! "
            "You may see unexpected behavior when restarting "
            "from checkpoints."
        )

    log_prefix = args.log_prefix
    time_str = misc_util.get_time_str()
    checkpoint_dir = os.path.join(log_prefix, args.checkpoint_dirname, time_str)

    torch_devices = [int(gpu_id.strip()) for gpu_id in args.pytorch_gpu_ids.split(",")]
    args.gpu = torch_devices[0]
    device = "cuda:" + str(torch_devices[0])

    model = ImagenetModel()
    model = pt_util.get_data_parallel(model, torch_devices)
    model.to(device)

    start_iter = 0
    if args.load_model:
        start_iter = pt_util.restore_from_folder(model, os.path.join(log_prefix, args.checkpoint_dirname, "*"))
    args.start_epoch = start_iter

    train_logger = None
    test_logger = None
    if args.tensorboard:
        train_logger = tensorboard_logger.Logger(
            os.path.join(log_prefix, args.tensorboard_dirname, time_str + "_train")
        )
        test_logger = tensorboard_logger.Logger(os.path.join(log_prefix, args.tensorboard_dirname, time_str + "_test"))

    main_worker(model, args.gpu, args, train_logger, test_logger, checkpoint_dir)
Exemplo n.º 2
0
    def __init__(self, args):
        super(VinceModel, self).__init__(args)
        self.args = args
        self.num_frames = self.args.num_frames

        # Network stuff
        self.feature_extractor = self.args.backbone(self.args, -2)
        resnet_output_channels = self.feature_extractor.output_channels
        self.output_channels = resnet_output_channels

        if self.args.use_attention:
            self.average_layers = pt_util.AttentionPool2D(
                resnet_output_channels, keepdim=False, return_masks=True)
        else:
            self.average_layers = nn.Sequential(
                torch.nn.AdaptiveAvgPool2d((1, 1)), pt_util.RemoveDim((2, 3)))

        self.feature_extractor = pt_util.get_data_parallel(
            self.feature_extractor, args.feature_extractor_gpu_ids)
        self.feature_extractor_device = args.feature_extractor_gpu_ids[0]

        self.embedding = nn.Sequential(
            nn.Linear(self.output_channels, self.output_channels),
            constants.NONLINEARITY(),
            nn.Linear(self.output_channels, self.args.vince_embedding_size),
        )
        if self.args.jigsaw:
            self.jigsaw_linear = nn.Linear(self.output_channels,
                                           self.output_channels)
            self.jigsaw_embedding = nn.Sequential(
                nn.Linear(self.output_channels * 9, self.output_channels),
                constants.NONLINEARITY(),
                nn.Linear(self.output_channels,
                          self.args.vince_embedding_size),
            )
        if self.args.inter_batch_comparison:
            if self.num_frames > 1:
                diag_mask = pt_util.from_numpy(
                    scipy.linalg.block_diag(*[
                        np.ones(
                            (self.num_frames, self.num_frames), dtype=np.bool)
                    ] * (self.args.batch_size // self.num_frames))).to(
                        device=self.device)
                self.similarity_mask = torch.cat(
                    (
                        diag_mask,
                        torch.zeros(
                            (self.args.batch_size, self.args.vince_queue_size),
                            device=self.device,
                            dtype=torch.bool),
                    ),
                    dim=1,
                )

            eye = torch.eye(self.args.batch_size,
                            device=self.device,
                            dtype=torch.bool)
            self.eye_mask = torch.cat(
                (
                    eye,
                    torch.zeros(
                        (self.args.batch_size, self.args.vince_queue_size),
                        device=self.device,
                        dtype=torch.bool),
                ),
                dim=1,
            )

        if self.args.use_imagenet:
            self.imagenet_decoders = nn.ModuleList([
                nn.Linear(self.output_channels, 1000),
                nn.Sequential(
                    nn.Linear(self.output_channels, self.output_channels),
                    constants.NONLINEARITY(),
                    nn.Linear(self.output_channels, 1000),
                ),
            ])
            self.num_imagenet_decoders = len(self.imagenet_decoders)
Exemplo n.º 3
0
    def __init__(
        self,
        encoder_type,
        decoder_output_info,
        recurrent=False,
        end_to_end=False,
        hidden_size=512,
        target_vector_size=None,
        action_size=None,
        gpu_ids=None,
        create_decoder=True,
        blind=False,
    ):
        assert action_size is not None
        self.aah_im_blind = blind
        self.end_to_end = end_to_end
        self.action_size = action_size
        self.target_vector_size = target_vector_size
        self.decoder_enabled = False
        self.decoder_outputs = None
        self.class_pred = None
        self.visual_encoder_features = None
        self.visual_features = None

        super(RLBaseWithVisualEncoder, self).__init__(
            recurrent,
            recurrent_input_size=hidden_size + self.target_vector_size +
            self.action_size,
            hidden_size=hidden_size,
        )

        if self.aah_im_blind:
            self.blind_projection = nn.Sequential(
                nn.Linear(
                    self.target_vector_size + self.action_size,
                    hidden_size + self.target_vector_size + self.action_size))
        else:
            self.visual_encoder = encoder_type(decoder_output_info,
                                               create_decoder)
            self.num_output_channels = self.visual_encoder.num_output_channels

            self.visual_encoder = pt_util.get_data_parallel(
                self.visual_encoder, gpu_ids)

            self.decoder_output_info = decoder_output_info

            self.visual_projection = nn.Sequential(
                ConvBlock(self.num_output_channels, hidden_size),
                ConvBlock(hidden_size, hidden_size),
                nn.AvgPool2d(2, 2),
                pt_util.RemoveDim((2, 3)),
                nn.Linear(hidden_size * 4 * 4, hidden_size),
            )

        self.rl_layers = nn.Sequential(
            nn.Linear(hidden_size + self.target_vector_size + self.action_size,
                      hidden_size),
            nn.ELU(inplace=True),
            nn.Linear(hidden_size, hidden_size),
            nn.ELU(inplace=True),
        )

        self.egomotion_layer = nn.Sequential(
            nn.Linear(2 * hidden_size, hidden_size), nn.ELU(inplace=True),
            nn.Linear(hidden_size, action_size))

        self.motion_model_layer = nn.Sequential(
            nn.Linear(hidden_size + action_size, hidden_size),
            nn.ELU(inplace=True), nn.Linear(hidden_size, hidden_size))

        self.critic_linear = init(nn.Linear(hidden_size,
                                            1), nn.init.orthogonal_,
                                  lambda x: nn.init.constant_(x, 0),
                                  np.sqrt(2))
Exemplo n.º 4
0
def main():
    torch_devices = [
        int(gpu_id.strip()) for gpu_id in args.pytorch_gpu_ids.split(",")
    ]
    render_gpus = [
        int(gpu_id.strip()) for gpu_id in args.render_gpu_ids.split(",")
    ]
    device = "cuda:" + str(torch_devices[0])

    decoder_output_info = [("reconstruction", 3), ("depth", 1),
                           ("surface_normals", 3)]
    if USE_SEMANTIC:
        decoder_output_info.append(("semantic", 41))

    model = ShallowVisualEncoder(decoder_output_info)
    model = pt_util.get_data_parallel(model, torch_devices)
    model = pt_util.DummyScope(model, ["base", "visual_encoder"])
    model.to(device)

    print("Model constructed")
    print(model)

    train_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(224)
    ])

    train_transforms_depth = transforms.Compose([
        PIL.Image.fromarray,
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(224), np.array
    ])

    train_transforms_semantic = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(224)
    ])

    sensors = ["RGB_SENSOR", "DEPTH_SENSOR"
               ] + (["SEMANTIC_SENSOR"] if USE_SEMANTIC else [])
    if args.dataset == "suncg":
        data_train = HabitatImageGenerator(
            render_gpus,
            "suncg",
            args.data_subset,
            "data/dumps/suncg/{split}/dataset_one_ep_per_scene.json.gz",
            images_before_reset=1000,
            sensors=sensors,
            transform=train_transforms,
            depth_transform=train_transforms_depth,
            semantic_transform=train_transforms_semantic,
        )
        print("Num train images", len(data_train))

        data_test = HabitatImageGenerator(
            render_gpus,
            "suncg",
            "val",
            "data/dumps/suncg/{split}/dataset_one_ep_per_scene.json.gz",
            images_before_reset=1000,
            sensors=sensors,
        )
    elif args.dataset == "mp3d":
        data_train = HabitatImageGenerator(
            render_gpus,
            "mp3d",
            args.data_subset,
            "data/dumps/mp3d/{split}/dataset_one_ep_per_scene.json.gz",
            images_before_reset=1000,
            sensors=sensors,
            transform=train_transforms,
            depth_transform=train_transforms_depth,
            semantic_transform=train_transforms_semantic,
        )
        print("Num train images", len(data_train))

        data_test = HabitatImageGenerator(
            render_gpus,
            "mp3d",
            "val",
            "data/dumps/mp3d/{split}/dataset_one_ep_per_scene.json.gz",
            images_before_reset=1000,
            sensors=sensors,
        )
    elif args.dataset == "gibson":
        data_train = HabitatImageGenerator(
            render_gpus,
            "gibson",
            args.data_subset,
            "data/datasets/pointnav/gibson/v1/{split}/{split}.json.gz",
            images_before_reset=1000,
            sensors=sensors,
            transform=train_transforms,
            depth_transform=train_transforms_depth,
            semantic_transform=train_transforms_semantic,
        )
        print("Num train images", len(data_train))

        data_test = HabitatImageGenerator(
            render_gpus,
            "gibson",
            "val",
            "data/datasets/pointnav/gibson/v1/{split}/{split}.json.gz",
            images_before_reset=1000,
            sensors=sensors,
        )
    else:
        raise NotImplementedError("No rule for this dataset.")

    print("Num train images", len(data_train))
    print("Num val images", len(data_test))

    print("Using device", device)
    print("num cpus:", args.num_processes)

    train_loader = torch.utils.data.DataLoader(
        data_train,
        batch_size=BATCH_SIZE,
        num_workers=args.num_processes,
        worker_init_fn=data_train.worker_init_fn,
        shuffle=False,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(
        data_test,
        batch_size=TEST_BATCH_SIZE,
        num_workers=len(render_gpus) if args.num_processes > 0 else 0,
        worker_init_fn=data_test.worker_init_fn,
        shuffle=False,
        pin_memory=True,
    )

    log_prefix = args.log_prefix
    time_str = misc_util.get_time_str()
    checkpoint_dir = os.path.join(log_prefix, args.checkpoint_dirname,
                                  time_str)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    start_iter = 0
    if args.load_model:
        start_iter = pt_util.restore_from_folder(
            model, os.path.join(log_prefix, args.checkpoint_dirname, "*"))

    train_logger = None
    test_logger = None
    if args.tensorboard:
        train_logger = tensorboard_logger.Logger(
            os.path.join(log_prefix, args.tensorboard_dirname,
                         time_str + "_train"))
        test_logger = tensorboard_logger.Logger(
            os.path.join(log_prefix, args.tensorboard_dirname,
                         time_str + "_test"))

    total_num_steps = start_iter

    if args.save_checkpoints and not args.no_weight_update:
        pt_util.save(model,
                     checkpoint_dir,
                     num_to_keep=5,
                     iteration=total_num_steps)

    evaluate_model(model, device, test_loader, total_num_steps, test_logger,
                   decoder_output_info)

    for epoch in range(0, EPOCHS + 1):
        total_num_steps = train_model(model, device, train_loader, optimizer,
                                      total_num_steps, train_logger,
                                      decoder_output_info, checkpoint_dir)
        evaluate_model(model, device, test_loader, total_num_steps,
                       test_logger, decoder_output_info)