Beispiel #1
0
    def __init__(self, n_epochs, lr):
        self.n_epochs = n_epochs
        self.lr = lr

        self.datasets, self.dataloaders = get_ds_loaders()
        self.model = CustomNet().to(device)
        self.criterion = torch.nn.CrossEntropyLoss()

        self.build_opt()
        self.vis = Visualizer()
Beispiel #2
0
class Train():
    def __init__(self, n_epochs, lr):
        self.n_epochs = n_epochs
        self.lr = lr

        self.datasets, self.dataloaders = get_ds_loaders()
        self.model = CustomNet().to(device)
        self.criterion = torch.nn.CrossEntropyLoss()

        self.build_opt()
        self.vis = Visualizer()

        # for name, param in self.model.named_parameters():
        #     print(name, param.requires_grad)

    def build_opt(self):
        self.optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.lr,
                                   momentum=0.9)  # train only FC layer parameter
        # print(self.optimizer.param_groups) # => only attached layers params

        # Decay LR by a factor of 0.1 every 7 epochs
        self.exp_lr_scheduler = lr_scheduler.StepLR(self.optimizer, step_size=7, gamma=0.1)

    def train(self):
        start = time.time()

        best_model_wts = copy.deepcopy(self.model.state_dict())
        best_acc = 0.0

        for epoch in range(self.n_epochs):
            print("Epoch [%d/%d]" % (epoch + 1, self.n_epochs))
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    self.exp_lr_scheduler.step()
                    self.model.train()  # Set model to training mode
                else:
                    self.model.eval()  # Set model to evaluate mode

                epoch_loss = []
                epoch_acc = []

                # Iterate over data.
                for step, (inputs, labels) in enumerate(self.dataloaders[phase]):
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # zero the parameter gradients
                    self.optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):

                        # for param in self.model.parameters():
                        #     print(param[0], param.requires_grad)
                        # => if param.requires_grad == False, param doesn't change

                        outputs = self.model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = self.criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            self.optimizer.step()

                    # statistics
                    step_loss = loss.item()
                    step_acc = float(torch.sum(preds == labels.data)) / len(preds)

                    if phase == 'train':
                        print("[%d/%d] [%d/%d] loss: %.3f acc: %.3f" % (
                        epoch + 1, self.n_epochs, step + 1, len(self.dataloaders[phase]), step_loss, step_acc))
                        self.vis.plot("Train loss plot per step", step_loss)
                        self.vis.plot("Train acc plot per step", step_acc)

                    epoch_loss.append(step_loss)
                    epoch_acc.append(step_acc)

                epoch_loss = np.mean(epoch_loss)
                epoch_acc = np.mean(epoch_acc)

                print("[%d/%d] phase=%s: Avg loss: %.3f Avg acc: %.3f" % (
                epoch + 1, self.n_epochs, phase, epoch_loss, epoch_acc))
                self.vis.plot("%s avg loss plot per epoch" % phase, epoch_loss)
                self.vis.plot("%s avg acc plot per epoch" % phase, epoch_acc)

                # deep copy the model
                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(self.model.state_dict())

            print()

        time_elapsed = time.time() - start
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print('Best val Acc: {:4f}'.format(best_acc))

        # load best model weights
        self.model.load_state_dict(best_model_wts)
        return self.model
Beispiel #3
0
 def build_net(self):
     model = CustomNet()
     self.model = model.to(device)
Beispiel #4
0
def main():
    # ---------------------------------------------------------
    # Configurations
    # ---------------------------------------------------------

    heavy_augmentation = True  # False to use author's default implementation
    gan_training = False
    mixup_augmentation = False
    fullsize_training = False
    multiscale_training = False
    multi_gpu = True
    mixed_precision_training = True

    model_name = "u2net"  # "u2net", "u2netp", "u2net_heavy"
    se_type = None  # "csse", "sse", "cse", None; None to use author's default implementation
    # checkpoint = "saved_models/u2net/u2net.pth"
    checkpoint = None
    checkpoint_netD = None

    w_adv = 0.2
    w_vgg = 0.2

    train_dirs = [
        "../datasets/sky_segmentation_dataset/datasets/cvprw2020_sky_seg/train/"
    ]
    train_dirs_file_limit = [
        None,
    ]

    image_ext = '.jpg'
    label_ext = '.png'
    dataset_name = "cvprw2020_sky_seg"

    lr = 0.0003
    epoch_num = 500
    batch_size_train = 48
    # batch_size_val = 1
    workers = 16
    save_frq = 1000  # save the model every 2000 iterations

    save_debug_samples = False
    debug_samples_dir = "./debug/"

    # ---------------------------------------------------------

    model_dir = './saved_models/' + model_name + '/'
    os.makedirs(model_dir, exist_ok=True)

    writer = SummaryWriter()

    if fullsize_training:
        batch_size_train = 1
        multiscale_training = False

    # ---------------------------------------------------------
    # 1. Construct data input pipeline
    # ---------------------------------------------------------

    # Get dataset name
    dataset_name = dataset_name.replace(" ", "_")

    # Get training data
    assert len(train_dirs) == len(train_dirs_file_limit), \
        "Different train dirs and train dirs file limit length!"

    tra_img_name_list = []
    tra_lbl_name_list = []
    for d, flimit in zip(train_dirs, train_dirs_file_limit):
        img_files = glob.glob(d + '**/*' + image_ext, recursive=True)
        if flimit:
            img_files = np.random.choice(img_files, size=flimit, replace=False)

        print(f"directory: {d}, files: {len(img_files)}")

        for img_path in img_files:
            lbl_path = img_path.replace("/image/", "/alpha/") \
                .replace(image_ext, label_ext)

            if os.path.exists(img_path) and os.path.exists(lbl_path):
                assert os.path.splitext(
                    os.path.basename(img_path))[0] == os.path.splitext(
                        os.path.basename(lbl_path))[0], "Wrong filename."

                tra_img_name_list.append(img_path)
                tra_lbl_name_list.append(lbl_path)
            else:
                print(
                    f"Warning, dropping sample {img_path} because label file {lbl_path} not found!"
                )

    tra_img_name_list, tra_lbl_name_list = shuffle(tra_img_name_list,
                                                   tra_lbl_name_list)

    train_num = len(tra_img_name_list)
    # val_num = 0  # unused
    print(f"dataset name        : {dataset_name}")
    print(f"training samples    : {train_num}")

    # Construct data input pipeline
    if heavy_augmentation:
        transform = AlbuSampleTransformer(
            get_heavy_transform(
                fullsize_training=fullsize_training,
                transform_size=False if
                (fullsize_training or multiscale_training) else True))
    else:
        transform = transforms.Compose([
            RescaleT(320),
            RandomCrop(288),
        ])

    # Create dataset and dataloader
    dataset_kwargs = dict(img_name_list=tra_img_name_list,
                          lbl_name_list=tra_lbl_name_list,
                          transform=transforms.Compose([
                              transform,
                          ] + ([
                              SaveDebugSamples(out_dir=debug_samples_dir),
                          ] if save_debug_samples else []) + ([
                              ToTensorLab(flag=0),
                          ] if not multiscale_training else [])))
    if mixup_augmentation:
        _dataset_cls = MixupAugSalObjDataset
    else:
        _dataset_cls = SalObjDataset

    salobj_dataset = _dataset_cls(**dataset_kwargs)
    salobj_dataloader = DataLoader(
        salobj_dataset,
        batch_size=batch_size_train,
        collate_fn=multi_scale_collater if multiscale_training else None,
        shuffle=True,
        pin_memory=True,
        num_workers=workers)

    # ---------------------------------------------------------
    # 2. Load model
    # ---------------------------------------------------------

    # Instantiate model
    if model_name == "u2net":
        net = U2NET(3, 1, se_type=se_type)
    elif model_name == "u2netp":
        net = U2NETP(3, 1, se_type=se_type)
    elif model_name == "u2net_heavy":
        net = u2net_heavy()
    elif model_name == "custom":
        net = CustomNet()
    else:
        raise ValueError(f"Unknown model_name: {model_name}")

    # Restore model weights from checkpoint
    if checkpoint:
        if not os.path.exists(checkpoint):
            raise FileNotFoundError(f"Checkpoint file not found: {checkpoint}")

        try:
            print(f"Restoring from checkpoint: {checkpoint}")
            net.load_state_dict(torch.load(checkpoint, map_location="cpu"))
            print(" - [x] success")
        except:
            print(" - [!] error")

    if torch.cuda.is_available():
        net.cuda()

    if gan_training:
        netD = MultiScaleNLayerDiscriminator()

        if checkpoint_netD:
            if not os.path.exists(checkpoint_netD):
                raise FileNotFoundError(
                    f"Discriminator checkpoint file not found: {checkpoint_netD}"
                )

            try:
                print(
                    f"Restoring discriminator from checkpoint: {checkpoint_netD}"
                )
                netD.load_state_dict(
                    torch.load(checkpoint_netD, map_location="cpu"))
                print(" - [x] success")
            except:
                print(" - [!] error")

        if torch.cuda.is_available():
            netD.cuda()

        vgg19 = VGG19Features()
        vgg19.eval()
        if torch.cuda.is_available():
            vgg19 = vgg19.cuda()

    # ---------------------------------------------------------
    # 3. Define optimizer
    # ---------------------------------------------------------

    optimizer = optim.Adam(net.parameters(),
                           lr=lr,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0)
    # optimizer = optim.SGD(net.parameters(), lr=lr)
    # scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=lr/4, max_lr=lr,
    #                                         mode="triangular2",
    #                                         step_size_up=2 * len(salobj_dataloader))

    if gan_training:
        optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.9))

    # ---------------------------------------------------------
    # 4. Initialize AMP and data parallel stuffs
    # ---------------------------------------------------------

    GOT_AMP = False
    if mixed_precision_training:
        try:
            print("Checking for Apex AMP support...")
            from apex import amp
            GOT_AMP = True
            print(" - [x] yes")
        except ImportError:
            print(" - [!] no")

    if GOT_AMP:
        amp.register_float_function(torch, 'sigmoid')
        net, optimizer = amp.initialize(net, optimizer, opt_level="O1")

        if gan_training:
            netD, optimizerD = amp.initialize(netD, optimizerD, opt_level="O1")
            vgg19 = amp.initialize(vgg19, opt_level="O1")

    if torch.cuda.device_count() > 1 and multi_gpu:
        print(f"Multi-GPU training using {torch.cuda.device_count()} GPUs.")
        net = nn.DataParallel(net)

        if gan_training:
            netD = nn.DataParallel(netD)
            vgg19 = nn.DataParallel(vgg19)
    else:
        print(f"Training using {torch.cuda.device_count()} GPUs.")

    # ---------------------------------------------------------
    # 5. Training
    # ---------------------------------------------------------

    print("Start training...")

    ite_num = 0
    ite_num4val = 0
    running_loss = 0.0
    running_bce_loss = 0.0
    running_tar_loss = 0.0
    running_adv_loss = 0.0
    running_per_loss = 0.0
    running_fake_loss = 0.0
    running_real_loss = 0.0
    running_lossD = 0.0

    for epoch in tqdm(range(0, epoch_num), desc="All epochs"):
        net.train()
        if gan_training:
            netD.train()

        for i, data in enumerate(
                tqdm(salobj_dataloader, desc=f"Epoch #{epoch}")):
            ite_num = ite_num + 1
            ite_num4val = ite_num4val + 1

            image_key = "image"
            label_key = "label"
            inputs, labels = data[image_key], data[label_key]
            # tqdm.write(f"input tensor shape: {inputs.shape}")

            inputs = inputs.type(torch.FloatTensor)
            labels = labels.type(torch.FloatTensor)

            # Wrap them in Variable
            if torch.cuda.is_available():
                inputs_v, labels_v = \
                    Variable(inputs.cuda(), requires_grad=False), \
                    Variable(labels.cuda(), requires_grad=False)
            else:
                inputs_v, labels_v = \
                    Variable(inputs, requires_grad=False), \
                    Variable(labels, requires_grad=False)

            # # Zero the parameter gradients
            # optimizer.zero_grad()

            # Forward + backward + optimize

            d6 = 0
            if model_name == "custom":
                d0, d1, d2, d3, d4, d5 = net(inputs_v)
            else:
                d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)

            if gan_training:
                optimizerD.zero_grad()

                dis_fake = netD(inputs_v, d0.detach())
                dis_real = netD(inputs_v, labels_v)

                loss_fake = bce_with_logits_loss(dis_fake,
                                                 torch.zeros_like(dis_fake))
                loss_real = bce_with_logits_loss(dis_real,
                                                 torch.ones_like(dis_real))
                lossD = loss_fake + loss_real

                if GOT_AMP:
                    with amp.scale_loss(lossD, optimizerD) as scaled_loss:
                        scaled_loss.backward()
                else:
                    lossD.backward()

                optimizerD.step()

                writer.add_scalar("lossD/fake", loss_fake.item(), ite_num)
                writer.add_scalar("lossD/real", loss_real.item(), ite_num)
                writer.add_scalar("lossD/sum", lossD.item(), ite_num)
                running_fake_loss += loss_fake.item()
                running_real_loss += loss_real.item()
                running_lossD += lossD.item()

            # Zero the parameter gradients
            optimizer.zero_grad()

            if model_name == "custom":
                loss2, loss = multi_bce_loss_fusion5(d0, d1, d2, d3, d4, d5,
                                                     labels_v)
            else:
                loss2, loss = multi_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6,
                                                    labels_v)

            writer.add_scalar("lossG/bce", loss.item(), ite_num)
            running_bce_loss += loss.item()

            if gan_training:
                # Adversarial loss
                loss_adv = 0.0
                if w_adv:
                    dis_fake = netD(inputs_v, d0)
                    loss_adv = bce_with_logits_loss(dis_fake,
                                                    torch.ones_like(dis_fake))

                # Perceptual loss
                loss_per = 0.0
                if w_vgg:
                    vgg19_fm_pred = vgg19(inputs_v * d0)
                    vgg19_fm_label = vgg19(inputs_v * labels_v)
                    loss_per = mae_loss(vgg19_fm_pred, vgg19_fm_label)

                loss = loss + w_adv * loss_adv + w_vgg * loss_per

            if GOT_AMP:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
            # scheduler.step()

            writer.add_scalar("lossG/sum", loss.item(), ite_num)
            writer.add_scalar("lossG/loss2", loss2.item(), ite_num)
            running_loss += loss.item()
            running_tar_loss += loss2.item()
            if gan_training:
                writer.add_scalar("lossG/adv", loss_adv.item(), ite_num)
                writer.add_scalar("lossG/perceptual", loss_per.item(), ite_num)
                running_adv_loss += loss_adv.item()
                running_per_loss += loss_per.item()

            if ite_num % 200 == 0:
                writer.add_images("inputs", inv_normalize(inputs_v), ite_num)
                writer.add_images("labels", labels_v, ite_num)
                writer.add_images("preds", d0, ite_num)

            # Delete temporary outputs and loss
            del d0, d1, d2, d3, d4, d5, d6, loss2, loss
            if gan_training:
                del dis_fake, dis_real, loss_fake, loss_real, lossD, loss_adv, vgg19_fm_pred, vgg19_fm_label, loss_per

            # Print stats
            tqdm.write(
                "[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train G/sum: %3f, G/bce: %3f, G/bce_tar: %3f, G/adv: %3f, G/percept: %3f, D/fake: %3f, D/real: %3f, D/sum: %3f"
                % (epoch + 1, epoch_num,
                   (i + 1) * batch_size_train, train_num, ite_num,
                   running_loss / ite_num4val, running_bce_loss / ite_num4val,
                   running_tar_loss / ite_num4val, running_adv_loss /
                   ite_num4val, running_per_loss / ite_num4val,
                   running_fake_loss / ite_num4val, running_real_loss /
                   ite_num4val, running_lossD / ite_num4val))

            if ite_num % save_frq == 0:
                # Save checkpoint
                torch.save(
                    net.module.state_dict() if hasattr(
                        net, "module") else net.state_dict(), model_dir +
                    model_name + (("_" + se_type) if se_type else "") +
                    ("_" + dataset_name) +
                    ("_mixup_aug" if mixup_augmentation else "") +
                    ("_heavy_aug" if heavy_augmentation else "") +
                    ("_fullsize" if fullsize_training else "") +
                    ("_multiscale" if multiscale_training else "") +
                    "_bce_itr_%d_train_%3f_tar_%3f.pth" %
                    (ite_num, running_loss / ite_num4val,
                     running_tar_loss / ite_num4val))

                if gan_training:
                    torch.save(
                        netD.module.state_dict() if hasattr(netD, "module")
                        else netD.state_dict(), model_dir + "netD_" +
                        model_name + (("_" + se_type) if se_type else "") +
                        ("_" + dataset_name) +
                        ("_mixup_aug" if mixup_augmentation else "") +
                        ("_heavy_aug" if heavy_augmentation else "") +
                        ("_fullsize" if fullsize_training else "") +
                        ("_multiscale" if multiscale_training else "") +
                        "itr_%d.pth" % (ite_num))

                # Reset stats
                running_loss = 0.0
                running_bce_loss = 0.0
                running_tar_loss = 0.0
                running_adv_loss = 0.0
                running_per_loss = 0.0
                running_fake_loss = 0.0
                running_real_loss = 0.0
                running_lossD = 0.0
                ite_num4val = 0

                net.train()  # resume train
                if gan_training:
                    netD.train()

    writer.close()
    print("Training completed successfully.")
Beispiel #5
0
def main():

    full_sized_testing = False
    model_name = 'u2net'
    # model_name = 'u2netp'

    # image_dir = './test_images/'
    # image_dir = "../detectron2_mask_prediction/predictions/image/"
    image_dir = "../datasets/DUTS-TE/image/"
    # prediction_dir = './test_data/' + model_name + '_results/'
    # model_dir = './saved_models/' + model_name + '/' + model_name + '.pth'

    model_dir = "./saved_models/u2net/u2net.pth"
    # model_dir = "./u2net_mixed_person_n_portraits_heavy_aug_multiscale_bce_itr_8000_train_0.384033_tar_0.046964.pth"
    assert os.path.isfile(model_dir)
    prediction_dir = f"./predictions{'_fullsize' if full_sized_testing else ''}_{os.path.splitext(os.path.basename(model_dir))[0]}/"

    os.makedirs(prediction_dir, exist_ok=True)

    img_name_list = glob.glob(image_dir + '*')

    img_exts = [".jpg", ".jpeg", ".png", ".jfif"]
    img_name_list = list(
        filter(lambda p: os.path.splitext(p)[-1].lower() in img_exts,
               img_name_list))
    # print(img_name_list)

    # --------- 2. dataloader ---------
    # 1. dataloader
    test_salobj_dataset = SalObjDataset(
        img_name_list=img_name_list,
        lbl_name_list=[],
        transform=transforms.Compose(([] if full_sized_testing else [
            RescaleT(320),
        ]) + [
            ToTensorLab(flag=0),
        ]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if (model_name == 'u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3, 1, se_type=None)
    elif (model_name == 'u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3, 1)
    elif (model_name == 'custom'):
        net = CustomNet()
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split("/")[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d7 = 0
        if model_name == "custom":
            d1, d2, d3, d4, d5, d6 = net(inputs_test)
        else:
            d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        save_output(img_name_list[i_test], pred, prediction_dir)

        del d1, d2, d3, d4, d5, d6, d7
Beispiel #6
0
        "G:/EVA5/ToGit/Planercnn/content/planercnn/test/inference/"
    }

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device: %s" % device)

    midas_model = MidasNet(r"G:\EVA5\ToGit\model-f6b98070.pt",
                           non_negative=True)
    midas_model.eval()
    midas_model.to(device)
    #print(midas_model)
    print("Model Loaded")

    # model = CustomNet("model-f46da743.pt", non_negative=True, yolo_cfg=yolo_cfg)
    model = CustomNet("G:\EVA5\ToGit\yolov3-spp-ultralytics.pt",
                      non_negative=True,
                      yolo_cfg=yolo_cfg)

    model.gr = 1.0
    model.hyp = hyp
    model.to(device)

    #print(model)

    # freeze(model, base=True)

    # Training on images of size 64

    batch_size = 256
    img_size = 64
    test_batch_size = 256