Esempio n. 1
0
def main(output, dataset, datadir, lr, momentum, snapshot, downscale,
         cls_weights, gpu, weights_init, num_cls, lsgan, max_iter, lambda_d,
         lambda_g, train_discrim_only, weights_discrim, crop_size,
         weights_shared, discrim_feat, half_crop, batch, model, data_flag,
         resize, with_mmd_loss, small):
    # So data is sampled in consistent way
    np.random.seed(1336)
    torch.manual_seed(1336)
    logdir = 'runs/{:s}/{:s}_to_{:s}/lr{:.1g}_ld{:.2g}_lg{:.2g}'.format(
        model, dataset[0], dataset[1], lr, lambda_d, lambda_g)
    if weights_shared:
        logdir += '_weights_shared'
    else:
        logdir += '_weights_unshared'
    if discrim_feat:
        logdir += '_discrim_feat'
    else:
        logdir += '_discrim_score'
    logdir += '/' + datetime.now().strftime('%Y_%b_%d-%H:%M')
    writer = SummaryWriter(log_dir=logdir)

    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    config_logging()
    print('Train Discrim Only', train_discrim_only)
    if model == 'fcn8s':
        net = get_model(model,
                        num_cls=num_cls,
                        pretrained=True,
                        weights_init=weights_init,
                        output_last_ft=discrim_feat)
    else:
        net = get_model(model,
                        num_cls=num_cls,
                        finetune=True,
                        pretrained=True,
                        weights_init=weights_init,
                        output_last_ft=discrim_feat)

    net.cuda()
    str_ids = gpu.split(',')
    gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >= 0:
            gpu_ids.append(id)

    # set gpu ids
    if len(gpu_ids) > 0:
        torch.cuda.set_device(gpu_ids[0])
        assert (torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)

    if weights_shared:
        net_src = net  # shared weights
    else:
        net_src = get_model(model,
                            num_cls=num_cls,
                            finetune=True,
                            pretrained=True,
                            weights_init=weights_init,
                            output_last_ft=discrim_feat)
        net_src.eval()

    # initialize Discrminator
    odim = 1 if lsgan else 2
    idim = num_cls if not discrim_feat else 4096
    print('Discrim_feat', discrim_feat, idim)
    print('Discriminator init weights: ', weights_discrim)
    discriminator = Discriminator(input_dim=idim,
                                  output_dim=odim,
                                  pretrained=not (weights_discrim == None),
                                  weights_init=weights_discrim).cuda()

    discriminator.to(gpu_ids[0])
    discriminator = torch.nn.DataParallel(discriminator, gpu_ids)

    loader = AddaDataLoader(net.module.transform,
                            dataset,
                            datadir,
                            downscale,
                            resize=resize,
                            crop_size=crop_size,
                            half_crop=half_crop,
                            batch_size=batch,
                            shuffle=True,
                            num_workers=16,
                            src_data_flag=data_flag,
                            small=small)
    print('dataset', dataset)

    # Class weighted loss?
    if cls_weights is not None:
        weights = np.loadtxt(cls_weights)
    else:
        weights = None

    # setup optimizers
    opt_dis = torch.optim.SGD(discriminator.module.parameters(),
                              lr=lr,
                              momentum=momentum,
                              weight_decay=0.0005)
    opt_rep = torch.optim.SGD(net.module.parameters(),
                              lr=lr,
                              momentum=momentum,
                              weight_decay=0.0005)

    iteration = 0
    num_update_g = 0
    last_update_g = -1
    losses_super_s = deque(maxlen=100)
    losses_super_t = deque(maxlen=100)
    losses_dis = deque(maxlen=100)
    losses_rep = deque(maxlen=100)
    accuracies_dom = deque(maxlen=100)
    intersections = np.zeros([100, num_cls])
    iu_deque = deque(maxlen=100)
    unions = np.zeros([100, num_cls])
    accuracy = deque(maxlen=100)
    print('Max Iter:', max_iter)

    net.train()
    discriminator.train()

    loader.loader_src.dataset.__getitem__(0, debug=True)
    loader.loader_tgt.dataset.__getitem__(0, debug=True)

    while iteration < max_iter:

        for im_s, im_t, label_s, label_t in loader:

            if iteration == 0:
                print("IM S: {}".format(im_s.size()))
                print("Label S: {}".format(label_s.size()))
                print("IM T: {}".format(im_t.size()))
                print("Label T: {}".format(label_t.size()))

            if iteration > max_iter:
                break

            info_str = 'Iteration {}: '.format(iteration)

            if not check_label(label_s, num_cls):
                continue

            ###########################
            # 1. Setup Data Variables #
            ###########################
            im_s = make_variable(im_s, requires_grad=False)
            label_s = make_variable(label_s, requires_grad=False)
            im_t = make_variable(im_t, requires_grad=False)
            label_t = make_variable(label_t, requires_grad=False)

            #############################
            # 2. Optimize Discriminator #
            #############################

            # zero gradients for optimizer
            opt_dis.zero_grad()
            opt_rep.zero_grad()

            # extract features
            if discrim_feat:
                score_s, feat_s = net_src(im_s)
                score_s = Variable(score_s.data, requires_grad=False)
                f_s = Variable(feat_s.data, requires_grad=False)
            else:
                score_s = Variable(net_src(im_s).data, requires_grad=False)
                f_s = score_s

            dis_score_s = discriminator(f_s)

            if discrim_feat:
                score_t, feat_t = net(im_t)
                score_t = Variable(score_t.data, requires_grad=False)
                f_t = Variable(feat_t.data, requires_grad=False)
            else:
                score_t = Variable(net(im_t).data, requires_grad=False)
                f_t = score_t
            dis_score_t = discriminator(f_t)

            dis_pred_concat = torch.cat((dis_score_s, dis_score_t))

            # prepare real and fake labels
            batch_t, _, h, w = dis_score_t.size()
            batch_s, _, _, _ = dis_score_s.size()
            dis_label_concat = make_variable(torch.cat([
                torch.ones(batch_s, h, w).long(),
                torch.zeros(batch_t, h, w).long()
            ]),
                                             requires_grad=False)

            # compute loss for discriminator
            loss_dis = supervised_loss(dis_pred_concat, dis_label_concat)
            (lambda_d * loss_dis).backward()
            losses_dis.append(loss_dis.item())

            # optimize discriminator
            opt_dis.step()

            # compute discriminator acc
            pred_dis = torch.squeeze(dis_pred_concat.max(1)[1])
            dom_acc = (pred_dis == dis_label_concat).float().mean().item()
            accuracies_dom.append(dom_acc * 100.)

            # add discriminator info to log
            info_str += " domacc:{:0.1f}  D:{:.3f}".format(
                np.mean(accuracies_dom), np.mean(losses_dis))
            writer.add_scalar('loss/discriminator', np.mean(losses_dis),
                              iteration)
            writer.add_scalar('acc/discriminator', np.mean(accuracies_dom),
                              iteration)

            ###########################
            # Optimize Target Network #
            ########################### np.mean(accuracies_dom) > dom_acc_thresh

            dom_acc_thresh = 60

            if train_discrim_only and np.mean(accuracies_dom) > dom_acc_thresh:
                os.makedirs(output, exist_ok=True)
                torch.save(
                    discriminator.module.state_dict(),
                    '{}/discriminator_abv60.pth'.format(output, iteration))
                break

            if not train_discrim_only and np.mean(
                    accuracies_dom) > dom_acc_thresh:

                last_update_g = iteration
                num_update_g += 1
                if num_update_g % 1 == 0:
                    print(
                        'Updating G with adversarial loss ({:d} times)'.format(
                            num_update_g))

                # zero out optimizer gradients
                opt_dis.zero_grad()
                opt_rep.zero_grad()

                # extract features
                if discrim_feat:
                    score_t, feat_t = net(im_t)
                    score_t = Variable(score_t.data, requires_grad=False)
                    f_t = feat_t
                else:
                    score_t = net(im_t)
                    f_t = score_t

                # score_t = net(im_t)
                dis_score_t = discriminator(f_t)

                # create fake label
                batch, _, h, w = dis_score_t.size()
                target_dom_fake_t = make_variable(torch.ones(batch, h,
                                                             w).long(),
                                                  requires_grad=False)

                # compute loss for target net
                loss_gan_t = supervised_loss(dis_score_t, target_dom_fake_t)
                (lambda_g * loss_gan_t).backward()
                losses_rep.append(loss_gan_t.item())
                writer.add_scalar('loss/generator', np.mean(losses_rep),
                                  iteration)

                # optimize target net
                opt_rep.step()

                # log net update info
                info_str += ' G:{:.3f}'.format(np.mean(losses_rep))

            if (not train_discrim_only) and weights_shared and np.mean(
                    accuracies_dom) > dom_acc_thresh:
                print('Updating G using source supervised loss.')
                # zero out optimizer gradients
                opt_dis.zero_grad()
                opt_rep.zero_grad()

                # extract features
                if discrim_feat:
                    score_s, feat_s = net(im_s)
                else:
                    score_s = net(im_s)

                loss_supervised_s = supervised_loss(score_s,
                                                    label_s,
                                                    weights=weights)

                if with_mmd_loss:
                    print("Updating G using discrepancy loss")
                    lambda_discrepancy = 0.1
                    loss_mmd = mmd_loss(feat_s, feat_t) * 0.5 + mmd_loss(
                        score_s, score_t) * 0.5
                    loss_supervised_s += lambda_discrepancy * loss_mmd

                loss_supervised_s.backward()
                losses_super_s.append(loss_supervised_s.item())
                info_str += ' clsS:{:.2f}'.format(np.mean(losses_super_s))
                writer.add_scalar('loss/supervised/source',
                                  np.mean(losses_super_s), iteration)

                # optimize target net
                opt_rep.step()

            # compute supervised losses for target -- monitoring only!!!no backward()
            loss_supervised_t = supervised_loss(score_t,
                                                label_t,
                                                weights=weights)
            losses_super_t.append(loss_supervised_t.item())
            info_str += ' clsT:{:.2f}'.format(np.mean(losses_super_t))
            writer.add_scalar('loss/supervised/target',
                              np.mean(losses_super_t), iteration)

            ###########################
            # Log and compute metrics #
            ###########################
            if iteration % 10 == 0 and iteration > 0:

                # compute metrics
                intersection, union, acc = seg_accuracy(
                    score_t, label_t.data, num_cls)
                intersections = np.vstack(
                    [intersections[1:, :], intersection[np.newaxis, :]])
                unions = np.vstack([unions[1:, :], union[np.newaxis, :]])
                accuracy.append(acc.item() * 100)
                acc = np.mean(accuracy)
                mIoU = np.mean(
                    np.maximum(intersections, 1) / np.maximum(unions, 1)) * 100

                iu = (intersection / union) * 10000
                iu_deque.append(np.nanmean(iu))

                info_str += ' acc:{:0.2f}  mIoU:{:0.2f}'.format(
                    acc, np.mean(iu_deque))
                writer.add_scalar('metrics/acc', np.mean(accuracy), iteration)
                writer.add_scalar('metrics/mIoU', np.mean(mIoU), iteration)
                logging.info(info_str)

            iteration += 1

            ################
            # Save outputs #
            ################

            # every 500 iters save current model
            if iteration % 500 == 0:
                os.makedirs(output, exist_ok=True)
                if not train_discrim_only:
                    torch.save(net.module.state_dict(),
                               '{}/net-itercurr.pth'.format(output))
                torch.save(discriminator.module.state_dict(),
                           '{}/discriminator-itercurr.pth'.format(output))

            # save labeled snapshots
            if iteration % snapshot == 0:
                os.makedirs(output, exist_ok=True)
                if not train_discrim_only:
                    torch.save(net.module.state_dict(),
                               '{}/net-iter{}.pth'.format(output, iteration))
                torch.save(
                    discriminator.module.state_dict(),
                    '{}/discriminator-iter{}.pth'.format(output, iteration))

            if iteration - last_update_g >= 3 * len(loader):
                print('No suitable discriminator found -- returning.')
                torch.save(net.module.state_dict(),
                           '{}/net-iter{}.pth'.format(output, iteration))
                iteration = max_iter  # make sure outside loop breaks
                break

    writer.close()
Esempio n. 2
0
def main(output, dataset, datadir, lr, momentum, snapshot, downscale,
         cls_weights, weights_init, num_cls, lsgan, max_iter, lambda_d,
         lambda_g, train_discrim_only, weights_discrim, crop_size,
         weights_shared, discrim_feat, half_crop, batch, model, targetsup):

    targetSup = 1
    # So data is sampled in consistent way
    np.random.seed(1337)
    torch.manual_seed(1337)
    logdir = 'runs/{:s}/{:s}_to_{:s}/lr{:.1g}_ld{:.2g}_lg{:.2g}'.format(
        model, dataset[0], dataset[1], lr, lambda_d, lambda_g)
    if weights_shared:
        logdir += '_weightshared'
    else:
        logdir += '_weightsunshared'
    if discrim_feat:
        logdir += '_discrimfeat'
    else:
        logdir += '_discrimscore'
    logdir += '/' + datetime.now().strftime('%Y_%b_%d-%H:%M')
    writer = SummaryWriter(logdir)

    config_logging()
    print('Train Discrim Only', train_discrim_only)
    net = get_model(model, num_cls=num_cls, output_last_ft=discrim_feat)
    net.load_state_dict(torch.load(weights_init))
    if weights_shared:
        net_src = net  # shared weights
    else:
        net_src = get_model(model,
                            num_cls=num_cls,
                            output_last_ft=discrim_feat)
        new_src.load_state_dict(torch.load(weights_init))
        net_src.eval()

    print("GOT MODEL")

    odim = 1 if lsgan else 2
    idim = num_cls if not discrim_feat else 4096
    print('discrim_feat', discrim_feat, idim)
    print('discriminator init weights: ', weights_discrim)

    if torch.cuda.is_available():
        discriminator = Discriminator(input_dim=idim,
                                      output_dim=odim,
                                      pretrained=not (weights_discrim == None),
                                      weights_init=weights_discrim).cuda()
    else:
        discriminator = Discriminator(input_dim=idim,
                                      output_dim=odim,
                                      pretrained=not (weights_discrim == None),
                                      weights_init=weights_discrim)

    loader = AddaDataLoader(None,
                            dataset,
                            datadir,
                            downscale,
                            crop_size=crop_size,
                            half_crop=half_crop,
                            batch_size=batch,
                            shuffle=True,
                            num_workers=2)
    print('dataset', dataset)

    # Class weighted loss?
    if cls_weights is not None:
        weights = np.loadtxt(cls_weights)
    else:
        weights = None

    # setup optimizers
    opt_dis = torch.optim.SGD(discriminator.parameters(),
                              lr=lr,
                              momentum=momentum,
                              weight_decay=0.0005)
    opt_rep = torch.optim.SGD(net.parameters(),
                              lr=lr,
                              momentum=momentum,
                              weight_decay=0.0005)

    iteration = 0
    num_update_g = 0
    last_update_g = -1
    losses_super_s = deque(maxlen=100)
    losses_super_t = deque(maxlen=100)
    losses_dis = deque(maxlen=100)
    losses_rep = deque(maxlen=100)
    accuracies_dom = deque(maxlen=100)
    intersections = np.zeros([100, num_cls])
    unions = np.zeros([100, num_cls])
    accuracy = deque(maxlen=100)
    print('max iter:', max_iter)

    net.train()
    discriminator.train()
    IoU_s = deque(maxlen=100)
    IoU_t = deque(maxlen=100)

    Recall_s = deque(maxlen=100)
    Recall_t = deque(maxlen=100)

    while iteration < max_iter:

        for im_s, im_t, label_s, label_t in loader:

            if iteration > max_iter:
                break

            info_str = 'Iteration {}: '.format(iteration)

            if not check_label(label_s, num_cls):
                continue

            ###########################
            # 1. Setup Data Variables #
            ###########################
            im_s = make_variable(im_s, requires_grad=False)
            label_s = make_variable(label_s, requires_grad=False)
            im_t = make_variable(im_t, requires_grad=False)
            label_t = make_variable(label_t, requires_grad=False)

            #############################
            # 2. Optimize Discriminator #
            #############################

            # zero gradients for optimizer
            opt_dis.zero_grad()
            opt_rep.zero_grad()

            # extract features
            if discrim_feat:
                score_s, feat_s = net_src(im_s)
                score_s = Variable(score_s.data, requires_grad=False)
                f_s = Variable(feat_s.data, requires_grad=False)
            else:
                score_s = Variable(net_src(im_s).data, requires_grad=False)
                f_s = score_s
            dis_score_s = discriminator(f_s)

            if discrim_feat:
                score_t, feat_t = net(im_t)
                score_t = Variable(score_t.data, requires_grad=False)
                f_t = Variable(feat_t.data, requires_grad=False)
            else:
                score_t = Variable(net(im_t).data, requires_grad=False)
                f_t = score_t
            dis_score_t = discriminator(f_t)

            dis_pred_concat = torch.cat((dis_score_s, dis_score_t))

            # prepare real and fake labels
            batch_t, _, h, w = dis_score_t.size()
            batch_s, _, _, _ = dis_score_s.size()
            dis_label_concat = make_variable(torch.cat([
                torch.ones(batch_s, h, w).long(),
                torch.zeros(batch_t, h, w).long()
            ]),
                                             requires_grad=False)

            # compute loss for discriminator
            loss_dis = supervised_loss(dis_pred_concat, dis_label_concat)
            (lambda_d * loss_dis).backward()
            losses_dis.append(loss_dis.item())

            # optimize discriminator
            opt_dis.step()

            # compute discriminator acc
            pred_dis = torch.squeeze(dis_pred_concat.max(1)[1])
            dom_acc = (pred_dis == dis_label_concat).float().mean().item()
            accuracies_dom.append(dom_acc * 100.)

            # add discriminator info to log
            info_str += " domacc:{:0.1f}  D:{:.3f}".format(
                np.mean(accuracies_dom), np.mean(losses_dis))
            writer.add_scalar('loss/discriminator', np.mean(losses_dis),
                              iteration)
            writer.add_scalar('acc/discriminator', np.mean(accuracies_dom),
                              iteration)

            ###########################
            # Optimize Target Network #
            ###########################

            dom_acc_thresh = 55

            if not train_discrim_only and np.mean(
                    accuracies_dom) > dom_acc_thresh:

                last_update_g = iteration
                num_update_g += 1
                if num_update_g % 1 == 0:
                    print(
                        'Updating G with adversarial loss ({:d} times)'.format(
                            num_update_g))

                # zero out optimizer gradients
                opt_dis.zero_grad()
                opt_rep.zero_grad()

                # extract features
                if discrim_feat:
                    score_t, feat_t = net(im_t)
                    score_t = Variable(score_t.data, requires_grad=False)
                    f_t = feat_t
                else:
                    score_t = net(im_t)
                    f_t = score_t

                #score_t = net(im_t)
                dis_score_t = discriminator(f_t)

                # create fake label
                batch, _, h, w = dis_score_t.size()
                target_dom_fake_t = make_variable(torch.ones(batch, h,
                                                             w).long(),
                                                  requires_grad=False)

                # compute loss for target net
                loss_gan_t = supervised_loss(dis_score_t, target_dom_fake_t)
                (lambda_g * loss_gan_t).backward()
                losses_rep.append(loss_gan_t.item())
                writer.add_scalar('loss/generator', np.mean(losses_rep),
                                  iteration)

                # optimize target net
                opt_rep.step()

                # log net update info
                info_str += ' G:{:.3f}'.format(np.mean(losses_rep))

            if (not train_discrim_only) and weights_shared and (
                    np.mean(accuracies_dom) > dom_acc_thresh):

                print('Updating G using source supervised loss.')

                # zero out optimizer gradients
                opt_dis.zero_grad()
                opt_rep.zero_grad()

                # extract features
                if discrim_feat:
                    score_s, _ = net(im_s)
                    score_t, _ = net(im_t)
                else:
                    score_s = net(im_s)
                    score_t = net(im_t)

                loss_supervised_s = supervised_loss(score_s,
                                                    label_s,
                                                    weights=weights)
                loss_supervised_t = supervised_loss(score_t,
                                                    label_t,
                                                    weights=weights)
                loss_supervised = loss_supervised_s

                if targetSup:
                    loss_supervised += loss_supervised_t

                loss_supervised.backward()

                losses_super_s.append(loss_supervised_s.item())
                info_str += ' clsS:{:.2f}'.format(np.mean(losses_super_s))
                writer.add_scalar('loss/supervised/source',
                                  np.mean(losses_super_s), iteration)

                losses_super_t.append(loss_supervised_t.item())
                info_str += ' clsT:{:.2f}'.format(np.mean(losses_super_t))
                writer.add_scalar('loss/supervised/target',
                                  np.mean(losses_super_t), iteration)

                # optimize target net
                opt_rep.step()

            ###########################
            # Log and compute metrics #
            ###########################
            if iteration % 10 == 0 and iteration > 0:

                # compute metrics
                intersection, union, acc = seg_accuracy(
                    score_t, label_t.data, num_cls)
                iou_s = IoU(score_s, label_s)
                iou_t = IoU(score_t, label_t)
                rc_s = recall(score_s, label_s)
                rc_t = recall(score_t, label_t)
                IoU_s.append(iou_s.item())
                IoU_t.append(iou_t.item())
                Recall_s.append(rc_s.item())
                Recall_t.append(rc_t.item())
                intersections = np.vstack(
                    [intersections[1:, :], intersection[np.newaxis, :]])
                unions = np.vstack([unions[1:, :], union[np.newaxis, :]])
                accuracy.append(acc.item() * 100)
                acc = np.mean(accuracy)
                mIoU = np.mean(
                    np.maximum(intersections, 1) / np.maximum(unions, 1)) * 100

                info_str += ' IoU:{:0.2f}  Recall:{:0.2f}'.format(iou_s, rc_s)
                # writer.add_scalar('metrics/acc', np.mean(accuracy), iteration)
                # writer.add_scalar('metrics/mIoU', np.mean(mIoU), iteration)
                # writer.add_scalar('metrics/RealIoU_Source', np.mean(IoU_s))
                # writer.add_scalar('metrics/RealIoU_Target', np.mean(IoU_t))
                # writer.add_scalar('metrics/RealRecall_Source', np.mean(Recall_s))
                # writer.add_scalar('metrics/RealRecall_Target', np.mean(Recall_t))
                logging.info(info_str)
                print(info_str)

                im_s = Image.fromarray(
                    np.uint8(
                        norm(im_s[0]).permute(1, 2, 0).cpu().data.numpy() *
                        255))
                im_t = Image.fromarray(
                    np.uint8(
                        norm(im_t[0]).permute(1, 2, 0).cpu().data.numpy() *
                        255))
                label_s = Image.fromarray(
                    np.uint8(label_s[0].cpu().data.numpy() * 255))
                label_t = Image.fromarray(
                    np.uint8(label_t[0].cpu().data.numpy() * 255))
                score_s = Image.fromarray(
                    np.uint8(mxAxis(score_s[0]).cpu().data.numpy() * 255))
                score_t = Image.fromarray(
                    np.uint8(mxAxis(score_t[0]).cpu().data.numpy() * 255))

                im_s.save(output + "/im_s.png")
                im_t.save(output + "/im_t.png")
                label_s.save(output + "/label_s.png")
                label_t.save(output + "/label_t.png")
                score_s.save(output + "/score_s.png")
                score_t.save(output + "/score_t.png")

            iteration += 1

            ################
            # Save outputs #
            ################

            # every 500 iters save current model
            if iteration % 500 == 0:
                os.makedirs(output, exist_ok=True)
                if not train_discrim_only:
                    torch.save(net.state_dict(),
                               '{}/net-itercurr.pth'.format(output))
                torch.save(discriminator.state_dict(),
                           '{}/discriminator-itercurr.pth'.format(output))

            # save labeled snapshots
            if iteration % snapshot == 0:
                os.makedirs(output, exist_ok=True)
                if not train_discrim_only:
                    torch.save(net.state_dict(),
                               '{}/net-iter{}.pth'.format(output, iteration))
                torch.save(
                    discriminator.state_dict(),
                    '{}/discriminator-iter{}.pth'.format(output, iteration))

            if iteration - last_update_g >= len(loader):
                print('No suitable discriminator found -- returning.')
                # import pdb;pdb.set_trace()
                # torch.save(net.state_dict(),'{}/net-iter{}.pth'.format(output, iteration))
                # iteration = max_iter # make sure outside loop breaks
                # break

    writer.close()