Beispiel #1
0
def validate(model, optimizer, criterion, metrics, options):
    model.eval()

    losses = AverageMeter()
    for metric in metrics:
        metric.reset()

    for batch_idx, (data, target) in zip(maybe_range(options.max_batch_per_epoch),
                                         options.val_loader):
        data = convert_dtype(options.dtype, data)
        if options.force_target_dtype:
            target = convert_dtype(options.dtype, target)

        if options.use_cuda:
            data, target = data.cuda(), target.cuda()

        with torch.no_grad():
            output = model(data)

            loss = criterion(output, target)
            losses.update(loss.item(), data.size(0))

            for metric in metrics:
                metric_value = metric(output, target)
                metric.update(metric_value, data.size(0))

    metrics_averages = {metric.name: metric.average().item() for metric in metrics}
    loss_average = global_average(losses.sum, losses.count).item()
    return metrics_averages, loss_average
Beispiel #2
0
    def train_one_epoch(self):
        """
        One epoch training function
        :return:
        """

        tqdm_batch = tqdm.tqdm(self.data_loader.train_loader,
                               total=self.data_loader.train_iterations,
                               desc="Epoch-{}-".format(self.current_epoch))

        self.train()

        epoch_loss = AverageMeter()
        top1_acc = AverageMeter()
        top5_acc = AverageMeter()

        current_batch = 0
        for i, (x, y) in enumerate(tqdm_batch):
            if self.cuda:
                x, y = x.cuda(non_blocking=self.config.async_loading), y.cuda(
                    non_blocking=self.config.async_loading)

            self.optimizer.zero_grad()
            #             self.adjust_learning_rate(self.optimizer, self.current_epoch, i, self.data_loader.train_iterations)

            pred = self(x)
            cur_loss = self.loss_fn(pred, y)

            if np.isnan(float(cur_loss.item())):
                raise ValueError('Loss is nan during training...')

            cur_loss.backward()
            self.optimizer.step()

            top1, top5 = cls_accuracy(pred.data, y.data, topk=(1, 5))
            top1_acc.update(top1.item(), x.size(0))
            top5_acc.update(top5.item(), x.size(0))

            epoch_loss.update(cur_loss.item())

            self.current_iteration += 1
            current_batch += 1

        tqdm_batch.close()

        print("Training at epoch-" + str(self.current_epoch) + " | " +
              "loss: " + str(epoch_loss.val) + "\tTop1 Acc: " +
              str(top1_acc.val))
Beispiel #3
0
    def validate(self):

        # set the model in eval mode
        self.net.eval()

        # Initialize average meters
        epoch_loss = AverageMeter()
        epoch_iou = AverageMeter()
        epoch_filtered_iou = AverageMeter()

        tqdm_batch = tqdm(self.valid_loader, f'Epoch-{self.current_epoch}-')
        with torch.no_grad():
            for x in tqdm_batch:
                # prepare data
                imgs = torch.tensor(x['img'],
                                    dtype=torch.float,
                                    device=self.device)
                masks = torch.tensor(x['mask'],
                                     dtype=torch.float,
                                     device=self.device)

                # model
                pred, *_ = self.net(imgs)

                # loss
                cur_loss = self.loss(pred, masks)
                if np.isnan(float(cur_loss.item())):
                    raise ValueError('Loss is nan during validation...')

                # metrics
                pred_t = torch.sigmoid(pred) > 0.5
                masks_t = masks > 0.5

                cur_iou = iou_pytorch(pred_t, masks_t)
                cur_filtered_iou = iou_pytorch(remove_small_mask_batch(pred_t),
                                               masks_t)

                batch_size = imgs.shape[0]
                epoch_loss.update(cur_loss.item(), batch_size)
                epoch_iou.update(cur_iou.item(), batch_size)
                epoch_filtered_iou.update(cur_filtered_iou.item(), batch_size)

        tqdm_batch.close()

        logging.info(f'Validation at epoch- {self.current_epoch} |'
                     f'loss: {epoch_loss.val:.5} - IOU: {epoch_iou.val:.5}'
                     f' - Filtered IOU: {epoch_filtered_iou.val:.5}')

        return epoch_filtered_iou.val
Beispiel #4
0
    def validate(self):
        """
        One epoch validation
        :return:
        """
        # set the model in validation mode
        self.model.eval()
        val_losses = []

        valid_loss = AverageMeter()
        valid_err_translation = AverageMeter()
        valid_err_rotation = AverageMeter()
        valid_err_joints = AverageMeter()

        t = Transformation(config=self.config)

        val_tic = time.time()

        with torch.no_grad():
            for x, y in self.data_loader.valid_loader:
                # validate on gpu
                if self.cuda:
                    x = x.to(device=self.device, dtype=torch.float)
                    y = y.to(device=self.device, dtype=torch.float)

                # model
                pred = self.model(x.type(torch.FloatTensor))

                if self.config.data_output_type == "joints_absolute":
                    loss_joints = self.loss(pred, y)
                    total_loss = loss_joints
                    valid_loss.update(total_loss.item())
                    valid_err_joints.update(total_loss.item())
                elif self.config.data_output_type == "q_trans_simple":
                    loss_q_trans_simple = self.loss(pred, y)
                    total_loss = loss_q_trans_simple
                elif self.config.data_output_type == "joints_relative":
                    total_loss = self.loss(pred, y)
                    valid_loss.update(total_loss.item())
                    # print("Validation loss {:f}".format(total_loss.item()))
                elif self.config.data_output_type == "pose_relative":
                    # loss for rotation
                    # select rotation indices from the prediction tensor
                    indices = torch.tensor([3, 4, 5, 6])
                    indices = indices.to(self.device)
                    rotation = torch.index_select(pred, 1, indices)
                    # select rotation indices from the label tensor
                    y_rot = torch.index_select(y, 1, indices)
                    # calc MSE loss for rotation
                    # print("Rotation Pred", rotation[0])
                    # print("Rotation Label", y_rot[0])
                    # q_distance = pq.Quaternion.distance(rotation)

                    loss_rotation = self.loss(rotation, y_rot)
                    # penalty loss from facebook paper posenet
                    # penalty_loss = self.config.rot_reg * torch.mean((torch.sum(quater ** 2, dim=1) - 1) ** 2)
                    penalty_loss = 0

                    # loss for translation
                    # select translation indices from the prediction tensor
                    indices = torch.tensor([0, 1, 2])
                    indices = indices.to(self.device)
                    translation = torch.index_select(pred, 1, indices)
                    # select translation indices from the label tensor
                    y_trans = torch.index_select(y, 1, indices)

                    # calc MSE loss for translation
                    loss_translation = self.loss(translation, y_trans)
                    total_loss = penalty_loss + loss_rotation + loss_translation

                    q_pred = pq.Quaternion(rotation[0].cpu().detach().numpy())
                    q_rot = pq.Quaternion(y_rot[0].cpu().detach().numpy())
                    q_dist = math.degrees(pq.Quaternion.distance(
                        q_pred, q_rot))
                    valid_err_rotation.update(q_dist)

                    # loss for translation
                    # select translation indices from the prediction tensor
                    indices = torch.tensor([0, 1, 2])
                    indices = indices.to(self.device)
                    translation = torch.index_select(pred, 1, indices)
                    # select translation indices from the label tensor
                    y_trans = torch.index_select(y, 1, indices)

                    trans_pred = translation[0].cpu().detach().numpy()
                    trans_label = y_trans[0].cpu().detach().numpy()

                    # calc translation MSE
                    mse_trans = (np.square(trans_pred - trans_label)).mean()
                    valid_err_translation.update(mse_trans)

                    # use simple loss
                    total_loss = self.loss(pred, y)
                    valid_loss.update(total_loss.item())

                elif self.config.data_output_type == "pose_absolute":
                    # select rotation indices from the prediction tensor
                    indices = torch.tensor([3, 4, 5, 6])
                    indices = indices.to(self.device)
                    rotation = torch.index_select(pred, 1, indices)
                    # select rotation indices from the label tensor
                    y_rot = torch.index_select(y, 1, indices)
                    q_pred = pq.Quaternion(rotation[0].cpu().detach().numpy())
                    q_rot = pq.Quaternion(y_rot[0].cpu().detach().numpy())
                    q_dist = math.degrees(pq.Quaternion.distance(
                        q_pred, q_rot))
                    valid_err_rotation.update(q_dist)

                    # loss for translation
                    # select translation indices from the prediction tensor
                    indices = torch.tensor([0, 1, 2])
                    indices = indices.to(self.device)
                    translation = torch.index_select(pred, 1, indices)
                    # select translation indices from the label tensor
                    y_trans = torch.index_select(y, 1, indices)

                    trans_pred = translation[0].cpu().detach().numpy()
                    trans_label = y_trans[0].cpu().detach().numpy()

                    # calc translation MSE
                    mse_trans = (np.square(trans_pred - trans_label)).mean()
                    valid_err_translation.update(mse_trans)

                    # use simple loss
                    total_loss = self.loss(pred, y)
                    valid_loss.update(total_loss.item())
                else:
                    raise Exception("Wrong data output type chosen.")

                if np.isnan(float(total_loss.item())):
                    raise ValueError('Loss is nan during Validation.')
                val_losses.append(total_loss.item())

        # update logging dict
        # self.logging_dict["loss_validation_mse"].append(np.mean(val_losses))

        self.logging_dict["valid_loss"].append(valid_loss.val)
        self.logging_dict["valid_err_rotation"].append(valid_err_rotation.val)
        self.logging_dict["valid_err_translation"].append(
            valid_err_translation.val)
        self.logging_dict["valid_err_joints"].append(valid_err_joints.val)

        progress = float((self.current_epoch + 1) / self.config.max_epoch)
        val_duration = time.time() - val_tic
        self.logger.info(
            "Valid Epoch: {:>4d} | Total: {:>4d} | Progress: {:2.2%} | Loss: {:>3.2e} | Translation [mm]: {:>3.2e} |"
            " Rotation [deg] {:>3.2e} | Joints [deg] {:>3.2e} ({:02d}:{:02d}:{:02d}) "
            .format(self.current_epoch +
                    1, self.config.max_epoch, progress, valid_loss.val,
                    valid_err_translation.val, valid_err_rotation.val,
                    valid_err_joints.val, int(val_duration / 3600),
                    int(np.mod(val_duration, 3600) /
                        60), int(np.mod(np.mod(val_duration, 3600), 60))) +
            time.strftime("%d.%m.%y %H:%M:%S", time.localtime()))
        return np.mean(val_losses)
Beispiel #5
0
def train(model, train_loader, test_loader, opt):
    logging.info('===> Init the optimizer ...')
    criterion = SmoothCrossEntropy()
    if opt.use_sgd:
        logging.info("===> Use SGD")
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=opt.lr * 100,
                                    momentum=0.9,
                                    weight_decay=opt.weight_decay)
    else:
        logging.info("===> Use Adam")
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opt.lr,
                                     weight_decay=opt.weight_decay)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           opt.epochs,
                                                           eta_min=opt.lr)
    optimizer, scheduler, opt.lr = load_pretrained_optimizer(
        opt.pretrained_model, optimizer, scheduler, opt.lr)

    logging.info('===> Init Metric ...')
    opt.train_losses = AverageMeter()
    opt.test_losses = AverageMeter()
    best_test_overall_acc = 0.
    avg_acc_when_best = 0.

    logging.info('===> start training ...')
    for _ in range(opt.epoch, opt.epochs):
        opt.epoch += 1
        # reset tracker
        opt.train_losses.reset()
        opt.test_losses.reset()

        train_overall_acc, train_class_acc, opt = train_step(
            model, train_loader, optimizer, criterion, opt)
        test_overall_acc, test_class_acc, opt = infer(model, test_loader,
                                                      criterion, opt)

        scheduler.step()

        # ------------------  save ckpt
        if test_overall_acc > best_test_overall_acc:
            best_test_overall_acc = test_overall_acc
            avg_acc_when_best = test_class_acc
            logging.info(
                "Got a new best model on Test with Overall ACC {:.4f}. "
                "Its avg acc is {:.4f}".format(best_test_overall_acc,
                                               avg_acc_when_best))
            save_ckpt(model, optimizer, scheduler, opt, 'best')

        # ------------------ show information
        logging.info(
            "===> Epoch {}/{}, Train Loss {:.4f}, Test Overall Acc {:.4f}, Test Avg Acc {:4f}, "
            "Best Test Overall Acc {:.4f}, Its test avg acc {:.4f}.".format(
                opt.epoch, opt.epochs, opt.train_losses.avg, test_overall_acc,
                test_class_acc, best_test_overall_acc, avg_acc_when_best))

        info = {
            'train_loss': opt.train_losses.avg,
            'train_OA': train_overall_acc,
            'train_avg_acc': train_class_acc,
            'test_loss': opt.test_losses.avg,
            'test_OA': test_overall_acc,
            'test_avg_acc': test_class_acc,
            'lr': scheduler.get_lr()[0]
        }
        for tag, value in info.items():
            opt.logger.scalar_summary(tag, value, opt.step)

    save_ckpt(model, optimizer, scheduler, opt, 'last')
    logging.info(
        'Saving the final model.Finish! Best Test Overall Acc {:.4f}, Its test avg acc {:.4f}. '
        'Last Test Overall Acc {:.4f}, Its test avg acc {:.4f}.'.format(
            best_test_overall_acc, avg_acc_when_best, test_overall_acc,
            test_class_acc))
def train_one_epoch(sess, ops, epoch, lr):
    """
    One epoch training
    """

    is_training = True

    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    weight_loss_meter = AverageMeter()
    cls_loss_meter = AverageMeter()
    acc_meter = AverageMeter()

    sess.run(ops['train_init_op'])
    feed_dict = {ops['is_training_pl']: is_training, ops['learning_rate']: lr}

    batch_idx = 0
    end = time.time()
    while True:
        try:
            _, loss, classification_loss, weight_loss, \
            tower_logits, tower_labels, = sess.run([ops['train_op'],
                                                    ops['loss'],
                                                    ops['classification_loss'],
                                                    ops['weight_loss'],
                                                    ops['tower_logits'],
                                                    ops['tower_labels']],
                                                   feed_dict=feed_dict)
            for logits, labels in zip(tower_logits, tower_labels):
                pred = np.argmax(logits, -1)
                correct = np.mean(pred == labels)
                acc_meter.update(correct, pred.shape[0])

            # update meters
            loss_meter.update(loss)
            cls_loss_meter.update(classification_loss)
            weight_loss_meter.update(weight_loss)
            batch_time.update(time.time() - end)
            end = time.time()

            if (batch_idx + 1) % config.print_freq == 0:
                logger.info(
                    f'Train: [{epoch}][{batch_idx}] '
                    f'T {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                    f'acc {acc_meter.val:.3f} ({acc_meter.avg:.3f}) '
                    f'loss {loss_meter.val:.3f} ({loss_meter.avg:.3f}) '
                    f'cls loss {cls_loss_meter.val:.3f} ({cls_loss_meter.avg:.3f}) '
                    f'weight loss {weight_loss_meter.val:.3f} ({weight_loss_meter.avg:.3f})'
                )
                wandb.log({
                    "Train Acc": acc_meter.val,
                    "Train Avg Acc": acc_meter.avg,
                    "Train Loss": loss_meter.val,
                    "Train Avg Loss": loss_meter.avg,
                    "Train Class Loss": cls_loss_meter.val,
                    "Train Avg Class Loss": cls_loss_meter.avg,
                    "Train Weight Loss": weight_loss_meter.val,
                    "Train Avg Weight Loss": weight_loss_meter.avg
                })
            batch_idx += 1
        except tf.errors.OutOfRangeError:
            break
class Self_CLAN:
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.start_iter = args.start_iter
        self.num_steps = args.num_steps
        self.num_classes = args.num_classes
        self.preheat = self.num_steps/20  # damping instead of early stopping
        self.source_label = 0
        self.target_label = 1
        self.best_miou = 0
        # TODO: CHANGE LOSS for LSGAN
        self.bce_loss = torch.nn.BCEWithLogitsLoss()
        #self.bce_loss = torch.nn.MSELoss() #LSGAN
        self.weighted_bce_loss = WeightedBCEWithLogitsLoss()
        self.aux_acc = AverageMeter()
        self.save_path = args.prediction_dir  # dir to save class mIoU when validating model
        self.losses = {'seg': list(),'seg_t': list(), 'adv': list(), 'weight': list(), 'ds': list(), 'dt': list(), 'aux': list()}
        self.rotations = [0, 90, 180, 270]

        cudnn.enabled = True
        #cudnn.benchmark = True

        # set up models
        if args.model.name == 'DeepLab':
            self.model = Res_Deeplab(num_classes=args.num_classes, restore_from=args.model.restore_from)
            self.optimizer = optim.SGD(self.model.optim_parameters(args.model.optimizer), lr=args.model.optimizer.lr, momentum=args.model.optimizer.momentum, weight_decay=args.model.optimizer.weight_decay)
        if args.model.name == 'ErfNet':
            self.model = ERFNet(args.num_classes)  # To add image-net pre-training and double classificator
            self.optimizer = optim.SGD(self.model.optim_parameters(args.model.optimizer), lr=args.model.optimizer.lr, momentum=args.model.optimizer.momentum, weight_decay=args.model.optimizer.weight_decay)

        if args.method.adversarial:
            self.model_D = discriminator(name=args.discriminator.name, num_classes=args.num_classes, restore_from=args.discriminator.restore_from)
            self.optimizer_D = optim.Adam(self.model_D.parameters(), lr=args.discriminator.optimizer.lr, betas=(0.9, 0.99))
        if args.method.self:
            self.model_A = auxiliary(name=args.auxiliary.name, input_dim=args.auxiliary.classes, aux_classes=args.auxiliary.aux_classes, restore_from=args.auxiliary.restore_from)
            self.optimizer_A = optim.Adam(self.model_A.parameters(), lr=args.auxiliary.optimizer.lr, betas=(0.9, 0.99))
            self.aux_loss = nn.CrossEntropyLoss()

    def train(self, src_loader, tar_loader, val_loader):

        loss_rot = loss_adv = loss_weight = loss_D_s = loss_D_t = 0
        args = self.args
        log = self.logger
        device = self.device

        interp_source = nn.Upsample(size=(args.datasets.source.images_size[1], args.datasets.source.images_size[0]), mode='bilinear',  align_corners=True)
        interp_target = nn.Upsample(size=(args.datasets.target.images_size[1], args.datasets.target.images_size[0]), mode='bilinear',  align_corners=True)
        interp_prediction = nn.Upsample(size=(args.auxiliary.images_size[1], args.auxiliary.images_size[0]), mode='bilinear', align_corners=True)

        source_iter = enumerate(src_loader)
        target_iter = enumerate(tar_loader)

        self.model.train()
        self.model = self.model.to(device)

        if args.method.adversarial:
            self.model_D.train()
            self.model_D = self.model_D.to(device)

        if args.method.self:
            self.model_A.train()
            self.model_A = self.model_A.to(device)

        log.info('###########   TRAINING STARTED  ############')
        start = time.time()

        for i_iter in range(self.start_iter, self.num_steps):

            self.model.train()
            self.optimizer.zero_grad()
            adjust_learning_rate(self.optimizer, self.preheat, args.num_steps, args.power, i_iter, args.model.optimizer)

            # Train with adversarial loss
            if args.method.adversarial:
                self.model_D.train()
                self.optimizer_D.zero_grad()
                adjust_learning_rate(self.optimizer_D, self.preheat, args.num_steps, args.power, i_iter, args.discriminator.optimizer)

            # Adding Rotation task
            if args.method.self:
                self.model_A.train()
                self.optimizer_A.zero_grad()
                adjust_learning_rate(self.optimizer_A, self.preheat, args.num_steps, args.power, i_iter, args.auxiliary.optimizer)

            damping = (1 - i_iter/self.num_steps)  # similar to early stopping

        # ======================================================================================
        # train G
        # ======================================================================================
            if args.method.adversarial:
                for param in self.model_D.parameters():  # Remove Grads in D
                    param.requires_grad = False

            # Train with Source
            _, batch = next(source_iter)
            images_s, labels_s, _, _ = batch
            images_s = images_s.to(device)
            pred_source1_, pred_source2_ = self.model(images_s)

            pred_source1 = interp_source(pred_source1_)
            pred_source2 = interp_source(pred_source2_)

            # Segmentation Loss
            loss_seg = (loss_calc(self.num_classes, pred_source1, labels_s, device) + loss_calc(self.num_classes, pred_source2, labels_s, device))
            loss_seg.backward()
            self.losses['seg'].append(loss_seg.item())

            # Train with Target
            _, batch = next(target_iter)
            images_t, labels_t = batch
            images_t = images_t.to(device)
            pred_target1_, pred_target2_ = self.model(images_t)

            pred_target1 = interp_target(pred_target1_)
            pred_target2 = interp_target(pred_target2_)

            # Semi-supervised approach
            if args.use_target_labels and i_iter % int(1 / args.target_frac) == 0:
                loss_seg_t = (loss_calc(args.num_classes, pred_target1, labels_t, device) + loss_calc(args.num_classes, pred_target2, labels_t, device))
                loss_seg_t.backward()
                self.losses['seg_t'].append(loss_seg_t.item())

            # Adversarial Loss
            if args.method.adversarial:

                pred_target1 = pred_target1.detach()
                pred_target2 = pred_target2.detach()

                # TODO: Save the weightmap
                weight_map = weightmap(F.softmax(pred_target1, dim=1), F.softmax(pred_target2, dim=1))

                D_out = interp_target(self.model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

                # Adaptive Adversarial Loss
                if i_iter > self.preheat:
                    loss_adv = self.weighted_bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(self.source_label).to(device), weight_map, args.Epsilon, args.Lambda_local)
                else:
                    loss_adv = self.bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(self.source_label).to(device))

                loss_adv.requires_grad = True
                loss_adv = loss_adv * self.args.Lambda_adv * damping
                loss_adv.backward()
                self.losses['adv'].append(loss_adv.item())

        # Weight Discrepancy Loss
            if args.weight_loss:

                # Init container variables of DeepLab weights of layers 5 and 6
                W5 = None
                W6 = None
                # TODO: ADD ERF-NET
                if args.model.name == 'DeepLab':  
                    
                    for (w5, w6) in zip(self.model.layer5.parameters(), self.model.layer6.parameters()):
                        if W5 is None and W6 is None:
                            W5 = w5.view(-1)
                            W6 = w6.view(-1)
                        else:
                            W5 = torch.cat((W5, w5.view(-1)), 0)
                            W6 = torch.cat((W6, w6.view(-1)), 0)

                # Cosine distance between W5 and W6 vectors 
                loss_weight = (torch.matmul(W5, W6) / (torch.norm(W5) * torch.norm(W6)) + 1)  # +1 is for a positive loss
                loss_weight = loss_weight * args.Lambda_weight * damping * 2
                loss_weight.backward()
                self.losses['weight'].append(loss_weight.item())

        # ======================================================================================
        # train D
        # ======================================================================================
            if args.method.adversarial:
                # Bring back Grads in D
                for param in self.model_D.parameters():
                    param.requires_grad = True

                # Train with Source
                pred_source1 = pred_source1.detach()
                pred_source2 = pred_source2.detach()

                D_out_s = interp_source(self.model_D(F.softmax(pred_source1 + pred_source2, dim=1)))
                loss_D_s = self.bce_loss(D_out_s, torch.FloatTensor(D_out_s.data.size()).fill_(self.source_label).to(device))
                loss_D_s.backward()
                self.losses['ds'].append(loss_D_s.item())

                # Train with Target
                pred_target1 = pred_target1.detach()
                pred_target2 = pred_target2.detach()
                weight_map = weight_map.detach()

                D_out_t = interp_target(self.model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

                # Adaptive Adversarial Loss
                if i_iter > self.preheat:
                    loss_D_t = self.weighted_bce_loss(D_out_t, torch.FloatTensor(D_out_t.data.size()).fill_(self.target_label).to(device), weight_map, args.Epsilon, args.Lambda_local)
                else:
                    loss_D_t = self.bce_loss(D_out_t, torch.FloatTensor(D_out_t.data.size()).fill_(self.target_label).to(device))

                loss_D_t.backward()
                self.losses['dt'].append(loss_D_t.item())

        # ======================================================================================
        # Train SELF SUPERVISED TASK
        # ======================================================================================
            if args.method.self:

                ''' SELF-SUPERVISED (ROTATION) ALGORITHM 
                - Get squared prediction 
                - Rotate it randomly (0,90,180,270) -> assign self-label (0,1,2,3)  [*2 IF WANT TO CLASSIFY ALSO S/T]
                - Send rotated prediction to the classifier
                - Get loss 
                - Update weights of classifier and G (segmentation network) 
                '''

                # Train with Target
                pred_target1 = pred_target1_.detach()
                pred_target2 = pred_target2_.detach()

                pred_target = interp_prediction(F.softmax(pred_target1 + pred_target2, dim=1))

                # Rotate prediction randomly
                label_target = torch.empty(1, dtype=torch.long).random_(args.auxiliary.aux_classes).to(device)
                rotated_pred_target = rotate_tensor(pred_target, self.rotations[label_target.item()])

                pred_target_label = self.model_A(rotated_pred_target)
                loss_rot_target = self.aux_loss(pred_target_label, label_target)

                save_rotations(args.images_dir, pred_target, rotated_pred_target, i_iter)

                # calculate accuracy of aux
                if label_target.item() == F.softmax(pred_target_label, dim=1).argmax(dim=1).item():
                    self.aux_acc.update(1)
                else:
                    self.aux_acc.update(0)

                loss_rot = loss_rot_target * args.Lambda_aux
                loss_rot.backward()
                self.losses['aux'].append(loss_rot.item())

            # Optimizers steps
            self.optimizer.step()
            if args.method.adversarial:
                self.optimizer_D.step()
            if args.method.self:
                self.optimizer_A.step()

            if i_iter % 10 == 0:
                log.info('Iter = {0:6d}/{1:6d}, loss_seg = {2:.4f} loss_rot = {3:.4f}, loss_adv = {4:.4f}, loss_weight = {5:.4f}, loss_D_s = {6:.4f} loss_D_t = {7:.4f}, aux_acc = {8:.2f}%'.format(
                        i_iter, self.num_steps, loss_seg, loss_rot, loss_adv, loss_weight, loss_D_s, loss_D_t, self.aux_acc.val))

            if (i_iter % args.save_pred_every == 0 and i_iter != 0) or i_iter == self.num_steps-1:
                i_iter = i_iter if i_iter != self.num_steps - 1 else i_iter + 1  # for last iter

                # Validate and calculate mIoU
                self.validate(i_iter, val_loader)
                miou = compute_mIoU(i_iter, args.datasets.target.val.label_dir, self.save_path, args.datasets.target.json_file,
                            args.datasets.target.base_list, args.results_dir)

                log.info('saving weights...')

                # TODO: SAVE ONLY BEST AND LAST MODELS
                torch.save(self.model.state_dict(), join(args.snapshot_dir, 'GTA5' + '.pth'))
                if args.method.adversarial:
                    torch.save(self.model_D.state_dict(), join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))
                if args.method.self:
                    torch.save(self.model_A.state_dict(), join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_Aux.pth'))

                # SAVE LOSS PLOT, EXAMPLE TENSORS
                save_losses_plot(args.results_dir, self.losses)
                # TODO: SAVE TENSORS ALSO FOR ADV AND SELF
                save_segmentations(args.images_dir, images_s, labels_s, pred_source1, images_t)
                save_rotations(args.images_dir, pred_target, rotated_pred_target, i_iter)

            del images_s, labels_s, pred_source1, pred_source2, pred_source1_, pred_source2_
            del images_t, labels_t, pred_target1, pred_target2, pred_target1_, pred_target2_, rotated_pred_target

        end = time.time()
        days = int((end - start) / 86400)
        log.info('Total training time: {} days, {} hours, {} min, {} sec '.format(days, int((end - start) / 3600)-(days*24), int((end - start) / 60 % 60),  int((end - start) % 60)))
        print('### Experiment: ' + args.experiment + ' finished ###')

    def validate(self, current_iter, val_loader):
        self.model.eval()
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True)
        self.save_path = join(self.args.prediction_dir, str(current_iter))
        os.makedirs(self.save_path, exist_ok=True)

        print('### STARTING EVALUATING ###')
        print('total to process: %d' % len(val_loader))
        with torch.no_grad():
            for index, batch in enumerate(val_loader):
                if index % 100 == 0:
                    print('%d processed' % index)
                image, _, name, = batch
                output1, output2 = self.model(image.to(self.device))
                output = interp(output1 + output2).cpu().data[0].numpy()
                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
                output_col = colorize_mask(output)
                output = Image.fromarray(output)

                name = name[0].split('/')[-1]
                output.save('%s/%s' % (self.save_path, name))
                output_col.save('%s/%s_color.png' % (self.save_path, name.split('.')[0]))

        print('### EVALUATING FINISHED ###')
def train_eval(train_loader, val_loader, model, criterion, optimizer, args, epoch, fnames=[]):
	batch_time = AverageMeter()
	losses = AverageMeter()
	top1 = AverageMeter()

	model.eval()

	end = time.time()
	scores = np.zeros((len(train_loader.dataset), args.num_classes))
	labels = np.zeros((len(train_loader.dataset), ))

	# Checkpoint begin batch.
	b_batch=0
	if epoch == cp_recorder.contextual['b_epoch']:
		b_batch = cp_recorder.contextual['b_batch']+1
	
	for i, (union, obj1, obj2, bpos, target, _, _, _) in enumerate(train_loader):
		# Jump to contextual batch
		if i < b_batch:
			continue

		target = target.cuda(async=True)
		union = union.cuda()
		obj1 = obj1.cuda()
		obj2 = obj2.cuda()
		bpos = bpos.cuda()
		
		output, _ = model(union, obj1, obj2, bpos)
		
		loss = criterion(output, target)
		losses.update(loss.item(), union.size(0))

		prec1 = accuracy(output.data, target)
		top1.update(prec1[0], union.size(0))

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		batch_time.update(time.time() - end)
		end = time.time()
		
		if (i+1) % args.print_freq == 0:
			"""Every 10 batches, print on screen and print train information on tensorboard
			"""
			niter = epoch * len(train_loader) + i
			print('Train [Batch {0}/{1}|Epoch {2}/{3}]:  '
					'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  '
					'Loss {loss.val:.4f} ({loss.avg:.4f})  '
					'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
						i, len(train_loader), epoch, args.epoch, batch_time=batch_time,
						loss=losses, top1=top1))
			
		 	writer.add_scalars('Loss (per batch)', {'train-10b': loss.item()}, niter)
			writer.add_scalars('Prec@1 (per batch)', {'train-10b': prec1[0]}, niter)

		
		if (i+1) % (args.print_freq*10) == 0 :
			# Every 100 batches, print on screen and print validation information on tensorboard
			
			top1_avg_val, loss_avg_val, prec, recall, ap = validate_eval(val_loader, model, criterion, args, epoch)
			writer.add_scalars('Loss (per batch)', {'valid': loss_avg_val}, niter)
			writer.add_scalars('Prec@1 (per batch)', {'valid': top1_avg_val}, niter)
			writer.add_scalars('mAP (per batch)', {'valid': np.nan_to_num(ap).mean()}, niter)

			# Save checkpoint every 100 batches.
			cp_recorder.record_contextual({'b_epoch': epoch, 'b_batch': i, 'prec': top1_avg_val, 'loss': loss_avg_val, 
				'class_prec': prec, 'class_recall': recall, 'class_ap': ap, 'mAP': np.nan_to_num(ap).mean()})
			cp_recorder.save_checkpoint(model)
		
		# Record scores.
		output_f = F.softmax(output, dim=1)  # To [0, 1]
		output_np = output_f.data.cpu().numpy()
		labels_np = target.data.cpu().numpy()
		b_ind = i*args.batch_size
		e_ind = b_ind + min(args.batch_size, output_np.shape[0])
		scores[b_ind:e_ind, :] = output_np
		labels[b_ind:e_ind] = labels_np

	
	res_scores = multi_scores(scores, labels, ['precision', 'recall', 'average_precision'])
	print('Train [Epoch {0}/{1}]:  '
		'*Time {2:.2f}mins ({batch_time.avg:.2f}s)  '
		'*Loss {loss.avg:.4f}  '
		'*Prec@1 {top1.avg:.3f}'.format(epoch, args.epoch, batch_time.sum/60,
			batch_time=batch_time, loss=losses, top1=top1))
	
	return top1.avg, losses.avg, res_scores['precision'], res_scores['recall'], res_scores['average_precision']
Beispiel #9
0
    def train_one_epoch(self):
        # initialize tqdm batch
        tqdm_batch = tqdm(self.dataloader.loader,
                          total=self.dataloader.num_iterations,
                          desc="epoch-{}-".format(self.current_epoch))

        self.netG.train()
        self.netD.train()

        epoch_lossG = AverageMeter()
        epoch_lossD = AverageMeter()

        for curr_it, x in enumerate(tqdm_batch):
            #y = torch.full((self.batch_size,), self.real_label)
            x = x[0]
            y = torch.randn(x.size(0), )
            fake_noise = torch.randn(x.size(0), self.config.g_input_size, 1, 1)

            if self.cuda:
                x = x.cuda(async=self.config.async_loading)
                y = y.cuda(async=self.config.async_loading)
                fake_noise = fake_noise.cuda(async=self.config.async_loading)

            x = Variable(x)
            y = Variable(y)
            fake_noise = Variable(fake_noise)
            ####################
            # Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            # train with real
            self.netD.zero_grad()
            D_real_out = self.netD(x)
            y.fill_(self.real_label)
            loss_D_real = self.loss(D_real_out, y)
            loss_D_real.backward()

            # train with fake
            G_fake_out = self.netG(fake_noise)
            y.fill_(self.fake_label)

            D_fake_out = self.netD(G_fake_out.detach())

            loss_D_fake = self.loss(D_fake_out, y)
            loss_D_fake.backward()
            #D_mean_fake_out = D_fake_out.mean().item()

            loss_D = loss_D_fake + loss_D_real
            self.optimD.step()

            ####################
            # Update G network: maximize log(D(G(z)))
            self.netG.zero_grad()
            y.fill_(self.real_label)
            D_out = self.netD(G_fake_out)
            loss_G = self.loss(D_out, y)
            loss_G.backward()

            #D_G_mean_out = D_out.mean().item()

            self.optimG.step()

            epoch_lossD.update(loss_D.item())
            epoch_lossG.update(loss_G.item())

            self.current_iteration += 1

            self.summary_writer.add_scalar("epoch/Generator_loss",
                                           epoch_lossG.val,
                                           self.current_iteration)
            self.summary_writer.add_scalar("epoch/Discriminator_loss",
                                           epoch_lossD.val,
                                           self.current_iteration)

        gen_out = self.netG(self.fixed_noise)
        out_img = self.dataloader.plot_samples_per_epoch(
            gen_out.data, self.current_iteration)
        self.summary_writer.add_image('train/generated_image', out_img,
                                      self.current_iteration)

        tqdm_batch.close()

        self.logger.info("Training at epoch-" + str(self.current_epoch) +
                         " | " + "Discriminator loss: " +
                         str(epoch_lossD.val) + " - Generator Loss-: " +
                         str(epoch_lossG.val))
Beispiel #10
0
    def train_one_epoch(self):
        tqdm_batch = tqdm(self.dataloader, total=self.dataset.num_iterations,
                          desc="epoch-{}-".format(self.current_epoch))

        self.model.train()
        self.discriminator.train()

        epoch_loss = AverageMeter()
        epoch_lossD = AverageMeter()

        for curr_it, (note, pre_note, position) in enumerate(tqdm_batch):
            if self.cuda:
                note = note.cuda(async=self.config.async_loading)
                pre_note = pre_note.cuda(async=self.config.async_loading)
                position = position.cuda(async=self.config.async_loading)

            note = Variable(note)
            pre_note = Variable(pre_note)
            position = Variable(position)

            ####################
            self.model.zero_grad()
            self.discriminator.zero_grad()
            zeros = torch.randn(note.size(0), ).fill_(0.).cuda()
            ones = torch.randn(note.size(0), ).fill_(1.).cuda()

            ####################
            gen_note, mean, var = self.model(note, pre_note, position)
            f_logits = self.discriminator(gen_note)
            r_logits = self.discriminator(note)

            ####################
            gan_loss = self.lossD(f_logits, ones)

            loss_model = self.loss(gen_note, note, mean, var, gan_loss)
            loss_model.backward(retain_graph=True)
            self.optimVAE.step()

            ####################
            r_lossD = self.lossD(r_logits, ones)
            f_lossD = self.lossD(f_logits, zeros)

            loss_D = r_lossD + f_lossD
            loss_D.backward(retain_graph=True)
            self.optimD.step()

            ####################
            epoch_lossD.update(loss_D.item())
            epoch_loss.update(loss_model.item())

            self.current_iteration += 1

            self.summary_writer.add_scalar("epoch/Generator_loss", epoch_loss.val, self.current_iteration)
            self.summary_writer.add_scalar("epoch/Discriminator_loss", epoch_lossD.val, self.current_iteration)

        out_img = self.model(self.fixed_noise, self.zero_note, torch.tensor([330], dtype=torch.long).cuda(), False)
        self.summary_writer.add_image('train/generated_image',
                                      torch.gt(out_img, 0.3).type('torch.FloatTensor').view(1, 384, 96) * 255,
                                      self.current_iteration)

        tqdm_batch.close()

        self.logger.info("Training at epoch-" + str(self.current_epoch) + " | " + "Discriminator loss: " +
                         str(epoch_lossD.val) + " - Generator Loss-: " + str(epoch_loss.val))

        if epoch_loss.val < self.best_error:
            self.best_error = epoch_loss.val
            return True
        else:
            return False
    def _train_epoch(self, epoch):
        if self.gpu == 0:
            self.logger.info('\n')

        self.model.train()

        self.supervised_loader.train_sampler.set_epoch(epoch)
        self.unsupervised_loader.train_sampler.set_epoch(epoch)

        if self.mode == 'supervised':
            dataloader = iter(self.supervised_loader)
            tbar = tqdm(range(len(self.supervised_loader)), ncols=135)
        else:
            dataloader = iter(
                zip(cycle(self.supervised_loader),
                    cycle(self.unsupervised_loader)))
            tbar = tqdm(range(self.iter_per_epoch), ncols=135)

        self._reset_metrics()

        for batch_idx in tbar:

            if self.mode == 'supervised':
                (input_l, target_l), (input_ul,
                                      target_ul) = next(dataloader), (None,
                                                                      None)
            else:
                (input_l, target_l), (input_ul, target_ul, ul1, br1, ul2, br2,
                                      flip) = next(dataloader)

            input_l, target_l = input_l.cuda(non_blocking=True), target_l.cuda(
                non_blocking=True)
            input_ul, target_ul = input_ul.cuda(
                non_blocking=True), target_ul.cuda(non_blocking=True)
            self.optimizer.zero_grad()

            if self.mode == 'supervised':
                total_loss, cur_losses, outputs = self.model(
                    x_l=input_l,
                    target_l=target_l,
                    x_ul=input_ul,
                    curr_iter=batch_idx,
                    target_ul=target_ul,
                    epoch=epoch - 1)
            else:
                kargs = {
                    'gpu': self.gpu,
                    'ul1': ul1,
                    'br1': br1,
                    'ul2': ul2,
                    'br2': br2,
                    'flip': flip
                }
                total_loss, cur_losses, outputs = self.model(
                    x_l=input_l,
                    target_l=target_l,
                    x_ul=input_ul,
                    curr_iter=batch_idx,
                    target_ul=target_ul,
                    epoch=epoch - 1,
                    **kargs)
                target_ul = target_ul[:, 0]

            total_loss.backward()
            self.optimizer.step()

            if self.gpu == 0:
                if batch_idx % 100 == 0:
                    self.logger.info("epoch: {} train_loss: {}".format(
                        epoch, total_loss))

            if batch_idx == 0:
                for key in cur_losses:
                    if not hasattr(self, key):
                        setattr(self, key, AverageMeter())

            # self._update_losses has already implemented synchronized DDP
            self._update_losses(cur_losses)

            self._compute_metrics(outputs, target_l, target_ul, epoch - 1)

            if self.gpu == 0:
                logs = self._log_values(cur_losses)

                if batch_idx % self.log_step == 0:
                    self.wrt_step = (epoch - 1) * len(
                        self.unsupervised_loader) + batch_idx
                    self._write_scalars_tb(logs)

                # if batch_idx % int(len(self.unsupervised_loader)*0.9) == 0:
                #     self._write_img_tb(input_l, target_l, input_ul, target_ul, outputs, epoch)

                descrip = 'T ({}) | '.format(epoch)
                for key in cur_losses:
                    descrip += key + ' {:.2f} '.format(
                        getattr(self, key).average)
                descrip += 'm1 {:.2f} m2 {:.2f}|'.format(
                    self.mIoU_l, self.mIoU_ul)
                tbar.set_description(descrip)

            del input_l, target_l, input_ul, target_ul
            del total_loss, cur_losses, outputs

            self.lr_scheduler.step(epoch=epoch - 1)

        return logs if self.gpu == 0 else None
Beispiel #12
0
class MnistAgent:
    def __init__(self, config):

        self.config = config
        self.logger = logging.getLogger("Agent")
        # define models
        self.model = config.model

        # define data_loader
        bs = config.batch_size
        num_workers = config["num_workers"] if "num_workers" in config else 8
        data_root = config.data_root
        size = config.image_size
        images_per_class = config.images_per_class
        specific_classes = Config[
            'specific_classes'] if 'specific_classes' in config else None

        # define loss
        loss_name = config.loss.name if "loss" in config else "cross_entropy"

        if loss_name == "cross_entropy":
            self.loss = nn.CrossEntropyLoss()
        elif loss_name == "smooth_svm":
            tau = float(config.loss.tau)
            alpha = float(config.loss.alpha)
            k = int(config.loss.k)
            self.loss = SmoothSVM(config.num_classes, alpha, tau, 3)

        if specific_classes:
            self.loss = nn.CrossEntropyLoss()

        # define optimizer
        if self.config.optim == "SGD":
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.config.learning_rate,
                                       nesterov=self.config.nesterov,
                                       momentum=self.config.momentum,
                                       weight_decay=self.config.w_decay)
        elif self.config.optim == "ADAM":
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=self.config.learning_rate,
                                        betas=(0.9, 0.99))

        self.acum_batches = int(
            config["acum_batches"]) if "acum_batches" in config else 0

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_metric = 0

        self.summary_writer = None

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda

        # set the manual seed for torch
        if self.cuda:
            self.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            self.model = self.model.to(self.device)
            self.loss = self.loss.to(self.device)

            self.logger.info("Program will run on *****GPU-CUDA***** ")
            #            print_cuda_statistics()
        else:
            self.device = torch.device("cpu")
            self.logger.info("Program will run on *****CPU*****\n")

        # Model Loading from the latest checkpoint if not found start from scratch.
        if self.config.checkpoint:
            self.load_checkpoint(self.config.checkpoint)

        input_channels = int(
            config.input_channels) if "input_channels" in config else 1
        prob_drop_stroke = float(
            config.prob_drop_stroke) if "prob_drop_stroke" in config else 0.0
        self.input_channels = input_channels
        self.train_data_loader, self.train_dataset = dataset.train(
            data_root,
            bs,
            images_per_class,
            size,
            num_workers,
            input_channels=input_channels,
            prob_drop_stroke=prob_drop_stroke,
            specific_folders=specific_classes)
        self.valid_data_loader, self.valid_dataset = dataset.valid(
            data_root,
            bs,
            size,
            num_workers,
            input_channels=input_channels,
            specific_folders=specific_classes)

        if "scheduler" in self.config:
            if self.config.scheduler.type == "Cyclic":
                self.scheduler = CyclicScheduler(
                    self.optimizer,
                    min_lr=config.learning_rate,
                    max_lr=config.scheduler.max_lr,
                    period=config.scheduler.period,
                    warm_start=config.scheduler.warm_start)
            elif self.config.scheduler.type == "Cyclic2":
                epoch_step = config.scheduler.step_size

                minibatches_in_epoch = int(len(self.train_dataset) / bs)
                steps_up = int(minibatches_in_epoch * epoch_step / 2.0)

                self.scheduler = CyclicLR(
                    self.optimizer,
                    base_lr=config.learning_rate,
                    max_lr=config.scheduler.max_lr,
                    step_size_up=steps_up,
                    mode=config.scheduler.mode,
                    gamma=config.scheduler.gamma,
                )
            elif self.config.scheduler.type == "LROnPlateau":
                patience = self.config.scheduler.patience
                factor = self.config.scheduler.factor
                min_lr = self.config.scheduler.min_lr
                threshold = self.config.scheduler.threshold
                self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                    self.optimizer,
                    patience=patience,
                    factor=factor,
                    min_lr=min_lr,
                    threshold=threshold)
        else:
            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, patience=1)

        # Summary Writer
        run_name = f'{config.exp_name}'
        self.summary_writer = SummaryWriter(self.config.summary_dir, run_name)

    def load_checkpoint(self, filename="checkpoint.pth.tar"):
        """
        Latest checkpoint loader
        :param file_name: name of the checkpoint file
        :return:
        """
        checkpoint = torch.load(filename)
        self.model.load_state_dict(checkpoint)

    def save_checkpoint(self, file_name=None, is_best=0):
        """
        Checkpoint saver
        :param file_name: name of the checkpoint file
        :param is_best: boolean flag to indicate whether current checkpoint's accuracy is the best so far
        :return:
        """
        if not file_name:
            file_name = f"{self.current_epoch}_{self.current_iteration}.pth"
        PATH = join(self.config.checkpoint_dir, file_name)
        if not self.config.dry_run:
            torch.save(self.model.state_dict(), PATH)
            self.logger.info(f"model saved in {PATH}")

    def run(self):
        """
        The main operator
        :return:
        """
        try:
            self.train()

        except KeyboardInterrupt:
            self.finalize()
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def run_validation_only(self):
        print("Validation oonlyyyy")

        self.validate()

    def run_create_concrete():
        path = "./configs/pairs/"
        config_files = [
            join(path, f) for f in listdir(path) if isfile(join(path, f))
            if ".py" not in f and "generic" not in f and "__" not in f
        ]

        for config_file in config_files:
            print("MMMMMMMMMMMMMMMMMMMMMMMMMM", config_file)
            _config = process_config(config_file, create_folders=False)

            data_root = join(_config.data_root, "valid")

            ds = dataset.Dataset(data_root,
                                 specific_folders=_config.specific_classes)
            classes = ds.classes

            ## Find checkpoint
            print(_config.exp_name)
            checkpoint_file = glob(
                f"./experiments/{_config.exp_name}*/checkpoints/killed*.pth"
            )[0]
            self.load_checkpoint(checkpoint_file)
            print("CHECKPOINT!", checkpoint_file)

            input_channels = int(_config.input_channels)

            loader, test_ds = dataset.test(
                "./input/quickdraw/test_simplified.csv",
                200,
                _config.image_size,
                num_workers=8,
                input_channels=input_channels)

            self.validate_concrete("test", classes, loader, ds)

            ## valid
            ds = self.valid_dataset
            loader = self.valid_data_loader

            self.validate_concrete("valid", classes, loader, ds)

    def train(self):
        """
        Main training loop
        :return:
        """
        for epoch in range(1, self.config.max_epoch + 1):
            self.train_one_epoch()
            self.validate()

            self.current_epoch += 1

    def train_one_epoch(self):
        """
        One epoch of training
        :return:
        """
        self.loss_train_avg = AverageMeter()
        self.mapk_train_avg = AverageMeter()

        self.model.train()

        count = 0

        def logging(output, target, loss, batch_idx):
            mapk3_metric = mapk3(output, target)
            self.mapk_train_avg.update(mapk3_metric)

            self.logger.info(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tMapk3: {:.6f}'
                .format(self.current_epoch, batch_idx * len(data),
                        len(self.train_data_loader.dataset),
                        100. * batch_idx / len(self.train_data_loader),
                        loss.item(), mapk3_metric))
            if self.summary_writer:
                iteration = self.current_epoch * 1000 + int(
                    1000. * batch_idx / len(self.train_data_loader))

                self.summary_writer.add_scalar("train_loss", loss.item(),
                                               iteration)
                self.summary_writer.add_scalar("train_mapk", mapk3_metric,
                                               iteration)
                self.summary_writer.add_scalar(
                    "lr", self.optimizer.param_groups[0]['lr'], iteration)

        for batch_idx, (data, target) in enumerate(self.train_data_loader):
            if count <= 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
                count = self.acum_batches
            data, target = data.to(self.device), target.to(self.device)

            output = self.model(data)

            loss = self.loss(output, target)

            ## Loggin and gradient accum
            if batch_idx % self.config.log_interval == 0:
                logging(output, target, loss, batch_idx)

            if self.acum_batches >= 2:
                loss = loss / self.acum_batches

            loss.backward()

            self.loss_train_avg.update(loss.item())

            self.current_iteration += 1
            count = count - 1
            if self.scheduler and "step_batch" in dir(self.scheduler):
                self.scheduler.step_batch(self.current_iteration)
        self.save_checkpoint()

    def validate_concrete(self, mode, classes, loader, ds):
        self.model.eval()

        r = []
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(loader):
                print(batch_idx, len(loader))
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)

                x = output.max(1, keepdim=True)[1].cpu().numpy()
                r = r + [x]

        df = pd.DataFrame(np.concatenate(r))
        df[0] = df[0].apply(lambda x: classes[x])
        df.to_csv(f"./concrete/{mode}-{classes[0]}-{classes[1]}.csv",
                  header=False)

    def validate(self, step=False, calc_confusion=False):
        """
        One cycle of model validation
        :return:
        """
        self.model.eval()
        correct = 0
        loss_avg = AverageMeter()
        mapk_avg = AverageMeter()

        if calc_confusion:
            num_categories = 340
            confusion = np.zeros((num_categories, num_categories),
                                 dtype=np.float32)

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                print(
                    f"validation calc: {batch_idx} of {len(self.valid_data_loader)}"
                )

                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)

                if calc_confusion:
                    s_output = F.softmax(output, dim=1)
                    _, prediction_categories = s_output.topk(3,
                                                             dim=1,
                                                             sorted=True)

                    for bpc, bc in zip(prediction_categories[:, 0], target):
                        confusion[bpc, bc] += 1

                loss_value = self.loss(output,
                                       target).item()  # sum up batch loss
                loss_avg.update(loss_value)

                pred = output.max(1, keepdim=True)[
                    1]  # get the index of the max log-probability
                #                pred = refine(pred.cpu().numpy(), "valid", batch_idx * self.config.batch_size)
                #                pred = torch.from_numpy(pred).cuda()

                correct += pred.eq(target.view_as(pred)).sum().item()

                mapk3_metric = mapk3(output, target, "valid",
                                     batch_idx * self.config.batch_size)
                mapk_avg.update(mapk3_metric)

        if hasattr(self, "loss_train_avg") and hasattr(self, "mapk_train_avg"):
            self.logger.info("Epoch, LossV, LossT, Mapk3V, Mapk3T")
            self.logger.info(
                "| {} | {:.4f} | {:.4f} | {:.4f} | {:.4f} | {}".format(
                    self.current_epoch, loss_avg.val, self.loss_train_avg.val,
                    mapk_avg.val, self.mapk_train_avg.val, step))

            if self.summary_writer:
                iteration = (self.current_epoch + 1) * 1000

                self.summary_writer.add_scalar("valid_loss", loss_avg.val,
                                               iteration)
                self.summary_writer.add_scalar("valid_mapk", mapk_avg.val,
                                               iteration)

            if self.scheduler and "step" in dir(self.scheduler):
                old_lr = self.optimizer.param_groups[0]['lr']
                self.scheduler.step(loss_avg.val)
                new_lr = self.optimizer.param_groups[0]['lr']
                if old_lr > new_lr:
                    self.logger.info(f"Changing LR from {old_lr} to {new_lr}")

        self.model.train()

        if calc_confusion:
            for c in range(confusion.shape[0]):
                category_count = confusion[c, :].sum()
                if category_count != 0:
                    confusion[c, :] /= category_count

            np.save("confusion.np", confusion)

    def test(self):
        self.model.eval()

        test_data_loader, test_dataset = dataset.test(
            self.config.testcsv,
            self.config.batch_size,
            self.config.image_size,
            num_workers=8,
            input_channels=self.input_channels)

        self.load_checkpoint(self.config.checkpoint)

        cls_to_idx = {
            cls: idx
            for idx, cls in enumerate(self.train_dataset.classes)
        }
        print(cls_to_idx)
        idx_to_cls = {cls_to_idx[c]: c for c in cls_to_idx}

        def row2string(r):
            v = [r[-1], r[-2], r[-3]]
            v = [v.item() for v in v]
            v = map(lambda v: v if v < 340 else 0, v)

            v = [idx_to_cls[v].replace(' ', '_') for v in v]

            return ' '.join(v)

        labels = []
        key_ids = []
        outputs = []

        with torch.no_grad():
            for idx, (data, target) in enumerate(test_data_loader):
                data = data.to(self.device)
                output = self.model(data)

                n = output.detach().cpu().numpy()
                outputs.append(n)

                order = np.argsort(n, 1)[:, -3:]
                #                order = refine(order, "test", idx * self.config.batch_size)

                predicted_y = [row2string(o) for o in order]
                labels = labels + predicted_y
                key_ids = key_ids + target.numpy().tolist()

                if idx % 10 == 0:
                    print(f"{idx} of  {len(test_data_loader)}")

        import pickle

        with open('labels', 'wb') as fp:
            pickle.dump(key_ids, fp)
        with open('idx_to_cls', 'wb') as fp:
            pickle.dump(idx_to_cls, fp)

        _all = np.concatenate(outputs)
        np.save("output.npy", _all)

        d = {'key_id': key_ids, 'word': labels}
        df = pd.DataFrame.from_dict(d)
        df.to_csv('submission.csv', index=False)

    def finalize(self):
        """
        Finalizes all the operations of the 2 Main classes of the process, the operator and the data loader
        :return:
        """
        self.logger.info("saving before finalize")
        file_name = f"killed_{self.current_epoch}_{self.current_iteration}.pth"
        self.save_checkpoint(file_name=file_name)
        self.logger.info(f"saved! with name {file_name}")
Beispiel #13
0
    def validate(self, step=False, calc_confusion=False):
        """
        One cycle of model validation
        :return:
        """
        self.model.eval()
        correct = 0
        loss_avg = AverageMeter()
        mapk_avg = AverageMeter()

        if calc_confusion:
            num_categories = 340
            confusion = np.zeros((num_categories, num_categories),
                                 dtype=np.float32)

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                print(
                    f"validation calc: {batch_idx} of {len(self.valid_data_loader)}"
                )

                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)

                if calc_confusion:
                    s_output = F.softmax(output, dim=1)
                    _, prediction_categories = s_output.topk(3,
                                                             dim=1,
                                                             sorted=True)

                    for bpc, bc in zip(prediction_categories[:, 0], target):
                        confusion[bpc, bc] += 1

                loss_value = self.loss(output,
                                       target).item()  # sum up batch loss
                loss_avg.update(loss_value)

                pred = output.max(1, keepdim=True)[
                    1]  # get the index of the max log-probability
                #                pred = refine(pred.cpu().numpy(), "valid", batch_idx * self.config.batch_size)
                #                pred = torch.from_numpy(pred).cuda()

                correct += pred.eq(target.view_as(pred)).sum().item()

                mapk3_metric = mapk3(output, target, "valid",
                                     batch_idx * self.config.batch_size)
                mapk_avg.update(mapk3_metric)

        if hasattr(self, "loss_train_avg") and hasattr(self, "mapk_train_avg"):
            self.logger.info("Epoch, LossV, LossT, Mapk3V, Mapk3T")
            self.logger.info(
                "| {} | {:.4f} | {:.4f} | {:.4f} | {:.4f} | {}".format(
                    self.current_epoch, loss_avg.val, self.loss_train_avg.val,
                    mapk_avg.val, self.mapk_train_avg.val, step))

            if self.summary_writer:
                iteration = (self.current_epoch + 1) * 1000

                self.summary_writer.add_scalar("valid_loss", loss_avg.val,
                                               iteration)
                self.summary_writer.add_scalar("valid_mapk", mapk_avg.val,
                                               iteration)

            if self.scheduler and "step" in dir(self.scheduler):
                old_lr = self.optimizer.param_groups[0]['lr']
                self.scheduler.step(loss_avg.val)
                new_lr = self.optimizer.param_groups[0]['lr']
                if old_lr > new_lr:
                    self.logger.info(f"Changing LR from {old_lr} to {new_lr}")

        self.model.train()

        if calc_confusion:
            for c in range(confusion.shape[0]):
                category_count = confusion[c, :].sum()
                if category_count != 0:
                    confusion[c, :] /= category_count

            np.save("confusion.np", confusion)
Beispiel #14
0
    def validate_epoch(self, epoch=0, store=False):
        self.logger.show_nl("Epoch: [{0}]".format(epoch))
        losses = AverageMeter()
        smooth_losses = AverageMeter()
        image_losses = AverageMeter()
        label_losses = AverageMeter()
        len_val = len(self.val_loader)
        pb = tqdm(self.val_loader)

        self.model.eval()

        with torch.no_grad():
            for i, (name, _, hsi) in enumerate(pb):
                hsi = hsi.to(self.device)
                sens = self.sens_list[i].to(self.device)
                img_real = create_rgb(sens, hsi)

                pred = self.model(img_real)

                img_pred = create_rgb(pred, hsi)

                smooth_loss = self.smooth_criterion(pred, pred.size(0))
                image_loss = self.image_criterion(img_pred, img_real)
                label_loss = self.label_criterion(pred, sens)
                loss = self.calc_total_loss(image_loss, label_loss, smooth_loss)

                losses.update(loss.item(), n=self.batch_size)
                image_losses.update(image_loss.item(), n=self.batch_size)
                label_losses.update(label_loss.item(), n=self.batch_size)
                smooth_losses.update(smooth_loss.item(), n=self.batch_size)

                # img_pred = to_array(img_pred[0])
                # img_real = to_array(img_real[0])

                for m in self.metrics:
                    m.update(img_pred, img_real)

                desc = self.logger.make_desc(
                    i + 1, len_val,
                    ('loss', losses, '.4f'),
                    ('IL', image_losses, '.4f'),
                    ('LL', label_losses, '.4f'),
                    ('SL', smooth_losses, '.4f'),
                    *(
                        (m.__name__, m, '.4f')
                        for m in self.metrics
                    )
                )

                pb.set_description(desc)
                self.logger.dump(desc)

                # @zjw: tensorboard
                self.logger.add_scalar('Estimator-Loss/validate/total_losses', losses.val, epoch * len_val + i)
                self.logger.add_scalar('Estimator-Loss/validate/image_losses', image_losses.val, epoch * len_val + i)
                self.logger.add_scalar('Estimator-Loss/validate/label_losses', label_losses.val, epoch * len_val + i)
                self.logger.add_scalar('Estimator-Loss/validate/smooth_losses', smooth_losses.val, epoch * len_val + i)
                for m in self.metrics:
                    self.logger.add_scalar('Estimator-validate/metrics/' + m.__name__, m.val, epoch * len_val + i)

                if store:
                    self.logger.add_images('Estimator-validate/real-pred', torch.cat((img_real, img_pred), dim=3),
                                           epoch * len_val + i)
                    self.save_image_tensor(self.gpc.add_suffix(name[0], suffix='real-pred', underline=True),
                                    torch.cat((img_real, img_pred), dim=3), epoch)
                    # self.save_image(self.gpc.add_suffix(name[0], suffix='pred', underline=True), (img_pred*255).astype('uint8'), epoch)
                    # self.save_image(self.gpc.add_suffix(name[0], suffix='real', underline=True), (img_real*255).astype('uint8'), epoch)

        return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc)
Beispiel #15
0
    def train_epoch(self, epoch=0):
        losses = AverageMeter()
        smooth_losses = AverageMeter()
        image_losses = AverageMeter()
        label_losses = AverageMeter()
        len_train = len(self.train_loader)
        pb = tqdm(self.train_loader)

        self.model.train()

        for i, (_, hsi) in enumerate(pb):
            # sens for sensitivity and hsi for hyperspectral image
            hsi = hsi.to(self.device)
            sens = create_sensitivity('C').to(self.device)

            # The reconstructed RGB in the range [0,1]
            img_real = create_rgb(sens, hsi)

            pred = self.model(img_real)

            # Reconstruct RGB from sensitivity function and HSI
            img_pred = create_rgb(pred, hsi)

            smooth_loss = self.smooth_criterion(pred, pred.size(0))
            image_loss = self.image_criterion(img_pred, img_real)
            label_loss = self.label_criterion(pred, sens)
            loss = self.calc_total_loss(image_loss, label_loss, smooth_loss)

            losses.update(loss.item(), n=self.batch_size)
            image_losses.update(image_loss.item(), n=self.batch_size)
            label_losses.update(label_loss.item(), n=self.batch_size)
            smooth_losses.update(smooth_loss.item(), n=self.batch_size)

            # Compute gradients and do SGD step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            desc = self.logger.make_desc(
                i + 1, len_train,
                ('loss', losses, '.4f'),
                ('IL', image_losses, '.4f'),
                ('LL', label_losses, '.4f'),
                ('SL', smooth_losses, '.4f')
            )

            pb.set_description(desc)
            self.logger.dump(desc)

            # @zjw: tensorboard
            self.logger.add_scalar('Estimator-Loss/train/total_losses', losses.val, epoch * len_train + i)
            self.logger.add_scalar('Estimator-Loss/train/image_losses', image_losses.val, epoch * len_train + i)
            self.logger.add_scalar('Estimator-Loss/train/label_losses', label_losses.val, epoch * len_train + i)
            self.logger.add_scalar('Estimator-Loss/train/smooth_losses', smooth_losses.val, epoch * len_train + i)
            self.logger.add_scalar('Estimator-Lr', self.optimizer.param_groups[0]['lr'], epoch * len_train + i)
Beispiel #16
0
    def train_one_epoch(self):
        tqdm_batch = tqdm(self.dataloader,
                          total=self.dataset.num_iterations,
                          desc="epoch-{}-".format(self.current_epoch))

        self.model.train()
        self.phrase_model.train()

        epoch_loss = AverageMeter()
        epoch_phrase_loss = AverageMeter()

        for curr_it, (note, pre_note, pre_phrase,
                      position) in enumerate(tqdm_batch):
            if self.cuda:
                note = note.cuda(async=self.config.async_loading)
                pre_note = pre_note.cuda(async=self.config.async_loading)
                pre_phrase = pre_phrase.cuda(async=self.config.async_loading)
                position = position.cuda(async=self.config.async_loading)

            note = Variable(note)
            pre_note = Variable(pre_note)
            pre_phrase = Variable(pre_phrase)
            position = Variable(position)

            ####################
            self.model.zero_grad()
            self.phrase_model.zero_grad()

            #################### Generator ####################
            self.free(self.model)
            self.frozen(self.phrase_model)

            phrase_feature, _, _ = self.phrase_model(pre_phrase, position)
            gen_note, mean, var, pre_mean, pre_var, _, _ = self.model(
                note, pre_note, phrase_feature)

            gen_loss = self.loss(gen_note, note, mean, var, pre_mean, pre_var)
            gen_loss.backward(retain_graph=True)
            self.optimVAE.step()

            #################### Phrase Encoder ####################
            self.free(self.phrase_model)
            self.frozen(self.model)

            phrase_feature, mean, var = self.phrase_model(pre_phrase, position)
            gen_note, _, _, _, _, _, _ = self.model(note, pre_note,
                                                    phrase_feature)

            phrase_loss = self.phrase_loss(gen_note, note, mean, var)
            phrase_loss.backward(retain_graph=True)
            self.optim_phrase.step()

            ####################
            epoch_loss.update(gen_loss.item())
            epoch_phrase_loss.update(phrase_loss.item())

            self.current_iteration += 1

            self.summary_writer.add_scalar("epoch/Generator_loss",
                                           epoch_loss.val,
                                           self.current_iteration)
            self.summary_writer.add_scalar("epoch/PhrasseEncoder_loss",
                                           epoch_phrase_loss.val,
                                           self.current_iteration)

        tqdm_batch.close()
        self.scheduler.step(epoch_loss.val)
        self.scheduler_phrase.step(epoch_phrase_loss.val)

        self.logger.info("Training at epoch-" + str(self.current_epoch) +
                         " | " + "Discriminator loss: " +
                         " - Generator Loss-: " + str(epoch_loss.val))

        if epoch_loss.val < self.best_error:
            self.best_error = epoch_loss.val
            return True, epoch_loss.val
        else:
            return False, epoch_loss.val
Beispiel #17
0
    def train(self, src_loader, tar_loader, val_loader, test_loader):

        num_batches = len(src_loader)
        print_freq = max(num_batches // self.config.training_num_print_epoch,
                         1)
        i_iter = self.start_iter
        start_epoch = i_iter // num_batches
        num_epochs = self.config.num_epochs
        best_acc = 0
        for epoch in range(start_epoch, num_epochs):
            self.model.train()
            batch_time = AverageMeter()
            losses = AverageMeter()

            # adjust learning rate
            self.scheduler.step()

            for it, src_batch in enumerate(src_loader):
                t = time.time()

                self.optimizer.zero_grad()
                src = src_batch
                src = to_device(src, self.device)
                src_imgs, src_cls_lbls, src_aux_imgs, src_aux_lbls = src

                self.optimizer.zero_grad()

                src_main_logits = self.model(src_imgs, 'main_task')
                src_main_loss = self.class_loss_func(src_main_logits,
                                                     src_cls_lbls)
                loss = src_main_loss * self.config.loss_weight['main_task']

                loss.backward()
                self.optimizer.step()

                losses.update(loss.item(), src_imgs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - t)

                i_iter += 1

                print_string = 'Epoch {:>2} | iter {:>4} | loss:{:.3f}| src_main: {:.3f} |' + '|{:4.2f} s/it'

                self.logger.info(
                    print_string.format(epoch, i_iter, losses.avg,
                                        src_main_loss.item(), batch_time.avg))
                self.writer.add_scalar('losses/all_loss', losses.avg, i_iter)
                self.writer.add_scalar('losses/src_main_loss', src_main_loss,
                                       i_iter)
            # del loss, src_class_loss, src_aux_loss, tar_aux_loss, tar_entropy_loss
            # del src_aux_logits, src_class_logits
            # del tar_aux_logits, tar_class_logits

            # validation
            self.save(self.config.model_dir, i_iter)

            if val_loader is not None:
                self.logger.info('validating...')
                class_acc = self.test(val_loader)
                # self.writer.add_scalar('val/aux_acc', class_acc, i_iter)
                self.writer.add_scalar('val/class_acc', class_acc, i_iter)
                if class_acc > best_acc:
                    best_acc = class_acc
                    self.save(self.config.best_model_dir, i_iter)
                    # todo copy current model to best model
                self.logger.info(
                    'Best testing accuracy: {:.2f} %'.format(best_acc))

            if test_loader is not None:
                self.logger.info('testing...')
                class_acc = self.test(test_loader)
                # self.writer.add_scalar('test/aux_acc', class_acc, i_iter)
                self.writer.add_scalar('test/class_acc', class_acc, i_iter)
                if class_acc > best_acc:
                    best_acc = class_acc
                    # todo copy current model to best model
                self.logger.info(
                    'Best testing accuracy: {:.2f} %'.format(best_acc))

        self.logger.info('Best testing accuracy: {:.2f} %'.format(best_acc))
        self.logger.info('Finished Training.')
def validate_eval(val_loader, model, criterion, args, epoch=None, fnames=[]):
	batch_time = AverageMeter()
	losses = AverageMeter()
	top1 = AverageMeter()

	model.eval()

	end = time.time()
	scores = np.zeros((len(val_loader.dataset), args.num_classes))
	labels = np.zeros((len(val_loader.dataset), ))
	for i, (union, obj1, obj2, bpos, target, _, _, _) in enumerate(val_loader):
		with torch.no_grad():
			target = target.cuda(async=True)
			union = union.cuda()
			obj1 = obj1.cuda()
			obj2 = obj2.cuda()
			bpos = bpos.cuda()

			output, _ = model(union, obj1, obj2, bpos)
			
			loss = criterion(output, target)
			losses.update(loss.item(), union.size(0))
			prec1 = accuracy(output.data, target)
			top1.update(prec1[0], union.size(0))

			batch_time.update(time.time() - end)
			end = time.time()

			# Record scores.
			output_f = F.softmax(output, dim=1)  # To [0, 1]
			output_np = output_f.data.cpu().numpy()
			labels_np = target.data.cpu().numpy()
			b_ind = i*args.batch_size
			e_ind = b_ind + min(args.batch_size, output_np.shape[0])
			scores[b_ind:e_ind, :] = output_np
			labels[b_ind:e_ind] = labels_np
	
	print('Test [Epoch {0}/{1}]:  '
		'*Time {2:.2f}mins ({batch_time.avg:.2f}s)  '
		'*Loss {loss.avg:.4f}  '
		'*Prec@1 {top1.avg:.3f}'.format(epoch, args.epoch, batch_time.sum/60,
			batch_time=batch_time, top1=top1, loss=losses))

	res_scores = multi_scores(scores, labels, ['precision', 'recall', 'average_precision'])
	return top1.avg, losses.avg, res_scores['precision'], res_scores['recall'], res_scores['average_precision']
Beispiel #19
0
    def single_train(self, iter):
        self.model.train()
        running_loss = AverageMeter()
        running_aux_loss = AverageMeter()
        running_mas_loss = AverageMeter()
        running_iou = AverageMeter()
        running_pixelacc = AverageMeter()  # 像素点的准确度
        count = 0
        # local_state_dict = deepcopy(self.best_state_dict)
        # pre_loss = numpy.inf
        print(f"{iter} / {self.iters}  ==== train  \n")
        for input in self.train_dataloader:
            data, label = input
            N, C, H, W = data.size()
            if torch.cuda.is_available() and self.gpu:
                data = data.cuda(self.gpu[0])
                label = label.cuda(self.gpu[0])
            label = label.long()
            self.optimizer.zero_grad()
            label = label.view(label.shape[0], *label.shape[2:])
            mas_o, aux_o = self.model(data)
            mas_o_m = mas_o.argmax(dim=1)
            self.metricer.loadData(mas_o_m.cpu().numpy(), label.cpu().numpy())

            # aux_o = aux_o.view(-1,self.numclass)
            # mas_o = mas_o.view(-1,self.numclass)
            # label = label.view(-1)

            aux_loss = self.crition(aux_o, label)
            mas_loss = self.crition(mas_o, label)
            mas_loss, aux_loss = torch.mean(mas_loss), torch.mean(aux_loss)
            aux_weight = self.config['aux_weight']
            loss = aux_loss * aux_weight + mas_loss
            # batch = N*H*W
            running_loss.update(loss.item(), N * H * W)
            running_aux_loss.update(aux_loss.item(), N * H * W)
            running_mas_loss.update(mas_loss.item(), N * H * W)
            running_pixelacc.update(self.metricer.pixelAccuracy())
            running_iou.update(self.metricer.meanIntersectionOverUnion())

            loss.backward()
            self.optimizer.step()
            if count % self.print_freq == self.print_freq - 1:
                print(
                    f"[第{count//self.print_freq}次] --- aux_loss:{running_aux_loss.val}  "
                    f"mas_loss:{running_mas_loss.val}  "
                    f"loss:{running_loss.val}  "
                    f"pixelacc:{running_pixelacc.val}  "
                    f"meaniou:{running_iou.val}")

                self.writer.add_scalar(
                    "aux_loss", running_aux_loss.val,
                    iter * len(self.train_dataloader) + count)
                self.writer.add_scalar(
                    "mas_loss", running_mas_loss.val,
                    iter * len(self.train_dataloader) + count)
                self.writer.add_scalar(
                    "loss", running_loss.val,
                    iter * len(self.train_dataloader) + count)

                self.writer.add_scalar(
                    "pixel_acc", running_pixelacc.val,
                    iter * len(self.train_dataloader) + count)
                self.writer.add_scalar(
                    "mean_iou", running_iou.val,
                    iter * len(self.train_dataloader) + count)

                # TODO:最小iou
                # if running_loss.val < self.min_loss and \
                #             running_pixelacc.val > self.best_acc:
                #     self.min_loss = running_loss.val
                #     self.best_acc = running_pixelacc.val
                #     local_state_dict = deepcopy(self.model.state_dict())
            torch.cuda.empty_cache()
            count += 1
            self.metricer.reset()

        # self.best_state_dict = local_state_dict
        print(
            f"train result at epoch [{iter}/{self.iters}] :mIou/mAcc: {running_iou.avg}/{running_pixelacc.avg}"
        )
    def train_one_epoch(self):
        # initialize tqdm batch
        tqdm_batch = tqdm(self.trainloader.loader,
                          total=self.trainloader.num_iterations,
                          desc="epoch-{}-".format(self.current_epoch))

        self.generator.train()
        self.discriminator.train()
        epoch_loss_gen = AverageMeter()
        epoch_loss_dis = AverageMeter()
        epoch_loss_ce = AverageMeter()
        epoch_loss_unlab = AverageMeter()
        epoch_loss_fake = AverageMeter()

        for curr_it, (patches_lab, patches_unlab,
                      labels) in enumerate(tqdm_batch):
            #y = torch.full((self.batch_size,), self.real_label)
            if self.cuda:
                patches_lab = patches_lab.cuda()
                patches_unlab = patches_unlab.cuda()
                labels = labels.cuda()

            patches_lab = Variable(patches_lab)
            patches_unlab = Variable(patches_unlab.float())
            labels = Variable(labels).long()

            noise_vector = torch.tensor(
                np.random.uniform(
                    -1, 1,
                    [self.config.batch_size, self.config.noise_dim])).float()
            if self.cuda:
                noise_vector = noise_vector.cuda()
            patches_fake = self.generator(noise_vector)

            ## Discriminator
            # Supervised loss
            lab_output, lab_output_sofmax = self.discriminator(patches_lab)
            lab_loss = self.criterion(lab_output, labels)

            unlab_output, unlab_output_softmax = self.discriminator(
                patches_unlab)
            fake_output, fake_output_softmax = self.discriminator(
                patches_fake.detach())

            # Unlabeled Loss and Fake loss
            unlab_lsp = torch.logsumexp(unlab_output, dim=1)
            fake_lsp = torch.logsumexp(fake_output, dim=1)
            unlab_loss = -0.5 * torch.mean(unlab_lsp) + 0.5 * torch.mean(
                F.softplus(unlab_lsp, 1))
            fake_loss = 0.5 * torch.mean(F.softplus(fake_lsp, 1))
            discriminator_loss = lab_loss + unlab_loss + fake_loss

            self.d_optim.zero_grad()
            discriminator_loss.backward()
            self.d_optim.step()

            ## Generator
            _, _, unlab_feature = self.discriminator(patches_unlab,
                                                     get_feature=True)
            _, _, fake_feature = self.discriminator(patches_fake,
                                                    get_feature=True)

            # Feature matching loss
            unlab_feature, fake_feature = torch.mean(unlab_feature,
                                                     0), torch.mean(
                                                         fake_feature, 0)
            fm_loss = torch.mean(torch.abs(unlab_feature - fake_feature))

            # Variational Inferece loss
            mu, log_sigma = self.encoder(patches_fake)
            vi_loss = gaussian_nll(mu, log_sigma, noise_vector)

            generator_loss = fm_loss + self.config.vi_loss_weight * vi_loss

            self.g_optim.zero_grad()
            self.e_optim.zero_grad()
            generator_loss.backward()
            self.g_optim.step()
            self.e_optim.step()

            epoch_loss_gen.update(generator_loss.item())
            epoch_loss_dis.update(discriminator_loss.item())
            epoch_loss_ce.update(lab_loss.item())
            epoch_loss_unlab.update(unlab_loss.item())
            epoch_loss_fake.update(fake_loss.item())
            self.current_iteration += 1

            print("Epoch: {0}, Iteration: {1}/{2}, Gen loss: {3:.3f}, Dis loss: {4:.3f} :: CE loss {5:.3f}, Unlab loss: {6:.3f}, Fake loss: {7:.3f}, VI loss: {8:.3f}".format(
                                self.current_epoch, self.current_iteration,\
                                self.trainloader.num_iterations, generator_loss.item(), discriminator_loss.item(),\
                                lab_loss.item(), unlab_loss.item(), fake_loss.item(), vi_loss.item()))

        tqdm_batch.close()

        self.logger.info("Training at epoch-" + str(self.current_epoch) + " | " +\
         " Generator loss: " + str(epoch_loss_gen.val) +\
          " Discriminator loss: " + str(epoch_loss_dis.val) +\
           " CE loss: " + str(epoch_loss_ce.val) + " Unlab loss: " + str(epoch_loss_unlab.val) + " Fake loss: " + str(epoch_loss_fake.val))
Beispiel #21
0
    def single_eval(self, iter):
        self.model.eval()
        running_mas_loss = AverageMeter()
        running_iou = AverageMeter()
        running_pixelacc = AverageMeter()  # 像素点的准确度
        count = 0
        print(f"{iter} / {self.iters}  ===   eval    \n")
        with torch.no_grad():
            for input in self.val_dataloader:
                data, label = input
                N, C, H, W = data.shape
                if torch.cuda.is_available() and self.gpu:
                    data = data.cuda(self.gpu[0])
                    label = label.cuda(self.gpu[0])
                label = label.long()
                label = label.view(label.shape[0], *label.shape[2:])
                mas_o = self.model(data)
                mas_o_m = mas_o.argmax(dim=1)

                self.metricer.loadData(mas_o_m.cpu().numpy(),
                                       label.cpu().numpy())

                mas_loss = self.crition(mas_o, label)
                running_mas_loss.update(mas_loss.item(), N * H * W)
                running_pixelacc.update(self.metricer.pixelAccuracy())
                running_iou.update(self.metricer.meanIntersectionOverUnion())

                if count % self.print_freq == self.print_freq - 1:
                    print(
                        f"[{count//self.print_freq}次]   -----    mas_loss:{running_mas_loss.val}  "
                        f"pixelacc:{running_pixelacc.val}  "
                        f"meaniou:{running_iou.val}")

                    self.writer.add_scalar(
                        "eval_loss", running_mas_loss.val,
                        iter * len(self.val_dataloader) + count)
                    self.writer.add_scalar(
                        "pixel_acc", running_pixelacc.val,
                        iter * len(self.val_dataloader) + count)
                    self.writer.add_scalar(
                        "mean_iou", running_iou.val,
                        iter * len(self.val_dataloader) + count)

                count += 1
                self.metricer.reset()
Beispiel #22
0
def main(cfg, distributed=True):
    if distributed:
        # DPP 1
        dist.init_process_group('nccl')
        # DPP 2
        local_rank = dist.get_rank()
        print(local_rank)
        torch.cuda.set_device(local_rank)
        device = torch.device('cuda', local_rank)
    else:
        device = torch.device("cuda:0")
        local_rank = 0

    ###################################################
    mode = cfg.mode
    n_class = cfg.n_class
    model_path = cfg.model_path  # save model
    log_path = cfg.log_path
    output_path = cfg.output_path

    if local_rank == 0:
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        if not os.path.exists(log_path):
            os.makedirs(log_path)
        if not os.path.exists(output_path):
            os.makedirs(output_path)

    task_name = cfg.task_name
    print(task_name)

    ###################################
    print("preparing datasets and dataloaders......")
    batch_size = cfg.batch_size
    sub_batch_size = cfg.sub_batch_size
    size_g = (cfg.size_g, cfg.size_g)
    size_p = (cfg.size_p, cfg.size_p)
    num_workers = cfg.num_workers
    trainset_cfg = cfg.trainset_cfg
    valset_cfg = cfg.valset_cfg

    data_time = AverageMeter("DataTime", ':3.3f')
    batch_time = AverageMeter("BatchTime", ':3.3f')

    transformer_train = TransformerSegGL(crop_size=cfg.crop_size)
    dataset_train = OralDatasetSeg(
        trainset_cfg["img_dir"],
        trainset_cfg["mask_dir"],
        trainset_cfg["meta_file"],
        label=trainset_cfg["label"],
        transform=transformer_train,
    )
    if distributed:
        sampler_train = DistributedSampler(dataset_train, shuffle=True)
        dataloader_train = DataLoader(dataset_train,
                                      num_workers=num_workers,
                                      batch_size=batch_size,
                                      collate_fn=collateGL,
                                      sampler=sampler_train,
                                      pin_memory=True)
    else:
        dataloader_train = DataLoader(dataset_train,
                                      num_workers=num_workers,
                                      batch_size=batch_size,
                                      collate_fn=collateGL,
                                      shuffle=True,
                                      pin_memory=True)
    transformer_val = TransformerSegGLVal()
    dataset_val = OralDatasetSeg(valset_cfg["img_dir"],
                                 valset_cfg["mask_dir"],
                                 valset_cfg["meta_file"],
                                 label=valset_cfg["label"],
                                 transform=transformer_val)
    dataloader_val = DataLoader(dataset_val,
                                num_workers=2,
                                batch_size=batch_size,
                                collate_fn=collateGL,
                                shuffle=False,
                                pin_memory=True)

    ###################################
    print("creating models......")
    path_g = cfg.path_g
    path_g2l = cfg.path_g2l
    path_l2g = cfg.path_l2g
    model = GLNet(n_class, cfg.encoder, **cfg.model_cfg)
    if mode == 3:
        global_fixed = GLNet(n_class, cfg.encoder, **cfg.model_cfg)
    else:
        global_fixed = None
    model, global_fixed = create_model_load_weights(model,
                                                    global_fixed,
                                                    device,
                                                    mode=mode,
                                                    distributed=distributed,
                                                    local_rank=local_rank,
                                                    evaluation=False,
                                                    path_g=path_g,
                                                    path_g2l=path_g2l,
                                                    path_l2g=path_l2g)

    ###################################
    num_epochs = cfg.num_epochs
    learning_rate = cfg.lr

    optimizer = get_optimizer(model, mode, learning_rate=learning_rate)
    scheduler = LR_Scheduler(cfg.scheduler, learning_rate, num_epochs,
                             len(dataloader_train))
    ##################################
    if cfg.loss == "ce":
        criterion = nn.CrossEntropyLoss(reduction='mean')
    elif cfg.loss == "sce":
        criterion = SymmetricCrossEntropyLoss(alpha=cfg.alpha,
                                              beta=cfg.beta,
                                              num_classes=cfg.n_class)
        # criterion4 = NormalizedSymmetricCrossEntropyLoss(alpha=cfg.alpha, beta=cfg.beta, num_classes=cfg.n_class)
    elif cfg.loss == "focal":
        criterion = FocalLoss(gamma=3)
    elif cfg.loss == "ce-dice":
        criterion = nn.CrossEntropyLoss(reduction='mean')
        # criterion2 =

    #######################################
    trainer = Trainer(criterion, optimizer, n_class, size_g, size_p,
                      sub_batch_size, mode, cfg.lamb_fmreg)
    evaluator = Evaluator(n_class, size_g, size_p, sub_batch_size, mode)
    evaluation = cfg.evaluation
    val_vis = cfg.val_vis

    best_pred = 0.0
    print("start training......")

    # log
    if local_rank == 0:
        f_log = open(os.path.join(log_path, ".log"), 'w')
        log = task_name + '\n'
        for k, v in cfg.__dict__.items():
            log += str(k) + ' = ' + str(v) + '\n'
        f_log.write(log)
        f_log.flush()
    # writer
    if local_rank == 0:
        writer = SummaryWriter(log_dir=log_path)
    writer_info = {}

    for epoch in range(num_epochs):
        trainer.set_train(model)
        optimizer.zero_grad()
        tbar = tqdm(dataloader_train)
        train_loss = 0

        start_time = time.time()
        for i_batch, sample in enumerate(tbar):
            data_time.update(time.time() - start_time)
            scheduler(optimizer, i_batch, epoch, best_pred)
            # loss = trainer.train(sample, model)
            loss = trainer.train(sample, model, global_fixed)
            train_loss += loss.item()
            score_train, score_train_global, score_train_local = trainer.get_scores(
            )

            batch_time.update(time.time() - start_time)
            start_time = time.time()

            if i_batch % 20 == 0 and local_rank == 0:
                if mode == 1:
                    tbar.set_description(
                        'Train loss: %.4f;global mIoU: %.4f; data time: %.2f; batch time: %.2f'
                        % (train_loss /
                           (i_batch + 1), score_train_global["iou_mean"],
                           data_time.avg, batch_time.avg))
                elif mode == 2:
                    tbar.set_description(
                        'Train loss: %.4f;agg mIoU: %.4f; local mIoU: %.4f; data time: %.2f; batch time: %.2f'
                        % (train_loss / (i_batch + 1), score_train["iou_mean"],
                           score_train_local["iou_mean"], data_time.avg,
                           batch_time.avg))
                else:
                    tbar.set_description(
                        'Train loss: %.4f;agg mIoU: %.4f; global mIoU: %.4f; local mIoU: %.4f; data time: %.2f; batch time: %.2f'
                        % (train_loss / (i_batch + 1), score_train["iou_mean"],
                           score_train_global["iouu_mean"],
                           score_train_local["iou_mean"], data_time.avg,
                           batch_time.avg))

        score_train, score_train_global, score_train_local = trainer.get_scores(
        )
        trainer.reset_metrics()
        data_time.reset()
        batch_time.reset()

        if evaluation and epoch % 1 == 0 and local_rank == 0:
            with torch.no_grad():
                model.eval()
                print("evaluating...")
                tbar = tqdm(dataloader_val)

                start_time = time.time()
                for i_batch, sample in enumerate(tbar):
                    data_time.update(time.time() - start_time)
                    predictions, predictions_global, predictions_local = evaluator.eval_test(
                        sample, model, global_fixed)
                    score_val, score_val_global, score_val_local = evaluator.get_scores(
                    )

                    batch_time.update(time.time() - start_time)
                    if i_batch % 20 == 0 and local_rank == 0:
                        if mode == 1:
                            tbar.set_description(
                                'global mIoU: %.4f; data time: %.2f; batch time: %.2f'
                                % (score_val_global["iou_mean"], data_time.avg,
                                   batch_time.avg))
                        elif mode == 2:
                            tbar.set_description(
                                'agg mIoU: %.4f; local mIoU: %.4f; data time: %.2f; batch time: %.2f'
                                % (score_val["iou_mean"],
                                   score_val_local["iou_mean"], data_time.avg,
                                   batch_time.avg))
                        else:
                            tbar.set_description(
                                'agg mIoU: %.4f; global mIoU: %.4f; local mIoU: %.4f; data time: %.2f; batch time: %.2f'
                                % (score_val["iou_mean"],
                                   score_val_global["iou_mean"],
                                   score_val_local["iou_mean"], data_time.avg,
                                   batch_time.avg))

                    if val_vis and i_batch == len(
                            tbar) // 2:  # val set result visualize
                        mask_rgb = class_to_RGB(np.array(sample['mask'][1]))
                        mask_rgb = ToTensor()(mask_rgb)
                        writer_info.update(mask=mask_rgb,
                                           prediction_global=ToTensor()(
                                               class_to_RGB(
                                                   predictions_global[1])))
                        if mode == 2 or mode == 3:
                            writer.update(prediction=ToTensor()(class_to_RGB(
                                predictions[1])),
                                          prediction_local=ToTensor()(
                                              class_to_RGB(
                                                  predictions_local[1])))

                    start_time = time.time()

                data_time.reset()
                batch_time.reset()
                score_val, score_val_global, score_val_local = evaluator.get_scores(
                )
                evaluator.reset_metrics()

                # save model
                best_pred = save_ckpt_model(model, cfg, score_val,
                                            score_val_global, best_pred, epoch)
                # log
                update_log(
                    f_log, cfg,
                    [score_train, score_train_global, score_train_local],
                    [score_val, score_val_global, score_val_local], epoch)
                # writer
                if mode == 1:
                    writer_info.update(
                        loss=train_loss / len(tbar),
                        lr=optimizer.param_groups[0]['lr'],
                        mIOU={
                            "train": score_train_global["iou_mean"],
                            "val": score_val_global["iou_mean"],
                        },
                        global_mIOU={
                            "train": score_train_global["iou_mean"],
                            "val": score_val_global["iou_mean"],
                        },
                        mucosa_iou={
                            "train": score_train_global["iou"][2],
                            "val": score_val_global["iou"][2],
                        },
                        tumor_iou={
                            "train": score_train_global["iou"][3],
                            "val": score_val_global["iou"][3],
                        },
                    )
                else:
                    writer_info.update(
                        loss=train_loss / len(tbar),
                        lr=optimizer.param_groups[0]['lr'],
                        mIOU={
                            "train": score_train["iou_mean"],
                            "val": score_val["iou_mean"],
                        },
                        global_mIOU={
                            "train": score_train_global["iou_mean"],
                            "val": score_val_global["iou_mean"],
                        },
                        local_mIOU={
                            "train": score_train_local["iou_mean"],
                            "val": score_val_local["iou_mean"],
                        },
                        mucosa_iou={
                            "train": score_train["iou"][2],
                            "val": score_val["iou"][2],
                        },
                        tumor_iou={
                            "train": score_train["iou"][3],
                            "val": score_val["iou"][3],
                        },
                    )

                update_writer(writer, writer_info, epoch)
    if local_rank == 0:
        f_log.close()
Beispiel #23
0
    def _valid_epoch(self, epoch):
        if self.val_loader is None:
            self.logger.warning(
                'Not data loader was passed for the validation step, No validation is performed !'
            )
            return {}
        self.logger.info('\n###### EVALUATION ######')

        self.model.eval()
        self.wrt_mode = 'val'
        total_loss_val = AverageMeter()
        total_inter, total_union = 0, 0
        total_correct, total_label = 0, 0

        tbar = tqdm(self.val_loader, ncols=130)
        with torch.no_grad():
            val_visual = []
            for batch_idx, (data, target) in enumerate(tbar):
                target, data = target.cuda(non_blocking=True), data.cuda(
                    non_blocking=True)

                H, W = target.size(1), target.size(2)
                up_sizes = (ceil(H / 8) * 8, ceil(W / 8) * 8)
                pad_h, pad_w = up_sizes[0] - data.size(
                    2), up_sizes[1] - data.size(3)
                data = F.pad(data, pad=(0, pad_w, 0, pad_h), mode='reflect')
                output = self.model(data)
                output = output[:, :, :H, :W]

                # LOSS
                loss = F.cross_entropy(output,
                                       target,
                                       ignore_index=self.ignore_index)
                total_loss_val.update(loss.item())

                correct, labeled, inter, union = eval_metrics(
                    output, target, self.num_classes, self.ignore_index)
                total_inter, total_union = total_inter + inter, total_union + union
                total_correct, total_label = total_correct + correct, total_label + labeled

                # LIST OF IMAGE TO VIZ (15 images)
                if len(val_visual) < 15:
                    if isinstance(data, list): data = data[0]
                    target_np = target.data.cpu().numpy()
                    output_np = output.data.max(1)[1].cpu().numpy()
                    val_visual.append(
                        [data[0].data.cpu(), target_np[0], output_np[0]])

                # PRINT INFO
                pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
                IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
                mIoU = IoU.mean()
                seg_metrics = {
                    "Pixel_Accuracy":
                    np.round(pixAcc, 3),
                    "Mean_IoU":
                    np.round(mIoU, 3),
                    "Class_IoU":
                    dict(zip(range(self.num_classes), np.round(IoU, 3)))
                }

                tbar.set_description(
                    'EVAL ({}) | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f} |'
                    .format(epoch, total_loss_val.average, pixAcc, mIoU))

            self._add_img_tb(val_visual, 'val')

            # METRICS TO TENSORBOARD
            self.wrt_step = (epoch) * len(self.val_loader)
            self.writer.add_scalar(f'{self.wrt_mode}/loss',
                                   total_loss_val.average, self.wrt_step)
            for k, v in list(seg_metrics.items())[:-1]:
                self.writer.add_scalar(f'{self.wrt_mode}/{k}', v,
                                       self.wrt_step)

            log = {'val_loss': total_loss_val.average, **seg_metrics}
            self.html_results.add_results(epoch=epoch, seg_resuts=log)
            self.html_results.save()

            if (time.time() - self.start_time) / 3600 > 22:
                self._save_checkpoint(epoch, save_best=self.improved)
        return log
Beispiel #24
0
    def train_one_epoch(self):
        """
        One epoch training function
        """

        # Set the model to be in training mode
        self.model.train()
        # Initialize your average meters
        train_loss = AverageMeter()
        train_err_joints = AverageMeter()
        train_err_rotation = AverageMeter()
        train_err_translation = AverageMeter()

        total_batch = len(self.data_loader.train_loader)

        print("Starting Epoch Training with batch size: {:d}".format(
            total_batch))
        print("Batch Size: {:d}".format(self.config.batch_size))
        current_batch = 1

        # times = {}

        q_dist_list = []
        trans_list = []
        joint_list = []

        tic = time.time()

        mean_speed = 0
        total_loss_sum = 0

        samples = 300
        np_batch_bench = np.zeros([samples, 4])

        for x, y in self.data_loader.train_loader:
            batch_start = time.time()

            if self.cuda:
                x = x.to(device=self.device, dtype=torch.long)
                y = y.to(device=self.device, dtype=torch.long)

            progress = float(
                self.current_epoch * self.data_loader.train_iterations +
                current_batch) / (self.config.max_epoch *
                                  self.data_loader.train_iterations)

            # adjust learning rate
            lr = adjust_learning_rate(self.optimizer,
                                      self.current_epoch,
                                      self.config,
                                      batch=current_batch,
                                      nBatch=self.data_loader.train_iterations)

            # model

            pred = self.model((x))

            if self.config.data_output_type == "joints_absolute":
                loss_joints = self.loss(pred, y)
                total_loss = loss_joints
                train_loss.update(total_loss.item())
                train_err_joints.update(total_loss.item())
            elif self.config.data_output_type == "q_trans_simple":
                loss_q_trans_simple = self.loss(pred, y)
                total_loss = loss_q_trans_simple
            elif self.config.data_output_type == "pose_relative":
                # loss for rotation
                # select rotation indices from the prediction tensor
                indices = torch.tensor([3, 4, 5, 6])
                indices = indices.to(self.device)
                rotation = torch.index_select(pred, 1, indices)
                # select rotation indices from the label tensor
                y_rot = torch.index_select(y, 1, indices)
                # calc MSE loss for rotation
                # loss_rotation = self.loss(rotation, y_rot)

                # trans_list.append(loss_rotation[0].item().numpy())
                # print(loss_rotation.item())

                # penalty loss from facebook paper posenet
                # penalty_loss = self.config.rot_reg * torch.mean((torch.sum(quater ** 2, dim=1) - 1) ** 2)
                penalty_loss = 0

                q_pred = pq.Quaternion(rotation[0].cpu().detach().numpy())
                q_rot = pq.Quaternion(y_rot[0].cpu().detach().numpy())
                q_dist = math.degrees(pq.Quaternion.distance(q_pred, q_rot))
                q_dist_list.append(q_dist)

                # loss for translation
                # select translation indices from the prediction tensor
                indices = torch.tensor([0, 1, 2])
                indices = indices.to(self.device)
                translation = torch.index_select(pred, 1, indices)
                # select translation indices from the label tensor
                y_trans = torch.index_select(y, 1, indices)

                # calc MSE loss for translation
                loss_translation = self.loss(translation, y_trans)
                trans_list.append(loss_translation.item())

                # total_loss = penalty_loss + loss_rotation + loss_translation
                # use simple loss
                total_loss = self.loss(pred.double(), y.double())

                # calc translation MSE
                q_pred = pq.Quaternion(rotation[0].cpu().detach().numpy())
                q_rot = pq.Quaternion(y_rot[0].cpu().detach().numpy())
                q_dist = math.degrees(pq.Quaternion.distance(q_pred, q_rot))
                q_dist_list.append(q_dist)
                trans_pred = translation[0].cpu().detach().numpy()
                trans_label = y_trans[0].cpu().detach().numpy()
                mse_trans = (np.square(trans_pred - trans_label)).mean()
                train_err_translation.update(mse_trans)
                train_err_rotation.update(q_dist)

            elif self.config.data_output_type == "pose_absolute":
                # select rotation indices from the prediction tensor
                indices = torch.tensor([3, 4, 5, 6])
                indices = indices.to(self.device)
                rotation = torch.index_select(pred, 1, indices)
                # select rotation indices from the label tensor
                y_rot = torch.index_select(y, 1, indices)

                q_pred = pq.Quaternion(rotation[0].cpu().detach().numpy())
                q_rot = pq.Quaternion(y_rot[0].cpu().detach().numpy())
                q_dist = math.degrees(pq.Quaternion.distance(q_pred, q_rot))
                q_dist_list.append(q_dist)

                # loss for translation
                # select translation indices from the prediction tensor
                indices = torch.tensor([0, 1, 2])
                indices = indices.to(self.device)
                translation = torch.index_select(pred, 1, indices)
                # select translation indices from the label tensor
                y_trans = torch.index_select(y, 1, indices)

                trans_pred = translation[0].cpu().detach().numpy()
                trans_label = y_trans[0].cpu().detach().numpy()

                # calc MSE loss for translation
                loss_translation = self.loss(translation, y_trans)
                trans_list.append(loss_translation.item())

                # use simple loss
                total_loss = self.loss(pred, y)

                # calc translation MSE
                mse_trans = (np.square(trans_pred - trans_label)).mean()
                train_err_translation.update(mse_trans)
                train_err_rotation.update(q_dist)

            elif self.config.data_output_type == "joints_relative":
                total_loss = self.loss(pred, y)
                train_err_joints.update(total_loss.item())
                # print("Train loss {:f}".format(total_loss.item()))
                joint_list.append(total_loss.item())
            else:
                raise Exception("Wrong data output type chosen.")

            if np.isnan(float(total_loss.item())):
                raise ValueError('Loss is nan during training...')

            # optimizer
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

            train_loss.update(total_loss.item())

            self.current_iteration += 1

            batch_duration = time.time() - batch_start
            mean_speed += batch_duration
            speed = float(mean_speed / current_batch)
            remaining_sec = speed * (total_batch - current_batch) * (
                self.config.max_epoch - self.current_epoch)
            batch_progress = float(current_batch / total_batch) * 100
            # print(int(batch_progress) % 5)

            total_loss_sum += total_loss.item()
            avg_total_loss = float(total_loss_sum / current_batch)
            #
            # if avg_total_loss <= self.config.min_avg_loss:
            #     print("Loss is {:.3e} <= {:.3e}".format(avg_total_loss, self.config.min_avg_loss))
            # else:
            #     print("Loss is {:.3e} > {:.3e}".format(avg_total_loss, self.config.min_avg_loss))

            if self.config.DEBUG_TRAINING_DURATION:  # and int(math.floor(batch_progress)) % 25 == 0:
                print(
                    "Current Batch {:d} {:d} {:2.1%} {:.2f} s Avg {:.2f} s/batch Loss {:.3e} Remaining {:s}"
                    .format(
                        current_batch, total_batch,
                        float(current_batch / total_batch), batch_duration,
                        speed, avg_total_loss,
                        time.strftime('Days %d Time %H:%M:%S',
                                      time.gmtime(remaining_sec))))

            if current_batch > samples:
                break

            print(np_batch_bench.shape)
            print(current_batch)
            np_batch_bench[current_batch - 1][0] = current_batch
            np_batch_bench[current_batch - 1][1] = total_batch
            np_batch_bench[current_batch - 1][2] = batch_duration
            np_batch_bench[current_batch - 1][3] = speed

            current_batch += 1

        # save mean of q_dist_list into bigger array
        mean = np.mean(np.asarray(q_dist_list))
        # print("Q mean {:3.2f} deg".format(mean))
        mean_t = np.mean(np.asarray(trans_list))
        mean_joints = np.mean(np.asarray(joint_list))

        self.trans_mean.append([self.iter, mean_t])
        self.q_dist_mean.append([self.iter, mean])
        self.joints_mean.append([self.iter, mean_joints])
        self.iter += 1

        # update logging dict
        self.logging_dict["learning_rate"].append(lr)
        self.logging_dict["train_loss"].append(train_loss.val)
        self.logging_dict["train_err_rotation"].append(train_err_rotation.val)
        self.logging_dict["train_err_translation"].append(
            train_err_translation.val)
        self.logging_dict["train_err_joints"].append(train_err_joints.val)

        # print progress
        progress = float((self.current_epoch + 1) / self.config.max_epoch)
        duration_epoch = time.time() - tic
        if self.current_epoch % self.config.display_step == 0 or self.current_epoch % 1 == 0:
            self.duration = time.time() - self.start_time
            self.logger.info(
                "Train Epoch: {:>4d} | Total: {:>4d} | Progress: {:>3.2%} | Loss: {:>3.2e} | Translation [mm]: {:>3.2e} |"
                " Rotation [deg] {:>3.2e} | Joints [deg] {:>3.2e} | ({:02d}:{:02d}:{:02d}) "
                .format(self.current_epoch + 1, self.config.max_epoch,
                        progress, train_loss.val, train_err_translation.val,
                        train_err_rotation.val, train_err_joints.val,
                        int(self.duration /
                            3600), int(np.mod(self.duration, 3600) / 60),
                        int(np.mod(np.mod(self.duration, 3600), 60))) +
                time.strftime("%d.%m.%y %H:%M:%S", time.localtime()))

        ds = pd.DataFrame(np_batch_bench)
        print(ds)
        path = "/home/speerponar/pytorch_models/evaluation/"
        ds.to_csv(path + "test_" + str(self.config.batch_size) + "w_" +
                  str(self.config.data_loader_workers) + ".csv")
        print("Save csv file")
Beispiel #25
0
class Trainer(BaseTrainer):
    def __init__(self,
                 model,
                 resume,
                 config,
                 supervised_loader,
                 unsupervised_loader,
                 iter_per_epoch,
                 val_loader=None,
                 val_logger=None,
                 train_logger=None):
        super(Trainer, self).__init__(model, resume, config, iter_per_epoch,
                                      val_logger, train_logger)

        self.supervised_loader = supervised_loader
        self.unsupervised_loader = unsupervised_loader
        self.val_loader = val_loader

        self.cutmix_param = config['cutmix']
        self.mix_image = CutMix(self.cutmix_param)

        self.ignore_index = self.val_loader.dataset.ignore_index
        self.wrt_mode, self.wrt_step = 'train_', 0
        self.log_step = config['trainer'].get(
            'log_per_iter', int(np.sqrt(self.val_loader.batch_size)))
        if config['trainer']['log_per_iter']:
            self.log_step = int(self.log_step / self.val_loader.batch_size) + 1

        self.num_classes = self.val_loader.dataset.num_classes
        self.mode = self.model.module.mode

        # TRANSORMS FOR VISUALIZATION
        self.restore_transform = transforms.Compose([
            DeNormalize(self.val_loader.MEAN, self.val_loader.STD),
            transforms.ToPILImage()
        ])
        self.viz_transform = transforms.Compose(
            [transforms.Resize((400, 400)),
             transforms.ToTensor()])

        self.start_time = time.time()

    def _train_epoch(self, epoch):
        self.html_results.save()

        self.logger.info('\n')
        self.model.train()

        if self.mode == 'supervised':
            dataloader = iter(self.supervised_loader)
            tbar = tqdm(range(len(self.supervised_loader)), ncols=135)
        else:
            dataloader = iter(
                zip(cycle(self.supervised_loader), self.unsupervised_loader))
            tbar = tqdm(range(len(self.unsupervised_loader)), ncols=135)

        self._reset_metrics()
        for batch_idx in tbar:
            if self.mode == 'supervised':
                (input_l, target_l), (input_ul,
                                      target_ul) = next(dataloader), (None,
                                                                      None)
            else:
                (input_l, target_l), (input_ul, target_ul) = next(dataloader)
                #input_ul, target_ul = input_ul.cuda(non_blocking=True), target_ul.cuda(non_blocking=True)
                target_ul = target_ul.cuda(non_blocking=True)

                if self.cutmix_param["cutmix_transform"]:
                    input_ul_mix = self.mix_image.generate_cutmix_images(
                        input_ul)
                    input_ul = input_ul_mix.cuda(non_blocking=True)
                    del input_ul_mix
                    # print(f"Applied cutmix to encoder")

            input_l, target_l = input_l.cuda(non_blocking=True), target_l.cuda(
                non_blocking=True)

            self.optimizer.zero_grad()

            total_loss, cur_losses, outputs = self.model(x_l=input_l,
                                                         target_l=target_l,
                                                         x_ul=input_ul,
                                                         curr_iter=batch_idx,
                                                         target_ul=target_ul,
                                                         epoch=epoch - 1)
            total_loss = total_loss.mean()
            total_loss.backward()
            self.optimizer.step()

            self._update_losses(cur_losses)
            self._compute_metrics(outputs, target_l, target_ul, epoch - 1)
            logs = self._log_values(cur_losses)

            if batch_idx % self.log_step == 0:
                self.wrt_step = (epoch - 1) * len(
                    self.unsupervised_loader) + batch_idx
                self._write_scalars_tb(logs)

            if batch_idx % int(len(self.unsupervised_loader) * 0.9) == 0:
                self._write_img_tb(input_l, target_l, input_ul, target_ul,
                                   outputs, epoch)

            del input_l, target_l, input_ul, target_ul
            del total_loss, cur_losses, outputs

            tbar.set_description(
                'T ({}) | Ls {:.2f} Lu {:.2f} Lw {:.2f} PW {:.2f} m1 {:.2f} m2 {:.2f}|'
                .format(epoch, self.loss_sup.average, self.loss_unsup.average,
                        self.loss_weakly.average, self.pair_wise.average,
                        self.mIoU_l, self.mIoU_ul))

            self.lr_scheduler.step(epoch=epoch - 1)

        return logs

    def _valid_epoch(self, epoch):
        if self.val_loader is None:
            self.logger.warning(
                'Not data loader was passed for the validation step, No validation is performed !'
            )
            return {}
        self.logger.info('\n###### EVALUATION ######')

        self.model.eval()
        self.wrt_mode = 'val'
        total_loss_val = AverageMeter()
        total_inter, total_union = 0, 0
        total_correct, total_label = 0, 0

        tbar = tqdm(self.val_loader, ncols=130)
        with torch.no_grad():
            val_visual = []
            for batch_idx, (data, target) in enumerate(tbar):
                target, data = target.cuda(non_blocking=True), data.cuda(
                    non_blocking=True)

                H, W = target.size(1), target.size(2)
                up_sizes = (ceil(H / 8) * 8, ceil(W / 8) * 8)
                pad_h, pad_w = up_sizes[0] - data.size(
                    2), up_sizes[1] - data.size(3)
                data = F.pad(data, pad=(0, pad_w, 0, pad_h), mode='reflect')
                output = self.model(data)
                output = output[:, :, :H, :W]

                # LOSS
                loss = F.cross_entropy(output,
                                       target,
                                       ignore_index=self.ignore_index)
                total_loss_val.update(loss.item())

                correct, labeled, inter, union = eval_metrics(
                    output, target, self.num_classes, self.ignore_index)
                total_inter, total_union = total_inter + inter, total_union + union
                total_correct, total_label = total_correct + correct, total_label + labeled

                # LIST OF IMAGE TO VIZ (15 images)
                if len(val_visual) < 15:
                    if isinstance(data, list): data = data[0]
                    target_np = target.data.cpu().numpy()
                    output_np = output.data.max(1)[1].cpu().numpy()
                    val_visual.append(
                        [data[0].data.cpu(), target_np[0], output_np[0]])

                # PRINT INFO
                pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
                IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
                mIoU = IoU.mean()
                seg_metrics = {
                    "Pixel_Accuracy":
                    np.round(pixAcc, 3),
                    "Mean_IoU":
                    np.round(mIoU, 3),
                    "Class_IoU":
                    dict(zip(range(self.num_classes), np.round(IoU, 3)))
                }

                tbar.set_description(
                    'EVAL ({}) | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f} |'
                    .format(epoch, total_loss_val.average, pixAcc, mIoU))

            self._add_img_tb(val_visual, 'val')

            # METRICS TO TENSORBOARD
            self.wrt_step = (epoch) * len(self.val_loader)
            self.writer.add_scalar(f'{self.wrt_mode}/loss',
                                   total_loss_val.average, self.wrt_step)
            for k, v in list(seg_metrics.items())[:-1]:
                self.writer.add_scalar(f'{self.wrt_mode}/{k}', v,
                                       self.wrt_step)

            log = {'val_loss': total_loss_val.average, **seg_metrics}
            self.html_results.add_results(epoch=epoch, seg_resuts=log)
            self.html_results.save()

            if (time.time() - self.start_time) / 3600 > 22:
                self._save_checkpoint(epoch, save_best=self.improved)
        return log

    def _reset_metrics(self):
        self.loss_sup = AverageMeter()
        self.loss_unsup = AverageMeter()
        self.loss_weakly = AverageMeter()
        self.pair_wise = AverageMeter()
        self.total_inter_l, self.total_union_l = 0, 0
        self.total_correct_l, self.total_label_l = 0, 0
        self.total_inter_ul, self.total_union_ul = 0, 0
        self.total_correct_ul, self.total_label_ul = 0, 0
        self.mIoU_l, self.mIoU_ul = 0, 0
        self.pixel_acc_l, self.pixel_acc_ul = 0, 0
        self.class_iou_l, self.class_iou_ul = {}, {}

    def _update_losses(self, cur_losses):
        if "loss_sup" in cur_losses.keys():
            self.loss_sup.update(cur_losses['loss_sup'].mean().item())
        if "loss_unsup" in cur_losses.keys():
            self.loss_unsup.update(cur_losses['loss_unsup'].mean().item())
        if "loss_weakly" in cur_losses.keys():
            self.loss_weakly.update(cur_losses['loss_weakly'].mean().item())
        if "pair_wise" in cur_losses.keys():
            self.pair_wise.update(cur_losses['pair_wise'].mean().item())

    def _compute_metrics(self, outputs, target_l, target_ul, epoch):
        seg_metrics_l = eval_metrics(outputs['sup_pred'], target_l,
                                     self.num_classes, self.ignore_index)
        self._update_seg_metrics(*seg_metrics_l, True)
        seg_metrics_l = self._get_seg_metrics(True)
        self.pixel_acc_l, self.mIoU_l, self.class_iou_l = seg_metrics_l.values(
        )

        if self.mode == 'semi':
            seg_metrics_ul = eval_metrics(outputs['unsup_pred'], target_ul,
                                          self.num_classes, self.ignore_index)
            self._update_seg_metrics(*seg_metrics_ul, False)
            seg_metrics_ul = self._get_seg_metrics(False)
            self.pixel_acc_ul, self.mIoU_ul, self.class_iou_ul = seg_metrics_ul.values(
            )

    def _update_seg_metrics(self,
                            correct,
                            labeled,
                            inter,
                            union,
                            supervised=True):
        if supervised:
            self.total_correct_l += correct
            self.total_label_l += labeled
            self.total_inter_l += inter
            self.total_union_l += union
        else:
            self.total_correct_ul += correct
            self.total_label_ul += labeled
            self.total_inter_ul += inter
            self.total_union_ul += union

    def _get_seg_metrics(self, supervised=True):
        if supervised:
            pixAcc = 1.0 * self.total_correct_l / (np.spacing(1) +
                                                   self.total_label_l)
            IoU = 1.0 * self.total_inter_l / (np.spacing(1) +
                                              self.total_union_l)
        else:
            pixAcc = 1.0 * self.total_correct_ul / (np.spacing(1) +
                                                    self.total_label_ul)
            IoU = 1.0 * self.total_inter_ul / (np.spacing(1) +
                                               self.total_union_ul)
        mIoU = IoU.mean()
        return {
            "Pixel_Accuracy": np.round(pixAcc, 3),
            "Mean_IoU": np.round(mIoU, 3),
            "Class_IoU": dict(zip(range(self.num_classes), np.round(IoU, 3)))
        }

    def _log_values(self, cur_losses):
        logs = {}
        if "loss_sup" in cur_losses.keys():
            logs['loss_sup'] = round(self.loss_sup.average, 4)
        if "loss_unsup" in cur_losses.keys():
            logs['loss_unsup'] = round(self.loss_unsup.average, 4)
        if "loss_weakly" in cur_losses.keys():
            logs['loss_weakly'] = round(self.loss_weakly.average, 4)
        if "pair_wise" in cur_losses.keys():
            logs['pair_wise'] = round(self.pair_wise.average, 4)

        logs['mIoU_labeled'] = round(self.mIoU_l, 3)
        logs['pixel_acc_labeled'] = round(self.pixel_acc_l, 4)
        if self.mode == 'semi':
            logs['mIoU_unlabeled'] = round(self.mIoU_ul, 4)
            logs['pixel_acc_unlabeled'] = round(self.pixel_acc_ul, 4)
        return logs

    def _write_scalars_tb(self, logs):
        for k, v in logs.items():
            if 'class_iou' not in k:
                self.writer.add_scalar(f'train/{k}', v, self.wrt_step)
        for i, opt_group in enumerate(self.optimizer.param_groups):
            self.writer.add_scalar(f'train/Learning_rate_{i}', opt_group['lr'],
                                   self.wrt_step)
        current_rampup = self.model.module.unsup_loss_w.current_rampup
        self.writer.add_scalar('train/Unsupervised_rampup', current_rampup,
                               self.wrt_step)

    def _add_img_tb(self, val_visual, wrt_mode):
        val_img = []
        palette = self.val_loader.dataset.palette
        for imgs in val_visual:
            imgs = [
                self.restore_transform(i) if
                (isinstance(i, torch.Tensor)
                 and len(i.shape) == 3) else colorize_mask(i, palette)
                for i in imgs
            ]
            imgs = [i.convert('RGB') for i in imgs]
            imgs = [self.viz_transform(i) for i in imgs]
            val_img.extend(imgs)
        val_img = torch.stack(val_img, 0)
        val_img = make_grid(val_img.cpu(),
                            nrow=val_img.size(0) // len(val_visual),
                            padding=5)
        self.writer.add_image(f'{wrt_mode}/inputs_targets_predictions',
                              val_img, self.wrt_step)

    def _write_img_tb(self, input_l, target_l, input_ul, target_ul, outputs,
                      epoch):
        outputs_l_np = outputs['sup_pred'].data.max(1)[1].cpu().numpy()
        targets_l_np = target_l.data.cpu().numpy()
        imgs = [[i.data.cpu(), j, k]
                for i, j, k in zip(input_l, outputs_l_np, targets_l_np)]
        self._add_img_tb(imgs, 'supervised')

        if self.mode == 'semi':
            outputs_ul_np = outputs['unsup_pred'].data.max(1)[1].cpu().numpy()
            targets_ul_np = target_ul.data.cpu().numpy()
            imgs = [[i.data.cpu(), j, k]
                    for i, j, k in zip(input_ul, outputs_ul_np, targets_ul_np)]
            self._add_img_tb(imgs, 'unsupervised')
Beispiel #26
0
    def train_one_epoch_supervision(self):

        # Set the model to be in training mode
        self.net.train()

        # Initialize average meters
        epoch_loss = AverageMeter()
        epoch_acc = AverageMeter()
        epoch_iou = AverageMeter()
        epoch_filtered_iou = AverageMeter()

        tqdm_batch = tqdm(self.train_loader, f'Epoch-{self.current_epoch}-')
        for x in tqdm_batch:
            # prepare data
            imgs = torch.tensor(x['img'],
                                dtype=torch.float,
                                device=self.device)
            masks = torch.tensor(x['mask'],
                                 dtype=torch.float,
                                 device=self.device)
            salty = torch.tensor(x['salty'],
                                 dtype=torch.float,
                                 device=self.device)

            # model
            pred, pred_seg_pure, pred_salty = self.net(imgs)

            # loss
            cur_pred_loss = self.loss(pred, masks)
            cur_seg_pure_loss = self.loss(pred_seg_pure[salty.squeeze() > 0.5],
                                          masks[salty.squeeze() > 0.5])
            cur_salty_loss = F.binary_cross_entropy_with_logits(
                pred_salty, salty)
            cur_loss = 0.005 * cur_salty_loss + 0.5 * cur_seg_pure_loss + cur_pred_loss
            if np.isnan(float(cur_loss.item())):
                raise ValueError('Loss is nan during training...')

            # optimizer
            self.optimizer.zero_grad()
            cur_loss.backward()
            self.optimizer.step()

            # metrics
            pred_t = torch.sigmoid(pred) > 0.5
            masks_t = masks > 0.5

            cur_acc = torch.sum(pred_t == masks_t).item() / masks.numel()
            cur_iou = iou_pytorch(pred_t, masks_t)
            cur_filtered_iou = iou_pytorch(remove_small_mask_batch(pred_t),
                                           masks_t)

            batch_size = imgs.shape[0]
            epoch_loss.update(cur_loss.item(), batch_size)
            epoch_acc.update(cur_acc, batch_size)
            epoch_iou.update(cur_iou.item(), batch_size)
            epoch_filtered_iou.update(cur_filtered_iou.item(), batch_size)

        tqdm_batch.close()

        logging.info(
            f'Training at epoch- {self.current_epoch} |'
            f'loss: {epoch_loss.val:.5} - Acc: {epoch_acc.val:.5}'
            f'- IOU: {epoch_iou.val:.5} - Filtered IOU: {epoch_filtered_iou.val:.5}'
        )
Beispiel #27
0
def train():
    info_format = 'Epoch: [{}]\t loss: {: .6f} train mF1: {: .6f} \t val mF1: {: .6f}\t test mF1: {:.6f} \t ' \
                  'best val mF1: {: .6f}\t best test mF1: {:.6f}'
    opt.printer.info('===> Init the optimizer ...')
    criterion = torch.nn.BCEWithLogitsLoss().to(opt.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    scheduler = ReduceLROnPlateau(optimizer,
                                  "min",
                                  patience=opt.lr_patience,
                                  verbose=True,
                                  factor=0.5,
                                  cooldown=30,
                                  min_lr=opt.lr / 100)
    opt.scheduler = 'ReduceLROnPlateau'

    optimizer, scheduler, opt.lr = load_pretrained_optimizer(
        opt.pretrained_model, optimizer, scheduler, opt.lr)

    opt.printer.info('===> Init Metric ...')
    opt.losses = AverageMeter()

    best_val_value = 0.
    best_test_value = 0.

    opt.printer.info('===> Start training ...')
    for _ in range(opt.epoch, opt.total_epochs):
        opt.epoch += 1
        loss, train_value = train_step(model, train_loader, optimizer,
                                       criterion, opt)
        val_value = test(model, valid_loader, opt)
        test_value = test(model, test_loader, opt)

        if val_value > best_val_value:
            best_val_value = val_value
            save_ckpt(model,
                      optimizer,
                      scheduler,
                      opt.epoch,
                      opt.save_path,
                      opt.post,
                      name_post='val_best')
        if test_value > best_test_value:
            best_test_value = test_value
            save_ckpt(model,
                      optimizer,
                      scheduler,
                      opt.epoch,
                      opt.save_path,
                      opt.post,
                      name_post='test_best')

        opt.printer.info(
            info_format.format(opt.epoch, loss, train_value, val_value,
                               test_value, best_val_value, best_test_value))

        if opt.scheduler == 'ReduceLROnPlateau':
            scheduler.step(opt.losses.avg)
        else:
            scheduler.step()

    opt.printer.info('Saving the final model.Finish!')
 def _reset_metrics(self):
     self.batch_time = AverageMeter()
     self.data_time = AverageMeter()
     self.total_loss = AverageMeter()
     self.total_inter, self.total_union = 0, 0
     self.total_correct, self.total_label = 0, 0
Beispiel #29
0
    def train_one_epoch(self):
        """
        One epoch training function
        """
        # Initialize tqdm
        tqdm_batch = tqdm(self.data_loader.train_loader,
                          total=self.data_loader.train_iterations,
                          desc="Epoch-{}-".format(self.current_epoch))
        # Set the model to be in training mode
        self.model.train()
        # Initialize your average meters
        epoch_loss = AverageMeter()
        top1_acc = AverageMeter()
        top5_acc = AverageMeter()

        current_batch = 0
        for x, y in tqdm_batch:
            if self.cuda:
                x, y = x.cuda(async=self.config.async_loading), y.cuda(
                    async=self.config.async_loading)

            # current iteration over total iterations
            progress = float(
                self.current_epoch * self.data_loader.train_iterations +
                current_batch) / (self.config.max_epoch *
                                  self.data_loader.train_iterations)
            # progress = float(self.current_iteration) / (self.config.max_epoch * self.data_loader.train_iterations)
            x, y = Variable(x), Variable(y)
            lr = adjust_learning_rate(self.optimizer,
                                      self.current_epoch,
                                      self.config,
                                      batch=current_batch,
                                      nBatch=self.data_loader.train_iterations)
            # model
            pred = self.model(x, progress)
            # loss
            cur_loss = self.loss(pred, y)
            if np.isnan(float(cur_loss.item())):
                raise ValueError('Loss is nan during training...')
            # optimizer
            self.optimizer.zero_grad()
            cur_loss.backward()
            self.optimizer.step()

            top1, top5 = cls_accuracy(pred.data, y.data, topk=(1, 5))

            epoch_loss.update(cur_loss.item())
            top1_acc.update(top1.item(), x.size(0))
            top5_acc.update(top5.item(), x.size(0))

            self.current_iteration += 1
            current_batch += 1

            self.summary_writer.add_scalar("epoch/loss", epoch_loss.val,
                                           self.current_iteration)
            self.summary_writer.add_scalar("epoch/accuracy", top1_acc.val,
                                           self.current_iteration)
        tqdm_batch.close()

        self.logger.info("Training at epoch-" + str(self.current_epoch) +
                         " | " + "loss: " + str(epoch_loss.val) +
                         "- Top1 Acc: " + str(top1_acc.val) + "- Top5 Acc: " +
                         str(top5_acc.val))
Beispiel #30
0
    def _validate(self, config):
        """
        One epoch validation
        :return:
        """
        self.data_loader = Cifar100DataLoader(self.config)
        self.loss_fn = nn.CrossEntropyLoss()
        self.loss_fn = self.loss_fn.to(self.device)
        tqdm_batch = tqdm.tqdm(self.data_loader.valid_loader,
                               total=self.data_loader.valid_iterations,
                               desc="Valiation at -{}-".format(
                                   self.current_epoch))

        self.eval()

        epoch_loss = AverageMeter()
        top1_acc = AverageMeter()
        top5_acc = AverageMeter()

        for x, y in tqdm_batch:
            if self.cuda:
                x, y = x.cuda(non_blocking=self.config.async_loading), y.cuda(
                    non_blocking=self.config.async_loading)

            # model
            pred = self(x)
            # loss
            cur_loss = self.loss_fn(pred, y)
            if np.isnan(float(cur_loss.item())):
                raise ValueError('Loss is nan during validation...')

            top1, top5 = cls_accuracy(pred.data, y.data, topk=(1, 5))
            top1_acc.update(top1.item(), x.size(0))
            top5_acc.update(top5.item(), x.size(0))

            epoch_loss.update(cur_loss.item())

        print("Validation results at epoch-" + str(self.current_epoch) +
              " | " + "loss: " + str(epoch_loss.avg) + "\tTop1 Acc: " +
              str(top1_acc.val))

        tqdm_batch.close()

        return top1_acc.avg
Beispiel #31
0
    def validate(self):
        """
        One epoch validation
        :return:
        """
        tqdm_batch = tqdm(self.data_loader.valid_loader,
                          total=self.data_loader.valid_iterations,
                          desc="Valiation at -{}-".format(self.current_epoch))

        # set the model in training mode
        self.model.eval()

        epoch_loss = AverageMeter()
        top1_acc = AverageMeter()
        top5_acc = AverageMeter()

        for x, y in tqdm_batch:
            if self.cuda:
                x, y = x.cuda(async=self.config.async_loading), y.cuda(
                    async=self.config.async_loading)

            x, y = Variable(x), Variable(y)
            # model
            pred = self.model(x)
            # loss
            cur_loss = self.loss(pred, y)
            if np.isnan(float(cur_loss.item())):
                raise ValueError('Loss is nan during validation...')

            top1, top5 = cls_accuracy(pred.data, y.data, topk=(1, 5))
            epoch_loss.update(cur_loss.item())
            top1_acc.update(top1.item(), x.size(0))
            top5_acc.update(top5.item(), x.size(0))

        self.logger.info("Validation results at epoch-" +
                         str(self.current_epoch) + " | " + "loss: " +
                         str(epoch_loss.avg) + "- Top1 Acc: " +
                         str(top1_acc.val) + "- Top5 Acc: " +
                         str(top5_acc.val))

        tqdm_batch.close()

        return top1_acc.avg