Ejemplo n.º 1
0
def define_loss():
    class_weights = params['class_weights']
    # Loss, with class weighting
    if params.get('focal_loss', False):
        modelVars['criterion'] = utils.FocalLoss(alpha=class_weights.tolist())
    elif params['balance_classes'] == 2:
        #modelVars['criterion'] = nn.BCEWithLogitsLoss(weight=torch.cuda.FloatTensor(class_weights.astype(np.float32)))
        modelVars['criterion'] = nn.CrossEntropyLoss(
            weight=torch.cuda.FloatTensor(class_weights.astype(np.float32)))
    elif params['balance_classes'] == 3 or params[
            'balance_classes'] == 0 or params['balance_classes'] == 12:
        modelVars['criterion'] = nn.CrossEntropyLoss()
    elif params['balance_classes'] == 8:
        modelVars['criterion'] = nn.CrossEntropyLoss(reduce=False)
    elif params['balance_classes'] == 6 or params['balance_classes'] == 7:
        modelVars['criterion'] = nn.CrossEntropyLoss(
            weight=torch.cuda.FloatTensor(class_weights.astype(np.float32)),
            reduce=False)
    elif params['balance_classes'] == 10:
        modelVars['criterion'] = utils.FocalLoss(params['numClasses'])
    elif params['balance_classes'] == 11:
        modelVars['criterion'] = utils.FocalLoss(
            params['numClasses'],
            alpha=torch.cuda.FloatTensor(class_weights.astype(np.float32)))
    else:
        modelVars['criterion'] = nn.CrossEntropyLoss(
            weight=torch.cuda.FloatTensor(class_weights.astype(np.float32)))
def real_channel():
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }

    model = model_map['deeplabv3plus_mobilenet'](num_classes=cfg.n_classes,
                                                 output_stride=cfg.stride)

    # if opts.separable_conv and 'plus' in opts.model:
    #     network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=cfg.momentum)

    # Set up metrics
    # metrics = StreamSegMetrics(opts.num_classes)
    l_r = cfg.l_r
    weight_decay = cfg.weight_decay
    lr_policy = cfg.lr_policy
    total_itrs = cfg.total_iter
    step_size = cfg.step_size
    loss_type = cfg.real_loss
    # Set up optimizer
    optimizer = torch.optim.SGD(params=[
        {
            'params': model.backbone.parameters(),
            'lr': 0.1 * l_r
        },
        {
            'params': model.classifier.parameters(),
            'lr': l_r
        },
    ],
                                lr=l_r,
                                momentum=0.9,
                                weight_decay=weight_decay)
    #optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    #torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, total_itrs, power=0.9)
    elif lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=step_size,
                                                    gamma=0.1)

    # Set up criterion
    #criterion = utils.get_loss(opts.loss_type)
    if loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
    elif loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255)

    return criterion, optimizer, scheduler, model
Ejemplo n.º 3
0
    def __init__(self, data_loader, opts):
        #super(Trainer, self).__init__(data_loader, opts)
        #self.opts = opts
        self.train_loader = data_loader[0]
        self.val_loader = data_loader[1]

        # Set up model
        model_map = {
            'deeplabv3_resnet50': network.deeplabv3_resnet50,
            'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
            'deeplabv3_resnet101': network.deeplabv3_resnet101,
            'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
            'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
            'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
        }

        self.model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
        if opts.separable_conv and 'plus' in opts.model:
            network.convert_to_separable_conv(self.model.classifier)

        def set_bn_momentum(model, momentum=0.1):
            for m in model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.momentum = momentum

        set_bn_momentum(self.model.backbone, momentum=0.01)             ##### What is Momentum? 0.01 or 0.99? #####


        # Set up metrics
        self.metrics = StreamSegMetrics(opts.num_classes)

        # Set up optimizer
        self.optimizer = torch.optim.SGD(params=[
            {'params': self.model.backbone.parameters(), 'lr': 0.1*opts.lr},
            {'params': self.model.classifier.parameters(), 'lr': opts.lr},
        ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)

        if opts.lr_policy=='poly':
            self.scheduler = utils.PolyLR(self.optimizer, opts.total_itrs, power=0.9)
        elif opts.lr_policy=='step':
            self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=opts.step_size, gamma=0.1)

        # Set up criterion
        if opts.loss_type == 'focal_loss':
            self.criterion = utils.FocalLoss(ignore_index=255, size_average=True)
        elif opts.loss_type == 'cross_entropy':
            self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')


        self.best_mean_iu = 0
        self.iteration = 0
Ejemplo n.º 4
0
def main():
    opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset == 'voc' and not opts.crop_val:
        opts.val_batch_size = 1

    train_dst, val_dst = get_dataset(opts)
    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=2)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.val_batch_size,
                                 shuffle=True,
                                 num_workers=2)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))

    # Set up model
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet,
        'doubleattention_resnet50': network.doubleattention_resnet50,
        'doubleattention_resnet101': network.doubleattention_resnet101,
        'head_resnet50': network.head_resnet50,
        'head_resnet101': network.head_resnet101
    }

    model = model_map[opts.model](num_classes=opts.num_classes,
                                  output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    optimizer = torch.optim.SGD(params=[
        {
            'params': model.backbone.parameters(),
            'lr': 0.1 * opts.lr
        },
        {
            'params': model.classifier.parameters(),
            'lr': opts.lr
        },
    ],
                                lr=opts.lr,
                                momentum=0.9,
                                weight_decay=opts.weight_decay)
    # optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=opts.step_size,
                                                    gamma=0.1)

    # Set up criterion
    # criterion = utils.get_loss(opts.loss_type)
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
        coss_manifode = utils.ManifondLoss(alpha=1).to(device)
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
        coss_manifode = utils.ManifondLoss(alpha=1).to(device)

    def save_ckpt(path):
        """ save current model
        """
        torch.save(
            {
                "cur_itrs": cur_itrs,
                "model_state": model.module.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s" % path)

    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    # ==========   Train Loop   ==========#
    vis_sample_id = np.random.randint(
        0, len(val_loader), opts.vis_num_samples,
        np.int32) if opts.enable_vis else None  # sample idxs for visualization
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224,
                                    0.225])  # denormalization for ori images

    if opts.test_only:
        model.eval()
        val_score, ret_samples = validate(opts=opts,
                                          model=model,
                                          loader=val_loader,
                                          device=device,
                                          metrics=metrics,
                                          ret_samples_ids=vis_sample_id)
        print(metrics.to_str(val_score))
        return

    interval_loss = 0
    while True:  # cur_itrs < opts.total_itrs:
        # =====  Train  =====
        model.train()
        cur_epochs += 1
        for (images, labels) in train_loader:
            cur_itrs += 1

            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs,
                             labels) + coss_manifode(outputs, labels) * 0.01
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss
            if vis is not None:
                vis.vis_scalar('Loss', cur_itrs, np_loss)

            if (cur_itrs) % 10 == 0:
                interval_loss = interval_loss / 10
                print("Epoch %d, Itrs %d/%d, Loss=%f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                interval_loss = 0.0

            if (cur_itrs) % opts.val_interval == 0:
                save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
                          (opts.model, opts.dataset, opts.output_stride))
                print("validation...")
                model.eval()
                val_score, ret_samples = validate(
                    opts=opts,
                    model=model,
                    loader=val_loader,
                    device=device,
                    metrics=metrics,
                    ret_samples_ids=vis_sample_id)
                print(metrics.to_str(val_score))
                if val_score['Mean IoU'] > best_score:  # save best model
                    best_score = val_score['Mean IoU']
                    save_ckpt('checkpoints/best_%s_%s_os%d.pth' %
                              (opts.model, opts.dataset, opts.output_stride))

                if vis is not None:  # visualize validation score and samples
                    vis.vis_scalar("[Val] Overall Acc", cur_itrs,
                                   val_score['Overall Acc'])
                    vis.vis_scalar("[Val] Mean IoU", cur_itrs,
                                   val_score['Mean IoU'])
                    vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

                    for k, (img, target, lbl) in enumerate(ret_samples):
                        img = (denorm(img) * 255).astype(np.uint8)
                        target = train_dst.decode_target(target).transpose(
                            2, 0, 1).astype(np.uint8)
                        lbl = train_dst.decode_target(lbl).transpose(
                            2, 0, 1).astype(np.uint8)
                        concat_img = np.concatenate(
                            (img, target, lbl), axis=2)  # concat along width
                        vis.vis_image('Sample %d' % k, concat_img)
                model.train()
            scheduler.step()

            if cur_itrs >= opts.total_itrs:
                return
Ejemplo n.º 5
0
            # mark cnn parameters
            for param in modelVars['model'].parameters():
                param.is_cnn_param = True
            # unmark fc
            for param in modelVars['model']._fc.parameters():
                param.is_cnn_param = False
        # modify model
        modelVars['model'] = models.modify_meta(mdlParams, modelVars['model'])
        # Mark new parameters
        for param in modelVars['model'].parameters():
            if not hasattr(param, 'is_cnn_param'):
                param.is_cnn_param = False

    if mdlParams['focal_loss']:
        modelVars['criterion'] = utils.FocalLoss(
            mdlParams['numClasses'],
            alpha=torch.FloatTensor(class_weights.astype(np.float32)).to(
                modelVars['device']))
    else:
        modelVars['criterion'] = nn.CrossEntropyLoss(weight=torch.FloatTensor(
            class_weights.astype(np.float32)).to(modelVars['device']))

    if mdlParams.get('with_meta', False):
        if mdlParams['freeze_cnn']:
            modelVars['optimizer'] = optim.Adam(
                filter(lambda p: p.requires_grad,
                       modelVars['model'].parameters()),
                lr=mdlParams['learning_rate_meta'])
            # sanity check
            for param in filter(lambda p: p.requires_grad,
                                modelVars['model'].parameters()):
                print(param.name, param.shape)
def main():
    opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset == 'voc' and not opts.crop_val:
        opts.val_batch_size = 1

    # Set up metrics
    # metrics = StreamSegMetrics(opts.num_classes)
    metrics = StreamSegMetrics(21)
    # Set up optimizer
    # criterion = utils.get_loss(opts.loss_type)
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
    elif opts.loss_type == 'logit':
        criterion = nn.BCELoss(reduction='mean')

    def save_ckpt(path):
        """ save current model
        """
        torch.save({
            "cur_itrs": cur_itrs,
            "model_state": model.module.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_score": best_score,
        }, path)
        print("Model saved as %s" % path)

    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None:
        print("Error --ckpt, can't read model")
        return

    _, val_dst, test_dst = get_dataset(opts)
    val_loader = data.DataLoader(
        val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
    test_loader = data.DataLoader(
        test_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
    vis_sample_id = np.random.randint(0, len(test_loader), opts.vis_num_samples,
                                      np.int32) if opts.enable_vis else None  # sample idxs for visualization

    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # denormalization for ori images
    # ==========   Test Loop   ==========#

    if opts.test_only:
        print("Dataset: %s,  Val set: %d, Test set: %d" %
              (opts.dataset, len(val_dst), len(test_dst)))

        metrics = StreamSegMetrics(21)
        print("val")

        test_score, ret_samples = test_single(opts=opts,
                                              loader=test_loader, device=device, metrics=metrics,
                                              ret_samples_ids=vis_sample_id)
        print("test")
        test_score, ret_samples = test_multiple(
            opts=opts, loader=test_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
        print(metrics.to_str(test_score))
        return

    # ==========   Train Loop   ==========#
    utils.mkdir('checkpoints/multiple_model2')
    for class_num in range(opts.start_class, opts.num_classes):
        # ==========   Dataset   ==========#
        train_dst, val_dst, test_dst = get_dataset_multiple(opts, class_num)
        train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2)
        val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
        test_loader = data.DataLoader(test_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
        print("Dataset: %s Class %d, Train set: %d, Val set: %d, Test set: %d" % (
            opts.dataset, class_num, len(train_dst), len(val_dst), len(test_dst)))

        # ==========   Model   ==========#
        model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
        if opts.separable_conv and 'plus' in opts.model:
            network.convert_to_separable_conv(model.classifier)
        utils.set_bn_momentum(model.backbone, momentum=0.01)

        # ==========   Params and learning rate   ==========#
        params_list = [
            {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr},
            {'params': model.classifier.parameters(), 'lr': 0.1 * opts.lr}  # opts.lr
        ]
        if 'SA' in opts.model:
            params_list.append({'params': model.attention.parameters(), 'lr': 0.1 * opts.lr})
        optimizer = torch.optim.Adam(params=params_list, lr=opts.lr, weight_decay=opts.weight_decay)

        if opts.lr_policy == 'poly':
            scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
        elif opts.lr_policy == 'step':
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)

        model = nn.DataParallel(model)
        model.to(device)

        best_score = 0.0
        cur_itrs = 0
        cur_epochs = 0

        interval_loss = 0
        while True:  # cur_itrs < opts.total_itrs:
            # =====  Train  =====
            model.train()

            cur_epochs += 1
            for (images, labels) in train_loader:
                cur_itrs += 1

                images = images.to(device, dtype=torch.float32)
                labels = labels.to(device, dtype=torch.long)
                # labels=(labels==class_num).float()
                optimizer.zero_grad()
                outputs = model(images)

                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                np_loss = loss.detach().cpu().numpy()
                interval_loss += np_loss
                if vis is not None:
                    vis.vis_scalar('Loss', cur_itrs, np_loss)

                if (cur_itrs) % 10 == 0:
                    interval_loss = interval_loss / 10
                    print("Epoch %d, Itrs %d/%d, Loss=%f" %
                          (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                    interval_loss = 0.0

                if (cur_itrs) % opts.val_interval == 0:
                    save_ckpt('checkpoints/multiple_model2/latest_%s_%s_class%d_os%d.pth' %
                              (opts.model, opts.dataset, class_num, opts.output_stride,))
                    print("validation...")
                    model.eval()
                    val_score, ret_samples = validate(
                        opts=opts, model=model, loader=val_loader, device=device, metrics=metrics,
                        ret_samples_ids=vis_sample_id, class_num=class_num)
                    print(metrics.to_str(val_score))

                    if val_score['Mean IoU'] > best_score:  # save best model
                        best_score = val_score['Mean IoU']
                        save_ckpt('checkpoints/multiple_model2/best_%s_%s_class%d_os%d.pth' %
                                  (opts.model, opts.dataset, class_num, opts.output_stride))

                    if vis is not None:  # visualize validation score and samples
                        vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
                        vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
                        vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

                        for k, (img, target, lbl) in enumerate(ret_samples):
                            img = (denorm(img) * 255).astype(np.uint8)
                            target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
                            lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
                            concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along width
                            vis.vis_image('Sample %d' % k, concat_img)
                    model.train()

                scheduler.step()

                if cur_itrs >= opts.total_itrs:
                    save_ckpt('checkpoints/multiple_model2/latest_%s_%s_class%d_os%d.pth' %
                              (opts.model, opts.dataset, class_num, opts.output_stride,))
                    print("Saving..")
                    break
            if cur_itrs >= opts.total_itrs:
                cur_itrs = 0
                break

        print("Model of class %d is trained and saved " % (class_num))
Ejemplo n.º 7
0
def main():
    opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset=='voc' and not opts.crop_val:
        opts.val_batch_size = 1
    

    pipe = create_dali_pipeline(batch_size=opts.batch_size, num_threads=8,
                                device_id=0, data_dir="/home/ubuntu/cityscapes")
    pipe.build()
    train_loader = DALIGenericIterator(pipe, output_map=['image', 'label'], last_batch_policy=LastBatchPolicy.PARTIAL)

    # Set up model
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }

    model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)
    
    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},
        {'params': model.classifier.parameters(), 'lr': opts.lr},
    ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    #optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    #torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy=='poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy=='step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)

    # Set up criterion
    #criterion = utils.get_loss(opts.loss_type)
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')

    def save_ckpt(path):
        """ save current model
        """
        torch.save({
            "cur_itrs": cur_itrs,
            "model_state": model.module.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_score": best_score,
        }, path)
        print("Model saved as %s" % path)
    
    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    #==========   Train Loop   ==========#
    interval_loss = 0

    class_conv = [255, 255, 255, 255, 255,
                  255, 255, 255, 0, 1,
                  255, 255, 2, 3, 4,
                  255, 255, 255, 5, 255,
                  6, 7, 8, 9, 10,
                  11, 12, 13, 14, 15,
                  255, 255, 16, 17, 18]

    while True: #cur_itrs < opts.total_itrs:
        # =====  Train  =====
        model.train()
        #model = model.half()
        cur_epochs += 1
        while True:
            train_iter = iter(train_loader)
            try:
                nvtx.range_push("Batch " + str(cur_itrs))

                nvtx.range_push("Data loading")
                data = next(train_iter)
                cur_itrs += 1

                images = data[0]['image'].to(dtype=torch.float32)
                labels = data[0]['label'][:, :, :, 0].to(dtype=torch.long)
                labels = torch.zeros(data[0]['label'][:, :, :, 0].shape).to(device, dtype=torch.long)
                nvtx.range_pop()

                nvtx.range_push("Forward pass")
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                nvtx.range_pop()

                nvtx.range_push("Backward pass")
                loss.backward()
                optimizer.step()
                nvtx.range_pop()

                np_loss = loss.detach().cpu().numpy()
                interval_loss += np_loss

                nvtx.range_pop()

                if cur_itrs == 10:
                    break

                if vis is not None:
                    vis.vis_scalar('Loss', cur_itrs, np_loss)

                print("Epoch %d, Itrs %d/%d, Loss=%f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                interval_loss = 0.0

                scheduler.step()  

                if cur_itrs >=  opts.total_itrs:
                    return
            except StopIteration:
                break

        break
Ejemplo n.º 8
0
         for param in modelVars['model']._fc.parameters():
             param.is_cnn_param = False                              
     # modify model
     modelVars['model'] = models.modify_meta(params,modelVars['model'])  
     # Mark new parameters
     for param in modelVars['model'].parameters():
         if not hasattr(param, 'is_cnn_param'):
             param.is_cnn_param = False                 
 # multi gpu support
 if len(params['numGPUs']) > 1:
     modelVars['model'] = nn.DataParallel(modelVars['model']) 
 modelVars['model'] = modelVars['model'].cuda()
 #summary(modelVars['model'], modelVars['model'].input_size)# (params['input_size'][2], params['input_size'][0], params['input_size'][1]))
 # Loss, with class weighting
 if params.get('focal_loss',False):
     modelVars['criterion'] = utils.FocalLoss(alpha=class_weights.tolist())
 elif params['balance_classes'] == 2:
     #modelVars['criterion'] = nn.BCEWithLogitsLoss(weight=torch.cuda.FloatTensor(class_weights.astype(np.float32)))
     modelVars['criterion'] = nn.CrossEntropyLoss(weight=torch.cuda.FloatTensor(class_weights.astype(np.float32)))
 elif params['balance_classes'] == 3 or params['balance_classes'] == 0 or params['balance_classes'] == 12:
     modelVars['criterion'] = nn.CrossEntropyLoss()
 elif params['balance_classes'] == 8:
     modelVars['criterion'] = nn.CrossEntropyLoss(reduce=False)
 elif params['balance_classes'] == 6 or params['balance_classes'] == 7:
     modelVars['criterion'] = nn.CrossEntropyLoss(weight=torch.cuda.FloatTensor(class_weights.astype(np.float32)),reduce=False)
 elif params['balance_classes'] == 10:
     modelVars['criterion'] = utils.FocalLoss(params['numClasses'])
 elif params['balance_classes'] == 11:
     modelVars['criterion'] = utils.FocalLoss(params['numClasses'],alpha=torch.cuda.FloatTensor(class_weights.astype(np.float32)))
 else:
     modelVars['criterion'] = nn.CrossEntropyLoss(weight=torch.cuda.FloatTensor(class_weights.astype(np.float32)))
def main():

    opts = get_argparser().parse_args()

    save_dir = os.path.join(opts.save_dir + opts.model + '/')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    print('Save position is %s\n' % (save_dir))

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    # select the GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s,  CUDA_VISIBLE_DEVICES: %s\n" % (device, opts.gpu_id))

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    train_dst, val_dst = get_dataset(opts)
    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=8,
                                   drop_last=True,
                                   pin_memory=False)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.batch_size,
                                 shuffle=True,
                                 num_workers=8,
                                 drop_last=True,
                                 pin_memory=False)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))

    # Set up model
    model_map = {
        'self_contrast': network.self_contrast,
        'DCNet_L1': network.DCNet_L1,
        'DCNet_L12': network.DCNet_L12,
        'DCNet_L123': network.DCNet_L123,
        'FCN': network.FCN,
        'UNet': network.UNet,
        'SegNet': network.SegNet,
        'cloudSegNet': network.cloudSegNet,
        'cloudUNet': network.cloudUNet
    }

    print('Model = %s, num_classes=%d' % (opts.model, opts.num_classes))
    model = model_map[opts.model](n_classes=opts.num_classes,
                                  is_batchnorm=True,
                                  in_channels=opts.in_channels,
                                  feature_scale=opts.feature_scale,
                                  is_deconv=False)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=opts.lr,
                                momentum=0.9,
                                weight_decay=opts.weight_decay)

    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=opts.step_size,
                                                    gamma=0.5)

    # Set up criterion
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss()
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss()

    def save_ckpt(path):
        """ save current model
        """
        torch.save(
            {
                "cur_itrs": cur_itrs,
                "model_state": model.module.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s\n\n" % path)

    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in checkpoint["model_state"].items() if (k in model_dict)
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            scheduler.max_iters = opts.total_itrs
            scheduler.min_lr = opts.lr
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Continue training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        print("Best_score is %s" % (str(best_score)))
        # del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    # ==========   Train Loop   ==========#
    vis_sample_id = np.random.randint(
        0, len(val_loader), opts.vis_num_samples,
        np.int32) if opts.enable_vis else None  # sample idxs for visualization

    interval_loss = 0
    train_loss = list()
    train_accuracy = list()
    best_val_itrs = list()
    while True:  # cur_itrs < opts.total_itrs:
        # =====  Train  =====
        model.train()
        cur_epochs += 1

        for (images, labels) in train_loader:
            if (cur_itrs) == 0 or (cur_itrs) % opts.print_interval == 0:
                t1 = time.time()

            cur_itrs += 1

            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss

            if (cur_itrs) % opts.print_interval == 0:
                interval_loss = interval_loss / opts.print_interval
                train_loss.append(interval_loss)
                t2 = time.time()
                print("Epoch %d, Itrs %d/%d, Loss=%f, Time = %f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss,
                       t2 - t1))
                interval_loss = 0.0

            # save the ckpt file per 5000 itrs
            if (cur_itrs) % opts.val_interval == 0:
                print("validation...")
                model.eval()

                save_ckpt(save_dir + 'latest_%s_%s_itrs%s.pth' %
                          (opts.model, opts.dataset, str(cur_itrs)))
                time_before_val = time.time()
                val_score, ret_samples = validate(
                    opts=opts,
                    model=model,
                    loader=val_loader,
                    device=device,
                    metrics=metrics,
                    ret_samples_ids=vis_sample_id)

                time_after_val = time.time()
                print('Time_val = %f' % (time_after_val - time_before_val))
                print(metrics.to_str(val_score))

                train_accuracy.append(val_score['overall_acc'])
                if val_score['overall_acc'] > best_score:  # save best model
                    best_score = val_score['overall_acc']
                    save_ckpt(save_dir + 'best_%s_%s_.pth' %
                              (opts.model, opts.dataset))
                    best_val_itrs.append(cur_itrs)
                model.train()
            scheduler.step()  # update

            if cur_itrs >= opts.total_itrs:
                print(cur_itrs)
                print(opts.total_itrs)
                return
Ejemplo n.º 10
0
def train(args):
    # get model 
    model = getattr(models, config.model_name)()
    if args.ckpt and not args.resume:
        state = torch.load(args.ckpt, map_location='cpu')
        model.load_state_dict(state['state_dict'])
        print('train with pretrained weight val_f1', state['f1'])
    model = model.to(device)
    print(model)
    # data
    train_dataset = ECGDataset(data_path=config.train_data, train=True)
    train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=6)
    val_dataset = ECGDataset(data_path=config.train_data, train=False)
    val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, num_workers=4)
    print("train_datasize", len(train_dataset), "val_datasize", len(val_dataset))
    # optimizer and loss
    optimizer = optim.Adam(model.parameters(), lr=config.lr)
    w = torch.tensor(train_dataset.wc, dtype=torch.float).to(device)
    criterion = utils.FocalLoss()
    # model save dir
    model_save_dir = '%s/%s_%s' % (config.ckpt, config.model_name, time.strftime("%Y%m%d%H%M"))
    utils.mkdirs(model_save_dir)
    if args.ex: model_save_dir += args.ex
    best_f1 = -1
    lr = config.lr
    start_epoch = 1
    stage = 1
    # train from last save point
    if args.resume:
        if os.path.exists(args.ckpt):  # weight path
            model_save_dir = args.ckpt
            current_w = torch.load(os.path.join(args.ckpt, config.current_w))
            best_w = torch.load(os.path.join(model_save_dir, config.best_w))
            best_f1 = best_w['loss']
            start_epoch = current_w['epoch'] + 1
            lr = current_w['lr']
            stage = current_w['stage']
            model.load_state_dict(current_w['state_dict'])
            if start_epoch - 1 in config.stage_epoch:
                stage += 1
                lr /= config.lr_decay
                utils.adjust_learning_rate(optimizer, lr)
                model.load_state_dict(best_w['state_dict'])
            print("=> loaded checkpoint (epoch {})".format(start_epoch - 1))
    #logger = Logger(logdir=model_save_dir, flush_secs=2)
    # =========>start training<=========
    for epoch in range(start_epoch, config.max_epoch + 1):
        since = time.time()
        train_loss, train_f1 = train_epoch(model, optimizer, criterion, train_dataloader, show_interval=100)
        val_loss, val_f1 = val_epoch(model, criterion, val_dataloader)
        print('#epoch:%02d stage:%d train_loss:%.3e train_f1:%.3f  val_loss:%0.3e val_f1:%.3f time:%s\n'
              % (epoch, stage, train_loss, train_f1, val_loss, val_f1, utils.print_time_cost(since)))
        # logger.log_value('train_loss', train_loss, step=epoch)
        # logger.log_value('train_f1', train_f1, step=epoch)
        # logger.log_value('val_loss', val_loss, step=epoch)
        # logger.log_value('val_f1', val_f1, step=epoch)
        state = {"state_dict": model.state_dict(), "epoch": epoch, "loss": val_loss, 'f1': val_f1, 'lr': lr,
                 'stage': stage}
        save_ckpt(state, best_f1 < val_f1, model_save_dir)
        best_f1 = max(best_f1, val_f1)
        print(best_f1)
        if epoch in config.stage_epoch:
            stage += 1
            lr /= config.lr_decay
            best_w = os.path.join(model_save_dir, config.best_w)
            model.load_state_dict(torch.load(best_w)['state_dict'])
            print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr))
            utils.adjust_learning_rate(optimizer, lr)
Ejemplo n.º 11
0
def main():
    # add configuration file
    # Dictionary for model configuration
    mdlParams = {}

    # Import machine config
    pc_cfg = importlib.import_module('pc_cfgs.' + sys.argv[1])
    mdlParams.update(pc_cfg.mdlParams)

    # Import model config
    model_cfg = importlib.import_module('cfgs.' + sys.argv[2])
    mdlParams_model = model_cfg.init(mdlParams)

    if 'cc_late' in sys.argv[2]:
        mdlParams['color_augmentation'] = True
    else:
        mdlParams['color_augmentation'] = False
    mdlParams.update(mdlParams_model)

    # Indicate training
    mdlParams['trainSetState'] = 'train'

    # Path name from filename
    mdlParams['saveDirBase'] = mdlParams['saveDir'] + sys.argv[2]

    # Set visible devices
    if 'gpu' in sys.argv[3]:
        mdlParams['numGPUs'] = [[
            int(s) for s in re.findall(r'\d+', sys.argv[3])
        ][-1]]
        cuda_str = ""
        for i in range(len(mdlParams['numGPUs'])):
            cuda_str = cuda_str + str(mdlParams['numGPUs'][i])
            if i is not len(mdlParams['numGPUs']) - 1:
                cuda_str = cuda_str + ","
        print("Devices to use:", cuda_str)
        os.environ["CUDA_VISIBLE_DEVICES"] = cuda_str

    # Specify val set to train for
    if len(sys.argv) > 4:
        mdlParams['cv_subset'] = [
            int(s) for s in re.findall(r'\d+', sys.argv[4])
        ]
        print("Training validation sets", mdlParams['cv_subset'])

    # Check if there is a validation set, if not, evaluate train error instead
    if 'valIndCV' in mdlParams or 'valInd' in mdlParams:
        eval_set = 'valInd'
        print("Evaluating on validation set during training.")
    else:
        eval_set = 'trainInd'
        print("No validation set, evaluating on training set during training.")

    # Check if there were previous ones that have alreary bin learned
    # prevFile = Path(mdlParams['saveDirBase'] + '/CV.pkl')
    prevFile = Path(mdlParams['saveDirBase'] + '\CV.pkl')
    # print(prevFile)
    if prevFile.exists():
        print("Part of CV already done")
        # with open(mdlParams['saveDirBase'] + '/CV.pkl', 'rb') as f:
        with open(mdlParams['saveDirBase'] + '\CV.pkl', 'rb') as f:
            allData = pickle.load(f)
    else:
        allData = {}
        allData['f1Best'] = {}
        allData['sensBest'] = {}
        allData['specBest'] = {}
        allData['accBest'] = {}
        allData['waccBest'] = {}
        allData['aucBest'] = {}
        allData['convergeTime'] = {}
        allData['bestPred'] = {}
        allData['targets'] = {}

    # Take care of CV
    if mdlParams.get('cv_subset', None) is not None:
        cv_set = mdlParams['cv_subset']
    else:
        cv_set = range(mdlParams['numCV'])
    for cv in cv_set:
        # Check if this fold was already trained
        already_trained = False
        if 'valIndCV' in mdlParams:
            mdlParams['saveDir'] = mdlParams['saveDirBase'] + '/CVSet' + str(
                cv) + str(cuda_str)
            # mdlParams['saveDir'] = mdlParams['saveDirBase'] + '\CVSet' + str(cv)
            if os.path.isdir(mdlParams['saveDirBase']):
                if os.path.isdir(mdlParams['saveDir']):
                    all_max_iter = []
                    for name in os.listdir(mdlParams['saveDir']):
                        int_list = [int(s) for s in re.findall(r'\d+', name)]
                        if len(int_list) > 0:
                            all_max_iter.append(int_list[-1])
                        # if '-' + str(mdlParams['training_steps'])+ '.pt' in name:
                        #    print("Fold %d already fully trained"%(cv))
                        #    already_trained = True
                    all_max_iter = np.array(all_max_iter)
                    if len(all_max_iter) > 0 and np.max(
                            all_max_iter) >= mdlParams['training_steps']:
                        print(
                            "Fold %d already fully trained with %d iterations"
                            % (cv, np.max(all_max_iter)))
                        already_trained = True
        if already_trained:
            continue
        print("CV set", cv)
        # Reset model graph
        importlib.reload(models)
        # importlib.reload(torchvision)
        # Collect model variables
        modelVars = {}
        # print("here")
        modelVars['device'] = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        print(modelVars['device'])
        # Def current CV set
        mdlParams['trainInd'] = mdlParams['trainIndCV']
        if 'valIndCV' in mdlParams:
            mdlParams['valInd'] = mdlParams['valIndCV']
        # Def current path for saving stuff
        if 'valIndCV' in mdlParams:
            mdlParams['saveDir'] = mdlParams['saveDirBase'] + '/CVSet' + str(
                cv) + str(cuda_str)
            # mdlParams['saveDir'] = mdlParams['saveDirBase'] + '\CVSet' + str(cv)
        else:
            mdlParams['saveDir'] = mdlParams['saveDirBase']
        # Create basepath if it doesnt exist yet
        if not os.path.isdir(mdlParams['saveDirBase']):
            os.mkdir(mdlParams['saveDirBase'])
        # Check if there is something to load
        load_old = 0
        if os.path.isdir(mdlParams['saveDir']):
            # Check if a checkpoint is in there
            if len([name for name in os.listdir(mdlParams['saveDir'])]) > 0:
                load_old = 1
                print("Loading old model")
            else:
                # Delete whatever is in there (nothing happens)
                filelist = [
                    os.remove(mdlParams['saveDir'] + '/' + f)
                    for f in os.listdir(mdlParams['saveDir'])
                ]
                # filelist = [os.remove(mdlParams['saveDir'] + '\\' + f) for f in os.listdir(mdlParams['saveDir'])]
        else:
            os.mkdir(mdlParams['saveDir'])
        # Save training progress in here
        save_dict = {}
        save_dict['acc'] = []
        save_dict['loss'] = []
        save_dict['wacc'] = []
        save_dict['auc'] = []
        save_dict['sens'] = []
        save_dict['spec'] = []
        save_dict['f1'] = []
        save_dict['step_num'] = []
        if mdlParams['print_trainerr']:
            save_dict_train = {}
            save_dict_train['acc'] = []
            save_dict_train['loss'] = []
            save_dict_train['wacc'] = []
            save_dict_train['auc'] = []
            save_dict_train['sens'] = []
            save_dict_train['spec'] = []
            save_dict_train['f1'] = []
            save_dict_train['step_num'] = []
            # Potentially calculate setMean to subtract
        elif mdlParams['balance_classes'] == 9:
            # Only use official indicies for calculation
            print("Balance 9")
            indices_ham = mdlParams['trainInd']
            print("Length of trainInd line " + str(len(indices_ham)))
            if mdlParams['numClasses'] == 9:
                class_weights_ = 1.0 / np.mean(
                    mdlParams['labels_array'][indices_ham, :8], axis=0)
                # print("class before",class_weights_)
                class_weights = np.zeros([mdlParams['numClasses']])
                class_weights[:8] = class_weights_
                class_weights[-1] = np.max(class_weights_)
            else:
                print(len(mdlParams['labels_array']))
                class_weights = 1.0 / np.mean(mdlParams['labels_array'],
                                              axis=0)
            print("Current class weights", class_weights)
            if isinstance(mdlParams['extra_fac'], float):
                class_weights = np.power(class_weights, mdlParams['extra_fac'])
            else:
                class_weights = class_weights * mdlParams['extra_fac']
            print("Current class weights with extra", class_weights)

            # Meta scaler
        if mdlParams.get('meta_features',
                         None) is not None and mdlParams['scale_features']:
            mdlParams[
                'feature_scaler_meta'] = sklearn.preprocessing.StandardScaler(
                ).fit(mdlParams['meta_array'][mdlParams['trainInd'], :])
            print("scaler mean", mdlParams['feature_scaler_meta'].mean_, "var",
                  mdlParams['feature_scaler_meta'].var_)

            # Set up dataloaders
        num_workers = psutil.cpu_count(logical=False)
        # For train
        dataset_train = utils.ISICDataset(mdlParams, 'trainInd')
        # For val
        dataset_val = utils.ISICDataset(mdlParams, 'valInd')
        if mdlParams['multiCropEval'] > 0:
            modelVars['dataloader_valInd'] = DataLoader(
                dataset_val,
                batch_size=mdlParams['multiCropEval'],
                shuffle=False,
                num_workers=num_workers,
                pin_memory=True)
        else:
            modelVars['dataloader_valInd'] = DataLoader(
                dataset_val,
                batch_size=mdlParams['batchSize'],
                shuffle=False,
                num_workers=num_workers,
                pin_memory=True)

        if mdlParams['balance_classes'] == 12 or mdlParams[
                'balance_classes'] == 13:
            # print(np.argmax(mdlParams['labels_array'][mdlParams['trainInd'],:],1).size(0))
            strat_sampler = utils.StratifiedSampler(mdlParams)
            modelVars['dataloader_trainInd'] = DataLoader(
                dataset_train,
                batch_size=mdlParams['batchSize'],
                sampler=strat_sampler,
                num_workers=num_workers,
                pin_memory=True)
        else:
            modelVars['dataloader_trainInd'] = DataLoader(
                dataset_train,
                batch_size=mdlParams['batchSize'],
                shuffle=True,
                num_workers=num_workers,
                pin_memory=True,
                drop_last=True)
            # print("Setdiff",np.setdiff1d(mdlParams['trainInd'],mdlParams['trainInd']))
        # Define model
        modelVars['model'] = models.getModel(mdlParams)()
        # Load trained model
        if mdlParams.get('meta_features', None) is not None:
            # Find best checkpoint
            files = glob(mdlParams['model_load_path'] + '/CVSet' + str(cv) +
                         str(cuda_str) + '/*')
            # files = glob(mdlParams['model_load_path'] + '\CVSet' + str(cv) + '\*')
            global_steps = np.zeros([len(files)])
            # print("files",files)
            for i in range(len(files)):
                # Use meta files to find the highest index
                if 'best' not in files[i]:
                    continue
                if 'checkpoint' not in files[i]:
                    continue
                    # Extract global step
                nums = [int(s) for s in re.findall(r'\d+', files[i])]
                global_steps[i] = nums[-1]
            # Create path with maximum global step found
            chkPath = mdlParams['model_load_path'] + '/CVSet' + str(
                cv) + '/checkpoint_best-' + str(int(
                    np.max(global_steps))) + '.pt'
            # chkPath = mdlParams['model_load_path'] + '\CVSet' + str(cv) + '\checkpoint_best-' + str(
            #     int(np.max(global_steps))) + '.pt'
            print("Restoring lesion-trained CNN for meta data training: ",
                  chkPath)
            # Load
            state = torch.load(chkPath)
            # Initialize model
            curr_model_dict = modelVars['model'].state_dict()
            for name, param in state['state_dict'].items():
                # print(name,param.shape)
                if isinstance(param, nn.Parameter):
                    # backwards compatibility for serialized parameters
                    param = param.data
                if curr_model_dict[name].shape == param.shape:
                    curr_model_dict[name].copy_(param)
                else:
                    print("not restored", name, param.shape)
            # modelVars['model'].load_state_dict(state['state_dict'])
        # Original input size
        # if 'Dense' not in mdlParams['model_type']:
        #    print("Original input size",modelVars['model'].input_size)
        # print(modelVars['model'])
        # Define classifier layer
        if 'Dense' in mdlParams['model_type']:
            if mdlParams['input_size'][0] != 224:
                modelVars['model'] = utils.modify_densenet_avg_pool(
                    modelVars['model'])
                # print(modelVars['model'])
            num_ftrs = modelVars['model'].classifier.in_features
            modelVars['model'].classifier = nn.Linear(num_ftrs,
                                                      mdlParams['numClasses'])
            # print(modelVars['model'])
        elif 'dpn' in mdlParams['model_type']:
            num_ftrs = modelVars['model'].classifier.in_channels
            modelVars['model'].classifier = nn.Conv2d(num_ftrs,
                                                      mdlParams['numClasses'],
                                                      [1, 1])
            # modelVars['model'].add_module('real_classifier',nn.Linear(num_ftrs, mdlParams['numClasses']))
            # print(modelVars['model'])
        elif 'efficient' in mdlParams['model_type'] and ('0' in mdlParams['model_type'] or '1' in mdlParams['model_type'] \
                or '2' in mdlParams['model_type'] or '3' in mdlParams['model_type']):
            num_ftrs = modelVars['model'].classifier.in_features
            modelVars['model'].classifier = nn.Linear(num_ftrs,
                                                      mdlParams['numClasses'])

        elif 'efficient' in mdlParams['model_type']:
            num_ftrs = modelVars['model']._fc.in_features
            modelVars['model'].classifier = nn.Linear(num_ftrs,
                                                      mdlParams['numClasses'])

        elif 'wsl' in mdlParams['model_type'] or 'Resnet' in mdlParams[
                'model_type'] or 'Inception' in mdlParams['model_type']:
            # Do nothing, output is prepared
            num_ftrs = modelVars['model'].fc.in_features
            modelVars['model'].fc = nn.Linear(num_ftrs,
                                              mdlParams['numClasses'])
        else:
            num_ftrs = modelVars['model'].last_linear.in_features
            modelVars['model'].last_linear = nn.Linear(num_ftrs,
                                                       mdlParams['numClasses'])
            # Take care of meta case
        if mdlParams.get('meta_features', None) is not None:
            # freeze cnn first
            if mdlParams['freeze_cnn']:
                # deactivate all
                for param in modelVars['model'].parameters():
                    param.requires_grad = False
                if 'wsl' in mdlParams['model_type'] or 'Resnet' in mdlParams[
                        'model_type'] or 'Inception' in mdlParams['model_type']:
                    # Activate classifier layer
                    for param in modelVars['model'].fc.parameters():
                        param.requires_grad = True
                elif ('efficient' in mdlParams['model_type'] and ('0' in mdlParams['model_type'] or '1' in mdlParams['model_type'] \
                        or '2' in mdlParams['model_type'] or '3' in mdlParams['model_type'])) or 'Dense' in mdlParams['model_type']:
                    # Activate classifier layer
                    for param in modelVars['model'].classifier.parameters():
                        param.requires_grad = True
                elif 'efficient' in mdlParams['model_type']:
                    #Activate classifier layer
                    for param in modelVars['model']._fc.parameters():
                        param.requires_grad = True
                else:
                    # Activate classifier layer
                    for param in modelVars['model'].last_linear.parameters():
                        param.requires_grad = True
            else:
                # mark cnn parameters
                for param in modelVars['model'].parameters():
                    param.is_cnn_param = True
                # unmark fc
                for param in modelVars['model']._fc.parameters():
                    param.is_cnn_param = False
                    # modify model
            modelVars['model'] = models.modify_meta(mdlParams,
                                                    modelVars['model'])
            # Mark new parameters
            for param in modelVars['model'].parameters():
                if not hasattr(param, 'is_cnn_param'):
                    param.is_cnn_param = False
                    # multi gpu support
        if len(mdlParams['numGPUs']) > 1:
            modelVars['model'] = nn.DataParallel(modelVars['model'])
        modelVars['model'] = modelVars['model'].cuda()
        # summary(modelVars['model'], modelVars['model'].input_size)# (mdlParams['input_size'][2], mdlParams['input_size'][0], mdlParams['input_size'][1]))
        # Loss, with class weighting
        if mdlParams.get('focal_loss', False):
            modelVars['criterion'] = utils.FocalLoss(
                alpha=class_weights.tolist())
        elif mdlParams['balance_classes'] == 3 or mdlParams[
                'balance_classes'] == 0 or mdlParams['balance_classes'] == 12:
            modelVars['criterion'] = nn.CrossEntropyLoss()
        elif mdlParams['balance_classes'] == 8:
            modelVars['criterion'] = nn.CrossEntropyLoss(reduce=False)
        elif mdlParams['balance_classes'] == 6 or mdlParams[
                'balance_classes'] == 7:
            modelVars['criterion'] = nn.CrossEntropyLoss(
                weight=torch.cuda.FloatTensor(class_weights.astype(
                    np.float32)),
                reduce=False)
        elif mdlParams['balance_classes'] == 10:
            modelVars['criterion'] = utils.FocalLoss(mdlParams['numClasses'])
        elif mdlParams['balance_classes'] == 11:
            modelVars['criterion'] = utils.FocalLoss(
                mdlParams['numClasses'],
                alpha=torch.cuda.FloatTensor(class_weights.astype(np.float32)))
        else:
            modelVars['criterion'] = nn.CrossEntropyLoss(
                weight=torch.cuda.FloatTensor(class_weights.astype(
                    np.float32)))

        if mdlParams.get('meta_features', None) is not None:
            if mdlParams['freeze_cnn']:
                modelVars['optimizer'] = optim.Adam(
                    filter(lambda p: p.requires_grad,
                           modelVars['model'].parameters()),
                    lr=mdlParams['learning_rate_meta'])
                # sanity check
                for param in filter(lambda p: p.requires_grad,
                                    modelVars['model'].parameters()):
                    print(param.name, param.shape)
            else:
                modelVars['optimizer'] = optim.Adam(
                    [{
                        'params':
                        filter(lambda p: not p.is_cnn_param,
                               modelVars['model'].parameters()),
                        'lr':
                        mdlParams['learning_rate_meta']
                    }, {
                        'params':
                        filter(lambda p: p.is_cnn_param,
                               modelVars['model'].parameters()),
                        'lr':
                        mdlParams['learning_rate']
                    }],
                    lr=mdlParams['learning_rate'])
        else:
            modelVars['optimizer'] = optim.Adam(
                modelVars['model'].parameters(), lr=mdlParams['learning_rate'])

        # Decay LR by a factor of 0.1 every 7 epochs
        modelVars['scheduler'] = lr_scheduler.StepLR(
            modelVars['optimizer'],
            step_size=mdlParams['lowerLRAfter'],
            gamma=1 / np.float32(mdlParams['LRstep']))

        # Define softmax
        modelVars['softmax'] = nn.Softmax(dim=1)

        # Set up training
        # loading from checkpoint
        if load_old:
            # Find last, not last best checkpoint
            files = glob(mdlParams['saveDir'] + '/*')
            # files = glob(mdlParams['saveDir'] + '\*')
            global_steps = np.zeros([len(files)])
            for i in range(len(files)):
                # Use meta files to find the highest index
                if 'best' in files[i]:
                    continue
                if 'checkpoint-' not in files[i]:
                    continue
                    # Extract global step
                nums = [int(s) for s in re.findall(r'\d+', files[i])]
                global_steps[i] = nums[-1]
            # Create path with maximum global step found
            chkPath = mdlParams['saveDir'] + '/checkpoint-' + str(
                int(np.max(global_steps))) + '.pt'
            # chkPath = mdlParams['saveDir'] + '\checkpoint-' + str(int(np.max(global_steps))) + '.pt'
            print("Restoring: ", chkPath)
            # Load
            state = torch.load(chkPath)
            # Initialize model and optimizer
            modelVars['model'].load_state_dict(state['state_dict'])
            modelVars['optimizer'].load_state_dict(state['optimizer'])
            start_epoch = state['epoch'] + 1
            mdlParams['valBest'] = state.get('valBest', 1000)
            mdlParams['lastBestInd'] = state.get('lastBestInd',
                                                 int(np.max(global_steps)))
        else:
            start_epoch = 1
            mdlParams['lastBestInd'] = -1
            # Track metrics for saving best model
            mdlParams['valBest'] = 1000

        # Num batches
        numBatchesTrain = int(
            math.floor(len(mdlParams['trainInd']) / mdlParams['batchSize']))
        print("Train batches", numBatchesTrain)

        # Run training
        start_time = time.time()
        print("Start training...")
        for step in range(start_epoch, mdlParams['training_steps'] + 1):
            # One Epoch of training
            if step >= mdlParams['lowerLRat'] - mdlParams['lowerLRAfter']:
                modelVars['scheduler'].step()
            modelVars['model'].train()
            for j, (inputs, labels, indices,
                    _) in enumerate(modelVars['dataloader_trainInd']):
                # print(indices)
                # t_load = time.time()
                # Run optimization
                if mdlParams.get('meta_features', None) is not None:
                    inputs[0] = inputs[0].cuda()
                    inputs[1] = inputs[1].cuda()
                else:
                    inputs = inputs.cuda()
                # print(inputs.shape)
                labels = labels.cuda()
                # zero the parameter gradients
                modelVars['optimizer'].zero_grad()
                # forward
                # track history if only in train
                with torch.set_grad_enabled(True):
                    if mdlParams.get('aux_classifier', False):
                        outputs, outputs_aux = modelVars['model'](inputs)
                        loss1 = modelVars['criterion'](outputs, labels)
                        labels_aux = labels.repeat(mdlParams['multiCropTrain'])
                        loss2 = modelVars['criterion'](outputs_aux, labels_aux)
                        loss = loss1 + mdlParams[
                            'aux_classifier_loss_fac'] * loss2
                    else:
                        # print("load",time.time()-t_load)
                        # t_fwd = time.time()
                        outputs = modelVars['model'](inputs)
                        # print("forward",time.time()-t_fwd)
                        # t_bwd = time.time()
                        loss = modelVars['criterion'](outputs, labels)
                        # Perhaps adjust weighting of the loss by the specific index
                    if mdlParams['balance_classes'] == 6 or mdlParams[
                            'balance_classes'] == 7 or mdlParams[
                                'balance_classes'] == 8:
                        # loss = loss.cpu()
                        indices = indices.numpy()
                        loss = loss * torch.cuda.FloatTensor(
                            mdlParams['loss_fac_per_example'][indices].astype(
                                np.float32))
                        loss = torch.mean(loss)
                        # loss = loss.cuda()
                    # backward + optimize only if in training phase
                    loss.backward()
                    modelVars['optimizer'].step()
                    # print("backward",time.time()-t_bwd)
            if step % mdlParams['display_step'] == 0 or step == 1:
                # Calculate evaluation metrics
                if mdlParams['classification']:
                    # Adjust model state
                    modelVars['model'].eval()
                    # Get metrics
                    loss, accuracy, sensitivity, specificity, conf_matrix, f1, auc, waccuracy, predictions, targets, _, _ = utils.getErrClassification_mgpu(
                        mdlParams, eval_set, modelVars)
                    # Save in mat
                    save_dict['loss'].append(loss)
                    save_dict['acc'].append(accuracy)
                    save_dict['wacc'].append(waccuracy)
                    save_dict['auc'].append(auc)
                    save_dict['sens'].append(sensitivity)
                    save_dict['spec'].append(specificity)
                    save_dict['f1'].append(f1)
                    save_dict['step_num'].append(step)
                    if os.path.isfile(mdlParams['saveDir'] + '/progression_' +
                                      eval_set + '.mat'):
                        os.remove(mdlParams['saveDir'] + '/progression_' +
                                  eval_set + '.mat')
                    io.savemat(
                        mdlParams['saveDir'] + '/progression_' + eval_set +
                        '.mat', save_dict)
                    # if os.path.isfile(mdlParams['saveDir'] + '\progression_' + eval_set + '.mat'):
                    #     os.remove(mdlParams['saveDir'] + '\progression_' + eval_set + '.mat')
                    # io.savemat(mdlParams['saveDir'] + '\progression_' + eval_set + '.mat', save_dict)
                eval_metric = -np.mean(waccuracy)
                # Check if we have a new best value
                if eval_metric < mdlParams['valBest']:
                    mdlParams['valBest'] = eval_metric
                    if mdlParams['classification']:
                        allData['f1Best'][cv] = f1
                        allData['sensBest'][cv] = sensitivity
                        allData['specBest'][cv] = specificity
                        allData['accBest'][cv] = accuracy
                        allData['waccBest'][cv] = waccuracy
                        allData['aucBest'][cv] = auc
                    oldBestInd = mdlParams['lastBestInd']
                    mdlParams['lastBestInd'] = step
                    allData['convergeTime'][cv] = step
                    # Save best predictions
                    allData['bestPred'][cv] = predictions
                    allData['targets'][cv] = targets
                    # Write to File
                    with open(mdlParams['saveDirBase'] + '/CV.pkl', 'wb') as f:
                        # with open(mdlParams['saveDirBase'] + '\CV.pkl', 'wb') as f:
                        pickle.dump(allData, f, pickle.HIGHEST_PROTOCOL)
                        # Delte previously best model
                    if os.path.isfile(mdlParams['saveDir'] +
                                      '/checkpoint_best-' + str(oldBestInd) +
                                      '.pt'):
                        os.remove(mdlParams['saveDir'] + '/checkpoint_best-' +
                                  str(oldBestInd) + '.pt')
                    # if os.path.isfile(mdlParams['saveDir'] + '\checkpoint_best-' + str(oldBestInd) + '.pt'):
                    #     os.remove(mdlParams['saveDir'] + '\checkpoint_best-' + str(oldBestInd) + '.pt')
                    # Save currently best model
                    state = {
                        'epoch': step,
                        'valBest': mdlParams['valBest'],
                        'lastBestInd': mdlParams['lastBestInd'],
                        'state_dict': modelVars['model'].state_dict(),
                        'optimizer': modelVars['optimizer'].state_dict()
                    }
                    torch.save(
                        state, mdlParams['saveDir'] + '/checkpoint_best-' +
                        str(step) + '.pt')
                    # torch.save(state, mdlParams['saveDir'] + '\checkpoint_best-' + str(step) + '.pt')

                    # If its not better, just save it delete the last checkpoint if it is not current best one
                # Save current model
                state = {
                    'epoch': step,
                    'valBest': mdlParams['valBest'],
                    'lastBestInd': mdlParams['lastBestInd'],
                    'state_dict': modelVars['model'].state_dict(),
                    'optimizer': modelVars['optimizer'].state_dict()
                }
                torch.save(
                    state,
                    mdlParams['saveDir'] + '/checkpoint-' + str(step) + '.pt')
                # torch.save(state, mdlParams['saveDir'] + '\checkpoint-' + str(step) + '.pt')
                # Delete last one
                if step == mdlParams['display_step']:
                    lastInd = 1
                else:
                    lastInd = step - mdlParams['display_step']
                if os.path.isfile(mdlParams['saveDir'] + '/checkpoint-' +
                                  str(lastInd) + '.pt'):
                    os.remove(mdlParams['saveDir'] + '/checkpoint-' +
                              str(lastInd) + '.pt')
                # if os.path.isfile(mdlParams['saveDir'] + '\checkpoint-' + str(lastInd) + '.pt'):
                #     os.remove(mdlParams['saveDir'] + '\checkpoint-' + str(lastInd) + '.pt')
                # Duration so far
                duration = time.time() - start_time
                # Print
                if mdlParams['classification']:
                    print("\n")
                    print("Config:", sys.argv[2])
                    print('Fold: %d Epoch: %d/%d (%d h %d m %d s)' %
                          (cv, step, mdlParams['training_steps'],
                           int(duration / 3600),
                           int(np.mod(duration, 3600) / 60),
                           int(np.mod(np.mod(duration, 3600), 60))) +
                          time.strftime("%d.%m.-%H:%M:%S", time.localtime()))
                    print("Loss on ", eval_set, "set: ", loss, " Accuracy: ",
                          accuracy, " F1: ", f1, " (best WACC: ",
                          -mdlParams['valBest'], " at Epoch ",
                          mdlParams['lastBestInd'], ")")
                    print("Auc", auc, "Mean AUC", np.mean(auc))
                    print("Per Class Acc", waccuracy, "Weighted Accuracy",
                          np.mean(waccuracy))
                    print("Sensitivity: ", sensitivity, "Specificity",
                          specificity)
                    print("Confusion Matrix")
                    print(conf_matrix)
                    # Potentially peek at test error
                    if mdlParams['peak_at_testerr']:
                        loss, accuracy, sensitivity, specificity, _, f1, _, _, _, _, _, _ = utils.getErrClassification_mgpu(
                            mdlParams, 'testInd', modelVars)
                        print("Test loss: ", loss, " Accuracy: ", accuracy,
                              " F1: ", f1)
                        print("Sensitivity: ", sensitivity, "Specificity",
                              specificity)
                    # Potentially print train err
                    if mdlParams['print_trainerr'] and 'train' not in eval_set:
                        loss, accuracy, sensitivity, specificity, conf_matrix, f1, auc, waccuracy, predictions, targets, _, _ = utils.getErrClassification_mgpu(
                            mdlParams, 'trainInd', modelVars)
                        # Save in mat
                        save_dict_train['loss'].append(loss)
                        save_dict_train['acc'].append(accuracy)
                        save_dict_train['wacc'].append(waccuracy)
                        save_dict_train['auc'].append(auc)
                        save_dict_train['sens'].append(sensitivity)
                        save_dict_train['spec'].append(specificity)
                        save_dict_train['f1'].append(f1)
                        save_dict_train['step_num'].append(step)
                        if os.path.isfile(mdlParams['saveDir'] +
                                          '/progression_trainInd.mat'):
                            os.remove(mdlParams['saveDir'] +
                                      '/progression_trainInd.mat')
                        # if os.path.isfile(mdlParams['saveDir'] + '\progression_trainInd.mat'):
                        #     os.remove(mdlParams['saveDir'] + '\progression_trainInd.mat')
                        scipy.io.savemat(
                            mdlParams['saveDir'] + '/progression_trainInd.mat',
                            save_dict_train)
                        # scipy.io.savemat(mdlParams['saveDir'] + '\progression_trainInd.mat', save_dict_train)
                        print("Train loss: ", loss, " Accuracy: ", accuracy,
                              " F1: ", f1)
                        print("Sensitivity: ", sensitivity, "Specificity",
                              specificity)
        # Free everything in modelvars
        modelVars.clear()
        # After CV Training: print CV results and save them
        print("Best F1:", allData['f1Best'][cv])
        print("Best Sens:", allData['sensBest'][cv])
        print("Best Spec:", allData['specBest'][cv])
        print("Best Acc:", allData['accBest'][cv])
        print("Best Per Class Accuracy:", allData['waccBest'][cv])
        print("Best Weighted Acc:", np.mean(allData['waccBest'][cv]))
        print("Best AUC:", allData['aucBest'][cv])
        print("Best Mean AUC:", np.mean(allData['aucBest'][cv]))
        print("Convergence Steps:", allData['convergeTime'][cv])
testloader = data.DataLoader(cityscapesDataSet(DATA_DIRECTORY,
                                               DATA_LIST_PATH,
                                               crop_size=(512, 256),
                                               mean=IMG_MEAN,
                                               scale=False,
                                               mirror=False,
                                               set=SET,
                                               transform=val_transform),
                             batch_size=cfg.batch,
                             shuffle=False,
                             pin_memory=True)

fusion_loss = FusionLoss()
jacc_loss = JaccardLoss()
mse_loss = torch.nn.MSELoss()
adapt_loss = utils.FocalLoss(ignore_index=255, size_average=True)

interval_loss = 0
cur_itrs = cfg.cur_itrs
cur_epochs = cfg.cur_epochs
total_iter = cfg.total_iter
output_stride = cfg.stride
############################# Training #########################
print(" \n ###################### Training ############### \n")
while True:
    net.train()
    cur_epochs += 1
    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image_adapt, _, name, image_real, target_real = batch