Exemple #1
0
def main(opt):
    opt.digitroot = _DIGIT_ROOT
    if opt.prefix == '':
        opt.prefix = _PREFIX
    if opt.model == '':
        opt.model = _MODEL
    if opt.beta == '':
        opt.beta = _BETA
    if opt.mu == '':
        opt.mu = _MU
    opt.gamma = _GAMMA
    opt.alpha = _ALPHA
    if opt.norm == None:
        opt.norm = _NORM

    modelname = '{0}_{1}_{2:0.1f}_{3:0.1f}'.format(opt.prefix, opt.model,
                                                   opt.beta, opt.mu)
    modelpath = 'model/' + modelname + '.pth'

    torch.cuda.set_device(opt.gpu)
    device = torch.device('cuda:{0}'.format(opt.gpu))

    now = datetime.now()
    curtime = now.isoformat()
    run_dir = "runs/{0}_{1}_ongoing".format(curtime[0:16], modelname)

    resultname = '{2}/result_{0}_{1}.txt'.format(modelname, opt.num_epochs,
                                                 run_dir)
    n_ch = 64
    n_hidden = 5
    n_resblock = 4

    prompt = ''
    prompt += ('====================================\n')
    prompt += run_dir + '\n'
    for arg in vars(opt):
        prompt = '{0}{1} : {2}\n'.format(prompt, arg, getattr(opt, arg))
    prompt += ('====================================\n')
    print(prompt, end='')

    # opt.model = 'svhn_mnist'
    # opt.model = 'mnist_usps'
    # opt.model = 'usps_mnist'
    # opt.model = 'cifar10_stl10'
    # opt.model = 'stl10_cifar10'

    # opt.model = 'svhn_svhn'
    # opt.model = 'mnist_mnist'
    # opt.model = 'usps_usps'
    # opt.model = 'svhn_usps'
    #########################
    #### DATASET
    #########################
    modelsplit = opt.model.split('_')
    if (modelsplit[0] == 'mnist'
            or modelsplit[0] == 'usps') and modelsplit[1] != 'svhn':
        n_c_in = 1  # number of color channels
    else:
        n_c_in = 3  # number of color channels

    if (modelsplit[1] == 'mnist'
            or modelsplit[1] == 'usps') and modelsplit[0] != 'svhn':
        n_c_out = 1  # number of color channels
    else:
        n_c_out = 3  # number of color channels

    trainset, trainset2, testset = utils.load_data(opt=opt)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=opt.batch_size,
                                               drop_last=True,
                                               sampler=InfiniteSampler(
                                                   len(trainset)))  # model
    train_loader2 = torch.utils.data.DataLoader(trainset2,
                                                batch_size=opt.batch_size,
                                                drop_last=True,
                                                sampler=InfiniteSampler(
                                                    len(trainset2)))  # model
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=opt.batch_size,
                                              shuffle=True,
                                              drop_last=True)  # model

    n_sample = max(len(trainset), len(trainset2))
    iter_per_epoch = n_sample // opt.batch_size + 1

    src_train_iter = iter(train_loader)
    tgt_train_iter = iter(train_loader2)

    if opt.norm == True:
        X_min = -1  # 0.5 mormalize 는 0~1
        X_max = 1
    else:
        X_min = trainset.data.min()
        X_max = trainset.data.max()

    # pdb.set_trace()

    #########################
    #### Model
    #########################
    if modelsplit[0] == 'svhn' or modelsplit[1] == 'svhn' or \
        modelsplit[0] == 'usps' or modelsplit[0] == 'cifar10' or \
            modelsplit[0] == 'stl10':
        model1 = conv9(p=opt.dropout_probability).cuda(
        )  # 3x32x32 -> 1x128x1x1 (before FC)
        model2 = conv9(p=opt.dropout_probability).cuda(
        )  # 3x32x32 -> 1x128x1x1 (before FC)
    else:
        model1 = conv3(p=opt.dropout_probability).cuda(
        )  # 1x28x28 -> 1x128x4x4 (before FC)
        model2 = conv3(p=opt.dropout_probability).cuda(
        )  # 1x28x28 -> 1x128x4x4 (before FC)

    dropout_mask1 = torch.randint(2, (1, 128, 1, 1), dtype=torch.float).cuda()
    # dropout_mask1 = torch.randint(2,(1,128,4,4), dtype=torch.float).cuda()

    weights_init_gaussian = weights_init('gaussian')

    for X, Y in train_loader:
        res_x = X.shape[-1]
        break

    for X, Y in train_loader2:
        res_y = X.shape[-1]
        break

    gen_st = Generator(n_hidden=n_hidden, n_resblock=n_resblock, \
        n_ch=n_ch, res=res_x, n_c_in=n_c_in, n_c_out=n_c_out).cuda()

    gen_ts = Generator(n_hidden=n_hidden, n_resblock=n_resblock, \
        n_ch=n_ch, res=res_y, n_c_in=n_c_out, n_c_out=n_c_in).cuda()

    dis_s = Discriminator(n_ch=n_ch, res=res_x, n_c_in=n_c_in).cuda()
    dis_t = Discriminator(n_ch=n_ch, res=res_y, n_c_in=n_c_out).cuda()

    gen_st.apply(weights_init_gaussian)
    gen_ts.apply(weights_init_gaussian)
    dis_s.apply(weights_init_gaussian)
    dis_t.apply(weights_init_gaussian)

    pool_size = 50
    fake_src_x_pool = ImagePool(pool_size * opt.batch_size)
    fake_tgt_x_pool = ImagePool(pool_size * opt.batch_size)

    #########################
    #### Loss
    #########################

    config2 = {
        'lr': opt.learning_rate,
        'weight_decay': opt.weight_decay,
        'betas': (0.5, 0.999)
    }

    opt_gen = torch.optim.Adam(
        chain(gen_st.parameters(), gen_ts.parameters(), model1.parameters(),
              model2.parameters()), **config2)
    opt_dis = torch.optim.Adam(chain(dis_s.parameters(), dis_t.parameters()),
                               **config2)

    loss_CE = torch.nn.CrossEntropyLoss().cuda()
    loss_KLD = torch.nn.KLDivLoss(reduction='batchmean').cuda()

    loss_LS = GANLoss(device, use_lsgan=True)

    #########################
    #### argument print
    #########################

    writer = SummaryWriter(run_dir)

    f = open(resultname, 'w')
    f.write(prompt)
    f.close()

    #########################
    #### Run
    #########################

    if os.path.isfile(opt.pretrained):
        modelpath = opt.pretrained
        print("model load..", modelpath)
        checkpoint = torch.load(modelpath,
                                map_location='cuda:{0}'.format(opt.gpu))
        dropout_mask1 = checkpoint['dropout_mask1']

    else:
        modelpath = 'model/{0}'.format(modelname)
        os.makedirs(modelpath, exist_ok=True)

        print("model train..")
        print(modelname)

        niter = 0
        epoch = 0

        while True:
            model1.train()
            model2.train()

            niter += 1
            src_x, src_y = next(src_train_iter)
            tgt_x, tgt_y = next(tgt_train_iter)

            src_x = src_x.cuda()
            src_y = src_y.cuda()
            tgt_x = tgt_x.cuda()

            fake_tgt_x = gen_st(src_x)
            fake_src_x = gen_ts(tgt_x)
            fake_back_src_x = gen_ts(fake_tgt_x)

            if opt.prefix == 'tranlsation_noCE':
                loss_gen = opt.gamma * loss_LS(dis_s(fake_src_x), True)
                loss_gen += opt.alpha * loss_LS(dis_t(fake_tgt_x), True)
            else:
                loss_gen = opt.beta * loss_CE(model2(fake_tgt_x), src_y)
                loss_gen += opt.mu * loss_CE(model1(src_x), src_y)

                loss_gen += opt.gamma * loss_LS(dis_s(fake_src_x), True)
                loss_gen += opt.alpha * loss_LS(dis_t(fake_tgt_x), True)

            loss_dis_s = opt.gamma * loss_LS(
                dis_s(fake_src_x_pool.query(fake_src_x)), False)
            loss_dis_s += opt.gamma * loss_LS(dis_s(src_x), True)

            loss_dis_t = opt.alpha * loss_LS(
                dis_t(fake_tgt_x_pool.query(fake_tgt_x)), False)
            loss_dis_t += opt.alpha * loss_LS(dis_t(tgt_x), True)

            loss_dis = loss_dis_s + loss_dis_t

            for optim, loss in zip([opt_dis, opt_gen], [loss_dis, loss_gen]):
                optim.zero_grad()
                loss.backward(retain_graph=True)
                optim.step()

            if niter % opt.print_delay == 0 and niter > 0:
                with torch.no_grad():
                    ##########################
                    loss_dis_s1 = opt.gamma * loss_LS(
                        dis_s(fake_src_x_pool.query(fake_src_x)), False)
                    loss_dis_s2 = opt.gamma * loss_LS(dis_s(src_x), True)
                    loss_dis_t1 = opt.alpha * loss_LS(
                        dis_t(fake_tgt_x_pool.query(fake_tgt_x)), False)
                    loss_dis_t2 = opt.alpha * loss_LS(dis_t(tgt_x), True)

                    loss_gen_s = opt.gamma * loss_LS(dis_s(fake_src_x), True)
                    loss_gen_t = opt.alpha * loss_LS(dis_t(fake_tgt_x), True)
                    loss_gen_CE_t = opt.beta * loss_CE(model2(fake_tgt_x),
                                                       src_y)
                    loss_gen_CE_s = opt.mu * loss_CE(model1(src_x), src_y)
                    ###########################

                    print('epoch {0} ({1}/{2}) '.format(epoch, (niter % iter_per_epoch), iter_per_epoch ) \
                    + 'dis_s1 {0:02.4f}, dis_s2 {1:02.4f}, '.format(loss_dis_s1.item(), loss_dis_s2.item()) \
                        + 'dis_t1 {0:02.4f}, dis_t2 {1:02.4f}, '.format(loss_dis_t1.item(), loss_dis_t2.item()) \
                            + 'loss_gen_s {0:02.4f}, loss_gen_t {1:02.4f} '.format(loss_gen_s.item(), loss_gen_t.item())
                                + 'loss_gen_CE_t {0:02.4f}, loss_gen_CE_s {1:02.4f}'.format(loss_gen_CE_t.item(), loss_gen_CE_s.item()), end='\r')

                    writer.add_scalar('dis/src', loss_dis_s.item(), niter)
                    writer.add_scalar('dis/src1', loss_dis_s1.item(), niter)
                    writer.add_scalar('dis/src2', loss_dis_s2.item(), niter)
                    writer.add_scalar('dis/tgt', loss_dis_t.item(), niter)
                    writer.add_scalar('dis/tgt1', loss_dis_t1.item(), niter)
                    writer.add_scalar('dis/tgt2', loss_dis_t2.item(), niter)
                    writer.add_scalar('gen', loss_gen.item(), niter)
                    writer.add_scalar('gen/src', loss_gen_s.item(), niter)
                    writer.add_scalar('gen/tgt', loss_gen_t.item(), niter)
                    writer.add_scalar(
                        'CE/tgt',
                        loss_CE(model2(fake_tgt_x), src_y).item(), niter)
                    writer.add_scalar('CE/src',
                                      loss_CE(model1(src_x), src_y).item(),
                                      niter)

                    # pdb.set_trace()
                    if niter % (opt.print_delay * 10) == 0:
                        data_grid = []
                        for x in [
                                src_x, fake_tgt_x, fake_back_src_x, tgt_x,
                                fake_src_x
                        ]:
                            x = x.to(torch.device('cpu'))
                            if x.size(1) == 1:
                                x = x.repeat(1, 3, 1, 1)  # grayscale2rgb
                            data_grid.append(x)
                        grid = make_grid(torch.cat(tuple(data_grid), dim=0),
                                         normalize=True,
                                         range=(X_min, X_max),
                                         nrow=opt.batch_size)  # for SVHN?
                        writer.add_image('generated_{0}'.format(opt.prefix),
                                         grid, niter)

            if niter % iter_per_epoch == 0 and niter > 0:
                with torch.no_grad():
                    epoch = niter // iter_per_epoch

                    model1.eval()
                    model2.eval()

                    avgaccuracy1 = 0
                    avgaccuracy2 = 0
                    n = 0
                    nagree = 0

                    for X, Y in test_loader:
                        n += X.size()[0]
                        X_test = X.cuda()
                        Y_test = Y.cuda()

                        prediction1 = model1(X_test)  #
                        predicted_classes1 = torch.argmax(prediction1, 1)
                        correct_count1 = (predicted_classes1 == Y_test)
                        testaccuracy1 = correct_count1.float().sum()
                        avgaccuracy1 += testaccuracy1

                        prediction2 = model2(X_test)  #
                        predicted_classes2 = torch.argmax(prediction2, 1)
                        correct_count2 = (predicted_classes2 == Y_test)
                        testaccuracy2 = correct_count2.float().sum()
                        avgaccuracy2 += testaccuracy2

                    avgaccuracy1 = (avgaccuracy1 / n) * 100
                    avgaccuracy2 = (avgaccuracy2 / n) * 100
                    agreement = (predicted_classes1 == predicted_classes2)
                    nagree = nagree + (agreement).int().sum()

                    writer.add_scalar('accuracy/tgt', avgaccuracy1, niter)
                    writer.add_scalar('accuracy/src', avgaccuracy2, niter)
                    writer.add_scalar('agreement', (nagree / n) * 100, niter)

                    f = open(resultname, 'a')
                    f.write('epoch : {0}\n'.format(epoch))
                    f.write('\tloss_gen_s : {0:0.4f}\n'.format(
                        loss_gen_s.item()))
                    f.write('\tloss_gen_t : {0:0.4f}\n'.format(
                        loss_gen_t.item()))
                    f.write('\tloss_gen_CE_t : {0:0.4f}\n'.format(
                        loss_gen_CE_t.item()))
                    f.write('\tloss_gen_CE_s : {0:0.4f}\n'.format(
                        loss_gen_CE_s.item()))
                    f.write('\tloss_dis_s1 : {0:0.4f}\n'.format(
                        loss_dis_s1.item()))
                    f.write('\tloss_dis_t1 : {0:0.4f}\n'.format(
                        loss_dis_t1.item()))
                    f.write('\tloss_dis_s2 : {0:0.4f}\n'.format(
                        loss_dis_s2.item()))
                    f.write('\tloss_dis_t2 : {0:0.4f}\n'.format(
                        loss_dis_t2.item()))
                    f.write(
                        '\tavgaccuracy_tgt : {0:0.2f}\n'.format(avgaccuracy1))
                    f.write(
                        '\tavgaccuracy_src : {0:0.2f}\n'.format(avgaccuracy2))
                    f.write('\tagreement : {0}\n'.format(nagree))
                    f.close()

            if epoch >= opt.num_epochs:
                os.rename(run_dir, run_dir[:-8])
                break
Exemple #2
0
optimizer_G = torch.optim.Adam(chain(gen_st.parameters(), gen_ts.parameters()),
                               lr=0.0003)
optimizer_D_s = torch.optim.Adam(D_s.parameters(), lr=0.0003)
optimizer_D_t = torch.optim.Adam(D_t.parameters(), lr=0.0003)
optimizer = torch.optim.SGD(model.parameters(),
                            lr=opt.lr,
                            weight_decay=opt.weight_decay,
                            momentum=0.9)
# optimizer_ad = torch.optim.SGD(ad_net.parameters(), lr=opt.lr, weight_decay=opt.weight_decay, momentum=0.9)

### Data_load
trainset, trainset2, testset = utils.load_data(opt=opt)
train_loader = torch.utils.data.DataLoader(trainset,
                                           batch_size=opt.batch_size,
                                           drop_last=True,
                                           sampler=InfiniteSampler(
                                               len(trainset)))  # model
train_loader2 = torch.utils.data.DataLoader(trainset2,
                                            batch_size=opt.batch_size,
                                            drop_last=True,
                                            sampler=InfiniteSampler(
                                                len(trainset2)))  # model
test_loader = torch.utils.data.DataLoader(testset,
                                          batch_size=opt.batch_size,
                                          shuffle=True,
                                          drop_last=True)  # model

n_sample = max(len(trainset), len(trainset2))
iter_per_epoch = n_sample // opt.batch_size + 1

src_train_iter = iter(train_loader)
tgt_train_iter = iter(train_loader2)
prompt += run_dir + '\n'
for arg in vars(opt):
    prompt = '{0}{1} : {2}\n'.format(prompt, arg, getattr(opt, arg))
prompt += ('====================================\n')
print(prompt, end='')

cuda = False
if torch.cuda.is_available():
    cuda = True
    torch.cuda.set_device(opt.gpu)
    device = torch.device('cuda:{0}'.format(opt.gpu))
# Configure data loader

import utils
trainset, trainset2, testset = utils.load_data(opt=opt)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, drop_last=True, sampler=InfiniteSampler(len(trainset))) # model
train_loader2 = torch.utils.data.DataLoader(trainset2, batch_size=opt.batch_size, drop_last=True, sampler=InfiniteSampler(len(trainset2))) # model
test_loader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=True, drop_last=True) # model

n_sample = max(len(trainset), len(trainset2))
iter_per_epoch = n_sample // opt.batch_size + 1

src_train_iter = iter(train_loader)
tgt_train_iter = iter(train_loader2)

if opt.norm == True:
    X_min = -1 # 0.5 mormalize 는 0~1
    X_max = 1
else:
    X_min = trainset.data.min()
    X_max = trainset.data.max()
def experiment(exp, num_epochs, pretrain, consistency):
    config = get_config('config.yaml')
    identifier = '{:s}_ndf{:d}_ngf{:d}'.format(consistency,
                                               config['dis']['ndf'],
                                               config['gen']['ngf'])
    log_dir = 'log/{:s}/{:s}'.format(exp, identifier)
    snapshot_dir = 'snapshot/{:s}/{:s}'.format(exp, identifier)
    writer = SummaryWriter(log_dir=log_dir)
    os.makedirs(snapshot_dir, exist_ok=True)

    shutil.copy('config.yaml', '{:s}/{:s}'.format(snapshot_dir, 'config.yaml'))
    batch_size = int(config['batch_size'])
    pool_size = int(config['pool_size'])
    lr = float(config['lr'])
    weight_decay = float(config['weight_decay'])

    device = torch.device('cuda')

    src, tgt = load_source_target_datasets(exp)

    n_ch_s = src.train_X.shape[1]  # number of color channels
    n_ch_t = tgt.train_X.shape[1]  # number of color channels
    n_class = src.n_classes

    train_tfs = get_composed_transforms()
    test_tfs = get_composed_transforms()

    src_train = DADataset(src.train_X, src.train_y, train_tfs)
    src_test = DADataset(src.test_X, src.test_y, train_tfs)
    tgt_train = DADataset(tgt.train_X, tgt.train_y, train_tfs)
    tgt_train = SubsetDataset(tgt_train, range(1000))  # fix indices
    tgt_test = DADataset(tgt.test_X, tgt.test_y, test_tfs)
    del src, tgt

    n_sample = max(len(src_train), len(tgt_train))
    iter_per_epoch = n_sample // batch_size + 1

    cls_s = Classifier(n_class, n_ch_s).to(device)
    cls_t = Classifier(n_class, n_ch_t).to(device)

    if not pretrain:
        load_model(cls_s, 'snapshot/{:s}/pretrain_cls_s.tar'.format(exp))

    gen_s_t_params = {'input_nc': n_ch_s, 'output_nc': n_ch_t}
    gen_t_s_params = {'input_nc': n_ch_t, 'output_nc': n_ch_s}
    gen_s_t = define_G(**{**config['gen'], **gen_s_t_params}).to(device)
    gen_t_s = define_G(**{**config['gen'], **gen_t_s_params}).to(device)

    dis_s = define_D(**{**config['dis'], 'input_nc': n_ch_s}).to(device)
    dis_t = define_D(**{**config['dis'], 'input_nc': n_ch_t}).to(device)

    opt_config = {'lr': lr, 'weight_decay': weight_decay, 'betas': (0.5, 0.99)}
    opt_gen = Adam(chain(gen_s_t.parameters(), gen_t_s.parameters(), \
                         cls_s.parameters(), cls_t.parameters()), **opt_config)
    opt_dis = Adam(chain(dis_s.parameters(), dis_t.parameters()), **opt_config)

    calc_ls = GANLoss(device, use_lsgan=True).to(device)
    calc_ce = torch.nn.CrossEntropyLoss().to(device)
    calc_l1 = torch.nn.L1Loss().to(device)

    fake_src_x_pool = ImagePool(pool_size * batch_size)
    fake_tgt_x_pool = ImagePool(pool_size * batch_size)

    src_train_iter = iter(
        DataLoader(src_train,
                   batch_size=batch_size,
                   num_workers=4,
                   sampler=InfiniteSampler(len(src_train))))
    tgt_train_iter = iter(
        DataLoader(tgt_train,
                   batch_size=batch_size,
                   num_workers=4,
                   sampler=InfiniteSampler(len(tgt_train))))
    src_test_loader = DataLoader(src_test,
                                 batch_size=batch_size * 4,
                                 num_workers=4)
    tgt_test_loader = DataLoader(tgt_test,
                                 batch_size=batch_size * 4,
                                 num_workers=4)
    print('Training...')

    cls_s.train()
    cls_t.train()

    niter = 0

    if pretrain:
        while True:
            niter += 1
            src_x, src_y = next(src_train_iter)
            loss = calc_ce(cls_s(src_x.to(device)), src_y.to(device))
            opt_gen.zero_grad()
            loss.backward()
            opt_gen.step()

            if niter % iter_per_epoch == 0:
                epoch = niter // iter_per_epoch
                n_err = evaluate_classifier(cls_s, tgt_test_loader, device)
                print(epoch, n_err / len(tgt_test))

                # n_err = evaluate_classifier(cls_s, src_test_loader, device)
                # print(epoch, n_err / len(src_test))

                if epoch >= num_epochs:
                    save_model(cls_s,
                               '{:s}/pretrain_cls_s.tar'.format(snapshot_dir))
                    break
        exit()

    while True:
        niter += 1
        src_x, src_y = next(src_train_iter)
        tgt_x, tgt_y = next(tgt_train_iter)
        src_x, src_y = src_x.to(device), src_y.to(device)
        tgt_x, tgt_y = tgt_x.to(device), tgt_y.to(device)

        fake_tgt_x = gen_s_t(src_x)
        fake_back_src_x = gen_t_s(fake_tgt_x)
        fake_src_x = gen_t_s(tgt_x)
        fake_back_tgt_x = gen_s_t(fake_src_x)

        #################
        # discriminator #
        #################

        loss_dis_s = calc_ls(dis_s(fake_src_x_pool.query(fake_src_x.detach())),
                             False)
        loss_dis_s += calc_ls(dis_s(src_x), True)
        loss_dis_t = calc_ls(dis_t(fake_tgt_x_pool.query(fake_tgt_x.detach())),
                             False)
        loss_dis_t += calc_ls(dis_t(tgt_x), True)
        loss_dis = loss_dis_s + loss_dis_t

        ##########################
        # generator + classifier #
        ##########################

        # classification
        loss_gen_cls_s = calc_ce(cls_s(src_x), src_y)
        loss_gen_cls_t = calc_ce(cls_t(tgt_x), tgt_y)
        loss_gen_cls = loss_gen_cls_s + loss_gen_cls_t

        # augmented cycle consistency
        if consistency == 'augmented':
            loss_gen_aug_s = calc_ce(cls_s(fake_src_x), tgt_y)
            loss_gen_aug_s += calc_ce(cls_s(fake_back_src_x), src_y)
            loss_gen_aug_t = calc_ce(cls_t(fake_tgt_x), src_y)
            loss_gen_aug_t += calc_ce(cls_t(fake_back_tgt_x), tgt_y)
            loss_gen_aug = loss_gen_aug_s + loss_gen_aug_t
        elif consistency == 'relaxed':
            loss_gen_aug_s = calc_ce(cls_s(fake_back_src_x), src_y)
            loss_gen_aug_t = calc_ce(cls_t(fake_back_tgt_x), tgt_y)
            loss_gen_aug = loss_gen_aug_s + loss_gen_aug_t
        elif consistency == 'simple':
            loss_gen_aug_s = calc_ce(cls_s(fake_src_x), tgt_y)
            loss_gen_aug_t = calc_ce(cls_t(fake_tgt_x), src_y)
            loss_gen_aug = loss_gen_aug_s + loss_gen_aug_t
        elif consistency == 'cycle':
            loss_gen_aug_s = calc_l1(fake_back_src_x, src_x)
            loss_gen_aug_t = calc_l1(fake_back_tgt_x, tgt_x)
            loss_gen_aug = loss_gen_aug_s + loss_gen_aug_t
        else:
            raise NotImplementedError

        # deceive discriminator
        loss_gen_adv_s = calc_ls(dis_s(fake_src_x), True)
        loss_gen_adv_t = calc_ls(dis_t(fake_tgt_x), True)
        loss_gen_adv = loss_gen_adv_s + loss_gen_adv_t

        loss_gen = loss_gen_cls + loss_gen_aug + loss_gen_adv

        opt_dis.zero_grad()
        loss_dis.backward()
        opt_dis.step()

        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        if niter % 100 == 0 and niter > 0:
            writer.add_scalar('dis/src', loss_dis_s.item(), niter)
            writer.add_scalar('dis/tgt', loss_dis_t.item(), niter)
            writer.add_scalar('gen/cls_s', loss_gen_cls_s.item(), niter)
            writer.add_scalar('gen/cls_t', loss_gen_cls_t.item(), niter)
            writer.add_scalar('gen/aug_s', loss_gen_aug_s.item(), niter)
            writer.add_scalar('gen/aug_t', loss_gen_aug_t.item(), niter)
            writer.add_scalar('gen/adv_s', loss_gen_adv_s.item(), niter)
            writer.add_scalar('gen/adv_t', loss_gen_adv_t.item(), niter)

        if niter % iter_per_epoch == 0:
            epoch = niter // iter_per_epoch

            if epoch % 1 == 0:
                data = []
                for x in [
                        src_x, fake_tgt_x, fake_back_src_x, tgt_x, fake_src_x,
                        fake_back_tgt_x
                ]:
                    x = x.to(torch.device('cpu'))
                    if x.size(1) == 1:
                        x = x.repeat(1, 3, 1, 1)  # grayscale2rgb
                    data.append(x)
                grid = make_grid(torch.cat(tuple(data), dim=0),
                                 nrow=16,
                                 normalize=True,
                                 range=(-1.0, 1.0))
                writer.add_image('generated', grid, epoch)

            n_err = evaluate_classifier(cls_t, tgt_test_loader, device)
            writer.add_scalar('err_tgt', n_err / len(tgt_test), epoch)

            if epoch % 50 == 0:
                models_dict = {
                    'cls_s': cls_s,
                    'cls_t': cls_t,
                    'dis_s': dis_s,
                    'dis_t': dis_t,
                    'gen_s_t': gen_s_t,
                    'gen_t_s': gen_t_s
                }
                filename = '{:s}/epoch{:d}.tar'.format(snapshot_dir, epoch)
                save_models_dict(models_dict, filename)

            if epoch >= num_epochs:
                break
def experiment(exp, affine, num_epochs):
    writer = SummaryWriter()
    log_dir = 'log/{:s}/sbada'.format(exp)
    os.makedirs(log_dir, exist_ok=True)
    device = torch.device('cuda')

    config = get_config('config.yaml')

    alpha = float(config['weight']['alpha'])
    beta = float(config['weight']['beta'])
    gamma = float(config['weight']['gamma'])
    mu = float(config['weight']['mu'])
    new = float(config['weight']['new'])
    eta = 0.0
    batch_size = int(config['batch_size'])
    pool_size = int(config['pool_size'])
    lr = float(config['lr'])
    weight_decay = float(config['weight_decay'])

    src, tgt = load_source_target_datasets(exp)

    n_ch_s = src.train_X.shape[1]  # number of color channels
    n_ch_t = tgt.train_X.shape[1]  # number of color channels
    res = src.train_X.shape[-1]  # size of image
    n_classes = src.n_classes

    train_tfs = get_composed_transforms(train=True, hflip=False)
    test_tfs = get_composed_transforms(train=False, hflip=False)

    src_train = DADataset(src.train_X, src.train_y, train_tfs, affine)
    tgt_train = DADataset(tgt.train_X, None, train_tfs, affine)
    tgt_test = DADataset(tgt.test_X, tgt.test_y, test_tfs, affine)
    del src, tgt

    n_sample = max(len(src_train), len(tgt_train))
    iter_per_epoch = n_sample // batch_size + 1

    weights_init_kaiming = weights_init('kaiming')
    weights_init_gaussian = weights_init('gaussian')

    cls_s = LenetClassifier(n_classes, n_ch_s, res).to(device)
    cls_t = LenetClassifier(n_classes, n_ch_t, res).to(device)

    cls_s.apply(weights_init_kaiming)
    cls_t.apply(weights_init_kaiming)

    gen_s_t_params = {'res': res, 'n_c_in': n_ch_s, 'n_c_out': n_ch_t}
    gen_t_s_params = {'res': res, 'n_c_in': n_ch_t, 'n_c_out': n_ch_s}
    gen_s_t = Generator(**{**config['gen_init'], **gen_s_t_params}).to(device)
    gen_t_s = Generator(**{**config['gen_init'], **gen_t_s_params}).to(device)
    gen_s_t.apply(weights_init_gaussian)
    gen_t_s.apply(weights_init_gaussian)

    dis_s_params = {'res': res, 'n_c_in': n_ch_s}
    dis_t_params = {'res': res, 'n_c_in': n_ch_t}
    dis_s = Discriminator(**{**config['dis_init'], **dis_s_params}).to(device)
    dis_t = Discriminator(**{**config['dis_init'], **dis_t_params}).to(device)
    dis_s.apply(weights_init_gaussian)
    dis_t.apply(weights_init_gaussian)

    config = {'lr': lr, 'weight_decay': weight_decay, 'betas': (0.5, 0.999)}
    opt_gen = Adam(
        chain(gen_s_t.parameters(), gen_t_s.parameters(), cls_s.parameters(),
              cls_t.parameters()), **config)
    opt_dis = Adam(chain(dis_s.parameters(), dis_t.parameters()), **config)

    calc_ls = GANLoss(device, use_lsgan=True)
    calc_ce = F.cross_entropy

    fake_src_x_pool = ImagePool(pool_size * batch_size)
    fake_tgt_x_pool = ImagePool(pool_size * batch_size)

    src_train_iter = iter(
        DataLoader(src_train,
                   batch_size=batch_size,
                   num_workers=4,
                   sampler=InfiniteSampler(len(src_train))))
    tgt_train_iter = iter(
        DataLoader(tgt_train,
                   batch_size=batch_size,
                   num_workers=4,
                   sampler=InfiniteSampler(len(tgt_train))))
    tgt_test_loader = DataLoader(tgt_test,
                                 batch_size=batch_size * 4,
                                 num_workers=4)
    print('Training...')

    cls_s.train()
    cls_t.train()

    niter = 0
    while True:
        niter += 1
        src_x, src_y = next(src_train_iter)
        tgt_x = next(tgt_train_iter)
        src_x, src_y = src_x.to(device), src_y.to(device)
        tgt_x = tgt_x.to(device)

        if niter >= num_epochs * 0.75 * iter_per_epoch:
            eta = config['weight']['eta']

        fake_tgt_x = gen_s_t(src_x)
        fake_back_src_x = gen_t_s(fake_tgt_x)
        fake_src_x = gen_t_s(tgt_x)

        with torch.no_grad():
            fake_src_pseudo_y = torch.max(cls_s(fake_src_x), dim=1)[1]

        # eq2
        loss_gen = beta * calc_ce(cls_t(fake_tgt_x), src_y)
        loss_gen += mu * calc_ce(cls_s(src_x), src_y)

        # eq3
        loss_gen += gamma * calc_ls(dis_s(fake_src_x), True)
        loss_gen += alpha * calc_ls(dis_t(fake_tgt_x), True)

        # eq5
        loss_gen += eta * calc_ce(cls_s(fake_src_x), fake_src_pseudo_y)

        # eq6
        loss_gen += new * calc_ce(cls_s(fake_back_src_x), src_y)

        # do not backpropagate loss to generator
        fake_tgt_x = fake_tgt_x.detach()
        fake_src_x = fake_src_x.detach()
        fake_back_src_x = fake_back_src_x.detach()

        # eq3
        loss_dis_s = gamma * calc_ls(dis_s(fake_src_x_pool.query(fake_src_x)),
                                     False)
        loss_dis_s += gamma * calc_ls(dis_s(src_x), True)
        loss_dis_t = alpha * calc_ls(dis_t(fake_tgt_x_pool.query(fake_tgt_x)),
                                     False)
        loss_dis_t += alpha * calc_ls(dis_t(tgt_x), True)

        loss_dis = loss_dis_s + loss_dis_t

        for opt, loss in zip([opt_dis, opt_gen], [loss_dis, loss_gen]):
            opt.zero_grad()
            loss.backward(retain_graph=True)
            opt.step()

        if niter % 100 == 0 and niter > 0:
            writer.add_scalar('dis/src', loss_dis_s.item(), niter)
            writer.add_scalar('dis/tgt', loss_dis_t.item(), niter)
            writer.add_scalar('gen', loss_gen.item(), niter)

        if niter % iter_per_epoch == 0:
            epoch = niter // iter_per_epoch

            if epoch % 10 == 0:
                data = []
                for x in [
                        src_x, fake_tgt_x, fake_back_src_x, tgt_x, fake_src_x
                ]:
                    x = x.to(torch.device('cpu'))
                    if x.size(1) == 1:
                        x = x.repeat(1, 3, 1, 1)  # grayscale2rgb
                    data.append(x)
                grid = make_grid(torch.cat(tuple(data), dim=0),
                                 normalize=True,
                                 range=(-1.0, 1.0))
                writer.add_image('generated', grid, epoch)

            cls_t.eval()

            n_err = 0
            with torch.no_grad():
                for tgt_x, tgt_y in tgt_test_loader:
                    prob_y = F.softmax(cls_t(tgt_x.to(device)), dim=1)
                    pred_y = torch.max(prob_y, dim=1)[1]
                    pred_y = pred_y.to(torch.device('cpu'))
                    n_err += (pred_y != tgt_y).sum().item()

            writer.add_scalar('err_tgt', n_err / len(tgt_test), epoch)

            cls_t.train()

            if epoch % 50 == 0:
                models_dict = {
                    'cls_s': cls_s,
                    'cls_t': cls_t,
                    'dis_s': dis_s,
                    'dis_t': dis_t,
                    'gen_s_t': gen_s_t,
                    'gen_t_s': gen_t_s
                }
                filename = '{:s}/epoch{:d}.tar'.format(log_dir, epoch)
                save_models_dict(models_dict, filename)

            if epoch >= num_epochs:
                break