def training_step(self, batch, batch_idx, optimizer_idx):
        self.unused_(batch_idx)

        # unpacking batch
        images = batch[0]
        masks = batch[1]
        batch_size = len(batch)

        # init losses
        g_loss = model.loss.GeneratorLoss()
        d_loss = model.loss.DiscriminatorLoss()
        r_loss = model.loss.ReconLoss()

        # discriminator training step
        if optimizer_idx == 0:
            fake_images, coarse_raw, recon_raw = self(images, masks)
            all_images = torch.cat([fake_images, images], dim=0)

            double_masks = torch.cat([masks, masks], dim=0)
            all_output = self.discriminator(all_images, double_masks)

            fake_output = all_output[:batch_size]
            real_output = all_output[batch_size:]

            loss = d_loss(real_output, fake_output)
            self.log('d_loss', loss.item())
            return loss

        # generator training step
        if optimizer_idx == 1:
            fake_images, coarse_raw, recon_raw = self(images, masks)
            d_output = self.discriminator(fake_images, masks)

            loss_1 = g_loss(d_output)
            loss_2 = r_loss(images, coarse_raw, recon_raw, masks)
            loss = loss_1

            if np.random.uniform(0, 1) >= 0.1:
                loss += loss_2

            self.log('g_loss', loss_1.item())
            self.log('r_loss', loss_2.item())
            return loss
示例#2
0
def evaluate(model, loss_fn, dataloader, metrics, model_dir, hyper_params):
    """Evaluate the model.
    Args:
        model: (torch.nn.Module) the neural network
        loss_fn: (Function) a function that takes batch_output and batch_labels and computes the loss for the batch
        dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        model_dir: (string) Location to save the output images in.
        hyper_params: (HyperParams) hyperparameters
    """

    # set model to evaluation mode
    model.eval()

    # summary for current eval loop
    summ = []

    # dictionary of results for every separate image
    all_single_metrics = {}

    with torch.no_grad():
        # compute metrics over the dataset
        for idx, (scan_batch, ground_truth_batch,
                  ground_truth_filename) in enumerate(dataloader):

            # move to GPU if available
            if hyper_params.cuda:
                scan_batch, ground_truth_batch = scan_batch.to(
                    device=hyper_params.cuda), ground_truth_batch.to(
                        device=hyper_params.cuda)
            # fetch the next evaluation batch
            scan_batch, ground_truth_batch = Variable(scan_batch), Variable(
                ground_truth_batch)

            # compute model output
            output_batch = model(scan_batch)
            loss = loss_fn(output_batch, ground_truth_batch)

            if model_dir is not "":
                # compute loss for every single file of this batch
                for i in range(0, output_batch.shape[0]):
                    all_single_metrics[str(
                        Path(ground_truth_filename[i]).parts[-1])] = {}
                    all_single_metrics[str(
                        Path(ground_truth_filename[i]).parts[-1]
                    )]['loss'] = loss_fn(
                        torch.index_select(
                            output_batch, 0,
                            torch.tensor(
                                [i],
                                device='cuda:' + str(hyper_params.cuda)[-1]
                                if hyper_params.cuda is not -1 else 'cpu')),
                        torch.index_select(
                            ground_truth_batch, 0,
                            torch.tensor(
                                [i],
                                device='cuda:' +
                                str(hyper_params.cuda)[-1] if hyper_params.cuda
                                is not -1 else 'cpu'))).item()
                    #print("Old shape: {}, new shape: {}".format(output_batch.shape, torch.index_select(output_batch, 0, torch.tensor([i], device='cuda:'+str(hyper_params.cuda)[-1] if hyper_params.cuda is not -1 else 'cpu')).shape))

            # make output_batch one-hot encoded
            output_batch = torch.sigmoid(output_batch)
            output_batch = (output_batch > hyper_params.treshold).float()

            # extract data from torch Variable, move to cpu, convert to numpy arrays
            output_batch = output_batch.data.cpu().numpy()
            ground_truth_batch = ground_truth_batch.data.cpu().numpy()
            error_batch = np.absolute(
                np.subtract(ground_truth_batch, output_batch))

            # save result images
            if model_dir is not "":
                #print("Output batch shape: {}/{}".format(output_batch.shape[0], output_batch.shape))
                for i in range(0, output_batch.shape[0]):
                    image = Image.fromarray(output_batch[i][0], 'I')
                    image.save(
                        Path(model_dir) /
                        str(Path(ground_truth_filename[i]).parts[-1]).replace(
                            ".png", "_SEG.png"))
                    image = Image.fromarray(error_batch[i][0], 'I')
                    image.save(
                        Path(model_dir) /
                        str(Path(ground_truth_filename[i]).parts[-1]).replace(
                            ".png", "_SEG_ERROR.png"))

            # compute all metrics on this batch
            summary_batch = {
                metric: metrics[metric](output_batch, ground_truth_batch)
                for metric in metrics
            }
            summary_batch['loss'] = loss.item()
            summ.append(summary_batch)

            if model_dir is not "":
                # compute all metrics for every single file of this batch
                for i in range(0, output_batch.shape[0]):
                    #print("Original shape vs. Reduced shapes: {} / {}".format(output_batch.shape, output_batch[i:i+1][:].shape))
                    saved_loss = all_single_metrics[str(
                        Path(ground_truth_filename[i]).parts[-1])]['loss']
                    all_single_metrics[str(
                        Path(ground_truth_filename[i]).parts[-1])] = {
                            metric:
                            metrics[metric](output_batch[i:i + 1][:],
                                            ground_truth_batch[i:i + 1][:])
                            for metric in metrics
                        }
                    all_single_metrics[str(
                        Path(ground_truth_filename[i]).parts[-1]
                    )]['loss'] = saved_loss

    # compute mean of all metrics in summary
    metrics_mean = {
        metric: np.mean([x[metric] for x in summ])
        for metric in summ[0]
    }
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v)
                                for k, v in metrics_mean.items())
    logging.info("- Eval metrics : " + metrics_string)

    if model_dir is not "":
        sorted_all_single_metrics = {}
        for key in sorted(all_single_metrics,
                          key=lambda x: (all_single_metrics[x]['dsc']),
                          reverse=True):
            sorted_all_single_metrics[key] = all_single_metrics[key]
        # write all_single_metrics to a json
        all_single_metrics_path = str(
            Path(model_dir) / "metrics_test_single_file_names.json")
        utils.save_dict_to_json(sorted_all_single_metrics,
                                all_single_metrics_path)
    return metrics_mean
示例#3
0
    def calc_losses(self, data, is_train=True, global_step=0):
        if "images" not in data:
            return {}
        all_images = data["images"].to(device=device)  # (SB, NV, 3, H, W)

        SB, NV, _, H, W = all_images.shape
        all_poses = data["poses"].to(device=device)  # (SB, NV, 4, 4)
        all_bboxes = data.get("bbox")  # (SB, NV, 4)  cmin rmin cmax rmax
        all_focals = data["focal"]  # (SB)
        all_c = data.get("c")  # (SB)

        if self.use_bbox and global_step >= args.no_bbox_step:
            self.use_bbox = False
            print(">>> Stopped using bbox sampling @ iter", global_step)

        if not is_train or not self.use_bbox:
            all_bboxes = None

        all_rgb_gt = []
        all_rays = []

        curr_nviews = nviews[torch.randint(0, len(nviews), ()).item()]
        if curr_nviews == 1:
            image_ord = torch.randint(0, NV, (SB, 1))
        else:
            image_ord = torch.empty((SB, curr_nviews), dtype=torch.long)
        for obj_idx in range(SB):
            if all_bboxes is not None:
                bboxes = all_bboxes[obj_idx]
            images = all_images[obj_idx]  # (NV, 3, H, W)
            poses = all_poses[obj_idx]  # (NV, 4, 4)
            focal = all_focals[obj_idx]
            c = None
            if "c" in data:
                c = data["c"][obj_idx]
            if curr_nviews > 1:
                # Somewhat inefficient, don't know better way
                image_ord[obj_idx] = torch.from_numpy(
                    np.random.choice(NV, curr_nviews, replace=False))
            images_0to1 = images * 0.5 + 0.5

            cam_rays = util.gen_rays(poses,
                                     W,
                                     H,
                                     focal,
                                     self.z_near,
                                     self.z_far,
                                     c=c)  # (NV, H, W, 8)
            rgb_gt_all = images_0to1
            rgb_gt_all = (rgb_gt_all.permute(0, 2, 3,
                                             1).contiguous().reshape(-1, 3)
                          )  # (NV, H, W, 3)

            if all_bboxes is not None:
                pix = util.bbox_sample(bboxes, args.ray_batch_size)
                pix_inds = pix[..., 0] * H * W + pix[..., 1] * W + pix[..., 2]
            else:
                pix_inds = torch.randint(0, NV * H * W,
                                         (args.ray_batch_size, ))

            rgb_gt = rgb_gt_all[pix_inds]  # (ray_batch_size, 3)
            rays = cam_rays.view(-1, cam_rays.shape[-1])[pix_inds].to(
                device=device)  # (ray_batch_size, 8)

            all_rgb_gt.append(rgb_gt)
            all_rays.append(rays)

        all_rgb_gt = torch.stack(all_rgb_gt)  # (SB, ray_batch_size, 3)
        all_rays = torch.stack(all_rays)  # (SB, ray_batch_size, 8)

        image_ord = image_ord.to(device)
        src_images = util.batched_index_select_nd(
            all_images, image_ord)  # (SB, NS, 3, H, W)
        src_poses = util.batched_index_select_nd(all_poses,
                                                 image_ord)  # (SB, NS, 4, 4)

        all_bboxes = all_poses = all_images = None

        net.encode(
            src_images,
            src_poses,
            all_focals.to(device=device),
            c=all_c.to(device=device) if all_c is not None else None,
        )

        render_dict = DotMap(render_par(
            all_rays,
            want_weights=True,
        ))
        coarse = render_dict.coarse
        fine = render_dict.fine
        using_fine = len(fine) > 0

        loss_dict = {}

        rgb_loss = self.rgb_coarse_crit(coarse.rgb, all_rgb_gt)
        loss_dict["rc"] = rgb_loss.item() * self.lambda_coarse
        if using_fine:
            fine_loss = self.rgb_fine_crit(fine.rgb, all_rgb_gt)
            rgb_loss = rgb_loss * self.lambda_coarse + fine_loss * self.lambda_fine
            loss_dict["rf"] = fine_loss.item() * self.lambda_fine

        loss = rgb_loss
        if is_train:
            loss.backward()
        loss_dict["t"] = loss.item()

        return loss_dict
示例#4
0
def train_model(model,
                dataloaders,
                criterion,
                optimizer,
                writer,
                scheduler,
                config,
                num_epochs=150):
    since = time.time()
    val_acc_history = []
    saved_model_name = "music_siamese_50000" + datetime.now().strftime(
        '%b%d_%H-%M-%S') + ".pth"

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    stop_times = 0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        info = {}
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            running_loss = 0.0
            running_corrects = 0

            sum_label = torch.zeros(0)
            sum_preds = torch.zeros(0)
            for input1s, input2s, labels in tqdm(dataloaders[phase]):

                if config.multi_gpu:
                    if config.model_type == 'cnn':
                        input1s = input1s.cuda(config.device_ids[0])
                        input2s = input2s.cuda(config.device_ids[0])
                        labels = labels.cuda(config.device_ids[0])
                    else:
                        input1s = input1s.cuda(
                            config.device_ids[0]), model.init_h0().cuda(
                                config.device_ids[0])
                        input2s = input2s.cuda(
                            config.device_ids[0]), model.init_h0().cuda(
                                config.device_ids[0])
                        labels = labels.cuda(config.device_ids[0])

                else:
                    if config.model_type == 'cnn':
                        input1s = input1s.to(config.device)
                        input2s = input2s.to(config.device)
                        labels = labels.to(config.device)
                    else:
                        input1s = input1s.to(
                            config.device), model.init_h0().to(config.device)
                        input2s = input2s.to(
                            config.device), model.init_h0().to(config.device)
                        labels = labels.to(config.device)
                labels = labels.squeeze()

                sum_label = torch.cat((sum_label, labels.cpu()))
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):

                    output1s, output2s = model(input1s, input2s)
                    loss = criterion(output1s, output2s, labels)
                    # _, preds = outputs.topk(1, 1, True, True)
                    # _, preds = torch.max(outputs, 1)

                    preds = F.pairwise_distance(output1s,
                                                output2s) < config.evalue_thr
                    sum_preds = torch.cat((sum_preds, preds.cpu().float()))

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        # nn.utils.clip_grad_norm(model.parameters(), max_norm=0.1)
                        optimizer.step()

                # statistics
                running_loss += loss.item()
            running_corrects = torch.sum(sum_preds.byte() == sum_label.byte())
            tp, fp, tn, fn = 0, 0, 0, 0
            for pred, label in zip(sum_preds.byte(), sum_label.byte()):
                if pred == label:
                    if pred == 1:
                        tp += 1
                    else:
                        tn += 1
                else:
                    if pred == 1:
                        fp += 1
                    else:
                        fn += 1

            epoch_loss = running_loss / len(dataloaders[phase])
            epoch_acc = running_corrects.double() / len(
                dataloaders[phase].dataset)
            epoch_tpr = tp * 2 / len(dataloaders[phase].dataset)
            epoch_tnr = tn * 2 / len(dataloaders[phase].dataset)
            print("corrects sum {}".format(str(running_corrects)))
            print("epoch_tpr: {}".format(str(epoch_tpr)))
            print("epoch_tnr: {}".format(str(epoch_tnr)))

            # epoch_acc = np.mean(mAP)
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss,
                                                       epoch_acc))
            info[phase] = {'acc': epoch_acc, 'loss': epoch_loss}

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                stop_times = 0
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), saved_model_name)
            if phase == 'val' and epoch_acc < best_acc:
                stop_times = stop_times + 1

            if phase == 'val':
                val_acc_history.append(epoch_acc)

        writer.add_scalars('data/acc', {
            'train': info["train"]['acc'],
            'val': info["val"]['acc']
        }, epoch)
        writer.add_scalars('data/loss', {
            'train': info["train"]['loss'],
            'val': info["val"]['loss']
        }, epoch)
        scheduler.step(info["val"]['loss'])
        if stop_times >= 20:
            break
    time_elapsed = time.time() - since

    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), saved_model_name)

    return model, val_acc_history
示例#5
0
    def _train_epoch(self, epoch, phase="train"):
        """
		Training logic for an epoch

		:param epoch: Integer, current training epoch.
		:return: A log that contains average loss and metric in this epoch.
		"""
        import torch.nn.functional as F
        import model.loss
        print("Finding LR")
        for param_group in self.optimizer.param_groups:
            print(param_group['lr'])

        if phase == "train":
            self.model.train()
            self.train_metrics.reset()
            torch.set_grad_enabled(True)
            metrics = self.train_metrics
        elif phase == "val":
            self.model.eval()
            self.valid_metrics.reset()
            torch.set_grad_enabled(False)
            metrics = self.valid_metrics

        outputs = []
        outputs_continuous = []
        targets = []
        targets_continuous = []

        data_loader = self.data_loader if phase == "train" else self.valid_data_loader

        for batch_idx, (data, embeddings, target, target_continuous,
                        lengths) in enumerate(data_loader):

            data, target, target_continuous = data.to(self.device), target.to(
                self.device), target_continuous.to(self.device)
            embeddings = embeddings.to(self.device)

            if phase == "train":
                self.optimizer.zero_grad()

            out = self.model(data, embeddings)

            loss = 0

            loss_categorical = self.criterion_categorical(
                out['categorical'], target)
            loss += loss_categorical

            loss_continuous = self.criterion_continuous(
                torch.sigmoid(out['continuous']), target_continuous)
            loss += loss_continuous

            if self.embed:
                loss_embed = model.loss.mse_center_loss(
                    out['embed'], embeddings, target)
                loss += loss_embed

            if phase == "train":
                loss.backward()
                self.optimizer.step()

            output = out['categorical'].cpu().detach().numpy()
            target = target.cpu().detach().numpy()
            outputs.append(output)
            targets.append(target)

            output_continuous = torch.sigmoid(
                out['continuous']).cpu().detach().numpy()
            target_continuous = target_continuous.cpu().detach().numpy()
            outputs_continuous.append(output_continuous)
            targets_continuous.append(target_continuous)

            if batch_idx % self.log_step == 0:
                self.logger.debug(
                    '{} Epoch: {} {} Loss: {:.6f} Loss categorical: {:.6f} Loss continuous: {:.6f}'
                    .format(phase, epoch, self._progress(batch_idx),
                            loss.item(), loss_categorical.item(),
                            loss_continuous.item()))

            if batch_idx == self.len_epoch:
                break

        if phase == "train":
            self.writer.set_step(epoch)
        else:
            self.writer.set_step(epoch, "valid")

        metrics.update('loss', loss.item())

        metrics.update('loss_categorical', loss_categorical.item())
        if self.embed:
            metrics.update('loss_embed', loss_embed.item())

        output = np.vstack(outputs)
        target = np.vstack(targets)
        target[target >= 0.5] = 1  # threshold to get binary labels
        target[target < 0.5] = 0

        ap = model.metric.average_precision(output, target)
        roc_auc = model.metric.roc_auc(output, target)
        metrics.update("map", np.mean(ap))
        metrics.update("roc_auc", np.mean(roc_auc))

        self.writer.add_figure(
            '%s ap per class' % phase,
            make_barplot(ap,
                         self.valid_data_loader.dataset.categorical_emotions,
                         'average_precision'))
        self.writer.add_figure(
            '%s roc auc per class' % phase,
            make_barplot(roc_auc,
                         self.valid_data_loader.dataset.categorical_emotions,
                         'roc auc'))

        metrics.update('loss_continuous', loss_continuous.item())
        output_continuous = np.vstack(outputs_continuous)
        target_continuous = np.vstack(targets_continuous)

        mse = model.metric.mean_squared_error(output_continuous,
                                              target_continuous)
        r2 = model.metric.r2(output_continuous, target_continuous)

        metrics.update("r2", np.mean(r2))
        metrics.update("mse", np.mean(mse))

        self.writer.add_figure(
            '%s r2 per class' % phase,
            make_barplot(r2,
                         self.valid_data_loader.dataset.continuous_emotions,
                         'r2'))
        self.writer.add_figure(
            '%s mse auc per class' % phase,
            make_barplot(mse,
                         self.valid_data_loader.dataset.continuous_emotions,
                         'mse'))

        metrics.update(
            "mre", model.metric.ERS(np.mean(r2), np.mean(ap),
                                    np.mean(roc_auc)))

        log = metrics.result()

        if phase == "train":
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

            if self.do_validation:
                val_log = self._train_epoch(epoch, phase="val")
                log.update(**{'val_' + k: v for k, v in val_log.items()})

            return log

        elif phase == "val":
            if self.categorical:
                self.writer.save_results(output, "output")
            if self.continuous:
                self.writer.save_results(output_continuous,
                                         "output_continuous")

            return metrics.result()
def train(model, optimizer, loss_fn, dataloader, metrics_dict, hyper_params):
    """
    Train the model.
    Args:
        model: (torch.nn.Module) the neural network
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data
        metrics_dict: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        hyper_params: (Params) hyperparameters
    """

    # set model to training mode
    model.train()

    # summary for current training loop and a running average object for loss
    summ = []
    loss_avg = utils.RunningAverage()
    # Use tqdm for progress bar
    with tqdm(total=len(dataloader)) as t:
        for i, (scan_batch, ground_truth_batch, _) in enumerate(dataloader):
            # move to GPU if available
            if hyper_params.cuda is not -1:
                scan_batch, ground_truth_batch = scan_batch.to(
                    device=hyper_params.cuda), ground_truth_batch.to(
                        device=hyper_params.cuda)
            # convert to torch Variables
            scan_batch, ground_truth_batch = Variable(scan_batch), Variable(
                ground_truth_batch)

            # compute model output and loss
            output_batch = model(scan_batch)
            loss = loss_fn(output_batch, ground_truth_batch)

            # clear previous gradients, compute gradients of all variables wrt loss
            optimizer.zero_grad()
            loss.backward()

            # performs updates using calculated gradients
            optimizer.step()

            # make output_batch one-hot encoded
            output_batch = torch.sigmoid(output_batch)
            output_batch = (output_batch > hyper_params.treshold).float()

            # Evaluate summaries only once in a while
            #if i % hyper_params.save_summary_steps == 0:
            # extract data from torch Variable, move to cpu, convert to numpy arrays
            output_batch_np = output_batch.data.cpu().numpy()
            ground_truth_batch_np = ground_truth_batch.data.cpu().numpy()

            # compute all metrics on this batch
            summary_batch = {
                metric: metrics_dict[metric](output_batch_np,
                                             ground_truth_batch_np)
                for metric in metrics_dict
            }
            summary_batch['loss'] = loss.item()
            summ.append(summary_batch)
            # update the average loss
            loss_avg.update(loss.item())

            t.set_postfix(loss='{:05.3f}'.format(loss_avg()))
            t.update()

    # compute mean of all metrics in summary
    metrics_mean = {
        metric: np.mean([x[metric] for x in summ])
        for metric in summ[0]
    }
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v)
                                for k, v in metrics_mean.items())
    logging.info("- Train metrics: " + metrics_string)
    #return metrics_mean, to append to list, to average later on for k-fold cross validation
    return metrics_mean