示例#1
0
def main():
    # torch.manual_seed(1234)
    # torch.cuda.manual_seed(1234)
    opt = TrainOptions()
    args = opt.initialize()

    _t = {'iter time': Timer()}

    model_name = args.source + '_to_' + args.target
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
        os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
    opt.print_options(args)

    sourceloader, targetloader = CreateSrcDataLoader(
        args), CreateTrgDataLoader(args)
    targetloader_iter, sourceloader_iter = iter(targetloader), iter(
        sourceloader)

    model, optimizer = CreateModel(args)
    model_D, optimizer_D = CreateDiscriminator(args)

    start_iter = 0
    if args.restore_from is not None:
        start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])

    train_writer = tensorboardX.SummaryWriter(
        os.path.join(args.snapshot_dir, "logs", model_name))

    bce_loss = torch.nn.BCEWithLogitsLoss()
    l1_loss = torch.nn.L1Loss()
    cos_loss = torch.nn.CosineSimilarity(dim=0, eps=1e-06)

    cudnn.enabled = True
    cudnn.benchmark = True
    model.train()
    model.cuda()
    model_D.train()
    model_D.cuda()
    loss = [
        'loss_seg_src', 'loss_seg_trg', 'loss_D_trg_fake', 'loss_D_src_real',
        'loss_D_trg_real'
    ]
    _t['iter time'].tic()

    pbar = tqdm(range(start_iter, args.num_steps_stop))
    #for i in range(start_iter, args.num_steps):
    for i in pbar:

        model.adjust_learning_rate(args, optimizer, i)
        model_D.adjust_learning_rate(args, optimizer_D, i)

        optimizer.zero_grad()
        optimizer_D.zero_grad()
        for param in model_D.parameters():
            param.requires_grad = False

        src_img, src_lbl, _, _ = sourceloader_iter.next()
        src_img, src_lbl = Variable(src_img).cuda(), Variable(
            src_lbl.long()).cuda()
        src_seg_score, src_seg_score2 = model(src_img)
        loss_seg_src1 = CrossEntropy2d(src_seg_score, src_lbl)
        loss_seg_src2 = CrossEntropy2d(src_seg_score2, src_lbl)
        loss_seg_src = loss_seg_src1 + loss_seg_src2
        loss_seg_src.backward()

        if args.data_label_folder_target is not None:
            trg_img, trg_lbl, _, _ = targetloader_iter.next()
            trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(
                trg_lbl.long()).cuda()
            trg_seg_score = model(trg_img)
            loss_seg_trg = model.loss
        else:
            trg_img, _, name = targetloader_iter.next()
            trg_img = Variable(trg_img).cuda()
            trg_seg_score, trg_seg_score2 = model(trg_img)
            loss_seg_trg = 0

        outD_trg = model_D(F.softmax(trg_seg_score))
        outD_trg2 = model_D(F.softmax(trg_seg_score2))
        loss_D_trg_fake1 = bce_loss(
            outD_trg,
            Variable(torch.FloatTensor(outD_trg.data.size()).fill_(0)).cuda())
        loss_D_trg_fake2 = bce_loss(
            outD_trg2,
            Variable(torch.FloatTensor(outD_trg2.data.size()).fill_(0)).cuda())
        loss_D_trg_fake = loss_D_trg_fake1 + loss_D_trg_fake2

        loss_agree = l1_loss(F.softmax(trg_seg_score),
                             F.softmax(trg_seg_score2))

        loss_trg = args.lambda_adv_target * loss_D_trg_fake + loss_seg_trg + loss_agree
        loss_trg.backward()

        #Weight Discrepancy Loss

        W5 = None
        W6 = None
        if args.model == 'DeepLab2':

            for (w5, w6) in zip(model.layer5.parameters(),
                                model.layer6.parameters()):
                if W5 is None and W6 is None:
                    W5 = w5.view(-1)
                    W6 = w6.view(-1)
                else:
                    W5 = torch.cat((W5, w5.view(-1)), 0)
                    W6 = torch.cat((W6, w6.view(-1)), 0)

        #ipdb.set_trace()
        #loss_weight = (torch.matmul(W5, W6) / (torch.norm(W5) * torch.norm(W6)) + 1) # +1 is for a positive loss
        # loss_weight = loss_weight  * damping * 2
        loss_weight = args.weight_div * (cos_loss(W5, W6) + 1)
        loss_weight.backward()

        for param in model_D.parameters():
            param.requires_grad = True

        src_seg_score, trg_seg_score = src_seg_score.detach(
        ), trg_seg_score.detach()
        src_seg_score2, trg_seg_score2 = src_seg_score2.detach(
        ), trg_seg_score2.detach()

        outD_src = model_D(F.softmax(src_seg_score))
        loss_D_src_real1 = bce_loss(
            outD_src,
            Variable(torch.FloatTensor(
                outD_src.data.size()).fill_(0)).cuda()) / 2
        outD_src2 = model_D(F.softmax(src_seg_score2))
        loss_D_src_real2 = bce_loss(
            outD_src2,
            Variable(torch.FloatTensor(
                outD_src2.data.size()).fill_(0)).cuda()) / 2
        loss_D_src_real = loss_D_src_real1 + loss_D_src_real2
        loss_D_src_real.backward()

        outD_trg = model_D(F.softmax(trg_seg_score))
        loss_D_trg_real1 = bce_loss(
            outD_trg,
            Variable(torch.FloatTensor(
                outD_trg.data.size()).fill_(1)).cuda()) / 2
        outD_trg2 = model_D(F.softmax(trg_seg_score2))
        loss_D_trg_real2 = bce_loss(
            outD_trg2,
            Variable(torch.FloatTensor(
                outD_trg2.data.size()).fill_(1)).cuda()) / 2
        loss_D_trg_real = loss_D_trg_real1 + loss_D_trg_real2
        loss_D_trg_real.backward()

        d_loss = loss_D_src_real.data + loss_D_trg_real.data

        optimizer.step()
        optimizer_D.step()

        for m in loss:
            train_writer.add_scalar(m, eval(m), i + 1)

        if (i + 1) % args.save_pred_every == 0:
            print 'taking snapshot ...'
            torch.save(
                model.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '.pth'))
            torch.save(
                model_D.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '_D.pth'))

        if (i + 1) % args.print_freq == 0:
            _t['iter time'].toc(average=False)
            print '[it %d][src seg loss %.4f][adv loss %.4f][d loss %.4f][agree loss %.4f][div loss %.4f][lr %.4f][%.2fs]' % \
                    (i + 1, loss_seg_src.data, loss_D_trg_fake.data,d_loss,loss_agree.data,loss_weight.data, optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff)
            if i + 1 > args.num_steps_stop:
                print 'finish training'
                break
            _t['iter time'].tic()
示例#2
0
def main():
    # torch.manual_seed(1234)
    # torch.cuda.manual_seed(1234)
    opt = TrainOptions()
    args = opt.initialize()
    
    _t = {'iter time' : Timer()}
    
    model_name = args.source + '_to_' + args.target
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)   
        os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
    opt.print_options(args)
    
    sourceloader, targetloader = CreateSrcDataLoader(args), CreateTrgDataLoader(args)
    targetloader_iter, sourceloader_iter = iter(targetloader), iter(sourceloader)
    
    model, optimizer = CreateModel(args)
    model_D, optimizer_D = CreateDiscriminator(args)
    
    start_iter = 0
    if args.restore_from is not None:
        start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])
        
    train_writer = tensorboardX.SummaryWriter(os.path.join(args.snapshot_dir, "logs", model_name))
    
    bce_loss = torch.nn.BCEWithLogitsLoss()
    cent_loss=ConditionalEntropyLoss()
    
    cudnn.enabled = True
    cudnn.benchmark = True
    model.train()
    model.cuda()
    model_D.train()
    model_D.cuda()
    loss = ['loss_seg_src', 'loss_seg_trg', 'loss_D_trg_fake', 'loss_D_src_real', 'loss_D_trg_real']
    _t['iter time'].tic()

    pbar = tqdm(range(start_iter,args.num_steps_stop))
    #for i in range(start_iter, args.num_steps):
    for i in pbar:
        
        model.adjust_learning_rate(args, optimizer, i)
        model_D.adjust_learning_rate(args, optimizer_D, i)
        
        optimizer.zero_grad()
        optimizer_D.zero_grad()
        for param in model_D.parameters():
            param.requires_grad = False 
            
        src_img, src_lbl, _, _ = sourceloader_iter.next()
        src_img, src_lbl = Variable(src_img).cuda(), Variable(src_lbl.long()).cuda()
        src_seg_score = model(src_img)
        loss_seg_src = CrossEntropy2d(src_seg_score, src_lbl)
        #loss_seg_src = model.loss   
        loss_seg_src.backward()
        
        if args.data_label_folder_target is not None:
            trg_img, trg_lbl, _, _ = targetloader_iter.next()
            trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(trg_lbl.long()).cuda()
            trg_seg_score = model(trg_img) 
            loss_seg_trg = model.loss
        else:
            trg_img, _, name = targetloader_iter.next()
            trg_img = Variable(trg_img).cuda()
            trg_seg_score = model(trg_img)
            #ipdb.set_trace()
            loss_seg_trg= cent_loss(trg_seg_score)
            #loss_seg_trg= entropy_loss(F.softmax(trg_seg_score))
            #loss_seg_trg = 0
        
        outD_trg = model_D(F.softmax(trg_seg_score))
        loss_D_trg_fake = bce_loss(outD_trg, Variable(torch.FloatTensor(outD_trg.data.size()).fill_(0)).cuda())
        #loss_D_trg_fake = model_D.loss
        
        loss_trg = args.lambda_adv_target * (loss_D_trg_fake + loss_seg_trg)
        loss_trg.backward()
        
        for param in model_D.parameters():
            param.requires_grad = True
        
        src_seg_score, trg_seg_score = src_seg_score.detach(), trg_seg_score.detach()
        
        outD_src = model_D(F.softmax(src_seg_score))
        loss_D_src_real = bce_loss(outD_src, Variable(torch.FloatTensor(outD_src.data.size()).fill_(0)).cuda())/ 2
        #loss_D_src_real = model_D.loss / 2
        loss_D_src_real.backward()
        
        outD_trg = model_D(F.softmax(trg_seg_score))
        loss_D_trg_real = bce_loss(outD_trg, Variable(torch.FloatTensor(outD_trg.data.size()).fill_(1)).cuda())/ 2
        #loss_D_trg_real = model_D.loss / 2
        loss_D_trg_real.backward()   

        d_loss=loss_D_src_real.data+  loss_D_trg_real.data
       
        
        optimizer.step()
        optimizer_D.step()
        
        
        for m in loss:
            train_writer.add_scalar(m, eval(m), i+1)
            
        if (i+1) % args.save_pred_every == 0:
            print 'taking snapshot ...'
            torch.save(model.state_dict(), os.path.join(args.snapshot_dir, '%s_' %(args.source) +str(i+1)+'.pth' )) 
            torch.save(model_D.state_dict(), os.path.join(args.snapshot_dir, '%s_' %(args.source) +str(i+1)+'_D.pth' ))   
            
        if (i+1) % args.print_freq == 0:
            _t['iter time'].toc(average=False)
            print '[it %d][src seg loss %.4f][adv loss %.4f][d loss %.4f][lr %.4f][%.2fs]' % \
                    (i + 1, loss_seg_src.data, loss_D_trg_fake.data,d_loss,optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff)
            if i + 1 > args.num_steps_stop:
                print 'finish training'
                break
            _t['iter time'].tic()
示例#3
0
文件: train.py 项目: yahong7/DFSNet
def train(cfg):
    # configure train
    train_name = time.strftime("%m%d_%H%M%S", time.localtime(
    )) + '_' + cfg.model + '_' + os.path.basename(cfg.dataset_path)
    cfg.name = train_name
    log_interval = int(np.ceil(cfg.max_epochs * 0.1))
    print(cfg)

    # cpu or gpu?
    if torch.cuda.is_available() and cfg.device is not None:
        device = torch.device(cfg.device)
    else:
        if not torch.cuda.is_available():
            print("hey man, buy a GPU!")
        device = torch.device("cpu")

    # data
    print('Loading Data')
    train_data = monoSimDataset(path=cfg.dataset_path,
                                mode='train',
                                seed=cfg.seed,
                                debug_data=cfg.debug)
    train_data_loader = DataLoader(train_data,
                                   cfg.batch_size,
                                   drop_last=True,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)
    val_data = monoSimDataset(path=cfg.dataset_path,
                              mode='val',
                              seed=cfg.seed,
                              debug_data=cfg.debug)
    val_data_loader = DataLoader(val_data,
                                 cfg.batch_size,
                                 shuffle=False,
                                 drop_last=True,
                                 num_workers=cfg.num_workers)

    # configure model
    print('Loading Model')
    model = MobileNetV2_Lite(True, cfg.mask_learn_rate)
    assert model is not None
    model.to(device)
    if cfg.cp_path:
        cp_data = torch.load(cfg.cp_path, map_location=device)
        try:
            model.load_state_dict(cp_data['model'])
        except Exception as e:
            model.load_state_dict(cp_data['model'], strict=False)
            print(e)

        cp_data['cfg'] = '' if 'cfg' not in cp_data else cp_data['cfg']
        print(cp_data['cfg'])

    # criterion and optimizer
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=cfg.lr,
        # momentum=cfg.momentum,
        weight_decay=cfg.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           factor=0.5,
                                                           verbose=True)

    pred_criterion = nn.MSELoss()
    mask_criterion = CrossEntropy2d()

    # checkpoint
    if cfg.cp_num > 0:
        cp_dir_path = os.path.normcase(os.path.join('checkpoints', train_name))
        os.mkdir(cp_dir_path)
        best_cp = []
        history_dir_path = os.path.normcase(
            os.path.join(cp_dir_path, 'history'))
        os.mkdir(history_dir_path)
        with open(os.path.normcase(os.path.join(cp_dir_path, 'config.txt')),
                  'w') as f:
            info = str(cfg) + '#' * 30 + '\npre_cfg:\n' + str(
                cp_data['cfg']) if cfg.cp_path else str(cfg)
            f.write(info)

    # visble
    if cfg.visible:
        log_writer = SummaryWriter(os.path.join("log", train_name))
        log_writer.add_text('cur_cfg', cfg.__str__())
        if cfg.cp_path:
            log_writer.add_text('pre_cfg', cp_data['cfg'].__str__())

    # Start!
    print("Start training!\n")
    for epoch in range(1, cfg.max_epochs + 1):
        if epoch % int(cfg.max_epochs / 10) == 0 and cfg.mask_lr_decay < 1:
            cfg.mask_learn_rate *= cfg.mask_lr_decay
            print("[{}] Mask learn rate: {:.4e}".format(
                epoch, cfg.mask_learn_rate))

        # train
        model.train()
        epoch_loss = 0
        for img, mask, target in tqdm(
                train_data_loader,
                desc='[{}] mini_batch'.format(epoch),
                bar_format='{desc}: {n_fmt}/{total_fmt} -{percentage:3.0f}%'):
            img = img.to(device)
            mask = mask.to(device)
            target = target.to(device)
            optimizer.zero_grad()
            pred, heatmap = model(img)
            if cfg.mask_learn_rate == 0:
                loss = pred_criterion(pred, target)
            elif cfg.mask_learn_rate == 0:
                loss = mask_criterion(heatmap, mask)
            else:
                loss = (1 - cfg.mask_learn_rate) * pred_criterion(
                    pred, target) + cfg.mask_learn_rate * mask_criterion(
                        heatmap, mask)
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()
        train_loss = epoch_loss / len(train_data_loader)
        scheduler.step(train_loss)

        print("[{}] Training - loss: {:.4e}".format(epoch, train_loss))
        if cfg.visible:
            log_writer.add_scalar('Train/Loss', train_loss, epoch)
            log_writer.add_scalar('Train/lr', optimizer.param_groups[0]['lr'],
                                  epoch)

        # val
        if epoch % 5 == 0 or cfg.debug:
            if cfg.model.split('_')[0] == 'MobileNetV3':
                model.train()
            else:
                model.eval()
            with torch.no_grad():
                val_pred_loss = 0
                scores = np.zeros((1))
                prediction = np.zeros((1))
                for img, mask, target in tqdm(
                        val_data_loader,
                        desc='[{}] val_batch'.format(epoch),
                        bar_format=
                        '{desc}: {n_fmt}/{total_fmt} -{percentage:3.0f}%'):
                    img = img.to(device)
                    mask = mask.to(device)
                    target = target.to(device)
                    pred, heatmap = model(img)
                    val_pred_loss += nn.functional.mse_loss(pred,
                                                            target,
                                                            reduction='sum')
                    scores = np.append(scores,
                                       target.cpu().numpy().reshape((-1)))
                    prediction = np.append(prediction,
                                           pred.cpu().numpy().reshape((-1)))
                val_pred_loss = val_pred_loss / len(val_data)
                prediction = np.nan_to_num(prediction)
                srocc = stats.spearmanr(prediction[1:], scores[1:])[0]
                lcc = stats.pearsonr(prediction[1:], scores[1:])[0]

                print("[{}] Val - MSE: {:.4e}".format(epoch, val_pred_loss))
                print("[{}] Val - LCC: {:.4f}, SROCC: {:.4f}".format(
                    epoch, lcc, srocc))
                if cfg.visible:
                    idx = np.random.randint(0, mask.shape[0])
                    heatmap_s = torch.softmax(heatmap, 1)[idx, 1, :, :]
                    log_writer.add_scalar('Val/MSE', val_pred_loss, epoch)
                    log_writer.add_scalar('Val/LCC', lcc, epoch)
                    log_writer.add_scalar('Val/SROCC', srocc, epoch)
                    log_writer.add_image('Val/img', img[idx], epoch)
                    log_writer.add_image('Val/mask',
                                         torch.squeeze(mask[idx]),
                                         epoch,
                                         dataformats='HW')
                    log_writer.add_image('Val/heatmap',
                                         torch.squeeze(heatmap_s),
                                         epoch,
                                         dataformats='HW')

        # checkpoint
        if cfg.cp_num > 0:
            # model.cpu()
            cp_name = "{}_{:.4e}.pth".format(epoch, train_loss)

            if epoch < cfg.cp_num + 1:
                best_cp.append([cp_name, train_loss])
                best_cp.sort(key=lambda x: x[1])
                best_cp_path = os.path.normcase(
                    os.path.join(cp_dir_path, cp_name))

                cp_data = dict(
                    cfg=str(cfg),
                    model=model.state_dict(),
                )
                torch.save(cp_data, best_cp_path)
            else:
                if train_loss < best_cp[-1][1]:
                    os.remove(
                        os.path.normcase(
                            os.path.join(cp_dir_path, best_cp[-1][0])))
                    best_cp[-1] = [cp_name, train_loss]
                    best_cp.sort(key=lambda x: x[1])
                    best_cp_path = os.path.normcase(
                        os.path.join(cp_dir_path, cp_name))
                    cp_data = dict(
                        cfg=str(cfg),
                        model=model.state_dict(),
                    )
                    torch.save(cp_data, best_cp_path)

            if ((log_interval > 0) and (epoch % log_interval == 0 or epoch % 100 == 0)) or \
                    (epoch == cfg.max_epochs):
                history_cp_path = os.path.normcase(
                    os.path.join(history_dir_path, cp_name))
                cp_data = dict(
                    cfg=str(cfg),
                    model=model.state_dict(),
                )
                torch.save(cp_data, history_cp_path)

            # model.to(device)

    return model.cpu()