示例#1
0
def train(args):
    my_dataset = MyDataset("../data/train", transform=x_transforms, target_transform=y_transforms)
    dataloaders = DataLoader(my_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1)
    model = Unet(3, 1).to(device)
    model.train()
    criterion = torch.nn.BCELoss()
    optimizer = optim.Adam(model.parameters())
    num_epochs = args.epochs
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        data_size = len(dataloaders.dataset)
        epoch_loss = 0
        step = 0
        for x, y in dataloaders:
            step += 1
            inputs = x.to(device)
            lables = y.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, lables)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print("%d/%d, train_loss:%0.3f" % (step, (data_size - 1) // dataloaders.batch_size + 1, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
    torch.save(model.state_dict(), 'model_weights.pth')
    return model
示例#2
0
def train():
    save_dir = "/home/FuDawei/NLP/SQUAD/unet/data/"
    train_examples, dev_examples, opt = prepare_train(save_dir)
    epoch = 30
    batch_size = 32
    model = Unet(opt=opt).to(device)
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adamax(parameters, lr=opt["lr"])
    best_score, exact_scores, f1_scores = 0, [], []

    count = 0
    total_loss = 0
    for ep in range(epoch):
        model.train()
        for batch_data in get_batch_data(train_examples, batch_size):
            data = model.get_data(batch_data)
            loss = model(data)
            model.zero_grad()
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(parameters, 10)
            optimizer.step()
            model.reset_parameters()
            count += 1
            # print(loss.item())
            # Evaluate(dev_examples, model)

            total_loss += loss.item()
            if count % 100 == 0:
                print(total_loss / 100)
                total_loss = 0
                # model.eval()
                # Evaluate(dev_examples, model, opt)
            if not opt["fix_word_embedding"]:
                model.reset_parameters()
        print(ep)
        model.eval()
        exact, f1 = Evaluate(dev_examples, model, opt)
        exact_scores.append(exact)
        f1_scores.append(f1)
        if f1 > best_score:
            best_score = f1
            torch.save(model.state_dict(), save_dir + "best_model")
    with open(save_dir + '_f1_scores.pkl', 'wb') as f:
        pkl.dump(f1_scores, f)
    with open(save_dir + '_exact_scores.pkl', 'wb') as f:
        pkl.dump(exact_scores, f)
示例#3
0
    fix_t1ce = generator(t2_fix.float().cuda(), c_fix).data.cpu().numpy()
    fix_t1ce_2 = generator(t2_fix_2.float().cuda(), c_fix).data.cpu().numpy()
    fix_t1ce_3 = generator(t2_fix_3.float().cuda(), c_fix).data.cpu().numpy()

    c_fix = torch.tensor([[0, 0, 1]]).float().cuda()
    fix_t1 = generator(t2_fix.float().cuda(), c_fix).data.cpu().numpy()
    fix_t1_2 = generator(t2_fix_2.float().cuda(), c_fix).data.cpu().numpy()
    fix_t1_3 = generator(t2_fix_3.float().cuda(), c_fix).data.cpu().numpy()
    gen_fix = np.hstack((t2_fix[0][0], fix_flair[0][0], fix_t1ce[0][0],
                         fix_t1[0][0], seg_fix[0][0]))
    gen_fix_2 = np.hstack((t2_fix_2[0][0], fix_flair_2[0][0], fix_t1ce_2[0][0],
                           fix_t1_2[0][0], seg_fix_2[0][0]))
    gen_fix_3 = np.hstack((t2_fix_3[0][0], fix_flair_3[0][0], fix_t1ce_3[0][0],
                           fix_t1_3[0][0], seg_fix_3[0][0]))

    plt.axis('off')
    plt.imshow(
        np.vstack((origin_fix, gen_fix, origin_fix_2, gen_fix_2, origin_fix_3,
                   gen_fix_3)))
    plt.savefig('glips_bw.png', format='png')
    plt.close()

f.close()

model_save_g = './weight/generator_t2_tumor_bw.pth'
model_save_g = './weight/segmentor_t2_tumor_bw.pth'
model_save_d = './weight/discriminator_t2_bw.pth'
torch.save(generator.state_dict(), model_save_g)
torch.save(unet.state_dict(), model_save_g)
torch.save(discriminator.state_dict(), model_save_d)
示例#4
0
文件: train.py 项目: ngocphucck/Unet
def train_net(image_size, batch_size, num_epochs, lr, num_workers, checkpoint):
    train_loader, val_loader = data_loaders(image_size=(image_size,
                                                        image_size),
                                            batch_size=batch_size)
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
    model = Unet().to(device)
    if checkpoint:
        model.load_state_dict(torch.load(checkpoint))

    criterion = DiceLoss().to(device)
    optimizer = Adam(model.parameters(), lr=lr)

    logging.info(f'Start training:\n'
                 f'Num epochs:               {num_epochs}\n'
                 f'Batch size:               {batch_size}\n'
                 f'Learning rate:            {lr}\n'
                 f'Num workers:              {num_workers}\n'
                 f'Scale image size:         {image_size}\n'
                 f'Device:                   {device}\n'
                 f'Checkpoint:               {checkpoint}\n')

    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}: ')
        train_batch_losses = []
        val_batch_losses = []
        best_val_loss = 9999

        for x_train, y_train in tqdm(train_loader):
            x_train = x_train.to(device)
            y_train = y_train.to(device)
            y_pred = model(x_train)

            optimizer.zero_grad()
            loss = criterion(y_pred, y_train)
            train_batch_losses.append(loss.item())
            loss.backward()
            optimizer.step()

        train_losses.append(sum(train_batch_losses) / len(train_batch_losses))
        print(
            f'-----------------------Train loss: {train_losses[-1]} -------------------------------'
        )

        for x_val, y_val in tqdm(val_loader):
            x_val = x_val.to(device)
            y_val = y_val.to(device)
            y_pred = model(x_val)

            loss = criterion(y_pred, y_val)
            val_batch_losses.append(loss.item())

        val_losses.append(sum(val_batch_losses) / len(val_batch_losses))
        print(
            f'-----------------------Val loss: {val_losses[-1]} -------------------------------'
        )
        if val_losses[-1] < best_val_loss:
            best_val_loss = val_losses[-1]
            if not os.path.isdir('weights/'):
                os.mkdir('weights/')
            torch.save(model.state_dict(), f'weights/checkpoint{epoch+1}.pth')
            print(f'Save checkpoint in: weights/checkpoint{epoch+1}.pth')
示例#5
0
def train():
    model = Unet(input_channel=opt.input_channel, cls_num=opt.cls_num)
    model_name = 'Unet_bn'
    train_logger = LogWay(
        datetime.datetime.now().strftime('%Y-%m-%d %H-%M-%S') + '.txt')
    train_data = My_Dataset(opt.train_images, opt.train_masks)
    train_dataloader = DataLoader(train_data,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  num_workers=0)

    if opt.cls_num == 1:
        criterion = torch.nn.BCELoss()
    else:
        criterion = torch.nn.NLLLoss()
    if use_gpu:
        model.cuda()
        if opt.cls_num == 1:
            criterion = torch.nn.BCELoss().cuda()
        else:
            criterion = torch.nn.NLLLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=opt.learning_rate,
                                momentum=opt.momentum,
                                weight_decay=opt.weight_decay)

    for epoch in range(opt.epoch):
        loss_sum = 0
        for i, (data, target) in enumerate(train_dataloader):
            data = Variable(data)
            target = Variable(target)
            if use_gpu:
                data = data.cuda()
                target = target.cuda()
            outputs = model(data)

            if opt.cls_num == 1:
                outputs = F.sigmoid(outputs).view(-1)
                mask_true = target.view(-1)
                loss = criterion(outputs, mask_true)
            else:
                outputs = F.LogSoftmax(outputs, dim=1)
                loss = criterion(outputs, target)

            loss_sum = loss_sum + loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print("epoch:{} batch:{} loss:{}".format(epoch + 1, i,
                                                     loss.item()))
        info = 'Time:{}    Epoch:{}    Loss_avg:{}\n'.format(
            str(datetime.datetime.now()), epoch + 1, loss_sum / (i + 1))
        train_logger.add(info)
        adjusting_rate(optimizer, opt.learning_rate, epoch + 1)
        realepoch = epoch + 1
        if (realepoch % 10 == 0):
            save_name = datetime.datetime.now().strftime(
                '%Y-%m-%d %H-%M-%S') + ' ' + model_name + str(
                    realepoch) + '.pt'
            torch.save(model.state_dict(), save_name)
示例#6
0
            for img, masks in v_bar:
                if is_gpu:
                    img, masks = img.cuda(), masks.cuda()
                masks_pr = net(img)
                loss = criterion(masks_pr, masks)
                valid_loss += loss.item() * img.shape[0]
                v_bar.set_postfix(ordered_dict={'valid_loss': loss.item()})
        # record & update
        train_loss = train_loss / train_loader.dataset.__len__()
        valid_loss = valid_loss / valid_loader.dataset.__len__()
        train_loss_list.append(train_loss)
        valid_loss_list.append(valid_loss)
        lr_list.append(
            [param_group['lr'] for param_group in optimizer.param_groups][0])
        print('epoch: {}, train_loss: {:.6f}, valid_loss: {:.6f}, lr: {:.6f}'.
              format(epoch, train_loss, valid_loss, lr_list[-1]))
        if valid_loss < valid_loss_min:
            print('model update, saving...')
            torch.save(net.state_dict(), model_id)
            valid_loss_min = valid_loss
        scheduler.step(valid_loss)
    # train and validate over
    # record
    with open(model_id + '.rec', 'w') as fout:
        fout.write('trainloss:\n')
        fout.write(' '.join([str(x) for x in train_loss_list]) + '\n')
        fout.write('validloss:\n')
        fout.write(' '.join([str(x) for x in valid_loss_list]) + '\n')
        fout.write('lr:\n')
        fout.write(' '.join([str(x) for x in lr_list]) + '\n')
        writer.add_scalars('Train/GAN_loss', {
            'A': loss_G_A.item(),
            'B': loss_G_B.item()
        },
                           epoch * len(train_dataloader) + batch_idx)
        writer.add_scalars('Train/CYCLE_loss', {
            'A': loss_cycle_A.item(),
            'B': loss_cycle_B.item()
        },
                           epoch * len(train_dataloader) + batch_idx)
        writer.add_scalars('Train/D_loss', {
            'A': loss_D_A,
            'B': loss_D_B
        },
                           epoch * len(train_dataloader) + batch_idx)
        writer.add_scalars('Train/CC', {
            'A': cc_A.item(),
            'B': cc_B.item()
        },
                           epoch * len(train_dataloader) + batch_idx)
        writer.add_scalars('Train/vertex_loss', {
            'A': loss_vertex_A.item(),
            'B': loss_vertex_B.item()
        },
                           epoch * len(train_dataloader) + batch_idx)

    torch.save(netG_A.state_dict(), "model/netG_A_pair.pkl")
    torch.save(netG_B.state_dict(), "model/netG_B_pair.pkl")
    torch.save(netD_A.state_dict(), "model/netD_A_pair.pkl")
    torch.save(netD_B.state_dict(), "model/netD_B_pair.pkl")
示例#8
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    ### Hyperparameters Setting ###
    epochs = args.epochs
    batch_size = args.batch_size
    num_workers = args.num_workers
    valid_ratio = args.valid_ratio
    threshold = args.threshold
    separable = args.separable
    down_method = args.down_method
    up_method = args.up_method
    ### DataLoader ###
    dataset = DataSetWrapper(batch_size, num_workers, valid_ratio)
    train_dl, valid_dl = dataset.get_data_loaders(train=True)

    ### Model: U-Net ###
    model = Unet(input_dim=1,
                 separable=separable,
                 down_method=down_method,
                 up_method=up_method)
    model.summary()
    model = nn.DataParallel(model).to(device)

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     T_max=len(train_dl),
                                                     eta_min=0,
                                                     last_epoch=-1)
    criterion = nn.BCEWithLogitsLoss()
    train_losses = []
    val_losses = []

    ###Train & Validation start ###
    mIOU_list = []
    best_mIOU = 0.
    step = 0

    for epoch in range(epochs):

        ### train ###
        pbar = tqdm(train_dl)
        model.train()
        losses = []

        for (img, label) in pbar:
            optimizer.zero_grad()
            img, label = img.to(device), label.to(device)
            pred = model(img)
            # pred = Padding()(pred, label.size(3))
            loss = criterion(pred, label)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            pbar.set_description(
                f'E: {epoch + 1} | L: {loss.item():.4f} | lr: {scheduler.get_lr()[0]:.7f}'
            )
        scheduler.step()
        if (epoch + 1) % 10:
            losses = sum(losses) / len(losses)
            train_losses.append(losses)

        ### validation ###
        with torch.no_grad():
            model.eval()
            mIOU = []
            losses = []
            pbar = tqdm(valid_dl)
            for (img, label) in pbar:
                img, label = img.to(device), label.to(device)
                pred = model(img)

                loss = criterion(pred, label)

                mIOU.append(get_IOU(pred, label, threshold=threshold))
                losses.append(loss.item())

            mIOU = sum(mIOU) / len(mIOU)
            mIOU_list.append(mIOU)
            if (epoch + 1) % 10:
                losses = sum(losses) / len(losses)
                val_losses.append(losses)

            print(
                f'VL: {loss.item():.4f} | mIOU: {100 * mIOU:.1f}% | best mIOU: {100 * best_mIOU:.1f}'
            )

        ### Early Stopping ###
        if mIOU > best_mIOU:
            best_mIOU = mIOU
            save_state = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses,
                'best_mIOU': best_mIOU
            }
            torch.save(
                save_state,
                f'./checkpoint/{down_method}_{up_method}_{separable}.ckpt')
            step = 0
        else:
            step += 1
            if step > args.patience:
                print('Early stopped...')
                return
示例#9
0
文件: main.py 项目: oosky9/2D-Unet
def train(args, x_train, y_train, x_valid, y_valid):

    writer = SummaryWriter()

    best_dice = 0 

    model = Unet().to(args.device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    bce_loss = torch.nn.BCELoss()

    train_dataloader = load_dataset(x_train, y_train, args.batch_size, True)
    valid_dataloader = load_dataset(x_valid, y_valid, args.batch_size, False)

    result = {}
    result['train/BCE'] = []
    result['train/Dice'] = []
    result['valid/BCE'] = []
    result['valid/Dice'] = []

    for epoch in range(args.epochs):
        print('train step: epoch {}'.format(str(epoch+1).zfill(4)))

        train_bce = []
        train_dice = []

        for inp_im, lab_im in tqdm(train_dataloader):
            inp_im = inp_im.to(args.device)
            lab_im = lab_im.to(args.device)

            pred = model(inp_im)

            bce = bce_loss(pred, lab_im)
            dice = calc_dice(pred, lab_im)

            train_bce.append(bce.item())
            train_dice.append(dice)

            model.zero_grad()
            bce.backward()
            optimizer.step()
        
        result['train/BCE'].append(statistics.mean(train_bce))
        result['train/Dice'].append(statistics.mean(train_dice))

        writer.add_scalar('train/BinaryCrossEntropy', result['train/BCE'][-1], epoch+1)
        writer.add_scalar('train/DiceScore', result['train/Dice'][-1], epoch+1)

        print('BCE: {}, Dice: {}'.format(result['train/BCE'][-1], result['train/Dice'][-1]))

        if (epoch+1) % 10 == 0 or (epoch+1) == 1:

            with torch.no_grad():
                print('valid step: epoch {}'.format(str(epoch+1).zfill(4)))
                model.eval()

                valid_bce = []
                valid_dice = []
                for inp_im, lab_im in tqdm(valid_dataloader):
                    inp_im = inp_im.to(args.device)
                    lab_im = lab_im.to(args.device)

                    pred = model(inp_im)

                    bce = bce_loss(pred, lab_im)
                    dice = calc_dice(pred, lab_im)

                    valid_bce.append(bce.item())
                    valid_dice.append(dice)
                
                result['valid/BCE'].append(statistics.mean(valid_bce))
                result['valid/Dice'].append(statistics.mean(valid_dice))

                writer.add_scalar('valid/BinaryCrossEntropy', result['valid/BCE'][-1], epoch+1)
                writer.add_scalar('valid/DiceScore', result['valid/Dice'][-1], epoch+1)

                print('BCE: {}, Dice: {}'.format(result['valid/BCE'][-1], result['valid/Dice'][-1]))


                if best_dice < result['valid/Dice'][-1]:
                    best_dice = result['valid/Dice'][-1]

                    best_model_name = os.path.join(args.save_model_path, f'best_model_{epoch + 1:04}.pth')
                    print('save model ==>> {}'.format(best_model_name))
                    torch.save(model.state_dict(), best_model_name)
示例#10
0
    ### then compute mean loss for current epoch and save ckp ###
    epoch_train_loss, epoch_train_score = np.mean(running_train_loss), np.mean(running_train_score)
    print('Train loss : {} iou : {}'.format(epoch_train_loss, epoch_train_score))
    train_loss.append(epoch_train_loss)
    train_iou.append(epoch_train_score)

    epoch_val_loss, epoch_val_score = np.mean(running_val_loss), np.mean(running_val_score)
    print('Validation loss : {} iou : {}'.format(epoch_val_loss, epoch_val_score))
    val_loss.append(epoch_val_loss)
    val_iou.append(epoch_val_score)

    # create checkpoint variable and add important data
    checkpoint = {
        'epoch': i,
        'valid_loss_min': epoch_val_loss,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    # save checkpoint
    save_ckp(checkpoint, False, checkpoint_path + f'/model_{i}.pth', best_model_path)
    ## TODO: save the model if validation loss has decreased
    if epoch_val_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min, epoch_val_loss))
        # save checkpoint as best model
        save_ckp(checkpoint, True, checkpoint_path + f'/model_{i}.pth', best_model_path)
        valid_loss_min = epoch_val_loss

    i = i + 1

def unet_train():

    batch_size = 1
    num_epochs = [5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000]
    num_workers = 2
    lr = 0.0001

    losslist = ['dice']  # ['focal', 'bce', 'dice', 'lovasz']
    optimlist = ['adam']  # ['adam', 'sgd']
    iflog = True

    SC_root_dir = '../dataset-EdmSealedCrack-512'
    train_files, val_files, test_files = myutils.organize_SC_files(SC_root_dir)

    train_RC_dataset = DatasetRealCrack('../dataset-EdmCrack600-512/A/train',
                                        transform=transform)
    train_SC_dataset = DatasetSealedCrack(files=train_files,
                                          root_dir=SC_root_dir,
                                          transform=data_Train_transforms)
    val_RC_dataset = DatasetRealCrack('../dataset-EdmCrack600-512/A/val',
                                      transform=transform)
    val_SC_dataset = DatasetSealedCrack(files=val_files,
                                        root_dir=SC_root_dir,
                                        transform=data_Test_transforms)

    train_loader = torch.utils.data.DataLoader(ConcatDataset(
        train_RC_dataset, train_SC_dataset),
                                               batch_size=2,
                                               shuffle=True,
                                               num_workers=2)

    criterion = nn.BCELoss()
    focallos = FocalLoss(gamma=2)
    doubleFocalloss = focalloss.FocalLoss_2_datasets(gamma=2)

    epoidx = -1
    for los in losslist:
        for opt in optimlist:
            start = time.time()
            print(los, opt)
            torch.manual_seed(77)
            torch.cuda.manual_seed(77)
            #################
            #unet = Unet_SpatialPyramidPooling(3).cuda()
            #################
            unet = Unet(3).cuda()
            SC_classifier = classifier(64, 2).cuda()
            RC_classifier = classifier(64, 2).cuda()

            ##################
            #unet = smp.Unet('resnet34', encoder_weights='imagenet').cuda()
            #unet.segmentation_head = torch.nn.Sequential().cuda()
            #SC_classifier = classifier(16, 2).cuda()
            #RC_classifier = classifier(16, 2).cuda()

            #UNCOMMENT TO KEEP TRAINING THE BEST MODEL
            prev_epoch = 0  # if loading model 58, change to prev_epoch = 58. When saving the model, it is going to be named as 59, 60, 61...
            #unet.load_state_dict(torch.load('trained_models/unet_adam_dice_58.pkl'))
            #SC_classifier.load_state_dict(torch.load('trained_models/SC_classifier_adam_dice_58.pkl'))
            #RC_classifier.load_state_dict(torch.load('trained_models/RC_classifier_adam_dice_58.pkl'))

            history = []
            if 'adam' in opt:
                optimizer = torch.optim.Adam(unet.parameters(), lr=lr)
            elif 'sgd' in opt:
                optimizer = torch.optim.SGD(unet.parameters(),
                                            lr=10 * lr,
                                            momentum=0.9)

            logging.basicConfig(filename='./logs/logger_unet.log',
                                level=logging.INFO)

            total_step = len(train_loader)
            epoidx += 1
            for epoch in range(num_epochs[epoidx]):
                totalloss = 0
                for i, (realCrack_batch,
                        sealedCrack_batch) in enumerate(train_loader):
                    SC_images = sealedCrack_batch[0].cuda()
                    SC_masks = sealedCrack_batch[1].cuda()
                    RC_images = realCrack_batch[0].cuda()
                    RC_masks = realCrack_batch[1].cuda()
                    SC_encoder = unet(SC_images)
                    RC_encoder = unet(RC_images)
                    #############
                    SC_outputs = SC_classifier(SC_encoder)
                    RC_outputs = RC_classifier(RC_encoder)
                    #############
                    #Deep lab v3
                    #SC_outputs = SC_classifier(SC_encoder['out'])
                    #RC_outputs = RC_classifier(RC_encoder['out'])
                    ##############
                    if 'bce' in los:
                        masks = onehot(masks)
                        loss = criterion(outputs, masks)
                    elif 'dice' in los:
                        branch_RC = {'outputs': RC_outputs, 'masks': RC_masks}
                        branch_SC = {'outputs': SC_outputs, 'masks': SC_masks}
                        loss = dice_loss_2_datasets(branch_RC, branch_SC)
                        #masks = onehot(masks)
                        #loss = dice_loss(outputs, masks)
                    elif 'lovasz' in los:
                        masks = onehot(masks)
                        loss = L.lovasz_hinge(outputs, masks)
                    elif 'focal' in los:
                        #loss = focallos(outputs, masks.long())
                        branch_RC = {
                            'outputs': RC_outputs,
                            'masks': RC_masks.long()
                        }
                        branch_SC = {
                            'outputs': SC_outputs,
                            'masks': SC_masks.long()
                        }
                        loss = doubleFocalloss(branch_RC, branch_SC)
                    totalloss += loss * RC_images.size(0)  #*2?
                    #print(RC_images.size(0))

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    if i % 10 == 0:
                        print(epoch, i)
                        print("total loss: ", totalloss)
                    if i % 1000 == 0:
                        print("Epoch:%d;     Iteration:%d;      Loss:%f" %
                              (epoch, i, loss))

                    if i + 1 == total_step:  # and epoch%1==0: #and val_miou>0.85:
                        torch.save(
                            unet.state_dict(),
                            './trained_models/unet_' + opt + '_' + los + '_' +
                            str(epoch + 1 + prev_epoch) + '.pkl')
                        torch.save(
                            RC_classifier.state_dict(),
                            './trained_models/RC_classifier_' + opt + '_' +
                            los + '_' + str(epoch + 1 + prev_epoch) + '.pkl')
                        torch.save(
                            SC_classifier.state_dict(),
                            './trained_models/SC_classifier_' + opt + '_' +
                            los + '_' + str(epoch + 1 + prev_epoch) + '.pkl')
                history_np = np.array(history)
                np.save('./logs/unet_' + opt + '_' + los + '.npy', history_np)
            end = time.time()
            print((end - start) / 60)
示例#12
0
            correct += (
                predicted.cpu() == labels.cpu()).squeeze().sum().numpy()

        avg_correct = correct / len(eval_dataloader) / n_size
        val_loss, pixel_acc_avg, mean_iou_avg, fw_iou_avg = eval_net(
            model, eval_dataloader, device)
        writer.add_scalar('Loss/test', loss_sigma / len(eval_dataloader),
                          global_step)
        writer.add_scalar('fw_iou/test', fw_iou_avg, global_step)
        writer.add_scalar('acc/test', avg_correct, global_step)

        if epoch == 0:
            _fw_iou_avg = fw_iou_avg
            net_save_path = 'checkpoints/student_net' + '.pth'
            model_file = {
                'net': model.state_dict(),
                'correct': correct / len(eval_dataloader),
                'epoch': epoch + 1
            }
            torch.save(model_file, net_save_path)
            print(
                '-------------------------{} set correct:{:.4%}---------------------'
                .format('Valid', avg_correct))
            print(
                '-------------------------{} set fw_iou:{:.4%}---------------------'
                .format('Valid', fw_iou_avg))
        elif fw_iou_avg > _fw_iou_avg:
            _fw_iou_avg = fw_iou_avg
            net_save_path = 'checkpoints/student_net' + '.pth'
            model_file = {
                'net': model.state_dict(),