def update_train_weights(self):
     if self.skew_dataset:
         self._train_weights = self._compute_train_weights()
         if self.use_parallel_dataloading:
             self.train_dataloader = DataLoader(
                 self.train_dataset_pt,
                 sampler=InfiniteWeightedRandomSampler(self.train_dataset, self._train_weights),
                 batch_size=self.batch_size,
                 drop_last=False,
                 num_workers=self.train_data_workers,
                 pin_memory=True,
             )
             self.train_dataloader = iter(self.train_dataloader)
예제 #2
0
    def __init__(
        self,
        train_dataset,
        test_dataset,
        model,
        positive_range=2,
        negative_range=10,
        triplet_sample_num=8,
        triplet_loss_margin=0.5,
        batch_size=128,
        log_interval=0,
        recon_loss_coef=1,
        triplet_loss_coef=[],
        triplet_loss_type=[],
        ae_loss_coef=1,
        matching_loss_coef=1,
        vae_matching_loss_coef=1,
        matching_loss_one_side=False,
        contrastive_loss_coef=0,
        lstm_kl_loss_coef=0,
        adaptive_margin=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        use_parallel_dataloading=False,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
    ):

        print("In LSTM trainer, ae_loss_coef is: ", ae_loss_coef)
        print("In LSTM trainer, matching_loss_coef is: ", matching_loss_coef)
        print("In LSTM trainer, vae_matching_loss_coef is: ",
              vae_matching_loss_coef)

        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot

        self.recon_loss_coef = recon_loss_coef
        self.triplet_loss_coef = triplet_loss_coef
        self.ae_loss_coef = ae_loss_coef
        self.matching_loss_coef = matching_loss_coef
        self.vae_matching_loss_coef = vae_matching_loss_coef
        self.contrastive_loss_coef = contrastive_loss_coef
        self.lstm_kl_loss_coef = lstm_kl_loss_coef
        self.matching_loss_one_side = matching_loss_one_side

        # triplet loss range
        self.positve_range = positive_range
        self.negative_range = negative_range
        self.triplet_sample_num = triplet_sample_num
        self.triplet_loss_margin = triplet_loss_margin
        self.triplet_loss_type = triplet_loss_type
        self.adaptive_margin = adaptive_margin

        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )
        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset.dtype == np.uint8
        assert self.test_dataset.dtype == np.uint8

        self.batch_size = batch_size
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if self.skew_dataset:
            self._train_weights = self._compute_train_weights()
        else:
            self._train_weights = None

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.eval_statistics = OrderedDict()
        self._extra_stats_to_log = None
예제 #3
0
    def __init__(
        self,
        train_dataset,
        test_dataset,
        model,
        batch_size=128,
        log_interval=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        use_parallel_dataloading=True,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
    ):
        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot

        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )
        self.train_dataset, self.test_dataset = train_dataset, test_dataset
        assert self.train_dataset.dtype == np.uint8
        assert self.test_dataset.dtype == np.uint8
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset

        self.batch_size = batch_size
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if self.skew_dataset:
            self._train_weights = self._compute_train_weights()
        else:
            self._train_weights = None

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.eval_statistics = OrderedDict()
        self._extra_stats_to_log = None
예제 #4
0
    def __init__(
        self,
        model,
        batch_size=128,
        log_interval=0,
        beta=0.5,
        beta_schedule=None,
        lr=None,
        do_scatterplot=False,
        normalize=False,
        mse_weight=0.1,
        is_auto_encoder=False,
        background_subtract=False,
        linearity_weight=0.0,
        distance_weight=0.0,
        loss_weights=None,
        use_linear_dynamics=False,
        use_parallel_dataloading=False,
        train_data_workers=2,
        skew_dataset=False,
        skew_config=None,
        priority_function_kwargs=None,
        start_skew_epoch=0,
        weight_decay=0,
        key_to_reconstruct='observations',
        num_epochs=None,
    ):
        #TODO:steven fix pickling
        assert not use_parallel_dataloading, "Have to fix pickling the dataloaders first"

        if skew_config is None:
            skew_config = {}
        self.log_interval = log_interval
        self.batch_size = batch_size
        self.beta = beta
        if is_auto_encoder:
            self.beta = 0
        if lr is None:
            if is_auto_encoder:
                lr = 1e-2
            else:
                lr = 1e-3
        self.beta_schedule = beta_schedule
        self.num_epochs = num_epochs
        if self.beta_schedule is None or is_auto_encoder:
            self.beta_schedule = ConstantSchedule(self.beta)
        self.imsize = model.imsize
        self.do_scatterplot = do_scatterplot
        model.to(ptu.device)

        self.model = model
        self.representation_size = model.representation_size
        self.input_channels = model.input_channels
        self.imlength = model.imlength

        self.lr = lr
        params = list(self.model.parameters())
        self.optimizer = optim.Adam(
            params,
            lr=self.lr,
            weight_decay=weight_decay,
        )

        self.key_to_reconstruct = key_to_reconstruct
        self.use_parallel_dataloading = use_parallel_dataloading
        self.train_data_workers = train_data_workers
        self.skew_dataset = skew_dataset
        self.skew_config = skew_config
        self.start_skew_epoch = start_skew_epoch
        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if use_parallel_dataloading:
            self.train_dataset_pt = ImageDataset(train_dataset,
                                                 should_normalize=True)
            self.test_dataset_pt = ImageDataset(test_dataset,
                                                should_normalize=True)

            if self.skew_dataset:
                base_sampler = InfiniteWeightedRandomSampler(
                    self.train_dataset, self._train_weights)
            else:
                base_sampler = InfiniteRandomSampler(self.train_dataset)
            self.train_dataloader = DataLoader(
                self.train_dataset_pt,
                sampler=InfiniteRandomSampler(self.train_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=train_data_workers,
                pin_memory=True,
            )
            self.test_dataloader = DataLoader(
                self.test_dataset_pt,
                sampler=InfiniteRandomSampler(self.test_dataset),
                batch_size=batch_size,
                drop_last=False,
                num_workers=0,
                pin_memory=True,
            )
            self.train_dataloader = iter(self.train_dataloader)
            self.test_dataloader = iter(self.test_dataloader)

        self.normalize = normalize
        self.mse_weight = mse_weight
        self.background_subtract = background_subtract

        if self.normalize or self.background_subtract:
            self.train_data_mean = np.mean(self.train_dataset, axis=0)
            self.train_data_mean = normalize_image(
                np.uint8(self.train_data_mean))
        self.linearity_weight = linearity_weight
        self.distance_weight = distance_weight
        self.loss_weights = loss_weights

        self.use_linear_dynamics = use_linear_dynamics
        self._extra_stats_to_log = None

        # stateful tracking variables, reset every epoch
        self.eval_statistics = collections.defaultdict(list)
        self.eval_data = collections.defaultdict(list)
        self.num_batches = 0