Beispiel #1
0
def cross_session(data, label, session_id, subject_id, category_number, batch_size, iteration, lr, momentum, log_interval):
    ## LOSO
    train_idxs = list(range(3))
    del train_idxs[session_id]
    test_idx = session_id

    target_data, target_label = copy.deepcopy(data[test_idx][subject_id]), copy.deepcopy(label[test_idx][subject_id])
    source_data, source_label = copy.deepcopy(data[train_idxs][:, subject_id]), copy.deepcopy(label[train_idxs][:, subject_id])

    source_data_comb = np.vstack((source_data[0], source_data[1]))
    source_label_comb = np.vstack((source_label[0], source_label[1]))
    for j in range(1, len(source_data)):
        source_data_comb = np.vstack((source_data_comb, source_data[j]))
        source_label_comb = np.vstack((source_label_comb, source_label[j]))
    source_loader = torch.utils.data.DataLoader(dataset=utils.CustomDataset(source_data_comb, source_label_comb),
                                                            batch_size=batch_size,
                                                            shuffle=True,
                                                            drop_last=True)
    target_loader = torch.utils.data.DataLoader(dataset=utils.CustomDataset(target_data, target_label),
                                                            batch_size=batch_size, 
                                                            shuffle=True, 
                                                            drop_last=True)
    model = DCORAL(model=models.DeepCoral(pretrained=False, number_of_category=category_number),
                source_loader=source_loader,
                target_loader=target_loader,
                batch_size=batch_size,
                iteration=iteration,
                lr=lr,
                momentum=momentum,
                log_interval=log_interval)
    # print(model.__getModel__())
    acc = model.train()
    print('Target_session_id: {}, current_subject_id: {}, acc: {}'.format(test_idx, subject_id, acc))
    return acc
Beispiel #2
0
 def __init__(self, model=models.DeepCoral(), source_loader=0, target_loader=0, batch_size=64, iteration=10000, lr=0.001, momentum=0.9, log_interval=10):
     self.model = model
     self.model.to(device)
     self.source_loader = source_loader
     self.target_loader = target_loader
     self.batch_size = batch_size
     self.iteration = iteration
     self.lr = lr
     self.momentum = momentum
     self.log_interval = log_interval
Beispiel #3
0
def cross_train():

    #Basic parameters
    gpus = FLAG.gpus
    batch_size = FLAG.batch_size
    epoches = FLAG.epoch
    init_lr = FLAG.lr
    LOG_INTERVAL = 10
    TEST_INTERVAL = 2
    source_name = FLAG.source
    target_name = FLAG.target
    model_name = FLAG.arch
    adapt_mode = FLAG.adapt_mode
    l2_decay = 5e-4

    #Loading dataset
    if FLAG.isLT:
        source_train,target_train,target_test,classes = cross_dataset_LT(FLAG)
    else:
        source_train,target_train,target_test,classes = my_cross_dataset(FLAG)
    source_train_loader = torch.utils.data.DataLoader(dataset=source_train,batch_size=batch_size,
                    shuffle=True,num_workers=8,drop_last=True)
    target_train_loader = torch.utils.data.DataLoader(dataset=target_train,batch_size=batch_size,
                    shuffle=True,num_workers=8,drop_last=True)
    target_test_loader = torch.utils.data.DataLoader(dataset=target_test,batch_size=batch_size,
                    shuffle=False,num_workers=8)
    #Define model
    if adapt_mode == 'ddc':
        cross_model = models.DDCNet(FLAG)
        
        #adapt_loss_function = mmd_linear
        adapt_loss_function = mmd_rbf_noaccelerate
        #print(model)

    elif adapt_mode == 'coral':
        cross_model = models.DeepCoral(FLAG)
        adapt_loss_function = CORAL

    elif adapt_mode == 'mmd':
        cross_model = models.DDCNet(FLAG)
        adapt_loss_function = mmd_linear

    else:
        print('The adaptive model name is wrong !')
    
    if len(gpus)>1:
        gpus = gpus.split(',')
        gpus = [int(v) for v in gpus]
        cross_model = nn.DataParallel(cross_model,device_ids=gpus)

    cross_model.to(DEVICE)
    #Define Optimizer
    if len(gpus)>1:
        optimizer = optim.SGD([{'params':cross_model.module.sharedNet.parameters()},
                            {'params':cross_model.module.cls_fc.parameters(),'lr':init_lr}],
                            lr=init_lr/10,momentum=0.9,weight_decay=l2_decay)

    else:
        optimizer = optim.SGD([{'params':cross_model.sharedNet.parameters()},
                            {'params':cross_model.cls_fc.parameters(),'lr':init_lr}],
                            lr=init_lr/10,momentum=0.9,weight_decay=l2_decay)
    #print(optimizer.param_groups)
    #loss function
    criterion = torch.nn.CrossEntropyLoss()
    #Training
    
    best_result = 0.0
    #Model store
    model_dir = os.path.join('./cross_models/',adapt_mode+'-'+source_name+'2'+target_name+'-'+model_name)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    #Tensorboard configuration
    log_dir = os.path.join('./cross_logs/',adapt_mode+'-'+source_name+'2'+target_name+'-'+model_name)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    writer = SummaryWriter(logdir=log_dir)

    for epoch in range(1,epoches+1):
        cross_model.train()
        len_source_loader= len(source_train_loader)
        len_target_loader = len(target_train_loader)
        iter_source = iter(source_train_loader)
        iter_target = iter(target_train_loader)

        if len_target_loader <= len_source_loader:
            iter_num = len_target_loader
            which_dataset = True
        else:
            iter_num = len_source_loader
            which_dataset = False
        #Adaptive learning rate
        optimizer = adjust_lr(optimizer,epoch,FLAG)
        writer.add_scalar('data/SharedNet lr',optimizer.param_groups[0]['lr'],epoch)
        running_loss = 0.0
        for i in range(1,iter_num+1):

            if which_dataset:
                target_data,_ = next(iter_target)
                if i % len_target_loader == 0:
                    iter_source = iter(source_train_loader)
                source_data,source_label = next(iter_source)
            else:
                source_data,source_label = next(iter_source)
                if i % len_source_loader == 0:
                    iter_target = iter(target_train_loader)
                target_data,_ = next(iter_target)

            input_source_data,input_source_label = source_data.to(DEVICE),source_label.to(DEVICE).squeeze()
            input_target_data = target_data.to(DEVICE)

            optimizer.zero_grad()


            label_source_pred,source_output,target_output = cross_model(input_source_data, input_target_data)
            loss_adapt = adapt_loss_function(source_output,target_output)
            loss_cls = criterion(label_source_pred,input_source_label)
            lambda_1 = 2 / (1 + math.exp(-10 * (epoch) / epoches)) - 1
            loss = loss_cls + lambda_1 * loss_adapt
            
            
            if i%5 ==0:
                n_iter = (epoch-1)*len_target_loader+i
                writer.add_scalar('data/adapt loss',loss_adapt,n_iter)
                writer.add_scalar('data/cls loss',loss_cls,n_iter)
                writer.add_scalar('data/total loss',loss,n_iter)
                #print(optimizer.param_groups[0]['lr'])

            loss.backward()
            optimizer.step()

            #Print statistics
            running_loss += loss.item()
            if i%LOG_INTERVAL == 0: #Print every 30 mini-batches
                print('Epoch:[{}/{}],Batch:[{}/{}] loss: {}'.format(epoch,epoches,i,len_target_loader,running_loss/LOG_INTERVAL))
                running_loss = 0

        if epoch%TEST_INTERVAL ==0:   #Every 2 epoches
            
            acc_test,class_corr,class_total=cross_test(cross_model,target_test_loader,epoch)
            #log test acc
            writer.add_scalar('data/test accuracy',acc_test,epoch)
            #Store the best model
            if acc_test>best_result:
                model_path = os.path.join(model_dir,
                            '{}-{}-{}-epoch_{}-accval_{}.pth'.format(source_name,target_name,model_name,epoch,round(acc_test,3)))
                torch.save(cross_model,model_path)
                #log results for classes
                log_path = model_path = os.path.join(model_dir,
                            '{}-{}-{}-epoch_{}-accval_{}.csv'.format(source_name,target_name,model_name,epoch,round(acc_test,3)))
                log_to_csv(log_path,classes,class_corr,class_total)
                best_result = acc_test
            else:
                print('The results in this epoch cannot exceed the best results !')

    writer.close()
Beispiel #4
0
                                               CFG['batch_size'], False,
                                               CFG['kwargs'])
    return source_loader, target_train_loader, target_test_loader


if __name__ == '__main__':
    torch.manual_seed(CFG['seed'])

    source_name = "amazon"
    target_name = "webcam"

    source_loader, target_train_loader, target_test_loader = load_data(
        source_name, target_name, CFG['data_path'])

    model = models.DeepCoral(CFG['n_class'],
                             adapt_loss='mmd',
                             backbone='alexnet').to(DEVICE)
    optimizer = torch.optim.SGD([
        {
            'params': model.sharedNet.parameters()
        },
        {
            'params': model.fc.parameters()
        },
        {
            'params': model.cls_fc.parameters(),
            'lr': 10 * CFG['lr']
        },
    ],
                                lr=CFG['lr'],
                                momentum=CFG['momentum'],
            100. * correct / len_target_dataset))
        print('source: {} to target: {} max correct: {} max accuracy{: .2f}%\n'.format(
            source_name, target_name, correct, 100. * correct / len_target_dataset))


def load_data(src, tar, root_dir):
    source_loader = data_loader.load_data(
        root_dir, src, CFG['batch_size'], True, CFG['kwargs'])
    target_train_loader = data_loader.load_data(
        root_dir, tar, CFG['batch_size'], False, CFG['kwargs'])
    target_test_loader = data_loader.load_data(
        root_dir, tar, CFG['batch_size'], False, CFG['kwargs'])    
    return source_loader, target_train_loader, target_test_loader        

if __name__ == '__main__':
    torch.manual_seed(CFG['seed'])

    source_name = "amazon"
    target_name = "webcam"

    source_loader, target_train_loader, target_test_loader = load_data(source_name, target_name, CFG['data_path'])

    model = models.DeepCoral(CFG['n_class'], CFG['backbone']).to(DEVICE)
    optimizer = torch.optim.SGD([
        {'params': model.sharedNet.parameters()},
        {'params': model.fc.parameters()},
        {'params': model.cls_fc.parameters(), 'lr': 10 * CFG['lr']},
    ], lr=CFG['lr'], momentum=CFG['momentum'], weight_decay=CFG['l2_decay'])

    train(source_loader, target_train_loader, target_test_loader, model, optimizer, CFG)