예제 #1
0
def train(data_insts, data_labels, logger):
    # The confusion matrix stores the prediction accuracy between the source and the target tasks. The row index the source
    # task and the column index the target task.
    results = {}
    logger.info(
        "Training fraction = {}, number of actual training data instances = {}"
        .format(args.frac, configs["num_trains"]))
    logger.info("-" * 100)

    batch_size = configs["batch_size"]
    num_trains = configs["num_trains"]
    num_epochs = configs["num_epochs"]
    num_domains = configs["num_domains"]
    num_data_sets = configs["num_data_sets"]

    lr = configs["lr"]
    mu = configs["mu"]
    gamma = configs["gamma"]
    mode = configs["mode"]
    if configs["model"] == "gmdan":
        lamda = configs["lambda"]
    logger.info("Training with domain adaptation using PyTorch madnNet: ")
    logger.info("Hyperparameter setting = {}.".format(configs))
    error_dicts = {}
    # print (num_domains) 3 -- number of source domains
    for i in range(
            num_data_sets):  # for each domain, it trains the following model
        # Build source instances.
        source_insts = []
        source_labels = []
        # print (i)
        for j in range(num_data_sets
                       ):  # add every dataset except the current source here
            if j != i:
                source_insts.append(
                    data_insts[j][:num_trains, :].todense().astype(np.float32))
                source_labels.append(
                    data_labels[j][:num_trains, :].ravel().astype(np.int64))
        # print (len(source_insts)) # 3, number of other domains

        # Build target instances.
        target_idx = i  # the current domain
        target_insts = data_insts[i][:num_trains, :].todense().astype(
            np.float32)
        target_labels = data_labels[i][:num_trains, :].ravel().astype(np.int64)

        # print (target_insts.shape) # (4465, 5000) the current domain shape
        # print (target_labels.shape) # (4465,)

        # Train DannNet.
        if configs["model"] == "mdan":
            model = MDANet(configs).to(device)
        elif configs["model"] == "gmdan":
            model = GraphMDANet(configs, device).to(device)
        optimizer = optim.Adadelta(model.parameters(),
                                   lr=lr)  # why Adadelta here ?
        model.train()  # seems train is function by PyTorch
        # Training phase.
        time_start = time.time()
        for t in range(num_epochs):  # 15 epoch by default
            running_loss = 0.0
            train_loader = multi_data_loader(
                source_insts, source_labels, batch_size
            )  # containing instances and labels from multiple sources

            for xs, ys in train_loader:  # for each source-target pair

                losses, domain_losses, tripletlosses = train_batch(
                    target_insts, xs, ys, model, optimizer)

                if configs["model"] == "mdan":
                    loss = torch.log(
                        torch.sum(
                            torch.exp(gamma * (losses + mu * domain_losses)))
                    ) / gamma  # max domain_loss
                elif configs["model"] == "gmdan":
                    test_loader = multi_data_loader(
                        source_insts, source_labels, batch_size
                    )  # containing instances and labels from multiple sources
                    loss = torch.log(
                        torch.sum(
                            torch.exp(gamma * (losses + mu * domain_losses)
                                      ))) / gamma + lamda * torch.sum(
                                          tripletlosses)  # max domain_loss
                    # loss = torch.log(torch.sum(torch.exp(gamma * (losses + mu * domain_losses)))) / gamma # max domain_loss
                    # print ("----")
                    # print (losses)
                    # print (domain_losses)
                    # print (tripletlosses)
                    # print ("===")
                    # print (loss)
                else:
                    raise ValueError(
                        "No support for the training mode on madnNet: {}.".
                        format(configs["model"]))
                # print (loss.size())
                running_loss += loss.item()
                loss.backward()
                optimizer.step()
            logger.info("Iteration {}, loss = {}".format(t, running_loss))
            time_end = time.time()

        # Test on other domains.
        model.eval()
        val_target_insts = data_insts[i][num_trains:, :].todense().astype(
            np.float32)
        val_target_labels = data_labels[i][num_trains:, :].ravel().astype(
            np.int64)
        val_target_insts = torch.tensor(val_target_insts,
                                        requires_grad=False).to(device)
        val_target_labels = torch.tensor(val_target_labels)

        if configs["model"] == "mdan":
            preds_labels = torch.max(model.inference(val_target_insts),
                                     1)[1].cpu().data.squeeze_()
        elif configs["model"] == "gmdan":
            # generate source test data
            val_source_insts = []
            val_source_labels = []
            for j in range(num_data_sets):
                if j != i:
                    val_source_insts.append(
                        data_insts[j][:num_trains, :].todense().astype(
                            np.float32))
            for j in range(num_domains):  # turn source to tensor
                val_source_insts[j] = torch.tensor(
                    val_source_insts[j], requires_grad=False).to(device)
            preds_labels = torch.max(
                model.inference(val_source_insts, val_target_insts),
                1)[1].cpu().data.squeeze_()
        else:
            raise ValueError(
                "No support for the training mode on madnNet: {}.".format(
                    configs["model"]))

        val_target_labels = val_target_labels[:preds_labels.size()[0]]
        pred_acc = torch.sum(preds_labels == val_target_labels).item() / float(
            len(preds_labels))
        error_dicts[configs["data_name"]
                    [i]] = preds_labels.numpy() != val_target_labels.numpy()

        logger.info("Prediction accuracy on {} = {} ".format(
            configs["data_name"][i], pred_acc))
        print("-----")
        results[configs["data_name"][i]] = pred_acc
    logger.info(
        "Prediction accuracy with multiple source domain adaptation using madnNet: "
    )
    logger.info(results)
    pickle.dump(
        error_dicts,
        open(
            "{}-{}-{}-{}.pkl".format(args.name, args.frac, args.model,
                                     args.mode), "wb"))
    logger.info("*" * 100)
        # target_insts = data_insts[i].values.astype(np.float32)
        target_insts = torch.tensor(data_insts[i][num_trains:].values,
                                    requires_grad=True,
                                    dtype=torch.float32).to(device)
        target_labels = data_labels[i][num_trains:].astype(np.float32)
        # Train DannNet.
        mdan = MDANet(configs).to(device)
        #optimizer = optim.Adadelta(mdan.parameters(), lr=lr)
        optimizer = optim.SGD(mdan.parameters(), lr=lr)
        mdan.train()
        # Training phase.
        time_start = time.time()

        for t in range(num_epochs):
            running_loss = 0.0
            train_loader = multi_data_loader(source_insts, source_labels,
                                             batch_size)
            for xs, ys in train_loader:
                slabels = torch.ones(batch_size,
                                     requires_grad=True,
                                     dtype=torch.float32).to(device)
                tlabels = torch.zeros(batch_size,
                                      requires_grad=True,
                                      dtype=torch.float32).to(device)

                for j in range(num_domains):
                    xs[j] = torch.tensor(xs[j],
                                         requires_grad=True,
                                         dtype=torch.float32).to(device)
                    ys[j] = torch.tensor(ys[j],
                                         requires_grad=True,
                                         dtype=torch.float32).to(device)
예제 #3
0
def train(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    time_start = time.time()
    results = {}
    data_insts, data_labels, num_insts = [], [], []
    # mnist_m = MNIST_Dataset('MNIST_M',is_train = True,max_data_num =20000,transform=data_transforms[split_name[True]])
    # mnist = MNIST_Dataset('MNIST',is_train = True,max_data_num =20000,transform=data_transforms[split_name[True]])
    # synthdigits = MNIST_Dataset('SYNTHDIGITS',is_train = True,max_data_num =20000,transform=data_transforms[split_name[True]])
    
    num_data_sets = 4
    dataset_names = ['MNIST_M','SVHN','MNIST','SYNTHDIGITS']
    for i in range(num_data_sets):
        dataset = MNIST_Dataset(dataset_names[i],is_train = True,
            max_data_num =20000,transform=data_transforms[split_name[True]])
        
        tmp_data ,tmp_label = [],[]
        for j in range(len(dataset)):
            tmp_data.append(dataset[j][0])
            tmp_label.append(dataset[j][1])
        
        tmp_data =  torch.stack(tmp_data)
        tmp_label = np.array(tmp_label)
        
        data_insts.append(tmp_data)
        
        data_labels.append(tmp_label)


   
        
       
        
    
    if args.model=="mdan":
        configs = {"num_classes": 10,
               "num_epochs":5, "batch_size": 5, "lr": 1e-1, "mu": 10, "num_domains":
                   3,  "gamma": 10.0,  "lambda": 0.01, 'margin':0.1, 'dropout':0,'k':2,'alpha':0.2, 'device':device, 
                 "update_lr": 0.05, "meta_lr": 0.05, "update_step": 4 }
        configs["data_name"] = ['MNIST_M','SVHN','MNIST','SYNTHDIGITS']

        num_epochs = configs["num_epochs"]
        batch_size = configs["batch_size"]
        num_domains = configs["num_domains"]
        lr = configs["lr"]
        mu = configs["mu"]
        gamma = configs["gamma"]
        lamda = configs["lambda"]
        
        logger.info("Training with domain adaptation using PyTorch madnNet: ")
        logger.info("Hyperparameter setting = {}.".format(configs))
    
        
        error_dicts = {}
        target_data_insts, target_data_labels = [],[]
        for i in range(num_data_sets):
            # Build source instances.
            configs["test_task"] = configs["data_name"][i]

            source_insts = []
            source_labels = []
            infer_source_insts =[]
            infer_source_labels =[]
            for j in range(num_data_sets):
                if j != i:
                    configs["val_task"] = configs["data_name"][j]
                    val_task_id = j 
                    source_insts.append(data_insts[j][:,:,:,:].numpy().astype(np.float32))
                    source_labels.append(data_labels[j][:].ravel().astype(np.int64))
                    
            
            target_idx = i
            target_dataset = MNIST_Dataset(dataset_names[i],is_train = False,
                      max_data_num =20000,transform=data_transforms[split_name[False]])
            tmp_data ,tmp_label = [],[]
            for k in range(len(target_dataset)):
                tmp_data.append(target_dataset[k][0])
                tmp_label.append(target_dataset[k][1])
        
            tmp_data =  torch.stack(tmp_data)
            tmp_label = np.array(tmp_label)
        
            target_data_insts.append(tmp_data)
            
            target_data_labels.append(tmp_label)

            target_insts = target_data_insts[i][:,:,:,:]
            target_labels = target_data_labels[i][:].ravel().astype(np.int64)
            
           
            #model = OurNet(configs).to(device)
            print('all good until now')
            model = Mdan(configs).to(device)
            optimizer = optim.Adadelta(model.parameters(), lr=lr)
            model.train()
            # Training.
            
            time_start = time.time()
            for t in range(num_epochs):
                running_loss = 0.0
                train_loader = multi_data_loader(source_insts, source_labels, batch_size)
                for xs, ys in train_loader:
                    
                    slabels = torch.ones(batch_size, requires_grad=False).type(torch.LongTensor).to(device)
                    tlabels = torch.zeros(batch_size, requires_grad=False).type(torch.LongTensor).to(device)
                    for j in range(num_domains):
                        xs[j] = torch.tensor(xs[j], requires_grad=False).to(device)
                        
                        ys[j] = torch.tensor(ys[j], requires_grad=False).to(device)
                        
                    ridx = np.random.choice(target_insts.shape[0], batch_size)
                    tinputs = target_insts[ridx, :]
                    tinputs = torch.tensor(tinputs, requires_grad=False).to(device)
                    
                    optimizer.zero_grad()
                    logprobs, sdomains, tdomains= model(xs, tinputs, ys)
                    #logprobs, sdomains, tdomains= model(xs, tinputs)
                    #print('tinputsshape', tinputs.shape)
                    # Compute prediction accuracy on multiple training sources.
                    
                    losses = torch.stack([F.nll_loss(logprobs[j], ys[j]) for j in range(num_domains)])
                    domain_losses = torch.stack([F.nll_loss(sdomains[j], slabels) +
                                            F.nll_loss(tdomains[j], tlabels) for j in range(num_domains)])
                    
                    if mode == "maxmin":
                        loss = torch.max(losses) + mu * torch.min(domain_losses)
                    elif mode == "dynamic":
                        loss = torch.log(torch.sum(torch.exp(gamma * (losses + mu * domain_losses)))) / gamma
                        
                    else:
                        raise ValueError("No support for the training mode on madnNet: {}.".format(mode))
                    running_loss += loss.item()
                    loss.backward()
                    optimizer.step()
                logger.info("Iteration {}, loss = {}".format(t, torch.max(losses).item()))
                logger.info("Iteration {}, loss = {}".format(t, loss.item()))
                logger.info("Iteration {}, loss = {}".format(t, running_loss))
               
            time_end = time.time()
            # Test on other domains.
            model.eval()
            # target_insts = data_insts[i][:].astype(np.float32)
            # target_insts = torch.tensor(target_insts, requires_grad=False).to(device)
            
            # target_labels = data_labels[i][:].ravel().astype(np.float32)
            # target_labels = torch.tensor(target_labels, dtype=torch.long)
            
            target_labels = torch.tensor(target_labels, requires_grad=False, dtype= torch.long).cpu().data.squeeze_() # numpy 2 tensor
            model = model.cpu()
            preds_labels = torch.max(model.inference(target_insts), 1)[1].cpu().data.squeeze_()
            preds_labels = torch.tensor(preds_labels, requires_grad=False, dtype= torch.long) # numpy 2 tensor
            pred_acc = torch.sum(preds_labels == target_labels).item() / float(target_insts.size(0))
            print(preds_labels.shape)
            print(target_labels.shape)
            print(torch.sum(preds_labels == target_labels).item())
            print(target_insts.size(0))
            
            #pred_acc = torch.sum(preds_labels == target_labels).item() / float(target_insts.size(0))
            
            
            logger.info("Prediction accuracy on {} = {}, time used = {} seconds.".
                        format(dataset_names[i], pred_acc, time_end - time_start))
            results[dataset_names[i]] = pred_acc
        logger.info("Prediction accuracy with multiple source domain adaptation using madnNet: ")
        logger.info(results)
        
        logger.info("*" * 100)
    else:
        raise ValueError("No support for the following model: {}.".format(args.model))