コード例 #1
0
def exchange(model, world_size, rank):
    old_model = copy.deepcopy(model)
    for param in old_model.parameters():
        dist.isend( param.data, dst=(rank+1)%world_size )
    for param in model.parameters():
        dist.recv( param.data, src=(rank-1)%world_size )
    return model
コード例 #2
0
    def test_send_recv(self):
        rank = dist.get_rank()
        tensor = _build_tensor(rank + 1)
        for dest in range(0, dist.get_world_size()):
            if dest == rank:
                continue
            dist.send(tensor, dest)

        for src in range(0, dist.get_world_size()):
            if src == rank:
                continue
            tensor = _build_tensor(src + 1, value=-1)
            expected_tensor = _build_tensor(src + 1)
            dist.recv(tensor, src)
            self.assertEqual(tensor, expected_tensor)

        self._barrier()
コード例 #3
0
    def test_isend(self):
        rank = dist.get_rank()
        world_size = dist.get_world_size()

        if rank == 0:
            requests = [
                dist.isend(_build_tensor(dest, 10), dest)
                for dest in range(1, world_size)
            ]
            for request in requests:
                request.wait()
                self.assertTrue(request.is_completed())
        else:
            tensor = _build_tensor(rank, -1)
            dist.recv(tensor, 0)
            self.assertEqual(tensor, _build_tensor(rank, 10))

        self._barrier()
コード例 #4
0
    def test_send_recv_any_source(self):
        rank = dist.get_rank()
        tensor = _build_tensor(10, rank)
        for dest in range(0, dist.get_world_size()):
            if dest == rank:
                continue
            dist.send(tensor, dest)

        recv_ranks = set()
        for src in range(0, dist.get_world_size()):
            if src == rank:
                continue
            tensor = _build_tensor(10, value=-1)
            sender = dist.recv(tensor)
            self.assertTrue(tensor.eq(sender).all())
            recv_ranks.add(sender)

        self.assertEqual(len(recv_ranks), dist.get_world_size() - 1)
        self._barrier()
コード例 #5
0
def run(rank, model, train_data, test_data, queue, param_q, stop_flag):
    # Get the initial model from the server
    while True:
        if not param_q.empty():
            param_dict = param_q.get()
            tmp = OrderedDict(
                map(lambda item: (item[0], torch.from_numpy(item[1])),
                    param_dict.items()))
            model.load_state_dict(tmp)
            break
    print('Model recved successfully!')

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        optimizer = MySGD(model.parameters(), lr=0.1)
    else:
        optimizer = MySGD(model.parameters(), lr=0.01)

    if args.model in ['MnistCNN', 'AlexNet']:
        criterion = torch.nn.NLLLoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if args.model in ['AlexNet', 'ResNet18OnCifar10']:
        decay_period = 50
    else:
        decay_period = 100
    print('Begin!')

    time_logs = open("./record" + str(rank), 'w')

    for epoch in range(int(args.epochs)):
        model.train()
        # Decay the learning at the specific epoch
        #if args.model == 'AlexNet':
        if (epoch + 1) % decay_period == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
                print('LR Decreased! Now: {}'.format(param_group['lr']))
        epoch_train_loss = 0
        for batch_idx, (data, target) in enumerate(train_data):
            it_start = time.time()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            delta_ws = optimizer.get_delta_w()

            it_comp_end = time.time()
            # noinspection PyBroadException
            try:
                if delta_ws:
                    queue.put({
                        rank:
                        [loss.data.numpy(),
                         np.array(args.train_bsz), False]
                    })

                for delta in delta_ws:
                    dist.send(tensor=delta, dst=0)

                for idx, param in enumerate(model.parameters()):
                    tmp_tensor = torch.zeros_like(param.data)
                    dist.recv(tensor=tmp_tensor, src=0)
                    param.data = tmp_tensor

                #print('Rank {}, Epoch {}, Batch {}/{}, Loss:{}'
                #     .format(rank, epoch, batch_idx, len(train_data), loss.data[0]))
            except Exception as e:
                print(str(e))
                print('Should Stop: {}!'.format(stop_flag.value))
                break

            it_comm_end = time.time()
            it_duration = it_comm_end - it_start
            it_comp_duration = it_comp_end - it_start
            it_comm_duration = it_comm_end - it_comp_end
            time_logs.write(
                str(it_duration) + "\t" + str(it_comp_duration) + "\t" +
                str(it_comm_duration) + "\n")
            time_logs.flush()

        # test the model
        print("test Model:", epoch)
        # test_model(rank, model, test_data, criterion=criterion)
        if stop_flag.value:
            break
    queue.put({rank: [[], [], True]})
    time_logs.close()
    print("Worker {} has completed epoch {}!".format(args.this_rank, epoch))
コード例 #6
0
ファイル: param_server.py プロジェクト: AragornThorongil/ICPP
def run(model, test_data, queue, param_q, stop_signal, train_pics):
    if args.model == 'MnistCNN':
        criterion = torch.nn.NLLLoss()
    else:
        criterion = torch.nn.NLLLoss()

    # Transform tensor to numpy
    tmp = map(lambda item: (item[0], item[1].numpy()), model.state_dict().items())
    _tmp = OrderedDict(tmp)
    workers = [v+1 for v in range(args.workers_num)]
    for _ in workers:
        param_q.put(_tmp)
    print('Model Sent Finished!')

    print('Begin!')

    worker_gradient_list = []

    epoch_train_loss = 0
    iteration_in_epoch = 0
    data_size_epoch = 0   # len(train_data), one epoch
    epoch_count = 0
    staleness_sum_suqare_epoch = 0
    staleness_sum_epoch = 0

    staleness = 0
    learner_staleness = {l: 0 for l in workers}
    s_time = time.time()
    epoch_time = s_time

    # In SSP, the fast workers have to wait the slowest worker a given duration
    # The fast worker exceeding the duration will be pushed into the queue to wait
    trainloss_file = './trainloss' + args.model + '.txt'
    staleness_file = './staleness' + args.model + ".txt"
    log_file = './log' + args.model + ".txt"

    if(os.path.isfile(trainloss_file)):
        os.remove(trainloss_file)
    if(os.path.isfile(staleness_file)):
        os.remove(staleness_file)
    if (os.path.isfile(log_file)):
        os.remove(log_file)
    f_trainloss = open(trainloss_file, 'a')
    f_staleness = open(staleness_file, 'a')
    f_log = open(log_file, 'a')

    it_start_time = time.time()
    global_clock = 0
    while True:
        if not queue.empty():
            batch_start_time = time.time()

            tmp_dict = queue.get()
            rank_src = list(tmp_dict.keys())[0]
            isWorkerEnd = tmp_dict[rank_src][2]
            if isWorkerEnd:
                print("Worker {} has completed all its data computation!".format(rank_src))
                learner_staleness.pop(rank_src)
                if (len(learner_staleness) == 0):
                    f_trainloss.close()
                    f_staleness.close()
                    stop_signal.put(1)
                    print('Epoch is done: {}'.format(epoch_count))
                    break
                continue

            batch_receive_time = time.time()

            # Dictionary:k:index, v:delta_w
            tmp_gradient = []
            for param in model.parameters():
                tmp = torch.zeros_like(param.data)
                dist.recv(tensor = tmp, src=rank_src)
                tmp_gradient.append(tmp)

            # add the local to k-list cache
            worker_gradient_list.append(tmp_gradient)
            iteration_loss = tmp_dict[rank_src][0]
            batch_size = tmp_dict[rank_src][1]

            iteration_in_epoch += 1
            epoch_train_loss += iteration_loss
            data_size_epoch += batch_size

            stale = int(staleness - learner_staleness[rank_src])
            learner_staleness[rank_src] = staleness
            staleness_sum_epoch += stale
            # staleness_sum_suqare_epoch += stale ** 2

            batch_preprocess_time = time.time()

            # return current model to worker
            for idx, param in enumerate(model.parameters()):
                dist.send(tensor=param.data, dst=rank_src)

            batch_send_time = time.time()


            # update current model while number of mini-batchs =  args.stale_threshold
            if len(worker_gradient_list) >= args.stale_threshold:
                global_clock += 1
                for idx, param in enumerate(model.parameters()):
                    delta_ws = torch.zeros_like(param.data)
                    for w_g_idx in range(len(worker_gradient_list)):
                        delta_ws += worker_gradient_list[w_g_idx][idx]
                    param.data -= delta_ws/len(worker_gradient_list)

                f_log.write(str(len(worker_gradient_list))+"\n")
                f_log.flush()
                worker_gradient_list = []
                staleness += 1      # update system clock

            batch_end_time = time.time()
            batch_interval = batch_end_time - batch_start_time
            batch_receive_interval = batch_receive_time - batch_start_time
            batch_send_interval = batch_send_time - batch_preprocess_time

            it_iterval = batch_end_time - it_start_time
            it_start_time = time.time()
            #print('Done From Rank {}, Staleness {}!'
            #      .format(rank_src, stale))
            # epoch, rank, batch size, stale
            f_staleness.write(str(epoch_count) +
                        "\t" + str(rank_src) +
                        "\t" + str(batch_size) +
                        "\t" + str(stale) +
                        "\t" + str(batch_interval) +
                        "\t" + str(batch_receive_interval) +
                        "\t" + str(batch_send_interval) +
                        "\t" + str(it_iterval) +
                        "\t" + str(global_clock) +
                        '\n')

            # once reach an epoch, count the average train loss
            if(data_size_epoch >= train_pics):
                e_epoch_time = time.time()
                #variance of stale
                # diversity_stale = (staleness_sum_suqare_epoch/iteration_in_epoch)\
                #                  - (staleness_sum_epoch/iteration_in_epoch)**2
                diversity_stale = 0.0
                staleness_sum_suqare_epoch = 0
                staleness_sum_epoch = 0
                # test_loss, test_acc = test_model(dist.get_rank(), model, test_data, criterion=criterion)
                test_acc = 0
                # rank, trainloss, variance of stalness, time in one epoch, time till now
                f_trainloss.write(str(args.this_rank) +
                                  "\t" + str(epoch_train_loss/float(iteration_in_epoch)) +
								  "\t" + str(0) +
                                  "\t" + str(e_epoch_time - epoch_time) +
                                  "\t" + str(e_epoch_time - s_time) +
                                  "\t" + str(e_epoch_time - s_time) +
                                  "\t" + str(epoch_count) +
                                  "\t" + str(test_acc) +
                                  "\t" + str(diversity_stale) +
                                  "\t" + str(global_clock) +
                                  '\n')
                f_trainloss.flush()
                f_staleness.flush()
                iteration_in_epoch = 0
                epoch_count += 1
                epoch_train_loss = 0
                data_size_epoch = 0
                epoch_time = e_epoch_time

            # The training stop
            if(epoch_count >= args.epochs):
                f_trainloss.close()
                f_staleness.close()
                f_log.close()
                stop_signal.put(1)
                print('Epoch is done: {}'.format(epoch_count))
                break

        e_time = time.time()
        if (e_time - s_time) >= float(args.timeout):
            f_trainloss.close()
            f_staleness.close()
            stop_signal.put(1)
            print('Time up: {}, Stop Now!'.format(e_time - s_time))
            break
コード例 #7
0
def run(model, test_data, queue, param_q, stop_signal, train_pics):
    if args.model == 'MnistCNN':
        criterion = torch.nn.NLLLoss()
    else:
        criterion = torch.nn.NLLLoss()

    # Transform tensor to numpy
    tmp = map(lambda item: (item[0], item[1].numpy()),
              model.state_dict().items())
    _tmp = OrderedDict(tmp)
    workers = [v + 1 for v in range(args.workers_num)]
    for _ in workers:
        param_q.put(_tmp)
    print('Model Sent Finished!')

    print('Begin!')

    epoch_train_loss = 0
    iteration_in_epoch = 0
    data_size_epoch = 0  # len(train_data), one epoch
    epoch_count = 0
    staleness_sum_suqare_epoch = 0
    staleness_sum_epoch = 0

    staleness = 0  # the global clock
    learner_staleness = {l: 0 for l in workers}
    s_time = time.time()
    epoch_time = s_time

    # In SSP, the fast workers have to wait the slowest worker a given duration
    # The fast worker exceeding the duration will be pushed into the queue to wait
    stale_stack = []

    trainloss_file = './trainloss' + args.model + '.txt'
    staleness_file = './staleness' + args.model + ".txt"

    if (os.path.isfile(trainloss_file)):
        os.remove(trainloss_file)
    if (os.path.isfile(staleness_file)):
        os.remove(staleness_file)
    f_trainloss = open(trainloss_file, 'a')
    f_staleness = open(staleness_file, 'a')
    global_clock = 0
    while True:
        if not queue.empty():
            it_start_time = time.time()

            tmp_dict = queue.get()
            rank_src = list(tmp_dict.keys())[0]
            isWorkerEnd = tmp_dict[rank_src][2]
            if isWorkerEnd:
                print(
                    "Worker {} has completed all its data computation!".format(
                        rank_src))
                learner_staleness.pop(rank_src)
                if (len(learner_staleness) == 0):
                    f_trainloss.close()
                    f_staleness.close()
                    stop_signal.put(1)
                    print('Epoch is done: {}'.format(epoch_count))
                    break
                continue

            stale = int(staleness - learner_staleness[rank_src])
            staleness_sum_epoch += stale
            # staleness_sum_suqare_epoch += stale**2
            staleness_sum_suqare_epoch += 0.0
            staleness += 1
            learner_staleness[rank_src] = staleness
            stale_stack.append(rank_src)

            # recv gradients of the worker and update current model
            for idx, param in enumerate(model.parameters()):
                tmp_tensor = torch.zeros_like(param.data)
                dist.recv(tensor=tmp_tensor, src=rank_src)
                param.data -= tmp_tensor / (stale + 1)

            global_clock += 1
            iteration_loss = tmp_dict[rank_src][0]
            batch_size = tmp_dict[rank_src][1]

            iteration_in_epoch += 1
            epoch_train_loss += iteration_loss
            data_size_epoch += batch_size

            # judge if the staleness exceed the staleness threshold in SSP
            outOfStale = False
            for stale_each_worker in learner_staleness:
                if (stale_each_worker not in stale_stack) & \
                    (staleness - learner_staleness[stale_each_worker] > args.stale_threshold):
                    outOfStale = True
                    break
            if not outOfStale:
                for i in range(len(stale_stack)):
                    rank_wait = stale_stack.pop()
                    # update the value of staleness in the corresponding learner
                    learner_staleness[rank_wait] = staleness
                    for idx, param in enumerate(model.parameters()):
                        dist.send(tensor=param.data, dst=rank_wait)
            else:
                continue

            it_end_time = time.time()
            f_staleness.write(
                str(epoch_count) + "\t" + str(rank_src) + "\t" +
                str(batch_size) + "\t" + str(stale) + "\t" +
                str(it_end_time - it_start_time) + "\t" + str(global_clock) +
                '\n')
            f_staleness.flush()

            # once reach an epoch, count the average train loss
            if (data_size_epoch >= train_pics):
                e_epoch_time = time.time()
                #variance of stale
                diversity_stale = (staleness_sum_suqare_epoch/iteration_in_epoch)\
                                 - (staleness_sum_epoch/iteration_in_epoch)**2
                staleness_sum_suqare_epoch = 0
                staleness_sum_epoch = 0
                #test_loss, test_acc = test_model(dist.get_rank(), model, test_data, criterion=criterion)
                test_acc = 0.0
                test_acc = 0.0
                # rank, trainloss, variance of stalness, time in one epoch, time till now
                f_trainloss.write(
                    str(args.this_rank) + "\t" +
                    str(epoch_train_loss / float(iteration_in_epoch)) + "\t" +
                    str(diversity_stale) + "\t" +
                    str(e_epoch_time - epoch_time) + "\t" +
                    str(e_epoch_time - s_time) + "\t" + str(epoch_count) +
                    "\t" + str(test_acc) + "\t" + str(global_clock) + '\n')
                f_trainloss.flush()
                iteration_in_epoch = 0
                epoch_count += 1
                epoch_train_loss = 0
                data_size_epoch = 0
                epoch_time = e_epoch_time

            # The training stop
            if (epoch_count >= args.epochs):
                f_trainloss.close()
                f_staleness.close()
                stop_signal.put(1)
                print('Epoch is done: {}'.format(epoch_count))
                break

        e_time = time.time()
        if (e_time - s_time) >= float(args.timeout):
            f_trainloss.close()
            f_staleness.close()
            stop_signal.put(1)
            print('Time up: {}, Stop Now!'.format(e_time - s_time))
            break
コード例 #8
0
def get_new_model(model):
    for param in model.parameters():
        dist.recv(param.data, src=0)
    # print(dist.get_rank())
    return model
コード例 #9
0
def run(rank, model, train_data, test_data, queue, param_q, stop_flag):
    # 获取ps端传来的模型初始参数
    while True:
        if not param_q.empty():
            param_dict = param_q.get()
            tmp = OrderedDict(
                map(lambda item: (item[0], torch.from_numpy(item[1])),
                    param_dict.items()))
            model.load_state_dict(tmp)
            break
    print('Model recved successfully!')

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        optimizer = MySGD(model.parameters(), lr=0.1)
    else:
        optimizer = MySGD(model.parameters(), lr=0.01)

    if args.model in ['MnistCNN', 'AlexNet']:
        criterion = torch.nn.NLLLoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if args.model in ['AlexNet', 'ResNet18OnCifar10']:
        decay_period = 50
    else:
        decay_period = 100
    print('Begin!')

    time_logs = open("./record" + str(rank), 'w')
    for epoch in range(int(args.epochs)):
        batch_interval = 0.0
        batch_comp_interval = 0.0
        batch_comm_interval = 0.0
        batch_push_interval = 0.0
        batch_pull_interval = 0.0
        model.train()
        # AlexNet在指定epoch减少学习率LR
        #if args.model == 'AlexNet':
        if (epoch + 1) % decay_period == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
                print('LR Decreased! Now: {}'.format(param_group['lr']))
        epoch_train_loss = 0
        for batch_idx, (data, target) in enumerate(train_data):
            batch_start_time = time.time()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            delta_ws = optimizer.get_delta_w()

            batch_comp_time = time.time()
            # noinspection PyBroadException
            try:  # 捕获异常,异常来源于ps进程的停止
                if delta_ws:
                    queue.put({
                        rank:
                        [loss.data.numpy(),
                         np.array(args.train_bsz), False]
                    })
                for delta in delta_ws:
                    dist.send(tensor=delta, dst=0)

                batch_push_time = time.time()

                for idx, param in enumerate(model.parameters()):
                    tmp_tensor = torch.zeros_like(param.data)
                    dist.recv(tensor=tmp_tensor, src=0)
                    param.data = tmp_tensor

                batch_tmp_time = time.time()
                batch_pull_time = time.time()
                #print('Rank {}, Epoch {}, Batch {}/{}, Loss:{}'
                #     .format(rank, epoch, batch_idx, len(train_data), loss.data[0]))
            except Exception as e:
                print(str(e))
                print('Should Stop: {}!'.format(stop_flag.value))
                break

            batch_interval += batch_pull_time - batch_start_time
            batch_comp_interval += batch_comp_time - batch_start_time
            batch_comm_interval += batch_pull_time - batch_comp_time
            batch_push_interval += batch_push_time - batch_comp_time
            batch_pull_interval += batch_pull_time - batch_push_time
            b_interval = batch_interval / (batch_idx + 1)
            b_comp_interval = batch_comp_interval / (batch_idx + 1)
            b_comm_interval = batch_comm_interval / (batch_idx + 1)
            b_push_interval = batch_push_interval / (batch_idx + 1)
            b_pull_interval = batch_pull_interval / (batch_idx + 1)
            logs = torch.tensor([
                0.0, b_interval, b_comp_interval, b_comm_interval,
                b_push_interval, b_pull_interval,
                batch_pull_time - batch_tmp_time
            ])
            time_logs.write(str(logs) + '\n')
            time_logs.flush()

        batch_interval /= batch_idx
        batch_comp_interval /= batch_idx
        batch_comm_interval /= batch_idx
        batch_push_interval /= batch_idx
        batch_pull_interval /= batch_idx
        logs = torch.tensor([
            0.0, batch_interval, batch_comp_interval, batch_comm_interval,
            batch_push_interval, batch_pull_interval
        ])
        time_logs.write(str(epoch) + '\t' + str(logs) + '\n')
        time_logs.flush()
        # 训练结束后进行test
        print("test Model:", epoch)
        # test_model(rank, model, test_data, criterion=criterion)
        if stop_flag.value:
            break
    queue.put({rank: [[], [], True]})
    time_logs.close()
    print("Worker {} has completed epoch {}!".format(args.this_rank, epoch))