def cross_subject(data, label, session_id, category_number, batch_size, iteration, lr, momentum, log_interval): one_session_data, one_session_label = copy.deepcopy( data_tmp[session_id]), copy.deepcopy(label[session_id]) target_data, target_label = one_session_data.pop(), one_session_label.pop() source_data, source_label = copy.deepcopy(one_session_data), copy.deepcopy( one_session_label) # print(len(source_data)) source_loaders = [] for j in range(len(source_data)): source_loaders.append( torch.utils.data.DataLoader(dataset=utils.CustomDataset( source_data[j], source_label[j]), batch_size=batch_size, shuffle=True, drop_last=True)) target_loader = torch.utils.data.DataLoader(dataset=utils.CustomDataset( target_data, target_label), batch_size=batch_size, shuffle=True, drop_last=True) model = MSMDAER(model=models.MSMDAERNet( pretrained=False, number_of_source=len(source_loaders), number_of_category=category_number), source_loaders=source_loaders, target_loader=target_loader, batch_size=batch_size, iteration=iteration, lr=lr, momentum=momentum, log_interval=log_interval) # print(model.__getModel__()) acc = model.train() return acc
def __init__(self, model=models.MSMDAERNet(), source_loaders=0, target_loader=0, batch_size=64, iteration=10000, lr=0.001, momentum=0.9, log_interval=10): self.model = model self.model.to(device) self.source_loaders = source_loaders self.target_loader = target_loader self.batch_size = batch_size self.iteration = iteration self.lr = lr self.momentum = momentum self.log_interval = log_interval
>>>>>>> Stashed changes del one_session_label del one_session_data source_loaders = [] for j in range(len(source_data)): source_loaders.append(torch.utils.data.DataLoader(dataset=utils.CustomDataset(source_data[j], source_label[j]), batch_size=batch_size, shuffle=True, drop_last=True)) target_loader = torch.utils.data.DataLoader(dataset=utils.CustomDataset(target_data, target_label), batch_size=batch_size, shuffle=True, drop_last=True) model = MSMDAER(model=models.MSMDAERNet(pretrained=False, number_of_source=len(source_loaders), number_of_category=category_number), source_loaders=source_loaders, target_loader=target_loader, batch_size=batch_size, iteration=iteration, lr=lr, momentum=momentum, log_interval=log_interval) # print(model.__getModel__()) acc = model.train() print('Target_subject_id: {}, current_session_id: {}, acc: {}'.format(test_idx, session_id, acc)) return acc <<<<<<< Updated upstream def cross_session(data, label, subject_id, category_number, batch_size, iteration, lr, momentum, log_interval):