示例#1
0
    def update(self, batch):
        inputs, labels, masks, orig_idx = unpack_batch(batch, self.opt['cuda'])
        self.model.train()
        self.optimizer.zero_grad()
        logits = self.model(inputs)
        if self.opt['crf']:
            loss, _ = self.crit(logits, masks, labels)
        else:
            logits_flat = logits.view(-1, logits.size(-1))
            loss = self.crit(logits_flat, labels.view(-1))

        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                       self.opt['max_grad_norm'])
        self.optimizer.step()
        loss_val = loss.data.item()
        return loss_val
示例#2
0
    def train(self):
        source_img, source_label = self.data_loader.get_random_batch()
        target_img, target_label = self.data_loader.get_random_batch()

        assert (len(source_img) == len(target_img))

        perturb_img = Variable(source_img.cuda(), requires_grad=True)
        source_img = source_img.cuda()
        target_img = target_img.cuda()
        loss_fn = AdvarsarialLoss(self.model, source_img, target_img,
                                  self.lamb, self.budget)

        for epoch in range(self.num_epochs):
            tk_perturb = get_internal_representation(self.model, perturb_img)
            self.optimizer.zero_grad()
            loss = loss_fn(perturb_img, tk_perturb)

            print('epoch {} : loss {}'.format(epoch, loss))
            loss.backward()
            self.optimizer.step()
示例#3
0
    def calc_losses(self, data, is_train=True, global_step=0):
        if "images" not in data:
            return {}
        all_images = data["images"].to(device=device)  # (SB, NV, 3, H, W)

        SB, NV, _, H, W = all_images.shape
        all_poses = data["poses"].to(device=device)  # (SB, NV, 4, 4)
        all_bboxes = data.get("bbox")  # (SB, NV, 4)  cmin rmin cmax rmax
        all_focals = data["focal"]  # (SB)
        all_c = data.get("c")  # (SB)

        if self.use_bbox and global_step >= args.no_bbox_step:
            self.use_bbox = False
            print(">>> Stopped using bbox sampling @ iter", global_step)

        if not is_train or not self.use_bbox:
            all_bboxes = None

        all_rgb_gt = []
        all_rays = []

        curr_nviews = nviews[torch.randint(0, len(nviews), ()).item()]
        if curr_nviews == 1:
            image_ord = torch.randint(0, NV, (SB, 1))
        else:
            image_ord = torch.empty((SB, curr_nviews), dtype=torch.long)
        for obj_idx in range(SB):
            if all_bboxes is not None:
                bboxes = all_bboxes[obj_idx]
            images = all_images[obj_idx]  # (NV, 3, H, W)
            poses = all_poses[obj_idx]  # (NV, 4, 4)
            focal = all_focals[obj_idx]
            c = None
            if "c" in data:
                c = data["c"][obj_idx]
            if curr_nviews > 1:
                # Somewhat inefficient, don't know better way
                image_ord[obj_idx] = torch.from_numpy(
                    np.random.choice(NV, curr_nviews, replace=False))
            images_0to1 = images * 0.5 + 0.5

            cam_rays = util.gen_rays(poses,
                                     W,
                                     H,
                                     focal,
                                     self.z_near,
                                     self.z_far,
                                     c=c)  # (NV, H, W, 8)
            rgb_gt_all = images_0to1
            rgb_gt_all = (rgb_gt_all.permute(0, 2, 3,
                                             1).contiguous().reshape(-1, 3)
                          )  # (NV, H, W, 3)

            if all_bboxes is not None:
                pix = util.bbox_sample(bboxes, args.ray_batch_size)
                pix_inds = pix[..., 0] * H * W + pix[..., 1] * W + pix[..., 2]
            else:
                pix_inds = torch.randint(0, NV * H * W,
                                         (args.ray_batch_size, ))

            rgb_gt = rgb_gt_all[pix_inds]  # (ray_batch_size, 3)
            rays = cam_rays.view(-1, cam_rays.shape[-1])[pix_inds].to(
                device=device)  # (ray_batch_size, 8)

            all_rgb_gt.append(rgb_gt)
            all_rays.append(rays)

        all_rgb_gt = torch.stack(all_rgb_gt)  # (SB, ray_batch_size, 3)
        all_rays = torch.stack(all_rays)  # (SB, ray_batch_size, 8)

        image_ord = image_ord.to(device)
        src_images = util.batched_index_select_nd(
            all_images, image_ord)  # (SB, NS, 3, H, W)
        src_poses = util.batched_index_select_nd(all_poses,
                                                 image_ord)  # (SB, NS, 4, 4)

        all_bboxes = all_poses = all_images = None

        net.encode(
            src_images,
            src_poses,
            all_focals.to(device=device),
            c=all_c.to(device=device) if all_c is not None else None,
        )

        render_dict = DotMap(render_par(
            all_rays,
            want_weights=True,
        ))
        coarse = render_dict.coarse
        fine = render_dict.fine
        using_fine = len(fine) > 0

        loss_dict = {}

        rgb_loss = self.rgb_coarse_crit(coarse.rgb, all_rgb_gt)
        loss_dict["rc"] = rgb_loss.item() * self.lambda_coarse
        if using_fine:
            fine_loss = self.rgb_fine_crit(fine.rgb, all_rgb_gt)
            rgb_loss = rgb_loss * self.lambda_coarse + fine_loss * self.lambda_fine
            loss_dict["rf"] = fine_loss.item() * self.lambda_fine

        loss = rgb_loss
        if is_train:
            loss.backward()
        loss_dict["t"] = loss.item()

        return loss_dict
示例#4
0
def train_model(model,
                dataloaders,
                criterion,
                optimizer,
                writer,
                scheduler,
                config,
                num_epochs=150):
    since = time.time()
    val_acc_history = []
    saved_model_name = "music_siamese_50000" + datetime.now().strftime(
        '%b%d_%H-%M-%S') + ".pth"

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    stop_times = 0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        info = {}
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            running_loss = 0.0
            running_corrects = 0

            sum_label = torch.zeros(0)
            sum_preds = torch.zeros(0)
            for input1s, input2s, labels in tqdm(dataloaders[phase]):

                if config.multi_gpu:
                    if config.model_type == 'cnn':
                        input1s = input1s.cuda(config.device_ids[0])
                        input2s = input2s.cuda(config.device_ids[0])
                        labels = labels.cuda(config.device_ids[0])
                    else:
                        input1s = input1s.cuda(
                            config.device_ids[0]), model.init_h0().cuda(
                                config.device_ids[0])
                        input2s = input2s.cuda(
                            config.device_ids[0]), model.init_h0().cuda(
                                config.device_ids[0])
                        labels = labels.cuda(config.device_ids[0])

                else:
                    if config.model_type == 'cnn':
                        input1s = input1s.to(config.device)
                        input2s = input2s.to(config.device)
                        labels = labels.to(config.device)
                    else:
                        input1s = input1s.to(
                            config.device), model.init_h0().to(config.device)
                        input2s = input2s.to(
                            config.device), model.init_h0().to(config.device)
                        labels = labels.to(config.device)
                labels = labels.squeeze()

                sum_label = torch.cat((sum_label, labels.cpu()))
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):

                    output1s, output2s = model(input1s, input2s)
                    loss = criterion(output1s, output2s, labels)
                    # _, preds = outputs.topk(1, 1, True, True)
                    # _, preds = torch.max(outputs, 1)

                    preds = F.pairwise_distance(output1s,
                                                output2s) < config.evalue_thr
                    sum_preds = torch.cat((sum_preds, preds.cpu().float()))

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        # nn.utils.clip_grad_norm(model.parameters(), max_norm=0.1)
                        optimizer.step()

                # statistics
                running_loss += loss.item()
            running_corrects = torch.sum(sum_preds.byte() == sum_label.byte())
            tp, fp, tn, fn = 0, 0, 0, 0
            for pred, label in zip(sum_preds.byte(), sum_label.byte()):
                if pred == label:
                    if pred == 1:
                        tp += 1
                    else:
                        tn += 1
                else:
                    if pred == 1:
                        fp += 1
                    else:
                        fn += 1

            epoch_loss = running_loss / len(dataloaders[phase])
            epoch_acc = running_corrects.double() / len(
                dataloaders[phase].dataset)
            epoch_tpr = tp * 2 / len(dataloaders[phase].dataset)
            epoch_tnr = tn * 2 / len(dataloaders[phase].dataset)
            print("corrects sum {}".format(str(running_corrects)))
            print("epoch_tpr: {}".format(str(epoch_tpr)))
            print("epoch_tnr: {}".format(str(epoch_tnr)))

            # epoch_acc = np.mean(mAP)
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss,
                                                       epoch_acc))
            info[phase] = {'acc': epoch_acc, 'loss': epoch_loss}

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                stop_times = 0
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), saved_model_name)
            if phase == 'val' and epoch_acc < best_acc:
                stop_times = stop_times + 1

            if phase == 'val':
                val_acc_history.append(epoch_acc)

        writer.add_scalars('data/acc', {
            'train': info["train"]['acc'],
            'val': info["val"]['acc']
        }, epoch)
        writer.add_scalars('data/loss', {
            'train': info["train"]['loss'],
            'val': info["val"]['loss']
        }, epoch)
        scheduler.step(info["val"]['loss'])
        if stop_times >= 20:
            break
    time_elapsed = time.time() - since

    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), saved_model_name)

    return model, val_acc_history
示例#5
0
for epoch in range(opt.start_epoch, opt.epochs):
    if epoch % opt.save_freq == 0:
        #model, epoch, model_path, iteration, prefix=""
        save_checkpoint(model, epoch, opt.model_path, 0, prefix=opt.model)

    for iteration, batch in enumerate(train_data_loader):
        lr, hr = Variable(batch[0]), Variable(batch[1])
        if torch.cuda.is_available():
            lr = lr.cuda()
            hr = hr.cuda()

        sr = model(lr)
        model.zero_grad()
        loss = criterion(sr, hr)
        loss.backward()
        optimizerG.step()

        info = "===> Epoch[{}]({}/{}): time: {:4.4f}:\n".format(
            epoch, iteration,
            len(demo_dataset_x4) // opt.batch_size,
            time.time() - start_time)

        info += "Loss: {:.4f}\n".format(loss.float())
        print(info)

        if iteration % opt.iter_freq == 0:
            # model, epoch, model_path, iteration, prefix=""
            # if not os.path.isdir(opt.result_dir + "{}_{}_{}_result".format(epoch,iteration,opt.model)):
            #     os.makedirs(opt.result_dir + "{}_{}_{}_result".format(epoch,iteration,opt.model))
示例#6
0
    def _train_epoch(self, epoch, phase="train"):
        """
		Training logic for an epoch

		:param epoch: Integer, current training epoch.
		:return: A log that contains average loss and metric in this epoch.
		"""
        import torch.nn.functional as F
        import model.loss
        print("Finding LR")
        for param_group in self.optimizer.param_groups:
            print(param_group['lr'])

        if phase == "train":
            self.model.train()
            self.train_metrics.reset()
            torch.set_grad_enabled(True)
            metrics = self.train_metrics
        elif phase == "val":
            self.model.eval()
            self.valid_metrics.reset()
            torch.set_grad_enabled(False)
            metrics = self.valid_metrics

        outputs = []
        outputs_continuous = []
        targets = []
        targets_continuous = []

        data_loader = self.data_loader if phase == "train" else self.valid_data_loader

        for batch_idx, (data, embeddings, target, target_continuous,
                        lengths) in enumerate(data_loader):

            data, target, target_continuous = data.to(self.device), target.to(
                self.device), target_continuous.to(self.device)
            embeddings = embeddings.to(self.device)

            if phase == "train":
                self.optimizer.zero_grad()

            out = self.model(data, embeddings)

            loss = 0

            loss_categorical = self.criterion_categorical(
                out['categorical'], target)
            loss += loss_categorical

            loss_continuous = self.criterion_continuous(
                torch.sigmoid(out['continuous']), target_continuous)
            loss += loss_continuous

            if self.embed:
                loss_embed = model.loss.mse_center_loss(
                    out['embed'], embeddings, target)
                loss += loss_embed

            if phase == "train":
                loss.backward()
                self.optimizer.step()

            output = out['categorical'].cpu().detach().numpy()
            target = target.cpu().detach().numpy()
            outputs.append(output)
            targets.append(target)

            output_continuous = torch.sigmoid(
                out['continuous']).cpu().detach().numpy()
            target_continuous = target_continuous.cpu().detach().numpy()
            outputs_continuous.append(output_continuous)
            targets_continuous.append(target_continuous)

            if batch_idx % self.log_step == 0:
                self.logger.debug(
                    '{} Epoch: {} {} Loss: {:.6f} Loss categorical: {:.6f} Loss continuous: {:.6f}'
                    .format(phase, epoch, self._progress(batch_idx),
                            loss.item(), loss_categorical.item(),
                            loss_continuous.item()))

            if batch_idx == self.len_epoch:
                break

        if phase == "train":
            self.writer.set_step(epoch)
        else:
            self.writer.set_step(epoch, "valid")

        metrics.update('loss', loss.item())

        metrics.update('loss_categorical', loss_categorical.item())
        if self.embed:
            metrics.update('loss_embed', loss_embed.item())

        output = np.vstack(outputs)
        target = np.vstack(targets)
        target[target >= 0.5] = 1  # threshold to get binary labels
        target[target < 0.5] = 0

        ap = model.metric.average_precision(output, target)
        roc_auc = model.metric.roc_auc(output, target)
        metrics.update("map", np.mean(ap))
        metrics.update("roc_auc", np.mean(roc_auc))

        self.writer.add_figure(
            '%s ap per class' % phase,
            make_barplot(ap,
                         self.valid_data_loader.dataset.categorical_emotions,
                         'average_precision'))
        self.writer.add_figure(
            '%s roc auc per class' % phase,
            make_barplot(roc_auc,
                         self.valid_data_loader.dataset.categorical_emotions,
                         'roc auc'))

        metrics.update('loss_continuous', loss_continuous.item())
        output_continuous = np.vstack(outputs_continuous)
        target_continuous = np.vstack(targets_continuous)

        mse = model.metric.mean_squared_error(output_continuous,
                                              target_continuous)
        r2 = model.metric.r2(output_continuous, target_continuous)

        metrics.update("r2", np.mean(r2))
        metrics.update("mse", np.mean(mse))

        self.writer.add_figure(
            '%s r2 per class' % phase,
            make_barplot(r2,
                         self.valid_data_loader.dataset.continuous_emotions,
                         'r2'))
        self.writer.add_figure(
            '%s mse auc per class' % phase,
            make_barplot(mse,
                         self.valid_data_loader.dataset.continuous_emotions,
                         'mse'))

        metrics.update(
            "mre", model.metric.ERS(np.mean(r2), np.mean(ap),
                                    np.mean(roc_auc)))

        log = metrics.result()

        if phase == "train":
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

            if self.do_validation:
                val_log = self._train_epoch(epoch, phase="val")
                log.update(**{'val_' + k: v for k, v in val_log.items()})

            return log

        elif phase == "val":
            if self.categorical:
                self.writer.save_results(output, "output")
            if self.continuous:
                self.writer.save_results(output_continuous,
                                         "output_continuous")

            return metrics.result()
def train(model, optimizer, loss_fn, dataloader, metrics_dict, hyper_params):
    """
    Train the model.
    Args:
        model: (torch.nn.Module) the neural network
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data
        metrics_dict: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        hyper_params: (Params) hyperparameters
    """

    # set model to training mode
    model.train()

    # summary for current training loop and a running average object for loss
    summ = []
    loss_avg = utils.RunningAverage()
    # Use tqdm for progress bar
    with tqdm(total=len(dataloader)) as t:
        for i, (scan_batch, ground_truth_batch, _) in enumerate(dataloader):
            # move to GPU if available
            if hyper_params.cuda is not -1:
                scan_batch, ground_truth_batch = scan_batch.to(
                    device=hyper_params.cuda), ground_truth_batch.to(
                        device=hyper_params.cuda)
            # convert to torch Variables
            scan_batch, ground_truth_batch = Variable(scan_batch), Variable(
                ground_truth_batch)

            # compute model output and loss
            output_batch = model(scan_batch)
            loss = loss_fn(output_batch, ground_truth_batch)

            # clear previous gradients, compute gradients of all variables wrt loss
            optimizer.zero_grad()
            loss.backward()

            # performs updates using calculated gradients
            optimizer.step()

            # make output_batch one-hot encoded
            output_batch = torch.sigmoid(output_batch)
            output_batch = (output_batch > hyper_params.treshold).float()

            # Evaluate summaries only once in a while
            #if i % hyper_params.save_summary_steps == 0:
            # extract data from torch Variable, move to cpu, convert to numpy arrays
            output_batch_np = output_batch.data.cpu().numpy()
            ground_truth_batch_np = ground_truth_batch.data.cpu().numpy()

            # compute all metrics on this batch
            summary_batch = {
                metric: metrics_dict[metric](output_batch_np,
                                             ground_truth_batch_np)
                for metric in metrics_dict
            }
            summary_batch['loss'] = loss.item()
            summ.append(summary_batch)
            # update the average loss
            loss_avg.update(loss.item())

            t.set_postfix(loss='{:05.3f}'.format(loss_avg()))
            t.update()

    # compute mean of all metrics in summary
    metrics_mean = {
        metric: np.mean([x[metric] for x in summ])
        for metric in summ[0]
    }
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v)
                                for k, v in metrics_mean.items())
    logging.info("- Train metrics: " + metrics_string)
    #return metrics_mean, to append to list, to average later on for k-fold cross validation
    return metrics_mean