예제 #1
0
def train(model, criterion, optimizer, loader, epoch):

    model.train()

    losses = util.Meter(ptag='Loss')
    top1 = util.Meter(ptag='Prec@1')

    for batch_idx, (data, target) in enumerate(loader):
        # data loading
        data = data.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # forward pass
        output = model(data)
        loss = criterion(output, target)

        # backward pass
        loss.backward()

        # gradient step
        optimizer.step()
        optimizer.zero_grad()

        # write log files
        train_acc = util.comp_accuracy(output, target)

        losses.update(loss.item(), data.size(0))
        top1.update(train_acc[0].item(), data.size(0))

        if batch_idx % args.print_freq == 0 and args.save:
            logging.debug(
                'epoch {} itr {}, '
                'rank {}, loss value {:.4f}, train accuracy {:.3f}'.format(
                    epoch, batch_idx, rank, losses.avg, top1.avg))

            with open(args.out_fname, '+a') as f:
                print('{ep},{itr},'
                      '{loss.val:.4f},{loss.avg:.4f},'
                      '{top1.val:.3f},{top1.avg:.3f},-1'.format(ep=epoch,
                                                                itr=batch_idx,
                                                                loss=losses,
                                                                top1=top1),
                      file=f)

    with open(args.out_fname, '+a') as f:
        print('{ep},{itr},'
              '{loss.val:.4f},{loss.avg:.4f},'
              '{top1.val:.3f},{top1.avg:.3f},-1'.format(ep=epoch,
                                                        itr=batch_idx,
                                                        loss=losses,
                                                        top1=top1),
              file=f)
예제 #2
0
def evaluate(model, test_loader):
    model.eval()
    top1 = util.Meter(ptag='Acc')

    with torch.no_grad():
        for data, target in test_loader:
            data = data.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)
            outputs = model(data)
            acc1 = util.comp_accuracy(outputs, target)
            top1.update(acc1[0].item(), data.size(0))

    return top1.avg
예제 #3
0
def train(rank, model, criterion, optimizer, scheduler, batch_meter,
          comm_meter, train_loader_list, test_loader_list, epoch, device,
          ue_list_epoches, G, user_weight_diff_array):
    average_model_weights = copy.deepcopy(model.state_dict())
    average_group_model_weights = copy.deepcopy(model.state_dict())

    model.train()
    WD_list = []
    top1 = util.Meter(ptag='Prec@1')
    iter_time = time.time()
    accum_steps = 1
    iteration = 0

    while iteration < args.iteration:

        ue_list = ue_list_epoches[epoch][
            iteration]  ### Get the users (a list) that are involved in the computation

        user_id = ue_list[rank]

        groups, server_list = get_groups(args)

        if args.user_semi:
            loader = zip(train_loader_list[user_id][0],
                         train_loader_list[user_id][1])
            test_loader = test_loader_list[0]

            if args.eval_grad and epoch % args.epoch_interval == 0:
                group_id = Get_group_num(args, groups, rank)
                checkpoint_weights = util_1.Load_Avg_model_checkpoint(
                    args.experiment_folder,
                    args.experiment_name,
                    epoch,
                    prefix=f'after_g{group_id+1}')
                model.load_state_dict(checkpoint_weights, strict=False)

        else:
            if args.H:

                if user_id in set(server_list):
                    loader = train_loader_list[0]
                    test_loader = test_loader_list[0]
                else:
                    loader = train_loader_list[user_id]
                    test_loader = test_loader_list[0]

                if args.eval_grad and epoch % args.epoch_interval == 0:
                    group_id = Get_group_num(args, groups, rank)
                    checkpoint_weights = util_1.Load_Avg_model_checkpoint(
                        args.experiment_folder,
                        args.experiment_name,
                        epoch,
                        prefix=f'after_g{group_id+1}')
                    model.load_state_dict(checkpoint_weights, strict=False)

            else:
                loader = train_loader_list[user_id]
                test_loader = test_loader_list[0]

                if args.eval_grad and epoch % args.epoch_interval == 0:
                    group_id = Get_group_num(args, groups, rank)
                    checkpoint_weights = util_1.Load_Avg_model_checkpoint(
                        args.experiment_folder,
                        args.experiment_name,
                        epoch,
                        prefix=f'before')
                    model.load_state_dict(checkpoint_weights, strict=False)

        while 1:
            break_flag = False
            train_loss = 0
            loss_steps = 0
            train_mask = 0

            for batch_idx, (data) in enumerate(loader):
                if args.user_semi:
                    data_x, data_u = data
                    inputs_x, targets_x = data_x
                    (inputs_u_w, inputs_u_s), _ = data_u

                    batch_size = inputs_x.shape[0]
                    inputs = torch.cat(
                        (inputs_x, inputs_u_w, inputs_u_s)).to(device)

                    targets_x = targets_x.to(device)
                    logits = model(inputs)
                    logits_x = logits[:batch_size]

                    logits_u_w = logits[batch_size:]
                    del logits

                    Lx = F.cross_entropy(logits_x, targets_x, reduction='mean')

                    pseudo_label = torch.softmax(logits_u_w.detach(), dim=-1)
                    max_probs, targets_u = torch.max(pseudo_label, dim=-1)

                    if not args.eval_grad:
                        mask = max_probs.ge(0.95).float()
                    else:
                        mask = max_probs.ge(args.tao).float()
                        train_mask += max_probs.ge(0.95).float().sum().item()

                    Lu = (F.cross_entropy(
                        logits_u_w, targets_u, reduction='none') *
                          mask).mean()

                    loss = Lx + Lu
                else:

                    if user_id in set(server_list):
                        inputs_x, targets_x = data
                        inputs_x = inputs_x.to(device)
                        targets_x = targets_x.to(device)
                        output = model(inputs_x)
                        loss = criterion(output, targets_x)
                    else:
                        if args.labeled:
                            (inputs_u_w, inputs_u_s), target_labels = data

                            inputs_x = inputs_u_w.to(device)
                            targets_x = target_labels.to(device)
                            output = model(inputs_x)
                            loss = criterion(output, targets_x)
                        else:
                            if args.ue_loss == 'CRL':
                                (inputs_u_w, inputs_u_s), _ = data

                                inputs = torch.cat(
                                    (inputs_u_w, inputs_u_s)).to(device)
                                logits = model(inputs)
                                logits_u_w, logits_u_s = logits.chunk(2)
                                del logits

                                pseudo_label = torch.softmax(
                                    logits_u_w.detach_(), dim=-1)
                                max_probs, targets_u = torch.max(pseudo_label,
                                                                 dim=-1)
                                if not args.eval_grad:
                                    mask = max_probs.ge(0.95).float()
                                else:

                                    mask = max_probs.ge(args.tao).float()
                                    train_mask += max_probs.ge(
                                        0.95).float().sum().item()

                                loss = (F.cross_entropy(
                                    logits_u_s, targets_u, reduction='none') *
                                        mask).mean()
                                train_loss += loss.item()
                                loss_steps += 1

                            if args.ue_loss == 'SF':

                                inputs_x, targets_x = data

                                inputs = inputs_x.to(device)
                                model.eval()

                                with torch.no_grad():
                                    logits = model(inputs)

                                pseudo_label = torch.softmax(logits.detach_(),
                                                             dim=-1)
                                max_probs, targets_u = torch.max(pseudo_label,
                                                                 dim=-1)

                                if not args.eval_grad:
                                    mask = max_probs.ge(0.95).float()
                                else:
                                    mask = max_probs.ge(args.tao).float()
                                    train_mask += max_probs.ge(
                                        0.95).float().sum().item()

                                model.train()
                                output = model(inputs)
                                loss = (F.cross_entropy(
                                    output, targets_u, reduction='none') *
                                        mask).mean()

                loss.backward()

                if not args.eval_grad:
                    optimizer.step()
                    optimizer.zero_grad()
                    scheduler.step()

                if not args.eval_grad:
                    if iteration != 0 and iteration % args.cp * accum_steps == 0:

                        if epoch % args.epoch_interval == 0 or epoch == args.epoch - 1:
                            util_1.Save_model_checkpoint(
                                args.experiment_name, model, rank, epoch)
                            group_id = Get_group_num(args, groups, rank)
                            save_each_group_avg_model(
                                args,
                                average_group_model_weights,
                                epoch,
                                rank,
                                rank_save=groups[group_id][0],
                                prefix=f'before_g{group_id+1}')
                            if rank == 0:
                                util_1.Save_Avg_model_checkpoint(
                                    args.experiment_name,
                                    average_model_weights,
                                    rank,
                                    epoch,
                                    prefix='before')

                        if args.user_semi:
                            if args.H:
                                ue_list = ue_list[0:args.num_comm_ue]
                                group1_size = len(ue_list) // 2
                                group1 = np.array(ue_list)[np.arange(
                                    0, group1_size).tolist()].tolist()
                                group2 = np.array(ue_list)[np.arange(
                                    group1_size,
                                    len(ue_list)).tolist()].tolist()

                                if rank < len(ue_list) // 2:
                                    #### Group 1 avgerage and communicate
                                    SyncAllreduce_1(model,
                                                    rank,
                                                    size=len(group1),
                                                    group=G[0])
                                else:
                                    #### Group 2 avgerage and communicate
                                    SyncAllreduce_1(model,
                                                    rank,
                                                    size=len(group2),
                                                    group=G[1])
                                if rank == 0 or rank == args.num_rank - 1:
                                    SyncAllreduce_1(model,
                                                    rank,
                                                    size=len(groups[-1]),
                                                    group=G[-1])
                            else:
                                SyncAllreduce(model, rank, args.num_rank)

                        else:
                            if args.H:
                                # print('Groupng method >>>>>')
                                average_group_model_weights = Grouping_Avg(
                                    args, model, rank, G, groups, epoch)
                            else:
                                SyncAllreduce(model, rank, args.num_rank)

                        average_model_weights = copy.deepcopy(
                            model.state_dict())
                        if epoch % args.epoch_interval == 0 or epoch == args.epoch - 1:
                            if rank == 0:
                                util_1.Save_Avg_model_checkpoint(
                                    args.experiment_name,
                                    average_model_weights,
                                    rank,
                                    epoch,
                                    prefix='after')

                        iteration += 1
                        break_flag = True

                        break
                iteration += 1

            if args.eval_grad:
                print(f"save grad. of the whole DataLoader of UE {user_id}")
                Save_model_grad_checkpoint(args.experiment_folder,
                                           args.experiment_name, model, rank,
                                           epoch, args.tao)
                ### save train_loss train_mask of this epoch
                values = {
                    'train_loss': train_loss,
                    'train_mask': train_mask,
                    'len_loader': len(loader)
                }
                print(epoch, rank, values)
                Save_train_state(args.experiment_folder, args.experiment_name,
                                 rank, epoch, values, args.tao)
                break
            if break_flag:
                break

    return user_id, WD_list, user_weight_diff_array
예제 #4
0
def run(size):
    models = []
    anchor_models = []
    optimizers = []
    ratios = []
    iters = []
    cps = args.cp
    save_names = []
    loss_Meters = []
    top1_Meters = []
    best_test_accs = []

    if args.constant_cp:
        cps = args.cp * args.size
    elif args.persistent:
        cps = [5, 5, 5, 5, 5, 5, 5, 20, 20, 20]
    else:
        local_cps = args.cp * np.ones(size, dtype=int)
        num_slow_nodes = int(size * args.slowRatio)
        np.random.seed(2020)
        random_cps = 5 + np.random.randn(num_slow_nodes) * 2
        for i in range(len(random_cps)):
            random_cps[i] = round(random_cps[i])
        local_cps[:num_slow_nodes] = random_cps
        # local_iterations = local_cps[rank]
        cps = local_cps

    for rank in range(args.size):
        # initiate experiments folder
        save_path = 'new_results/'
        folder_name = save_path + args.name
        if rank == 0 and os.path.isdir(folder_name) == False and args.save:
            os.mkdir(folder_name)
        # initiate log files
        tag = '{}/lr{:.3f}_bs{:d}_cr{:d}_avgcp{:.3f}_e{}_r{}_n{}.csv'
        saveFileName = tag.format(folder_name, args.lr, args.bs, args.cr,
                                  np.mean(args.cp), args.seed, rank, size)
        args.out_fname = saveFileName
        save_names.append(saveFileName)
        with open(args.out_fname, 'w+') as f:
            print('BEGIN-TRAINING\n'
                  'World-Size,{ws}\n'
                  'Batch-Size,{bs}\n'
                  'itr,'
                  'Loss,avg:Loss,Prec@1,avg:Prec@1,val'.format(ws=args.size,
                                                               bs=args.bs),
                  file=f)

        globalCp = args.globalCp
        total_size = args.total_size

        # seed for reproducibility
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True

        # load datasets
        train_loader, test_loader, dataRatio, x, y = partition_dataset(
            rank, total_size, 1, args.alpha, args.beta, args)
        ratios.append(dataRatio)
        print(sum([len(i) for i in x]))
        data_iter = iter(train_loader)
        iters.append(data_iter)

        # define neural nets model, criterion, and optimizer
        model = util.select_model(args.model, args)
        anchor_model = util.select_model(args.model, args)

        models.append(model)
        anchor_models.append(anchor_model)

        criterion = nn.CrossEntropyLoss()
        if args.FedProx:
            optimizer = FedProx.FedProxSGD(model.parameters(),
                                           lr=args.lr,
                                           momentum=0,
                                           nesterov=False,
                                           weight_decay=1e-4)
        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=0,
                                  nesterov=False,
                                  weight_decay=1e-4)
        optimizers.append(optimizer)

        batch_idx = 0
        best_test_accuracy = 0
        best_test_accs.append(best_test_accuracy)

        losses = util.Meter(ptag='Loss')
        top1 = util.Meter(ptag='Prec@1')
        loss_Meters.append(losses)
        top1_Meters.append(top1)

        model.train()
        tic = time.time()
        print(dataRatio, len(train_loader), len(test_loader))

    round_communicated = 0
    while round_communicated < args.cr:
        for rank in range(args.size):
            model = models[rank]
            anchor_model = anchor_models[rank]
            data_iter = iters[rank]
            optimizer = optimizers[rank]
            losses = loss_Meters[rank]
            top1 = top1_Meters[rank]

            for cp in range(cps[rank]):
                try:
                    data, target = data_iter.next()
                except StopIteration:
                    data_iter = iter(train_loader)
                    data, target = data_iter.next()

                # data loading
                data = data
                target = target

                # forward pass
                output = model(data)
                loss = criterion(output, target)

                # backward pass
                loss.backward()
                if args.FedProx:
                    optimizer.step(anchor_model, args.mu)
                else:
                    optimizer.step()
                optimizer.zero_grad()

                train_acc = util.comp_accuracy(output, target)
                losses.update(loss.item(), data.size(0))
                top1.update(train_acc[0].item(), data.size(0))

                # batch_idx += 1
            # change the worker
            train_loader, dataRatio = get_next_trainloader(
                round_communicated, x, y, rank, args)
            data_iter = iter(train_loader)
            iters[rank] = data_iter
            ratios[rank] = dataRatio

        if args.NSGD:
            NormalSGDALLreduce(models, anchor_models, cps, globalCp, ratios)
        elif args.FedProx:
            FedProx_SyncAllreduce(models, ratios, anchor_models)
        else:
            unbalanced_SyncAllreduce(models, ratios)
        round_communicated += 1
        # update_lr(optimizer, round_communicated)

        if round_communicated % 4 == 0:
            for rank in range(args.size):
                name = save_names[rank]
                losses = loss_Meters[rank]
                top1 = top1_Meters[rank]

                with open(name, '+a') as f:
                    print('{itr},'
                          '{loss.val:.4f},{loss.avg:.4f},'
                          '{top1.val:.3f},{top1.avg:.3f},-1'.format(
                              itr=round_communicated, loss=losses, top1=top1),
                          file=f)

        if round_communicated % 12 == 0:
            for rank in range(args.size):
                name = save_names[rank]
                model = models[rank]
                losses = loss_Meters[rank]
                top1 = top1_Meters[rank]
                name = save_names[rank]

                test_acc, global_loss = evaluate(model, test_loader, criterion)

                if test_acc > best_test_accs[rank]:
                    best_test_accs[rank] = test_acc

                print('itr {}, '
                      'rank {}, loss value {:.4f}, '
                      'train accuracy {:.3f}, test accuracy {:.3f}, '
                      'elasped time {:.3f}'.format(round_communicated, rank,
                                                   losses.avg, top1.avg,
                                                   test_acc,
                                                   time.time() - tic))

                with open(name, '+a') as f:
                    print('{itr},{filler},{filler},'
                          '{filler},{loss:.4f},'
                          '{val:.4f}'.format(itr=-1,
                                             filler=-1,
                                             loss=global_loss,
                                             val=test_acc),
                          file=f)

                losses.reset()
                top1.reset()
                tic = time.time()
                # return

    for rank in range(args.size):
        name = save_names[rank]
        with open(name, '+a') as f:
            print('{itr} best test accuracy: {val:.4f}'.format(
                itr=-2, val=best_test_accs[rank]),
                  file=f)
예제 #5
0
def run(rank, size, G):
    # initiate experiments folder
    save_path = f'./results_v0/{args.experiment_name}/'
    if rank == 0:
        if not os.path.exists(save_path):
            try:
                os.makedirs(save_path)
            except OSError:
                pass

        folder_name = save_path + args.name + '/'
        if rank == 0 and os.path.isdir(folder_name) == False and args.save:
            os.makedirs(folder_name)
    else:
        time.sleep(5)

    dist.barrier()
    # seed for reproducibility
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
    torch.backends.cudnn.deterministic = True

    train_loader_list, test_loader_list, path_device_idxs, max_len = Get_TrainLoader(
        args)  ### load datasets
    ue_list_epoches = util_1.Load_communicate_user_list(
        args, path_device_idxs)  ### load communicate user list

    # define neural nets model, criterion, and optimizer
    model = Get_Model(args)
    criterion = Get_Criterion(args)
    optimizer = Get_Optimizer(args, model, size=size, lr=args.lr)
    if args.fast == 0:
        fast = False
    else:
        fast = True

    scheduler = Get_Scheduler(args,
                              optimizer,
                              warmup_epoch=args.warmup_epoch,
                              fast=fast)

    batch_meter = util.Meter(ptag='Time')
    comm_meter = util.Meter(ptag='Time')

    print('Now train the model')

    Fed_training = True

    user_weight_diff_array = np.zeros(
        (args.size, args.epoch, args.iteration + 1))
    if Fed_training:

        if args.epoch_resume == 0:
            start_epoch = 0
            #### At the first epoch, we delete the past files
            if not args.eval_grad and rank == 0:
                util_1.init_files(args, save_path, rank, prefix='Test_Acc')
            Fed_acc_list = []

        else:
            start_epoch = args.epoch_resume + 1
            if rank == 0:
                Fed_acc_list = util_1.get_acc(args,
                                              save_path,
                                              rank,
                                              prefix='Test_Acc')

        if args.eval_grad:  #### at this time, we only want to cal. grad. dont want to change DataLoader
            args.iteration = 1

        if args.ue_loss == 'SF':
            args.iteration = max_len // args.bs

        for epoch in range(start_epoch, args.epoch):

            begin_time = time.time()
            if args.epoch_resume > 0 and epoch == args.epoch_resume + 1:
                print('Loading saved averaged model ... epoch=', epoch,
                      args.epoch_resume)
                checkpoint_weights = util_1.Load_Avg_model_checkpoint(
                    args.experiment_folder,
                    args.experiment_name,
                    epoch,
                    prefix='after')
                model.load_state_dict(checkpoint_weights, strict=False)

            if not args.eval_grad or epoch % args.epoch_interval == 0:
                user_id, WD_list, user_weight_diff_array = train(
                    rank, model, criterion, optimizer, scheduler, batch_meter,
                    comm_meter, train_loader_list, test_loader_list, epoch,
                    device, ue_list_epoches, G, user_weight_diff_array)
                # get and save the local fine-tuning acc
                if rank == 0:

                    test_acc = evaluate(model, test_loader_list[0])
                    test_acc = round(test_acc, 2)
                    print('test acc', epoch, test_acc,
                          time.time() - begin_time)

                    if not args.eval_grad:
                        Fed_acc_list.append(test_acc)
                        util_1.Save_acc_file(args,
                                             save_path,
                                             rank,
                                             prefix='Test_Acc',
                                             acc_list=Fed_acc_list)
def train(rank, model, criterion, optimizer, scheduler, batch_meter,
          comm_meter, loader, epoch, device, ue_list_epoches, G):

    model.train()

    top1 = util.Meter(ptag='Prec@1')

    iter_time = time.time()

    if args.H:
        if args.dataset == 'emnist':
            group1 = [0] + np.arange(1, 11).tolist()
            group2 = [48] + np.arange(11, 21).tolist()
            group3 = [49] + np.arange(21, 31).tolist()
            group4 = [50] + np.arange(31, 41).tolist()
            group5 = [51] + np.arange(41, 48).tolist()
            group6 = [0, 48, 49, 50, 51]
        else:
            group6 = [0, args.size - 1]

    for batch_idx, (data) in enumerate(loader):
        training = 0
        if args.num_comm_ue < args.size - 1 - args.H:
            ue_list = ue_list_epoches[epoch][batch_idx]
            ue_list_set = set(ue_list)
            if rank in ue_list_set:
                training = 1
            else:
                training = 0
        else:
            training = 1
        if training:
            if args.H:
                if rank in set(group6):
                    inputs_x, targets_x = data
                    inputs_x = inputs_x.to(device)
                    targets_x = targets_x.to(device)
                    output = model(inputs_x)
                    loss = criterion(output, targets_x)
                else:
                    (inputs_u_w, inputs_u_s), _ = data

                    inputs = torch.cat((inputs_u_w, inputs_u_s)).to(device)
                    logits = model(inputs)
                    logits_u_w, logits_u_s = logits.chunk(2)
                    del logits

                    pseudo_label = torch.softmax(logits_u_w.detach_(), dim=-1)
                    max_probs, targets_u = torch.max(pseudo_label, dim=-1)
                    mask = max_probs.ge(0.95).float()

                    loss = (F.cross_entropy(
                        logits_u_s, targets_u, reduction='none') *
                            mask).mean()
            else:
                if rank == 0:
                    inputs_x, targets_x = data
                    inputs_x = inputs_x.to(device)
                    targets_x = targets_x.to(device)
                    output = model(inputs_x)
                    loss = criterion(output, targets_x)
                else:
                    (inputs_u_w, inputs_u_s), _ = data

                    inputs = torch.cat((inputs_u_w, inputs_u_s)).to(device)
                    logits = model(inputs)
                    logits_u_w, logits_u_s = logits.chunk(2)
                    del logits

                    pseudo_label = torch.softmax(logits_u_w.detach_(), dim=-1)
                    max_probs, targets_u = torch.max(pseudo_label, dim=-1)
                    mask = max_probs.ge(0.95).float()

                    loss = (F.cross_entropy(
                        logits_u_s, targets_u, reduction='none') *
                            mask).mean()

            # backward pass
            accum_steps = 1
            loss = loss / accum_steps
            loss.backward()
            if batch_idx % accum_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

        torch.cuda.synchronize()
        comm_start = time.time()

        accum_steps = 1
        if args.H:
            if args.dataset == 'emnist':
                if batch_idx != 0 and batch_idx % args.cp * accum_steps == 0:

                    if rank in set(group1):
                        SyncAllreduce_1(model,
                                        rank,
                                        size=len(group1),
                                        group=G[0])
                    elif rank in set(group2):
                        SyncAllreduce_1(model,
                                        rank,
                                        size=len(group2),
                                        group=G[1])
                    elif rank in set(group3):
                        SyncAllreduce_1(model,
                                        rank,
                                        size=len(group3),
                                        group=G[2])
                    elif rank in set(group4):
                        SyncAllreduce_1(model,
                                        rank,
                                        size=len(group4),
                                        group=G[3])
                    elif rank in set(group5):
                        SyncAllreduce_1(model,
                                        rank,
                                        size=len(group5),
                                        group=G[4])
                    if rank in set(group6):
                        SyncAllreduce_1(model,
                                        rank,
                                        size=len(group6),
                                        group=G[5])

            else:
                if batch_idx != 0 and batch_idx % args.cp * accum_steps == 0:
                    if rank < args.size // 2:
                        #### Group 1 avgerage and communicate
                        SyncAllreduce_1(model,
                                        rank,
                                        size=args.size // 2,
                                        group=G[0])
                    else:
                        #### Group 2 avgerage and communicate
                        SyncAllreduce_1(model,
                                        rank,
                                        size=args.size - args.size // 2,
                                        group=G[1])
                    if rank == 0 or rank == args.size - 1:
                        #### Server model 1 and server 2 avgerage and communicate
                        SyncAllreduce_1(model, rank, size=2, group=G[2])

        else:
            if batch_idx != 0 and batch_idx % args.cp * accum_steps == 0:
                if args.num_comm_ue < args.size - 1:
                    ue_list = ue_list_epoches[epoch][batch_idx]
                    SyncAllreduce_2(model, rank, size, ue_list)
                else:
                    SyncAllreduce(model, rank, size)

        if not (epoch == 0 and batch_idx == 0):
            torch.cuda.synchronize()
            comm_meter.update(time.time() - comm_start)
            batch_meter.update(time.time() - iter_time)

        torch.cuda.synchronize()
        iter_time = time.time()
def run(rank, size, G):
    # initiate experiments folder
    save_path = './results_v0/'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    folder_name = save_path + args.name + '/'
    if rank == 0 and os.path.isdir(folder_name) == False and args.save:
        os.makedirs(folder_name)
    dist.barrier()

    # seed for reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True

    # load datasets
    if args.H:
        if args.dataset == 'emnist':
            labeled_set = [0, 48, 49, 50, 51]
            if rank in set(labeled_set):
                train_loader = labeled_trainloader
            else:
                train_loader = unlabeled_trainloader_list[rank - 1]
        else:
            if rank == 0 or rank == args.size - 1:
                train_loader = labeled_trainloader
            else:
                train_loader = unlabeled_trainloader_list[rank - 1]
    else:
        if rank == 0:
            train_loader = labeled_trainloader
        else:
            train_loader = unlabeled_trainloader_list[rank - 1]

    # define neural nets model, criterion, and optimizer
    model = util.select_model(args.model, args).cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          alpha=args.alpha,
                          gmf=args.gmf,
                          size=size,
                          momentum=0.9,
                          nesterov=True,
                          weight_decay=1e-4)

    args.iteration = args.k_img // args.bs
    total_steps = 1024 * args.iteration
    # total_steps = args.epoch * args.iteration
    warmup_epoch = 5
    if args.dataset == 'emnist':
        warmup_epoch = 0
        total_steps = args.epoch * args.iteration

    scheduler = get_cosine_schedule_with_warmup(optimizer,
                                                warmup_epoch * args.iteration,
                                                total_steps,
                                                lr_weight=1)

    batch_meter = util.Meter(ptag='Time')
    comm_meter = util.Meter(ptag='Time')

    best_test_accuracy = 0
    req = None
    acc_list = []
    print('Now train the model')

    for epoch in range(args.epoch):
        if rank == 0:
            begin_time = time.time()

        train(rank, model, criterion, optimizer, scheduler, batch_meter,
              comm_meter, train_loader, epoch, device, ue_list_epoches, G)
        ### test the server model
        if rank == 0:
            test_acc = evaluate(model, test_loader)
            acc_list.append(round(test_acc, 2))
            print('test acc', epoch, test_acc, time.time() - begin_time)
            if args.H:
                filename = f"./results_v0/{args.experiment_name}_{args.dataset}_iid{args.iid}_UE{args.size - 1}_{args.basicLabelRatio}_{args.model}_bs{args.bs}_H1_cp{args.cp}.txt"
            else:
                filename = f"./results_v0/{args.experiment_name}_{args.dataset}_iid{args.iid}_UE{args.size - 1 - args.H}_{args.basicLabelRatio}_comUE{args.num_comm_ue}_{args.model}_bs{args.bs}_H0_cp{args.cp}.txt"
            if filename:
                with open(filename, 'w') as f:
                    json.dump(acc_list, f)

        path_checkpoint = f"./checkpoint/{args.experiment_name}/"
        if not os.path.exists(path_checkpoint):
            os.makedirs(path_checkpoint)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict()
        }, path_checkpoint + f'{rank}_weights.pth')
예제 #8
0
def train(model, criterion, optimizer, batch_meter, comm_meter, loader, epoch,
          req):

    model.train()

    losses = util.Meter(ptag='Loss')
    top1 = util.Meter(ptag='Prec@1')
    weights = [1 / args.size for i in range(args.size)]

    iter_time = time.time()
    for batch_idx, (data, target) in enumerate(loader):
        # data loading
        data = data.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # forward pass
        output = model(data)
        loss = criterion(output, target)

        # backward pass
        loss.backward()
        update_learning_rate(optimizer,
                             epoch,
                             itr=batch_idx,
                             itr_per_epoch=len(loader))
        optimizer.step()
        optimizer.zero_grad()

        torch.cuda.synchronize()
        comm_start = time.time()

        ## CoCoD-SGD
        # optimizer.async_CoCoD_SGD_step(batch_idx, args.cp, req)

        ## Local SGD
        # if batch_idx != 0 and batch_idx % args.cp == 0:
        # SyncAllreduce(model, rank, size)

        optimizer.OverlapLocalSGD_step(batch_idx, args.cp, req)

        ## EASGD
        #optimizer.elastic_average(batch_idx, args.cp)

        if not (epoch == 0 and batch_idx == 0):
            torch.cuda.synchronize()
            comm_meter.update(time.time() - comm_start)
            batch_meter.update(time.time() - iter_time)

        # write log files
        train_acc = util.comp_accuracy(output, target)
        losses.update(loss.item(), data.size(0))
        top1.update(train_acc[0].item(), data.size(0))

        if batch_idx % args.print_freq == 0 and args.save:
            print('epoch {} itr {}, '
                  'rank {}, loss value {:.4f}, train accuracy {:.3f}'.format(
                      epoch, batch_idx, rank, losses.avg, top1.avg))

            with open(args.out_fname, '+a') as f:
                print('{ep},{itr},{bt},{ct},'
                      '{loss.val:.4f},{loss.avg:.4f},'
                      '{top1.val:.3f},{top1.avg:.3f},-1'.format(ep=epoch,
                                                                itr=batch_idx,
                                                                bt=batch_meter,
                                                                ct=comm_meter,
                                                                loss=losses,
                                                                top1=top1),
                      file=f)

        torch.cuda.synchronize()
        iter_time = time.time()

    with open(args.out_fname, '+a') as f:
        print('{ep},{itr},{bt},{ct},'
              '{loss.val:.4f},{loss.avg:.4f},'
              '{top1.val:.3f},{top1.avg:.3f},-1'.format(ep=epoch,
                                                        itr=batch_idx,
                                                        bt=batch_meter,
                                                        ct=comm_meter,
                                                        loss=losses,
                                                        top1=top1),
              file=f)
    return req
예제 #9
0
def run(rank, size):
    # initiate experiments folder
    save_path = '/users/jianyuw1/SGD_non_iid/results/'
    folder_name = save_path + args.name
    if rank == 0 and os.path.isdir(folder_name) == False and args.save:
        os.mkdir(folder_name)
    dist.barrier()
    # initiate log files
    tag = '{}/lr{:.3f}_bs{:d}_cp{:d}_a{:.2f}_b{:.2f}_e{}_r{}_n{}.csv'
    saveFileName = tag.format(folder_name, args.lr, args.bs, args.cp,
                              args.alpha, args.gmf, args.seed, rank, size)
    args.out_fname = saveFileName
    with open(args.out_fname, 'w+') as f:
        print('BEGIN-TRAINING\n'
              'World-Size,{ws}\n'
              'Batch-Size,{bs}\n'
              'Epoch,itr,BT(s),avg:BT(s),std:BT(s),'
              'CT(s),avg:CT(s),std:CT(s),'
              'Loss,avg:Loss,Prec@1,avg:Prec@1,val'.format(ws=args.size,
                                                           bs=args.bs),
              file=f)

    # seed for reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True

    # load datasets
    train_loader, test_loader = util.partition_dataset(rank, size, args)

    # define neural nets model, criterion, and optimizer
    model = util.select_model(args.model, args).cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          alpha=args.alpha,
                          gmf=args.gmf,
                          size=size,
                          momentum=0.9,
                          nesterov=True,
                          weight_decay=1e-4)

    batch_meter = util.Meter(ptag='Time')
    comm_meter = util.Meter(ptag='Time')

    best_test_accuracy = 0
    req = None
    for epoch in range(args.epoch):
        req = train(model, criterion, optimizer, batch_meter, comm_meter,
                    train_loader, epoch, req)
        test_acc = evaluate(model, test_loader)
        if test_acc > best_test_accuracy:
            best_test_accuracy = test_acc

        with open(args.out_fname, '+a') as f:
            print('{ep},{itr},{bt:.4f},{filler},{filler},'
                  '{ct:.4f},{filler},{filler},'
                  '{filler},{filler},'
                  '{filler},{filler},'
                  '{val:.4f}'.format(ep=epoch,
                                     itr=-1,
                                     bt=batch_meter.sum,
                                     ct=comm_meter.sum,
                                     filler=-1,
                                     val=test_acc),
                  file=f)