Exemple #1
0
    def run(rank, size):
        if rank == task.to_rank:  # process disturb-ed
            disturb.init(port=12123)
            for _ in range(10):
                req = task.fork_detach(
                    torch.zeros(size),
                    torch.zeros(size, dtype=torch.uint8) + 1,
                    dtype=(torch.float32, torch.uint8),
                )
                req.wait()

        else:  # Some server task running on another node
            utils.init_distributedenv(1, port=12123)

            for _ in range(10):
                recv_buff_feat, recv_buff_label, _ = utils.fork_recv(
                    rank=0, dtype=(torch.float32, torch.uint8))
                assert torch.all(
                    torch.eq(recv_buff_feat,
                             torch.zeros(size, dtype=torch.float32)))
                assert torch.all(
                    torch.eq(recv_buff_label,
                             torch.zeros(size, dtype=torch.uint8) + 1))
                assert recv_buff_label.dtype == torch.uint8
                assert recv_buff_feat.dtype == torch.float32
Exemple #2
0
def main():
    """Run the main training function."""
    parser = get_parser()
    args, _ = parser.parse_known_args()

    device = torch.device(
        f"cuda:{args.gpu_device}" if torch.cuda.is_available() else "cpu"
    )

    assert len(args.label) == len(
        args.label_name
    ), "The size of '--label' must be the same as the size of '--label-name'"

    # load the conf
    spec = importlib.util.spec_from_file_location("config", args.config)
    config = importlib.util.module_from_spec(spec)
    config.argsparser = (
        parser  # Share the configargparse.ArgumentParser with the user defined module
    )
    spec.loader.exec_module(config)

    # create the net
    net = config.net.to(device)

    # load the snapshot
    net.load_state_dict(torch.load(args.snapshot, map_location=device)["model"])

    # init the rank of this task
    utils.init_distributedenv(
        rank=args.task_rank, world_size=args.world_size, ip=args.master_ip
    )

    net.eval()
    total_labels = torch.LongTensor([])
    total_pred = torch.LongTensor([])
    with torch.no_grad():
        while True:
            features, y_mapper, is_meta_data = utils.fork_recv(
                rank=0, dtype=(torch.float32, torch.long)
            )

            if is_meta_data:
                meta_data = y_mapper

                if const.should_stop(meta_data):
                    break

            target = config.mapper(y_mapper)
            y_pred = net(features.to(device))

            _, predicted = torch.max(y_pred.data, dim=1)

            total_labels = torch.cat((total_labels, target.cpu()))
            total_pred = torch.cat((total_pred, predicted.cpu()))

    utils.display_evaluation_result(args, total_labels, total_pred)
Exemple #3
0
    def run(rank, size):
        if rank == task.to_rank:  # process disturb-ed
            disturb.init(port=12121)
            for _ in range(10):
                req = task.isend(torch.zeros(size))
                req.wait()

        else:  # Some server task running on another node
            utils.init_distributedenv(1, port=12121)

            for _ in range(10):
                recv_buff, _ = utils.recv(rank=0)
                assert torch.all(torch.eq(recv_buff, torch.zeros(size)))
Exemple #4
0
def init(expected_domain_tasks=int(os.getenv("DAMPED_N_DOMAIN", 1)),
         port=29500) -> None:
    """Initialize the damped distributed environment

    Args:
        expected_domain_tasks (int): The number of expected domain task.
        port (int): port on which the the tensor will be exchanged
    """
    logger.info("Waiting for domain-task trainer connection")
    utils.init_distributedenv(0,
                              world_size=expected_domain_tasks + 1,
                              port=port)

    # init ManagedMemory
    ManagedMemory()
def test_init():
    utils.init_distributedenv(0, world_size=1, port=6223)
    assert dist.is_initialized()
Exemple #6
0
def main():
    """Run the main training function."""
    parser = get_parser()
    args, _ = parser.parse_known_args()

    device = torch.device(
        f"cuda:{args.gpu_device}" if torch.cuda.is_available() else "cpu")

    # load the conf
    spec = importlib.util.spec_from_file_location("config", args.config)
    config = importlib.util.module_from_spec(spec)
    config.argsparser = (
        parser  # Share the configargparse.ArgumentParser with the user defined module
    )
    spec.loader.exec_module(config)

    # create the net and training optim/criterion
    optimizer = config.optimizer
    net = config.net.to(device)
    criterion = config.criterion

    # keep track of some values while training
    total_correct = 0
    total_target = 0

    tensorboard_dir = args.tensorboard_dir
    if args.tensorboard_dir == "":
        tensorboard_dir = os.path.join("exp/", args.exp_path, "tensorboard")

    monitor = utils.Monitor(
        tensorboard_dir=tensorboard_dir,
        save_path=os.path.join("exp/", args.exp_path),
        exp_id=net.__class__.__name__,
        model=net,
        eval_metrics="acc, loss",
        early_metric="acc",
        save_best_metrics=True,
        n_checkpoints=args.n_checkpoint,
    )
    monitor.set_optimizer(optimizer)
    monitor.save_model_summary()

    if args.resume:
        print("resumed from %s" % args.resume, flush=True)
        # load last checkpoint
        monitor.load_checkpoint(args.resume, args.load_optimizer)

    # Eval related
    eval_mode = False
    total_labels = torch.LongTensor([])
    total_pred = torch.LongTensor([])
    loss_batches = 0
    loss_batches_count = 0

    # indicate if damped.disturb-ed toolkit wants the gradient form the DomainTask
    send_backward_grad = False

    # init the rank of this task
    utils.init_distributedenv(rank=args.task_rank,
                              world_size=args.world_size,
                              ip=args.master_ip)

    print("Training started on %s" % time.strftime("%d-%m-%Y %H:%M"),
          flush=True)

    # TODO(pchampio) refactor this training loop into sub-functions
    while True:
        features, y_mapper, is_meta_data = utils.fork_recv(
            rank=0, dtype=(torch.float32, torch.long))

        if is_meta_data:
            meta_data = y_mapper

            if const.should_stop(meta_data):
                break

            if const.is_no_wait_backward(meta_data):
                print("Switch to NOT sending backward gradient", flush=True)
                send_backward_grad = False

            if const.is_wait_backward(meta_data):
                print("Switch to sending backward gradient", flush=True)
                send_backward_grad = True

            last_eval = eval_mode
            eval_mode = const.is_eval(meta_data)

            # detect changes from train to eval
            if eval_mode and not last_eval:
                print("Running evaluation on dev..", flush=True)
                net.eval()

            # detect changes from eval to train
            if not eval_mode and last_eval:
                net.train()
                # display validation metics
                accuracy = (accuracy_score(total_labels.flatten().numpy(),
                                           total_pred.flatten().numpy()) * 100)

                loss = 0
                if loss_batches_count != 0:
                    loss = loss_batches / loss_batches_count

                monitor.update_dev_scores([
                    utils.Metric("acc", accuracy),
                    utils.Metric(
                        "loss",
                        loss,
                        higher_better=False,
                    ),
                ])

                monitor.save_models()
                monitor.vctr += 1
                # clear for next eval
                total_labels = torch.LongTensor([])
                total_pred = torch.LongTensor([])
                loss_batches = 0
                loss_batches_count = 0

            # When meta_data is shared, no features/label are sent
            continue

        target = config.mapper(y_mapper)

        input = features.to(device)
        input.requires_grad = True

        # Eval
        if eval_mode:
            y_pred = net(input)

            # send back the gradient if needed
            if send_backward_grad:
                # backward will not be applied (eval), send 0 grad
                damped.disturb.DomainTask._isend(
                    0, torch.zeros(*input.size())).wait()

            _, predicted = torch.max(y_pred.data, dim=1)

            total_labels = torch.cat((total_labels, target.cpu()))
            total_pred = torch.cat((total_pred, predicted.cpu()))

            loss = criterion(y_pred, target.to(device))

            loss_batches += loss.cpu().detach().numpy()
            loss_batches_count += 1
            continue

        optimizer.zero_grad()
        y_pred = net(input)

        if torch.any(torch.isnan(y_pred)):
            print(features)
            print("ERROR: ignoring this batch, prediction is NaN")
            continue

        loss = criterion(y_pred, target.to(device))
        loss.backward()

        # send back the gradient if asked
        if send_backward_grad:
            damped.disturb.DomainTask._isend(0, input.grad.data.cpu()).wait()

        optimizer.step()

        correct = (torch.argmax(y_pred.data,
                                1) == target.to(device)).sum().item()
        total_correct += correct
        total_target += target.size(0)

        monitor.train_loss.append(loss.item())

        monitor.uctr += 1
        if monitor.uctr % args.log_interval == 0:
            accuracy = (total_correct / total_target) * 100
            print(
                "Train batch [{}]\tLoss: {:.6f}\tTrain Accuracy: {:.3f}".
                format(
                    monitor.uctr,
                    loss.item(),
                    accuracy,
                ),
                flush=True,
            )
            monitor.tensorboard_writter.add_scalar("/train/accuracy", accuracy,
                                                   monitor.uctr)
            monitor.tensorboard_writter.add_scalar("/train/loss", loss.item(),
                                                   monitor.uctr)
            total_correct = 0
            total_target = 0

    print("Training finished on %s" % time.strftime("%d-%m-%Y %H:%M"))