Example #1
0
def data_setup():
    """Sets up logging, random seeds and corpus"""
    # global variables
    # Set the random seed manually for reproducibility.
    random.seed(g.args.seed)
    np.random.seed(g.args.seed)
    torch.manual_seed(g.args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(g.args.seed)
        torch.cuda.set_device(g.args.local_rank)

    g.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ###############################################################################
    # Load data
    ###############################################################################
    g.corpus = get_lm_corpus(g.args.data, g.args.dataset, use_bpe=g.args.bpe)
    g.ntokens = len(g.corpus.vocab)

    g.va_iter, g.te_iter = [
        g.corpus.get_dist_iterator(split, bsz=g.args.batch_size * 2, bptt=g.args.tgt_len, rank=util.get_global_rank(),
                                   max_rank=util.get_world_size(),
                                   device=g.device, ext_len=g.args.ext_len)
        for split in ('valid', 'test')
    ]
Example #2
0
def test_optimize():
    global log

    recv_bytes, transmit_bytes = util.network_bytes()
    
    device = 'cuda'
    fp16 = True

    dim = 2 ** 12  # multiple of 8, about 67MB matrix in fp32

    model = SimpleNet(args.num_layers, dim)
    model = model.to(device)
    if fp16:
        model = model.half()
        bytes_per_number = 2
    else:
        bytes_per_number = 4

    gradient_size = args.num_layers * (dim * dim) * bytes_per_number
    size_mb = gradient_size / 1e6

    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=util.get_world_size())

    model = DistributedDataParallel(model,
                                    device_ids=[args.local_rank],
                                    output_device=args.local_rank)
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    x = torch.eye(dim)
    x = x.to(device)
    if fp16:
        x = x.half()
    time_list = []
    start_time = time.perf_counter()
    start_time0 = start_time
    for i in range(1):
        optimizer.zero_grad()

        output = model(x)

        def sqr(a): return a*a
        loss = sqr(output-x).sum()
        loss.backward()
        optimizer.step()

        torch.cuda.synchronize()
        elapsed_time_sec = (time.perf_counter() - start_time)
        start_time = time.perf_counter()
        
        elapsed_time_ms = elapsed_time_sec * 1000
        time_list.append(elapsed_time_ms)
        rate = size_mb / elapsed_time_sec
Example #3
0
def main_loop():
    util.cancel_shutdown()
    losses = []

    args = g.args

    if not args.local:
        g.logger.info(
            f'Distributed initializing process group with '
            f'{args.dist_backend}, {args.dist_url}, {util.get_world_size()}')
        dist.init_process_group(
            backend=args.dist_backend,
            #init_method=args.dist_url,
            #world_size=util.get_world_size()
        )
        assert (util.get_world_size() == dist.get_world_size())
        g.logger.info(
            f"Distributed: success ({args.local_rank}/{dist.get_world_size()})"
        )

    g.logger.info("creating new model")
    g.state = TrainState(args)
    g.state.model = MemTransformerLM(g.ntokens,
                                     args.n_layer,
                                     args.n_head,
                                     args.d_model,
                                     args.d_head,
                                     args.d_inner,
                                     args.dropout,
                                     args.dropatt,
                                     tie_weight=args.tied,
                                     d_embed=args.d_embed,
                                     div_val=args.div_val,
                                     tie_projs=g.tie_projs,
                                     pre_lnorm=args.pre_lnorm,
                                     tgt_len=args.tgt_len,
                                     ext_len=args.ext_len,
                                     mem_len=args.mem_len,
                                     cutoffs=g.cutoffs,
                                     same_length=args.same_length,
                                     attn_type=args.attn_type,
                                     clamp_len=args.clamp_len,
                                     sample_softmax=args.sample_softmax,
                                     freeze_below=args.freeze_below)
    g.state.model.to(g.device)
    optimizer_setup(g.state)
    if args.checkpoint:
        if args.checkpoint_secondary:
            g.logger.info(f"restoring extra checkpoint")
            util.restore_from_checkpoint(g.state.model, g.state.optimizer,
                                         args.checkpoint_secondary,
                                         args.optim_state_dict)
        g.logger.info(f"Restoring model from {args.checkpoint}" +
                      f" and optimizer from {args.optim_state_dict}" if args.
                      optim_state_dict else "")
        util.restore_from_checkpoint(g.state.model, g.state.optimizer,
                                     args.checkpoint, args.optim_state_dict)

    else:
        g.state.model.apply(weights_init)
        # ensure embedding init is not overridden by out_layer in case of weight sharing
        g.state.model.word_emb.apply(weights_init)

    model: MemTransformerLM = g.state.model
    optimizer = g.state.optimizer

    if g.state.args.fp16:
        model = FP16_Module(model)
        optimizer = FP16_Optimizer(
            optimizer,
            static_loss_scale=g.state.args.static_loss_scale,
            dynamic_loss_scale=g.state.args.dynamic_loss_scale,
            dynamic_loss_args={'init_scale': 2**16},
            verbose=False)

    # log model info
    # n_all_param = sum([p.nelement() for p in model.parameters()])
    # log_tb('sizes/params', n_all_param)
    # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    # log_tb('sizes/non_emb_params', n_nonemb_param)
    # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # scheduler
    if args.scheduler == 'cosine':
        # Divide by 1e6 for numerical stability.
        g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.max_tokens // 1e6, eta_min=args.eta_min)
    elif args.scheduler == 'finder':
        g.state.scheduler: LRFinder = LRFinder(optimizer,
                                               args.max_tokens,
                                               init_value=args.lr / 1e3)
    else:
        assert args.scheduler == 'constant'
        g.state.scheduler = util.NoOp()

    # Setup distributed model
    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank)  # , find_unused_parameters=True)

    if util.get_global_rank() == 0:
        if not args.test:
            wandb.config.update(vars(args))
            # wandb.watch(model)

    g.event_writer.add_text('args', str(args))  # TODO: replace with log_tb

    accumulated_loss = 0
    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=g.state.last_epoch):
            print(f"epoch -- {epoch}, token_count -- {g.state.token_count}")
            model.train()

            log_tb('sizes/batch_size', args.batch_size)
            log_tb('sizes/seq_size', args.tgt_len)

            if g.state.partial_epoch:
                # reuse previously loaded tr_iter and states
                assert g.state.tr_iter is not None
                assert g.state.mems is not None
            else:
                g.state.tr_iter = g.corpus.get_dist_iterator(
                    'train',
                    rank=util.get_global_rank(),
                    max_rank=util.get_world_size(),
                    bsz=args.batch_size,
                    bptt=args.tgt_len,
                    device=g.device,
                    ext_len=args.ext_len,
                    skip_files=g.args.skip_files)
                g.state.mems = tuple()
            g.state.last_epoch = epoch

            log_start_time = time.time()
            tokens_per_epoch = 0
            for batch, (data, target, seq_len) in enumerate(g.state.tr_iter):
                # assert seq_len == data.shape[0]
                # for i in range(1, data.shape[0]):
                #     assert torch.all(torch.eq(data[i], target[i - 1]))
                #     break

                # print(g.state.token_count, data)

                if g.state.train_step % args.eval_interval == 0:
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-1',
                                     generate_text=False,
                                     reset_mems_interval=1)
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-2',
                                     generate_text=False,
                                     reset_mems_interval=2)
                    evaluate_and_log(model,
                                     g.va_iter,
                                     'val_short-mem-3',
                                     generate_text=False,
                                     reset_mems_interval=3)
                    evaluate_and_log(model, g.va_iter, 'val')
                    if g.va_custom_iter:
                        evaluate_and_log(g.state.model,
                                         g.va_custom_iter,
                                         g.args.valid_custom,
                                         generate_text=False)

                batch_total = torch.tensor(data.shape[1]).to(g.device)
                if args.local:  # TODO(y): factor out (need way to see if dist was inited)
                    batch_total = batch_total.sum()
                else:
                    batch_total = util.dist_sum_tensor(
                        batch_total)  # global batch size
                batch_total = util.toscalar(batch_total)

                should_log = (g.state.train_step < args.verbose_log_steps) or \
                             (g.state.train_step + 1) % args.log_interval == 0

                model.zero_grad()

                ret = model(data, target, *g.state.mems)
                loss, g.state.mems = ret[0], ret[1:]

                loss: torch.Tensor = loss.float().mean().type_as(loss)
                with timeit('backwards', noop=not should_log):
                    if args.fp16:
                        optimizer.backward(loss)
                    else:
                        loss.backward()
                loss0 = util.toscalar(loss)
                util.record('loss', loss0)

                util.record('params', torch.sum(util.flat_param(model)).item())
                losses.append(loss0)
                accumulated_loss += loss0

                if args.fp16:
                    optimizer.clip_master_grads(args.clip)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.clip)

                # step-wise learning rate annealing
                if hasattr(optimizer, 'overflow') and optimizer.overflow:
                    g.logger.info("skipped iteration")
                else:
                    if args.scheduler in ['cosine', 'constant', 'dev_perf']:
                        # linear warmup stage
                        if g.state.token_count < args.warmup_tokens:
                            curr_lr = args.lr * float(
                                g.state.token_count) / args.warmup_tokens
                            optimizer.param_groups[0]['lr'] = curr_lr
                        elif args.scheduler == 'cosine':
                            # Divide by 1e6 for numerical stability.
                            g.state.scheduler.step(g.state.token_count //
                                                   1000 // 1000)
                    else:
                        g.state.scheduler.step(g.state.token_count)

                optimizer.step()
                g.state.train_step += 1

                consumed_tokens = data.shape[0] * data.shape[1]
                world_size = int(os.environ.get("WORLD_SIZE", "8"))
                if world_size > 8:  # correction factor for multiple machines
                    consumed_tokens = consumed_tokens * (world_size // 8)
                tokens_per_epoch += consumed_tokens
                g.state.token_count += consumed_tokens
                g.token_count = g.state.token_count
                if g.state.token_count >= args.max_tokens:
                    g.state.partial_epoch = True
                    raise StopIteration  # break out of parent train loop

                if should_log:
                    elapsed_time = time.time() - log_start_time
                    elapsed_steps = g.state.train_step - g.state.last_log_step

                    # compute average loss over last logging interval
                    cur_loss = accumulated_loss / elapsed_steps
                    cur_loss_mean = util.dist_mean(cur_loss)
                    log_str = f'| epoch {epoch:3d} step {g.state.train_step:>8d} ' \
                              f'| {batch:>6d} batches ' \
                              f'| lr {optimizer.param_groups[0]["lr"]:.3g} ' \
                              f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} ' \
                              f'| loss {cur_loss:5.2f}'
                    if args.dataset in ['enwik8', 'text8']:
                        log_str += f' | bpc {cur_loss / math.log(2):9.5f}'
                    else:
                        log_str += f' | ppl {math.exp(cur_loss):9.3f}'
                    g.logger.info(log_str)
                    log_tb('learning/epoch', epoch)
                    log_tb('_loss', cur_loss_mean)  # the most important thing
                    log_tb('learning/loss', cur_loss_mean)
                    log_tb('learning/ppl', math.exp(cur_loss_mean))

                    # currently step timings are not synchronized in multi-machine
                    # case (see #4). Can add torch.distributed.barrier() to get
                    # more accurate timings, but this may add slowness.
                    log_tb('times/step', 1000 * elapsed_time / elapsed_steps)
                    current_lr = optimizer.param_groups[0]['lr']

                    log_tb('learning/lr', current_lr)

                    # 32 is the "canonical" batch size
                    linear_scaling_factor = batch_total / 32  # TODO(y): merge logic from master
                    log_tb('learning/base_lr',
                           current_lr / linear_scaling_factor)
                    if args.optim == 'lamb':
                        log_lamb_rs(optimizer, g.event_writer,
                                    g.state.token_count)

                    time_per_batch = elapsed_time / elapsed_steps
                    time_per_sample = time_per_batch / args.batch_size
                    time_per_token = time_per_sample / args.tgt_len

                    log_tb('times/batches_per_sec', 1 / time_per_batch)
                    log_tb('times/samples_per_sec', 1 / time_per_sample)
                    log_tb('times/tokens_per_sec', 1 / time_per_token)

                    if str(g.device) == 'cuda':
                        log_tb("memory/allocated_gb",
                               torch.cuda.memory_allocated() / 1e9)
                        log_tb("memory/max_allocated_gb",
                               torch.cuda.max_memory_allocated() / 1e9)
                        log_tb("memory/cached_gb",
                               torch.cuda.memory_cached() / 1e9)
                        log_tb("memory/max_cached_gb",
                               torch.cuda.max_memory_cached() / 1e9)

                    accumulated_loss = 0
                    log_start_time = time.time()
                    g.state.last_log_step = g.state.train_step

            if args.checkpoint_each_epoch:
                g.logger.info(f'Saving checkpoint for epoch {epoch}')
                util.dist_save_checkpoint(model,
                                          optimizer,
                                          args.logdir,
                                          suffix=f'{epoch}')
            if tokens_per_epoch == 0:
                logging.info("Zero tokens in last epoch, breaking")

                break

            g.state.partial_epoch = False

    except KeyboardInterrupt:
        g.logger.info('-' * 100)
        g.logger.info('Exiting from training early')
    except StopIteration:
        pass

    return losses
Example #4
0
def main():
    global global_token_count, event_writer, train_step, train_loss, last_log_step, \
        best_val_loss, epoch, model

    if args.local_rank > 0:
        pass  # skip shutdown when rank is explicitly set + not zero rank
    else:
        os.system('shutdown -c')

    if not args.local:
        logger.info(
            f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}'
        )
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=util.get_world_size())
        assert (util.get_world_size() == dist.get_world_size())
        logger.info(
            f"Distributed: success ({args.local_rank}/{dist.get_world_size()})"
        )

    model = MemTransformerLM(ntokens,
                             args.n_layer,
                             args.n_head,
                             args.d_model,
                             args.d_head,
                             args.d_inner,
                             args.dropout,
                             args.dropatt,
                             tie_weight=args.tied,
                             d_embed=args.d_embed,
                             div_val=args.div_val,
                             tie_projs=tie_projs,
                             pre_lnorm=args.pre_lnorm,
                             tgt_len=args.tgt_len,
                             ext_len=args.ext_len,
                             mem_len=args.mem_len,
                             cutoffs=cutoffs,
                             same_length=args.same_length,
                             attn_type=args.attn_type,
                             clamp_len=args.clamp_len,
                             sample_softmax=args.sample_softmax)

    # log model info
    n_all_param = sum([p.nelement() for p in model.parameters()])
    log_tb('sizes/params', n_all_param)
    n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    log_tb('sizes/non_emb_params', n_nonemb_param)
    logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # optimizer
    if args.optim.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.mom)
    elif args.optim.lower() == 'lamb':
        optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd)
    else:
        assert args.optim.lower() == 'adam'
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.wd)

    # scheduler
    if args.scheduler == 'cosine':
        # Divide by 1e6 for numerical stability.
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         args.max_tokens //
                                                         1e6,
                                                         eta_min=args.eta_min)
    elif args.scheduler == 'finder':
        scheduler = LRFinder(optimizer,
                             args.max_tokens,
                             init_value=args.lr / 1e3)
    elif args.scheduler == 'constant':
        pass

    model.apply(weights_init)
    model.word_emb.apply(
        weights_init
    )  # ensure embedding init is not overridden by out_layer in case of weight sharing

    if args.checkpoint:
        if global_rank == 0:
            util.restore_from_checkpoint(model=model,
                                         checkpoint_fn=args.checkpoint)

    model = model.to(device)
    if args.fp16:
        model = FP16_Module(model)
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={'init_scale': 2**16},
                                   verbose=False)

    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank)  #, find_unused_parameters=True)

    if global_rank == 0:
        event_writer = SummaryWriter(args.logdir)

    event_writer.add_text('args', str(args))

    # test checkpoint writing
    if args.checkpoint_each_epoch:
        logger.info(f'Saving checkpoint for epoch {epoch}')
        util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{0}')

    # Loop over epochs.
    train_step = 0
    train_loss = 0
    last_log_step = 0
    best_val_loss = None
    va_iter, te_iter = [
        corpus.get_dist_iterator(split,
                                 global_rank,
                                 max_rank,
                                 args.batch_size * 2,
                                 args.tgt_len,
                                 device=device,
                                 ext_len=args.ext_len)
        for split in ('valid', 'test')
    ]

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=1):
            train(va_iter, optimizer, scheduler)
    except KeyboardInterrupt:
        logger.info('-' * 100)
        logger.info('Exiting from training early')
    except StopIteration:
        pass

    # Eval one more time.
    evaluate_and_log(optimizer, va_iter, 'val', train_step=-1)

    # Load the best saved model.
    logger.info("Loading best checkpoint")
    model_file = os.path.join(args.logdir, 'model-best.pt')
    if os.path.exists(model_file):
        with open(model_file, 'rb') as model_f:
            with timeit('load'):
                if args.local:
                    model = torch.load(model_f)
                else:
                    model = torch.load(model_f,
                                       map_location=lambda storage, loc:
                                       storage.cuda(args.local_rank))
                    model = DistributedDataParallel(
                        model,
                        device_ids=[args.local_rank],
                        output_device=args.local_rank)
    else:
        logger.warn('no model file, using current model for loss')

    # Run on test data.
    evaluate_and_log(optimizer, te_iter, 'test', -1)
Example #5
0
                    type=int,
                    help='how long to wait before shutting down on error')

args = parser.parse_args()
args.tied = not args.not_tied

# global variables
global_timeit_dict = OrderedDict()
global_token_count = 0
event_writer = util.NoOp()
epoch = 0
train_step = 0

local_rank = args.local_rank
global_rank = util.get_global_rank()
max_rank = util.get_world_size()


class FileLogger:
    def __init__(self, output_dir: str, global_rank: int, local_rank: int):
        self.output_dir = output_dir
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir, exist_ok=True)
        self.logger = FileLogger.get_logger(output_dir,
                                            global_rank=global_rank,
                                            local_rank=local_rank)

    def exception(self, *args_, **kwargs):
        return self.logger.exception(*args_, **kwargs)

    @staticmethod
def test_optimize():
    global log

    recv_bytes, transmit_bytes = util.network_bytes()

    device = 'cuda'

    dim = 2 ** 12  # multiple of 8, about 67MB matrix in fp32

    model = SimpleNet(args.num_layers, dim)
    model = model.to(device)
    if fp16:
        model = model.half()
        bytes_per_number = 2
    else:
        bytes_per_number = 4

    gradient_size = args.num_layers * (dim * dim) * bytes_per_number
    size_mb = gradient_size / 1e6

    log('initializing process group')
    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=util.get_world_size())

    log('calling DDP')
    model = DistributedDataParallel(model,
                                    device_ids=[args.local_rank],
                                    output_device=args.local_rank)
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    x = torch.eye(dim)
    x = x.to(device)
    if fp16:
        x = x.half()
    time_list = []

    # force initialization of NCCL
    dist.all_reduce(torch.ones(()).cuda())
    dist.barrier()

    log("Start timing")
    start_time = time.perf_counter()
    start_time0 = start_time
    for i in range(args.iters):
        optimizer.zero_grad()

        output = model(x)

        def sqr(a): return a * a

        loss = sqr(output - x).sum()
        loss.backward()
        optimizer.step()

        torch.cuda.synchronize()
        elapsed_time_sec = (time.perf_counter() - start_time)
        start_time = time.perf_counter()

        elapsed_time_ms = elapsed_time_sec * 1000
        time_list.append(elapsed_time_ms)
        rate = size_mb / elapsed_time_sec

        log('%03d/%d added %d MBs in %.1f ms: %.2f MB/second %.1f' % (
            i, args.iters, size_mb, elapsed_time_ms, rate, loss))

    del time_list[0]  # first measurement is off because of syncing
    min_time = np.min(time_list)
    median = np.median(time_list)
    log(f"min: {min_time:8.2f}, median: {median:8.2f}, mean: {np.mean(time_list):8.2f}")

    dist.barrier()
    elapsed_time = time.perf_counter() - start_time0
    recv_bytes1, transmit_bytes1 = util.network_bytes()
    log(f"Received {(recv_bytes1 - recv_bytes) / 1e9:.1f}, transmitted {(transmit_bytes1 - transmit_bytes) / 1e9:.1f} "
        f"in {elapsed_time:.1f} seconds")
    log(f"predicted {gradient_size * args.iters / 1e9:.1f}")

    log(f"average observed bw: {(recv_bytes1 - recv_bytes) * 8 / elapsed_time / 1e9:.1f} Gbps")
    
    time_to_sync_buffer_sec = np.mean(time_list)/1000
    effective_bw_gbps = gradient_size/time_to_sync_buffer_sec*8/1e9
    
    log(f"average effective bw: {effective_bw_gbps} Gbps")
Example #7
0
def main_loop():
    util.cancel_shutdown()
    losses = []

    args = g.args

    if not args.local:
        g.logger.info(
            f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}')
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=util.get_world_size())
        assert (util.get_world_size() == dist.get_world_size())
        g.logger.info(f"Distributed: success ({args.local_rank}/{dist.get_world_size()})")

    if args.load_state_fn:
        g.state = load_state(args.load_state_fn)
        g.logger.info(f"Restoring training from {args.load_state_fn}")
    else:
        g.logger.info("creating new model")
        g.state = TrainState(args)

        g.state.model = MemTransformerLM(g.ntokens, args.n_layer, args.n_head, args.d_model,
                                         args.d_head, args.d_inner, args.dropout, args.dropatt,
                                         tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val,
                                         tie_projs=g.tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
                                         ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=g.cutoffs,
                                         same_length=args.same_length, attn_type=args.attn_type,
                                         clamp_len=args.clamp_len, sample_softmax=args.sample_softmax)
        if args.checkpoint:
            util.restore_from_checkpoint(g.state.model, checkpoint_fn=args.checkpoint)
        else:
            g.state.model.apply(weights_init)
            g.state.model.word_emb.apply(
                weights_init)  # ensure embedding init is not overridden by out_layer in case of weight sharing
        g.state.model.to(g.device)
        optimizer_setup(g.state)

    model: MemTransformerLM = g.state.model
    optimizer = g.state.optimizer

    # log model info
    # n_all_param = sum([p.nelement() for p in model.parameters()])
    # log_tb('sizes/params', n_all_param)
    # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    # log_tb('sizes/non_emb_params', n_nonemb_param)
    # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # scheduler
    if not g.args.load_state_fn:
        if args.scheduler == 'cosine':
            # Divide by 1e6 for numerical stability.
            g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_tokens // 1e6,
                                                                     eta_min=args.eta_min)
        elif args.scheduler == 'finder':
            g.state.scheduler: LRFinder = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3)
        else:
            assert args.scheduler == 'constant'
            g.state.scheduler = util.NoOp()

    # Setup distributed model
    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(model, device_ids=[args.local_rank],
                                        output_device=args.local_rank)  # , find_unused_parameters=True)

    if util.get_global_rank() == 0:
        if not args.test:
            wandb.config.update(vars(args))
            # wandb.watch(model)

    g.event_writer.add_text('args', str(args))  # TODO: replace with log_tb

    accumulated_loss = 0
    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=g.state.last_epoch):
            print(f"epoch -- {epoch}, token_count -- {g.state.
def test_allreduce():
    global log

    recv_bytes, transmit_bytes = util.network_bytes()
    
    device = 'cuda'

    dim = 2 ** 12  # multiple of 8, about 67MB matrix in fp32

    if fp16:
        bytes_per_number = 2
    else:
        bytes_per_number = 4

    gradient_size = args.num_layers * (dim * dim) * bytes_per_number
    size_mb = gradient_size / 1e6

    log('initializing process group')
    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=util.get_world_size())

    xs = [torch.ones((dim, dim)) for i in range(args.num_layers)]
    xs = [x.to(device) for x in xs]
    if fp16:
        xs = [x.half() for x in xs]
    time_list = []


    # force initialization of NCCL
    dist.all_reduce(torch.ones(()).cuda())
    dist.barrier()
    
    log("Start timing")
    start_time = time.perf_counter()
    start_time0 = start_time
    for i in range(args.iters):
        
        [dist.all_reduce(x, async_op=True) for x in xs]
        
        torch.cuda.synchronize()
        elapsed_time_sec = (time.perf_counter() - start_time)
        start_time = time.perf_counter()
        
        elapsed_time_ms = elapsed_time_sec * 1000
        time_list.append(elapsed_time_ms)
        rate = size_mb / elapsed_time_sec

        # could do barrier, but didn't have effect on timing
        # dist.barrier()   
        new_result = xs[0]
        log('%03d/%d added %d MBs in %.1f ms: %.2f MB/second %.1f' % (
            i, args.iters, size_mb, elapsed_time_ms, rate, new_result[0,0]))

    del time_list[0]   # first measurement is off because of syncing
    min_time = np.min(time_list)
    median = np.median(time_list)
    log(f"min: {min_time:8.2f}, median: {median:8.2f}, mean: {np.mean(time_list):8.2f}")

    dist.barrier()
    elapsed_time = time.perf_counter() - start_time0
    recv_bytes1, transmit_bytes1 = util.network_bytes()
    log(f"Received {(recv_bytes1-recv_bytes)/1e9:.1f}, transmitted {(transmit_bytes1-transmit_bytes)/1e9:.1f} in {elapsed_time:.1f} seconds")
    log(f"predicted {gradient_size*args.iters/1e9:.1f}")

    log(f"average bw: {(recv_bytes1-recv_bytes)*8/elapsed_time/1e9:.1f} Gbps")