def load_and_test_net(data_list, datadir, weights, model, num_cls,
                      dset='test', base_model=None, base_path=None, **test_args):
    # Setup GPU Usage
    if torch.cuda.is_available():
        kwargs = {'num_workers': 1, 'pin_memory': True}
    else:
        kwargs = {}

    # Eval tgt from AddaNet or TaskNet model #
    if model == 'AddaNet':
        net = get_model(model, num_cls=num_cls, weights_init=weights,
                        model=base_model)
        net = net.tgt_net
    elif model == 'CalibratorNet':
        net = get_model(model, num_cls=num_cls, weights_init=weights,
                        model=base_model)

    else:
        net = get_model(model, num_cls=num_cls, weights_init=weights)

    # Load data
    loaders = []
    for (i, data) in enumerate(data_list):
        test_data = load_data(data, dset, batch=100,
                              rootdir=datadir[i], num_channels=net.num_channels,
                              image_size=net.image_size, download=True, kwargs=kwargs)
        if test_data is None:
            print('skipping test')
            return
        loaders.append(test_data)

    return test(net, loaders_list=loaders, save_path=base_path + "_with_" + "_".join(data_list) + "log", **test_args)
Esempio n. 2
0
def main(output, dataset, datadir, lr, momentum, snapshot, downscale, cls_weights, gpu,
         weights_init, num_cls, lsgan, max_iter, lambda_d, lambda_g,
         train_discrim_only, weights_discrim, crop_size, weights_shared,
         discrim_feat, half_crop, batch, model):
    # So data is sampled in consistent way
    #np.random.seed(1337)
    #torch.manual_seed(1337)
    logdir = 'runs/{:s}/{:s}_to_{:s}/lr{:.1g}_ld{:.2g}_lg{:.2g}'.format(model, dataset[0],
                                                                        dataset[1], lr, lambda_d, lambda_g)
    if weights_shared:
        logdir += '_weightshared'
    else:
        logdir += '_weightsunshared'
    if discrim_feat:
        logdir += '_discrimfeat'
    else:
        logdir += '_discrimscore'
    logdir += '/' + datetime.now().strftime('%Y_%b_%d-%H:%M')
    writer = SummaryWriter(log_dir=logdir)

    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    config_logging()
    print('Train Discrim Only', train_discrim_only)
    print (weights_init)
    
    net = get_model('CalibratorNet', model=model, num_cls=num_cls, cali_model='resnet_9blocks',src_weights_init = weights_init,task = 'segmentation')
    

    net.src_net.eval()

    
    #odim = 1 if lsgan else 2
    odim = 4
    idim = num_cls if not discrim_feat else 4096

    '''
    discriminator = Discriminator(input_dim=idim, output_dim=odim,
                                  pretrained=not (weights_discrim == None),
                                  weights_init=weights_discrim).cuda()
    
    net.discriminator = discriminator
    '''
    print (net)
    transform = net.src_net.module.transform if hasattr(net.src_net,'module') else net.src_net.transform
    
    loader = AddaDataLoader(transform, dataset, datadir, downscale,
                            crop_size=crop_size, half_crop=half_crop,
                            batch_size=batch, shuffle=True, num_workers=2)
    print('dataset', dataset)

    # Class weighted loss?
    if cls_weights is not None:
        weights = np.loadtxt(cls_weights)
    else:
        weights = None

    # setup optimizers

    weight_decay = 0.005
    betas = (0.9,0.999)
    lr = 2e-4

    '''
    opt_dis = optim.SGD(net.discriminator.parameters(), lr=lr,
                        momentum = momentum, weight_decay = 0.0005)

    opt_p = optim.SGD(net.pixel_discriminator.parameters(), lr=lr,
                        momentum = momentum, weight_decay = 0.0005)
    
    opt_cali = optim.SGD(net.calibrator_T.parameters(), lr=lr,
                         momentum = momentum, weight_decay = 0.0005)
    '''
    


    opt_dis = optim.Adam(net.discriminator.parameters(), lr=lr,
                         weight_decay=weight_decay, betas=betas)

    opt_p = optim.Adam(net.pixel_discriminator.parameters(), lr=lr,
                         weight_decay=weight_decay, betas=betas)

    
    opt_cali = optim.Adam(net.calibrator_T.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)

    iteration = 0
    num_update_g = 0
    last_update_g = -1
    losses_super_s = deque(maxlen=100)
    losses_super_t = deque(maxlen=100)
    losses_dis = deque(maxlen=100)
    losses_rep = deque(maxlen=100)
    accuracies_dom = deque(maxlen=100)
    intersections = np.zeros([100, num_cls])
    unions = np.zeros([100, num_cls])
    accuracy = deque(maxlen=100)
    print('max iter:', max_iter)

    net.src_net.train()
    net.discriminator.train()
    net.pixel_discriminator.train()
    net.calibrator_T.train()
    freq_D = 8
    freq_G = 8
    while iteration < max_iter:

        for im_s, im_t, label_s, label_t in loader:
            if iteration > max_iter:
                break
            if im_s.size(0) != im_t.size(0):
                continue
            
            info_str = 'Iteration {}: '.format(iteration)

            if not check_label(label_s, num_cls):
                continue

            ###########################
            # 1. Setup Data Variables #
            ###########################
            im_s = make_variable(im_s, requires_grad=False)
            label_s = make_variable(label_s, requires_grad=False)
            im_t = make_variable(im_t, requires_grad=False)
            label_t = make_variable(label_t, requires_grad=False)

            #############################
            # 2. Optimize Discriminator #
            #############################

            # zero gradients for optimizer
            if (iteration+1) % freq_D ==0:
                opt_dis.step()
                opt_p.step()                
                opt_dis.zero_grad()
                opt_p.zero_grad()

            pert_T_s,pert_T_t = forward_calibrator(im_s,im_t,net.calibrator_T)

            fake_T_s = torch.clamp(im_s+pert_T_s,-3,3)
            fake_T_t = torch.clamp(im_t+pert_T_t,-3,3)

            score_s,score_t = forward_clean_data(im_s,im_t,net.src_net)
            score_T_s,score_T_t = forward_pert_data(fake_T_s,fake_T_t,net.src_net)

            pred_p_t = net.pixel_discriminator(im_t)
            pred_p_s = net.pixel_discriminator(im_s)
            pred_p_T_s = net.pixel_discriminator(fake_T_s)
            pred_p_T_t = net.pixel_discriminator(fake_T_t)

            # prediction for feature discriminator            

            gan_criterion = GANLoss().cuda()
            idt_criterion = torch.nn.L1Loss()
            cycle_criterion = torch.nn.L1Loss()

            # 0,1,2,3 for 4 different domains


            pred_s = net.discriminator(score_s)
            pred_t = net.discriminator(score_t)
            pred_T_t = net.discriminator(score_T_t)
            pred_T_s = net.discriminator(score_T_s)
                       
            

                    
            #dis_pred_concat = torch.cat((dis_score_s, dis_score_t))

            # prepare real and fake labels
            batch_t, _, h, w = score_t.size()
            batch_s, _, _, _ = score_s.size()
            label_0 = make_variable(
                    0*torch.ones(batch_s, h, w).long(), #s
                 requires_grad=False)
            label_1 = make_variable(
                    1*torch.ones(batch_t, h, w).long(), #t
                 requires_grad=False)

            label_2 = make_variable(
                    2*torch.ones(batch_t, h, w).long(), #T_t
                 requires_grad=False)
            label_3 = make_variable(
                    3*torch.ones(batch_s, h, w).long(), #T_s
                 requires_grad=False)                                    

            P_loss_s = gan_criterion(pred_p_s,0)                
            P_loss_t = gan_criterion(pred_p_t,1)
            P_loss_T_t = gan_criterion(pred_p_T_t,2)
            P_loss_T_s = gan_criterion(pred_p_T_s,3)


            dis_pred_concat = torch.cat([pred_s,pred_t,pred_T_t,pred_T_s])
            dis_label_concat = torch.cat([label_0,label_1,label_2,label_3])
            
            # compute loss for discriminator
            loss_dis = supervised_loss(dis_pred_concat, dis_label_concat)

            D_loss = loss_dis + P_loss_T_t + P_loss_s + P_loss_t + P_loss_T_s
            
            D_loss.backward(retain_graph = True)
            losses_dis.append(D_loss.item())


            info_str += ' D:{:.3f}'.format(np.mean(losses_dis))
            # optimize discriminator

            

            ###########################
            # Optimize Target Network #
            ###########################



            if True:
                
                last_update_g = iteration
                num_update_g += 1
                if num_update_g % 1 == 0:
                    pass
                    #print('Updating G with adversarial loss ({:d} times)'.format(num_update_g))

                # zero out optimizer gradients
                if (iteration+1)% freq_G == 0:
                    opt_cali.step()                    
                    opt_dis.zero_grad()
                    opt_p.zero_grad()
                    opt_cali.zero_grad()
                # create fake label
                batch, _, h, w = score_t.size()

                label_T_t = make_variable(
                    0*torch.ones(batch, h, w).long(),
                 requires_grad=False)
                label_T_s = make_variable(
                    0*torch.ones(batch, h, w).long(),
                 requires_grad=False)
                

                P_loss_T_t = gan_criterion(pred_p_T_t,0)
                P_loss_T_s = gan_criterion(pred_p_T_s,0)

                
                G_loss_T_t = supervised_loss(pred_T_t,label_T_t)
                
                G_loss_T_s = supervised_loss(pred_T_s,label_T_s)
                

                G_loss = G_loss_T_t + 0.2*G_loss_T_s + P_loss_T_t + 0.2*P_loss_T_s  
                G_loss.backward()



                losses_rep.append(G_loss.item())
                #writer.add_scalar('loss/generator', np.mean(losses_rep), iteration)

                # optimize target net

                # log net update info
                info_str += ' G:{:.3f}'.format(np.mean(losses_rep))


            # compute supervised losses for target -- monitoring only!!!


            ###########################
            # Log and compute metrics #
            ###########################
            if iteration % 10 == 0 and iteration > 0:
                # compute metrics
                intersection, union, acc = seg_accuracy(score_T_t, label_t.data, num_cls)
                intersections = np.vstack([intersections[1:, :], intersection[np.newaxis, :]])
                unions = np.vstack([unions[1:, :], union[np.newaxis, :]])
                accuracy.append(acc.item() * 100)
                acc = np.mean(accuracy)
                mIoU = np.mean(np.maximum(intersections, 1) / np.maximum(unions, 1)) * 100

                info_str += ' acc:{:0.2f}  mIoU:{:0.2f}'.format(acc, mIoU)
                writer.add_scalar('metrics/acc', np.mean(accuracy), iteration)
                writer.add_scalar('metrics/mIoU', np.mean(mIoU), iteration)
                print (info_str)
                logging.info(info_str)

            iteration += 1

            ################
            # Save outputs #
            ################

            # every 500 iters save current model
            if iteration % 500 == 0:
                os.makedirs(output, exist_ok=True)
                if not train_discrim_only:
                    torch.save(net.state_dict(),
                               '{}/net-itercurr.pth'.format(output))
                torch.save(net.discriminator.state_dict(),
                           '{}/discriminator-itercurr.pth'.format(output))

            # save labeled snapshots
            if iteration % snapshot == 0:
                os.makedirs(output, exist_ok=True)
                if not train_discrim_only:
                    torch.save(net.state_dict(),
                               '{}/net-iter{}.pth'.format(output, iteration))
                torch.save(net.discriminator.state_dict(),
                           '{}/discriminator-iter{}.pth'.format(output, iteration))

            if iteration - last_update_g >= len(loader):
                print('No suitable discriminator found -- returning.')
                torch.save(net.state_dict(),
                           '{}/net-iter{}.pth'.format(output, iteration))
                iteration = max_iter  # make sure outside loop breaks
                break

    writer.close()
Esempio n. 3
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = [
            'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B',
            'sem_AB'
        ]

        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_A')
            visual_names_B.append('idt_B')

        self.visual_names = visual_names_A + visual_names_B
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']

        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)

            # Here for semantic consistency loss, load a fcn network as fs here.
            self.netPixelCLS = get_model(opt.weights_model_type,
                                         num_cls=opt.num_cls,
                                         pretrained=True,
                                         weights_init=opt.weights_init)
            # Specially initialize Pixel CLS network
            if len(self.gpu_ids) > 0:
                assert (torch.cuda.is_available())
                self.netPixelCLS.to(self.gpu_ids[0])
                self.netPixelCLS = torch.nn.DataParallel(
                    self.netPixelCLS, self.gpu_ids)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(
                use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # self.criterionCLS = torch.nn.modules.CrossEntropyLoss()
            self.criterionSemantic = torch.nn.KLDivLoss(reduction='batchmean')
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
Esempio n. 4
0
def main(output, dataset, target_name, datadir, batch_size, lr, iterations,
         momentum, snapshot, downscale, augmentation, fyu, crop_size, weights,
         model, gpu, num_cls, nthreads, model_weights, data_flag,
         serial_batches, resize_to, start_step, preprocessing, small,
         rundir_flag):
    if weights is not None:
        raise RuntimeError("weights don't work because eric is bad at coding")
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    config_logging()
    logdir_flag = data_flag
    if rundir_flag != "":
        logdir_flag += "_{}".format(rundir_flag)

    logdir = 'runs/{:s}/{:s}/{:s}'.format(model, '-'.join(dataset),
                                          logdir_flag)
    writer = SummaryWriter(log_dir=logdir)
    if model == 'fcn8s':
        net = get_model(model,
                        num_cls=num_cls,
                        weights_init=model_weights,
                        output_last_ft=True)
    else:
        net = get_model(model,
                        num_cls=num_cls,
                        finetune=True,
                        weights_init=model_weights)
    net.cuda()

    str_ids = gpu.split(',')
    gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >= 0:
            gpu_ids.append(id)

    # set gpu ids
    if len(gpu_ids) > 0:
        torch.cuda.set_device(gpu_ids[0])
        assert (torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)

    transform = []
    target_transform = []

    if preprocessing:
        transform.extend([
            torchvision.transforms.Resize(
                [int(resize_to), int(int(resize_to) * 1.8)],
                interpolation=Image.BICUBIC)
        ])
        target_transform.extend([
            torchvision.transforms.Resize(
                [int(resize_to), int(int(resize_to) * 1.8)],
                interpolation=Image.NEAREST)
        ])

    transform.extend([net.module.transform])
    target_transform.extend([to_tensor_raw])
    transform = torchvision.transforms.Compose(transform)
    target_transform = torchvision.transforms.Compose(target_transform)

    datasets = [
        get_dataset(name,
                    os.path.join(datadir, name),
                    num_cls=num_cls,
                    transform=transform,
                    target_transform=target_transform,
                    data_flag=data_flag,
                    small=small) for name in dataset
    ]

    target_dataset = get_dataset(target_name,
                                 os.path.join(datadir, target_name),
                                 num_cls=num_cls,
                                 transform=transform,
                                 target_transform=target_transform,
                                 data_flag=data_flag,
                                 small=small)

    if weights is not None:
        weights = np.loadtxt(weights)

    if augmentation:
        collate_fn = lambda batch: augment_collate(
            batch, crop=crop_size, flip=True)
    else:
        collate_fn = torch.utils.data.dataloader.default_collate

    loaders = [
        torch.utils.data.DataLoader(dataset,
                                    batch_size=batch_size,
                                    shuffle=not serial_batches,
                                    num_workers=nthreads,
                                    collate_fn=collate_fn,
                                    pin_memory=True,
                                    drop_last=True) for dataset in datasets
    ]

    target_loader = torch.utils.data.DataLoader(target_dataset,
                                                batch_size=batch_size,
                                                shuffle=not serial_batches,
                                                num_workers=nthreads,
                                                collate_fn=collate_fn,
                                                pin_memory=True,
                                                drop_last=True)
    iteration = start_step
    losses = deque(maxlen=10)
    losses_domain_syn = deque(maxlen=10)
    losses_domain_gta = deque(maxlen=10)
    losses_task = deque(maxlen=10)

    for loader in loaders:
        loader.dataset.__getitem__(0, debug=True)

    input_dim = 2048
    configs = {
        "input_dim": input_dim,
        "hidden_layers": [1000, 500, 100],
        "num_classes": 2,
        'num_domains': 2,
        'mode': 'dynamic',
        'mu': 1e-2,
        'gamma': 10.0
    }

    mdan = MDANet(configs).to(gpu_ids[0])
    mdan = torch.nn.DataParallel(mdan, gpu_ids)
    mdan.train()

    opt = torch.optim.Adam(itertools.chain(mdan.module.parameters(),
                                           net.module.parameters()),
                           lr=1e-4)

    # cnt = 0
    for (im_syn, label_syn), (im_gta,
                              label_gta), (im_cs,
                                           label_cs) in multi_source_infinite(
                                               loaders, target_loader):
        # cnt += 1
        # print(cnt)
        # Clear out gradients
        opt.zero_grad()

        # load data/label
        im_syn = make_variable(im_syn, requires_grad=False)
        label_syn = make_variable(label_syn, requires_grad=False)

        im_gta = make_variable(im_gta, requires_grad=False)
        label_gta = make_variable(label_gta, requires_grad=False)

        im_cs = make_variable(im_cs, requires_grad=False)
        label_cs = make_variable(label_cs, requires_grad=False)

        if iteration == 0:
            print("im_syn size: {}".format(im_syn.size()))
            print("label_syn size: {}".format(label_syn.size()))

            print("im_gta size: {}".format(im_gta.size()))
            print("label_gta size: {}".format(label_gta.size()))

            print("im_cs size: {}".format(im_cs.size()))
            print("label_cs size: {}".format(label_cs.size()))

        if not (im_syn.size() == im_gta.size() == im_cs.size()):
            print(im_syn.size())
            print(im_gta.size())
            print(im_cs.size())

        # forward pass and compute loss
        preds_syn, ft_syn = net(im_syn)
        # pooled_ft_syn = avg_pool(ft_syn)

        preds_gta, ft_gta = net(im_gta)
        # pooled_ft_gta = avg_pool(ft_gta)

        preds_cs, ft_cs = net(im_cs)
        # pooled_ft_cs = avg_pool(ft_cs)

        loss_synthia = supervised_loss(preds_syn, label_syn)
        loss_gta = supervised_loss(preds_gta, label_gta)

        loss = loss_synthia + loss_gta
        losses_task.append(loss.item())

        logprobs, sdomains, tdomains = mdan(ft_syn, ft_gta, ft_cs)

        slabels = torch.ones(batch_size, requires_grad=False).type(
            torch.LongTensor).to(gpu_ids[0])
        tlabels = torch.zeros(batch_size, requires_grad=False).type(
            torch.LongTensor).to(gpu_ids[0])

        # TODO: increase task loss
        # Compute prediction accuracy on multiple training sources.
        domain_losses = torch.stack([
            F.nll_loss(sdomains[j], slabels) +
            F.nll_loss(tdomains[j], tlabels)
            for j in range(configs['num_domains'])
        ])
        losses_domain_syn.append(domain_losses[0].item())
        losses_domain_gta.append(domain_losses[1].item())

        # Different final loss function depending on different training modes.
        if configs['mode'] == "maxmin":
            loss = torch.max(loss) + configs['mu'] * torch.min(domain_losses)
        elif configs['mode'] == "dynamic":
            loss = torch.log(
                torch.sum(
                    torch.exp(configs['gamma'] *
                              (loss + configs['mu'] * domain_losses)))
            ) / configs['gamma']

        # backward pass
        loss.backward()
        losses.append(loss.item())

        torch.nn.utils.clip_grad_norm_(net.module.parameters(), 10)
        torch.nn.utils.clip_grad_norm_(mdan.module.parameters(), 10)
        # step gradients
        opt.step()

        # log results
        if iteration % 10 == 0:
            logging.info(
                'Iteration {}:\t{:.3f} Domain SYN: {:.3f} Domain GTA: {:.3f} Task: {:.3f}'
                .format(iteration, np.mean(losses), np.mean(losses_domain_syn),
                        np.mean(losses_domain_gta), np.mean(losses_task)))
            writer.add_scalar('loss', np.mean(losses), iteration)
            writer.add_scalar('domain_syn', np.mean(losses_domain_syn),
                              iteration)
            writer.add_scalar('domain_gta', np.mean(losses_domain_gta),
                              iteration)
            writer.add_scalar('task', np.mean(losses_task), iteration)
        iteration += 1

        if iteration % 500 == 0:
            os.makedirs(output, exist_ok=True)
            torch.save(net.module.state_dict(),
                       '{}/net-itercurr.pth'.format(output))

        if iteration % snapshot == 0:
            torch.save(net.module.state_dict(),
                       '{}/iter_{}.pth'.format(output, iteration))

        if iteration >= iterations:
            logging.info('Optimization complete.')
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        self.semantic_loss = opt.semantic_loss

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = [
            'D_A_1', 'G_A_1', 'cycle_A_1', 'idt_A_1', 'D_B_1', 'G_B_1',
            'cycle_B_1', 'idt_B_1', 'D_A_2', 'G_A_2', 'cycle_A_2', 'idt_A_2',
            'D_B_2', 'G_B_2', 'cycle_B_2', 'idt_B_2'
        ]

        if opt.SAD:
            self.loss_names.extend(['D_3_1', 'G_s1s2'])

        if opt.CCD or opt.HF_CCD:
            self.loss_names.extend(['D_21', 'G_s1s21'])
            self.loss_names.extend(['D_12', 'G_s2s12'])

        if self.semantic_loss:
            self.loss_names.extend(['sem_syn', 'sem_gta'])

        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        visual_names_A_1 = ['real_A_1', 'fake_B_1', 'rec_A_1']
        visual_names_B_1 = ['real_B', 'fake_A_1', 'rec_B_1']

        visual_names_A_2 = ['real_A_2', 'fake_B_2', 'rec_A_2']
        visual_names_B_2 = ['fake_A_2', 'rec_B_2']

        if self.isTrain and self.opt.lambda_identity > 0.0:
            visual_names_A_1.append('idt_A_1')
            visual_names_B_1.append('idt_B_1')

            visual_names_A_2.append('idt_A_2')
            visual_names_B_2.append('idt_B_2')

        self.visual_names = visual_names_A_1 + visual_names_B_1 + visual_names_A_2 + visual_names_B_2
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            # self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
            if opt.Shared_DT:
                self.model_names = [
                    'G_A_1', 'G_B_1', 'D_A', 'D_B_1', 'D_B_2', 'G_A_2', 'G_B_2'
                ]
            else:
                self.model_names = [
                    'G_A_1', 'G_B_1', 'D_A_1', 'D_B_1', 'G_A_2', 'G_B_2',
                    'D_A_2', 'D_B_2'
                ]
            if opt.SAD:
                self.model_names.append('D_3')

            if opt.CCD or opt.HF_CCD:
                self.model_names.append('D_12')
                self.model_names.append('D_21')

        else:  # during test time, only load Gs
            self.model_names = ['G_A_1', 'G_B_1', 'G_A_2', 'G_B_2']

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A_1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                          opt.which_model_netG, opt.norm,
                                          not opt.no_dropout, opt.init_type,
                                          self.gpu_ids)
        self.netG_B_1 = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                          opt.which_model_netG, opt.norm,
                                          not opt.no_dropout, opt.init_type,
                                          self.gpu_ids)

        self.netG_A_2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                          opt.which_model_netG, opt.norm,
                                          not opt.no_dropout, opt.init_type,
                                          self.gpu_ids)

        self.netG_B_2 = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                          opt.which_model_netG, opt.norm,
                                          not opt.no_dropout, opt.init_type,
                                          self.gpu_ids)

        if opt.semantic_loss:
            self.netPixelCLS_SYN = get_model(opt.weights_model_type,
                                             num_cls=opt.num_cls,
                                             pretrained=True,
                                             weights_init=opt.weights_syn)
            self.netPixelCLS_GTA = get_model(opt.weights_model_type,
                                             num_cls=opt.num_cls,
                                             pretrained=True,
                                             weights_init=opt.weights_gta)
            if len(self.gpu_ids) > 0:
                assert (torch.cuda.is_available())
                self.netPixelCLS_SYN.to(self.gpu_ids[0])
                self.netPixelCLS_SYN = torch.nn.DataParallel(
                    self.netPixelCLS_SYN, self.gpu_ids)
                self.netPixelCLS_GTA.to(self.gpu_ids[0])
                self.netPixelCLS_GTA = torch.nn.DataParallel(
                    self.netPixelCLS_GTA, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            if opt.Shared_DT:
                self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                                opt.which_model_netD,
                                                opt.n_layers_D, opt.norm,
                                                use_sigmoid, opt.init_type,
                                                self.gpu_ids)
            else:
                self.netD_A_1 = networks.define_D(opt.output_nc, opt.ndf,
                                                  opt.which_model_netD,
                                                  opt.n_layers_D, opt.norm,
                                                  use_sigmoid, opt.init_type,
                                                  self.gpu_ids)

                self.netD_A_2 = networks.define_D(opt.output_nc, opt.ndf,
                                                  opt.which_model_netD,
                                                  opt.n_layers_D, opt.norm,
                                                  use_sigmoid, opt.init_type,
                                                  self.gpu_ids)

            self.netD_B_1 = networks.define_D(opt.input_nc, opt.ndf,
                                              opt.which_model_netD,
                                              opt.n_layers_D, opt.norm,
                                              use_sigmoid, opt.init_type,
                                              self.gpu_ids)

            self.netD_B_2 = networks.define_D(opt.input_nc, opt.ndf,
                                              opt.which_model_netD,
                                              opt.n_layers_D, opt.norm,
                                              use_sigmoid, opt.init_type,
                                              self.gpu_ids)

            if opt.SAD:
                self.netD_3 = networks.define_D(opt.input_nc, opt.ndf,
                                                opt.which_model_netD,
                                                opt.n_layers_D, opt.norm,
                                                use_sigmoid, opt.init_type,
                                                self.gpu_ids)
            if opt.CCD or opt.HF_CCD:
                self.netD_12 = networks.define_D(opt.input_nc, opt.ndf,
                                                 opt.which_model_netD,
                                                 opt.n_layers_D, opt.norm,
                                                 use_sigmoid, opt.init_type,
                                                 self.gpu_ids)
                self.netD_21 = networks.define_D(opt.input_nc, opt.ndf,
                                                 opt.which_model_netD,
                                                 opt.n_layers_D, opt.norm,
                                                 use_sigmoid, opt.init_type,
                                                 self.gpu_ids)

        if self.isTrain:
            self.fake_A_1_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            self.fake_B_1_pool = ImagePool(opt.pool_size)
            self.fake_A_2_pool = ImagePool(opt.pool_size)
            self.fake_B_2_pool = ImagePool(opt.pool_size)
            self.fake_A_21_pool = ImagePool(opt.pool_size)
            self.fake_A_12_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(
                use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionSemantic = torch.nn.KLDivLoss(reduction='batchmean')
            # initialize optimizers
            if opt.Shared_DT:
                self.optimizer_D = torch.optim.Adam(itertools.chain(
                    self.netD_A.parameters(), self.netD_B_1.parameters(),
                    self.netD_B_2.parameters()),
                                                    lr=opt.lr,
                                                    betas=(opt.beta1, 0.999))
            else:
                self.optimizer_D_1 = torch.optim.Adam(itertools.chain(
                    self.netD_A_1.parameters(), self.netD_B_1.parameters()),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
                self.optimizer_D_2 = torch.optim.Adam(itertools.chain(
                    self.netD_A_2.parameters(), self.netD_B_2.parameters()),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))

            self.optimizer_G_1 = torch.optim.Adam(itertools.chain(
                self.netG_A_1.parameters(), self.netG_B_1.parameters()),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))

            self.optimizer_G_2 = torch.optim.Adam(itertools.chain(
                self.netG_A_2.parameters(), self.netG_B_2.parameters()),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))

            if opt.SAD:
                self.optimizer_D_3 = torch.optim.Adam(self.netD_3.parameters(),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))

            if opt.CCD or opt.HF_CCD:
                self.optimizer_D_21 = torch.optim.Adam(
                    self.netD_21.parameters(),
                    lr=opt.lr,
                    betas=(opt.beta1, 0.999))
                self.optimizer_D_12 = torch.optim.Adam(
                    self.netD_12.parameters(),
                    lr=opt.lr,
                    betas=(opt.beta1, 0.999))

            self.optimizers = []
            self.optimizers.append(self.optimizer_G_1)
            self.optimizers.append(self.optimizer_G_2)
            if opt.Shared_DT:
                self.optimizers.append(self.optimizer_D)
            else:
                self.optimizers.append(self.optimizer_D_1)
                self.optimizers.append(self.optimizer_D_2)

            if opt.SAD:
                self.optimizers.append(self.optimizer_D_3)
            if opt.CCD or opt.HF_CCD:
                self.optimizers.append(self.optimizer_D_12)
                self.optimizers.append(self.optimizer_D_21)
Esempio n. 6
0
def main(output, dataset, datadir, lr, momentum, snapshot, downscale,
         cls_weights, gpu, weights_init, num_cls, lsgan, max_iter, lambda_d,
         lambda_g, train_discrim_only, weights_discrim, crop_size,
         weights_shared, discrim_feat, half_crop, batch, model, data_flag,
         resize, with_mmd_loss, small):
    # So data is sampled in consistent way
    np.random.seed(1336)
    torch.manual_seed(1336)
    logdir = 'runs/{:s}/{:s}_to_{:s}/lr{:.1g}_ld{:.2g}_lg{:.2g}'.format(
        model, dataset[0], dataset[1], lr, lambda_d, lambda_g)
    if weights_shared:
        logdir += '_weights_shared'
    else:
        logdir += '_weights_unshared'
    if discrim_feat:
        logdir += '_discrim_feat'
    else:
        logdir += '_discrim_score'
    logdir += '/' + datetime.now().strftime('%Y_%b_%d-%H:%M')
    writer = SummaryWriter(log_dir=logdir)

    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    config_logging()
    print('Train Discrim Only', train_discrim_only)
    if model == 'fcn8s':
        net = get_model(model,
                        num_cls=num_cls,
                        pretrained=True,
                        weights_init=weights_init,
                        output_last_ft=discrim_feat)
    else:
        net = get_model(model,
                        num_cls=num_cls,
                        finetune=True,
                        pretrained=True,
                        weights_init=weights_init,
                        output_last_ft=discrim_feat)

    net.cuda()
    str_ids = gpu.split(',')
    gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >= 0:
            gpu_ids.append(id)

    # set gpu ids
    if len(gpu_ids) > 0:
        torch.cuda.set_device(gpu_ids[0])
        assert (torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)

    if weights_shared:
        net_src = net  # shared weights
    else:
        net_src = get_model(model,
                            num_cls=num_cls,
                            finetune=True,
                            pretrained=True,
                            weights_init=weights_init,
                            output_last_ft=discrim_feat)
        net_src.eval()

    # initialize Discrminator
    odim = 1 if lsgan else 2
    idim = num_cls if not discrim_feat else 4096
    print('Discrim_feat', discrim_feat, idim)
    print('Discriminator init weights: ', weights_discrim)
    discriminator = Discriminator(input_dim=idim,
                                  output_dim=odim,
                                  pretrained=not (weights_discrim == None),
                                  weights_init=weights_discrim).cuda()

    discriminator.to(gpu_ids[0])
    discriminator = torch.nn.DataParallel(discriminator, gpu_ids)

    loader = AddaDataLoader(net.module.transform,
                            dataset,
                            datadir,
                            downscale,
                            resize=resize,
                            crop_size=crop_size,
                            half_crop=half_crop,
                            batch_size=batch,
                            shuffle=True,
                            num_workers=16,
                            src_data_flag=data_flag,
                            small=small)
    print('dataset', dataset)

    # Class weighted loss?
    if cls_weights is not None:
        weights = np.loadtxt(cls_weights)
    else:
        weights = None

    # setup optimizers
    opt_dis = torch.optim.SGD(discriminator.module.parameters(),
                              lr=lr,
                              momentum=momentum,
                              weight_decay=0.0005)
    opt_rep = torch.optim.SGD(net.module.parameters(),
                              lr=lr,
                              momentum=momentum,
                              weight_decay=0.0005)

    iteration = 0
    num_update_g = 0
    last_update_g = -1
    losses_super_s = deque(maxlen=100)
    losses_super_t = deque(maxlen=100)
    losses_dis = deque(maxlen=100)
    losses_rep = deque(maxlen=100)
    accuracies_dom = deque(maxlen=100)
    intersections = np.zeros([100, num_cls])
    iu_deque = deque(maxlen=100)
    unions = np.zeros([100, num_cls])
    accuracy = deque(maxlen=100)
    print('Max Iter:', max_iter)

    net.train()
    discriminator.train()

    loader.loader_src.dataset.__getitem__(0, debug=True)
    loader.loader_tgt.dataset.__getitem__(0, debug=True)

    while iteration < max_iter:

        for im_s, im_t, label_s, label_t in loader:

            if iteration == 0:
                print("IM S: {}".format(im_s.size()))
                print("Label S: {}".format(label_s.size()))
                print("IM T: {}".format(im_t.size()))
                print("Label T: {}".format(label_t.size()))

            if iteration > max_iter:
                break

            info_str = 'Iteration {}: '.format(iteration)

            if not check_label(label_s, num_cls):
                continue

            ###########################
            # 1. Setup Data Variables #
            ###########################
            im_s = make_variable(im_s, requires_grad=False)
            label_s = make_variable(label_s, requires_grad=False)
            im_t = make_variable(im_t, requires_grad=False)
            label_t = make_variable(label_t, requires_grad=False)

            #############################
            # 2. Optimize Discriminator #
            #############################

            # zero gradients for optimizer
            opt_dis.zero_grad()
            opt_rep.zero_grad()

            # extract features
            if discrim_feat:
                score_s, feat_s = net_src(im_s)
                score_s = Variable(score_s.data, requires_grad=False)
                f_s = Variable(feat_s.data, requires_grad=False)
            else:
                score_s = Variable(net_src(im_s).data, requires_grad=False)
                f_s = score_s

            dis_score_s = discriminator(f_s)

            if discrim_feat:
                score_t, feat_t = net(im_t)
                score_t = Variable(score_t.data, requires_grad=False)
                f_t = Variable(feat_t.data, requires_grad=False)
            else:
                score_t = Variable(net(im_t).data, requires_grad=False)
                f_t = score_t
            dis_score_t = discriminator(f_t)

            dis_pred_concat = torch.cat((dis_score_s, dis_score_t))

            # prepare real and fake labels
            batch_t, _, h, w = dis_score_t.size()
            batch_s, _, _, _ = dis_score_s.size()
            dis_label_concat = make_variable(torch.cat([
                torch.ones(batch_s, h, w).long(),
                torch.zeros(batch_t, h, w).long()
            ]),
                                             requires_grad=False)

            # compute loss for discriminator
            loss_dis = supervised_loss(dis_pred_concat, dis_label_concat)
            (lambda_d * loss_dis).backward()
            losses_dis.append(loss_dis.item())

            # optimize discriminator
            opt_dis.step()

            # compute discriminator acc
            pred_dis = torch.squeeze(dis_pred_concat.max(1)[1])
            dom_acc = (pred_dis == dis_label_concat).float().mean().item()
            accuracies_dom.append(dom_acc * 100.)

            # add discriminator info to log
            info_str += " domacc:{:0.1f}  D:{:.3f}".format(
                np.mean(accuracies_dom), np.mean(losses_dis))
            writer.add_scalar('loss/discriminator', np.mean(losses_dis),
                              iteration)
            writer.add_scalar('acc/discriminator', np.mean(accuracies_dom),
                              iteration)

            ###########################
            # Optimize Target Network #
            ########################### np.mean(accuracies_dom) > dom_acc_thresh

            dom_acc_thresh = 60

            if train_discrim_only and np.mean(accuracies_dom) > dom_acc_thresh:
                os.makedirs(output, exist_ok=True)
                torch.save(
                    discriminator.module.state_dict(),
                    '{}/discriminator_abv60.pth'.format(output, iteration))
                break

            if not train_discrim_only and np.mean(
                    accuracies_dom) > dom_acc_thresh:

                last_update_g = iteration
                num_update_g += 1
                if num_update_g % 1 == 0:
                    print(
                        'Updating G with adversarial loss ({:d} times)'.format(
                            num_update_g))

                # zero out optimizer gradients
                opt_dis.zero_grad()
                opt_rep.zero_grad()

                # extract features
                if discrim_feat:
                    score_t, feat_t = net(im_t)
                    score_t = Variable(score_t.data, requires_grad=False)
                    f_t = feat_t
                else:
                    score_t = net(im_t)
                    f_t = score_t

                # score_t = net(im_t)
                dis_score_t = discriminator(f_t)

                # create fake label
                batch, _, h, w = dis_score_t.size()
                target_dom_fake_t = make_variable(torch.ones(batch, h,
                                                             w).long(),
                                                  requires_grad=False)

                # compute loss for target net
                loss_gan_t = supervised_loss(dis_score_t, target_dom_fake_t)
                (lambda_g * loss_gan_t).backward()
                losses_rep.append(loss_gan_t.item())
                writer.add_scalar('loss/generator', np.mean(losses_rep),
                                  iteration)

                # optimize target net
                opt_rep.step()

                # log net update info
                info_str += ' G:{:.3f}'.format(np.mean(losses_rep))

            if (not train_discrim_only) and weights_shared and np.mean(
                    accuracies_dom) > dom_acc_thresh:
                print('Updating G using source supervised loss.')
                # zero out optimizer gradients
                opt_dis.zero_grad()
                opt_rep.zero_grad()

                # extract features
                if discrim_feat:
                    score_s, feat_s = net(im_s)
                else:
                    score_s = net(im_s)

                loss_supervised_s = supervised_loss(score_s,
                                                    label_s,
                                                    weights=weights)

                if with_mmd_loss:
                    print("Updating G using discrepancy loss")
                    lambda_discrepancy = 0.1
                    loss_mmd = mmd_loss(feat_s, feat_t) * 0.5 + mmd_loss(
                        score_s, score_t) * 0.5
                    loss_supervised_s += lambda_discrepancy * loss_mmd

                loss_supervised_s.backward()
                losses_super_s.append(loss_supervised_s.item())
                info_str += ' clsS:{:.2f}'.format(np.mean(losses_super_s))
                writer.add_scalar('loss/supervised/source',
                                  np.mean(losses_super_s), iteration)

                # optimize target net
                opt_rep.step()

            # compute supervised losses for target -- monitoring only!!!no backward()
            loss_supervised_t = supervised_loss(score_t,
                                                label_t,
                                                weights=weights)
            losses_super_t.append(loss_supervised_t.item())
            info_str += ' clsT:{:.2f}'.format(np.mean(losses_super_t))
            writer.add_scalar('loss/supervised/target',
                              np.mean(losses_super_t), iteration)

            ###########################
            # Log and compute metrics #
            ###########################
            if iteration % 10 == 0 and iteration > 0:

                # compute metrics
                intersection, union, acc = seg_accuracy(
                    score_t, label_t.data, num_cls)
                intersections = np.vstack(
                    [intersections[1:, :], intersection[np.newaxis, :]])
                unions = np.vstack([unions[1:, :], union[np.newaxis, :]])
                accuracy.append(acc.item() * 100)
                acc = np.mean(accuracy)
                mIoU = np.mean(
                    np.maximum(intersections, 1) / np.maximum(unions, 1)) * 100

                iu = (intersection / union) * 10000
                iu_deque.append(np.nanmean(iu))

                info_str += ' acc:{:0.2f}  mIoU:{:0.2f}'.format(
                    acc, np.mean(iu_deque))
                writer.add_scalar('metrics/acc', np.mean(accuracy), iteration)
                writer.add_scalar('metrics/mIoU', np.mean(mIoU), iteration)
                logging.info(info_str)

            iteration += 1

            ################
            # Save outputs #
            ################

            # every 500 iters save current model
            if iteration % 500 == 0:
                os.makedirs(output, exist_ok=True)
                if not train_discrim_only:
                    torch.save(net.module.state_dict(),
                               '{}/net-itercurr.pth'.format(output))
                torch.save(discriminator.module.state_dict(),
                           '{}/discriminator-itercurr.pth'.format(output))

            # save labeled snapshots
            if iteration % snapshot == 0:
                os.makedirs(output, exist_ok=True)
                if not train_discrim_only:
                    torch.save(net.module.state_dict(),
                               '{}/net-iter{}.pth'.format(output, iteration))
                torch.save(
                    discriminator.module.state_dict(),
                    '{}/discriminator-iter{}.pth'.format(output, iteration))

            if iteration - last_update_g >= 3 * len(loader):
                print('No suitable discriminator found -- returning.')
                torch.save(net.module.state_dict(),
                           '{}/net-iter{}.pth'.format(output, iteration))
                iteration = max_iter  # make sure outside loop breaks
                break

    writer.close()
Esempio n. 7
0
def main(output, dataset, datadir, batch_size, lr, step, iterations, momentum,
         snapshot, downscale, augmentation, fyu, crop_size, weights, model,
         gpu, num_cls):
    if weights is not None:
        raise RuntimeError("weights don't work because eric is bad at coding")
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    config_logging()

    logdir = 'runs/{:s}/{:s}'.format(model, '-'.join(dataset))
    writer = SummaryWriter(log_dir=logdir)
    net = get_model(model, num_cls=num_cls, finetune=True)
    net.cuda()
    transform = []
    target_transform = []
    if downscale is not None:
        transform.append(torchvision.transforms.Scale(1024 // downscale))
        target_transform.append(
            torchvision.transforms.Scale(1024 // downscale,
                                         interpolation=Image.NEAREST))
    transform.extend([torchvision.transforms.Scale(1024), net.transform])
    target_transform.extend([
        torchvision.transforms.Scale(1024, interpolation=Image.NEAREST),
        to_tensor_raw
    ])
    transform = torchvision.transforms.Compose(transform)
    target_transform = torchvision.transforms.Compose(target_transform)

    datasets = [
        get_dataset(name,
                    os.path.join(datadir, name),
                    transform=transform,
                    target_transform=target_transform) for name in dataset
    ]

    if weights is not None:
        weights = np.loadtxt(weights)
    opt = torch.optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=0.0005)

    if augmentation:
        collate_fn = lambda batch: augment_collate(
            batch, crop=crop_size, flip=True)
    else:
        collate_fn = torch.utils.data.dataloader.default_collate
    print(datasets)
    loaders = [
        torch.utils.data.DataLoader(dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=2,
                                    collate_fn=collate_fn,
                                    pin_memory=True) for dataset in datasets
    ]
    iteration = 0
    losses = deque(maxlen=10)
    for im, label in roundrobin_infinite(*loaders):
        # Clear out gradients
        opt.zero_grad()

        # load data/label
        im = make_variable(im, requires_grad=False)
        label = make_variable(label, requires_grad=False)

        # forward pass and compute loss
        preds = net(im)
        loss = supervised_loss(preds, label)

        # backward pass
        loss.backward()
        losses.append(loss.item())

        # step gradients
        opt.step()

        # log results
        if iteration % 10 == 0:
            logging.info('Iteration {}:\t{}'.format(iteration,
                                                    np.mean(losses)))
            writer.add_scalar('loss', np.mean(losses), iteration)
        iteration += 1
        if step is not None and iteration % step == 0:
            logging.info('Decreasing learning rate by 0.1.')
            step_lr(optimizer, 0.1)
        if iteration % snapshot == 0:
            torch.save(net.state_dict(),
                       '{}-iter{}.pth'.format(output, iteration))
        if iteration >= iterations:
            logging.info('Optimization complete.')
            break
Esempio n. 8
0
def main(output, dataset, datadir, lr, momentum, snapshot, downscale,
         cls_weights, gpu, weights_init, num_cls, lsgan, max_iter, lambda_d,
         lambda_g, train_discrim_only, weights_discrim, crop_size,
         weights_shared, discrim_feat, half_crop, batch, model):

    # So data is sampled in consistent way
    np.random.seed(1337)
    torch.manual_seed(1337)
    logdir = 'runs/{:s}/{:s}_to_{:s}/lr{:.1g}_ld{:.2g}_lg{:.2g}'.format(
        model, dataset[0], dataset[1], lr, lambda_d, lambda_g)
    if weights_shared:
        logdir += '_weightshared'
    else:
        logdir += '_weightsunshared'
    if discrim_feat:
        logdir += '_discrimfeat'
    else:
        logdir += '_discrimscore'
    logdir += '/' + datetime.now().strftime('%Y_%b_%d-%H:%M')
    writer = SummaryWriter(log_dir=logdir)

    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    config_logging()
    print(model)
    net = get_model(model,
                    num_cls=num_cls,
                    pretrained=True,
                    weights_init=weights_init,
                    output_last_ft=discrim_feat)

    loader = AddaDataLoader(net.transform,
                            dataset,
                            datadir,
                            downscale,
                            crop_size=crop_size,
                            half_crop=half_crop,
                            batch_size=batch,
                            shuffle=True,
                            num_workers=2)
    print('dataset', dataset)

    # Class weighted loss?
    weights = None

    # setup optimizers
    opt_rep = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=0.0005)

    iteration = 0
    num_update_g = 0
    last_update_g = -1
    losses_super_s = deque(maxlen=100)
    losses_super_t = deque(maxlen=100)
    intersections = np.zeros([100, num_cls])
    unions = np.zeros([100, num_cls])
    accuracy = deque(maxlen=100)
    print('max iter:', max_iter)

    net.train()
    cpn = UNet_CPN()
    cpn.load_state_dict(torch.load('CPN_cycada.pth'))
    cpn.cuda()
    cpn.eval()
    for param in cpn.parameters():
        param.requires_grad = False

    while iteration < max_iter:
        for im_s, im_t, label_s, label_t in loader:

            if iteration > max_iter:
                break
            info_str = 'Iteration {}: '.format(iteration)

            if not check_label(label_s, num_cls):
                continue

            ###########################
            # 1. Setup Data Variables #
            ###########################
            im_s = make_variable(im_s, requires_grad=False)
            label_s = make_variable(label_s, requires_grad=False)
            im_t = make_variable(im_t, requires_grad=False)
            label_t = make_variable(label_t, requires_grad=False)

            #############################
            # 2. Optimize Discriminator #
            #############################

            # zero gradients for optimizer
            opt_rep.zero_grad()
            score_s = net(im_s)

            _, fake_label_s = torch.max(score_s, 1)
            _, fake_label_s = torch.max(cpn(im_t, one_hot(fake_label_s, 19)),
                                        1)
            loss_supervised_s = supervised_loss(
                score_s, label_s, weights=weights) + 0.5 * supervised_loss(
                    score_s, fake_label_s, weights=weights)

            score_t = net(im_t)

            _, fake_label_t = torch.max(score_t, 1)
            _, fake_label_t = torch.max(cpn(im_t, one_hot(fake_label_t, 19)),
                                        1)
            loss_supervised_t = supervised_loss(score_t,
                                                fake_label_t,
                                                weights=weights)

            loss = loss_supervised_s + loss_supervised_t

            losses_super_t.append(loss_supervised_t.item())
            info_str += ' clsT:{:.2f}'.format(np.mean(losses_super_t))
            writer.add_scalar('loss/supervised/target',
                              np.mean(losses_super_t), iteration)

            losses_super_s.append(loss_supervised_s.item())
            info_str += ' clsS:{:.2f}'.format(np.mean(losses_super_s))
            writer.add_scalar('loss/supervised/source',
                              np.mean(losses_super_s), iteration)

            loss.backward()
            # optimize target net
            opt_rep.step()

            ###########################
            # Log and compute metrics #
            ###########################)
            if iteration % 1 == 0 and iteration > 0:
                # compute metrics
                intersection, union, acc = seg_accuracy(
                    score_t, label_t.data, num_cls)
                intersections = np.vstack(
                    [intersections[1:, :], intersection[np.newaxis, :]])
                unions = np.vstack([unions[1:, :], union[np.newaxis, :]])
                accuracy.append(acc.item() * 100)
                acc = np.mean(accuracy)
                mIoU = np.mean(
                    np.maximum(intersections, 1) / np.maximum(unions, 1)) * 100

                info_str += ' acc:{:0.2f}  mIoU:{:0.2f}'.format(acc, mIoU)
                writer.add_scalar('metrics/acc', np.mean(accuracy), iteration)
                writer.add_scalar('metrics/mIoU', np.mean(mIoU), iteration)
                logging.info(info_str)
            iteration += 1

            ################
            # Save outputs #
            ################

            # every 500 iters save current model
            if iteration % 500 == 0:
                os.makedirs(output, exist_ok=True)
                torch.save(net.state_dict(),
                           '{}/net-itercurr.pth'.format(output))

    writer.close()
Esempio n. 9
0
def main(output, dataset, datadir, lr, momentum, snapshot, downscale,
         cls_weights, weights_init, num_cls, lsgan, max_iter, lambda_d,
         lambda_g, train_discrim_only, weights_discrim, crop_size,
         weights_shared, discrim_feat, half_crop, batch, model, targetsup):

    targetSup = 1
    # So data is sampled in consistent way
    np.random.seed(1337)
    torch.manual_seed(1337)
    logdir = 'runs/{:s}/{:s}_to_{:s}/lr{:.1g}_ld{:.2g}_lg{:.2g}'.format(
        model, dataset[0], dataset[1], lr, lambda_d, lambda_g)
    if weights_shared:
        logdir += '_weightshared'
    else:
        logdir += '_weightsunshared'
    if discrim_feat:
        logdir += '_discrimfeat'
    else:
        logdir += '_discrimscore'
    logdir += '/' + datetime.now().strftime('%Y_%b_%d-%H:%M')
    writer = SummaryWriter(logdir)

    config_logging()
    print('Train Discrim Only', train_discrim_only)
    net = get_model(model, num_cls=num_cls, output_last_ft=discrim_feat)
    net.load_state_dict(torch.load(weights_init))
    if weights_shared:
        net_src = net  # shared weights
    else:
        net_src = get_model(model,
                            num_cls=num_cls,
                            output_last_ft=discrim_feat)
        new_src.load_state_dict(torch.load(weights_init))
        net_src.eval()

    print("GOT MODEL")

    odim = 1 if lsgan else 2
    idim = num_cls if not discrim_feat else 4096
    print('discrim_feat', discrim_feat, idim)
    print('discriminator init weights: ', weights_discrim)

    if torch.cuda.is_available():
        discriminator = Discriminator(input_dim=idim,
                                      output_dim=odim,
                                      pretrained=not (weights_discrim == None),
                                      weights_init=weights_discrim).cuda()
    else:
        discriminator = Discriminator(input_dim=idim,
                                      output_dim=odim,
                                      pretrained=not (weights_discrim == None),
                                      weights_init=weights_discrim)

    loader = AddaDataLoader(None,
                            dataset,
                            datadir,
                            downscale,
                            crop_size=crop_size,
                            half_crop=half_crop,
                            batch_size=batch,
                            shuffle=True,
                            num_workers=2)
    print('dataset', dataset)

    # Class weighted loss?
    if cls_weights is not None:
        weights = np.loadtxt(cls_weights)
    else:
        weights = None

    # setup optimizers
    opt_dis = torch.optim.SGD(discriminator.parameters(),
                              lr=lr,
                              momentum=momentum,
                              weight_decay=0.0005)
    opt_rep = torch.optim.SGD(net.parameters(),
                              lr=lr,
                              momentum=momentum,
                              weight_decay=0.0005)

    iteration = 0
    num_update_g = 0
    last_update_g = -1
    losses_super_s = deque(maxlen=100)
    losses_super_t = deque(maxlen=100)
    losses_dis = deque(maxlen=100)
    losses_rep = deque(maxlen=100)
    accuracies_dom = deque(maxlen=100)
    intersections = np.zeros([100, num_cls])
    unions = np.zeros([100, num_cls])
    accuracy = deque(maxlen=100)
    print('max iter:', max_iter)

    net.train()
    discriminator.train()
    IoU_s = deque(maxlen=100)
    IoU_t = deque(maxlen=100)

    Recall_s = deque(maxlen=100)
    Recall_t = deque(maxlen=100)

    while iteration < max_iter:

        for im_s, im_t, label_s, label_t in loader:

            if iteration > max_iter:
                break

            info_str = 'Iteration {}: '.format(iteration)

            if not check_label(label_s, num_cls):
                continue

            ###########################
            # 1. Setup Data Variables #
            ###########################
            im_s = make_variable(im_s, requires_grad=False)
            label_s = make_variable(label_s, requires_grad=False)
            im_t = make_variable(im_t, requires_grad=False)
            label_t = make_variable(label_t, requires_grad=False)

            #############################
            # 2. Optimize Discriminator #
            #############################

            # zero gradients for optimizer
            opt_dis.zero_grad()
            opt_rep.zero_grad()

            # extract features
            if discrim_feat:
                score_s, feat_s = net_src(im_s)
                score_s = Variable(score_s.data, requires_grad=False)
                f_s = Variable(feat_s.data, requires_grad=False)
            else:
                score_s = Variable(net_src(im_s).data, requires_grad=False)
                f_s = score_s
            dis_score_s = discriminator(f_s)

            if discrim_feat:
                score_t, feat_t = net(im_t)
                score_t = Variable(score_t.data, requires_grad=False)
                f_t = Variable(feat_t.data, requires_grad=False)
            else:
                score_t = Variable(net(im_t).data, requires_grad=False)
                f_t = score_t
            dis_score_t = discriminator(f_t)

            dis_pred_concat = torch.cat((dis_score_s, dis_score_t))

            # prepare real and fake labels
            batch_t, _, h, w = dis_score_t.size()
            batch_s, _, _, _ = dis_score_s.size()
            dis_label_concat = make_variable(torch.cat([
                torch.ones(batch_s, h, w).long(),
                torch.zeros(batch_t, h, w).long()
            ]),
                                             requires_grad=False)

            # compute loss for discriminator
            loss_dis = supervised_loss(dis_pred_concat, dis_label_concat)
            (lambda_d * loss_dis).backward()
            losses_dis.append(loss_dis.item())

            # optimize discriminator
            opt_dis.step()

            # compute discriminator acc
            pred_dis = torch.squeeze(dis_pred_concat.max(1)[1])
            dom_acc = (pred_dis == dis_label_concat).float().mean().item()
            accuracies_dom.append(dom_acc * 100.)

            # add discriminator info to log
            info_str += " domacc:{:0.1f}  D:{:.3f}".format(
                np.mean(accuracies_dom), np.mean(losses_dis))
            writer.add_scalar('loss/discriminator', np.mean(losses_dis),
                              iteration)
            writer.add_scalar('acc/discriminator', np.mean(accuracies_dom),
                              iteration)

            ###########################
            # Optimize Target Network #
            ###########################

            dom_acc_thresh = 55

            if not train_discrim_only and np.mean(
                    accuracies_dom) > dom_acc_thresh:

                last_update_g = iteration
                num_update_g += 1
                if num_update_g % 1 == 0:
                    print(
                        'Updating G with adversarial loss ({:d} times)'.format(
                            num_update_g))

                # zero out optimizer gradients
                opt_dis.zero_grad()
                opt_rep.zero_grad()

                # extract features
                if discrim_feat:
                    score_t, feat_t = net(im_t)
                    score_t = Variable(score_t.data, requires_grad=False)
                    f_t = feat_t
                else:
                    score_t = net(im_t)
                    f_t = score_t

                #score_t = net(im_t)
                dis_score_t = discriminator(f_t)

                # create fake label
                batch, _, h, w = dis_score_t.size()
                target_dom_fake_t = make_variable(torch.ones(batch, h,
                                                             w).long(),
                                                  requires_grad=False)

                # compute loss for target net
                loss_gan_t = supervised_loss(dis_score_t, target_dom_fake_t)
                (lambda_g * loss_gan_t).backward()
                losses_rep.append(loss_gan_t.item())
                writer.add_scalar('loss/generator', np.mean(losses_rep),
                                  iteration)

                # optimize target net
                opt_rep.step()

                # log net update info
                info_str += ' G:{:.3f}'.format(np.mean(losses_rep))

            if (not train_discrim_only) and weights_shared and (
                    np.mean(accuracies_dom) > dom_acc_thresh):

                print('Updating G using source supervised loss.')

                # zero out optimizer gradients
                opt_dis.zero_grad()
                opt_rep.zero_grad()

                # extract features
                if discrim_feat:
                    score_s, _ = net(im_s)
                    score_t, _ = net(im_t)
                else:
                    score_s = net(im_s)
                    score_t = net(im_t)

                loss_supervised_s = supervised_loss(score_s,
                                                    label_s,
                                                    weights=weights)
                loss_supervised_t = supervised_loss(score_t,
                                                    label_t,
                                                    weights=weights)
                loss_supervised = loss_supervised_s

                if targetSup:
                    loss_supervised += loss_supervised_t

                loss_supervised.backward()

                losses_super_s.append(loss_supervised_s.item())
                info_str += ' clsS:{:.2f}'.format(np.mean(losses_super_s))
                writer.add_scalar('loss/supervised/source',
                                  np.mean(losses_super_s), iteration)

                losses_super_t.append(loss_supervised_t.item())
                info_str += ' clsT:{:.2f}'.format(np.mean(losses_super_t))
                writer.add_scalar('loss/supervised/target',
                                  np.mean(losses_super_t), iteration)

                # optimize target net
                opt_rep.step()

            ###########################
            # Log and compute metrics #
            ###########################
            if iteration % 10 == 0 and iteration > 0:

                # compute metrics
                intersection, union, acc = seg_accuracy(
                    score_t, label_t.data, num_cls)
                iou_s = IoU(score_s, label_s)
                iou_t = IoU(score_t, label_t)
                rc_s = recall(score_s, label_s)
                rc_t = recall(score_t, label_t)
                IoU_s.append(iou_s.item())
                IoU_t.append(iou_t.item())
                Recall_s.append(rc_s.item())
                Recall_t.append(rc_t.item())
                intersections = np.vstack(
                    [intersections[1:, :], intersection[np.newaxis, :]])
                unions = np.vstack([unions[1:, :], union[np.newaxis, :]])
                accuracy.append(acc.item() * 100)
                acc = np.mean(accuracy)
                mIoU = np.mean(
                    np.maximum(intersections, 1) / np.maximum(unions, 1)) * 100

                info_str += ' IoU:{:0.2f}  Recall:{:0.2f}'.format(iou_s, rc_s)
                # writer.add_scalar('metrics/acc', np.mean(accuracy), iteration)
                # writer.add_scalar('metrics/mIoU', np.mean(mIoU), iteration)
                # writer.add_scalar('metrics/RealIoU_Source', np.mean(IoU_s))
                # writer.add_scalar('metrics/RealIoU_Target', np.mean(IoU_t))
                # writer.add_scalar('metrics/RealRecall_Source', np.mean(Recall_s))
                # writer.add_scalar('metrics/RealRecall_Target', np.mean(Recall_t))
                logging.info(info_str)
                print(info_str)

                im_s = Image.fromarray(
                    np.uint8(
                        norm(im_s[0]).permute(1, 2, 0).cpu().data.numpy() *
                        255))
                im_t = Image.fromarray(
                    np.uint8(
                        norm(im_t[0]).permute(1, 2, 0).cpu().data.numpy() *
                        255))
                label_s = Image.fromarray(
                    np.uint8(label_s[0].cpu().data.numpy() * 255))
                label_t = Image.fromarray(
                    np.uint8(label_t[0].cpu().data.numpy() * 255))
                score_s = Image.fromarray(
                    np.uint8(mxAxis(score_s[0]).cpu().data.numpy() * 255))
                score_t = Image.fromarray(
                    np.uint8(mxAxis(score_t[0]).cpu().data.numpy() * 255))

                im_s.save(output + "/im_s.png")
                im_t.save(output + "/im_t.png")
                label_s.save(output + "/label_s.png")
                label_t.save(output + "/label_t.png")
                score_s.save(output + "/score_s.png")
                score_t.save(output + "/score_t.png")

            iteration += 1

            ################
            # Save outputs #
            ################

            # every 500 iters save current model
            if iteration % 500 == 0:
                os.makedirs(output, exist_ok=True)
                if not train_discrim_only:
                    torch.save(net.state_dict(),
                               '{}/net-itercurr.pth'.format(output))
                torch.save(discriminator.state_dict(),
                           '{}/discriminator-itercurr.pth'.format(output))

            # save labeled snapshots
            if iteration % snapshot == 0:
                os.makedirs(output, exist_ok=True)
                if not train_discrim_only:
                    torch.save(net.state_dict(),
                               '{}/net-iter{}.pth'.format(output, iteration))
                torch.save(
                    discriminator.state_dict(),
                    '{}/discriminator-iter{}.pth'.format(output, iteration))

            if iteration - last_update_g >= len(loader):
                print('No suitable discriminator found -- returning.')
                # import pdb;pdb.set_trace()
                # torch.save(net.state_dict(),'{}/net-iter{}.pth'.format(output, iteration))
                # iteration = max_iter # make sure outside loop breaks
                # break

    writer.close()
Esempio n. 10
0
def main(output, dataset, datadir, batch_size, lr, step, iterations, momentum,
         snapshot, downscale, augmentation, fyu, crop_size, weights, model,
         gpu, num_cls, nthreads, model_weights, data_flag, serial_batches,
         resize_to, start_step, preprocessing, small, rundir_flag, force_split,
         adam):
    if weights is not None:
        raise RuntimeError("weights don't work because eric is bad at coding")
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    config_logging()
    logdir_flag = data_flag
    if rundir_flag != "":
        logdir_flag += "_{}".format(rundir_flag)

    logdir = 'runs/{:s}/{:s}/{:s}'.format(model, '-'.join(dataset),
                                          logdir_flag)
    writer = SummaryWriter(log_dir=logdir)
    if model == 'fcn8s':
        net = get_model(model, num_cls=num_cls, weights_init=model_weights)
    else:
        net = get_model(model,
                        num_cls=num_cls,
                        finetune=True,
                        weights_init=model_weights)
    net.cuda()

    str_ids = gpu.split(',')
    gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >= 0:
            gpu_ids.append(id)

    # set gpu ids
    if len(gpu_ids) > 0:
        torch.cuda.set_device(gpu_ids[0])
        assert (torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)

    transform = []
    target_transform = []

    if preprocessing:
        transform.extend([
            torchvision.transforms.Resize(
                [int(resize_to), int(int(resize_to) * 1.8)])
        ])
        target_transform.extend([
            torchvision.transforms.Resize(
                [int(resize_to), int(int(resize_to) * 1.8)],
                interpolation=Image.NEAREST)
        ])

    transform.extend([net.module.transform])
    target_transform.extend([to_tensor_raw])
    transform = torchvision.transforms.Compose(transform)
    target_transform = torchvision.transforms.Compose(target_transform)

    if force_split:
        datasets = []
        datasets.append(
            get_dataset(dataset[0],
                        os.path.join(datadir, dataset[0]),
                        num_cls=num_cls,
                        transform=transform,
                        target_transform=target_transform,
                        data_flag=data_flag))
        datasets.append(
            get_dataset(dataset[1],
                        os.path.join(datadir, dataset[1]),
                        num_cls=num_cls,
                        transform=transform,
                        target_transform=target_transform))
    else:
        datasets = [
            get_dataset(name,
                        os.path.join(datadir, name),
                        num_cls=num_cls,
                        transform=transform,
                        target_transform=target_transform,
                        data_flag=data_flag) for name in dataset
        ]

    if weights is not None:
        weights = np.loadtxt(weights)

    if adam:
        print("Using Adam")
        opt = torch.optim.Adam(net.module.parameters(), lr=1e-4)
    else:
        print("Using SGD")
        opt = torch.optim.SGD(net.module.parameters(),
                              lr=lr,
                              momentum=momentum,
                              weight_decay=0.0005)

    if augmentation:
        collate_fn = lambda batch: augment_collate(
            batch, crop=crop_size, flip=True)
    else:
        collate_fn = torch.utils.data.dataloader.default_collate

    loaders = [
        torch.utils.data.DataLoader(dataset,
                                    batch_size=batch_size,
                                    shuffle=not serial_batches,
                                    num_workers=nthreads,
                                    collate_fn=collate_fn,
                                    pin_memory=True) for dataset in datasets
    ]
    iteration = start_step
    losses = deque(maxlen=10)

    for loader in loaders:
        loader.dataset.__getitem__(0, debug=True)

    for im, label in roundrobin_infinite(*loaders):
        # Clear out gradients
        opt.zero_grad()

        # load data/label
        im = make_variable(im, requires_grad=False)
        label = make_variable(label, requires_grad=False)

        if iteration == 0:
            print("im size: {}".format(im.size()))
            print("label size: {}".format(label.size()))

        # forward pass and compute loss
        preds = net(im)
        loss = supervised_loss(preds, label)

        # backward pass
        loss.backward()
        losses.append(loss.item())

        # step gradients
        opt.step()

        # log results
        if iteration % 10 == 0:
            logging.info('Iteration {}:\t{}'.format(iteration,
                                                    np.mean(losses)))
            writer.add_scalar('loss', np.mean(losses), iteration)
        iteration += 1
        if step is not None and iteration % step == 0:
            logging.info('Decreasing learning rate by 0.1.')
            step_lr(opt, 0.1)

        if iteration % snapshot == 0:
            torch.save(net.module.state_dict(),
                       '{}/iter_{}.pth'.format(output, iteration))

        if iteration >= iterations:
            logging.info('Optimization complete.')
Esempio n. 11
0
def main(config_path):
    config = None

    config_file = config_path.split('/')[-1]
    version = config_file.split('.')[0][1:]

    with open(config_path, 'r') as f:
        config = json.load(f)

    config["version"] = version
    config_logging()

    # Initialize SummaryWriter - For tensorboard visualizations
    logdir = 'runs/{:s}/{:s}/{:s}/{:s}'.format(config["model"],
                                               config["dataset"],
                                               'v{}'.format(config["version"]),
                                               'tflogs')
    logdir = logdir + "/"

    checkpointdir = join('runs', config["model"], config["dataset"],
                         'v{}'.format(config["version"]), 'checkpoints')

    print("Logging directory: {}".format(logdir))
    print("Checkpoint directory: {}".format(checkpointdir))

    versionpath = join('runs', config["model"], config["dataset"],
                       'v{}'.format(config["version"]))

    if not exists(versionpath):
        os.makedirs(versionpath)
        os.makedirs(checkpointdir)
        os.makedirs(logdir)
    elif exists(versionpath) and config["force"]:
        shutil.rmtree(versionpath)
        os.makedirs(versionpath)
        os.makedirs(checkpointdir)
        os.makedirs(logdir)
    else:
        print(
            "Version {} already exists! Please run with different version number"
            .format(config["version"]))
        logging.info(
            "Version {} already exists! Please run with different version number"
            .format(config["version"]))
        sys.exit(-1)

    writer = SummaryWriter(logdir)
    # Get appropriate model based on config parameters
    net = get_model(config["model"], num_cls=config["num_cls"])
    if args.load:
        net.load_state_dict(torch.load(args.load))
        print("============ Loading Model ===============")

    model_parameters = filter(lambda p: p.requires_grad, net.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    dataset = config["dataset"]
    num_workers = config["num_workers"]
    pin_memory = config["pin_memory"]
    dataset = dataset[0]

    datasets_train = get_fcn_dataset(config["dataset"],
                                     config["data_type"],
                                     join(config["datadir"],
                                          config["dataset"]),
                                     split='train')
    datasets_val = get_fcn_dataset(config["dataset"],
                                   config["data_type"],
                                   join(config["datadir"], config["dataset"]),
                                   split='val')
    datasets_test = get_fcn_dataset(config["dataset"],
                                    config["data_type"],
                                    join(config["datadir"], config["dataset"]),
                                    split='test')

    if config["weights"] is not None:
        weights = np.loadtxt(config["weights"])
    opt = torch.optim.SGD(net.parameters(),
                          lr=config["lr"],
                          momentum=config["momentum"],
                          weight_decay=0.0005)

    if config["augmentation"]:
        collate_fn = lambda batch: augment_collate(
            batch, crop=config["crop_size"], flip=True)
    else:
        collate_fn = torch.utils.data.dataloader.default_collate

    train_loader = torch.utils.data.DataLoader(datasets_train,
                                               batch_size=config["batch_size"],
                                               shuffle=True,
                                               num_workers=num_workers,
                                               collate_fn=collate_fn,
                                               pin_memory=pin_memory)

    # val_loader = torch.utils.data.DataLoader(datasets_val, batch_size=config["batch_size"],
    #                                         shuffle=True, num_workers=num_workers,
    #                                         collate_fn=collate_fn,
    #                                         pin_memory=pin_memory)

    test_loader = torch.utils.data.DataLoader(datasets_test,
                                              batch_size=config["batch_size"],
                                              shuffle=False,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn,
                                              pin_memory=pin_memory)

    data_metric = {'train': None, 'val': None, 'test': None}
    Q_size = len(train_loader) / config["batch_size"]

    metrics = {'losses': list(), 'ious': list(), 'recalls': list()}

    data_metric['train'] = copy(metrics)
    data_metric['val'] = copy(metrics)
    data_metric['test'] = copy(metrics)
    num_cls = config["num_cls"]
    hist = np.zeros((num_cls, num_cls))
    iteration = 0

    for epoch in range(config["num_epochs"] + 1):
        if config["phase"] == 'train':
            net.train()
            iterator = tqdm(iter(train_loader))

            # Epoch train
            print("Train Epoch!")
            for im, label in iterator:
                if torch.isnan(im).any() or torch.isnan(label).any():
                    import pdb
                    pdb.set_trace()
                iteration += 1
                # Clear out gradients
                opt.zero_grad()
                # load data/label
                im = make_variable(im, requires_grad=False)
                label = make_variable(label, requires_grad=False)
                #print(im.size())

                # forward pass and compute loss
                preds = net(im)
                #score = preds.data
                #_, pred = torch.max(score, 1)

                #hist += fast_hist(label.cpu().numpy().flatten(), pred.cpu().numpy().flatten(),num_cls)

                #acc_overall, acc_percls, iu, fwIU = result_stats(hist)
                loss = supervised_loss(preds, label)
                # iou = jaccard_score(preds, label)
                precision, rc, fscore, support, iou = sklearnScores(
                    preds, label.type(torch.IntTensor))
                #print(acc_overall, np.nanmean(acc_percls), np.nanmean(iu), fwIU)
                # backward pass
                loss.backward()

                # TODO: Right now this is running average, ideally we want true average. Make that change
                # Total average will be memory intensive, let it be running average for the moment.
                data_metric['train']['losses'].append(loss.item())
                data_metric['train']['ious'].append(iou)
                data_metric['train']['recalls'].append(rc)
                # step gradients
                opt.step()

                # Train visualizations - each iteration
                if iteration % config["train_tf_interval"] == 0:
                    vizz = preprocess_viz(im, preds, label)
                    writer.add_scalar('train/loss', loss, iteration)
                    writer.add_scalar('train/IOU', iou, iteration)
                    writer.add_scalar('train/recall', rc, iteration)
                    imutil = vutils.make_grid(torch.from_numpy(vizz),
                                              nrow=3,
                                              normalize=True,
                                              scale_each=True)
                    writer.add_image('{}_image_data'.format('train'), imutil,
                                     iteration)

                iterator.set_description("TRAIN V: {} | Epoch: {}".format(
                    config["version"], epoch))
                iterator.refresh()

                if iteration % 20000 == 0:
                    torch.save(
                        net.state_dict(),
                        join(checkpointdir,
                             'iter_{}_{}.pth'.format(iteration, epoch)))

            # clean before test/val
            opt.zero_grad()

            # Train visualizations - per epoch
            vizz = preprocess_viz(im, preds, label)
            writer.add_scalar('trainepoch/loss',
                              np.mean(data_metric['train']['losses']),
                              global_step=epoch)
            writer.add_scalar('trainepoch/IOU',
                              np.mean(data_metric['train']['ious']),
                              global_step=epoch)
            writer.add_scalar('trainepoch/recall',
                              np.mean(data_metric['train']['recalls']),
                              global_step=epoch)
            imutil = vutils.make_grid(torch.from_numpy(vizz),
                                      nrow=3,
                                      normalize=True,
                                      scale_each=True)
            writer.add_image('{}_image_data'.format('trainepoch'),
                             imutil,
                             global_step=epoch)

            print("Loss :{}".format(np.mean(data_metric['train']['losses'])))
            print("IOU :{}".format(np.mean(data_metric['train']['ious'])))
            print("recall :{}".format(np.mean(
                data_metric['train']['recalls'])))

            if epoch % config["checkpoint_interval"] == 0:
                torch.save(net.state_dict(),
                           join(checkpointdir, 'iter{}.pth'.format(epoch)))

            # Train epoch done. Free up lists
            for key in data_metric['train'].keys():
                data_metric['train'][key] = list()

            if epoch % config["val_epoch_interval"] == 0:
                net.eval()
                print("Val_epoch!")
                iterator = tqdm(iter(val_loader))
                for im, label in iterator:
                    # load data/label
                    im = make_variable(im, requires_grad=False)
                    label = make_variable(label, requires_grad=False)

                    # forward pass and compute loss
                    preds = net(im)
                    loss = supervised_loss(preds, label)
                    precision, rc, fscore, support, iou = sklearnScores(
                        preds, label.type(torch.IntTensor))

                    data_metric['val']['losses'].append(loss.item())
                    data_metric['val']['ious'].append(iou)
                    data_metric['val']['recalls'].append(rc)

                    iterator.set_description("VAL V: {} | Epoch: {}".format(
                        config["version"], epoch))
                    iterator.refresh()

                # Val visualizations
                vizz = preprocess_viz(im, preds, label)
                writer.add_scalar('valepoch/loss',
                                  np.mean(data_metric['val']['losses']),
                                  global_step=epoch)
                writer.add_scalar('valepoch/IOU',
                                  np.mean(data_metric['val']['ious']),
                                  global_step=epoch)
                writer.add_scalar('valepoch/Recall',
                                  np.mean(data_metric['val']['recalls']),
                                  global_step=epoch)
                imutil = vutils.make_grid(torch.from_numpy(vizz),
                                          nrow=3,
                                          normalize=True,
                                          scale_each=True)
                writer.add_image('{}_image_data'.format('val'),
                                 imutil,
                                 global_step=epoch)

                # Val epoch done. Free up lists
                for key in data_metric['val'].keys():
                    data_metric['val'][key] = list()

            # Epoch Test
            if epoch % config["test_epoch_interval"] == 0:
                net.eval()
                print("Test_epoch!")
                iterator = tqdm(iter(test_loader))
                for im, label in iterator:
                    # load data/label
                    im = make_variable(im, requires_grad=False)
                    label = make_variable(label, requires_grad=False)

                    # forward pass and compute loss
                    preds = net(im)
                    loss = supervised_loss(preds, label)
                    precision, rc, fscore, support, iou = sklearnScores(
                        preds, label.type(torch.IntTensor))

                    data_metric['test']['losses'].append(loss.item())
                    data_metric['test']['ious'].append(iou)
                    data_metric['test']['recalls'].append(rc)

                    iterator.set_description("TEST V: {} | Epoch: {}".format(
                        config["version"], epoch))
                    iterator.refresh()

                # Test visualizations
                writer.add_scalar('testepoch/loss',
                                  np.mean(data_metric['test']['losses']),
                                  global_step=epoch)
                writer.add_scalar('testepoch/IOU',
                                  np.mean(data_metric['test']['ious']),
                                  global_step=epoch)
                writer.add_scalar('testepoch/Recall',
                                  np.mean(data_metric['test']['recalls']),
                                  global_step=epoch)

                # Test epoch done. Free up lists
                for key in data_metric['test'].keys():
                    data_metric['test'][key] = list()

            if config["step"] is not None and epoch % config["step"] == 0:
                logging.info('Decreasing learning rate by 0.1 factor')
                step_lr(optimizer, 0.1)

    logging.info('Optimization complete.')