class MLP_Regression():
    def __init__(self, label, parameters):
        super().__init__()
        self.writer = SummaryWriter(comment=f"_{label}_training")
        self.label = label
        self.lr = parameters['lr']
        self.hidden_units = parameters['hidden_units']
        self.mode = parameters['mode']
        self.batch_size = parameters['batch_size']
        self.num_batches = parameters['num_batches']
        self.x_shape = parameters['x_shape']
        self.y_shape = parameters['y_shape']
        self.save_model_path = f'{parameters["save_dir"]}/{label}_model.pt'
        self.best_loss = np.inf
        self.init_net(parameters)

    def init_net(self, parameters):
        if not os.path.exists(parameters["save_dir"]):
            os.makedirs(parameters["save_dir"])

        model_params = {
            'input_shape': self.x_shape,
            'classes': self.y_shape,
            'batch_size': self.batch_size,
            'hidden_units': self.hidden_units,
            'mode': self.mode
        }
        self.net = MLP(model_params).to(DEVICE)
        self.optimiser = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimiser,
                                                         step_size=5000,
                                                         gamma=0.5)
        print("MLP Parameters: ")
        print(
            f'batch size: {self.batch_size}, input shape: {model_params["input_shape"]}, hidden units: {model_params["hidden_units"]}, output shape: {model_params["classes"]}, lr: {self.lr}'
        )

    def train_step(self, train_data):
        self.net.train()
        for _, (x, y) in enumerate(train_data):
            x, y = x.to(DEVICE), y.to(DEVICE)
            self.net.zero_grad()
            self.loss_info = torch.nn.functional.mse_loss(self.net(x),
                                                          y,
                                                          reduction='sum')
            self.loss_info.backward()
            self.optimiser.step()

        self.epoch_loss = self.loss_info.item()

    def evaluate(self, x_test):
        self.net.eval()
        with torch.no_grad():
            y_test = self.net(x_test.to(DEVICE)).detach().cpu().numpy()
            return y_test

    def log_progress(self, step):
        write_loss(self.writer, self.loss_info, step)
示例#2
0
def model(data, labels):
    network = MLP(n_pixels, 10, args.units, args.nonlinear)
    if cuda:
        network.cuda()
    optimizer = th.optim.Adam(network.parameters(), args.lr)
    prediction_list = []
    for i in range(args.T):
        category = network(data)
        prediction_list.append(category)

        if i != args.T - 1:
            loss = F.cross_entropy(category, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return prediction_list
class Greedy_Bandit(Bandit):
    def __init__(self, label, *args):
        super().__init__(label, *args)
        self.writer = SummaryWriter(comment=f"_{label}_training"),

    def init_net(self, parameters):
        model_params = {
            'input_shape': self.x.shape[1] + 2,
            'classes': 1 if len(self.y.shape) == 1 else self.y.shape[1],
            'batch_size': self.batch_size,
            'hidden_units': parameters['hidden_units'],
            'mode': parameters['mode']
        }
        self.net = MLP(model_params).to(DEVICE)
        self.optimiser = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimiser,
                                                         step_size=5000,
                                                         gamma=0.5)
        print(f'Bandit {self.label} Parameters: ')
        print(
            f'buffer_size: {self.buffer_size}, batch size: {self.batch_size}, number of samples: {self.n_samples}, epsilon: {self.epsilon}'
        )
        print("MLP Parameters: ")
        print(
            f'input shape: {model_params["input_shape"]}, hidden units: {model_params["hidden_units"]}, output shape: {model_params["classes"]}, lr: {self.lr}'
        )

    def loss_step(self, x, y, batch_id):
        self.net.train()
        self.net.zero_grad()
        net_loss = torch.nn.functional.mse_loss(self.net(x).squeeze(),
                                                y,
                                                reduction='sum')
        net_loss.backward()
        self.optimiser.step()
        return net_loss

    def log_progress(self, step):
        write_loss(self.writer[0], self.loss_info, step)
        self.writer[0].add_scalar('logs/cumulative_regret',
                                  self.cumulative_regrets[-1], step)
class MLP_Classification():
    def __init__(self, label, parameters):
        super().__init__()
        self.writer = SummaryWriter(comment=f"_{label}_training")
        self.label = label
        self.lr = parameters['lr']
        self.hidden_units = parameters['hidden_units']
        self.mode = parameters['mode']
        self.batch_size = parameters['batch_size']
        self.num_batches = parameters['num_batches']
        self.x_shape = parameters['x_shape']
        self.classes = parameters['classes']
        self.save_model_path = f'{parameters["save_dir"]}/{label}_model.pt'
        self.best_acc = 0.
        self.dropout = parameters['dropout']
        self.init_net(parameters)
    
    def init_net(self, parameters):
        if not os.path.exists(parameters["save_dir"]):
            os.makedirs(parameters["save_dir"])

        model_params = {
            'input_shape': self.x_shape,
            'classes': self.classes,
            'batch_size': self.batch_size,
            'hidden_units': self.hidden_units,
            'mode': self.mode,
            'dropout': self.dropout,
        }
        if self.dropout:
            self.net = MLP_Dropout(model_params).to(DEVICE)
            print('MLP Dropout Parameters: ')
            print(f'batch size: {self.batch_size}, input shape: {model_params["input_shape"]}, hidden units: {model_params["hidden_units"]}, output shape: {model_params["classes"]}, lr: {self.lr}')
        else:
            self.net = MLP(model_params).to(DEVICE)
            print('MLP Parameters: ')
            print(f'batch size: {self.batch_size}, input shape: {model_params["input_shape"]}, hidden units: {model_params["hidden_units"]}, output shape: {model_params["classes"]}, lr: {self.lr}')
        self.optimiser = torch.optim.SGD(self.net.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimiser, step_size=100, gamma=0.5)

    def train_step(self, train_data):
        self.net.train()
        for _, (x, y) in enumerate(tqdm(train_data)):
            x, y = x.to(DEVICE), y.to(DEVICE)
            self.net.zero_grad()
            self.loss_info = torch.nn.functional.cross_entropy(self.net(x), y, reduction='sum')
            self.loss_info.backward()
            self.optimiser.step()

    def predict(self, X):
        probs = torch.nn.Softmax(dim=1)(self.net(X))
        preds = torch.argmax(probs, dim=1)
        return preds, probs

    def evaluate(self, test_loader):
        self.net.eval()
        print('Evaluating on validation data')
        correct = 0
        total = 0

        with torch.no_grad():
            for data in tqdm(test_loader):
                X, y = data
                X, y = X.to(DEVICE), y.to(DEVICE)
                preds, _ = self.predict(X)
                total += self.batch_size
                correct += (preds == y).sum().item()
        self.acc = correct / total
        print(f'{self.label} validation accuracy: {self.acc}')    

    def log_progress(self, step):
        write_loss(self.writer, self.loss_info, step)
        write_acc(self.writer, self.acc, step)
示例#5
0
def main(args):
    torch.manual_seed(args.seed)
    if not os.path.exists(args.res_dir):
        os.mkdir(args.res_dir)
    if not os.path.exists(
            os.path.join(args.res_dir, args.type + str(args.noise))):
        os.mkdir(os.path.join(args.res_dir, args.type + str(args.noise)))
    if not os.path.exists(
            os.path.join(args.res_dir, args.type + str(args.noise),
                         str(args.pace))):
        os.mkdir(
            os.path.join(args.res_dir, args.type + str(args.noise),
                         str(args.pace)))

    if not os.path.exists(args.model_dir):
        os.mkdir(args.model_dir)

    res_dir = os.path.join(args.res_dir, args.type + str(args.noise),
                           str(args.pace))

    data1 = dd.io.load(os.path.join(args.vec_dir, 'NYU_correlation_matrix.h5'))
    data2 = dd.io.load(os.path.join(args.vec_dir, 'UM_correlation_matrix.h5'))
    data3 = dd.io.load(os.path.join(args.vec_dir, 'USM_correlation_matrix.h5'))
    data4 = dd.io.load(os.path.join(args.vec_dir,
                                    'UCLA_correlation_matrix.h5'))

    x1 = torch.from_numpy(data1['data']).float()
    y1 = torch.from_numpy(data1['label']).long()
    x2 = torch.from_numpy(data2['data']).float()
    y2 = torch.from_numpy(data2['label']).long()
    x3 = torch.from_numpy(data3['data']).float()
    y3 = torch.from_numpy(data3['label']).long()
    x4 = torch.from_numpy(data4['data']).float()
    y4 = torch.from_numpy(data4['label']).long()

    if args.overlap:
        idNYU = dd.io.load('./idx/NYU_sub_overlap.h5')
        idUM = dd.io.load('./idx/UM_sub_overlap.h5')
        idUSM = dd.io.load('./idx/USM_sub_overlap.h5')
        idUCLA = dd.io.load('./idx/UCLA_sub_overlap.h5')
    else:
        idNYU = dd.io.load('./idx/NYU_sub.h5')
        idUM = dd.io.load('./idx/UM_sub.h5')
        idUSM = dd.io.load('./idx/USM_sub.h5')
        idUCLA = dd.io.load('./idx/UCLA_sub.h5')

    if args.split == 0:
        tr1 = idNYU['1'] + idNYU['2'] + idNYU['3'] + idNYU['4']
        tr2 = idUM['1'] + idUM['2'] + idUM['3'] + idUM['4']
        tr3 = idUSM['1'] + idUSM['2'] + idUSM['3'] + idUSM['4']
        tr4 = idUCLA['1'] + idUCLA['2'] + idUCLA['3'] + idUCLA['4']
        te1 = idNYU['0']
        te2 = idUM['0']
        te3 = idUSM['0']
        te4 = idUCLA['0']
    elif args.split == 1:
        tr1 = idNYU['0'] + idNYU['2'] + idNYU['3'] + idNYU['4']
        tr2 = idUM['0'] + idUM['2'] + idUM['3'] + idUM['4']
        tr3 = idUSM['0'] + idUSM['2'] + idUSM['3'] + idUSM['4']
        tr4 = idUCLA['0'] + idUCLA['2'] + idUCLA['3'] + idUCLA['4']
        te1 = idNYU['1']
        te2 = idUM['1']
        te3 = idUSM['1']
        te4 = idUCLA['1']
    elif args.split == 2:
        tr1 = idNYU['0'] + idNYU['1'] + idNYU['3'] + idNYU['4']
        tr2 = idUM['0'] + idUM['1'] + idUM['3'] + idUM['4']
        tr3 = idUSM['0'] + idUSM['1'] + idUSM['3'] + idUSM['4']
        tr4 = idUCLA['0'] + idUCLA['1'] + idUCLA['3'] + idUCLA['4']
        te1 = idNYU['2']
        te2 = idUM['2']
        te3 = idUSM['2']
        te4 = idUCLA['2']
    elif args.split == 3:
        tr1 = idNYU['0'] + idNYU['1'] + idNYU['2'] + idNYU['4']
        tr2 = idUM['0'] + idUM['1'] + idUM['2'] + idUM['4']
        tr3 = idUSM['0'] + idUSM['1'] + idUSM['2'] + idUSM['4']
        tr4 = idUCLA['0'] + idUCLA['1'] + idUCLA['2'] + idUCLA['4']
        te1 = idNYU['3']
        te2 = idUM['3']
        te3 = idUSM['3']
        te4 = idUCLA['3']
    elif args.split == 4:
        tr1 = idNYU['0'] + idNYU['1'] + idNYU['2'] + idNYU['3']
        tr2 = idUM['0'] + idUM['1'] + idUM['2'] + idUM['3']
        tr3 = idUSM['0'] + idUSM['1'] + idUSM['2'] + idUSM['3']
        tr4 = idUCLA['0'] + idUCLA['1'] + idUCLA['2'] + idUCLA['3']
        te1 = idNYU['4']
        te2 = idUM['4']
        te3 = idUSM['4']
        te4 = idUCLA['4']

    x1_train = x1[tr1]
    y1_train = y1[tr1]
    x2_train = x2[tr2]
    y2_train = y2[tr2]
    x3_train = x3[tr3]
    y3_train = y3[tr3]
    x4_train = x4[tr4]
    y4_train = y4[tr4]

    x1_test = x1[te1]
    y1_test = y1[te1]
    x2_test = x2[te2]
    y2_test = y2[te2]
    x3_test = x3[te3]
    y3_test = y3[te3]
    x4_test = x4[te4]
    y4_test = y4[te4]

    if args.sepnorm:
        mean = x1_train.mean(0, keepdim=True)
        dev = x1_train.std(0, keepdim=True)
        x1_train = (x1_train - mean) / dev
        x1_test = (x1_test - mean) / dev

        mean = x2_train.mean(0, keepdim=True)
        dev = x2_train.std(0, keepdim=True)
        x2_train = (x2_train - mean) / dev
        x2_test = (x2_test - mean) / dev

        mean = x3_train.mean(0, keepdim=True)
        dev = x3_train.std(0, keepdim=True)
        x3_train = (x3_train - mean) / dev
        x3_test = (x3_test - mean) / dev

        mean = x4_train.mean(0, keepdim=True)
        dev = x4_train.std(0, keepdim=True)
        x4_train = (x4_train - mean) / dev
        x4_test = (x4_test - mean) / dev
    else:
        mean = torch.cat((x1_train, x2_train, x3_train, x4_train),
                         0).mean(0, keepdim=True)
        dev = torch.cat((x1_train, x2_train, x3_train, x4_train),
                        0).std(0, keepdim=True)
        x1_train = (x1_train - mean) / dev
        x1_test = (x1_test - mean) / dev
        x2_train = (x2_train - mean) / dev
        x2_test = (x2_test - mean) / dev
        x3_train = (x3_train - mean) / dev
        x3_test = (x3_test - mean) / dev
        x4_train = (x4_train - mean) / dev
        x4_test = (x4_test - mean) / dev

    train1 = TensorDataset(x1_train, y1_train)
    train_loader1 = DataLoader(train1,
                               batch_size=len(train1) // args.nsteps,
                               shuffle=True)
    train2 = TensorDataset(x2_train, y2_train)
    train_loader2 = DataLoader(train2,
                               batch_size=len(train2) // args.nsteps,
                               shuffle=True)
    train3 = TensorDataset(x3_train, y3_train)
    train_loader3 = DataLoader(train3,
                               batch_size=len(train3) // args.nsteps,
                               shuffle=True)
    train4 = TensorDataset(x4_train, y4_train)
    train_loader4 = DataLoader(train4,
                               batch_size=len(train4) // args.nsteps,
                               shuffle=True)
    train_all = ConcatDataset([train1, train2, train3, train4])
    train_loader = DataLoader(train_all, batch_size=500, shuffle=False)

    test1 = TensorDataset(x1_test, y1_test)
    test2 = TensorDataset(x2_test, y2_test)
    test3 = TensorDataset(x3_test, y3_test)
    test4 = TensorDataset(x4_test, y4_test)
    test_loader1 = DataLoader(test1,
                              batch_size=args.test_batch_size1,
                              shuffle=False)
    test_loader2 = DataLoader(test2,
                              batch_size=args.test_batch_size2,
                              shuffle=False)
    test_loader3 = DataLoader(test3,
                              batch_size=args.test_batch_size3,
                              shuffle=False)
    test_loader4 = DataLoader(test4,
                              batch_size=args.test_batch_size4,
                              shuffle=False)
    tbs = [
        args.test_batch_size1, args.test_batch_size2, args.test_batch_size3,
        args.test_batch_size4
    ]

    model1 = MLP(6105, args.dim, 2).to(device)
    model2 = MLP(6105, args.dim, 2).to(device)
    model3 = MLP(6105, args.dim, 2).to(device)
    model4 = MLP(6105, args.dim, 2).to(device)
    optimizer1 = optim.Adam(model1.parameters(),
                            lr=args.lr1,
                            weight_decay=5e-2)
    optimizer2 = optim.Adam(model2.parameters(),
                            lr=args.lr2,
                            weight_decay=5e-2)
    optimizer3 = optim.Adam(model3.parameters(),
                            lr=args.lr3,
                            weight_decay=5e-2)
    optimizer4 = optim.Adam(model4.parameters(),
                            lr=args.lr4,
                            weight_decay=5e-2)

    models = [model1, model2, model3, model4]
    train_loaders = [
        train_loader1, train_loader2, train_loader3, train_loader4
    ]
    optimizers = [optimizer1, optimizer2, optimizer3, optimizer4]

    model = MLP(6105, args.dim, 2).to(device)
    print(model)
    nnloss = nn.NLLLoss()

    def train(epoch):
        pace = args.pace
        for i in range(4):
            models[i].train()
            if epoch <= 50 and epoch % 20 == 0:
                for param_group1 in optimizers[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']
            elif epoch > 50 and epoch % 20 == 0:
                for param_group1 in optimizers[i].param_groups:
                    param_group1['lr'] = 0.5 * param_group1['lr']

        #define weights
        w = dict()
        denominator = np.sum(np.array(tbs))
        for i in range(4):
            w[i] = 0.25  #tbs[i]/denominator

        loss_all = dict()
        num_data = dict()
        for i in range(4):
            loss_all[i] = 0
            num_data[i] = 0
        count = 0
        for t in range(args.nsteps):
            for i in range(4):
                optimizers[i].zero_grad()
                a, b = next(iter(train_loaders[i]))
                num_data[i] += b.size(0)
                a = a.to(device)
                b = b.to(device)
                output = models[i](a)
                loss = nnloss(output, b)
                loss.backward()
                loss_all[i] += loss.item() * b.size(0)
                optimizers[i].step()
            count += 1
            if count % pace == 0 or t == args.nsteps - 1:
                with torch.no_grad():
                    for key in model.state_dict().keys():
                        if models[0].state_dict()[key].dtype == torch.int64:
                            model.state_dict()[key].data.copy_(
                                models[0].state_dict()[key])
                        else:
                            temp = torch.zeros_like(model.state_dict()[key])
                            # add noise
                            for s in range(4):
                                if args.type == 'G':
                                    nn = tdist.Normal(
                                        torch.tensor([0.0]),
                                        args.noise *
                                        torch.std(models[s].state_dict()
                                                  [key].detach().cpu()))
                                else:
                                    nn = tdist.Laplace(
                                        torch.tensor([0.0]),
                                        args.noise *
                                        torch.std(models[s].state_dict()
                                                  [key].detach().cpu()))
                                noise = nn.sample(models[s].state_dict()
                                                  [key].size()).squeeze()
                                noise = noise.to(device)
                                temp += w[s] * (models[s].state_dict()[key] +
                                                noise)
                            # update global model
                            model.state_dict()[key].data.copy_(temp)
                            # updata local model
                            for s in range(4):
                                models[s].state_dict()[key].data.copy_(
                                    model.state_dict()[key])

        return loss_all[0] / num_data[0], loss_all[1] / num_data[1], \
               loss_all[2] / num_data[2], loss_all[3] / num_data[3]

    def test(federated_model, dataloader, train=False):
        federated_model.eval()
        test_loss = 0
        correct = 0
        outputs = []
        preds = []
        targets = []
        for data, target in dataloader:
            targets.append(target[0].detach().numpy())
            data = data.to(device)
            target = target.to(device)
            output = federated_model(data)
            outputs.append(output.detach().cpu().numpy())
            test_loss += nnloss(output, target).item() * target.size(0)
            pred = output.data.max(1)[1]
            preds.append(pred.detach().cpu().numpy())
            correct += pred.eq(target.view(-1)).sum().item()

        test_loss /= len(dataloader.dataset)
        correct /= len(dataloader.dataset)
        if train:
            print('Train set local: Average loss: {:.4f}, Average acc: {:.4f}'.
                  format(test_loss, correct))
        else:
            print('Test set local: Average loss: {:.4f}, Average acc: {:.4f}'.
                  format(test_loss, correct))
        return test_loss, correct, targets, outputs, preds

    best_acc = 0
    best_epoch = 0
    train_loss = dict()
    for i in range(4):
        train_loss[i] = list()
    for epoch in range(args.epochs):
        start_time = time.time()
        print(f"Epoch Number {epoch + 1}")
        l1, l2, l3, l4 = train(epoch)
        print(
            ' L1 loss: {:.4f}, L2 loss: {:.4f}, L3 loss: {:.4f}, L4 loss: {:.4f}'
            .format(l1, l2, l3, l4))
        train_loss[0].append(l1)
        train_loss[1].append(l2)
        train_loss[2].append(l3)
        train_loss[3].append(l4)
        test(model, train_loader, train=True)
        test(model, train_loader, train=True)

        print('===NYU===')
        _, acc1, targets1, outputs1, preds1 = test(model,
                                                   test_loader1,
                                                   train=False)
        print('===UM===')
        _, acc2, targets2, outputs2, preds2 = test(model,
                                                   test_loader2,
                                                   train=False)
        print('===USM===')
        _, acc3, targets3, outputs3, preds3 = test(model,
                                                   test_loader3,
                                                   train=False)
        print('===UCLA===')
        _, acc4, targets4, outputs4, preds4 = test(model,
                                                   test_loader4,
                                                   train=False)
        if (acc1 + acc2 + acc3 + acc4) / 4 > best_acc:
            best_acc = (acc1 + acc2 + acc3 + acc4) / 4
            best_epoch = epoch
        total_time = time.time() - start_time
        print('Communication time over the network', round(total_time, 2),
              's\n')
    model_wts = copy.deepcopy(model.state_dict())
    torch.save(model_wts, os.path.join(args.model_dir,
                                       str(args.split) + '.pth'))
    dd.io.save(os.path.join(res_dir, 'NYU_' + str(args.split) + '.h5'), {
        'outputs': outputs1,
        'preds': preds1,
        'targets': targets1
    })
    dd.io.save(os.path.join(res_dir, 'UM_' + str(args.split) + '.h5'), {
        'outputs': outputs2,
        'preds': preds2,
        'targets': targets2
    })
    dd.io.save(os.path.join(res_dir, 'USM_' + str(args.split) + '.h5'), {
        'outputs': outputs3,
        'preds': preds3,
        'targets': targets3
    })
    dd.io.save(os.path.join(res_dir, 'UCLA_' + str(args.split) + '.h5'), {
        'outputs': outputs4,
        'preds': preds4,
        'targets': targets4
    })
    dd.io.save(os.path.join(res_dir, 'train_loss.h5'), {'loss': train_loss})
    print('Best Acc:', best_acc)
    print('split:', args.split, '   noise:', args.noise, '   pace:', args.pace)
示例#6
0
class VanillaAE(nn.Module):
    def __init__(self, opt):
        super(VanillaAE, self).__init__()
        self.opt = opt
        self.device = torch.device("cuda:0" if not opt.no_cuda else "cpu")
        nc = int(opt.nc)
        imageSize = int(opt.imageSize)
        nz = int(opt.nz)
        nblk = int(opt.nblk)

        # generator
        self.netG = MLP(input_dim=nc * imageSize * imageSize,
                        output_dim=nc * imageSize * imageSize,
                        dim=nz,
                        n_blk=nblk,
                        norm='none',
                        activ='relu').to(self.device)
        weights_init(self.netG)
        if opt.netG != '':
            self.netG.load_state_dict(
                torch.load(opt.netG, map_location=self.device))
        print_and_write_log(opt.train_log_file, 'netG:')
        print_and_write_log(opt.train_log_file, str(self.netG))

        # losses
        self.criterion = nn.MSELoss()
        # define focal frequency loss
        self.criterion_freq = FFL(loss_weight=opt.ffl_w,
                                  alpha=opt.alpha,
                                  patch_factor=opt.patch_factor,
                                  ave_spectrum=opt.ave_spectrum,
                                  log_matrix=opt.log_matrix,
                                  batch_matrix=opt.batch_matrix).to(
                                      self.device)

        # misc
        self.to(self.device)

        # optimizer
        self.optimizerG = optim.Adam(self.netG.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.beta1, opt.beta2))

    def forward(self):
        pass

    def gen_update(self, data, epoch, matrix=None):
        self.netG.zero_grad()
        real = data.to(self.device)
        if matrix is not None:
            matrix = matrix.to(self.device)
        recon = self.netG(real)

        # apply pixel-level loss
        errG_pix = self.criterion(recon, real) * self.opt.mse_w

        # apply focal frequency loss
        if epoch >= self.opt.freq_start_epoch:
            errG_freq = self.criterion_freq(recon, real, matrix)
        else:
            errG_freq = torch.tensor(0.0).to(self.device)

        errG = errG_pix + errG_freq
        errG.backward()
        self.optimizerG.step()

        return errG_pix, errG_freq

    def sample(self, x):
        x = x.to(self.device)
        self.netG.eval()
        with torch.no_grad():
            recon = self.netG(x)
        self.netG.train()

        return recon

    def save_checkpoints(self, ckpt_dir, epoch):
        torch.save(self.netG.state_dict(),
                   '%s/netG_epoch_%03d.pth' % (ckpt_dir, epoch))
示例#7
0
def main(args):
    torch.manual_seed(args.seed)
    if not os.path.exists(args.res_dir):
        os.mkdir(args.res_dir)
    if not os.path.exists(args.model_dir):
        os.mkdir(args.model_dir)

    data1 = dd.io.load(os.path.join(args.vec_dir, 'NYU_correlation_matrix.h5'))
    data2 = dd.io.load(os.path.join(args.vec_dir, 'UM_correlation_matrix.h5'))
    data3 = dd.io.load(os.path.join(args.vec_dir, 'USM_correlation_matrix.h5'))
    data4 = dd.io.load(os.path.join(args.vec_dir,
                                    'UCLA_correlation_matrix.h5'))

    x1 = torch.from_numpy(data1['data']).float()
    y1 = torch.from_numpy(data1['label']).long()
    x2 = torch.from_numpy(data2['data']).float()
    y2 = torch.from_numpy(data2['label']).long()
    x3 = torch.from_numpy(data3['data']).float()
    y3 = torch.from_numpy(data3['label']).long()
    x4 = torch.from_numpy(data4['data']).float()
    y4 = torch.from_numpy(data4['label']).long()

    if args.overlap:
        idNYU = dd.io.load('./idx/NYU_sub_overlap.h5')
        idUM = dd.io.load('./idx/UM_sub_overlap.h5')
        idUSM = dd.io.load('./idx/USM_sub_overlap.h5')
        idUCLA = dd.io.load('./idx/UCLA_sub_overlap.h5')
    else:
        idNYU = dd.io.load('./idx/NYU_sub.h5')
        idUM = dd.io.load('./idx/UM_sub.h5')
        idUSM = dd.io.load('./idx/USM_sub.h5')
        idUCLA = dd.io.load('./idx/UCLA_sub.h5')

    if args.split == 0:
        tr1 = idNYU['1'] + idNYU['2'] + idNYU['3'] + idNYU['4']
        tr2 = idUM['1'] + idUM['2'] + idUM['3'] + idUM['4']
        tr3 = idUSM['1'] + idUSM['2'] + idUSM['3'] + idUSM['4']
        tr4 = idUCLA['1'] + idUCLA['2'] + idUCLA['3'] + idUCLA['4']
        te1 = idNYU['0']
        te2 = idUM['0']
        te3 = idUSM['0']
        te4 = idUCLA['0']
    elif args.split == 1:
        tr1 = idNYU['0'] + idNYU['2'] + idNYU['3'] + idNYU['4']
        tr2 = idUM['0'] + idUM['2'] + idUM['3'] + idUM['4']
        tr3 = idUSM['0'] + idUSM['2'] + idUSM['3'] + idUSM['4']
        tr4 = idUCLA['0'] + idUCLA['2'] + idUCLA['3'] + idUCLA['4']
        te1 = idNYU['1']
        te2 = idUM['1']
        te3 = idUSM['1']
        te4 = idUCLA['1']
    elif args.split == 2:
        tr1 = idNYU['0'] + idNYU['1'] + idNYU['3'] + idNYU['4']
        tr2 = idUM['0'] + idUM['1'] + idUM['3'] + idUM['4']
        tr3 = idUSM['0'] + idUSM['1'] + idUSM['3'] + idUSM['4']
        tr4 = idUCLA['0'] + idUCLA['1'] + idUCLA['3'] + idUCLA['4']
        te1 = idNYU['2']
        te2 = idUM['2']
        te3 = idUSM['2']
        te4 = idUCLA['2']
    elif args.split == 3:
        tr1 = idNYU['0'] + idNYU['1'] + idNYU['2'] + idNYU['4']
        tr2 = idUM['0'] + idUM['1'] + idUM['2'] + idUM['4']
        tr3 = idUSM['0'] + idUSM['1'] + idUSM['2'] + idUSM['4']
        tr4 = idUCLA['0'] + idUCLA['1'] + idUCLA['2'] + idUCLA['4']
        te1 = idNYU['3']
        te2 = idUM['3']
        te3 = idUSM['3']
        te4 = idUCLA['3']
    elif args.split == 4:
        tr1 = idNYU['0'] + idNYU['1'] + idNYU['2'] + idNYU['3']
        tr2 = idUM['0'] + idUM['1'] + idUM['2'] + idUM['3']
        tr3 = idUSM['0'] + idUSM['1'] + idUSM['2'] + idUSM['3']
        tr4 = idUCLA['0'] + idUCLA['1'] + idUCLA['2'] + idUCLA['3']
        te1 = idNYU['4']
        te2 = idUM['4']
        te3 = idUSM['4']
        te4 = idUCLA['4']

    x1_train = x1[tr1]
    y1_train = y1[tr1]
    x2_train = x2[tr2]
    y2_train = y2[tr2]
    x3_train = x3[tr3]
    y3_train = y3[tr3]
    x4_train = x4[tr4]
    y4_train = y4[tr4]

    x1_test = x1[te1]
    y1_test = y1[te1]
    x2_test = x2[te2]
    y2_test = y2[te2]
    x3_test = x3[te3]
    y3_test = y3[te3]
    x4_test = x4[te4]
    y4_test = y4[te4]

    if args.sepnorm:
        mean = x1_train.mean(0, keepdim=True)
        dev = x1_train.std(0, keepdim=True)
        x1_train = (x1_train - mean) / dev
        x1_test = (x1_test - mean) / dev

        mean = x2_train.mean(0, keepdim=True)
        dev = x2_train.std(0, keepdim=True)
        x2_train = (x2_train - mean) / dev
        x2_test = (x2_test - mean) / dev

        mean = x3_train.mean(0, keepdim=True)
        dev = x3_train.std(0, keepdim=True)
        x3_train = (x3_train - mean) / dev
        x3_test = (x3_test - mean) / dev

        mean = x4_train.mean(0, keepdim=True)
        dev = x4_train.std(0, keepdim=True)
        x4_train = (x4_train - mean) / dev
        x4_test = (x4_test - mean) / dev
    else:
        mean = torch.cat((x1_train, x2_train, x3_train, x4_train),
                         0).mean(0, keepdim=True)
        dev = torch.cat((x1_train, x2_train, x3_train, x4_train),
                        0).std(0, keepdim=True)
        x1_train = (x1_train - mean) / dev
        x1_test = (x1_test - mean) / dev
        x2_train = (x2_train - mean) / dev
        x2_test = (x2_test - mean) / dev
        x3_train = (x3_train - mean) / dev
        x3_test = (x3_test - mean) / dev
        x4_train = (x4_train - mean) / dev
        x4_test = (x4_test - mean) / dev

    train = TensorDataset(
        torch.cat((x1_train, x2_train, x3_train, x4_train), 0),
        torch.cat((y1_train, y2_train, y3_train, y4_train), 0))
    train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True)

    test1 = TensorDataset(x1_test, y1_test)
    test_loader1 = DataLoader(test1,
                              batch_size=args.test_batch_size1,
                              shuffle=False)
    test2 = TensorDataset(x2_test, y2_test)
    test_loader2 = DataLoader(test2,
                              batch_size=args.test_batch_size2,
                              shuffle=False)
    test3 = TensorDataset(x3_test, y3_test)
    test_loader3 = DataLoader(test3,
                              batch_size=args.test_batch_size3,
                              shuffle=False)
    test4 = TensorDataset(x4_test, y4_test)
    test_loader4 = DataLoader(test4,
                              batch_size=args.test_batch_size4,
                              shuffle=False)

    model = MLP(6105, args.dim, 2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-2)
    print(model)
    nnloss = nn.NLLLoss()

    def train(data_loader, epoch):
        model.train()

        if epoch <= 50 and epoch % 20 == 0:
            for param_group1 in optimizer.param_groups:
                param_group1['lr'] = 0.5 * param_group1['lr']
        elif epoch > 50 and epoch % 20 == 0:
            for param_group1 in optimizer.param_groups:
                param_group1['lr'] = 0.5 * param_group1['lr']

        loss_all1 = 0

        for data, target in data_loader:
            optimizer.zero_grad()
            data = data.to(device)
            target = target.to(device)
            output1 = model(data)
            loss1 = nnloss(output1, target)
            loss1.backward()
            loss_all1 += loss1.item() * target.size(0)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
        return loss_all1 / (len(data_loader.dataset))

    def test(data_loader, train=False):
        model.eval()
        test_loss = 0
        correct = 0
        outputs = []
        preds = []
        targets = []
        for data, target in data_loader:
            data = data.to(device)
            targets.append(target[0].detach().numpy())
            target = target.to(device)
            output = model(data)
            outputs.append(output.detach().cpu().numpy())
            test_loss += nnloss(output, target).item() * target.size(0)
            pred = output.data.max(1)[1]
            preds.append(pred.detach().cpu().numpy())
            correct += pred.eq(target.view(-1)).sum().item()

        test_loss /= len(data_loader.dataset)
        correct /= len(data_loader.dataset)
        if train:
            print(
                'Train set: Average loss: {:.4f}, Average acc: {:.4f}'.format(
                    test_loss, correct))
        else:
            print('Test set: Average loss: {:.4f}, Average acc: {:.4f}'.format(
                test_loss, correct))
        return test_loss, correct, targets, outputs, preds

    best_acc = 0
    best_epoch = 0
    for epoch in range(args.epochs):
        start_time = time.time()
        print(f"Epoch Number {epoch + 1}")
        l1 = train(train_loader, epoch)
        test(train_loader, train=True)
        print(' L1 loss: {:.4f}'.format(l1))
        print('===NYU===')
        _, acc1, targets1, outputs1, preds1 = test(test_loader1, train=False)
        print('===UM===')
        _, acc2, targets2, outputs2, preds2 = test(test_loader2, train=False)
        print('===USM===')
        _, acc3, targets3, outputs3, preds3 = test(test_loader3, train=False)
        print('===UCLA===')
        _, acc4, targets4, outputs4, preds4 = test(test_loader4, train=False)
        if (acc1 + acc2 + acc3 + acc4) / 4 > best_acc:
            best_acc = (acc1 + acc2 + acc3 + acc4) / 4
            best_epoch = epoch
        total_time = time.time() - start_time
        print('Communication time over the network', round(total_time, 2),
              's\n')
    model_wts = copy.deepcopy(model.state_dict())
    torch.save(model_wts, os.path.join(args.model_dir,
                                       str(args.split) + '.pth'))
    dd.io.save(os.path.join(args.res_dir, 'NYU_' + str(args.split) + '.h5'), {
        'outputs': outputs1,
        'preds': preds1,
        'targets': targets1
    })
    dd.io.save(os.path.join(args.res_dir, 'UM_' + str(args.split) + '.h5'), {
        'outputs': outputs2,
        'preds': preds2,
        'targets': targets2
    })
    dd.io.save(os.path.join(args.res_dir, 'USM_' + str(args.split) + '.h5'), {
        'outputs': outputs3,
        'preds': preds3,
        'targets': targets3
    })
    dd.io.save(os.path.join(args.res_dir, 'UCLA_' + str(args.split) + '.h5'), {
        'outputs': outputs4,
        'preds': preds4,
        'targets': targets4
    })
    print('Best Acc:', best_acc, 'Best Epoch:', best_epoch)
    print('split:', args.split)
示例#8
0
def main(args):
    if not os.path.exists(args.res_dir):
        os.mkdir(args.res_dir)
    if not os.path.exists(os.path.join(args.res_dir, args.trainsite)):
        os.mkdir(os.path.join(args.res_dir, args.trainsite))

    if not os.path.exists(args.model_dir):
        os.mkdir(args.model_dir)

    torch.manual_seed(args.seed)

    data1 = dd.io.load(os.path.join(args.vec_dir, 'NYU_correlation_matrix.h5'))
    data2 = dd.io.load(os.path.join(args.vec_dir, 'UM_correlation_matrix.h5'))
    data3 = dd.io.load(os.path.join(args.vec_dir, 'USM_correlation_matrix.h5'))
    data4 = dd.io.load(os.path.join(args.vec_dir,
                                    'UCLA_correlation_matrix.h5'))

    x1 = torch.from_numpy(data1['data']).float()
    y1 = torch.from_numpy(data1['label']).long()
    x2 = torch.from_numpy(data2['data']).float()
    y2 = torch.from_numpy(data2['label']).long()
    x3 = torch.from_numpy(data3['data']).float()
    y3 = torch.from_numpy(data3['label']).long()
    x4 = torch.from_numpy(data4['data']).float()
    y4 = torch.from_numpy(data4['label']).long()

    if args.sepnorm:
        mean = x1.mean(0, keepdim=True)
        dev = x1.std(0, keepdim=True)
        x1 = (x1 - mean) / dev
        mean = x2.mean(0, keepdim=True)
        dev = x2.std(0, keepdim=True)
        x2 = (x2 - mean) / dev
        mean = x3.mean(0, keepdim=True)
        dev = x3.std(0, keepdim=True)
        x3 = (x3 - mean) / dev
        mean = x4.mean(0, keepdim=True)
        dev = x4.std(0, keepdim=True)
        x4 = (x4 - mean) / dev

    else:
        if args.trainsite == 'NYU':
            mean = x1.mean(0, keepdim=True)
            dev = x1.std(0, keepdim=True)
        elif args.trainsite == 'UM':
            mean = x2.mean(0, keepdim=True)
            dev = x2.std(0, keepdim=True)
        elif args.trainsite == 'USM':
            mean = x3.mean(0, keepdim=True)
            dev = x3.std(0, keepdim=True)
        elif args.trainsite == 'UCLA':
            mean = x4.mean(0, keepdim=True)
            dev = x4.std(0, keepdim=True)
        x1 = (x1 - mean) / dev
        x2 = (x2 - mean) / dev
        x3 = (x3 - mean) / dev
        x4 = (x4 - mean) / dev

    datas = [
        TensorDataset(x1, y1),
        TensorDataset(x2, y2),
        TensorDataset(x3, y3),
        TensorDataset(x4, y4)
    ]

    if args.trainsite == 'NYU':
        train_loader = DataLoader(datas[0],
                                  batch_size=args.batch_size,
                                  shuffle=True)
    elif args.trainsite == 'UM':
        train_loader = DataLoader(datas[1],
                                  batch_size=args.batch_size,
                                  shuffle=True)
    elif args.trainsite == 'USM':
        train_loader = DataLoader(datas[2],
                                  batch_size=args.batch_size,
                                  shuffle=True)
    elif args.trainsite == 'UCLA':
        train_loader = DataLoader(datas[3],
                                  batch_size=args.batch_size,
                                  shuffle=True)

    test_loader1 = DataLoader(datas[0],
                              batch_size=args.test_batch_size1,
                              shuffle=False)
    test_loader2 = DataLoader(datas[1],
                              batch_size=args.test_batch_size2,
                              shuffle=False)
    test_loader3 = DataLoader(datas[2],
                              batch_size=args.test_batch_size3,
                              shuffle=False)
    test_loader4 = DataLoader(datas[3],
                              batch_size=args.test_batch_size4,
                              shuffle=False)

    model = MLP(6105, args.dim, 2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-2)
    print(model)
    nnloss = nn.NLLLoss()

    def train(data_loader, epoch):
        model.train()
        if epoch <= 50 and epoch % 20 == 0:
            for param_group1 in optimizer.param_groups:
                param_group1['lr'] = 0.5 * param_group1['lr']
        elif epoch > 50 and epoch % 20 == 0:
            for param_group1 in optimizer.param_groups:
                param_group1['lr'] = 0.5 * param_group1['lr']

        loss_all1 = 0

        for data, target in data_loader:
            optimizer.zero_grad()
            data = data.to(device)
            target = target.to(device)
            output1 = model(data)
            loss1 = nnloss(output1, target)
            loss1.backward()
            loss_all1 += loss1.item() * target.size(0)
            optimizer.step()

        return loss_all1 / (len(data_loader.dataset)), model

    def test(data_loader, train=False):
        model.eval()
        test_loss = 0
        correct = 0
        outputs = []
        preds = []
        targets = []
        for data, target in data_loader:
            data = data.to(device)
            targets.append(target[0].detach().numpy())
            target = target.to(device)
            output = federated_model(data)
            outputs.append(output.detach().cpu().numpy())
            test_loss += nnloss(output, target).item() * target.size(0)
            pred = output.data.max(1)[1]
            preds.append(pred.detach().cpu().numpy())
            correct += pred.eq(target.view(-1)).sum().item()

        test_loss /= len(data_loader.dataset)
        correct /= len(data_loader.dataset)
        if train:
            print(
                'Train set: Average loss: {:.4f}, Average acc: {:.4f}'.format(
                    test_loss, correct))
        else:
            print('Test set: Average loss: {:.4f}, Average acc: {:.4f}'.format(
                test_loss, correct))
        return test_loss, correct, targets, outputs, preds

    for epoch in range(args.epochs):
        start_time = time.time()
        print(f"Epoch Number {epoch + 1}")
        l1, federated_model = train(train_loader, epoch)
        print(' L1 loss: {:.4f}'.format(l1))
        print('===NYU===')
        _, acc1, targets1, outputs1, preds1 = test(test_loader1, train=False)
        print('===UM===')
        _, acc2, targets2, outputs2, preds2 = test(test_loader2, train=False)
        print('===USM===')
        _, acc3, targets3, outputs3, preds3 = test(test_loader3, train=False)
        print('===UCLA===')
        _, acc4, targets4, outputs4, preds4 = test(test_loader4, train=False)
        total_time = time.time() - start_time
        print('Communication time over the network', round(total_time, 2),
              's\n')

    model_wts = copy.deepcopy(model.state_dict())
    torch.save(model_wts, os.path.join(args.model_dir,
                                       args.trainsite + '.pth'))
    dd.io.save(os.path.join(args.res_dir, args.trainsite, 'NYU.h5'), {
        'outputs': outputs1,
        'preds': preds1,
        'targets': targets1
    })
    dd.io.save(os.path.join(args.res_dir, args.trainsite, 'UM.h5'), {
        'outputs': outputs2,
        'preds': preds2,
        'targets': targets2
    })
    dd.io.save(os.path.join(args.res_dir, args.trainsite, 'USM.h5'), {
        'outputs': outputs3,
        'preds': preds3,
        'targets': targets3
    })
    dd.io.save(os.path.join(args.res_dir, args.trainsite, 'UCLA.h5'), {
        'outputs': outputs4,
        'preds': preds4,
        'targets': targets4
    })
示例#9
0
def file_classify_demo(fobjs: List[FileInfo]):
    words = {}
    for fobj in filter(lambda x: not x.istest, fobjs):
        for kw, freq in zip(fobj.keywords, fobj.kwfreq):
            if kw in words:
                words[kw] += freq
            else:
                words[kw] = freq

    # make keyword score vec: train and test
    all_wordvec = []
    all_wordvec_test = []
    labels_train = []
    labels_test = []
    for fobj in fobjs:
        fobj.set_wordvec(words)
        if not fobj.kwfreq:
            continue
        if not fobj.istest:
            all_wordvec.append(fobj.wordvec)
            labels_train.append(fobj.label)
            # curlabel = [0] * 5
            # curlabel[fobj.label] = 1
            # labels.append(curlabel)
        else:
            all_wordvec_test.append(fobj.wordvec)
            labels_test.append(fobj.label)

    # pca make fingerprints
    inputdim = 20
    outputdim = 7
    pca = PCA(n_components=inputdim)
    pca.fit(all_wordvec)
    fprints = pca.transform(all_wordvec)
    fprints_test = pca.transform(all_wordvec_test)
    print('PCA ratio sum:', sum(pca.explained_variance_ratio_))
    print()

    x_train = torch.from_numpy(fprints).float()
    x_test = torch.from_numpy(fprints_test).float()
    y_train = torch.Tensor(labels_train).long()  # float()

    train_dataset = TensorDataset(x_train, y_train)
    dloader = DataLoader(train_dataset, batch_size=6, shuffle=True)

    model = MLP(inputdim, outputdim)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    lossfunc = nn.CrossEntropyLoss()

    epoch = 300 + 1
    for ecnt in range(epoch):
        for i, data in enumerate(dloader):
            optimizer.zero_grad()
            inputs, labels = data
            inputs = torch.autograd.Variable(inputs)
            labels = torch.autograd.Variable(labels)

            outputs = model(inputs)
            loss = lossfunc(outputs, labels)  # / outputs.size()[0]
            # loss = torch.Tensor([0])
            # for b in range(outputs.size()[0]):
            #     loss += sum(abs(outputs[b] - labels[b]))
            # loss /= outputs.size()[0]

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           5)  # clip param important
            optimizer.step()

            # if i % 1 == 0:
            #     print(i, ":", loss)
            #     print(outputs)
            #     print(labels)

        if ecnt % 20 == 0:
            print('Epoch:', ecnt)
            model.eval()
            y_train_step = model(x_train)
            y_train_step_label = np.argmax(y_train_step.data, axis=1)
            y_test_step = model(x_test)
            y_test_step_label = np.argmax(y_test_step.data, axis=1)
            tran_accu = len(
                list(
                    filter(lambda x: x == 0, y_train_step_label -
                           np.array(labels_train)))) / len(labels_train)
            test_accu = len(
                list(
                    filter(lambda x: x == 0, y_test_step_label -
                           np.array(labels_test)))) / len(labels_test)

            print('tran_accu', tran_accu)
            print('test_accu', test_accu)
            print()
            model.train()

    # save model
    save_path = r'.\ai\classify-demo.pth'
    torch.save(model.state_dict(), save_path)

    # load model
    new_model = MLP(inputdim, outputdim)
    new_model.load_state_dict(torch.load(save_path))

    y_train_look = new_model(x_train)
    y_test = new_model(x_test)
    print(y_train_look)
    print(y_test)
    print(np.argmax(y_test.data, axis=1))
    print(labels_test)
    print(len(labels_test))
    print(len(labels_train))
示例#10
0
def main(args):
    torch.manual_seed(args.seed)
    if not os.path.exists(args.res_dir):
        os.mkdir(args.res_dir)
    if not os.path.exists(args.model_dir):
        os.mkdir(args.model_dir)
    if not os.path.exists(os.path.join(args.model_dir, args.site)):
        os.mkdir(os.path.join(args.model_dir, args.site))
    save_model_dir = os.path.join(args.model_dir, args.site)

    data = dd.io.load(
        os.path.join(args.vec_dir, args.site + '_correlation_matrix.h5'))
    x = torch.from_numpy(data['data']).float()
    y = torch.from_numpy(data['label']).long()

    if args.site == 'NYU':
        rep = 145  #7
    elif args.site == 'UM':
        rep = 265  #9
    elif args.site == 'USM':
        rep = 205  #8
    elif args.site == 'UCLA':
        rep = 85  #7

    split_dir = os.path.join(args.id_dir, args.site + '_sub_overlap.h5')

    if not os.path.exists(split_dir):  #save splitting
        n = len(y) // rep
        ll = list(range(n))
        random.seed(args.seed)
        random.shuffle(ll)
        list1 = dict()
        for i in range(5):  # 5 splits
            list1[i] = list()
            if i != 4:
                temp = ll[i * n // 5:(i + 1) * n // 5]
            else:
                temp = ll[4 * n // 5:]
            for t in temp:
                list1[i] += [t * rep + j for j in range(rep)]

        dd.io.save(
            split_dir, {
                '0': list1[0],
                '1': list1[1],
                '2': list1[2],
                '3': list1[3],
                '4': list1[4]
            })
        print("data saved")

    id = dd.io.load(split_dir)

    if args.split == 0:
        tr = id['1'] + id['2'] + id['3'] + id['4']
        te = id['0']
    elif args.split == 1:
        tr = id['0'] + id['2'] + id['3'] + id['4']
        te = id['1']
    elif args.split == 2:
        tr = id['0'] + id['1'] + id['3'] + id['4']
        te = id['2']
    elif args.split == 3:
        tr = id['0'] + id['1'] + id['2'] + id['4']
        te = id['3']
    elif args.split == 4:
        tr = id['0'] + id['1'] + id['2'] + id['3']
        te = id['4']

    x_train = x[tr]
    y_train = y[tr]
    x_test = x[te]
    y_test = y[te]

    mean = x_train.mean(0, keepdim=True)
    dev = x_train.std(0, keepdim=True)
    x_train = (x_train - mean) / dev
    x_test = (x_test - mean) / dev

    train = TensorDataset(x_train, y_train)
    train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True)

    test = TensorDataset(x_test, y_test)
    test_loader = DataLoader(test, batch_size=rep, shuffle=False)

    model = MLP(6105, args.dim, 2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-2)
    print(model)

    def train(data_loader, optimizer, epoch):
        model.train()
        if epoch <= 50 and epoch % 20 == 0:
            for param_group1 in optimizer.param_groups:
                param_group1['lr'] = 0.5 * param_group1['lr']
        elif epoch > 50 and epoch % 20 == 0:
            for param_group1 in optimizer.param_groups:
                param_group1['lr'] = 0.5 * param_group1['lr']

        loss_all = 0

        for data, target in data_loader:
            optimizer.zero_grad()
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            loss = nnloss(output, target)
            loss.backward()
            loss_all += loss.item() * target.size(0)
            optimizer.step()

        return loss_all / (len(data_loader.dataset))

    def test(data_loader, train):
        model.eval()
        test_loss = 0
        correct = 0
        outputs = []
        preds = []
        targets = []
        for data, target in data_loader:
            data = data.to(device)
            targets.append(target[0].detach().numpy())
            target = target.to(device)
            output = model(data)
            outputs.append(output.detach().cpu().numpy())
            test_loss += nnloss(output, target).item() * target.size(0)
            pred = output.data.max(1)[1]
            preds.append(pred.detach().cpu().numpy())
            correct += pred.eq(target.view(-1)).sum().item()

        test_loss /= len(data_loader.dataset)
        correct /= len(data_loader.dataset)
        if train:
            print(
                'Train set: Average loss: {:.4f}, Average acc: {:.4f}'.format(
                    test_loss, correct))
        else:
            print('Test set: Average loss: {:.4f}, Average acc: {:.4f}'.format(
                test_loss, correct))
            return targets, outputs, preds

    for epoch in range(args.epochs):
        start_time = time.time()
        print(f"Epoch Number {epoch + 1}")
        l1 = train(train_loader, optimizer, epoch)
        print(' L1 loss: {:.4f}'.format(l1))
        test(train_loader, train=True)
        targets, outputs, preds = test(test_loader, train=False)
        total_time = time.time() - start_time
        print('Communication time over the network', round(total_time, 2),
              's\n')
    model_wts = copy.deepcopy(model.state_dict())
    torch.save(model_wts, os.path.join(save_model_dir,
                                       str(args.split) + '.pth'))
    dd.io.save(
        os.path.join(args.res_dir, args.site + '_' + str(args.split) + '.h5'),
        {
            'outputs': outputs,
            'preds': preds,
            'targets': targets
        })
    print('site:', args.site, '  split:', args.split)