示例#1
0
 def save_checkpoint(self, filename, extra_state):
     """Save all training state in a checkpoint file."""
     if distributed_utils.is_master(self.args):  # only save one checkpoint
         extra_state['train_meters'] = self.meters
         checkpoint_utils.save_state(
             filename, self.args, self.get_model().state_dict(), None,
             self.optimizer, self.lr_scheduler, self.get_num_updates(),
             self._optim_history, extra_state,
         )
示例#2
0
def evaluate(args, model, va_loader):
    """Evaluate on validation data."""
    # Keep non-master processes waiting here
    if not is_master(args):
        accuracy = torch.zeros([1]).cuda()
        torch.distributed.barrier()

    # Only master perform evaluation
    if is_master(args):
        num_correct, num_example = 0, 0
        num_tp, num_fp, num_tn, num_fn = 0, 0, 0, 0
        model.eval()
        with torch.no_grad():
            for sent, seg_id, label in va_loader.get_iter(shuffle=False):
                _, ret_dict = model(sent, seg_id=seg_id, cls_target=label)
                cls_corr = ret_dict["cls_corr"]
                num_correct += cls_corr
                num_example += len(sent)
                tp, fp, tn, fn = confusion_matrix(ret_dict["cls_pred"], label)
                num_tp = num_tp + tp
                num_fp = num_fp + fp
                num_tn = num_tn + tn
                num_fn = num_fn + fn

        model.train()

        if args.dataset in ["CoLA"]:
            accuracy = _compute_metric_based_on_keys("corr", num_tp.item(),
                                                     num_fp.item(),
                                                     num_tn.item(),
                                                     num_fn.item())
            accuracy = torch.FloatTensor([accuracy]).cuda()
        else:
            accuracy = num_correct / num_example

        if args.distributed:
            torch.distributed.barrier()

    # sync accuracy
    if args.distributed:
        torch.distributed.all_reduce(accuracy,
                                     op=torch.distributed.ReduceOp.SUM)

    return accuracy.item()
示例#3
0
def setup_special_ids(args, tokenizer):
    """Set up the id of special tokens."""
    special_symbols_mapping = collections.OrderedDict([("<unk>", "unk_id"),
                                                       ("<s>", "bos_id"),
                                                       ("</s>", "eos_id"),
                                                       ("<cls>", "cls_id"),
                                                       ("<sep>", "sep_id"),
                                                       ("<pad>", "pad_id"),
                                                       ("<mask>", "mask_id"),
                                                       ("<eod>", "eod_id"),
                                                       ("<eop>", "eop_id")])
    args.vocab_size = tokenizer.get_vocab_size()
    if is_master(args):
        print("Set vocab_size: {}.".format(args.vocab_size))
    for sym, sym_id_str in special_symbols_mapping.items():
        try:
            sym_id = tokenizer.get_token_id(sym)
            setattr(args, sym_id_str, sym_id)
            if is_master(args):
                print("Set {} to {}.".format(sym_id_str, sym_id))
        except KeyError:
            if is_master(args):
                print("Skip {}: not found in tokenizer's vocab.".format(sym))
示例#4
0
def predict(args, model, loader, out_path, rev_label_dict):
    """Make prediction and write to file. This should only be called by master."""
    # Only master perform prediction
    if is_master(args):
        model.eval()
        with open(out_path, "w") as fo:
            with torch.no_grad():
                for sent, seg_id, label in loader.get_iter(shuffle=False):
                    _, ret_dict = model(sent, seg_id=seg_id, cls_target=label)
                    cls_pred = ret_dict["cls_pred"]
                    for i in range(cls_pred.size(0)):
                        label = rev_label_dict[cls_pred[i].item()]
                        fo.write("{}\n".format(label))

        model.train()
示例#5
0
def main(args):
    args = options.set_default_args(args)

    if args.ddp_backend == 'apex':
        from apex.parallel import DistributedDataParallel as DDP
    else:
        from torch.nn.parallel import DistributedDataParallel as DDP

    ############################################################################
    # Random seed
    ############################################################################
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    ############################################################################
    # Experiment & Logging
    ############################################################################
    if is_master(args):
        if args.resume:
            # rank-0 device creates experiment dir and log to the file
            logging = get_logger(os.path.join(args.expname, 'log.txt'),
                                 log_=not args.debug)
        else:
            # rank-0 device creates experiment dir and log to the file
            logging = create_exp_dir(args.expname, debug=args.debug)
    else:
        # other devices only log to console (print) but not the file
        logging = get_logger(log_path=None, log_=False)

    args.model_path = os.path.join(args.expname, 'model.pt')
    args.var_path = os.path.join(args.expname, 'var.pt')

    ############################################################################
    # Load data
    ############################################################################
    logging('Loading data..')
    tr_data, va_data = options.load_data(args)

    train_step = 0
    best_eval_ll = -float('inf')
    if args.resume:
        logging('Resuming from {}...'.format(args.resume))
        model, opt = torch.load(args.model_path, map_location='cpu')
        model = model.to(args.device)
        for state in opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(args.device)
        best_eval_ll, train_step = torch.load(args.var_path)
    else:
        logging('Building model..')
        if args.model_name in ['srnn', 'srnn_zforce', 'srnn_hier']:
            model = eval(args.model_name).Model(args.n_mix,
                                                args.d_data,
                                                args.d_emb,
                                                args.d_mlp,
                                                args.d_rnn,
                                                args.d_lat,
                                                dropout=args.dropout,
                                                n_layer=args.n_layer)
        elif args.model_name in ['rnn', 'rnn_hier']:
            model = eval(args.model_name).Model(args.n_mix,
                                                args.d_data,
                                                args.d_emb,
                                                args.d_rnn,
                                                dropout=args.dropout,
                                                n_layer=args.n_layer)
        else:
            raise ValueError('unsupported model type {}'.format(
                args.model_name))

        model = model.to(args.device)

        # create new optimizer
        opt = torch.optim.Adam(model.parameters(), lr=args.lr)

    if not args.test_only:
        # criterion params and model params
        crit_params, model_params = [], []
        for n, p in model.named_parameters():
            if 'crit' in n:
                crit_params.append(p)
            else:
                model_params.append(p)

        ############################################################################
        # Distributed Data Parallel
        ############################################################################
        if args.distributed:
            if args.ddp_backend == 'apex':
                torch.cuda.set_device(args.distributed_rank)
                para_model = DDP(model)
            else:
                para_model = DDP(model,
                                 device_ids=[args.device_id],
                                 output_device=args.device_id)
        else:
            para_model = model

        ############################################################################
        # Log args
        ############################################################################
        args.n_crit_param = sum([p.nelement() for p in crit_params])
        args.n_model_param = sum([p.nelement() for p in model_params])
        args.n_param = args.n_crit_param + args.n_model_param
        if is_master(args):
            logging('=' * 100)
            for k, v in args.__dict__.items():
                logging('  - {} : {}'.format(k, v))
            logging('=' * 100)

        ############################################################################
        # Training
        ############################################################################
        # linear cosine annealing
        kld_weight = min(1., args.init_kld + train_step * args.kld_incr)

        loss_sum = torch.Tensor([0]).to(args.device)
        kld_sum = torch.Tensor([0]).to(args.device)
        nll_sum = torch.Tensor([0]).to(args.device)
        gnorm_sum = 0
        t = timeit.default_timer()
        for epoch in range(args.num_epochs):
            model.train()
            # make sure all data iterators use the same seed to shuffle data
            if args.distributed:
                np.random.seed(args.seed + epoch)

            #initalize the hidden state
            if args.pass_h:
                hidden = model.init_hidden(args.batch_size)
            else:
                hidden = None

            for x, y, mask in tr_data.get_masked_iter(shuffle=True):
                opt.zero_grad()
                ratio = 1. / torch.sum(mask)
                if args.kld:
                    nll_loss, kld_loss, hidden = para_model(x,
                                                            y,
                                                            mask=mask,
                                                            hidden=hidden)
                    nll_loss = nll_loss.sum() * ratio
                    kld_loss = kld_loss.sum() * ratio
                    train_loss = nll_loss - kld_loss * kld_weight
                    train_loss.backward()

                    total_loss = nll_loss.detach() - kld_loss.detach()
                    kld_sum += -kld_loss.detach()
                    nll_sum += nll_loss.detach()
                else:
                    nll_loss, hidden = para_model(x,
                                                  y,
                                                  mask=mask,
                                                  hidden=hidden)
                    train_loss = nll_loss.sum() * ratio
                    train_loss.backward()

                    total_loss = train_loss.detach()

                if args.clip > 0:
                    gnorm = nn.utils.clip_grad_norm_(model.parameters(),
                                                     args.clip)
                else:
                    gnorm = 0
                    for n, p in model.named_parameters():
                        param_gnorm = p.grad.data.norm(2)
                        gnorm += param_gnorm.item()**2
                    gnorm = gnorm**(1. / 2)

                opt.step()

                gnorm_sum += gnorm
                loss_sum += total_loss
                train_step += 1

                # lr & kl annealling
                kld_weight = min(1., kld_weight + args.kld_incr)
                adjust_lr(opt, train_step, args.max_step, args.lr, args.end_lr)

                # log training
                if train_step % args.log_interval == 0:
                    if args.distributed:
                        dist.reduce(loss_sum, dst=0, op=dist.ReduceOp.SUM)
                        loss_sum = loss_sum.div_(args.distributed_world_size)
                        dist.reduce(nll_sum, dst=0, op=dist.ReduceOp.SUM)
                        nll_sum = nll_sum.div_(args.distributed_world_size)
                        dist.reduce(kld_sum, dst=0, op=dist.ReduceOp.SUM)
                        kld_sum = kld_sum.div_(args.distributed_world_size)

                    if is_master(args):
                        cur_loss = loss_sum.item() / args.log_interval
                        cur_nll = nll_sum.item() / args.log_interval
                        cur_kld = kld_sum.item() / args.log_interval
                        elapsed = (timeit.default_timer() - t) / 3600
                        logging('| total hrs [{:.2f}] | epoch {} step {} ' \
                                '| lr {:8.6f}, klw {:7.5f} | LL {:>9.4f} ' \
                                '| nll_loss {:>7.4f}, kld_loss {:>8.4f} ' \
                                '| gnorm {:.4f}'.format(
                          elapsed, epoch, train_step, opt.param_groups[0]['lr'],
                          kld_weight, -cur_loss, cur_nll, cur_kld,
                          gnorm_sum / args.log_interval))

                    loss_sum = torch.Tensor([0]).to(args.device)
                    kld_sum = torch.Tensor([0]).to(args.device)
                    nll_sum = torch.Tensor([0]).to(args.device)
                    gnorm_sum = 0

                # validation
                if train_step % args.eval_interval == 0:
                    eval_ll = evaluate(va_data, model, args)
                    if is_master(args):
                        logging('-' * 120)
                        logging('Eval [{}] at step: {} | valid LL: {:>8.4f}'.
                                format(train_step // args.eval_interval,
                                       train_step, eval_ll))
                        if eval_ll > best_eval_ll:
                            best_eval_ll = eval_ll
                            if not args.debug:
                                logging('Save checkpoint. ' \
                                        'Best valid LL {:>9.4f}'.format(eval_ll))
                                torch.save([model, opt], args.model_path)
                                torch.save([best_eval_ll, train_step],
                                           args.var_path)
                        logging('-' * 120)

                # Reach maximum training step
                if train_step == args.max_step:
                    break
            if train_step == args.max_step:
                break

    eval_ll = evaluate(va_data, model, args)
    if is_master(args):
        logging('-' * 120)
        logging('Eval [{}] | step: {}, LL: {:>8.4f}'.format(
            train_step // args.eval_interval, train_step, eval_ll))
        logging('-' * 120)

    # evaluate the current model
    test_loss = evaluate(te_data, model, args)
    if is_master(args):
        logging('Test -- LL: {:>8.4f}'.format(test_loss))
示例#6
0
def main(args, init_distributed=False):
    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)

    #  set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    print(args, flush=True)

    # Setup task, e.g., translation, language modeling, etc.
    task = None
    if args.task == 'bert':
        task = tasks.LanguageModelingTask.setup_task(args)
    elif args.task == 'mnist':
        task = tasks.MNISTTask.setup_task(args)
    assert task != None

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model
    model = task.build_model(args)

    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build controller
    controller = Controller(args, task, model)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator

    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, controller)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf

    lr = controller.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()

    while (lr > args.min_lr and (epoch_itr.epoch < max_epoch or
                                 (epoch_itr.epoch == max_epoch
                                  and epoch_itr._next_epoch_itr is not None))
           and controller.get_num_updates() < max_update):
        # train for one epoch
        train(args, controller, task, epoch_itr)  # #revise-task 6

        # debug
        valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = controller.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, controller, epoch_itr,
                                             valid_losses[0])

        reload_dataset = ':' in getattr(args, 'data', '')
        # sharded data: get train iterator for next epoch
        epoch_itr = controller.get_train_iterator(epoch_itr.epoch,
                                                  load_dataset=reload_dataset)

    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
示例#7
0
def stat(args, data):
    lengths = [len(x[0]) for x in data]
    if is_master(args):
        print("Number of sent: {}".format(len(data)))
        print("Sent length: mean {}, std {}, max {}".format(
            np.mean(lengths), np.std(lengths), np.max(lengths)))
示例#8
0
def convert_examples_to_tensors(args, examples, label_dict, tokenizer,
                                output_file):
    """Encode and cache raw data into pytorch format."""
    if not is_master(args) and args.distributed:
        torch.distributed.barrier()

    if not os.path.exists(output_file) or args.overwrite_data:
        sents, labels, seg_ids = [], [], []
        for (ex_index, example) in enumerate(examples):
            example_len = 0
            tokens_a = tokenizer.convert_text_to_ids(example.text_a)
            example_len += len(tokens_a)
            tokens_b = None
            if example.text_b:
                tokens_b = tokenizer.convert_text_to_ids(example.text_b)
                example_len += len(tokens_b)

            if tokens_b:
                # Modifies `tokens_a` and `tokens_b` in place so that the total
                # length is less than the specified length.
                # Account for two [SEP] & one [CLS] with "- 3"
                _truncate_seq_pair(tokens_a, tokens_b, args.max_length - 3)
            else:
                # Account for one [SEP] & one [CLS] with "- 2"
                if len(tokens_a) > args.max_length - 2:
                    tokens_a = tokens_a[:args.max_length - 2]

            input_ids = []
            segment_ids = []
            if tokens_b is not None:
                input_ids = ([args.cls_id] + tokens_a + [args.sep_id] +
                             tokens_b + [args.sep_id])
                segment_ids = ([args.seg_id_cls] + [args.seg_id_a] *
                               (len(tokens_a) + 1) + [args.seg_id_b] *
                               (len(tokens_b) + 1))
            else:
                input_ids = [args.cls_id] + tokens_a + [args.sep_id]
                segment_ids = ([args.seg_id_cls] + [args.seg_id_a] *
                               (len(tokens_a) + 1))

            # Label
            if label_dict is not None:
                label_id = label_dict[example.label]
            else:
                label_id = example.label

            input_ids = torch.LongTensor(input_ids)
            segment_ids = torch.LongTensor(segment_ids)

            sents.append(input_ids)
            seg_ids.append(segment_ids)
            labels.append(label_id)

        data = list(zip(sents, seg_ids, labels))

        torch.save(data, output_file)
    else:
        data = torch.load(output_file)

    if is_master(args) and args.distributed:
        torch.distributed.barrier()

    stat(args, data)

    return data
示例#9
0
def main(args):
    """Main training function."""
    torch.cuda.set_device(args.device_id)
    if args.distributed:
        args.distributed_rank = args.device_id
        distributed_init(args)
    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
    options.setup_device(args)

    ############################################################################
    # Experiment & Logging
    ############################################################################
    if is_master(args):
        if args.resume:
            # rank-0 device creates experiment dir and log to the file
            logging = utils.get_logger(os.path.join(args.model_dir, "log.txt"),
                                       log_=not args.debug)
        else:
            # rank-0 device creates experiment dir and log to the file
            logging = utils.create_exp_dir(args.model_dir, debug=args.debug)
    else:
        # other devices only log to console (print) but not the file
        logging = utils.get_logger(log_path=None, log_=False)

    ############################################################################
    # Load data
    ############################################################################
    logging("Loading data..")
    loaded_data, label_dict = data.load_data(args)
    args.num_class = len(label_dict)
    logging("Loading finish")
    tr_data, va_data, te_data = loaded_data
    va_loader = data.BucketIterator(va_data, args.valid_bsz, args.pad_id,
                                    args.seg_id_pad, args.device,
                                    args.max_length)
    te_loader = data.BucketIterator(te_data, args.test_bsz, args.pad_id,
                                    args.seg_id_pad, args.device,
                                    args.max_length)

    options.setup_device(args)

    args.model_path = os.path.join(args.model_dir, "model.pt")
    args.var_path = os.path.join(args.model_dir, "var.pt")
    args.config_path = os.path.join(args.model_dir, "net_config.json")
    train_step = 0
    best_accuracy = -float("inf")

    # create model
    if args.resume:
        logging("Resuming from {}...".format(args.model_dir))
        net_config = modeling.ModelConfig.init_from_json(
            args.config_path, args)
        model = modeling.FunnelTFM(net_config, args)
        model_param, optimizer = torch.load(args.model_path,
                                            map_location="cpu")
        logging(model.load_state_dict(model_param, strict=False))
        model = model.to(args.device)
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(args.device)
        best_accuracy, train_step = torch.load(args.var_path)
    else:
        # create new model
        if args.init_ckpt:
            logging("Init from ckpt {}".format(args.init_ckpt))
            net_config = modeling.ModelConfig.init_from_json(
                args.init_ckpt_config, args)
            model = modeling.FunnelTFM(net_config, args)
            print(
                model.load_state_dict(torch.load(args.init_ckpt),
                                      strict=False))
        else:
            logging("init model")
            net_config = modeling.ModelConfig.init_from_args(args)
            model = modeling.FunnelTFM(net_config, args)
        net_config.to_json(args.config_path)
        model = model.to(args.device)

    # create new optimizer
    if args.fp16:
        from apex.optimizers import FusedAdam
        import apex.amp as amp
        optimizer = FusedAdam(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)
        amp_model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.amp_opt)
    else:
        try:
            from apex.optimizers import FusedAdam
            optimizer = FusedAdam(model.parameters(),
                                  lr=args.lr,
                                  betas=(0.9, 0.99),
                                  eps=1e-6,
                                  weight_decay=args.weight_decay)
        except ImportError as e:
            logging("use pytorch optimizer")
            optimizer = torch.optim.AdamW(model.parameters(),
                                          lr=args.lr,
                                          betas=(0.9, 0.99),
                                          eps=1e-6,
                                          weight_decay=args.weight_decay)
        amp_model = model

    if args.distributed:
        if args.ddp_backend == "apex":
            from apex.parallel import DistributedDataParallel as DDP
            para_model = DDP(amp_model)
        else:
            from torch.nn.parallel import DistributedDataParallel as DDP
            para_model = DDP(amp_model,
                             device_ids=[args.device_id],
                             find_unused_parameters=True)
    else:
        para_model = amp_model

    ############################################################################
    # Log args
    ############################################################################
    logging("=" * 100)
    for k, v in args.__dict__.items():
        logging("  - {} : {}".format(k, v))
    logging("=" * 100)

    ############################################################################
    # Training
    ############################################################################
    if not args.test_only:
        tr_loader = data.BucketIterator(tr_data, args.train_bsz, args.pad_id,
                                        args.seg_id_pad, args.device,
                                        args.max_length)

        if args.distributed:
            num_data = len(tr_data) // args.distributed_world_size
        else:
            num_data = len(tr_data)
        num_tr_batch = (num_data + args.train_bsz - 1) // args.train_bsz
        args.train_steps = num_tr_batch * args.epochs
        args.warmup_steps = int(args.train_steps * args.warmup_prop)

        num_example = torch.Tensor([0]).to(args.device)
        num_correct = torch.Tensor([0]).to(args.device)

        if args.dataset in ["CoLA"]:
            num_tp = torch.Tensor([0]).to(args.device)
            num_fp = torch.Tensor([0]).to(args.device)
            num_tn = torch.Tensor([0]).to(args.device)
            num_fn = torch.Tensor([0]).to(args.device)

        for epoch in range(args.epochs):
            #### One epoch
            for i, (sent, seg_id, label) in enumerate(
                    tr_loader.get_iter(epoch, distributed=args.distributed)):
                optimizer.zero_grad()
                _, ret_dict = para_model(sent, seg_id=seg_id, cls_target=label)
                cls_loss = ret_dict["cls_loss"]
                cls_corr = ret_dict["cls_corr"]

                if args.fp16:
                    with amp.scale_loss(cls_loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    cls_loss.backward()
                num_correct += cls_corr.detach()
                num_example += len(sent)
                if args.dataset in ["CoLA"]:
                    tp, fp, tn, fn = confusion_matrix(ret_dict["cls_pred"],
                                                      label)
                    num_tp = num_tp + tp
                    num_fp = num_fp + fp
                    num_tn = num_tn + tn
                    num_fn = num_fn + fn

                if args.clip > 0:
                    if args.fp16:
                        gnorm = torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.clip)
                    else:
                        gnorm = torch.nn.utils.clip_grad_norm_(
                            model.parameters(), args.clip)
                else:
                    gnorm = 0
                    for p in model.parameters():
                        if p.grad is not None:
                            param_gnorm = p.grad.data.norm(2)
                            gnorm += param_gnorm.item()**2
                    gnorm = gnorm**(1. / 2)
                train_step += 1
                adjust_lr(args, train_step, optimizer)
                optimizer.step()

                ##### training stat
                if (i + 1) % (num_tr_batch // args.n_log_epoch) == 0:
                    if args.distributed:
                        torch.distributed.all_reduce(
                            num_correct, op=torch.distributed.ReduceOp.SUM)
                        torch.distributed.all_reduce(
                            num_example, op=torch.distributed.ReduceOp.SUM)
                        if args.dataset in ["CoLA"]:
                            torch.distributed.all_reduce(
                                num_tp, op=torch.distributed.ReduceOp.SUM)
                            torch.distributed.all_reduce(
                                num_fp, op=torch.distributed.ReduceOp.SUM)
                            torch.distributed.all_reduce(
                                num_tn, op=torch.distributed.ReduceOp.SUM)
                            torch.distributed.all_reduce(
                                num_fn, op=torch.distributed.ReduceOp.SUM)

                    if is_master(args):
                        if args.dataset in ["CoLA"]:
                            corref = _compute_metric_based_on_keys(
                                "corr", num_tp.item(), num_fp.item(),
                                num_tn.item(), num_fn.item())
                            logging(
                                "[{:>02d}/{:>08d}] Train | corref {:.4f} | gnorm {:.2f} "
                                "| lr {:.6f}".format(
                                    epoch, train_step, corref, gnorm,
                                    optimizer.param_groups[0]["lr"]))
                        else:
                            accuracy = num_correct.item() / num_example.item()
                            logging(
                                "[{:>02d}/{:>08d}] Train | accu {:.4f} | gnorm {:.2f} "
                                "| lr {:.6f}".format(
                                    epoch, train_step, accuracy, gnorm,
                                    optimizer.param_groups[0]["lr"]))
                    num_example.zero_()
                    num_correct.zero_()
                    if args.dataset in ["CoLA"]:
                        num_tp.zero_()
                        num_fp.zero_()
                        num_tn.zero_()
                        num_fn.zero_()

                ##### validation
                if train_step % (args.train_steps // 10) == 0:
                    accuracy = evaluate(args, model, va_loader)
                    if is_master(args):
                        if accuracy > best_accuracy:
                            torch.save([model.state_dict(), optimizer],
                                       args.model_path)
                            torch.save([best_accuracy, train_step],
                                       args.var_path)
                        best_accuracy = max(accuracy, best_accuracy)
                        logging(
                            "[{}] Valid | curr accu {:.4f} | best accu {:.4f}".
                            format(train_step // (args.train_steps // 10),
                                   accuracy, best_accuracy))

    ##### make prediction
    if is_master(args) and args.write_prediction:
        rev_label_dict = dict((v, k) for k, v in label_dict.items())
        model.load_state_dict(torch.load(args.model_path,
                                         map_location="cpu")[0],
                              strict=False)
        model = model.to(args.device)
        predict(args, model, te_loader,
                os.path.join(args.model_dir, "test_results.txt"),
                rev_label_dict)
        predict(args, model, va_loader,
                os.path.join(args.model_dir, "valid_results.txt"),
                rev_label_dict)
示例#10
0
def save_checkpoint(args, controller, epoch_itr, val_loss):
    import distributed_utils, meters

    prev_best = getattr(save_checkpoint, 'best', val_loss)
    if val_loss is not None:
        best_function = max if args.maximize_best_checkpoint_metric else min
        save_checkpoint.best = best_function(val_loss, prev_best)

    if args.no_save or not distributed_utils.is_master(args):
        return

    def is_better(a, b):
        return a >= b if args.maximize_best_checkpoint_metric else a <= b

    write_timer = meters.StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = controller.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
        end_of_epoch and not args.no_epoch_checkpoints
        and epoch % args.save_interval == 0)
    checkpoint_conds['checkpoint_{}_{}.pt'.format(
        epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0
                            and updates % args.save_interval_updates == 0)
    checkpoint_conds['checkpoint_best.pt'] = (
        val_loss is not None
        and (not hasattr(save_checkpoint, 'best')
             or is_better(val_loss, save_checkpoint.best)))
    checkpoint_conds['checkpoint_last.pt'] = not args.no_last_checkpoints

    extra_state = {
        'train_iterator': epoch_itr.state_dict(),
        'val_loss': val_loss,
    }
    if hasattr(save_checkpoint, 'best'):
        extra_state.update({'best': save_checkpoint.best})

    checkpoints = [
        os.path.join(args.save_dir, fn)
        for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        controller.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            shutil.copyfile(checkpoints[0], cp)

        write_timer.stop()
        print(
            '| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'
            .format(checkpoints[0], epoch, updates, write_timer.sum))

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(
            args.save_dir,
            pattern=r'checkpoint_\d+_(\d+)\.pt',
        )
        for old_chk in checkpoints[args.keep_interval_updates:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(
            args.save_dir,
            pattern=r'checkpoint(\d+)\.pt',
        )
        for old_chk in checkpoints[args.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)