コード例 #1
0
def train(args, model, optimizer, dataloader_train, dataloader_val):
    writer = SummaryWriter(
        comment=''.format(args.optimizer, args.context_path))
    if args.loss == 'dice':
        loss_func = DiceLoss()
    elif args.loss == 'crossentropy':
        loss_func = torch.nn.CrossEntropyLoss(ignore_index=255)
    max_miou = 0
    step = 0
    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(args.num_epochs):
        lr = poly_lr_scheduler(optimizer,
                               args.learning_rate,
                               iter=epoch,
                               max_iter=args.num_epochs)
        model.train()
        tq = tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        loss_record = []
        for i, (data, label) in enumerate(dataloader_train):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda().long()
            with torch.cuda.amp.autocast():
                output, output_sup1, output_sup2 = model(data)
                loss1 = loss_func(output, label)
                loss2 = loss_func(output_sup1, label)
                loss3 = loss_func(output_sup2, label)
                loss = loss1 + loss2 + loss3
                tq.update(args.batch_size)
                tq.set_postfix(loss='%.6f' % loss)
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            step += 1
            writer.add_scalar('loss_step', loss, step)
            loss_record.append(loss.item())
            scaler.update()
        tq.close()
        loss_train_mean = np.mean(loss_record)
        writer.add_scalar('epoch/loss_epoch_train', float(loss_train_mean),
                          epoch)
        print('loss for train : %f' % (loss_train_mean))
        if epoch % args.checkpoint_step == 0 and epoch != 0:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(model.state_dict(),
                       os.path.join(args.save_model_path, 'model.pth'))

        if epoch % args.validation_step == 0 and epoch != 0:
            precision, miou = val(args, model, dataloader_val)
            if miou > max_miou:
                max_miou = miou
                #import os
                os.makedirs(args.save_model_path, exist_ok=True)
                torch.save(
                    model.state_dict(),
                    os.path.join(args.save_model_path, 'best_dice_loss.pth'))
            writer.add_scalar('epoch/precision_val', precision, epoch)
            writer.add_scalar('epoch/miou val', miou, epoch)
コード例 #2
0
def train(args, model, optimizer, dataloader_train):
    step = 0
    for epoch in range(args.num_epochs):
        lr = poly_lr_scheduler(optimizer,
                               args.learning_rate,
                               iter=epoch,
                               max_iter=args.num_epochs)
        model.train()
        loss_record = []
        for i, (data, label) in enumerate(dataloader_train):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()
            output = model(data)
            loss = torch.nn.CrossEntropyLoss()(output, label[:, 0])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step += 1
            loss_record.append(loss.item())
            if i % 50 == 0:
                average = sum(loss_record) / len(loss_record)
                print('epoch:%f' % epoch, 'step:%f' % i, 'loss:%f' % average)
        loss_train_mean = np.mean(loss_record)
        print('loss for train : %f' % (loss_train_mean))
        if epoch % args.checkpoint_step == 0 and epoch != 0:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(
                model.module.state_dict(),
                os.path.join(args.save_model_path,
                             'epoch_{}.pth'.format(epoch)))
コード例 #3
0
def train(args, model, optimizer, dataloader_train, dataloader_val, csv_path):
    writer = SummaryWriter()
    step = 0
    for epoch in range(args.num_epochs):
        lr = poly_lr_scheduler(optimizer, args.learning_rate, iter=epoch, max_iter=args.num_epochs)
        model.train()
        tq = tqdm.tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        loss_record = []
        for i,(data, label) in enumerate(dataloader_train):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()
            output, output_sup1, output_sup2 = model(data)
            ################################
            # from PIL import Image
            # import numpy as np
            # print(output.size())
            # temp = np.reshape(output.detach().cpu().numpy(), (32, 640, 640))
            # # print(type(temp))
            # temp = np.transpose(temp, [1, 2, 0])
            # temp = np.asarray(temp[:, :, 0])
            # # print(type(temp))
            # temp = np.asarray(temp < 0.05)
            # new_im = Image.fromarray(temp)
            # new_im.save('l.gif')
            #################################
            loss1 = torch.nn.BCEWithLogitsLoss()(output, label)
            loss2 = torch.nn.BCEWithLogitsLoss()(output_sup1, label)
            loss3 = torch.nn.BCEWithLogitsLoss()(output_sup2, label)
            loss = loss1 + loss2 + loss3
            tq.update(args.batch_size)
            tq.set_postfix(loss='%.6f' % loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step += 1
            writer.add_scalar('loss_step', loss, step)
            loss_record.append(loss.item())
        tq.close()
        loss_train_mean = np.mean(loss_record)
        writer.add_scalar('loss_epoch_train', float(loss_train_mean), epoch)
        print('loss for train : %f' % (loss_train_mean))
        if epoch % args.checkpoint_step == 0 and epoch != 0:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(model.module.state_dict(), os.path.join(args.save_model_path, 'epoch_{}.pth'.format(epoch)))
        if epoch % args.validation_step == 0:
            dice = val(args, model, dataloader_val, csv_path)
            writer.add_scalar('precision_val', dice, epoch)
コード例 #4
0
def train(args, model, optimizer, dataloader_train, csv_path):
    writer = SummaryWriter()
    step = 0
    for epoch in range(args.epoch_start_i, args.num_epochs):
        lr = poly_lr_scheduler(optimizer, args.learning_rate, iter=epoch, max_iter=args.num_epochs)
        model.train()
        tq = tqdm.tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        loss_record = []
        for i, (data, label) in enumerate(dataloader_train):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()

            # p = label
            # for i in range(args.batch_size):
            #     predict = np.array(reverse_one_hot(p[i]))
            #     print(predict)
            #     print('label')

            output = model(data)
            # p = output
            # for i in range(args.batch_size):
            #     predict = np.array(reverse_one_hot(p[i]))
            #     print(predict)
            #     print('output')

            loss = torch.nn.BCEWithLogitsLoss()(output, label)
            tq.update(args.batch_size)
            tq.set_postfix(loss='%.6f' % loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step += 1
            writer.add_scalar('loss_step', loss, step)
            loss_record.append(loss.item())
        tq.close()
        loss_train_mean = np.mean(loss_record)
        writer.add_scalar('loss_epoch_train', float(loss_train_mean), epoch)
        print('loss for train : %f' % loss_train_mean)
        if epoch % args.checkpoint_step == 0:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(model.module.state_dict(), os.path.join(args.save_model_path, 'epoch_{}.pth'.format(epoch)))
        if epoch % args.validation_step == 0:
            dice = val(args, model, csv_path)
            writer.add_scalar('precision_val', dice, epoch)
コード例 #5
0
def train(args, model, optimizer, dataloader_train):

    step = 0
    for epoch in range(1, args.num_epochs):
        lr = poly_lr_scheduler(optimizer,
                               args.learning_rate,
                               iter=epoch,
                               max_iter=args.num_epochs)
        model.train()

        loss_record = []
        for i, (data, label) in enumerate(dataloader_train):

            data = data.cuda()
            label = label.cuda()
            output = model(data)
            print type(output)
            criterion = SegmentationMultiLosses(nclass=2).cuda()

            loss = criterion(output, label[:, 0])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step += 1

            loss_record.append(loss.item())
            if i % 50 == 0:
                loss_averge = sum(loss_record) / len(loss_record)
                print('epoch:%f' % epoch, 'step:%f' % i,
                      'loss:%f' % loss_averge)
                #torch.save(model.state_dict(), os.path.join(args.save_model_path, 'epoch_{}.pth'.format(epoch)))

        loss_train_mean = np.mean(loss_record)

        print('loss for train : %f' % (loss_train_mean))
        if epoch % args.checkpoint_step == 0:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(
                model.state_dict(),
                os.path.join(args.save_model_path,
                             'epoch_{}.pth'.format(epoch)))
コード例 #6
0
def train(args, model, optimizer, train_img_path, train_label_path, val_img_path, val_label_path, csv_path):
    writer = SummaryWriter()
    step = 0
    for epoch in range(args.epoch_start_i, args.num_epochs):
        dataset_train = ADE(train_img_path, train_label_path, scale=(args.crop_height, args.crop_width), mode='train')
        dataloader_train = DataLoader(
            dataset_train,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers
        )
        lr = poly_lr_scheduler(optimizer, args.learning_rate, iter=epoch, max_iter=args.num_epochs)
        model.train()
        tq = tqdm.tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        loss_record = []
        for i, (data, label) in enumerate(dataloader_train):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()
            output = model(data)  # 5,3,480,640
            loss = torch.nn.BCEWithLogitsLoss()(output, label)
            tq.update(args.batch_size)
            tq.set_postfix(loss='%.6f' % loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step += 1
            writer.add_scalar('loss_step', loss, step)
            loss_record.append(loss.item())
        tq.close()
        loss_train_mean = np.mean(loss_record)
        writer.add_scalar('loss_epoch_train', float(loss_train_mean), epoch)
        print('loss for train : %f' % loss_train_mean)
        if epoch % args.checkpoint_step == 0:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(model.module.state_dict(), os.path.join(args.save_model_path, 'epoch_{}.pth'.format(epoch)))
        if epoch % args.validation_step == 0:
            dice = val(args, model, val_img_path, val_label_path, csv_path)
            writer.add_scalar('precision_val', dice, epoch)
コード例 #7
0
def train(args, model, optimizer, criterion, dataloader_train, dataloader_val):
    writer = SummaryWriter(args.log_path)
    print(args)

    model.train()
    max_iter = len(dataloader_train)
    max_segmiou = -1
    for i, (data, seg, pos) in enumerate(dataloader_train):
        lr = poly_lr_scheduler(optimizer,
                               args.learning_rate,
                               iter=i,
                               max_iter=max_iter)
        seg_pred, pos_pred = model(data)
        pos = pos.cuda()
        loss1 = criterion(pos_pred, pos)
        seg = seg.cuda()
        loss2 = criterion(seg_pred, seg)
        loss = loss1 + loss2
        loss = torch.mean(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        writer.add_scalar('loss_step', loss.item(), i)
        print('iter:{}/{} lr:{:.5f} loss:{:.5f}'.format(
            i, max_iter, lr, loss.item()))

        if i % args.validation_step + 1 == args.validation_step:
            segmiou, posmiou = val(args, model, dataloader_val)
            writer.add_scalar('segmiou', segmiou, i)
            writer.add_scalar('posmiou', posmiou, i)
            if segmiou > max_segmiou:
                max_segmiou = segmiou
                torch.save(model.module.state_dict(),
                           os.path.join(args.save_model_path, 'best.pth'))
            model.train()
    torch.save(model.module.state_dict(),
               os.path.join(args.save_model_path, 'last.pth'))
コード例 #8
0
def train(args, model, optimizer, dataloader_train, dataloader_val):
    # E' l'oggetto che ci stampa a schermo ciò chee acca
    writer = SummaryWriter(
        comment=''.format(args.optimizer, args.context_path))
    # settiamo la loss
    if args.loss == 'dice':
        # classe definita da loro nel file loss.py
        loss_func = DiceLoss()
    elif args.loss == 'crossentropy':
        loss_func = torch.nn.CrossEntropyLoss(ignore_index=255)
    # inizializziamo i contatori
    max_miou = 0
    step = 0
    # iniziamo il training
    for epoch in range(args.num_epochs):
        # inizializziamo il learning rate
        lr = poly_lr_scheduler(optimizer,
                               args.learning_rate,
                               iter=epoch,
                               max_iter=args.num_epochs)
        # iniziamo il train
        model.train()
        # cosa grafica sequenziale
        tq = tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        # Crediamo che sia la lista delle loss di ogni batch:
        loss_record = []

        # per ogni immagine o per ogni batch??? Ipotizziamo sia su ogni singolo mini-batch
        for i, (data, label) in enumerate(dataloader_train):

            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda().long()

            # Prendiamo:
            # - risultato finale dopo FFM
            # - risultato del 16xdown del contextPath, dopo ARM, modificati (?)
            # - risultato del 32xdown del contextPath, dopo ARM, modificati (?)
            output, output_sup1, output_sup2 = model(data)

            # Calcoliammo la loss
            # Principal loss function (l_p in the paper):
            loss1 = loss_func(output, label)
            # Auxilary loss functions (l_i, for i=2, 3 in the paper):
            loss2 = loss_func(output_sup1, label)
            loss3 = loss_func(output_sup2, label)

            # alfa = 1, compute equation 2:
            loss = loss1 + loss2 + loss3

            # codice grafica
            tq.update(args.batch_size)
            tq.set_postfix(loss='%.6f' % loss)
            '''
            zero_grad clears old gradients from the last step (otherwise you’d just accumulate the gradients from all loss.backward() calls).
            loss.backward() computes the derivative of the loss w.r.t. the parameters (or anything requiring gradients) using backpropagation.
            opt.step() causes the optimizer to take a step based on the gradients of the parameters.
            '''
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # incrementiamo il contatore
            step += 1
            # aggiungiamo i valori per il grafico
            writer.add_scalar('loss_step', loss, step)
            loss_record.append(loss.item())
        tq.close()
        loss_train_mean = np.mean(loss_record)
        writer.add_scalar('epoch/loss_epoch_train', float(loss_train_mean),
                          epoch)
        print('loss for train : %f' % (loss_train_mean))

        # salva il modello fin ora trainato
        if epoch % args.checkpoint_step == 0 and epoch != 0:
            import os
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(model.state_dict(),
                       os.path.join(args.save_model_path, 'model.pth'))

        # compute validation every 10 epochs
        if epoch % args.validation_step == 0 and epoch != 0:

            # chaiam la funzione val che da in output le metriche
            precision, miou = val(args, model, dataloader_val)

            # salva miou max e salva il relativo miglior modello
            if miou > max_miou:
                max_miou = miou
                import os
                os.makedirs(args.save_model_path, exist_ok=True)
                torch.save(
                    model.state_dict(),
                    os.path.join(args.save_model_path, 'best_dice_loss.pth'))

            writer.add_scalar('epoch/precision_val', precision, epoch)
            writer.add_scalar('epoch/miou val', miou, epoch)
    # proviamo a terminare il writer per vedere se stampa qualcosa
    writer.close()
コード例 #9
0
def train(args, model, optimizer, dataloader_train, dataloader_val_train,
          dataloader_test):
    writer = SummaryWriter(log_dir='runs_50_adadelta',
                           comment=''.format(args.optimizer,
                                             args.context_path))
    if args.loss == 'dice':
        loss_func = DiceLoss()
    elif args.loss == 'crossentropy':
        loss_func = torch.nn.CrossEntropyLoss()
    max_miou = 0
    step = 0
    for epoch in range(args.epoch_start_i, args.num_epochs):
        lr = poly_lr_scheduler(optimizer,
                               args.learning_rate,
                               iter=epoch,
                               max_iter=args.num_epochs)
        model.train()
        tq = tqdm.tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        loss_record = []
        for i, (data, label) in enumerate(dataloader_train):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()
            output, output_sup1, output_sup2 = model(data)
            loss1 = loss_func(output, label)
            loss2 = loss_func(output_sup1, label)
            loss3 = loss_func(output_sup2, label)
            loss = loss1 + loss2 + loss3
            tq.update(args.batch_size)
            tq.set_postfix(loss='%.6f' % loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step += 1
            writer.add_scalar('loss_step', loss, step)
            loss_record.append(loss.item())
        tq.close()
        loss_train_mean = np.mean(loss_record)
        writer.add_scalar('epoch/loss_epoch_train', float(loss_train_mean),
                          epoch)
        print('loss for train : %f' % (loss_train_mean))
        if epoch % args.checkpoint_step == 0 and epoch != 0:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(
                model.module.state_dict(),
                os.path.join(args.save_model_path, 'latest_dice_loss.pth'))

        if epoch % args.validation_step == 0:
            #precision, miou = val(args, model, dataloader_val)
            oa, miou, cm, cks, f1 = val(args, model, dataloader_val_train,
                                        'train')
            oa_test, miou_test, cm_test, cks_test, f1_test = val(
                args, model, dataloader_test, 'test')
            if miou > max_miou:
                max_miou = miou
                torch.save(
                    model.module.state_dict(),
                    os.path.join(args.save_model_path, 'best_dice_loss.pth'))
            #writer.add_scalar('epoch/precision_val', precision, epoch)
            writer.add_scalar('epoch/oa_train', oa, epoch)
            writer.add_scalar('epoch/oa_test', oa_test, epoch)
            #writer.add_scalar('epoch/miou val', miou, epoch)
            writer.add_scalar('epoch/miou_train', miou, epoch)
            writer.add_scalar('epoch/miou_test', miou_test, epoch)
            writer.add_scalar('epoch/cks_train', cks, epoch)
            writer.add_scalar('epoch/cks_test', cks_test, epoch)
            writer.add_scalar('epoch/f1_train', f1, epoch)
            writer.add_scalar('epoch/f1_test', f1_test, epoch)
            with open(os.path.join(args.save_model_path,
                                   'classification_results.txt'),
                      mode='a') as f:
                f.write('epoch: ' + str(epoch) + '\n')
                # f.write('train time:\t' + str(train_time))
                # f.write('\ntest time:\t' + str(test_time))
                f.write('\nmiou:\t' + str(miou))
                f.write('\noverall accuracy:\t' + str(oa))
                f.write('\ncohen kappa:\t' + str(cks))
                f.write('\nconfusion matrix:\n')
                f.write(str(cm))
                f.write('\nf1:\t' + str(f1))
                f.write('\n\n')
コード例 #10
0
def train(args, model, optimizer, dataloader_train, dataloader_val):
    plotting.output_file('learning_curve_%s_%s.html' %
                         (args.optimizer, args.context_path))
    fig_loss = plotting.figure(title='Loss Curve',
                               x_axis_label='epochs',
                               y_axis_label='loss',
                               plot_width=600,
                               plot_height=600)
    fig_precision = plotting.figure(title='Precision Curve',
                                    x_axis_label='epochs',
                                    y_axis_label='precision',
                                    plot_width=600,
                                    plot_height=600)
    fig_miou = plotting.figure(title='mIOU Curve',
                               x_axis_label='epochs',
                               y_axis_label='mIOU',
                               plot_width=600,
                               plot_height=600)

    if args.loss == 'dice':
        loss_func = DiceLoss()
    elif args.loss == 'crossentropy':
        loss_func = torch.nn.CrossEntropyLoss()

    max_miou = 0
    loss_list = []
    epoch_x = []
    precision_list = []
    miou_list = []
    for epoch in range(args.num_epochs):
        lr = poly_lr_scheduler(optimizer,
                               args.learning_rate,
                               iter=epoch,
                               max_iter=args.num_epochs)
        model.train()
        tq = tqdm.tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        loss_record = []
        for i, (data, label) in enumerate(dataloader_train):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()
            output, output_sup1, output_sup2 = model(data)
            loss1 = loss_func(output, label)
            loss2 = loss_func(output_sup1, label)
            loss3 = loss_func(output_sup2, label)
            loss = loss1 + loss2 + loss3
            tq.update(args.batch_size)
            tq.set_postfix(loss='%.6f' % loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_record.append(loss.item())
        tq.close()
        loss_train_mean = np.mean(loss_record)
        loss_list.append(loss_train_mean)
        print('loss for train : %f' % (loss_train_mean))

        if epoch % args.checkpoint_step == 0 and epoch != 0:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(
                model.module.state_dict(),
                os.path.join(args.save_model_path, 'latest_dice_loss.pth'))

        if epoch % args.validation_step == 0 or epoch == (args.num_epochs - 1):
            precision, miou = val(args, model, dataloader_val)
            if miou > max_miou:
                max_miou = miou
                torch.save(
                    model.module.state_dict(),
                    os.path.join(args.save_model_path, 'best_dice_loss.pth'))

            precision_list.append(precision)
            miou_list.append(miou)
            epoch_x.append(epoch)

    fig_loss.line(range(len(loss_list)),
                  loss_list,
                  legend_label='train loss, min: %.4f' % min(loss_list),
                  line_width=2,
                  line_color='red')
    fig_precision.line(epoch_x,
                       precision_list,
                       legend_label='precision, max: %.4f' %
                       max(precision_list),
                       line_width=2,
                       line_color='blue')
    fig_miou.line(epoch_x,
                  miou_list,
                  legend_label='miou, max: %.4f' % max(miou_list),
                  line_width=2,
                  line_color='green')
    plotting.save(row(fig_loss, fig_precision, fig_miou))
コード例 #11
0
def train(args, model_G, model_D, optimizer_G, optimizer_D, CamVid_dataloader_train, CamVid_dataloader_val, IDDA_dataloader, curr_epoch, max_miou): 
# we need the camvid data loader an modify the dataloadrer val we don't need the data loader train because we use Idda dataloader 
    writer = SummaryWriter(comment=''.format(args.optimizer_G,args.optimizer_D, args.context_path))#not important for now


    scaler = amp.GradScaler()
    if args.loss_G == 'dice':
        loss_func_G = DiceLoss()
    elif args.loss_G == 'crossentropy':
        loss_func_G = torch.nn.CrossEntropyLoss()
        
    loss_func_adv = torch.nn.BCEWithLogitsLoss()
    loss_func_D = torch.nn.BCEWithLogitsLoss()
        
    step = 0
    for epoch in range(curr_epoch + 1, args.num_epochs + 1):  # added +1 shift to finish with an eval
        lr_G = poly_lr_scheduler(optimizer_G, args.learning_rate_G, iter=epoch, max_iter=args.num_epochs)
        lr_D = poly_lr_scheduler(optimizer_D, args.learning_rate_D, iter=epoch, max_iter=args.num_epochs)
        model_G.train()
        model_D.train()
        tq = tqdm(total=len(CamVid_dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr_G %f , lr_D %f' % (epoch, lr_G ,lr_D )) 

        # set the ground truth for the discriminator
        source_label = 0
        target_label = 1
        # iniate lists to track the losses 
        loss_G_record = []
        loss_adv_record = []  # we added a new list to track the advarsirial loss of generator
        loss_D_record = []     # we added a new list to track the discriminator loss 
        
        source_train_loader = enumerate(IDDA_dataloader)
        s_size = len(IDDA_dataloader)
        target_loader = enumerate(CamVid_dataloader_train)
        t_size = len(CamVid_dataloader_train)

        for i in range(t_size):

            optimizer_G.zero_grad()
            optimizer_D.zero_grad()

        #train G:
        
            for param in model_D.parameters():
                param.requires_grad = False

            #train with source:

            _, batch = next(source_train_loader)
            data, label = batch
            #label = label.type(torch.LongTensor)
            data = data.cuda()
            label = label.long().cuda()

            with amp.autocast():
                output_s, output_sup1, output_sup2 = model_G(data)
                loss1 = loss_func_G(output_s, label)
                loss2 = loss_func_G(output_sup1, label)
                loss3 = loss_func_G(output_sup2, label)
                loss_G = loss1 + loss2 + loss3

            scaler.scale(loss_G).backward()

            #train with target:

            #try:
            _, batch = next(target_loader)
            #except:
            #    target_loader = enumerate(CamVid_dataloader_train)
            #    _, batch = next(target_loader)

            data, _ = batch
            data = data.cuda()
            with amp.autocast():

                output_t, output_sup1, output_sup2 = model_G(data)
                D_out = model_D(F.softmax(output_t))
                loss_adv = loss_func_adv(D_out , Variable(torch.FloatTensor(D_out.data.size()).fill_(source_label)).cuda() )  # I MIDIFIED THOSE TRY TO FOOL THE DISC
                loss_adv = loss_adv * args.lambda_adv#0.001 or 0.01 CHECK

            scaler.scale(loss_adv).backward()

        # train D:
            for param in model_D.parameters():
                param.requires_grad = True

            #train with source:

            output_s = output_s.detach()
            with amp.autocast():
                D_out = model_D(F.softmax(output_s))  # we feed the discriminator with the output of the model
                loss_D = loss_func_D(D_out, Variable(torch.FloatTensor(D_out.data.size()).fill_(source_label)).cuda())   # add the adversarial loss
                loss_D = loss_D / 2
            scaler.scale(loss_D).backward()

            #train with target:

            output_t = output_t.detach()
            with amp.autocast():
                D_out = model_D(F.softmax(output_t))  # we feed the discriminator with the output of the model
                loss_D = loss_func_D(D_out, Variable(torch.FloatTensor(D_out.data.size()).fill_(target_label)).cuda())  # add the adversarial loss
                loss_D = loss_D / 2
            scaler.scale(loss_D).backward()

            tq.update(args.batch_size)
            losses = {"loss_seg" : '%.6f' %(loss_G.item())  , "loss_adv" : '%.6f' %(loss_adv.item()) , "loss_D" : '%.6f'%(loss_D.item()) } # add dictionary to print losses
            tq.set_postfix(losses)

            loss_G_record.append(loss_G.item())
            loss_adv_record.append(loss_adv.item())
            loss_D_record.append(loss_D.item())           
            step += 1
            writer.add_scalar('loss_G_step', loss_G, step)  # track the segmentation loss 
            writer.add_scalar('loss_adv_step', loss_adv, step)  # track the adversarial loss 
            writer.add_scalar('loss_D_step', loss_D, step)  # track the discreminator loss 
            scaler.step(optimizer_G)  # update the optimizer for genarator
            scaler.step(optimizer_D)  # update the optimizer for discriminator
            scaler.update()

        tq.close()
        loss_G_train_mean = np.mean(loss_G_record)
        loss_adv_train_mean = np.mean(loss_adv_record)
        loss_D_train_mean = np.mean(loss_D_record)
        writer.add_scalar('epoch/loss_G_train_mean', float(loss_G_train_mean), epoch)
        writer.add_scalar('epoch/loss_adv_train_mean', float(loss_adv_train_mean), epoch)
        writer.add_scalar('epoch/loss_D_train_mean', float(loss_D_train_mean), epoch)
    
        
        
        
        #the checkpoint needs rewriting
        
        if epoch % args.checkpoint_step == 0 and epoch != 0:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            state = {
                "epoch": epoch,
                "model_G_state": model_G.module.state_dict(),
                "optimizer_G": optimizer_G.state_dict() ,
                "model_D_state": model_D.module.state_dict(), 
                "optimizer_D": optimizer_D.state_dict(),
                "max_miou": max_miou
            }

            torch.save(state, os.path.join(args.save_model_path, 'latest_dice_loss.pth'))

            print("*** epoch " + str(epoch) + " saved as recent checkpoint!!!")

        if epoch % args.validation_step == 0 and epoch != 0:
            precision, miou = val(args, model_G, CamVid_dataloader_val)
            if miou > max_miou:
                max_miou = miou
                os.makedirs(args.save_model_path, exist_ok=True)
                state = {
                    "epoch": epoch,
                    "model_state": model_G.module.state_dict(),
                    "optimizer": optimizer_G.state_dict(),
                    "max_miou": max_miou
                }
                torch.save(state, os.path.join(args.save_model_path, 'best_dice_loss.pth'))
                print("*** epoch " + str(epoch) + " saved as best checkpoint!!!")
            writer.add_scalar('epoch/precision_val', precision, epoch)
            writer.add_scalar('epoch/miou val', miou, epoch)
コード例 #12
0
ファイル: train2.py プロジェクト: shaofeifei11/CI-CAM
            }, {
                'params': net_module.backbone.parameters()
            }],
            lr=lr)

    if args.attention:
        branch = 2
    else:
        branch = 1
    for e in range(begin_epoch, epoch):
        if args.decay:
            poly_lr_scheduler(optimizer,
                              lr,
                              e,
                              lr_decay_iter=1,
                              max_iter=epoch,
                              dataset=args.dataset,
                              backbone_rate=args.backbone_rate,
                              decay_rate=args.decay_rate,
                              decay_epoch=args.decay_epoch)
        net.train()
        epoch_loss = 0
        epoch_acc = 0
        epoch_bbox_loss = 0
        for i, dat in enumerate(train_loader):
            images, labels = dat
            images, labels = images.to(cfg.device), labels.to(cfg.device)
            labels = labels.long()
            optimizer.zero_grad()

            cam_up, out_up, pred_sort_up, pred_ids_up, cam_down, out_down, pred_sort_down, pred_ids_down = net(
コード例 #13
0
ファイル: train.py プロジェクト: DeskDown/BiseNetv1
def train(args, model, optimizer, dataloader_train, dataloader_val, scaler):
    # Prepare the tensorboard
    writer = makeWriter(dataloader_train, args, model)
    # init loss func
    losses = {"dice": DiceLoss(), "crossentropy": torch.nn.CrossEntropyLoss(
        ignore_index=255)}
    loss_func = losses[args.loss]
    if args.use_lrScheduler:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=optimizer,
            mode='min',
            factor=0.1,
            patience=3,
            threshold=0.0001,
            min_lr=0,
        )
    max_miou = 0
    step = 0
    # start training
    for epoch in range(1, args.num_epochs + 1):
        model.train()
        # lr = optimizer.param_groups[0]['lr']
        lr = poly_lr_scheduler(optimizer, args.learning_rate,
                               iter=epoch, max_iter=args.num_epochs)
        loss_record = []
        principal_loss_record = []
        # progress bar
        tq = tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description("epoch: {}/{}".format(epoch, args.num_epochs))
        for i, (data, label) in enumerate(dataloader_train):
            label = label.type(torch.LongTensor)
            if args.use_gpu:
                data = data.to(device)
                label = label.to(device)

            # forward
            optimizer.zero_grad()
            if scaler:
                cm = amp.autocast()
            else:
                cm = dummy_cm()

            with cm:
                output, output_sup1, output_sup2 = model(data)
                loss1 = loss_func(output, label)
                loss2 = loss_func(output_sup1, label)
                loss3 = loss_func(output_sup2, label)
                loss = loss1 + loss2 + loss3

            tq.update(args.batch_size)
            tq.set_postfix(loss=f"{loss:.4f}", lr=lr)
            # backward
            if scaler:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
            if args.use_lrScheduler:
                scheduler.step(loss)

            step += 1
            # log the progress
            writer.add_scalar("loss_step", loss, step)
            loss_record.append(loss.item())
            principal_loss_record.append(loss1.item())

        tq.close()
        loss_train_mean = np.mean(loss_record)
        pri_train_mean = np.mean(principal_loss_record)
        writer.add_scalar("epoch/loss_epoch_train",
                          float(loss_train_mean), epoch)
        writer.add_scalar("epoch/pri_loss_epoch_train",
                          float(pri_train_mean), epoch)

        if epoch % args.checkpoint_step == 0 or epoch == args.num_epochs:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path, exist_ok=True)
            torch.save(
                model.state_dict(),
                os.path.join(args.save_model_path, f"model_{epoch}_.pth"),
            )

        if (epoch % args.validation_step == 0 or
            epoch == args.num_epochs or
                epoch == 1):
            precision, miou, val_loss = val(
                args, model, dataloader_val, loss_func)
            if miou > max_miou:
                max_miou = miou
                os.makedirs(args.save_model_path, exist_ok=True)
                torch.save(
                    model.state_dict(),
                    os.path.join(args.save_model_path,
                                 f"best_{args.loss}_loss.pth"),
                )

            writer.add_scalar("epoch/precision_val", precision, epoch)
            writer.add_scalar("epoch/miou_val", miou, epoch)
            writer.add_scalar("epoch/loss_val", loss, epoch)
            print("epoch: {}, train_loss: {}, val_loss: {}, val_precision: {}, val_miou: {}".format(
                epoch, pri_train_mean, val_loss, precision, miou
            ))
        writer.flush()

    writer.close()