Пример #1
0
def Pre_train(args, FeatExtor, DepthEsmator, data_loader_real,
              data_loader_fake, summary_writer, saver, savefilename):

    # savepath = os.path.join(args.results_path, savefilename)
    # mkdir(savepath)
    ####################
    # 1. setup network #
    ####################
    # set train state for Dropout and BN layers
    FeatExtor.train()
    DepthEsmator.train()

    FeatExtor = DataParallel(FeatExtor)
    DepthEsmator = DataParallel(DepthEsmator)

    criterionDepth = torch.nn.MSELoss()

    optimizer_DG_depth = optim.Adam(list(FeatExtor.parameters()) +
                                    list(DepthEsmator.parameters()),
                                    lr=args.lr_DG_depth,
                                    betas=(args.beta1, args.beta2))

    iternum = max(len(data_loader_real), len(data_loader_fake))

    print('iternum={}'.format(iternum))

    ####################
    # 2. train network #
    ####################
    global_step = 0

    for epoch in range(args.pre_epochs):

        # epoch=epochNum+5

        data_real = get_inf_iterator(data_loader_real)
        data_fake = get_inf_iterator(data_loader_fake)

        for step in range(iternum):

            cat_img_real, depth_img_real, lab_real = next(data_real)
            cat_img_fake, depth_img_fake, lab_fake = next(data_fake)

            ori_img = torch.cat([cat_img_real, cat_img_fake], 0)
            ori_img = ori_img.cuda()

            depth_img = torch.cat([depth_img_real, depth_img_fake], 0)
            depth_img = depth_img.cuda()

            feat_ext, _ = FeatExtor(ori_img)
            depth_Pre = DepthEsmator(feat_ext)

            Loss_depth = criterionDepth(depth_Pre, depth_img)

            optimizer_DG_depth.zero_grad()
            Loss_depth.backward()
            optimizer_DG_depth.step()

            info = {
                'Loss_depth': Loss_depth.item(),
            }
            for tag, value in info.items():
                summary_writer.add_scalar(tag, value, global_step)

            #============ print the log info ============#
            if (step + 1) % args.log_step == 0:
                errors = OrderedDict([('Loss_depth', Loss_depth.item())])
                saver.print_current_errors((epoch + 1), (step + 1), errors)

            global_step += 1

        if ((epoch + 1) % args.model_save_epoch == 0):
            model_save_path = os.path.join(args.results_path, 'snapshots',
                                           savefilename)
            mkdir(model_save_path)

            torch.save(
                FeatExtor.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-Ext-{}.pt".format(epoch + 1)))
            torch.save(
                DepthEsmator.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-Depth-{}.pt".format(epoch + 1)))

    torch.save(FeatExtor.state_dict(),
               os.path.join(model_save_path, "DGFA-Ext-final.pt"))
    torch.save(DepthEsmator.state_dict(),
               os.path.join(model_save_path, "DGFA-Depth-final.pt"))
Пример #2
0
def Train(args, FeatExtor, DepthEstor, FeatEmbder, data_loader1_real,
          data_loader1_fake, data_loader2_real, data_loader2_fake,
          data_loader3_real, data_loader3_fake, data_loader_target,
          summary_writer, Saver, savefilename):

    ####################
    # 1. setup network #
    ####################
    # set train state for Dropout and BN layers
    FeatExtor.train()
    DepthEstor.train()
    FeatEmbder.train()

    FeatExtor = DataParallel(FeatExtor)
    DepthEstor = DataParallel(DepthEstor)

    # setup criterion and optimizer
    criterionCls = nn.BCEWithLogitsLoss()
    criterionDepth = torch.nn.MSELoss()

    if args.optimizer_meta is 'adam':

        optimizer_all = optim.Adam(itertools.chain(FeatExtor.parameters(),
                                                   DepthEstor.parameters(),
                                                   FeatEmbder.parameters()),
                                   lr=args.lr_meta,
                                   betas=(args.beta1, args.beta2))

    else:
        raise NotImplementedError('Not a suitable optimizer')

    iternum = max(len(data_loader1_real), len(data_loader1_fake),
                  len(data_loader2_real), len(data_loader2_fake),
                  len(data_loader3_real), len(data_loader3_fake))

    print('iternum={}'.format(iternum))

    ####################
    # 2. train network #
    ####################
    global_step = 0

    for epoch in range(args.epochs):

        data1_real = get_inf_iterator(data_loader1_real)
        data1_fake = get_inf_iterator(data_loader1_fake)

        data2_real = get_inf_iterator(data_loader2_real)
        data2_fake = get_inf_iterator(data_loader2_fake)

        data3_real = get_inf_iterator(data_loader3_real)
        data3_fake = get_inf_iterator(data_loader3_fake)

        for step in range(iternum):

            #============ one batch extraction ============#

            cat_img1_real, depth_img1_real, lab1_real = next(data1_real)
            cat_img1_fake, depth_img1_fake, lab1_fake = next(data1_fake)

            cat_img2_real, depth_img2_real, lab2_real = next(data2_real)
            cat_img2_fake, depth_img2_fake, lab2_fake = next(data2_fake)

            cat_img3_real, depth_img3_real, lab3_real = next(data3_real)
            cat_img3_fake, depth_img3_fake, lab3_fake = next(data3_fake)

            #============ one batch collection ============#

            catimg1 = torch.cat([cat_img1_real, cat_img1_fake], 0).cuda()
            depth_img1 = torch.cat([depth_img1_real, depth_img1_fake],
                                   0).cuda()
            lab1 = torch.cat([lab1_real, lab1_fake], 0).float().cuda()

            catimg2 = torch.cat([cat_img2_real, cat_img2_fake], 0).cuda()
            depth_img2 = torch.cat([depth_img2_real, depth_img2_fake],
                                   0).cuda()
            lab2 = torch.cat([lab2_real, lab2_fake], 0).float().cuda()

            catimg3 = torch.cat([cat_img3_real, cat_img3_fake], 0).cuda()
            depth_img3 = torch.cat([depth_img3_real, depth_img3_fake],
                                   0).cuda()
            lab3 = torch.cat([lab3_real, lab3_fake], 0).float().cuda()

            catimg = torch.cat([catimg1, catimg2, catimg3], 0)
            depth_GT = torch.cat([depth_img1, depth_img2, depth_img3], 0)
            label = torch.cat([lab1, lab2, lab3], 0)

            #============ doamin list augmentation ============#
            catimglist = [catimg1, catimg2, catimg3]
            lablist = [lab1, lab2, lab3]
            deplist = [depth_img1, depth_img2, depth_img3]

            domain_list = list(range(len(catimglist)))
            random.shuffle(domain_list)

            meta_train_list = domain_list[:args.metatrainsize]
            meta_test_list = domain_list[args.metatrainsize:]
            print('metatrn={}, metatst={}'.format(meta_train_list,
                                                  meta_test_list[0]))

            #============ meta training ============#

            Loss_dep_train = 0.0
            Loss_cls_train = 0.0

            adapted_state_dicts = []

            for index in meta_train_list:

                catimg_meta = catimglist[index]
                lab_meta = lablist[index]
                depGT_meta = deplist[index]

                batchidx = list(range(len(catimg_meta)))
                random.shuffle(batchidx)

                img_rand = catimg_meta[batchidx, :]
                lab_rand = lab_meta[batchidx]
                depGT_rand = depGT_meta[batchidx, :]

                feat_ext_all, feat = FeatExtor(img_rand)
                pred = FeatEmbder(feat)
                depth_Pre = DepthEstor(feat_ext_all)

                Loss_cls = criterionCls(pred.squeeze(), lab_rand)
                Loss_dep = criterionDepth(depth_Pre, depGT_rand)

                Loss_dep_train += Loss_dep
                Loss_cls_train += Loss_cls

                zero_param_grad(FeatEmbder.parameters())
                grads_FeatEmbder = torch.autograd.grad(Loss_cls,
                                                       FeatEmbder.parameters(),
                                                       create_graph=True)
                fast_weights_FeatEmbder = FeatEmbder.cloned_state_dict()

                adapted_params = OrderedDict()
                for (key, val), grad in zip(FeatEmbder.named_parameters(),
                                            grads_FeatEmbder):
                    adapted_params[key] = val - args.meta_step_size * grad
                    fast_weights_FeatEmbder[key] = adapted_params[key]

                adapted_state_dicts.append(fast_weights_FeatEmbder)

            #============ meta testing ============#
            Loss_dep_test = 0.0
            Loss_cls_test = 0.0

            index = meta_test_list[0]

            catimg_meta = catimglist[index]
            lab_meta = lablist[index]
            depGT_meta = deplist[index]

            batchidx = list(range(len(catimg_meta)))
            random.shuffle(batchidx)

            img_rand = catimg_meta[batchidx, :]
            lab_rand = lab_meta[batchidx]
            depGT_rand = depGT_meta[batchidx, :]

            feat_ext_all, feat = FeatExtor(img_rand)
            depth_Pre = DepthEstor(feat_ext_all)
            Loss_dep = criterionDepth(depth_Pre, depGT_rand)

            for n_scr in range(len(meta_train_list)):
                a_dict = adapted_state_dicts[n_scr]

                pred = FeatEmbder(feat, a_dict)
                Loss_cls = criterionCls(pred.squeeze(), lab_rand)

                Loss_cls_test += Loss_cls

            Loss_dep_test = Loss_dep

            Loss_dep_train_ave = Loss_dep_train / len(meta_train_list)
            Loss_dep_test = Loss_dep_test

            Loss_meta_train = Loss_cls_train + args.W_depth * Loss_dep_train
            Loss_meta_test = Loss_cls_test + args.W_depth * Loss_dep_test

            Loss_all = Loss_meta_train + args.W_metatest * Loss_meta_test

            optimizer_all.zero_grad()
            Loss_all.backward()
            optimizer_all.step()

            if (step + 1) % args.log_step == 0:
                errors = OrderedDict([
                    ('Loss_meta_train', Loss_meta_train.item()),
                    ('Loss_meta_test', Loss_meta_test.item()),
                    ('Loss_cls_train', Loss_cls_train.item()),
                    ('Loss_cls_test', Loss_cls_test.item()),
                    ('Loss_dep_train_ave', Loss_dep_train_ave.item()),
                    ('Loss_dep_test', Loss_dep_test.item()),
                ])
                Saver.print_current_errors((epoch + 1), (step + 1), errors)

            #============ tensorboard the log info ============#
            info = {
                'Loss_meta_train': Loss_meta_train.item(),
                'Loss_meta_test': Loss_meta_test.item(),
                'Loss_cls_train': Loss_cls_train.item(),
                'Loss_cls_test': Loss_cls_test.item(),
                'Loss_dep_train_ave': Loss_dep_train_ave.item(),
                'Loss_dep_test': Loss_dep_test.item(),
            }
            for tag, value in info.items():
                summary_writer.add_scalar(tag, value, global_step)

            global_step += 1

            #############################
            # 2.4 save model parameters #
            #############################
            if ((step + 1) % args.model_save_step == 0):
                model_save_path = os.path.join(args.results_path, 'snapshots',
                                               savefilename)
                mkdir(model_save_path)

                torch.save(
                    FeatExtor.state_dict(),
                    os.path.join(
                        model_save_path,
                        "FeatExtor-{}-{}.pt".format(epoch + 1, step + 1)))
                torch.save(
                    FeatEmbder.state_dict(),
                    os.path.join(
                        model_save_path,
                        "FeatEmbder-{}-{}.pt".format(epoch + 1, step + 1)))
                torch.save(
                    DepthEstor.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DepthEstor-{}-{}.pt".format(epoch + 1, step + 1)))

        if ((epoch + 1) % args.model_save_epoch == 0):
            model_save_path = os.path.join(args.results_path, 'snapshots',
                                           savefilename)
            mkdir(model_save_path)

            torch.save(
                FeatExtor.state_dict(),
                os.path.join(model_save_path,
                             "FeatExtor-{}.pt".format(epoch + 1)))
            torch.save(
                FeatEmbder.state_dict(),
                os.path.join(model_save_path,
                             "FeatEmbder-{}.pt".format(epoch + 1)))
            torch.save(
                DepthEstor.state_dict(),
                os.path.join(model_save_path,
                             "DepthEstor-{}.pt".format(epoch + 1)))

    torch.save(FeatExtor.state_dict(),
               os.path.join(model_save_path, "FeatExtor-final.pt"))
    torch.save(FeatEmbder.state_dict(),
               os.path.join(model_save_path, "FeatEmbder-final.pt"))
    torch.save(DepthEstor.state_dict(),
               os.path.join(model_save_path, "DepthEstor-final.pt"))
Пример #3
0
def train(classifier, generator, critic, src_data_loader, tgt_data_loader):
    """Train generator, classifier and critic jointly."""
    ####################
    # 1. setup network #
    ####################

    # set train state for Dropout and BN layers
    classifier.train()
    generator.train()
    critic.train()

    # set criterion for classifier and optimizers
    criterion = nn.CrossEntropyLoss()
    optimizer_c = get_optimizer(classifier, "Adam")
    optimizer_g = get_optimizer(generator, "Adam")
    optimizer_d = get_optimizer(critic, "Adam")

    # zip source and target data pair
    data_iter_src = get_inf_iterator(src_data_loader)
    data_iter_tgt = get_inf_iterator(tgt_data_loader)

    # counter
    g_step = 0

    # positive and negative labels
    pos_labels = make_variable(torch.FloatTensor([1]))
    neg_labels = make_variable(torch.FloatTensor([-1]))

    ####################
    # 2. train network #
    ####################

    for epoch in range(params.num_epochs):
        ###########################
        # 2.1 train discriminator #
        ###########################
        # requires to compute gradients for D
        for p in critic.parameters():
            p.requires_grad = True

        # set steps for discriminator
        if g_step < 25 or g_step % 500 == 0:
            # this helps to start with the critic at optimum
            # even in the first iterations.
            critic_iters = 100
        else:
            critic_iters = params.d_steps

        # loop for optimizing discriminator
        for d_step in range(critic_iters):
            # convert images into torch.Variable
            images_src, labels_src = next(data_iter_src)
            images_tgt, _ = next(data_iter_tgt)
            images_src = make_variable(images_src)
            labels_src = make_variable(labels_src.squeeze_())
            images_tgt = make_variable(images_tgt)
            if images_src.size(0) != params.batch_size or \
                    images_tgt.size(0) != params.batch_size:
                continue

            # zero gradients for optimizer
            optimizer_d.zero_grad()

            # compute source data loss for discriminator
            feat_src = generator(images_src)
            d_loss_src = critic(feat_src.detach())
            d_loss_src = d_loss_src.mean()
            d_loss_src.backward(neg_labels)

            # compute target data loss for discriminator
            feat_tgt = generator(images_tgt)
            d_loss_tgt = critic(feat_tgt.detach())
            d_loss_tgt = d_loss_tgt.mean()
            d_loss_tgt.backward(pos_labels)

            # compute gradient penalty
            gradient_penalty = calc_gradient_penalty(critic, feat_src.data,
                                                     feat_tgt.data)
            gradient_penalty.backward()

            # optimize weights of discriminator
            d_loss = -d_loss_src + d_loss_tgt + gradient_penalty
            optimizer_d.step()

        ########################
        # 2.2 train classifier #
        ########################

        # zero gradients for optimizer
        optimizer_c.zero_grad()

        # compute loss for critic
        preds_c = classifier(generator(images_src).detach())
        c_loss = criterion(preds_c, labels_src)

        # optimize source classifier
        c_loss.backward()
        optimizer_c.step()

        #######################
        # 2.3 train generator #
        #######################
        # avoid to compute gradients for D
        for p in critic.parameters():
            p.requires_grad = False

        # zero grad for optimizer of generator
        optimizer_g.zero_grad()

        # compute source data classification loss for generator
        feat_src = generator(images_src)
        preds_c = classifier(feat_src)
        g_loss_cls = criterion(preds_c, labels_src)
        g_loss_cls.backward()

        # compute source data discriminattion loss for generator
        feat_src = generator(images_src)
        g_loss_src = critic(feat_src).mean()
        g_loss_src.backward(pos_labels)

        # compute target data discriminattion loss for generator
        feat_tgt = generator(images_tgt)
        g_loss_tgt = critic(feat_tgt).mean()
        g_loss_tgt.backward(neg_labels)

        # compute loss for generator
        g_loss = g_loss_src - g_loss_tgt + g_loss_cls

        # optimize weights of generator
        optimizer_g.step()
        g_step += 1

        ##################
        # 2.4 print info #
        ##################
        if ((epoch + 1) % params.log_step == 0):
            print("Epoch [{}/{}]:"
                  "d_loss={:.5f} c_loss={:.5f} g_loss={:.5f} "
                  "D(x)={:.5f} D(G(z))={:.5f} GP={:.5f}".format(
                      epoch + 1, params.num_epochs, d_loss.data[0],
                      c_loss.data[0], g_loss.data[0], d_loss_src.data[0],
                      d_loss_tgt.data[0], gradient_penalty.data[0]))

        #############################
        # 2.5 save model parameters #
        #############################
        if ((epoch + 1) % params.save_step == 0):
            save_model(critic, "WGAN-GP_critic-{}.pt".format(epoch + 1))
            save_model(classifier,
                       "WGAN-GP_classifier-{}.pt".format(epoch + 1))
            save_model(generator, "WGAN-GP_generator-{}.pt".format(epoch + 1))

    return classifier, generator
Пример #4
0
def domain_adapt(F, F_1, F_2, F_t, source_dataset, target_dataset, excerpt,
                 pseudo_labels, plot):
    """Perform Doamin Adaptation between source and target domains."""
    # set criterion for classifier and optimizers
    criterion = nn.CrossEntropyLoss()
    if 0:
        optimType = "Adam"
        cfg.learning_rate = 1.0E-4
    else:
        optimType = "sgd"
        cfg.learning_rate = 1.0E-4
    optimizer_F = get_optimizer(F, optimType)
    optimizer_F_1 = get_optimizer(F_1, optimType)
    optimizer_F_2 = get_optimizer(F_2, optimType)
    optimizer_F_t = get_optimizer(F_t, optimType)

    # get labelled target dataset
    print('pseudo_labels = %s' % str(pseudo_labels))
    target_dataset_labelled = get_dummy(target_dataset,
                                        excerpt,
                                        pseudo_labels,
                                        get_dataset=True)

    # merge soruce data and target data
    merged_dataset = ConcatDataset([source_dataset, target_dataset_labelled])

    print('target_dataset_labelled = %d' % len(target_dataset_labelled))

    # start training
    plt.figure()

    for k in range(cfg.num_epochs_k):
        # set train state for Dropout and BN layers
        F.train()
        F_1.train()
        F_2.train()
        F_t.train()

        losses = []

        merged_dataloader = make_data_loader(merged_dataset)
        target_dataloader_labelled = make_data_loader(target_dataset_labelled)
        target_dataloader_labelled_iter = get_inf_iterator(
            target_dataloader_labelled)

        if 0:
            plt.figure()
            atr.showDataSet(target_dataloader_labelled)
            plt.waitforbuttonpress()

        if 0:
            # There's a bug here, the labels are not the same data type.  print them out!!
            source_dataloader_iter = get_inf_iterator(
                make_data_loader(source_dataset))

            a, b = next(source_dataloader_iter)
            c, d = next(target_dataloader_labelled_iter)
            print('source labels = {}'.format(b))
            print('target labels = {}'.format(d))
            sys.exit(0)

        for epoch in range(cfg.num_epochs_adapt):
            if optimType == 'sgd':
                adjustLearningRate(optimizer_F, cfg.learning_rate, epoch,
                                   cfg.num_epochs_adapt)
                adjustLearningRate(optimizer_F_1, cfg.learning_rate, epoch,
                                   cfg.num_epochs_adapt)
                adjustLearningRate(optimizer_F_2, cfg.learning_rate, epoch,
                                   cfg.num_epochs_adapt)
                adjustLearningRate(optimizer_F_t, cfg.learning_rate, epoch,
                                   cfg.num_epochs_adapt)

            for step, rez in enumerate(merged_dataloader):
                #!!print('rez = %s' % rez)
                images, labels = rez
                if images.shape[0] < cfg.batch_size:
                    print('WARNING: batch of size %d smaller than desired %d: skipping' % \
                          (images.shape[0], cfg.batch_size))
                    continue

                # sample from T_l
                images_tgt, labels_tgt = next(target_dataloader_labelled_iter)
                while images_tgt.shape[0] < cfg.batch_size:
                    print('WARNING: target batch of size %d smaller than desired %d' % \
                          (images_tgt.shape[0], cfg.batch_size))
                    images_tgt, labels_tgt = next(
                        target_dataloader_labelled_iter)

                # convert into torch.autograd.Variable
                images = make_variable(images)
                labels = make_variable(labels)
                images_tgt = make_variable(images_tgt)
                labels_tgt = make_variable(labels_tgt)

                # zero-grad optimizer
                optimizer_F.zero_grad()
                optimizer_F_1.zero_grad()
                optimizer_F_2.zero_grad()
                optimizer_F_t.zero_grad()

                # forward networks
                #print('images shape = {}'.format(images.shape))#!!
                out_F = F(images)
                #print('out_F = {}'.format(out_F.shape))#!!
                out_F_1 = F_1(out_F)
                out_F_2 = F_2(out_F)
                out_F_t = F_t(F(images_tgt))

                # compute labelling loss
                loss_similiar = calc_similiar_penalty(F_1, F_2)
                loss_F_1 = criterion(out_F_1, labels)
                loss_F_2 = criterion(out_F_2, labels)
                loss_labelling = loss_F_1 + loss_F_2 + 0.03 * loss_similiar
                loss_labelling.backward()

                # compute target specific loss
                loss_F_t = criterion(out_F_t, labels_tgt)
                loss_F_t.backward()

                # optimize
                optimizer_F.step()
                optimizer_F_1.step()
                optimizer_F_2.step()
                optimizer_F_t.step()

                losses.append(loss_F_t.item())

                # print step info
                if ((step + 1) % cfg.log_step == 0):
                    print("K[{}/{}] Epoch [{}/{}] Step[{}/{}] Loss("
                          "labelling={:.5f} target={:.5f})".format(
                              k + 1,
                              cfg.num_epochs_k,
                              epoch + 1,
                              cfg.num_epochs_adapt,
                              step + 1,
                              len(merged_dataloader),
                              loss_labelling.item(),  #.data[0],
                              loss_F_t.item(),  #.data[0],
                          ))
                    #!!print('end of loop')

                    if plot:
                        plt.clf()
                        plt.plot(losses)
                        plt.grid(1)
                        plt.title(
                            'Loss for domain adaptation, k = {}/{}, epoch = {}/{}'
                            .format(k, cfg.num_epochs_k, epoch,
                                    cfg.num_epochs_adapt))
                        plt.waitforbuttonpress(0.0001)

        # re-compute the number of selected taget data
        num_target = (k + 2) * len(source_dataset) // 20
        num_target = min(num_target, cfg.num_target_max)
        print(">>> Set num of sampled target data: {}".format(num_target))

        # re-generate pseudo labels
        excerpt, pseudo_labels = generate_labels(F,
                                                 F_1,
                                                 F_2,
                                                 target_dataset,
                                                 num_target,
                                                 useWeightedSampling=True)
        print(">>> Genrate pseudo labels [{}] numtarget = {}".format(
            len(target_dataset_labelled), num_target))

        print('sizes = {}, {}, excerpt = {}, \npseudo_labels = {}'.format(
            len(excerpt), len(pseudo_labels), excerpt, pseudo_labels))

        # get labelled target dataset
        target_dataset_labelled = get_dummy(target_dataset,
                                            excerpt,
                                            pseudo_labels,
                                            get_dataset=True)

        # re-merge soruce data and target data
        merged_dataset = ConcatDataset(
            [source_dataset, target_dataset_labelled])

        # save model
        if ((k + 1) % cfg.save_step == 0):
            save_model(F, "adapt-F-{}.pt".format(k + 1))
            save_model(F_1, "adapt-F_1-{}.pt".format(k + 1))
            save_model(F_2, "adapt-F_2-{}.pt".format(k + 1))
            save_model(F_t, "adapt-F_t-{}.pt".format(k + 1))

    # save final model
    save_model(F, "adapt-F-final.pt")
    save_model(F_1, "adapt-F_1-final.pt")
    save_model(F_2, "adapt-F_2-final.pt")
    save_model(F_t, "adapt-F_t-final.pt")
Пример #5
0
def domain_adapt(F, F_1, F_2, F_t, source_dataset, target_dataset, excerpt,
                 pseudo_labels):
    """Perform Doamin Adaptation between source and target domains."""
    # set criterion for classifier and optimizers
    criterion = nn.CrossEntropyLoss()
    optimizer_F = get_optimizer(F, "Adam")
    optimizer_F_1 = get_optimizer(F_1, "Adam")
    optimizer_F_2 = get_optimizer(F_2, "Adam")
    optimizer_F_t = get_optimizer(F_t, "Adam")

    # get labelled target dataset
    target_dataset_labelled = get_dummy(target_dataset,
                                        excerpt,
                                        pseudo_labels,
                                        get_dataset=True)

    # merge soruce data and target data
    merged_dataset = ConcatDataset([source_dataset, target_dataset_labelled])

    # start training
    for k in range(cfg.num_epochs_k):
        # set train state for Dropout and BN layers
        F.train()
        F_1.train()
        F_2.train()
        F_t.train()

        merged_dataloader = make_data_loader(merged_dataset)
        target_dataloader_labelled = get_inf_iterator(
            make_data_loader(target_dataset_labelled))

        for epoch in range(cfg.num_epochs_adapt):
            for step, (images, labels) in enumerate(merged_dataloader):
                # sample from T_l
                images_tgt, labels_tgt = next(target_dataloader_labelled)

                # convert into torch.autograd.Variable
                images = make_variable(images)
                labels = make_variable(labels)
                images_tgt = make_variable(images_tgt)
                labels_tgt = make_variable(labels_tgt)

                # zero-grad optimizer
                optimizer_F.zero_grad()
                optimizer_F_1.zero_grad()
                optimizer_F_2.zero_grad()
                optimizer_F_t.zero_grad()

                # forward networks
                out_F = F(images)
                out_F_1 = F_1(out_F)
                out_F_2 = F_2(out_F)
                out_F_t = F_t(F(images_tgt))

                # compute labelling loss
                loss_similiar = calc_similiar_penalty(F_1, F_2)
                loss_F_1 = criterion(out_F_1, labels)
                loss_F_2 = criterion(out_F_2, labels)
                loss_labelling = loss_F_1 + loss_F_2 + loss_similiar
                loss_labelling.backward()

                # compute target specific loss
                loss_F_t = criterion(out_F_t, labels_tgt)
                loss_F_t.backward()

                # optimize
                optimizer_F.step()
                optimizer_F_1.step()
                optimizer_F_2.step()
                optimizer_F_t.step()

                # print step info
                if ((step + 1) % cfg.log_step == 0):
                    print("K[{}/{}] Epoch [{}/{}] Step[{}/{}] Loss("
                          "labelling={:.5f} target={:.5f})".format(
                              k + 1,
                              cfg.num_epochs_k,
                              epoch + 1,
                              cfg.num_epochs_adapt,
                              step + 1,
                              len(merged_dataloader),
                              loss_labelling.data[0],
                              loss_F_t.data[0],
                          ))

        # re-compute the number of selected taget data
        num_target = (k + 2) * len(source_dataset) // 20
        num_target = min(num_target, cfg.num_target_max)
        print(">>> Set num of sampled target data: {}".format(num_target))

        # re-generate pseudo labels
        excerpt, pseudo_labels = genarate_labels(F, F_1, F_2, target_dataset,
                                                 num_target)
        print(">>> Genrate pseudo labels [{}]".format(
            len(target_dataset_labelled)))

        # get labelled target dataset
        target_dataset_labelled = get_dummy(target_dataset,
                                            excerpt,
                                            pseudo_labels,
                                            get_dataset=True)

        # re-merge soruce data and target data
        merged_dataset = ConcatDataset(
            [source_dataset, target_dataset_labelled])

        # save model
        if ((k + 1) % cfg.save_step == 0):
            save_model(F, "adapt-F-{}.pt".format(k + 1))
            save_model(F_1, "adapt-F_1-{}.pt".format(k + 1))
            save_model(F_2, "adapt-F_2-{}.pt".format(k + 1))
            save_model(F_t, "adapt-F_t-{}.pt".format(k + 1))

    # save final model
    save_model(F, "adapt-F-final.pt")
    save_model(F_1, "adapt-F_1-final.pt")
    save_model(F_2, "adapt-F_2-final.pt")
    save_model(F_t, "adapt-F_t-final.pt")
Пример #6
0
def train(classifier, generator, critic, src_data_loader, tgt_data_loader):
    """Train generator, classifier and critic jointly."""
    ####################
    # 1. setup network #
    ####################

    # set train state for Dropout and BN layers
    classifier.train()
    generator.train()
    # set criterion for classifier and optimizers
    criterion = nn.CrossEntropyLoss()
    optimizer_c = get_optimizer(classifier, "Adam")

    # zip source and target data pair
    data_iter_src = get_inf_iterator(src_data_loader)

    # counter
    g_step = 0

    ####################
    # 2. train network #
    ####################

    for epoch in range(params.num_epochs):
        ###########################
        # 2.1 train discriminator #
        ###########################
        # requires to compute gradients for D
        for p in critic.parameters():
            p.requires_grad = True

        # set steps for discriminator
        if g_step < 25 or g_step % 500 == 0:
            # this helps to start with the critic at optimum
            # even in the first iterations.
            critic_iters = 100
        else:
            critic_iters = params.d_steps
        critic_iters = 0
        # loop for optimizing discriminator
        #for d_step in range(critic_iters):
        # convert images into torch.Variable
        images_src, labels_src = next(data_iter_src)

        images_src = make_variable(images_src).cuda()
        labels_src = make_variable(labels_src.squeeze_()).cuda()
        # print(type(images_src))

        ########################
        # 2.2 train classifier #
        ########################

        # zero gradients for optimizer
        optimizer_c.zero_grad()

        # compute loss for critic
        preds_c = classifier(generator(images_src))
        c_loss = criterion(preds_c, labels_src)

        # optimize source classifier
        c_loss.backward()
        optimizer_c.step()
        g_step += 1

        ##################
        # 2.4 print info #
        ##################
        if ((epoch + 1) % 500 == 0):
            # print("Epoch [{}/{}]:"
            #       "c_loss={:.5f}"
            #       "D(x)={:.5f}"
            #       .format(epoch + 1,
            #               params.num_epochs,
            #               c_loss.item(),
            #               ))
            test(classifier, generator, src_data_loader, params.src_dataset)
        if ((epoch + 1) % 500 == 0):
            save_model(generator, "Mnist-generator-{}.pt".format(epoch + 1))
            save_model(classifier, "Mnist-classifer{}.pt".format(epoch + 1))
Пример #7
0
def Train(args, FeatExtor, DepthEsmator, FeatEmbder, Discriminator1,
          Discriminator2, Discriminator3, PreFeatExtorS1, PreFeatExtorS2,
          PreFeatExtorS3, data_loader1_real, data_loader1_fake,
          data_loader2_real, data_loader2_fake, data_loader3_real,
          data_loader3_fake, data_loader_target, summary_writer, Saver,
          savefilename):

    ####################
    # 1. setup network #
    ####################
    # set train state for Dropout and BN layers
    FeatExtor.train()
    FeatEmbder.train()
    DepthEsmator.train()
    Discriminator1.train()
    Discriminator2.train()
    Discriminator3.train()

    PreFeatExtorS1.eval()
    PreFeatExtorS2.eval()
    PreFeatExtorS3.eval()

    FeatExtor = DataParallel(FeatExtor)
    FeatEmbder = DataParallel(FeatEmbder)
    DepthEsmator = DataParallel(DepthEsmator)
    Discriminator1 = DataParallel(Discriminator1)
    Discriminator2 = DataParallel(Discriminator2)
    Discriminator3 = DataParallel(Discriminator3)

    PreFeatExtorS1 = DataParallel(PreFeatExtorS1)
    PreFeatExtorS2 = DataParallel(PreFeatExtorS2)
    PreFeatExtorS3 = DataParallel(PreFeatExtorS3)

    # setup criterion and optimizer
    criterionDepth = torch.nn.MSELoss()
    criterionAdv = loss.GANLoss()
    criterionCls = torch.nn.BCEWithLogitsLoss()

    optimizer_DG_depth = optim.Adam(itertools.chain(FeatExtor.parameters(),
                                                    DepthEsmator.parameters()),
                                    lr=args.lr_DG_depth,
                                    betas=(args.beta1, args.beta2))

    optimizer_DG_conf = optim.Adam(itertools.chain(FeatExtor.parameters(),
                                                   FeatEmbder.parameters()),
                                   lr=args.lr_DG_conf,
                                   betas=(args.beta1, args.beta2))

    optimizer_critic1 = optim.Adam(Discriminator1.parameters(),
                                   lr=args.lr_critic,
                                   betas=(args.beta1, args.beta2))

    optimizer_critic2 = optim.Adam(Discriminator2.parameters(),
                                   lr=args.lr_critic,
                                   betas=(args.beta1, args.beta2))

    optimizer_critic3 = optim.Adam(Discriminator3.parameters(),
                                   lr=args.lr_critic,
                                   betas=(args.beta1, args.beta2))

    iternum = max(len(data_loader1_real), len(data_loader1_fake),
                  len(data_loader2_real), len(data_loader2_fake),
                  len(data_loader3_real), len(data_loader3_fake))

    print('iternum={}'.format(iternum))

    ####################
    # 2. train network #
    ####################
    global_step = 0

    for epoch in range(args.epochs):

        data1_real = get_inf_iterator(data_loader1_real)
        data1_fake = get_inf_iterator(data_loader1_fake)

        data2_real = get_inf_iterator(data_loader2_real)
        data2_fake = get_inf_iterator(data_loader2_fake)

        data3_real = get_inf_iterator(data_loader3_real)
        data3_fake = get_inf_iterator(data_loader3_fake)

        for step in range(iternum):

            FeatExtor.train()
            FeatEmbder.train()
            DepthEsmator.train()
            Discriminator1.train()
            Discriminator2.train()
            Discriminator3.train()

            #============ one batch extraction ============#

            cat_img1_real, depth_img1_real, lab1_real = next(data1_real)
            cat_img1_fake, depth_img1_fake, lab1_fake = next(data1_fake)

            cat_img2_real, depth_img2_real, lab2_real = next(data2_real)
            cat_img2_fake, depth_img2_fake, lab2_fake = next(data2_fake)

            cat_img3_real, depth_img3_real, lab3_real = next(data3_real)
            cat_img3_fake, depth_img3_fake, lab3_fake = next(data3_fake)

            #============ one batch collection ============#

            ori_img1 = torch.cat([cat_img1_real, cat_img1_fake], 0).cuda()
            depth_img1 = torch.cat([depth_img1_real, depth_img1_fake], 0)
            lab1 = torch.cat([lab1_real, lab1_fake], 0)

            ori_img2 = torch.cat([cat_img2_real, cat_img2_fake], 0).cuda()
            depth_img2 = torch.cat([depth_img2_real, depth_img2_fake], 0)
            lab2 = torch.cat([lab2_real, lab2_fake], 0)

            ori_img3 = torch.cat([cat_img3_real, cat_img3_fake], 0).cuda()
            depth_img3 = torch.cat([depth_img3_real, depth_img3_fake], 0)
            lab3 = torch.cat([lab3_real, lab3_fake], 0)

            ori_img = torch.cat([ori_img1, ori_img2, ori_img3], 0)
            # ori_img = ori_img.cuda()

            depth_GT = torch.cat([depth_img1, depth_img2, depth_img3], 0)
            depth_GT = depth_GT.cuda()

            label = torch.cat([lab1, lab2, lab3], 0)
            label = label.long().squeeze().cuda()

            with torch.no_grad():
                pre_feat_ext1 = PreFeatExtorS1(ori_img1)[1]
                pre_feat_ext2 = PreFeatExtorS2(ori_img2)[1]
                pre_feat_ext3 = PreFeatExtorS3(ori_img3)[1]

            #============ Depth supervision ============#

            ######### 1. depth loss #########
            optimizer_DG_depth.zero_grad()

            feat_ext_all, feat_ext = FeatExtor(ori_img)
            depth_Pre = DepthEsmator(feat_ext_all)

            Loss_depth = args.W_depth * criterionDepth(depth_Pre, depth_GT)

            Loss_depth.backward()
            optimizer_DG_depth.step()

            #============ domain generalization supervision ============#

            optimizer_DG_conf.zero_grad()

            _, feat_ext = FeatExtor(ori_img)

            feat_tgt = feat_ext

            #************************* confusion all **********************************#

            # predict on generator
            loss_generator1 = criterionAdv(Discriminator1(feat_tgt), True)

            loss_generator2 = criterionAdv(Discriminator2(feat_tgt), True)

            loss_generator3 = criterionAdv(Discriminator3(feat_tgt), True)

            feat_embd, label_pred = FeatEmbder(feat_ext)

            ########## cross-domain triplet loss #########
            Loss_triplet = TripletLossCal(args, feat_embd, lab1, lab2, lab3)

            Loss_cls = criterionCls(label_pred.squeeze(), label.float())

            Loss_gen = args.W_genave * (loss_generator1 + loss_generator2 +
                                        loss_generator3)

            Loss_G = args.W_trip * Loss_triplet + args.W_cls * Loss_cls + args.W_gen * Loss_gen

            Loss_G.backward()
            optimizer_DG_conf.step()

            #************************* confusion domain 1 with 2,3 **********************************#

            feat_src = torch.cat([pre_feat_ext1, pre_feat_ext1, pre_feat_ext1],
                                 0)

            # predict on discriminator
            optimizer_critic1.zero_grad()

            real_loss = criterionAdv(Discriminator1(feat_src), True)
            fake_loss = criterionAdv(Discriminator1(feat_tgt.detach()), False)

            loss_critic1 = 0.5 * (real_loss + fake_loss)

            loss_critic1.backward()
            optimizer_critic1.step()

            #************************* confusion domain 2 with 1,3 **********************************#

            feat_src = torch.cat([pre_feat_ext2, pre_feat_ext2, pre_feat_ext2],
                                 0)

            # predict on discriminator
            optimizer_critic2.zero_grad()

            real_loss = criterionAdv(Discriminator2(feat_src), True)
            fake_loss = criterionAdv(Discriminator2(feat_tgt.detach()), False)

            loss_critic2 = 0.5 * (real_loss + fake_loss)

            loss_critic2.backward()
            optimizer_critic2.step()

            #************************* confusion domain 3 with 1,2 **********************************#

            feat_src = torch.cat([pre_feat_ext3, pre_feat_ext3, pre_feat_ext3],
                                 0)

            # predict on discriminator
            optimizer_critic3.zero_grad()

            real_loss = criterionAdv(Discriminator3(feat_src), True)
            fake_loss = criterionAdv(Discriminator3(feat_tgt.detach()), False)

            loss_critic3 = 0.5 * (real_loss + fake_loss)

            loss_critic3.backward()
            optimizer_critic3.step()

            #============ tensorboard the log info ============#
            info = {
                'Loss_depth': Loss_depth.item(),
                'Loss_triplet': Loss_triplet.item(),
                'Loss_cls': Loss_cls.item(),
                'Loss_G': Loss_G.item(),
                'loss_critic1': loss_critic1.item(),
                'loss_generator1': loss_generator1.item(),
                'loss_critic2': loss_critic2.item(),
                'loss_generator2': loss_generator2.item(),
                'loss_critic3': loss_critic3.item(),
                'loss_generator3': loss_generator3.item(),
            }
            for tag, value in info.items():
                summary_writer.add_scalar(tag, value, global_step)

            if (step + 1) % args.tst_step == 0:
                depth_Pre_real = torch.cat([
                    depth_Pre[0:args.batchsize],
                    depth_Pre[2 * args.batchsize:3 * args.batchsize],
                    depth_Pre[4 * args.batchsize:5 * args.batchsize]
                ], 0)
                depth_Pre_fake = torch.cat([
                    depth_Pre[args.batchsize:2 * args.batchsize],
                    depth_Pre[3 * args.batchsize:4 * args.batchsize],
                    depth_Pre[5 * args.batchsize:6 * args.batchsize]
                ], 0)

                depth_Pre_all = vutils.make_grid(depth_Pre,
                                                 normalize=True,
                                                 scale_each=True)
                depth_Pre_real = vutils.make_grid(depth_Pre_real,
                                                  normalize=True,
                                                  scale_each=True)
                depth_Pre_fake = vutils.make_grid(depth_Pre_fake,
                                                  normalize=True,
                                                  scale_each=True)

                summary_writer.add_image('Depth_Image_all', depth_Pre_all,
                                         global_step)
                summary_writer.add_image('Depth_Image_real', depth_Pre_real,
                                         global_step)
                summary_writer.add_image('Depth_Image_fake', depth_Pre_fake,
                                         global_step)

            #============ print the log info ============#
            if (step + 1) % args.log_step == 0:
                errors = OrderedDict([
                    ('Loss_depth', Loss_depth.item()),
                    ('Loss_triplet', Loss_triplet.item()),
                    ('Loss_cls', Loss_cls.item()), ('Loss_G', Loss_G.item()),
                    ('loss_critic1', loss_critic1.item()),
                    ('loss_generator1', loss_generator1.item()),
                    ('loss_critic2', loss_critic2.item()),
                    ('loss_generator2', loss_generator2.item()),
                    ('loss_critic3', loss_critic3.item()),
                    ('loss_generator3', loss_generator3.item())
                ])

                Saver.print_current_errors((epoch + 1), (step + 1), errors)

            if (step + 1) % args.tst_step == 0:
                evaluate.evaluate_img(FeatExtor, DepthEsmator,
                                      data_loader_target, (epoch + 1),
                                      (step + 1), Saver)

            global_step += 1

            #############################
            # 2.4 save model parameters #
            #############################
            if ((step + 1) % args.model_save_step == 0):
                model_save_path = os.path.join(args.results_path, 'snapshots',
                                               savefilename)
                mkdir(model_save_path)

                torch.save(
                    FeatExtor.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DGFA-Ext-{}-{}.pt".format(epoch + 1, step + 1)))

                torch.save(
                    FeatEmbder.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DGFA-Embd-{}-{}.pt".format(epoch + 1, step + 1)))

                torch.save(
                    DepthEsmator.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DGFA-Depth-{}-{}.pt".format(epoch + 1, step + 1)))

                torch.save(
                    Discriminator1.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DGFA-D1-{}-{}.pt".format(epoch + 1, step + 1)))

                torch.save(
                    Discriminator2.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DGFA-D2-{}-{}.pt".format(epoch + 1, step + 1)))

                torch.save(
                    Discriminator3.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DGFA-D3-{}-{}.pt".format(epoch + 1, step + 1)))

        if ((epoch + 1) % args.model_save_epoch == 0):
            model_save_path = os.path.join(args.results_path, 'snapshots',
                                           savefilename)
            mkdir(model_save_path)

            torch.save(
                FeatExtor.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-Ext-{}.pt".format(epoch + 1)))

            torch.save(
                FeatEmbder.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-Embd-{}.pt".format(epoch + 1)))

            torch.save(
                DepthEsmator.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-Depth-{}.pt".format(epoch + 1)))

            torch.save(
                Discriminator1.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-D1-{}.pt".format(epoch + 1)))

            torch.save(
                Discriminator2.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-D2-{}.pt".format(epoch + 1)))

            torch.save(
                Discriminator3.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-D3-{}.pt".format(epoch + 1)))

    torch.save(FeatExtor.state_dict(),
               os.path.join(model_save_path, "DGFA-Ext-final.pt"))

    torch.save(FeatEmbder.state_dict(),
               os.path.join(model_save_path, "DGFA-Embd-final.pt"))

    torch.save(DepthEsmator.state_dict(),
               os.path.join(model_save_path, "DGFA-Depth-final.pt"))

    torch.save(Discriminator1.state_dict(),
               os.path.join(model_save_path, "DGFA-D1-final.pt"))

    torch.save(Discriminator2.state_dict(),
               os.path.join(model_save_path, "DGFA-D2-final.pt"))

    torch.save(Discriminator3.state_dict(),
               os.path.join(model_save_path, "DGFA-D3-final.pt"))