Exemple #1
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    if int(args.rank) == int(args.world_size) - 1:
        log_level = logging.INFO
    else:
        log_level = logging.WARNING
        # log_level = logging.INFO
    logging.basicConfig(
        level=log_level,
        format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s")

    logging.info(f'Find median: {args.find_median}')
    logging.warning(f'rank:{args.rank}, local_rank:{args.local_rank}')

    torch.cuda.set_device(args.local_rank)

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss()

    # create stages of the model
    module = importlib.import_module(args.module)
    args.arch = module.arch()
    model = module.model(criterion)

    # determine shapes of all tensors in passed-in model
    if args.arch == 'inception_v3':
        input_size = [args.batch_size, 3, 299, 299]
    else:
        #input_size = [args.batch_size, 3, 224, 224]
        # input_size = [args.batch_size, 3, 32, 32]
        input_size = [args.batch_size, 200]
    training_tensor_shapes = {
        "input0": input_size,
        "target": [args.batch_size]
    }
    dtypes = {"input0": torch.int64, "target": torch.int64}
    inputs_module_destinations = {"input": 0}
    target_tensor_names = {"target"}
    for i, (stage, inputs,
            outputs) in enumerate(model[:-1]):  # Skip last layer (loss).
        input_tensors = []
        for input in inputs:
            if i == 0:
                input_tensor = torch.zeros(tuple(
                    training_tensor_shapes[input]),
                                           dtype=torch.int64)
            else:
                input_tensor = torch.zeros(tuple(
                    training_tensor_shapes[input]),
                                           dtype=torch.float32)

            input_tensors.append(input_tensor)
        with torch.no_grad():
            logging.debug(
                f'[{i}] input tensor shape: {input_tensors[0].shape}')
            output_tensors = stage(*tuple(input_tensors))
        if not type(output_tensors) is tuple:
            output_tensors = [output_tensors]
        for output, output_tensor in zip(outputs, list(output_tensors)):
            training_tensor_shapes[output] = list(output_tensor.size())
            dtypes[output] = output_tensor.dtype
    eval_tensor_shapes = {}
    for key in training_tensor_shapes:
        eval_tensor_shapes[key] = tuple([args.eval_batch_size] +
                                        training_tensor_shapes[key][1:])
        training_tensor_shapes[key] = tuple(training_tensor_shapes[key])

    configuration_maps = {
        'module_to_stage_map': None,
        'stage_to_rank_map': None,
        'stage_to_depth_map': None
    }
    if args.config_path is not None:
        json_config_file = json.load(open(args.config_path, 'r'))
        configuration_maps['module_to_stage_map'] = json_config_file.get(
            "module_to_stage_map", None)
        configuration_maps['stage_to_rank_map'] = json_config_file.get(
            "stage_to_rank_map", None)
        configuration_maps['stage_to_rank_map'] = {
            int(k): v
            for (k, v) in configuration_maps['stage_to_rank_map'].items()
        }
        configuration_maps['stage_to_depth_map'] = json_config_file.get(
            "stage_to_depth_map", None)

    r = runtime.StageRuntime(
        model=model,
        distributed_backend=args.distributed_backend,
        fp16=args.fp16,
        loss_scale=args.loss_scale,
        training_tensor_shapes=training_tensor_shapes,
        eval_tensor_shapes=eval_tensor_shapes,
        training_tensor_dtypes=dtypes,
        inputs_module_destinations=inputs_module_destinations,
        target_tensor_names=target_tensor_names,
        configuration_maps=configuration_maps,
        master_addr=args.master_addr,
        rank=args.rank,
        local_rank=args.local_rank,
        num_ranks_in_server=args.num_ranks_in_server,
        verbose_freq=args.verbose_frequency,
        model_type=runtime.IMAGE_CLASSIFICATION,
        port=args.port,
        enable_recompute=args.recompute)

    # stage needed to determine if current stage is the first stage
    # num_stages needed to determine if current stage is the last stage
    # num_ranks needed to determine number of warmup_minibatches in case of pipelining
    args.stage = r.stage
    args.num_stages = r.num_stages
    args.num_ranks = r.num_ranks
    if not is_first_stage():
        args.synthetic_data = True

    # define optimizer
    if args.no_input_pipelining:
        num_versions = 1
    else:
        # number of versions is the total number of machines following the current
        # stage, shared amongst all replicas in this stage
        num_versions = r.num_warmup_minibatches + 1

    # if specified, resume from checkpoint
    if args.resume:
        checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage)
        assert os.path.isfile(checkpoint_file_path)
        logging.info("=> loading checkpoint '{}'".format(checkpoint_file_path))
        checkpoint = torch.load(checkpoint_file_path)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        r.load_state_dict(checkpoint['state_dict'])
        logging.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file_path, checkpoint['epoch']))

    #optimizer = sgd.SGDWithWeightStashing(r.modules(), r.master_parameters,
    if args.spectrain:
        if args.log_dir != None:
            args.log_dir += '_spectrain_v1'
        logging.info('Using spectrain_v1')
        if args.square:
            if args.log_dir != None:
                args.log_dir += '_square'
            logging.info('s = version difference ^ 2')
        else:
            logging.info('s = version difference')
        optimizer = sgd.SGDWithSpectrainCHC(
            r.modules(),
            r.master_parameters,
            r.model_parameters,
            args.loss_scale,
            #  num_versions=num_versions,
            num_versions=1,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            verbose_freq=args.verbose_frequency,
            macrobatch=args.macrobatch)
    else:
        logging.info('Not using spectrain')
        optimizer = sgd.SGDWithWeightStashing(
            r.modules(),
            r.master_parameters,
            r.model_parameters,
            args.loss_scale,
            num_versions=num_versions,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            verbose_freq=args.verbose_frequency,
            macrobatch=args.macrobatch)

    logging.info(f'log_dir: {args.log_dir}')
    if args.resume:
        optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # logging.info(f'args.arch = {args.arch}')
    # logging.info(f'args.synthetic_data = {args.synthetic_data}')
    from keras.preprocessing.sequence import pad_sequences
    from keras.datasets import imdb
    (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=10000)
    x_train = pad_sequences(x_train,
                            maxlen=200,
                            padding="post",
                            truncating="post")
    x_test = pad_sequences(x_test,
                           maxlen=200,
                           padding="post",
                           truncating="post")
    print(x_train.shape, x_test.shape)

    train_dataset = TensorDataset(torch.LongTensor(x_train),
                                  torch.LongTensor(y_train))
    val_dataset = TensorDataset(torch.LongTensor(x_test),
                                torch.LongTensor(y_test))

    # logging.info(f'rank[{args.rank}] type(train_dataset) = {type(train_dataset)}')
    # exit()
    global writer
    if dist.get_rank() == dist.get_world_size() - 1:
        # writer = SummaryWriter(args.log_dir)
        pass

    distributed_sampler = False
    train_sampler = None
    val_sampler = None
    if configuration_maps['stage_to_rank_map'] is not None:
        num_ranks_in_first_stage = len(
            configuration_maps['stage_to_rank_map'][0])
        if num_ranks_in_first_stage > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=num_ranks_in_first_stage,
                rank=args.rank)
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_dataset,
                num_replicas=num_ranks_in_first_stage,
                rank=args.rank)
            distributed_sampler = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    # logging.info(f'type(train_loader) = {type(train_loader)}')

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.eval_batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             drop_last=True)

    # if checkpoint is loaded, start by running validation
    if args.resume:
        assert args.start_epoch > 0
        validate(val_loader, r, args.start_epoch - 1)

    for epoch in range(args.start_epoch, args.epochs):
        if distributed_sampler:
            train_sampler.set_epoch(epoch)

        # train or run forward pass only for one epoch
        if args.forward_only:
            validate(val_loader, r, epoch)
        else:
            train(train_loader, r, optimizer, epoch)

            # evaluate on validation set
            prec1 = validate(val_loader, r, epoch)
            if r.stage != r.num_stages:
                prec1 = 0

            # remember best prec@1 and save checkpoint
            best_prec1 = max(prec1, best_prec1)

            should_save_checkpoint = args.checkpoint_dir_not_nfs or r.rank_in_stage == 0
            if args.checkpoint_dir and should_save_checkpoint:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': r.state_dict(),
                        'best_prec1': best_prec1,
                        'optimizer': optimizer.state_dict(),
                    }, args.checkpoint_dir, r.stage)
Exemple #2
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    args.data = args.data_dir

    os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.local_rank}"

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Build criterion
    criterion = task.build_criterion(args)

    # create stages of the model
    module = importlib.import_module(args.module)
    args.arch = module.arch()
    model = module.model(criterion)

    max_positions = (args.max_source_positions, args.max_target_positions)
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions)
    inputs = dummy_batch['net_input']
    input0 = inputs['src_tokens']
    input1 = inputs['prev_output_tokens']
    target = dummy_batch['target']

    training_tensor_shapes = {
        "input0": list(input0.size()),
        "input1": list(input1.size()),
        "target": list(target.size()),
        "ntokens": [1]
    }
    dtypes = {
        "input0": input0.dtype,
        "input1": input1.dtype,
        "target": target.dtype,
        "ntokens": torch.float32
    }
    inputs_module_destinations = {"input0": 0, "input1": 0}
    target_tensor_names = {"target", "ntokens"}
    for module_id, (stage, inputs, outputs) in enumerate(
            model[:-1]):  # Skip last layer (loss).
        input_tensors = []
        for module_input in inputs:
            if module_input in inputs_module_destinations:
                inputs_module_destinations[module_input] = module_id

            input_tensor = torch.ones(tuple(
                training_tensor_shapes[module_input]),
                                      dtype=dtypes[module_input]).cuda()
            input_tensors.append(input_tensor)
        stage.cuda()
        # PyTorch should not maintain metadata for a backward pass on
        # synthetic inputs. Without the following line, the runtime is
        # as much as 1.5x slower in a full DP configuration.
        with torch.no_grad():
            output_tensors = stage(*tuple(input_tensors))
        if not type(output_tensors) is tuple:
            output_tensors = [output_tensors]
        for output, output_tensor in zip(outputs, list(output_tensors)):
            training_tensor_shapes[output] = list(output_tensor.size())
            dtypes[output] = output_tensor.dtype

    eval_tensor_shapes = {}
    for key in training_tensor_shapes:
        eval_tensor_shapes[key] = tuple(training_tensor_shapes[key])
        training_tensor_shapes[key] = tuple(training_tensor_shapes[key])

    configuration_maps = {
        'module_to_stage_map': None,
        'stage_to_rank_map': None,
        'stage_to_depth_map': None
    }
    if args.config_path is not None:
        json_config_file = json.load(open(args.config_path, 'r'))
        configuration_maps['module_to_stage_map'] = json_config_file.get(
            "module_to_stage_map", None)
        configuration_maps['stage_to_rank_map'] = json_config_file.get(
            "stage_to_rank_map", None)
        configuration_maps['stage_to_rank_map'] = {
            int(k): v
            for (k, v) in configuration_maps['stage_to_rank_map'].items()
        }
        configuration_maps['stage_to_depth_map'] = json_config_file.get(
            "stage_to_depth_map", None)

    r = runtime.StageRuntime(
        model=model,
        distributed_backend=args.distributed_backend,
        fp16=args.fp16,
        loss_scale=args.loss_scale,
        training_tensor_shapes=training_tensor_shapes,
        eval_tensor_shapes=eval_tensor_shapes,
        training_tensor_dtypes=dtypes,
        inputs_module_destinations=inputs_module_destinations,
        target_tensor_names=target_tensor_names,
        configuration_maps=configuration_maps,
        master_addr=args.master_addr,
        rank=args.rank,
        local_rank=args.local_rank,
        num_ranks_in_server=args.num_ranks_in_server,
        verbose_freq=args.verbose_frequency,
        model_type=runtime.TRANSLATION,
        enable_recompute=args.recompute)

    # stage needed to determine if current stage is the first stage
    # num_stages needed to determine if current stage is the last stage
    # num_ranks needed to determine number of warmup_minibatches in case of pipelining
    args.stage = r.stage
    args.num_stages = r.num_stages
    args.num_ranks = r.num_ranks
    if not is_first_stage():
        args.synthetic_data = True

    # define optimizer
    if args.no_input_pipelining:
        num_versions = 1
    else:
        # number of versions is the total number of machines following the current
        # stage, shared amongst all replicas in this stage
        num_versions = r.num_warmup_minibatches + 1

    # if specified, resume from checkpoint
    if args.resume:
        checkpoint_file_path = os.path.join(
            args.checkpoint_dir,
            f"checkpoint.{r.stage}.pth.tar.epoch.{args.start_epoch}")
        assert os.path.isfile(checkpoint_file_path)
        print("=> loading checkpoint '{}'".format(checkpoint_file_path))
        checkpoint = torch.load(checkpoint_file_path)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        r.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file_path, checkpoint['epoch']))

    # TODO: make this configurable by args
    use_adam_optimizer = True
    if use_adam_optimizer:
        optimizer = adam.Adam(r.master_parameters,
                              lr=args.lr,
                              betas=(0.9, 0.98),
                              weight_decay=args.weight_decay)
    else:
        optimizer = sgd.SGD(r.master_parameters,
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    scheduler = lr_scheduler.build_lr_scheduler(args, optimizer)

    if args.resume:
        optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    # epoch_itr = data.EpochBatchIterator(
    #     dataset=task.dataset(args.train_subset),
    #     max_tokens=args.max_tokens,
    #     max_sentences=args.max_sentences_valid,
    #     max_positions=max_positions,
    #     ignore_invalid_inputs=True,
    #     required_batch_size_multiple=8,
    #     seed=1,
    #     num_shards=1,
    #     shard_id=0,
    # )

    def epoch_itr():
        return task.dataset('train').get_dummy_batch(args.max_tokens,
                                                     max_positions)

    distributed_sampler = False
    if configuration_maps['stage_to_rank_map'] is not None:
        num_ranks_in_first_stage = len(
            configuration_maps['stage_to_rank_map'][0])
        if num_ranks_in_first_stage > 1:
            distributed_sampler = True

    for epoch in range(args.start_epoch, args.epochs):
        if distributed_sampler:
            train_loader.sampler.set_epoch(epoch)

        # train or run forward pass only for one epoch
        if args.forward_only:
            validate(val_loader, r, epoch)
        else:
            train(epoch_itr, r, optimizer, epoch, scheduler)

            # evaluate on validation set
            # prec1 = validate(val_loader, r, epoch)
            prec1 = 0
            if r.stage != r.num_stages: prec1 = 0

            # remember best prec@1 and save checkpoint
            best_prec1 = max(prec1, best_prec1)

            should_save_checkpoint = args.checkpoint_dir_not_nfs or r.rank_in_stage == 0
            if args.checkpoint_dir and should_save_checkpoint:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': r.state_dict(),
                        'best_prec1': best_prec1,
                        'optimizer': optimizer.state_dict()
                    }, args.checkpoint_dir, r.stage, epoch)
Exemple #3
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    torch.cuda.set_device(args.local_rank)

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss()

    # create stages of the model
    module = importlib.import_module(args.module)
    args.arch = module.arch()
    model = module.model(criterion)

    # determine shapes of all tensors in passed-in model
    if args.arch == 'inception_v3':
        input_size = [args.batch_size, 3, 299, 299]
    else:
        input_size = [args.batch_size, 3, 224, 224]
    training_tensor_shapes = {
        "input0": input_size,
        "target": [args.batch_size]
    }
    dtypes = {"input0": torch.int64, "target": torch.int64}
    inputs_module_destinations = {"input": 0}
    target_tensor_names = {"target"}
    for (stage, inputs, outputs) in model[:-1]:  # Skip last layer (loss).
        input_tensors = []
        for input in inputs:
            input_tensor = torch.zeros(tuple(training_tensor_shapes[input]),
                                       dtype=torch.float32)
            input_tensors.append(input_tensor)
        with torch.no_grad():
            output_tensors = stage(*tuple(input_tensors))
        if not type(output_tensors) is tuple:
            output_tensors = [output_tensors]
        for output, output_tensor in zip(outputs, list(output_tensors)):
            training_tensor_shapes[output] = list(output_tensor.size())
            dtypes[output] = output_tensor.dtype

    eval_tensor_shapes = {}
    for key in training_tensor_shapes:
        eval_tensor_shapes[key] = tuple([args.eval_batch_size] +
                                        training_tensor_shapes[key][1:])
        training_tensor_shapes[key] = tuple(training_tensor_shapes[key])

    configuration_maps = {
        'module_to_stage_map': None,
        'stage_to_rank_map': None,
        'stage_to_depth_map': None
    }
    if args.config_path is not None:
        json_config_file = json.load(open(args.config_path, 'r'))
        configuration_maps['module_to_stage_map'] = json_config_file.get(
            "module_to_stage_map", None)
        configuration_maps['stage_to_rank_map'] = json_config_file.get(
            "stage_to_rank_map", None)
        configuration_maps['stage_to_rank_map'] = {
            int(k): v
            for (k, v) in configuration_maps['stage_to_rank_map'].items()
        }
        configuration_maps['stage_to_depth_map'] = json_config_file.get(
            "stage_to_depth_map", None)

    r = runtime.StageRuntime(
        model=model,
        distributed_backend=args.distributed_backend,
        fp16=args.fp16,
        loss_scale=args.loss_scale,
        training_tensor_shapes=training_tensor_shapes,
        eval_tensor_shapes=eval_tensor_shapes,
        training_tensor_dtypes=dtypes,
        inputs_module_destinations=inputs_module_destinations,
        target_tensor_names=target_tensor_names,
        configuration_maps=configuration_maps,
        master_addr=args.master_addr,
        rank=args.rank,
        local_rank=args.local_rank,
        num_ranks_in_server=args.num_ranks_in_server,
        verbose_freq=args.verbose_frequency,
        model_type=runtime.IMAGE_CLASSIFICATION,
        enable_recompute=args.recompute)

    # stage needed to determine if current stage is the first stage
    # num_stages needed to determine if current stage is the last stage
    # num_ranks needed to determine number of warmup_minibatches in case of pipelining
    args.stage = r.stage
    args.num_stages = r.num_stages
    args.num_ranks = r.num_ranks
    if not is_first_stage():
        args.synthetic_data = True

    # define optimizer
    if args.no_input_pipelining:
        num_versions = 1
    else:
        # number of versions is the total number of machines following the current
        # stage, shared amongst all replicas in this stage
        num_versions = r.num_warmup_minibatches + 1

    # if specified, resume from checkpoint
    if args.resume:
        checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage)
        assert os.path.isfile(checkpoint_file_path)
        print("=> loading checkpoint '{}'".format(checkpoint_file_path))
        checkpoint = torch.load(checkpoint_file_path)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        r.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file_path, checkpoint['epoch']))

    optimizer = sgd.SGDWithWeightStashing(r.modules(),
                                          r.master_parameters,
                                          r.model_parameters,
                                          args.loss_scale,
                                          num_versions=num_versions,
                                          lr=args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay,
                                          verbose_freq=args.verbose_frequency,
                                          macrobatch=args.macrobatch)

    if args.resume:
        optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data_dir, 'train')
    valdir = os.path.join(args.data_dir, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    if args.arch == 'inception_v3':
        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(299),
                transforms.ToTensor(),
                normalize,
            ]))
        if args.synthetic_data:
            train_dataset = SyntheticDataset((3, 299, 299), len(train_dataset))
    else:
        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        if args.synthetic_data:
            train_dataset = SyntheticDataset((3, 224, 224), len(train_dataset))

    val_dataset = datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    distributed_sampler = False
    train_sampler = None
    val_sampler = None
    if configuration_maps['stage_to_rank_map'] is not None:
        num_ranks_in_first_stage = len(
            configuration_maps['stage_to_rank_map'][0])
        if num_ranks_in_first_stage > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=num_ranks_in_first_stage,
                rank=args.rank)
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_dataset,
                num_replicas=num_ranks_in_first_stage,
                rank=args.rank)
            distributed_sampler = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.eval_batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             drop_last=True)

    # if checkpoint is loaded, start by running validation
    if args.resume:
        assert args.start_epoch > 0
        validate(val_loader, r, args.start_epoch - 1)

    for epoch in range(args.start_epoch, args.epochs):
        if distributed_sampler:
            train_sampler.set_epoch(epoch)

        # train or run forward pass only for one epoch
        if args.forward_only:
            validate(val_loader, r, epoch)
        else:
            train(train_loader, r, optimizer, epoch)

            # evaluate on validation set
            prec1 = validate(val_loader, r, epoch)
            if r.stage != r.num_stages: prec1 = 0

            # remember best prec@1 and save checkpoint
            best_prec1 = max(prec1, best_prec1)

            should_save_checkpoint = args.checkpoint_dir_not_nfs or r.rank_in_stage == 0
            if args.checkpoint_dir and should_save_checkpoint:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': r.state_dict(),
                        'best_prec1': best_prec1,
                        'optimizer': optimizer.state_dict(),
                    }, args.checkpoint_dir, r.stage)
def main():
    global args, best_prec1
    args = parser.parse_args()

    # Special case handling for GNMT model
    l2_promote()

    torch.cuda.set_device(args.local_rank)

    # build tokenizer
    tokenizer = Tokenizer(os.path.join(args.data_dir, config.VOCAB_FNAME))

    # define loss function
    criterion = build_gnmt_criterion(vocab_size=tokenizer.vocab_size,
                                     padding_idx=config.PAD,
                                     smoothing=0.1)

    # create stages of the model
    module = importlib.import_module(args.module)
    args.arch = module.arch()
    model = module.model(criterion)

    input_size = [args.max_length_train, args.batch_size]
    training_tensor_shapes = {
        "input0": input_size,
        "input1": [args.batch_size],
        "input2": input_size,
        "target": [args.max_length_train * args.batch_size],
        "target_length": [args.batch_size]
    }
    dtypes = {
        "input0": torch.int64,
        "input1": torch.int64,
        "input2": torch.int64,
        "target": torch.int64,
        "target_length": torch.int32
    }
    inputs_module_destinations = {"input0": 0, "input1": 0, "input2": 0}
    target_tensor_names = {"target", "target_length"}
    for module_id, (stage, inputs, outputs) in enumerate(
            model[:-1]):  # Skip last layer (loss).
        input_tensors = []
        for module_input in inputs:
            if module_input in inputs_module_destinations:
                inputs_module_destinations[module_input] = module_id

            input_tensor = torch.ones(tuple(
                training_tensor_shapes[module_input]),
                                      dtype=dtypes[module_input]).cuda()
            input_tensors.append(input_tensor)
        stage.cuda()
        # PyTorch should not maintain metadata for a backward pass on
        # synthetic inputs. Without the following line, the runtime is
        # as much as 1.5x slower in a full DP configuration.
        with torch.no_grad():
            output_tensors = stage(*tuple(input_tensors))
        if not type(output_tensors) is tuple:
            output_tensors = [output_tensors]
        for output, output_tensor in zip(outputs, list(output_tensors)):
            training_tensor_shapes[output] = list(output_tensor.size())
            dtypes[output] = output_tensor.dtype

    eval_tensor_shapes = {}
    for key in training_tensor_shapes:
        eval_tensor_shapes[key] = tuple(training_tensor_shapes[key])
        training_tensor_shapes[key] = tuple(training_tensor_shapes[key])

    configuration_maps = {
        'module_to_stage_map': None,
        'stage_to_rank_map': None,
        'stage_to_depth_map': None
    }
    if args.config_path is not None:
        json_config_file = json.load(open(args.config_path, 'r'))
        configuration_maps['module_to_stage_map'] = json_config_file.get(
            "module_to_stage_map", None)
        configuration_maps['stage_to_rank_map'] = json_config_file.get(
            "stage_to_rank_map", None)
        configuration_maps['stage_to_rank_map'] = {
            int(k): v
            for (k, v) in configuration_maps['stage_to_rank_map'].items()
        }
        configuration_maps['stage_to_depth_map'] = json_config_file.get(
            "stage_to_depth_map", None)

    r = runtime.StageRuntime(
        model=model,
        distributed_backend=args.distributed_backend,
        fp16=args.fp16,
        loss_scale=args.loss_scale,
        training_tensor_shapes=training_tensor_shapes,
        eval_tensor_shapes=eval_tensor_shapes,
        training_tensor_dtypes=dtypes,
        inputs_module_destinations=inputs_module_destinations,
        target_tensor_names=target_tensor_names,
        configuration_maps=configuration_maps,
        master_addr=args.master_addr,
        rank=args.rank,
        local_rank=args.local_rank,
        num_ranks_in_server=args.num_ranks_in_server,
        verbose_freq=args.verbose_frequency,
        model_type=runtime.TRANSLATION,
        enable_recompute=args.recompute)

    # stage needed to determine if current stage is the first stage
    # num_stages needed to determine if current stage is the last stage
    # num_ranks needed to determine number of warmup_minibatches in case of pipelining
    args.stage = r.stage
    args.num_stages = r.num_stages
    args.num_ranks = r.num_ranks
    if not is_first_stage():
        args.synthetic_data = True

    # define optimizer
    if args.no_input_pipelining:
        num_versions = 1
    else:
        # number of versions is the total number of machines following the current
        # stage, shared amongst all replicas in this stage
        num_versions = r.num_warmup_minibatches + 1

    # if specified, resume from checkpoint
    if args.resume:
        checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage)
        assert os.path.isfile(checkpoint_file_path)
        print("=> loading checkpoint '{}'".format(checkpoint_file_path))
        checkpoint = torch.load(checkpoint_file_path)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        r.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file_path, checkpoint['epoch']))

    # TODO: make this configurable by args
    use_adam_optimizer = True
    if use_adam_optimizer:
        optimizer = adam.AdamWithWeightStashing(
            modules=r.modules(),
            master_parameters=r.master_parameters,
            model_parameters=r.model_parameters,
            loss_scale=args.loss_scale,
            num_versions=num_versions,
            lr=args.lr,
            betas=(0.9, 0.999),
            weight_decay=args.weight_decay,
            verbose_freq=args.verbose_frequency,
            macrobatch=args.macrobatch)
    else:
        optimizer = sgd.SGDWithWeightStashing(
            modules=r.modules(),
            master_parameters=r.master_parameters,
            model_parameters=r.model_parameters,
            loss_scale=args.loss_scale,
            num_versions=num_versions,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            verbose_freq=args.verbose_frequency)

    if args.resume:
        optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    train_dataset = LazyParallelDataset(
        src_fname=os.path.join(args.data_dir, config.SRC_TRAIN_FNAME),
        tgt_fname=os.path.join(args.data_dir, config.TGT_TRAIN_FNAME),
        tokenizer=tokenizer,
        min_len=args.min_length_train,
        max_len=args.max_length_train,
        sort=False,
        max_size=None)

    val_dataset = ParallelDataset(
        src_fname=os.path.join(args.data_dir, config.SRC_VAL_FNAME),
        tgt_fname=os.path.join(args.data_dir, config.TGT_VAL_FNAME),
        tokenizer=tokenizer,
        min_len=args.min_length_train,
        max_len=args.max_length_train,
        sort=True)

    distributed_sampler = False
    if configuration_maps['stage_to_rank_map'] is not None:
        num_ranks_in_first_stage = len(
            configuration_maps['stage_to_rank_map'][0])
        if num_ranks_in_first_stage > 1:
            distributed_sampler = True

    # TODO: fix random seeds
    train_loader = train_dataset.get_loader(
        batch_size=args.batch_size,
        seeds=range(args.epochs),
        batch_first=False,
        shuffle=True,
        bucketing=not args.no_bucketing,
        num_workers=args.workers,
        world_size=r.num_ranks_in_first_stage,
        rank=r.rank_in_stage if r.stage == 0 else 0)

    val_loader = val_dataset.get_loader(
        batch_size=args.batch_size,
        batch_first=False,
        shuffle=True,
        num_workers=args.workers,
        world_size=r.num_ranks_in_first_stage,
        seeds=range(args.epochs),
        rank=r.rank_in_stage if r.stage == 0 else 0)

    # if checkpoint is loaded, start by running validation
    if args.resume:
        assert args.start_epoch > 0
        validate(val_loader, r, args.start_epoch - 1)

    for epoch in range(args.start_epoch, args.epochs):
        if distributed_sampler:
            train_loader.sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args.epochs, r, args.lr_policy)

        # train or run forward pass only for one epoch
        if args.forward_only:
            validate(val_loader, r, epoch)
        else:
            train(train_loader, r, optimizer, epoch)

            # evaluate on validation set
            prec1 = validate(val_loader, r, epoch)
            if r.stage != r.num_stages: prec1 = 0

            # remember best prec@1 and save checkpoint
            best_prec1 = max(prec1, best_prec1)

            should_save_checkpoint = args.checkpoint_dir_not_nfs or r.rank_in_stage == 0
            if args.checkpoint_dir and should_save_checkpoint:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': r.state_dict(),
                        'best_prec1': best_prec1,
                        'optimizer': optimizer.state_dict(),
                        'tokenizer': tokenizer.get_state()
                    }, args.checkpoint_dir, r.stage, epoch)
def main():
    global args, best_prec1
    args = parser.parse_args()
    if rank >= args.world_size:
        return
    print("initialising device...")
    local_rank = rank % args.num_ranks_in_server
    print("workers = ", args.workers)

    writer = None
    if args.log_dir:
        writer = SummaryWriter(log_dir=args.log_dir)
    ##### ENABLING GPU DIRECT HERE THROUGH A HACK ###
    args.num_ranks_in_server = args.world_size

    torch.cuda.set_device(local_rank)
    print("local rank {} device {}".format(local_rank,
                                           torch.cuda.current_device()))
    args.rank = rank  # my change
    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss()

    # create stages of the model
    module = importlib.import_module(args.module)
    args.arch = module.arch()
    model = module.model(criterion)
    print("Local rank {} imported module".format(local_rank))

    # determine shapes of all tensors in passed-in model

    target_size = [args.batch_size]
    if args.dataset_name == "ImageNet":
        if args.arch == 'inception_v3':
            input_size = [args.batch_size, 3, 299, 299]
        else:
            input_size = [args.batch_size, 3, 224, 224]
        first_stage_input_dtype = torch.float32
    elif args.dataset_name == "MNIST":
        input_size = [args.batch_size, 1, 28, 28]
        first_stage_input_dtype = torch.float32
    elif args.dataset_name == "CIFAR10":
        input_size = [args.batch_size, 3, 32, 32]
        first_stage_input_dtype = torch.float32
    elif args.dataset_name in ["wikitext-2", "wikitext-103"]:
        input_size = [args.batch_size, args.bptt_len]
        first_stage_input_dtype = torch.int64
        target_size = [args.batch_size * args.bptt_len]
    else:
        print("Dataset {} not supported".format(args.dataset_name))

    training_tensor_shapes = {"input0": input_size, "target": target_size}
    dtypes = {"input0": torch.int64, "target": torch.int64}
    inputs_module_destinations = {"input": 0}
    target_tensor_names = {"target"}
    stage_number = 0
    for (stage, inputs, outputs) in model[:-1]:  # Skip last layer (loss).
        input_tensors = []
        for input in inputs:
            if stage_number == 0:
                input_dtype = first_stage_input_dtype
            else:
                input_dtype = torch.float32

            input_tensor = torch.zeros(tuple(training_tensor_shapes[input]),
                                       dtype=input_dtype).cuda()

            input_tensors.append(input_tensor)
        stage_number += 1
        stage.cuda()
        with torch.no_grad():
            output_tensors = stage(*tuple(input_tensors))
        if not type(output_tensors) is tuple:
            output_tensors = [output_tensors]
        for output, output_tensor in zip(outputs, list(output_tensors)):
            training_tensor_shapes[output] = list(output_tensor.size())
            dtypes[output] = output_tensor.dtype
        del output_tensors
        del input_tensors
        stage.cpu()

    #print("local rank {} finished 1 forward pass...".format(local_rank))
    eval_tensor_shapes = {}
    for key in training_tensor_shapes:
        eval_tensor_shapes[key] = tuple([args.eval_batch_size] +
                                        training_tensor_shapes[key][1:])
        training_tensor_shapes[key] = tuple(training_tensor_shapes[key])

    configuration_maps = {
        'module_to_stage_map': None,
        'stage_to_rank_map': None,
        'stage_to_depth_map': None
    }
    if args.config_path is not None:
        json_config_file = json.load(open(args.config_path, 'r'))
        configuration_maps['module_to_stage_map'] = json_config_file.get(
            "module_to_stage_map", None)
        configuration_maps['stage_to_rank_map'] = json_config_file.get(
            "stage_to_rank_map", None)
        configuration_maps['stage_to_rank_map'] = {
            int(k): v
            for (k, v) in configuration_maps['stage_to_rank_map'].items()
        }
        # print("========================")
        # print(configuration_maps['stage_to_rank_map'])

        configuration_maps['stage_to_depth_map'] = json_config_file.get(
            "stage_to_depth_map", None)

    if args.data_prl:
        print("Modifying stage to rank map to be data parallel")
        stage_to_rank_map = configuration_maps['stage_to_rank_map']
        for k in stage_to_rank_map:
            stage_to_rank_map[k] = list(range(args.world_size))

    print("Local rank {} Staging runtime....".format(local_rank))

    if args.language_modelling:
        model_type = runtime.LANGUAGE_MODELLING
    else:
        model_type = runtime.IMAGE_CLASSIFICATION

    r = runtime.StageRuntime(
        model=model,
        distributed_backend=args.distributed_backend,
        fp16=args.fp16,
        loss_scale=args.loss_scale,
        training_tensor_shapes=training_tensor_shapes,
        eval_tensor_shapes=eval_tensor_shapes,
        training_tensor_dtypes=dtypes,
        inputs_module_destinations=inputs_module_destinations,
        target_tensor_names=target_tensor_names,
        configuration_maps=configuration_maps,
        master_addr=args.master_addr,
        rank=args.rank,
        local_rank=local_rank,
        num_ranks_in_server=args.num_ranks_in_server,
        verbose_freq=args.verbose_frequency,
        model_type=model_type,
        enable_recompute=args.recompute)

    # stage needed to determine if current stage is the first stage
    # num_stages needed to determine if current stage is the last stage
    # num_ranks needed to determine number of warmup_minibatches in case of pipelining
    args.stage = r.stage
    args.num_stages = r.num_stages
    args.num_ranks = r.num_ranks
    if not is_first_stage():
        args.synthetic_data = True

    # define optimizer
    if args.no_input_pipelining:
        num_versions = 1
    else:
        # number of versions is the total number of machines following the current
        # stage, shared amongst all replicas in this stage
        num_versions = r.num_warmup_minibatches + 1

    # if specified, resume from checkpoint
    if args.resume:
        checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage)
        assert os.path.isfile(checkpoint_file_path)
        print("=> loading checkpoint '{}'".format(checkpoint_file_path))
        checkpoint = torch.load(checkpoint_file_path)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        r.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file_path, checkpoint['epoch']))

    #print_msg(args.rank, "number of versions" + str(num_versions) )

    if args.language_modelling:
        optimizer = adam.AdamWithWeightStashing(
            r.modules(),
            r.master_parameters,
            r.model_parameters,
            args.loss_scale,
            num_versions=num_versions,
            lr=args.lr,
            weight_decay=args.weight_decay,
            verbose_freq=args.verbose_frequency,
            macrobatch=args.macrobatch)
    else:
        # optimizer = sgd.SGDWithWeightStashing(r.modules(), r.master_parameters,
        #                                   r.model_parameters, args.loss_scale,
        #                                   num_versions=num_versions,
        #                                   lr=args.lr,
        #                                   momentum=args.momentum,
        #                                   weight_decay=args.weight_decay,
        #                                   verbose_freq=args.verbose_frequency,
        #                                   macrobatch=args.macrobatch)
        optimizer = adam.AdamWithWeightStashing(
            r.modules(),
            r.master_parameters,
            r.model_parameters,
            args.loss_scale,
            num_versions=num_versions,
            lr=args.lr,
            weight_decay=args.weight_decay,
            verbose_freq=args.verbose_frequency,
            macrobatch=args.macrobatch)

    if args.resume:
        optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    print(args.dataset_name)

    if args.dataset_name == "ImageNet":
        if args.arch == 'inception_v3':
            if args.synthetic_data:
                train_dataset = SyntheticDatasetImageClassification(
                    (3, 299, 299), 10000)
            else:
                traindir = os.path.join(args.data_dir, 'train')
                train_dataset = datasets.ImageFolder(
                    traindir,
                    transforms.Compose([
                        transforms.RandomResizedCrop(299),
                        transforms.ToTensor(),
                        normalize,
                    ]))
        else:
            print("Initialising dataset..")
            if args.synthetic_data:
                train_dataset = SyntheticDatasetImageClassification(
                    (3, 224, 224), 1281168)  #modified
            else:
                traindir = os.path.join(args.data_dir, 'train')
                train_dataset = datasets.ImageFolder(
                    traindir,
                    transforms.Compose([
                        transforms.RandomResizedCrop(224),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        normalize,
                    ]))

        if args.synthetic_data:
            val_dataset = SyntheticDatasetImageClassification((3, 224, 224),
                                                              10000)
        else:
            valdir = os.path.join(args.data_dir, 'val')
            val_dataset = datasets.ImageFolder(
                valdir,
                transforms.Compose([
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]))
            # val_dataset = SyntheticDatasetImageClassification((3, 224, 224), 10000)

    elif args.dataset_name == "MNIST":
        train_dataset = datasets.MNIST(args.data_dir,
                                       download=True,
                                       train=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, )),
                                       ]))

        val_dataset = datasets.MNIST(
            args.data_dir,
            download=True,
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(
                ),  # first, convert image to PyTorch tensor
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]))
    elif args.dataset_name == "CIFAR10":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        train_dataset = datasets.CIFAR10(root=args.data_dir,
                                         train=True,
                                         transform=transform)
        val_dataset = datasets.CIFAR10(root=args.data_dir,
                                       train=False,
                                       transform=transform)
    elif args.dataset_name in args.dataset_name in [
            "wikitext-2", "wikitext-103"
    ]:
        tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
        if not args.synthetic_data:
            train_dataset = huggingface.get_dataset(args.dataset_name,
                                                    tokenizer,
                                                    'train',
                                                    num_workers=1,
                                                    bptt_len=args.bptt_len,
                                                    cache_dir=args.data_dir)
            val_dataset = huggingface.get_dataset(args.dataset_name,
                                                  tokenizer,
                                                  'validation',
                                                  num_workers=1,
                                                  bptt_len=args.bptt_len,
                                                  cache_dir=args.data_dir)
        else:
            if args.dataset_name == "wikitext-2":
                train_length = 36718
            else:
                train_length = 1801350
            train_dataset = SyntheticDatasetLanguageModelling(
                tokenizer.vocab_size, args.bptt_len, train_length)
            val_dataset = SyntheticDatasetLanguageModelling(
                tokenizer.vocab_size, args.bptt_len, 3760)

    distributed_sampler = False
    train_sampler = None
    val_sampler = None
    if configuration_maps['stage_to_rank_map'] is not None:
        num_ranks_in_first_stage = len(
            configuration_maps['stage_to_rank_map'][0])
        if num_ranks_in_first_stage > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=num_ranks_in_first_stage,
                rank=args.rank % num_ranks_in_first_stage)
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_dataset,
                num_replicas=num_ranks_in_first_stage,
                rank=args.rank % num_ranks_in_first_stage)
            distributed_sampler = True

    print("initialising data loaders")
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             drop_last=True)

    print(
        f"Rank {args.rank}: Length of train loader: {len(train_loader)} Length of dataset: {len(train_dataset)} BPTT_LEN {args.bptt_len} BATCH SIZE {args.batch_size}"
    )
    # else:
    #     train_loader = None
    #     val_loader = None

    # if args.rank==0:
    #     lengths = torch.LongTensor([len(train_loader), len(val_loader)]).cuda()
    # else:
    #     lengths = torch.zeros((2)).long().cuda()

    lengths = torch.LongTensor([len(train_loader), len(val_loader)])

    if rank == 0:
        quantities = [len(configuration_maps['stage_to_rank_map'][0])]
        for i in range(len(configuration_maps['stage_to_rank_map']) - 1):
            curr = len(configuration_maps['stage_to_rank_map'][i])
            curr *= len(configuration_maps['stage_to_rank_map'][i + 1])
            quantities.append(curr)
        print(quantities)
        lcm = np.lcm.reduce(quantities)
        print(f"new length should be a multiple of {lcm}")
        old_length = lengths[0].item()
        lengths[0] = (lengths[0] // lcm) * lcm
        print(
            f"Rank {args.rank} : Old Train length {old_length} Adjusted Length {lengths[0]}"
        )
        old_length = lengths[1].item()
        lengths[1] = (lengths[1] // lcm) * lcm
        print(
            f"Rank {args.rank} Old Val length {old_length} Adjusted Length {lengths[1]}"
        )
        dist.broadcast(lengths, src=0)
    else:
        dist.broadcast(lengths, src=0)
        num_ranks_in_first_stage = len(
            configuration_maps['stage_to_rank_map'][0])
        lengths[0] *= num_ranks_in_first_stage
        lengths[1] *= num_ranks_in_first_stage
        lengths[0] = lengths[0] // r.num_ranks_in_stage
        lengths[1] = lengths[1] // r.num_ranks_in_stage
        train_len = lengths[0]
        val_len = lengths[1]
        print(
            f"rank {args.rank}, Adjusted train length {train_len}, Adjusted val length {val_len}"
        )

    #exit()
    # if checkpoint is loaded, start by running validation
    if args.resume:
        assert args.start_epoch > 0
        validate(val_loader, r, args.start_epoch - 1)

    for epoch in range(args.start_epoch, args.epochs):
        if args.rank == 0 and distributed_sampler:
            train_sampler.set_epoch(epoch)

        # train or run forward pass only for one epoch
        if args.forward_only:
            validate(val_loader, r, epoch)
        else:
            train(train_loader, r, optimizer, epoch, model_type, lengths,
                  writer)

            # evaluate on validation set
            prec1 = validate(val_loader, r, epoch, lengths, model_type)