Ejemplo n.º 1
0
def check():
    #-------------------------
    # Construce the DANN model
    #-------------------------
    feature_extractor = Feature_Extractor().to(DEVICE)
    class_classifier = Class_Classifier().to(DEVICE)
    domain_classifier = Domain_Classifier().to(DEVICE)

    class_criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
    domain_criterion = torch.nn.CrossEntropyLoss().to(DEVICE)

    DOMAINS = [("usps", "mnistm"), ("mnistm", "svhn"), ("svhn", 'usps')]

    for SOURCE, TARGET in DOMAINS:
        if TARGET != opt.target: continue

        # Read the DANN model
        feature_extractor, class_classifier, domain_classifier = utils.loadDANN(
            opt.model, feature_extractor, class_classifier, domain_classifier)

        # Create Dataloader
        source_black = True if SOURCE == 'usps' else False
        target_black = True if TARGET == 'usps' else False
        source_test_set = dataset.NumberClassify(
            "./hw3_data/digits",
            SOURCE,
            train=False,
            black=source_black,
            transform=transforms.ToTensor())
        target_train_set = dataset.NumberClassify(
            "./hw3_data/digits",
            TARGET,
            train=True,
            black=target_black,
            transform=transforms.ToTensor())
        target_test_set = dataset.NumberClassify(
            "./hw3_data/digits",
            TARGET,
            train=False,
            black=target_black,
            transform=transforms.ToTensor())
        print("Source_test: \t{}, {}".format(SOURCE, len(source_test_set)))
        print("Target_train: \t{}, {}".format(TARGET, len(target_train_set)))
        print("Target_test: \t{}, {}".format(TARGET, len(target_test_set)))
        source_test_loader = DataLoader(source_test_set,
                                        batch_size=opt.batch_size,
                                        shuffle=False,
                                        num_workers=opt.threads)
        target_train_loader = DataLoader(target_train_set,
                                         batch_size=opt.batch_size,
                                         shuffle=False,
                                         num_workers=opt.threads)
        target_test_loader = DataLoader(target_test_set,
                                        batch_size=opt.batch_size,
                                        shuffle=False,
                                        num_workers=opt.threads)

        # Predict
        class_acc, class_loss, domain_acc, domain_loss = val(
            feature_extractor, class_classifier, domain_classifier,
            source_test_loader, 0, class_criterion, domain_criterion)
        print("{}_Test: ".format(SOURCE))
        print(
            "[class_acc: {:.2f}% ] [class_loss: {:.4f}] [domain_acc: {:.2f} %] [domain_loss: {:.4f}]"
            .format(100 * class_acc, class_loss, 100 * domain_acc,
                    domain_loss))

        class_acc, class_loss, domain_acc, domain_loss = val(
            feature_extractor, class_classifier, domain_classifier,
            target_train_loader, 1, class_criterion, domain_criterion)
        print("{}_Train: ".format(TARGET))
        print(
            "[class_acc: {:.2f}% ] [class_loss: {:.4f}] [domain_acc: {:.2f} %] [domain_loss: {:.4f}]"
            .format(100 * class_acc, class_loss, 100 * domain_acc,
                    domain_loss))

        class_acc, class_loss, domain_acc, domain_loss = val(
            feature_extractor, class_classifier, domain_classifier,
            target_test_loader, 1, class_criterion, domain_criterion)
        print("{}_Test: ".format(TARGET))
        print(
            "[class_acc: {:.2f}% ] [class_loss: {:.4f}] [domain_acc: {:.2f} %] [domain_loss: {:.4f}]"
            .format(100 * class_acc, class_loss, 100 * domain_acc,
                    domain_loss))
Ejemplo n.º 2
0
def dann_performance_test(source, target, epochs, threshold, lr, weight_decay):
    #-----------------------------------------------------
    # Create Model, optimizer, scheduler, and loss function
    #------------------------------------------------------
    feature_extractor = Feature_Extractor().to(DEVICE)
    class_classifier  = Class_Classifier().to(DEVICE)
    domain_classifier = Domain_Classifier().to(DEVICE)

    optimizer = optim.Adam([
        {'params': feature_extractor.parameters()},
        {'params': class_classifier.parameters()},
        {'params': domain_classifier.parameters()}
        ], lr=lr, betas=(opt.b1, opt.b2), weight_decay=weight_decay
    )
    
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[], gamma=0.1)
    
    class_criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
    domain_criterion = torch.nn.CrossEntropyLoss().to(DEVICE)

    #------------------
    # Create Dataloader
    #------------------
    source_black = True if source == 'usps' else False
    target_black = True if target == 'usps' else False
    source_train_set = dataset.NumberClassify("./hw3_data/digits", source, train=True, black=source_black, transform=transforms.ToTensor())
    target_train_set = dataset.NumberClassify("./hw3_data/digits", target, train=True, black=target_black, transform=transforms.ToTensor())
    source_test_set  = dataset.NumberClassify("./hw3_data/digits", source, train=False, black=source_black, transform=transforms.ToTensor())
    target_test_set  = dataset.NumberClassify("./hw3_data/digits", target, train=False, black=target_black, transform=transforms.ToTensor())
    print("Source_train: \t{}, {}".format(source, len(source_train_set)))
    print("Source_test: \t{}, {}".format(source, len(source_test_set)))
    print("Target_train: \t{}, {}".format(target, len(target_train_set)))
    print("Target_test: \t{}, {}".format(target, len(target_test_set)))
    
    source_train_loader = DataLoader(source_train_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.threads)
    source_test_loader  = DataLoader(source_test_set, batch_size=opt.batch_size, shuffle=False, num_workers=opt.threads)
    target_train_loader = DataLoader(target_train_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.threads)
    target_test_loader  = DataLoader(target_test_set, batch_size=opt.batch_size, shuffle=False, num_workers=opt.threads)
    
    for epoch in range(1, epochs + 1):
        scheduler.step()

        loss = []
        feature_extractor.train()
        class_classifier.train()
        domain_classifier.train()

        for index, (source_batch, target_batch) in enumerate(zip(source_train_loader, target_train_loader), 1):
            source_img, source_label, _ = source_batch
            target_img, target_label, _ = target_batch
            batch_size_src, batch_size_tgt = source_img.shape[0], target_img.shape[0]

            source_img   = source_img.to(DEVICE)
            source_label = source_label.type(torch.long).view(-1).to(DEVICE)
            target_img   = target_img.to(DEVICE)
            target_label = target_label.type(torch.long).view(-1).to(DEVICE)

            source_domain_labels = torch.zeros(batch_size_src).type(torch.long).to(DEVICE)
            target_domain_labels = torch.ones(batch_size_tgt).type(torch.long).to(DEVICE)
            
            constant = opt.alpha

            optimizer.zero_grad()

            source_feature = feature_extractor(source_img).view(-1, 128 * 7 * 7)
            target_feature = feature_extractor(target_img).view(-1, 128 * 7 * 7)
            
            source_class_predict  = class_classifier(source_feature)
            source_domain_predict = domain_classifier(source_feature, constant)
            target_domain_predict = domain_classifier(target_feature, constant)

            #---------------------------------------
            # Compute the loss
            # For the case of unsupervised learning:
            #   When source domain img:
            #     loss = class_loss + domain_loss
            #   When target domain img:
            #     loss = domain_loss
            # 
            #   Needs to maximize the domain_loss 
            #------------------------------------
            # 1. class loss
            class_loss = class_criterion(source_class_predict, source_label)
            
            # 2. Domain loss
            # print("Source_domain_predict.shape: \t{}".format(source_domain_predict.shape))
            # print("Source_domain_labels.shape: \t{}".format(source_domain_labels.shape))
            source_domain_loss = domain_criterion(source_domain_predict, source_domain_labels)
            target_domain_loss = domain_criterion(target_domain_predict, target_domain_labels)
            domain_loss = target_domain_loss + source_domain_loss

            # Final loss
            # loss = class_loss + constant * domain_loss
            # loss = constant * domain_loss
            # loss.backward()
            if epoch > 5: domain_loss.backward()
            if epoch <=5: class_loss.backward()
            optimizer.step()

            source_class_predict = source_class_predict.cpu().detach().numpy()
            source_label = source_label.cpu().detach().numpy()

            source_acc = np.mean(np.argmax(source_class_predict, axis=1) == source_label)

            if index % opt.log_interval == 0:
                print("[Epoch %d] [ %d/%d ] [src_acc: %d%%] [loss_Y: %f] [loss_D: %f]" % (
                        epoch, index, len(source_train_loader), 100 * source_acc, class_loss.item(), constant * domain_loss.item()))
Ejemplo n.º 3
0
def domain_adaptation(source, target, epochs, threshold, lr, weight_decay):
    """ 
    Using DANN framework to train the network. 
    
    Parameter
    ---------
    source, target : str

    epochs : int

    threshold : float
        Maximum accuracy of pervious epochs

    lr : float
        Learning Rate 

    weight_decay : float
        Weight Regularization

    Return
    ------
    feature_extractor, class_classifier, domain_classifier : nn.Module
        (...)
    """
    feature_extractor = Feature_Extractor().to(DEVICE)
    class_classifier  = Class_Classifier().to(DEVICE)
    domain_classifier = Domain_Classifier().to(DEVICE)

    optimizer = optim.Adam([
        {'params': feature_extractor.parameters()},
        {'params': class_classifier.parameters()},
        {'params': domain_classifier.parameters()}], 
        lr=lr, betas=(opt.b1, opt.b2), weight_decay=weight_decay
    )
    
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40], gamma=0.2)
    
    class_criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
    domain_criterion = torch.nn.CrossEntropyLoss().to(DEVICE)

    # Create Dataloader
    # TODO: Make attributes nametuple for the datasets
    source_black = True if source == 'usps' else False
    target_black = True if target == 'usps' else False
    source_train_set = dataset.NumberClassify(
        "./hw3_data/digits", source, 
        train=True, 
        black=source_black, 
        transform=transforms.ToTensor()
    )
    target_train_set = dataset.NumberClassify(
        "./hw3_data/digits", target, 
        train=True, 
        black=target_black, 
        transform=transforms.ToTensor()
    )
    source_test_set  = dataset.NumberClassify(
        "./hw3_data/digits", source, 
        train=False, 
        black=source_black, 
        transform=transforms.ToTensor()
    )
    target_test_set  = dataset.NumberClassify(
        "./hw3_data/digits", target, 
        train=False, 
        black=target_black, 
        transform=transforms.ToTensor()
    )
    print("Source_train: \t{}, {}".format(source, len(source_train_set)))
    print("Source_test: \t{}, {}".format(source, len(source_test_set)))
    print("Target_train: \t{}, {}".format(target, len(target_train_set)))
    print("Target_test: \t{}, {}".format(target, len(target_test_set)))
    
    source_train_loader = DataLoader(source_train_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.threads)
    source_test_loader  = DataLoader(source_test_set, batch_size=opt.batch_size, shuffle=False, num_workers=opt.threads)
    target_train_loader = DataLoader(target_train_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.threads)
    target_test_loader  = DataLoader(target_test_set, batch_size=opt.batch_size, shuffle=False, num_workers=opt.threads)
    
    source_pred_values = []
    target_pred_values = []

    for epoch in range(1, epochs + 1):
        scheduler.step()

        feature_extractor, class_classifier, domain_classifier = train(feature_extractor, class_classifier, domain_classifier, source_train_loader, target_train_loader, optimizer, epoch, class_criterion, domain_criterion)
        source_pred_value = predict.val(feature_extractor, class_classifier, domain_classifier, source_test_loader, 0, class_criterion, domain_criterion)
        target_pred_value = predict.val(feature_extractor, class_classifier, domain_classifier, target_test_loader, 1, class_criterion, domain_criterion)
        # Return: class_acc, class_loss, domain_acc, domain_loss

        print("[Epoch {}] [ src_acc: {:.2f}% ] [ tgc_acc: {:.2f}% ] [ src_loss: {:.4f} ] [ tgc_loss: {:.4f} ]".format(
               epoch, 100 * source_pred_value[0], 100 * target_pred_value[0], source_pred_value[1] + source_pred_value[3], target_pred_value[1] + target_pred_value[3]))

        # Tracing the accuracy, loss
        source_pred_values.append(source_pred_value)
        target_pred_values.append(target_pred_value)

        y_source = np.asarray(source_pred_values, dtype=float)
        y_target = np.asarray(target_pred_values, dtype=float)

        x = np.arange(start=1, stop=epoch+1)
        plt.clf()
        plt.figure(figsize=(12.8, 7.2))
        plt.plot(x, y_source[:, 0], 'r', label='Source Class Accuracy', linewidth=1)
        plt.plot(x, y_target[:, 0], 'b', label='Target Class Accuracy', linewidth=1)
        plt.plot(x, np.asarray([0.3]).repeat(len(x)), 'b-', linewidth=0.1)
        plt.plot(x, np.asarray([0.4]).repeat(len(x)), 'b-', linewidth=0.1)
        plt.plot(x, y_source[:, 2], 'r:', label='Source Domain Accuracy', linewidth=1)
        plt.plot(x, y_target[:, 2], 'b:', label='Target Domain Accuracy', linewidth=1)
        plt.legend(loc=0)
        plt.xlabel("Epochs(s)")
        plt.title("[Acc - Train {} Test {}] vs Epoch(s)".format(source, target))
        plt.savefig("DANN_{}_{}_{}-Acc.png".format(opt.alpha, source, target))
        plt.close()

        plt.clf()
        plt.figure(figsize=(12.8, 7.2))
        plt.plot(x, y_source[:, 1], 'r', label='Source Class Loss', linewidth=1)
        plt.plot(x, y_target[:, 1], 'b', label='Target Class Loss', linewidth=1)
        plt.plot(x, y_source[:, 3], 'r:', label='Source Domain Loss', linewidth=1)
        plt.plot(x, y_target[:, 3], 'b:', label='Target Domain Loss', linewidth=1)
        plt.legend(loc=0)
        plt.xlabel("Epochs(s)")
        plt.title("[Loss - Train {} Test {}] vs Epoch(s)".format(source, target))
        plt.savefig("DANN_{}_{}_{}-Loss.png".format(opt.alpha, source, target))
        plt.close()

        with open('statistics.txt', 'a') as textfile:
            textfile.write(datetime.datetime.now().strftime("%d, %b %Y %H:%M:%S"))
            textfile.write(str(source_pred_values))
            textfile.write(str(target_pred_values))

        if target_pred_value[0] > threshold:
            # Update the threshold to the maximum if the target acc was improved.
            if target_pred_value[0] > threshold:
                threshold = target_pred_value[0]
            
            utils.saveDANN("./models/dann/{}/DANN_{}_{}_{}.pth".format(opt.tag, source, target, epoch), feature_extractor, class_classifier, domain_classifier)
            
    return feature_extractor, class_classifier, domain_classifier
Ejemplo n.º 4
0
def main():
    os.system("clear")

    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Images to read for every iteration")
    parser.add_argument("--perpexity",
                        type=float,
                        help="the perpexity of the t-SNE algorithm")
    parser.add_argument(
        "--threads",
        type=int,
        default=8,
        help="Number of cpu threads to use during batch generation")
    parser.add_argument("--adda", action="store_true")
    parser.add_argument("--dann", action="store_true")
    opt = parser.parse_args()

    paths = {}
    DOMAINS = [("usps", "mnistm"), ("mnistm", "svhn"), ("svhn", 'usps')]

    #----------------------------
    # ADDA Model
    #----------------------------
    if opt.adda:
        source_encoder = Feature()
        target_encoder = Feature()
        classifier = Classifier(128 * 7 * 7, 1000, 10)

    if opt.dann:
        feature_exatractor = Feature_Extractor()
        class_classifier = Class_Classifier()

        for SOURCE, TARGET in DOMAINS:
            sourcepath = "./hw3_data/digits/{}/".format(SOURCE)
            targetpath = "./hw3_data/digits/{}/".format(TARGET)
            # sourcecsv  = os.path.join(sourcepath, "test.csv")
            # targetcsv  = os.path.join(targetpath, "test.csv")
            # modelpath  = os.path.join("./models/dann/20190504/ADDA_{}.pth".format(TARGET))

            paths[(SOURCE, TARGET)].append((sourcepath, targetpath))

        # Read model
        for (source, target) in paths:
            sourcepath, targetpath = paths[(source, target)]
            feature_extractor, class_classifier, _ = utils.loadDANN(
                modelpath, Feature_Extractor(), Class_Classifier(),
                Domain_Classifier())

            source_set = dataset.NumberClassify(sourcepath,
                                                train=True,
                                                black=False)

            # Random pick 20 images for each class.
            # s_df = pd.read_csv(sourcecsv)
            # t_df = pd.read_csv(targetcsv)

            imgs = []
            labels = []

            #----------------
            # Source Domain
            #----------------
            for label in range(0, 10):
                # df = s_df[s_df["label"] == label].sample(n=20)

                for index, (img_name, _) in df.iterrows():
                    img = PIL.Image.open(img_name)
                    imgs.append(img)
                    labels.append(label)

            x1 = transforms.ToTensor()(imgs)
            y1 = torch.Tensor(labels, dtype=torch.float)

            x1 = source_encoder(x1).view(-1, 128 * 7 * 7)

            imgs = []
            labels = []

            #----------------
            # Target Domain
            #----------------
            for label in range(0, 10):
                df = t_df[t_df["label"] == label].sample(n=20)

                for index, (img_name, _) in df.iterrows():
                    img = PIL.Image.open(img_name)
                    imgs.append(img)
                    labels.append(label)

            x2 = transforms.ToTensor()(imgs)
            y2 = torch.Tensor(labels, dtype=torch.float)

            x2 = source_encoder(x2).view(-1, 128, 7, 7)

            # Draw tsne
            tsne(x1,
                 x2,
                 n=2,
                 perpexity=opt.perpexity,
                 source=source,
                 target=target)

    #----------------------------
    # DANN Model
    #----------------------------
    if opt.dann:
        feature_extractor = Feature_Extractor()
        class_classifier = Class_Classifier()
        domain_classifier = Domain_Classifier()

        for SOURCE, TARGET in DOMAINS:
            sourcepath = "./hw3_data/digits/{}".format(SOURCE)
            targetpath = "./hw3_data/digits/{}".format(TARGET)
            sourcecsv = os.path.join(sourcepath, "test.csv")
            targetcsv = os.path.join(targetpath, "test.csv")
            modelpath = os.path.join(
                "./models/dann/20190504/ADDA_{}.pth".format(TARGET))

            paths[(SOURCE, TARGET)].append(
                (sourcepath, targetpath, sourcecsv, targetcsv, modelpath))

        for (source, target) in paths:
            sourcepath, targetpath, sourcecsv, targetcsv, modelpath = paths[(
                source, target)]
            feature_extractor, class_classifier, _ = utils.loadDANN(
                modelpath, feature_extractor, class_classifier,
                domain_classifier)

            # Random pick 20 images for each class.
            s_df = pd.read_csv(sourcecsv)
            t_df = pd.read_csv(targetcsv)

            # Draw tsne
            tsne(x,
                 y,
                 n=2,
                 perpexity=opt.perpexity,
                 source=source,
                 target=target)
Ejemplo n.º 5
0
def adversarial_discriminative_domain_adaptation(source, target, source_epochs, target_epochs, threshold, source_lr, target_lr, weight_decay):
    """ Using ADDA framework to train the network. """
    #-----------------------------------------------------
    # Create Model, optimizer, scheduler, and loss function
    #------------------------------------------------------
    source_encoder   = Feature().to(DEVICE)
    target_encoder   = Feature().to(DEVICE)
    class_classifier = Classifier(128 * 7 * 7, 1000, 10).to(DEVICE)
    discriminator    = Discriminator().to(DEVICE)

    source_optimizer = optim.Adam([{'params': source_encoder.parameters()},
                                   {'params': class_classifier.parameters()}], 
                                    lr=source_lr, betas=(opt.b1, opt.b2), weight_decay=weight_decay)
    
    source_scheduler = optim.lr_scheduler.MultiStepLR(source_optimizer, milestones=[], gamma=0.1)
    
    encoder_criterion     = nn.CrossEntropyLoss().to(DEVICE)
    adversarial_criterion = nn.MSELoss().to(DEVICE)
    
    #------------------
    # Create Dataloader
    #------------------
    source_black = True if source == 'usps' else False
    target_black = True if target == 'usps' else False
    
    source_train_set = dataset.NumberClassify("./hw3_data/digits", source, train=True, black=source_black, transform=transforms.ToTensor())
    target_train_set = dataset.NumberClassify("./hw3_data/digits", target, train=True, black=target_black, transform=transforms.ToTensor())
    source_test_set  = dataset.NumberClassify("./hw3_data/digits", source, train=False, black=source_black, transform=transforms.ToTensor())
    target_test_set  = dataset.NumberClassify("./hw3_data/digits", target, train=False, black=target_black, transform=transforms.ToTensor())
    
    print("Source_train: \t{}, {}".format(source, len(source_train_set)))
    print("Source_test: \t{}, {}".format(source, len(source_test_set)))
    print("Target_train: \t{}, {}".format(target, len(target_train_set)))
    print("Target_test: \t{}, {}".format(target, len(target_test_set)))
    
    source_train_loader = DataLoader(source_train_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.threads)
    source_test_loader  = DataLoader(source_test_set, batch_size=opt.batch_size, shuffle=False, num_workers=opt.threads)
    target_train_loader = DataLoader(target_train_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.threads)
    target_test_loader  = DataLoader(target_test_set, batch_size=opt.batch_size, shuffle=False, num_workers=opt.threads)
    
    source_pred_values = []
    target_pred_values = []

    #------------------
    # Train Source Domain
    #------------------
    if not opt.pretrain:
        for epoch in range(1, source_epochs + 1):
            source_scheduler.step()
            source_encoder, class_classifier = train_source(source_encoder, class_classifier, source_optimizer, source_train_loader, encoder_criterion, epoch)

            source_acc, source_loss = val(source_encoder, class_classifier, source_test_loader, encoder_criterion)
            target_acc, target_loss = val(source_encoder, class_classifier, target_test_loader, encoder_criterion)
            # Return: class_acc, class_loss, domain_acc, domain_loss

            print("[Epoch {}] [ src_acc: {:.2f}% ] [ tgt_acc: {:.2f}% ] [ src_loss: {:.4f} ] [ tgt_loss: {:.4f} ]".format(
                epoch, 100 * source_acc, 100 * target_acc, source_loss, target_loss))

            # Tracing the accuracy, loss
            source_pred_values.append((source_acc, source_loss))
            target_pred_values.append((target_acc, target_loss))

            y_source = np.asarray(source_pred_values, dtype=float)
            y_target = np.asarray(target_pred_values, dtype=float)
            x = np.arange(start=1, stop=epoch + 1)

            # Draw graphs
            draw_graphs(x, y_source, y_target, threshold, source_epochs, source, target)

            with open('statistics.txt', 'a') as textfile:
                textfile.write(datetime.datetime.now().strftime("%d, %b %Y %H:%M:%S"))
                textfile.write(str(source_acc))
                textfile.write(str(target_acc))

            if target_acc > threshold:
                # Update the threshold
                threshold = target_acc
                
                savepath = "./models/adda/{}/ADDA_{}_{}_{}.pth".format(opt.tag, source, target, epoch)
                utils.saveADDA(savepath, source_encoder, target_encoder, class_classifier)
                print("Model saved to: {}".format(savepath))
    
        #-----------------------------------
        # Save the model in the last epochs
        #----------------------------------
        savepath = "./models/adda/{}/ADDA_{}_{}_{}.pth".format(opt.tag, source, target, opt.source_epochs)
        utils.saveADDA(savepath, source_encoder, target_encoder, class_classifier)
        print("Model saved to: {}".format(savepath))

    if opt.pretrain:
        if not os.path.exists(opt.pretrain):
            raise IOError
        
        source_encoder, _, class_classifier = utils.loadADDA(opt.pretrain, source_encoder, target_encoder, class_classifier)
        source_acc, source_loss = val(source_encoder, class_classifier, source_test_loader, encoder_criterion)
        target_acc, target_loss = val(source_encoder, class_classifier, target_test_loader, encoder_criterion)
                
        source_pred_values = [(source_acc, source_loss) for _ in range(0, opt.source_epochs, opt.val_interval)]
        target_pred_values = [(target_acc, target_loss) for _ in range(0, opt.source_epochs, opt.val_interval)]

        
    #---------------------------------
    # Initial the target domain encoder
    #----------------------------------
    target_encoder.load_state_dict(source_encoder.state_dict())
    target_optimizer = optim.Adam(target_encoder.parameters(), target_lr, betas=(opt.b1, opt.b2))
    discri_optimizer = optim.Adam(discriminator.parameters(), target_lr)
    target_scheduler = optim.lr_scheduler.MultiStepLR(target_optimizer, milestones=[], gamma=0.1)

    #------------------
    # Train Target Domain
    #------------------
    for epoch in range(source_epochs + 1, source_epochs + target_epochs + 1):
        target_scheduler.step()
        
        target_encoder, discriminator = train_target(source_encoder, target_encoder, discriminator, adversarial_criterion,
                                                     source_train_loader, target_train_loader, discri_optimizer, target_optimizer, epoch)

        if epoch % opt.val_interval == 0:
            target_acc, target_loss = val(target_encoder, class_classifier, target_test_loader, encoder_criterion)

            print("[Epoch {}] [ src_acc: {:.2f}% ] [ tgt_acc: {:.2f}% ] [ src_loss: {:.4f} ] [ tgt_loss: {:.4f} ]".format(
                   epoch, 100 * source_acc, 100 * target_acc, source_loss, target_loss))

            # Tracing the accuracy, loss
            source_pred_values.append((source_acc, source_loss))
            target_pred_values.append((target_acc, target_loss))

            y_source = np.asarray(source_pred_values, dtype=float)
            y_target = np.asarray(target_pred_values, dtype=float)
            x = np.arange(start=1, stop=epoch+1, step=opt.val_interval)
        
            # Draw the graphs
            draw_graphs(x, y_source, y_target, threshold, source_epochs, source, target)

            # with open('statistics.txt', 'a') as textfile:
            #     textfile.write(datetime.datetime.now().strftime("%d, %b %Y %H:%M:%S"))
            #     textfile.write(str(source_acc))
            #     textfile.write(str(target_acc))

        if target_acc > threshold:
                # Update the threshold
            threshold = target_acc
            
            savepath = "./models/adda/{}/ADDA_{}_{}_{}.pth".format(opt.tag, source, target, epoch)
            utils.saveADDA(savepath, source_encoder, target_encoder, class_classifier)
            print("Model saved to: {}".format(savepath))

    #-----------------------------------
    # Save the model in the last epochs
    #----------------------------------
    savepath = "./models/adda/{}/ADDA_{}_{}_{}.pth".format(opt.tag, source, target, epoch)
    utils.saveADDA(savepath, source_encoder, target_encoder, class_classifier)
    print("Model saved to: {}".format(savepath))

    return source_encoder, target_encoder, class_classifier, discriminator
def train_A_test_B(source, target, epochs, lr, weight_decay):
    """ 
      Train the model with source data. Test the model with target data. 
    """
    #-----------------------------------------------------
    # Create Model, optimizer, scheduler, and loss function
    #------------------------------------------------------
    feature_extractor = Feature_Extractor().to(DEVICE)
    class_classifier = Class_Classifier().to(DEVICE)

    optimizer = optim.Adam([{
        'params': feature_extractor.parameters()
    }, {
        'params': class_classifier.parameters()
    }],
                           lr=lr,
                           betas=(opt.b1, opt.b2),
                           weight_decay=weight_decay)

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=[],
                                               gamma=0.1)
    class_criterion = torch.nn.CrossEntropyLoss().to(DEVICE)

    source_black = True if source == 'usps' else False
    target_black = True if target == 'usps' else False

    source_train_set = dataset.NumberClassify("./hw3_data/digits",
                                              source,
                                              train=True,
                                              black=source_black,
                                              transform=transforms.ToTensor())
    source_test_set = dataset.NumberClassify("./hw3_data/digits",
                                             source,
                                             train=False,
                                             black=source_black,
                                             transform=transforms.ToTensor())
    target_train_set = dataset.NumberClassify("./hw3_data/digits",
                                              target,
                                              train=True,
                                              black=target_black,
                                              transform=transforms.ToTensor())
    target_test_set = dataset.NumberClassify("./hw3_data/digits",
                                             target,
                                             train=False,
                                             black=target_black,
                                             transform=transforms.ToTensor())
    print("Source_train: \t{}, {}".format(source, len(source_train_set)))
    print("Source_test: \t{}, {}".format(source, len(source_test_set)))
    print("Target_train: \t{}, {}".format(target, len(target_train_set)))
    print("Target_test: \t{}, {}".format(target, len(target_test_set)))
    source_train_loader = DataLoader(source_train_set,
                                     batch_size=opt.batch_size,
                                     shuffle=True,
                                     num_workers=opt.threads)
    source_test_loader = DataLoader(source_test_set,
                                    batch_size=opt.batch_size,
                                    shuffle=False,
                                    num_workers=opt.threads)
    target_train_loader = DataLoader(target_train_set,
                                     batch_size=opt.batch_size,
                                     shuffle=True,
                                     num_workers=opt.threads)
    target_test_loader = DataLoader(target_test_set,
                                    batch_size=opt.batch_size,
                                    shuffle=False,
                                    num_workers=opt.threads)

    src_acc = []
    src_los = []
    tgt_acc = []
    tgt_los = []

    for epoch in range(1, epochs + 1):
        scheduler.step()

        feature_extractor, class_classifier = train(feature_extractor,
                                                    class_classifier,
                                                    source_train_loader,
                                                    optimizer, epoch,
                                                    class_criterion)
        test_accuracy, test_loss = val(feature_extractor, class_classifier,
                                       source_test_loader, target_test_loader,
                                       class_criterion)
        print(
            "[Epoch %d] [src_acc: %d%%] [tgc_acc: %d%%] [src_loss: %f] [tgc_loss: %f]"
            % (epoch, 100 * test_accuracy[0], 100 * test_accuracy[1],
               test_loss[0], test_loss[1]))

        # Tracing the accuracy
        src_acc.append(test_accuracy[0])
        src_los.append(test_loss[0])
        tgt_acc.append(test_accuracy[1])
        tgt_los.append(test_loss[1])

        x = np.arange(start=1, stop=epoch + 1)
        plt.clf()
        plt.figure(figsize=(12.8, 7.2))
        plt.plot(x, src_acc, label='Source Accuracy', linewidth=1)
        plt.plot(x, tgt_acc, label='Target Accuracy', linewidth=1)
        plt.legend(loc=0)
        plt.xlabel("Epochs(s)")
        plt.title("[Acc - Train {} Test {}] vs Epoch(s)".format(
            source, target))
        plt.savefig("Train_A_Test_B_{}_{}-Acc.png".format(source, target))
        plt.close()

        plt.clf()
        plt.figure(figsize=(12.8, 7.2))
        plt.plot(x, src_los, label='Source Loss', linewidth=1)
        plt.plot(x, tgt_los, label='Target Loss', linewidth=1)
        plt.legend(loc=0)
        plt.xlabel("Epochs(s)")
        plt.title("[Loss - Train {} Test {}] vs Epoch(s)".format(
            source, target))
        plt.savefig("Train_A_Test_B_{}_{}-Loss.png".format(source, target))
        plt.close()

        with open('statistics.txt', 'a') as textfile:
            # textfile.write(datetime.datetime.now().strftime("%d, %b %Y %H:%M:%S"))
            # textfile.write(str(src_acc))
            # textfile.write(str(tgt_acc))
            textfile.write("Source: {}, Target: {}, Accuracy: {}\n".format(
                source, target, max(tgt_acc)))
            print("Source: {}, Target: {}, Accuracy: {}".format(
                source, target, max(tgt_acc)))

        # if epoch % opt.save_interval == 0:
        #     save("./models/dann/{}/Train_A_Test_B_{}_{}_{}.pth".format(opt.tag, source, target, epoch), feature_extractor, class_classifier)

    return feature_extractor, class_classifier