def main():
    # Create output directory
    path_output = './checkpoints/'
    if not os.path.exists(path_output):
        os.makedirs(path_output)

    # Hyperparameters, to change
    epochs = 30
    batch_size = 8
    alpha = 1  # it's the trade-off parameter of loss function, what values should it take?
    gamma = 1
    # Source domains name
    root = 'data/'

    source1 = 'real'
    source2 = 'sketch'
    source3 = 'infograph'
    target = 'quickdraw'

    # Dataloader
    dataset_s1 = dataset.DA(dir=root,
                            name=source1,
                            img_size=(224, 224),
                            train=True)
    dataset_s2 = dataset.DA(dir=root,
                            name=source2,
                            img_size=(224, 224),
                            train=True)
    dataset_s3 = dataset.DA(dir=root,
                            name=source3,
                            img_size=(224, 224),
                            train=True)
    dataset_t = dataset.DA(dir=root,
                           name=target,
                           img_size=(224, 224),
                           train=True)
    dataset_val = dataset.DA(dir=root,
                             name=target,
                             img_size=(224, 224),
                             train=True,
                             real_val=False)

    dataloader_s1 = DataLoader(dataset_s1,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=2)
    dataloader_s2 = DataLoader(dataset_s2,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=2)
    dataloader_s3 = DataLoader(dataset_s3,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=2)
    dataloader_t = DataLoader(dataset_t,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=2)
    dataloader_val = DataLoader(dataset_val,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=2)

    len_data = min(len(dataset_s1), len(dataset_s2), len(dataset_s3),
                   len(dataset_t))  # length of "shorter" domain
    len_dataloader = min(len(dataloader_s1), len(dataloader_s2),
                         len(dataloader_s3), len(dataloader_t))

    # Define networks
    feature_extractor = models.feature_extractor()
    classifier_1 = models.class_classifier()
    classifier_2 = models.class_classifier()
    classifier_3 = models.class_classifier()
    classifier_1.apply(weight_init)
    classifier_2.apply(weight_init)
    classifier_3.apply(weight_init)

    discriminator = models.discriminator()
    discriminator.apply(weight_init)

    if torch.cuda.is_available():
        feature_extractor = feature_extractor.cuda()
        classifier_1 = classifier_1.cuda()
        classifier_2 = classifier_2.cuda()
        classifier_3 = classifier_3.cuda()
        discriminator = discriminator.cuda()

    # Define loss
    cl_loss = nn.CrossEntropyLoss()
    disc_loss = nn.NLLLoss()

    # Optimizers
    # Change the LR
    optimizer_features = SGD(feature_extractor.parameters(),
                             lr=0.0001,
                             momentum=0.9,
                             weight_decay=5e-4)
    optimizer_classifier = SGD(([{
        'params': classifier_1.parameters()
    }, {
        'params': classifier_2.parameters()
    }, {
        'params': classifier_3.parameters()
    }]),
                               lr=0.002,
                               momentum=0.9,
                               weight_decay=5e-4)

    optimizer_discriminator = SGD(([
        {
            'params': discriminator.parameters()
        },
    ]),
                                  lr=0.002,
                                  momentum=0.9,
                                  weight_decay=5e-4)

    # Lists
    training_loss = []
    train_loss = []
    train_class_loss = []
    train_domain_loss = []
    val_class_loss = []
    val_domain_loss = []
    acc_on_target = []
    best_acc = 0.0
    for epoch in range(epochs):
        epochTic = timeit.default_timer()
        tot_loss = 0.0
        tot_c_loss = 0.0
        tot_d_loss = 0.0
        tot_val_c_loss = 0.0
        tot_val_d_loss = 0.0
        w1_mean = 0.0
        w2_mean = 0.0
        w3_mean = 0.0
        feature_extractor.train()
        classifier_1.train(), classifier_2.train(), classifier_3.train()
        discriminator.train()
        if epoch + 1 == 5:
            optimizer_classifier = SGD(([{
                'params': classifier_1.parameters()
            }, {
                'params': classifier_2.parameters()
            }, {
                'params': classifier_3.parameters()
            }]),
                                       lr=0.001,
                                       momentum=0.9,
                                       weight_decay=5e-4)

            optimizer_discriminator = SGD(
                ([{
                    'params': discriminator.parameters()
                }]),
                lr=0.001,
                momentum=0.9,
                weight_decay=5e-4)

        if epoch + 1 == 10:
            optimizer_classifier = SGD(([{
                'params': classifier_1.parameters()
            }, {
                'params': classifier_2.parameters()
            }, {
                'params': classifier_3.parameters()
            }]),
                                       lr=0.0001,
                                       momentum=0.9,
                                       weight_decay=5e-4)

            optimizer_discriminator = SGD(
                ([{
                    'params': discriminator.parameters()
                }]),
                lr=0.0001,
                momentum=0.9,
                weight_decay=5e-4)
        print('*************************************************')
        for i, (data_1, data_2, data_3, data_t) in enumerate(
                zip(dataloader_s1, dataloader_s2, dataloader_s3,
                    dataloader_t)):

            p = float(i + epoch * len_data) / epochs / len_data

            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            img1, lb1 = data_1
            img2, lb2 = data_2
            img3, lb3 = data_3
            imgt, _ = data_t

            # Prepare data
            cur_batch = min(img1.shape[0], img2.shape[0], img3.shape[0],
                            imgt.shape[0])

            img1, lb1 = Variable(img1[0:cur_batch, :, :, :]).cuda(), Variable(
                lb1[0:cur_batch]).cuda()
            img2, lb2 = Variable(img2[0:cur_batch, :, :, :]).cuda(), Variable(
                lb2[0:cur_batch]).cuda()
            img3, lb3 = Variable(img3[0:cur_batch, :, :, :]).cuda(), Variable(
                lb3[0:cur_batch]).cuda()
            imgt = Variable(imgt[0:cur_batch, :, :, :]).cuda()

            # Forward
            optimizer_features.zero_grad()
            optimizer_classifier.zero_grad()
            optimizer_discriminator.zero_grad()

            # Extract Features
            ft1 = feature_extractor(img1)
            ft2 = feature_extractor(img2)
            ft3 = feature_extractor(img3)
            ft_t = feature_extractor(imgt)

            # Train the discriminator
            ds_s1 = discriminator(torch.cat((ft1, ft2, ft3)), alpha)
            ds_t = discriminator(ft_t, alpha)

            # Class Prediction
            cl1 = classifier_1(ft1)
            cl2 = classifier_2(ft2)
            cl3 = classifier_3(ft3)

            # Compute the "discriminator loss"
            ds_label = torch.zeros(cur_batch * 3).long()
            dt_label = torch.ones(cur_batch).long()

            d_s = disc_loss(ds_s1, ds_label.cuda())
            d_t = disc_loss(ds_t, dt_label.cuda())

            # Cross entropy loss
            l1 = cl_loss(cl1, lb1)
            l2 = cl_loss(cl2, lb2)
            l3 = cl_loss(cl3, lb3)

            # Classifier Weight
            total_class_loss = 1 / l1 + 1 / l2 + 1 / l3
            w1 = (1 / l1) / total_class_loss
            w2 = (1 / l2) / total_class_loss
            w3 = (1 / l3) / total_class_loss
            w1_mean += w1
            w2_mean += w2
            w3_mean += w3

            # total loss
            Class_loss = l1 + l2 + l3
            Domain_loss = gamma * (d_s + d_t)
            loss = Class_loss + Domain_loss

            loss.backward()
            optimizer_features.step()
            optimizer_classifier.step()
            optimizer_discriminator.step()

            tot_loss += loss.item() * cur_batch
            tot_c_loss += Class_loss.item()
            tot_d_loss += Domain_loss.item()
            # Progress indicator
            print('\rTraining... Progress: %.1f %%' %
                  (100 * (i + 1) / len_dataloader),
                  end='')

            # Save Class loss and Domain loss
            if i % 50 == 0:
                train_class_loss.append(tot_c_loss / (i + 1))
                train_domain_loss.append(tot_d_loss / (i + 1))
                train_loss.append(tot_loss / (i + 1) / cur_batch)

        tot_t_loss = tot_loss / (len_data)
        training_loss.append(tot_t_loss)
        w1_mean /= len_dataloader
        w2_mean /= len_dataloader
        w3_mean /= len_dataloader
        #print(w1_mean,w2_mean,w3_mean)

        print('\rEpoch [%d/%d], Training loss: %.4f' %
              (epoch + 1, epochs, tot_t_loss),
              end='\n')
        ####################################################################################################################
        # Compute the accuracy at the end of each epoch
        feature_extractor.eval()
        classifier_1.eval(), classifier_2.eval(), classifier_3.eval()
        discriminator.eval()

        tot_acc = 0
        with torch.no_grad():
            for i, (imgt, lbt) in enumerate(dataloader_val):

                cur_batch = imgt.shape[0]

                imgt = imgt.cuda()
                lbt = lbt.cuda()

                # Forward the test images
                ft_t = feature_extractor(imgt)

                pred1 = classifier_1(ft_t)
                pred2 = classifier_2(ft_t)
                pred3 = classifier_3(ft_t)
                val_ds_t = discriminator(ft_t, alpha)

                # Compute class loss
                val_l1 = cl_loss(pred1, lbt)
                val_l2 = cl_loss(pred2, lbt)
                val_l3 = cl_loss(pred3, lbt)
                val_CE_loss = val_l1 + val_l2 + val_l3

                # Compute domain loss
                val_dt_label = torch.ones(cur_batch).long()
                val_d_t = disc_loss(val_ds_t, val_dt_label.cuda())

                # Compute accuracy
                output = pred1 * w1_mean + pred2 * w2_mean + pred3 * w3_mean
                _, pred = torch.max(output, dim=1)
                correct = pred.eq(lbt.data.view_as(pred))
                accuracy = torch.mean(correct.type(torch.FloatTensor))
                tot_acc += accuracy.item() * cur_batch

                # total loss
                tot_val_c_loss += val_CE_loss.item()
                tot_val_d_loss += val_d_t.item()

                # Progress indicator
                print('\rValidation... Progress: %.1f %%' %
                      (100 * (i + 1) / len(dataloader_val)),
                      end='')

                # Save validation loss
                if i % 50 == 0:
                    val_class_loss.append(tot_val_c_loss / (i + 1))
                    val_domain_loss.append(tot_val_d_loss / (i + 1))

            tot_t_acc = tot_acc / (len(dataset_val))

            # Print
            acc_on_target.append(tot_t_acc)
            print('\rEpoch [%d/%d], Accuracy on target: %.4f' %
                  (epoch + 1, epochs, tot_t_acc),
                  end='\n')

        # Save every save_interval
        if best_acc < tot_t_acc:
            torch.save(
                {
                    'epoch': epoch,
                    'feature_extractor': feature_extractor.state_dict(),
                    '{}_classifier'.format(source1): classifier_1.state_dict(),
                    '{}_classifier'.format(source2): classifier_2.state_dict(),
                    '{}_classifier'.format(source3): classifier_3.state_dict(),
                    'discriminator': discriminator.state_dict(),
                    'features_optimizer': optimizer_features.state_dict(),
                    'classifier_optimizer': optimizer_classifier.state_dict(),
                    'loss': training_loss,
                    '{}_weight'.format(source1): w1_mean,
                    '{}_weight'.format(source2): w2_mean,
                    '{}_weight'.format(source3): w3_mean,
                },
                os.path.join(path_output,
                             target + '-{}-deming.pth'.format(epoch)))
            print('Saved best model!')
            best_acc = tot_t_acc

        # Pirnt elapsed time per epoch
        epochToc = timeit.default_timer()
        (t_min, t_sec) = divmod((epochToc - epochTic), 60)
        print('Elapsed time is: %d min: %d sec' % (t_min, t_sec))

        # Save training loss and accuracy on target (if not 'real')
        pkl.dump(train_loss,
                 open('{}total_loss_{}.p'.format(path_output, target), 'wb'))
        pkl.dump(train_class_loss,
                 open('{}class_loss_{}.p'.format(path_output, target), 'wb'))
        pkl.dump(train_domain_loss,
                 open('{}domain_loss_{}.p'.format(path_output, target), 'wb'))
        pkl.dump(
            acc_on_target,
            open('{}target_accuracy_{}.p'.format(path_output, target), 'wb'))
        pkl.dump(
            val_class_loss,
            open('{}val_class_loss_{}.p'.format(path_output, target), 'wb'))
        pkl.dump(
            val_domain_loss,
            open('{}val_domain_loss_{}.p'.format(path_output, target), 'wb'))
예제 #2
0
def main():
    # Create output directory
    path_output = './checkpoints/'
    if not os.path.exists(path_output):
        os.makedirs(path_output)

    # Hyperparameters, to change
    epochs = 50
    batch_size = 8
    alpha = 1  # it's the trade-off parameter of loss function, what values should it take?

    # Source domains name
    save_interval = 10  # save every 10 epochs
    root = 'data/'

    source1 = 'sketch'
    source2 = 'quickdraw'
    source3 = 'infograph'
    target = 'real'

    # Dataloader
    dataset_s1 = dataset.DA(dir=root,
                            name=source1,
                            img_size=(224, 224),
                            train=True)
    dataset_s2 = dataset.DA(dir=root,
                            name=source2,
                            img_size=(224, 224),
                            train=True)
    dataset_s3 = dataset.DA(dir=root,
                            name=source3,
                            img_size=(224, 224),
                            train=True)

    if target == 'real':
        tmp = os.path.join(root, 'test')
        dataset_t = dataset.DA_test(dir=tmp, img_size=(224, 224))
    else:
        dataset_t = dataset.DA(dir=root,
                               name=target,
                               img_size=(224, 224),
                               train=False)

    dataloader_s1 = DataLoader(dataset_s1,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=2)
    dataloader_s2 = DataLoader(dataset_s2,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=2)
    dataloader_s3 = DataLoader(dataset_s3,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=2)
    dataloader_t = DataLoader(dataset_t,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=2)

    len_data = min(len(dataset_s1), len(dataset_s2), len(dataset_s3),
                   len(dataset_t))  # length of "shorter" domain

    # Define networks
    feature_extractor = models.feature_extractor()
    classifier_1 = models.class_classifier()
    classifier_2 = models.class_classifier()
    classifier_3 = models.class_classifier()

    # Weight initialization
    classifier_1.apply(weight_init)
    classifier_2.apply(weight_init)
    classifier_3.apply(weight_init)

    if torch.cuda.is_available():
        feature_extractor = feature_extractor.cuda()
        classifier_1 = classifier_1.cuda()
        classifier_2 = classifier_2.cuda()
        classifier_3 = classifier_3.cuda()

    # Define loss
    mom_loss = momentumLoss()
    cl_loss = nn.CrossEntropyLoss()

    # Optimizers
    # Change the LR
    optimizer_features = Adam(feature_extractor.parameters(), lr=0.0001)
    optimizer_classifier = Adam(([{
        'params': classifier_1.parameters()
    }, {
        'params': classifier_2.parameters()
    }, {
        'params': classifier_3.parameters()
    }]),
                                lr=0.002)

    # Lists
    train_loss = []
    acc_on_target = []

    for epoch in range(epochs):
        tot_loss = 0.0
        feature_extractor.train()
        classifier_1.train(), classifier_2.train(), classifier_3.train()

        for i, (data_1, data_2, data_3, data_t) in enumerate(
                zip(dataloader_s1, dataloader_s2, dataloader_s3,
                    dataloader_t)):
            img1, lb1 = data_1
            img2, lb2 = data_2
            img3, lb3 = data_3
            if target == 'real':
                imgt = data_t
            else:
                imgt, _ = data_t

            # Prepare data
            cur_batch = min(img1.shape[0], img2.shape[0], img3.shape[0],
                            imgt.shape[0])

            img1, lb1 = Variable(img1[0:cur_batch, :, :, :]).cuda(), Variable(
                lb1[0:cur_batch]).cuda()
            img2, lb2 = Variable(img2[0:cur_batch, :, :, :]).cuda(), Variable(
                lb2[0:cur_batch]).cuda()
            img3, lb3 = Variable(img3[0:cur_batch, :, :, :]).cuda(), Variable(
                lb3[0:cur_batch]).cuda()
            imgt = Variable(imgt[0:cur_batch, :, :, :]).cuda()

            # Forward
            optimizer_features.zero_grad()
            optimizer_classifier.zero_grad()

            # Extract Features
            ft1 = feature_extractor(img1)
            ft2 = feature_extractor(img2)
            ft3 = feature_extractor(img3)
            ft_t = feature_extractor(imgt)

            # Class Prediction
            cl1 = classifier_1(ft1)
            cl2 = classifier_2(ft2)
            cl3 = classifier_3(ft3)

            # Compute "momentum loss"
            loss_mom = mom_loss(ft1, ft2, ft3, ft_t)

            # Cross entropy loss
            l1 = cl_loss(cl1, lb1)
            l2 = cl_loss(cl2, lb2)
            l3 = cl_loss(cl3, lb3)

            # total loss
            loss = l1 + l2 + l3 + alpha * loss_mom
            #print(loss_mom,(l1+l2+l3))
            loss.backward()
            optimizer_features.step()
            optimizer_classifier.step()

            tot_loss += loss.item() * cur_batch

        tot_t_loss = tot_loss / (len_data)

        # Print
        train_loss.append(tot_t_loss)
        print('*************************************************')
        print('Epoch [%d/%d], Training loss: %.4f' %
              (epoch + 1, epochs, tot_t_loss))
        ####################################################################################################################
        # Compute the accuracy at the end of each epoch
        if target != 'real':

            feature_extractor.eval()
            classifier_1.eval(), classifier_2.eval(), classifier_3.eval()

            tot_acc = 0
            with torch.no_grad():
                for i, (imgt, lbt) in enumerate(dataloader_t):

                    cur_batch = imgt.shape[0]

                    imgt = imgt.cuda()
                    lbt = lbt.cuda()

                    # Forward the test images
                    ft_t = feature_extractor(imgt)

                    pred1 = classifier_1(ft_t)
                    pred2 = classifier_2(ft_t)
                    pred3 = classifier_3(ft_t)

                    # Compute accuracy
                    output = torch.mean(torch.stack((pred1, pred2, pred3)), 0)
                    _, pred = torch.max(output, dim=1)
                    correct = pred.eq(lbt.data.view_as(pred))
                    accuracy = torch.mean(correct.type(torch.FloatTensor))
                    tot_acc += accuracy.item() * cur_batch

            tot_t_acc = tot_acc / (len(dataset_t))

            # Print
            acc_on_target.append(tot_t_acc)
            print('Epoch [%d/%d], Accuracy on target: %.4f' %
                  (epoch + 1, epochs, tot_t_acc))

        # Save every save_interval
        if epoch % save_interval == 0 or epoch == epochs - 1:
            torch.save(
                {
                    'epoch': epoch,
                    'feature_extractor': feature_extractor.state_dict(),
                    '{}_classifier'.format(source1): classifier_1.state_dict(),
                    '{}_classifier'.format(source2): classifier_2.state_dict(),
                    '{}_classifier'.format(source3): classifier_3.state_dict(),
                    'features_optimizer': optimizer_features.state_dict(),
                    'classifier_optimizer': optimizer_classifier.state_dict(),
                    'loss': tot_loss,
                }, os.path.join(path_output, target + '-{}.pth'.format(epoch)))

    # Save training loss and accuracy on target (if not 'real')
    pkl.dump(train_loss, open('{}train_loss.p'.format(path_output), 'wb'))
    if target != 'real':
        pkl.dump(acc_on_target,
                 open('{}target_accuracy.p'.format(path_output), 'wb'))
예제 #3
0
def train(opt):
    from tensorboardX import SummaryWriter
    writer = SummaryWriter(path_output)

    source1, source2, source3, target = taskSelect(opt.target)

    dataset_s1 = dataset.DA(dir=root, name=source1, img_size=(224, 224), train=True)
    dataset_s2 = dataset.DA(dir=root, name=source2, img_size=(224, 224), train=True)
    dataset_s3 = dataset.DA(dir=root, name=source3, img_size=(224, 224), train=True)
    dataset_t = dataset.DA(dir=root, name=target, img_size=(224, 224), train=True)
    dataset_tt = dataset.DA(dir=root, name=target, img_size=(224,224), train=False,real_val=False)

    dataloader_s1 = DataLoader(dataset_s1, batch_size=opt.bs, shuffle=True, num_workers=2)
    dataloader_s2 = DataLoader(dataset_s2, batch_size=opt.bs, shuffle=True, num_workers=2)
    dataloader_s3 = DataLoader(dataset_s3, batch_size=opt.bs, shuffle=True, num_workers=2)
    dataloader_t = DataLoader(dataset_t, batch_size=opt.bs, shuffle=True, num_workers=2)
    dataloader_tt = DataLoader(dataset_tt, batch_size=opt.bs, shuffle=False, num_workers=2)


    # dataset_s1 = dataset.DA(dir=root, name=source1, img_size=(224, 224), train=True)
    # dataset_s2 = dataset.DA(dir=root, name=source2, img_size=(224, 224), train=True)
    # dataset_s3 = dataset.DA(dir=root, name=source3, img_size=(224, 224), train=True)
    # dataset_t = dataset.DA(dir=root, name=target, img_size=(224, 224), train=True)

    # if target == 'real':
    #     tmp = os.path.join(root, 'test')
    #     dataset_tt = dataset.DA_test(dir=tmp, img_size=(224,224))
    # else:
    #     dataset_tt = dataset.DA(dir=root, name=target, img_size=(224, 224), train=False)

    # dataloader_s1 = DataLoader(dataset_s1, batch_size=opt.bs, shuffle=True, num_workers=2)
    # dataloader_s2 = DataLoader(dataset_s2, batch_size=opt.bs, shuffle=True, num_workers=2)
    # dataloader_s3 = DataLoader(dataset_s3, batch_size=opt.bs, shuffle=True, num_workers=2)
    # dataloader_t = DataLoader(dataset_t, batch_size=opt.bs, shuffle=True, num_workers=2)
    # dataloader_tt = DataLoader(dataset_tt, batch_size=opt.bs, shuffle=False, num_workers=2)


    len_data = min(len(dataset_s1), len(dataset_s2), len(dataset_s3), len(dataset_t))           # length of "shorter" domain
    len_bs = min(len(dataloader_s1), len(dataloader_s2), len(dataloader_s3), len(dataloader_t))

    # Define networks
    feature_extractor = models.feature_extractor()
    classifier_1 = models.class_classifier()
    classifier_2 = models.class_classifier()
    classifier_3 = models.class_classifier()
    classifier_1_ = models.class_classifier()
    classifier_2_ = models.class_classifier()
    classifier_3_ = models.class_classifier()

    # if torch.cuda.is_available():
    feature_extractor = feature_extractor.to(device)
    classifier_1 = classifier_1.to(device).apply(weight_init)
    classifier_2 = classifier_2.to(device).apply(weight_init)
    classifier_3 = classifier_3.to(device).apply(weight_init)
    classifier_1_ = classifier_1_.to(device).apply(weight_init)
    classifier_2_ = classifier_2_.to(device).apply(weight_init)
    classifier_3_ = classifier_3_.to(device).apply(weight_init)

    # Define loss
    mom_loss = momentumLoss()
    cl_loss = nn.CrossEntropyLoss()
    disc_loss = discrepancyLoss()

    # Optimizers
    # Change the LR
    optimizer_features = SGD(feature_extractor.parameters(), lr=0.0001,momentum=0.9,weight_decay=5e-4)
    optimizer_classifier = SGD(([{'params': classifier_1.parameters()},
                    {'params': classifier_2.parameters()},
                    {'params': classifier_3.parameters()}]), lr=0.002,momentum=0.9,weight_decay=5e-4)

    optimizer_classifier_ = SGD(([{'params': classifier_1_.parameters()},
                    {'params': classifier_2_.parameters()},
                    {'params': classifier_3_.parameters()}]), lr=0.002,momentum=0.9,weight_decay=5e-4)

    # optimizer_features = SGD(feature_extractor.parameters(), lr=0.0001)
    # optimizer_classifier = Adam(([{'params': classifier_1.parameters()},
    #                    {'params': classifier_2.parameters()},
    #                    {'params': classifier_3.parameters()}]), lr=0.002)
    # optimizer_classifier_ = Adam(([{'params': classifier_1_.parameters()},
    #                    {'params': classifier_2_.parameters()},
    #                    {'params': classifier_3_.parameters()}]), lr=0.002)

    if opt.pretrain is not None:
        state = torch.load(opt.pretrain)
        feature_extractor.load_state_dict(state['feature_extractor'])
        classifier_1.load_state_dict(state['{}_classifier'.format(source1)])
        classifier_2.load_state_dict(state['{}_classifier'.format(source2)])
        classifier_3.load_state_dict(state['{}_classifier'.format(source3)])
        classifier_1_.load_state_dict(state['{}_classifier_'.format(source1)])
        classifier_2_.load_state_dict(state['{}_classifier_'.format(source2)])
        classifier_3_.load_state_dict(state['{}_classifier_'.format(source3)])

    # Lists
    train_loss = []
    acc_on_target = []

    tot_loss, tot_clf_loss, tot_mom_loss, tot_s2_loss, tot_s3_loss = 0.0, 0.0, 0.0, 0.0, 0.0
    n_samples, iteration = 0, 0
    tot_correct = [0, 0, 0, 0, 0, 0]
    saved_time = time.time()
    feature_extractor.train()
    classifier_1.train(), classifier_2.train(), classifier_3.train()
    classifier_1_.train(), classifier_2_.train(), classifier_3_.train()

    for epoch in range(opt.ep):

        if epoch+1 == 5:
            optimizer_classifier = SGD(([{'params': classifier_1.parameters()},
                    {'params': classifier_2.parameters()},
                    {'params': classifier_3.parameters()}]), lr=0.001,momentum=0.9,weight_decay=5e-4)

            optimizer_classifier_ = SGD(([{'params': classifier_1_.parameters()},
                            {'params': classifier_2_.parameters()},
                            {'params': classifier_3_.parameters()}]), lr=0.001,momentum=0.9,weight_decay=5e-4)

        if epoch+1 == 10:
            optimizer_classifier = SGD(([{'params': classifier_1.parameters()},
                    {'params': classifier_2.parameters()},
                    {'params': classifier_3.parameters()}]), lr=0.0001,momentum=0.9,weight_decay=5e-4)

            optimizer_classifier_ = SGD(([{'params': classifier_1_.parameters()},
                            {'params': classifier_2_.parameters()},
                            {'params': classifier_3_.parameters()}]), lr=0.0001,momentum=0.9,weight_decay=5e-4)


        for i, (data_1, data_2, data_3, data_t) in enumerate(zip(dataloader_s1, dataloader_s2, dataloader_s3, dataloader_t)):

            img1, lb1 = data_1
            img2, lb2 = data_2
            img3, lb3 = data_3
            imgt, _ = data_t

            # Prepare data
            cur_batch = min(img1.shape[0], img2.shape[0], img3.shape[0], imgt.shape[0])
            # print(i, cur_batch)
            img1, lb1 = Variable(img1[0:cur_batch,:,:,:]).to(device), Variable(lb1[0:cur_batch]).to(device)
            img2, lb2 = Variable(img2[0:cur_batch,:,:,:]).to(device), Variable(lb2[0:cur_batch]).to(device)
            img3, lb3 = Variable(img3[0:cur_batch,:,:,:]).to(device), Variable(lb3[0:cur_batch]).to(device)
            imgt = Variable(imgt[0:cur_batch,:,:,:]).to(device)

            ### STEP 1 ### train G and C pairs
            # Forward
            optimizer_features.zero_grad()
            optimizer_classifier.zero_grad()
            optimizer_classifier_.zero_grad()

            # Extract Features
            ft1 = feature_extractor(img1)
            ft2 = feature_extractor(img2)
            ft3 = feature_extractor(img3)
            ft_t = feature_extractor(imgt)

            # Class Prediction [bs, 345]
            cl1, cl1_ = classifier_1(ft1), classifier_1_(ft1)
            cl2, cl2_ = classifier_2(ft2), classifier_2_(ft2)
            cl3, cl3_ = classifier_3(ft3), classifier_3_(ft3)

            # Compute "momentum loss"
            loss_mom = mom_loss(ft1, ft2, ft3, ft_t)

            # Cross entropy loss
            l1, l1_ = cl_loss(cl1, lb1), cl_loss(cl1_, lb1)
            l2, l2_ = cl_loss(cl2, lb2), cl_loss(cl2_, lb2)
            l3, l3_ = cl_loss(cl3, lb3), cl_loss(cl3_, lb3)
            # total loss
            s1loss = l1 + l2 + l3 + l1_ + l2_ + l3_ + opt.alpha * loss_mom

            s1loss.backward()
            optimizer_features.step()
            optimizer_classifier.step()
            optimizer_classifier_.step()

            ### STEP 2 ### fix G, and train C pairs 
            optimizer_classifier.zero_grad()
            optimizer_classifier_.zero_grad()

            # Class Prediction on each src domain
            cl1, cl1_ = classifier_1(ft1.detach()), classifier_1_(ft1.detach())
            cl2, cl2_ = classifier_2(ft2.detach()), classifier_2_(ft2.detach())
            cl3, cl3_ = classifier_3(ft3.detach()), classifier_3_(ft3.detach())

            # discrepancy on tgt domain
            clt1, clt1_ = classifier_1(ft_t.detach()), classifier_1_(ft_t.detach())
            clt2, clt2_ = classifier_2(ft_t.detach()), classifier_2_(ft_t.detach())
            clt3, clt3_ = classifier_3(ft_t.detach()), classifier_3_(ft_t.detach())

            # classification loss
            l1, l1_ = cl_loss(cl1, lb1), cl_loss(cl1_, lb1)
            l2, l2_ = cl_loss(cl2, lb2), cl_loss(cl2_, lb2)
            l3, l3_ = cl_loss(cl3, lb3), cl_loss(cl3_, lb3)

            # print(clt1.shape)
            dl1 = disc_loss(clt1, clt1_)
            dl2 = disc_loss(clt2, clt2_)
            dl3 = disc_loss(clt3, clt3_)
            # print(dl1, dl2, dl3)

            # backward
            s2loss = l1 + l2 + l3 + l1_ + l2_ + l3_ - dl1 - dl2 - dl3
            s2loss.backward()
            optimizer_classifier.step()
            optimizer_classifier_.step()

            ### STEP 3 #### fix C pairs, train G
            optimizer_features.zero_grad()

            ft_t = feature_extractor(imgt)
            clt1, clt1_ = classifier_1(ft_t), classifier_1_(ft_t)
            clt2, clt2_ = classifier_2(ft_t), classifier_2_(ft_t)
            clt3, clt3_ = classifier_3(ft_t), classifier_3_(ft_t)

            dl1 = disc_loss(clt1, clt1_)
            dl2 = disc_loss(clt2, clt2_)
            dl3 = disc_loss(clt3, clt3_)

            s3loss = dl1 + dl2 + dl3
            s3loss.backward()
            optimizer_features.step()
            


            pred = torch.stack((cl1, cl2, cl3, cl1_, cl2_, cl3_), 0) # [6, bs, 345]
            _, pred = torch.max(pred, dim = 2) # [6, bs]
            gt = torch.stack((lb1, lb2, lb3, lb1, lb2, lb3), 0) # [6, bs]
            correct = pred.eq(gt.data)
            correct = torch.mean(correct.type(torch.FloatTensor), dim = 1).cpu().numpy()

            tot_loss += s1loss.item() * cur_batch
            tot_clf_loss += (s1loss.item() - opt.alpha * loss_mom.item()) * cur_batch
            tot_s2_loss += s2loss.item() * cur_batch
            tot_s3_loss += s3loss.item() * cur_batch
            tot_mom_loss += loss_mom.item() * cur_batch
            tot_correct += correct * cur_batch
            n_samples += cur_batch

            # print(cur_batch)
            if iteration % opt.log_interval == 0:
                current_time = time.time()
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tClfLoss: {:.4f}\tMMLoss: {:.4f}\t \
                    S2Loss: {:.4f}\tS3Loss: {:.4f}\t \
                    Accu: {:.4f}\\{:.4f}\\{:.4f}\\{:.4f}\\{:.4f}\\{:.4f}\tTime: {:.3f}'.format(\
                        epoch, i * opt.bs, len_data, 100. * i / len_bs, \
                        tot_clf_loss / n_samples, 
                        tot_mom_loss / n_samples,
                        tot_s2_loss / n_samples,
                        tot_s3_loss / n_samples,
                        tot_correct[0] / n_samples,
                        tot_correct[1] / n_samples,
                        tot_correct[2] / n_samples,
                        tot_correct[3] / n_samples,
                        tot_correct[4] / n_samples,
                        tot_correct[5] / n_samples,
                        current_time - saved_time))
                writer.add_scalar('Train/ClfLoss', tot_clf_loss / n_samples, iteration * opt.bs)
                writer.add_scalar('Train/MMLoss', tot_mom_loss / n_samples, iteration * opt.bs)
                writer.add_scalar('Train/s2Loss', tot_s2_loss / n_samples, iteration * opt.bs)
                writer.add_scalar('Train/s3Loss', tot_s3_loss / n_samples, iteration * opt.bs)
                writer.add_scalar('Train/Accu0', tot_correct[0] / n_samples, iteration * opt.bs)
                writer.add_scalar('Train/Accu1', tot_correct[1] / n_samples, iteration * opt.bs)
                writer.add_scalar('Train/Accu2', tot_correct[2] / n_samples, iteration * opt.bs)
                writer.add_scalar('Train/Accu0_', tot_correct[3] / n_samples, iteration * opt.bs)
                writer.add_scalar('Train/Accu1_', tot_correct[4] / n_samples, iteration * opt.bs)
                writer.add_scalar('Train/Accu2_', tot_correct[5] / n_samples, iteration * opt.bs)

                saved_weight = torch.FloatTensor([tot_correct[0], tot_correct[1], tot_correct[2], tot_correct[3], tot_correct[4], tot_correct[5]]).to(device)
                if torch.sum(saved_weight) == 0.:
                    saved_weight = torch.FloatTensor(6).to(device).fill_(1)/6.
                else:
                    saved_weight = saved_weight/torch.sum(saved_weight)
                
                saved_time = time.time()
                tot_clf_loss, tot_mom_loss, tot_correct, n_samples = 0, 0, [0, 0, 0, 0, 0, 0], 0
                tot_s2_loss, tot_s3_loss = 0, 0
                train_loss.append(tot_loss)

            # evaluation and save
            if iteration % opt.eval_interval == 0 and iteration >= 0 and target != 'real':
                print('weight = ', saved_weight.cpu().numpy())
                evalacc = eval(saved_weight, feature_extractor, classifier_1_, classifier_2_, classifier_3_,
                classifier_1, classifier_2, classifier_3, dataloader_tt)
                writer.add_scalar('Test/Accu', evalacc, iteration * opt.bs)
                acc_on_target.append(evalacc)
                print('Eval Acc = {:.2f}\n'.format(evalacc*100))
                torch.save({
                        'epoch': epoch,
                        'feature_extractor': feature_extractor.state_dict(),
                        '{}_classifier'.format(source1): classifier_1.state_dict(),
                        '{}_classifier'.format(source2): classifier_2.state_dict(),
                        '{}_classifier'.format(source3): classifier_3.state_dict(),
                        '{}_classifier_'.format(source1): classifier_1_.state_dict(),
                        '{}_classifier_'.format(source2): classifier_2_.state_dict(),
                        '{}_classifier_'.format(source3): classifier_3_.state_dict(),
                        'features_optimizer': optimizer_features.state_dict(),
                        'classifier_optimizer': optimizer_classifier.state_dict(),
                        'loss': tot_loss,
                        'saved_weight': saved_weight
               }, os.path.join(path_output, target + '-{}-{:.2f}.pth'.format(epoch, evalacc*100)))

            iteration += 1

    pkl.dump(train_loss, open('{}train_loss.p'.format(path_output), 'wb'))
    if target != 'real':
        pkl.dump(acc_on_target, open('{}target_accuracy.p'.format(path_output), 'wb'))
예제 #4
0
def main():
    root = 'data/'

    source1 = 'real'
    source2 = 'infograph'
    source3 = 'quickdraw'
    target = 'sketch'
    adaptive_weight = True

    if not target == 'real':
        dataset_t = DA(dir=root, name=target, img_size=(224, 224), train=False)
    else:
        dataset_t = test_dataset(dir='data/test', img_size=(224, 224))

    dataloader_t = DataLoader(dataset_t,
                              batch_size=64,
                              shuffle=False,
                              num_workers=8)

    path = 'checkpoints/infograph-0-deming.pth'  #you may change the path  'checkpoints/sketch-30.pth'

    feature_extractor = models.feature_extractor()
    classifier_1 = models.class_classifier()
    classifier_2 = models.class_classifier()
    classifier_3 = models.class_classifier()

    state = torch.load(path)
    print(len(state))
    print(state.keys())
    print()

    feature_extractor.load_state_dict(state['feature_extractor'])
    classifier_1.load_state_dict(state['{}_classifier'.format(source1)])
    classifier_2.load_state_dict(state['{}_classifier'.format(source2)])
    classifier_3.load_state_dict(state['{}_classifier'.format(source3)])

    if adaptive_weight:
        w1_mean = state['{}_weight'.format(source1)]
        w2_mean = state['{}_weight'.format(source2)]
        w3_mean = state['{}_weight'.format(source3)]
    else:
        w1_mean = 1 / 3
        w2_mean = 1 / 3
        w3_mean = 1 / 3

    if torch.cuda.is_available():
        feature_extractor = feature_extractor.cuda()
        classifier_1 = classifier_1.cuda()
        classifier_2 = classifier_2.cuda()
        classifier_3 = classifier_3.cuda()

    feature_extractor.eval()
    classifier_1.eval(), classifier_2.eval(), classifier_3.eval()

    ans = open('{}_pred.csv'.format(target), 'w')
    ans.write('image_name,label\n')
    m = nn.Softmax(1)
    with torch.no_grad():
        for idx, (img, name) in enumerate(dataloader_t):
            if torch.cuda.is_available():
                img = img.cuda()

            ft_t = feature_extractor(img)

            pred1 = classifier_1(ft_t)
            pred2 = classifier_2(ft_t)
            pred3 = classifier_3(ft_t)

            pred = (pred1 * w1_mean + pred2 * w2_mean + pred3 * w3_mean)
            pred = m(pred)
            #embed()
            _, pred = torch.max(pred, dim=1)

            print('\r Predicting... Progress: %.1f %%' %
                  (100 * (idx + 1) / len(dataloader_t)),
                  end='')

            for i in range(len(name)):
                ans.write('{},{}\n'.format(os.path.join('test/', name[i]),
                                           pred[i]))

    ans.close()