Ejemplo n.º 1
0
def run(rank, model, train_data, test_data, queue, param_q, stop_flag):
    # Get the initial model from the server
    while True:
        if not param_q.empty():
            param_dict = param_q.get()
            tmp = OrderedDict(
                map(lambda item: (item[0], torch.from_numpy(item[1])),
                    param_dict.items()))
            model.load_state_dict(tmp)
            break
    print('Model recved successfully!')

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        optimizer = MySGD(model.parameters(), lr=0.1)
    else:
        optimizer = MySGD(model.parameters(), lr=0.01)

    if args.model in ['MnistCNN', 'AlexNet']:
        criterion = torch.nn.NLLLoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if args.model in ['AlexNet', 'ResNet18OnCifar10']:
        decay_period = 50
    else:
        decay_period = 100
    print('Begin!')

    time_logs = open("./record" + str(rank), 'w')

    for epoch in range(int(args.epochs)):
        model.train()
        # Decay the learning at the specific epoch
        #if args.model == 'AlexNet':
        if (epoch + 1) % decay_period == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
                print('LR Decreased! Now: {}'.format(param_group['lr']))
        epoch_train_loss = 0
        for batch_idx, (data, target) in enumerate(train_data):
            it_start = time.time()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            delta_ws = optimizer.get_delta_w()

            it_comp_end = time.time()
            # noinspection PyBroadException
            try:
                if delta_ws:
                    queue.put({
                        rank:
                        [loss.data.numpy(),
                         np.array(args.train_bsz), False]
                    })

                for delta in delta_ws:
                    dist.send(tensor=delta, dst=0)

                for idx, param in enumerate(model.parameters()):
                    tmp_tensor = torch.zeros_like(param.data)
                    dist.recv(tensor=tmp_tensor, src=0)
                    param.data = tmp_tensor

                #print('Rank {}, Epoch {}, Batch {}/{}, Loss:{}'
                #     .format(rank, epoch, batch_idx, len(train_data), loss.data[0]))
            except Exception as e:
                print(str(e))
                print('Should Stop: {}!'.format(stop_flag.value))
                break

            it_comm_end = time.time()
            it_duration = it_comm_end - it_start
            it_comp_duration = it_comp_end - it_start
            it_comm_duration = it_comm_end - it_comp_end
            time_logs.write(
                str(it_duration) + "\t" + str(it_comp_duration) + "\t" +
                str(it_comm_duration) + "\n")
            time_logs.flush()

        # test the model
        print("test Model:", epoch)
        # test_model(rank, model, test_data, criterion=criterion)
        if stop_flag.value:
            break
    queue.put({rank: [[], [], True]})
    time_logs.close()
    print("Worker {} has completed epoch {}!".format(args.this_rank, epoch))
Ejemplo n.º 2
0
def run(rank, workers, model, save_path, train_data, test_data):
    # Get the initial model from the server
    _group = [w for w in workers].append(0)
    group = dist.new_group(_group)

    for p in model.parameters():
        tmp_p = torch.zeros_like(p)
        dist.scatter(tensor=tmp_p, src=0, group=group)
        p.data = tmp_p
    print('Model recved successfully!')

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        optimizer = MySGD(model.parameters(), lr=0.1)
    else:
        optimizer = MySGD(model.parameters(), lr=0.01)

    if args.model in ['MnistCNN', 'AlexNet']:
        criterion = torch.nn.NLLLoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if args.model in ['AlexNet', 'ResNet18OnCifar10']:
        decay_period = 50
    else:
        decay_period = 100

    print('Begin!')

    time_logs = open("./record" + str(rank), 'w')
    for epoch in range(args.epochs):
        batch_interval = 0.0
        batch_comp_interval = 0.0
        batch_comm_interval = 0.0
        s_time = time.time()
        model.train()

        # Reduce the learning rate LR in some specific epochs
        #if args.model == 'AlexNet':
        if (epoch + 1) % decay_period == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
                print('LR Decreased! Now: {}'.format(param_group['lr']))

        epoch_train_loss = 0
        for batch_idx, (data, target) in enumerate(train_data):
            batch_start_time = time.time()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            delta_ws = optimizer.get_delta_w()

            batch_comp_time = time.time()
            # Synchronization
            # send epoch train loss firstly
            dist.gather(loss.data, dst=0, group=group)
            for idx, param in enumerate(model.parameters()):
                dist.gather(tensor=delta_ws[idx], dst=0, group=group)
                recv = torch.zeros_like(delta_ws[idx])
                dist.scatter(tensor=recv, src=0, group=group)
                param.data = recv

            epoch_train_loss += loss.data.item()
            batch_end_time = time.time()

            batch_interval += batch_end_time - batch_start_time
            batch_comp_interval += batch_comp_time - batch_start_time
            batch_comm_interval += batch_end_time - batch_comp_time

            logs = torch.tensor([
                0.0, batch_interval / (batch_idx + 1),
                batch_comp_interval / (batch_idx + 1),
                batch_comm_interval / (batch_idx + 1)
            ])
            time_logs.write(str(logs) + '\n')
            time_logs.flush()

        print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch,
                                                  loss.data.item()))

        e_time = time.time()
        #epoch_train_loss /= len(train_data)
        #epoch_train_loss = format(epoch_train_loss, '.4f')
        # test the model
        #test_loss, acc = test_model(rank, model, test_data, criterion=criterion)
        acc = 0.0
        batch_interval /= batch_idx
        batch_comp_interval /= batch_idx
        batch_comm_interval /= batch_idx
        logs = torch.tensor(
            [acc, batch_interval, batch_comp_interval, batch_comm_interval])
        time_logs.write(str(logs) + '\n')
        time_logs.flush()
        #dist.gather(tensor=logs, dst = 0, group = group)
    time_logs.close()
Ejemplo n.º 3
0
def run(rank, workers, model, save_path, train_data, test_data, global_lr):
    # Get the initial model from the server
    print(workers)

    _group = [w for w in workers].append(0)
    group = dist.new_group(_group)

    for p in model.parameters():
        tmp_p = torch.zeros_like(p)
        dist.scatter(tensor=tmp_p, src=0, group=group)
        p.data = tmp_p
    print('Model recved successfully!')

    temp_lr = global_lr.get()

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        optimizer = MySGD(model.parameters(), lr=temp_lr)
    else:
        optimizer = MySGD(model.parameters(), lr=temp_lr)

    if args.model in ['MnistCNN', 'AlexNet']:
        criterion = torch.nn.NLLLoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    print('Begin!')

    # the parameters that will be transferred to the thread
    model_cache = [p.data + 0.0 for p in model.parameters()]
    global_update = [torch.zeros_like(p) for p in model.parameters()]
    local_update = [torch.zeros_like(p) for p in model.parameters()]
    it_count = Value(c_float,
                     0.)  # count update times in an iteration by local worker
    data_lock = Lock()
    update_lock = Queue()
    update_lock.put(1)

    loss_t = torch.tensor(0.0)
    receive_end = Value(c_bool, False)
    batch_communication_interval = Value(c_float, 0.0)
    stale_in_iteration = Value(c_float, 0.)

    sender_td = Thread(target=sender,
                       args=(
                           model_cache,
                           global_update,
                           local_update,
                           it_count,
                           loss_t,
                           update_lock,
                           data_lock,
                           group,
                           receive_end,
                           batch_communication_interval,
                           stale_in_iteration,
                       ),
                       daemon=True)
    sender_td.start()

    time_logs = open("./record" + str(rank), 'w')
    osp_logs = open("./log" + str(rank), 'w')
    Stale_Threshold = args.stale_threshold
    for epoch in range(args.epochs):
        batch_interval = 0.0
        batch_comp_interval = 0.0
        s_time = time.time()
        model.train()

        # Decay the learning at the specific epoch
        # learning rate should be decreased on server due to unmatched updating speed between local worker and server
        if not global_lr.empty():
            g_lr = global_lr.get()
            if args.model == 'AlexNet':
                for param_group in optimizer.param_groups:
                    param_group['lr'] = g_lr
                    print('LR Decreased! Now: {}'.format(param_group['lr']))

        for batch_idx, (data, target) in enumerate(train_data):
            batch_start_time = time.time()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            delta_ws = optimizer.get_delta_w()

            optimizer.step()

            # Aggregate local update
            data_lock.acquire()
            # aggregate loss
            loss_t.data += loss.data
            it_count.value += 1
            for g_idx, update in enumerate(local_update):
                update.data += delta_ws[g_idx].data
            data_lock.release()

            batch_computation_time = time.time()

            # Open the lock once the local update has at least one gradient
            if it_count.value == 1:
                update_lock.put(1)
            while it_count.value >= Stale_Threshold:
                pass

            if receive_end.value:
                receive_end.value = False
                for idx, param in enumerate(model.parameters()):
                    param.data = model_cache[idx]  # without local update
                    # param.data = model_cache[idx] - global_update[idx] # with local update

            batch_end_time = time.time()
            batch_interval += batch_end_time - batch_start_time
            batch_comp_interval += batch_computation_time - batch_start_time
            osp_logs.write(
                str(batch_end_time - batch_start_time) + "\t" +
                str(batch_computation_time - batch_start_time) + "\n")
            osp_logs.flush()

        print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch,
                                                  loss.data.item()))

        e_time = time.time()
        # 训练结束后进行test
        #test_loss, acc = test_model(rank, model, test_data, criterion=criterion)
        acc = 0.0
        batch_interval /= batch_idx
        batch_comp_interval /= batch_idx
        logs = torch.tensor([
            acc, batch_interval, batch_comp_interval,
            batch_communication_interval.value, stale_in_iteration.value
        ])
        time_logs.write(str(logs) + '\n')
        time_logs.flush()
        # dist.gather(tensor=logs, dst = 0, group = group)
    time_logs.close()
    sender_td.join()
Ejemplo n.º 4
0
def run(rank, workers, model, save_path, train_data, test_data):
    # 获取ps端传来的模型初始参数
    _group = [w for w in workers].append(0)
    group = dist.new_group(_group)

    param_num = 0
    for p in model.parameters():
        tmp_p = torch.zeros_like(p)
        param_num += torch.numel(tmp_p)
        dist.scatter(tensor=tmp_p, src=0, group=group)
        p.data = tmp_p
    print('Model recved successfully!')

    compression_num = int(param_num * args.ratio)
    compression_num = compression_num if compression_num > 0 else 1
    dist.gather(torch.tensor([compression_num / param_num]),
                dst=0,
                group=group)

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        learning_rate = 0.1
    else:
        learning_rate = args.lr
    optimizer = MySGD(model.parameters(), lr=learning_rate)

    if args.model in ['MnistCNN', 'AlexNet']:
        criterion = torch.nn.NLLLoss()
    elif args.model in ['Abalone', 'Bodyfat', 'Housing']:
        criterion = torch.nn.MSELoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if args.model in ['AlexNet', 'ResNet18OnCifar10']:
        decay_period = 50
    elif args.model in [
            'LROnMnist', 'LROnCifar10', 'LROnCifar100', 'Abalone', 'Bodyfat',
            'Housing'
    ]:
        decay_period = 1000000  # learning rate is constant for LR (convex) models
    else:
        decay_period = 100

    print('Begin!')

    global_clock = 0
    g_remain = [torch.zeros_like(param.data) for param in model.parameters()]
    time_logs = open("./record" + str(rank), 'w')
    for epoch in range(args.epochs):
        batch_interval = 0.0
        batch_comp_interval = 0.0
        batch_comm_interval = 0.0
        s_time = time.time()
        model.train()

        # AlexNet在指定epoch减少学习率LR
        #if args.model == 'AlexNet':
        if (epoch + 1) % decay_period == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
                print('LR Decreased! Now: {}'.format(param_group['lr']))

        epoch_train_loss = 0
        for batch_idx, (data, target) in enumerate(train_data):
            batch_start_time = time.time()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            delta_ws = optimizer.get_delta_w()

            g_remain, g_large_change = get_upload(g_remain, delta_ws,
                                                  args.ratio,
                                                  args.isCompensate)

            batch_comp_time = time.time()
            # 同步操作
            # send epoch train loss firstly
            dist.gather(loss.data, dst=0, group=group)
            for idx, param in enumerate(model.parameters()):
                dist.gather(tensor=g_large_change[idx], dst=0, group=group)
                recv = torch.zeros_like(delta_ws[idx])
                dist.scatter(tensor=recv, src=0, group=group)
                param.data = recv

            epoch_train_loss += loss.data.item()
            batch_end_time = time.time()

            batch_interval += batch_end_time - batch_start_time
            batch_comp_interval += batch_comp_time - batch_start_time
            batch_comm_interval += batch_end_time - batch_comp_time

            logs = torch.tensor([
                0.0, batch_interval / (batch_idx + 1),
                batch_comp_interval / (batch_idx + 1),
                batch_comm_interval / (batch_idx + 1)
            ])
            time_logs.write(str(logs) + '\n')
            time_logs.flush()

        print('Rank {}, Epoch {}, Loss:{}'.format(rank, epoch,
                                                  loss.data.item()))

        e_time = time.time()
        #epoch_train_loss /= len(train_data)
        #epoch_train_loss = format(epoch_train_loss, '.4f')
        # 训练结束后进行test
        #test_loss, acc = test_model(rank, model, test_data, criterion=criterion)
        acc = 0.0
        batch_interval /= batch_idx + 1
        batch_comp_interval /= batch_idx + 1
        batch_comm_interval /= batch_idx + 1
        logs = torch.tensor(
            [acc, batch_interval, batch_comp_interval, batch_comm_interval])
        time_logs.write(str(logs) + '\n')
        time_logs.flush()
        #dist.gather(tensor=logs, dst = 0, group = group)
    time_logs.close()
Ejemplo n.º 5
0
def run(rank, model, train_data, test_data, queue, param_q, stop_flag):
    # 获取ps端传来的模型初始参数
    while True:
        if not param_q.empty():
            param_dict = param_q.get()
            tmp = OrderedDict(
                map(lambda item: (item[0], torch.from_numpy(item[1])),
                    param_dict.items()))
            model.load_state_dict(tmp)
            break
    print('Model recved successfully!')

    if args.model in ['MnistCNN', 'AlexNet', 'ResNet18OnCifar10']:
        optimizer = MySGD(model.parameters(), lr=0.1)
    else:
        optimizer = MySGD(model.parameters(), lr=0.01)

    if args.model in ['MnistCNN', 'AlexNet']:
        criterion = torch.nn.NLLLoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if args.model in ['AlexNet', 'ResNet18OnCifar10']:
        decay_period = 50
    else:
        decay_period = 100
    print('Begin!')

    time_logs = open("./record" + str(rank), 'w')
    for epoch in range(int(args.epochs)):
        batch_interval = 0.0
        batch_comp_interval = 0.0
        batch_comm_interval = 0.0
        batch_push_interval = 0.0
        batch_pull_interval = 0.0
        model.train()
        # AlexNet在指定epoch减少学习率LR
        #if args.model == 'AlexNet':
        if (epoch + 1) % decay_period == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
                print('LR Decreased! Now: {}'.format(param_group['lr']))
        epoch_train_loss = 0
        for batch_idx, (data, target) in enumerate(train_data):
            batch_start_time = time.time()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            delta_ws = optimizer.get_delta_w()

            batch_comp_time = time.time()
            # noinspection PyBroadException
            try:  # 捕获异常,异常来源于ps进程的停止
                if delta_ws:
                    queue.put({
                        rank:
                        [loss.data.numpy(),
                         np.array(args.train_bsz), False]
                    })
                for delta in delta_ws:
                    dist.send(tensor=delta, dst=0)

                batch_push_time = time.time()

                for idx, param in enumerate(model.parameters()):
                    tmp_tensor = torch.zeros_like(param.data)
                    dist.recv(tensor=tmp_tensor, src=0)
                    param.data = tmp_tensor

                batch_tmp_time = time.time()
                batch_pull_time = time.time()
                #print('Rank {}, Epoch {}, Batch {}/{}, Loss:{}'
                #     .format(rank, epoch, batch_idx, len(train_data), loss.data[0]))
            except Exception as e:
                print(str(e))
                print('Should Stop: {}!'.format(stop_flag.value))
                break

            batch_interval += batch_pull_time - batch_start_time
            batch_comp_interval += batch_comp_time - batch_start_time
            batch_comm_interval += batch_pull_time - batch_comp_time
            batch_push_interval += batch_push_time - batch_comp_time
            batch_pull_interval += batch_pull_time - batch_push_time
            b_interval = batch_interval / (batch_idx + 1)
            b_comp_interval = batch_comp_interval / (batch_idx + 1)
            b_comm_interval = batch_comm_interval / (batch_idx + 1)
            b_push_interval = batch_push_interval / (batch_idx + 1)
            b_pull_interval = batch_pull_interval / (batch_idx + 1)
            logs = torch.tensor([
                0.0, b_interval, b_comp_interval, b_comm_interval,
                b_push_interval, b_pull_interval,
                batch_pull_time - batch_tmp_time
            ])
            time_logs.write(str(logs) + '\n')
            time_logs.flush()

        batch_interval /= batch_idx
        batch_comp_interval /= batch_idx
        batch_comm_interval /= batch_idx
        batch_push_interval /= batch_idx
        batch_pull_interval /= batch_idx
        logs = torch.tensor([
            0.0, batch_interval, batch_comp_interval, batch_comm_interval,
            batch_push_interval, batch_pull_interval
        ])
        time_logs.write(str(epoch) + '\t' + str(logs) + '\n')
        time_logs.flush()
        # 训练结束后进行test
        print("test Model:", epoch)
        # test_model(rank, model, test_data, criterion=criterion)
        if stop_flag.value:
            break
    queue.put({rank: [[], [], True]})
    time_logs.close()
    print("Worker {} has completed epoch {}!".format(args.this_rank, epoch))