Example #1
0
def gradient_descent(Xn,y,theta,alpha,num_iters=1000,tol=None,theta_hist=False):
    """Perform gradient descent optimization to learn theta that creates the best fit
    hypothesis h(theta)=X @ theta to the dataset

    Args:
        Xn:     Normalized Feature Matrix
        y:      Target Vector
        alpha:  (Real, >0) Learning Rate

    Kwargs:
        num_iters:  (Real) Maximum iterations to perform optimization
        tol:        (Real) If provided, superscede num_iters, breaking optimization if tolerance cost is reached
        theta_hist: (Bool) IF provided, also return theta's history
    """
    
    # Check to see if Xn is normalized. Warn if not. 
    if round(Xn[:,1].std()) != 1:
        utils.printYellow("Gradient Descent X matrix is not normalized. Pass in normalized in the future to ensure convergence")
        # Xn,_,_ = normalize_features(Xn)

    m = 1.0*len(y)
    J_history =[]
    theta_history = []
    for idx in range(0,num_iters):
        ## Compute new theta
        theta = theta -  (alpha/m) * ((Xn @ theta - y).T @ Xn).T
        theta_history.append(theta)

        ## Save new J cost
        J_history.append(compute_cost(Xn,y,theta))
        if (idx>1) and (tol is not None) and (J_history[-1]-J_history[-2] <= tol):
            break

        ## Check to make sure J is decreasing...
        if (idx > 1) and J_history[-2] <= J_history[-1]:
            utils.printRed("Gradient Descent is not decreasing! Alpha: {}\t previous J {}\tJ {}. Try decreasing alpha".format(alpha,J_history[-2], J_history[-1]))
    if theta_hist:
        return theta, J_history, np.vstack(theta_history)
    return theta, J_history
def main():
    parser = argparse.ArgumentParser(description="Train the model")
    parser.add_argument('-trainf', "--train-filepath", type=str, default=None, required=True,
                        help="training dataset filepath.")
    parser.add_argument('-validf', "--val-filepath", type=str, default=None,
                        help="validation dataset filepath.")
    parser.add_argument("--shuffle", action="store_true", default=False,
                        help="Shuffle the dataset")
    parser.add_argument("--load-weights", type=str, default=None,
                        help="load pretrained weights")
    parser.add_argument("--load-model", type=str, default=None,
                        help="load pretrained model, entire model (filepath, default: None)")

    parser.add_argument("--debug", action="store_true", default=False)
    parser.add_argument('--epochs', type=int, default=30,
                        help='number of epochs to train (default: 30)')
    parser.add_argument("--batch-size", type=int, default=32,
                        help="Batch size")

    parser.add_argument('--img-shape', type=str, default="(1,512,512)",
                        help='Image shape (default "(1,512,512)"')

    parser.add_argument("--num-cpu", type=int, default=10,
                        help="Number of CPUs to use in parallel for dataloader.")
    parser.add_argument('--cuda', type=int, default=0,
                        help='CUDA visible device (use CPU if -1, default: 0)')
    parser.add_argument('--cuda-non-deterministic', action='store_true', default=False,
                        help="sets flags for non-determinism when using CUDA (potentially fast)")

    parser.add_argument('-lr', type=float, default=0.0005,
                        help='Learning rate')
    parser.add_argument('--seed', type=int, default=0,
                        help='Seed (numpy and cuda if GPU is used.).')

    parser.add_argument('--log-dir', type=str, default=None,
                        help='Save the results/model weights/logs under the directory.')

    args = parser.parse_args()

    # TODO: support image reshape
    img_shape = tuple(map(int, args.img_shape.strip()[1:-1].split(",")))

    if args.log_dir:
        os.makedirs(args.log_dir, exist_ok=True)
        best_model_path = os.path.join(args.log_dir, "model_weights.pth")
    else:
        best_model_path = None

    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if args.cuda >= 0:
            if args.cuda_non_deterministic:
                printBlue("Warning: using CUDA non-deterministc. Could be faster but results might not be reproducible.")
            else:
                printBlue("Using CUDA deterministc. Use --cuda-non-deterministic might accelerate the training a bit.")
            # Make CuDNN Determinist
            torch.backends.cudnn.deterministic = not args.cuda_non_deterministic

            # torch.cuda.manual_seed(args.seed)
            torch.cuda.manual_seed_all(args.seed)

    # TODO [OPT] enable multi-GPUs ?
    # https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html
    device = torch.device("cuda:{}".format(args.cuda) if torch.cuda.is_available()
                          and (args.cuda >= 0) else "cpu")

    # ================= Build dataloader =================
    # DataLoader
    # transform_normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
    #                                            std=[0.5, 0.5, 0.5])
    transform_normalize = transforms.Normalize(mean=[0.5],
                                               std=[0.5])

    # Warning: DO NOT use geometry transform (do it in the dataloader instead)
    data_transform = transforms.Compose([
        # transforms.ToPILImage(mode='F'), # mode='F' for one-channel image
        # transforms.Resize((256, 256)) # NO
        # transforms.RandomResizedCrop(256), # NO
        # transforms.RandomHorizontalFlip(p=0.5), # NO
        # WARNING, ISSUE: transforms.ColorJitter doesn't work with ToPILImage(mode='F').
        # Need custom data augmentation functions: TODO: DONE.
        # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),

        # Use OpenCVRotation, OpenCVXXX, ... (our implementation)
        # OpenCVRotation((-10, 10)), # angles (in degree)
        transforms.ToTensor(),  # already done in the dataloader
        transform_normalize
    ])

    geo_transform = GeoCompose([
        OpenCVRotation(angles=(-10, 10),
                       scales=(0.9, 1.1),
                       centers=(-0.05, 0.05)),

        # TODO add more data augmentation here
    ])

    def worker_init_fn(worker_id):
        # WARNING spawn start method is used,
        # worker_init_fn cannot be an unpicklable object, e.g., a lambda function.
        # A work-around for issue #5059: https://github.com/pytorch/pytorch/issues/5059
        np.random.seed()

    data_loader_train = {'batch_size': args.batch_size,
                         'shuffle': args.shuffle,
                         'num_workers': args.num_cpu,
                         #   'sampler': balanced_sampler,
                         'drop_last': True,  # for GAN-like
                         'pin_memory': False,
                         'worker_init_fn': worker_init_fn,
                         }

    data_loader_valid = {'batch_size': args.batch_size,
                         'shuffle': False,
                         'num_workers': args.num_cpu,
                         'drop_last': False,
                         'pin_memory': False,
                         }

    train_set = LiTSDataset(args.train_filepath,
                            dtype=np.float32,
                            geometry_transform=geo_transform,  # TODO enable data augmentation
                            pixelwise_transform=data_transform,
                            )
    valid_set = LiTSDataset(args.val_filepath,
                            dtype=np.float32,
                            pixelwise_transform=data_transform,
                            )

    dataloader_train = torch.utils.data.DataLoader(train_set, **data_loader_train)
    dataloader_valid = torch.utils.data.DataLoader(valid_set, **data_loader_valid)
    # =================== Build model ===================
    # TODO: control the model by bash command

    if args.load_weights:
        model = UNet(in_ch=1,
                     out_ch=3,  # there are 3 classes: 0: background, 1: liver, 2: tumor
                     depth=4,
                     start_ch=32, # 64
                     inc_rate=2,
                     kernel_size=5, # 3 
                     padding=True,
                     batch_norm=True,
                     spec_norm=False,
                     dropout=0.5,
                     up_mode='upconv',
                     include_top=True,
                     include_last_act=False,
                     )
        printYellow(f"Loading pretrained weights from: {args.load_weights}...")
        model.load_state_dict(torch.load(args.load_weights))
        printYellow("+ Done.")
    elif args.load_model:
        # load entire model
        model = torch.load(args.load_model)
        printYellow("Successfully loaded pretrained model.")

    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.95))  # TODO
    best_valid_loss = float('inf')
    # TODO TODO: add learning decay
    
    for epoch in range(args.epochs):
        for valid_mode, dataloader in enumerate([dataloader_train, dataloader_valid]):
            n_batch_per_epoch = len(dataloader)
            if args.debug:
                n_batch_per_epoch = 1

            # infinite dataloader allows several update per iteration (for special models e.g. GAN)
            dataloader = infinite_dataloader(dataloader)
            if valid_mode:
                printYellow("Switch to validation mode.")
                model.eval()
                prev_grad_mode = torch.is_grad_enabled()
                torch.set_grad_enabled(False)
            else:
                model.train()

            st = time.time()
            cum_loss = 0
            for iter_ind in range(n_batch_per_epoch):
                supplement_logs = ""
                # reset cumulated losses at the begining of each batch
                # loss_manager.reset_losses() # TODO: use torch.utils.tensorboard !!
                optimizer.zero_grad()

                img, msk = next(dataloader)
                img, msk = img.to(device), msk.to(device)

                # TODO this is ugly: convert dtype and convert the shape from (N, 1, 512, 512) to (N, 512, 512)
                msk = msk.to(torch.long).squeeze(1)

                msk_pred = model(img)  # shape (N, 3, 512, 512)

                # label_weights is determined according the liver_ratio & tumor_ratio
                # loss = CrossEntropyLoss(msk_pred, msk, label_weights=[1., 10., 100.], device=device)
                loss = DiceLoss(msk_pred, msk, label_weights=[1., 20., 50.], device=device)
                # loss = DiceLoss(msk_pred, msk, label_weights=[1., 20., 500.], device=device)

                if valid_mode:
                    pass
                else:
                    loss.backward()
                    optimizer.step()

                loss = loss.item()  # release
                cum_loss += loss
                if valid_mode:
                    print("\r--------(valid) {:.2%} Loss: {:.3f} (time: {:.1f}s) |supp: {}".format(
                        (iter_ind+1)/n_batch_per_epoch, cum_loss/(iter_ind+1), time.time()-st, supplement_logs), end="")
                else:
                    print("\rEpoch: {:3}/{} {:.2%} Loss: {:.3f} (time: {:.1f}s) |supp: {}".format(
                        (epoch+1), args.epochs, (iter_ind+1)/n_batch_per_epoch, cum_loss/(iter_ind+1), time.time()-st, supplement_logs), end="")
            print()
            if valid_mode:
                torch.set_grad_enabled(prev_grad_mode)

        valid_mean_loss = cum_loss/(iter_ind+1)  # validation (mean) loss of the current epoch

        if best_model_path and (valid_mean_loss < best_valid_loss):
            printGreen("Valid loss decreases from {:.5f} to {:.5f}, saving best model.".format(
                best_valid_loss, valid_mean_loss))
            best_valid_loss = valid_mean_loss
            # Only need to save the weights
            # torch.save(model.state_dict(), best_model_path)
            # save the entire model
            torch.save(model, best_model_path)

    return best_valid_loss
Example #3
0
def inference():
    """Support two mode: evaluation (on valid set) or inference mode (on test-set for submission)

    """
    parser = argparse.ArgumentParser(description="Inference mode")
    parser.add_argument('-testf', "--test-filepath", type=str, default=None, required=True,
                        help="testing dataset filepath.")
    parser.add_argument("-eval", "--evaluate", action="store_true", default=False,
                        help="Evaluation mode")
    parser.add_argument("--load-weights", type=str, default=None,
                        help="Load pretrained weights, torch state_dict() (filepath, default: None)")
    parser.add_argument("--load-model", type=str, default=None,
                        help="Load pretrained model, entire model (filepath, default: None)")

    parser.add_argument("--save2dir", type=str, default=None,
                        help="save the prediction labels to the directory (default: None)")
    parser.add_argument("--debug", action="store_true", default=False)
    parser.add_argument("--batch-size", type=int, default=32,
                        help="Batch size")

    parser.add_argument("--num-cpu", type=int, default=10,
                        help="Number of CPUs to use in parallel for dataloader.")
    parser.add_argument('--cuda', type=int, default=0,
                        help='CUDA visible device (use CPU if -1, default: 0)')
    args = parser.parse_args()

    printYellow("="*10 + " Inference mode. "+"="*10)
    if args.save2dir:
        os.makedirs(args.save2dir, exist_ok=True)

    device = torch.device("cuda:{}".format(args.cuda) if torch.cuda.is_available()
                          and (args.cuda >= 0) else "cpu")

    transform_normalize = transforms.Normalize(mean=[0.5],
                                               std=[0.5])

    data_transform = transforms.Compose([
        transforms.ToTensor(),
        transform_normalize
    ])

    data_loader_params = {'batch_size': args.batch_size,
                          'shuffle': False,
                          'num_workers': args.num_cpu,
                          'drop_last': False,
                          'pin_memory': False
                          }

    test_set = LiTSDataset(args.test_filepath,
                           dtype=np.float32,
                           pixelwise_transform=data_transform,
                           inference_mode=(not args.evaluate),
                           )
    dataloader_test = torch.utils.data.DataLoader(test_set, **data_loader_params)
    # =================== Build model ===================
    if args.load_weights:
        model = UNet(in_ch=1,
                     out_ch=3,  # there are 3 classes: 0: background, 1: liver, 2: tumor
                     depth=4,
                     start_ch=64,
                     inc_rate=2,
                     kernel_size=3,
                     padding=True,
                     batch_norm=True,
                     spec_norm=False,
                     dropout=0.5,
                     up_mode='upconv',
                     include_top=True,
                     include_last_act=False,
                     )
        model.load_state_dict(torch.load(args.load_weights))
        printYellow("Successfully loaded pretrained weights.")
    elif args.load_model:
        # load entire model
        model = torch.load(args.load_model)
        printYellow("Successfully loaded pretrained model.")
    model.eval()
    model.to(device)

    # n_batch_per_epoch = len(dataloader_test)

    sigmoid_act = torch.nn.Sigmoid()
    st = time.time()

    volume_start_index = test_set.volume_start_index
    spacing = test_set.spacing
    direction = test_set.direction  # use it for the submission
    offset = test_set.offset

    msk_pred_buffer = []
    if args.evaluate:
        msk_gt_buffer = []

    for data_batch in tqdm(dataloader_test):
        # import ipdb
        # ipdb.set_trace()
        if args.evaluate:
            img, msk_gt = data_batch
            msk_gt_buffer.append(msk_gt.cpu().detach().numpy())
        else:
            img = data_batch
        img = img.to(device)
        with torch.no_grad():
            msk_pred = model(img)  # shape (N, 3, H, W)
            msk_pred = sigmoid_act(msk_pred)
        msk_pred_buffer.append(msk_pred.cpu().detach().numpy())

    msk_pred_buffer = np.vstack(msk_pred_buffer)  # shape (N, 3, H, W)
    if args.evaluate:
        msk_gt_buffer = np.vstack(msk_gt_buffer)

    results = []
    for vol_ind, vol_start_ind in enumerate(volume_start_index):
        if vol_ind == len(volume_start_index) - 1:
            volume_msk = msk_pred_buffer[vol_start_ind:]  # shape (N, 3, H, W)
            if args.evaluate:
                volume_msk_gt = msk_gt_buffer[vol_start_ind:]
        else:
            vol_end_ind = volume_start_index[vol_ind+1]
            volume_msk = msk_pred_buffer[vol_start_ind:vol_end_ind]  # shape (N, 3, H, W)
            if args.evaluate:
                volume_msk_gt = msk_gt_buffer[vol_start_ind:vol_end_ind]
        if args.evaluate:
            # liver
            liver_scores = get_scores(volume_msk[:, 1] >= 0.5, volume_msk_gt >= 1, spacing[vol_ind])
            # tumor
            lesion_scores = get_scores(volume_msk[:, 2] >= 0.5, volume_msk_gt == 2, spacing[vol_ind])
            print("Liver dice", liver_scores['dice'], "Lesion dice", lesion_scores['dice'])
            results.append([vol_ind, liver_scores, lesion_scores])
            # ===========================
        else:
            # import ipdb; ipdb.set_trace()
            if args.save2dir:
                # reverse the order, because we prioritize tumor, liver then background.
                msk_pred = (volume_msk >= 0.5)[:, ::-1, ...]  # shape (N, 3, H, W)
                msk_pred = np.argmax(msk_pred, axis=1)  # shape (N, H, W) = (z, x, y)
                msk_pred = np.transpose(msk_pred, axes=(1, 2, 0))  # shape (x, y, z)
                # remember to correct 'direction' and np.transpose before the submission !!!
                if direction[vol_ind][0] == -1:
                    # x-axis
                    msk_pred = msk_pred[::-1, ...]
                if direction[vol_ind][1] == -1:
                    # y-axis
                    msk_pred = msk_pred[:, ::-1, :]
                if direction[vol_ind][2] == -1:
                    # z-axis
                    msk_pred = msk_pred[..., ::-1]
                # save medical image header as well
                # see: http://loli.github.io/medpy/generated/medpy.io.header.Header.html
                file_header = med_header(spacing=tuple(spacing[vol_ind]),
                                         offset=tuple(offset[vol_ind]),
                                         direction=np.diag(direction[vol_ind]))
                # submission guide:
                # see: https://github.com/PatrickChrist/LITS-CHALLENGE/blob/master/submission-guide.md
                # test-segmentation-X.nii
                filepath = os.path.join(args.save2dir, f"test-segmentation-{vol_ind}.nii")
                med_save(msk_pred, filepath, hdr=file_header)
    if args.save2dir:
        # outpath = os.path.join(args.save2dir, "results.csv")
        outpath = os.path.join(args.save2dir, "results.pkl")
        with open(outpath, "wb") as file:
            final_result = {}
            final_result['liver'] = defaultdict(list)
            final_result['tumor'] = defaultdict(list)
            for vol_ind, liver_scores, lesion_scores in results:
                # [OTC] assuming vol_ind is continuous
                for key in liver_scores:
                    final_result['liver'][key].append(liver_scores[key])
                for key in lesion_scores:
                    final_result['tumor'][key].append(lesion_scores[key])
            pickle.dump(final_result, file, protocol=3)
        # ======== code from official metric ========
        # create line for csv file
        # outstr = str(vol_ind) + ','
        # for l in [liver_scores, lesion_scores]:
        #     for k, v in l.items():
        #         outstr += str(v) + ','
        #         outstr += '\n'
        # # create header for csv file if necessary
        # if not os.path.isfile(outpath):
        #     headerstr = 'Volume,'
        #     for k, v in liver_scores.items():
        #         headerstr += 'Liver_' + k + ','
        #     for k, v in liver_scores.items():
        #         headerstr += 'Lesion_' + k + ','
        #     headerstr += '\n'
        #     outstr = headerstr + outstr
        # # write to file
        # f = open(outpath, 'a+')
        # f.write(outstr)
        # f.close()
        # ===========================
    printGreen(f"Total elapsed time: {time.time()-st}")
    return results
Example #4
0
    def learn(self, images_path, actions, rewards, episode_starts):
        """
        Learn a state representation
        :param images_path: (numpy 1D array)
        :param actions: (np.ndarray)
        :param rewards: (numpy 1D array)
        :param episode_starts: (numpy 1D array) boolean array
                                the ith index is True if one episode starts at this frame
        :return: (np.ndarray) the learned states for the given observations
        """

        print("\nYour are using the following weights for the losses:")
        pprint(self.losses_weights_dict)

        # PREPARE DATA -------------------------------------------------------------------------------------------------
        # here, we organize the data into minibatches
        # and find pairs for the respective loss terms (for robotics priors only)

        num_samples = images_path.shape[0] - 1  # number of samples

        # indices for all time steps where the episode continues
        indices = np.array([i for i in range(num_samples) if not episode_starts[i + 1]], dtype='int64')
        np.random.shuffle(indices)

        # split indices into minibatches. minibatchlist is a list of lists; each
        # list is the id of the observation preserved through the training
        minibatchlist = [np.array(sorted(indices[start_idx:start_idx + self.batch_size]))
                         for start_idx in range(0, len(indices) - self.batch_size + 1, self.batch_size)]

        test_minibatchlist = DataLoader.createTestMinibatchList(len(images_path), MAX_BATCH_SIZE_GPU)

        # Number of minibatches used for validation:
        n_val_batches = np.round(VALIDATION_SIZE * len(minibatchlist)).astype(np.int64)
        val_indices = np.random.permutation(len(minibatchlist))[:n_val_batches]
        # Print some info
        print("{} minibatches for training, {} samples".format(len(minibatchlist) - n_val_batches,
                                                               (len(minibatchlist) - n_val_batches) * BATCH_SIZE))
        print("{} minibatches for validation, {} samples".format(n_val_batches, n_val_batches * BATCH_SIZE))
        assert n_val_batches > 0, "Not enough sample to create a validation set"

        # Stats about actions
        if not self.continuous_action:
            print('Discrete action space:')
            action_set = set(actions)
            n_actions = int(np.max(actions) + 1)
            print("{} unique actions / {} actions".format(len(action_set), n_actions))
            n_pairs_per_action = np.zeros(n_actions, dtype=np.int64)
            n_obs_per_action = np.zeros(n_actions, dtype=np.int64)
            for i in range(n_actions):
                n_obs_per_action[i] = np.sum(actions == i)

            print("Number of observations per action")
            print(n_obs_per_action)

        else:
            print('Continuous action space:')
            print('Action dimension: {}'.format(self.dim_action))

        dissimilar_pairs, same_actions_pairs = None, None
        if not self.no_priors:
            if self.continuous_action:
                print('This option (priors) doesnt support continuous action space for now !')

            dissimilar_pairs, same_actions_pairs = findPriorsPairs(self.batch_size, minibatchlist, actions, rewards,
                                                                   n_actions, n_pairs_per_action)

        if self.use_vae and self.perceptual_similarity_loss and self.path_to_dae is not None:

            self.denoiser = SRLModules(state_dim=self.state_dim_dae, action_dim=self.dim_action,
                                       model_type="custom_cnn",
                                       cuda=self.cuda, losses=["dae"])
            self.denoiser.load_state_dict(th.load(self.path_to_dae))
            self.denoiser.eval()
            self.denoiser = self.denoiser.to(self.device)
            for param in self.denoiser.parameters():
                param.requires_grad = False

        if self.episode_prior:
            idx_to_episode = {idx: episode_idx for idx, episode_idx in enumerate(np.cumsum(episode_starts))}
            minibatch_episodes = [[idx_to_episode[i] for i in minibatch] for minibatch in minibatchlist]

        data_loader = DataLoader(minibatchlist, images_path, n_workers=N_WORKERS, multi_view=self.multi_view,
                                 use_triplets=self.use_triplets, is_training=True, apply_occlusion=self.use_dae,
                                 occlusion_percentage=self.occlusion_percentage)
        test_data_loader = DataLoader(test_minibatchlist, images_path, n_workers=N_WORKERS, multi_view=self.multi_view,
                                      use_triplets=self.use_triplets, max_queue_len=1, is_training=False,
                                      apply_occlusion=self.use_dae, occlusion_percentage=self.occlusion_percentage)
        # TRAINING -----------------------------------------------------------------------------------------------------
        loss_history = defaultdict(list)

        loss_manager = LossManager(self.model, loss_history)

        best_error = np.inf
        best_model_path = "{}/srl_model.pth".format(self.log_folder)
        start_time = time.time()

        # Random features, we don't need to train a model
        if len(self.losses) == 1 and self.losses[0] == 'random':
            global N_EPOCHS
            N_EPOCHS = 0
            printYellow("Skipping training because using random features")
            th.save(self.model.state_dict(), best_model_path)

        for epoch in range(N_EPOCHS):
            # In each epoch, we do a full pass over the training data:
            epoch_loss, epoch_batches = 0, 0
            val_loss = 0
            pbar = tqdm(total=len(minibatchlist))

            for minibatch_num, (minibatch_idx, obs, next_obs, noisy_obs, next_noisy_obs) in enumerate(data_loader):

                validation_mode = minibatch_idx in val_indices
                if validation_mode:
                    self.model.eval()
                else:
                    self.model.train()

                if self.use_dae:
                    noisy_obs = noisy_obs.to(self.device)
                    next_noisy_obs = next_noisy_obs.to(self.device)
                obs, next_obs = obs.to(self.device), next_obs.to(self.device)

                self.optimizer.zero_grad()
                loss_manager.resetLosses()

                decoded_obs, decoded_next_obs = None, None
                states_denoiser = None
                states_denoiser_predicted = None
                next_states_denoiser = None
                next_states_denoiser_predicted = None

                # Predict states given observations as in Time Contrastive Network (Triplet Loss) [Sermanet et al.]
                if self.use_triplets:
                    states, positive_states, negative_states = self.model.forwardTriplets(obs[:, :3:, :, :],
                                                                                          obs[:, 3:6, :, :],
                                                                                          obs[:, 6:, :, :])

                    next_states, next_positive_states, next_negative_states = self.model.forwardTriplets(
                        next_obs[:, :3:, :, :],
                        next_obs[:, 3:6, :, :],
                        next_obs[:, 6:, :, :])
                elif self.use_autoencoder:
                    (states, decoded_obs), (next_states, decoded_next_obs) = self.model(obs), self.model(next_obs)

                elif self.use_dae:
                    (states, decoded_obs), (next_states, decoded_next_obs) = \
                        self.model(noisy_obs), self.model(next_noisy_obs)

                elif self.use_vae:
                    (decoded_obs, mu, logvar), (next_decoded_obs, next_mu, next_logvar) = self.model(obs), \
                                                                                          self.model(next_obs)
                    states, next_states = self.model.getStates(obs), self.model.getStates(next_obs)

                    if self.perceptual_similarity_loss:
                        # Predictions for the perceptual similarity loss as in DARLA
                        # https://arxiv.org/pdf/1707.08475.pdf
                        (states_denoiser, decoded_obs_denoiser), (next_states_denoiser, decoded_next_obs_denoiser) = \
                            self.denoiser(obs), self.denoiser(next_obs)

                        (states_denoiser_predicted, decoded_obs_denoiser_predicted) = self.denoiser(decoded_obs)
                        (next_states_denoiser_predicted,
                         decoded_next_obs_denoiser_predicted) = self.denoiser(next_decoded_obs)
                else:
                    states, next_states = self.model(obs), self.model(next_obs)

                # Actions associated to the observations of the current minibatch
                actions_st = actions[minibatchlist[minibatch_idx]]
                if not self.continuous_action:
                    # Discrete actions, rearrange action to have n_minibatch ligns and one column, containing the int action
                    actions_st = th.from_numpy(actions_st).view(-1, 1).requires_grad_(False).to(self.device)
                else:
                    # Continuous actions, rearrange action to have n_minibatch ligns and dim_action columns
                    actions_st = th.from_numpy(actions_st).view(-1, self.dim_action).requires_grad_(False).to(self.device)

                # L1 regularization
                if self.losses_weights_dict['l1_reg'] > 0:
                    l1Loss(loss_manager.reg_params, self.losses_weights_dict['l1_reg'], loss_manager)

                if self.losses_weights_dict['l2_reg'] > 0:
                    l2Loss(loss_manager.reg_params, self.losses_weights_dict['l2_reg'], loss_manager)

                if not self.no_priors:
                    if self.n_actions == np.inf:
                        print('This option (priors) doesnt support continuous action space for now !')

                    roboticPriorsLoss(states, next_states, minibatch_idx=minibatch_idx,
                                      dissimilar_pairs=dissimilar_pairs, same_actions_pairs=same_actions_pairs,
                                      weight=self.losses_weights_dict['priors'], loss_manager=loss_manager)

                # TODO change here to classic call (forward and backward)
                if self.use_forward_loss:
                    next_states_pred = self.model.forwardModel(states, actions_st)
                    forwardModelLoss(next_states_pred, next_states,
                                     weight=self.losses_weights_dict['forward'],
                                     loss_manager=loss_manager)

                if self.use_inverse_loss:
                    actions_pred = self.model.inverseModel(states, next_states)
                    inverseModelLoss(actions_pred, actions_st, weight=self.losses_weights_dict['inverse'],
                                     loss_manager=loss_manager, continuous_action=self.continuous_action)

                if self.use_reward_loss:
                    rewards_st = rewards[minibatchlist[minibatch_idx]].copy()
                    # Removing negative reward
                    rewards_st[rewards_st == -1] = 0
                    rewards_st = th.from_numpy(rewards_st).to(self.device)
                    rewards_pred = self.model.rewardModel(states, next_states)
                    rewardModelLoss(rewards_pred, rewards_st.long(), weight=self.losses_weights_dict['reward'],
                                    loss_manager=loss_manager)

                if self.use_autoencoder or self.use_dae:
                    loss_type = "dae" if self.use_dae else "autoencoder"
                    autoEncoderLoss(obs, decoded_obs, next_obs, decoded_next_obs,
                                    weight=self.losses_weights_dict[loss_type], loss_manager=loss_manager)

                if self.use_vae:

                    kullbackLeiblerLoss(mu, next_mu, logvar, next_logvar, loss_manager=loss_manager, beta=self.beta)

                    if self.perceptual_similarity_loss:
                        perceptualSimilarityLoss(states_denoiser, states_denoiser_predicted, next_states_denoiser,
                                                 next_states_denoiser_predicted,
                                                 weight=self.losses_weights_dict['perceptual'],
                                                 loss_manager=loss_manager)
                    else:
                        generationLoss(decoded_obs, next_decoded_obs, obs, next_obs,
                                       weight=self.losses_weights_dict['vae'], loss_manager=loss_manager)

                if self.reward_prior:
                    rewards_st = rewards[minibatchlist[minibatch_idx]]
                    rewards_st = th.from_numpy(rewards_st).float().view(-1, 1).to(self.device)
                    rewardPriorLoss(states, rewards_st, weight=self.losses_weights_dict['reward-prior'],
                                    loss_manager=loss_manager)

                if self.episode_prior:
                    episodePriorLoss(minibatch_idx, minibatch_episodes, states, self.discriminator,
                                     BALANCED_SAMPLING, weight=self.losses_weights_dict['episode-prior'],
                                     loss_manager=loss_manager)
                if self.use_triplets:
                    tripletLoss(states, positive_states, negative_states, weight=self.losses_weights_dict['triplet'],
                                loss_manager=loss_manager, alpha=0.2)
                # Compute weighted average of losses
                loss_manager.updateLossHistory()
                loss = loss_manager.computeTotalLoss()

                # We have to call backward in both train/val
                # to avoid memory error
                loss.backward()
                if validation_mode:
                    val_loss += loss.item()
                    # We do not optimize on validation data
                    # so optimizer.step() is not called
                else:
                    self.optimizer.step()
                    epoch_loss += loss.item()
                    epoch_batches += 1
                pbar.update(1)
            pbar.close()

            train_loss = epoch_loss / float(epoch_batches)
            val_loss /= float(n_val_batches)
            # Even if loss_history is modified by LossManager
            # we make it explicit
            loss_history = loss_manager.loss_history
            loss_history['train_loss'].append(train_loss)
            loss_history['val_loss'].append(val_loss)
            for key in loss_history.keys():
                if key in ['train_loss', 'val_loss']:
                    continue
                loss_history[key][-1] /= epoch_batches
                if epoch + 1 < N_EPOCHS:
                    loss_history[key].append(0)

            # Save best model
            if val_loss < best_error:
                best_error = val_loss
                th.save(self.model.state_dict(), best_model_path)

            if np.isnan(train_loss):
                printRed("NaN Loss, consider increasing NOISE_STD in the gaussian noise layer")
                sys.exit(NAN_ERROR)

            # Then we print the results for this epoch:
            if (epoch + 1) % EPOCH_FLAG == 0:
                print("Epoch {:3}/{}, train_loss:{:.4f} val_loss:{:.4f}".format(epoch + 1, N_EPOCHS, train_loss,
                                                                                val_loss))
                print("{:.2f}s/epoch".format((time.time() - start_time) / (epoch + 1)))
                if DISPLAY_PLOTS:
                    with th.no_grad():
                        self.model.eval()
                        # Optionally plot the current state space
                        plotRepresentation(self.predStatesWithDataLoader(test_data_loader), rewards,
                                           add_colorbar=epoch == 0,
                                           name="Learned State Representation (Training Data)")

                        if self.use_autoencoder or self.use_vae or self.use_dae:
                            # Plot Reconstructed Image
                            if obs[0].shape[0] == 3:  # RGB
                                plotImage(deNormalize(detachToNumpy(obs[0])), "Input Image (Train)")
                                if self.use_dae:
                                    plotImage(deNormalize(detachToNumpy(noisy_obs[0])), "Noisy Input Image (Train)")
                                if self.perceptual_similarity_loss:
                                    plotImage(deNormalize(detachToNumpy(decoded_obs_denoiser[0])),
                                              "Reconstructed Image DAE")
                                    plotImage(deNormalize(detachToNumpy(decoded_obs_denoiser_predicted[0])),
                                              "Reconstructed Image predicted DAE")
                                plotImage(deNormalize(detachToNumpy(decoded_obs[0])), "Reconstructed Image")

                            elif obs[0].shape[0] % 3 == 0:  # Multi-RGB
                                for k in range(obs[0].shape[0] // 3):
                                    plotImage(deNormalize(detachToNumpy(obs[0][k * 3:(k + 1) * 3, :, :]), "image_net"),
                                              "Input Image {} (Train)".format(k + 1))
                                    if self.use_dae:
                                        plotImage(deNormalize(detachToNumpy(noisy_obs[0][k * 3:(k + 1) * 3, :, :])),
                                                  "Noisy Input Image (Train)".format(k + 1))
                                    if self.perceptual_similarity_loss:
                                        plotImage(deNormalize(
                                            detachToNumpy(decoded_obs_denoiser[0][k * 3:(k + 1) * 3, :, :])),
                                            "Reconstructed Image DAE")
                                        plotImage(deNormalize(
                                            detachToNumpy(decoded_obs_denoiser_predicted[0][k * 3:(k + 1) * 3, :, :])),
                                            "Reconstructed Image predicted DAE")
                                    plotImage(deNormalize(detachToNumpy(decoded_obs[0][k * 3:(k + 1) * 3, :, :])),
                                              "Reconstructed Image {}".format(k + 1))

        if DISPLAY_PLOTS:
            plt.close("Learned State Representation (Training Data)")

        # Load best model before predicting states
        self.model.load_state_dict(th.load(best_model_path))

        print("Predicting states for all the observations...")
        # return predicted states for training observations
        self.model.eval()
        with th.no_grad():
            pred_states = self.predStatesWithDataLoader(test_data_loader)
        pairs_loss_weight = [k for k in zip(loss_manager.names, loss_manager.weights)]
        return loss_history, pred_states, pairs_loss_weight
Example #5
0
    def __init__(self, state_dim, model_type="resnet", inverse_model_type="linear", log_folder="logs/default",
                 seed=1, learning_rate=0.001, l1_reg=0.0, l2_reg=0.0, cuda=False,
                 multi_view=False, losses=None, losses_weights_dict=None, n_actions=6, continuous_action=False, beta=1,
                 split_dimensions=-1, path_to_dae=None, state_dim_dae=200, occlusion_percentage=None):

        super(SRL4robotics, self).__init__(state_dim, BATCH_SIZE, seed, cuda)

        self.multi_view = multi_view
        self.losses = losses
        self.dim_action = n_actions
        self.continuous_action = continuous_action
        self.beta = beta
        self.denoiser = None

        if model_type in ["linear", "mlp", "resnet", "custom_cnn"] \
                or "autoencoder" in losses or "vae" in losses:
            self.use_forward_loss = "forward" in losses
            self.use_inverse_loss = "inverse" in losses
            self.use_reward_loss = "reward" in losses
            self.no_priors = "priors" not in losses
            self.episode_prior = "episode-prior" in losses
            self.reward_prior = "reward-prior" in losses
            self.use_autoencoder = "autoencoder" in losses
            self.use_vae = "vae" in losses
            self.use_triplets = "triplet" in self.losses
            self.perceptual_similarity_loss = "perceptual" in self.losses
            self.use_dae = "dae" in self.losses
            self.path_to_dae = path_to_dae

            if isinstance(split_dimensions, OrderedDict) and sum(split_dimensions.values()) > 0:
                printYellow("Using splitted representation")
                self.model = SRLModulesSplit(state_dim=self.state_dim, action_dim=self.dim_action,
                                             model_type=model_type, cuda=cuda, losses=losses,
                                             split_dimensions=split_dimensions, inverse_model_type=inverse_model_type)
            else:
                self.model = SRLModules(state_dim=self.state_dim, action_dim=self.dim_action,
                                        continuous_action=self.continuous_action, model_type=model_type,
                                        cuda=cuda, losses=losses, inverse_model_type=inverse_model_type)
        else:
            raise ValueError("Unknown model: {}".format(model_type))

        print("Using {} model".format(model_type))

        self.cuda = cuda
        self.device = th.device("cuda" if th.cuda.is_available() and cuda else "cpu")

        if self.episode_prior:
            self.discriminator = Discriminator(2 * self.state_dim).to(self.device)

        self.model = self.model.to(self.device)

        learnable_params = [param for param in self.model.parameters() if param.requires_grad]

        if self.episode_prior:
            learnable_params += [p for p in self.discriminator.parameters()]

        self.optimizer = th.optim.Adam(learnable_params, lr=learning_rate)
        self.log_folder = log_folder
        self.model_type = model_type

        # Default weights that are updated with the weights passed to the script
        self.losses_weights_dict = {"forward": 1.0, "inverse": 2.0, "reward": 1.0, "priors": 1.0,
                                    "episode-prior": 1.0, "reward-prior": 10, "triplet": 1.0,
                                    "autoencoder": 1.0, "vae": 0.5e-6, "perceptual": 1e-6, "dae": 1.0,
                                    'l1_reg': l1_reg, "l2_reg": l2_reg, 'random': 1.0}
        self.occlusion_percentage = occlusion_percentage
        self.state_dim_dae = state_dim_dae

        if losses_weights_dict is not None:
            self.losses_weights_dict.update(losses_weights_dict)

        if self.use_dae and self.occlusion_percentage is not None:
            print("Using a maximum occlusion surface of {}".format(str(self.occlusion_percentage)))
Example #6
0
        printGreen("\n Grid search on several state_dim on dataset folder: {} \n".format(exp_config['data-folder']))

        createFolder("logs/{}".format(exp_config['data-folder']), "Dataset log folder already exist")
        createFolder("logs/{}/baselines".format(exp_config['data-folder']), "Baseline folder already exist")

        # Check that the dataset is already preprocessed
        preprocessingCheck(exp_config)

        # Grid search
        for seed in [0]:
            exp_config['seed'] = seed
            for state_dim in [3, 6]:
                # Update config
                exp_config['state-dim'] = state_dim
                log_folder, experiment_name = getLogFolderName(exp_config)
                exp_config['log-folder'] = log_folder
                exp_config['experiment-name'] = experiment_name
                # Save config in log folder
                saveConfig(exp_config, print_config=True)

                # Learn a state representation and plot it
                ok = stateRepresentationLearningCall(exp_config)
                if not ok:
                    printYellow("Skipping evaluation...")
                    continue
                # Evaluate the representation with kNN
                knnCall(exp_config)

    else:
        printYellow("Please specify one of --exp-config or --data-folder")