Esempio n. 1
0
def run_batches(model, opt, lr_scheduler, loader, training, epoch_fraction,
                args):
    if not training and epoch_fraction != 1:
        raise ValueError("Must do full epochs for val")
    if epoch_fraction > 1 or epoch_fraction <= 0:
        msg = "Invalid epoch_fraction {}.".format(epoch_fraction)
        msg += " Should satisfy 0 < epoch_fraction <= 1"
        raise ValueError(msg)

    model.train(training)
    losses = []
    accs = []

    client_download = None
    client_upload = None
    start_time = 0
    if training:
        num_clients = loader.dataset.num_clients
        client_download = torch.zeros(num_clients)
        client_upload = torch.zeros(num_clients)
        spe = steps_per_epoch(args.local_batch_size, loader.dataset,
                              args.num_workers)
        for i, batch in enumerate(loader):
            # only carry out an epoch_fraction portion of the epoch
            if i > spe * epoch_fraction:
                break

            lr_scheduler.step()

            if lr_scheduler.get_last_lr()[0] == 0:
                # hack to get the starting LR right for fedavg
                print("HACK STEP")
                opt.step()

            if args.local_batch_size == -1:
                expected_num_clients = args.num_workers
                if torch.unique(batch[0]).numel() < expected_num_clients:
                    # skip if there weren't enough clients left
                    msg = "SKIPPING BATCH: NOT ENOUGH CLIENTS ({} < {})"
                    print(
                        msg.format(
                            torch.unique(batch[0]).numel(),
                            expected_num_clients))
                    continue
            else:
                expected_numel = args.num_workers * args.local_batch_size
                if batch[0].numel() < expected_numel:
                    # skip incomplete batches
                    msg = "SKIPPING BATCH: NOT ENOUGH DATA ({} < {})"
                    print(msg.format(batch[0].numel(), expected_numel))
                    continue

            loss, acc, download, upload = model(batch)
            if np.any(np.isnan(loss)):
                print(f"LOSS OF {np.mean(loss)} IS NAN, TERMINATING TRAINING")
                return np.nan, np.nan, np.nan, np.nan

            client_download += download
            client_upload += upload

            opt.step()
            #model.zero_grad()
            losses.extend(loss)
            accs.extend(acc)
            if args.dataset_name == "EMNIST":
                lr = lr_scheduler.get_last_lr()[0]
                print(
                    "LR: {:0.5f}, Loss: {:0.5f}, Acc: {:0.5f}, Time: {:0.2f}".
                    format(lr,
                           loss.mean().item(),
                           acc.mean().item(),
                           time.time() - start_time))
                start_time = time.time()
            if args.do_test:
                break
    else:
        for batch in loader:
            if batch[0].numel() < args.valid_batch_size:
                print("SKIPPING VAL BATCH: TOO SMALL")
                continue
            loss, acc = model(batch)
            losses.extend(loss)
            accs.extend(acc)
            if args.do_test:
                break

    return np.mean(losses), np.mean(accs), client_download, client_upload
Esempio n. 2
0
    opt = optim.SGD(param_groups, lr=1)

    model = FedModel(model, compute_loss_train, args, compute_loss_val)
    opt = FedOptimizer(opt, args)

    # set up learning rate scheduler
    # original cifar10_fast repo uses [0, 5, 24] and [0, 0.4, 0]
    if args.lr_scale is None:
        args.lr_scale = 0.4
    lr_schedule = PiecewiseLinear([0, args.pivot_epoch, args.num_epochs],
                                  [0, args.lr_scale, 0])

    # grad_reduction only controls how gradients from different
    # workers are combined
    # so the lr is multiplied by num_workers for both mean and median
    spe = steps_per_epoch(args.local_batch_size, train_loader.dataset,
                          args.num_workers)
    lambda_step = lambda step: lr_schedule(step / spe)
    lr_scheduler = LambdaLR(opt, lr_lambda=lambda_step)

    # set up output
    log_dir = make_logdir(args)
    if args.use_tensorboard:
        writer = SummaryWriter(log_dir=log_dir)
    else:
        writer = None
    print('Finished initializing in {:.2f} seconds'.format(timer()))

    # and do the training
    train(model,
          opt,
          lr_scheduler,
Esempio n. 3
0
def run_batches(model,
                opt,
                lr_scheduler,
                loader,
                args,
                timer,
                training,
                epoch=None,
                epoch_fraction=None,
                logger=None,
                writer=None):
    if not training and epoch_fraction != 1:
        raise ValueError("Must do full epochs for val")
    if epoch_fraction > 1 or epoch_fraction <= 0:
        msg = "Invalid epoch_fraction {}.".format(epoch_fraction)
        msg += " Should satisfy 0 < epoch_fraction <= 1"
        raise ValueError(msg)

    model.train(training)
    client_download = torch.zeros(loader.dataset.num_clients)
    client_upload = torch.zeros(loader.dataset.num_clients)
    spe = steps_per_epoch(args.local_batch_size, loader.dataset,
                          args.num_workers)

    if training:
        epoch_idxs = epoch * spe
        losses = []
        for batch_idx, batch in enumerate(loader):
            if batch_idx > 2 and args.do_test and batch_idx < spe - 10:
                print("skipping ", batch_idx)
                continue
            # only carry out an epoch_fraction portion of the epoch
            if batch_idx > spe * epoch_fraction:
                break
            lr_scheduler.step()
            if lr_scheduler.get_lr() == 0:
                # hack to get the starting LR right for fedavg
                opt.step()

            if args.local_batch_size == -1:
                expected_num_clients = args.num_workers
                if torch.unique(batch[0]).numel() < expected_num_clients:
                    # skip if there weren't enough clients left
                    print("SKIPPING BATCH: NOT ENOUGH CLIENTS")
                    continue
            else:
                expected_numel = args.num_workers * args.local_batch_size
                if batch[0].numel() < expected_numel:
                    # skip incomplete batches
                    print("SKIPPING BATCH: NOT ENOUGH DATA")
                    continue

            loss, download, upload = model(batch)

            client_download += download
            client_upload += upload

            opt.step()
            loss = np.mean(loss)
            losses.append(loss)
            train_time = timer()
            download_mb = download.sum().item() / (1024 * 1024)
            upload_mb = upload.sum().item() / (1024 * 1024)
            batch_stats = {
                'train_time': train_time,
                'train_loss': loss,
                'total_time': timer.total_time,
                'down (MiB)': round(download_mb),
                'up (MiB)': round(upload_mb),
            }
            lr = lr_scheduler.get_lr()[0]

            writer.add_scalar('training/loss', loss, batch_idx + epoch_idxs)
            writer.add_scalar('Lr', lr, batch_idx + epoch_idxs)
            writer.add_scalar('Time/train', train_time, batch_idx + epoch_idxs)
            summary = union({
                'batch_idx': batch_idx + 1 + epoch_idxs,
                'lr': lr
            }, batch_stats)
            logger.append(summary)
        return np.mean(losses), client_download, client_upload

    else:
        nlls, accs, ppls = [], [], []
        for batch_idx, batch in enumerate(loader):
            if batch_idx > 5 and args.do_test and batch_idx < spe - 5:
                print("skipping ", batch_idx)
                continue
            nll, acc = model(batch)
            nll = np.mean(nll)
            acc = np.mean(acc)
            nlls.append(nll)
            accs.append(acc)
        return np.mean(nlls), np.mean(accs), np.exp(np.mean(nlls))
Esempio n. 4
0
def train():
    args = parse_args(default_lr=4e-2)

    print(args)

    timer = Timer()
    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    if "gpt2" in args.model_checkpoint:
        tokenizer_class = GPT2Tokenizer
        model_class = GPT2DoubleHeadsModel
    else:
        tokenizer_class = OpenAIGPTTokenizer
        model_class = OpenAIGPTDoubleHeadsModel

    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    if args.do_finetune:
        if not args.do_test:
            args.model_checkpoint = args.finetune_path
    model = model_class.from_pretrained(args.model_checkpoint)

    args.len_tokenizer = len(tokenizer)

    # Do logging now before we overwrite model
    log_dir = make_logdir(args)
    writer = SummaryWriter(log_dir=log_dir)
    tokenizer.save_pretrained(log_dir)
    getattr(model, 'module',
            model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
    # Add special tokens if they are not already added
    add_special_tokens_(model, tokenizer)
    # HAVE TO USE SGD FOR FED
    optimizer = SGD(model.parameters(), lr=1)

    logger.info('Finished in {:.2f} seconds'.format(timer()))
    logger.info("Prepare datasets")
    loaders = get_data_loaders(args, tokenizer)
    train_loader, val_loader = loaders

    logger.info('Finished in {:.2f} seconds'.format(timer()))
    logger.info("Initializing everything")
    model = FedModel(model, compute_loss_train, args, compute_loss_val)
    optimizer = FedOptimizer(optimizer, args)
    spe = steps_per_epoch(args.local_batch_size, train_loader.dataset,
                          args.num_workers)
    print("Steps per epoch", spe)
    lr_schedule = PiecewiseLinear([0, args.num_epochs * spe],
                                  [args.lr_scale, 0.0])
    lambda_step = lambda x: lr_schedule(x)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lr_lambda=[lambda_step])
    if args.do_finetune:
        test_gpt2(model,
                  val_loader,
                  args,
                  logger=TableLogger(),
                  timer=timer,
                  writer=writer)
    else:
        train_gpt2(model,
                   optimizer,
                   scheduler,
                   train_loader,
                   val_loader,
                   args,
                   log_dir,
                   writer=writer,
                   logger=TableLogger(),
                   timer=timer)
    model.finalize()
Esempio n. 5
0
    data_dir = '/Users/xusenyin/git-store/dnd/dataset-for-dialogue'
    model_dir = '/Users/xusenyin/experiments/dialogue-model'
    path_json = model_dir + '/params.json'
    path_src_vocab = data_dir + '/words.txt'
    path_tgt_vocab = data_dir + '/tags.txt'
    path_src = data_dir + '/dev/sentences.txt'
    path_tgt = data_dir + '/dev/labels.txt'

    params = Params(path_json)
    params.update(data_dir + '/dataset_params.json')
    params.eval_size = params.dev_size
    params.buffer_size = params.train_size  # buffer size for shuffling

    inputs = dialogue_input_fn(
        'train', path_src, path_tgt, path_src_vocab, path_tgt_vocab, params)

    src = inputs['src']
    tgt_input = inputs['tgt_in']
    tgt_output = inputs['tgt_out']
    num_tgt_tokens = inputs['num_tgt_tokens']
    res = tf.reduce_max(num_tgt_tokens, axis=0)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.tables_initializer())
        sess.run(inputs['iterator_init_op'])

        for i in range(steps_per_epoch(2000, 128)):
            print(sess.run([tf.shape(res), res]))