コード例 #1
0
    def predict(self, X, y, query_label, device=0, enable_dropout=False):
        """
        Predicts the outout after the model is trained.
        Inputs:
        - X: Volume to be predicted
        """
        self.eval()
        input1, input2, y2 = split_batch(X, y, query_label)
        input1, input2, y2 = to_cuda(input1, device), to_cuda(input2, device), to_cuda(y2, device)

        if enable_dropout:
            self.enable_test_dropout()

        with torch.no_grad():
            out = self.forward(input1, input2)

        # max_val, idx = torch.max(out, 1)
        idx = out > 0.5
        idx = idx.data.cpu().numpy()
        prediction = np.squeeze(idx)
        del X, out, idx
        return prediction
コード例 #2
0
    def train(self, train_loader, test_loader):
        """
        Train a given model with the provided data.

        Inputs:
        - train_loader: train data in torch.utils.data.DataLoader
        - val_loader: val data in torch.utils.data.DataLoader
        """
        model, optim, scheduler = self.model, self.optim, self.scheduler

        data_loader = {'train': train_loader, 'val': test_loader}

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            model.cuda(self.device)

        self.logWriter.log(
            'START TRAINING. : model name = %s, device = %s' %
            (self.model_name, torch.cuda.get_device_name(self.device)))
        current_iteration = self.start_iteration
        warm_up_epoch = 5
        val_old = 0
        change_model = False
        current_model = 'seg'
        for epoch in range(self.start_epoch, self.num_epochs + 1):
            self.logWriter.log(
                'train', "\n==== Epoch [ %d  /  %d ] START ====" %
                (epoch, self.num_epochs))

            for phase in ['train', 'val']:
                self.logWriter.log("<<<= Phase: %s =>>>" % phase)
                loss_arr = []
                input_img_list = []
                y_list = []
                out_list = []
                condition_input_img_list = []
                condition_y_list = []

                if phase == 'train':
                    model.train()
                    scheduler.step()
                else:
                    model.eval()
                for i_batch, sampled_batch in enumerate(data_loader[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
                    y = sampled_batch[1].type(torch.LongTensor)
                    w = sampled_batch[2].type(torch.FloatTensor)

                    query_label = data_loader[phase].batch_sampler.query_label

                    input1, input2, y1, y2 = split_batch(
                        X, y, int(query_label))

                    condition_input = torch.mul(input1, y1.unsqueeze(1))
                    query_input = input2

                    if model.is_cuda:
                        condition_input, query_input, y2 = condition_input.cuda(
                            self.device, non_blocking=True), query_input.cuda(
                                self.device,
                                non_blocking=True), y2.cuda(self.device,
                                                            non_blocking=True)

                    output = model(condition_input, query_input)
                    # TODO: add weights
                    loss = self.loss_func(output, y2)
                    optim.zero_grad()
                    loss.backward()
                    if phase == 'train':
                        optim.step()

                        if i_batch % self.log_nth == 0:
                            self.logWriter.loss_per_iter(
                                loss.item(), i_batch, current_iteration)
                        current_iteration += 1

                    loss_arr.append(loss.item())

                    # batch_output = output > 0.5
                    _, batch_output = torch.max(F.softmax(output, dim=1),
                                                dim=1)

                    out_list.append(batch_output.cpu())
                    input_img_list.append(input2.cpu())
                    y_list.append(y2.cpu())
                    condition_input_img_list.append(input1.cpu())
                    condition_y_list.append(y1)

                    del X, y, w, output, batch_output, loss, input1, input2, y2
                    torch.cuda.empty_cache()
                    if phase == 'val':
                        if i_batch != len(data_loader[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)
                if phase == 'train':
                    self.logWriter.log('saving checkpoint ....')
                    self.save_checkpoint(
                        {
                            'epoch': epoch + 1,
                            'start_iteration': current_iteration + 1,
                            'arch': self.model_name,
                            'state_dict': model.state_dict(),
                            'optimizer': optim.state_dict(),
                            'scheduler': scheduler.state_dict(),
                        },
                        os.path.join(
                            self.exp_dir_path, CHECKPOINT_DIR,
                            'checkpoint_epoch_' + str(epoch) + '.' +
                            CHECKPOINT_EXTENSION))

                with torch.no_grad():
                    input_img_arr = torch.cat(input_img_list)
                    y_arr = torch.cat(y_list)
                    out_arr = torch.cat(out_list)
                    condition_input_img_arr = torch.cat(
                        condition_input_img_list)
                    condition_y_arr = torch.cat(condition_y_list)

                    current_loss = self.logWriter.loss_per_epoch(
                        loss_arr, phase, epoch)
                    if phase == 'val':
                        if epoch > warm_up_epoch:
                            self.logWriter.log("Diff : " +
                                               str(current_loss - val_old))
                            change_model = (current_loss - val_old) > 0.001

                        if change_model and current_model == 'seg':
                            self.logWriter.log("Setting to con")
                            current_model = 'con'
                        elif change_model and current_model == 'con':
                            self.logWriter.log("Setting to seg")
                            current_model = 'seg'
                        val_old = current_loss
                    index = np.random.choice(len(out_arr), 3, replace=False)
                    self.logWriter.image_per_epoch(
                        out_arr[index],
                        y_arr[index],
                        phase,
                        epoch,
                        additional_image=(input_img_arr[index],
                                          condition_input_img_arr[index],
                                          condition_y_arr[index]))
                    self.logWriter.dice_score_per_epoch(
                        phase, out_arr, y_arr, epoch)

                    self.logWriter.log("==== Epoch [" + str(epoch) + " / " +
                                       str(self.num_epochs) + "] DONE ====")
                self.logWriter.log('FINISH.')
        self.logWriter.close()
コード例 #3
0
    def train(self, train_loader, test_loader):
        """
        Train a given model with the provided data.

        Inputs:
        - train_loader: train data in torch.utils.data.DataLoader
        - val_loader: val data in torch.utils.data.DataLoader
        """
        model, optim_c, optim_s, scheduler_c, scheduler_s = self.model, self.optim_c, self.optim_s, self.scheduler_c, self.scheduler_s

        data_loader = {
            'train': train_loader,
            'val': test_loader
        }

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            model.cuda(self.device)

        print('START TRAINING. : model name = %s, device = %s' % (
            self.model_name, torch.cuda.get_device_name(self.device)))
        current_iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.num_epochs + 1):
            print("\n==== Epoch [ %d  /  %d ] START ====" % (epoch, self.num_epochs))
            for phase in ['train', 'val']:
                print("<<<= Phase: %s =>>>" % phase)
                loss_arr = []
                input_img_list = []
                y_list = []
                out_list = []
                condition_input_img_list = []
                condition_y_list = []

                if phase == 'train':
                    model.train()
                    scheduler_c.step()
                    scheduler_s.step()
                else:
                    model.eval()
                for i_batch, sampled_batch in enumerate(data_loader[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
                    y = sampled_batch[1].type(torch.LongTensor)
                    w = sampled_batch[2].type(torch.FloatTensor)

                    query_label = data_loader[phase].batch_sampler.query_label

                    input1, input2, y1, y2 = split_batch(X, y, int(query_label))
                    condition_input = torch.mul(input1, y1.unsqueeze(1))

                    if model.is_cuda:
                        condition_input, input2, y2 = condition_input.cuda(self.device, non_blocking=True), input2.cuda(
                            self.device,
                            non_blocking=True), y2.cuda(
                            self.device, non_blocking=True)

                    weights = model.conditioner(condition_input)
                    output = model.segmentor(input2, weights)
                    # TODO: add weights
                    loss = self.loss_func(output, y2)
                    optim_s.zero_grad()
                    optim_c.zero_grad()
                    loss.backward()

                    if phase == 'train':
                        if epoch <= 1:
                            optim_s.step()
                            optim_c.step()
                        elif epoch in [2, 3, 6, 7, 10]:
                            optim_s.step()
                        elif epoch in [4, 5, 8, 9]:
                            optim_c.step()

                        # # TODO: value needs to be optimized, Gradient Clipping (Optional)
                        # if epoch > 1:
                        #     torch.nn.utils.clip_grad_norm_(model.parameters(), 0.0001)

                        if i_batch % self.log_nth == 0:
                            self.logWriter.loss_per_iter(loss.item(), i_batch, current_iteration)
                        current_iteration += 1

                    loss_arr.append(loss.item())

                    # batch_output = output > 0.5
                    _, batch_output = torch.max(output, dim=1)

                    out_list.append(batch_output.cpu())
                    input_img_list.append(input2.cpu())
                    y_list.append(y2.cpu())
                    condition_input_img_list.append(input1.cpu())
                    condition_y_list.append(y1)

                    del X, y, w, output, batch_output, loss, input1, input2, y2
                    torch.cuda.empty_cache()
                    if phase == 'val':
                        if i_batch != len(data_loader[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)
                if phase == 'train':
                    print('saving checkpoint ....')
                    self.save_checkpoint({
                        'epoch': epoch + 1,
                        'start_iteration': current_iteration + 1,
                        'arch': self.model_name,
                        'state_dict': model.state_dict(),
                        'optimizer_c': optim_c.state_dict(),
                        'scheduler_c': scheduler_c.state_dict(),
                        'optimizer_s': optim_s.state_dict(),
                        'scheduler_s': scheduler_s.state_dict()
                    }, os.path.join(self.exp_dir_path, CHECKPOINT_DIR,
                                    'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION))

                with torch.no_grad():
                    input_img_arr = torch.cat(input_img_list)
                    y_arr = torch.cat(y_list)
                    out_arr = torch.cat(out_list)
                    condition_input_img_arr = torch.cat(condition_input_img_list)
                    condition_y_arr = torch.cat(condition_y_list)

                    self.logWriter.loss_per_epoch(loss_arr, phase, epoch)
                    index = np.random.choice(len(out_arr), 3, replace=False)
                    self.logWriter.image_per_epoch(out_arr[index], y_arr[index], phase, epoch, additional_image=(
                        input_img_arr[index], condition_input_img_arr[index], condition_y_arr[index]))
                    self.logWriter.dice_score_per_epoch(phase, out_arr, y_arr, epoch)

            print("==== Epoch [" + str(epoch) + " / " + str(self.num_epochs) + "] DONE ====")
        print('FINISH.')
        self.logWriter.close()