Ejemplo n.º 1
0
def train(net, epochs, batch_size, lr, mra_transforms, label_transforms):
    dir_imgs = "./data/after_slice/copy/data/"
    dir_labels = "./data/after_slice/copy/seg/"
    dir_model = "./model"

    utility.sureDir(dir_model)

    #load data
    dataset = NiiDataset(mra_dir=dir_imgs,
                         label_dir=dir_labels,
                         mra_transforms=mra_transforms,
                         label_transforms=label_transforms)

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=4)

    #loss and optimizer
    criterion = SoftDiceLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    #begin train
    for epoch in range(epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
        print('-' * 10)

        net.train()
        dt_size = len(dataloader.dataset)
        epoch_loss = 0
        step = 0

        for img, label in dataloader:
            step += 1
            input = img.type(torch.FloatTensor).cuda()  #因为前面已经为它们to tensor了
            label = label.type(torch.FloatTensor).cuda().squeeze()  # .long()

            # zero the parameter gradients
            optimizer.zero_grad()

            output = net(input)

            out = output[:, 1, :, :, :].squeeze()  #(75,64,64)
            print("dice: %0.3f " % dice_coeff(out, label))

            loss = criterion(out, label)
            loss.backward()
            optimizer.step()
            epoch_loss += float(loss.item())
            print("%d/%d,train_loss:%0.3f" %
                  (step, dt_size // dataloader.batch_size, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss / step))

        torch.save(net.state_dict(), dir_model)
Ejemplo n.º 2
0
print(y.shape)
test_dataset = Data.TensorDataset(x, y)
testloder = Data.DataLoader(dataset=test_dataset,
                            batch_size=1,
                            shuffle=True,
                            num_workers=1)
# ******************************************************************************************
model = VNet()
if args.cuda:
    model.cuda()
if args.optimizer == 'SGD':
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.5)
if args.optimizer == 'ADAM':
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Defining Loss Function
criterion = SoftDiceLoss()
testcriterion = SoftDiceLoss()


def confusionmetric(X, Y):
    x1 = X.reshape(-1)  # pickle mask最大值为1,mhd为255
    # print('x1 is:',np.max(x1))
    y1 = Y.reshape(-1)
    # print('y1 is:', np.max(y1))
    conmat = confusion_matrix(x1, y1)
    com = conmat.flatten()
    TN = com[0]
    FP = com[1]
    FN = com[2]
    TP = com[3]
    return TN, FP, FN, TP
def train_net(model,
              train_data_path,
              train_label_path,
              val_data_path,
              val_label_path,
              n_epochs,
              batch_size,
              # weight,
              checkpoint_dir='weights',
              lr=1e-4):

    # Model on cuda
    if torch.cuda.is_available():
        model = model.cuda()

    # data_transform = transforms.RandomHorizontalFlip()

    train_dataset = NrrdReader3D(train_data_path, train_label_path)
    val_dataset = NrrdReader3D(val_data_path, val_label_path)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    print('''
    Starting training:
        Epochs: {}
        Batch size: {}
        Learning rate: {}
        Training size: {}
        Validation size: {}
    '''.format(n_epochs, batch_size, lr, train_dataset.__len__(),
               val_dataset.__len__()))

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0005)
    # criterion = nn.NLLLoss(weight=weight)
    criterion = SoftDiceLoss(n_classes=3)

    losses = []
    val_losses = []
    for epoch in range(n_epochs):
        losses_avg = train_epoch(model,
                                 train_dataloader,
                                 optimizer,
                                 criterion,
                                 epoch,
                                 n_epochs,
                                 print_freq=100)

        val_losses_avg = val_epoch(model,
                                   val_dataloader,
                                   criterion,
                                   print_freq=10)

        losses.append(round(losses_avg.cpu().numpy().tolist(), 4))
        val_losses.append(round(val_losses_avg.cpu().numpy().tolist(), 4))

        # save model parameters
        parameters_name = str(epoch) + '.pkl'
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, parameters_name))

    # save loss figure
    draw_loss(n_epochs, losses, val_losses)

    # save loss data
    with open('loss/loss.txt', 'w') as loss_file:
        loss_file.write('train loss:\n')
        for i, loss in enumerate(losses):
            output = '{' + str(i) + '}: {' + str(loss) + '}\n'
            loss_file.write(output)
        loss_file.write('-' * 50)
        loss_file.write('\n')
        loss_file.write('validation loss:\n')
        for i, val_loss in enumerate(val_losses):
            output = '{' + str(i) + '}: {' + str(val_loss) + '}\n'
            loss_file.write(output)
Ejemplo n.º 4
0
def train_model(model, optimizer, dataloaders, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_iou = 0.0
    #class_weights=1/torch.Tensor(image_datasets['train'].get_label_ratio())
    cel = CrossEntropyLoss2d(
    )  #weight=(class_weights*((1024**2)/class_weights.sum())).cuda())#my_model.Weighted_BCELoss(pos_weight=[0.0062,1])
    dice = SoftDiceLoss()
    customdice = CustomizedSoftDiceLoss()
    focal = FocalLoss(gamma=2)
    tod_loss = nn.BCEWithLogitsLoss()
    tunnel_loss = nn.BCEWithLogitsLoss()
    hdice = HardDiceLoss()
    # Optimizerの第1引数には更新対象のfc層のパラメータのみ指定
    #optimizer = optim.SGD(list(model.module.conv1.parameters())+list(model.module.fc.parameters()), lr=0.01, momentum=0.9)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                               num_epochs,
                                               eta_min=0.00001,
                                               last_epoch=-1)
    #scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[100], gamma=0.1)
    loss_dict = {
        x: np.array([np.nan] * num_epochs)
        for x in ['train', "validation"]
    }
    iou_dict = {
        x: np.array([np.nan] * num_epochs * 4).reshape(num_epochs, 4)
        for x in ['train', "validation"]
    }
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # 各エポックで訓練+バリデーションを実行
        for phase in ['train', "validation"]:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # training mode
            else:
                model.train(False)  # evaluate mode

            running_loss = 0.0
            running_iou = np.zeros(4)
            n_time = time.time()
            for step, data in enumerate(tqdm(
                    dataloaders[phase])):  #tqdm(dataloaders[phase]):
                inputs, labels = data
                if use_gpu:
                    inputs = torch.Tensor(inputs).to(device)  #.unsqueeze_(1)
                    labels = torch.Tensor(labels).to(device)
                    #tod_labels=torch.Tensor(tod_labels.float()).to(device)
                else:
                    inputs = torch.Tensor(inputs)  #.unsqueeze_(1)
                    labels = torch.Tensor(labels)  #.float()

                batch_size, n_input_channel, img_height, img_width = tuple(
                    inputs.shape)
                optimizer.zero_grad()
                if phase == 'train':
                    outputs = model(inputs)
                else:
                    with torch.no_grad():
                        outputs = model(inputs)
                # label_weight_sum=labels.sum(dim=(0,2,3))
                # label_weight_sum[label_weight_sum==0]=1
                # class_weights=1/label_weight_sum
                # if ip==3:#128.6
                #     loss = cel(outputs, labels.argmax(1))+customdice(outputs, labels)#AZ.log()#+0.1*tod_loss(tod_outputs[:,0],tod_labels[:,0])+0.1*tunnel_loss(tod_outputs[:,1],tod_labels[:,1])#cel(outputs, labels.argmax(1))+dice(outputs,labels).log()#weight=class_weights)(outputs, labels.argmax(1))
                # elif ip==5:#128.8
                #     loss = cel(outputs, labels.argmax(1))+L.lovasz_softmax(F.softmax(outputs), labels.argmax(1), per_image=True)
                # elif ip==2:#128.11
                #     loss=cel(outputs, labels.argmax(1))+dice(outputs, labels, per_image=True)#hdice(F.softmax(outputs),labels)
                # elif ip==4:#128.7
                #     loss=cel(outputs, labels.argmax(1))+dice(outputs, labels)
                # elif ip==7:#128.10
                if epoch < 100:
                    loss = cel(
                        outputs, labels.argmax(1)
                    )  #+0.75*L.lovasz_softmax(F.softmax(outputs), labels.argmax(1), per_image=True)
                else:
                    loss = 0.25 * cel(outputs, labels.argmax(
                        1)) + 0.75 * L.lovasz_softmax(F.softmax(outputs),
                                                      labels.argmax(1),
                                                      per_image=True)
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                running_loss += loss.item()  #* batch_size
                running_iou += iou(outputs, labels,
                                   average=False)  #.item()#*batch_size
                torch.cuda.empty_cache()
            # サンプル数で割って平均を求める
            epoch_loss = running_loss / (
                step + 1)  #dataset_sizes[phase]#dataset_sizes[phase]
            epoch_iou = running_iou / (
                step + 1)  #dataset_sizes[phase]#dataset_sizes[phase]

            print('{} Loss: {:.4f} IOU: {:.4f}'.format(phase, epoch_loss,
                                                       epoch_iou.mean()))
            #print('{} Loss: {:.4f} '.format(phase, epoch_loss))

            loss_dict[phase][epoch] = epoch_loss
            iou_dict[phase][epoch] = epoch_iou
            #visdom
            if phase == "validation":
                output_img = np.zeros((3, img_height, img_width))
                label_img = np.zeros((3, img_height, img_width))
                output_argmax = outputs[0].argmax(0)  #(height,width)
                for idx, cla in enumerate(
                        image_datasets["train"].category_list):
                    if idx == 4:
                        break
                    #for y in range(img_height):
                    #for x in range(img_width):
                    #print(cla,labels[0,idx].cpu().data.numpy().sum(),np.array(image_datasets["train"].category_list[cla]).reshape((3,1,1)))
                    output_img += ((output_argmax
                                    == idx).float().cpu().data.numpy().reshape(
                                        (1, img_height, img_width)) *
                                   np.array(image_datasets["train"].
                                            category_list[cla]).reshape(
                                                (3, 1, 1)))
                    # if cla=="car":
                    #     print(label_img)
                    #     print(labels[0,idx].sum().cpu().data.numpy())
                    #     print(image_datasets["train"].category_list[cla])
                    label_img += (labels[0, idx].cpu().data.numpy().reshape(
                        (1, img_height, img_width)) *
                                  np.array(image_datasets["train"].
                                           category_list[cla]).reshape(
                                               (3, 1, 1)))

                    #if idx==3:
                    #break
                #print(output_img.shape)
                win_output = viz.image(output_img / 255,
                                       win="output",
                                       opts=dict(title='output'))
                win_label = viz.image(label_img / 255,
                                      win="label",
                                      opts=dict(title='label'))
                win_input = viz.image(inputs[0].cpu().data.numpy(),
                                      win="input",
                                      opts=dict(title='input'))

                if epoch > 0:
                    viz.line(X=np.arange(epoch + 1),
                             Y=loss_dict["train"][:epoch + 1],
                             update="replace",
                             win="loss",
                             name="train")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["train"][:epoch + 1].mean(1),
                             update="replace",
                             win="iou",
                             name="train")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["train"][:epoch + 1, 0],
                             update="replace",
                             win="car",
                             name="train")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["train"][:epoch + 1, 1],
                             update="replace",
                             win="signal",
                             name="train")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["train"][:epoch + 1, 2],
                             update="replace",
                             win="pedestrian",
                             name="train")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["train"][:epoch + 1, 3],
                             update="replace",
                             win="lane",
                             name="train")
                    viz.line(X=np.arange(epoch + 1),
                             Y=loss_dict["validation"][:epoch + 1],
                             update="replace",
                             win="loss",
                             name="validation")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["validation"][:epoch + 1].mean(1),
                             update="replace",
                             win="iou",
                             name="validation")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["validation"][:epoch + 1, 0],
                             update="replace",
                             win="car",
                             name="validation")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["validation"][:epoch + 1, 1],
                             update="replace",
                             win="signal",
                             name="validation")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["validation"][:epoch + 1, 2],
                             update="replace",
                             win="pedestrian",
                             name="validation")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["validation"][:epoch + 1, 3],
                             update="replace",
                             win="lane",
                             name="validation")

                else:
                    win_loss = viz.line(X=np.arange(epoch + 1),
                                        Y=loss_dict["train"][:epoch + 1],
                                        win="loss",
                                        name="train",
                                        opts=dict(title='loss'))
                    win_iou = viz.line(X=np.arange(epoch + 1),
                                       Y=iou_dict["train"][:epoch + 1].mean(1),
                                       win="iou",
                                       name="train",
                                       opts=dict(title='iou'))
                    win_car = viz.line(X=np.arange(epoch + 1),
                                       Y=iou_dict["train"][:epoch + 1, 0],
                                       win="car",
                                       name="train",
                                       opts=dict(title='car'))
                    win_sig = viz.line(X=np.arange(epoch + 1),
                                       Y=iou_dict["train"][:epoch + 1, 1],
                                       win="signal",
                                       name="train",
                                       opts=dict(title='sig'))
                    win_ped = viz.line(X=np.arange(epoch + 1),
                                       Y=iou_dict["train"][:epoch + 1, 2],
                                       win="pedestrian",
                                       name="train",
                                       opts=dict(title='ped'))
                    win_lan = viz.line(X=np.arange(epoch + 1),
                                       Y=iou_dict["train"][:epoch + 1, 3],
                                       win="lane",
                                       name="train",
                                       opts=dict(title='lan'))
                    viz.line(X=np.arange(epoch + 1),
                             Y=loss_dict["validation"][:epoch + 1],
                             win="loss",
                             name="validation")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["validation"][:epoch + 1].mean(1),
                             win="iou",
                             name="validation")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["validation"][:epoch + 1, 0],
                             win="car",
                             name="validation")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["validation"][:epoch + 1, 1],
                             win="signal",
                             name="validation")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["validation"][:epoch + 1, 2],
                             win="pedestrian",
                             name="validation")
                    viz.line(X=np.arange(epoch + 1),
                             Y=iou_dict["validation"][:epoch + 1, 3],
                             win="lane",
                             name="validation")

            # deep copy the model
            # 精度が改善したらモデルを保存する
            if phase == "validation" and epoch_iou.mean() > 0.65:
                #print("save weights...",end="")
                best_iou = epoch_iou.mean()
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(
                    model.state_dict(),
                    os.path.join(
                        save_path,
                        "{}_{:.4f}_{:.4f}.pth".format(epoch, epoch_loss,
                                                      epoch_iou.mean())))
                #print("complete")
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val iou: {:.4f}'.format(best_iou))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model
Ejemplo n.º 5
0
    def train(self):
        scheduler = MultiStepLR(
            self.optimizer_g,
            milestones=[int(x) for x in '10,20,30,200'.split(',')])
        if self.loss == 'nll':
            class_weight = np.asarray(
                [
                    0.1,  # background
                    1,  # solid
                    1.5,  # broken
                    1.5,
                    1
                ],
                np.float32)  #fishbone
            class_weight = class_weight / np.sum(class_weight)
            class_weight = torch.from_numpy(class_weight)
            seg_criterion = nn.NLLLoss2d(class_weight).cuda()
        elif self.loss == 'focal':
            seg_criterion = SoftmaxFocalLoss(gamma=2, OHEM_percent=0.05).cuda()
        elif self.loss == 'dice':
            class_weight = np.asarray(
                [
                    0.1,  # background
                    1,  # solid
                    1.5,  # broken
                    1.5,
                    1
                ],
                np.float32)  #fishbone
            class_weight = class_weight / np.sum(class_weight)
            class_weight = torch.from_numpy(class_weight)
            seg_criterion = SoftDiceLoss().cuda()
        """Train StarGAN within a single dataset."""

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_A = []
        fixed_B = []

        for i, (A, B) in enumerate(self.data_loader):
            fixed_A.append(A)
            fixed_B.append(B)
            if i == 1:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_A = torch.cat(fixed_A, dim=0)
        fixed_A = self.to_var(fixed_A, volatile=True)

        fixed_B = torch.cat(fixed_B, dim=0)
        fixed_B = self.to_var(fixed_B, volatile=True)
        fixed_B = fixed_B.unsqueeze(1)
        fixed_B = torch.cat([fixed_B, fixed_B, fixed_B], 1).float()

        fake_image_list = [
            fixed_A, 2 * (fixed_B.float() / (self.class_num - 1) - 0.5)
        ]
        fake_images = torch.cat(fake_image_list, dim=3)
        save_image(self.denorm(fake_images.data.cpu()),
                   os.path.join(self.sample_path, 'target.png'),
                   nrow=1,
                   padding=0)
        print('Translated images and saved into {}..!'.format(
            self.sample_path))

        # lr cache for decaying
        g_lr = self.g_lr
        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()

        for e in range(start, self.num_epochs):
            for i, (real_A, real_B) in enumerate(self.data_loader):
                real_A = self.to_var(real_A)
                real_B = self.to_var(real_B)

                if real_A.shape[0] != self.batch_size:
                    break

                if (i + 1) % self.sample_step == 0:
                    with torch.no_grad():
                        self.G.eval()
                        fake_image_list = [fixed_A]
                        fake_B = self.G(fixed_A)

                        _, out_mask = torch.max(fake_B, dim=1)
                        out_mask = out_mask.unsqueeze(1)
                        out_mask = torch.cat([out_mask, out_mask, out_mask],
                                             1).float()
                        fake_image_list.append(
                            2 * (out_mask / (self.class_num - 1) - 0.5))
                        fake_image_list.append(
                            2 * (fixed_B / (self.class_num - 1) - 0.5))

                        fake_images = torch.cat(fake_image_list, dim=3)
                        save_image(self.denorm(fake_images.data.cpu()),
                                   os.path.join(
                                       self.sample_path,
                                       '{}_{}_fake.png'.format(e + 1, i + 1)),
                                   nrow=1,
                                   padding=0)
                        print('Translated images and saved into {}..!'.format(
                            self.sample_path))

                        self.G.train()

                # for a(day)
                out_mask = self.G(real_A)
                if self.loss == 'nll' or self.loss == 'focal':
                    total_loss = seg_criterion(F.log_softmax(out_mask, dim=1),
                                               real_B)
                elif self.loss == 'dice':
                    onehot_B = F.one_hot(real_B,
                                         self.class_num).permute(0, 3, 1, 2)
                    total_loss = seg_criterion(F.log_softmax(out_mask, dim=1),
                                               onehot_B, class_weight)

                self.reset_grad()
                total_loss.backward()
                self.optimizer_g.step()

                # Logging
                loss = {}
                loss['total_loss'] = total_loss.item()

                # Print out log info
                if (i + 1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e + 1, self.num_epochs, i + 1,
                        iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)

                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(
                                tag, value, e * iters_per_epoch + i + 1)

                # Save model checkpoints
                if (i + 1) % self.model_save_step == 0:
                    torch.save(
                        self.G.state_dict(),
                        os.path.join(self.model_save_path,
                                     '{}_{}_G.pth'.format(e + 1, i + 1)))

            if (e + 1) % self.val_log_step == 0:
                val_iters_per_epoch = len(self.data_loader2)

                for i, (real_A, real_B) in enumerate(self.data_loader2):
                    real_A = self.to_var(real_A)
                    real_B = self.to_var(real_B)

                    if real_A.shape[0] != self.batch_size:
                        break

                    # for a(day)
                    tgt_out = self.G(real_A)
                    if self.loss == 'nll' or self.loss == 'focal':
                        val_loss = seg_criterion(F.log_softmax(tgt_out, dim=1),
                                                 real_B)
                    elif self.loss == 'dice':
                        onehot_B = F.one_hot(real_B).permute(0, 3, 1, 2)
                        val_loss = seg_criterion(F.log_softmax(tgt_out, dim=1),
                                                 onehot_B, class_weight)

                    # Logging
                    loss = {}
                    loss['val_loss'] = val_loss.item()

                    # Print out log info
                    log = "Epoch [{}/{}], val_Iter [{}/{}]".format(
                        e + 1, self.num_epochs, i + 1, val_iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)

                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(
                                tag, value, e * val_iters_per_epoch + i + 1)

            # Decay learning rate
            if (e + 1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr)
                print('Decay learning rate to g_lr: {},.'.format(g_lr))