예제 #1
0
    def _test_scatter_helper(self, group, group_id, rank):
        for dest in group:
            tensor = _build_tensor(dest + 1, -1)
            expected_tensor = _build_tensor(dest + 1, rank)
            tensors = (
                [_build_tensor(dest + 1, i) for i in group] if rank == dest else []
            )
            dist.scatter(tensor, src=dest, scatter_list=tensors, group=group_id)
            self.assertEqual(tensor, expected_tensor)

        self._barrier()
예제 #2
0
def sender(model_cache, global_update, local_update, it_count, loss_t,
           update_lock, data_lock, group, receive_end,
           batch_communication_interval, stale_in_iteration):
    comm_count = 0
    while True:
        # this is a queue that is controled by computer module
        # At lease one gradient has been generated, the sender has gradient to send
        update_lock.get()

        # Note: A lock should be here before accessing to local_update
        # copy local update to global update, and then set local update to be 0
        data_lock.acquire()
        loss = loss_t.data
        loss_t.data = torch.tensor(0.)
        it_times = it_count.value
        it_count.value = 0.
        for idx, update in enumerate(global_update):
            update.data = local_update[idx].data
            local_update[idx].data = torch.zeros_like(update.data)
        data_lock.release()

        comm_s = time.time()
        loss_it = torch.tensor([float(loss), it_times])
        dist.gather(tensor=loss_it, dst=0, group=group)
        for idx, update in enumerate(global_update):
            dist.gather(tensor=update.data, dst=0, group=group)

        for idx, param in enumerate(model_cache):
            dist.scatter(tensor=param.data, src=0, group=group)
        receive_end.value = True

        comm_e = time.time()
        comm_count += 1
        # compute average communication time of an iteration
        batch_communication_interval.value = (
            batch_communication_interval.value * (comm_count - 1) +
            (comm_e - comm_s)) / comm_count
        stale_in_iteration.value = (stale_in_iteration.value *
                                    (comm_count - 1) + it_times) / comm_count
    return
예제 #3
0
def run(rank, workers, model, save_path, train_data, test_data):
    # Get the initial model from the server
    _group = [w for w in workers].append(0)
    group = dist.new_group(_group)

    for p in model.parameters():
        tmp_p = torch.zeros_like(p)
        dist.scatter(tensor=tmp_p, src=0, group=group)
        p.data = tmp_p
    print('Model recved successfully!')

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        optimizer = MySGD(model.parameters(), lr=0.1)
    else:
        optimizer = MySGD(model.parameters(), lr=0.01)

    if args.model in ['MnistCNN', 'AlexNet']:
        criterion = torch.nn.NLLLoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if args.model in ['AlexNet', 'ResNet18OnCifar10']:
        decay_period = 50
    else:
        decay_period = 100

    print('Begin!')

    time_logs = open("./record" + str(rank), 'w')
    for epoch in range(args.epochs):
        batch_interval = 0.0
        batch_comp_interval = 0.0
        batch_comm_interval = 0.0
        s_time = time.time()
        model.train()

        # Reduce the learning rate LR in some specific epochs
        #if args.model == 'AlexNet':
        if (epoch + 1) % decay_period == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
                print('LR Decreased! Now: {}'.format(param_group['lr']))

        epoch_train_loss = 0
        for batch_idx, (data, target) in enumerate(train_data):
            batch_start_time = time.time()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            delta_ws = optimizer.get_delta_w()

            batch_comp_time = time.time()
            # Synchronization
            # send epoch train loss firstly
            dist.gather(loss.data, dst=0, group=group)
            for idx, param in enumerate(model.parameters()):
                dist.gather(tensor=delta_ws[idx], dst=0, group=group)
                recv = torch.zeros_like(delta_ws[idx])
                dist.scatter(tensor=recv, src=0, group=group)
                param.data = recv

            epoch_train_loss += loss.data.item()
            batch_end_time = time.time()

            batch_interval += batch_end_time - batch_start_time
            batch_comp_interval += batch_comp_time - batch_start_time
            batch_comm_interval += batch_end_time - batch_comp_time

            logs = torch.tensor([
                0.0, batch_interval / (batch_idx + 1),
                batch_comp_interval / (batch_idx + 1),
                batch_comm_interval / (batch_idx + 1)
            ])
            time_logs.write(str(logs) + '\n')
            time_logs.flush()

        print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch,
                                                  loss.data.item()))

        e_time = time.time()
        #epoch_train_loss /= len(train_data)
        #epoch_train_loss = format(epoch_train_loss, '.4f')
        # test the model
        #test_loss, acc = test_model(rank, model, test_data, criterion=criterion)
        acc = 0.0
        batch_interval /= batch_idx
        batch_comp_interval /= batch_idx
        batch_comm_interval /= batch_idx
        logs = torch.tensor(
            [acc, batch_interval, batch_comp_interval, batch_comm_interval])
        time_logs.write(str(logs) + '\n')
        time_logs.flush()
        #dist.gather(tensor=logs, dst = 0, group = group)
    time_logs.close()
예제 #4
0
def run(rank, model, train_pics, train_bsz):
    workers = [v + 1 for v in range(args.workers_num)]
    _group = [w for w in workers].append(rank)
    group = dist.new_group(_group)

    for p in model.parameters():
        scatter_p_list = [p.data for _ in range(len(workers) + 1)]
        dist.scatter(tensor=p.data, scatter_list=scatter_p_list, group=group)

    print('Model Sent Finished!')

    print('Begin!')

    trainloss_file = './trainloss' + args.model + '.txt'
    if (os.path.isfile(trainloss_file)):
        os.remove(trainloss_file)
    f_trainloss = open(trainloss_file, 'a')

    tmp = [
        (0, 0)
        for _ in range(int(math.ceil(train_pics / (len(workers) * train_bsz))))
    ]

    g_list = [[torch.zeros_like(param.data) for param in model.parameters()]
              for _ in range(len(workers) + 1)]

    s_time = time.time()
    epoch_time = s_time
    global_clock = 0
    epoch_train_loss = 0.0
    sparsification_ratio = 0.0
    for epoch in range(args.epochs):
        for batch_idx, (_, _) in enumerate(tmp):
            # receive the list of train loss from workers
            info_list = [torch.tensor([0.0]) for _ in range(len(workers) + 1)]
            dist.gather(tensor=torch.tensor([0.0]),
                        gather_list=info_list,
                        group=group)
            batch_loss = sum(info_list).item() / len(workers)
            epoch_train_loss += batch_loss

            sparsification_ratio_list = [
                torch.tensor([0.0]) for _ in range(len(workers) + 1)
            ]
            dist.gather(tensor=torch.tensor([0.0]),
                        gather_list=sparsification_ratio_list,
                        group=group)
            batch_ratio = sum(sparsification_ratio_list).item() / len(workers)
            sparsification_ratio += batch_ratio

            param_idx = 0
            g_quare_sum = torch.tensor([0.])
            for layer_idx, param in enumerate(model.parameters()):
                tensor = torch.zeros_like(param.data)
                # FIXME FIXED:gather_list中的每个Tensor都必须是新的对象,否则会出问题
                gather_list = [
                    torch.zeros_like(param.data)
                    for _ in range(len(workers) + 1)
                ]
                dist.gather(tensor=tensor,
                            gather_list=gather_list,
                            group=group)
                tensor = sum(gather_list) / len(workers)
                param.data -= tensor
                scatter_list = [param.data for _ in range(len(workers) + 1)]
                dist.scatter(tensor=tensor,
                             scatter_list=scatter_list,
                             group=group)
                param_idx += 1

            global_clock += 1

            f_trainloss.write(
                str(args.this_rank) + "\t" +
                str(epoch_train_loss / float(batch_idx + 1)) + "\t" +
                str(batch_loss) + "\t" + str(0) + "\t" + str(0) + "\t" +
                str(epoch) + "\t" + str(0) + "\t" + str(0) + "\t" + str(0) +
                "\t" + str(batch_ratio) + "\t" +
                str(sparsification_ratio / float(batch_idx + 1)) + "\t" +
                str(global_clock) + '\n')
            # f_trainloss.flush()

            #print('Done {}/{}!'.format(batch_idx, len(tmp)))
        print('Done Epoch {}/{}!'.format(epoch + 1, args.epochs))

        e_epoch_time = time.time()

        # test_acc, batch_interval, batch_comp_interval, batch_comm_interval
        logs = torch.tensor([0.0, 0.0, 0.0, 0.0])
        logs_list = [torch.zeros_like(logs) for _ in range(len(workers) + 1)]
        #dist.gather(tensor = logs, gather_list = logs_list, group = group)
        test_acc, batch_interval, batch_comp_interval, batch_comm_interval = zip(
            *logs_list)
        test_acc = sum(test_acc) / len(workers)
        batch_interval = sum(batch_interval) / len(workers)
        batch_comp_interval = sum(batch_comp_interval) / len(workers)
        batch_comm_interval = sum(batch_comm_interval) / len(workers)

        # f_trainloss.write(str(args.this_rank) +
        #                   "\t" + str(epoch_train_loss / float(batch_idx+1)) +
        #                   "\t" + str(0) +
        #                   "\t" + str(e_epoch_time - epoch_time) +
        #                   "\t" + str(e_epoch_time - s_time) +
        #                   "\t" + str(epoch) +
        #                   "\t" + str(test_acc.item()) +
        #                   "\t" + str(batch_interval.item()) +
        #                   "\t" + str(batch_comp_interval.item()) +
        #                   "\t" + str(batch_comm_interval.item()) +
        #                   "\t" + str(sparsification_ratio / float(batch_idx+1)) +
        #                   "\t" + str(global_clock) +
        #                   '\n')
        f_trainloss.flush()
        epoch_time = e_epoch_time
        epoch_train_loss = 0.0
        sparsification_ratio = 0.0

        if (epoch + 1) % 2 == 0:
            if not os.path.exists('model_state'):
                os.makedirs('model_state')
            torch.save(
                model.state_dict(), 'model_state' + '/' + args.model + '_' +
                str(epoch + 1) + '.pkl')

    f_trainloss.close()
예제 #5
0
def run(rank, workers, model, save_path, train_data, test_data):
    # 获取ps端传来的模型初始参数
    _group = [w for w in workers].append(0)
    group = dist.new_group(_group)

    param_num = 0
    for p in model.parameters():
        tmp_p = torch.zeros_like(p)
        param_num += torch.numel(tmp_p)
        dist.scatter(tensor=tmp_p, src=0, group=group)
        p.data = tmp_p
    print('Model recved successfully!')

    compression_num = int(param_num * args.ratio)
    compression_num = compression_num if compression_num > 0 else 1
    dist.gather(torch.tensor([compression_num / param_num]),
                dst=0,
                group=group)

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        learning_rate = 0.1
    else:
        learning_rate = args.lr
    optimizer = MySGD(model.parameters(), lr=learning_rate)

    if args.model in ['MnistCNN', 'AlexNet']:
        criterion = torch.nn.NLLLoss()
    elif args.model in ['Abalone', 'Bodyfat', 'Housing']:
        criterion = torch.nn.MSELoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if args.model in ['AlexNet', 'ResNet18OnCifar10']:
        decay_period = 50
    elif args.model in [
            'LROnMnist', 'LROnCifar10', 'LROnCifar100', 'Abalone', 'Bodyfat',
            'Housing'
    ]:
        decay_period = 1000000  # learning rate is constant for LR (convex) models
    else:
        decay_period = 100

    print('Begin!')

    global_clock = 0
    g_remain = [torch.zeros_like(param.data) for param in model.parameters()]
    time_logs = open("./record" + str(rank), 'w')
    for epoch in range(args.epochs):
        batch_interval = 0.0
        batch_comp_interval = 0.0
        batch_comm_interval = 0.0
        s_time = time.time()
        model.train()

        # AlexNet在指定epoch减少学习率LR
        #if args.model == 'AlexNet':
        if (epoch + 1) % decay_period == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
                print('LR Decreased! Now: {}'.format(param_group['lr']))

        epoch_train_loss = 0
        for batch_idx, (data, target) in enumerate(train_data):
            batch_start_time = time.time()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            delta_ws = optimizer.get_delta_w()

            g_remain, g_large_change = get_upload(g_remain, delta_ws,
                                                  args.ratio,
                                                  args.isCompensate)

            batch_comp_time = time.time()
            # 同步操作
            # send epoch train loss firstly
            dist.gather(loss.data, dst=0, group=group)
            for idx, param in enumerate(model.parameters()):
                dist.gather(tensor=g_large_change[idx], dst=0, group=group)
                recv = torch.zeros_like(delta_ws[idx])
                dist.scatter(tensor=recv, src=0, group=group)
                param.data = recv

            epoch_train_loss += loss.data.item()
            batch_end_time = time.time()

            batch_interval += batch_end_time - batch_start_time
            batch_comp_interval += batch_comp_time - batch_start_time
            batch_comm_interval += batch_end_time - batch_comp_time

            logs = torch.tensor([
                0.0, batch_interval / (batch_idx + 1),
                batch_comp_interval / (batch_idx + 1),
                batch_comm_interval / (batch_idx + 1)
            ])
            time_logs.write(str(logs) + '\n')
            time_logs.flush()

        print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch,
                                                  loss.data.item()))

        e_time = time.time()
        #epoch_train_loss /= len(train_data)
        #epoch_train_loss = format(epoch_train_loss, '.4f')
        # 训练结束后进行test
        #test_loss, acc = test_model(rank, model, test_data, criterion=criterion)
        acc = 0.0
        batch_interval /= batch_idx + 1
        batch_comp_interval /= batch_idx + 1
        batch_comm_interval /= batch_idx + 1
        logs = torch.tensor(
            [acc, batch_interval, batch_comp_interval, batch_comm_interval])
        time_logs.write(str(logs) + '\n')
        time_logs.flush()
        #dist.gather(tensor=logs, dst = 0, group = group)
    time_logs.close()
예제 #6
0
def run(rank, workers, model, save_path, train_data, test_data, global_lr):
    # Get the initial model from the server
    print(workers)

    _group = [w for w in workers].append(0)
    group = dist.new_group(_group)

    for p in model.parameters():
        tmp_p = torch.zeros_like(p)
        dist.scatter(tensor=tmp_p, src=0, group=group)
        p.data = tmp_p
    print('Model recved successfully!')

    temp_lr = global_lr.get()

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        optimizer = MySGD(model.parameters(), lr=temp_lr)
    else:
        optimizer = MySGD(model.parameters(), lr=temp_lr)

    if args.model in ['MnistCNN', 'AlexNet']:
        criterion = torch.nn.NLLLoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    print('Begin!')

    # the parameters that will be transferred to the thread
    model_cache = [p.data + 0.0 for p in model.parameters()]
    global_update = [torch.zeros_like(p) for p in model.parameters()]
    local_update = [torch.zeros_like(p) for p in model.parameters()]
    it_count = Value(c_float,
                     0.)  # count update times in an iteration by local worker
    data_lock = Lock()
    update_lock = Queue()
    update_lock.put(1)

    loss_t = torch.tensor(0.0)
    receive_end = Value(c_bool, False)
    batch_communication_interval = Value(c_float, 0.0)
    stale_in_iteration = Value(c_float, 0.)

    sender_td = Thread(target=sender,
                       args=(
                           model_cache,
                           global_update,
                           local_update,
                           it_count,
                           loss_t,
                           update_lock,
                           data_lock,
                           group,
                           receive_end,
                           batch_communication_interval,
                           stale_in_iteration,
                       ),
                       daemon=True)
    sender_td.start()

    time_logs = open("./record" + str(rank), 'w')
    osp_logs = open("./log" + str(rank), 'w')
    Stale_Threshold = args.stale_threshold
    for epoch in range(args.epochs):
        batch_interval = 0.0
        batch_comp_interval = 0.0
        s_time = time.time()
        model.train()

        # Decay the learning at the specific epoch
        # learning rate should be decreased on server due to unmatched updating speed between local worker and server
        if not global_lr.empty():
            g_lr = global_lr.get()
            if args.model == 'AlexNet':
                for param_group in optimizer.param_groups:
                    param_group['lr'] = g_lr
                    print('LR Decreased! Now: {}'.format(param_group['lr']))

        for batch_idx, (data, target) in enumerate(train_data):
            batch_start_time = time.time()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            delta_ws = optimizer.get_delta_w()

            optimizer.step()

            # Aggregate local update
            data_lock.acquire()
            # aggregate loss
            loss_t.data += loss.data
            it_count.value += 1
            for g_idx, update in enumerate(local_update):
                update.data += delta_ws[g_idx].data
            data_lock.release()

            batch_computation_time = time.time()

            # Open the lock once the local update has at least one gradient
            if it_count.value == 1:
                update_lock.put(1)
            while it_count.value >= Stale_Threshold:
                pass

            if receive_end.value:
                receive_end.value = False
                for idx, param in enumerate(model.parameters()):
                    param.data = model_cache[idx]  # without local update
                    # param.data = model_cache[idx] - global_update[idx] # with local update

            batch_end_time = time.time()
            batch_interval += batch_end_time - batch_start_time
            batch_comp_interval += batch_computation_time - batch_start_time
            osp_logs.write(
                str(batch_end_time - batch_start_time) + "\t" +
                str(batch_computation_time - batch_start_time) + "\n")
            osp_logs.flush()

        print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch,
                                                  loss.data.item()))

        e_time = time.time()
        # 训练结束后进行test
        #test_loss, acc = test_model(rank, model, test_data, criterion=criterion)
        acc = 0.0
        batch_interval /= batch_idx
        batch_comp_interval /= batch_idx
        logs = torch.tensor([
            acc, batch_interval, batch_comp_interval,
            batch_communication_interval.value, stale_in_iteration.value
        ])
        time_logs.write(str(logs) + '\n')
        time_logs.flush()
        # dist.gather(tensor=logs, dst = 0, group = group)
    time_logs.close()
    sender_td.join()
예제 #7
0
def run(rank, model, train_pics, train_bsz, g_lr):
    workers = [v + 1 for v in range(args.workers_num)]
    print(workers)

    _group = [w for w in workers].append(rank)
    group = dist.new_group(_group)

    for p in model.parameters():
        scatter_p_list = [p.data for _ in range(len(workers) + 1)]
        dist.scatter(tensor=p.data, scatter_list=scatter_p_list, group=group)

    # initialize learning rate

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        global_lr = 0.1
    else:
        global_lr = 0.01

    for w in workers:
        g_lr.put(global_lr)
    print('Model Sent Finished!')

    print('Begin!')

    epoch_train_loss = 0.0
    trainloss_file = './trainloss' + args.model + '.txt'
    log_file = './log' + args.model + '.txt'
    if (os.path.isfile(trainloss_file)):
        os.remove(trainloss_file)
    if (os.path.isfile(log_file)):
        os.remove(log_file)
    f_trainloss = open(trainloss_file, 'a')
    f_log = open(log_file, 'a')

    tmp = [(0, 0) for _ in range(
        int(
            math.ceil(
                int(train_pics * args.data_ratio) /
                (len(workers) * train_bsz))))]

    s_time = time.time()
    epoch_time = s_time
    total_iteration_time = 0
    iteration_times_count_epoch = 0
    iteration_times_epoch = len(tmp) * len(workers)
    if args.model in ['AlexNet', 'ResNet18OnCifar10']:
        decay_period = 50
    else:
        decay_period = 100

    global_clock = 0
    epoch_clock = 0
    real_epoch = 0
    for epoch in range(args.epochs):
        for batch_idx, (_, _) in enumerate(tmp):
            batch_start_time = time.time()

            # receive the list of train loss and local iteration count from workers
            loss_it_list = [
                torch.tensor([0.0, 0.0]) for _ in range(len(workers) + 1)
            ]
            dist.gather(tensor=torch.tensor([0.0, 1.0]),
                        gather_list=loss_it_list,
                        group=group)

            loss_avg = [loss_it[0] / loss_it[1] for loss_it in loss_it_list]
            epoch_train_loss += sum(loss_avg).item() / len(workers)
            iteration_times = sum(loss_it_list)[1]
            total_iteration_time += iteration_times
            iteration_times_count_epoch += iteration_times

            # receive global update from each worker
            for update_idx, param in enumerate(model.parameters()):
                tensor = torch.zeros_like(param.data)
                gather_list = [
                    torch.zeros_like(param.data)
                    for _ in range(len(workers) + 1)
                ]
                dist.gather(tensor=tensor,
                            gather_list=gather_list,
                            group=group)
                # here we only use average temperally for simplicity
                tensor = sum(gather_list) / len(workers)
                param.data -= tensor
            # send updated model back to workers
            for param_idx, param in enumerate(model.parameters()):
                scatter_list = [param.data for _ in range(len(workers) + 1)]
                dist.scatter(tensor=torch.zeros_like(param.data),
                             scatter_list=scatter_list,
                             group=group)

            global_clock += 1
            epoch_clock += 1
            # Decay the learning at the specific epoch
            # not update intra the epoch
            temp_epoch = int(total_iteration_time / iteration_times_epoch) + 1
            if temp_epoch > real_epoch:
                real_epoch = temp_epoch
                if real_epoch % decay_period == 0:
                    global_lr *= 0.1
                    for w in workers:
                        g_lr.put(global_lr)
                    print('LR Decreased! Now: {}'.format(global_lr))

            # evaluate the time of each iteration
            batch_end_time = time.time()
            batch_time_interval = batch_end_time - batch_start_time
            f_log.write(
                str(batch_time_interval) + "\t" + str(iteration_times) + "\n")
            f_log.flush()

            if iteration_times_count_epoch >= iteration_times_epoch:
                e_epoch_time = time.time()
                iteration_times_count_epoch -= iteration_times_epoch
                # test_acc, batch_interval, batch_comp_interval, batch_comm_interval
                logs = torch.tensor([0.0, 0.0, 0.0, 0.0])
                logs_list = [
                    torch.zeros_like(logs) for _ in range(len(workers) + 1)
                ]
                # dist.gather(tensor=logs, gather_list=logs_list, group=group)
                test_acc, batch_interval, batch_comp_interval, batch_comm_interval = zip(
                    *logs_list)
                test_acc = sum(test_acc) / len(workers)
                batch_interval = sum(batch_interval) / len(workers)
                batch_comp_interval = sum(batch_comp_interval) / len(workers)
                batch_comm_interval = sum(batch_comm_interval) / len(workers)

                f_trainloss.write(
                    str(args.this_rank) + "\t" +
                    str(epoch_train_loss / float(epoch_clock)) + "\t" +
                    str(0) + "\t" + str(e_epoch_time - epoch_time) + "\t" +
                    str(e_epoch_time - s_time) + "\t" +
                    str(int(total_iteration_time / iteration_times_epoch)) +
                    "\t" + str(test_acc.item()) + "\t" +
                    str(batch_interval.item()) + "\t" +
                    str(batch_comp_interval.item()) + "\t" +
                    str(batch_comm_interval.item()) + "\t" +
                    str(global_clock) + '\n')
                print("total_iteration_time:{}, iteration_times:{}".format(
                    total_iteration_time, iteration_times_epoch))
                f_trainloss.flush()
                epoch_time = e_epoch_time
                epoch_train_loss = 0.0
                epoch_clock = 0

        print('Done Epoch {}/{}!'.format(epoch + 1, args.epochs))

    f_trainloss.close()