コード例 #1
0
def run(size, rank, epoch, batchsize):
    #print('run')
    if MODEL == 'CNN' and DATA_SET == 'KWS':
        model = CNNKws()
    if MODEL == 'CNN' and DATA_SET == 'Cifar10':
        model = CNNCifar()
    if MODEL == 'ResNet18' and DATA_SET == 'Cifar10':
        model = ResNet18()

    model = model.cuda()

    optimizer = torch.optim.SGD(model.parameters(), lr=LR, weight_decay=1e-3)
    loss_func = torch.nn.CrossEntropyLoss()

    train_loader = get_local_data(size, rank, batchsize)
    if rank == 0:
        test_loader = get_testset(rank)
        #fo = open("file_multi"+str(rank)+".txt", 'w')

    group_list = [i for i in range(size)]
    group = dist.new_group(group_list)

    model, round = load_model(model, group, rank)
    while round < MAX_ROUND:
        sys.stdout.flush()
        if rank == 0:
            accuracy = 0
            positive_test_number = 0
            total_test_number = 0
            for step, (test_x, test_y) in enumerate(test_loader):
                test_x = test_x.cuda()
                test_y = test_y.cuda()
                test_output = model(test_x)
                pred_y = torch.max(test_output, 1)[1].data.cpu().numpy()
                positive_test_number += (
                    pred_y == test_y.data.cpu().numpy()).astype(int).sum()
                # print(positive_test_number)
                total_test_number += float(test_y.size(0))
            accuracy = positive_test_number / total_test_number
            print('Round: ', round, ' Rank: ', rank,
                  '| test accuracy: %.4f' % accuracy)
            #fo.write(str(round) + "    " + str(rank) + "    " + str(accuracy) + "\n")

        for epoch_cnt in range(epoch):
            for step, (b_x, b_y) in enumerate(train_loader):
                b_x = b_x.cuda()
                b_y = b_y.cuda()
                optimizer.zero_grad()
                output = model(b_x)
                loss = loss_func(output, b_y)
                loss.backward()
                optimizer.step()

        model = all_reduce(model, size, group)
        #if (round+1) % ROUND_NUMBER_FOR_REDUCE == 0:
        #model = all_reduce(model, size, group)

        if (round + 1) % ROUND_NUMBER_FOR_SAVE == 0:
            save_model(model, round + 1, rank)
        round += 1
コード例 #2
0
    def _start_reduction_threads(self):
        num_buckets = len(self.bucket_sizes)
        self._reduction_queues = [queue.Queue() for _ in range(num_buckets)]
        self._reduction_threads = []
        self._reduction_streams = [[] for _ in range(num_buckets)]
        self._nccl_streams = []
        self._default_streams = []
        for dev_id in self.device_ids:
            with torch.cuda.device(dev_id):
                # TODO: don't assume we're on a default stream
                self._default_streams.append(torch.cuda.current_stream())
                self._nccl_streams.append(torch.cuda.Stream())
        for reduction_queue, reduction_streams in zip(self._reduction_queues, self._reduction_streams):
            for dev_id in self.device_ids:
                with torch.cuda.device(dev_id):
                    reduction_streams.append(torch.cuda.Stream())
            # We only use the first device for distributed reductions
            dist._register_stream(reduction_streams[0])

            group_id = dist.new_group()

            self._reduction_threads.append(threading.Thread(
                target=self._reduction_thread_fn,
                args=(reduction_queue, group_id, self.device_ids, reduction_streams, self._nccl_streams)))
            self._reduction_threads[-1].daemon = True
            self._reduction_threads[-1].start()
コード例 #3
0
    def _init_group_test(self):
        group = [1, 2]
        group_id = dist.new_group(group)
        rank = dist.get_rank()
        if rank not in group:
            return ([], None, rank)

        return (group, group_id, rank)
コード例 #4
0
def init_processes(master_address, world_size, rank, epoch_per_round,
                   batch_size, run):
    # change 'tcp' to 'nccl' if running on GPU worker
    dist.init_process_group(backend='tcp',
                            init_method=master_address,
                            world_size=world_size,
                            rank=rank)
    group = dist.new_group([i for i in range(world_size)])
    run(world_size, rank, group, epoch_per_round, batch_size)
コード例 #5
0
ファイル: master.py プロジェクト: holoword/fl
def run():
    modell = model.CNN()
    # modell = model.AlexNet()

    size = dist.get_world_size()
    rank = dist.get_rank()

    group_list = []
    for i in range(size):
        group_list.append(i)
    group = dist.new_group(group_list)

    while (1):

        for param in modell.parameters():
            # for dst in range(1, size):
            # dist.send(param.data, dst=dst)
            dist.broadcast(param.data, src=0, group=group)

        for param in modell.parameters():
            tensor_temp = torch.zeros_like(param.data)
            dist.reduce(tensor_temp, dst=0, op=dist.reduce_op.SUM, group=group)
            param.data = tensor_temp / (size - 1)
コード例 #6
0
    def _register_nccl_grad_hook(self):
        """
        This function registers the callback all-reduction function for the
        NCCL backend. All gradients will be all reduced in one single step.
        The NCCL reduction will directly be enqueued into the
        default CUDA stream. Therefore, no synchronization is needed.
        """
        # Creating a new group
        self.nccl_reduction_group_id = dist.new_group()

        def reduction_fn_nccl():
            # This function only needs to be called once
            if not self.need_reduction:
                return

            self.need_reduction = False
            all_grads = [[] for _ in range(len(self._module_copies))]
            all_grads_buckets_iters = []

            # Bucketing all the gradients
            for dev_idx, module in enumerate(self._module_copies):
                for param in module.parameters():
                    if not param.requires_grad or param.grad is None:
                        continue
                    if param.grad.requires_grad:
                        raise RuntimeError("DistributedDataParallel only works "
                                           "with gradients that don't require "
                                           "grad")
                    # Adding the gradients for reduction
                    all_grads[dev_idx].append(param.grad.data)

                # Now bucketing the parameters
                dev_grads_buckets = _take_tensors(all_grads[dev_idx],
                                                  self.nccl_reduce_bucket_size)

                all_grads_buckets_iters.append(dev_grads_buckets)

            # Now reduce each bucket one after another
            for grads_batch in zip(*all_grads_buckets_iters):
                grads_batch_coalesced = []
                # Coalesce each bucket
                for dev_idx, dev_grads_batch in enumerate(grads_batch):
                    dev_id = self.device_ids[dev_idx]
                    with torch.cuda.device(dev_id):
                        dev_grads_batch_coalesced = _flatten_dense_tensors(dev_grads_batch)
                        grads_batch_coalesced.append(dev_grads_batch_coalesced)

                # We will only use device 0's results, but this single op should be
                # faster than doing the following two operation sequentially:
                # (1) intra-node reduce to lead GPU, followed by
                # (2) inter-node allreduce for all the first lead GPUs in all nodes
                dist.all_reduce_multigpu(grads_batch_coalesced,
                                         group=self.nccl_reduction_group_id)

                # Now only work on the first device of self.device_ids, uncoalesce
                # the gradients for each bucket
                grads_batch_coalesced[0] /= dist.get_world_size()
                grads_batch_reduced = _unflatten_dense_tensors(grads_batch_coalesced[0], grads_batch[0])
                for grad, reduced in zip(grads_batch[0], grads_batch_reduced):
                    grad.copy_(reduced)

            # clear the gradients and save memory for replicas
            for module in self._module_copies[1:]:
                for param in module.parameters():
                    if param.requires_grad:
                        param.grad = None
                        param.data.set_()

        # Now register the reduction hook on the parameters
        for p in self.module.parameters():
            if not p.requires_grad:
                continue

            @torch.utils.hooks.unserializable_hook
            def allreduce_hook(*unused):
                Variable._execution_engine.queue_callback(reduction_fn_nccl)

            p.register_hook(allreduce_hook)
コード例 #7
0
ファイル: multi_slave.py プロジェクト: holoword/fl
def run(size, rank):


    modell = model.CNN()
    #modell = model.AlexNet()

    optimizer = torch.optim.Adam(modell.parameters(), lr=LR)
    loss_func = torch.nn.CrossEntropyLoss()



    if(IID == True):
        train_loader = Mnist().get_train_data()
        test_data = Mnist().get_test_data()
    else:
        if(rank > 0):
            if(rank == 1):
                train_loader = Mnist_noniid().get_train_data1()
                test_data = Mnist_noniid().get_test_data1()
            if(rank == 2):
                train_loader = Mnist_noniid().get_train_data2()
                test_data = Mnist_noniid().get_test_data2()
            if(rank == 3):
                train_loader = Mnist_noniid().get_train_data3()
                test_data = Mnist_noniid().get_test_data3()
            if(rank == 4):
                train_loader = Mnist_noniid().get_train_data4()
                test_data = Mnist_noniid().get_test_data4()
            if(rank == 5):
                train_loader = Mnist_noniid().get_train_data5()
                test_data = Mnist_noniid().get_test_data5()

    #size = dist.get_world_size()
    #rank = dist.get_rank()

    #train_loader = Mnist().get_train_data()
    #test_data = Mnist().get_test_data()

    for step, (b_x, b_y) in enumerate(test_data):
        test_x = b_x
        test_y = b_y

    group_list = []
    for i in range(size):
        group_list.append(i)
    group = dist.new_group(group_list)

    for epoch in range(MAX_EPOCH):

        modell = get_new_model(modell, group)
        #current_model = copy.deepcopy(modell)

        test_output, last_layer = modell(test_x)
        pred_y = torch.max(test_output, 1)[1].data.numpy()
        accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))


        for step, (b_x, b_y) in enumerate(train_loader):

            #modell = get_new_model(modell)
            #current_model = copy.deepcopy(modell)

            output = modell(b_x)[0]
            loss = loss_func(output, b_y)
            optimizer.zero_grad()
            loss.backward()   
            optimizer.step()

        for param in modell.parameters():
            dist.reduce(param.data, dst=0, op=dist.reduce_op.SUM, group=group)

        f = open('./test.txt', 'a')
        print('Epoch: ', epoch, ' Rank: ', rank, '| train loss: %.4f' % loss.data.numpy(),
              '| test accuracy: %.2f' % accuracy, file=f)
        print('Epoch: ', epoch, ' Rank: ', rank, '| train loss: %.4f' % loss.data.numpy(),
              '| test accuracy: %.2f' % accuracy)
        f.close()
コード例 #8
0
ファイル: learner.py プロジェクト: AragornThorongil/ICPP
def run(rank, workers, model, save_path, train_data, test_data):
    # Get the initial model from the server
    _group = [w for w in workers].append(0)
    group = dist.new_group(_group)

    for p in model.parameters():
        tmp_p = torch.zeros_like(p)
        dist.scatter(tensor=tmp_p, src=0, group=group)
        p.data = tmp_p
    print('Model recved successfully!')

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        optimizer = MySGD(model.parameters(), lr=0.1)
    else:
        optimizer = MySGD(model.parameters(), lr=0.01)

    if args.model in ['MnistCNN', 'AlexNet']:
        criterion = torch.nn.NLLLoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if args.model in ['AlexNet', 'ResNet18OnCifar10']:
        decay_period = 50
    else:
        decay_period = 100

    print('Begin!')

    time_logs = open("./record" + str(rank), 'w')
    for epoch in range(args.epochs):
        batch_interval = 0.0
        batch_comp_interval = 0.0
        batch_comm_interval = 0.0
        s_time = time.time()
        model.train()

        # Reduce the learning rate LR in some specific epochs
        #if args.model == 'AlexNet':
        if (epoch + 1) % decay_period == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
                print('LR Decreased! Now: {}'.format(param_group['lr']))

        epoch_train_loss = 0
        for batch_idx, (data, target) in enumerate(train_data):
            batch_start_time = time.time()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            delta_ws = optimizer.get_delta_w()

            batch_comp_time = time.time()
            # Synchronization
            # send epoch train loss firstly
            dist.gather(loss.data, dst=0, group=group)
            for idx, param in enumerate(model.parameters()):
                dist.gather(tensor=delta_ws[idx], dst=0, group=group)
                recv = torch.zeros_like(delta_ws[idx])
                dist.scatter(tensor=recv, src=0, group=group)
                param.data = recv

            epoch_train_loss += loss.data.item()
            batch_end_time = time.time()

            batch_interval += batch_end_time - batch_start_time
            batch_comp_interval += batch_comp_time - batch_start_time
            batch_comm_interval += batch_end_time - batch_comp_time

            logs = torch.tensor([
                0.0, batch_interval / (batch_idx + 1),
                batch_comp_interval / (batch_idx + 1),
                batch_comm_interval / (batch_idx + 1)
            ])
            time_logs.write(str(logs) + '\n')
            time_logs.flush()

        print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch,
                                                  loss.data.item()))

        e_time = time.time()
        #epoch_train_loss /= len(train_data)
        #epoch_train_loss = format(epoch_train_loss, '.4f')
        # test the model
        #test_loss, acc = test_model(rank, model, test_data, criterion=criterion)
        acc = 0.0
        batch_interval /= batch_idx
        batch_comp_interval /= batch_idx
        batch_comm_interval /= batch_idx
        logs = torch.tensor(
            [acc, batch_interval, batch_comp_interval, batch_comm_interval])
        time_logs.write(str(logs) + '\n')
        time_logs.flush()
        #dist.gather(tensor=logs, dst = 0, group = group)
    time_logs.close()
コード例 #9
0
ファイル: param_server.py プロジェクト: AragornThorongil/DGS
def run(rank, model, train_pics, train_bsz):
    workers = [v + 1 for v in range(args.workers_num)]
    _group = [w for w in workers].append(rank)
    group = dist.new_group(_group)

    for p in model.parameters():
        scatter_p_list = [p.data for _ in range(len(workers) + 1)]
        dist.scatter(tensor=p.data, scatter_list=scatter_p_list, group=group)

    print('Model Sent Finished!')

    print('Begin!')

    trainloss_file = './trainloss' + args.model + '.txt'
    if (os.path.isfile(trainloss_file)):
        os.remove(trainloss_file)
    f_trainloss = open(trainloss_file, 'a')

    tmp = [
        (0, 0)
        for _ in range(int(math.ceil(train_pics / (len(workers) * train_bsz))))
    ]

    g_list = [[torch.zeros_like(param.data) for param in model.parameters()]
              for _ in range(len(workers) + 1)]

    s_time = time.time()
    epoch_time = s_time
    global_clock = 0
    epoch_train_loss = 0.0
    sparsification_ratio = 0.0
    for epoch in range(args.epochs):
        for batch_idx, (_, _) in enumerate(tmp):
            # receive the list of train loss from workers
            info_list = [torch.tensor([0.0]) for _ in range(len(workers) + 1)]
            dist.gather(tensor=torch.tensor([0.0]),
                        gather_list=info_list,
                        group=group)
            batch_loss = sum(info_list).item() / len(workers)
            epoch_train_loss += batch_loss

            sparsification_ratio_list = [
                torch.tensor([0.0]) for _ in range(len(workers) + 1)
            ]
            dist.gather(tensor=torch.tensor([0.0]),
                        gather_list=sparsification_ratio_list,
                        group=group)
            batch_ratio = sum(sparsification_ratio_list).item() / len(workers)
            sparsification_ratio += batch_ratio

            param_idx = 0
            g_quare_sum = torch.tensor([0.])
            for layer_idx, param in enumerate(model.parameters()):
                tensor = torch.zeros_like(param.data)
                # FIXME FIXED:gather_list中的每个Tensor都必须是新的对象,否则会出问题
                gather_list = [
                    torch.zeros_like(param.data)
                    for _ in range(len(workers) + 1)
                ]
                dist.gather(tensor=tensor,
                            gather_list=gather_list,
                            group=group)
                tensor = sum(gather_list) / len(workers)
                param.data -= tensor
                scatter_list = [param.data for _ in range(len(workers) + 1)]
                dist.scatter(tensor=tensor,
                             scatter_list=scatter_list,
                             group=group)
                param_idx += 1

            global_clock += 1

            f_trainloss.write(
                str(args.this_rank) + "\t" +
                str(epoch_train_loss / float(batch_idx + 1)) + "\t" +
                str(batch_loss) + "\t" + str(0) + "\t" + str(0) + "\t" +
                str(epoch) + "\t" + str(0) + "\t" + str(0) + "\t" + str(0) +
                "\t" + str(batch_ratio) + "\t" +
                str(sparsification_ratio / float(batch_idx + 1)) + "\t" +
                str(global_clock) + '\n')
            # f_trainloss.flush()

            #print('Done {}/{}!'.format(batch_idx, len(tmp)))
        print('Done Epoch {}/{}!'.format(epoch + 1, args.epochs))

        e_epoch_time = time.time()

        # test_acc, batch_interval, batch_comp_interval, batch_comm_interval
        logs = torch.tensor([0.0, 0.0, 0.0, 0.0])
        logs_list = [torch.zeros_like(logs) for _ in range(len(workers) + 1)]
        #dist.gather(tensor = logs, gather_list = logs_list, group = group)
        test_acc, batch_interval, batch_comp_interval, batch_comm_interval = zip(
            *logs_list)
        test_acc = sum(test_acc) / len(workers)
        batch_interval = sum(batch_interval) / len(workers)
        batch_comp_interval = sum(batch_comp_interval) / len(workers)
        batch_comm_interval = sum(batch_comm_interval) / len(workers)

        # f_trainloss.write(str(args.this_rank) +
        #                   "\t" + str(epoch_train_loss / float(batch_idx+1)) +
        #                   "\t" + str(0) +
        #                   "\t" + str(e_epoch_time - epoch_time) +
        #                   "\t" + str(e_epoch_time - s_time) +
        #                   "\t" + str(epoch) +
        #                   "\t" + str(test_acc.item()) +
        #                   "\t" + str(batch_interval.item()) +
        #                   "\t" + str(batch_comp_interval.item()) +
        #                   "\t" + str(batch_comm_interval.item()) +
        #                   "\t" + str(sparsification_ratio / float(batch_idx+1)) +
        #                   "\t" + str(global_clock) +
        #                   '\n')
        f_trainloss.flush()
        epoch_time = e_epoch_time
        epoch_train_loss = 0.0
        sparsification_ratio = 0.0

        if (epoch + 1) % 2 == 0:
            if not os.path.exists('model_state'):
                os.makedirs('model_state')
            torch.save(
                model.state_dict(), 'model_state' + '/' + args.model + '_' +
                str(epoch + 1) + '.pkl')

    f_trainloss.close()
コード例 #10
0
ファイル: learner.py プロジェクト: AragornThorongil/DGS
def run(rank, workers, model, save_path, train_data, test_data):
    # 获取ps端传来的模型初始参数
    _group = [w for w in workers].append(0)
    group = dist.new_group(_group)

    param_num = 0
    for p in model.parameters():
        tmp_p = torch.zeros_like(p)
        param_num += torch.numel(tmp_p)
        dist.scatter(tensor=tmp_p, src=0, group=group)
        p.data = tmp_p
    print('Model recved successfully!')

    compression_num = int(param_num * args.ratio)
    compression_num = compression_num if compression_num > 0 else 1
    dist.gather(torch.tensor([compression_num / param_num]),
                dst=0,
                group=group)

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        learning_rate = 0.1
    else:
        learning_rate = args.lr
    optimizer = MySGD(model.parameters(), lr=learning_rate)

    if args.model in ['MnistCNN', 'AlexNet']:
        criterion = torch.nn.NLLLoss()
    elif args.model in ['Abalone', 'Bodyfat', 'Housing']:
        criterion = torch.nn.MSELoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if args.model in ['AlexNet', 'ResNet18OnCifar10']:
        decay_period = 50
    elif args.model in [
            'LROnMnist', 'LROnCifar10', 'LROnCifar100', 'Abalone', 'Bodyfat',
            'Housing'
    ]:
        decay_period = 1000000  # learning rate is constant for LR (convex) models
    else:
        decay_period = 100

    print('Begin!')

    global_clock = 0
    g_remain = [torch.zeros_like(param.data) for param in model.parameters()]
    time_logs = open("./record" + str(rank), 'w')
    for epoch in range(args.epochs):
        batch_interval = 0.0
        batch_comp_interval = 0.0
        batch_comm_interval = 0.0
        s_time = time.time()
        model.train()

        # AlexNet在指定epoch减少学习率LR
        #if args.model == 'AlexNet':
        if (epoch + 1) % decay_period == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
                print('LR Decreased! Now: {}'.format(param_group['lr']))

        epoch_train_loss = 0
        for batch_idx, (data, target) in enumerate(train_data):
            batch_start_time = time.time()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            delta_ws = optimizer.get_delta_w()

            g_remain, g_large_change = get_upload(g_remain, delta_ws,
                                                  args.ratio,
                                                  args.isCompensate)

            batch_comp_time = time.time()
            # 同步操作
            # send epoch train loss firstly
            dist.gather(loss.data, dst=0, group=group)
            for idx, param in enumerate(model.parameters()):
                dist.gather(tensor=g_large_change[idx], dst=0, group=group)
                recv = torch.zeros_like(delta_ws[idx])
                dist.scatter(tensor=recv, src=0, group=group)
                param.data = recv

            epoch_train_loss += loss.data.item()
            batch_end_time = time.time()

            batch_interval += batch_end_time - batch_start_time
            batch_comp_interval += batch_comp_time - batch_start_time
            batch_comm_interval += batch_end_time - batch_comp_time

            logs = torch.tensor([
                0.0, batch_interval / (batch_idx + 1),
                batch_comp_interval / (batch_idx + 1),
                batch_comm_interval / (batch_idx + 1)
            ])
            time_logs.write(str(logs) + '\n')
            time_logs.flush()

        print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch,
                                                  loss.data.item()))

        e_time = time.time()
        #epoch_train_loss /= len(train_data)
        #epoch_train_loss = format(epoch_train_loss, '.4f')
        # 训练结束后进行test
        #test_loss, acc = test_model(rank, model, test_data, criterion=criterion)
        acc = 0.0
        batch_interval /= batch_idx + 1
        batch_comp_interval /= batch_idx + 1
        batch_comm_interval /= batch_idx + 1
        logs = torch.tensor(
            [acc, batch_interval, batch_comp_interval, batch_comm_interval])
        time_logs.write(str(logs) + '\n')
        time_logs.flush()
        #dist.gather(tensor=logs, dst = 0, group = group)
    time_logs.close()
コード例 #11
0
def run(rank, workers, model, save_path, train_data, test_data, global_lr):
    # Get the initial model from the server
    print(workers)

    _group = [w for w in workers].append(0)
    group = dist.new_group(_group)

    for p in model.parameters():
        tmp_p = torch.zeros_like(p)
        dist.scatter(tensor=tmp_p, src=0, group=group)
        p.data = tmp_p
    print('Model recved successfully!')

    temp_lr = global_lr.get()

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        optimizer = MySGD(model.parameters(), lr=temp_lr)
    else:
        optimizer = MySGD(model.parameters(), lr=temp_lr)

    if args.model in ['MnistCNN', 'AlexNet']:
        criterion = torch.nn.NLLLoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    print('Begin!')

    # the parameters that will be transferred to the thread
    model_cache = [p.data + 0.0 for p in model.parameters()]
    global_update = [torch.zeros_like(p) for p in model.parameters()]
    local_update = [torch.zeros_like(p) for p in model.parameters()]
    it_count = Value(c_float,
                     0.)  # count update times in an iteration by local worker
    data_lock = Lock()
    update_lock = Queue()
    update_lock.put(1)

    loss_t = torch.tensor(0.0)
    receive_end = Value(c_bool, False)
    batch_communication_interval = Value(c_float, 0.0)
    stale_in_iteration = Value(c_float, 0.)

    sender_td = Thread(target=sender,
                       args=(
                           model_cache,
                           global_update,
                           local_update,
                           it_count,
                           loss_t,
                           update_lock,
                           data_lock,
                           group,
                           receive_end,
                           batch_communication_interval,
                           stale_in_iteration,
                       ),
                       daemon=True)
    sender_td.start()

    time_logs = open("./record" + str(rank), 'w')
    osp_logs = open("./log" + str(rank), 'w')
    Stale_Threshold = args.stale_threshold
    for epoch in range(args.epochs):
        batch_interval = 0.0
        batch_comp_interval = 0.0
        s_time = time.time()
        model.train()

        # Decay the learning at the specific epoch
        # learning rate should be decreased on server due to unmatched updating speed between local worker and server
        if not global_lr.empty():
            g_lr = global_lr.get()
            if args.model == 'AlexNet':
                for param_group in optimizer.param_groups:
                    param_group['lr'] = g_lr
                    print('LR Decreased! Now: {}'.format(param_group['lr']))

        for batch_idx, (data, target) in enumerate(train_data):
            batch_start_time = time.time()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            delta_ws = optimizer.get_delta_w()

            optimizer.step()

            # Aggregate local update
            data_lock.acquire()
            # aggregate loss
            loss_t.data += loss.data
            it_count.value += 1
            for g_idx, update in enumerate(local_update):
                update.data += delta_ws[g_idx].data
            data_lock.release()

            batch_computation_time = time.time()

            # Open the lock once the local update has at least one gradient
            if it_count.value == 1:
                update_lock.put(1)
            while it_count.value >= Stale_Threshold:
                pass

            if receive_end.value:
                receive_end.value = False
                for idx, param in enumerate(model.parameters()):
                    param.data = model_cache[idx]  # without local update
                    # param.data = model_cache[idx] - global_update[idx] # with local update

            batch_end_time = time.time()
            batch_interval += batch_end_time - batch_start_time
            batch_comp_interval += batch_computation_time - batch_start_time
            osp_logs.write(
                str(batch_end_time - batch_start_time) + "\t" +
                str(batch_computation_time - batch_start_time) + "\n")
            osp_logs.flush()

        print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch,
                                                  loss.data.item()))

        e_time = time.time()
        # 训练结束后进行test
        #test_loss, acc = test_model(rank, model, test_data, criterion=criterion)
        acc = 0.0
        batch_interval /= batch_idx
        batch_comp_interval /= batch_idx
        logs = torch.tensor([
            acc, batch_interval, batch_comp_interval,
            batch_communication_interval.value, stale_in_iteration.value
        ])
        time_logs.write(str(logs) + '\n')
        time_logs.flush()
        # dist.gather(tensor=logs, dst = 0, group = group)
    time_logs.close()
    sender_td.join()
コード例 #12
0
def run(size, rank):
    modell = model.CNN()
    # modell = model.AlexNet()

    optimizer = torch.optim.Adam(modell.parameters(), lr=LR)
    loss_func = torch.nn.CrossEntropyLoss()

    # size = dist.get_world_size()
    # rank = dist.get_rank()

    if (IID == True):
        train_loader = Mnist().get_train_data()
        test_data = Mnist().get_test_data()
        test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
            torch.FloatTensor
        ) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
        test_y = test_data.test_labels
    else:
        if (rank > 0):
            if (rank == 1):
                train_loader = Mnist_noniid().get_train_data1()
                test_data = Mnist_noniid().get_test_data1()
                test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
                    torch.FloatTensor
                ) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
                test_y = test_data.test_labels
            if (rank == 2):
                train_loader = Mnist_noniid().get_train_data2()
                test_data = Mnist_noniid().get_test_data2()
                test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
                    torch.FloatTensor
                ) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
                test_y = test_data.test_labels
            if (rank == 3):
                train_loader = Mnist_noniid().get_train_data3()
                test_data = Mnist_noniid().get_test_data3()
                test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
                    torch.FloatTensor
                ) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
                test_y = test_data.test_labels
            if (rank == 4):
                train_loader = Mnist_noniid().get_train_data4()
                test_data = Mnist_noniid().get_test_data4()
                test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
                    torch.FloatTensor
                ) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
                test_y = test_data.test_labels
            if (rank == 5):
                train_loader = Mnist_noniid().get_train_data5()
                test_data = Mnist_noniid().get_test_data5()
                test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
                    torch.FloatTensor
                ) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
                test_y = test_data.test_labels

    # test_x = torch.unsqueeze(test_data.test_data, dim=1).type(
    #     torch.FloatTensor) / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
    # test_y = test_data.test_labels

    group_list = []
    for i in range(size):
        group_list.append(i)
    group = dist.new_group(group_list)

    for epoch in range(MAX_EPOCH):

        modell = get_new_model(modell)
        # current_model = copy.deepcopy(modell)

        for step, (b_x, b_y) in enumerate(train_loader):
            # modell = get_new_model(modell)
            # current_model = copy.deepcopy(modell)

            output = modell(b_x)[0]
            loss = loss_func(output, b_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # new_model = copy.deepcopy(modell)

        # for param1, param2 in zip( current_model.parameters(), new_model.parameters() ):
        # dist.reduce(param2.data-param1.data, dst=0, op=dist.reduce_op.SUM, group=group)

        for param in modell.parameters():
            dist.reduce(param, dst=0, op=dist.reduce_op.SUM, group=group)

        test_output, last_layer = modell(test_x)
        pred_y = torch.max(test_output, 1)[1].data.numpy()
        accuracy = float(
            (pred_y == test_y.data.numpy()).astype(int).sum()) / float(
                test_y.size(0))
        print('Epoch: ', epoch, ' Rank: ', rank,
              '| train loss: %.4f' % loss.data.numpy(),
              '| test accuracy: %.2f' % accuracy)
コード例 #13
0
def run(rank, model, train_pics, train_bsz, g_lr):
    workers = [v + 1 for v in range(args.workers_num)]
    print(workers)

    _group = [w for w in workers].append(rank)
    group = dist.new_group(_group)

    for p in model.parameters():
        scatter_p_list = [p.data for _ in range(len(workers) + 1)]
        dist.scatter(tensor=p.data, scatter_list=scatter_p_list, group=group)

    # initialize learning rate

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        global_lr = 0.1
    else:
        global_lr = 0.01

    for w in workers:
        g_lr.put(global_lr)
    print('Model Sent Finished!')

    print('Begin!')

    epoch_train_loss = 0.0
    trainloss_file = './trainloss' + args.model + '.txt'
    log_file = './log' + args.model + '.txt'
    if (os.path.isfile(trainloss_file)):
        os.remove(trainloss_file)
    if (os.path.isfile(log_file)):
        os.remove(log_file)
    f_trainloss = open(trainloss_file, 'a')
    f_log = open(log_file, 'a')

    tmp = [(0, 0) for _ in range(
        int(
            math.ceil(
                int(train_pics * args.data_ratio) /
                (len(workers) * train_bsz))))]

    s_time = time.time()
    epoch_time = s_time
    total_iteration_time = 0
    iteration_times_count_epoch = 0
    iteration_times_epoch = len(tmp) * len(workers)
    if args.model in ['AlexNet', 'ResNet18OnCifar10']:
        decay_period = 50
    else:
        decay_period = 100

    global_clock = 0
    epoch_clock = 0
    real_epoch = 0
    for epoch in range(args.epochs):
        for batch_idx, (_, _) in enumerate(tmp):
            batch_start_time = time.time()

            # receive the list of train loss and local iteration count from workers
            loss_it_list = [
                torch.tensor([0.0, 0.0]) for _ in range(len(workers) + 1)
            ]
            dist.gather(tensor=torch.tensor([0.0, 1.0]),
                        gather_list=loss_it_list,
                        group=group)

            loss_avg = [loss_it[0] / loss_it[1] for loss_it in loss_it_list]
            epoch_train_loss += sum(loss_avg).item() / len(workers)
            iteration_times = sum(loss_it_list)[1]
            total_iteration_time += iteration_times
            iteration_times_count_epoch += iteration_times

            # receive global update from each worker
            for update_idx, param in enumerate(model.parameters()):
                tensor = torch.zeros_like(param.data)
                gather_list = [
                    torch.zeros_like(param.data)
                    for _ in range(len(workers) + 1)
                ]
                dist.gather(tensor=tensor,
                            gather_list=gather_list,
                            group=group)
                # here we only use average temperally for simplicity
                tensor = sum(gather_list) / len(workers)
                param.data -= tensor
            # send updated model back to workers
            for param_idx, param in enumerate(model.parameters()):
                scatter_list = [param.data for _ in range(len(workers) + 1)]
                dist.scatter(tensor=torch.zeros_like(param.data),
                             scatter_list=scatter_list,
                             group=group)

            global_clock += 1
            epoch_clock += 1
            # Decay the learning at the specific epoch
            # not update intra the epoch
            temp_epoch = int(total_iteration_time / iteration_times_epoch) + 1
            if temp_epoch > real_epoch:
                real_epoch = temp_epoch
                if real_epoch % decay_period == 0:
                    global_lr *= 0.1
                    for w in workers:
                        g_lr.put(global_lr)
                    print('LR Decreased! Now: {}'.format(global_lr))

            # evaluate the time of each iteration
            batch_end_time = time.time()
            batch_time_interval = batch_end_time - batch_start_time
            f_log.write(
                str(batch_time_interval) + "\t" + str(iteration_times) + "\n")
            f_log.flush()

            if iteration_times_count_epoch >= iteration_times_epoch:
                e_epoch_time = time.time()
                iteration_times_count_epoch -= iteration_times_epoch
                # test_acc, batch_interval, batch_comp_interval, batch_comm_interval
                logs = torch.tensor([0.0, 0.0, 0.0, 0.0])
                logs_list = [
                    torch.zeros_like(logs) for _ in range(len(workers) + 1)
                ]
                # dist.gather(tensor=logs, gather_list=logs_list, group=group)
                test_acc, batch_interval, batch_comp_interval, batch_comm_interval = zip(
                    *logs_list)
                test_acc = sum(test_acc) / len(workers)
                batch_interval = sum(batch_interval) / len(workers)
                batch_comp_interval = sum(batch_comp_interval) / len(workers)
                batch_comm_interval = sum(batch_comm_interval) / len(workers)

                f_trainloss.write(
                    str(args.this_rank) + "\t" +
                    str(epoch_train_loss / float(epoch_clock)) + "\t" +
                    str(0) + "\t" + str(e_epoch_time - epoch_time) + "\t" +
                    str(e_epoch_time - s_time) + "\t" +
                    str(int(total_iteration_time / iteration_times_epoch)) +
                    "\t" + str(test_acc.item()) + "\t" +
                    str(batch_interval.item()) + "\t" +
                    str(batch_comp_interval.item()) + "\t" +
                    str(batch_comm_interval.item()) + "\t" +
                    str(global_clock) + '\n')
                print("total_iteration_time:{}, iteration_times:{}".format(
                    total_iteration_time, iteration_times_epoch))
                f_trainloss.flush()
                epoch_time = e_epoch_time
                epoch_train_loss = 0.0
                epoch_clock = 0

        print('Done Epoch {}/{}!'.format(epoch + 1, args.epochs))

    f_trainloss.close()