Exemple #1
0
def main():
    opt = TrainOptions()  # loading train options(arg parser)
    args = opt.initialize()  # get arguments
    os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
    _t = {'iter time': Timer()}

    model_name = args.source + '_to_' + args.target
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
        os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
    opt.print_options(args)  # print set options

    sourceloader, targetloader = CreateSrcDataLoader(
        args), CreateTrgDataLoader(args)
    sourceloader_iter, targetloader_iter = iter(sourceloader), iter(
        targetloader)

    model, optimizer = CreateModel(args)

    start_iter = 0
    if args.restore_from is not None:
        start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])

    cudnn.enabled = True
    cudnn.benchmark = True

    model.train()
    model.cuda()

    # losses to log
    loss = ['loss_seg_src', 'loss_seg_trg']
    loss_train = 0.0
    loss_val = 0.0
    loss_train_list = []
    loss_val_list = []

    mean_img = torch.zeros(1, 1)
    class_weights = Variable(CS_weights).cuda()

    _t['iter time'].tic()
    for i in range(start_iter, args.num_steps):
        model.adjust_learning_rate(args, optimizer, i)  # adjust learning rate
        optimizer.zero_grad()  # zero grad

        src_img, src_lbl, _, _ = sourceloader_iter.next()  # new batch source
        trg_img, trg_lbl, _, _ = targetloader_iter.next()  # new batch target

        scr_img_copy = src_img.clone()

        if mean_img.shape[-1] < 2:
            B, C, H, W = src_img.shape
            mean_img = IMG_MEAN.repeat(B, 1, H, W)

        #-------------------------------------------------------------------#

        # 1. source to target, target to target
        src_in_trg = FDA_source_to_target(src_img, trg_img,
                                          L=args.LB)  # src_lbl
        trg_in_trg = trg_img

        # 2. subtract mean
        src_img = src_in_trg.clone() - mean_img  # src, src_lbl
        trg_img = trg_in_trg.clone() - mean_img  # trg, trg_lbl

        #-------------------------------------------------------------------#

        # evaluate and update params #####
        src_img, src_lbl = Variable(src_img).cuda(), Variable(
            src_lbl.long()).cuda()  # to gpu
        src_seg_score = model(src_img,
                              lbl=src_lbl,
                              weight=class_weights,
                              ita=args.ita)  # forward pass
        loss_seg_src = model.loss_seg  # get loss
        loss_ent_src = model.loss_ent

        # get target loss, only entropy for backpro
        trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(
            trg_lbl.long()).cuda()  # to gpu
        trg_seg_score = model(trg_img,
                              lbl=trg_lbl,
                              weight=class_weights,
                              ita=args.ita)  # forward pass
        loss_seg_trg = model.loss_seg  # get loss
        loss_ent_trg = model.loss_ent

        triger_ent = 0.0
        if i > args.switch2entropy:
            triger_ent = 1.0

        loss_all = loss_seg_src + triger_ent * args.entW * loss_ent_trg  # loss of seg on src, and ent on s and t

        loss_all.backward()
        optimizer.step()

        loss_train += loss_seg_src.detach().cpu().numpy()
        loss_val += loss_seg_trg.detach().cpu().numpy()

        if (i + 1) % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '.pth'))

        if (i + 1) % args.print_freq == 0:
            _t['iter time'].toc(average=False)
            print('[it %d][src seg loss %.4f][trg seg loss %.4f][lr %.4f][%.2fs]' % \
                    (i + 1, loss_seg_src.data, loss_seg_trg.data, optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff) )

            sio.savemat(args.tempdata, {
                'src_img': src_img.cpu().numpy(),
                'trg_img': trg_img.cpu().numpy()
            })

            loss_train /= args.print_freq
            loss_val /= args.print_freq
            loss_train_list.append(loss_train)
            loss_val_list.append(loss_val)
            sio.savemat(args.matname, {
                'loss_train': loss_train_list,
                'loss_val': loss_val_list
            })
            loss_train = 0.0
            loss_val = 0.0

            if i + 1 > args.num_steps_stop:
                print('finish training')
                break
            _t['iter time'].tic()
Exemple #2
0
def main():

    opt = TrainOptions()
    args = opt.initialize()

    _t = {'iter time': Timer()}

    model_name = args.source + '_to_' + args.target
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
        os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
    opt.print_options(args)

    sourceloader, targetloader = CreateSrcDataLoader(
        args), CreateTrgDataLoader(args)
    sourceloader_iter, targetloader_iter = iter(sourceloader), iter(
        targetloader)

    model, optimizer = CreateModel(args)
    model_D, optimizer_D = CreateDiscriminator(args)

    start_iter = 0
    if args.restore_from is not None:
        start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])

    train_writer = tensorboardX.SummaryWriter(
        os.path.join(args.snapshot_dir, "logs", model_name))

    bce_loss = torch.nn.BCEWithLogitsLoss()

    cudnn.enabled = True
    cudnn.benchmark = True
    model.train()
    model.cuda()
    model_D.train()
    model_D.cuda()
    loss = [
        'loss_seg_src', 'loss_seg_trg', 'loss_D_trg_fake', 'loss_D_src_real',
        'loss_D_trg_real'
    ]

    _t['iter time'].tic()
    for i in range(start_iter, args.num_steps):

        model.adjust_learning_rate(args, optimizer, i)
        model_D.adjust_learning_rate(args, optimizer_D, i)

        optimizer.zero_grad()
        optimizer_D.zero_grad()
        """
        train Segmentation model
        """
        for param in model_D.parameters():
            param.requires_grad = False
        '''
        train on [S':transferred source images] for loss_seg_src
        '''
        src_img, src_lbl, _, _ = sourceloader_iter.next()
        src_img, src_lbl = Variable(src_img).cuda(), Variable(
            src_lbl.long()).cuda()
        src_seg_score = model(src_img, lbl=src_lbl)
        loss_seg_src = model.loss
        loss_seg_src.backward()
        '''
        train on [T:target images] for loss_seg_trg
        '''
        if args.data_label_folder_target is not None:
            trg_img, trg_lbl, _, _ = targetloader_iter.next()
            trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(
                trg_lbl.long()).cuda()
            trg_seg_score = model(trg_img, lbl=trg_lbl)
            loss_seg_trg = model.loss
        else:
            trg_img, _, name = targetloader_iter.next()
            trg_img = Variable(trg_img).cuda()
            trg_seg_score = model(trg_img)
            loss_seg_trg = 0

        outD_trg = model_D(F.softmax(trg_seg_score), 0)
        loss_D_trg_fake = model_D.loss

        loss_trg = args.lambda_adv_target * loss_D_trg_fake + loss_seg_trg
        loss_trg.backward()
        """
        train Discriminator
        """
        for param in model_D.parameters():
            param.requires_grad = True

        src_seg_score, trg_seg_score = src_seg_score.detach(
        ), trg_seg_score.detach()

        outD_src = model_D(F.softmax(src_seg_score), 0)
        loss_D_src_real = model_D.loss / 2
        loss_D_src_real.backward()

        outD_trg = model_D(F.softmax(trg_seg_score), 1)
        loss_D_trg_real = model_D.loss / 2
        loss_D_trg_real.backward()
        """
        update segmentation model & discriminator model
        """
        optimizer.step()
        optimizer_D.step()

        for m in loss:
            train_writer.add_scalar(m, eval(m), i + 1)

        if (i + 1) % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '.pth'))

        if (i + 1) % args.print_freq == 0:
            _t['iter time'].toc(average=False)
            print('[it %d][src seg loss %.4f][lr %.4f][%.2fs]' %
                  (i + 1, loss_seg_src.data, optimizer.param_groups[0]['lr'] *
                   10000, _t['iter time'].diff))
            if i + 1 > args.num_steps_stop:
                print('finish training')
                break
            _t['iter time'].tic()
Exemple #3
0
    def train_adp(self,round_):
        self.args.data_label_folder_target = root_base +'/dataset/generated_data/'   
        self.args.shuffel_ = True
        self.args.data_dir_target = root_base +'/dataset/generated_data/'
        self.args.data_list_target = self.args.data_gen_list
        self.args.batch_size = 1
 
        _t = {'iter time' : Timer()}

        self.args.num_steps = int(self.cnt_img * self.args.epoch_per_round * self.args.no_of_patches_per_image)
        
        sourceloader, targetloader = CreateSrcDataLoader(self.args), CreateTrgDataLoader(self.args)
        targetloader_iter, sourceloader_iter = iter(targetloader), iter(sourceloader)
        start_iter = 0

        train_writer = tensorboardX.SummaryWriter(os.path.join(self.args.snapshot_dir, "logs", self.model_name))

        cudnn.enabled = True
        cudnn.benchmark = True
        self.model.train()
        self.model.cuda()

        loss = ['loss_seg_src', 'loss_seg_trg']
        _t['iter time'].tic()

        for i in range(start_iter, self.args.num_steps):

            self.model.adjust_learning_rate(self.args, self.optimizer, i)
            self.optimizer.zero_grad()
            src_img, src_lbl, _, _,_ = sourceloader_iter.next()
            src_img, src_lbl = Variable(src_img).cuda(), Variable(src_lbl.long()).cuda()
            src_seg_score = self.model(src_img, lbl=src_lbl)       
            loss_seg_src = self.model.loss
            loss_src = torch.mean(loss_seg_src)     
            ##############################
            loss_src.backward()

            trg_img, trg_lbl, _, _ = targetloader_iter.next()
            trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(trg_lbl.long()).cuda()
            trg_seg_score = self.model(trg_img, lbl=trg_lbl) 
            ############################
            loss_seg_trg = self.model.loss 
            
            ##########Focal loss############
            loss_trg_2 = torch.mean(loss_seg_trg)
            if self.args.focal_loss:
                pt = torch.exp(-loss_seg_trg)
                loss_trg =   loss_seg_trg  * (1-pt)**self.args.gamma
                trg_fcl = torch.mean(loss_trg)
            else:
                trg_fcl = 0
        
            loss_trg =  self.args.beta *trg_fcl  + loss_trg_2
            loss_trg.backward()
        
            src_seg_score, trg_seg_score = src_seg_score.detach(), trg_seg_score.detach()
            self.optimizer.step()

            if (i+1) % self.args.saving_step == 0 :
                print ('taking snapshot ...')
                if(args.focal_loss):
                    torch.save(self.model.state_dict(), os.path.join(self.args.snapshot_dir, '%s' %(self.args.source+"_to_" +self.args.target)+"_w_focal_loss_"+args.model +'.pth' ))   
                else:
                    torch.save(self.model.state_dict(), os.path.join(self.args.snapshot_dir, '%s' %(self.args.source+"_to_" +self.args.target)+"_wo_focal_loss" +args.model +'.pth' ))   
            if (i+1) % 100 == 0:
                _t['iter time'].toc(average=False)
                print ('[it %d][src seg loss %.4f][lr %.4f][%.2fs]' % \
                        (i + 1, loss_src.data, self.optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff))
                
                _t['iter time'].tic()
Exemple #4
0
def main():

    opt = TrainOptions()
    args = opt.initialize()

    _t = {'iter time': Timer()}

    model_name = args.source + '_to_' + args.target
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
        os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
    opt.print_options(args)

    sourceloader, targetloader = CreateSrcDataLoader(
        args), CreateTrgDataLoader(args)
    targetloader_iter, sourceloader_iter = iter(targetloader), iter(
        sourceloader)

    model, optimizer = CreateModel(args)
    model_D1, optimizer_D1 = CreateDiscriminator(args, 1)
    model_D2, optimizer_D2 = CreateDiscriminator(args, 2)
    start_iter = 0
    if args.restore_from is not None:
        start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])

    train_writer = tensorboardX.SummaryWriter(
        os.path.join(args.snapshot_dir, "logs", model_name))

    bce_loss = torch.nn.BCEWithLogitsLoss()
    interp_target = nn.Upsample(size=(1024, 1024),
                                mode='bilinear',
                                align_corners=True)
    interp_source = nn.Upsample(size=(1024, 1024),
                                mode='bilinear',
                                align_corners=True)
    cudnn.enabled = True
    cudnn.benchmark = True
    model.train()
    model.cuda()
    model_D1.train()
    model_D1.cuda()
    model_D2.train()
    model_D2.cuda()
    weight_loss = WeightedBCEWithLogitsLoss()
    weight_map_loss = WeightMapLoss()
    loss = [
        'loss_seg_src', 'loss_seg_trg', 'loss_D_trg_fake', 'loss_D_src_real',
        'loss_D_trg_real'
    ]
    _t['iter time'].tic()
    for i in range(start_iter, args.num_steps):
        print(i)
        model.adjust_learning_rate(args, optimizer, i)
        model_D1.adjust_learning_rate(args, optimizer_D1, i)
        model_D2.adjust_learning_rate(args, optimizer_D2, i)
        optimizer.zero_grad()
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()

        ##train G
        for param in model_D1.parameters():
            param.requires_grad = False
        for param in model_D2.parameters():
            param.requires_grad = False

        try:
            src_img, src_lbl, weight_map, _ = sourceloader_iter.next()
        except StopIteration:
            sourceloader_iter = iter(sourceloader)
            src_img, src_lbl, weight_map, _ = sourceloader_iter.next()
        src_img, src_lbl, weight_map = Variable(src_img).cuda(), Variable(
            src_lbl.long()).cuda(), Variable(weight_map.long()).cuda()
        src_seg_score1, src_seg_score2, src_seg_score3, src_seg_score4 = model(
            src_img, lbl=src_lbl, weight=weight_map)
        #import pdb;pdb.set_trace()
        #WeightLoss1 = weight_map_loss(src_seg_score1, src_lbl, weight_map)
        #WeightLoss2 = weight_map_loss(src_seg_score2, src_lbl, weight_map)
        loss_seg_src = model.loss
        #print('WeightLoss2, WeightLoss1:', WeightLoss2.data, WeightLoss1.data)
        loss_seg_src.backward()

        if args.data_label_folder_target is not None:
            trg_img, trg_lbl, _, name = targetloader_iter.next()
            trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(
                trg_lbl.long()).cuda()
            trg_seg_score1, trg_seg_score2, trg_seg_score3, trg_seg_score4 = model(
                trg_img, lbl=trg_lbl)
            loss_seg_trg = model.loss
        else:
            trg_img, _, name = targetloader_iter.next()
            trg_img = Variable(trg_img).cuda()
            trg_seg_score1, trg_seg_score2, trg_seg_score3, trg_seg_score4 = model(
                trg_img)
            loss_seg_trg = 0
        outD1_trg = model_D1(F.softmax(trg_seg_score1), 0)
        outD2_trg = model_D2(F.softmax(trg_seg_score2), 0)
        #import pdb;pdb.set_trace()
        outD1_trg = interp_target(outD1_trg)  #[1, 1, 1024, 1024]
        outD2_trg = interp_target(outD2_trg)
        '''
        if i > 9001:
            #import pdb;pdb.set_trace()
            weight_map1 = prob_2_entropy(F.softmax(trg_seg_score1)) #[1, 1, 1024, 1024]
            weight_map2 = prob_2_entropy(F.softmax(trg_seg_score2)) #[1, 1, 1024, 1024]
            loss_D1_trg_fake = weight_loss(outD1_trg, Variable(torch.FloatTensor(outD1_trg.data.size()).fill_(0)).cuda(), weight_map1, 0.3, 1)
            loss_D2_trg_fake = weight_loss(outD2_trg, Variable(torch.FloatTensor(outD2_trg.data.size()).fill_(0)).cuda(), weight_map2, 0.3, 1)
        else:
            loss_D1_trg_fake = model_D1.loss
            loss_D2_trg_fake = model_D2.loss
        loss_D_trg_fake = loss_D1_trg_fake*0.2 + loss_D2_trg_fake
        '''

        loss_D_trg_fake = model_D1.loss * 0.2 + model_D2.loss
        loss_trg = args.lambda_adv_target * loss_D_trg_fake + loss_seg_trg
        loss_trg.backward()

        ###train D
        for param in model_D1.parameters():
            param.requires_grad = True
        for param in model_D2.parameters():
            param.requires_grad = True

        src_seg_score1, src_seg_score2, src_seg_score3, src_seg_score4, trg_seg_score1, trg_seg_score2, trg_seg_score3, trg_seg_score4 = src_seg_score1.detach(
        ), src_seg_score2.detach(), src_seg_score3.detach(
        ), src_seg_score4.detach(), trg_seg_score1.detach(
        ), trg_seg_score2.detach(), trg_seg_score3.detach(
        ), trg_seg_score4.detach()

        outD1_src = model_D1(F.softmax(src_seg_score1), 0)
        outD2_src = model_D2(F.softmax(src_seg_score2), 0)

        loss_D1_src_real = model_D1.loss / 2
        loss_D1_src_real.backward()
        loss_D2_src_real = model_D2.loss / 2
        loss_D2_src_real.backward()
        loss_D_src_real = loss_D1_src_real + loss_D2_src_real

        outD1_trg = model_D1(F.softmax(trg_seg_score1), 1)
        outD2_trg = model_D2(F.softmax(trg_seg_score2), 1)

        outD1_trg = interp_target(outD1_trg)
        outD2_trg = interp_target(outD2_trg)
        if i > 9001:
            weight_map1 = prob_2_entropy(F.softmax(trg_seg_score1))
            weight_map2 = prob_2_entropy(F.softmax(trg_seg_score2))
            loss_D1_trg_real = weight_loss(
                outD1_trg,
                Variable(torch.FloatTensor(
                    outD1_trg.data.size()).fill_(1)).cuda(), weight_map1, 0.3,
                1) / 2
            loss_D2_trg_real = weight_loss(
                outD2_trg,
                Variable(torch.FloatTensor(
                    outD2_trg.data.size()).fill_(1)).cuda(), weight_map2, 0.3,
                1) / 2

        else:
            loss_D1_trg_real = model_D1.loss / 2
            loss_D2_trg_real = model_D2.loss / 2

        loss_D1_trg_real.backward()
        loss_D2_trg_real.backward()
        loss_D_trg_real = loss_D1_trg_real + loss_D2_trg_real

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        for m in loss:
            train_writer.add_scalar(m, eval(m), i + 1)

        if (i + 1) % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '_D2.pth'))
        if (i + 1) % args.print_freq == 0:
            _t['iter time'].toc(average=False)
            print ('[it %d][src seg loss %.4f][trg seg loss %.4f][lr %.4f][%.2fs]' % \
                    (i + 1, loss_seg_src.data,loss_seg_trg.data, optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff))
            if i + 1 > args.num_steps_stop:
                print('finish training')
                break
            _t['iter time'].tic()
Exemple #5
0
def main():
    opt = TrainOptions()
    args = opt.initialize()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
    _t = {'iter time': Timer()}

    model_name = args.source + '_to_' + args.target
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
        os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
    opt.print_options(args)

    sourceloader, targetloader = CreateSrcDataLoader(
        args), CreateTrgDataLoader(args)
    sourceloader_iter, targetloader_iter = iter(sourceloader), iter(
        targetloader)

    pseudotrgloader = CreatePseudoTrgLoader(args)
    pseudoloader_iter = iter(pseudotrgloader)

    model, optimizer = CreateModel(args)

    start_iter = 0
    if args.restore_from is not None:
        start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])
    if args.restore_optim_from is not None:
        optimizer.load_state_dict(torch.load(args.restore_optim_from))
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

    cudnn.enabled = True
    cudnn.benchmark = True

    model.train()
    model.cuda()

    wandb.watch(model, log='gradient', log_freq=1)

    # losses to log
    loss = ['loss_seg_src', 'loss_seg_psu']
    loss_train = 0.0
    loss_val = 0.0
    loss_pseudo = 0.0
    loss_train_list = []
    loss_val_list = []
    loss_pseudo_list = []

    mean_img = torch.zeros(1, 1)
    class_weights = Variable(CS_weights).cuda()

    _t['iter time'].tic()
    for i in range(start_iter, args.num_steps):

        model.adjust_learning_rate(args, optimizer, i)  # adjust learning rate
        optimizer.zero_grad()  # zero grad

        src_img, src_lbl, _, _ = sourceloader_iter.next()  # new batch source
        trg_img, trg_lbl, _, _ = targetloader_iter.next()  # new batch target
        psu_img, psu_lbl, _, _ = pseudoloader_iter.next()

        scr_img_copy = src_img.clone()

        if mean_img.shape[-1] < 2:
            B, C, H, W = src_img.shape
            mean_img = IMG_MEAN.repeat(B, 1, H, W)

        #-------------------------------------------------------------------#

        # 1. source to target, target to target
        src_in_trg = FDA_source_to_target(src_img, trg_img,
                                          L=args.LB)  # src_lbl
        trg_in_trg = trg_img

        # 2. subtract mean
        src_img = src_in_trg.clone() - mean_img  # src_1, trg_1, src_lbl
        trg_img = trg_in_trg.clone() - mean_img  # trg_1, trg_0, trg_lbl
        psu_img = psu_img.clone() - mean_img

        #-------------------------------------------------------------------#

        # evaluate and update params #####
        src_img, src_lbl = Variable(src_img).cuda(), Variable(
            src_lbl.long()).cuda()  # to gpu
        src_seg_score = model(src_img,
                              lbl=src_lbl,
                              weight=class_weights,
                              ita=args.ita)  # forward pass
        loss_seg_src = model.loss_seg  # get loss
        loss_ent_src = model.loss_ent

        # use pseudo label as supervision
        psu_img, psu_lbl = Variable(psu_img).cuda(), Variable(
            psu_lbl.long()).cuda()
        psu_seg_score = model(psu_img,
                              lbl=psu_lbl,
                              weight=class_weights,
                              ita=args.ita)
        loss_seg_psu = model.loss_seg
        loss_ent_psu = model.loss_ent

        loss_all = loss_seg_src + (loss_seg_psu + args.entW * loss_ent_psu
                                   )  # loss of seg on src, and ent on s and t
        loss_all.backward()
        optimizer.step()

        loss_train += loss_seg_src.detach().cpu().numpy()
        loss_val += loss_seg_psu.detach().cpu().numpy()

        if (i + 1) % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '.pth'))
            torch.save(
                optimizer.state_dict(),
                os.path.join(args.snapshot_dir_optim,
                             '%s_' % (args.source) + '.pth'))
            wandb.log({
                "src seg loss": loss_seg_src.data,
                "psu seg loss": loss_seg_psu.data,
                "learnign rate": optimizer.param_groups[0]['lr'] * 10000
            })
        if (i + 1) % args.print_freq == 0:
            _t['iter time'].toc(average=False)
            print('[it %d][src seg loss %.4f][psu seg loss %.4f][lr %.4f][%.2fs]' % \
                    (i + 1, loss_seg_src.data, loss_seg_psu.data, optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff) )

            sio.savemat(args.tempdata, {
                'src_img': src_img.cpu().numpy(),
                'trg_img': trg_img.cpu().numpy()
            })

            loss_train /= args.print_freq
            loss_val /= args.print_freq
            loss_train_list.append(loss_train)
            loss_val_list.append(loss_val)
            sio.savemat(args.matname, {
                'loss_train': loss_train_list,
                'loss_val': loss_val_list
            })
            loss_train = 0.0
            loss_val = 0.0

            if i + 1 > args.num_steps_stop:
                print('finish training')
                break
            _t['iter time'].tic()
Exemple #6
0
def main():
    args = arg_parser.Parse()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    logger = Logger(args.log_dir)
    logger.PrintAndLogArgs(args)
    saver = ImageAndLossSaver(args.tb_logs_dir, logger.log_folder,
                              args.checkpoints_dir, args.save_pics_every)
    source_loader, target_train_loader, target_eval_loader = CreateSrcDataLoader(
        args), CreateTrgDataLoader(args,
                                   'train'), CreateTrgDataLoader(args, 'val')
    epoch_size = np.maximum(len(target_train_loader.dataset),
                            len(source_loader.dataset))
    steps_per_epoch = int(np.floor(epoch_size / args.batch_size))
    source_loader.dataset.SetEpochSize(epoch_size)
    target_train_loader.dataset.SetEpochSize(epoch_size)

    generator = model.DeepLPFNet()
    generator = nn.DataParallel(generator.cuda())
    generator_criterion = model.GeneratorLoss()
    generator_optimizer = optim.Adam(generator.parameters(),
                                     lr=args.generator_lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08)
    discriminator = model.Discriminator()
    discriminator = nn.DataParallel(discriminator.cuda())
    discriminator_criterion = model.DiscriminatorLoss()
    discriminator_optimizer = optim.Adam(discriminator.parameters(),
                                         lr=args.discriminator_lr,
                                         betas=(0.9, 0.999),
                                         eps=1e-08)
    semseg_net, semseg_optimizer = CreateModel(args)
    semseg_net = nn.DataParallel(semseg_net.cuda())

    logger.info('######### Network created #########')
    logger.info('Architecture of Generator:\n' + str(generator))
    logger.info('Architecture of Discriminator:\n' + str(discriminator))
    logger.info('Architecture of Backbone net:\n' + str(semseg_net))

    for epoch in range(args.num_epochs):
        generator.train()
        discriminator.train()
        semseg_net.train()
        saver.Reset()
        discriminate_src = True
        source_loader_iter, target_train_loader_iter, target_eval_loader_iter = iter(
            source_loader), iter(target_train_loader), iter(target_eval_loader)
        logger.info('#################[Epoch %d]#################' %
                    (epoch + 1))

        for batch_num in range(steps_per_epoch):
            start_time = time.time()
            training_discriminator = (batch_num >= args.generator_boost) and (
                batch_num - args.generator_boost) % (
                    args.discriminator_iters +
                    args.generator_iters) < args.discriminator_iters
            src_img, src_lbl, src_shapes, src_names = source_loader_iter.next(
            )  # new batch source
            trg_eval_img, trg_eval_lbl, trg_shapes, trg_names = target_train_loader_iter.next(
            )  # new batch target

            generator_optimizer.zero_grad()
            discriminator_optimizer.zero_grad()
            semseg_optimizer.zero_grad()

            src_input_batch = Variable(src_img, requires_grad=False).cuda()
            src_label_batch = Variable(src_lbl, requires_grad=False).cuda()
            trg_input_batch = Variable(trg_eval_img,
                                       requires_grad=False).cuda()
            # trg_label_batch = Variable(trg_lbl, requires_grad=False).cuda()
            src_in_trg = generator(src_input_batch, trg_input_batch)  # G(S,T)

            if training_discriminator:  #train discriminator
                if discriminate_src == True:
                    discriminator_src_in_trg = discriminator(
                        src_in_trg)  # D(G(S,T))
                    discriminator_trg = None  # D(T)
                else:
                    discriminator_src_in_trg = None  # D(G(S,T))
                    discriminator_trg = discriminator(trg_input_batch)  # D(T)
                discriminate_src = not discriminate_src
                loss = discriminator_criterion(discriminator_src_in_trg,
                                               discriminator_trg)
            else:  #train generator and semseg net
                discriminator_trg = discriminator(trg_input_batch)  # D(T)
                predicted, loss_seg, loss_ent = semseg_net(
                    src_in_trg, lbl=src_label_batch)  # F(G(S.T))
                src_in_trg_labels = torch.argmax(predicted, dim=1)
                loss = generator_criterion(loss_seg, loss_ent, args.entW,
                                           discriminator_trg)

            saver.WriteLossHistory(training_discriminator, loss.item())
            loss.backward()

            if training_discriminator:  # train discriminator
                discriminator_optimizer.step()
            else:  # train generator and semseg net
                generator_optimizer.step()
                semseg_optimizer.step()

            saver.running_time += time.time() - start_time

            if (not training_discriminator) and saver.SaveImagesIteration:
                saver.SaveTrainImages(epoch, src_img[0, :, :, :],
                                      src_in_trg[0, :, :, :], src_lbl[0, :, :],
                                      src_in_trg_labels[0, :, :])

            if (batch_num + 1) % args.print_every == 0:
                logger.PrintAndLogData(saver, epoch, batch_num,
                                       args.print_every)

            if (batch_num + 1) % args.save_checkpoint == 0:
                saver.SaveModelsCheckpoint(semseg_net, discriminator,
                                           generator, epoch, batch_num)

        #Validation:
        semseg_net.eval()
        rand_samp_inds = np.random.randint(0, len(target_eval_loader.dataset),
                                           5)
        rand_batchs = np.floor(rand_samp_inds / args.batch_size).astype(np.int)
        cm = torch.zeros((NUM_CLASSES, NUM_CLASSES)).cuda()
        for val_batch_num, (trg_eval_img, trg_eval_lbl, _,
                            _) in enumerate(target_eval_loader):
            with torch.no_grad():
                trg_input_batch = Variable(trg_eval_img,
                                           requires_grad=False).cuda()
                trg_label_batch = Variable(trg_eval_lbl,
                                           requires_grad=False).cuda()
                pred_softs_batch = semseg_net(trg_input_batch)
                pred_batch = torch.argmax(pred_softs_batch, dim=1)
                cm += compute_cm_batch_torch(pred_batch, trg_label_batch,
                                             IGNORE_LABEL, NUM_CLASSES)
                print('Validation: saw', val_batch_num * args.batch_size,
                      'examples')
                if (val_batch_num + 1) in rand_batchs:
                    rand_offset = np.random.randint(0, args.batch_size)
                    saver.SaveValidationImages(
                        epoch, trg_input_batch[rand_offset, :, :, :],
                        trg_label_batch[rand_offset, :, :],
                        pred_batch[rand_offset, :, :])
        iou, miou = compute_iou_torch(cm)
        saver.SaveEpochAccuracy(iou, miou, epoch)
        logger.info(
            'Average accuracy of Epoch #%d on target domain: mIoU = %2f' %
            (epoch + 1, miou))
        logger.info(
            '-----------------------------------Epoch #%d Finished-----------------------------------'
            % (epoch + 1))
        del cm, trg_input_batch, trg_label_batch, pred_softs_batch, pred_batch

    saver.tb.close()
    logger.info('Finished training.')